Add samples for custom kernel match+replace scripts (#16150)

Custom match and replace scripts is a workflow for injecting custom
dispatches into a module without the need for any surrounding compiler
infrastructure for building the dispatches. Custom kernels are paired
with an externally authored script that matches a subgraph and replaces
the subgraph with a call to the kernel, all delivered from independently valid IR.
The examples here use the transform dialect to do this by adding a plugin
point during preprocessing that will run the user provided specification.

The flow demonstrated here requires authoring two functions per kernel
alongside some additional boilerplate.

1. A `func.func @my_kernel(...)` that takes (typically tensor) arguments
   and includes the call to the custom dispatch inline. This can use any
   of the other custom dispatch approaches.

2. A `transform.named_sequence @my_matcher` that describes the
   compatible subgraph to match.
diff --git a/build_tools/cmake/ctest_all.sh b/build_tools/cmake/ctest_all.sh
index 755aac7..ebdd89d 100755
--- a/build_tools/cmake/ctest_all.sh
+++ b/build_tools/cmake/ctest_all.sh
@@ -132,6 +132,7 @@
 excluded_tests+=(
   "iree/samples/custom_dispatch/cpu/embedded/example_hal.mlir.test"
   "iree/samples/custom_dispatch/cpu/embedded/example_stream.mlir.test"
+  "iree/samples/custom_dispatch/cpu/embedded/example_transform.mlir.test"
 )
 
 ctest_args=(
diff --git a/compiler/src/iree/compiler/Pipelines/Options.cpp b/compiler/src/iree/compiler/Pipelines/Options.cpp
index f24086d..ac8cf2e 100644
--- a/compiler/src/iree/compiler/Pipelines/Options.cpp
+++ b/compiler/src/iree/compiler/Pipelines/Options.cpp
@@ -213,6 +213,12 @@
       llvm::cl::desc("Textual description of the pass pipeline to run before "
                      "running normal IREE compilation pipelines"),
       llvm::cl::cat(category));
+  binder.opt<std::string>(
+      "iree-preprocessing-transform-spec-filename",
+      preprocessingTransformSpecFilename,
+      llvm::cl::desc(
+          "File name of a transform dialect spec to use for preprocessing"),
+      llvm::cl::cat(category));
 }
 
 } // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Pipelines/Options.h b/compiler/src/iree/compiler/Pipelines/Options.h
index 01698b7..32c22a4 100644
--- a/compiler/src/iree/compiler/Pipelines/Options.h
+++ b/compiler/src/iree/compiler/Pipelines/Options.h
@@ -164,6 +164,7 @@
 
 struct PreprocessingOptions {
   std::string preprocessingPassPipeline;
+  std::string preprocessingTransformSpecFilename;
   void bindOptions(OptionsBinder &binder);
   using FromFlags = OptionsFromFlags<PreprocessingOptions>;
 };
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
index c9acfa2..bbd7019 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
@@ -31,6 +31,7 @@
     name = "Transforms",
     srcs = [
         "ConvertConv2DToImg2Col.cpp",
+        "InterpreterPass.cpp",
         "MakeSingleDispatchForFunction.cpp",
         "PadLinalgOps.cpp",
         "PassDetail.h",
@@ -55,6 +56,8 @@
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:TensorDialect",
         "@llvm-project//mlir:TensorUtils",
+        "@llvm-project//mlir:TransformDialect",
+        "@llvm-project//mlir:TransformDialectTransforms",
         "@llvm-project//mlir:Transforms",
     ],
 )
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
index fb7b26f..9fc8003 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
@@ -27,6 +27,7 @@
     "Passes.h.inc"
   SRCS
     "ConvertConv2DToImg2Col.cpp"
+    "InterpreterPass.cpp"
     "MakeSingleDispatchForFunction.cpp"
     "PadLinalgOps.cpp"
     "PassDetail.h"
@@ -43,6 +44,8 @@
     MLIRPass
     MLIRTensorDialect
     MLIRTensorUtils
+    MLIRTransformDialect
+    MLIRTransformDialectTransforms
     MLIRTransforms
     iree::compiler::Dialect::Flow::IR
     iree::compiler::Dialect::Flow::Transforms
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConv2DToImg2Col.cpp b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConv2DToImg2Col.cpp
index 1776519..77faa23 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConv2DToImg2Col.cpp
+++ b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConv2DToImg2Col.cpp
@@ -552,9 +552,6 @@
 
 struct ConvertConv2DToImg2ColPass
     : ConvertConv2DToImg2ColBase<ConvertConv2DToImg2ColPass> {
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<linalg::LinalgDialect>();
-  }
   void runOnOperation() override {
     MLIRContext *context = &getContext();
     RewritePatternSet patterns(&getContext());
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/InterpreterPass.cpp b/compiler/src/iree/compiler/Preprocessing/Common/InterpreterPass.cpp
new file mode 100644
index 0000000..84948bc
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/Common/InterpreterPass.cpp
@@ -0,0 +1,54 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Preprocessing/Common/Passes.h"
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
+
+using namespace mlir;
+
+namespace mlir::iree_compiler::Preprocessing {
+
+#define GEN_PASS_DEF_INTERPRETERPASS
+#include "iree/compiler/Preprocessing/Common/Passes.h.inc" // IWYU pragma: export
+
+} // namespace mlir::iree_compiler::Preprocessing
+
+namespace {
+class InterpreterPass
+    : public iree_compiler::Preprocessing::impl::InterpreterPassBase<
+          InterpreterPass> {
+public:
+  using Base::Base;
+
+  void runOnOperation() override {
+    MLIRContext *context = &getContext();
+    // Load the module from the spec path. The module will be unloaded once the
+    // pass finishes.
+    OwningOpRef<ModuleOp> transformModule;
+    if (failed(transform::detail::assembleTransformLibraryFromPaths(
+            context, transformSpecPath, transformModule)))
+      return signalPassFailure();
+    Operation *payloadRoot = getOperation();
+    Operation *transformEntryPoint = transform::detail::findTransformEntryPoint(
+        getOperation(), *transformModule, "__transform_main");
+    if (!transformEntryPoint) {
+      getOperation()->emitError() << "could not find transform entry point "
+                                     "__preprocessing_main in transform module";
+      return signalPassFailure();
+    }
+
+    if (failed(transform::applyTransformNamedSequence(
+            payloadRoot, transformEntryPoint, *transformModule,
+            options.enableExpensiveChecks(!disableExpensiveChecks)))) {
+      return signalPassFailure();
+    }
+  }
+
+private:
+  /// Transform interpreter options.
+  transform::TransformOptions options;
+};
+} // namespace
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/MakeSingleDispatchForFunction.cpp b/compiler/src/iree/compiler/Preprocessing/Common/MakeSingleDispatchForFunction.cpp
index 3e762d0..120c464 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/MakeSingleDispatchForFunction.cpp
+++ b/compiler/src/iree/compiler/Preprocessing/Common/MakeSingleDispatchForFunction.cpp
@@ -19,10 +19,6 @@
 struct MakeSingleDispatchForFunctionPass
     : public MakeSingleDispatchForFunctionBase<
           MakeSingleDispatchForFunctionPass> {
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<IREE::Flow::FlowDialect>();
-  }
-
   void runOnOperation() override;
 };
 } // namespace
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/PadLinalgOps.cpp b/compiler/src/iree/compiler/Preprocessing/Common/PadLinalgOps.cpp
index 45a394c..68125a4 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/PadLinalgOps.cpp
+++ b/compiler/src/iree/compiler/Preprocessing/Common/PadLinalgOps.cpp
@@ -155,9 +155,6 @@
 
 class PadLinalgOpsPass : public PadLinalgOpsBase<PadLinalgOpsPass> {
 public:
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<linalg::LinalgDialect>();
-  }
   void runOnOperation() override {
     MLIRContext *context = &getContext();
     RewritePatternSet patterns(context);
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/PassDetail.h b/compiler/src/iree/compiler/Preprocessing/Common/PassDetail.h
index 5fe9b54..447b597 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/PassDetail.h
+++ b/compiler/src/iree/compiler/Preprocessing/Common/PassDetail.h
@@ -7,6 +7,9 @@
 #ifndef IREE_COMPILER_PREPROCESSING_COMMON_PASS_DETAIL_H_
 #define IREE_COMPILER_PREPROCESSING_COMMON_PASS_DETAIL_H_
 
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Pass/Pass.h"
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h
index 7caa385..0b03c38 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h
+++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h
@@ -30,6 +30,9 @@
 // Register all Passes
 //===----------------------------------------------------------------------===//
 
+#define GEN_PASS_DECL
+#include "iree/compiler/Preprocessing/Common/Passes.h.inc" // IWYU pragma: keep
+
 void registerCommonPreprocessingPasses();
 
 } // namespace mlir::iree_compiler::Preprocessing
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
index 376d92f..9ac1436 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
+++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
@@ -13,12 +13,38 @@
     Pass<"iree-preprocessing-convert-conv2d-to-img2col", ""> {
   let summary = "Convert linalg convolution ops to matmul img2col based implementation";
   let constructor = "mlir::iree_compiler::Preprocessing::createConvertConv2DToImg2ColPass()";
+  let dependentDialects = [
+    "mlir::linalg::LinalgDialect",
+  ];
+}
+
+def InterpreterPass : Pass<"iree-preprocessing-transform-interpreter"> {
+  let summary = "transform dialect interpreter";
+  let description = [{
+    This pass runs the transform dialect interpreter and applies the named
+    sequence transformation named `__transform_main`.
+
+    TODO: Drop this pass in favor of the one upstream. The one upstream requires
+    separate loading of the module and thus isn't suited for single-use
+    transform scripts.
+  }];
+  let dependentDialects = ["::mlir::transform::TransformDialect"];
+  let options = [
+    Option<"disableExpensiveChecks", "disable-expensive-checks", "bool",
+           "false",
+           "Disable expensive checks in the interpreter for a faster run.">,
+    Option<"transformSpecPath", "transform-spec-path", "std::string",
+           /*default=*/"", "File path to the transform spec to use.">,
+  ];
 }
 
 def MakeSingleDispatchForFunction :
     InterfacePass<"iree-preprocessing-make-single-dispatch-for-function", "mlir::FunctionOpInterface"> {
   let summary = "Convert entire function into a single dispatch";
   let constructor = "mlir::iree_compiler::Preprocessing::createMakeSingleDispatchForFunctionPass()";
+  let dependentDialects = [
+    "IREE::Flow::FlowDialect",
+  ];
 }
 
 def PadLinalgOps :
@@ -30,6 +56,9 @@
            /*default=*/"4",
            "Specify the padding size">,
   ];
+  let dependentDialects = [
+    "mlir::linalg::LinalgDialect",
+  ];
 }
 
 #endif  // IREE_PREPROCESSING_COMMON_PASSES
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel
index d11313b..539471f 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel
@@ -19,10 +19,18 @@
             "conv2d_to_img2col.mlir",
             "make_single_dispatch_for_function.mlir",
             "pad_linalg_ops.mlir",
+            "preprocessing_match_ops.mlir",
+            "transform_symbol_importing.mlir",
         ],
         include = ["*.mlir"],
+        exclude = [
+            "external_function_spec.mlir",
+        ],
     ),
     cfg = "//compiler:lit.cfg.py",
+    data = [
+        "external_function_spec.mlir",
+    ],
     tools = [
         "//tools:iree-opt",
         "@llvm-project//llvm:FileCheck",
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt
index 19425cb..991046d 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt
@@ -17,9 +17,13 @@
     "conv2d_to_img2col.mlir"
     "make_single_dispatch_for_function.mlir"
     "pad_linalg_ops.mlir"
+    "preprocessing_match_ops.mlir"
+    "transform_symbol_importing.mlir"
   TOOLS
     FileCheck
     iree-opt
+  DATA
+    external_function_spec.mlir
 )
 
 ### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/external_function_spec.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/external_function_spec.mlir
