[SPIRV] Add ability to specify transform dialect codegen spec file (#13267)
Adds `--iree-spirv-use-transform-dialect=/path/to/spec.mlir` to
enable specifying custom transform dialect scripts for testing,
mirroring the flag available for CUDA.
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel b/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel
index 849d351..9e2ea6f 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel
@@ -106,6 +106,7 @@
"@llvm-project//mlir:TensorTransforms",
"@llvm-project//mlir:TosaDialect",
"@llvm-project//mlir:TosaToArith",
+ "@llvm-project//mlir:TransformDialect",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:VectorDialect",
"@llvm-project//mlir:VectorInterfaces",
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt
index 2331597..7095601 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt
@@ -93,6 +93,7 @@
MLIRTensorTransforms
MLIRTosaDialect
MLIRTosaToArith
+ MLIRTransformDialect
MLIRTransforms
MLIRVectorDialect
MLIRVectorInterfaces
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
index 2f332bf..7cd949d 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
@@ -17,6 +17,7 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"
#include "mlir/Analysis/SliceAnalysis.h"
@@ -42,6 +43,12 @@
namespace mlir {
namespace iree_compiler {
+llvm::cl::opt<std::string> clSPIRVTransformDialectFileName(
+ "iree-spirv-use-transform-dialect",
+ llvm::cl::desc(
+ "MLIR file containing a transform dialect specification to apply"),
+ llvm::cl::init(""));
+
using CodeGenPipeline = IREE::Codegen::DispatchLoweringPassPipeline;
//===----------------------------------------------------------------------===//
@@ -1400,6 +1407,15 @@
return setUserConfig(entryPointFn, rootOp, compilationInfo);
}
+ if (!clSPIRVTransformDialectFileName.empty()) {
+ MLIRContext *context = entryPointFn.getContext();
+ auto translationInfo = IREE::Codegen::TranslationInfoAttr::get(
+ context, CodeGenPipeline::TransformDialectCodegen);
+ LLVM_DEBUG(llvm::dbgs() << "using user specified transform dialect...\n");
+
+ return setTranslationInfo(entryPointFn, translationInfo);
+ }
+
// First try to find a proper CodeGen configuration to tile and vectorize for
// the current target architecture.
switch (targetEnv.getVendorID()) {
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
index 9dad76c..ddcf653 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -246,10 +246,13 @@
spirvPM.addPass(spirv::createSPIRVUpdateVCEPass());
}
+extern llvm::cl::opt<std::string> clSPIRVTransformDialectFileName;
+
void addSPIRVTransformDialectPasses(OpPassManager &passManager) {
// Give control to the transform dialect.
passManager.addPass(
- mlir::iree_compiler::createTransformDialectInterpreterPass());
+ mlir::iree_compiler::createTransformDialectInterpreterPass(
+ clSPIRVTransformDialectFileName));
// Dropping the schedule is needed:
// 1. if we want to embed the transform in the module: we should drop the
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
index 6d2649a..f7d8a47 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
@@ -47,7 +48,8 @@
gpu::GPUDialect, IREE::HAL::HALDialect, linalg::LinalgDialect,
IREE::LinalgExt::IREELinalgExtDialect, memref::MemRefDialect,
bufferization::BufferizationDialect, scf::SCFDialect,
- spirv::SPIRVDialect, vector::VectorDialect>();
+ spirv::SPIRVDialect, transform::TransformDialect,
+ vector::VectorDialect>();
}
void runOnOperation() override;
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel
index ff2a1a0..ebe52b6 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel
@@ -48,6 +48,7 @@
"pipeline_matmul_promotion.mlir",
"pipeline_matmul_vectorization.mlir",
"pipeline_reduction_subgroup.mlir",
+ "set_transform_strategy.mlir",
"tile_and_distribute.mlir",
"tile_and_distribute_scatter.mlir",
"tile_and_distribute_sort.mlir",
@@ -67,8 +68,14 @@
"vectorize_reduction.mlir",
],
include = ["*.mlir"],
+ exclude = [
+ "transform_dialect_dummy_spec.mlir",
+ ],
),
cfg = "//compiler:lit.cfg.py",
+ data = [
+ "transform_dialect_dummy_spec.mlir",
+ ],
tools = [
"//tools:iree-opt",
"@llvm-project//llvm:FileCheck",
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
index 02dfd88..20d6e8f 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
@@ -44,6 +44,7 @@
"pipeline_matmul_promotion.mlir"
"pipeline_matmul_vectorization.mlir"
"pipeline_reduction_subgroup.mlir"
+ "set_transform_strategy.mlir"
"tile_and_distribute.mlir"
"tile_and_distribute_scatter.mlir"
"tile_and_distribute_sort.mlir"
@@ -64,6 +65,8 @@
TOOLS
FileCheck
iree-opt
+ DATA
+ transform_dialect_dummy_spec.mlir
)
### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/set_transform_strategy.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/set_transform_strategy.mlir
new file mode 100644
index 0000000..5c2fb35
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/set_transform_strategy.mlir
@@ -0,0 +1,47 @@
+// RUN: iree-opt %s --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-spirv-lower-executable-target-pass)))" --iree-spirv-use-transform-dialect=%p/transform_dialect_dummy_spec.mlir | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>
+ ]>
+]>
+hal.executable private @copy_f32 {
+ hal.executable.variant @vulkan_spirv_fb, target = <"vulkan", "vulkan-spirv-fb", {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformShuffle], []>, Unknown:IntegratedGPU, #spirv.resource_limits<
+ max_compute_shared_memory_size = 32768,
+ max_compute_workgroup_invocations = 512,
+ max_compute_workgroup_size = [512, 512, 512],
+ subgroup_size = 16>>
+ }> {
+ hal.executable.export public @copy_f32 ordinal(0) layout(#pipeline_layout) {
+ ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
+ %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ // CHECK: IR printer:
+ func.func @copy_f32() {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<2x2xf32>>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x2xf32>>
+ %2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [2, 2], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2x2xf32>> -> tensor<2x2xf32>
+ %3 = tensor.empty() : tensor<2x2xf32>
+ %4 = linalg.generic {
+ indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]}
+ ins(%2 : tensor<2x2xf32>) outs(%3 : tensor<2x2xf32>) {
+ ^bb0(%arg0: f32, %arg1: f32):
+ %5 = math.sqrt %arg0 : f32
+ linalg.yield %5 : f32
+ } -> tensor<2x2xf32>
+ flow.dispatch.tensor.store %4, %1, offsets = [0, 0], sizes = [2, 2], strides = [1, 1] : tensor<2x2xf32> -> !flow.dispatch.tensor<writeonly:tensor<2x2xf32>>
+ return
+ }
+ }
+ // CHECK-COUNT-2: vector.transfer_read
+ // CHECK-COUNT-2: math.sqrt
+ // CHECK-COUNT-2: vector.transfer_write
+ }
+}
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/transform_dialect_dummy_spec.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/transform_dialect_dummy_spec.mlir
new file mode 100644
index 0000000..6bdcadc
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/transform_dialect_dummy_spec.mlir
@@ -0,0 +1,6 @@
+// RUN: iree-opt %s
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !pdl.operation):
+ print %arg0 : !pdl.operation
+}