[SPIRV] Add ability to specify transform dialect codegen spec file (#13267)

Adds `--iree-spirv-use-transform-dialect=/path/to/spec.mlir` to
enable specifying custom transform dialect scripts for testing,
mirroring the flag available for CUDA.
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel b/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel
index 849d351..9e2ea6f 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel
@@ -106,6 +106,7 @@
         "@llvm-project//mlir:TensorTransforms",
         "@llvm-project//mlir:TosaDialect",
         "@llvm-project//mlir:TosaToArith",
+        "@llvm-project//mlir:TransformDialect",
         "@llvm-project//mlir:Transforms",
         "@llvm-project//mlir:VectorDialect",
         "@llvm-project//mlir:VectorInterfaces",
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt
index 2331597..7095601 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt
@@ -93,6 +93,7 @@
     MLIRTensorTransforms
     MLIRTosaDialect
     MLIRTosaToArith
+    MLIRTransformDialect
     MLIRTransforms
     MLIRVectorDialect
     MLIRVectorInterfaces
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
index 2f332bf..7cd949d 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
@@ -17,6 +17,7 @@
 #include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/MathExtras.h"
 #include "mlir/Analysis/SliceAnalysis.h"
@@ -42,6 +43,12 @@
 namespace mlir {
 namespace iree_compiler {
 
+llvm::cl::opt<std::string> clSPIRVTransformDialectFileName(
+    "iree-spirv-use-transform-dialect",
+    llvm::cl::desc(
+        "MLIR file containing a transform dialect specification to apply"),
+    llvm::cl::init(""));
+
 using CodeGenPipeline = IREE::Codegen::DispatchLoweringPassPipeline;
 
 //===----------------------------------------------------------------------===//
@@ -1400,6 +1407,15 @@
     return setUserConfig(entryPointFn, rootOp, compilationInfo);
   }
 
+  if (!clSPIRVTransformDialectFileName.empty()) {
+    MLIRContext *context = entryPointFn.getContext();
+    auto translationInfo = IREE::Codegen::TranslationInfoAttr::get(
+        context, CodeGenPipeline::TransformDialectCodegen);
+    LLVM_DEBUG(llvm::dbgs() << "using user specified transform dialect...\n");
+
+    return setTranslationInfo(entryPointFn, translationInfo);
+  }
+
   // First try to find a proper CodeGen configuration to tile and vectorize for
   // the current target architecture.
   switch (targetEnv.getVendorID()) {
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
index 9dad76c..ddcf653 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -246,10 +246,13 @@
   spirvPM.addPass(spirv::createSPIRVUpdateVCEPass());
 }
 
+extern llvm::cl::opt<std::string> clSPIRVTransformDialectFileName;
+
 void addSPIRVTransformDialectPasses(OpPassManager &passManager) {
   // Give control to the transform dialect.
   passManager.addPass(
-      mlir::iree_compiler::createTransformDialectInterpreterPass());
+      mlir::iree_compiler::createTransformDialectInterpreterPass(
+          clSPIRVTransformDialectFileName));
 
   // Dropping the schedule is needed:
   //   1. if we want to embed the transform in the module: we should drop the
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
index 6d2649a..f7d8a47 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Pass/PassRegistry.h"
@@ -47,7 +48,8 @@
                 gpu::GPUDialect, IREE::HAL::HALDialect, linalg::LinalgDialect,
                 IREE::LinalgExt::IREELinalgExtDialect, memref::MemRefDialect,
                 bufferization::BufferizationDialect, scf::SCFDialect,
-                spirv::SPIRVDialect, vector::VectorDialect>();
+                spirv::SPIRVDialect, transform::TransformDialect,
+                vector::VectorDialect>();
   }
 
   void runOnOperation() override;
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel
index ff2a1a0..ebe52b6 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel
@@ -48,6 +48,7 @@
             "pipeline_matmul_promotion.mlir",
             "pipeline_matmul_vectorization.mlir",
             "pipeline_reduction_subgroup.mlir",
+            "set_transform_strategy.mlir",
             "tile_and_distribute.mlir",
             "tile_and_distribute_scatter.mlir",
             "tile_and_distribute_sort.mlir",
@@ -67,8 +68,14 @@
             "vectorize_reduction.mlir",
         ],
         include = ["*.mlir"],
