| // Copyright 2021 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "compiler/plugins/target/WebGPUSPIRV/SPIRVToWGSL.h" |
| #include "iree/compiler/Codegen/Common/Passes.h" |
| #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h" |
| #include "iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.h" |
| #include "iree/compiler/Codegen/SPIRV/Passes.h" |
| #include "iree/compiler/Codegen/Utils/GPUUtils.h" |
| #include "iree/compiler/Codegen/WGSL/Passes.h" |
| #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" |
| #include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h" |
| #include "iree/compiler/Dialect/HAL/Utils/ExecutableDebugInfoUtils.h" |
| #include "iree/compiler/PluginAPI/Client.h" |
| #include "iree/compiler/Utils/FlatbufferUtils.h" |
| #include "iree/schemas/webgpu_executable_def_builder.h" |
| #include "llvm/Support/CommandLine.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
| #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
| #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" |
| #include "mlir/Dialect/SPIRV/Transforms/Passes.h" |
| #include "mlir/Target/SPIRV/Serialization.h" |
| |
| namespace mlir::iree_compiler::IREE::HAL { |
| namespace { |
| |
| struct WebGPUSPIRVOptions { |
| bool debugSymbols = true; |
| |
| void bindOptions(OptionsBinder &binder) { |
| static llvm::cl::OptionCategory category("WebGPU HAL Target"); |
| binder.opt<bool>( |
| "iree-webgpu-debug-symbols", debugSymbols, llvm::cl::cat(category), |
| llvm::cl::desc( |
| "Include debug information like variable names in outputs.")); |
| } |
| }; |
| |
| // TODO: WebGPUOptions for choosing the version/extensions/etc. |
| class WebGPUTargetDevice final : public TargetDevice { |
| public: |
| WebGPUTargetDevice(const WebGPUSPIRVOptions & /*options*/) {} |
| |
| IREE::HAL::DeviceTargetAttr |
| getDefaultDeviceTarget(MLIRContext *context, |
| const TargetRegistry &targetRegistry) const final { |
| Builder b(context); |
| auto configAttr = b.getDictionaryAttr({}); |
| |
| // If we had multiple target environments we would generate one target attr |
| // per environment, with each setting its own environment attribute. |
| SmallVector<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs; |
| targetRegistry.getTargetBackend("webgpu-spirv") |
| ->getDefaultExecutableTargets(context, "webgpu", configAttr, |
| executableTargetAttrs); |
| |
| return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("webgpu"), |
| configAttr, executableTargetAttrs); |
| } |
| }; |
| |
| class WebGPUSPIRVTargetBackend final : public TargetBackend { |
| public: |
| WebGPUSPIRVTargetBackend(const WebGPUSPIRVOptions &options) |
| : options(options) {} |
| |
| std::string getLegacyDefaultDeviceID() const final { return "webgpu"; } |
| |
| void getDefaultExecutableTargets( |
| MLIRContext *context, StringRef deviceID, DictionaryAttr deviceConfigAttr, |
| SmallVectorImpl<IREE::HAL::ExecutableTargetAttr> &executableTargetAttrs) |
| const final { |
| executableTargetAttrs.push_back(getExecutableTarget(context)); |
| } |
| |
| IREE::HAL::ExecutableTargetAttr |
| getExecutableTarget(MLIRContext *context) const { |
| Builder b(context); |
| SmallVector<NamedAttribute, 1> configItems; |
| if (auto target = GPU::getWebGPUTargetDetails(context)) { |
| addConfigGPUTarget(context, target, configItems); |
| } |
| |
| return b.getAttr<IREE::HAL::ExecutableTargetAttr>( |
| b.getStringAttr("webgpu-spirv"), b.getStringAttr("webgpu-wgsl-fb"), |
| b.getDictionaryAttr(configItems)); |
| } |
| |
| // TODO(scotttodd): Prune FlowDialect dep when WGSLReplacePushConstantsPass |
| // does not use the Flow dialect (TranslateExecutables calls this |
| // function and _does not_ query which passes are used by the dynamic |
| // pipeline created by buildTranslationPassPipeline) |
| void getDependentDialects(DialectRegistry ®istry) const final { |
| registry.insert<IREE::Codegen::IREECodegenDialect, IREE::Flow::FlowDialect, |
| spirv::SPIRVDialect, gpu::GPUDialect, |
| IREE::GPU::IREEGPUDialect>(); |
| } |
| |
| void |
| buildConfigurationPassPipeline(IREE::HAL::ExecutableTargetAttr targetAttr, |
| OpPassManager &passManager) final { |
| buildCodegenConfigurationPreProcessingPassPipeline(passManager); |
| buildSPIRVCodegenConfigurationPassPipeline(passManager.nest<ModuleOp>()); |
| } |
| |
| void buildTranslationPassPipeline(IREE::HAL::ExecutableTargetAttr targetAttr, |
| OpPassManager &passManager) final { |
| buildSPIRVCodegenPassPipeline(passManager.nest<ModuleOp>()); |
| buildCodegenTranslationPostProcessingPassPipeline(passManager); |
| |
| // Prepare SPIR-V for WebGPU by expanding or removing unsupported ops. |
| // For example, |
| // * WGSL does not support extended multiplication: |
| // https://github.com/gpuweb/gpuweb/issues/1565, so we lower to |
| // regular multiplication |
| // * WGSL does not support NaN or infinities: |
| // https://www.w3.org/TR/WGSL/#floating-point-evaluation |
| passManager.nest<ModuleOp>().nest<spirv::ModuleOp>().addPass( |
| spirv::createSPIRVWebGPUPreparePass()); |
| } |
| |
| LogicalResult serializeExecutable(const SerializationOptions &serOptions, |
| IREE::HAL::ExecutableVariantOp variantOp, |
| OpBuilder &executableBuilder) final { |
| ModuleOp innerModuleOp = variantOp.getInnerModule(); |
| auto spirvModuleOps = innerModuleOp.getOps<spirv::ModuleOp>(); |
| if (!llvm::hasSingleElement(spirvModuleOps)) { |
| // TODO(#7824): Implement linking / shader module combining and relax this |
| return variantOp.emitError() |
| << "should only contain exactly one spirv.module op"; |
| } |
| |
| auto spvModuleOp = *spirvModuleOps.begin(); |
| if (!serOptions.dumpIntermediatesPath.empty()) { |
| std::string assembly; |
| llvm::raw_string_ostream os(assembly); |
| spvModuleOp.print(os, OpPrintingFlags().useLocalScope()); |
| dumpDataToPath(serOptions.dumpIntermediatesPath, serOptions.dumpBaseName, |
| variantOp.getName(), ".mlir", assembly); |
| } |
| |
| // The schema expects each shader module to have entry points named "dN", |
| // where N is the entry point ordinal. For each executable entry point op, |
| // rename the entry point symbol using that convention. |
| auto exportOps = llvm::to_vector(variantOp.getExportOps()); |
| SymbolTableCollection symbolTable; |
| SymbolUserMap symbolUsers(symbolTable, variantOp); |
| for (auto exportOp : exportOps) { |
| auto ordinalAttr = exportOp.getOrdinalAttr(); |
| if (!ordinalAttr) { |
| return mlir::emitError(exportOp.getLoc()) |
| << "could not compile WebGPU binary: export op is missing " |
| "ordinal"; |
| } |
| int64_t ordinal = ordinalAttr.getInt(); |
| auto entryPointFunc = dyn_cast<spirv::FuncOp>( |
| SymbolTable::lookupSymbolIn(spvModuleOp, exportOp.getSymName())); |
| |
| std::string symbolName = llvm::formatv("d{}", ordinal); |
| mlir::StringAttr nameAttr = |
| mlir::StringAttr::get(variantOp->getContext(), symbolName); |
| |
| symbolUsers.replaceAllUsesWith(entryPointFunc, nameAttr); |
| exportOp.setName(symbolName); // Same symbol reference? Not in table? |
| SymbolTable::setSymbolName(entryPointFunc, symbolName); |
| } |
| |
| // Serialize the spirv::ModuleOp into binary format. |
| SmallVector<uint32_t, 0> spvBinary; |
| spirv::SerializationOptions spirvSerializationOptions; |
| spirvSerializationOptions.emitSymbolName = options.debugSymbols; |
| spirvSerializationOptions.emitDebugInfo = options.debugSymbols; |
| if (failed(spirv::serialize(spvModuleOp, spvBinary, |
| spirvSerializationOptions)) || |
| spvBinary.empty()) { |
| return variantOp.emitError() << "failed to serialize spirv.module"; |
| } |
| if (!serOptions.dumpIntermediatesPath.empty()) { |
| dumpDataToPath<uint32_t>(serOptions.dumpIntermediatesPath, |
| serOptions.dumpBaseName, variantOp.getName(), |
| ".spv", spvBinary); |
| } |
| |
| // Compile SPIR-V to WGSL source code. |
| auto wgsl = compileSPIRVToWGSL(spvBinary); |
| if (!wgsl.has_value()) { |
| return variantOp.emitError() |
| << "failed to compile SPIR-V to WGSL; see the preceding Tint " |
| "diagnostics. Use -iree-hal-dump-executable-intermediates to " |
| "capture the serialized SPIR-V module."; |
| } |
| if (!serOptions.dumpBinariesPath.empty()) { |
| dumpDataToPath(serOptions.dumpBinariesPath, serOptions.dumpBaseName, |
| variantOp.getName(), ".wgsl", wgsl.value()); |
| } |
| |
| // Pack the WGSL and metadata into a FlatBuffer. |
| FlatbufferBuilder builder; |
| iree_hal_webgpu_ExecutableDef_start_as_root(builder); |
| |
| // Attach embedded source file contents. |
| auto sourceFilesRef = createSourceFilesVec( |
| serOptions.debugLevel, variantOp.getSourcesAttr(), builder); |
| |
| iree_hal_webgpu_ShaderModuleDef_start(builder); |
| auto wgslRef = builder.createString(wgsl.value()); |
| iree_hal_webgpu_ShaderModuleDef_wgsl_source_add(builder, wgslRef); |
| // TODO(scotttodd): populate source map |
| auto shaderModuleRef = iree_hal_webgpu_ShaderModuleDef_end(builder); |
| |
| auto shaderModulesVec = iree_hal_webgpu_ShaderModuleDef_vec_create( |
| builder, &shaderModuleRef, /*len=*/1); |
| iree_hal_webgpu_ExecutableDef_shader_modules_add(builder, shaderModulesVec); |
| |
| // Generate optional per-export debug information. |
| // May be empty if no debug information was requested. |
| auto exportDebugInfos = |
| createExportDefs(serOptions.debugLevel, exportOps, builder); |
| |
| SmallVector<iree_hal_webgpu_ExportDef_ref_t> exportRefs; |
| exportRefs.resize(exportOps.size(), 0); |
| for (auto exportOp : exportOps) { |
| auto ordinalAttr = exportOp.getOrdinalAttr(); |
| if (!ordinalAttr) { |
| return mlir::emitError(exportOp.getLoc()) |
| << "could not compile WebGPU binary: export op is missing " |
| "ordinal"; |
| } |
| int64_t ordinal = ordinalAttr.getInt(); |
| |
| auto entryPointRef = builder.createString(exportOp.getName()); |
| |
| iree_hal_webgpu_WorkgroupSize_t workgroupSize = {0}; |
| if (auto workgroupSizeAttr = exportOp.getWorkgroupSize()) { |
| auto workgroupSizeDims = workgroupSizeAttr->getValue(); |
| workgroupSize.x = cast<IntegerAttr>(workgroupSizeDims[0]).getInt(); |
| workgroupSize.y = cast<IntegerAttr>(workgroupSizeDims[1]).getInt(); |
| workgroupSize.z = cast<IntegerAttr>(workgroupSizeDims[2]).getInt(); |
| } |
| |
| auto layoutAttr = exportOp.getLayoutAttr(); |
| uint32_t constantCount = static_cast<uint32_t>(layoutAttr.getConstants()); |
| SmallVector<iree_hal_webgpu_BindingBits_enum_t> bindingFlags; |
| for (auto bindingAttr : layoutAttr.getBindings()) { |
| iree_hal_webgpu_BindingBits_enum_t flags = 0; |
| if (allEnumBitsSet(bindingAttr.getFlags(), |
| IREE::HAL::DescriptorFlags::ReadOnly)) { |
| flags |= iree_hal_webgpu_BindingBits_READ_ONLY; |
| } |
| if (allEnumBitsSet(bindingAttr.getFlags(), |
| IREE::HAL::DescriptorFlags::Indirect)) { |
| flags |= iree_hal_webgpu_BindingBits_INDIRECT; |
| } |
| bindingFlags.push_back(flags); |
| } |
| auto bindingFlagsRef = iree_hal_webgpu_BindingBits_vec_create( |
| builder, bindingFlags.data(), bindingFlags.size()); |
| |
| iree_hal_webgpu_ExportDef_start(builder); |
| // We only have one shader module right now, so all point to index 0. |
| // TODO(#7824): Support multiple shader modules per executable. |
| iree_hal_webgpu_ExportDef_shader_module_ordinal_add(builder, 0); |
| iree_hal_webgpu_ExportDef_entry_point_add(builder, entryPointRef); |
| iree_hal_webgpu_ExportDef_workgroup_size_add(builder, &workgroupSize); |
| iree_hal_webgpu_ExportDef_constant_count_add(builder, constantCount); |
| iree_hal_webgpu_ExportDef_binding_flags_add(builder, bindingFlagsRef); |
| iree_hal_webgpu_ExportDef_debug_info_add(builder, |
| exportDebugInfos[ordinal]); |
| exportRefs[ordinal] = iree_hal_webgpu_ExportDef_end(builder); |
| } |
| auto exportsRef = builder.createOffsetVecDestructive(exportRefs); |
| |
| iree_hal_webgpu_ExecutableDef_exports_add(builder, exportsRef); |
| iree_hal_webgpu_ExecutableDef_source_files_add(builder, sourceFilesRef); |
| |
| iree_hal_webgpu_ExecutableDef_end_as_root(builder); |
| |
| // Add the binary data to the target executable. |
| auto binaryOp = IREE::HAL::ExecutableBinaryOp::create( |
| executableBuilder, variantOp.getLoc(), variantOp.getSymName(), |
| variantOp.getTarget().getFormat(), |
| builder.getHeaderPrefixedBufferAttr( |
| executableBuilder.getContext(), |
| /*magic=*/iree_hal_webgpu_ExecutableDef_file_identifier, |
| /*version=*/0)); |
| binaryOp.setMimeTypeAttr( |
| executableBuilder.getStringAttr("application/x-flatbuffers")); |
| |
| return success(); |
| } |
| |
| private: |
| const WebGPUSPIRVOptions &options; |
| }; |
| |
| struct WebGPUSPIRVSession final |
| : PluginSession<WebGPUSPIRVSession, WebGPUSPIRVOptions, |
| PluginActivationPolicy::DefaultActivated> { |
| void populateHALTargetDevices(IREE::HAL::TargetDeviceList &targets) final { |
| // #hal.device.target<"webgpu", ... |
| targets.add("webgpu", [this]() { |
| return std::make_shared<WebGPUTargetDevice>(options); |
| }); |
| } |
| void populateHALTargetBackends(IREE::HAL::TargetBackendList &targets) final { |
| // #hal.executable.target<"webgpu-spirv", ... |
| targets.add("webgpu-spirv", [this]() { |
| return std::make_shared<WebGPUSPIRVTargetBackend>(options); |
| }); |
| } |
| }; |
| |
| } // namespace |
| } // namespace mlir::iree_compiler::IREE::HAL |
| |
| IREE_DEFINE_COMPILER_OPTION_FLAGS( |
| mlir::iree_compiler::IREE::HAL::WebGPUSPIRVOptions); |
| |
| extern "C" bool iree_register_compiler_plugin_hal_target_webgpu_spirv( |
| mlir::iree_compiler::PluginRegistrar *registrar) { |
| registrar->registerPlugin<mlir::iree_compiler::IREE::HAL::WebGPUSPIRVSession>( |
| "hal_target_webgpu_spirv"); |
| return true; |
| } |