Add compiler target for WebGPU, using Tint for SPIR-V -> WGSL. (#7906)

Fixes https://github.com/google/iree/issues/7840

This can produce valid WGSL packed into a well formed IREE .vmfb file for simple programs like [simple_abs.mlir](https://github.com/google/iree/blob/main/iree/samples/models/simple_abs.mlir). More complex programs workarounds for push constant use, buffer mapping (https://github.com/google/iree/pull/7900), and other compiler changes.
diff --git a/iree/compiler/Dialect/HAL/Target/WebGPU/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/WebGPU/CMakeLists.txt
new file mode 100644
index 0000000..2cc4b61
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/WebGPU/CMakeLists.txt
@@ -0,0 +1,34 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+if(NOT "${IREE_TARGET_BACKEND_WEBGPU}")
+  return()
+endif()
+
+iree_cc_library(
+  NAME
+    WebGPU
+  HDRS
+    "SPIRVToWGSL.h"
+    "WebGPUTarget.h"
+  SRCS
+    "SPIRVToWGSL.cpp"
+    "WebGPUTarget.cpp"
+  DEPS
+    LLVMSupport
+    MLIRGPUOps
+    MLIRIR
+    MLIRSPIRV
+    MLIRSPIRVSerialization
+    SPIRV-Tools
+    iree::compiler::Codegen::Dialect::IREECodegenDialect
+    iree::compiler::Codegen::SPIRV
+    iree::compiler::Dialect::HAL::Target
+    iree::compiler::Utils
+    iree::schemas::wgsl_executable_def_c_fbs
+    libtint
+  PUBLIC
+)
diff --git a/iree/compiler/Dialect/HAL/Target/WebGPU/SPIRVToWGSL.cpp b/iree/compiler/Dialect/HAL/Target/WebGPU/SPIRVToWGSL.cpp
new file mode 100644
index 0000000..18e86b9
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/WebGPU/SPIRVToWGSL.cpp
@@ -0,0 +1,75 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/HAL/Target/WebGPU/SPIRVToWGSL.h"
+
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include "tint/tint.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+
+llvm::Optional<std::string> compileSPIRVToWGSL(
+    llvm::ArrayRef<uint32_t> spvBinary) {
+  // TODO(scotttodd): reroute to MLIR diagnostics?
+  auto diagPrinter = tint::diag::Printer::create(stderr, true);
+  tint::diag::Formatter diagFormatter;
+
+  // TODO(scotttodd): remove this copy (API for std::span or [uint8_t*, size]?)
+  std::vector<uint32_t> binaryVector(spvBinary.size());
+  std::memcpy(binaryVector.data(), spvBinary.data(),
+              spvBinary.size() * sizeof(uint32_t));
+
+  auto program =
+      std::make_unique<tint::Program>(tint::reader::spirv::Parse(binaryVector));
+  if (!program) {
+    llvm::errs() << "Tint failed to parse SPIR-V program\n";
+    return llvm::None;
+  }
+
+  if (program->Diagnostics().contains_errors()) {
+    llvm::errs() << "Tint reported " << program->Diagnostics().error_count()
+                 << " error(s) for a SPIR-V program, see diagnostics:\n";
+    diagFormatter.format(program->Diagnostics(), diagPrinter.get());
+    return llvm::None;
+  }
+
+  if (!program->IsValid()) {
+    llvm::errs() << "Tint parsed an invalid SPIR-V program\n";
+    return llvm::None;
+  }
+
+  // TODO(scotttodd): Refine this set of transforms
+  tint::transform::Manager transformManager;
+  tint::transform::DataMap transformInputs;
+  transformInputs.Add<tint::transform::FirstIndexOffset::BindingPoint>(0, 0);
+  transformManager.Add<tint::transform::FirstIndexOffset>();
+  transformManager.Add<tint::transform::FoldTrivialSingleUseLets>();
+
+  auto output = transformManager.Run(program.get(), std::move(transformInputs));
+  if (!output.program.IsValid()) {
+    llvm::errs() << "Tint transforms failed on the parsed SPIR-V program\n";
+    diagFormatter.format(output.program.Diagnostics(), diagPrinter.get());
+    return llvm::None;
+  }
+
+  tint::writer::wgsl::Options genOptions;
+  auto result = tint::writer::wgsl::Generate(&output.program, genOptions);
+  if (!result.success) {
+    llvm::errs() << "Tint failed to generate WGSL: " << result.error << "\n";
+    return llvm::None;
+  }
+
+  return result.wgsl;
+}
+
+}  // namespace HAL
+}  // namespace IREE
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Target/WebGPU/SPIRVToWGSL.h b/iree/compiler/Dialect/HAL/Target/WebGPU/SPIRVToWGSL.h
new file mode 100644
index 0000000..c712631
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/WebGPU/SPIRVToWGSL.h
@@ -0,0 +1,30 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_HAL_TARGET_WEBGPU_SPIRVTOWGSL_H_
+#define IREE_COMPILER_DIALECT_HAL_TARGET_WEBGPU_SPIRVTOWGSL_H_
+
+#include <string>
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/Optional.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+
+// Compiles SPIR-V into WebGPU Shading Language (WGSL) source code.
+// Returns llvm::None on failure.
+llvm::Optional<std::string> compileSPIRVToWGSL(
+    llvm::ArrayRef<uint32_t> spvBinary);
+
+}  // namespace HAL
+}  // namespace IREE
+}  // namespace iree_compiler
+}  // namespace mlir
+
+#endif  // IREE_COMPILER_DIALECT_HAL_TARGET_WEBGPU_SPIRVTOWGSL_H_
diff --git a/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.cpp b/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.cpp
new file mode 100644
index 0000000..66de197
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.cpp
@@ -0,0 +1,273 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.h"
+
+#include "iree/compiler/Codegen/Dialect/IREECodegenDialect.h"
+#include "iree/compiler/Codegen/Passes.h"
+#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
+#include "iree/compiler/Dialect/HAL/Target/WebGPU/SPIRVToWGSL.h"
+#include "iree/compiler/Utils/FlatbufferUtils.h"
+#include "iree/schemas/wgsl_executable_def_builder.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/ToolOutputFile.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/Target/SPIRV/Serialization.h"
+#include "spirv-tools/libspirv.hpp"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+
+WebGPUTargetOptions getWebGPUTargetOptionsFromFlags() {
+  static llvm::cl::opt<bool> clDebugSymbols(
+      "iree-webgpu-debug-symbols",
+      llvm::cl::desc(
+          "Include debug information like variable names in outputs"),
+      llvm::cl::init(false));
+
+  static llvm::cl::opt<bool> clWebGPUKeepShaderModules(
+      "iree-webgpu-keep-shader-modules",
+      llvm::cl::desc("Save shader modules to disk separately"),
+      llvm::cl::init(false));
+
+  WebGPUTargetOptions targetOptions;
+  targetOptions.keepShaderModules = clWebGPUKeepShaderModules;
+
+  return targetOptions;
+}
+
+// TODO(scotttodd): provide a proper target environment for WebGPU.
+static spirv::TargetEnvAttr getWebGPUTargetEnv(MLIRContext *context) {
+  // TODO(scotttodd): find list of SPIR-V extensions supported by WebGPU/WGSL
+  auto triple = spirv::VerCapExtAttr::get(
+      spirv::Version::V_1_0, {spirv::Capability::Shader},
+      {spirv::Extension::SPV_KHR_storage_buffer_storage_class}, context);
+  return spirv::TargetEnvAttr::get(triple, spirv::Vendor::Unknown,
+                                   spirv::DeviceType::Unknown,
+                                   spirv::TargetEnvAttr::kUnknownDeviceID,
+                                   spirv::getDefaultResourceLimits(context));
+}
+
+class WebGPUTargetBackend : public TargetBackend {
+ public:
+  WebGPUTargetBackend(WebGPUTargetOptions options)
+      : options_(std::move(options)) {}
+
+  // NOTE: we could vary this based on the options such as 'webgpu-v2'.
+  std::string name() const override { return "webgpu"; }
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<IREE::Codegen::IREECodegenDialect, spirv::SPIRVDialect,
+                    gpu::GPUDialect>();
+  }
+
+  IREE::HAL::DeviceTargetAttr getDefaultDeviceTarget(
+      MLIRContext *context) const override {
+    Builder b(context);
+    SmallVector<NamedAttribute> configItems;
+
+    configItems.emplace_back(b.getIdentifier("executable_targets"),
+                             getExecutableTargets(context));
+
+    auto configAttr = b.getDictionaryAttr(configItems);
+    return IREE::HAL::DeviceTargetAttr::get(
+        context, b.getStringAttr(deviceID()), configAttr);
+  }
+
+  void buildTranslationPassPipeline(OpPassManager &passManager) override {
+    buildSPIRVCodegenPassPipeline(passManager);
+    // TODO(scotttodd): additional passes for WebGPU/WGSL
+    //                  (here or during serialization?)
+  }
+
+  LogicalResult serializeExecutable(IREE::HAL::ExecutableVariantOp variantOp,
+                                    OpBuilder &executableBuilder) override {
+    ModuleOp innerModuleOp = variantOp.getInnerModule();
+    auto spirvModuleOps = innerModuleOp.getOps<spirv::ModuleOp>();
+    if (!llvm::hasSingleElement(spirvModuleOps)) {
+      // TODO(#7824): Implement linking / shader module combining and relax this
+      return variantOp.emitError()
+             << "should only contain exactly one spv.module op";
+    }
+    auto spvModuleOp = *spirvModuleOps.begin();
+
+    // The schema expects each shader module to have entry points named "dN",
+    // where N is the entry point ordinal.
+    // For each executable entry point op, rename the entry point symbol using
+    // that convention and keep track of the mapping between entry point
+    // ordinals to which shader module they reference.
+    auto entryPointOps = llvm::to_vector<4>(
+        variantOp.getOps<IREE::HAL::ExecutableEntryPointOp>());
+    llvm::SmallVector<uint32_t, 4> entryPointOrdinals(entryPointOps.size());
+    SymbolTableCollection symbolTable;
+    SymbolUserMap symbolUsers(symbolTable, variantOp);
+    for (auto entryPointOp : entryPointOps) {
+      auto entryPointFunc = dyn_cast<spirv::FuncOp>(
+          SymbolTable::lookupSymbolIn(spvModuleOp, entryPointOp.sym_name()));
+
+      std::string symbolName = llvm::formatv("d{0}", entryPointOp.ordinal());
+      mlir::StringAttr nameAttr =
+          mlir::StringAttr::get(variantOp->getContext(), symbolName);
+
+      symbolUsers.replaceAllUsesWith(entryPointFunc, nameAttr);
+      entryPointOp.setName(symbolName);  // Same symbol reference? Not in table?
+      SymbolTable::setSymbolName(entryPointFunc, symbolName);
+
+      // We only have one shader module right now, so all point to index 0.
+      // TODO(#7824): Support multiple shader modules per executable
+      entryPointOrdinals[entryPointOp.ordinal().getZExtValue()] = 0;
+    }
+
+    // Serialize the spirv::ModuleOp into binary format.
+    SmallVector<uint32_t, 0> spvBinary;
+    if (failed(spirv::serialize(spvModuleOp, spvBinary)) || spvBinary.empty()) {
+      return variantOp.emitError() << "failed to serialize spv.module";
+    }
+    if (options_.keepShaderModules) {
+      saveShaderToTempFile(variantOp, "spv",
+                           reinterpret_cast<const char *>(spvBinary.data()),
+                           spvBinary.size_in_bytes());
+
+      // Disassemble the shader and save that too.
+      // Note: this should match what getWebGPUTargetEnv used.
+      // TODO(scotttodd): Query spirv env from the executable variant?
+      spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_0);
+      std::string spvDisassembled;
+      if (spirvTools.Disassemble(
+              spvBinary.data(), spvBinary.size(), &spvDisassembled,
+              SPV_BINARY_TO_TEXT_OPTION_INDENT |
+                  SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES)) {
+        saveShaderToTempFile(variantOp, "spvasm", spvDisassembled.data(),
+                             spvDisassembled.size());
+      } else {
+        llvm::errs() << "Failed to disassemble SPIR-V binary\n";
+      }
+    }
+
+    // Compile SPIR-V to WGSL source code.
+    auto wgsl = compileSPIRVToWGSL(spvBinary);
+    if (!wgsl.hasValue()) {
+      // TODO(scotttodd): restructure branching and write disassembled SPIR-V
+      //                  to stderr / an error diagnostic (don't want to
+      //                  disassemble if successful + option not set, also
+      //                  don't want to disassemble twice :P)
+      return variantOp.emitError()
+             << "failed to compile SPIR-V to WGSL. Consider inspecting the "
+                "shader program using -iree-webgpu-keep-shader-modules";
+    }
+    if (options_.keepShaderModules) {
+      saveShaderToTempFile(variantOp, "wgsl", wgsl.getValue().data(),
+                           wgsl.getValue().length());
+    }
+
+    // Pack the WGSL and metadata into a flatbuffer.
+    FlatbufferBuilder builder;
+    iree_WGSLExecutableDef_start_as_root(builder);
+
+    iree_WGSLShaderModuleDef_start(builder);
+    auto wgslRef = builder.createString(wgsl.getValue());
+    iree_WGSLShaderModuleDef_code_add(builder, wgslRef);
+    // TODO(scotttodd): populate source map
+    auto shaderModuleRef = iree_WGSLShaderModuleDef_end(builder);
+
+    auto shaderModulesVec = iree_WGSLShaderModuleDef_vec_create(
+        builder, &shaderModuleRef, /*len=*/1);
+    iree_WGSLExecutableDef_shader_modules_add(builder, shaderModulesVec);
+
+    auto entryPointsRef = flatbuffers_uint32_vec_create(
+        builder, entryPointOrdinals.data(), entryPointOrdinals.size());
+    iree_WGSLExecutableDef_entry_points_add(builder, entryPointsRef);
+
+    iree_WGSLExecutableDef_end_as_root(builder);
+
+    // Add the binary data to the target executable.
+    auto binaryOp = executableBuilder.create<IREE::HAL::ExecutableBinaryOp>(
+        variantOp.getLoc(), variantOp.sym_name(),
+        variantOp.target().getFormat(),
+        builder.getBufferAttr(executableBuilder.getContext()));
+    binaryOp.mime_typeAttr(
+        executableBuilder.getStringAttr("application/x-flatbuffers"));
+
+    return success();
+  }
+
+ private:
+  ArrayAttr getExecutableTargets(MLIRContext *context) const {
+    SmallVector<Attribute> targetAttrs;
+    // If we had multiple target environments we would generate one target attr
+    // per environment, with each setting its own environment attribute.
+    targetAttrs.push_back(
+        getExecutableTarget(context, getWebGPUTargetEnv(context)));
+    return ArrayAttr::get(context, targetAttrs);
+  }
+
+  IREE::HAL::ExecutableTargetAttr getExecutableTarget(
+      MLIRContext *context, spirv::TargetEnvAttr targetEnv) const {
+    Builder b(context);
+    SmallVector<NamedAttribute> configItems;
+
+    configItems.emplace_back(b.getIdentifier(spirv::getTargetEnvAttrName()),
+                             targetEnv);
+
+    auto configAttr = b.getDictionaryAttr(configItems);
+    return IREE::HAL::ExecutableTargetAttr::get(
+        context, b.getStringAttr("webgpu"), b.getStringAttr("webgpu-wgsl-fb"),
+        configAttr);
+  }
+
+  void saveShaderToTempFile(IREE::HAL::ExecutableVariantOp variantOp,
+                            llvm::StringRef suffix, const char *data,
+                            size_t size) {
+    llvm::SmallString<32> filePath;
+    if (std::error_code error = llvm::sys::fs::createTemporaryFile(
+            variantOp.getName(), suffix, filePath)) {
+      llvm::errs() << "failed to generate temp file for shader: "
+                   << error.message();
+      return;
+    }
+    std::error_code error;
+    auto file = std::make_unique<llvm::ToolOutputFile>(filePath, error,
+                                                       llvm::sys::fs::OF_None);
+    if (error) {
+      llvm::errs() << "failed to open temp file for shader '" << filePath
+                   << "': " << error.message();
+      return;
+    }
+
+    // TODO(scotttodd): refactor to group these messages
+    mlir::emitRemark(variantOp.getLoc())
+        << "Shader file for " << variantOp.getName() << " preserved:\n"
+        << "    " << filePath;
+    file->os().write(data, size);
+    file->keep();
+  }
+
+  WebGPUTargetOptions options_;
+};
+
+void registerWebGPUTargetBackends(
+    std::function<WebGPUTargetOptions()> queryOptions) {
+  getWebGPUTargetOptionsFromFlags();
+  auto backendFactory = [=]() {
+    return std::make_shared<WebGPUTargetBackend>(queryOptions());
+  };
+  // #hal.device.target<"webgpu", ...
+  static TargetBackendRegistration registration0("webgpu", backendFactory);
+  // #hal.executable.target<"webgpu-wgsl", ...
+  static TargetBackendRegistration registration1("webgpu-wgsl", backendFactory);
+}
+
+}  // namespace HAL
+}  // namespace IREE
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.h b/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.h
new file mode 100644
index 0000000..eb77188
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.h
@@ -0,0 +1,40 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_HAL_TARGET_WEBGPU_WEBGPUTARGET_H_
+#define IREE_COMPILER_DIALECT_HAL_TARGET_WEBGPU_WEBGPUTARGET_H_
+
+#include <functional>
+#include <string>
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+
+// Options controlling the WebGPU/WGSL translation.
+struct WebGPUTargetOptions {
+  // Include debug information like variable names in outputs.
+  bool debugSymbols = true;
+
+  // True to keep shader modules for debugging.
+  bool keepShaderModules = false;
+};
+
+// Returns a WebGPUTargetOptions struct initialized with WebGPU/WGSL related
+// command-line flags.
+WebGPUTargetOptions getWebGPUTargetOptionsFromFlags();
+
+// Registers the WebGPU/WGSL backends.
+void registerWebGPUTargetBackends(
+    std::function<WebGPUTargetOptions()> queryOptions);
+
+}  // namespace HAL
+}  // namespace IREE
+}  // namespace iree_compiler
+}  // namespace mlir
+
+#endif  // IREE_COMPILER_DIALECT_HAL_TARGET_WEBGPU_WEBGPUTARGET_H_
diff --git a/iree/tools/CMakeLists.txt b/iree/tools/CMakeLists.txt
index 3ca8abe..5c1fc8b 100644
--- a/iree/tools/CMakeLists.txt
+++ b/iree/tools/CMakeLists.txt
@@ -35,6 +35,10 @@
   list(APPEND IREE_COMPILER_TARGETS iree::compiler::Dialect::HAL::Target::VulkanSPIRV)
   list(APPEND IREE_COMPILER_TARGET_COPTS "-DIREE_HAVE_VULKANSPIRV_TARGET")
 endif()
