Move TransformMatchers to llvm-external-projects (#11472)

This is a followup commit to #11422 addressing code move requests in the
review.

Co-authored-by: Thomas Raoux <thomasraoux@google.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD b/compiler/src/iree/compiler/Codegen/Common/BUILD
index 814298b..66512d4 100644
--- a/compiler/src/iree/compiler/Codegen/Common/BUILD
+++ b/compiler/src/iree/compiler/Codegen/Common/BUILD
@@ -100,14 +100,19 @@
     srcs = [
         "TransformDialectJitterPass.cpp",
         "TransformDialectStrategies.cpp",
+        "TransformDialectStrategiesCPU.cpp",
+        "TransformDialectStrategiesGPU.cpp",
     ],
     hdrs = [
         "TransformDialectStrategies.h",
+        "TransformDialectStrategiesCPU.h",
+        "TransformDialectStrategiesGPU.h",
     ],
     deps = [
         # Dialects
         "//compiler/src/iree/compiler/Codegen/Dialect:IREECodegenDialect",
         "//compiler/src/iree/compiler/Dialect/Flow/IR",
+        "//llvm-external-projects/iree-dialects:IREEDialectsTransforms",
         "//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
         "//llvm-external-projects/iree-dialects:IREELinalgExtTransformOps",
         "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect",
diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
index 201d5fc..fc6b489 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -72,10 +72,15 @@
     TransformDialectJitterPass
   HDRS
     "TransformDialectStrategies.h"
+    "TransformDialectStrategiesCPU.h"
+    "TransformDialectStrategiesGPU.h"
   SRCS
     "TransformDialectJitterPass.cpp"
     "TransformDialectStrategies.cpp"
+    "TransformDialectStrategiesCPU.cpp"
+    "TransformDialectStrategiesGPU.cpp"
   DEPS
+    IREEDialectsTransforms
     IREELinalgExtDialect
     IREELinalgExtTransformOps
     IREELinalgTransformDialect
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.cpp
index c2c0a5d..99dcc27 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.cpp
@@ -10,8 +10,8 @@
 #include <type_traits>
 
 #include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h"
+#include "iree-dialects/Transforms/TransformMatchers.h"
 #include "iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h"
-#include "iree/compiler/Codegen/Common/TransformExtensions/TransformMatchers.h"
 #include "iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.h"
 #include "iree/compiler/Codegen/PassDetail.h"
 #include "iree/compiler/Codegen/Passes.h"
@@ -26,6 +26,7 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -46,27 +47,18 @@
 
 using namespace mlir;
 
-namespace mlir {
-namespace iree_compiler {
-
 #define DEBUG_TYPE "iree-transform-builder"
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
 
 // TODO: significantly better namespacing.
-using iree_compiler::IREE::transform_dialect::AllDims;
 using iree_compiler::IREE::transform_dialect::ApplyPatternsOp;
 using iree_compiler::IREE::transform_dialect::ConfigExtractPart;
 using iree_compiler::IREE::transform_dialect::ForeachThreadToWorkgroupOp;
 using iree_compiler::IREE::transform_dialect::IREEBufferizeOp;
 using iree_compiler::IREE::transform_dialect::
     IREEEraseHALDescriptorTypeFromMemRefOp;
-using iree_compiler::IREE::transform_dialect::IsPermutation;
-using iree_compiler::IREE::transform_dialect::m_StructuredOp;
 using iree_compiler::IREE::transform_dialect::
     MapNestedForeachThreadToGpuThreadsOp;
-using iree_compiler::IREE::transform_dialect::NumEqualsTo;
-using iree_compiler::IREE::transform_dialect::ShapeKind;
-using iree_compiler::IREE::transform_dialect::StructuredOpMatcher;
 using iree_compiler::IREE::transform_dialect::
     TileToForeachThreadAndWorkgroupCountRegionOp;
 using iree_compiler::IREE::transform_dialect::VectorToWarpExecuteOnLane0Op;
@@ -80,6 +72,12 @@
 using transform::SplitReductionOp;
 using transform::TileToForeachThreadOp;
 using transform::VectorizeOp;
+using transform_ext::AllDims;
+using transform_ext::IsPermutation;
+using transform_ext::m_StructuredOp;
+using transform_ext::NumEqualsTo;
+using transform_ext::ShapeKind;
+using transform_ext::StructuredOpMatcher;
 
 /// Matches `args` within `targetH` and unpacks a number of handles `N`.
 /// Assumes there are exactly `N` matched ops (but could be relaxed).
@@ -101,7 +99,8 @@
 //===----------------------------------------------------------------------===//
 
 /// Prints `handles` in order. Prints the whole IR if `handles` is empty.
-static void buildPrint(ImplicitLocOpBuilder &b, ValueRange handles = {}) {
+void mlir::iree_compiler::buildPrint(ImplicitLocOpBuilder &b,
+                                     ValueRange handles) {
   if (handles.empty()) b.create<PrintOp>();
   for (auto h : handles) b.create<PrintOp>(h);
 }
@@ -141,56 +140,94 @@
   return foreachThreadH;
 }
 
-/// Call buildTileAndFuseAndDistributeImpl with ArrayRef<int64_t> tilesSizes.
-template <typename TilingTransformOp = TileToForeachThreadOp>
+// TODO: if someone knows how to properly export templates go for it ..
+// sigh.
+template <typename TilingTransformOp>
 static Value buildTileFuseDistWithTileSizes(
     ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
     ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping,
-    SmallVectorImpl<Value> *resultingFusedOpsHandles = nullptr) {
+    SmallVectorImpl<Value> *resultingFusedOpsHandles) {
   return buildTileAndFuseAndDistributeImpl<TilingTransformOp,
                                            transform::TileSizesSpec>(
       b, rootH, opsHToFuse, tileSizes, threadDimMapping,
       resultingFusedOpsHandles);
 }
-
-/// Call buildTileAndFuseAndDistributeImpl with ArrayRef<int64_t> numThreads.
-template <typename TilingTransformOp = TileToForeachThreadOp>
-static Value buildTileFuseDistWithNumThreads(
+Value mlir::iree_compiler::buildTileFuseDistToForeachThreadWithTileSizes(
     ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
-    ArrayRef<int64_t> numThreads, ArrayAttr threadDimMapping,
-    SmallVectorImpl<Value> *resultingFusedOpsHandles = nullptr) {
-  return buildTileAndFuseAndDistributeImpl<TilingTransformOp,
-                                           transform::NumThreadsSpec>(
-      b, rootH, opsHToFuse, getAsOpFoldResult(b.getI64ArrayAttr(numThreads)),
-      threadDimMapping, resultingFusedOpsHandles);
+    ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping,
+    SmallVectorImpl<Value> *resultingFusedOpsHandles) {
+  return buildTileFuseDistWithTileSizes<TileToForeachThreadOp>(
+      b, rootH, opsHToFuse, tileSizes, threadDimMapping,
+      resultingFusedOpsHandles);
+}
+Value mlir::iree_compiler::
+    buildTileFuseDistToForeachThreadAndWorgroupCountWithTileSizes(
+        ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
+        ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping,
+        SmallVectorImpl<Value> *resultingFusedOpsHandles) {
+  return buildTileFuseDistWithTileSizes<
+      TileToForeachThreadAndWorkgroupCountRegionOp>(b, rootH, opsHToFuse,
+                                                    tileSizes, threadDimMapping,
+                                                    resultingFusedOpsHandles);
 }
 
-/// Call buildTileAndFuseAndDistributeImpl with a handle to multiple numThreads.
-template <typename TilingTransformOp = TileToForeachThreadOp>
+/// Call buildTileAndFuseAndDistributeImpl with ArrayRef<int64_t> numThreads.
+// TODO: if someone knows how to properly export templates go for it ..
+// sigh.
+template <typename TilingTransformOp>
 static Value buildTileFuseDistWithNumThreads(
     ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
-    Value numThreads, ArrayAttr threadDimMapping,
-    SmallVectorImpl<Value> *resultingFusedOpsHandles = nullptr) {
+    ArrayRef<OpFoldResult> numThreads, ArrayAttr threadDimMapping,
+    SmallVectorImpl<Value> *resultingFusedOpsHandles) {
   return buildTileAndFuseAndDistributeImpl<TilingTransformOp,
                                            transform::NumThreadsSpec>(
-      b, rootH, opsHToFuse, ArrayRef<OpFoldResult>{numThreads},
-      threadDimMapping, resultingFusedOpsHandles);
+      b, rootH, opsHToFuse, numThreads, threadDimMapping,
+      resultingFusedOpsHandles);
+}
+Value mlir::iree_compiler::buildTileFuseDistToForeachThreadWithNumThreads(
+    ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
+    ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping,
+    SmallVectorImpl<Value> *resultingFusedOpsHandles) {
+  return buildTileFuseDistWithTileSizes<TileToForeachThreadOp>(
+      b, rootH, opsHToFuse, tileSizes, threadDimMapping,
+      resultingFusedOpsHandles);
+}
+Value mlir::iree_compiler::
+    buildTileFuseDistToForeachThreadAndWorgroupCountWithNumThreads(
+        ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
+        ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping,
+        SmallVectorImpl<Value> *resultingFusedOpsHandles) {
+  return buildTileFuseDistWithTileSizes<
+      TileToForeachThreadAndWorkgroupCountRegionOp>(b, rootH, opsHToFuse,
+                                                    tileSizes, threadDimMapping,
+                                                    resultingFusedOpsHandles);
 }
 
 /// Apply patterns and vectorize (for now always applies rank-reduction).
 /// Takes a handle to a func.func and returns an updated handle to a
 /// func.func.
 // TODO: configure patterns.
-static Value buildVectorizeStrategy(ImplicitLocOpBuilder &b, Value funcH) {
+Value mlir::iree_compiler::buildVectorize(ImplicitLocOpBuilder &b,
+                                          Value funcH) {
   funcH = b.create<ApplyPatternsOp>(funcH, /*rankReducing=*/true);
   return b.create<VectorizeOp>(funcH);
 }
 
+/// Bufferize and drop HAL descriptor from memref ops.
+Value mlir::iree_compiler::buildBufferize(ImplicitLocOpBuilder &b,
+                                          Value variantH, bool targetGpu) {
+  variantH = b.create<IREEBufferizeOp>(variantH, /*targetGpu=*/true);
+  Value memrefFunc =
+      b.create<MatchOp>(variantH, func::FuncOp::getOperationName());
+  b.create<IREEEraseHALDescriptorTypeFromMemRefOp>(memrefFunc);
+  return variantH;
+}
+
 /// Post-bufferization mapping to blocks and threads.
 /// Takes a handle to a func.func and returns an updated handle to a
 /// func.func.
-static Value buildMapToBlockAndThreads(ImplicitLocOpBuilder &b, Value funcH,
-                                       ArrayRef<int64_t> blockSize) {
+Value mlir::iree_compiler::buildMapToBlockAndThreads(
+    ImplicitLocOpBuilder &b, Value funcH, ArrayRef<int64_t> blockSize) {
   funcH = b.create<ForeachThreadToWorkgroupOp>(funcH);
   return b.create<MapNestedForeachThreadToGpuThreadsOp>(funcH, blockSize);
 }
@@ -200,9 +237,9 @@
 /// Post-bufferization vector distribution with rank-reduction.
 /// Takes a handle to a func.func and returns an updated handle to a
 /// func.func.
-static Value buildDistributeVectors(ImplicitLocOpBuilder &b, Value variantH,
-                                    Value funcH,
-                                    int64_t warpSize = kCudaWarpSize) {
+Value mlir::iree_compiler::buildDistributeVectors(ImplicitLocOpBuilder &b,
+                                                  Value variantH, Value funcH,
+                                                  int64_t warpSize) {
   funcH = b.create<ApplyPatternsOp>(funcH, /*rankReducing=*/true);
   Value ifH = b.create<MatchOp>(funcH, scf::IfOp::getOperationName());
   // Locally suppress failures for this op only because it doesn't cover the
@@ -220,11 +257,7 @@
   return funcH;
 }
 
-//===----------------------------------------------------------------------===//
-// Higher-level problem-specific strategy creation APIs, these should favor
-// user-friendliness.
-//===----------------------------------------------------------------------===//
-
+namespace {
 /// Various handles produced by reduction splitting.
 struct ReductionSplitResult {
   /// Handle to the leading elementwise operation, may be null if no such
@@ -237,13 +270,14 @@
   Value splitLinalgH;
   /// Handle to the final reduction.
   Value combinerH;
-  /// Handle to the original fill operation, may be null if the operation was
-  /// not re-matched.
+  /// Handle to the original fill operation, may be null if the operation
+  /// was not re-matched.
   Value originalFillH;
-  /// Handle to the trailing fill operation, may be null if the operation was
-  /// not re-matched.
+  /// Handle to the trailing fill operation, may be null if the operation
+  /// was not re-matched.
   Value trailingEltwiseH;
 };
+}  // namespace
 
 /// Builds transform IR requesting to bubble up the "expand_shape" operation
 /// produced as parent of reduction splitting if necessary for fusion of the
@@ -280,11 +314,11 @@
 
 /// Distribute to blocks using the current IREE lowering config.
 // TODO: consider passing a problem-specific struct to control information.
-static Value createReductionStrategyBlockDistributionPart(
+Value mlir::iree_compiler::createReductionStrategyBlockDistributionPart(
     ImplicitLocOpBuilder &b, Value variantH, Value originalFillH,
     Value reductionH, Value optionalFusionRootH,
-    ArrayRef<OpFoldResult> tileSizes0Generic, bool hasLeadingEltwise = false,
-    bool hasTrailingEltwise = false) {
+    ArrayRef<OpFoldResult> tileSizes0Generic, bool hasLeadingEltwise,
+    bool hasTrailingEltwise) {
   // Step 1. Split the reduction to get meatier parallelism.
   // TODO: use a scf.foreach_thread for this.
   auto splitReductionTransformOp =
@@ -302,8 +336,8 @@
   auto x = mlir::gpu::GPUBlockMappingAttr::get(b.getContext(),
                                                ::mlir::gpu::Blocks::DimX);
   // Step 2. First level of tiling + fusion parallelizes to blocks using
-  // `tileSizes`. If the fusion root was the reduction op, update it to be the
-  // combiner op. Otherwise, fuse the combiner op into root.
+  // `tileSizes`. If the fusion root was the reduction op, update it to be
+  // the combiner op. Otherwise, fuse the combiner op into root.
   SmallVector<Value> opsHToFuse(
       {rs.originalFillH ? rs.originalFillH : originalFillH, rs.splitFillH,
        rs.splitLinalgH});
@@ -318,192 +352,25 @@
     opsHToFuse.push_back(rs.leadingEltwiseH);
   }
 
-  // The presence of leading elementwise operation implies that dispatch region
-  // formation happened using another transform dialect script and doesn't need
-  // the workgroup count part.
+  // The presence of leading elementwise operation implies that dispatch
+  // region formation happened using another transform dialect script and
+  // doesn't need the workgroup count part.
   if (hasLeadingEltwise) {
-    buildTileFuseDistWithTileSizes<TileToForeachThreadOp>(
+    iree_compiler::buildTileFuseDistToForeachThreadWithTileSizes(
         b, optionalFusionRootH, opsHToFuse, tileSizes0Generic,
         b.getArrayAttr({x}));
   } else {
-    buildTileFuseDistWithTileSizes<
-        TileToForeachThreadAndWorkgroupCountRegionOp>(
-        b, optionalFusionRootH, opsHToFuse, tileSizes0Generic,
-        b.getArrayAttr({x}));
+    iree_compiler::
+        buildTileFuseDistToForeachThreadAndWorgroupCountWithTileSizes(
+            b, optionalFusionRootH, opsHToFuse, tileSizes0Generic,
+            b.getArrayAttr({x}));
   }
 
   return variantH;
 }
 
-// TODO: consider passing a problem-specific struct to control information.
-static Value createReductionStrategyThreadDistributionPart(
-    ImplicitLocOpBuilder &b, Value variantH, ArrayRef<int64_t> tileSizes1Fill,
-    ArrayRef<int64_t> tileSizes1Generic, bool hasLeadingEltwise,
-    bool hasTrailingEltwise) {
-  // TODO: Relying on ordering is brittle, harden this.
-  Value matchedH = b.create<MatchOp>(
-      variantH, ArrayRef<StringRef>{linalg::GenericOp::getOperationName(),
-                                    linalg::FillOp::getOperationName()});
-  auto split = b.create<SplitHandlesOp>(
-      matchedH,
-      /*numResultHandles=*/4 + hasLeadingEltwise + hasTrailingEltwise);
-  Value firstFusionRootH = split.getResults()[1 + hasLeadingEltwise];
-  SmallVector<Value> firstFusionGroupHs =
-      split.getResults().take_front(1 + hasLeadingEltwise);
-  Value secondFusionRootH = split.getResults().back();
-  SmallVector<Value> secondFusionGroupHs =
-      split.getResults().drop_front(2 + hasLeadingEltwise).drop_back();
-
-  auto z = mlir::gpu::GPUThreadMappingAttr::get(b.getContext(),
-                                                ::mlir::gpu::Threads::DimZ);
-  auto y = mlir::gpu::GPUThreadMappingAttr::get(b.getContext(),
-                                                ::mlir::gpu::Threads::DimY);
-
-  // clang-format off
-  buildTileFuseDistWithTileSizes<TileToForeachThreadOp>(b,
-                   /*rootH=*/secondFusionRootH,
-                   /*opsHToFuse=*/secondFusionGroupHs,
-                   /*tileSizes=*/getAsOpFoldResult(b.getI64ArrayAttr(tileSizes1Fill)),
-                   /*threadDimMapping=*/b.getArrayAttr({z}));
-  buildTileFuseDistWithTileSizes<TileToForeachThreadOp>(b,
-                   /*rootH=*/firstFusionRootH,
-                   /*opsHToFuse=*/firstFusionGroupHs,
-                   /*tileSizes=*/getAsOpFoldResult(b.getI64ArrayAttr(tileSizes1Generic)),
-                   /*threadDimMapping=*/b.getArrayAttr({z,y}));
-  // clang-format on
-  return variantH;
-}
-
-/// Structure to hold the parameters related to GPU reduction strategy.
-struct GPUReductionStrategyInfos {
-  std::array<int64_t, 3> workgroupSize;
-  SmallVector<int64_t> workgroupTileSizes;
-  SmallVector<int64_t> fillSecondTileSizes;
-  SmallVector<int64_t> genericSecondTileSizes;
-  bool hasLeadingEltwise;
-  bool hasTrailingEltwise;
-};
-
-/// Returns a triple of handles: the leading elementwise operation, the
-/// reduction operation and the fusion root. The leading elementwise and the
-/// fusion root may be null. If the fusion root is null, the reduction operation
-/// should be used as fusion root instead.
-// TODO: consider passing a problem-specific struct to control information.
-static std::tuple<Value, Value, Value>
-createMatchReductionBlockDistributionHandles(ImplicitLocOpBuilder &b,
-                                             Value variantH,
-                                             bool hasLeadingEltwise,
-                                             bool hasTrailingEltwise) {
-  Value originalGenericH =
-      b.create<MatchOp>(variantH, linalg::GenericOp::getOperationName());
-  auto op = b.create<SplitHandlesOp>(
-      originalGenericH,
-      /*numResultHandles=*/1 + hasLeadingEltwise + hasTrailingEltwise);
-  return std::make_tuple(hasLeadingEltwise ? op.getResults().front() : Value(),
-                         op.getResults().drop_front(hasLeadingEltwise).front(),
-                         hasTrailingEltwise ? op.getResults().back() : Value());
-}
-
-// TODO: generalize and automate over and over.
-// TODO: significantly shrink this down.
-// TODO: consider passing a problem-specific struct to control information.
-static void createReductionCudaStrategy(
-    ImplicitLocOpBuilder &b, Value variantH,
-    const GPUReductionStrategyInfos &infos) {
-  // Step 0. Match the ops.
-  Value originalFillH =
-      b.create<MatchOp>(variantH, linalg::FillOp::getOperationName());
-  auto [leadingH, reductionH, fusionRootH] =
-      createMatchReductionBlockDistributionHandles(
-          b, variantH, infos.hasLeadingEltwise, infos.hasTrailingEltwise);
-
-  // Step 1: Distribute to blocks using the current IREE lowering config.
-  variantH = createReductionStrategyBlockDistributionPart(
-      b, variantH, originalFillH, reductionH, fusionRootH,
-      getAsOpFoldResult(b.getI64ArrayAttr(infos.workgroupTileSizes)),
-      infos.hasLeadingEltwise, infos.hasTrailingEltwise);
-
-  // Step 2. Second level of tiling + fusion parallelizes to threads.
-  variantH = createReductionStrategyThreadDistributionPart(
-      b, variantH, infos.fillSecondTileSizes, infos.genericSecondTileSizes,
-      infos.hasLeadingEltwise, infos.hasTrailingEltwise);
-
-  // Step 3. Rank-reduce and vectorize.
-  // TODO: assumes a single func::FuncOp to transform, may need hardening.
-  Value funcH = b.create<MatchOp>(variantH, func::FuncOp::getOperationName());
-  funcH = buildVectorizeStrategy(b, funcH);
-
-  // Step 4. Bufferize and drop HAL decriptor from memref ops.
-  variantH = b.create<IREEBufferizeOp>(variantH, /*targetGpu=*/true);
-  Value memrefFunc =
-      b.create<MatchOp>(variantH, func::FuncOp::getOperationName());
-  b.create<IREEEraseHALDescriptorTypeFromMemRefOp>(memrefFunc);
-
-  // Step 5. Post-bufferization mapping to blocks and threads.
-  // Need to match again since bufferize invalidated all handles.
-  // TODO: assumes a single func::FuncOp to transform, may need hardening.
-  funcH = b.create<MatchOp>(variantH, func::FuncOp::getOperationName());
-  funcH = buildMapToBlockAndThreads(b, funcH, infos.workgroupSize);
-
-  // Step 6. Post-bufferization vector distribution with rank-reduction.
-  buildDistributeVectors(b, variantH, funcH);
-}
-
-// TODO: consider passing a problem-specific struct to control information.
-static bool matchGPUReduction(linalg::LinalgOp op,
-                              GPUReductionStrategyInfos &info) {
-  // TODO: match the sequence the strategy supports.
-  StructuredOpMatcher pattern, fill, leadingEltwise, trailingEltwise;
-  makeGPUReductionMatcher(pattern, fill, leadingEltwise, trailingEltwise);
-  if (!matchPattern(op, pattern)) return false;
-
-  info.hasLeadingEltwise = leadingEltwise.getCaptured() != nullptr;
-  info.hasTrailingEltwise = trailingEltwise.getCaptured() != nullptr;
-
-  // Hardcoded workagroup size, this could be deduced from the reduction dim.
-  info.workgroupSize = {32, 2, 1};
-  SmallVector<unsigned> partitionedLoops =
-      cast<PartitionableLoopsInterface>(op.getOperation())
-          .getPartitionableLoops(kNumMaxParallelDims);
-  size_t numLoops = partitionedLoops.empty() ? 0 : partitionedLoops.back() + 1;
-  // Tile all the parallel dimension to 1.
-  info.workgroupTileSizes.append(numLoops, 1);
-  info.fillSecondTileSizes = {1, 0, 0};
-  info.genericSecondTileSizes = {1, 1, 0};
-  return true;
-}
-
-/// Structure to hold the parameters related to GPU reduction strategy.
-struct CPUReductionStrategyInfos {
-  int64_t workgroupSize;
-  SmallVector<int64_t> tileSizes;
-};
-
-static bool matchCPUReduction(linalg::LinalgOp op,
-                              CPUReductionStrategyInfos &infos) {
-  // TODO: match the sequence the strategy supports.
-  auto fill = m_StructuredOp<linalg::FillOp>();
-  auto pattern = m_StructuredOp()
-                     .dim(AllDims(), ShapeKind::Static)
-                     .dim(-1, utils::IteratorType::reduction)
-                     .output(NumEqualsTo(1))
-                     .output(0, fill);
-
-  // TODO: set the right config as expected by the strategy.
-  infos.workgroupSize = 1;
-  SmallVector<unsigned> partitionedLoops =
-      cast<PartitionableLoopsInterface>(op.getOperation())
-          .getPartitionableLoops(kNumMaxParallelDims);
-  size_t numLoops = partitionedLoops.empty() ? 0 : partitionedLoops.back() + 1;
-  // Tile all the parallel dimension to 1.
-  infos.tileSizes.append(numLoops, 1);
-  return true;
-}
-
-using StrategyBuilderFn = std::function<void(ImplicitLocOpBuilder &, Value)>;
-
-static void createTransformRegion(func::FuncOp entryPoint,
-                                  StrategyBuilderFn buildStrategy) {
+void mlir::iree_compiler::createTransformRegion(
+    func::FuncOp entryPoint, StrategyBuilderFn buildStrategy) {
   MLIRContext *ctx = entryPoint.getContext();
   Location loc = entryPoint.getLoc();
   OpBuilder b(ctx);
@@ -524,71 +391,3 @@
                     << "\n");
   LLVM_DEBUG(sequence.print(DBGS()));
 }
-
-// TODO: generalize and automate over and over.
-// TODO: significantly shrink this down.
-static LogicalResult createReductionCpuStrategy(
-    ImplicitLocOpBuilder &b, Value variantH,
-    const CPUReductionStrategyInfos &info) {
-  // Step 0. Fetch transform information from the config and materialize it in
-  // the payload IR.
-  // TODO: this still requires specific knowledge of ops present in the IR
-  // and is very brittle.
-  Value originalFillH =
-      b.create<MatchOp>(variantH, linalg::FillOp::getOperationName());
-  Value originalGenericH =
-      b.create<MatchOp>(variantH, linalg::GenericOp::getOperationName());
-
-  // Step 1: Distribute to blocks using the current IREE lowering config.
-  variantH = createReductionStrategyBlockDistributionPart(
-      b, variantH, originalFillH, originalGenericH, Value(),
-      getAsOpFoldResult(b.getI64ArrayAttr(info.tileSizes)));
-
-  // Step 2. Rank-reduce and buildVectorizeStrategy.
-  // TODO: assumes a single func::FuncOp to transform, may need hardening.
-  Value funcH = b.create<MatchOp>(variantH, func::FuncOp::getOperationName());
-  funcH = buildVectorizeStrategy(b, funcH);
-
-  // Step 3. Bufferize and drop HAL decriptor from memref ops.
-  variantH = b.create<IREEBufferizeOp>(variantH, /*targetGpu=*/true);
-  Value memrefFunc =
-      b.create<MatchOp>(variantH, func::FuncOp::getOperationName());
-  b.create<IREEEraseHALDescriptorTypeFromMemRefOp>(memrefFunc);
-
-  // Step 4. Post-bufferization mapping to blocks only.
-  // Need to match again since bufferize invalidated all handles.
-  // TODO: assumes a single func::FuncOp to transform, may need hardening.
-  funcH = b.create<MatchOp>(variantH, func::FuncOp::getOperationName());
-  funcH = b.create<ForeachThreadToWorkgroupOp>(funcH);
-
-  return success();
-}
-
-LogicalResult matchAndSetGPUReductionTransformStrategy(func::FuncOp entryPoint,
-                                                       linalg::LinalgOp op) {
-  // 1. Match
-  GPUReductionStrategyInfos infos;
-  if (!matchGPUReduction(op, infos)) return failure();
-  auto strategyBuilder = [&](ImplicitLocOpBuilder &b, Value variant) {
-    return createReductionCudaStrategy(b, variant, infos);
-  };
-  // 2. Add the strategy.
-  createTransformRegion(entryPoint, strategyBuilder);
-  return success();
-}
-
-LogicalResult matchAndSetCPUReductionTransformStrategy(func::FuncOp entryPoint,
-                                                       linalg::LinalgOp op) {
-  // 1. Match
-  CPUReductionStrategyInfos infos;
-  if (!matchCPUReduction(op, infos)) return failure();
-  auto startegyBuilder = [&](ImplicitLocOpBuilder &b, Value variant) {
-    return createReductionCpuStrategy(b, variant, infos);
-  };
-  // 2. Add the strategy.
-  createTransformRegion(entryPoint, startegyBuilder);
-  return success();
-}
-
-}  // namespace iree_compiler
-}  // namespace mlir
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.h b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.h
index 83ddb85..c7ca9fb 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.h
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.h
@@ -3,7 +3,8 @@
 // Licensed under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#ifndef IREE_COMPILER_CODEGEN_LLVMGPU_GPUTRANSFORMDIALECT_STRATEGIES_H_
+
+#ifndef IREE_COMPILER_CODEGEN_COMMON_TRANSFORMDIALECT_STRATEGIES_H_
 
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -12,15 +13,95 @@
 namespace mlir {
 namespace iree_compiler {
 
-/// Return success if the IR matches what the GPU reduction strategy can handle.
-/// If it is success it will append the transform dialect after the entry point
-/// module.
-LogicalResult matchAndSetGPUReductionTransformStrategy(func::FuncOp entryPoint,
-                                                       linalg::LinalgOp op);
+//===----------------------------------------------------------------------===//
+// Low-level reusable builder APIs, these should follow MLIR-style builders.
+//===----------------------------------------------------------------------===//
 
-LogicalResult matchAndSetCPUReductionTransformStrategy(func::FuncOp entryPoint,
-                                                       linalg::LinalgOp op);
+/// Prints `handles` in order. Prints the whole IR if `handles` is empty.
+static void buildPrint(ImplicitLocOpBuilder &b, ValueRange handles = {});
+
+/// Performs the following transformations:
+///   1. Tiles `rootH` to scf.foreach_thread to with `tileSizesOrNumThreads`
+///      according to whether spec is a TileSizesSpec or a NumThreadsSpec.
+///   2. Maps the resulting scf.foreach_thread to threads according to
+///      `threadDimMapping`.
+///   3. Iterates over `opsHToFuse` in order and fuses into the containing op.
+/// Returns a handle to the resulting scf.foreach_thread.
+///
+/// Fusion operates in batch mode: a single fusion command is issued and a
+/// topological sort is automatically computed by the fusion.
+/// Since this applies a single fusion, no interleaved canonicalization / cse /
+/// enabling transformation occurs and the resulting fusion may not be as good.
+///
+/// In the future, an iterative mode in which the user is responsible for
+/// providing the fusion order and has interleaved canonicalization / cse /
+/// enabling transform will be introduced and may result in better fusions.
+///
+/// If `resultingFusedOpsHandles` is a non-null pointer, the fused operation are
+/// appended in order.
+///
+// TODO: if someone knows how to properly export templates go for it .. sigh.
+Value buildTileFuseDistToForeachThreadWithTileSizes(
+    ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
+    ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping,
+    SmallVectorImpl<Value> *resultingFusedOpsHandles = nullptr);
+Value buildTileFuseDistToForeachThreadAndWorgroupCountWithTileSizes(
+    ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
+    ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping,
+    SmallVectorImpl<Value> *resultingFusedOpsHandles = nullptr);
+
+/// See buildTileFuseDistWithTileSizes.
+Value buildTileFuseDistToForeachThreadWithNumThreads(
+    ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
+    ArrayRef<OpFoldResult> numThreads, ArrayAttr threadDimMapping,
+    SmallVectorImpl<Value> *resultingFusedOpsHandles = nullptr);
+Value buildTileFuseDistToForeachThreadAndWorgroupCountWithNumThreads(
+    ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
+    ArrayRef<OpFoldResult> numThreads, ArrayAttr threadDimMapping,
+    SmallVectorImpl<Value> *resultingFusedOpsHandles = nullptr);
+
+/// Apply patterns and vectorize (for now always applies rank-reduction).
+/// Takes a handle to a func.func and returns an updated handle to a
+/// func.func.
+Value buildVectorize(ImplicitLocOpBuilder &b, Value funcH);
+
+/// Bufferize and drop HAL decriptor from memref ops.
+/// Takes a handle variantOp and returns a handle to the same variant op.
+Value buildBufferize(ImplicitLocOpBuilder &b, Value variantH,
+                     bool targetGpu = false);
+
+/// Post-bufferization mapping to blocks and threads.
+/// Takes a handle to a func.func and returns an updated handle to a
+/// func.func.
+Value buildMapToBlockAndThreads(ImplicitLocOpBuilder &b, Value funcH,
+                                ArrayRef<int64_t> blockSize);
+
+static constexpr unsigned kCudaWarpSize = 32;
+
+/// Post-bufferization vector distribution with rank-reduction.
+/// Takes a handle to a func.func and returns an updated handle to a
+/// func.func.
+Value buildDistributeVectors(ImplicitLocOpBuilder &b, Value variantH,
+                             Value funcH, int64_t warpSize = kCudaWarpSize);
+
+using StrategyBuilderFn = std::function<void(ImplicitLocOpBuilder &, Value)>;
+
+void createTransformRegion(func::FuncOp entryPoint,
+                           StrategyBuilderFn buildStrategy);
+
+//===----------------------------------------------------------------------===//
+// Higher-level problem-specific strategy creation APIs, these should favor
+// user-friendliness.
+//===----------------------------------------------------------------------===//
+/// Distribute to blocks using the current IREE lowering config.
+// TODO: consider passing a problem-specific struct to control information.
+Value createReductionStrategyBlockDistributionPart(
+    ImplicitLocOpBuilder &b, Value variantH, Value originalFillH,
+    Value reductionH, Value optionalFusionRootH,
+    ArrayRef<OpFoldResult> tileSizes0Generic, bool hasLeadingEltwise = false,
+    bool hasTrailingEltwise = false);
+
 }  // namespace iree_compiler
 }  // namespace mlir
 
-#endif  // IREE_COMPILER_CODEGEN_LLVMGPU_GPUTRANSFORMDIALECT_STRATEGIES_H_
+#endif  // IREE_COMPILER_CODEGEN_COMMON_TRANSFORMDIALECT_STRATEGIES_H_
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategiesCPU.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategiesCPU.cpp
new file mode 100644
index 0000000..64b3d09
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategiesCPU.cpp
@@ -0,0 +1,178 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Common/TransformDialectStrategiesCPU.h"
+
+#include <numeric>
+#include <type_traits>
+
+#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h"
+#include "iree-dialects/Transforms/TransformMatchers.h"
+#include "iree/compiler/Codegen/Common/TransformDialectStrategies.h"
+#include "iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h"
+#include "iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.h"
+#include "iree/compiler/Codegen/PassDetail.h"
+#include "iree/compiler/Codegen/Passes.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+
+using namespace mlir;
+
+#define DEBUG_TYPE "iree-transform-builder"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+// TODO: significantly better namespacing.
+using iree_compiler::IREE::transform_dialect::ApplyPatternsOp;
+using iree_compiler::IREE::transform_dialect::ConfigExtractPart;
+using iree_compiler::IREE::transform_dialect::ForeachThreadToWorkgroupOp;
+using iree_compiler::IREE::transform_dialect::IREEBufferizeOp;
+using iree_compiler::IREE::transform_dialect::
+    IREEEraseHALDescriptorTypeFromMemRefOp;
+using iree_compiler::IREE::transform_dialect::
+    MapNestedForeachThreadToGpuThreadsOp;
+using iree_compiler::IREE::transform_dialect::
+    TileToForeachThreadAndWorkgroupCountRegionOp;
+using iree_compiler::IREE::transform_dialect::VectorToWarpExecuteOnLane0Op;
+using iree_compiler::IREE::transform_dialect::VectorWarpDistributionOp;
+using transform::FuseIntoContainingOp;
+using transform::MatchOp;
+using transform::MergeHandlesOp;
+using transform::PrintOp;
+using transform::SequenceOp;
+using transform::SplitHandlesOp;
+using transform::SplitReductionOp;
+using transform::TileToForeachThreadOp;
+using transform::VectorizeOp;
+using transform_ext::AllDims;
+using transform_ext::IsPermutation;
+using transform_ext::m_StructuredOp;
+using transform_ext::NumEqualsTo;
+using transform_ext::ShapeKind;
+using transform_ext::StructuredOpMatcher;
+
+/// Matches `args` within `targetH` and unpacks a number of handles `N`.
+/// Assumes there are exactly `N` matched ops (but could be relaxed).
+/// Returns the tuple of handles.
+template <int N, typename... MatchingArgs>
+auto matchAndUnpack(ImplicitLocOpBuilder &b, Value targetH,
+                    MatchingArgs... args) {
+  Value matchedH = b.create<MatchOp>(targetH, args...);
+  auto matchOp = b.create<SplitHandlesOp>(matchedH,
+                                          /*numHandles=*/N);
+  assert(matchOp->getNumResults() == N && "Unexpected number of results");
+  std::array<Value, N> a;
+  for (int64_t i = 0; i < N; ++i) a[i] = matchOp->getResult(i);
+  return std::tuple_cat(a);
+}
+
+//===----------------------------------------------------------------------===//
+// Higher-level problem-specific strategy creation APIs, these should favor
+// user-friendliness.
+//===----------------------------------------------------------------------===//
+
+// TODO: consider lifting and exposing.
+
+/// Structure to hold the parameters related to GPU reduction strategy.
+struct CPUReductionStrategyInfos {
+  int64_t workgroupSize;
+  SmallVector<int64_t> tileSizes;
+};
+
+static bool matchCPUReduction(linalg::LinalgOp op,
+                              CPUReductionStrategyInfos &infos) {
+  // TODO: match the sequence the strategy supports.
+  auto fill = m_StructuredOp<linalg::FillOp>();
+  auto pattern = m_StructuredOp()
+                     .dim(AllDims(), ShapeKind::Static)
+                     .dim(-1, utils::IteratorType::reduction)
+                     .output(NumEqualsTo(1))
+                     .output(0, fill);
+
+  // TODO: set the right config as expected by the strategy.
+  infos.workgroupSize = 1;
+  SmallVector<unsigned> partitionedLoops =
+      cast<iree_compiler::PartitionableLoopsInterface>(op.getOperation())
+          .getPartitionableLoops(iree_compiler::kNumMaxParallelDims);
+  size_t numLoops = partitionedLoops.empty() ? 0 : partitionedLoops.back() + 1;
+  // Tile all the parallel dimension to 1.
+  infos.tileSizes.append(numLoops, 1);
+  return true;
+}
+
+// TODO: generalize and automate over and over.
+// TODO: significantly shrink this down.
+static LogicalResult createReductionCpuStrategy(
+    ImplicitLocOpBuilder &b, Value variantH,
+    const CPUReductionStrategyInfos &info) {
+  // Step 0. Fetch transform information from the config and materialize it in
+  // the payload IR.
+  // TODO: this still requires specific knowledge of ops present in the IR
+  // and is very brittle.
+  Value originalFillH =
+      b.create<MatchOp>(variantH, linalg::FillOp::getOperationName());
+  Value originalGenericH =
+      b.create<MatchOp>(variantH, linalg::GenericOp::getOperationName());
+
+  // Step 1: Distribute to blocks using the current IREE lowering config.
+  variantH = iree_compiler::createReductionStrategyBlockDistributionPart(
+      b, variantH, originalFillH, originalGenericH, Value(),
+      getAsOpFoldResult(b.getI64ArrayAttr(info.tileSizes)));
+
+  // Step 2. Rank-reduce and buildVectorize.
+  // TODO: assumes a single func::FuncOp to transform, may need hardening.
+  Value funcH = b.create<MatchOp>(variantH, func::FuncOp::getOperationName());
+  funcH = iree_compiler::buildVectorize(b, funcH);
+
+  // Step 3. Bufferize and drop HAL descriptor from memref ops.
+  variantH = iree_compiler::buildBufferize(b, variantH, /*targetGpu=*/true);
+
+  // Step 4. Post-bufferization mapping to blocks only.
+  // Need to match again since bufferize invalidated all handles.
+  // TODO: assumes a single func::FuncOp to transform, may need hardening.
+  funcH = b.create<MatchOp>(variantH, func::FuncOp::getOperationName());
+  funcH = b.create<ForeachThreadToWorkgroupOp>(funcH);
+
+  return success();
+}
+
+LogicalResult iree_compiler::matchAndSetCPUReductionTransformStrategy(
+    func::FuncOp entryPoint, linalg::LinalgOp op) {
+  // 1. Match
+  CPUReductionStrategyInfos infos;
+  if (!matchCPUReduction(op, infos)) return failure();
+  auto startegyBuilder = [&](ImplicitLocOpBuilder &b, Value variant) {
+    return createReductionCpuStrategy(b, variant, infos);
+  };
+  // 2. Add the strategy.
+  createTransformRegion(entryPoint, startegyBuilder);
+  return success();
+}
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategiesCPU.h b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategiesCPU.h
new file mode 100644
index 0000000..965a366
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategiesCPU.h
@@ -0,0 +1,23 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_CODEGEN_COMMON_TRANSFORMDIALECT_STRATEGIES_CPU_H_
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/IR/BuiltinOps.h"
+
+namespace mlir {
+namespace iree_compiler {
+/// Return success if the IR matches what the GPU reduction strategy can handle.
+/// If it is success it will append the transform dialect after the entry point
+/// module.
+LogicalResult matchAndSetCPUReductionTransformStrategy(func::FuncOp entryPoint,
+                                                       linalg::LinalgOp op);
+}  // namespace iree_compiler
+}  // namespace mlir
+
+#endif  // IREE_COMPILER_CODEGEN_COMMON_TRANSFORMDIALECT_STRATEGIES_CPU_H_
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategiesGPU.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategiesGPU.cpp
new file mode 100644
index 0000000..c4e119d
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategiesGPU.cpp
@@ -0,0 +1,249 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Common/TransformDialectStrategiesGPU.h"
+
+#include <numeric>
+#include <type_traits>
+
+#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h"
+#include "iree-dialects/Transforms/TransformMatchers.h"
+#include "iree/compiler/Codegen/Common/TransformDialectStrategies.h"
+#include "iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h"
+#include "iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.h"
+#include "iree/compiler/Codegen/PassDetail.h"
+#include "iree/compiler/Codegen/Passes.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+
+using namespace mlir;
+
+#define DEBUG_TYPE "iree-transform-builder"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+// TODO: significantly better namespacing.
+using iree_compiler::IREE::transform_dialect::ApplyPatternsOp;
+using iree_compiler::IREE::transform_dialect::ConfigExtractPart;
+using iree_compiler::IREE::transform_dialect::ForeachThreadToWorkgroupOp;
+using iree_compiler::IREE::transform_dialect::IREEBufferizeOp;
+using iree_compiler::IREE::transform_dialect::
+    IREEEraseHALDescriptorTypeFromMemRefOp;
+using iree_compiler::IREE::transform_dialect::
+    MapNestedForeachThreadToGpuThreadsOp;
+using iree_compiler::IREE::transform_dialect::
+    TileToForeachThreadAndWorkgroupCountRegionOp;
+using iree_compiler::IREE::transform_dialect::VectorToWarpExecuteOnLane0Op;
+using iree_compiler::IREE::transform_dialect::VectorWarpDistributionOp;
+using transform::FuseIntoContainingOp;
+using transform::MatchOp;
+using transform::MergeHandlesOp;
+using transform::PrintOp;
+using transform::SequenceOp;
+using transform::SplitHandlesOp;
+using transform::SplitReductionOp;
+using transform::TileToForeachThreadOp;
+using transform::VectorizeOp;
+using transform_ext::AllDims;
+using transform_ext::IsPermutation;
+using transform_ext::m_StructuredOp;
+using transform_ext::NumEqualsTo;
+using transform_ext::ShapeKind;
+using transform_ext::StructuredOpMatcher;
+
+/// Matches `args` within `targetH` and unpacks a number of handles `N`.
+/// Assumes there are exactly `N` matched ops (but could be relaxed).
+/// Returns the tuple of handles.
+template <int N, typename... MatchingArgs>
+auto matchAndUnpack(ImplicitLocOpBuilder &b, Value targetH,
+                    MatchingArgs... args) {
+  Value matchedH = b.create<MatchOp>(targetH, args...);
+  auto matchOp = b.create<SplitHandlesOp>(matchedH,
+                                          /*numHandles=*/N);
+  assert(matchOp->getNumResults() == N && "Unexpected number of results");
+  std::array<Value, N> a;
+  for (int64_t i = 0; i < N; ++i) a[i] = matchOp->getResult(i);
+  return std::tuple_cat(a);
+}
+
+//===----------------------------------------------------------------------===//
+// Higher-level problem-specific strategy creation APIs, these should favor
+// user-friendliness.
+//===----------------------------------------------------------------------===//
+
+// TODO: consider passing a problem-specific struct to control information.
+static Value createReductionStrategyThreadDistributionPart(
+    ImplicitLocOpBuilder &b, Value variantH, ArrayRef<int64_t> tileSizes1Fill,
+    ArrayRef<int64_t> tileSizes1Generic, bool hasLeadingEltwise,
+    bool hasTrailingEltwise) {
+  // TODO: Relying on ordering is brittle, harden this.
+  Value matchedH = b.create<MatchOp>(
+      variantH, ArrayRef<StringRef>{linalg::GenericOp::getOperationName(),
+                                    linalg::FillOp::getOperationName()});
+  auto split = b.create<SplitHandlesOp>(
+      matchedH,
+      /*numResultHandles=*/4 + hasLeadingEltwise + hasTrailingEltwise);
+  Value firstFusionRootH = split.getResults()[1 + hasLeadingEltwise];
+  SmallVector<Value> firstFusionGroupHs =
+      split.getResults().take_front(1 + hasLeadingEltwise);
+  Value secondFusionRootH = split.getResults().back();
+  SmallVector<Value> secondFusionGroupHs =
+      split.getResults().drop_front(2 + hasLeadingEltwise).drop_back();
+
+  auto z = mlir::gpu::GPUThreadMappingAttr::get(b.getContext(),
+                                                ::mlir::gpu::Threads::DimZ);
+  auto y = mlir::gpu::GPUThreadMappingAttr::get(b.getContext(),
+                                                ::mlir::gpu::Threads::DimY);
+
+  // clang-format off
+  iree_compiler::buildTileFuseDistToForeachThreadWithTileSizes(b,
+                   /*rootH=*/secondFusionRootH,
+                   /*opsHToFuse=*/secondFusionGroupHs,
+                   /*tileSizes=*/getAsOpFoldResult(b.getI64ArrayAttr(tileSizes1Fill)),
+                   /*threadDimMapping=*/b.getArrayAttr({z}));
+  iree_compiler::buildTileFuseDistToForeachThreadWithTileSizes(b,
+                   /*rootH=*/firstFusionRootH,
+                   /*opsHToFuse=*/firstFusionGroupHs,
+                   /*tileSizes=*/getAsOpFoldResult(b.getI64ArrayAttr(tileSizes1Generic)),
+                   /*threadDimMapping=*/b.getArrayAttr({z,y}));
+  // clang-format on
+  return variantH;
+}
+
+/// Structure to hold the parameters related to GPU reduction strategy.
+struct GPUReductionStrategyInfos {
+  std::array<int64_t, 3> workgroupSize;
+  SmallVector<int64_t> workgroupTileSizes;
+  SmallVector<int64_t> fillSecondTileSizes;
+  SmallVector<int64_t> genericSecondTileSizes;
+  bool hasLeadingEltwise;
+  bool hasTrailingEltwise;
+};
+
+/// Returns a triple of handles: the leading elementwise operation, the
+/// reduction operation and the fusion root. The leading elementwise and the
+/// fusion root may be null. If the fusion root is null, the reduction operation
+/// should be used as fusion root instead.
+// TODO: consider passing a problem-specific struct to control information.
+static std::tuple<Value, Value, Value>
+createMatchReductionBlockDistributionHandles(ImplicitLocOpBuilder &b,
+                                             Value variantH,
+                                             bool hasLeadingEltwise,
+                                             bool hasTrailingEltwise) {
+  Value originalGenericH =
+      b.create<MatchOp>(variantH, linalg::GenericOp::getOperationName());
+  auto op = b.create<SplitHandlesOp>(
+      originalGenericH,
+      /*numResultHandles=*/1 + hasLeadingEltwise + hasTrailingEltwise);
+  return std::make_tuple(hasLeadingEltwise ? op.getResults().front() : Value(),
+                         op.getResults().drop_front(hasLeadingEltwise).front(),
+                         hasTrailingEltwise ? op.getResults().back() : Value());
+}
+
+// TODO: generalize and automate over and over.
+// TODO: significantly shrink this down.
+// TODO: consider passing a problem-specific struct to control information.
+static void createReductionCudaStrategy(
+    ImplicitLocOpBuilder &b, Value variantH,
+    const GPUReductionStrategyInfos &infos) {
+  // Step 0. Match the ops.
+  Value originalFillH =
+      b.create<MatchOp>(variantH, linalg::FillOp::getOperationName());
+  auto [leadingH, reductionH, fusionRootH] =
+      createMatchReductionBlockDistributionHandles(
+          b, variantH, infos.hasLeadingEltwise, infos.hasTrailingEltwise);
+
+  // Step 1: Distribute to blocks using the current IREE lowering config.
+  variantH = iree_compiler::createReductionStrategyBlockDistributionPart(
+      b, variantH, originalFillH, reductionH, fusionRootH,
+      getAsOpFoldResult(b.getI64ArrayAttr(infos.workgroupTileSizes)),
+      infos.hasLeadingEltwise, infos.hasTrailingEltwise);
+
+  // Step 2. Second level of tiling + fusion parallelizes to threads.
+  variantH = createReductionStrategyThreadDistributionPart(
+      b, variantH, infos.fillSecondTileSizes, infos.genericSecondTileSizes,
+      infos.hasLeadingEltwise, infos.hasTrailingEltwise);
+
+  // Step 3. Rank-reduce and vectorize.
+  // TODO: assumes a single func::FuncOp to transform, may need hardening.
+  Value funcH = b.create<MatchOp>(variantH, func::FuncOp::getOperationName());
+  funcH = iree_compiler::buildVectorize(b, funcH);
+
+  // Step 4. Bufferize and drop HAL descriptor from memref ops.
+  variantH = iree_compiler::buildBufferize(b, variantH, /*targetGpu=*/true);
+
+  // Step 5. Post-bufferization mapping to blocks and threads.
+  // Need to match again since bufferize invalidated all handles.
+  // TODO: assumes a single func::FuncOp to transform, may need hardening.
+  funcH = b.create<MatchOp>(variantH, func::FuncOp::getOperationName());
+  funcH =
+      iree_compiler::buildMapToBlockAndThreads(b, funcH, infos.workgroupSize);
+
+  // Step 6. Post-bufferization vector distribution with rank-reduction.
+  iree_compiler::buildDistributeVectors(b, variantH, funcH);
+}
+
+// TODO: consider passing a problem-specific struct to control information.
+static bool matchGPUReduction(linalg::LinalgOp op,
+                              GPUReductionStrategyInfos &info) {
+  // TODO: match the sequence the strategy supports.
+  StructuredOpMatcher pattern, fill, leadingEltwise, trailingEltwise;
+  makeReductionMatcher(pattern, fill, leadingEltwise, trailingEltwise);
+  if (!matchPattern(op, pattern)) return false;
+
+  info.hasLeadingEltwise = leadingEltwise.getCaptured() != nullptr;
+  info.hasTrailingEltwise = trailingEltwise.getCaptured() != nullptr;
+
+  // Hardcoded workagroup size, this could be deduced from the reduction dim.
+  info.workgroupSize = {32, 2, 1};
+  SmallVector<unsigned> partitionedLoops =
+      cast<iree_compiler::PartitionableLoopsInterface>(op.getOperation())
+          .getPartitionableLoops(iree_compiler::kNumMaxParallelDims);
+  size_t numLoops = partitionedLoops.empty() ? 0 : partitionedLoops.back() + 1;
+  // Tile all the parallel dimension to 1.
+  info.workgroupTileSizes.append(numLoops, 1);
+  info.fillSecondTileSizes = {1, 0, 0};
+  info.genericSecondTileSizes = {1, 1, 0};
+  return true;
+}
+
+LogicalResult iree_compiler::matchAndSetGPUReductionTransformStrategy(
+    func::FuncOp entryPoint, linalg::LinalgOp op) {
+  // 1. Match
+  GPUReductionStrategyInfos infos;
+  if (!matchGPUReduction(op, infos)) return failure();
+  auto strategyBuilder = [&](ImplicitLocOpBuilder &b, Value variant) {
+    return createReductionCudaStrategy(b, variant, infos);
+  };
+  // 2. Add the strategy.
+  createTransformRegion(entryPoint, strategyBuilder);
+  return success();
+}
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategiesGPU.h b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategiesGPU.h
new file mode 100644
index 0000000..f83d0d6
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategiesGPU.h
@@ -0,0 +1,25 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_CODEGEN_COMMON_TRANSFORMDIALECT_STRATEGIES_GPU_H_
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/IR/BuiltinOps.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+/// Return success if the IR matches what the GPU reduction strategy can handle.
+/// If it is success it will append the transform dialect after the entry point
+/// module.
+LogicalResult matchAndSetGPUReductionTransformStrategy(func::FuncOp entryPoint,
+                                                       linalg::LinalgOp op);
+
+}  // namespace iree_compiler
+}  // namespace mlir
+
+#endif  // IREE_COMPILER_CODEGEN_COMMON_TRANSFORMDIALECT_STRATEGIES_GPU_H_
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD
index 2b4d3e8..41376fd 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD
@@ -57,12 +57,10 @@
     srcs = [
         "CommonExtensions.cpp",
         "CommonExtensionsOps.cpp.inc",
-        "TransformMatchers.cpp",
     ],
     hdrs = [
         "CommonExtensions.h",
         "CommonExtensionsOps.h.inc",
-        "TransformMatchers.h",
     ],
     deps = [
         ":CommonExtensionsOpGen",
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt
index 2fd84a7..04ee20b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt
@@ -26,11 +26,9 @@
   HDRS
     "CommonExtensions.h"
     "CommonExtensionsOps.h.inc"
-    "TransformMatchers.h"
   SRCS
     "CommonExtensions.cpp"
     "CommonExtensionsOps.cpp.inc"
-    "TransformMatchers.cpp"
   DEPS
     ::CommonExtensionsOpGen
     IREEDialectsTransforms
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index 2ecd1a2..1e1e953 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -10,7 +10,7 @@
 #include "iree-dialects/Dialect/LinalgTransform/SimplePatternRewriter.h"
 #include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h"
 #include "iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h"
-#include "iree/compiler/Codegen/Common/TransformExtensions/TransformMatchers.h"
+#include "iree-dialects/Transforms/TransformMatchers.h"
 #include "iree/compiler/Codegen/Common/Transforms.h"
 #include "iree/compiler/Codegen/Interfaces/BufferizationInterfaces.h"
 #include "iree/compiler/Codegen/Passes.h"
@@ -1038,229 +1038,5 @@
   return DiagnosedSilenceableFailure::success();
 }
 
-//===---------------------------------------------------------------------===//
-// RegisterMatchCallbacksOp
-//===---------------------------------------------------------------------===//
-
-/// Match callback for "_test_match_callback" hook. Matches any payload
-/// operations associated with operand handles unless they have the
-/// "test.iree_transform_do_not_match" attribute, in which case produces a
-/// silenceable failure.
-static DiagnosedSilenceableFailure testMatchCallbackCallback(
-    transform_dialect::MatchCallbackResult &res, Location loc,
-    const transform::TransformState &state, ValueRange handles) {
-  bool hadFailures = false;
-  for (Value handle : handles) {
-    if (llvm::any_of(state.getPayloadOps(handle), [](Operation *op) {
-          return op->hasAttr("test.iree_transform_do_not_match");
-        })) {
-      res.addPayloadGroup(ArrayRef<Operation *>());
-      hadFailures = true;
-    } else {
-      res.addPayloadGroup(state.getPayloadOps(handle));
-    }
-  }
-  if (hadFailures) return emitSilenceableFailure(loc) << "failed to match";
-  return DiagnosedSilenceableFailure::success();
-}
-
-/// Match callback for a reduction with optional leading and trailing
-/// elementwise operations. Matches *the first* occurrence of such a reduction
-/// within an op associated with the given handle.
-///
-/// Input handles:
-///
-///   - container op, must be associated with one operation.
-///
-/// Output handles:
-///
-///   - leading elementwise op, if any;
-///   - the "fill" op preceding the reduction;
-///   - reduction op;
-///   - trailing elementwise op, if any.
-static DiagnosedSilenceableFailure reductionCallback(
-    transform_dialect::MatchCallbackResult &res, Location loc,
-    const transform::TransformState &state, ValueRange handles) {
-  if (handles.size() != 1 || state.getPayloadOps(handles[0]).size() != 1) {
-    return emitSilenceableFailure(loc)
-           << "expected one handle to one operation";
-  }
-
-  transform_dialect::StructuredOpMatcher pattern, fill, leadingEltwise,
-      trailingEltwise;
-  makeGPUReductionMatcher(pattern, fill, leadingEltwise, trailingEltwise);
-
-  // TODO: need a mechanism for this to go around the entire IR,
-  // potentially with list matches for each group.
-  Operation *root = state.getPayloadOps(handles[0])[0];
-  WalkResult walkResult = root->walk([&](Operation *op) {
-    pattern.resetCapture();
-    if (!matchPattern(op, pattern)) return WalkResult::advance();
-
-    res.addPotentiallyEmptyPayloadGroup(leadingEltwise.getCaptured());
-    res.addPayloadGroup({fill.getCaptured()});
-    res.addPayloadGroup({pattern.getCaptured()});
-    res.addPotentiallyEmptyPayloadGroup(trailingEltwise.getCaptured());
-    return WalkResult::interrupt();
-  });
-
-  if (walkResult.wasInterrupted())
-    return DiagnosedSilenceableFailure::success();
-  return emitSilenceableFailure(loc) << "failed to match";
-}
-
-/// Match callback for a reduction after splitting with optional leading and
-/// trailing elementwise operations. Matches *the first* occurrence of such a
-/// reduction within an op associated with the given handle.
-///
-/// Input handles:
-///
-///   - container op, must be associated with one operation.
-///
-/// Output handles:
-///
-///   - leading elementwise op, if any;
-///   - the "fill" op preceding the original reduction;
-///   - the "fill" op preceding the split, more parallel reduction;
-///   - the split, more parallel reduction op;
-///   - reduction op;
-///   - trailing elementwise op, if any.
-static DiagnosedSilenceableFailure splitReductionCallback(
-    transform_dialect::MatchCallbackResult &res, Location loc,
-    const transform::TransformState &state, ValueRange handles) {
-  if (handles.size() != 1 || state.getPayloadOps(handles[0]).size() != 1) {
-    return emitSilenceableFailure(loc)
-           << "expected one handle to one operation";
-  }
-
-  transform_dialect::StructuredOpMatcher parallel_reduction, combiner_reduction,
-      parallel_fill, original_fill, leading, trailing;
-  makeGPUSplitReductionMatcher(parallel_reduction, combiner_reduction,
-                               parallel_fill, original_fill, leading, trailing);
-
-  // TODO: need a mechanism for this to go around the entire IR,
-  // potentially with list matches for each group.
-  Operation *root = state.getPayloadOps(handles[0])[0];
-  WalkResult walkResult = root->walk([&](Operation *op) {
-    combiner_reduction.resetCapture();
-    if (!matchPattern(op, combiner_reduction)) return WalkResult::advance();
-
-    res.addPotentiallyEmptyPayloadGroup(leading.getCaptured());
-    res.addPayloadGroup({original_fill.getCaptured()});
-    res.addPayloadGroup({parallel_fill.getCaptured()});
-    res.addPayloadGroup({parallel_reduction.getCaptured()});
-    res.addPayloadGroup({combiner_reduction.getCaptured()});
-    res.addPotentiallyEmptyPayloadGroup(trailing.getCaptured());
-    return WalkResult::interrupt();
-  });
-
-  if (walkResult.wasInterrupted())
-    return DiagnosedSilenceableFailure::success();
-  return emitSilenceableFailure(loc) << "failed to match";
-}
-
-DiagnosedSilenceableFailure transform_dialect::RegisterMatchCallbacksOp::apply(
-    transform::TransformResults &results, transform::TransformState &state) {
-  auto &registry = state.addExtension<MatchCallbacksRegistry>();
-  registry.registerCallback("_test_match_callback", testMatchCallbackCallback);
-  registry.registerCallback("reduction", reductionCallback);
-  registry.registerCallback("split_reduction", splitReductionCallback);
-  return DiagnosedSilenceableFailure::success();
-}
-
-void transform_dialect::RegisterMatchCallbacksOp::getEffects(
-    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  // TODO: it doesn't really modify the payload, we need a separate resource for
-  // this mapping.
-  transform::modifiesPayload(effects);
-}
-
-//===---------------------------------------------------------------------===//
-// MatchCallbackOp
-//===---------------------------------------------------------------------===//
-
-DiagnosedSilenceableFailure transform_dialect::MatchCallbackOp::apply(
-    transform::TransformResults &results, transform::TransformState &state) {
-  auto setEmptyResults = [&results, this] {
-    for (OpResult value : getResults()) {
-      results.set(value, {});
-    }
-  };
-  auto errorOut = [this, &setEmptyResults] {
-    setEmptyResults();
-    return emitSilenceableError();
-  };
-
-  auto *registry = state.getExtension<MatchCallbacksRegistry>();
-  if (!registry) return errorOut() << "match registry not available";
-
-  const MatchCallbacksRegistry::MatchCallbackFn *callback =
-      registry->get(getCallbackName());
-  if (!callback) {
-    return errorOut() << "callback '" << getCallbackName()
-                      << "' not found in the registry";
-  }
-
-  MatchCallbackResult result;
-  DiagnosedSilenceableFailure status =
-      (*callback)(result, getLoc(), state, getInputs());
-  if (!status.succeeded()) {
-    setEmptyResults();
-    if (status.isDefiniteFailure()) return status;
-    if (getFailurePropagationMode() ==
-        transform::FailurePropagationMode::Propagate) {
-      return emitSilenceableError() << "failed to match";
-    } else {
-      return DiagnosedSilenceableFailure::success();
-    }
-  }
-  if (getNumResults() != result.getNumPayloadGroups()) {
-    return errorOut()
-           << "callback produced a different number of handles than expected ( "
-           << result.getNumPayloadGroups() << " vs " << getNumResults() << " )";
-  }
-
-  for (OpResult value : getResults()) {
-    results.set(value, result.getPayloadGroup(value.getResultNumber()));
-  }
-  return DiagnosedSilenceableFailure::success();
-}
-
-void transform_dialect::MatchCallbackOp::getEffects(
-    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  transform::onlyReadsHandle(getInputs(), effects);
-  transform::producesHandle(getOutputs(), effects);
-  // TODO: it doesn't really modify the payload, we need a separate resource for
-  // this mapping.
-  transform::modifiesPayload(effects);
-}
-
-DiagnosedSilenceableFailure transform_dialect::TakeFirstOp::apply(
-    transform::TransformResults &results, transform::TransformState &state) {
-  SmallVector<Operation *> concatenated;
-  bool found = false;
-  for (Value handle : getInputs()) {
-    ArrayRef<Operation *> payloads = state.getPayloadOps(handle);
-    if (payloads.empty()) continue;
-    if (!found) {
-      results.set(getFirst().cast<OpResult>(), payloads);
-      found = true;
-    } else {
-      llvm::append_range(concatenated, payloads);
-    }
-  }
-
-  if (!found) results.set(getFirst().cast<OpResult>(), {});
-  results.set(getRest().cast<OpResult>(), concatenated);
-  return DiagnosedSilenceableFailure::success();
-}
-
-void transform_dialect::TakeFirstOp::getEffects(
-    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  transform::onlyReadsHandle(getInputs(), effects);
-  transform::producesHandle(getFirst(), effects);
-  transform::producesHandle(getRest(), effects);
-}
-
 #define GET_OP_CLASSES
 #include "iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.cpp.inc"
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
index f012f51..8f7cf12 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
@@ -323,91 +323,4 @@
   }];
 }
 
-def RegisterMatchCallbacksOp :
-    Op<Transform_Dialect, "iree.register_match_callbacks",
-      [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
-       DeclareOpInterfaceMethods<TransformOpInterface>]> {
-  let description = [{
-    Registers named structured op matcher callbacks specific for IREE to use
-    with `transform.iree.match_callback`. This should be called before first
-    `match_callback` may be executed following the transform dialect control
-    flow.
-
-    The callbacks must have a unique name and a signature compatible with
-    `MatchCallbacksRegistry::MatchCallbackFn`, which currently means
-    `DiagnosedSilenceableFailure(MatchCallbackResult &, Location,
-     const TransformState &, ValueRange)`. The callback receives a "result",
-     followed by a location at which errors should be reported, a transform
-     state at the moment of the _match_ (not registration) and a list of
-     handle values passed as operands to the `match_callback` operation.
-     It is expected to populate the "result" object with lists of payload
-     operations that will be bound to the handles produced by the
-     `match_callback` operation. The callback may fail, at which point
-     it should produce a silenceable error. The callback currently is not
-     allowed to modify the payload IR (though this may be revised in the
-     future for the purpose of communicating the properties of the IR
-     captured by the match). Therefore, it should not have a reason to
-     produce a definite error.
-  }];
-
-  let arguments = (ins);
-  let results = (outs);
-  let assemblyFormat = "attr-dict";
-  let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
-}
-
-def MatchCallbackOp :
-    Op<Transform_Dialect, "iree.match_callback",
-       [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
-        DeclareOpInterfaceMethods<TransformOpInterface>]> {
-  let description = [{
-    Performs payload IR matching using a C++ callback registered beforehand.
-    The callback is identified by name and is passed the current transform
-    state and the list of handle operands, along with information necessary
-    for error propagation. See `register_match_callbacks` for the description
-    of the callback contract.
-
-    If `failure_propagation_mode` is set to `suppress`, any silenceable errors
-    in the callback (typically, "failure to match") will be ignored and the
-    resulting handles will be associated with empty lists of payload
-    operations. Otherwise, silenceable failures are propagated.
-  }];
-
-  let arguments = (ins StrAttr:$callback_name,
-                       FailurePropagationMode:$failure_propagation_mode,
-                       Variadic<TransformTypeInterface>:$inputs);
-  let results = (outs Variadic<TransformTypeInterface>:$outputs);
-  let assemblyFormat = "`failures` `(` $failure_propagation_mode `)` "
-                       "$callback_name `(` $inputs `)` attr-dict "
-                       "`:` functional-type($inputs, $outputs)";
-  let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
-}
-
-def TakeFirstOp :
-    Op<Transform_Dialect, "iree.take_first",
-       [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
-        DeclareOpInterfaceMethods<TransformOpInterface>]> {
-  let description = [{
-    Given an arbitrary list of handles associated with potentially empty lists
-    of payload operations, produces two new handles:
-
-      - a handle pointing to the same payload operations as the first operand
-        handle with a non-empty list of payload operations;
-      - a handle pointing to the concatenated list of payload operations
-        associated with any other handle.
-
-    Note that this does not perform any deduplication.
-
-    This operation is useful to select a single target after some potentially
-    unsuccessful matches.
-  }];
-
-  let arguments = (ins Variadic<TransformTypeInterface>:$inputs);
-  let results = (outs TransformTypeInterface:$first,
-                      TransformTypeInterface:$rest);
-  let assemblyFormat =
-      "$inputs attr-dict `:` functional-type($inputs, results)";
-  let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
-}
-
 #endif // IREE_COMPILER_CODEGEN_COMMON_TRANSFORMEXTENSIONS_COMMONEXTENSIONS
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD
index 6dfeffd..2f4b4a3 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD
@@ -38,6 +38,7 @@
     deps = [
         "//compiler/src/iree/compiler/Codegen:PassHeaders",
         "//compiler/src/iree/compiler/Codegen/Common",
+        "//compiler/src/iree/compiler/Codegen/Common:TransformDialectJitterPass",
         "//compiler/src/iree/compiler/Codegen/Dialect:IREECodegenDialect",
         "//compiler/src/iree/compiler/Codegen/Interfaces:PartitionableLoopsInterface",
         "//compiler/src/iree/compiler/Codegen/Sandbox",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
index 17b2e78..fcd81f6 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
@@ -81,6 +81,7 @@
     MLIRVectorToSCF
     MLIRVectorTransforms
     iree::compiler::Codegen::Common
+    iree::compiler::Codegen::Common::TransformDialectJitterPass
     iree::compiler::Codegen::Dialect::IREECodegenDialect
     iree::compiler::Codegen::Interfaces::PartitionableLoopsInterface
     iree::compiler::Codegen::PassHeaders
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
index 4ec9218..321020d 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
@@ -10,7 +10,7 @@
 
 #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "iree/compiler/Codegen/Common/LinalgOpInfo.h"
-#include "iree/compiler/Codegen/Common/TransformDialectStrategies.h"
+#include "iree/compiler/Codegen/Common/TransformDialectStrategiesCPU.h"
 #include "iree/compiler/Codegen/Common/UserConfig.h"
 #include "iree/compiler/Codegen/LLVMCPU/TargetMLTransformInfo.h"
 #include "iree/compiler/Codegen/Transforms/Transforms.h"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD
index 2243acf..b2d35d9 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD
@@ -40,6 +40,7 @@
     deps = [
         "//compiler/src/iree/compiler/Codegen:PassHeaders",
         "//compiler/src/iree/compiler/Codegen/Common",
+        "//compiler/src/iree/compiler/Codegen/Common:TransformDialectJitterPass",
         "//compiler/src/iree/compiler/Codegen/Dialect:IREECodegenDialect",
         "//compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions:LLVMGPUExtensions",
         "//compiler/src/iree/compiler/Codegen/Transforms",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
index 7d2cb6b..1aa5a15 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
@@ -88,6 +88,7 @@
     MLIRVectorToSCF
     MLIRVectorTransforms
     iree::compiler::Codegen::Common
+    iree::compiler::Codegen::Common::TransformDialectJitterPass
     iree::compiler::Codegen::Dialect::IREECodegenDialect
     iree::compiler::Codegen::LLVMGPU::TransformExtensions::LLVMGPUExtensions
     iree::compiler::Codegen::PassHeaders
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index cc3317c..6423570 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -10,7 +10,7 @@
 
 #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "iree/compiler/Codegen/Common/LinalgOpInfo.h"
-#include "iree/compiler/Codegen/Common/TransformDialectStrategies.h"
+#include "iree/compiler/Codegen/Common/TransformDialectStrategiesGPU.h"
 #include "iree/compiler/Codegen/Common/UserConfig.h"
 #include "iree/compiler/Codegen/Dialect/LoweringConfig.h"
 #include "iree/compiler/Codegen/LLVMGPU/TransposeUtils.h"
diff --git a/llvm-external-projects/iree-dialects/BUILD b/llvm-external-projects/iree-dialects/BUILD
index 1821fd3..46b5217 100644
--- a/llvm-external-projects/iree-dialects/BUILD
+++ b/llvm-external-projects/iree-dialects/BUILD
@@ -145,8 +145,13 @@
     includes = ["include"],
     deps = [
         "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:Analysis",
         "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:LinalgDialect",
         "@llvm-project//mlir:Rewrite",
+        "@llvm-project//mlir:SCFDialect",
+        "@llvm-project//mlir:TensorDialect",
+        "@llvm-project//mlir:TransformDialect",
         "@llvm-project//mlir:Transforms",
     ],
 )
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h
index 07a4cfd..954deaa 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h
@@ -73,6 +73,7 @@
 #define GET_OP_CLASSES
 #include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h.inc"
 
+namespace mlir {
 namespace transform_ext {
 class StructuredTransformOpsExtension
     : public mlir::transform::TransformDialectExtension<
@@ -81,5 +82,6 @@
   StructuredTransformOpsExtension();
 };
 } // namespace transform_ext
+} // namespace mlir
 
 #endif // IREE_DIALECTS_DIALECT_LINALG_TRANSFORM_STRUCTUREDTRANSFORMOPSEXT_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td
index d8bfe7c..e2b716f 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td
@@ -55,7 +55,7 @@
     static ::llvm::StringRef getDefaultDialect() { return "transform"; }
   }];
 
