[spirv][webgpu] Expand extended multiplication ops (#13274)
These are not supported by WGSL and need to be expanded before
converting SPIR-V to WGSL.
Re-enable the `mul_shift` tosa_ops test to exercise this pass. The exact
lowering is tested in MLIR.
Fixes: https://github.com/openxla/iree/issues/11571
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/WebGPU/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Target/WebGPU/CMakeLists.txt
index 469c53c..ce6edc9 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/WebGPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/WebGPU/CMakeLists.txt
@@ -25,6 +25,7 @@
MLIRIR
MLIRSPIRVDialect
MLIRSPIRVSerialization
+ MLIRSPIRVTransforms
SPIRV-Tools
iree::compiler::Codegen::Dialect::IREECodegenDialect
iree::compiler::Codegen::SPIRV
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.cpp
index ae8a544..b8e96c5 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.cpp
@@ -21,6 +21,7 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
#include "mlir/Target/SPIRV/Serialization.h"
#include "spirv-tools/libspirv.hpp"
@@ -108,6 +109,12 @@
// Therefore, just let the SPIR-V CodeGen to avoid generating guards w.r.t.
// NaN and infinity.
buildSPIRVCodegenPassPipeline(passManager, /*enableFastMath=*/true);
+
+ // WGSL does not support extended multiplication:
+ // https://github.com/gpuweb/gpuweb/issues/1565. Make sure to lower it to
+ // regular multiplication before we convert SPIR-V to WGSL.
+ passManager.nest<ModuleOp>().nest<spirv::ModuleOp>().addPass(
+ spirv::createSPIRVWebGPUPreparePass());
}
LogicalResult serializeExecutable(const SerializationOptions &options,
diff --git a/tests/e2e/tosa_ops/BUILD.bazel b/tests/e2e/tosa_ops/BUILD.bazel
index 56f5a6f..7d2e684 100644
--- a/tests/e2e/tosa_ops/BUILD.bazel
+++ b/tests/e2e/tosa_ops/BUILD.bazel
@@ -218,6 +218,7 @@
"max_pool.mlir",
"maximum.mlir",
"minimum.mlir",
+ "mul_shift.mlir",
"mul.mlir",
"negate.mlir",
"pad.mlir",
@@ -236,7 +237,6 @@
include = ["*.mlir"],
exclude = [
"logical_right_shift_16.mlir", # TODO(#11828)
- "mul_shift.mlir", # TODO(#11571)
],
)
diff --git a/tests/e2e/tosa_ops/CMakeLists.txt b/tests/e2e/tosa_ops/CMakeLists.txt
index 11340a5..9ce27b4 100644
--- a/tests/e2e/tosa_ops/CMakeLists.txt
+++ b/tests/e2e/tosa_ops/CMakeLists.txt
@@ -199,6 +199,7 @@
"maximum.mlir"
"minimum.mlir"
"mul.mlir"
+ "mul_shift.mlir"
"negate.mlir"
"pad.mlir"
"reciprocal.mlir"
@@ -252,7 +253,7 @@
"maximum.mlir"
"minimum.mlir"
"mul.mlir"
- # "mul_shift.mlir" # TODO(#11571): error: extended arithmetic is not finalized for WGSL
+ "mul_shift.mlir"
"negate.mlir"
"pad.mlir"
"reciprocal.mlir"