+if("${IREE_TARGET_BACKEND_WEBGPU}")
+  list(APPEND IREE_COMPILER_TARGETS iree::compiler::Dialect::HAL::Target::WebGPU)
+  list(APPEND IREE_COMPILER_TARGET_COPTS "-DIREE_HAVE_WEBGPU_TARGET")
+endif()
 if("${IREE_TARGET_BACKEND_CUDA}")
   list(APPEND IREE_COMPILER_TARGETS iree::compiler::Dialect::HAL::Target::CUDA)
   list(APPEND IREE_COMPILER_TARGET_COPTS "-DIREE_HAVE_CUDA_TARGET")
diff --git a/iree/tools/init_targets.cc b/iree/tools/init_targets.cc
index 74e7e1b..cc30903 100644
--- a/iree/tools/init_targets.cc
+++ b/iree/tools/init_targets.cc
@@ -26,6 +26,9 @@
 #ifdef IREE_HAVE_VULKANSPIRV_TARGET
 #include "iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.h"
 #endif  // IREE_HAVE_VULKANSPIRV_TARGET
+#ifdef IREE_HAVE_WEBGPU_TARGET
+#include "iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.h"
+#endif  // IREE_HAVE_WEBGPU_TARGET
 
 namespace mlir {
 namespace iree_compiler {
@@ -58,6 +61,10 @@
     IREE::HAL::registerVulkanSPIRVTargetBackends(
         []() { return IREE::HAL::getVulkanSPIRVTargetOptionsFromFlags(); });
 #endif  // IREE_HAVE_VULKANSPIRV_TARGET
+#ifdef IREE_HAVE_WEBGPU_TARGET
+    IREE::HAL::registerWebGPUTargetBackends(
+        []() { return IREE::HAL::getWebGPUTargetOptionsFromFlags(); });
+#endif  // IREE_HAVE_WEBGPU_TARGET
     return true;
   }();
   (void)init_once;