-  let cppNamespace = "transform_ext";
+  let cppNamespace = "mlir::transform_ext";
   let hasVerifier = 1;
 }
 
@@ -67,7 +67,7 @@
      DeclareOpInterfaceMethods<TransformOpInterface>]> {
   let description = [{Indicates that the entire module should be bufferized.}];
   let assemblyFormat = "attr-dict";
-  let cppNamespace = "transform_ext";
+  let cppNamespace = "mlir::transform_ext";
 }
 
 def LowerVectorsOp : Op<Transform_Dialect, "lower_vectors",
@@ -89,7 +89,7 @@
     );
 
   let assemblyFormat = "attr-dict";
-  let cppNamespace = "transform_ext";
+  let cppNamespace = "mlir::transform_ext";
 }
 
 def LowerToLLVMOp : Op<Transform_Dialect, "lower_to_llvm",
@@ -110,7 +110,95 @@
      DefaultValuedAttr<BoolAttr, "false">:$enable_async);
 
   let assemblyFormat = "attr-dict";
-  let cppNamespace = "transform_ext";
+  let cppNamespace = "mlir::transform_ext";
+}
+
+
+def RegisterMatchCallbacksOp :
+    Op<Transform_Dialect, "iree.register_match_callbacks",
+      [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+       DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let description = [{
+    Registers named structured op matcher callbacks specific for IREE to use
+    with `transform.iree.match_callback`. This should be called before first
+    `match_callback` may be executed following the transform dialect control
+    flow.
+
+    The callbacks must have a unique name and a signature compatible with
+    `MatchCallbacksRegistry::MatchCallbackFn`, which currently means
+    `DiagnosedSilenceableFailure(MatchCallbackResult &, Location,
+     const TransformState &, ValueRange)`. The callback receives a "result",
+     followed by a location at which errors should be reported, a transform
+     state at the moment of the _match_ (not registration) and a list of
+     handle values passed as operands to the `match_callback` operation.
+     It is expected to populate the "result" object with lists of payload
+     operations that will be bound to the handles produced by the
+     `match_callback` operation. The callback may fail, at which point
+     it should produce a silenceable error. The callback currently is not
+     allowed to modify the payload IR (though this may be revised in the
+     future for the purpose of communicating the properties of the IR
+     captured by the match). Therefore, it should not have a reason to
+     produce a definite error.
+  }];
+
+  let arguments = (ins);
+  let results = (outs);
+  let assemblyFormat = "attr-dict";
+  let cppNamespace = "mlir::transform_ext";
+}
+
+def MatchCallbackOp :
+    Op<Transform_Dialect, "iree.match_callback",
+       [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+        DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let description = [{
+    Performs payload IR matching using a C++ callback registered beforehand.
+    The callback is identified by name and is passed the current transform
+    state and the list of handle operands, along with information necessary
+    for error propagation. See `register_match_callbacks` for the description
+    of the callback contract.
+
+    If `failure_propagation_mode` is set to `suppress`, any silenceable errors
+    in the callback (typically, "failure to match") will be ignored and the
+    resulting handles will be associated with empty lists of payload
+    operations. Otherwise, silenceable failures are propagated.
+  }];
+
+  let arguments = (ins StrAttr:$callback_name,
+                       FailurePropagationMode:$failure_propagation_mode,
+                       Variadic<TransformTypeInterface>:$inputs);
+  let results = (outs Variadic<TransformTypeInterface>:$outputs);
+  let assemblyFormat = "`failures` `(` $failure_propagation_mode `)` "
+                       "$callback_name `(` $inputs `)` attr-dict "
+                       "`:` functional-type($inputs, $outputs)";
+  let cppNamespace = "mlir::transform_ext";
+}
+
+def TakeFirstOp :
+    Op<Transform_Dialect, "iree.take_first",
+       [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+        DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let description = [{
+    Given an arbitrary list of handles associated with potentially empty lists
+    of payload operations, produces two new handles:
+
+      - a handle pointing to the same payload operations as the first operand
+        handle with a non-empty list of payload operations;
+      - a handle pointing to the concatenated list of payload operations
+        associated with any other handle.
+
+    Note that this does not perform any deduplication.
+
+    This operation is useful to select a single target after some potentially
+    unsuccessful matches.
+  }];
+
+  let arguments = (ins Variadic<TransformTypeInterface>:$inputs);
+  let results = (outs TransformTypeInterface:$first,
+                      TransformTypeInterface:$rest);
+  let assemblyFormat =
+      "$inputs attr-dict `:` functional-type($inputs, results)";
+  let cppNamespace = "mlir::transform_ext";
 }
 
 #endif // STRUCTURED_TRANSFORM_OPS_EXT
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/TransformMatchers.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
similarity index 92%
rename from compiler/src/iree/compiler/Codegen/Common/TransformExtensions/TransformMatchers.h
rename to llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
index 4e467d5..430d879 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/TransformMatchers.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
@@ -11,14 +11,15 @@
 #include <cstdint>
 #include <functional>
 
