blob: 728a95287ec9c07c1170fb1151f01e1003d73e33 [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.h"
#include "iree/compiler/Conversion/Common/Attributes.h"
#include "iree/compiler/Dialect/HAL/Target/MetalSPIRV/SPIRVToMSL.h"
#include "iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "iree/compiler/Utils/FlatbufferUtils.h"
#include "iree/schemas/metal_executable_def_builder.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Target/SPIRV/Serialization.h"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace HAL {
MetalSPIRVTargetOptions getMetalSPIRVTargetOptionsFromFlags() {
MetalSPIRVTargetOptions targetOptions;
return targetOptions;
}
// TODO(antiagainst): provide a proper target environment for Metal.
static spirv::TargetEnvAttr getMetalTargetEnv(MLIRContext *context) {
auto triple = spirv::VerCapExtAttr::get(
spirv::Version::V_1_0, {spirv::Capability::Shader},
{spirv::Extension::SPV_KHR_storage_buffer_storage_class}, context);
return spirv::TargetEnvAttr::get(triple, spirv::Vendor::Unknown,
spirv::DeviceType::Unknown,
spirv::TargetEnvAttr::kUnknownDeviceID,
spirv::getDefaultResourceLimits(context));
}
class MetalSPIRVTargetBackend : public SPIRVTargetBackend {
public:
MetalSPIRVTargetBackend(MetalSPIRVTargetOptions options)
: SPIRVTargetBackend(SPIRVCodegenOptions()),
options_(std::move(options)) {}
// NOTE: we could vary this based on the options such as 'metal-v2'.
std::string name() const override { return "metal_spirv"; }
std::string filter_pattern() const override { return "metal*"; }
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<spirv::SPIRVDialect>();
}
void declareTargetOps(IREE::Flow::ExecutableOp sourceOp,
IREE::HAL::ExecutableOp executableOp) override {
declareTargetOpsForEnv(sourceOp, executableOp,
getMetalTargetEnv(sourceOp.getContext()));
}
LogicalResult serializeExecutable(IREE::HAL::ExecutableTargetOp targetOp,
OpBuilder &executableBuilder) override {
ModuleOp innerModuleOp = targetOp.getInnerModule();
auto spvModuleOp = *innerModuleOp.getOps<spirv::ModuleOp>().begin();
// The runtime use ordinals instead of names but Metal requires function
// names for constructing pipeline states. Get an ordered list of the entry
// point names.
SmallVector<StringRef, 8> entryPointNames;
if (auto scheduleAttr = innerModuleOp.getAttrOfType<ArrayAttr>(
iree_compiler::getEntryPointScheduleAttrName())) {
// We have multiple entry points in this module. Make sure the order
// specified in the schedule attribute is respected.
for (Attribute entryPoint : scheduleAttr) {
entryPointNames.push_back(entryPoint.cast<StringAttr>().getValue());
}
} else {
spvModuleOp.walk([&](spirv::EntryPointOp entryPointOp) {
entryPointNames.push_back(entryPointOp.fn());
});
}
// 1. Serialize the spirv::ModuleOp into binary format.
SmallVector<uint32_t, 0> spvBinary;
if (failed(spirv::serialize(spvModuleOp, spvBinary))) {
return targetOp.emitError() << "failed to serialize spv.module";
}
// 2. Cross compile SPIR-V to MSL source code.
llvm::SmallVector<MetalShader, 2> mslShaders;
for (const auto &entryPoint : entryPointNames) {
llvm::Optional<MetalShader> mslShader = crossCompileSPIRVToMSL(
// We can use ArrayRef here given spvBinary reserves 0 bytes on stack.
llvm::makeArrayRef(spvBinary.data(), spvBinary.size()), entryPoint);
if (!mslShader) {
return targetOp.emitError()
<< "failed to cross compile SPIR-V to Metal shader";
}
mslShaders.push_back(std::move(*mslShader));
}
// 3. Compile MSL to MTLLibrary.
// TODO(antiagainst): provide the option to compile the shaders into a
// library and embed in the flatbuffer. Metal provides APIs for compiling
// shader sources into a MTLLibrary at run-time, but does not provie
// a way to serialize the generated MTLLibrary. The only way available is
// to use command-line tools like `metal` and `metallib`. Likely we need
// to invoke them in C++.
// 4. Pack the MTLLibrary and metadata into a flatbuffer.
FlatbufferBuilder builder;
auto shaderSourcesRef = builder.createStringVec(llvm::map_range(
mslShaders, [&](const MetalShader &shader) { return shader.source; }));
iree_MetalThreadgroupSize_vec_start(builder);
for (auto &shader : mslShaders) {
iree_MetalThreadgroupSize_vec_push_create(
builder, shader.threadgroupSize.x, shader.threadgroupSize.y,
shader.threadgroupSize.z);
}
auto threadgroupSizesRef = iree_MetalThreadgroupSize_vec_end(builder);
auto entryPointNamesRef = builder.createStringVec(entryPointNames);
iree_MetalExecutableDef_start_as_root(builder);
iree_MetalExecutableDef_entry_points_add(builder, entryPointNamesRef);
iree_MetalExecutableDef_threadgroup_sizes_add(builder, threadgroupSizesRef);
iree_MetalExecutableDef_shader_sources_add(builder, shaderSourcesRef);
iree_MetalExecutableDef_end_as_root(builder);
// 5. Add the binary data to the target executable.
executableBuilder.create<IREE::HAL::ExecutableBinaryOp>(
targetOp.getLoc(),
static_cast<uint32_t>(IREE::HAL::ExecutableFormat::Metal),
builder.getBufferAttr(executableBuilder.getContext()));
return success();
}
protected:
MetalSPIRVTargetOptions options_;
};
void registerMetalSPIRVTargetBackends(
std::function<MetalSPIRVTargetOptions()> queryOptions) {
static TargetBackendRegistration registration("metal-spirv", [=]() {
return std::make_unique<MetalSPIRVTargetBackend>(queryOptions());
});
}
} // namespace HAL
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir