blob: dd39ebde4e2304b884690f5baea11fce563d1269 [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "./SPIRVToWGSL.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Codegen/SPIRV/Passes.h"
#include "iree/compiler/Codegen/WGSL/Passes.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
#include "iree/compiler/PluginAPI/Client.h"
#include "iree/compiler/Utils/FlatbufferUtils.h"
#include "iree/schemas/wgsl_executable_def_builder.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
#include "mlir/Target/SPIRV/Serialization.h"
#include "spirv-tools/libspirv.hpp"
namespace mlir::iree_compiler::IREE::HAL {
namespace {
struct WebGPUSPIRVOptions {
bool debugSymbols = true;
void bindOptions(OptionsBinder &binder) {
static llvm::cl::OptionCategory category("WebGPU HAL Target");
binder.opt<bool>(
"iree-webgpu-debug-symbols", debugSymbols, llvm::cl::cat(category),
llvm::cl::desc(
"Include debug information like variable names in outputs."));
}
};
// TODO(scotttodd): provide a proper target environment for WebGPU.
static spirv::TargetEnvAttr getWebGPUTargetEnv(MLIRContext *context) {
// TODO(scotttodd): find list of SPIR-V extensions supported by WebGPU/WGSL
auto triple = spirv::VerCapExtAttr::get(
spirv::Version::V_1_0, {spirv::Capability::Shader},
{spirv::Extension::SPV_KHR_storage_buffer_storage_class}, context);
return spirv::TargetEnvAttr::get(
triple, spirv::getDefaultResourceLimits(context),
spirv::ClientAPI::WebGPU, spirv::Vendor::Unknown,
spirv::DeviceType::Unknown, spirv::TargetEnvAttr::kUnknownDeviceID);
}
class WebGPUSPIRVTargetBackend : public TargetBackend {
public:
WebGPUSPIRVTargetBackend(const WebGPUSPIRVOptions &options)
: options(options) {}
// NOTE: we could vary this based on the options such as 'webgpu-v2'.
std::string name() const override { return "webgpu"; }
// TODO(scotttodd): Prune FlowDialect dep when WGSLReplacePushConstantsPass
// does not use the Flow dialect (TranslateExecutables calls this
// function and _does not_ query which passes are used by the dynamic
// pipeline created by buildTranslationPassPipeline)
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::Codegen::IREECodegenDialect, IREE::Flow::FlowDialect,
spirv::SPIRVDialect, gpu::GPUDialect>();
}
IREE::HAL::DeviceTargetAttr
getDefaultDeviceTarget(MLIRContext *context) const override {
Builder b(context);
SmallVector<NamedAttribute> configItems;
auto configAttr = b.getDictionaryAttr(configItems);
// If we had multiple target environments we would generate one target attr
// per environment, with each setting its own environment attribute.
SmallVector<IREE::HAL::ExecutableTargetAttr> targetAttrs;
targetAttrs.push_back(
getExecutableTarget(context, getWebGPUTargetEnv(context)));
return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("webgpu"),
configAttr, targetAttrs);
}
void buildConfigurationPassPipeline(IREE::HAL::ExecutableVariantOp variantOp,
OpPassManager &passManager) override {
// For now we disable configuration if the variant has external object
// files.
if (variantOp.isExternal())
return;
buildSPIRVCodegenConfigurationPassPipeline(passManager);
}
void buildTranslationPassPipeline(IREE::HAL::ExecutableVariantOp variantOp,
OpPassManager &passManager) override {
// For now we disable translation if the variant has external object files.
if (variantOp.isExternal())
return;
// WebGPU does not support push constants (yet?), so replace loads from
// push constants with loads from uniform buffers.
// The corresponding runtime code must perform similar emulation, based
// on the push constant count listed in the executable layout.
passManager.nest<ModuleOp>().nest<func::FuncOp>().addPass(
createWGSLReplacePushConstantsPass());
// From WGSL spec, "Floating Point Evaluation"
// (https://www.w3.org/TR/WGSL/#floating-point-evaluation):
// - Implementations may assume that NaNs and infinities are not present at
// runtime.
// - In such an implementation, when an evaluation would produce an
// infinity or a NaN, an undefined value of the target type is produced
// instead.
// So WebGPU effectively assumes fast math mode. We also don't have reliable
// ways to check whether a floating point number is NaN or infinity.
// Therefore, just let the SPIR-V CodeGen to avoid generating guards w.r.t.
// NaN and infinity.
buildSPIRVCodegenPassPipeline(passManager, /*enableFastMath=*/true);
// WGSL does not support extended multiplication:
// https://github.com/gpuweb/gpuweb/issues/1565. Make sure to lower it to
// regular multiplication before we convert SPIR-V to WGSL.
passManager.nest<ModuleOp>().nest<spirv::ModuleOp>().addPass(
spirv::createSPIRVWebGPUPreparePass());
}
LogicalResult serializeExecutable(const SerializationOptions &serOptions,
IREE::HAL::ExecutableVariantOp variantOp,
OpBuilder &executableBuilder) override {
ModuleOp innerModuleOp = variantOp.getInnerModule();
auto spirvModuleOps = innerModuleOp.getOps<spirv::ModuleOp>();
if (!llvm::hasSingleElement(spirvModuleOps)) {
// TODO(#7824): Implement linking / shader module combining and relax this
return variantOp.emitError()
<< "should only contain exactly one spirv.module op";
}
auto spvModuleOp = *spirvModuleOps.begin();
if (!serOptions.dumpIntermediatesPath.empty()) {
std::string assembly;
llvm::raw_string_ostream os(assembly);
spvModuleOp.print(os, OpPrintingFlags().useLocalScope());
dumpDataToPath(serOptions.dumpIntermediatesPath, serOptions.dumpBaseName,
variantOp.getName(), ".mlir", assembly);
}
// The schema expects each shader module to have entry points named "dN",
// where N is the entry point ordinal.
// For each executable entry point op, rename the entry point symbol using
// that convention and keep track of the mapping between entry point
// ordinals to which shader module they reference.
auto exportOps = llvm::to_vector(variantOp.getExportOps());
llvm::SmallVector<uint32_t> entryPointOrdinals(exportOps.size());
SymbolTableCollection symbolTable;
SymbolUserMap symbolUsers(symbolTable, variantOp);
for (auto exportOp : exportOps) {
auto entryPointFunc = dyn_cast<spirv::FuncOp>(
SymbolTable::lookupSymbolIn(spvModuleOp, exportOp.getSymName()));
std::string symbolName = llvm::formatv("d{0}", exportOp.getOrdinal());
mlir::StringAttr nameAttr =
mlir::StringAttr::get(variantOp->getContext(), symbolName);
symbolUsers.replaceAllUsesWith(entryPointFunc, nameAttr);
exportOp.setName(symbolName); // Same symbol reference? Not in table?
SymbolTable::setSymbolName(entryPointFunc, symbolName);
// We only have one shader module right now, so all point to index 0.
// TODO(#7824): Support multiple shader modules per executable.
uint64_t ordinal =
exportOp.getOrdinal().value_or(APInt(64, 0)).getZExtValue();
entryPointOrdinals[ordinal] = 0;
}
// Serialize the spirv::ModuleOp into binary format.
SmallVector<uint32_t, 0> spvBinary;
spirv::SerializationOptions spirvSerializationOptions;
spirvSerializationOptions.emitSymbolName = options.debugSymbols;
spirvSerializationOptions.emitDebugInfo = options.debugSymbols;
if (failed(spirv::serialize(spvModuleOp, spvBinary,
spirvSerializationOptions)) ||
spvBinary.empty()) {
return variantOp.emitError() << "failed to serialize spirv.module";
}
if (!serOptions.dumpIntermediatesPath.empty()) {
dumpDataToPath<uint32_t>(serOptions.dumpIntermediatesPath,
serOptions.dumpBaseName, variantOp.getName(),
".spv", spvBinary);
// Disassemble the shader and save that too.
// Note: this should match what getWebGPUTargetEnv used.
// TODO(scotttodd): Query spirv env from the executable variant?
spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_0);
std::string spvDisassembled;
if (spirvTools.Disassemble(
spvBinary.data(), spvBinary.size(), &spvDisassembled,
SPV_BINARY_TO_TEXT_OPTION_INDENT |
SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES)) {
dumpDataToPath(serOptions.dumpIntermediatesPath,
serOptions.dumpBaseName, variantOp.getName(), ".spvasm",
spvDisassembled);
} else {
llvm::errs() << "Failed to disassemble SPIR-V binary\n";
}
}
// Compile SPIR-V to WGSL source code.
auto wgsl = compileSPIRVToWGSL(spvBinary);
if (!wgsl.has_value()) {
// TODO(scotttodd): restructure branching and write disassembled SPIR-V
// to stderr / an error diagnostic (don't want to
// disassemble if successful + option not set, also
// don't want to disassemble twice :P)
return variantOp.emitError()
<< "failed to compile SPIR-V to WGSL. Consider inspecting the "
"shader program using -iree-hal-dump-executable-intermediates.";
}
if (!serOptions.dumpBinariesPath.empty()) {
dumpDataToPath(serOptions.dumpBinariesPath, serOptions.dumpBaseName,
variantOp.getName(), ".wgsl", wgsl.value());
}
// Pack the WGSL and metadata into a FlatBuffer.
FlatbufferBuilder builder;
iree_hal_wgsl_ExecutableDef_start_as_root(builder);
iree_hal_wgsl_ShaderModuleDef_start(builder);
auto wgslRef = builder.createString(wgsl.value());
iree_hal_wgsl_ShaderModuleDef_code_add(builder, wgslRef);
// TODO(scotttodd): populate source map
auto shaderModuleRef = iree_hal_wgsl_ShaderModuleDef_end(builder);
auto shaderModulesVec = iree_hal_wgsl_ShaderModuleDef_vec_create(
builder, &shaderModuleRef, /*len=*/1);
iree_hal_wgsl_ExecutableDef_shader_modules_add(builder, shaderModulesVec);
auto entryPointsRef = flatbuffers_uint32_vec_create(
builder, entryPointOrdinals.data(), entryPointOrdinals.size());
iree_hal_wgsl_ExecutableDef_entry_points_add(builder, entryPointsRef);
iree_hal_wgsl_ExecutableDef_end_as_root(builder);
// Add the binary data to the target executable.
auto binaryOp = executableBuilder.create<IREE::HAL::ExecutableBinaryOp>(
variantOp.getLoc(), variantOp.getSymName(),
variantOp.getTarget().getFormat(),
builder.getBufferAttr(executableBuilder.getContext()));
binaryOp.setMimeTypeAttr(
executableBuilder.getStringAttr("application/x-flatbuffers"));
return success();
}
private:
IREE::HAL::ExecutableTargetAttr
getExecutableTarget(MLIRContext *context,
spirv::TargetEnvAttr targetEnv) const {
Builder b(context);
SmallVector<NamedAttribute> configItems;
configItems.emplace_back(b.getStringAttr(spirv::getTargetEnvAttrName()),
targetEnv);
auto configAttr = b.getDictionaryAttr(configItems);
return IREE::HAL::ExecutableTargetAttr::get(
context, b.getStringAttr("webgpu-spirv"),
b.getStringAttr("webgpu-wgsl-fb"), configAttr);
}
const WebGPUSPIRVOptions &options;
};
struct WebGPUSPIRVSession
: public PluginSession<WebGPUSPIRVSession, WebGPUSPIRVOptions,
PluginActivationPolicy::DefaultActivated> {
void populateHALTargetBackends(IREE::HAL::TargetBackendList &targets) {
auto backendFactory = [=]() {
return std::make_shared<WebGPUSPIRVTargetBackend>(options);
};
// #hal.device.target<"webgpu", ...
targets.add("webgpu", backendFactory);
// #hal.executable.target<"webgpu-spirv", ...
targets.add("webgpu-spirv", backendFactory);
}
};
} // namespace
} // namespace mlir::iree_compiler::IREE::HAL
IREE_DEFINE_COMPILER_OPTION_FLAGS(
mlir::iree_compiler::IREE::HAL::WebGPUSPIRVOptions);
extern "C" bool iree_register_compiler_plugin_hal_target_webgpu_spirv(
mlir::iree_compiler::PluginRegistrar *registrar) {
registrar->registerPlugin<mlir::iree_compiler::IREE::HAL::WebGPUSPIRVSession>(
"hal_target_webgpu_spirv");
return true;
}