-#include "llvm/ADT/StringMap.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/IR/Matchers.h"
+#include "llvm/ADT/StringMap.h"
 
-namespace mlir::iree_compiler::IREE::transform_dialect {
+namespace mlir {
+namespace transform_ext {
 
 //===---------------------------------------------------------------------===//
 // StructuredOpMatcher and predicates.
@@ -93,7 +94,7 @@
 namespace detail {
 template <typename T>
 using has_reset_capture_t = decltype(std::declval<T>().resetCapture());
-}  // namespace detail
+} // namespace detail
 
 /// Structured op matcher with additional predicates attachable through the
 /// fluent, a.k.a. chainable, API. Note that public API must *not* accept
@@ -112,7 +113,7 @@
     predicates.push_back(std::move(firstPredicate));
   }
 
- public:
+public:
   /// Matches any structured operation, i.e., operation with LinalgOp interface.
   StructuredOpMatcher() {}
 
@@ -178,15 +179,18 @@
                           &operandMatcher](linalg::LinalgOp linalgOp) -> bool {
       int64_t transformedPosition =
           position >= 0 ? position : linalgOp.getNumDpsInputs() + position;
-      if (transformedPosition >= linalgOp.getNumDpsInputs()) return false;
+      if (transformedPosition >= linalgOp.getNumDpsInputs())
+        return false;
 
       Operation *definingOp = linalgOp.getDpsInputOperand(transformedPosition)
                                   ->get()
                                   .getDefiningOp();
-      if (!definingOp) return optional.value;
+      if (!definingOp)
+        return optional.value;
       // We MUST run the matcher at this point, even if the match is optional,
       // to allow for capture.
-      if (operandMatcher.match(definingOp)) return true;
+      if (operandMatcher.match(definingOp))
+        return true;
       return optional.value;
     });
     recordNestedMatcher(operandMatcher);
