blob: b9403588c96c0e9ebcfa5b938e3db7a4871f317a [file] [log] [blame]
// Copyright 2020 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "compiler/plugins/target/MetalSPIRV/MSLToMetalLib.h"
#include "compiler/plugins/target/MetalSPIRV/MetalTargetPlatform.h"
#include "compiler/plugins/target/MetalSPIRV/SPIRVToMSL.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.h"
#include "iree/compiler/Codegen/SPIRV/Passes.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "iree/compiler/Dialect/HAL/Utils/ExecutableDebugInfoUtils.h"
#include "iree/compiler/PluginAPI/Client.h"
#include "iree/compiler/Utils/FlatbufferUtils.h"
#include "iree/schemas/metal_executable_def_builder.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/TargetParser/Host.h"
#include "llvm/TargetParser/Triple.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Target/SPIRV/Serialization.h"
namespace mlir::iree_compiler::IREE::HAL {
namespace {
struct MetalSPIRVOptions {
MetalTargetPlatform targetPlatform = MetalTargetPlatform::macOS;
bool compileToMetalLib = true;
void bindOptions(OptionsBinder &binder) {
static llvm::cl::OptionCategory category("MetalSPIRV HAL Target");
binder.opt<MetalTargetPlatform>(
"iree-metal-target-platform", targetPlatform, llvm::cl::cat(category),
llvm::cl::desc("Apple platform to target"),
llvm::cl::values(
clEnumValN(MetalTargetPlatform::macOS, "macos", "macOS platform"),
clEnumValN(MetalTargetPlatform::iOS, "ios", "iOS platform"),
clEnumValN(MetalTargetPlatform::iOSSimulator, "ios-simulator",
"iOS simulator platform")));
binder.opt<bool>(
"iree-metal-compile-to-metallib", compileToMetalLib,
llvm::cl::cat(category),
llvm::cl::desc("Compile to .metallib and embed in IREE deployable "
"flatbuffer if true; "
"otherwise stop at and embed MSL source code"));
}
};
} // namespace
// TODO: MetalOptions for choosing the Metal version.
class MetalTargetDevice : public TargetDevice {
public:
MetalTargetDevice(const MetalSPIRVOptions &options) : options(options) {}
IREE::HAL::DeviceTargetAttr
getDefaultDeviceTarget(MLIRContext *context,
const TargetRegistry &targetRegistry) const override {
Builder b(context);
auto configAttr = b.getDictionaryAttr({});
// If we had multiple target environments we would generate one target attr
// per environment, with each setting its own environment attribute.
SmallVector<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs;
targetRegistry.getTargetBackend("metal-spirv")
->getDefaultExecutableTargets(context, "metal", configAttr,
executableTargetAttrs);
return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("metal"),
configAttr, executableTargetAttrs);
}
private:
const MetalSPIRVOptions &options;
};
class MetalSPIRVTargetBackend : public TargetBackend {
public:
MetalSPIRVTargetBackend(const MetalSPIRVOptions &options)
: options(options) {}
std::string getLegacyDefaultDeviceID() const override { return "metal"; }
void getDefaultExecutableTargets(
MLIRContext *context, StringRef deviceID, DictionaryAttr deviceConfigAttr,
SmallVectorImpl<IREE::HAL::ExecutableTargetAttr> &executableTargetAttrs)
const override {
executableTargetAttrs.push_back(getExecutableTarget(context));
}
IREE::HAL::ExecutableTargetAttr
getExecutableTarget(MLIRContext *context) const {
Builder b(context);
SmallVector<NamedAttribute, 1> configItems;
if (auto target = GPU::getMetalTargetDetails(context)) {
addConfigGPUTarget(context, target, configItems);
}
return b.getAttr<IREE::HAL::ExecutableTargetAttr>(
b.getStringAttr("metal-spirv"), b.getStringAttr("metal-msl-fb"),
b.getDictionaryAttr(configItems));
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<gpu::GPUDialect, IREE::Codegen::IREECodegenDialect,
IREE::Flow::FlowDialect, spirv::SPIRVDialect,
IREE::GPU::IREEGPUDialect>();
}
void
buildConfigurationPassPipeline(IREE::HAL::ExecutableTargetAttr targetAttr,
OpPassManager &passManager) override {
buildSPIRVCodegenConfigurationPassPipeline(passManager);
}
void buildTranslationPassPipeline(IREE::HAL::ExecutableTargetAttr targetAttr,
OpPassManager &passManager) override {
buildSPIRVCodegenPassPipeline(passManager);
}
LogicalResult serializeExecutable(const SerializationOptions &serOptions,
IREE::HAL::ExecutableVariantOp variantOp,
OpBuilder &executableBuilder) override {
ModuleOp innerModuleOp = variantOp.getInnerModule();
// TODO: rework this to compile all modules into the same metallib and
// source the entry points from them. Or use a linking tool (metal-ar) to
// link the compiled metallibs together. If we were not using spirv-cross
// we'd never do it like this with one module per function.
//
// Currently this is _really_ bad because it doesn't support linking like
// the Vulkan SPIR-V target: that allows multiple spirv::ModuleOps so we
// at least only have a single HAL executable; this should all be reworked
// to have multiple SPIR-V modules in a single executable and then even if
// passing through spirv-cross independently should link the resulting
// metallibs together.
auto spvModuleOp = *innerModuleOp.getOps<spirv::ModuleOp>().begin();
if (!serOptions.dumpIntermediatesPath.empty()) {
std::string assembly;
llvm::raw_string_ostream os(assembly);
spvModuleOp.print(os, OpPrintingFlags().useLocalScope());
dumpDataToPath(serOptions.dumpIntermediatesPath, serOptions.dumpBaseName,
variantOp.getName(), ".mlir", assembly);
}
// 1. Serialize the spirv::ModuleOp into binary format.
SmallVector<uint32_t, 0> spvBinary;
if (failed(spirv::serialize(spvModuleOp, spvBinary))) {
return variantOp.emitError() << "failed to serialize spirv.module";
}
if (!serOptions.dumpIntermediatesPath.empty()) {
dumpDataToPath<uint32_t>(serOptions.dumpIntermediatesPath,
serOptions.dumpBaseName, variantOp.getName(),
".spv", spvBinary);
}
// The runtime use ordinals instead of names but Metal requires function
// names for constructing pipeline states. Get an ordered list of the entry
// point names.
SmallVector<StringRef, 8> spirvEntryPointNames;
spvModuleOp.walk([&](spirv::EntryPointOp exportOp) {
spirvEntryPointNames.push_back(exportOp.getFn());
});
// 2. Cross compile SPIR-V to MSL source code.
SmallVector<MetalShader, 2> mslShaders;
SmallVector<std::string, 2> mslEntryPointNames;
mslShaders.reserve(spirvEntryPointNames.size());
mslEntryPointNames.reserve(spirvEntryPointNames.size());
for (const auto &entryPoint : spirvEntryPointNames) {
// We can use ArrayRef here given spvBinary reserves 0 bytes on stack.
ArrayRef spvData(spvBinary.data(), spvBinary.size());
std::optional<std::pair<MetalShader, std::string>> msl =
crossCompileSPIRVToMSL(options.targetPlatform, spvData, entryPoint);
if (!msl) {
return variantOp.emitError()
<< "failed to cross compile SPIR-V to Metal shader";
}
mslShaders.push_back(std::move(msl->first));
mslEntryPointNames.push_back(std::move(msl->second));
}
if (!serOptions.dumpBinariesPath.empty()) {
for (auto shader : llvm::enumerate(mslShaders)) {
dumpDataToPath(
serOptions.dumpBinariesPath, serOptions.dumpBaseName,
(variantOp.getName() + std::to_string(shader.index())).str(),
".metal", shader.value().source);
}
}
// 3. Compile MSL to MTLLibrary.
SmallVector<std::unique_ptr<llvm::MemoryBuffer>> metallibs;
metallibs.resize(mslShaders.size());
if (options.compileToMetalLib) {
// We need to use offline Metal shader compilers.
// TODO(#14048): The toolchain can also exist on other platforms. Probe
// the PATH instead.
auto hostTriple = llvm::Triple(llvm::sys::getProcessTriple());
if (hostTriple.isMacOSX()) {
for (auto [i, shader, entryPoint] :
llvm::zip_equal(llvm::seq(mslShaders.size()), mslShaders,
mslEntryPointNames)) {
std::unique_ptr<llvm::MemoryBuffer> lib = compileMSLToMetalLib(
options.targetPlatform, shader.source, entryPoint);
if (!lib) {
return variantOp.emitError()
<< "failed to compile to MTLLibrary from MSL:\n\n"
<< shader.source << "\n\n";
}
metallibs[i] = std::move(lib);
}
}
}
// 4. Pack the MTLLibrary and metadata into a FlatBuffer.
FlatbufferBuilder builder;
iree_hal_metal_ExecutableDef_start_as_root(builder);
// Attach embedded source file contents.
auto sourceFilesRef = createSourceFilesVec(
serOptions.debugLevel, variantOp.getSourcesAttr(), builder);
// Each library may provide multiple functions so we encode them
// independently.
SmallVector<iree_hal_metal_LibraryDef_ref_t> libraryRefs;
for (auto [shader, metallib] : llvm::zip_equal(mslShaders, metallibs)) {
const bool embedSource = !metallib || serOptions.debugLevel > 1;
iree_hal_metal_MSLSourceDef_ref_t sourceRef = 0;
if (embedSource) {
// TODO: pull this from an attribute?
// https://developer.apple.com/documentation/metal/mtllanguageversion
unsigned version = 196608; // MTLLanguageVersion3_0
auto sourceStrRef = builder.createString(shader.source);
sourceRef =
iree_hal_metal_MSLSourceDef_create(builder, version, sourceStrRef);
}
flatbuffers_string_ref_t metallibRef = 0;
if (metallib) {
metallibRef = flatbuffers_string_create(
builder, metallib->getBufferStart(), metallib->getBufferSize());
}
iree_hal_metal_LibraryDef_start(builder);
iree_hal_metal_LibraryDef_source_add(builder, sourceRef);
iree_hal_metal_LibraryDef_metallib_add(builder, metallibRef);
libraryRefs.push_back(iree_hal_metal_LibraryDef_end(builder));
}
auto librariesRef = builder.createOffsetVecDestructive(libraryRefs);
// Generate optional per-export debug information.
// May be empty if no debug information was requested.
auto exportOps = llvm::to_vector_of<IREE::HAL::ExecutableExportOp>(
variantOp.getExportOps());
auto exportDebugInfos =
createExportDefs(serOptions.debugLevel, exportOps, builder);
SmallVector<iree_hal_metal_PipelineDef_ref_t> pipelineRefs;
for (auto [i, shader, entryPoint, exportOp] :
llvm::zip_equal(llvm::seq(mslShaders.size()), mslShaders,
mslEntryPointNames, exportOps)) {
auto entryPointRef = builder.createString(entryPoint);
iree_hal_metal_ThreadgroupSize_t threadgroupSize = {
shader.threadgroupSize.x,
shader.threadgroupSize.y,
shader.threadgroupSize.z,
};
auto layoutAttr = exportOp.getLayoutAttr();
uint32_t constantCount = static_cast<uint32_t>(layoutAttr.getConstants());
SmallVector<iree_hal_metal_BindingBits_enum_t> bindingFlags;
for (auto bindingAttr : layoutAttr.getBindings()) {
iree_hal_metal_BindingBits_enum_t flags = 0;
if (allEnumBitsSet(bindingAttr.getFlags(),
IREE::HAL::DescriptorFlags::ReadOnly)) {
flags |= iree_hal_metal_BindingBits_IMMUTABLE;
}
bindingFlags.push_back(flags);
}
auto bindingFlagsRef = iree_hal_metal_BindingBits_vec_create(
builder, bindingFlags.data(), bindingFlags.size());
iree_hal_metal_PipelineDef_start(builder);
iree_hal_metal_PipelineDef_library_ordinal_add(builder, i);
iree_hal_metal_PipelineDef_entry_point_add(builder, entryPointRef);
iree_hal_metal_PipelineDef_threadgroup_size_add(builder,
&threadgroupSize);
// TODO: embed additional metadata on threadgroup info if available.
// iree_hal_metal_PipelineDef_max_threads_per_threadgroup_add(builder, 0);
// iree_hal_metal_PipelineDef_threadgroup_size_aligned_add(builder,
// false);
iree_hal_metal_PipelineDef_constant_count_add(builder, constantCount);
iree_hal_metal_PipelineDef_binding_flags_add(builder, bindingFlagsRef);
iree_hal_metal_PipelineDef_debug_info_add(builder, exportDebugInfos[i]);
pipelineRefs.push_back(iree_hal_metal_PipelineDef_end(builder));
}
auto pipelinesRef = builder.createOffsetVecDestructive(pipelineRefs);
iree_hal_metal_ExecutableDef_pipelines_add(builder, pipelinesRef);
iree_hal_metal_ExecutableDef_libraries_add(builder, librariesRef);
iree_hal_metal_ExecutableDef_source_files_add(builder, sourceFilesRef);
iree_hal_metal_ExecutableDef_end_as_root(builder);
// 5. Add the binary data to the target executable.
auto binaryOp = IREE::HAL::ExecutableBinaryOp::create(
executableBuilder, variantOp.getLoc(), variantOp.getSymName(),
variantOp.getTarget().getFormat(),
builder.getBufferAttr(executableBuilder.getContext()));
binaryOp.setMimeTypeAttr(
executableBuilder.getStringAttr("application/x-flatbuffers"));
return success();
}
private:
const MetalSPIRVOptions &options;
};
struct MetalSPIRVSession
: public PluginSession<MetalSPIRVSession, MetalSPIRVOptions,
PluginActivationPolicy::DefaultActivated> {
void populateHALTargetDevices(IREE::HAL::TargetDeviceList &targets) {
// #hal.device.target<"metal", ...
targets.add("metal",
[=]() { return std::make_shared<MetalTargetDevice>(options); });
}
void populateHALTargetBackends(IREE::HAL::TargetBackendList &targets) {
// #hal.executable.target<"metal-spirv", ...
targets.add("metal-spirv", [=]() {
return std::make_shared<MetalSPIRVTargetBackend>(options);
});
}
};
} // namespace mlir::iree_compiler::IREE::HAL
IREE_DEFINE_COMPILER_OPTION_FLAGS(
mlir::iree_compiler::IREE::HAL::MetalSPIRVOptions);
extern "C" bool iree_register_compiler_plugin_hal_target_metal_spirv(
mlir::iree_compiler::PluginRegistrar *registrar) {
registrar->registerPlugin<mlir::iree_compiler::IREE::HAL::MetalSPIRVSession>(
"hal_target_metal_spirv");
return true;
}