Add sample to match subgraph and call implementation in system plugin. (#16356)

This adds a sample that uses a transform dialect script to match an MLP
DAG and replaces with a dispatch that uses an external function for the
actual implementation. The implementation is provided using as system
plugin.

This also refactors the matcher to allow for matching implicit capture
within region of ops, as long as the captured value is part of input
or the matched ops.

Follow up to this will be to add some transform dialect ops that handle
some of the boiler plate.
diff --git a/compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.cpp b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.cpp
index 486ca26..8208e74 100644
--- a/compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.cpp
+++ b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/IRMapping.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/Value.h"
 
@@ -153,13 +154,14 @@
 // Compares the regions between two operations in lockstep for equality.
 static DiagnosedSilenceableFailure
 compareOperationRegions(transform::TransformOpInterface transformOp,
+                        OperationEquivalenceCache &cache, IRMapping &mapping,
                         Operation *target, Operation *payload) {
   if (target->getNumRegions() != payload->getNumRegions()) {
     return transformOp.emitSilenceableError() << "region count mismatch";
   }
   for (auto [r0, r1] :
        llvm::zip_equal(target->getRegions(), payload->getRegions())) {
-    if (!isStructurallyEquivalentTo(r0, r1)) {
+    if (!isStructurallyEquivalentTo(cache, r0, r1, mapping)) {
       return transformOp.emitSilenceableError()
              << "target op does not match specified body";
     }
@@ -204,8 +206,7 @@
       return transformOp.emitSilenceableError() << "operand type mismatch";
     }
   }
-
-  return compareOperationRegions(transformOp, target, payload);
+  return DiagnosedSilenceableFailure::success();
 }
 
 DiagnosedSilenceableFailure
@@ -222,13 +223,10 @@
   // Maps from target/payload op to the order in which they were first
   // processed. This is used to verify that two uses actually point to the
   // same node in the dag.
-  llvm::MapVector<Operation *, int64_t> targetDagOrder;
-  llvm::MapVector<Operation *, int64_t> payloadDagOrder;
-  int64_t index = 0;
+  llvm::MapVector<Operation *, Operation *> targetToPayloadMapping;
 
-  auto compareDagOrder = [&](Operation *target, Operation *payload) {
-    return targetDagOrder.find(target) == payloadDagOrder.find(payload);
-  };
+  // Step 1. First just walk from root op "upwards" to match basic
+  // producer-consumer match (without checking regions).
 
   // Populate the paired worklist with the current target and payload root ops.
   SmallVector<Operation *> targetWorklist = {targetDagRoot};
@@ -248,15 +246,14 @@
 
     // Verify that if already processed, both operations are at the same
     // position.
-    if (targetDagOrder.contains(targetOp) ||
-        payloadDagOrder.contains(payloadOp)) {
-      if (!compareDagOrder(targetOp, payloadOp)) {
+    if (targetToPayloadMapping.contains(targetOp)) {
+      if (targetToPayloadMapping.lookup(targetOp) != payloadOp) {
         return emitSilenceableError() << "dag mismatch";
       }
       continue;
     }
 
-    // Verify general operation equality (name, attributes, regions).
+    // Verify general operation equality (name, attributes).
     DiagnosedSilenceableFailure diag =
         compareCastCompatibleOperations(*this, targetOp, payloadOp);
     if (!diag.succeeded()) {
@@ -290,21 +287,38 @@
       // Check whether the producer was already processed, and if so make sure
       // the target and payload match.
       Operation *targetDefiningOp = targetOperand.getDefiningOp();
-      if (targetDagOrder.contains(targetDefiningOp) ||
-          payloadDagOrder.contains(payloadDefiningOp)) {
-        if (!compareDagOrder(targetDefiningOp, payloadDefiningOp)) {
+      if (targetToPayloadMapping.contains(targetDefiningOp)) {
+        if (targetToPayloadMapping.lookup(targetDefiningOp) !=
+            payloadDefiningOp) {
           return emitSilenceableError() << "dag mismatch";
         }
         continue;
       }
+
       // Pop the producer of this value onto the worklist.
       targetWorklist.push_back(targetDefiningOp);
       payloadWorklist.push_back(payloadDefiningOp);
     }
 
     // Mark the current target + payload as processed.
-    targetDagOrder[targetOp] = index;
-    payloadDagOrder[payloadOp] = index++;
+    targetToPayloadMapping[targetOp] = payloadOp;
+  }
+
+  // Step 2. Now check regions of all the ops match.
+  OperationEquivalenceCache cache(getContext());
+  auto mapping = cache.acquireMapping();
+  for (auto [targetOp, payloadOp] : llvm::reverse(targetToPayloadMapping)) {
+    DiagnosedSilenceableFailure diag =
+        compareOperationRegions(*this, cache, *mapping, targetOp, payloadOp);
+    if (!diag.succeeded()) {
+      diag.attachNote() << "While processing region of operation "
+                        << *payloadOp;
+      return diag;
+    }
+    for (auto [targetOpResult, payloadOpResult] :
+         llvm::zip_equal(targetOp->getResults(), payloadOp->getResults())) {
+      mapping->map(targetOpResult, payloadOpResult);
+    }
   }
 
   // Verify that all input arguments were successfully matched.
@@ -387,7 +401,10 @@
     Operation *current, transform::TransformResults &results,
     transform::TransformState &state) {
   Operation *comparisonTarget = &getRegion().front().front();
-  return compareOperationRegions(*this, comparisonTarget, current);
+  OperationEquivalenceCache cache(current->getContext());
+  auto mapping = cache.acquireMapping();
+  return compareOperationRegions(*this, cache, *mapping, comparisonTarget,
+                                 current);
 }
 
 LogicalResult IREE::transform_dialect::MatchRegionsOp::verify() {
diff --git a/compiler/src/iree/compiler/Utils/EquivalenceUtils.cpp b/compiler/src/iree/compiler/Utils/EquivalenceUtils.cpp
index bb5079c..debfe47 100644
--- a/compiler/src/iree/compiler/Utils/EquivalenceUtils.cpp
+++ b/compiler/src/iree/compiler/Utils/EquivalenceUtils.cpp
@@ -107,14 +107,11 @@
 }
 
 static bool isStructurallyEquivalentTo(OperationEquivalenceCache &cache,
-                                       Region &lhs, Region &rhs,
-                                       IRMapping &parentMapping);
-static bool isStructurallyEquivalentTo(OperationEquivalenceCache &cache,
                                        Operation &lhs, Operation &rhs,
                                        IRMapping &parentMapping);
 
-static bool isStructurallyEquivalentTo(OperationEquivalenceCache &cache,
-                                       Region &lhs, Region &rhs) {
+bool isStructurallyEquivalentTo(OperationEquivalenceCache &cache, Region &lhs,
+                                Region &rhs) {
   auto mapping = cache.acquireMapping();
   return isStructurallyEquivalentTo(cache, lhs, rhs, *mapping);
 }
@@ -156,9 +153,8 @@
 //
 // TODO(#3996): upstream into mlir::OperationEquivalence if this works.
 // TODO(#3996): add symbol ref comparison (add to IRMapping).
-static bool isStructurallyEquivalentTo(OperationEquivalenceCache &cache,
-                                       Region &lhs, Region &rhs,
-                                       IRMapping &mapping) {
+bool isStructurallyEquivalentTo(OperationEquivalenceCache &cache, Region &lhs,
+                                Region &rhs, IRMapping &mapping) {
   auto &lhsRegionEntry = cache.getRegion(&lhs);
   auto &rhsRegionEntry = cache.getRegion(&rhs);
   if (lhsRegionEntry.blocks.size() != rhsRegionEntry.blocks.size())
@@ -186,6 +182,7 @@
     const auto &rhsBlockEntry = cache.getBlock(rhsBlock);
     if (lhsBlockEntry.count != rhsBlockEntry.count)
       return false;
+
     for (auto [lhsOp, rhsOp] : llvm::zip_equal(lhsBlock->getOperations(),
                                                rhsBlock->getOperations())) {
       if (!isStructurallyEquivalentTo(cache, lhsOp, rhsOp, mapping))
diff --git a/compiler/src/iree/compiler/Utils/EquivalenceUtils.h b/compiler/src/iree/compiler/Utils/EquivalenceUtils.h
index fadadd0..9bf8fdb 100644
--- a/compiler/src/iree/compiler/Utils/EquivalenceUtils.h
+++ b/compiler/src/iree/compiler/Utils/EquivalenceUtils.h
@@ -75,6 +75,8 @@
 //
 // Uses |cache| to memoize operation information to improve repeated queries.
 // Callers must not mutate any IR that may be in the cache between queries.
+bool isStructurallyEquivalentTo(OperationEquivalenceCache &cache, Region &lhs,
+                                Region &rhs, IRMapping &mapping);
 bool isStructurallyEquivalentTo(OperationEquivalenceCache &cache,
                                 Operation &lhs, Operation &rhs);
 
diff --git a/samples/custom_dispatch/cpu/mlp_plugin/CMakeLists.txt b/samples/custom_dispatch/cpu/mlp_plugin/CMakeLists.txt
new file mode 100644
index 0000000..507fdd7
--- /dev/null
+++ b/samples/custom_dispatch/cpu/mlp_plugin/CMakeLists.txt
@@ -0,0 +1,54 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+if(NOT IREE_TARGET_BACKEND_LLVM_CPU OR
+   NOT IREE_HAL_DRIVER_LOCAL_SYNC OR
+   NOT IREE_HAL_EXECUTABLE_LOADER_EMBEDDED_ELF)
+  return()
+endif()
+
+# system-library plugin mechanism using the system dynamic library loader.
+if(IREE_HAL_EXECUTABLE_PLUGIN_SYSTEM_LIBRARY)
+
+  
+add_library(iree_samples_custom_dispatch_cpu_mlp_plugin SHARED
+  mlp_plugin.c
+)
+target_include_directories(iree_samples_custom_dispatch_cpu_mlp_plugin
+  PRIVATE
+    ${IREE_SOURCE_DIR}/runtime/src/
+)
+
+# NOTE: this is only required because we want this sample to run on all
+# platforms without needing to change the library name (libfoo.so/foo.dll).
+set_target_properties(iree_samples_custom_dispatch_cpu_mlp_plugin
+  PROPERTIES
+    WINDOWS_EXPORT_ALL_SYMBOLS ON
+    PREFIX ""
+    OUTPUT_NAME "mlp_plugin"
+)
+
+add_dependencies(iree-sample-deps
+  iree_samples_custom_dispatch_cpu_mlp_plugin
+  iree_samples_custom_dispatch_cpu_system_plugin)
+
+
+iree_lit_test_suite(
+  NAME
+    mlp_example
+  SRCS
+    "mlp.mlir"
+  TOOLS
+    FileCheck
+    iree-compile
+    iree-run-module
+    iree_samples_custom_dispatch_cpu_mlp_plugin
+  LABELS
+    "driver=local-sync"
+    "hostonly"
+)
+
+endif(IREE_HAL_EXECUTABLE_PLUGIN_SYSTEM_LIBRARY)
diff --git a/samples/custom_dispatch/cpu/mlp_plugin/mlp.mlir b/samples/custom_dispatch/cpu/mlp_plugin/mlp.mlir
new file mode 100644
index 0000000..fbe68f2
--- /dev/null
+++ b/samples/custom_dispatch/cpu/mlp_plugin/mlp.mlir
@@ -0,0 +1,62 @@
+// RUN: iree-compile --iree-preprocessing-transform-spec-filename=%p/mlp_spec.mlir  %s | \
+// RUN: iree-run-module --device=local-sync \
+// RUN:     --executable_plugin=$IREE_BINARY_DIR/samples/custom_dispatch/cpu/mlp_plugin/mlp_plugin$IREE_DYLIB_EXT \
+// RUN:     --module=- \
+// RUN:     --function=mlp_invocation \
+// RUN:     --input="2x2xf32=[[2.0, 2.0], [-2.0, -2.0]]" \
+// RUN:     --input="2x2xf32=[[3.0 -3.0], [3.0, -3.0]]"
+
+// The implementation of MLP is matched using a transform dialect script and is forwarded to a system plugin.
+
+#x86_64_target = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {
+  data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+  native_vector_size = 32 : index,
+  target_triple = "x86_64-none-elf"
+}>
+
+// The target devices that the program will run on. We can compile and run with
+// multiple targets, but this example is maintaining an implicit requirement
+// that the custom kernel being spliced in is supported by the target device,
+// hence we only support llvm-cpu here.
+#cpu_target = #hal.device.target<"llvm-cpu", {
+  executable_targets = [
+    #x86_64_target
+  ]
+}>
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module @example attributes {hal.device.targets = [#cpu_target]} {
+
+  // CHECK-LABEL: EXEC @mlp_invocation
+  //       CHECK: [Plugin]: M = 2, N = 2, K = 2
+  //       CHECK: 2x2xf32=[-12 0][0 -12]
+  func.func @mlp_invocation(%lhs: tensor<?x?xf32>,
+                            %rhs: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %cst = arith.constant 0.0 : f32
+    %dim0 = tensor.dim %lhs, %c0 : tensor<?x?xf32>
+    %dim1 = tensor.dim %rhs, %c1 : tensor<?x?xf32>
+    %empty = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
+    %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<?x?xf32>) -> tensor<?x?xf32>
+    %matmul = linalg.matmul ins(%lhs, %rhs : tensor<?x?xf32>, tensor<?x?xf32>)
+        outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
+    %relu = linalg.generic {
+        indexing_maps = [#map, #map],
+        iterator_types = ["parallel", "parallel"]}
+        ins(%matmul : tensor<?x?xf32>) outs(%empty : tensor<?x?xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+        %0 = arith.maximumf %b0, %cst : f32
+        linalg.yield %0 : f32
+      } -> tensor<?x?xf32>
+    %neg = linalg.generic {
+        indexing_maps = [#map, #map],
+        iterator_types  = ["parallel", "parallel"]}
+        ins(%relu : tensor<?x?xf32>) outs(%empty : tensor<?x?xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+        %0 = arith.negf %b0 : f32
+        linalg.yield %0 : f32
+    } -> tensor<?x?xf32>
+    return %neg : tensor<?x?xf32>
+  }
+}  // module
diff --git a/samples/custom_dispatch/cpu/mlp_plugin/mlp_plugin.c b/samples/custom_dispatch/cpu/mlp_plugin/mlp_plugin.c
new file mode 100644
index 0000000..35b22e4
--- /dev/null
+++ b/samples/custom_dispatch/cpu/mlp_plugin/mlp_plugin.c
@@ -0,0 +1,215 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+// Demonstrates an mlp example with the implementation of MLP provided
+// using system linked plugin exporting a single `mlp_external`
+// function.  See samples/custom_dispatch/cpu/plugin/system_plugin.c
+// for more information about system plugins and their caveats.
+
+#include <inttypes.h>
+#include <stdio.h>
+
+// The only header required from IREE:
+#include "iree/hal/local/executable_plugin.h"
+
+// Stateful plugin instance.
+// There may be multiple of these in a process at a time, each with its own
+// load/unload pairing. We pass a pointer to this to all import calls via the
+// context argument.
+typedef struct {
+  iree_hal_executable_plugin_allocator_t host_allocator;
+  FILE* file;
+} mlp_plugin_t;
+
+// Helper function to resolve index [i][j] into location for given strides.
+size_t get_index(size_t i, size_t j, size_t offset, size_t stride0,
+                 size_t stride1) {
+  return offset + i * stride0 + j * stride1;
+}
+
+// `ret = mlp(lhs, rhs)`
+//
+// Conforms to ABI:
+// #hal.pipeline.layout<push_constants = 1, sets = [
+//   <0, bindings = [
+//       <0, storage_buffer, ReadOnly>,
+//       <1, storage_buffer, ReadOnly>,
+//       <2, storage_buffer>
+//   ]>
+// ]>
+// With a workgroup size of 64x1x1.
+//
+// |context| is whatever was set in out_fn_contexts. This could point to shared
+// state or each import can have its own context (pointer into some JIT lookup
+// table, etc). In this sample we pass the sample plugin pointer to all imports.
+//
+// |params_ptr| points to a packed struct of all results followed by all args
+// using native arch packing/alignment rules. Results should be set before
+// returning.
+//
+// Expects a return of 0 on success and any other value indicates failure.
+// Try not to fail!
+static int mlp_external(void* params_ptr, void* context, void* reserved) {
+  mlp_plugin_t* plugin = (mlp_plugin_t*)context;
+  typedef struct {
+    const float* restrict lhs;
+    const float* restrict lhs_aligned;
+    size_t lhs_offset;
+    size_t lhs_size0;
+    size_t lhs_size1;
+    size_t lhs_stride0;
+    size_t lhs_stride1;
+    const float* restrict rhs;
+    const float* restrict rhs_aligned;
+    size_t rhs_offset;
+    size_t rhs_size0;
+    size_t rhs_size1;
+    size_t rhs_stride0;
+    size_t rhs_stride1;
+    float* restrict result;
+    float* restrict result_aligned;
+    size_t result_offset;
+    size_t result_size0;
+    size_t result_size1;
+    size_t result_stride0;
+    size_t result_stride1;
+    int32_t M;
+    int32_t N;
+    int32_t K;
+  } params_t;
+  const params_t* params = (const params_t*)params_ptr;
+  fprintf(plugin->file, "[Plugin]: M = %d, N = %d, K = %d\n", params->M,
+          params->N, params->K);
+  for (int32_t i = 0; i < params->M; i++) {
+    for (int32_t j = 0; j < params->N; j++) {
+      float curr_result = 0.0;
+      for (int32_t k = 0; k < params->K; k++) {
+        size_t lhs_index = get_index(i, k, params->lhs_offset,
+                                     params->lhs_stride0, params->lhs_stride1);
+        size_t rhs_index = get_index(k, j, params->rhs_offset,
+                                     params->rhs_stride0, params->rhs_stride1);
+        curr_result += params->lhs[lhs_index] * params->rhs[rhs_index];
+      }
+      curr_result = curr_result < 0.0 ? 0.0 : curr_result;
+      size_t result_index =
+          get_index(i, j, params->result_offset, params->result_stride0,
+                    params->result_stride1);
+      params->result[result_index] = curr_result;
+    }
+  }
+  return 0;
+}
+
+// Called once for each plugin load and paired with a future call to unload.
+// Even in standalone mode we could allocate using environment->host_allocator,
+// set an out_self pointer, and parse parameters but here in system mode we can
+// do whatever we want.
+//
+// If any state is required it should be allocated and stored in |out_self|.
+// This self value will be passed to all future calls related to the particular
+// instance. Note that there may be multiple instances of a plugin in any
+// particular process and this must be thread-safe.
+static iree_hal_executable_plugin_status_t mlp_plugin_load(
+    const iree_hal_executable_plugin_environment_v0_t* environment,
+    size_t param_count, const iree_hal_executable_plugin_string_pair_t* params,
+    void** out_self) {
+  // Allocate the plugin state.
+  mlp_plugin_t* plugin = NULL;
+  iree_hal_executable_plugin_status_t status =
+      iree_hal_executable_plugin_allocator_malloc(
+          environment->host_allocator, sizeof(*plugin), (void**)&plugin);
+  if (status) return status;
+  plugin->host_allocator = environment->host_allocator;
+
+  // "Open standard out" simulating us doing some syscalls or other expensive
+  // stateful/side-effecting things.
+  plugin->file = stdout;
+
+  // Pass back the plugin instance that'll be passed to resolve.
+  *out_self = plugin;
+  return iree_hal_executable_plugin_ok_status();
+}
+
+// Called to free any plugin state allocated in load.
+static void mlp_plugin_unload(void* self) {
+  mlp_plugin_t* plugin = (mlp_plugin_t*)self;
+  iree_hal_executable_plugin_allocator_t host_allocator =
+      plugin->host_allocator;
+
+  // "Close standard out" simulating us doing some syscalls and other expensive
+  // stateful/side-effecting things.
+  fflush(plugin->file);
+  plugin->file = NULL;
+
+  // Free the plugin state using the same allocator it came from.
+  iree_hal_executable_plugin_allocator_free(host_allocator, plugin);
+}
+
+// Called to resolve one or more imports by symbol name.
+// See the plugin API header for more information. Note that some of the
+// functions may already be resolved and some may be optional.
+static iree_hal_executable_plugin_status_t mlp_plugin_resolve(
+    void* self, const iree_hal_executable_plugin_resolve_params_v0_t* params,
+    iree_hal_executable_plugin_resolution_t* out_resolution) {
+  mlp_plugin_t* plugin = (mlp_plugin_t*)self;
+  *out_resolution = 0;
+  bool any_required_not_found = false;
+  for (size_t i = 0; i < params->count; ++i) {
+    if (params->out_fn_ptrs[i]) continue;
+    const char* symbol_name = params->symbol_names[i];
+    bool is_optional =
+        iree_hal_executable_plugin_import_is_optional(symbol_name);
+    if (is_optional) ++symbol_name;
+    if (iree_hal_executable_plugin_strcmp(symbol_name, "mlp_external") == 0) {
+      params->out_fn_ptrs[i] = mlp_external;
+      params->out_fn_contexts[i] =
+          plugin;  // passing plugin to each import call
+    } else {
+      if (is_optional) {
+        *out_resolution |=
+            IREE_HAL_EXECUTABLE_PLUGIN_RESOLUTION_MISSING_OPTIONAL;
+      } else {
+        any_required_not_found = true;
+      }
+    }
+  }
+  return any_required_not_found
+             ? iree_hal_executable_plugin_status_from_code(
+                   IREE_HAL_EXECUTABLE_PLUGIN_STATUS_NOT_FOUND)
+             : iree_hal_executable_plugin_ok_status();
+}
+
+// Exported on the shared library and used by the runtime to query the plugin
+// interface. When statically linking the plugin this is just a function that
+// can be called and can have any name to allow for multiple plugins. When
+// dynamically linking the exported symbol must be exactly this with no C++
+// name mangling.
+IREE_HAL_EXECUTABLE_PLUGIN_EXPORT const iree_hal_executable_plugin_header_t**
+iree_hal_executable_plugin_query(
+    iree_hal_executable_plugin_version_t max_version, void* reserved) {
+  static const iree_hal_executable_plugin_header_t header = {
+      // Declares what library version is present: newer runtimes may support
+      // loading older plugins but newer plugins cannot load on older runtimes.
+      .version = IREE_HAL_EXECUTABLE_PLUGIN_VERSION_LATEST,
+      // Name and description are used for tracing/logging/diagnostics.
+      .name = "sample_system",
+      .description =
+          "system plugin sample "
+          "(custom_dispatch/cpu/plugin/mlp_plugin.c)",
+      .features = 0,
+      // Let the runtime know what sanitizer this plugin was compiled with.
+      .sanitizer = IREE_HAL_EXECUTABLE_PLUGIN_SANITIZER_KIND,
+  };
+  static const iree_hal_executable_plugin_v0_t plugin = {
+      .header = &header,
+      .load = mlp_plugin_load,
+      .unload = mlp_plugin_unload,
+      .resolve = mlp_plugin_resolve,
+  };
+  return max_version <= IREE_HAL_EXECUTABLE_PLUGIN_VERSION_LATEST
+             ? (const iree_hal_executable_plugin_header_t**)&plugin
+             : NULL;
+}
diff --git a/samples/custom_dispatch/cpu/mlp_plugin/mlp_spec.mlir b/samples/custom_dispatch/cpu/mlp_plugin/mlp_spec.mlir
new file mode 100644
index 0000000..72a9860
--- /dev/null
+++ b/samples/custom_dispatch/cpu/mlp_plugin/mlp_spec.mlir
@@ -0,0 +1,142 @@
+// Sample spec that matches an MLP example and forwards to 
+// an implementation implemented by a system plugin.
+// Is used along with samples/custom_dispatch/cpu/plugin/mlp.mlir
+
+#x86_64_target = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {
+  data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+  native_vector_size = 32 : index,
+  target_triple = "x86_64-none-elf"
+}>
+
+#cpu_target = #hal.device.target<"llvm-cpu", {
+  executable_targets = [
+    #x86_64_target
+  ]
+}>
+
+module attributes {transform.with_named_sequence} {
+
+  // Executable that stages call to the external functions.
+  hal.executable private @executable {
+    hal.executable.variant public @x86_64 target(#x86_64_target) {
+      hal.executable.export public @mlp ordinal(0)
+          layout(#hal.pipeline.layout<push_constants = 3, sets = [
+            <0, bindings = [
+              <0, storage_buffer, ReadOnly>,
+              <1, storage_buffer, ReadOnly>,
+              <2, storage_buffer>
+            ]>
+          ]>) {
+      ^bb0(%device : !hal.device):
+        %c1 = arith.constant 1 : index
+        hal.return %c1, %c1, %c1 : index, index, index
+      }
+      builtin.module {
+        func.func private @mlp_external(%lhs : memref<?x?xf32>, %rhs : memref<?x?xf32>, %result : memref<?x?xf32>, %m : i32, %n : i32, %k : i32)
+        func.func @mlp() {
+          %m_i32 = hal.interface.constant.load[0] : i32
+          %n_i32 = hal.interface.constant.load[1] : i32
+          %k_i32 = hal.interface.constant.load[2] : i32
+          %c0 = arith.constant 0 : index
+          %m = arith.index_cast %m_i32 : i32 to index
+          %n = arith.index_cast %n_i32 : i32 to index
+          %k = arith.index_cast %k_i32 : i32 to index
+          %lhs = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : memref<?x?xf32>{%m, %k}
+          %rhs = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<?x?xf32>{%k, %n}
+          %result = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<?x?xf32>{%m, %n}
+          func.call @mlp_external(%lhs, %rhs, %result, %m_i32, %n_i32, %k_i32) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>, i32, i32, i32) -> ()
+          return
+        }
+      }
+    }
+  }
+
+  func.func private @call_mlp(%lhs : tensor<?x?xf32>, %rhs : tensor<?x?xf32>, %init1 : tensor<?x?xf32>, %init2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %m = tensor.dim %lhs, %c0 : tensor<?x?xf32>
+    %n = tensor.dim %rhs, %c1 : tensor<?x?xf32>
+    %k = tensor.dim %lhs, %c1 : tensor<?x?xf32>
+    %m_i32 = arith.index_cast %m : index to i32
+    %n_i32 = arith.index_cast %n : index to i32
+    %k_i32 = arith.index_cast %k : index to i32
+
+    %mlp_result = flow.dispatch @executable::@x86_64::@mlp[](%lhs, %rhs, %m_i32, %n_i32, %k_i32) {
+      hal.interface.bindings = [
+        #hal.interface.binding<0, 0>,
+        #hal.interface.binding<0, 1>,
+        #hal.interface.binding<0, 2>
+      ],
+      // HACK: keep the executable live through DCE. Only required when
+      // using the automatic variant selection.
+      hal.executable.ref = [@executable]
+    } : (tensor<?x?xf32>{%m, %k}, tensor<?x?xf32>{%k, %n}, i32, i32, i32) -> tensor<?x?xf32>{%m, %n}  
+    return %mlp_result : tensor<?x?xf32>    
+  }
+
+  transform.named_sequence @match_mlp(%root: !transform.any_op {transform.readonly}) -> (!transform.any_value, !transform.any_value) {
+    %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %root {
+      ^bb0(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>, %init1 : tensor<?x?xf32>, %init2 : tensor<?x?xf32>):
+        %cst = arith.constant 0.0 : f32
+        %fill = linalg.fill ins(%cst : f32) outs(%init1 : tensor<?x?xf32>) -> tensor<?x?xf32>
+        %matmul = linalg.matmul
+            ins(%lhs, %rhs : tensor<?x?xf32>, tensor<?x?xf32>)
+                outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
+        %relu = linalg.generic {
+            indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                             affine_map<(d0, d1) -> (d0, d1)>],
+            iterator_types = ["parallel", "parallel"]}
+            ins(%matmul : tensor<?x?xf32>)
+            outs(%init2 : tensor<?x?xf32>) {
+          ^bb0(%b0 : f32, %b1 : f32):
+            %0 = arith.maximumf %b0, %cst : f32
+            linalg.yield %0 : f32
+          } -> tensor<?x?xf32>
+      } : (!transform.any_op) -> (!transform.any_value, !transform.any_value)
+    transform.yield %ins, %outs : !transform.any_value, !transform.any_value
+  }
+
+
+  // Rewrite callback for `transform.foreach_match`. The input signature for
+  // this sequence must match exactly with the outputs of the matcher. In this
+  // case the matcher returns the inputs and outputs to the matched dag directly
+  // so we just insert a call to the hand authored function above.
+  transform.named_sequence @cast_and_call_dag(%ins: !transform.any_value {transform.readonly},
+                                              %out: !transform.any_value {transform.readonly}) {
+    %root = transform.get_defining_op %out : (!transform.any_value) -> !transform.any_op
+    %module = transform.iree.get_nearest_symbol_table %root : (!transform.any_op) -> !transform.any_op
+    %executable = transform.iree.import_symbol @executable into %module if undefined : (!transform.any_op) -> !transform.any_op
+    %func = transform.iree.import_symbol @call_mlp into %module if undefined : (!transform.any_op) -> !transform.any_op
+    transform.func.cast_and_call %func(%ins) -> %out after %root {
+      // This specifies how to resolve type mismatches between the arguments
+      // of the function and the inputs from the matcher. In this example,
+      // the only casts this will generate are same-rank tensor casts that
+      // drop static information.
+      transform.type_conversion.tensor.cast_shape_dynamic_dims
+    } : (!transform.any_op, !transform.any_value, !transform.any_value, !transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+
+  // Entry point for the transform interpreter, nested on the full module. This
+  // is because the rewrites needed for importing the custom kernel needs to
+  // add a new symbol to the module's symbol table.
+  transform.named_sequence @__transform_main(%module: !transform.any_op) {
+    // Gather the set of functions within the module.
+    %funcs = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op   
+    // For each function in the module, run the matcher on all contained
+    // operations.
+    transform.foreach %funcs : !transform.any_op {
+      ^bb1(%func: !transform.any_op):
+        transform.foreach_match in %func
+          // <matcher name> -> <rewriter name>
+          // Multiple matcher-action pairs can be specified comma separated,
+          // here we are only doing a single kind of match and replace.
+          @match_mlp -> @cast_and_call_dag
+        : (!transform.any_op) -> (!transform.any_op)
+    }
+    // Cleanup leftover dead code; cast_and_call does not do replacement, only
+    // rewires uses.
+    transform.apply_dce to %module : !transform.any_op
+    transform.yield
+  }
+}
diff --git a/samples/custom_dispatch/cpu/plugin/CMakeLists.txt b/samples/custom_dispatch/cpu/plugin/CMakeLists.txt
index 021ff0a..6054fd5 100644
--- a/samples/custom_dispatch/cpu/plugin/CMakeLists.txt
+++ b/samples/custom_dispatch/cpu/plugin/CMakeLists.txt
@@ -30,7 +30,8 @@
     OUTPUT_NAME "system_plugin"
 )
 
-add_dependencies(iree-sample-deps iree_samples_custom_dispatch_cpu_system_plugin)
+add_dependencies(iree-sample-deps
+  iree_samples_custom_dispatch_cpu_system_plugin)
 
 iree_lit_test_suite(
   NAME