Hook up the transform dialect interpreter to CUDA and activate the e2e test for CUDA too (#9529)

diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
index 3ef85c5..1597fd0 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
@@ -70,7 +70,7 @@
     llvm::cl::init(false));
 
 llvm::cl::opt<std::string> clCPUCodegenTransformDialectFileName(
-    "iree-codegen-use-transform-dialect",
+    "iree-codegen-llvmcpu-use-transform-dialect",
     llvm::cl::desc(
         "MLIR file containing a transform dialect specification to apply"),
     llvm::cl::init(""));
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/linalg_transform.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/linalg_transform.mlir
index 0ea9dea..a454677 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/linalg_transform.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/linalg_transform.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt %s  --pass-pipeline='hal.executable(hal.executable.variant(iree-llvmcpu-lower-executable-target))' --iree-codegen-use-transform-dialect=%p/linalg_transform_spec.mlir | FileCheck %s
+// RUN: iree-opt %s  --pass-pipeline='hal.executable(hal.executable.variant(iree-llvmcpu-lower-executable-target))' --iree-codegen-llvmcpu-use-transform-dialect=%p/linalg_transform_spec.mlir | FileCheck %s
 
 #device_target_cpu = #hal.device.target<"cpu", {executable_targets = [#hal.executable.target<"llvm", "embedded-elf-x86_64", {cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 16 : index, target_triple = "x86_64-unknown-unknown-eabi-elf"}>]}>
 #executable_layout = #hal.executable.layout<push_constants = 0, sets = [#hal.descriptor_set.layout<0, bindings = [#hal.descriptor_set.binding<0, storage_buffer>, #hal.descriptor_set.binding<1, storage_buffer>, #hal.descriptor_set.binding<2, storage_buffer>]>]>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/linalg_transform_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/linalg_transform_spec.mlir
