blob: 90ecbe06b9c838d8e6cbcc23fc4dea97e53a0d4f [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "compiler/plugins/target/WebGPUSPIRV/SPIRVToWGSL.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.h"
#include "iree/compiler/Codegen/SPIRV/Passes.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/WGSL/Passes.h"
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "iree/compiler/Dialect/HAL/Utils/ExecutableDebugInfoUtils.h"
#include "iree/compiler/PluginAPI/Client.h"
#include "iree/compiler/Utils/FlatbufferUtils.h"
#include "iree/schemas/webgpu_executable_def_builder.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
#include "mlir/Target/SPIRV/Serialization.h"
#include "spirv-tools/libspirv.hpp"
namespace mlir::iree_compiler::IREE::HAL {
namespace {
struct WebGPUSPIRVOptions {
bool debugSymbols = true;
void bindOptions(OptionsBinder &binder) {
static llvm::cl::OptionCategory category("WebGPU HAL Target");
binder.opt<bool>(
"iree-webgpu-debug-symbols", debugSymbols, llvm::cl::cat(category),
llvm::cl::desc(
"Include debug information like variable names in outputs."));
}
};
// TODO: WebGPUOptions for choosing the version/extensions/etc.
class WebGPUTargetDevice : public TargetDevice {
public:
WebGPUTargetDevice(const WebGPUSPIRVOptions &options) : options(options) {}
IREE::HAL::DeviceTargetAttr
getDefaultDeviceTarget(MLIRContext *context,
const TargetRegistry &targetRegistry) const override {
Builder b(context);
auto configAttr = b.getDictionaryAttr({});
// If we had multiple target environments we would generate one target attr
// per environment, with each setting its own environment attribute.
SmallVector<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs;
targetRegistry.getTargetBackend("webgpu-spirv")
->getDefaultExecutableTargets(context, "webgpu", configAttr,
executableTargetAttrs);
return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("webgpu"),
configAttr, executableTargetAttrs);
}
private:
const WebGPUSPIRVOptions &options;
};
class WebGPUSPIRVTargetBackend : public TargetBackend {
public:
WebGPUSPIRVTargetBackend(const WebGPUSPIRVOptions &options)
: options(options) {}
std::string getLegacyDefaultDeviceID() const override { return "webgpu"; }
void getDefaultExecutableTargets(
MLIRContext *context, StringRef deviceID, DictionaryAttr deviceConfigAttr,
SmallVectorImpl<IREE::HAL::ExecutableTargetAttr> &executableTargetAttrs)
const override {
executableTargetAttrs.push_back(getExecutableTarget(context));
}
IREE::HAL::ExecutableTargetAttr
getExecutableTarget(MLIRContext *context) const {
Builder b(context);
SmallVector<NamedAttribute, 1> configItems;
if (auto target = GPU::getWebGPUTargetDetails(context)) {
addConfigGPUTarget(context, target, configItems);
}
return b.getAttr<IREE::HAL::ExecutableTargetAttr>(
b.getStringAttr("webgpu-spirv"), b.getStringAttr("webgpu-wgsl-fb"),
b.getDictionaryAttr(configItems));
}
// TODO(scotttodd): Prune FlowDialect dep when WGSLReplacePushConstantsPass
// does not use the Flow dialect (TranslateExecutables calls this
// function and _does not_ query which passes are used by the dynamic
// pipeline created by buildTranslationPassPipeline)
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::Codegen::IREECodegenDialect, IREE::Flow::FlowDialect,
spirv::SPIRVDialect, gpu::GPUDialect,
IREE::GPU::IREEGPUDialect>();
}
void
buildConfigurationPassPipeline(IREE::HAL::ExecutableTargetAttr targetAttr,
OpPassManager &passManager) override {
buildSPIRVCodegenConfigurationPassPipeline(passManager);
}
void buildTranslationPassPipeline(IREE::HAL::ExecutableTargetAttr targetAttr,
OpPassManager &passManager) override {
buildSPIRVCodegenPassPipeline(passManager);
// Prepare SPIR-V for WebGPU by expanding or removing unsupported ops.
// For example,
// * WGSL does not support extended multiplication:
// https://github.com/gpuweb/gpuweb/issues/1565, so we lower to
// regular multiplication
// * WGSL does not support NaN or infinities:
// https://www.w3.org/TR/WGSL/#floating-point-evaluation
passManager.nest<ModuleOp>().nest<spirv::ModuleOp>().addPass(
spirv::createSPIRVWebGPUPreparePass());
}
LogicalResult serializeExecutable(const SerializationOptions &serOptions,
IREE::HAL::ExecutableVariantOp variantOp,
OpBuilder &executableBuilder) override {
ModuleOp innerModuleOp = variantOp.getInnerModule();
auto spirvModuleOps = innerModuleOp.getOps<spirv::ModuleOp>();
if (!llvm::hasSingleElement(spirvModuleOps)) {
// TODO(#7824): Implement linking / shader module combining and relax this
return variantOp.emitError()
<< "should only contain exactly one spirv.module op";
}
auto spvModuleOp = *spirvModuleOps.begin();
if (!serOptions.dumpIntermediatesPath.empty()) {
std::string assembly;
llvm::raw_string_ostream os(assembly);
spvModuleOp.print(os, OpPrintingFlags().useLocalScope());
dumpDataToPath(serOptions.dumpIntermediatesPath, serOptions.dumpBaseName,
variantOp.getName(), ".mlir", assembly);
}
// The schema expects each shader module to have entry points named "dN",
// where N is the entry point ordinal.
// For each executable entry point op, rename the entry point symbol using
// that convention and keep track of the mapping between entry point
// ordinals to which shader module they reference.
auto exportOps = llvm::to_vector(variantOp.getExportOps());
llvm::SmallVector<uint32_t> entryPointOrdinals(exportOps.size());
SymbolTableCollection symbolTable;
SymbolUserMap symbolUsers(symbolTable, variantOp);
for (auto exportOp : exportOps) {
auto entryPointFunc = dyn_cast<spirv::FuncOp>(
SymbolTable::lookupSymbolIn(spvModuleOp, exportOp.getSymName()));
std::string symbolName = llvm::formatv("d{}", exportOp.getOrdinal());
mlir::StringAttr nameAttr =
mlir::StringAttr::get(variantOp->getContext(), symbolName);
symbolUsers.replaceAllUsesWith(entryPointFunc, nameAttr);
exportOp.setName(symbolName); // Same symbol reference? Not in table?
SymbolTable::setSymbolName(entryPointFunc, symbolName);
// We only have one shader module right now, so all point to index 0.
// TODO(#7824): Support multiple shader modules per executable.
uint64_t ordinal =
exportOp.getOrdinal().value_or(APInt(64, 0)).getZExtValue();
entryPointOrdinals[ordinal] = 0;
}
// Serialize the spirv::ModuleOp into binary format.
SmallVector<uint32_t, 0> spvBinary;
spirv::SerializationOptions spirvSerializationOptions;
spirvSerializationOptions.emitSymbolName = options.debugSymbols;
spirvSerializationOptions.emitDebugInfo = options.debugSymbols;
if (failed(spirv::serialize(spvModuleOp, spvBinary,
spirvSerializationOptions)) ||
spvBinary.empty()) {
return variantOp.emitError() << "failed to serialize spirv.module";
}
if (!serOptions.dumpIntermediatesPath.empty()) {
dumpDataToPath<uint32_t>(serOptions.dumpIntermediatesPath,
serOptions.dumpBaseName, variantOp.getName(),
".spv", spvBinary);
// Disassemble the shader and save that too.
// Note: this should match what getWebGPUTargetEnv used.
// TODO(scotttodd): Query spirv env from the executable variant?
spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_0);
std::string spvDisassembled;
if (spirvTools.Disassemble(
spvBinary.data(), spvBinary.size(), &spvDisassembled,
SPV_BINARY_TO_TEXT_OPTION_INDENT |
SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES)) {
dumpDataToPath(serOptions.dumpIntermediatesPath,
serOptions.dumpBaseName, variantOp.getName(), ".spvasm",
spvDisassembled);
} else {
llvm::errs() << "Failed to disassemble SPIR-V binary\n";
}
}
// Compile SPIR-V to WGSL source code.
auto wgsl = compileSPIRVToWGSL(spvBinary);
if (!wgsl.has_value()) {
// TODO(scotttodd): restructure branching and write disassembled SPIR-V
// to stderr / an error diagnostic (don't want to
// disassemble if successful + option not set, also
// don't want to disassemble twice :P)
return variantOp.emitError()
<< "failed to compile SPIR-V to WGSL. Consider inspecting the "
"shader program using -iree-hal-dump-executable-intermediates.";
}
if (!serOptions.dumpBinariesPath.empty()) {
dumpDataToPath(serOptions.dumpBinariesPath, serOptions.dumpBaseName,
variantOp.getName(), ".wgsl", wgsl.value());
}
// Pack the WGSL and metadata into a FlatBuffer.
FlatbufferBuilder builder;
iree_hal_webgpu_ExecutableDef_start_as_root(builder);
// Attach embedded source file contents.
auto sourceFilesRef = createSourceFilesVec(
serOptions.debugLevel, variantOp.getSourcesAttr(), builder);
iree_hal_webgpu_ShaderModuleDef_start(builder);
auto wgslRef = builder.createString(wgsl.value());
iree_hal_webgpu_ShaderModuleDef_wgsl_source_add(builder, wgslRef);
// TODO(scotttodd): populate source map
auto shaderModuleRef = iree_hal_webgpu_ShaderModuleDef_end(builder);
auto shaderModulesVec = iree_hal_webgpu_ShaderModuleDef_vec_create(
builder, &shaderModuleRef, /*len=*/1);
iree_hal_webgpu_ExecutableDef_shader_modules_add(builder, shaderModulesVec);
auto entryPointsRef = flatbuffers_uint32_vec_create(
builder, entryPointOrdinals.data(), entryPointOrdinals.size());
iree_hal_webgpu_ExecutableDef_entry_points_add(builder, entryPointsRef);
iree_hal_webgpu_ExecutableDef_source_files_add(builder, sourceFilesRef);
iree_hal_webgpu_ExecutableDef_end_as_root(builder);
// Add the binary data to the target executable.
auto binaryOp = IREE::HAL::ExecutableBinaryOp::create(
executableBuilder, variantOp.getLoc(), variantOp.getSymName(),
variantOp.getTarget().getFormat(),
builder.getBufferAttr(executableBuilder.getContext()));
binaryOp.setMimeTypeAttr(
executableBuilder.getStringAttr("application/x-flatbuffers"));
return success();
}
private:
const WebGPUSPIRVOptions &options;
};
struct WebGPUSPIRVSession
: public PluginSession<WebGPUSPIRVSession, WebGPUSPIRVOptions,
PluginActivationPolicy::DefaultActivated> {
void populateHALTargetDevices(IREE::HAL::TargetDeviceList &targets) {
// #hal.device.target<"webgpu", ...
targets.add("webgpu", [=]() {
return std::make_shared<WebGPUTargetDevice>(options);
});
}
void populateHALTargetBackends(IREE::HAL::TargetBackendList &targets) {
// #hal.executable.target<"webgpu-spirv", ...
targets.add("webgpu-spirv", [=]() {
return std::make_shared<WebGPUSPIRVTargetBackend>(options);
});
}
};
} // namespace
} // namespace mlir::iree_compiler::IREE::HAL
IREE_DEFINE_COMPILER_OPTION_FLAGS(
mlir::iree_compiler::IREE::HAL::WebGPUSPIRVOptions);
extern "C" bool iree_register_compiler_plugin_hal_target_webgpu_spirv(
mlir::iree_compiler::PluginRegistrar *registrar) {
registrar->registerPlugin<mlir::iree_compiler::IREE::HAL::WebGPUSPIRVSession>(
"hal_target_webgpu_spirv");
return true;
}