@@ -246,15 +250,18 @@
                           &operandMatcher](linalg::LinalgOp linalgOp) -> bool {
       int64_t transformedPosition =
           position >= 0 ? position : linalgOp.getNumDpsInits() + position;
-      if (transformedPosition >= linalgOp.getNumDpsInits()) return false;
+      if (transformedPosition >= linalgOp.getNumDpsInits())
+        return false;
 
       Operation *definingOp = linalgOp.getDpsInitOperand(transformedPosition)
                                   ->get()
                                   .getDefiningOp();
-      if (!definingOp) return optional.value;
+      if (!definingOp)
+        return optional.value;
       // We MUST run the matcher at this point, even if the match is optional,
       // to allow for capture.
-      if (operandMatcher.match(definingOp)) return true;
+      if (operandMatcher.match(definingOp))
+        return true;
       return optional.value;
     });
     recordNestedMatcher(operandMatcher);
@@ -276,7 +283,8 @@
                           position](linalg::LinalgOp linalgOp) -> bool {
       int64_t transformedPosition =
           position >= 0 ? position : linalgOp->getNumResults() + position;
-      if (transformedPosition >= linalgOp->getNumResults()) return false;
+      if (transformedPosition >= linalgOp->getNumResults())
+        return false;
 
       // We MUST run the matcher at this point, even if the match is optional,
       // to allow for capture.
@@ -305,10 +313,11 @@
   /// for optional nested predicates from the previous application.
   void resetCapture() {
     captured = nullptr;
-    for (const CaptureResetFn &fn : captureResetFns) fn();
+    for (const CaptureResetFn &fn : captureResetFns)
+      fn();
   }
 
- private:
+private:
   /// Informs the matcher that it has another, nested matcher. Practically,
   /// records the captured value cleanup function so it runs when required.
   template <typename T>
@@ -348,7 +357,7 @@
 /// transform operation. Conceptually, a list of lists of payload operations to
 /// be associated with each result handle.
 class MatchCallbackResult {
- public:
+public:
   /// Returns the number of lists of payload operations.
   unsigned getNumPayloadGroups() const { return payloadGroupLengths.size(); }
 
@@ -380,7 +389,7 @@
       addPayloadGroup(ArrayRef<Operation *>(op));
   }
 
- private:
+private:
   /// The flat list of all payload opreations. `payloadGroupLengths` can be used
   /// to compute the sublist that corresponds to one nested list.
   // TODO: if somebody implements such a flattened vector generically, use it.
@@ -391,7 +400,7 @@
 /// A transform state extension that maintains the mapping between callback
 /// names as strings usable in `match_callback` and their implementations.
 class MatchCallbacksRegistry : public transform::TransformState::Extension {
- public:
+public:
   using MatchCallbackFn = std::function<DiagnosedSilenceableFailure(
       MatchCallbackResult &, Location, const transform::TransformState &,
       ValueRange)>;
@@ -414,11 +423,12 @@
   /// name, or null if it is not present in the registry.
   const MatchCallbackFn *get(StringRef name) const {
     auto iter = callbacks.find(name);
-    if (iter == callbacks.end()) return nullptr;
+    if (iter == callbacks.end())
+      return nullptr;
     return &iter->getValue();
   }
 
- private:
+private:
   llvm::StringMap<MatchCallbackFn> callbacks;
 };
 
@@ -432,10 +442,10 @@
 ///
 /// where trailing and leading are elementwise operations whose presence is
 /// optional. Each matcher will capture the corresponding operation.
-void makeGPUReductionMatcher(StructuredOpMatcher &reduction,
-                             StructuredOpMatcher &fill,
-                             StructuredOpMatcher &leading,
-                             StructuredOpMatcher &trailing);
+void makeReductionMatcher(StructuredOpMatcher &reduction,
+                          StructuredOpMatcher &fill,
+                          StructuredOpMatcher &leading,
+                          StructuredOpMatcher &trailing);
 
 /// Creates a group of matchers for:
 ///
@@ -447,13 +457,14 @@
 /// where trailing and leading are elementwise operations whose presence is
 /// optional, and with subsetting ops potentially present on the operand use-def
 /// chains.
-void makeGPUSplitReductionMatcher(StructuredOpMatcher &parallel_reduction,
-                                  StructuredOpMatcher &combiner_reduction,
-                                  StructuredOpMatcher &parallel_fill,
-                                  StructuredOpMatcher &original_fill,
-                                  StructuredOpMatcher &leading,
-                                  StructuredOpMatcher &trailing);
+void makeSplitReductionMatcher(StructuredOpMatcher &parallel_reduction,
+                               StructuredOpMatcher &combiner_reduction,
+                               StructuredOpMatcher &parallel_fill,
+                               StructuredOpMatcher &original_fill,
+                               StructuredOpMatcher &leading,
+                               StructuredOpMatcher &trailing);
 
-}  // namespace mlir::iree_compiler::IREE::transform_dialect
+} // namespace transform_ext
+} // namespace mlir
 