index abb6635..afc99a2 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/linalg_transform_spec.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/linalg_transform_spec.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt %s 
+// RUN: iree-opt %s
 
 transform.with_pdl_patterns {
 ^bb0(%arg0: !pdl.operation):
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD
index 3f54307..933605b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD
@@ -45,6 +45,8 @@
         "//compiler/src/iree/compiler/Dialect/Util/IR",
         "//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
         "//llvm-external-projects/iree-dialects:IREELinalgExtPasses",
+        "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect",
+        "//llvm-external-projects/iree-dialects:IREELinalgTransformDialectPasses",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:AffineDialect",
         "@llvm-project//mlir:AffineToStandard",
@@ -71,12 +73,15 @@
         "@llvm-project//mlir:NVGPUDialect",
         "@llvm-project//mlir:NVGPUToNVVM",
         "@llvm-project//mlir:NVVMDialect",
+        "@llvm-project//mlir:PDLDialect",
+        "@llvm-project//mlir:PDLInterpDialect",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:ROCDLDialect",
         "@llvm-project//mlir:SCFDialect",
         "@llvm-project//mlir:SCFToControlFlow",
         "@llvm-project//mlir:SCFTransforms",
         "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TransformDialect",
         "@llvm-project//mlir:Transforms",
         "@llvm-project//mlir:VectorDialect",
         "@llvm-project//mlir:VectorToGPU",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
index 80528fe..9d855af 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
@@ -34,6 +34,8 @@
   DEPS
     IREELinalgExtDialect
     IREELinalgExtPasses
+    IREELinalgTransformDialect
+    IREELinalgTransformDialectPasses
     LLVMSupport
     MLIRAffineDialect
     MLIRAffineToStandard
@@ -60,12 +62,15 @@
     MLIRNVGPUDialect
     MLIRNVGPUToNVVM
     MLIRNVVMDialect
+    MLIRPDLDialect
+    MLIRPDLInterpDialect
     MLIRPass
     MLIRROCDLDialect
     MLIRSCFDialect
     MLIRSCFToControlFlow
     MLIRSCFTransforms
     MLIRSupport
+    MLIRTransformDialect
     MLIRTransforms
     MLIRVectorDialect
     MLIRVectorToGPU
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
index eaf6dc8..c6db82e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
@@ -5,6 +5,7 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h"
 #include "iree/compiler/Codegen/Dialect/IREECodegenDialect.h"
 #include "iree/compiler/Codegen/Dialect/LoweringConfig.h"
 #include "iree/compiler/Codegen/LLVMGPU/KernelConfig.h"
@@ -14,7 +15,10 @@
 #include "iree/compiler/Dialect/HAL/IR/HALOps.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/NVGPU/NVGPUDialect.h"
+#include "mlir/Dialect/PDL/IR/PDL.h"
+#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
 #include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Pass/PassRegistry.h"
@@ -35,11 +39,22 @@
           LLVMGPULowerExecutableTargetPass> {
  public:
   void getDependentDialects(DialectRegistry &registry) const override {
+    // clang-format off
     registry
-        .insert<IREE::Codegen::IREECodegenDialect, IREE::HAL::HALDialect,
-                linalg::LinalgDialect, IREE::LinalgExt::IREELinalgExtDialect,
-                vector::VectorDialect, gpu::GPUDialect, nvgpu::NVGPUDialect,
-                scf::SCFDialect>();
+        .insert<IREE::Codegen::IREECodegenDialect,
+                IREE::HAL::HALDialect,
+                IREE::LinalgExt::IREELinalgExtDialect,
+                linalg::LinalgDialect,
+                linalg::transform::LinalgTransformDialect,
+                gpu::GPUDialect,
+                nvgpu::NVGPUDialect,
+                pdl::PDLDialect,
+                pdl_interp::PDLInterpDialect,
+                scf::SCFDialect,
+                tensor::TensorDialect,
+                transform::TransformDialect,
+                vector::VectorDialect>();
+    // clang-format on
   }
 
   LLVMGPULowerExecutableTargetPass() = default;
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index fc68e7d..f098f36 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -7,6 +7,7 @@
 #include "iree/compiler/Codegen/Passes.h"
 
 #include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
+#include "iree-dialects/Dialect/LinalgTransform/Passes.h"
 #include "iree/compiler/Codegen/PassDetail.h"
 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
 #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
@@ -249,7 +250,15 @@
 extern llvm::cl::opt<std::string> clGPUCodegenTransformDialectFileName;
 
 void addGPUTransformDialectInterpreterPasses(OpPassManager &passManager) {
-  assert(0 && "TODO: implement transform dialect path for LLVMGPU");
+  // Give control to the transform dialect.
+  passManager.addPass(createTransformDialectInterpreterPass(
+      clGPUCodegenTransformDialectFileName));
+
+  // Dropping the schedule is only needed if we want to embed the transform in
+  // the module: we should drop the schedule once applied.
+  // This pass does nothing in the case where we apply a separate policy
+  // through a file.
+  passManager.addPass(createDropSchedulePass());
 }
 
 void buildLLVMGPUTransformPassPipeline(OpPassManager &pm, bool useROCM) {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD
index f3e1b85..fdc427e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD
@@ -27,14 +27,19 @@
             "reduce_bank_conflicts.mlir",
             "rocdl_pipeline_test.mlir",
             "illegal_configuration.mlir",
+            "linalg_transform.mlir",
             "legalize.mlir",
             "tensorcore_vectorization.mlir",
             "vector_to_gpu.mlir",
             "vectorization.mlir",
         ],
         include = ["*.mlir"],
+        # linalg_transform_spec is a an MLIR file that specifies a
+        # transformation, it needs to be included as data.
+        exclude = ["linalg_transform_spec.mlir"],
     ),
     cfg = "//compiler:lit.cfg.py",
+    data = ["linalg_transform_spec.mlir"],
     tools = [
         "//tools:iree-opt",
         "@llvm-project//llvm:FileCheck",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index c96a726..7754051 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -20,6 +20,7 @@
     "gpu_set_num_workgroups.mlir"
     "illegal_configuration.mlir"
     "legalize.mlir"
+    "linalg_transform.mlir"
     "nvvm_pipeline_test.mlir"
     "reduce_bank_conflicts.mlir"
     "rocdl_pipeline_test.mlir"
@@ -29,6 +30,8 @@
   TOOLS
     FileCheck
     iree-opt
+  DATA
+    linalg_transform_spec.mlir
 )
 
 ### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/linalg_transform.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/linalg_transform.mlir
new file mode 100644
index 0000000..f5f3926
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/linalg_transform.mlir
@@ -0,0 +1,35 @@
+// RUN: iree-opt %s  --pass-pipeline='hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target-pass))' --iree-codegen-llvmgpu-use-transform-dialect=%p/linalg_transform_spec.mlir | FileCheck %s
+
+#device_target_cuda = #hal.device.target<"cuda", {executable_targets = [#hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_35"}>]}>
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [#hal.descriptor_set.layout<0, bindings = [#hal.descriptor_set.binding<0, storage_buffer>]>]>
+#executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_35"}>
+module attributes {hal.device.targets = [#device_target_cuda]} {
+  hal.executable private @matmul_static_dispatch_0 {
+    hal.executable.variant public @cuda_nvptx_fb, target = #executable_target_cuda_nvptx_fb {
+      hal.executable.export public @matmul_static_dispatch_0 ordinal(0) layout(#executable_layout)
+      builtin.module {
+        func.func @matmul_static_dispatch_0() {
+          %c0 = arith.constant 0 : index
+          %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:250x500xf32>
+          %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:500x1020xf32>
+          %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readwrite:250x1020xf32>
+          %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [250, 500], strides = [1, 1] : !flow.dispatch.tensor<readonly:250x500xf32> -> tensor<250x500xf32>
+          %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [500, 1020], strides = [1, 1] : !flow.dispatch.tensor<readonly:500x1020xf32> -> tensor<500x1020xf32>
+
+          %50 = linalg.init_tensor [250, 1020] : tensor<250x1020xf32>
+          %cst = arith.constant 0.000000e+00 : f32
+          %5 = linalg.fill ins(%cst : f32) outs(%50 : tensor<250x1020xf32>) -> tensor<250x1020xf32>
+
+          //      CHECK: memref.assume_alignment %{{.*}}, 64 : memref<250x1020xf32>
+          // CHECK-NEXT: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : memref<250x1020xf32>)
+          // CHECK-NEXT: linalg.matmul{{.*}}ins(%{{.*}} : memref<250x500xf32>, memref<500x1020xf32>) outs(%{{.*}} : memref<250x1020xf32>)
+          // CHECK-NEXT: return
+
+          %6 = linalg.matmul ins(%3, %4 : tensor<250x500xf32>, tensor<500x1020xf32>) outs(%5 : tensor<250x1020xf32>) -> tensor<250x1020xf32>
+          flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [250, 1020], strides = [1, 1] : tensor<250x1020xf32> -> !flow.dispatch.tensor<readwrite:250x1020xf32>
+          return
+        }
+      }
+    }
+  }
+}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/linalg_transform_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/linalg_transform_spec.mlir
new file mode 100644
index 0000000..afc99a2
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/linalg_transform_spec.mlir
@@ -0,0 +1,18 @@
+// RUN: iree-opt %s
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @pdl_matmul_target : benefit(1) {
+    %args = operands
+    %results = types
+    %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+    // TODO: we don't want this, but it is the required terminator for pdl.pattern
+    rewrite %0 with "transform.dialect"
+  }
+
+  transform.structured.canonicalized_sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = pdl_match @pdl_matmul_target in %arg1
+    transform.iree.bufferize
+  }
+}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index f483aba..63352f5 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -249,11 +249,9 @@
                            return createDispatchWithTransformDialect(
                                clDispatchTransformFileName);
                          })
