Add compiler target for WebGPU, using Tint for SPIR-V -> WGSL. (#7906)
Fixes https://github.com/google/iree/issues/7840
This can produce valid WGSL packed into a well formed IREE .vmfb file for simple programs like [simple_abs.mlir](https://github.com/google/iree/blob/main/iree/samples/models/simple_abs.mlir). More complex programs workarounds for push constant use, buffer mapping (https://github.com/google/iree/pull/7900), and other compiler changes.
diff --git a/iree/compiler/Dialect/HAL/Target/WebGPU/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/WebGPU/CMakeLists.txt
new file mode 100644
index 0000000..2cc4b61
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/WebGPU/CMakeLists.txt
@@ -0,0 +1,34 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+if(NOT "${IREE_TARGET_BACKEND_WEBGPU}")
+ return()
+endif()
+
+iree_cc_library(
+ NAME
+ WebGPU
+ HDRS
+ "SPIRVToWGSL.h"
+ "WebGPUTarget.h"
+ SRCS
+ "SPIRVToWGSL.cpp"
+ "WebGPUTarget.cpp"
+ DEPS
+ LLVMSupport
+ MLIRGPUOps
+ MLIRIR
+ MLIRSPIRV
+ MLIRSPIRVSerialization
+ SPIRV-Tools
+ iree::compiler::Codegen::Dialect::IREECodegenDialect
+ iree::compiler::Codegen::SPIRV
+ iree::compiler::Dialect::HAL::Target
+ iree::compiler::Utils
+ iree::schemas::wgsl_executable_def_c_fbs
+ libtint
+ PUBLIC
+)
diff --git a/iree/compiler/Dialect/HAL/Target/WebGPU/SPIRVToWGSL.cpp b/iree/compiler/Dialect/HAL/Target/WebGPU/SPIRVToWGSL.cpp
new file mode 100644
index 0000000..18e86b9
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/WebGPU/SPIRVToWGSL.cpp
@@ -0,0 +1,75 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/HAL/Target/WebGPU/SPIRVToWGSL.h"
+
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include "tint/tint.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+
+llvm::Optional<std::string> compileSPIRVToWGSL(
+ llvm::ArrayRef<uint32_t> spvBinary) {
+ // TODO(scotttodd): reroute to MLIR diagnostics?
+ auto diagPrinter = tint::diag::Printer::create(stderr, true);
+ tint::diag::Formatter diagFormatter;
+
+ // TODO(scotttodd): remove this copy (API for std::span or [uint8_t*, size]?)
+ std::vector<uint32_t> binaryVector(spvBinary.size());
+ std::memcpy(binaryVector.data(), spvBinary.data(),
+ spvBinary.size() * sizeof(uint32_t));
+
+ auto program =
+ std::make_unique<tint::Program>(tint::reader::spirv::Parse(binaryVector));
+ if (!program) {
+ llvm::errs() << "Tint failed to parse SPIR-V program\n";
+ return llvm::None;
+ }
+
+ if (program->Diagnostics().contains_errors()) {
+ llvm::errs() << "Tint reported " << program->Diagnostics().error_count()
+ << " error(s) for a SPIR-V program, see diagnostics:\n";
+ diagFormatter.format(program->Diagnostics(), diagPrinter.get());
+ return llvm::None;
+ }
+
+ if (!program->IsValid()) {
+ llvm::errs() << "Tint parsed an invalid SPIR-V program\n";
+ return llvm::None;
+ }
+
+ // TODO(scotttodd): Refine this set of transforms
+ tint::transform::Manager transformManager;
+ tint::transform::DataMap transformInputs;
+ transformInputs.Add<tint::transform::FirstIndexOffset::BindingPoint>(0, 0);
+ transformManager.Add<tint::transform::FirstIndexOffset>();
+ transformManager.Add<tint::transform::FoldTrivialSingleUseLets>();
+
+ auto output = transformManager.Run(program.get(), std::move(transformInputs));
+ if (!output.program.IsValid()) {
+ llvm::errs() << "Tint transforms failed on the parsed SPIR-V program\n";
+ diagFormatter.format(output.program.Diagnostics(), diagPrinter.get());
+ return llvm::None;
+ }
+
+ tint::writer::wgsl::Options genOptions;
+ auto result = tint::writer::wgsl::Generate(&output.program, genOptions);
+ if (!result.success) {
+ llvm::errs() << "Tint failed to generate WGSL: " << result.error << "\n";
+ return llvm::None;
+ }
+
+ return result.wgsl;
+}
+
+} // namespace HAL
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Target/WebGPU/SPIRVToWGSL.h b/iree/compiler/Dialect/HAL/Target/WebGPU/SPIRVToWGSL.h
new file mode 100644
index 0000000..c712631
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/WebGPU/SPIRVToWGSL.h
@@ -0,0 +1,30 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_HAL_TARGET_WEBGPU_SPIRVTOWGSL_H_
+#define IREE_COMPILER_DIALECT_HAL_TARGET_WEBGPU_SPIRVTOWGSL_H_
+
+#include <string>
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/Optional.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+
+// Compiles SPIR-V into WebGPU Shading Language (WGSL) source code.
+// Returns llvm::None on failure.
+llvm::Optional<std::string> compileSPIRVToWGSL(
+ llvm::ArrayRef<uint32_t> spvBinary);
+
+} // namespace HAL
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_DIALECT_HAL_TARGET_WEBGPU_SPIRVTOWGSL_H_
diff --git a/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.cpp b/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.cpp
new file mode 100644
index 0000000..66de197
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.cpp
@@ -0,0 +1,273 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.h"
+
+#include "iree/compiler/Codegen/Dialect/IREECodegenDialect.h"
+#include "iree/compiler/Codegen/Passes.h"
+#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
+#include "iree/compiler/Dialect/HAL/Target/WebGPU/SPIRVToWGSL.h"
+#include "iree/compiler/Utils/FlatbufferUtils.h"
+#include "iree/schemas/wgsl_executable_def_builder.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/ToolOutputFile.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/Target/SPIRV/Serialization.h"
+#include "spirv-tools/libspirv.hpp"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+
+WebGPUTargetOptions getWebGPUTargetOptionsFromFlags() {
+ static llvm::cl::opt<bool> clDebugSymbols(
+ "iree-webgpu-debug-symbols",
+ llvm::cl::desc(
+ "Include debug information like variable names in outputs"),
+ llvm::cl::init(false));
+
+ static llvm::cl::opt<bool> clWebGPUKeepShaderModules(
+ "iree-webgpu-keep-shader-modules",
+ llvm::cl::desc("Save shader modules to disk separately"),
+ llvm::cl::init(false));
+
+ WebGPUTargetOptions targetOptions;
+ targetOptions.keepShaderModules = clWebGPUKeepShaderModules;
+
+ return targetOptions;
+}
+
+// TODO(scotttodd): provide a proper target environment for WebGPU.
+static spirv::TargetEnvAttr getWebGPUTargetEnv(MLIRContext *context) {
+ // TODO(scotttodd): find list of SPIR-V extensions supported by WebGPU/WGSL
+ auto triple = spirv::VerCapExtAttr::get(
+ spirv::Version::V_1_0, {spirv::Capability::Shader},
+ {spirv::Extension::SPV_KHR_storage_buffer_storage_class}, context);
+ return spirv::TargetEnvAttr::get(triple, spirv::Vendor::Unknown,
+ spirv::DeviceType::Unknown,
+ spirv::TargetEnvAttr::kUnknownDeviceID,
+ spirv::getDefaultResourceLimits(context));
+}
+
+class WebGPUTargetBackend : public TargetBackend {
+ public:
+ WebGPUTargetBackend(WebGPUTargetOptions options)
+ : options_(std::move(options)) {}
+
+ // NOTE: we could vary this based on the options such as 'webgpu-v2'.
+ std::string name() const override { return "webgpu"; }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<IREE::Codegen::IREECodegenDialect, spirv::SPIRVDialect,
+ gpu::GPUDialect>();
+ }
+
+ IREE::HAL::DeviceTargetAttr getDefaultDeviceTarget(
+ MLIRContext *context) const override {
+ Builder b(context);
+ SmallVector<NamedAttribute> configItems;
+
+ configItems.emplace_back(b.getIdentifier("executable_targets"),
+ getExecutableTargets(context));
+
+ auto configAttr = b.getDictionaryAttr(configItems);
+ return IREE::HAL::DeviceTargetAttr::get(
+ context, b.getStringAttr(deviceID()), configAttr);
+ }
+
+ void buildTranslationPassPipeline(OpPassManager &passManager) override {
+ buildSPIRVCodegenPassPipeline(passManager);
+ // TODO(scotttodd): additional passes for WebGPU/WGSL
+ // (here or during serialization?)
+ }
+
+ LogicalResult serializeExecutable(IREE::HAL::ExecutableVariantOp variantOp,
+ OpBuilder &executableBuilder) override {
+ ModuleOp innerModuleOp = variantOp.getInnerModule();
+ auto spirvModuleOps = innerModuleOp.getOps<spirv::ModuleOp>();
+ if (!llvm::hasSingleElement(spirvModuleOps)) {
+ // TODO(#7824): Implement linking / shader module combining and relax this
+ return variantOp.emitError()
+ << "should only contain exactly one spv.module op";
+ }
+ auto spvModuleOp = *spirvModuleOps.begin();
+
+ // The schema expects each shader module to have entry points named "dN",
+ // where N is the entry point ordinal.
+ // For each executable entry point op, rename the entry point symbol using
+ // that convention and keep track of the mapping between entry point
+ // ordinals to which shader module they reference.
+ auto entryPointOps = llvm::to_vector<4>(
+ variantOp.getOps<IREE::HAL::ExecutableEntryPointOp>());
+ llvm::SmallVector<uint32_t, 4> entryPointOrdinals(entryPointOps.size());
+ SymbolTableCollection symbolTable;
+ SymbolUserMap symbolUsers(symbolTable, variantOp);
+ for (auto entryPointOp : entryPointOps) {
+ auto entryPointFunc = dyn_cast<spirv::FuncOp>(
+ SymbolTable::lookupSymbolIn(spvModuleOp, entryPointOp.sym_name()));
+
+ std::string symbolName = llvm::formatv("d{0}", entryPointOp.ordinal());
+ mlir::StringAttr nameAttr =
+ mlir::StringAttr::get(variantOp->getContext(), symbolName);
+
+ symbolUsers.replaceAllUsesWith(entryPointFunc, nameAttr);
+ entryPointOp.setName(symbolName); // Same symbol reference? Not in table?
+ SymbolTable::setSymbolName(entryPointFunc, symbolName);
+
+ // We only have one shader module right now, so all point to index 0.
+ // TODO(#7824): Support multiple shader modules per executable
+ entryPointOrdinals[entryPointOp.ordinal().getZExtValue()] = 0;
+ }
+
+ // Serialize the spirv::ModuleOp into binary format.
+ SmallVector<uint32_t, 0> spvBinary;
+ if (failed(spirv::serialize(spvModuleOp, spvBinary)) || spvBinary.empty()) {
+ return variantOp.emitError() << "failed to serialize spv.module";
+ }
+ if (options_.keepShaderModules) {
+ saveShaderToTempFile(variantOp, "spv",
+ reinterpret_cast<const char *>(spvBinary.data()),
+ spvBinary.size_in_bytes());
+
+ // Disassemble the shader and save that too.
+ // Note: this should match what getWebGPUTargetEnv used.
+ // TODO(scotttodd): Query spirv env from the executable variant?
+ spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_0);
+ std::string spvDisassembled;
+ if (spirvTools.Disassemble(
+ spvBinary.data(), spvBinary.size(), &spvDisassembled,
+ SPV_BINARY_TO_TEXT_OPTION_INDENT |
+ SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES)) {
+ saveShaderToTempFile(variantOp, "spvasm", spvDisassembled.data(),
+ spvDisassembled.size());
+ } else {
+ llvm::errs() << "Failed to disassemble SPIR-V binary\n";
+ }
+ }
+
+ // Compile SPIR-V to WGSL source code.
+ auto wgsl = compileSPIRVToWGSL(spvBinary);
+ if (!wgsl.hasValue()) {
+ // TODO(scotttodd): restructure branching and write disassembled SPIR-V
+ // to stderr / an error diagnostic (don't want to
+ // disassemble if successful + option not set, also
+ // don't want to disassemble twice :P)
+ return variantOp.emitError()
+ << "failed to compile SPIR-V to WGSL. Consider inspecting the "
+ "shader program using -iree-webgpu-keep-shader-modules";
+ }
+ if (options_.keepShaderModules) {
+ saveShaderToTempFile(variantOp, "wgsl", wgsl.getValue().data(),
+ wgsl.getValue().length());
+ }
+
+ // Pack the WGSL and metadata into a flatbuffer.
+ FlatbufferBuilder builder;
+ iree_WGSLExecutableDef_start_as_root(builder);
+
+ iree_WGSLShaderModuleDef_start(builder);
+ auto wgslRef = builder.createString(wgsl.getValue());
+ iree_WGSLShaderModuleDef_code_add(builder, wgslRef);
+ // TODO(scotttodd): populate source map
+ auto shaderModuleRef = iree_WGSLShaderModuleDef_end(builder);
+
+ auto shaderModulesVec = iree_WGSLShaderModuleDef_vec_create(
+ builder, &shaderModuleRef, /*len=*/1);
+ iree_WGSLExecutableDef_shader_modules_add(builder, shaderModulesVec);
+
+ auto entryPointsRef = flatbuffers_uint32_vec_create(
+ builder, entryPointOrdinals.data(), entryPointOrdinals.size());
+ iree_WGSLExecutableDef_entry_points_add(builder, entryPointsRef);
+
+ iree_WGSLExecutableDef_end_as_root(builder);
+
+ // Add the binary data to the target executable.
+ auto binaryOp = executableBuilder.create<IREE::HAL::ExecutableBinaryOp>(
+ variantOp.getLoc(), variantOp.sym_name(),
+ variantOp.target().getFormat(),
+ builder.getBufferAttr(executableBuilder.getContext()));
+ binaryOp.mime_typeAttr(
+ executableBuilder.getStringAttr("application/x-flatbuffers"));
+
+ return success();
+ }
+
+ private:
+ ArrayAttr getExecutableTargets(MLIRContext *context) const {
+ SmallVector<Attribute> targetAttrs;
+ // If we had multiple target environments we would generate one target attr
+ // per environment, with each setting its own environment attribute.
+ targetAttrs.push_back(
+ getExecutableTarget(context, getWebGPUTargetEnv(context)));
+ return ArrayAttr::get(context, targetAttrs);
+ }
+
+ IREE::HAL::ExecutableTargetAttr getExecutableTarget(
+ MLIRContext *context, spirv::TargetEnvAttr targetEnv) const {
+ Builder b(context);
+ SmallVector<NamedAttribute> configItems;
+
+ configItems.emplace_back(b.getIdentifier(spirv::getTargetEnvAttrName()),
+ targetEnv);
+
+ auto configAttr = b.getDictionaryAttr(configItems);
+ return IREE::HAL::ExecutableTargetAttr::get(
+ context, b.getStringAttr("webgpu"), b.getStringAttr("webgpu-wgsl-fb"),
+ configAttr);
+ }
+
+ void saveShaderToTempFile(IREE::HAL::ExecutableVariantOp variantOp,
+ llvm::StringRef suffix, const char *data,
+ size_t size) {
+ llvm::SmallString<32> filePath;
+ if (std::error_code error = llvm::sys::fs::createTemporaryFile(
+ variantOp.getName(), suffix, filePath)) {
+ llvm::errs() << "failed to generate temp file for shader: "
+ << error.message();
+ return;
+ }
+ std::error_code error;
+ auto file = std::make_unique<llvm::ToolOutputFile>(filePath, error,
+ llvm::sys::fs::OF_None);
+ if (error) {
+ llvm::errs() << "failed to open temp file for shader '" << filePath
+ << "': " << error.message();
+ return;
+ }
+
+ // TODO(scotttodd): refactor to group these messages
+ mlir::emitRemark(variantOp.getLoc())
+ << "Shader file for " << variantOp.getName() << " preserved:\n"
+ << " " << filePath;
+ file->os().write(data, size);
+ file->keep();
+ }
+
+ WebGPUTargetOptions options_;
+};
+
+void registerWebGPUTargetBackends(
+ std::function<WebGPUTargetOptions()> queryOptions) {
+ getWebGPUTargetOptionsFromFlags();
+ auto backendFactory = [=]() {
+ return std::make_shared<WebGPUTargetBackend>(queryOptions());
+ };
+ // #hal.device.target<"webgpu", ...
+ static TargetBackendRegistration registration0("webgpu", backendFactory);
+ // #hal.executable.target<"webgpu-wgsl", ...
+ static TargetBackendRegistration registration1("webgpu-wgsl", backendFactory);
+}
+
+} // namespace HAL
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.h b/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.h
new file mode 100644
index 0000000..eb77188
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.h
@@ -0,0 +1,40 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_HAL_TARGET_WEBGPU_WEBGPUTARGET_H_
+#define IREE_COMPILER_DIALECT_HAL_TARGET_WEBGPU_WEBGPUTARGET_H_
+
+#include <functional>
+#include <string>
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+
+// Options controlling the WebGPU/WGSL translation.
+struct WebGPUTargetOptions {
+ // Include debug information like variable names in outputs.
+ bool debugSymbols = true;
+
+ // True to keep shader modules for debugging.
+ bool keepShaderModules = false;
+};
+
+// Returns a WebGPUTargetOptions struct initialized with WebGPU/WGSL related
+// command-line flags.
+WebGPUTargetOptions getWebGPUTargetOptionsFromFlags();
+
+// Registers the WebGPU/WGSL backends.
+void registerWebGPUTargetBackends(
+ std::function<WebGPUTargetOptions()> queryOptions);
+
+} // namespace HAL
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_DIALECT_HAL_TARGET_WEBGPU_WEBGPUTARGET_H_
diff --git a/iree/tools/CMakeLists.txt b/iree/tools/CMakeLists.txt
index 3ca8abe..5c1fc8b 100644
--- a/iree/tools/CMakeLists.txt
+++ b/iree/tools/CMakeLists.txt
@@ -35,6 +35,10 @@
list(APPEND IREE_COMPILER_TARGETS iree::compiler::Dialect::HAL::Target::VulkanSPIRV)
list(APPEND IREE_COMPILER_TARGET_COPTS "-DIREE_HAVE_VULKANSPIRV_TARGET")
endif()
+if("${IREE_TARGET_BACKEND_WEBGPU}")
+ list(APPEND IREE_COMPILER_TARGETS iree::compiler::Dialect::HAL::Target::WebGPU)
+ list(APPEND IREE_COMPILER_TARGET_COPTS "-DIREE_HAVE_WEBGPU_TARGET")
+endif()
if("${IREE_TARGET_BACKEND_CUDA}")
list(APPEND IREE_COMPILER_TARGETS iree::compiler::Dialect::HAL::Target::CUDA)
list(APPEND IREE_COMPILER_TARGET_COPTS "-DIREE_HAVE_CUDA_TARGET")
diff --git a/iree/tools/init_targets.cc b/iree/tools/init_targets.cc
index 74e7e1b..cc30903 100644
--- a/iree/tools/init_targets.cc
+++ b/iree/tools/init_targets.cc
@@ -26,6 +26,9 @@
#ifdef IREE_HAVE_VULKANSPIRV_TARGET
#include "iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.h"
#endif // IREE_HAVE_VULKANSPIRV_TARGET
+#ifdef IREE_HAVE_WEBGPU_TARGET
+#include "iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.h"
+#endif // IREE_HAVE_WEBGPU_TARGET
namespace mlir {
namespace iree_compiler {
@@ -58,6 +61,10 @@
IREE::HAL::registerVulkanSPIRVTargetBackends(
[]() { return IREE::HAL::getVulkanSPIRVTargetOptionsFromFlags(); });
#endif // IREE_HAVE_VULKANSPIRV_TARGET
+#ifdef IREE_HAVE_WEBGPU_TARGET
+ IREE::HAL::registerWebGPUTargetBackends(
+ []() { return IREE::HAL::getWebGPUTargetOptionsFromFlags(); });
+#endif // IREE_HAVE_WEBGPU_TARGET
return true;
}();
(void)init_once;