-#endif  // IREE_COMPILER_CODEGEN_COMMON_TRANSFORMEXTENSIONS_TRANSFORMMATCHERS_H_
+#endif // IREE_COMPILER_CODEGEN_COMMON_TRANSFORMEXTENSIONS_TRANSFORMMATCHERS_H_
diff --git a/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp b/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp
index 4a4594a..576e626 100644
--- a/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp
+++ b/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp
@@ -62,7 +62,7 @@
   DialectRegistry registry;
   registry.addExtensions<
       mlir::iree_compiler::IREE::LinalgExt::LinalgExtTransformOpsExtension,
-      transform_ext::StructuredTransformOpsExtension>();
+      mlir::transform_ext::StructuredTransformOpsExtension>();
   ctx->appendDialectRegistry(registry);
 }
 
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
index dc013b2..2afd477 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
@@ -12,6 +12,7 @@
 #include "iree-dialects/Transforms/Listener.h"
 #include "iree-dialects/Transforms/ListenerCSE.h"
 #include "iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h"
+#include "iree-dialects/Transforms/TransformMatchers.h"
 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
@@ -317,7 +318,7 @@
 // StructuredTransformOpsExtension
 //===----------------------------------------------------------------------===//
 
-transform_ext::StructuredTransformOpsExtension::
+mlir::transform_ext::StructuredTransformOpsExtension::
     StructuredTransformOpsExtension() {
   registerTransformOps<
 #define GET_OP_LIST
@@ -1115,3 +1116,244 @@
   // TODO: make composable...
   return DiagnosedSilenceableFailure::success();
 }