-      // TODO: we may only want to use the transform dialect for some dispatch
-      // regions and let the DispatchLinalgOnTensorsPass unconditionally handle
-      // the rest.
-      .addPredicatedPass(clDispatchTransformFileName.empty(),
-                         createDispatchLinalgOnTensorsPass)
+      // Only want use the transform dialect for some dispatch regions and let
+      // the DispatchLinalgOnTensorsPass unconditionally handle the rest.
+      .addPass(createDispatchLinalgOnTensorsPass)
       ////////////////////////////////////////////////////////////////////////
       .addPass(createCaptureDispatchDynamicDimsPass)
       .addPass(mlir::createCanonicalizerPass)
diff --git a/tests/e2e/linalg_transform/linalg_transform.mlir b/tests/e2e/linalg_transform/linalg_transform.mlir
index d9764b6..b6b4ce4 100644
--- a/tests/e2e/linalg_transform/linalg_transform.mlir
+++ b/tests/e2e/linalg_transform/linalg_transform.mlir
@@ -2,7 +2,7 @@
 /// Specify the dispatch region formation with the transform dialect.
 // RUN:   --iree-flow-dispatch-use-transform-dialect=%p/transform_dialect_dispatch_spec.mlir \
 /// Specify the codegen strategy with the transform dialect.
-// RUN:   --iree-codegen-use-transform-dialect=%p/transform_dialect_codegen_spec.mlir \
+// RUN:   --iree-codegen-llvmcpu-use-transform-dialect=%p/transform_dialect_codegen_spec.mlir \
 // RUN: | FileCheck %s
 
 func.func @matmul_static() -> tensor<5x5xf32> {
diff --git a/tests/e2e/linalg_transform/transform_dialect_codegen_spec.mlir b/tests/e2e/linalg_transform/transform_dialect_codegen_spec.mlir
index 32f6369..afc99a2 100644
--- a/tests/e2e/linalg_transform/transform_dialect_codegen_spec.mlir
+++ b/tests/e2e/linalg_transform/transform_dialect_codegen_spec.mlir
@@ -13,7 +13,6 @@
   transform.structured.canonicalized_sequence %arg0 {
   ^bb1(%arg1: !pdl.operation):
     %0 = pdl_match @pdl_matmul_target in %arg1
-    transform.iree.set_num_workgroups_to_one
     transform.iree.bufferize
   }
 }