Retire usage of the transform dialect at the Flow level for cpu/matmul.mlir and cuda/reduction.mlir (#10527)

diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index f2c1dc3..aeee547 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -310,28 +310,39 @@
         "scf.foreach_thread with rank > 3 does not lower to workgroup");
 
   // Step 0. Outline the compute workload region and set up the workload
-  // operands.
-  auto maybeWorkgroupCounts = getNumThreads(rewriter, foreachThreadOp);
-  if (failed(maybeWorkgroupCounts) ||
-      llvm::any_of(*maybeWorkgroupCounts, [](OpFoldResult ofr) {
-        return !getConstantIntValue(ofr).has_value();
-      }))
-    return foreachThreadOp->emitError(
-        "unsupported dynamic workgroup_count atm --- need to slice out "
-        "workgroup_count computation into ExecutableExport::workgroup_count. "
-        "This region may require arbitrary computations and cannot magically "
-        "match what the `stream.cmd.dispatch` has already imposed on us at a "
-        "distance. For now we must specify the number of values properly when "
-        "applying the topLevel tile_to_foreach_thread_op");
-
-  SmallVector<int64_t> workgroupCounts;
-  for (OpFoldResult ofr : *maybeWorkgroupCounts)
-    workgroupCounts.push_back(getConstantIntValue(ofr).value());
-  if (failed(populateWorkgroupCountComputingRegion(rewriter, foreachThreadOp,
-                                                   exportOp))) {
-    return foreachThreadOp->emitOpError(
-               "failed to populate workload region for dispatchOp: ")
-           << exportOp;
+  // operands, if this has not been done already.
+  // Using `transform.iree.tile_to_foreach_thread_and_workgroup_count_region` is
+  // the preferred way to set up tiling and workgroup_count region **at the same
+  // time**.
+  //
+  // The block of code below will be retired once there is enough confidence we
+  // can do everything without it. This includes in particular providing custom
+  // fusion heuristics at the flow level: at this time, the only way to fully
+  // control fusion of more advanced cases is to use the transform dialect at
+  // the flow level and explicitly match the ops we want to fuse.
+  // Once fusion is customizable enough in perpetuity, we can retire this.
+  if (exportOp.getWorkgroupCount().empty()) {
+    auto maybeWorkgroupCounts = getNumThreads(rewriter, foreachThreadOp);
+    if (failed(maybeWorkgroupCounts) ||
+        llvm::any_of(*maybeWorkgroupCounts, [](OpFoldResult ofr) {
+          return !getConstantIntValue(ofr).has_value();
+        }))
+      return foreachThreadOp->emitError(
+          "unsupported dynamic workgroup_count atm --- need to slice out "
+          "workgroup_count computation into ExecutableExport::workgroup_count. "
+          "This region may require arbitrary computations and cannot magically "
+          "match what the `stream.cmd.dispatch` has already imposed on us at a "
+          "distance. For now we must specify the number of values properly "
+          "when applying the topLevel tile_to_foreach_thread_op");
+    SmallVector<int64_t> workgroupCounts;
+    for (OpFoldResult ofr : *maybeWorkgroupCounts)
+      workgroupCounts.push_back(getConstantIntValue(ofr).value());
+    if (failed(populateWorkgroupCountComputingRegion(rewriter, foreachThreadOp,
+                                                     exportOp))) {
+      return foreachThreadOp->emitOpError(
+                 "failed to populate workload region for dispatchOp: ")
+             << exportOp;
+    }
   }
 
   // Step 1. Create the workgroup id and count ops.
@@ -406,10 +417,6 @@
     state.getTopLevel()->emitOpError("no IREE::HAL::ExecutableExportOp found");
     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
   }
-  if (!exportOp.getWorkgroupCount().empty())
-    return emitDefaultSilenceableFailure(target)
-           << "export op must have an empty workgroup count region that "
-              "the transform fills --- the transform is not applied";
 
   scf::ForeachThreadOp topLevelForeachThreadOp;
   auto walkResult = target->walk([&](scf::ForeachThreadOp foreachThreadOp) {
@@ -436,8 +443,14 @@
   return DiagnosedSilenceableFailure(success());
 }
 
+void transform_dialect::ForeachThreadToWorkgroupOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::consumesHandle(getTarget(), effects);
+  transform::producesHandle(getTransformed(), effects);
+}
+
 //===---------------------------------------------------------------------===//
-// TileToWorkgroupsOp
+// TileToForeachThreadAndWorkgroupCountRegion
 //===---------------------------------------------------------------------===//
 
 /// Lower the ops within the workgroup count region of `exportOp` that
@@ -495,41 +508,42 @@
   return success();
 }
 