new file mode 100644
index 0000000..579ea29
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/external_function_spec.mlir
@@ -0,0 +1,18 @@
+// Test for importing functions from this spec to a payload module.
+// Tested in `transform_symbol_importing.mlir`
+module attributes {transform.with_named_sequence} {
+  func.func private @some_external_function(%arg0: tensor<?xf32>) -> tensor<?xf32>
+
+  func.func @some_function(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+    return %arg0 : tensor<?xf32>
+  }
+
+  transform.named_sequence @__transform_main(%module: !transform.any_op) {
+    %new_func = transform.iree.import_symbol @some_function into %module : (!transform.any_op) -> !transform.any_op
+
+    %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op   
+    %module_2 = transform.iree.get_nearest_symbol_table %func : (!transform.any_op) -> !transform.any_op
+    %new_func_2 = transform.iree.import_symbol @some_external_function into %module_2 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/preprocessing_match_ops.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/preprocessing_match_ops.mlir
new file mode 100644
index 0000000..83f7ec4
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/preprocessing_match_ops.mlir
@@ -0,0 +1,138 @@
+// RUN: iree-opt -transform-interpreter %s --split-input-file | FileCheck %s
+
+#map = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func.func @simple_max
+func.func @simple_max(%input: tensor<?xf32>, %dest: tensor<?xf32>) -> tensor<?xf32> {
+  // CHECK-NEXT: linalg.generic
+  // CHECK-SAME:   match_status = "matched"
+  %res = linalg.generic {indexing_maps = [#map, #map],
+                         iterator_types = ["parallel"]}
+                         ins(%input : tensor<?xf32>)
+                         outs(%dest : tensor<?xf32>) attrs = {match_status = "unmatched"} {
+  ^bb0(%in: f32, %out: f32):
+    %max = arith.maximumf %in, %out : f32
+    linalg.yield %max : f32
+  } -> tensor<?xf32>
+  return %res : tensor<?xf32>
+}
+
+// CHECK: func.func @simple_min
+func.func @simple_min(%input: tensor<?xf32>, %dest: tensor<?xf32>) -> tensor<?xf32> {
+  // CHECK-NEXT: linalg.generic
+  // CHECK-SAME:   match_status = "unmatched"
+  %res = linalg.generic {indexing_maps = [#map, #map],
+                         iterator_types = ["parallel"]}
+                         ins(%input : tensor<?xf32>)
+                         outs(%dest : tensor<?xf32>) attrs = {match_status = "unmatched"} {
+  ^bb0(%in: f32, %out: f32):
+    %max = arith.minimumf %in, %out : f32
+    linalg.yield %max : f32
+  } -> tensor<?xf32>
+  return %res : tensor<?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @match(%generic: !transform.any_op {transform.readonly}) -> (!transform.any_op) {
+    transform.match.operation_name %generic ["linalg.generic"] : !transform.any_op
+    transform.iree.match.regions %generic : !transform.any_op {
+      ^bb0(%target: tensor<f32>, %empty_max: tensor<f32>):
+        %0 = linalg.generic {indexing_maps = [affine_map<() -> ()>,
+                                                affine_map<() -> ()>],
+                               iterator_types = []}
+                               ins(%target : tensor<f32>)
+                               outs(%empty_max : tensor<f32>) {
+        ^bb0(%in: f32, %out: f32):
+          %max = arith.maximumf %in, %out : f32
+          linalg.yield %max : f32
+        } -> tensor<f32>
+    }
+    transform.yield %generic : !transform.any_op
+  }
+
+  transform.named_sequence @annotate(%generic: !transform.any_op {transform.readonly}) {
+    %0 = transform.param.constant "matched" -> !transform.any_param
+    transform.annotate %generic "match_status" = %0 : !transform.any_op, !transform.any_param
+    transform.yield
+  }
+
+  transform.named_sequence @__transform_main(%module: !transform.any_op) {
+    %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op   
+    transform.foreach_match in %module
+        @match -> @annotate
+      : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+func.func private @external(%arg0: tensor<?xf32>)
+func.func private @external_aligned(%arg0: tensor<100xf32>)
+func.func private @external_static(%arg0: tensor<10xf32>)
+func.func private @other_external_static(%arg0: tensor<15xf32>)
+func.func private @external_2d(%arg0: tensor<?x?xf32>)
+
+// CHECK-LABEL: func.func @call_external
+func.func @call_external(%input: tensor<?xf32>,
+                         %input_2d: tensor<?x?xf32>,
+                         %input_aligned: tensor<100xf32>,
+                         %input_static: tensor<10xf32>,
+                         %other_static: tensor<15xf32>) {
+//       CHECK: call @external
+//  CHECK-SAME:   match_status = "matched"
+  func.call @external(%input) {match_status = "unmatched"} : (tensor<?xf32>) -> ()
+//       CHECK: call @external_2d
+//  CHECK-SAME:   match_status = "unmatched"
+  func.call @external_2d(%input_2d) {match_status = "unmatched"} : (tensor<?x?xf32>) -> ()
+//       CHECK: call @external_aligned
+//  CHECK-SAME:   match_status = "aligned_match"
+  func.call @external_aligned(%input_aligned) {match_status = "unmatched"} : (tensor<100xf32>) -> ()
+//       CHECK: call @external_static
+//  CHECK-SAME:   match_status = "static_matched"
+  func.call @external_static(%input_static) {match_status = "unmatched"} : (tensor<10xf32>) -> ()
+//       CHECK: call @other_external_static
+//  CHECK-SAME:   match_status = "matched"
+  func.call @other_external_static(%other_static) {match_status = "unmatched"} : (tensor<15xf32>) -> ()
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @static_match(%call: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) {
+    transform.match.operation_name %call ["func.call"] : !transform.any_op
+    %in0 = transform.get_operand %call[0] : (!transform.any_op) -> !transform.any_value
+    transform.iree.match.cast_compatible_type %in0 = tensor<10xf32> : !transform.any_value
+    %0 = transform.param.constant "static_matched" -> !transform.any_param
+    transform.yield %call, %0 : !transform.any_op, !transform.any_param
+  }
+  transform.named_sequence @static_alignment_match(%call: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) {
+    transform.match.operation_name %call ["func.call"] : !transform.any_op
+    %in0 = transform.get_operand %call[0] : (!transform.any_op) -> !transform.any_value
+    transform.iree.match.cast_compatible_type %in0 = tensor<?xf32> : !transform.any_value
+    transform.iree.match.dim_is_multiple_of %in0[0], 20 : !transform.any_value
+    %0 = transform.param.constant "aligned_match" -> !transform.any_param
+    transform.yield %call, %0 : !transform.any_op, !transform.any_param
+  }
+  transform.named_sequence @match(%call: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) {
+    transform.match.operation_name %call ["func.call"] : !transform.any_op
+    %in0 = transform.get_operand %call[0] : (!transform.any_op) -> !transform.any_value
+    transform.iree.match.cast_compatible_type %in0 = tensor<?xf32> : !transform.any_value
+    %0 = transform.param.constant "matched" -> !transform.any_param
+    transform.yield %call, %0 : !transform.any_op, !transform.any_param
+  }
+
+  transform.named_sequence @annotate(%call: !transform.any_op {transform.readonly},
+                                     %note: !transform.any_param {transform.readonly}) {
+    transform.annotate %call "match_status" = %note : !transform.any_op, !transform.any_param
+    transform.yield
+  }
+
+  transform.named_sequence @__transform_main(%module: !transform.any_op) {
+    %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op   
+    transform.foreach_match in %module
+        @static_match -> @annotate,
+        @static_alignment_match -> @annotate,
+        @match -> @annotate
+      : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/transform_symbol_importing.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/transform_symbol_importing.mlir
new file mode 100644
index 0000000..77a6f5b
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/transform_symbol_importing.mlir
@@ -0,0 +1,10 @@
+// RUN: iree-opt --transform-preload-library='transform-library-paths=%p/external_function_spec.mlir' --transform-interpreter %s | FileCheck %s
+
+module @example {
+  // empty
+}
+
+// CHECK-LABEL: module @example
+//       CHECK:   func.func private @some_external_function(tensor<?xf32>) -> tensor<?xf32>
+//       CHECK:   func.func @some_function(%arg0: tensor<?xf32>) -> tensor<?xf32>
+//  CHECK-NEXT:     return %arg0 : tensor<?xf32>
diff --git a/compiler/src/iree/compiler/Preprocessing/Passes.cpp b/compiler/src/iree/compiler/Preprocessing/Passes.cpp
index bd8b5a9..6d87a4b 100644
--- a/compiler/src/iree/compiler/Preprocessing/Passes.cpp
+++ b/compiler/src/iree/compiler/Preprocessing/Passes.cpp
@@ -67,6 +67,14 @@
   if (pipelineExtensions) {
     pipelineExtensions->extendPreprocessingPassPipeline(passManager);
   }
+
+  if (!preprocessingOptions.preprocessingTransformSpecFilename.empty()) {
+    Preprocessing::InterpreterPassOptions interpreterOptions;
+    interpreterOptions.transformSpecPath =
+        preprocessingOptions.preprocessingTransformSpecFilename;
+    passManager.addPass(
+        Preprocessing::createInterpreterPass(interpreterOptions));
+  }
 }
 
 void registerPreprocessingPasses() { registerCommonPreprocessingPasses(); }
diff --git a/compiler/src/iree/compiler/Preprocessing/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/BUILD.bazel
new file mode 100644
index 0000000..08a41a0
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/BUILD.bazel
@@ -0,0 +1,69 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library", "iree_gentbl_cc_library", "iree_td_library")
+load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
+
+package(
+    default_visibility = ["//visibility:public"],
+    features = ["layering_check"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+iree_td_library(
+    name = "td_files",
+    srcs = enforce_glob(
+        [
+            "PreprocessingExtensionsOps.td",
+        ],
+        include = ["*.td"],
+    ),
+    deps = [
+        "@llvm-project//mlir:OpBaseTdFiles",
+        "@llvm-project//mlir:TransformDialectTdFiles",
+    ],
+)
+
+iree_gentbl_cc_library(
+    name = "PreprocessingExtensionsOpGen",
+    tbl_outs = [
+        (
+            ["--gen-op-decls"],
+            "PreprocessingExtensionsOps.h.inc",
+        ),
+        (
+            ["--gen-op-defs"],
+            "PreprocessingExtensionsOps.cpp.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "PreprocessingExtensionsOps.td",
+    deps = [":td_files"],
+)
+
+iree_compiler_cc_library(
+    name = "PreprocessingExtensions",
+    srcs = [
+        "PreprocessingExtensions.cpp",
+        "PreprocessingExtensionsOps.cpp.inc",
+    ],
+    hdrs = [
+        "PreprocessingExtensions.h",
+        "PreprocessingExtensionsOps.h.inc",
+    ],
+    deps = [
+        ":PreprocessingExtensionsOpGen",
+        "//compiler/src/iree/compiler/Utils",
+        "//llvm-external-projects/iree-dialects:IREEDialectsTransforms",
+        "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:PDLDialect",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TransformDialect",
+        "@llvm-project//mlir:TransformUtils",
+    ],
+)
diff --git a/compiler/src/iree/compiler/Preprocessing/TransformExtensions/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/CMakeLists.txt
new file mode 100644
index 0000000..c74f3f7
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/CMakeLists.txt
@@ -0,0 +1,46 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from           #
+# compiler/src/iree/compiler/Preprocessing/TransformExtensions/BUILD.bazel     #
+#                                                                              #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary   #
+# CMake-only content.                                                          #
+#                                                                              #
+# To disable autogeneration for this file entirely, delete this header.        #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_tablegen_library(
+  NAME
+    PreprocessingExtensionsOpGen
+  TD_FILE
+    "PreprocessingExtensionsOps.td"
+  OUTS
+    --gen-op-decls PreprocessingExtensionsOps.h.inc
+    --gen-op-defs PreprocessingExtensionsOps.cpp.inc
+)
+
+iree_cc_library(
+  NAME
+    PreprocessingExtensions
+  HDRS
+    "PreprocessingExtensions.h"
+    "PreprocessingExtensionsOps.h.inc"
+  SRCS
+    "PreprocessingExtensions.cpp"
+    "PreprocessingExtensionsOps.cpp.inc"
+  DEPS
+    ::PreprocessingExtensionsOpGen
+    IREEDialectsTransforms
+    IREELinalgTransformDialect
+    LLVMSupport
+    MLIRIR
+    MLIRPDLDialect
+    MLIRSupport
+    MLIRTransformDialect
+    MLIRTransformUtils
+    iree::compiler::Utils
+  PUBLIC
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.cpp b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.cpp
new file mode 100644
index 0000000..486ca26
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.cpp
@@ -0,0 +1,409 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "PreprocessingExtensions.h"
+
+#include "iree-dialects/Dialect/LinalgTransform/SimplePatternRewriter.h"
+#include "iree/compiler/Utils/EquivalenceUtils.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Value.h"
+
+namespace mlir::iree_compiler {
+
+IREE::transform_dialect::PreprocessingExtensions::PreprocessingExtensions() {
+  registerTransformOps<
+#define GET_OP_LIST
+#include "iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensionsOps.cpp.inc"
+      >();
+}
+
+void registerTransformDialectPreprocessingExtension(DialectRegistry &registry) {
+  registry.addExtensions<IREE::transform_dialect::PreprocessingExtensions>();
+}
+
+//===----------------------------------------------------------------------===//
+// GetNearestSymbolTableOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+IREE::transform_dialect::GetNearestSymbolTableOp::applyToOne(
+    transform::TransformRewriter &rewriter, Operation *target,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  auto tableOp = SymbolTable::getNearestSymbolTable(target);
+  if (!tableOp) {
+    return emitDefaultDefiniteFailure(target);
+  }
+  results.push_back(tableOp);
+  return DiagnosedSilenceableFailure::success();
+}
+
+void IREE::transform_dialect::GetNearestSymbolTableOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(getTarget(), effects);
+  transform::producesHandle(getResult(), effects);
+  transform::modifiesPayload(effects);
+}
+
+//===----------------------------------------------------------------------===//
+// ImportSymbolOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure IREE::transform_dialect::ImportSymbolOp::apply(
+    transform::TransformRewriter &rewriter,
+    transform::TransformResults &transformResults,
+    transform::TransformState &state) {
+  auto symbolOp = SymbolTable::lookupNearestSymbolFrom(*this, getSymbol());
+  if (!symbolOp) {
+    return emitDefiniteFailure() << "could not find corresponding symbol op";
+  }
+  // Require isolated from above as the clone does not make sense with escaping
+  // values.
+  if (!symbolOp->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
+    return emitDefiniteFailure()
+           << "target symbol op is not isolated from above";
+  }
+  StringRef symbol = getSymbol().getLeafReference();
+  SmallVector<Operation *> results;
+  for (Operation *payloadOp : state.getPayloadOps(getSymbolTable())) {
+    if (!payloadOp->hasTrait<OpTrait::SymbolTable>()) {
+      return emitDefiniteFailure()
+             << "target symbol table " << payloadOp << " is not a symbol table";
+    }
+    SymbolTable symbolTable(payloadOp);
+
+    if (Operation *preExistingSymbolOp = symbolTable.lookup(symbol)) {
+      if (getForceImport()) {
+        // If we want to overwrite pre-existing symbols, just erase it here.
+        symbolTable.erase(preExistingSymbolOp);
+      } else if (getIfUndefined()) {
+        // Skip if we want to use the symbol that is already there.
+        results.push_back(preExistingSymbolOp);
+        continue;
+      } else {
+        return emitDefiniteFailure()
+               << "target symbol " << symbol << " is already defined";
+      }
+    }
+
+    // Symbol table ops must have exactly one region with exactly one block.
+    // Simply clone the target symbol op into the single block.
+    rewriter.setInsertionPointToStart(&payloadOp->getRegion(0).front());
+    results.push_back(rewriter.clone(*symbolOp));
+  }
+  transformResults.set(cast<OpResult>(getClonedSymbol()), results);
+  return DiagnosedSilenceableFailure::success();
+}
+
+void IREE::transform_dialect::ImportSymbolOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(getSymbolTable(), effects);
+  transform::producesHandle(getClonedSymbol(), effects);
+  transform::modifiesPayload(effects);
+}
+
+LogicalResult IREE::transform_dialect::ImportSymbolOp::verify() {
+  if (getForceImport() && getIfUndefined()) {
+    return emitOpError()
+           << "force_import and if_undefined are mutually exclusive";
+  }
+  if (!SymbolTable::lookupNearestSymbolFrom(*this, getSymbol())) {
+    return emitOpError() << "invalid import of undefined symbol";
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// MatchCastCompatibleDagFromRootOp
+//===----------------------------------------------------------------------===//
+
+static bool isCastableToTensorType(Type from, RankedTensorType to) {
+  auto tensorType = dyn_cast<RankedTensorType>(from);
+  if (!tensorType) {
+    return false;
+  }
+  if (tensorType.getRank() != to.getRank()) {
+    return false;
+  }
+  if (tensorType.getElementType() != to.getElementType()) {
+    return false;
+  }
+  for (auto [fromSize, toSize] :
+       llvm::zip_equal(tensorType.getShape(), to.getShape())) {
+    // If the target dimension is dynamic we can always cast to it.
+    if (ShapedType::isDynamic(toSize)) {
+      continue;
+    }
+    // Casting a dynamic dimension to a static one is never valid, and static
+    // sizes must always match.
+    if (toSize != fromSize) {
+      return false;
+    }
+  }
+  return true;
+}
+
+// Compares the regions between two operations in lockstep for equality.
+static DiagnosedSilenceableFailure
+compareOperationRegions(transform::TransformOpInterface transformOp,
+                        Operation *target, Operation *payload) {
+  if (target->getNumRegions() != payload->getNumRegions()) {
+    return transformOp.emitSilenceableError() << "region count mismatch";
+  }
+  for (auto [r0, r1] :
+       llvm::zip_equal(target->getRegions(), payload->getRegions())) {
+    if (!isStructurallyEquivalentTo(r0, r1)) {
+      return transformOp.emitSilenceableError()
+             << "target op does not match specified body";
+    }
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
+// Helper to check whether two operations are equivalent up to cast
+// compatibility of their arguments (i.e. the arguments of the payload
+// can be casted to the arguments of the target).
+static DiagnosedSilenceableFailure
+compareCastCompatibleOperations(transform::TransformOpInterface transformOp,
+                                Operation *target, Operation *payload) {
+  if (target->getName() != payload->getName()) {
+    return transformOp.emitSilenceableError()
+           << "target operation name " << target->getName()
+           << " does not match payload " << payload->getName();
+  }
+
+  if (target->getAttrDictionary() != payload->getAttrDictionary()) {
+    return transformOp.emitSilenceableError()
+           << "target attribute dictionary " << target->getAttrDictionary()
+           << " does not match payload attribute dictionary "
+           << payload->getAttrDictionary();
+  }
+
+  if (target->getNumResults() != payload->getNumResults()) {
+    return transformOp.emitSilenceableError() << "result count mismatch";
+  }
+
+  if (target->getNumOperands() != payload->getNumOperands()) {
+    return transformOp.emitSilenceableError() << "operand count mismatch";
+  }
+  for (auto [targetType, payloadType] :
+       llvm::zip_equal(target->getOperandTypes(), payload->getOperandTypes())) {
+    if (auto targetTensorType = dyn_cast<RankedTensorType>(targetType)) {
+      if (!isCastableToTensorType(payloadType, targetTensorType)) {
+        return transformOp.emitSilenceableError()
+               << "operand tensor type mismatch";
+      }
+    } else if (targetType != payloadType) {
+      return transformOp.emitSilenceableError() << "operand type mismatch";
+    }
+  }
+
+  return compareOperationRegions(transformOp, target, payload);
+}
+
+DiagnosedSilenceableFailure
+IREE::transform_dialect::MatchCastCompatibleDagFromRootOp::matchOperation(
+    Operation *current, transform::TransformResults &results,
+    transform::TransformState &state) {
+  Operation *targetDagRoot = getRegion().front().back().getPrevNode();
+
+  // Reserve the list of inputs based on the number of block arguments in
+  // the operation region.
+  int64_t numInputs = getRegion().getNumArguments();
+  SmallVector<Value> inputs(numInputs, nullptr);
+
+  // Maps from target/payload op to the order in which they were first
+  // processed. This is used to verify that two uses actually point to the
+  // same node in the dag.
+  llvm::MapVector<Operation *, int64_t> targetDagOrder;
+  llvm::MapVector<Operation *, int64_t> payloadDagOrder;
+  int64_t index = 0;
+
+  auto compareDagOrder = [&](Operation *target, Operation *payload) {
+    return targetDagOrder.find(target) == payloadDagOrder.find(payload);
+  };
+
+  // Populate the paired worklist with the current target and payload root ops.
+  SmallVector<Operation *> targetWorklist = {targetDagRoot};
+  SmallVector<Operation *> payloadWorklist = {current};
+  while (!targetWorklist.empty()) {
+    Operation *targetOp = targetWorklist.pop_back_val();
+    Operation *payloadOp = payloadWorklist.pop_back_val();
+
+    if (targetOp->hasAttr("match.operation_name_only")) {
+      if (targetOp->getName() != payloadOp->getName()) {
+        return emitSilenceableError() << "only operation name op mismatch";
+      }
+      // Do not recurse and do not require any specific structure beyond the
+      // operation name.
+      continue;
+    }
+
+    // Verify that if already processed, both operations are at the same
+    // position.
+    if (targetDagOrder.contains(targetOp) ||
+        payloadDagOrder.contains(payloadOp)) {
+      if (!compareDagOrder(targetOp, payloadOp)) {
+        return emitSilenceableError() << "dag mismatch";
+      }
+      continue;
+    }
+
+    // Verify general operation equality (name, attributes, regions).
+    DiagnosedSilenceableFailure diag =
+        compareCastCompatibleOperations(*this, targetOp, payloadOp);
+    if (!diag.succeeded()) {
+      diag.attachNote() << "While processing operation " << *payloadOp;
+      return diag;
+    }
+
+    for (auto [payloadOperand, targetOperand] :
+         llvm::zip_equal(payloadOp->getOperands(), targetOp->getOperands())) {
+      // If the target value is a block argument, map the payload value to the
+      // associated input and don't process its producer.
+      if (auto targetBlockArg = dyn_cast<BlockArgument>(targetOperand)) {
+        if (targetBlockArg.getOwner() != &getRegion().front()) {
+          return emitDefiniteFailure() << "Invalid block argument in target";
+        }
+        int64_t argIdx = targetBlockArg.getArgNumber();
+        if (inputs[argIdx] && inputs[argIdx] != targetOperand) {
+          return emitSilenceableError()
+                 << "input operand with conflicting uses";
+        }
+        inputs[argIdx] = payloadOperand;
+        continue;
+      }
+
+      Operation *payloadDefiningOp = payloadOperand.getDefiningOp();
+      if (!payloadDefiningOp) {
+        return emitSilenceableError()
+               << "early termination of the operation dag";
+      }
+
+      // Check whether the producer was already processed, and if so make sure
+      // the target and payload match.
+      Operation *targetDefiningOp = targetOperand.getDefiningOp();
+      if (targetDagOrder.contains(targetDefiningOp) ||
+          payloadDagOrder.contains(payloadDefiningOp)) {
+        if (!compareDagOrder(targetDefiningOp, payloadDefiningOp)) {
+          return emitSilenceableError() << "dag mismatch";
+        }
+        continue;
+      }
+      // Pop the producer of this value onto the worklist.
+      targetWorklist.push_back(targetDefiningOp);
+      payloadWorklist.push_back(payloadDefiningOp);
+    }
+
+    // Mark the current target + payload as processed.
+    targetDagOrder[targetOp] = index;
+    payloadDagOrder[payloadOp] = index++;
+  }
+
+  // Verify that all input arguments were successfully matched.
+  if (llvm::any_of(inputs, [](Value in) { return !in; })) {
+    return emitSilenceableError() << "failed to match all input nodes";
+  }
+
+  results.setValues(cast<OpResult>(getInputs()), inputs);
+  results.setValues(cast<OpResult>(getOutputs()), current->getResults());
+  return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult
+IREE::transform_dialect::MatchCastCompatibleDagFromRootOp::verify() {
+  auto &body = getRegion().front();
+  if (llvm::range_size(body.getOperations()) < 2) {
+    return emitOpError() << "match region must contain at least one operation";
+  }
+  // TODO: Region verification that it includes a single DAG.
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// MatchCastCompatibleTypesOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+IREE::transform_dialect::MatchCastCompatibleTypesOp::matchValue(
+    Value current, transform::TransformResults &results,
+    transform::TransformState &state) {
+  Type targetType = getTargetType();
+  if (auto targetTensorType = dyn_cast<RankedTensorType>(targetType)) {
+    if (!isCastableToTensorType(current.getType(), targetTensorType)) {
+      return emitSilenceableError()
+             << "type " << current.getType() << " is not castable to "
+             << targetTensorType;
+    }
+    return DiagnosedSilenceableFailure::success();
+  }
+  if (current.getType() != targetType) {
+    return emitSilenceableError()
+           << "type " << current.getType() << " does not match " << targetType;
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
+//===----------------------------------------------------------------------===//
+// MatchDimIsMultipleOfOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+IREE::transform_dialect::MatchDimIsMultipleOfOp::matchValue(
+    Value current, transform::TransformResults &results,
+    transform::TransformState &state) {
+  auto shapedType = dyn_cast<ShapedType>(current.getType());
+  if (!shapedType) {
+    return emitSilenceableError()
+           << "type " << current.getType() << " is not a shaped type";
+  }
+  int64_t dim = getDim();
+  if (dim > shapedType.getRank()) {
+    return emitSilenceableError()
+           << "dim " << dim << " out of range for shaped type " << shapedType;
+  }
+  int64_t size = getSize();
+  if (shapedType.getShape()[dim] % size != 0) {
+    return emitSilenceableError()
+           << "dim " << dim << " of shaped type " << shapedType
+           << " is not a multiple of " << size;
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
+//===----------------------------------------------------------------------===//
+// MatchRegionsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+IREE::transform_dialect::MatchRegionsOp::matchOperation(
+    Operation *current, transform::TransformResults &results,
+    transform::TransformState &state) {
+  Operation *comparisonTarget = &getRegion().front().front();
+  return compareOperationRegions(*this, comparisonTarget, current);
+}
+
+LogicalResult IREE::transform_dialect::MatchRegionsOp::verify() {
+  auto &body = getRegion().front();
+  if (llvm::range_size(body.getOperations()) != 2) {
+    return emitOpError() << "match region must contain exactly one operation";
+  }
+  Operation *target = &body.front();
+  if (target->getNumRegions() == 0) {
+    return emitOpError() << "contained operation for comparison must have at "
+                            "least one region";
+  }
+  return success();
+}
+
+} // namespace mlir::iree_compiler
+
+#define GET_OP_CLASSES
+#include "iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensionsOps.cpp.inc"
diff --git a/compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.h b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.h
new file mode 100644
index 0000000..dbc9579
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.h
@@ -0,0 +1,40 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_CODEGEN_PREPROCESSING_TRANSFORMEXTENSIONS_PREPROCESSINGEXTENSIONS_H_
+#define IREE_COMPILER_CODEGEN_PREPROCESSING_TRANSFORMEXTENSIONS_PREPROCESSINGEXTENSIONS_H_
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
+
+namespace mlir {
+class DialectRegistry;
+} // namespace mlir
+
+#define GET_OP_CLASSES
+#include "iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensionsOps.h.inc"
+
+namespace mlir::iree_compiler {
+
+/// Registers Preprocessing transformations that require IREE-specific
+/// information into the transform dialect.
+void registerTransformDialectPreprocessingExtension(DialectRegistry &registry);
+
+namespace IREE::transform_dialect {
+// Hook to register Preprocessing transformations to the transform dialect.
+class PreprocessingExtensions
+    : public transform::TransformDialectExtension<PreprocessingExtensions> {
+public:
+  PreprocessingExtensions();
+};
+} // namespace IREE::transform_dialect
+
+} // namespace mlir::iree_compiler
+
+#endif // IREE_COMPILER_CODEGEN_PREPROCESSING_TRANSFORMEXTENSIONS_PREPROCESSINGEXTENSIONS_H_
diff --git a/compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensionsOps.td b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensionsOps.td
new file mode 100644
index 0000000..de0232f
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensionsOps.td
@@ -0,0 +1,213 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_PREPROCESSING_TRANSFORMEXTENSIONS_PREPROCESSINGEXTENSIONS
+#define IREE_COMPILER_DIALECT_PREPROCESSING_TRANSFORMEXTENSIONS_PREPROCESSINGEXTENSIONS
+
+include "mlir/Dialect/Transform/IR/MatchInterfaces.td"
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
+include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpBase.td"
+
+def GetNearestSymbolTableOp : Op<Transform_Dialect, "iree.get_nearest_symbol_table",
+    [FunctionalStyleTransformOpTrait,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     TransformOpInterface,
+     TransformEachOpTrait,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Returns the nearest symbol table op for each op in the payload, inclusive.
+
+    #### Return modes
+
+    This operation reads the `target` handle and produces the `result`
+    handle. This operation emits a definite failure if the nearest symbol table
+    is unknown.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs TransformHandleTypeInterface:$result);
+
+  let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
+  let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::Operation* target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
+def ImportSymbolOp : Op<Transform_Dialect, "iree.import_symbol",
+    [FunctionalStyleTransformOpTrait,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     DeclareOpInterfaceMethods<TransformOpInterface>,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Clones the op defined by the given symbol into the given symbol table and
+    returns the cloned symbol. If `force_import` is set, this will (unsafely)
+    overwrite any pre-existing definitions of the same symbol. If
+    `if_undefined` is set, this will return a handle to the pre-existing symbol
+    in the payload if found instead of failing.
+
+    #### Return modes
+
+    This operation reads the `symbol_table` handle and produces the
+    `cloned_symbol` handle. This operation emits a definite failure if the if
+    the `symbol_table` op does not define a symbol table.
+
+    This will emit a definite failure if the symbol already exists in the
+    symbol table and neither `force_import` and `if_undefined` are set.
+  }];
+
+  let arguments = (ins SymbolRefAttr:$symbol,
+                       UnitAttr:$if_undefined,
+                       UnitAttr:$force_import,
+                       TransformHandleTypeInterface:$symbol_table);
+  let results = (outs TransformHandleTypeInterface:$cloned_symbol);
+
+  let assemblyFormat = [{
+    (`force` $force_import^)? $symbol `into` $symbol_table
+    (`if` `undefined` $if_undefined^)? attr-dict 
+    `:` functional-type(operands, results)
+  }];
+  let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+
+  let hasVerifier = 1;
+}
+
+def MatchCastCompatibleDagFromRootOp : Op<Transform_Dialect, "iree.match.cast_compatible_dag_from_root",
+    [IsolatedFromAbove,
+     MatchOpInterface,
+     SingleOpMatcher,
+     SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">,
+     MemoryEffectsOpInterface]> {
+  let summary =
+      "Checks if the body of the target op matches an operation dag starting "
+      "at the given root";
+  let description = [{
+    Checks whether the given root op matches an operation dag specified in the
+    body of this op. Enforces cast compatibilty between types rather than a
+    strict equality, similar to `iree.match.cast_compatible_type`.
+
+    Note: This operation is experimental and subject to change. General subgraph
+    matching is difficult and can spawn various DSLs and a slew of transforms.
+    This op tries to keep it relatively simple an inflexible, reflecting the
+    expected use case of splicing in hand written kernels that can be equally
+    inflexible.
+
+    #### Return modes
+
+    Succeeds if the root operation matches the dag given by this op, and
+    produces a silenceable failure otherwise. Produces a definite failure
+    if the operand is not associated with a single payload value.
+
+    On success, this operation produces a handle to the inputs and outputs
+    of the operation dag based on the outputs of the root op and the block
+    arguments of this operations body.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$operand_handle);
+  let results = (outs TransformValueHandleTypeInterface:$inputs,
+                      TransformValueHandleTypeInterface:$outputs);
+  let regions = (region SizedRegion<1>:$region);
+  let assemblyFormat = "$operand_handle attr-dict-with-keyword regions `:` functional-type(operands, results)";
+  let extraClassDeclaration = SingleOpMatcher.extraDeclaration;
+
+  let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+  let hasVerifier = 1;
+}
+
+def MatchCastCompatibleTypesOp : Op<Transform_Dialect, "iree.match.cast_compatible_type",
+    [IsolatedFromAbove,
+     MatchOpInterface,
+     SingleValueMatcher,
+     MemoryEffectsOpInterface]> {
+  let summary =
+      "Checks if the body of the target op matches the body of the single contained op";
+  let description = [{
+    Checks whether the given value is cast-compatible with the given target
+    type attribute.
+
+    Currently this operation only allows casting of tensor types. Other types
+    must match exactly.
+
+    #### Return modes
+
+    Succeeds if the value's type is compatible with the target type, and
+    produces a silenceable failure otherwise. Produces a definite failure
+    if the operand is not associated with a single payload value.
+  }];
+
+  let arguments = (ins TransformValueHandleTypeInterface:$operand_handle,
+                       TypeAttr:$target_type);
+  let assemblyFormat = "$operand_handle `=` $target_type attr-dict `:` type($operand_handle)";
+  let extraClassDeclaration = SingleValueMatcher.extraDeclaration;
+
+  let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+}
+
+def MatchDimIsMultipleOfOp : Op<Transform_Dialect, "iree.match.dim_is_multiple_of",
+    [IsolatedFromAbove,
+     MatchOpInterface,
+     SingleValueMatcher,
+     MemoryEffectsOpInterface]> {
+  let summary =
+      "Checks if the body of the target op matches the body of the single contained op";
+  let description = [{
+    Checks whether the given dimension given shaped value is a multiple of the
+    given size.
+
+    #### Return modes
+
+    Succeeds if the value's type is compatible with the target type, and
+    produces a silenceable failure otherwise. Produces a definite failure
+    if the operand is not associated with a single payload value.
+  }];
+
+  let arguments = (ins TransformValueHandleTypeInterface:$operand_handle,
+                       I64Attr:$dim,
+                       I64Attr:$size);
+  let assemblyFormat = "$operand_handle `[` $dim `]` `,` $size "
+                       "attr-dict `:` type($operand_handle)";
+  let extraClassDeclaration = SingleValueMatcher.extraDeclaration;
+
+  let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+}
+
+def MatchRegionsOp : Op<Transform_Dialect, "iree.match.regions",
+    [IsolatedFromAbove,
+     MatchOpInterface,
+     SingleOpMatcher,
+     SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">,
+     MemoryEffectsOpInterface]> {
+  let summary =
+      "Checks if the body of the target op matches the body of the single contained op";
+  let description = [{
+    Does a structural comparison of the regions of the single op contained
+    within the region of this op against the regions of the target operation.
+
+    #### Return modes
+
+    Succeeds if the operation body satisfies the specified criteria, produces a
+    silenceable failure otherwise. Produces a definite failure if the operand is
+    not associated with a single payload op.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$operand_handle);
+  let regions = (region SizedRegion<1>:$region);
+  let assemblyFormat = "$operand_handle attr-dict `:` type($operand_handle) regions";
+  let extraClassDeclaration = SingleOpMatcher.extraDeclaration;
+
+  let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+  let hasVerifier = 1;
+}
+
+#endif // IREE_COMPILER_DIALECT_PREPROCESSING_TRANSFORMEXTENSIONS_PREPROCESSINGEXTENSIONS
diff --git a/compiler/src/iree/compiler/Tools/BUILD.bazel b/compiler/src/iree/compiler/Tools/BUILD.bazel
index 5d87d65..637a66b 100644
--- a/compiler/src/iree/compiler/Tools/BUILD.bazel
+++ b/compiler/src/iree/compiler/Tools/BUILD.bazel
@@ -63,6 +63,7 @@
         "//compiler/src/iree/compiler/Modules/IO/Parameters/Transforms",
         "//compiler/src/iree/compiler/Pipelines",
         "//compiler/src/iree/compiler/Preprocessing:Passes",
+        "//compiler/src/iree/compiler/Preprocessing/TransformExtensions:PreprocessingExtensions",
         "//llvm-external-projects/iree-dialects:IREEInputDialect",
         "//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
         "//llvm-external-projects/iree-dialects:IREELinalgExtPasses",
@@ -112,8 +113,21 @@
         "@llvm-project//mlir:SPIRVTransforms",
         "@llvm-project//mlir:ShapeDialect",
         "@llvm-project//mlir:TensorInferTypeOpInterfaceImpl",
+        "@llvm-project//mlir:TransformDialect",
         "@llvm-project//mlir:Transforms",
         "@llvm-project//mlir:VectorDialect",
+
+        # TransformExtensions
+        "@llvm-project//mlir:AffineTransformOps",
+        "@llvm-project//mlir:BufferizationTransformOps",
+        "@llvm-project//mlir:FuncTransformOps",
+        "@llvm-project//mlir:GPUTransformOps",
+        "@llvm-project//mlir:LinalgTransformOps",
+        "@llvm-project//mlir:MemRefTransformOps",
+        "@llvm-project//mlir:SCFTransformOps",
+        "@llvm-project//mlir:TensorTransformOps",
+        "@llvm-project//mlir:TransformLoopExtension",
+        "@llvm-project//mlir:VectorTransformOps",
     ],
 )
 
diff --git a/compiler/src/iree/compiler/Tools/CMakeLists.txt b/compiler/src/iree/compiler/Tools/CMakeLists.txt
index 97a3a92..03673ff 100644
--- a/compiler/src/iree/compiler/Tools/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Tools/CMakeLists.txt
@@ -70,6 +70,7 @@
     iree::compiler::Modules::IO::Parameters::Transforms
     iree::compiler::Pipelines
     iree::compiler::Preprocessing::Passes
+    iree::compiler::Preprocessing::TransformExtensions::PreprocessingExtensions
   PUBLIC
 )
 
@@ -105,9 +106,21 @@
     MLIRFuncDialect
     MLIRFuncToSPIRV
     MLIRTensorInferTypeOpInterfaceImpl
+    MLIRTransformDialect
     MLIRTransforms
     MLIRVectorDialect
     iree::compiler::Dialect::VM::Target::init_targets
+
+    MLIRAffineTransformOps
+    MLIRBufferizationTransformOps
+    MLIRFuncTransformOps
+    MLIRGPUTransformOps
+    MLIRLinalgTransformOps
+    MLIRMemRefTransformOps
+    MLIRSCFTransformOps
+    MLIRTensorTransformOps
+    MLIRVectorTransformOps
+    MLIRTransformLoopExtension
   PUBLIC
 )
 
diff --git a/compiler/src/iree/compiler/Tools/init_iree_dialects.h b/compiler/src/iree/compiler/Tools/init_iree_dialects.h
index 5fcdca4..c8aca0e 100644
--- a/compiler/src/iree/compiler/Tools/init_iree_dialects.h
+++ b/compiler/src/iree/compiler/Tools/init_iree_dialects.h
@@ -31,6 +31,7 @@
 #include "iree/compiler/Modules/HAL/Inline/IR/HALInlineDialect.h"
 #include "iree/compiler/Modules/HAL/Loader/IR/HALLoaderDialect.h"
 #include "iree/compiler/Modules/IO/Parameters/IR/IOParametersDialect.h"
+#include "iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.h"
 #include "mlir/IR/Dialect.h"
 
 namespace mlir::iree_compiler {
@@ -59,6 +60,9 @@
   registerCodegenInterfaces(registry);
   registerGlobalOptimizationInterfaces(registry);
   registerUKernelBufferizationInterface(registry);
+
+  // Register transform dialect extensions.
+  registerTransformDialectPreprocessingExtension(registry);
 }
 
 } // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Tools/init_mlir_dialects.h b/compiler/src/iree/compiler/Tools/init_mlir_dialects.h
index 80d22d8..ddc310d 100644
--- a/compiler/src/iree/compiler/Tools/init_mlir_dialects.h
+++ b/compiler/src/iree/compiler/Tools/init_mlir_dialects.h
@@ -13,30 +13,40 @@
 #define IREE_COMPILER_TOOLS_INIT_MLIR_DIALECTS_H_
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h"
 #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/PDL/IR/PDL.h"
 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
 #include "mlir/Dialect/Quant/QuantOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/Shape/IR/Shape.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h"
 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
 #include "mlir/IR/Dialect.h"
 
 #ifdef IREE_HAVE_C_OUTPUT_FORMAT
@@ -76,6 +86,18 @@
   tensor::registerInferTypeOpInterfaceExternalModels(registry);
   tensor::registerTilingInterfaceExternalModels(registry);
 
+  // Register all transform dialect extensions.
+  affine::registerTransformDialectExtension(registry);
+  bufferization::registerTransformDialectExtension(registry);
+  func::registerTransformDialectExtension(registry);
+  gpu::registerTransformDialectExtension(registry);
+  linalg::registerTransformDialectExtension(registry);
+  memref::registerTransformDialectExtension(registry);
+  scf::registerTransformDialectExtension(registry);
+  tensor::registerTransformDialectExtension(registry);
+  transform::registerLoopExtension(registry);
+  vector::registerTransformDialectExtension(registry);
+
 #ifdef IREE_HAVE_C_OUTPUT_FORMAT
   registry.insert<emitc::EmitCDialect>();
 #endif // IREE_HAVE_C_OUTPUT_FORMAT
diff --git a/compiler/src/iree/compiler/Tools/init_mlir_passes.h b/compiler/src/iree/compiler/Tools/init_mlir_passes.h
index 611797d..b3e1de8 100644
--- a/compiler/src/iree/compiler/Tools/init_mlir_passes.h
+++ b/compiler/src/iree/compiler/Tools/init_mlir_passes.h
@@ -23,6 +23,7 @@
 #include "mlir/Dialect/SCF/Transforms/Passes.h"
 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
 #include "mlir/Dialect/Shape/Transforms/Passes.h"
+#include "mlir/Dialect/Transform/Transforms/Passes.h"
 #include "mlir/Transforms/Passes.h"
 
 namespace mlir {
@@ -57,6 +58,9 @@
   affine::registerAffinePipelineDataTransferPass();
   registerConvertAffineToStandardPass();
 
+  // Arm SME
+  arm_sme::registerArmSMEPasses();
+
   // Linalg
   registerLinalgPasses();
 
@@ -80,8 +84,8 @@
   registerConvertControlFlowToSPIRVPass();
   registerConvertFuncToSPIRVPass();
 
-  // Arm SME
-  arm_sme::registerArmSMEPasses();
+  // Transform Dialect
+  transform::registerTransformPasses();
 }
 
 } // namespace mlir
diff --git a/compiler/src/iree/compiler/Utils/EquivalenceUtils.cpp b/compiler/src/iree/compiler/Utils/EquivalenceUtils.cpp
index 2f0b122..b4a938a 100644
--- a/compiler/src/iree/compiler/Utils/EquivalenceUtils.cpp
+++ b/compiler/src/iree/compiler/Utils/EquivalenceUtils.cpp
@@ -34,7 +34,7 @@
 static bool isStructurallyEquivalentTo(Operation &lhs, Operation &rhs,
                                        IRMapping &parentMapping);
 
-static bool isStructurallyEquivalentTo(Region &lhs, Region &rhs) {
+bool isStructurallyEquivalentTo(Region &lhs, Region &rhs) {
   IRMapping mapping;
   return isStructurallyEquivalentTo(lhs, rhs, mapping);
 }
diff --git a/compiler/src/iree/compiler/Utils/EquivalenceUtils.h b/compiler/src/iree/compiler/Utils/EquivalenceUtils.h
index c533b04..1b51717 100644
--- a/compiler/src/iree/compiler/Utils/EquivalenceUtils.h
+++ b/compiler/src/iree/compiler/Utils/EquivalenceUtils.h
@@ -11,6 +11,12 @@
 
 namespace mlir::iree_compiler {
 
+// Recursively compares two regions for structural equivalence.
+//
+// Structural equivalence ensures that operations in both regions
+// |lhs| and |rhs| have the same attributes and same use-def structure.
+bool isStructurallyEquivalentTo(Region &lhs, Region &rhs);
+
 // Recursively compares two operations for structural equivalence.
 //
 // Structural equivalence ensures that operations in the regions of both the
diff --git a/samples/custom_dispatch/README.md b/samples/custom_dispatch/README.md
index a5579f7..63d5748 100644
--- a/samples/custom_dispatch/README.md
+++ b/samples/custom_dispatch/README.md
@@ -257,6 +257,49 @@
 substitution approaches should be used instead and in many cases that is
 sufficient for most workloads not involving other libraries.
 
+### Compile Time Inlining Custom Function Calls
+
+**Overview**: user defines functions with MLIR dialects IREE is able to ingest
+paired with a matcher and replacement pattern. The matcher runs as preprocessing
+and calls into the replacement pattern for all successful matches. The
+replacement pattern imports a function from the externally 
+ABI, wires them up and links them in their runtime binary, declares the
+externally available functions in IR, and emits calls to the functions in IR
+interleaved with other IR.
+
+**Workflows**:
+
+Statically matched and imported external functions
+```
+                                            +--------------+
+                                            | example.mlir |
++--------------+       +--------------+     +--------------+
+| (one of the  |       | functions +  |            v
+| above static | ----> | matchers +   | ----> iree-compile
+| workflows)   |       | replace.mlir |            v
++--------------+       +--------------+     +--------------+
+                                            | example.vmfb |
+                                            +--------------+
+```
+
+**Samples**:
+
+* CPU: [custom_dispatch/cpu/embedded/](./cpu/embedded/) (.c -> .o)
+  * [custom_dispatch/cpu/embedded/](./cpu/embedded/example_transform_spec.mlir) (.mlir)
+* Vulkan/SPIR-V: [custom_dispatch/vulkan/shaders/](./vulkan/shaders/) (.glsl -> .spv)
+  * [custom_dispatch/vulkan/shaders/](./vulkan/shaders/example_transform_spec.mlir) (.mlir)
+
+The above two samples build on top of a couple of the static workflows shown
+above, but should work with any of the other approaches. The idea is to separate
+the custom kernel from the target module to be compiled, allowing integration of
+custom dispatches with default IREE codegen without the need to build a custom
+set of compiler tools around IREE to generate the necessary IR.
+
+There are a number of possible points at which the match and replace can happen;
+the above shows it after import + input conversion. Other plugin points are
+possible (e.g. before input conversion or after global optimization), but
+currently are missing some ergonomics on the available matchers.
+
 ### Others
 
 Most other situations are covered by [custom modules](/samples/custom_module/).
diff --git a/samples/custom_dispatch/cpu/embedded/CMakeLists.txt b/samples/custom_dispatch/cpu/embedded/CMakeLists.txt
index 95d32f9..2f428dc 100644
--- a/samples/custom_dispatch/cpu/embedded/CMakeLists.txt
+++ b/samples/custom_dispatch/cpu/embedded/CMakeLists.txt
@@ -66,9 +66,11 @@
   SRCS
     "example_hal.mlir"
     "example_stream.mlir"
+    "example_transform.mlir"
   DATA
   functions_arm_64.o
   functions_x86_64.o
+  example_transform_spec.mlir
   TOOLS
     FileCheck
     iree-compile
diff --git a/samples/custom_dispatch/cpu/embedded/README.md b/samples/custom_dispatch/cpu/embedded/README.md
index dc94a99..ede94a6 100644
--- a/samples/custom_dispatch/cpu/embedded/README.md
+++ b/samples/custom_dispatch/cpu/embedded/README.md
@@ -148,3 +148,29 @@
         --input=8xf32=4 \
         --module=/tmp/example.vmfb
     ```
+
+## Custom Kernel Match and Replace Scripting Instructions
+
+Follow the first two steps above to build the samples, and then compile with one
+additional flag to include the path to the kernel matcher and replacer.
+
+    ```
+    iree-compile \
+        --iree-hal-executable-object-search-path=../iree-build/ \
+        --iree-preprocessing-transform-spec-filename=samples/custom_dispatch/cpu/embedded/example_transform_spec.mlir \
+        samples/custom_dispatch/cpu/embedded/example_transform.mlir \
+        -o=/tmp/example.vmfb
+    ```
+
+And then run the example the same way.
+
+    ```
+    iree-run-module \
+        --device=local-sync \
+        --function=mixed_invocation \
+        --input=5xf32=7 \
+        --input=5xf32=4 \
+        --input=10xf32=-4 \
+        --input=10xf32=3 \
+        --module=/tmp/example.vmfb
+    ```
diff --git a/samples/custom_dispatch/cpu/embedded/example_transform.mlir b/samples/custom_dispatch/cpu/embedded/example_transform.mlir
new file mode 100644
index 0000000..0735189
--- /dev/null
+++ b/samples/custom_dispatch/cpu/embedded/example_transform.mlir
@@ -0,0 +1,111 @@
+// RUN: iree-compile %s \
+// RUN:     --iree-hal-executable-object-search-path=$IREE_BINARY_DIR \
+// RUN:     --iree-preprocessing-transform-spec-filename=%p/example_transform_spec.mlir | \
+// RUN: iree-run-module \
+// RUN:     --device=local-sync \
+// RUN:     --module=- \
+// RUN:     --function=mixed_invocation \
+// RUN:     --input=5xf32=7 \
+// RUN:     --input=5xf32=4 \
+// RUN:     --input=10xf32=-4 \
+// RUN:     --input=10xf32=3 | \
+// RUN: FileCheck %s
+
+// The configuration used for executable compilation.
+// This lets the compiler and runtime know the format and requirements of the
+// executable binaries produced and multiple variants with differing formats
+// and compilation options (architectures, etc) can be embedded for runtime
+// selection.
+#x86_64_target = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {
+  data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+  native_vector_size = 32 : index,
+  target_triple = "x86_64-none-elf"
+}>
+
+// The target devices that the program will run on. We can compile and run with
+// multiple targets, but this example is maintaining an implicit requirement
+// that the custom kernel being spliced in is supported by the target device,
+// hence we only support llvm-cpu here.
+#cpu_target = #hal.device.target<"llvm-cpu", {
+  executable_targets = [
+    #x86_64_target
+  ]
+}>
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0)>
+module @example attributes {hal.device.targets = [#cpu_target]} {
+
+  // CHECK-LABEL: EXEC @mixed_invocation
+  func.func @mixed_invocation(%lhs: tensor<?xf32>,
+                              %rhs: tensor<?xf32>,
+                              %lhs_static: tensor<10xf32>,
+                              %rhs_static: tensor<10xf32>) -> (tensor<?xf32>, tensor<10xf32>) {
+    %c0 = arith.constant 0 : index
+    %dim = tensor.dim %lhs, %c0 : tensor<?xf32>
+    %empty = tensor.empty(%dim) : tensor<?xf32>
+    %max = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
+                                            affine_map<(d0) -> (d0)>,
+                                            affine_map<(d0) -> (d0)>],
+                           iterator_types = ["parallel"]}
+                           ins(%lhs, %rhs : tensor<?xf32>, tensor<?xf32>)
+                           outs(%empty : tensor<?xf32>) {
+    ^bb0(%in: f32, %in0: f32, %out: f32):
+      %m = arith.mulf %in, %in0 : f32
+      linalg.yield %m : f32
+    } -> tensor<?xf32>
+    %abs = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
+                                            affine_map<(d0) -> (d0)>],
+                           iterator_types = ["parallel"]}
+                           ins(%max : tensor<?xf32>)
+                           outs(%empty : tensor<?xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      %a = math.absf %in : f32
+      linalg.yield %a : f32
+    } -> tensor<?xf32>
+    %neg = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
+                                            affine_map<(d0) -> (d0)>],
+                           iterator_types = ["parallel"]}
+                           ins(%abs : tensor<?xf32>)
+                           outs(%empty : tensor<?xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      %n = arith.negf %in : f32
+      linalg.yield %n : f32
+    } -> tensor<?xf32>
+
+    %empty_static = tensor.empty() : tensor<10xf32>
+    %max_static = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
+                                            affine_map<(d0) -> (d0)>,
+                                            affine_map<(d0) -> (d0)>],
+                           iterator_types = ["parallel"]}
+                           ins(%lhs_static, %rhs_static : tensor<10xf32>, tensor<10xf32>)
+                           outs(%empty_static : tensor<10xf32>) {
+    ^bb0(%in: f32, %in0: f32, %out: f32):
+      %m = arith.mulf %in, %in0 : f32
+      linalg.yield %m : f32
+    } -> tensor<10xf32>
+    %abs_static = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
+                                            affine_map<(d0) -> (d0)>],
+                           iterator_types = ["parallel"]}
+                           ins(%max_static : tensor<10xf32>)
+                           outs(%empty_static : tensor<10xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      %a = math.absf %in : f32
+      linalg.yield %a : f32
+    } -> tensor<10xf32>
+    %neg_static = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
+                                            affine_map<(d0) -> (d0)>],
+                           iterator_types = ["parallel"]}
+                           ins(%abs_static : tensor<10xf32>)
+                           outs(%empty_static : tensor<10xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      %n = arith.negf %in : f32
+      linalg.yield %n : f32
+    } -> tensor<10xf32>
+
+    // Add 1 to show that it actually runs the custom kernel.
+    // CHECK: 5xf32=-27 -27 -27 -27 -27
+    // CHECK: 10xf32=-11 -11 -11 -11 -11 -11 -11 -11 -11 -11
+    return %neg, %neg_static : tensor<?xf32>, tensor<10xf32>
+  }
+}  // module
diff --git a/samples/custom_dispatch/cpu/embedded/example_transform_spec.mlir b/samples/custom_dispatch/cpu/embedded/example_transform_spec.mlir
new file mode 100644
index 0000000..1867e5e
--- /dev/null
+++ b/samples/custom_dispatch/cpu/embedded/example_transform_spec.mlir
@@ -0,0 +1,175 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+// The configuration used for executable compilation.
+// This specifies the device configurations that support this custom kernel.
+#x86_64_target = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {
+  data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+  native_vector_size = 32 : index,
+  target_triple = "x86_64-none-elf"
+}>
+
+// The target devices that the program will run on.
+// These can come from compiler flags and multiple targets can be supported
+// It's possible, for example, to support targeting multiple devices in the same
+// compiled binary (CPU + Vulkan, etc).
+#cpu_target = #hal.device.target<"llvm-cpu", {
+  executable_targets = [
+    #x86_64_target
+  ]
+}>
+
+module attributes {transform.with_named_sequence} {
+
+  // Executable containing exported shims and calls to external functions.
+  // See the other examples in this directory for in-depth explanations of
+  // the IR structure of this executable.
+  hal.executable private @executable {
+    hal.executable.variant public @x86_64 target(#x86_64_target) objects([
+      #hal.executable.object<{
+        path = "samples/custom_dispatch/cpu/embedded/functions_x86_64.o"
+      }>
+    ]) {
+      hal.executable.export public @simple_mul_abs_negate ordinal(0)
+          layout(#hal.pipeline.layout<push_constants = 1, sets = [
+            <0, bindings = [
+                <0, storage_buffer, ReadOnly>,
+                <1, storage_buffer, ReadOnly>,
+                <2, storage_buffer>
+            ]>
+          ]>) {
+      ^bb0(%device: !hal.device, %workload: index):
+        %x = affine.apply affine_map<()[s0] -> (s0 ceildiv 64)>()[%workload]
+        %c1 = arith.constant 1 : index
+        hal.return %x, %c1, %c1 : index, index, index
+      }
+      builtin.module {
+        func.func private @simple_mul_abs_negate_workgroup(%binding0: memref<?xf32>, %binding1: memref<?xf32>, %binding2: memref<?xf32>, %dim: index, %tid: index) attributes {
+          hal.import.static
+        }
+        func.func @simple_mul_abs_negate() {
+          %c0 = arith.constant 0 : index
+          %dim_i32 = hal.interface.constant.load[0] : i32
+          %dim = arith.index_castui %dim_i32 : i32 to index
+          %workgroup_id_x = hal.interface.workgroup.id[0] : index
+          %tid = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x]
+
+          %binding0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : memref<?xf32>{%dim}
+          %binding1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<?xf32>{%dim}
+          %binding2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<?xf32>{%dim}
+
+          func.call @simple_mul_abs_negate_workgroup(%binding0, %binding1, %binding2, %dim, %tid) : (memref<?xf32>, memref<?xf32>, memref<?xf32>, index, index) -> ()
+          return
+        }
+      }
+    }  // hal.executable.variant
+  }  // hal.executable
+
+  func.func @call_mul_abs_negate(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
+    %c0 = arith.constant 0 : index
+    %dim = tensor.dim %arg0, %c0 : tensor<?xf32>
+    %dim_i32 = arith.index_cast %dim : index to i32
+
+    // Dispatch a basic `ret = -|lhs * rhs|` using an external function.
+    %0 = flow.dispatch @executable::@x86_64::@simple_mul_abs_negate[%dim](%dim_i32, %arg0, %arg1) {
+      hal.interface.bindings = [
+        #hal.interface.binding<0, 0>,
+        #hal.interface.binding<0, 1>,
+        #hal.interface.binding<0, 2>
+      ],
+      // HACK: keep the executable live through DCE. Only required when
+      // using the automatic variant selection.
+      hal.executable.ref = [@executable]
+    } : (i32, tensor<?xf32>{%dim}, tensor<?xf32>{%dim}) -> tensor<?xf32>{%dim}
+    return %0 : tensor<?xf32>
+  }
+
+  transform.named_sequence @match_mul_abs_negate(%root: !transform.any_op {transform.readonly}) -> (!transform.any_value, !transform.any_value) {
+    %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %root {
+      ^bb0(%lhs: tensor<?xf32>, %rhs: tensor<?xf32>):
+        // The matcher does not recurse to the constant index + dim because
+        // their only consumer matches only the operation name.
+        %c0 = arith.constant 0 : index
+        %dim = tensor.dim %lhs, %c0 : tensor<?xf32>
+        // --------------------------------------------------------------------
+        %empty = tensor.empty(%dim) {"match.operation_name_only"} : tensor<?xf32>
+        %mul = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
+                                                affine_map<(d0) -> (d0)>,
+                                                affine_map<(d0) -> (d0)>],
+                               iterator_types = ["parallel"]}
+                               ins(%lhs, %rhs : tensor<?xf32>, tensor<?xf32>)
+                               outs(%empty : tensor<?xf32>) {
+        ^bb0(%in: f32, %in0: f32, %out: f32):
+          %m = arith.mulf %in, %in0 : f32
+          linalg.yield %m : f32
+        } -> tensor<?xf32>
+        %abs = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
+                                                affine_map<(d0) -> (d0)>],
+                               iterator_types = ["parallel"]}
+                               ins(%mul : tensor<?xf32>)
+                               outs(%empty : tensor<?xf32>) {
+        ^bb0(%in: f32, %out: f32):
+          %a = math.absf %in : f32
+          linalg.yield %a : f32
+        } -> tensor<?xf32>
+        // The payload root is compared starting from here, walking up the chain
+        // of producers
+        %neg = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
+                                                affine_map<(d0) -> (d0)>],
+                               iterator_types = ["parallel"]}
+                               ins(%abs : tensor<?xf32>)
+                               outs(%empty : tensor<?xf32>) {
+        ^bb0(%in: f32, %out: f32):
+          %n = arith.negf %in : f32
+          linalg.yield %n : f32
+        } -> tensor<?xf32>
+    } : (!transform.any_op) -> (!transform.any_value, !transform.any_value)
+    transform.yield %ins, %outs : !transform.any_value, !transform.any_value
+  }
+
+  // Rewrite callback for `transform.foreach_match`. The input signature for
+  // this sequence must match exactly with the outputs of the matcher. In this
+  // case the matcher returns the inputs and outputs to the matched dag directly
+  // so we just insert a call to the hand authored function above.
+  transform.named_sequence @cast_and_call_dag(%ins: !transform.any_value {transform.readonly},
+                                              %out: !transform.any_value {transform.readonly}) {
+    %root = transform.get_defining_op %out : (!transform.any_value) -> !transform.any_op
+    %module = transform.iree.get_nearest_symbol_table %root : (!transform.any_op) -> !transform.any_op
+    %executable = transform.iree.import_symbol @executable into %module if undefined : (!transform.any_op) -> !transform.any_op
+    %func = transform.iree.import_symbol @call_mul_abs_negate into %module if undefined : (!transform.any_op) -> !transform.any_op
+    transform.func.cast_and_call %func(%ins) -> %out after %root {
+          // This specifies how to resolve type mismatches between the arguments
+          // of the function and the inputs from the matcher. In this example,
+          // the only casts this will generate are same-rank tensor casts that
+          // drop static information.
+          transform.type_conversion.tensor.cast_shape_dynamic_dims
+      } : (!transform.any_op, !transform.any_value, !transform.any_value, !transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+
+  // Entry point for the transform interpreter, nested on the full module. This
+  // is because the rewrites needed for importing the custom kernel needs to
+  // add a new symbol to the module's symbol table.
+  transform.named_sequence @__transform_main(%module: !transform.any_op) {
+    // Gather the set of functions within the module.
+    %funcs = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op   
+    // For each function in the module, run the matcher on all contained
+    // operations.
+    transform.foreach %funcs : !transform.any_op {
+      ^bb1(%func: !transform.any_op):
+        transform.foreach_match in %func
+            // <matcher name> -> <rewriter name>
+            // Multiple matcher-action pairs can be specified comma separated,
+            // here we are only doing a single kind of match and replace.
+            @match_mul_abs_negate -> @cast_and_call_dag
+          : (!transform.any_op) -> (!transform.any_op)
+    }
+    // Cleanup leftover dead code; cast_and_call does not do replacement, only
+    // rewires uses.
+    transform.apply_dce to %module : !transform.any_op
+    transform.yield
+  }
+}
diff --git a/samples/custom_dispatch/cpu/embedded/functions.c b/samples/custom_dispatch/cpu/embedded/functions.c
index 0746c4e..a962666 100644
--- a/samples/custom_dispatch/cpu/embedded/functions.c
+++ b/samples/custom_dispatch/cpu/embedded/functions.c
@@ -85,3 +85,33 @@
     binding1[i] *= binding0[i];
   }
 }
+
+// `ret = -|lhs * rhs|`
+//
+// Conforms to ABI:
+// #hal.pipeline.layout<push_constants = 1, sets = [
+//   <0, bindings = [
+//       <0, storage_buffer, ReadOnly>,
+//       <1, storage_buffer, ReadOnly>,
+//       <2, storage_buffer>
+//   ]>
+// ]>
+// With a workgroup size of 64x1x1.
+void simple_mul_abs_negate_workgroup(
+    // vvvv simplification pending (buffer + offset)
+    const float* restrict binding0, const float* restrict binding0_aligned,
+    size_t binding0_offset, size_t binding0_size, size_t binding0_stride,
+    const float* restrict binding1, const float* restrict binding1_aligned,
+    size_t binding1_offset, size_t binding1_size, size_t binding1_stride,
+    float* restrict binding2, float* restrict binding2_aligned,
+    size_t binding2_offset, size_t binding2_size, size_t binding2_stride,
+    // ^^^^ simplification pending (buffer + offset)
+    size_t dim, size_t tid) {
+  size_t end = tid + 64;
+  if (end > dim) end = dim;
+  for (size_t i = tid; i < end; ++i) {
+    float prod = binding0[i] * binding1[i];
+    if (prod >= 0) prod = -prod;
+    binding2[i] = prod + 1;
+  }
+}
diff --git a/samples/custom_dispatch/vulkan/shaders/CMakeLists.txt b/samples/custom_dispatch/vulkan/shaders/CMakeLists.txt
index 9d96884..13b0436 100644
--- a/samples/custom_dispatch/vulkan/shaders/CMakeLists.txt
+++ b/samples/custom_dispatch/vulkan/shaders/CMakeLists.txt
@@ -44,7 +44,18 @@
       ${CMAKE_CURRENT_SOURCE_DIR}/simple_mul_inplace.glsl
   VERBATIM
 )
+add_custom_command(
+  OUTPUT one_workgroup_argmax_subgroup_f32.spv
+  DEPENDS one_workgroup_argmax_subgroup_f32.glsl
+  COMMAND ${GLSLC}
+      -fshader-stage=compute
+      --target-spv=spv1.3
+      -o one_workgroup_argmax_subgroup_f32.spv
+      ${CMAKE_CURRENT_SOURCE_DIR}/one_workgroup_argmax_subgroup_f32.glsl
+  VERBATIM
+)
 add_custom_target(iree_samples_custom_dispatch_vulkan_shaders_spv DEPENDS
+  ${CMAKE_CURRENT_BINARY_DIR}/one_workgroup_argmax_subgroup_f32.spv
   ${CMAKE_CURRENT_BINARY_DIR}/simple_mul.spv
   ${CMAKE_CURRENT_BINARY_DIR}/simple_mul_inplace.spv
 )
@@ -56,8 +67,10 @@
   SRCS
     "example.mlir"
     "example_inline.mlir"
+    "example_transform.mlir"
   DATA
     ${_SPV_TARGET}
+    iree::samples::custom_dispatch::vulkan::shaders::example_transform_spec.mlir
   TOOLS
     FileCheck
     iree-compile
diff --git a/samples/custom_dispatch/vulkan/shaders/README.md b/samples/custom_dispatch/vulkan/shaders/README.md
index 00b71f8..fb00a88 100644
--- a/samples/custom_dispatch/vulkan/shaders/README.md
+++ b/samples/custom_dispatch/vulkan/shaders/README.md
@@ -145,3 +145,19 @@
         --input=8xf32=4 \
         /tmp/example.vmfb
     ```
+
+## Custom Kernel Match and Replace Scripting Instructions
+
+This is a flow for authoring custom dispatches externally alongside match and
+replace logic that can be fed directly into a pre-built version of the compiler.
+
+In addition to the above steps, when compiling the module, pass in both the
+target module and the transform library implementing the matcher + kernel.  
+
+    ```
+    iree-compile \
+        --iree-hal-executable-object-search-path=../iree-build/ \
+        samples/custom_dispatch/vulkan/shaders/example_transform.mlir \
+        --iree-preprocessing-transform-library=samples/custom_dispatch/vulkan/shaders/example_transform_spec.mlir \
+        -o=/tmp/example.vmfb
+    ```
diff --git a/samples/custom_dispatch/vulkan/shaders/example_transform.mlir b/samples/custom_dispatch/vulkan/shaders/example_transform.mlir
new file mode 100644
index 0000000..82662ac
--- /dev/null
+++ b/samples/custom_dispatch/vulkan/shaders/example_transform.mlir
@@ -0,0 +1,74 @@
+// RUN: iree-compile %s \
+// RUN:     --iree-hal-executable-object-search-path=$IREE_BINARY_DIR \
+// RUN:     --iree-preprocessing-transform-spec-filename=%p/example_transform_spec.mlir | \
+// RUN: iree-run-module \
+// RUN:     --device=vulkan \
+// RUN:     --module=- \
+// RUN:     --function=mixed_invocation \
+// RUN:     --input=1x128xf32=4 \
+// RUN:     --input=1x128xf32=3 | \
+// RUN: FileCheck %s
+
+// The configuration used for executable compilation.
+// This lets the compiler and runtime know the format and requirements of the
+// executable binaries produced and multiple variants with differing formats
+// and compilation options (architectures, etc) can be embedded for runtime
+// selection.
+// HACK: Currently this must match EXACTLY with the executable target for the
+// custom kernel. For things to be truly portable, we need to be able to compare
+// executable configurations.
+#spirv_target = #hal.executable.target<"vulkan", "vulkan-spirv-fb", {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.3, [Shader, GroupNonUniform, GroupNonUniformArithmetic, GroupNonUniformBallot],
+                     [SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
+    #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64], subgroup_size = 64>
+  >
+}>
+
+// The target devices that the program will run on. We can compile and run with
+// multiple targets, but this example is maintaining an implicit requirement
+// that the custom kernel being spliced in is supported by the target device,
+// hence we only support vulkan here. It is possible to hand author a custom
+// kernel that supports multiple targets by specifying an object per-target, but
+// that requires authoring the kernel for multiple targets.
+#vulkan_target = #hal.device.target<"vulkan", {
+  executable_targets = [#spirv_target],
+  // HACK: Vulkan target currently uses the legacy synchronous execution model.
+  legacy_sync
+}>
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0)>
+module @example attributes {hal.device.targets = [#vulkan_target]} {
+
+  // CHECK-LABEL: EXEC @mixed_invocation
+  func.func @mixed_invocation(%arg0: tensor<1x128xf32>, %arg1: tensor<1x128xf32>) -> tensor<1xi64> {
+    // Code gen some other ops - these will interleave with the matched and
+    // replaced ones but naturally won't be able to fuse with them.
+    %add = arith.addf %arg0, %arg1 : tensor<1x128xf32>
+
+    %c0_i64 = arith.constant 0 : i64
+    %cst = arith.constant 0xFF800000 : f32
+    %1 = tensor.empty() : tensor<1xi64>
+    %2 = linalg.fill ins(%c0_i64 : i64) outs(%1 : tensor<1xi64>) -> tensor<1xi64>
+    %3 = tensor.empty() : tensor<1xf32>
+    %4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<1xf32>) -> tensor<1xf32>
+    // Argmax that is the target for the custom kernel. Note that this operation
+    // only has uses for a single result and takes a single input.
+    %5:2 = linalg.generic {indexing_maps = [#map, #map1, #map1],
+                           iterator_types = ["parallel", "reduction"]}
+                           ins(%add : tensor<1x128xf32>)
+                           outs(%4, %2 : tensor<1xf32>, tensor<1xi64>) {
+    ^bb0(%in: f32, %out: f32, %out_0: i64):
+      %6 = linalg.index 1 : index
+      %7 = arith.index_cast %6 : index to i64
+      %8 = arith.maximumf %in, %out : f32
+      %9 = arith.cmpf ogt, %in, %out : f32
+      %10 = arith.select %9, %7, %out_0 : i64
+      linalg.yield %8, %10 : f32, i64
+    } -> (tensor<1xf32>, tensor<1xi64>)
+
+    // CHECK: 1xi64=0
+    return %5#1 : tensor<1xi64>
+  }
+}  // module
diff --git a/samples/custom_dispatch/vulkan/shaders/example_transform_spec.mlir b/samples/custom_dispatch/vulkan/shaders/example_transform_spec.mlir
new file mode 100644
index 0000000..065b795
--- /dev/null
+++ b/samples/custom_dispatch/vulkan/shaders/example_transform_spec.mlir
@@ -0,0 +1,171 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+// The configuration used for executable compilation.
+// This specifies the device configurations that support this custom kernel.
+#spirv_target = #hal.executable.target<"vulkan", "vulkan-spirv-fb", {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.3, [Shader, GroupNonUniform, GroupNonUniformArithmetic, GroupNonUniformBallot],
+                     [SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
+    #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64], subgroup_size = 64>
+  >
+}>
+
+module attributes {transform.with_named_sequence} {
+  func.func private @argmax_1d_f32_entry_point(%arg0: tensor<1x?xf32>) -> tensor<1xi64> {
+    %c1 = arith.constant 1 : index
+    %dim = tensor.dim %arg0, %c1 : tensor<1x?xf32>
+    // Note: This is not safe if the dim size exceeds INT32_MAX. To pass a 64
+    // bit value it must be broken down into two 32-bit values for the high and
+    // low bits.
+    %dim_i32 = arith.index_cast %dim : index to i32
+    // Inline external dispatch that conforms to the ABI that the kernel
+    // requires. This is the primary reason for the surrounding function as
+    // details like tensor shape and push constants need to line up after
+    // splicing in the custom dispatch. This allows the kernel author to manage
+    // such details by hand without needing the rewrite patterns to worry about
+    // things like order of push constants.
+    %4 = hal.dispatch.extern "main"[%dim](%dim_i32, %arg0) : (i32, tensor<1x?xf32>{%dim}) -> tensor<1xi64>
+      count(%device: !hal.device, %workload: index) -> (index, index, index) {
+        %c1_0 = arith.constant 1 : index
+        hal.return %c1_0, %c1_0, %c1_0 : index, index, index
+      }   
+      layout(#hal.pipeline.layout<push_constants = 1, sets = [
+        <0, bindings = [
+            <0, storage_buffer, ReadOnly>,
+            <1, storage_buffer>
+        ]>
+      ]>)
+      bindings([
+        #hal.interface.binding<0, 0>, 
+        #hal.interface.binding<0, 1>
+      ])  
+      objects({
+        #spirv_target ordinal(0) = [ 
+          #hal.executable.object<{
+            path = "samples/custom_dispatch/vulkan/shaders/one_workgroup_argmax_subgroup_f32.spv"
+          }>
+        ]
+      })
+    return %4 : tensor<1xi64>
+  }
+
+  // Custom matcher for argmax operations equivalent to the custom kernel. This
+  // matcher will be run one-by-one on all operations contained within the
+  // target function. On success, it will return the handle to the matched
+  // argmax operation.
+  transform.named_sequence @match_argmax(%generic: !transform.any_op {transform.readonly}) -> (!transform.any_op) {
+    // Fail fast on non-linalg generics.
+    transform.match.operation_name %generic ["linalg.generic"] : !transform.any_op
+    %matched = transform.match.structured failures(propagate) %generic : (!transform.any_op) -> (!transform.any_op) {
+    ^bb1(%argmax: !transform.any_op):
+      // Verify that the rank (i.e. number of loops) of the linalg op is 2,
+      // with one parallel iterator and one reduction iterator.
+      // TODO: Add optionality for the parallel dimensions.
+      %c2 = transform.param.constant 2 : i64 -> !transform.param<i64>
+      %rank = transform.match.structured.rank %argmax : (!transform.any_op) -> !transform.param<i64>
+      transform.match.param.cmpi eq %rank, %c2 : !transform.param<i64>
+      transform.match.structured.dim %argmax[0] {parallel} : !transform.any_op
+      transform.match.structured.dim %argmax[-1] {reduction} : !transform.any_op
+
+      // Verify a single input (target vector to compute the argmax of) and two
+      // outputs, one for the maximum value and one for the index.
+      %c1 = transform.param.constant 1 : i64 -> !transform.param<i64>
+      %n_inputs = transform.match.structured.num_inputs %argmax : (!transform.any_op) -> !transform.param<i64>
+      transform.match.param.cmpi eq %n_inputs, %c1 : !transform.param<i64>
+      %n_outputs = transform.match.structured.num_inits %argmax : (!transform.any_op) -> !transform.param<i64>
+      transform.match.param.cmpi eq %n_outputs, %c2 : !transform.param<i64>
+  
+      transform.match.structured.yield %argmax : !transform.any_op 
+    }
+
+    // Verify the operand shapes of the linalg op. For example, in the below,
+    // dim 0 must be statically 1, and dim 1 must be statically divisible by 64.
+    %in0 = transform.get_operand %matched[0] : (!transform.any_op) -> !transform.any_value
+    transform.iree.match.cast_compatible_type %in0 = tensor<1x?xf32> : !transform.any_value
+    transform.iree.match.dim_is_multiple_of %in0[1], 64 : !transform.any_value
+    %out0 = transform.get_operand %matched[1] : (!transform.any_op) -> !transform.any_value
+    transform.iree.match.cast_compatible_type %out0 = tensor<1xf32> : !transform.any_value
+    %out1 = transform.get_operand %matched[2] : (!transform.any_op) -> !transform.any_value
+    transform.iree.match.cast_compatible_type %out1 = tensor<1xi64> : !transform.any_value
+
+    // Verify the region of the argmax op. This does a structural comparison of
+    // region(s) of the payload operation against the single operation contained
+    // within the body of this operation. This does no verification of other
+    // input types/attributes. This is because typically for kernel matching,
+    // the most important part to get exactly right is the inner loop. Otherwise
+    // small variations to shape information and iterator counts and such are
+    // better suited for more general matchers.
+    transform.iree.match.regions %matched : !transform.any_op {
+      ^bb0(%target: tensor<1x?xf32>, %empty_max: tensor<1xf32>, %empty_idx: tensor<1xi64>):
+        %5:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                                                affine_map<(d0, d1) -> (d0)>,
+                                                affine_map<(d0, d1) -> (d0)>],
+                               iterator_types = ["parallel", "reduction"]}
+                               ins(%target : tensor<1x?xf32>)
+                               outs(%empty_max, %empty_idx : tensor<1xf32>, tensor<1xi64>) {
+        ^bb0(%in: f32, %out: f32, %out_0: i64):
+          %6 = linalg.index 1 : index
+          %7 = arith.index_cast %6 : index to i64
+          %8 = arith.maximumf %in, %out : f32
+          %9 = arith.cmpf ogt, %in, %out : f32
+          %10 = arith.select %9, %7, %out_0 : i64
+          linalg.yield %8, %10 : f32, i64
+        } -> (tensor<1xf32>, tensor<1xi64>)
+    }
+    transform.yield %generic : !transform.any_op
+  }
+
+  // Rewrite callback for `transform.foreach_match`. The input signature for
+  // this sequence must match exactly with the outputs of the matcher. In this
+  // case we just take the argmax as an input, import the entry point for the
+  // custom kernel authored above, and replace the users of the argmax with a
+  // call to the function.
+  transform.named_sequence @cast_and_call_argmax(%argmax: !transform.any_op {transform.readonly}) {
+    %module = transform.iree.get_nearest_symbol_table %argmax : (!transform.any_op) -> !transform.any_op
+    %func = transform.iree.import_symbol @argmax_1d_f32_entry_point into %module : (!transform.any_op) -> !transform.any_op
+    %ins = transform.get_operand %argmax[0] : (!transform.any_op) -> !transform.any_value
+    %outs = transform.get_result %argmax[1] : (!transform.any_op) -> !transform.any_value
+    transform.func.cast_and_call %func(%ins) -> %outs before %argmax {
+          // This specifies how to resolve type mismatches between the arguments
+          // of the function and the inputs to the argmax. In this example, the
+          // only casts this will generate are same-rank tensor casts that drop
+          // static information.
+          transform.type_conversion.tensor.cast_shape_dynamic_dims
+      } : (!transform.any_op, !transform.any_value, !transform.any_value, !transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+
+  // Entry point for the transform interpreter, nested on the full module. This
+  // is because the rewrites needed for importing the custom kernel needs to
+  // add a new symbol to the module's symbol table.
+  transform.named_sequence @__transform_main(%module: !transform.any_op) {
+    // Gather the set of functions within the module.
+    %funcs = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op   
+    // For each function in the module, run the matcher on all contained
+    // operations.
+    transform.foreach %funcs : !transform.any_op {
+      ^bb1(%func: !transform.any_op):
+        transform.foreach_match in %func
+            // <matcher name> -> <rewriter name>
+            // Multiple matcher-action pairs can be specified comma separated,
+            // here we are only doing a single kind of match and replace.
+            //
+            // Note that the operations within the module are walked in
+            // post-order, meaning actions must be very careful in their
+            // replacements not to modify successors of operations. Nested
+            // regions and DAG roots will be visited last so it is safest to
+            // do matching + replacement on the root of the DAG rather than
+            // trying to look ahead. The other option is to avoid dce/cse until
+            // after the walk is complete.
+            @match_argmax -> @cast_and_call_argmax
+          : (!transform.any_op) -> (!transform.any_op)
+    }
+    // Cleanup now dead instances of argmax.
+    transform.apply_dce to %module : !transform.any_op
+    transform.yield
+  }
+}
diff --git a/samples/custom_dispatch/vulkan/shaders/one_workgroup_argmax_subgroup_f32.glsl b/samples/custom_dispatch/vulkan/shaders/one_workgroup_argmax_subgroup_f32.glsl
new file mode 100644
index 0000000..c4e1e26
--- /dev/null
+++ b/samples/custom_dispatch/vulkan/shaders/one_workgroup_argmax_subgroup_f32.glsl
@@ -0,0 +1,57 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+// `ret = argmax(in)`
+//
+// Conforms to ABI:
+// #hal.pipeline.layout<push_constants = 1, sets = [
+//   <0, bindings = [
+//       <0, storage_buffer, ReadOnly>,
+//       <1, storage_buffer>
+//   ]>
+// ]>
+
+#version 450 core
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_KHR_shader_subgroup_arithmetic : enable
+#extension GL_KHR_shader_subgroup_ballot : enable
+
+layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
+
+layout(set=0, binding=0) buffer InputBuffer { float data[]; } Input;
+layout(set=0, binding=1) buffer OutputBuffer { uvec2 data; } Output;
+
+layout(push_constant) uniform PushConstants { uint totalCount; }; // Total number of scalars
+
+// Each workgroup contains just one subgroup.
+
+void main() {
+  uint laneID = gl_LocalInvocationID.x;
+  uint laneCount = gl_WorkGroupSize.x;
+
+  float laneMax = Input.data[laneID];
+  uint laneResult = 0;
+
+  uint numBatches = totalCount / (laneCount);
+  for (int i = 1; i < numBatches; ++i) {
+    uint idx = laneCount * i + laneID;
+    float new_in = Input.data[idx];
+    laneResult = new_in > laneMax ? idx : laneResult;
+    laneMax = max(laneMax, new_in);
+  }
+
+  // Final reduction with one subgroup
+  float wgMax = subgroupMax(laneMax);
+
+  // Find the smallest thread holding the maximum value.
+  bool eq = wgMax == laneMax;
+  uvec4 ballot = subgroupBallot(eq);
+  uint lsb = subgroupBallotFindLSB(ballot);
+
+  uint upper32bits = 0;
+  if (laneID == lsb) Output.data = uvec2(laneResult, upper32bits);
+}
+