+        exclude = [
+            "transform_dialect_dummy_spec.mlir",
+        ],
     ),
     cfg = "//compiler:lit.cfg.py",
+    data = [
+        "transform_dialect_dummy_spec.mlir",
+    ],
     tools = [
         "//tools:iree-opt",
         "@llvm-project//llvm:FileCheck",
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
index 02dfd88..20d6e8f 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
@@ -44,6 +44,7 @@
     "pipeline_matmul_promotion.mlir"
     "pipeline_matmul_vectorization.mlir"
     "pipeline_reduction_subgroup.mlir"
+    "set_transform_strategy.mlir"
     "tile_and_distribute.mlir"
     "tile_and_distribute_scatter.mlir"
     "tile_and_distribute_sort.mlir"
@@ -64,6 +65,8 @@
   TOOLS
     FileCheck
     iree-opt
+  DATA
+    transform_dialect_dummy_spec.mlir
 )
 
 ### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/set_transform_strategy.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/set_transform_strategy.mlir
new file mode 100644
index 0000000..5c2fb35
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/set_transform_strategy.mlir
@@ -0,0 +1,47 @@
+// RUN: iree-opt %s --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-spirv-lower-executable-target-pass)))" --iree-spirv-use-transform-dialect=%p/transform_dialect_dummy_spec.mlir | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>
+  ]>
+]>
+hal.executable private @copy_f32 {
+  hal.executable.variant @vulkan_spirv_fb, target = <"vulkan", "vulkan-spirv-fb", {
+      spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformShuffle], []>, Unknown:IntegratedGPU, #spirv.resource_limits<
+        max_compute_shared_memory_size = 32768,
+        max_compute_workgroup_invocations = 512,
+        max_compute_workgroup_size = [512, 512, 512],
+       subgroup_size = 16>>
+    }> {
+    hal.executable.export public @copy_f32 ordinal(0) layout(#pipeline_layout) {
+    ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
+      %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
+      hal.return %x, %y, %z : index, index, index
+    }
+    builtin.module {
+      // CHECK: IR printer:
+      func.func @copy_f32() {
+        %c0 = arith.constant 0 : index
+        %cst = arith.constant 0.000000e+00 : f32
+        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<2x2xf32>>
+        %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x2xf32>>
+        %2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [2, 2], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2x2xf32>> -> tensor<2x2xf32>
+        %3 = tensor.empty() : tensor<2x2xf32>
+        %4 = linalg.generic {
+            indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]}
+            ins(%2 : tensor<2x2xf32>) outs(%3 : tensor<2x2xf32>) {
+          ^bb0(%arg0: f32, %arg1: f32):
+            %5 = math.sqrt %arg0 : f32
+            linalg.yield %5 : f32
+          } -> tensor<2x2xf32>
+        flow.dispatch.tensor.store %4, %1, offsets = [0, 0], sizes = [2, 2], strides = [1, 1] : tensor<2x2xf32> -> !flow.dispatch.tensor<writeonly:tensor<2x2xf32>>
+        return
+      }
+    }
+    // CHECK-COUNT-2: vector.transfer_read
+    // CHECK-COUNT-2: math.sqrt
+    // CHECK-COUNT-2: vector.transfer_write
+  }
+}
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/transform_dialect_dummy_spec.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/transform_dialect_dummy_spec.mlir
new file mode 100644
index 0000000..6bdcadc
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/transform_dialect_dummy_spec.mlir
@@ -0,0 +1,6 @@
+// RUN: iree-opt %s
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !pdl.operation):
+  print %arg0 : !pdl.operation
+}