-SmallVector<OpFoldResult>
-transform_dialect::TileToWorkgroupsOp::getMixedNumThreads() {
+SmallVector<OpFoldResult> transform_dialect::
+    TileToForeachThreadAndWorkgroupCountRegion::getMixedNumThreads() {
   return getMixedSizes(getStaticNumThreads(), getNumThreads());
 }
 
-SmallVector<OpFoldResult>
-transform_dialect::TileToWorkgroupsOp::getMixedTileSizes() {
+SmallVector<OpFoldResult> transform_dialect::
+    TileToForeachThreadAndWorkgroupCountRegion::getMixedTileSizes() {
   return getMixedSizes(getStaticTileSizes(), getTileSizes());
 }
 
-LogicalResult transform_dialect::TileToWorkgroupsOp::verify() {
+LogicalResult
+transform_dialect::TileToForeachThreadAndWorkgroupCountRegion::verify() {
   if (getMixedNumThreads().empty() == getMixedTileSizes().empty())
     return emitOpError("either num_threads or tile_sizes must be specified");
   return success();
 }
 
-void transform_dialect::TileToWorkgroupsOp::getEffects(
+void transform_dialect::TileToForeachThreadAndWorkgroupCountRegion::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   transform::consumesHandle(getTarget(), effects);
-  transform::consumesHandle(getFunc(), effects);
   transform::onlyReadsHandle(getTileSizes(), effects);
   transform::onlyReadsHandle(getNumThreads(), effects);
   transform::producesHandle(getResults(), effects);
 }
 
-DiagnosedSilenceableFailure transform_dialect::TileToWorkgroupsOp::apply(
+DiagnosedSilenceableFailure
+transform_dialect::TileToForeachThreadAndWorkgroupCountRegion::apply(
     transform::TransformResults &transformResults,
     transform::TransformState &state) {
-  ArrayRef<Operation *> funcOps = state.getPayloadOps(getFunc());
-  assert(funcOps.size() == 1 && "expected single func op in payload");
-  FailureOr<IREE::HAL::ExecutableExportOp> exportOp =
-      getEntryPoint(cast<func::FuncOp>(funcOps[0]));
+  ArrayRef<Operation *> targetOps = state.getPayloadOps(getTarget());
+  assert(targetOps.size() == 1 && "expected single target op in payload");
+  auto funcOp = targetOps.front()->getParentOfType<func::FuncOp>();
+  FailureOr<IREE::HAL::ExecutableExportOp> exportOp = getEntryPoint(funcOp);
   if (failed(exportOp)) {
     state.getTopLevel()->emitOpError("couldn't find export op for func");
-    return DiagnosedSilenceableFailure(reportUnknownTransformError(funcOps[0]));
+    return DiagnosedSilenceableFailure(reportUnknownTransformError(funcOp));
   }
 
   SmallVector<OpFoldResult> mixedTileSizes = getMixedTileSizes();
@@ -539,8 +553,8 @@
         reportUnknownTransformError(exportOp.value()));
   }
 
-  /// Lower the workgroup count region in keeping with the way dispatch regions
-  /// are created by default in IREEs compilation flow.
+  /// Lower the workgroup count region in keeping with the way dispatch
+  /// regions are created by default in IREEs compilation flow.
   IRRewriter rewriter(getContext());
   if (failed(lowerWorkgroupCountComputingRegion(rewriter, exportOp.value(),
                                                 mixedTileSizes))) {
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
index 216e25b..7571fb7 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
@@ -101,7 +101,7 @@
 def ForeachThreadToWorkgroupOp : Op<Transform_Dialect, 
     "iree.foreach_thread_to_workgroup",
     [FunctionalStyleTransformOpTrait,
-     MemoryEffectsOpInterface,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
      TransformOpInterface,
      TransformEachOpTrait]> {
   let description = [{
@@ -153,8 +153,8 @@
   }];
 }
 
-def TileToWorkgroupsOp :
-    Op<Transform_Dialect, "iree.tile_to_workgroups_op",
+def TileToForeachThreadAndWorkgroupCountRegion :
+    Op<Transform_Dialect, "iree.tile_to_foreach_thread_and_workgroup_count_region",
       [AttrSizedOperandSegments,
        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
        TransformOpInterface]> {
@@ -162,14 +162,13 @@
     Wrapper around `structured.tile_to_foreach_thread_op` for use within IREE.
 
     In addition to tile and distribute using `scf.foreach_thread`, lowers the
-    the `workgroup_count` region of the export op corresponding to the
-    `func.func` to return the number of workgroups.
+    the `workgroup_count` region of the export op corresponding to the parent
+    `func.func` of the target to return the number of workgroups.
     Please see the doc of `structured.tile_to_foreach_thread_op` for full
     description of op semantics. 
   }];
 
   let arguments = (ins PDL_Operation:$target,
-                   PDL_Operation:$func,
                    Variadic<PDL_Operation>:$num_threads,
                    Variadic<PDL_Operation>:$tile_sizes,
                    DefaultValuedAttr<I64ArrayAttr, "{}">:$static_num_threads,
@@ -178,7 +177,7 @@
   let results = (outs PDL_Operation:$foreach_thread_op,
                       PDL_Operation:$tiled_op);
   let assemblyFormat = [{
-    $target $func oilist(
+    $target oilist(
         `num_threads` custom<DynamicIndexList>($num_threads,
                                                $static_num_threads,
                                                "ShapedType::kDynamicSize") |
diff --git a/tests/transform_dialect/cpu/BUILD b/tests/transform_dialect/cpu/BUILD
index 7a7d652..88ac647 100644
--- a/tests/transform_dialect/cpu/BUILD
+++ b/tests/transform_dialect/cpu/BUILD
@@ -20,10 +20,8 @@
     # transform dialect spec files are MLIR files that specify a transformation,
     # they need to be included as data.
     data = [
+        "matmul_codegen_custom_dispatch_formation_spec.mlir",
         "matmul_codegen_default_spec.mlir",
-        "matmul_codegen_spec.mlir",
-        "matmul_dispatch_spec.mlir",
-        "matmul_tiled_dispatch_spec.mlir",
     ],
     tags = [
         "noasan",
diff --git a/tests/transform_dialect/cpu/CMakeLists.txt b/tests/transform_dialect/cpu/CMakeLists.txt
index 5bfb667..ea778f5 100644
--- a/tests/transform_dialect/cpu/CMakeLists.txt
+++ b/tests/transform_dialect/cpu/CMakeLists.txt
@@ -22,10 +22,8 @@
     iree-opt
     iree-run-module
   DATA
+    matmul_codegen_custom_dispatch_formation_spec.mlir
     matmul_codegen_default_spec.mlir
-    matmul_codegen_spec.mlir
-    matmul_dispatch_spec.mlir
-    matmul_tiled_dispatch_spec.mlir
   LABELS
     "noasan"
     "nomsan"
diff --git a/tests/transform_dialect/cpu/matmul.mlir b/tests/transform_dialect/cpu/matmul.mlir
index 5b0cf05..e7a26e1 100644
--- a/tests/transform_dialect/cpu/matmul.mlir
+++ b/tests/transform_dialect/cpu/matmul.mlir
@@ -10,71 +10,45 @@
   return %0 : !C_size
 }
 
-// RUN: iree-opt %s --iree-hal-target-backends=llvm-cpu \
-// RUN:   --iree-abi-transformation-pipeline \
-// RUN:   --iree-flow-transformation-pipeline \
-// RUN:   --iree-flow-dispatch-use-transform-dialect=%p/matmul_dispatch_spec.mlir | \
-// RUN: FileCheck %s --check-prefixes=DISPATCH
-
-// TODO: make this test drop transform dialect usage at the flow level and use:
-//   --iree-flow-transformation-pipeline --iree-flow-convert-region-to-workgroups
-// Atm the 3rd flow.dispatch.tensor.load shows as readonly instead of readwrite.
-
-// DISPATCH: flow.executable private @matmul_static_dispatch_0 {
-// DISPATCH:   flow.executable.export public @matmul_static_dispatch_0_matmul_3x3x5
-// DISPATCH:     builtin.module {
-// DISPATCH:       func.func @matmul_static_dispatch_0_matmul_3x3x5
-// DISPATCH:         flow.dispatch.tensor.load {{.*}}, offsets = [0, 0], sizes = [3, 5], strides = [1, 1] : !flow.dispatch.tensor<readonly:3x5xf32> -> tensor<3x5xf32>
-// DISPATCH:         flow.dispatch.tensor.load {{.*}}, offsets = [0, 0], sizes = [5, 3], strides = [1, 1] : !flow.dispatch.tensor<readonly:5x3xf32> -> tensor<5x3xf32>
-// DISPATCH:         flow.dispatch.tensor.load {{.*}}, offsets = [0, 0], sizes = [3, 3], strides = [1, 1] : !flow.dispatch.tensor<readwrite:3x3xf32> -> tensor<3x3xf32>
-// DISPATCH:         linalg.matmul ins({{.*}} : tensor<3x5xf32>, tensor<5x3xf32>) outs({{.*}} : tensor<3x3xf32>) -> tensor<3x3xf32>
-// DISPATCH:         flow.dispatch.tensor.store {{.*}} offsets = [0, 0], sizes = [3, 3], strides = [1, 1] : tensor<3x3xf32> -> !flow.dispatch.tensor<readwrite:3x3xf32>
-// DISPATCH:         return
-
-// RUN: iree-opt %s --iree-hal-target-backends=llvm-cpu \
-// RUN:   --iree-abi-transformation-pipeline --iree-flow-transformation-pipeline  --iree-flow-dispatch-use-transform-dialect=%p/matmul_dispatch_spec.mlir \
-// RUN:   --iree-stream-transformation-pipeline --iree-hal-configuration-pipeline | \
-// RUN: iree-opt --pass-pipeline='hal.executable(hal.executable.variant(iree-llvmcpu-lower-executable-target))' \
-// RUN:   --iree-codegen-llvmcpu-use-transform-dialect=%p/matmul_codegen_spec.mlir | \
-// RUN: FileCheck %s --check-prefixes=CODEGEN
-
 // Run with C++ dispatch region formation but transform dialect codegen
 // RUN: iree-opt %s --iree-hal-target-backends=llvm-cpu \
 // RUN:   --iree-abi-transformation-pipeline --iree-flow-transformation-pipeline \
-// RUN:   --iree-flow-dispatch-via-region-ops --iree-flow-dispatch-via-region-ops-generate-workload-region=false \
-// RUN:   --iree-stream-transformation-pipeline --iree-hal-configuration-pipeline | \
+// RUN:   --iree-flow-dispatch-via-region-ops \
+// RUN:   --iree-flow-dispatch-via-region-ops-generate-workload-region=false \
+// RUN:   --iree-stream-transformation-pipeline \
+// RUN:   --iree-hal-configuration-pipeline | \
 // RUN: iree-opt --pass-pipeline='hal.executable(hal.executable.variant(iree-llvmcpu-lower-executable-target))' \
-// RUN:   --iree-codegen-llvmcpu-use-transform-dialect=%p/matmul_codegen_spec.mlir | \
-// RUN: FileCheck %s --check-prefixes=CODEGEN
+// RUN:   --iree-codegen-llvmcpu-use-transform-dialect=%p/matmul_codegen_custom_dispatch_formation_spec.mlir | \
+// RUN: FileCheck %s --check-prefix=CODEGEN-CUSTOM-DISPATCH-FORMATION
 
-// CODEGEN: hal.executable private @matmul_static_dispatch_0 {
-// CODEGEN:   hal.executable.variant public @embedded_elf_x86_64, target = #executable_target_embedded_elf_x86_64_ {
-// CODEGEN:     hal.executable.export public @matmul_static_dispatch_0_matmul_3x3x5 ordinal(0) layout(#{{.*}}) attributes {translation_info = #translation} {
-// CODEGEN:       ^bb0(%{{.*}}: !hal.device):
-// CODEGEN:         arith.constant 2 : index
-// CODEGEN:         arith.constant 1 : index
-// CODEGEN:         hal.return %{{.*}}, %{{.*}}, %{{.*}} : index, index, index
-// CODEGEN:       }
-// CODEGEN:       builtin.module {
-// CODEGEN:         func.func @matmul_static_dispatch_0_matmul_3x3x5() {
-// CODEGEN:           arith.constant 0 : index
-// CODEGEN:           hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset({{.*}}) alignment(64) : memref<3x5xf32>
-// CODEGEN:           memref.assume_alignment %{{.*}}, 64 : memref<3x5xf32>
-// CODEGEN:           hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset({{.*}}) alignment(64) : memref<5x3xf32>
-// CODEGEN:           memref.assume_alignment %{{.*}}, 64 : memref<5x3xf32>
-// CODEGEN:           hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset({{.*}}) alignment(64) : memref<3x3xf32>
-// CODEGEN:           memref.assume_alignment %{{.*}}, 64 : memref<3x3xf32>
-// CODEGEN:           %[[workgroup_id_x:.*]] = hal.interface.workgroup.id[0] : index
-// CODEGEN:           affine.apply {{.*}}()[%workgroup_id_x]
-// CODEGEN:           memref.subview %{{.*}}[%{{.*}}, 0] [%{{.*}}, 5] [1, 1] : memref<3x5xf32> to memref<?x5xf32, strided<[5, 1], offset: ?>>
-// CODEGEN:           memref.subview %{{.*}}[%{{.*}}, 0] [%{{.*}}, 3] [1, 1] : memref<3x3xf32> to memref<?x3xf32, strided<[3, 1], offset: ?>>
-// CODEGEN:           linalg.matmul ins(%{{.*}}, %{{.*}} : memref<?x5xf32, strided<[5, 1], offset: ?>>, memref<5x3xf32>) outs(%{{.*}} : memref<?x3xf32, strided<[3, 1], offset: ?>>)
+// CODEGEN-CUSTOM-DISPATCH-FORMATION: hal.executable private @matmul_static_dispatch_0 {
+// CODEGEN-CUSTOM-DISPATCH-FORMATION:   hal.executable.variant public @embedded_elf_x86_64, target = #executable_target_embedded_elf_x86_64_ {
+// CODEGEN-CUSTOM-DISPATCH-FORMATION:     hal.executable.export public @matmul_static_dispatch_0_matmul_3x3x5 ordinal(0) layout(#{{.*}}) attributes {translation_info = #translation} {
+// CODEGEN-CUSTOM-DISPATCH-FORMATION:       ^bb0(%{{.*}}: !hal.device):
+// CODEGEN-CUSTOM-DISPATCH-FORMATION:         %[[C2:.*]] = arith.constant 2 : index
+// CODEGEN-CUSTOM-DISPATCH-FORMATION:         %[[C1:.*]] = arith.constant 1 : index
+// CODEGEN-CUSTOM-DISPATCH-FORMATION:         hal.return %[[C2]], %[[C1]], %[[C1]] : index, index, index
+// CODEGEN-CUSTOM-DISPATCH-FORMATION:       }
+// CODEGEN-CUSTOM-DISPATCH-FORMATION:       builtin.module {
+// CODEGEN-CUSTOM-DISPATCH-FORMATION:         func.func @matmul_static_dispatch_0_matmul_3x3x5() {
+// CODEGEN-CUSTOM-DISPATCH-FORMATION:           arith.constant 0 : index
+// CODEGEN-CUSTOM-DISPATCH-FORMATION:           hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset({{.*}}) alignment(64) : memref<3x5xf32>
+// CODEGEN-CUSTOM-DISPATCH-FORMATION:           memref.assume_alignment %{{.*}}, 64 : memref<3x5xf32>
+// CODEGEN-CUSTOM-DISPATCH-FORMATION:           hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset({{.*}}) alignment(64) : memref<5x3xf32>
+// CODEGEN-CUSTOM-DISPATCH-FORMATION:           memref.assume_alignment %{{.*}}, 64 : memref<5x3xf32>
+// CODEGEN-CUSTOM-DISPATCH-FORMATION:           hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset({{.*}}) alignment(64) : memref<3x3xf32>
+// CODEGEN-CUSTOM-DISPATCH-FORMATION:           memref.assume_alignment %{{.*}}, 64 : memref<3x3xf32>
+// CODEGEN-CUSTOM-DISPATCH-FORMATION:           %[[workgroup_id_x:.*]] = hal.interface.workgroup.id[0] : index
+// CODEGEN-CUSTOM-DISPATCH-FORMATION:           affine.apply {{.*}}()[%workgroup_id_x]
+// CODEGEN-CUSTOM-DISPATCH-FORMATION:           memref.subview %{{.*}}[%{{.*}}, 0] [%{{.*}}, 5] [1, 1] : memref<3x5xf32> to memref<?x5xf32, strided<[5, 1], offset: ?>>
+// CODEGEN-CUSTOM-DISPATCH-FORMATION:           memref.subview %{{.*}}[%{{.*}}, 0] [%{{.*}}, 3] [1, 1] : memref<3x3xf32> to memref<?x3xf32, strided<[3, 1], offset: ?>>
+// CODEGEN-CUSTOM-DISPATCH-FORMATION:           linalg.matmul ins(%{{.*}}, %{{.*}} : memref<?x5xf32, strided<[5, 1], offset: ?>>, memref<5x3xf32>) outs(%{{.*}} : memref<?x3xf32, strided<[3, 1], offset: ?>>)
 
 // RUN: iree-opt %s --iree-hal-target-backends=llvm-cpu \
 // RUN:   --iree-abi-transformation-pipeline \
 // RUN:   --iree-flow-transformation-pipeline \
 // RUN:   --iree-stream-transformation-pipeline \
-// RUN:    --iree-hal-configuration-pipeline | \
+// RUN:   --iree-hal-configuration-pipeline | \
 // RUN: iree-opt --pass-pipeline='hal.executable(hal.executable.variant(iree-llvmcpu-lower-executable-target))' \
 // RUN:   --iree-codegen-llvmcpu-use-transform-dialect=%p/matmul_codegen_default_spec.mlir | \
 // RUN: FileCheck %s --check-prefixes=CODEGEN-DEFAULT
@@ -86,22 +60,4 @@
 // CODEGEN-DEFAULT:         %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
 // CODEGEN-DEFAULT:         hal.return %[[D0]], %[[C1]], %[[C1]]
 
-// RUN: iree-compile %s --iree-hal-target-backends=llvm-cpu \
-// RUN:   --iree-flow-dispatch-use-transform-dialect=%p/matmul_dispatch_spec.mlir \
-// RUN:   --iree-codegen-llvmcpu-use-transform-dialect=%p/matmul_codegen_spec.mlir | \
-// RUN: iree-run-module --entry_function=matmul_static \
-// RUN:   --function_input="3x5xf32=1 1 1 1 1 1 1 1 1 1 1 1 1 1 1" \
-// RUN:   --function_input="5x3xf32=1 1 1 1 1 1 1 1 1 1 1 1 1 1 1" \
-// RUN:   --function_input="3x3xf32=0 0 0 0 0 0 0 0 0"| \
-// RUN: FileCheck %s --check-prefixes=EXEC
-
 // EXEC: 3x3xf32=[5 5 5][5 5 5][5 5 5]
-
-// RUN: iree-compile --iree-hal-target-backends=llvm-cpu \
-// RUN:     --iree-flow-dispatch-use-transform-dialect=%p/matmul_tiled_dispatch_spec.mlir \
-// RUN:     --iree-flow-export-benchmark-funcs %s | \
-// RUN: iree-benchmark-module --device=local-task | \
-// RUN: FileCheck %s --check-prefixes=BENCHMARK-MODULE
-
-// When running iree-benchmark-module, we only check the existence of the func.
-// BENCHMARK-MODULE: matmul_static
diff --git a/tests/transform_dialect/cpu/matmul_codegen_spec.mlir b/tests/transform_dialect/cpu/matmul_codegen_custom_dispatch_formation_spec.mlir
similarity index 100%
rename from tests/transform_dialect/cpu/matmul_codegen_spec.mlir
rename to tests/transform_dialect/cpu/matmul_codegen_custom_dispatch_formation_spec.mlir
diff --git a/tests/transform_dialect/cpu/matmul_codegen_default_spec.mlir b/tests/transform_dialect/cpu/matmul_codegen_default_spec.mlir
index d26eaf1..9d7b392 100644
--- a/tests/transform_dialect/cpu/matmul_codegen_default_spec.mlir
+++ b/tests/transform_dialect/cpu/matmul_codegen_default_spec.mlir
@@ -2,11 +2,12 @@
 
 transform.structured.canonicalized_sequence failures(propagate) {
 ^bb1(%variant_op: !pdl.operation):
-  %func = transform.structured.match ops{["func.func"]} in %variant_op
-  %0 = transform.structured.match ops{["linalg.matmul"]} in %variant_op
+  %matmul = transform.structured.match ops{["linalg.matmul"]} in %variant_op
 
   %foreach_thread, %tiled_generic =
-    transform.iree.tile_to_workgroups_op %0 %func tile_sizes [2]
-
-  %1 = transform.iree.bufferize %variant_op
+    transform.iree.tile_to_foreach_thread_and_workgroup_count_region %matmul tile_sizes [2]
+  
+  %variant_op_2 = transform.iree.bufferize %variant_op
+  %func = transform.structured.match ops{["func.func"]} in %variant_op_2
+  transform.iree.foreach_thread_to_workgroup %func
 }
diff --git a/tests/transform_dialect/cpu/matmul_dispatch_spec.mlir b/tests/transform_dialect/cpu/matmul_dispatch_spec.mlir
deleted file mode 100644
index eba247b..0000000
--- a/tests/transform_dialect/cpu/matmul_dispatch_spec.mlir
+++ /dev/null
@@ -1,9 +0,0 @@
-transform.with_pdl_patterns {
-^bb0(%arg0: !pdl.operation):
-  transform.sequence %arg0 failures(propagate) {
-  ^bb1(%arg1: !pdl.operation):
-    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
-    %region_op = transform.iree.wrap_in_dispatch_region %0
-    transform.iree.region_to_workgroups %region_op
-  }
-}
diff --git a/tests/transform_dialect/cpu/matmul_tiled_dispatch_spec.mlir b/tests/transform_dialect/cpu/matmul_tiled_dispatch_spec.mlir
deleted file mode 100644
index 41dc2dd..0000000
--- a/tests/transform_dialect/cpu/matmul_tiled_dispatch_spec.mlir
+++ /dev/null
@@ -1,6 +0,0 @@
-transform.structured.canonicalized_sequence failures(propagate) {
-^bb1(%arg1: !pdl.operation):
-  %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
-  %foreach_op, %tiled_op = transform.structured.tile_to_foreach_thread_op %0 num_threads [10, 20]
-  %dispatch_op = transform.iree.foreach_thread_to_flow %foreach_op
-}
diff --git a/tests/transform_dialect/cuda/BUILD b/tests/transform_dialect/cuda/BUILD
index 8af33f6..4aeb99b 100644
--- a/tests/transform_dialect/cuda/BUILD
+++ b/tests/transform_dialect/cuda/BUILD
@@ -34,8 +34,9 @@
     # they need to be included as data.
     data = [
         "reduction_codegen_spec.mlir",
-        "reduction_dispatch_spec.mlir",
         "softmax_codegen_spec.mlir",
+        # FIXME: This cannot be retired yet as there is some writeonly vs readwrite
+        # issue and we even end up emitting out of bounds accesses.
         "softmax_dispatch_spec.mlir",
         "softmax_fused_codegen_spec.mlir",
     ],
diff --git a/tests/transform_dialect/cuda/CMakeLists.txt b/tests/transform_dialect/cuda/CMakeLists.txt
index d5a0fba..139817f 100644
--- a/tests/transform_dialect/cuda/CMakeLists.txt
+++ b/tests/transform_dialect/cuda/CMakeLists.txt
@@ -27,7 +27,6 @@
     iree-run-module
   DATA
     reduction_codegen_spec.mlir
-    reduction_dispatch_spec.mlir
     softmax_codegen_spec.mlir
     softmax_dispatch_spec.mlir
     softmax_fused_codegen_spec.mlir
diff --git a/tests/transform_dialect/cuda/reduction.mlir b/tests/transform_dialect/cuda/reduction.mlir
index 5863591..4ea9300 100644
--- a/tests/transform_dialect/cuda/reduction.mlir
+++ b/tests/transform_dialect/cuda/reduction.mlir
@@ -24,7 +24,6 @@
 // RUN: iree-opt %s --iree-hal-target-backends=cuda \
 // RUN:     --iree-abi-transformation-pipeline \
 // RUN:     --iree-flow-transformation-pipeline  \
-// RUN:     --iree-flow-dispatch-use-transform-dialect=%p/reduction_dispatch_spec.mlir \
 // RUN:     --iree-stream-transformation-pipeline \
 // RUN:     --iree-hal-configuration-pipeline | \
 // RUN: iree-opt --pass-pipeline='hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target-pass))' \
@@ -32,7 +31,6 @@
 // RUN: FileCheck %s --check-prefix=CHECK
 
 // RUN: iree-compile %s --iree-hal-target-backends=cuda \
-// RUN:     --iree-flow-dispatch-use-transform-dialect=%p/reduction_dispatch_spec.mlir \
 // RUN:     --iree-codegen-llvmgpu-use-transform-dialect=%p/reduction_codegen_spec.mlir | \
 // RUN: iree-run-module --entry_function=reduce --device=cuda |\
 // RUN: FileCheck %s --check-prefix=EXEC
@@ -49,7 +47,7 @@
   //         CHECK: %[[SHMEM_VIEW_EXPANDED:.*]] = memref.subview %[[SHMEM_ALLOC]][%[[TIDZ]], %[[TIDY]]]{{.*}}to memref<f32, {{.*}}, 3>
 
   // Distributed reduction: everyone loads then 5 xor + addf expected
-  //         CHECK: vector.transfer_read %{{.*}}[%[[TIDZ]], %[[TIDY]], %[[TIDX]]]
+  //         CHECK: vector.transfer_read %{{.*}}[%[[TIDX]]]
   // CHECK-COUNT-5: gpu.shuffle  xor{{.*}}{{[[:space:]].*}}{{.*}} arith.addf
 
   //         CHECK: %[[RES:.*]] = arith.addf %{{.*}}
diff --git a/tests/transform_dialect/cuda/reduction_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_codegen_spec.mlir
index 1015856..18db174 100644
--- a/tests/transform_dialect/cuda/reduction_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/reduction_codegen_spec.mlir
@@ -14,7 +14,7 @@
   // First level of tiling + fusion parallelizes to blocks.
   // The mapping to block ids can only happen after bufferization atm.
   %foreach_thread_grid, %grid_combiner_op =
-    transform.structured.tile_to_foreach_thread_op %combiner_op tile_sizes [1]
+    transform.iree.tile_to_foreach_thread_and_workgroup_count_region %combiner_op tile_sizes [1]
   %not_combiner = transform.merge_handles %fill, %more_parallel_fill_op, %more_parallel_op
   transform.structured.fuse_into_containing_op %not_combiner into %foreach_thread_grid
 
diff --git a/tests/transform_dialect/cuda/reduction_dispatch_spec.mlir b/tests/transform_dialect/cuda/reduction_dispatch_spec.mlir
deleted file mode 100644
index 7dc0521..0000000
--- a/tests/transform_dialect/cuda/reduction_dispatch_spec.mlir
+++ /dev/null
@@ -1,15 +0,0 @@
-// RUN: iree-opt %s
-
-// Dispatch reduction.
-transform.structured.canonicalized_sequence failures(propagate){
-^bb1(%variant_op: !pdl.operation):
-  %root = transform.structured.match interface{LinalgOp}
-    attributes{iterator_types = ["parallel", "reduction"]} in %variant_op
-  %fill = transform.structured.match ops{["linalg.fill"]} in %variant_op
-
-  // TODO: this could be replaced by a C++ only version.
-  // Atm the IR produced is not the same so all pieces do not connect.
-  %region_op = transform.iree.wrap_in_dispatch_region %root
-  %region_op_2 = transform.iree.move_preceding_op_into_dispatch_region %fill into %region_op
-  transform.iree.region_to_workgroups %region_op_2
-}
diff --git a/tests/transform_dialect/cuda/softmax.mlir b/tests/transform_dialect/cuda/softmax.mlir
index b6e4fed..7fa6c2c 100644
--- a/tests/transform_dialect/cuda/softmax.mlir
+++ b/tests/transform_dialect/cuda/softmax.mlir
@@ -2,7 +2,6 @@
 // RUN: iree-opt %s --iree-hal-target-backends=cuda \
 // RUN:     --iree-abi-transformation-pipeline \
 // RUN:     --iree-flow-transformation-pipeline  \
-// RUN:     --iree-flow-dispatch-use-transform-dialect=%p/softmax_dispatch_spec.mlir \
 // RUN:     --iree-stream-transformation-pipeline \
 // RUN:     --iree-hal-configuration-pipeline | \
 // RUN: iree-opt --pass-pipeline='hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target-pass))' \
@@ -10,7 +9,6 @@
 // RUN: FileCheck %s --check-prefix=CHECK-SHUFFLE
 
 // RUN: iree-compile %s --iree-hal-target-backends=cuda \
-// RUN:     --iree-flow-dispatch-use-transform-dialect=%p/softmax_dispatch_spec.mlir \
 // RUN:     --iree-codegen-llvmgpu-use-transform-dialect=%p/softmax_codegen_spec.mlir | \
 // RUN: iree-run-module --entry_function=max_sub_exp --device=cuda | \
 // RUN: FileCheck %s
@@ -18,6 +16,10 @@
 // RUN: iree-opt %s --iree-hal-target-backends=cuda \
 // RUN:     --iree-abi-transformation-pipeline \
 // RUN:     --iree-flow-transformation-pipeline  \
+///
+/// FIXME: This cannot be retired yet as there is some writeonly vs readwrite
+/// issue and we even end up emitting out of bounds accesses.
+///
 // RUN:     --iree-flow-dispatch-use-transform-dialect=%p/softmax_dispatch_spec.mlir \
 // RUN:     --iree-stream-transformation-pipeline \
 // RUN:     --iree-hal-configuration-pipeline | \
@@ -26,6 +28,10 @@
 // RUN: FileCheck %s --check-prefix=CHECK-SHUFFLE
 
 // RUN: iree-compile %s --iree-hal-target-backends=cuda \
+///
+/// FIXME: This cannot be retired yet as there is some writeonly vs readwrite
+/// issue and we even end up emitting out of bounds accesses.
+///
 // RUN:     --iree-flow-dispatch-use-transform-dialect=%p/softmax_dispatch_spec.mlir \
 // RUN:     --iree-codegen-llvmgpu-use-transform-dialect=%p/softmax_fused_codegen_spec.mlir | \
 // RUN: iree-run-module --entry_function=max_sub_exp --device=cuda | \
@@ -41,11 +47,13 @@
 // CHECK-SHUFFLE: gpu.shuffle  xor
 
 // Execution only checks that @max_sub_exp runs.
-// CHECK: EXEC @max_sub_exp
+//      CHECK: EXEC @max_sub_exp
+//      CHECK: 16x128x128xf32=[
+// CHECK-SAME:                [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 
 
-func.func @max_sub_exp() {
+func.func @max_sub_exp() -> !out_tensor_t {
   %cst = arith.constant -3.40282347E+38 : f32
-  %cst_0 = arith.constant dense<1.000000e+00> : !out_tensor_t
+  %cst_0 = arith.constant dense<1121212.000000e+00> : !out_tensor_t
   %cst_1 = arith.constant dense<5.000000e+00> : !out_tensor_t
   %0 = util.do_not_optimize(%cst_1) : !out_tensor_t
 
@@ -73,6 +81,5 @@
     linalg.yield %7 : f32
   } -> !out_tensor_t
 
-  check.expect_almost_eq(%5, %cst_0) : !out_tensor_t
-  return
+  return %5: !out_tensor_t
 }
diff --git a/tests/transform_dialect/cuda/softmax_codegen_spec.mlir b/tests/transform_dialect/cuda/softmax_codegen_spec.mlir
index 3666f20..4607570 100644
--- a/tests/transform_dialect/cuda/softmax_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/softmax_codegen_spec.mlir
@@ -1,82 +1,77 @@
 // RUN: iree-opt %s
 
 // Codegen
-transform.with_pdl_patterns {
-^bb0(%arg0: !pdl.operation):
-  transform.structured.canonicalized_sequence %arg0 failures(propagate) {
-  ^bb1(%variant_op: !pdl.operation):
-    // First level of tiling + fusion parallelizes to blocks.
-    // The mapping  to block ids can only happen after bufferization atm
-    %root = transform.structured.match interface{LinalgOp}
-      attributes{iterator_types = ["parallel", "parallel", "parallel"]} in %variant_op
-    %fill = transform.structured.match ops{["linalg.fill"]} in %variant_op
-    %red = transform.structured.match interface{LinalgOp}
-      attributes{iterator_types = ["parallel", "parallel", "reduction"]} in %variant_op
-    %not_root = merge_handles %fill, %red
-    %foreach_thread, %tiled_generic =
-      transform.structured.tile_to_foreach_thread_op %root tile_sizes [1, 4]
-    transform.structured.fuse_into_containing_op %not_root into %foreach_thread
-    
-    // Second level of tiling + fusion parallelizes to threads.
-    // Leaving the reduction untiled on threadIdx.x makes it sequential on
-    // threadIdx.x. After distribution, predication by if (threadIdx.x == 0) is
-    // introduced and opportunities for distributing vector ops across warps
-    // appear.
-    %fill_linalg = transform.structured.match ops{["linalg.fill"]} in %variant_op
-    %reduction_linalg = transform.structured.match ops{["linalg.generic"]}
-      attributes{iterator_types = ["parallel", "parallel", "reduction"]} in %variant_op
-    %parallel_linalg = transform.structured.match ops{["linalg.generic"]}
-      attributes{iterator_types = ["parallel", "parallel", "parallel"]} in %variant_op
-    %foreach_thread_reduction, %tiled_reduction_generic =
-      transform.structured.tile_to_foreach_thread_op %reduction_linalg tile_sizes [1, 1]
-        (mapped to dims [2, 1, 0])
-    // TODO: this fusion currently does not happen properly, this is related to the clone
-    // behavior when fusing into scf.foreach_thread.
-    // Once fixed we'll be able to fuse.
-    // Fusion will save us one roundtrip to memory.
-    // transform.structured.fuse_into_containing_op %fill_linalg into %foreach_thread_reduction
-    transform.structured.tile_to_foreach_thread_op %parallel_linalg num_threads [1, 4, 32]
-        (mapped to dims [2, 1, 0])
+transform.structured.canonicalized_sequence failures(propagate) {
+^bb1(%variant_op: !pdl.operation):
+  // First level of tiling + fusion parallelizes to blocks.
+  // The mapping  to block ids can only happen after bufferization atm
+  %root = transform.structured.match interface{LinalgOp}
+    attributes{iterator_types = ["parallel", "parallel", "parallel"]} in %variant_op
+  %fill = transform.structured.match ops{["linalg.fill"]} in %variant_op
+  %red = transform.structured.match interface{LinalgOp}
+    attributes{iterator_types = ["parallel", "parallel", "reduction"]} in %variant_op
+  %not_root = merge_handles %fill, %red
+  %foreach_thread, %tiled_generic =
+    transform.iree.tile_to_foreach_thread_and_workgroup_count_region %root tile_sizes [1, 4]
+  transform.structured.fuse_into_containing_op %not_root into %foreach_thread
+  
+  // Second level of tiling + fusion parallelizes to threads.
+  // Leaving the reduction untiled on threadIdx.x makes it sequential on
+  // threadIdx.x. After distribution, predication by if (threadIdx.x == 0) is
+  // introduced and opportunities for distributing vector ops across warps
+  // appear.
+  %fill_linalg = transform.structured.match ops{["linalg.fill"]} in %variant_op
+  %reduction_linalg = transform.structured.match ops{["linalg.generic"]}
+    attributes{iterator_types = ["parallel", "parallel", "reduction"]} in %variant_op
+  %parallel_linalg = transform.structured.match ops{["linalg.generic"]}
+    attributes{iterator_types = ["parallel", "parallel", "parallel"]} in %variant_op
+  %foreach_thread_reduction, %tiled_reduction_generic =
+    transform.structured.tile_to_foreach_thread_op %reduction_linalg tile_sizes [1, 1]
+      (mapped to dims [2, 1, 0])
+  // TODO: this fusion currently does not happen properly, this is related to the clone
+  // behavior when fusing into scf.foreach_thread.
+  // Once fixed we'll be able to fuse.
+  // Fusion will save us one roundtrip to memory.
+  // transform.structured.fuse_into_containing_op %fill_linalg into %foreach_thread_reduction
+  transform.structured.tile_to_foreach_thread_op %parallel_linalg num_threads [1, 4, 32]
+      (mapped to dims [2, 1, 0])
 
 
-    // Inability to tile reductions to scf.foreach_thread has 2 implications:
-    //   1. since no scf.foreach_thread is present, no gpu.barrier is added.
-    //      This should be fixed independently: ops that are not nested in an scf.foreach_thread
-    //      should have a gpu.barrier. Later needs to be complemented by a barrier
-    //      removal pass.
-    //   2. Similarly, needs to be predicated under an if threadIx == 0 to avoid
-    //      multiple threads updating the buffer inplace once bufferized.
-    //
-    // Instead, we can vectorize and go to vector SSA values that sidestep these
-    // issues.
-    // Everyone will race to the write while still computing the same value.
-    //
-    // That is still not good enough because we need to predicate this in order
-    // to enable the parallel reduction on warps.
-    %func = transform.structured.match ops{["func.func"]} in %variant_op
-    %funcx = transform.iree.apply_patterns %func { rank_reducing }
-    transform.structured.vectorize %funcx
+  // Inability to tile reductions to scf.foreach_thread has 2 implications:
+  //   1. since no scf.foreach_thread is present, no gpu.barrier is added.
+  //      This should be fixed independently: ops that are not nested in an scf.foreach_thread
+  //      should have a gpu.barrier. Later needs to be complemented by a barrier
+  //      removal pass.
+  //   2. Similarly, needs to be predicated under an if threadIx == 0 to avoid
+  //      multiple threads updating the buffer inplace once bufferized.
+  //
+  // Instead, we can vectorize and go to vector SSA values that sidestep these
+  // issues.
+  // Everyone will race to the write while still computing the same value.
+  //
+  // That is still not good enough because we need to predicate this in order
+  // to enable the parallel reduction on warps.
+  %func = transform.structured.match ops{["func.func"]} in %variant_op
+  %funcx = transform.iree.apply_patterns %func { rank_reducing }
+  transform.structured.vectorize %funcx
 
-    // Bufferization is necessary for:
-    //   1. lowering scf.foreach_thread to workgroup (block level parallelism)
-    //   2. lowering scf.foreach_thread to gpu (thread level parallelism)
-    //   3. introducing predication (due to 1. + 2.) which enables rewriting to
-    //      warp_execute_on_lane_0 and later vector distribution.
-    %variant_op_2 = transform.iree.bufferize { target_gpu } %variant_op
+  // Bufferization is necessary for:
+  //   1. lowering scf.foreach_thread to workgroup (block level parallelism)
+  //   2. lowering scf.foreach_thread to gpu (thread level parallelism)
+  //   3. introducing predication (due to 1. + 2.) which enables rewriting to
+  //      warp_execute_on_lane_0 and later vector distribution.
+  %variant_op_2 = transform.iree.bufferize { target_gpu } %variant_op
 
-    %func_2 = transform.structured.match ops{["func.func"]} in %variant_op_2
-    %func_3 = transform.iree.foreach_thread_to_workgroup %func_2
-    transform.iree.map_nested_foreach_thread_to_gpu_threads %func_3
-      { workgroup_size = [32, 4, 1] }
+  %func_2 = transform.structured.match ops{["func.func"]} in %variant_op_2
+  %func_3 = transform.iree.foreach_thread_to_workgroup %func_2
+  transform.iree.map_nested_foreach_thread_to_gpu_threads %func_3
+    { workgroup_size = [32, 4, 1] }
 
-    %end_func = transform.structured.match ops{["func.func"]} in %variant_op_2
-    %end_func_2 = transform.iree.apply_patterns %end_func { rank_reducing }
+  %end_func = transform.structured.match ops{["func.func"]} in %variant_op_2
+  %end_func_2 = transform.iree.apply_patterns %end_func { rank_reducing }
 
-    // Vector distribution needs to happen on buffers.
-    %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_2
-    %warp = transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 }
-    transform.iree.vector.warp_distribute %end_func_2
-  }
+  // Vector distribution needs to happen on buffers.
+  %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_2
+  %warp = transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 }
+  transform.iree.vector.warp_distribute %end_func_2
 }
-
-
diff --git a/tests/transform_dialect/cuda/softmax_fused_codegen_spec.mlir.broken b/tests/transform_dialect/cuda/softmax_fused_codegen_spec.mlir.broken
new file mode 100644
index 0000000..68f891e
--- /dev/null
+++ b/tests/transform_dialect/cuda/softmax_fused_codegen_spec.mlir.broken
@@ -0,0 +1,57 @@
+// RUN: iree-opt %s 
+ 
+// Codegen
+transform.structured.canonicalized_sequence failures(propagate) {
+// transform.sequence %arg0 failures(propagate) {
+^bb1(%variant_op: !pdl.operation):    
+  // First level of tiling + fusion parallelizes to blocks.
+  // The mapping  to block ids can only happen after bufferization atm 
+  %root = transform.structured.match interface{LinalgOp} 
+    attributes{iterator_types = ["parallel", "parallel", "parallel"]} in %variant_op
+  %fill = transform.structured.match ops{["linalg.fill"]} in %variant_op
+  %red = transform.structured.match interface{LinalgOp} 
+    attributes{iterator_types = ["parallel", "parallel", "reduction"]} in %variant_op
+  %not_root = merge_handles %fill, %red
+  %foreach_thread, %tiled_generic = 
+    transform.iree.tile_to_foreach_thread_and_workgroup_count_region %root tile_sizes [1, 1]
+      (mapped to dims [0, 1, 2])
+  transform.structured.fuse_into_containing_op %not_root into %foreach_thread
+  
+  // Second level of tiling + fusion parallelizes to threads.
+  // Leaving the reduction untiled on threadIdx.x makes it sequential on 
+  // threadIdx.x. After distribution, predication by if (threadIdx.x == 0) is
+  // introduced and opportunities for distributing vector ops across warps
+  // appear.
+  %fill_linalg = transform.structured.match ops{["linalg.fill"]} in %variant_op
+  %reduction_linalg = transform.structured.match ops{["linalg.generic"]} 
+    attributes{iterator_types = ["parallel", "parallel", "reduction"]} in %variant_op
+  %not_root_2 = merge_handles %fill_linalg, %reduction_linalg
+  %parallel_linalg = transform.structured.match ops{["linalg.generic"]} 
+    attributes{iterator_types = ["parallel", "parallel", "parallel"]} in %variant_op
+  %foreach_thread_2, %parallel_linalg_2 = 
+    transform.structured.tile_to_foreach_thread_op %parallel_linalg tile_sizes [1, 1, 0]
+      (mapped to dims [2, 1, 0])
+  transform.structured.fuse_into_containing_op %not_root_2 into %foreach_thread_2
+  
+  // Rank-reduce and vectorize.
+  %func = transform.structured.match ops{["func.func"]} in %variant_op
+  %funcx = transform.iree.apply_patterns %func { rank_reducing }
+  transform.structured.vectorize %funcx
+  
+  // Bufferization is necessary for:
+  //   1. lowering scf.foreach_thread to workgroup (block level parallelism)
+  //   2. lowering scf.foreach_thread to gpu (thread level parallelism)
+  //   3. introducing predication (due to 1. + 2.) which enables rewriting to
+  //      warp_execute_on_lane_0 and later vector distribution.
+  %variant_op_2 = transform.iree.bufferize { target_gpu } %variant_op
+  %func_2 = transform.structured.match ops{["func.func"]} in %variant_op_2
+  %func_3 = transform.iree.foreach_thread_to_workgroup %func_2
+  transform.iree.map_nested_foreach_thread_to_gpu_threads %func_3
+    { workgroup_size = [32, 1, 1] }
+  
+  // Vector distribution needs to happen on buffers.
+  %end_func = transform.structured.match ops{["func.func"]} in %variant_op_2
+  %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_2
+  %warp = transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 }
+  transform.iree.vector.warp_distribute %end_func
+}