blob: 0d23fd679c44964753d8c5b8c981725a9f95aecf [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/compiler/Dialect/HAL/Target/MetalSPIRV/SPIRVToMSL.h"
#include <vector>
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "third_party/spirv_cross/spirv_msl.hpp"
#define DEBUG_TYPE "spirv-to-msl"
namespace mlir {
namespace iree_compiler {
namespace {
class SPIRVToMSLCompiler : public SPIRV_CROSS_NAMESPACE::CompilerMSL {
public:
using CompilerMSL::CompilerMSL;
MetalShader::ThreadGroupSize getWorkgroupSizeForEntryPoint(
StringRef entryName) {
const auto& entryPoint = get_entry_point(
entryName.str(), spv::ExecutionModel::ExecutionModelGLCompute);
const auto& workgroupSize = entryPoint.workgroup_size;
// TODO(antiagainst): support specialization constant.
if (workgroupSize.constant != 0) return {0, 0, 0};
return {workgroupSize.x, workgroupSize.y, workgroupSize.z};
}
// A struct containing a resource descriptor's information.
struct Descriptor {
uint32_t set;
uint32_t binding;
Descriptor(uint32_t s, uint32_t b) : set(s), binding(b) {}
friend bool operator<(const Descriptor& l, const Descriptor& r) {
return std::tie(l.set, l.binding) < std::tie(r.set, r.binding);
}
};
// Returns all all resource buffer descriptors' set and binding number pairs
// in increasing order.
std::vector<Descriptor> getBufferSetBindingPairs() {
std::vector<Descriptor> descriptors;
// Iterate over all variables in the SPIR-V blob.
ir.for_each_typed_id<SPIRV_CROSS_NAMESPACE::SPIRVariable>(
[&](uint32_t id, SPIRV_CROSS_NAMESPACE::SPIRVariable& var) {
auto storage = var.storage;
switch (storage) {
// Non-interface variables. We don't care.
case spv::StorageClassFunction:
case spv::StorageClassPrivate:
case spv::StorageClassWorkgroup:
// Builtin variables. We don't care either.
case spv::StorageClassInput:
return;
default:
break;
}
if (storage == spv::StorageClassUniform ||
storage == spv::StorageClassStorageBuffer) {
uint32_t setNo = get_decoration(id, spv::DecorationDescriptorSet);
uint32_t bindingNo = get_decoration(id, spv::DecorationBinding);
descriptors.emplace_back(setNo, bindingNo);
return;
}
// TODO(antiagainst): push constant
llvm_unreachable("unspported storage class in SPIRVToMSLCompiler");
});
llvm::sort(descriptors);
return descriptors;
}
Options getCompilationOptions() {
// TODO(antiagainst): fill out the following according to the Metal GPU
// family.
SPIRVToMSLCompiler::Options spvCrossOptions;
spvCrossOptions.platform = SPIRVToMSLCompiler::Options::Platform::macOS;
spvCrossOptions.msl_version =
SPIRVToMSLCompiler::Options::make_msl_version(2, 0);
// Eanble using Metal argument buffers. It is more akin to Vulkan descriptor
// sets, which is how IREE HAL models resource bindings and mappings.
spvCrossOptions.argument_buffers = true;
return spvCrossOptions;
}
};
} // namespace
llvm::Optional<MetalShader> crossCompileSPIRVToMSL(
llvm::ArrayRef<uint32_t> spvBinary, StringRef entryPoint) {
SPIRVToMSLCompiler spvCrossCompiler(spvBinary.data(), spvBinary.size());
// All spirv-cross operations work on the current entry point. It should be
// set right after the cross compiler construction.
spvCrossCompiler.set_entry_point(
entryPoint.str(), spv::ExecutionModel::ExecutionModelGLCompute);
// Explicitly set the argument buffer index for each SPIR-V resource variable.
auto descriptors = spvCrossCompiler.getBufferSetBindingPairs();
for (const auto& descriptor : descriptors) {
if (descriptor.set != 0) {
llvm_unreachable(
"multiple descriptor set unimplemented in SPIRVToMSLCompiler");
}
SPIRV_CROSS_NAMESPACE::MSLResourceBinding binding = {};
binding.stage = spv::ExecutionModelGLCompute;
binding.desc_set = descriptor.set;
binding.binding = descriptor.binding;
// We only interact with buffers in IREE.
binding.msl_buffer = descriptor.binding;
spvCrossCompiler.add_msl_resource_binding(binding);
}
auto spvCrossOptions = spvCrossCompiler.getCompilationOptions();
spvCrossCompiler.set_msl_options(spvCrossOptions);
std::string mslSource = spvCrossCompiler.compile();
LLVM_DEBUG(llvm::dbgs()
<< "Cross compiled Metal Shading Language source code:\n-----\n"
<< mslSource << "\n-----\n");
auto workgroupSize =
spvCrossCompiler.getWorkgroupSizeForEntryPoint(entryPoint);
if (!workgroupSize.x || !workgroupSize.y || !workgroupSize.z) {
return llvm::None;
}
return MetalShader{std::move(mslSource), workgroupSize};
}
} // namespace iree_compiler
} // namespace mlir