[Codegen] Re-Enable transform dialect configuration strategy sample (#15787)

This re-adds a mechanism for setting strategy configurations using the
transform dialect through
--iree-codegen-use-transform-dialect-configuration as well as re-enables
the associated sample.
diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeUserConfigs.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeUserConfigs.cpp
index c8f00fb..cd588bc 100644
--- a/compiler/src/iree/compiler/Codegen/Common/MaterializeUserConfigs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeUserConfigs.cpp
@@ -38,13 +38,43 @@
     "iree-codegen-transform-dialect-library",
     llvm::cl::desc(
         "File path to a module containing a library of transform dialect"
-        "strategies"),
+        "strategies. Can be suffixed with the name of a transform sequence"
+        "within the library to run as preprocessing per executable variant."
+        "This is specified as <file-path>::<sequence-name>. If not specified,"
+        "this will default to `__kernel_config`."),
     llvm::cl::init(""));
 
 namespace {
 
 static const char kTranslationInfoAttrName[] = "translation_info";
 
+enum StrategyRunResult {
+  Success = 0,
+  NotFound = 1,
+  Failed = 2,
+};
+
+static StrategyRunResult
+runTransformConfigurationStrategy(Operation *payloadRoot,
+                                  StringRef entryPointName,
+                                  ModuleOp &transformLibrary) {
+  /// If we have a symbol, verify the existence of the symbol within the
+  /// transform library.
+  Operation *entryPoint = transform::detail::findTransformEntryPoint(
+      payloadRoot, transformLibrary, entryPointName);
+  if (!entryPoint) {
+    return StrategyRunResult::NotFound;
+  }
+
+  transform::TransformOptions options;
+  if (failed(transform::applyTransformNamedSequence(
+          payloadRoot, entryPoint, transformLibrary,
+          options.enableExpensiveChecks(true)))) {
+    return StrategyRunResult::Failed;
+  }
+  return StrategyRunResult::Success;
+}
+
 struct MaterializeUserConfigsPass
     : public MaterializeUserConfigsBase<MaterializeUserConfigsPass> {
   void getDependentDialects(DialectRegistry &registry) const override {
@@ -58,21 +88,98 @@
         getAllEntryPoints(moduleOp);
     MLIRContext *context = moduleOp.getContext();
 
+    // Parse the file path and kernel config strategy from flags. There are
+    // three possible usage flows:
+    //   1. Specify only a transform dialect library, e.g.
+    //      --iree-codegen-transform-dialect-library=/path/to/library.mlir
+    //      This will load the transform library and attempt to use two default
+    //      strategies, `__kernel_config` before strategy configuration (i.e.
+    //      now), and `__transform_main` for codegen. If `__kernel_config` is
+    //      present in the library, it will run it and use the annotations set
+    //      by it. If not present, `__transform_main` will be broadcasted to all
+    //      dispatches.
+    //
+    //   2. Specify a library path with a strategy name instead of
+    //      `__transform_main`. This is the same as (1) except it uses a
+    //      different default strategy.
+    //
+    //   3. Specify a library path suffixed with a kernel config entry point.
+    //      This will throw an error if the specified kernel config strategy is
+    //      not found.
+    SmallVector<StringRef, 2> parts;
+    llvm::SplitString(llvm::StringRef(clCodegenTransformDialectLibraryFileName),
+                      parts, "::");
+    if (parts.size() > 2) {
+      variantOp.emitError()
+          << "Invalid transform library path and sequence name "
+          << clCodegenTransformDialectLibraryFileName;
+      return signalPassFailure();
+    }
+    bool hasTransformConfig = parts.size() == 2;
+    bool hasTransformLibrary = parts.size() >= 1;
+    bool hasTransformStrategy = !clCodegenTransformDialectStrategyName.empty();
+
+    std::string libraryFileName;
+    if (hasTransformLibrary) {
+      if (parts[0].empty()) {
+        variantOp.emitError() << "Cannot specify an empty library path";
+        return signalPassFailure();
+      }
+      libraryFileName = parts[0];
+    }
+
+    std::string entrySequenceName;
+    if (hasTransformConfig) {
+      if (parts[1].empty()) {
+        variantOp.emitError() << "Cannot specify an empty sequence name";
+        return signalPassFailure();
+      }
+      entrySequenceName = parts[1];
+    }
+
     LDBG("MaterializeUserConfigsPass on variant: " << variantOp);
     std::optional<ModuleOp> transformLibrary = std::nullopt;
-    if (!clCodegenTransformDialectLibraryFileName.empty()) {
+    if (hasTransformLibrary) {
       auto dialect =
           context->getOrLoadDialect<IREE::Codegen::IREECodegenDialect>();
-      auto maybeTransformLibrary = dialect->getOrLoadTransformLibraryModule(
-          clCodegenTransformDialectLibraryFileName);
+      auto maybeTransformLibrary =
+          dialect->getOrLoadTransformLibraryModule(libraryFileName);
       if (failed(maybeTransformLibrary)) {
-        variantOp.emitError() << "failed to load transform library module: "
-                              << clCodegenTransformDialectLibraryFileName;
+        variantOp.emitError()
+            << "failed to load transform library module: " << libraryFileName;
         return signalPassFailure();
       }
       transformLibrary = *maybeTransformLibrary;
-      LDBG("--found transform library @"
-           << clCodegenTransformDialectLibraryFileName);
+      LDBG("--found transform library @" << libraryFileName);
+    }
+
+    // Run the user specified transform configuration, or attempt to run
+    // `__kernel_config` if no sequence name is specified.
+    if (hasTransformConfig) {
+      // If we get here, the transform library must necessarily have been
+      // loaded.
+      assert(transformLibrary && *transformLibrary &&
+             "Unexpected unloaded transform library");
+      if (runTransformConfigurationStrategy(variantOp, entrySequenceName,
+                                            *transformLibrary) !=
+          StrategyRunResult::Success) {
+        variantOp.emitError() << "transform kernel config strategy `"
+                              << entrySequenceName << "` failed to apply";
+        return signalPassFailure();
+      }
+    } else if (transformLibrary && (*transformLibrary)) {
+      StrategyRunResult res = runTransformConfigurationStrategy(
+          variantOp, "__kernel_config", *transformLibrary);
+      if (res == StrategyRunResult::Failed) {
+        variantOp.emitError()
+            << "default transform __kernel_config strategy failed to apply";
+        return signalPassFailure();
+      }
+      // If `__kernel_config` was found and ran successfully, indicate that
+      // there was a config strategy.
+      if (res == StrategyRunResult::Success) {
+        hasTransformConfig = true;
+      }
     }
 
     IREE::Codegen::DispatchLoweringPassPipeline tdPipeline =
@@ -81,10 +188,9 @@
     // Here we always set the pipeline strategy to transform dialect if the
     // flag is non-empty to ensure we pick the right lowering pipeline in the
     // event a strategy symbol is defined.
-    if (!clCodegenTransformDialectLibraryFileName.empty() ||
-        !clCodegenTransformDialectStrategyName.empty()) {
+    if ((hasTransformLibrary && !hasTransformConfig) || hasTransformStrategy) {
       StringRef strategyName =
-          (clCodegenTransformDialectStrategyName.empty())
+          (!hasTransformStrategy)
               ? StringRef(
                     transform::TransformDialect::kTransformEntryPointSymbolName)
               : clCodegenTransformDialectStrategyName;
diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.h b/compiler/src/iree/compiler/Codegen/Common/Passes.h
index 95fe68e..c3c1d3a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/Common/Passes.h
@@ -259,7 +259,8 @@
 
 /// Create an IREE-specific Transform dialect interpreter pass with all
 /// registrations necessary for IREE.
-std::unique_ptr<Pass> createTransformDialectInterpreterPass();
+std::unique_ptr<Pass>
+createTransformDialectInterpreterPass(StringRef transformSequenceName = "");
 
 /// Pass to propagate type to avoid generating load/stores of illegal types.
 std::unique_ptr<OperationPass<func::FuncOp>> createTypePropagationPass();
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformDialectInterpreterPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformDialectInterpreterPass.cpp
index 4852485..fb87d06 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformDialectInterpreterPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformDialectInterpreterPass.cpp
@@ -76,9 +76,19 @@
 extern llvm::cl::opt<std::string> clCodegenTransformDialectLibraryFileName;
 
 /// Create a Transform dialect interpreter pass.
-std::unique_ptr<Pass> createTransformDialectInterpreterPass() {
-  return std::make_unique<TransformDialectInterpreterPass>(
-      clCodegenTransformDialectLibraryFileName,
-      clCodegenTransformDialectStrategyName);
+std::unique_ptr<Pass>
+createTransformDialectInterpreterPass(StringRef transformSequenceName) {
+  StringRef strategyName = transformSequenceName.empty()
+                               ? clCodegenTransformDialectStrategyName
+                               : transformSequenceName;
+  StringRef libraryPath = "";
+  SmallVector<StringRef, 2> parts;
+  llvm::SplitString(llvm::StringRef(clCodegenTransformDialectLibraryFileName),
+                    parts, "::");
+  if (!parts.empty()) {
+    libraryPath = parts[0];
+  }
+  return std::make_unique<TransformDialectInterpreterPass>(libraryPath,
+                                                           strategyName);
 }
 } // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp
index aa0cd8a..091ca37 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp
@@ -203,9 +203,12 @@
     break;
   }
   // Transform-dialect pipelines.
-  case IREE::Codegen::DispatchLoweringPassPipeline::TransformDialectCodegen:
-    addTransformDialectPasses(pipeline);
+  case IREE::Codegen::DispatchLoweringPassPipeline::TransformDialectCodegen: {
+    SymbolRefAttr codegenSpec = translationInfo.value().getCodegenSpec();
+    addTransformDialectPasses(
+        pipeline, codegenSpec ? codegenSpec.getLeafReference() : StringRef(""));
     break;
+  }
   default:
     moduleOp.emitOpError("Unsupported pipeline on CPU target.");
     return signalPassFailure();
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index 4f876ba..b5bb3c8 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -560,10 +560,11 @@
   addCPUBufferizePasses(nestedModulePM);
 }
 
-void addTransformDialectPasses(OpPassManager &passManager) {
+void addTransformDialectPasses(OpPassManager &passManager,
+                               StringRef entryPoint) {
   // Give control to the transform dialect.
   passManager.addPass(
-      mlir::iree_compiler::createTransformDialectInterpreterPass());
+      mlir::iree_compiler::createTransformDialectInterpreterPass(entryPoint));
   // Dropping the schedule is needed:
   //   1. if we want to embed the transform in the module: we should drop the
   //      schedule once applied.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
index 74d688d..c964342 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
@@ -145,7 +145,8 @@
                                     bool lowerToVectors = true);
 
 /// Transform dialect-based common.
-void addTransformDialectPasses(OpPassManager &passManager);
+void addTransformDialectPasses(OpPassManager &passManager,
+                               StringRef entryPoint);
 
 // Populates the passes needed to do tiling, decomposing, and vectorizing the
 // convolution ops.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
index cef1407..4eb31ca 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
@@ -108,9 +108,12 @@
     addGPUPackUnPackPasses(pipeline);
     break;
   // Transform-dialect pipelines.
-  case IREE::Codegen::DispatchLoweringPassPipeline::TransformDialectCodegen:
-    addGPUTransformDialectPasses(pipeline);
+  case IREE::Codegen::DispatchLoweringPassPipeline::TransformDialectCodegen: {
+    SymbolRefAttr codegenSpec = translationInfo.value().getCodegenSpec();
+    addGPUTransformDialectPasses(
+        pipeline, codegenSpec ? codegenSpec.getLeafReference() : StringRef(""));
     break;
+  }
   // no pipeline specified, nothing to do.
   case IREE::Codegen::DispatchLoweringPassPipeline::None:
     return;
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index e529e8a..eddea59 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -597,9 +597,10 @@
 extern llvm::cl::opt<std::string> clGPUCodegenTransformDialectDebugPayloadTag;
 extern llvm::cl::opt<std::string> clGPUCodegenTransformDialectDebugTransformTag;
 
-void addGPUTransformDialectPasses(OpPassManager &passManager) {
+void addGPUTransformDialectPasses(OpPassManager &passManager,
+                                  StringRef entryPoint) {
   passManager.addPass(
-      mlir::iree_compiler::createTransformDialectInterpreterPass());
+      mlir::iree_compiler::createTransformDialectInterpreterPass(entryPoint));
 
   // Dropping the schedule is needed:
   //   1. if we want to embed the transform in the module: we should drop the
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
index 9860197..08f83a9 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
@@ -40,7 +40,7 @@
 void addGPUSimpleDistributePassPipeline(OpPassManager &pm);
 
 /// Transform dialect-based path.
-void addGPUTransformDialectPasses(OpPassManager &pm);
+void addGPUTransformDialectPasses(OpPassManager &pm, StringRef entryPoint);
 
 /// Lowering transpose using shared memory.
 void addGPUTransposePassPipeline(OpPassManager &pm);
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
index 13f7dcd..a6611e8 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -263,9 +263,10 @@
   spirvPM.addPass(spirv::createSPIRVUpdateVCEPass());
 }
 
-void addSPIRVTransformDialectPasses(OpPassManager &passManager) {
+void addSPIRVTransformDialectPasses(OpPassManager &passManager,
+                                    StringRef entryPoint) {
   passManager.addPass(
-      mlir::iree_compiler::createTransformDialectInterpreterPass());
+      mlir::iree_compiler::createTransformDialectInterpreterPass(entryPoint));
 
   // Dropping the schedule is needed:
   //   1. if we want to embed the transform in the module: we should drop the
@@ -645,8 +646,9 @@
   nestedModulePM.addPass(createCSEPass());
 }
 
-void addSPIRVTransformDialectPassPipeline(OpPassManager &pm) {
-  addSPIRVTransformDialectPasses(pm);
+void addSPIRVTransformDialectPassPipeline(OpPassManager &pm,
+                                          StringRef entryPoint) {
+  addSPIRVTransformDialectPasses(pm, entryPoint);
 
   // Run GenericVectorization pass additionally to convert vectors into forms
   // needed for SPIR-V.
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h
index dc4a22c..84cd3c7 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h
@@ -44,7 +44,8 @@
 void addSPIRVSubgroupReducePassPipeline(OpPassManager &pm);
 
 /// Pass pipeline to lower IREE HAL executables via transform dialect schedules.
-void addSPIRVTransformDialectPassPipeline(OpPassManager &pm);
+void addSPIRVTransformDialectPassPipeline(OpPassManager &pm,
+                                          StringRef entryPoint);
 
 /// Pass pipeline to lower winograd ops. This pipeline follows the
 /// SPIRVBaseVectorize pipeline with the following exception:
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
index 3943e99..14ec33f 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
@@ -95,9 +95,12 @@
   case CodeGenPipeline::SPIRVWinogradVectorize:
     addSPIRVWinogradVectorizePassPipeline(pipeline);
     break;
-  case CodeGenPipeline::TransformDialectCodegen:
-    addSPIRVTransformDialectPassPipeline(pipeline);
+  case CodeGenPipeline::TransformDialectCodegen: {
+    SymbolRefAttr codegenSpec = translationInfo.value().getCodegenSpec();
+    addSPIRVTransformDialectPassPipeline(
+        pipeline, codegenSpec ? codegenSpec.getLeafReference() : StringRef(""));
     break;
+  }
   // No pipeline specified, nothing to do.
   case CodeGenPipeline::None:
     return;
diff --git a/samples/transform_dialect/example_module.mlir b/samples/transform_dialect/example_module.mlir
index c5eab46..95b80af 100644
--- a/samples/transform_dialect/example_module.mlir
+++ b/samples/transform_dialect/example_module.mlir
@@ -107,28 +107,29 @@
 }
 
 /// We test first with threading off so that the printers are legible.
-// R-UN: iree-compile %s --iree-hal-target-backends=vulkan \
-// R-UN:   --iree-codegen-use-transform-dialect-strategy=transform_main \
-// R-UN:   --iree-codegen-transform-dialect-library=%p/transform_library.mlir \
-// R-UN:   --compile-from=executable-sources \
-// R-UN:   --compile-to=executable-targets \
-// R-UN:   --mlir-disable-threading | \
-// R-UN: FileCheck %s --check-prefixes=CODEGEN-PRINTER
+// RUN: iree-compile %s --iree-hal-target-backends=vulkan \
+// RUN:   --iree-codegen-transform-dialect-library=%p/transform_library.mlir::kernel_config \
+// RUN:   --compile-from=executable-sources \
+// RUN:   --compile-to=executable-targets \
+// RUN:   --mlir-disable-threading | \
+// RUN: FileCheck %s --check-prefixes=CODEGEN-PRINTER
 
-// CODEGEN-PRINTER:     IR printer: Setting matmul strategy to default top-level
-// CODEGEN-PRINTER:       translation_info = #iree_codegen.translation_info<TransformDialectCodegen codegen_spec = @transform_main
+// CODEGEN-PRINTER:     IR printer: Setting matmul strategy to custom_transform_strategy
+// CODEGEN-PRINTER:       translation_info = #iree_codegen.translation_info<TransformDialectCodegen codegen_spec = @custom_transform_strategy>
 // CODEGEN-PRINTER:     IR printer: Setting reduce strategy to base vectorize top-level
 // CODEGEN-PRINTER:       translation_info = #iree_codegen.translation_info<SPIRVBaseVectorize>, workgroup_size = [16 : index, 1 : index, 1 : index]
 
 /// Then test with threading to make sure it runs
 // RUN: iree-compile %s --iree-hal-target-backends=vulkan \
-// RUN:   --iree-codegen-use-transform-dialect-strategy=@transform_main \
-// RUN:   --iree-codegen-transform-dialect-library=%p/transform_library.mlir \
+// RUN:   --iree-codegen-transform-dialect-library=%p/transform_library.mlir::kernel_config \
 // RUN:   --compile-from=executable-sources \
 // RUN:   --compile-to=executable-targets \
 // RUN:   --mlir-disable-threading | \
 // RUN: FileCheck %s --check-prefixes=CODEGEN
 
+// CODEGEN: Ran custom_transform_strategy
 // CODEGEN: spirv.func @example_module_dispatch_0_generic_80_f32
-// CODEGEN: spirv.func @example_module_dispatch_1_matmul_16x16x5_f32
+// CODEGEN: hal.executable private @example_module_dispatch_1
+// CODEGEN:   #iree_codegen.translation_info<TransformDialectCodegen codegen_spec = @custom_transform_strategy>
+// CODEGEN:     spirv.func @example_module_dispatch_1_matmul_16x16x5_f32
 // CODEGEN: spirv.func @example_module_dispatch_2_generic_16x16_f32
diff --git a/samples/transform_dialect/transform_library.mlir b/samples/transform_dialect/transform_library.mlir
index 3bb75ad..d4ca3b4 100644
--- a/samples/transform_dialect/transform_library.mlir
+++ b/samples/transform_dialect/transform_library.mlir
@@ -1,13 +1,76 @@
 module attributes { transform.with_named_sequence } {
-  // Print and send it down normal IREE codegen.
-  transform.named_sequence @custom_matmul(%matmul: !transform.any_op {transform.consumed}) {  
-    %1 = transform.structured.generalize %matmul : (!transform.any_op) -> !transform.any_op
-    transform.print {name = "Setting matmul strategy to default"}
+  // Example of a custom matmul strategy. The target matmul is annotated with
+  // the name of this strategy down below before strategy selection, overriding
+  // default IREE codegen.
+  transform.named_sequence @custom_transform_strategy(
+      %variant_op: !transform.any_op {transform.consumed}) {
+    // Step 1. Re-match the matmul
+    // ===========================================================================
+    %matmul = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+
+    // Step 2. Tile to grid
+    // ===========================================================================
+    %grid_reduction, %forall_grid =
+    transform.structured.tile_using_forall %matmul tile_sizes [16, 16] ( mapping = [#gpu.block<x>, #gpu.block<y>] ) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.iree.populate_workgroup_count_region_using_num_threads_slice %forall_grid : (!transform.any_op) -> ()
+
+    // Step 3. Vectorize
+    // ===========================================================================
+    %func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.iree.fold_reshape_into_tensor_hal_interface
+      transform.apply_patterns.linalg.fold_unit_extent_dims_via_slices
+      transform.apply_patterns.vector.cast_away_vector_leading_one_dim
+    } : !transform.any_op
+    %func_1 = transform.structured.vectorize_children_and_apply_patterns %func : (!transform.any_op) -> !transform.any_op
+
+    // Step 4. Bufferize
+    // ===========================================================================
+    transform.apply_patterns to %func_1 {
+      transform.apply_patterns.iree.fold_fill_into_pad
+      transform.apply_patterns.linalg.tiling_canonicalization
+      transform.apply_patterns.scf.for_loop_canonicalization
+    } : !transform.any_op
+    transform.apply_patterns to %func_1 {
+      transform.apply_patterns.tensor.reassociative_reshape_folding
+      transform.apply_patterns.canonicalization
+    } : !transform.any_op
+    transform.iree.apply_cse %func_1 : !transform.any_op
+    transform.iree.eliminate_empty_tensors %variant_op : (!transform.any_op) -> ()
+    transform.apply_patterns to %func_1 {
+      transform.apply_patterns.linalg.erase_unnecessary_inputs
+    } : !transform.any_op
+    %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op : (!transform.any_op) -> (!transform.any_op)
+    %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op
+
+    // Step 6. Post-bufferization vector distribution
+    // ===========================================================================
+    %func_7 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op
+    transform.iree.forall_to_workgroup %func_7 : (!transform.any_op) -> ()
+    transform.iree.map_nested_forall_to_gpu_threads %func_7
+        workgroup_dims = [4, 8, 1] : (!transform.any_op) -> ()
+
+    // Step 7. Do layout analysis and lower to mma
+    // ===========================================================================
+    %func_10 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op
+    %func_11 = transform.iree.layout_analysis_and_distribution %func_10 : (!transform.any_op) -> (!transform.any_op)
+    transform.print {name = "Ran custom_transform_strategy"}
     transform.yield
   }
 
-  // Send it down subgroup reduce.
-  transform.named_sequence @use_subgroup_reduce(%reduce: !transform.any_op {transform.readonly}) {  
+  // Send it down a custom transform dialect pipeline.
+  transform.named_sequence @custom_matmul(%matmul: !transform.any_op {transform.readonly}) {
+    %variant_op = transform.get_parent_op %matmul {op_name = "hal.executable.variant"} : (!transform.any_op) -> !transform.any_op
+    %exports = transform.structured.match ops{["hal.executable.export"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    %subgroup_reduce = transform.param.constant #iree_codegen.translation_info<TransformDialectCodegen
+                                                                               codegen_spec = @custom_transform_strategy> -> !transform.any_param
+    transform.annotate %exports "translation_info" = %subgroup_reduce : !transform.any_op, !transform.any_param
+    transform.print {name = "Setting matmul strategy to custom_transform_strategy"}
+    transform.yield
+  }
+
+  // Send it down subgroup reduce with a custom tiling configuration.
+  transform.named_sequence @use_base_vectorize(%reduce: !transform.any_op {transform.readonly}) {
     %variant_op = transform.get_parent_op %reduce {op_name = "hal.executable.variant"} : (!transform.any_op) -> !transform.any_op
     %lowering_config = transform.param.constant #iree_codegen.lowering_config<tile_sizes = [[8, 0], [1, 0], [0, 0, 4]]> -> !transform.any_param
     transform.annotate %reduce "lowering_config" = %lowering_config : !transform.any_op, !transform.any_param
@@ -42,10 +105,34 @@
     transform.yield %matched : !transform.any_op
   }
 
-  transform.named_sequence @transform_main(%variant_op: !transform.any_op {transform.consumed}) {  
+  // An example of a custom transform dialect based kernel config. Note that
+  // because of the way `transform.foreach_match` works, the callback cannot
+  // manipulate IR beyond the op *given* to the matcher, as foreach_match will
+  // attempt to keep walking the IR even after a successful match. The expected
+  // flow for a strategy like this is as follows:
+  //
+  // Author an entry point like this (@kernel_config) that walks the IR and
+  // attempts to annotate the dispatch with the codegen strategy to use, i.e.
+  //   transform.foreach_match in %variant_op
+  //       @matcher_0 -> @annotator_0,
+  //       @matcher_1 -> @annotator_1,
+  //       ...
+  //
+  // the annotators should attach an #iree_codegen.translation_info attribute
+  // to the `hal.executable.export` ops within the variant as well as any
+  // relevant op specific tile sizes (and other important attributes like
+  // workgroup_size and subgroup_size, if relevant). This will then get handed
+  // off to backend specific kernel config, which will let these user configs
+  // pass through unperturbed.
+  //
+  // To couple this with a transform dialect based codegen strategy, the target
+  // codegen strategy can be included inline with this library and relevant ops
+  // can be annotated with `TransformDialectCodegen` as the lowering pipeline,
+  // with a reference to the strategy to use (see an example above).
+  transform.named_sequence @kernel_config(%variant_op: !transform.any_op {transform.consumed}) {
     transform.foreach_match in %variant_op
         @match_matmul -> @custom_matmul,
-        @match_reduce -> @use_subgroup_reduce
+        @match_reduce -> @use_base_vectorize
       : (!transform.any_op) -> (!transform.any_op)
     transform.yield
   }