+
+//===---------------------------------------------------------------------===//
+// MatchCallbackOp
+//===---------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform_ext::MatchCallbackOp::apply(
+    mlir::transform::TransformResults &results,
+    mlir::transform::TransformState &state) {
+  auto setEmptyResults = [&results, this] {
+    for (OpResult value : getResults()) {
+      results.set(value, {});
+    }
+  };
+  auto errorOut = [this, &setEmptyResults] {
+    setEmptyResults();
+    return emitSilenceableError();
+  };
+
+  auto *registry = state.getExtension<transform_ext::MatchCallbacksRegistry>();
+  if (!registry)
+    return errorOut() << "match registry not available";
+
+  const transform_ext::MatchCallbacksRegistry::MatchCallbackFn *callback =
+      registry->get(getCallbackName());
+  if (!callback) {
+    return errorOut() << "callback '" << getCallbackName()
+                      << "' not found in the registry";
+  }
+
+  MatchCallbackResult result;
+  DiagnosedSilenceableFailure status =
+      (*callback)(result, getLoc(), state, getInputs());
+  if (!status.succeeded()) {
+    setEmptyResults();
+    if (status.isDefiniteFailure())
+      return status;
+    if (getFailurePropagationMode() ==
+        mlir::transform::FailurePropagationMode::Propagate) {
+      return emitSilenceableError() << "failed to match";
+    } else {
+      return DiagnosedSilenceableFailure::success();
+    }
+  }
+  if (getNumResults() != result.getNumPayloadGroups()) {
+    return errorOut()
+           << "callback produced a different number of handles than expected ( "
+           << result.getNumPayloadGroups() << " vs " << getNumResults() << " )";
+  }
+
+  for (OpResult value : getResults()) {
+    results.set(value, result.getPayloadGroup(value.getResultNumber()));
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform_ext::MatchCallbackOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  mlir::transform::onlyReadsHandle(getInputs(), effects);
+  mlir::transform::producesHandle(getOutputs(), effects);
+  // TODO: it doesn't really modify the payload, we need a separate resource for
+  // this mapping.
+  mlir::transform::modifiesPayload(effects);
+}
+
+//===---------------------------------------------------------------------===//
+// RegisterMatchCallbacksOp
+//===---------------------------------------------------------------------===//
+
+/// Match callback for "_test_match_callback" hook. Matches any payload
+/// operations associated with operand handles unless they have the
+/// "test.iree_transform_do_not_match" attribute, in which case produces a
+/// silenceable failure.
+static DiagnosedSilenceableFailure
+testMatchCallbackCallback(transform_ext::MatchCallbackResult &res, Location loc,
+                          const mlir::transform::TransformState &state,
+                          ValueRange handles) {
+  bool hadFailures = false;
+  for (Value handle : handles) {
+    if (llvm::any_of(state.getPayloadOps(handle), [](Operation *op) {
+          return op->hasAttr("test.iree_transform_do_not_match");
+        })) {
+      res.addPayloadGroup(ArrayRef<Operation *>());
+      hadFailures = true;
+    } else {
+      res.addPayloadGroup(state.getPayloadOps(handle));
+    }
+  }
+  if (hadFailures)
+    return emitSilenceableFailure(loc) << "failed to match";
+  return DiagnosedSilenceableFailure::success();
+}
+
+/// Match callback for a reduction with optional leading and trailing
+/// elementwise operations. Matches *the first* occurrence of such a reduction
+/// within an op associated with the given handle.
+///
+/// Input handles:
+///
+///   - container op, must be associated with one operation.
+///
+/// Output handles:
+///
+///   - leading elementwise op, if any;
+///   - the "fill" op preceding the reduction;
+///   - reduction op;
+///   - trailing elementwise op, if any.
+static DiagnosedSilenceableFailure
+reductionCallback(transform_ext::MatchCallbackResult &res, Location loc,
+                  const mlir::transform::TransformState &state,
+                  ValueRange handles) {
+  if (handles.size() != 1 || state.getPayloadOps(handles[0]).size() != 1) {
+    return emitSilenceableFailure(loc)
+           << "expected one handle to one operation";
+  }
+
+  transform_ext::StructuredOpMatcher pattern, fill, leadingEltwise,
+      trailingEltwise;
+  makeReductionMatcher(pattern, fill, leadingEltwise, trailingEltwise);
+
+  // TODO: need a mechanism for this to go around the entire IR,
+  // potentially with list matches for each group.
+  Operation *root = state.getPayloadOps(handles[0])[0];
+  WalkResult walkResult = root->walk([&](Operation *op) {
+    pattern.resetCapture();
+    if (!matchPattern(op, pattern))
+      return WalkResult::advance();
+
+    res.addPotentiallyEmptyPayloadGroup(leadingEltwise.getCaptured());
+    res.addPayloadGroup({fill.getCaptured()});
+    res.addPayloadGroup({pattern.getCaptured()});
+    res.addPotentiallyEmptyPayloadGroup(trailingEltwise.getCaptured());
+    return WalkResult::interrupt();
+  });
+
+  if (walkResult.wasInterrupted())
+    return DiagnosedSilenceableFailure::success();
+  return emitSilenceableFailure(loc) << "failed to match";
+}
+
+/// Match callback for a reduction after splitting with optional leading and
+/// trailing elementwise operations. Matches *the first* occurrence of such a
+/// reduction within an op associated with the given handle.
+///
+/// Input handles:
+///
+///   - container op, must be associated with one operation.
+///
+/// Output handles:
+///
+///   - leading elementwise op, if any;
+///   - the "fill" op preceding the original reduction;
+///   - the "fill" op preceding the split, more parallel reduction;
+///   - the split, more parallel reduction op;
+///   - reduction op;
+///   - trailing elementwise op, if any.
+static DiagnosedSilenceableFailure
+splitReductionCallback(transform_ext::MatchCallbackResult &res, Location loc,
+                       const mlir::transform::TransformState &state,
+                       ValueRange handles) {
+  if (handles.size() != 1 || state.getPayloadOps(handles[0]).size() != 1) {
+    return emitSilenceableFailure(loc)
+           << "expected one handle to one operation";
+  }
+
+  transform_ext::StructuredOpMatcher parallel_reduction, combiner_reduction,
+      parallel_fill, original_fill, leading, trailing;
+  makeSplitReductionMatcher(parallel_reduction, combiner_reduction,
+                            parallel_fill, original_fill, leading, trailing);
+
+  // TODO: need a mechanism for this to go around the entire IR,
+  // potentially with list matches for each group.
+  Operation *root = state.getPayloadOps(handles[0])[0];
+  WalkResult walkResult = root->walk([&](Operation *op) {
+    combiner_reduction.resetCapture();
+    if (!matchPattern(op, combiner_reduction))
+      return WalkResult::advance();
+
+    res.addPotentiallyEmptyPayloadGroup(leading.getCaptured());
+    res.addPayloadGroup({original_fill.getCaptured()});
+    res.addPayloadGroup({parallel_fill.getCaptured()});
+    res.addPayloadGroup({parallel_reduction.getCaptured()});
+    res.addPayloadGroup({combiner_reduction.getCaptured()});
+    res.addPotentiallyEmptyPayloadGroup(trailing.getCaptured());
+    return WalkResult::interrupt();
+  });
+
+  if (walkResult.wasInterrupted())
+    return DiagnosedSilenceableFailure::success();
+  return emitSilenceableFailure(loc) << "failed to match";
+}
+
+DiagnosedSilenceableFailure transform_ext::RegisterMatchCallbacksOp::apply(
+    mlir::transform::TransformResults &results,
+    mlir::transform::TransformState &state) {
+  auto &registry = state.addExtension<transform_ext::MatchCallbacksRegistry>();
+  registry.registerCallback("_test_match_callback", testMatchCallbackCallback);
+  registry.registerCallback("reduction", reductionCallback);
+  registry.registerCallback("split_reduction", splitReductionCallback);
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform_ext::RegisterMatchCallbacksOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  // TODO: it doesn't really modify the payload, we need a separate resource for
+  // this mapping.
+  mlir::transform::modifiesPayload(effects);
+}
+
+//===---------------------------------------------------------------------===//
+// TakeFirstOp
+//===---------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform_ext::TakeFirstOp::apply(mlir::transform::TransformResults &results,
+                                  mlir::transform::TransformState &state) {
+  SmallVector<Operation *> concatenated;
+  bool found = false;
+  for (Value handle : getInputs()) {
+    ArrayRef<Operation *> payloads = state.getPayloadOps(handle);
+    if (payloads.empty())
+      continue;
+    if (!found) {
+      results.set(getFirst().cast<OpResult>(), payloads);
+      found = true;
+    } else {
+      llvm::append_range(concatenated, payloads);
+    }
+  }
+
+  if (!found)
+    results.set(getFirst().cast<OpResult>(), {});
+  results.set(getRest().cast<OpResult>(), concatenated);
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform_ext::TakeFirstOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  mlir::transform::onlyReadsHandle(getInputs(), effects);
+  mlir::transform::producesHandle(getFirst(), effects);
+  mlir::transform::producesHandle(getRest(), effects);
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Transforms/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Transforms/CMakeLists.txt
index 8047e74..00eda6f 100644
--- a/llvm-external-projects/iree-dialects/lib/Transforms/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@
   Listener.cpp
   ListenerCSE.cpp
   ListenerGreedyPatternRewriteDriver.cpp
+  TransformMatchers.cpp
   
   LINK_LIBS PRIVATE
   # TODO: break dialect dependency by implementing the transformation separately
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/TransformMatchers.cpp b/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
similarity index 75%
rename from compiler/src/iree/compiler/Codegen/Common/TransformExtensions/TransformMatchers.cpp
rename to llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
index 4085a16..eda5d55 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/TransformMatchers.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
@@ -4,22 +4,21 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-#include "iree/compiler/Codegen/Common/TransformExtensions/TransformMatchers.h"
+#include "iree-dialects/Transforms/TransformMatchers.h"
 
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 
 using namespace mlir;
-using namespace mlir::iree_compiler;
-using namespace mlir::iree_compiler::IREE;
 
 //===---------------------------------------------------------------------===//
 // StructuredOpMatcher and friends.
 //===---------------------------------------------------------------------===//
 
-bool transform_dialect::StructuredOpMatcher::match(Operation *op) {
+bool transform_ext::StructuredOpMatcher::match(Operation *op) {
   auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
-  if (!linalgOp) return false;
+  if (!linalgOp)
+    return false;
 
   if (!llvm::all_of(predicates, [linalgOp](const PredicateFn &fn) {
         return fn(linalgOp);
@@ -31,21 +30,22 @@
   return true;
 }
 
-transform_dialect::StructuredOpMatcher &
-transform_dialect::StructuredOpMatcher::dim(int64_t dimension, ShapeKind kind) {
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::dim(int64_t dimension, ShapeKind kind) {
   predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool {
     SmallVector<int64_t> shape = linalgOp.getStaticLoopRanges();
     int64_t transformedDimension =
         dimension >= 0 ? dimension : shape.size() + dimension;
-    if (transformedDimension >= shape.size()) return false;
+    if (transformedDimension >= shape.size())
+      return false;
     return ShapedType::isDynamic(shape[transformedDimension]) ^
            (kind == ShapeKind::Static);
   });
   return *this;
 }
 
-transform_dialect::StructuredOpMatcher &
-transform_dialect::StructuredOpMatcher::dim(AllDims tag, ShapeKind kind) {
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::dim(AllDims tag, ShapeKind kind) {
   predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool {
     SmallVector<int64_t> shape = linalgOp.getStaticLoopRanges();
     return llvm::all_of(shape, [=](int64_t dimension) {
@@ -55,14 +55,15 @@
   return *this;
 }
 
-transform_dialect::StructuredOpMatcher &
-transform_dialect::StructuredOpMatcher::dim(int64_t dimension,
-                                            utils::IteratorType kind) {
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::dim(int64_t dimension,
+                                        utils::IteratorType kind) {
   predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool {
     unsigned rank = linalgOp.getNumLoops();
     int64_t transformedDimension =
         dimension >= 0 ? dimension : rank + dimension;
-    if (transformedDimension >= rank) return false;
+    if (transformedDimension >= rank)
+      return false;
 
     utils::IteratorType iteratorKind =
         linalgOp.getIteratorTypesArray()[transformedDimension];
@@ -70,9 +71,8 @@
   });
   return *this;
 }
-transform_dialect::StructuredOpMatcher &
-transform_dialect::StructuredOpMatcher::dim(AllDims tag,
-                                            utils::IteratorType kind) {
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::dim(AllDims tag, utils::IteratorType kind) {
   predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool {
     return llvm::all_of(
         linalgOp.getIteratorTypesArray(),
@@ -81,14 +81,15 @@
   return *this;
 }
 
-transform_dialect::StructuredOpMatcher &
-transform_dialect::StructuredOpMatcher::dim(int64_t dimension,
-                                            DivisibleBy divisibleBy) {
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::dim(int64_t dimension,
+                                        DivisibleBy divisibleBy) {
   predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool {
     unsigned rank = linalgOp.getNumLoops();
     int64_t transformedDimension =
         dimension >= 0 ? dimension : rank + dimension;
-    if (transformedDimension >= rank) return false;
+    if (transformedDimension >= rank)
+      return false;
 
     int64_t size = linalgOp.getStaticLoopRanges()[transformedDimension];
     return !ShapedType::isDynamic(size) && (size % divisibleBy.value == 0);
@@ -96,8 +97,8 @@
   return *this;
 }
 
-transform_dialect::StructuredOpMatcher &
-transform_dialect::StructuredOpMatcher::input(AllOperands tag, IsPermutation) {
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::input(AllOperands tag, IsPermutation) {
   predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool {
     // all_of with a lambda requires const-casting dance, so using a loop.
     for (OpOperand *operand : linalgOp.getDpsInputOperands()) {
@@ -151,7 +152,8 @@
         auto it = llvm::find_if(range, [&](BlockArgument bbarg) {
           return loop.getTiedOpOperand(bbarg) != &use;
         });
-        if (it == range.end()) return user;
+        if (it == range.end())
+          return user;
         val = *it;
         continue;
       }
@@ -164,9 +166,8 @@
   } while (true);
 }
 
-transform_dialect::StructuredOpMatcher &
-transform_dialect::StructuredOpMatcher::input(int64_t position,
-                                              SubsetOf subset) {
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::input(int64_t position, SubsetOf subset) {
   // Implementation note: SubsetOf must *not* be passed by-reference because
   // it is typically a temporary constructed within the argument of a function
   // call, but it will be used in the lambda that outlives the temporary. The
@@ -174,7 +175,8 @@
   predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool {
     int64_t transformedPosition =
         position >= 0 ? position : linalgOp.getNumDpsInputs() + position;
-    if (transformedPosition >= linalgOp.getNumDpsInputs()) return false;
+    if (transformedPosition >= linalgOp.getNumDpsInputs())
+      return false;
 
     Operation *producer = traverseSubsetsBackwards(
         linalgOp.getDpsInputOperand(transformedPosition)->get());
@@ -184,8 +186,8 @@
   return *this;
 }
 
-transform_dialect::StructuredOpMatcher &
-transform_dialect::StructuredOpMatcher::output(AllOperands tag, IsPermutation) {
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::output(AllOperands tag, IsPermutation) {
   predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool {
     for (OpOperand *operand : linalgOp.getDpsInitOperands()) {
       if (!linalgOp.getMatchingIndexingMap(operand).isPermutation())
@@ -196,13 +198,14 @@
   return *this;
 }
 
-transform_dialect::StructuredOpMatcher &
-transform_dialect::StructuredOpMatcher::output(int64_t position,
-                                               ElementTypeBitWidth width) {
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::output(int64_t position,
+                                           ElementTypeBitWidth width) {
   predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool {
     int64_t updatedPosition =
         position >= 0 ? position : linalgOp.getNumDpsInits() + position;
-    if (updatedPosition >= linalgOp.getNumDpsInits()) return false;
+    if (updatedPosition >= linalgOp.getNumDpsInits())
+      return false;
     auto shapedType = linalgOp.getDpsInitOperand(updatedPosition)
                           ->get()
                           .getType()
@@ -213,13 +216,14 @@
   return *this;
 }
 
-transform_dialect::StructuredOpMatcher &
-transform_dialect::StructuredOpMatcher::output(int64_t position,
-                                               SingleCombinerReduction tag) {
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::output(int64_t position,
+                                           SingleCombinerReduction tag) {
   predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool {
     int64_t updatedPosition =
         position >= 0 ? position : linalgOp.getNumDpsInits() + position;
-    if (updatedPosition >= linalgOp.getNumDpsInits()) return false;
+    if (updatedPosition >= linalgOp.getNumDpsInits())
+      return false;
     SmallVector<Operation *> combinerOps;
     return matchReduction(linalgOp.getRegionOutputArgs(), updatedPosition,
                           combinerOps) &&
@@ -228,9 +232,8 @@
   return *this;
 }
 
-transform_dialect::StructuredOpMatcher &
-transform_dialect::StructuredOpMatcher::output(int64_t position,
-                                               SubsetOf subset) {
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::output(int64_t position, SubsetOf subset) {
   // Implementation note: SubsetOf must *not* be passed by-reference because
   // it is typically a temporary constructed within the argument of a function
   // call, but it will be used in the lambda that outlives the temporary. The
@@ -238,7 +241,8 @@
   predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool {
     int64_t transformedPosition =
         position >= 0 ? position : linalgOp.getNumDpsInputs() + position;
-    if (transformedPosition >= linalgOp.getNumDpsInputs()) return false;
+    if (transformedPosition >= linalgOp.getNumDpsInputs())
+      return false;
 
     Operation *producer = traverseSubsetsBackwards(
         linalgOp.getDpsInitOperand(transformedPosition)->get());
@@ -248,14 +252,13 @@
   return *this;
 }
 
-transform_dialect::StructuredOpMatcher &
-transform_dialect::StructuredOpMatcher::result(int64_t position, HasAnyUse tag,
-                                               SubsetOf subset,
-                                               OptionalMatch optional) {
+transform_ext::StructuredOpMatcher &transform_ext::StructuredOpMatcher::result(
+    int64_t position, HasAnyUse tag, SubsetOf subset, OptionalMatch optional) {
   predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool {
     int64_t transformedPosition =
         position >= 0 ? position : linalgOp->getNumResults() + position;
-    if (transformedPosition >= linalgOp->getNumResults()) return false;
+    if (transformedPosition >= linalgOp->getNumResults())
+      return false;
 
     Operation *user =
         traverseSubsetsForwardAnyUse(linalgOp->getResult(transformedPosition));
@@ -268,8 +271,8 @@
 // MatchCallbackResult.
 //===---------------------------------------------------------------------===//
 
-ArrayRef<Operation *> transform_dialect::MatchCallbackResult::getPayloadGroup(
-    unsigned position) const {
+ArrayRef<Operation *>
+transform_ext::MatchCallbackResult::getPayloadGroup(unsigned position) const {
   assert(position < payloadGroupLengths.size());
   int64_t start = 0;
   for (unsigned i = 0; i < position; ++i) {
@@ -285,11 +288,11 @@
 
 static constexpr unsigned kCudaWarpSize = 32;
 
-void transform_dialect::makeGPUReductionMatcher(
-    transform_dialect::StructuredOpMatcher &reduction,
-    transform_dialect::StructuredOpMatcher &fill,
-    transform_dialect::StructuredOpMatcher &leading,
-    transform_dialect::StructuredOpMatcher &trailing) {
+void transform_ext::makeReductionMatcher(
+    transform_ext::StructuredOpMatcher &reduction,
+    transform_ext::StructuredOpMatcher &fill,
+    transform_ext::StructuredOpMatcher &leading,
+    transform_ext::StructuredOpMatcher &trailing) {
   fill = m_StructuredOp<linalg::FillOp>();
   trailing = m_StructuredOp<linalg::GenericOp>()
                  .input(AllOperands(), IsPermutation())
@@ -314,13 +317,13 @@
                   .result(0, HasAnyUse(), trailing, OptionalMatch());
 }
 
-void transform_dialect::makeGPUSplitReductionMatcher(
-    transform_dialect::StructuredOpMatcher &parallel_reduction,
-    transform_dialect::StructuredOpMatcher &combiner_reduction,
-    transform_dialect::StructuredOpMatcher &parallel_fill,
-    transform_dialect::StructuredOpMatcher &original_fill,
-    transform_dialect::StructuredOpMatcher &leading,
-    transform_dialect::StructuredOpMatcher &trailing) {
+void transform_ext::makeSplitReductionMatcher(
+    transform_ext::StructuredOpMatcher &parallel_reduction,
+    transform_ext::StructuredOpMatcher &combiner_reduction,
+    transform_ext::StructuredOpMatcher &parallel_fill,
+    transform_ext::StructuredOpMatcher &original_fill,
+    transform_ext::StructuredOpMatcher &leading,
+    transform_ext::StructuredOpMatcher &trailing) {
   original_fill = m_StructuredOp<linalg::FillOp>();
   parallel_fill = m_StructuredOp<linalg::FillOp>();
   trailing = m_StructuredOp<linalg::GenericOp>()