Add samples for custom kernel match+replace scripts (#16150)
Custom match and replace scripts is a workflow for injecting custom
dispatches into a module without the need for any surrounding compiler
infrastructure for building the dispatches. Custom kernels are paired
with an externally authored script that matches a subgraph and replaces
the subgraph with a call to the kernel, all delivered from independently valid IR.
The examples here use the transform dialect to do this by adding a plugin
point during preprocessing that will run the user provided specification.
The flow demonstrated here requires authoring two functions per kernel
alongside some additional boilerplate.
1. A `func.func @my_kernel(...)` that takes (typically tensor) arguments
and includes the call to the custom dispatch inline. This can use any
of the other custom dispatch approaches.
2. A `transform.named_sequence @my_matcher` that describes the
compatible subgraph to match.
diff --git a/build_tools/cmake/ctest_all.sh b/build_tools/cmake/ctest_all.sh
index 755aac7..ebdd89d 100755
--- a/build_tools/cmake/ctest_all.sh
+++ b/build_tools/cmake/ctest_all.sh
@@ -132,6 +132,7 @@
excluded_tests+=(
"iree/samples/custom_dispatch/cpu/embedded/example_hal.mlir.test"
"iree/samples/custom_dispatch/cpu/embedded/example_stream.mlir.test"
+ "iree/samples/custom_dispatch/cpu/embedded/example_transform.mlir.test"
)
ctest_args=(
diff --git a/compiler/src/iree/compiler/Pipelines/Options.cpp b/compiler/src/iree/compiler/Pipelines/Options.cpp
index f24086d..ac8cf2e 100644
--- a/compiler/src/iree/compiler/Pipelines/Options.cpp
+++ b/compiler/src/iree/compiler/Pipelines/Options.cpp
@@ -213,6 +213,12 @@
llvm::cl::desc("Textual description of the pass pipeline to run before "
"running normal IREE compilation pipelines"),
llvm::cl::cat(category));
+ binder.opt<std::string>(
+ "iree-preprocessing-transform-spec-filename",
+ preprocessingTransformSpecFilename,
+ llvm::cl::desc(
+ "File name of a transform dialect spec to use for preprocessing"),
+ llvm::cl::cat(category));
}
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Pipelines/Options.h b/compiler/src/iree/compiler/Pipelines/Options.h
index 01698b7..32c22a4 100644
--- a/compiler/src/iree/compiler/Pipelines/Options.h
+++ b/compiler/src/iree/compiler/Pipelines/Options.h
@@ -164,6 +164,7 @@
struct PreprocessingOptions {
std::string preprocessingPassPipeline;
+ std::string preprocessingTransformSpecFilename;
void bindOptions(OptionsBinder &binder);
using FromFlags = OptionsFromFlags<PreprocessingOptions>;
};
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
index c9acfa2..bbd7019 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
@@ -31,6 +31,7 @@
name = "Transforms",
srcs = [
"ConvertConv2DToImg2Col.cpp",
+ "InterpreterPass.cpp",
"MakeSingleDispatchForFunction.cpp",
"PadLinalgOps.cpp",
"PassDetail.h",
@@ -55,6 +56,8 @@
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TensorUtils",
+ "@llvm-project//mlir:TransformDialect",
+ "@llvm-project//mlir:TransformDialectTransforms",
"@llvm-project//mlir:Transforms",
],
)
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
index fb7b26f..9fc8003 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
@@ -27,6 +27,7 @@
"Passes.h.inc"
SRCS
"ConvertConv2DToImg2Col.cpp"
+ "InterpreterPass.cpp"
"MakeSingleDispatchForFunction.cpp"
"PadLinalgOps.cpp"
"PassDetail.h"
@@ -43,6 +44,8 @@
MLIRPass
MLIRTensorDialect
MLIRTensorUtils
+ MLIRTransformDialect
+ MLIRTransformDialectTransforms
MLIRTransforms
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::Flow::Transforms
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConv2DToImg2Col.cpp b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConv2DToImg2Col.cpp
index 1776519..77faa23 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConv2DToImg2Col.cpp
+++ b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConv2DToImg2Col.cpp
@@ -552,9 +552,6 @@
struct ConvertConv2DToImg2ColPass
: ConvertConv2DToImg2ColBase<ConvertConv2DToImg2ColPass> {
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<linalg::LinalgDialect>();
- }
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(&getContext());
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/InterpreterPass.cpp b/compiler/src/iree/compiler/Preprocessing/Common/InterpreterPass.cpp
new file mode 100644
index 0000000..84948bc
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/Common/InterpreterPass.cpp
@@ -0,0 +1,54 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Preprocessing/Common/Passes.h"
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
+
+using namespace mlir;
+
+namespace mlir::iree_compiler::Preprocessing {
+
+#define GEN_PASS_DEF_INTERPRETERPASS
+#include "iree/compiler/Preprocessing/Common/Passes.h.inc" // IWYU pragma: export
+
+} // namespace mlir::iree_compiler::Preprocessing
+
+namespace {
+class InterpreterPass
+ : public iree_compiler::Preprocessing::impl::InterpreterPassBase<
+ InterpreterPass> {
+public:
+ using Base::Base;
+
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+ // Load the module from the spec path. The module will be unloaded once the
+ // pass finishes.
+ OwningOpRef<ModuleOp> transformModule;
+ if (failed(transform::detail::assembleTransformLibraryFromPaths(
+ context, transformSpecPath, transformModule)))
+ return signalPassFailure();
+ Operation *payloadRoot = getOperation();
+ Operation *transformEntryPoint = transform::detail::findTransformEntryPoint(
+ getOperation(), *transformModule, "__transform_main");
+ if (!transformEntryPoint) {
+ getOperation()->emitError() << "could not find transform entry point "
+ "__preprocessing_main in transform module";
+ return signalPassFailure();
+ }
+
+ if (failed(transform::applyTransformNamedSequence(
+ payloadRoot, transformEntryPoint, *transformModule,
+ options.enableExpensiveChecks(!disableExpensiveChecks)))) {
+ return signalPassFailure();
+ }
+ }
+
+private:
+ /// Transform interpreter options.
+ transform::TransformOptions options;
+};
+} // namespace
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/MakeSingleDispatchForFunction.cpp b/compiler/src/iree/compiler/Preprocessing/Common/MakeSingleDispatchForFunction.cpp
index 3e762d0..120c464 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/MakeSingleDispatchForFunction.cpp
+++ b/compiler/src/iree/compiler/Preprocessing/Common/MakeSingleDispatchForFunction.cpp
@@ -19,10 +19,6 @@
struct MakeSingleDispatchForFunctionPass
: public MakeSingleDispatchForFunctionBase<
MakeSingleDispatchForFunctionPass> {
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<IREE::Flow::FlowDialect>();
- }
-
void runOnOperation() override;
};
} // namespace
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/PadLinalgOps.cpp b/compiler/src/iree/compiler/Preprocessing/Common/PadLinalgOps.cpp
index 45a394c..68125a4 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/PadLinalgOps.cpp
+++ b/compiler/src/iree/compiler/Preprocessing/Common/PadLinalgOps.cpp
@@ -155,9 +155,6 @@
class PadLinalgOpsPass : public PadLinalgOpsBase<PadLinalgOpsPass> {
public:
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<linalg::LinalgDialect>();
- }
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/PassDetail.h b/compiler/src/iree/compiler/Preprocessing/Common/PassDetail.h
index 5fe9b54..447b597 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/PassDetail.h
+++ b/compiler/src/iree/compiler/Preprocessing/Common/PassDetail.h
@@ -7,6 +7,9 @@
#ifndef IREE_COMPILER_PREPROCESSING_COMMON_PASS_DETAIL_H_
#define IREE_COMPILER_PREPROCESSING_COMMON_PASS_DETAIL_H_
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h
index 7caa385..0b03c38 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h
+++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h
@@ -30,6 +30,9 @@
// Register all Passes
//===----------------------------------------------------------------------===//
+#define GEN_PASS_DECL
+#include "iree/compiler/Preprocessing/Common/Passes.h.inc" // IWYU pragma: keep
+
void registerCommonPreprocessingPasses();
} // namespace mlir::iree_compiler::Preprocessing
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
index 376d92f..9ac1436 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
+++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
@@ -13,12 +13,38 @@
Pass<"iree-preprocessing-convert-conv2d-to-img2col", ""> {
let summary = "Convert linalg convolution ops to matmul img2col based implementation";
let constructor = "mlir::iree_compiler::Preprocessing::createConvertConv2DToImg2ColPass()";
+ let dependentDialects = [
+ "mlir::linalg::LinalgDialect",
+ ];
+}
+
+def InterpreterPass : Pass<"iree-preprocessing-transform-interpreter"> {
+ let summary = "transform dialect interpreter";
+ let description = [{
+ This pass runs the transform dialect interpreter and applies the named
+ sequence transformation named `__transform_main`.
+
+ TODO: Drop this pass in favor of the one upstream. The one upstream requires
+ separate loading of the module and thus isn't suited for single-use
+ transform scripts.
+ }];
+ let dependentDialects = ["::mlir::transform::TransformDialect"];
+ let options = [
+ Option<"disableExpensiveChecks", "disable-expensive-checks", "bool",
+ "false",
+ "Disable expensive checks in the interpreter for a faster run.">,
+ Option<"transformSpecPath", "transform-spec-path", "std::string",
+ /*default=*/"", "File path to the transform spec to use.">,
+ ];
}
def MakeSingleDispatchForFunction :
InterfacePass<"iree-preprocessing-make-single-dispatch-for-function", "mlir::FunctionOpInterface"> {
let summary = "Convert entire function into a single dispatch";
let constructor = "mlir::iree_compiler::Preprocessing::createMakeSingleDispatchForFunctionPass()";
+ let dependentDialects = [
+ "IREE::Flow::FlowDialect",
+ ];
}
def PadLinalgOps :
@@ -30,6 +56,9 @@
/*default=*/"4",
"Specify the padding size">,
];
+ let dependentDialects = [
+ "mlir::linalg::LinalgDialect",
+ ];
}
#endif // IREE_PREPROCESSING_COMMON_PASSES
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel
index d11313b..539471f 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel
@@ -19,10 +19,18 @@
"conv2d_to_img2col.mlir",
"make_single_dispatch_for_function.mlir",
"pad_linalg_ops.mlir",
+ "preprocessing_match_ops.mlir",
+ "transform_symbol_importing.mlir",
],
include = ["*.mlir"],
+ exclude = [
+ "external_function_spec.mlir",
+ ],
),
cfg = "//compiler:lit.cfg.py",
+ data = [
+ "external_function_spec.mlir",
+ ],
tools = [
"//tools:iree-opt",
"@llvm-project//llvm:FileCheck",
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt
index 19425cb..991046d 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt
@@ -17,9 +17,13 @@
"conv2d_to_img2col.mlir"
"make_single_dispatch_for_function.mlir"
"pad_linalg_ops.mlir"
+ "preprocessing_match_ops.mlir"
+ "transform_symbol_importing.mlir"
TOOLS
FileCheck
iree-opt
+ DATA
+ external_function_spec.mlir
)
### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/external_function_spec.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/external_function_spec.mlir
new file mode 100644
index 0000000..579ea29
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/external_function_spec.mlir
@@ -0,0 +1,18 @@
+// Test for importing functions from this spec to a payload module.
+// Tested in `transform_symbol_importing.mlir`
+module attributes {transform.with_named_sequence} {
+ func.func private @some_external_function(%arg0: tensor<?xf32>) -> tensor<?xf32>
+
+ func.func @some_function(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+ return %arg0 : tensor<?xf32>
+ }
+
+ transform.named_sequence @__transform_main(%module: !transform.any_op) {
+ %new_func = transform.iree.import_symbol @some_function into %module : (!transform.any_op) -> !transform.any_op
+
+ %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op
+ %module_2 = transform.iree.get_nearest_symbol_table %func : (!transform.any_op) -> !transform.any_op
+ %new_func_2 = transform.iree.import_symbol @some_external_function into %module_2 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/preprocessing_match_ops.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/preprocessing_match_ops.mlir
new file mode 100644
index 0000000..83f7ec4
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/preprocessing_match_ops.mlir
@@ -0,0 +1,138 @@
+// RUN: iree-opt -transform-interpreter %s --split-input-file | FileCheck %s
+
+#map = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func.func @simple_max
+func.func @simple_max(%input: tensor<?xf32>, %dest: tensor<?xf32>) -> tensor<?xf32> {
+ // CHECK-NEXT: linalg.generic
+ // CHECK-SAME: match_status = "matched"
+ %res = linalg.generic {indexing_maps = [#map, #map],
+ iterator_types = ["parallel"]}
+ ins(%input : tensor<?xf32>)
+ outs(%dest : tensor<?xf32>) attrs = {match_status = "unmatched"} {
+ ^bb0(%in: f32, %out: f32):
+ %max = arith.maximumf %in, %out : f32
+ linalg.yield %max : f32
+ } -> tensor<?xf32>
+ return %res : tensor<?xf32>
+}
+
+// CHECK: func.func @simple_min
+func.func @simple_min(%input: tensor<?xf32>, %dest: tensor<?xf32>) -> tensor<?xf32> {
+ // CHECK-NEXT: linalg.generic
+ // CHECK-SAME: match_status = "unmatched"
+ %res = linalg.generic {indexing_maps = [#map, #map],
+ iterator_types = ["parallel"]}
+ ins(%input : tensor<?xf32>)
+ outs(%dest : tensor<?xf32>) attrs = {match_status = "unmatched"} {
+ ^bb0(%in: f32, %out: f32):
+ %max = arith.minimumf %in, %out : f32
+ linalg.yield %max : f32
+ } -> tensor<?xf32>
+ return %res : tensor<?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @match(%generic: !transform.any_op {transform.readonly}) -> (!transform.any_op) {
+ transform.match.operation_name %generic ["linalg.generic"] : !transform.any_op
+ transform.iree.match.regions %generic : !transform.any_op {
+ ^bb0(%target: tensor<f32>, %empty_max: tensor<f32>):
+ %0 = linalg.generic {indexing_maps = [affine_map<() -> ()>,
+ affine_map<() -> ()>],
+ iterator_types = []}
+ ins(%target : tensor<f32>)
+ outs(%empty_max : tensor<f32>) {
+ ^bb0(%in: f32, %out: f32):
+ %max = arith.maximumf %in, %out : f32
+ linalg.yield %max : f32
+ } -> tensor<f32>
+ }
+ transform.yield %generic : !transform.any_op
+ }
+
+ transform.named_sequence @annotate(%generic: !transform.any_op {transform.readonly}) {
+ %0 = transform.param.constant "matched" -> !transform.any_param
+ transform.annotate %generic "match_status" = %0 : !transform.any_op, !transform.any_param
+ transform.yield
+ }
+
+ transform.named_sequence @__transform_main(%module: !transform.any_op) {
+ %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op
+ transform.foreach_match in %module
+ @match -> @annotate
+ : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+func.func private @external(%arg0: tensor<?xf32>)
+func.func private @external_aligned(%arg0: tensor<100xf32>)
+func.func private @external_static(%arg0: tensor<10xf32>)
+func.func private @other_external_static(%arg0: tensor<15xf32>)
+func.func private @external_2d(%arg0: tensor<?x?xf32>)
+
+// CHECK-LABEL: func.func @call_external
+func.func @call_external(%input: tensor<?xf32>,
+ %input_2d: tensor<?x?xf32>,
+ %input_aligned: tensor<100xf32>,
+ %input_static: tensor<10xf32>,
+ %other_static: tensor<15xf32>) {
+// CHECK: call @external
+// CHECK-SAME: match_status = "matched"
+ func.call @external(%input) {match_status = "unmatched"} : (tensor<?xf32>) -> ()
+// CHECK: call @external_2d
+// CHECK-SAME: match_status = "unmatched"
+ func.call @external_2d(%input_2d) {match_status = "unmatched"} : (tensor<?x?xf32>) -> ()
+// CHECK: call @external_aligned
+// CHECK-SAME: match_status = "aligned_match"
+ func.call @external_aligned(%input_aligned) {match_status = "unmatched"} : (tensor<100xf32>) -> ()
+// CHECK: call @external_static
+// CHECK-SAME: match_status = "static_matched"
+ func.call @external_static(%input_static) {match_status = "unmatched"} : (tensor<10xf32>) -> ()
+// CHECK: call @other_external_static
+// CHECK-SAME: match_status = "matched"
+ func.call @other_external_static(%other_static) {match_status = "unmatched"} : (tensor<15xf32>) -> ()
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @static_match(%call: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) {
+ transform.match.operation_name %call ["func.call"] : !transform.any_op
+ %in0 = transform.get_operand %call[0] : (!transform.any_op) -> !transform.any_value
+ transform.iree.match.cast_compatible_type %in0 = tensor<10xf32> : !transform.any_value
+ %0 = transform.param.constant "static_matched" -> !transform.any_param
+ transform.yield %call, %0 : !transform.any_op, !transform.any_param
+ }
+ transform.named_sequence @static_alignment_match(%call: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) {
+ transform.match.operation_name %call ["func.call"] : !transform.any_op
+ %in0 = transform.get_operand %call[0] : (!transform.any_op) -> !transform.any_value
+ transform.iree.match.cast_compatible_type %in0 = tensor<?xf32> : !transform.any_value
+ transform.iree.match.dim_is_multiple_of %in0[0], 20 : !transform.any_value
+ %0 = transform.param.constant "aligned_match" -> !transform.any_param
+ transform.yield %call, %0 : !transform.any_op, !transform.any_param
+ }
+ transform.named_sequence @match(%call: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) {
+ transform.match.operation_name %call ["func.call"] : !transform.any_op
+ %in0 = transform.get_operand %call[0] : (!transform.any_op) -> !transform.any_value
+ transform.iree.match.cast_compatible_type %in0 = tensor<?xf32> : !transform.any_value
+ %0 = transform.param.constant "matched" -> !transform.any_param
+ transform.yield %call, %0 : !transform.any_op, !transform.any_param
+ }
+
+ transform.named_sequence @annotate(%call: !transform.any_op {transform.readonly},
+ %note: !transform.any_param {transform.readonly}) {
+ transform.annotate %call "match_status" = %note : !transform.any_op, !transform.any_param
+ transform.yield
+ }
+
+ transform.named_sequence @__transform_main(%module: !transform.any_op) {
+ %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op
+ transform.foreach_match in %module
+ @static_match -> @annotate,
+ @static_alignment_match -> @annotate,
+ @match -> @annotate
+ : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/transform_symbol_importing.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/transform_symbol_importing.mlir
new file mode 100644
index 0000000..77a6f5b
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/transform_symbol_importing.mlir
@@ -0,0 +1,10 @@
+// RUN: iree-opt --transform-preload-library='transform-library-paths=%p/external_function_spec.mlir' --transform-interpreter %s | FileCheck %s
+
+module @example {
+ // empty
+}
+
+// CHECK-LABEL: module @example
+// CHECK: func.func private @some_external_function(tensor<?xf32>) -> tensor<?xf32>
+// CHECK: func.func @some_function(%arg0: tensor<?xf32>) -> tensor<?xf32>
+// CHECK-NEXT: return %arg0 : tensor<?xf32>
diff --git a/compiler/src/iree/compiler/Preprocessing/Passes.cpp b/compiler/src/iree/compiler/Preprocessing/Passes.cpp
index bd8b5a9..6d87a4b 100644
--- a/compiler/src/iree/compiler/Preprocessing/Passes.cpp
+++ b/compiler/src/iree/compiler/Preprocessing/Passes.cpp
@@ -67,6 +67,14 @@
if (pipelineExtensions) {
pipelineExtensions->extendPreprocessingPassPipeline(passManager);
}
+
+ if (!preprocessingOptions.preprocessingTransformSpecFilename.empty()) {
+ Preprocessing::InterpreterPassOptions interpreterOptions;
+ interpreterOptions.transformSpecPath =
+ preprocessingOptions.preprocessingTransformSpecFilename;
+ passManager.addPass(
+ Preprocessing::createInterpreterPass(interpreterOptions));
+ }
}
void registerPreprocessingPasses() { registerCommonPreprocessingPasses(); }
diff --git a/compiler/src/iree/compiler/Preprocessing/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/BUILD.bazel
new file mode 100644
index 0000000..08a41a0
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/BUILD.bazel
@@ -0,0 +1,69 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library", "iree_gentbl_cc_library", "iree_td_library")
+load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_td_library(
+ name = "td_files",
+ srcs = enforce_glob(
+ [
+ "PreprocessingExtensionsOps.td",
+ ],
+ include = ["*.td"],
+ ),
+ deps = [
+ "@llvm-project//mlir:OpBaseTdFiles",
+ "@llvm-project//mlir:TransformDialectTdFiles",
+ ],
+)
+
+iree_gentbl_cc_library(
+ name = "PreprocessingExtensionsOpGen",
+ tbl_outs = [
+ (
+ ["--gen-op-decls"],
+ "PreprocessingExtensionsOps.h.inc",
+ ),
+ (
+ ["--gen-op-defs"],
+ "PreprocessingExtensionsOps.cpp.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "PreprocessingExtensionsOps.td",
+ deps = [":td_files"],
+)
+
+iree_compiler_cc_library(
+ name = "PreprocessingExtensions",
+ srcs = [
+ "PreprocessingExtensions.cpp",
+ "PreprocessingExtensionsOps.cpp.inc",
+ ],
+ hdrs = [
+ "PreprocessingExtensions.h",
+ "PreprocessingExtensionsOps.h.inc",
+ ],
+ deps = [
+ ":PreprocessingExtensionsOpGen",
+ "//compiler/src/iree/compiler/Utils",
+ "//llvm-external-projects/iree-dialects:IREEDialectsTransforms",
+ "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:PDLDialect",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TransformDialect",
+ "@llvm-project//mlir:TransformUtils",
+ ],
+)
diff --git a/compiler/src/iree/compiler/Preprocessing/TransformExtensions/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/CMakeLists.txt
new file mode 100644
index 0000000..c74f3f7
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/CMakeLists.txt
@@ -0,0 +1,46 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# compiler/src/iree/compiler/Preprocessing/TransformExtensions/BUILD.bazel #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_tablegen_library(
+ NAME
+ PreprocessingExtensionsOpGen
+ TD_FILE
+ "PreprocessingExtensionsOps.td"
+ OUTS
+ --gen-op-decls PreprocessingExtensionsOps.h.inc
+ --gen-op-defs PreprocessingExtensionsOps.cpp.inc
+)
+
+iree_cc_library(
+ NAME
+ PreprocessingExtensions
+ HDRS
+ "PreprocessingExtensions.h"
+ "PreprocessingExtensionsOps.h.inc"
+ SRCS
+ "PreprocessingExtensions.cpp"
+ "PreprocessingExtensionsOps.cpp.inc"
+ DEPS
+ ::PreprocessingExtensionsOpGen
+ IREEDialectsTransforms
+ IREELinalgTransformDialect
+ LLVMSupport
+ MLIRIR
+ MLIRPDLDialect
+ MLIRSupport
+ MLIRTransformDialect
+ MLIRTransformUtils
+ iree::compiler::Utils
+ PUBLIC
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.cpp b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.cpp
new file mode 100644
index 0000000..486ca26
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.cpp
@@ -0,0 +1,409 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "PreprocessingExtensions.h"
+
+#include "iree-dialects/Dialect/LinalgTransform/SimplePatternRewriter.h"
+#include "iree/compiler/Utils/EquivalenceUtils.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Value.h"
+
+namespace mlir::iree_compiler {
+
+IREE::transform_dialect::PreprocessingExtensions::PreprocessingExtensions() {
+ registerTransformOps<
+#define GET_OP_LIST
+#include "iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensionsOps.cpp.inc"
+ >();
+}
+
+void registerTransformDialectPreprocessingExtension(DialectRegistry ®istry) {
+ registry.addExtensions<IREE::transform_dialect::PreprocessingExtensions>();
+}
+
+//===----------------------------------------------------------------------===//
+// GetNearestSymbolTableOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+IREE::transform_dialect::GetNearestSymbolTableOp::applyToOne(
+ transform::TransformRewriter &rewriter, Operation *target,
+ transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ auto tableOp = SymbolTable::getNearestSymbolTable(target);
+ if (!tableOp) {
+ return emitDefaultDefiniteFailure(target);
+ }
+ results.push_back(tableOp);
+ return DiagnosedSilenceableFailure::success();
+}
+
+void IREE::transform_dialect::GetNearestSymbolTableOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::onlyReadsHandle(getTarget(), effects);
+ transform::producesHandle(getResult(), effects);
+ transform::modifiesPayload(effects);
+}
+
+//===----------------------------------------------------------------------===//
+// ImportSymbolOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure IREE::transform_dialect::ImportSymbolOp::apply(
+ transform::TransformRewriter &rewriter,
+ transform::TransformResults &transformResults,
+ transform::TransformState &state) {
+ auto symbolOp = SymbolTable::lookupNearestSymbolFrom(*this, getSymbol());
+ if (!symbolOp) {
+ return emitDefiniteFailure() << "could not find corresponding symbol op";
+ }
+ // Require isolated from above as the clone does not make sense with escaping
+ // values.
+ if (!symbolOp->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
+ return emitDefiniteFailure()
+ << "target symbol op is not isolated from above";
+ }
+ StringRef symbol = getSymbol().getLeafReference();
+ SmallVector<Operation *> results;
+ for (Operation *payloadOp : state.getPayloadOps(getSymbolTable())) {
+ if (!payloadOp->hasTrait<OpTrait::SymbolTable>()) {
+ return emitDefiniteFailure()
+ << "target symbol table " << payloadOp << " is not a symbol table";
+ }
+ SymbolTable symbolTable(payloadOp);
+
+ if (Operation *preExistingSymbolOp = symbolTable.lookup(symbol)) {
+ if (getForceImport()) {
+ // If we want to overwrite pre-existing symbols, just erase it here.
+ symbolTable.erase(preExistingSymbolOp);
+ } else if (getIfUndefined()) {
+ // Skip if we want to use the symbol that is already there.
+ results.push_back(preExistingSymbolOp);
+ continue;
+ } else {
+ return emitDefiniteFailure()
+ << "target symbol " << symbol << " is already defined";
+ }
+ }
+
+ // Symbol table ops must have exactly one region with exactly one block.
+ // Simply clone the target symbol op into the single block.
+ rewriter.setInsertionPointToStart(&payloadOp->getRegion(0).front());
+ results.push_back(rewriter.clone(*symbolOp));
+ }
+ transformResults.set(cast<OpResult>(getClonedSymbol()), results);
+ return DiagnosedSilenceableFailure::success();
+}
+
+void IREE::transform_dialect::ImportSymbolOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::onlyReadsHandle(getSymbolTable(), effects);
+ transform::producesHandle(getClonedSymbol(), effects);
+ transform::modifiesPayload(effects);
+}
+
+LogicalResult IREE::transform_dialect::ImportSymbolOp::verify() {
+ if (getForceImport() && getIfUndefined()) {
+ return emitOpError()
+ << "force_import and if_undefined are mutually exclusive";
+ }
+ if (!SymbolTable::lookupNearestSymbolFrom(*this, getSymbol())) {
+ return emitOpError() << "invalid import of undefined symbol";
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// MatchCastCompatibleDagFromRootOp
+//===----------------------------------------------------------------------===//
+
+static bool isCastableToTensorType(Type from, RankedTensorType to) {
+ auto tensorType = dyn_cast<RankedTensorType>(from);
+ if (!tensorType) {
+ return false;
+ }
+ if (tensorType.getRank() != to.getRank()) {
+ return false;
+ }
+ if (tensorType.getElementType() != to.getElementType()) {
+ return false;
+ }
+ for (auto [fromSize, toSize] :
+ llvm::zip_equal(tensorType.getShape(), to.getShape())) {
+ // If the target dimension is dynamic we can always cast to it.
+ if (ShapedType::isDynamic(toSize)) {
+ continue;
+ }
+ // Casting a dynamic dimension to a static one is never valid, and static
+ // sizes must always match.
+ if (toSize != fromSize) {
+ return false;
+ }
+ }
+ return true;
+}
+
+// Compares the regions between two operations in lockstep for equality.
+static DiagnosedSilenceableFailure
+compareOperationRegions(transform::TransformOpInterface transformOp,
+ Operation *target, Operation *payload) {
+ if (target->getNumRegions() != payload->getNumRegions()) {
+ return transformOp.emitSilenceableError() << "region count mismatch";
+ }
+ for (auto [r0, r1] :
+ llvm::zip_equal(target->getRegions(), payload->getRegions())) {
+ if (!isStructurallyEquivalentTo(r0, r1)) {
+ return transformOp.emitSilenceableError()
+ << "target op does not match specified body";
+ }
+ }
+ return DiagnosedSilenceableFailure::success();
+}
+
+// Helper to check whether two operations are equivalent up to cast
+// compatibility of their arguments (i.e. the arguments of the payload
+// can be casted to the arguments of the target).
+static DiagnosedSilenceableFailure
+compareCastCompatibleOperations(transform::TransformOpInterface transformOp,
+ Operation *target, Operation *payload) {
+ if (target->getName() != payload->getName()) {
+ return transformOp.emitSilenceableError()
+ << "target operation name " << target->getName()
+ << " does not match payload " << payload->getName();
+ }
+
+ if (target->getAttrDictionary() != payload->getAttrDictionary()) {
+ return transformOp.emitSilenceableError()
+ << "target attribute dictionary " << target->getAttrDictionary()
+ << " does not match payload attribute dictionary "
+ << payload->getAttrDictionary();
+ }
+
+ if (target->getNumResults() != payload->getNumResults()) {
+ return transformOp.emitSilenceableError() << "result count mismatch";
+ }
+
+ if (target->getNumOperands() != payload->getNumOperands()) {
+ return transformOp.emitSilenceableError() << "operand count mismatch";
+ }
+ for (auto [targetType, payloadType] :
+ llvm::zip_equal(target->getOperandTypes(), payload->getOperandTypes())) {
+ if (auto targetTensorType = dyn_cast<RankedTensorType>(targetType)) {
+ if (!isCastableToTensorType(payloadType, targetTensorType)) {
+ return transformOp.emitSilenceableError()
+ << "operand tensor type mismatch";
+ }
+ } else if (targetType != payloadType) {
+ return transformOp.emitSilenceableError() << "operand type mismatch";
+ }
+ }
+
+ return compareOperationRegions(transformOp, target, payload);
+}
+
+DiagnosedSilenceableFailure
+IREE::transform_dialect::MatchCastCompatibleDagFromRootOp::matchOperation(
+ Operation *current, transform::TransformResults &results,
+ transform::TransformState &state) {
+ Operation *targetDagRoot = getRegion().front().back().getPrevNode();
+
+ // Reserve the list of inputs based on the number of block arguments in
+ // the operation region.
+ int64_t numInputs = getRegion().getNumArguments();
+ SmallVector<Value> inputs(numInputs, nullptr);
+
+ // Maps from target/payload op to the order in which they were first
+ // processed. This is used to verify that two uses actually point to the
+ // same node in the dag.
+ llvm::MapVector<Operation *, int64_t> targetDagOrder;
+ llvm::MapVector<Operation *, int64_t> payloadDagOrder;
+ int64_t index = 0;
+
+ auto compareDagOrder = [&](Operation *target, Operation *payload) {
+ return targetDagOrder.find(target) == payloadDagOrder.find(payload);
+ };
+
+ // Populate the paired worklist with the current target and payload root ops.
+ SmallVector<Operation *> targetWorklist = {targetDagRoot};
+ SmallVector<Operation *> payloadWorklist = {current};
+ while (!targetWorklist.empty()) {
+ Operation *targetOp = targetWorklist.pop_back_val();
+ Operation *payloadOp = payloadWorklist.pop_back_val();
+
+ if (targetOp->hasAttr("match.operation_name_only")) {
+ if (targetOp->getName() != payloadOp->getName()) {
+ return emitSilenceableError() << "only operation name op mismatch";
+ }
+ // Do not recurse and do not require any specific structure beyond the
+ // operation name.
+ continue;
+ }
+
+ // Verify that if already processed, both operations are at the same
+ // position.
+ if (targetDagOrder.contains(targetOp) ||
+ payloadDagOrder.contains(payloadOp)) {
+ if (!compareDagOrder(targetOp, payloadOp)) {
+ return emitSilenceableError() << "dag mismatch";
+ }
+ continue;
+ }
+
+ // Verify general operation equality (name, attributes, regions).
+ DiagnosedSilenceableFailure diag =
+ compareCastCompatibleOperations(*this, targetOp, payloadOp);
+ if (!diag.succeeded()) {
+ diag.attachNote() << "While processing operation " << *payloadOp;
+ return diag;
+ }
+
+ for (auto [payloadOperand, targetOperand] :
+ llvm::zip_equal(payloadOp->getOperands(), targetOp->getOperands())) {
+ // If the target value is a block argument, map the payload value to the
+ // associated input and don't process its producer.
+ if (auto targetBlockArg = dyn_cast<BlockArgument>(targetOperand)) {
+ if (targetBlockArg.getOwner() != &getRegion().front()) {
+ return emitDefiniteFailure() << "Invalid block argument in target";
+ }
+ int64_t argIdx = targetBlockArg.getArgNumber();
+ if (inputs[argIdx] && inputs[argIdx] != targetOperand) {
+ return emitSilenceableError()
+ << "input operand with conflicting uses";
+ }
+ inputs[argIdx] = payloadOperand;
+ continue;
+ }
+
+ Operation *payloadDefiningOp = payloadOperand.getDefiningOp();
+ if (!payloadDefiningOp) {
+ return emitSilenceableError()
+ << "early termination of the operation dag";
+ }
+
+ // Check whether the producer was already processed, and if so make sure
+ // the target and payload match.
+ Operation *targetDefiningOp = targetOperand.getDefiningOp();
+ if (targetDagOrder.contains(targetDefiningOp) ||
+ payloadDagOrder.contains(payloadDefiningOp)) {
+ if (!compareDagOrder(targetDefiningOp, payloadDefiningOp)) {
+ return emitSilenceableError() << "dag mismatch";
+ }
+ continue;
+ }
+ // Pop the producer of this value onto the worklist.
+ targetWorklist.push_back(targetDefiningOp);
+ payloadWorklist.push_back(payloadDefiningOp);
+ }
+
+ // Mark the current target + payload as processed.
+ targetDagOrder[targetOp] = index;
+ payloadDagOrder[payloadOp] = index++;
+ }
+
+ // Verify that all input arguments were successfully matched.
+ if (llvm::any_of(inputs, [](Value in) { return !in; })) {
+ return emitSilenceableError() << "failed to match all input nodes";
+ }
+
+ results.setValues(cast<OpResult>(getInputs()), inputs);
+ results.setValues(cast<OpResult>(getOutputs()), current->getResults());
+ return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult
+IREE::transform_dialect::MatchCastCompatibleDagFromRootOp::verify() {
+ auto &body = getRegion().front();
+ if (llvm::range_size(body.getOperations()) < 2) {
+ return emitOpError() << "match region must contain at least one operation";
+ }
+ // TODO: Region verification that it includes a single DAG.
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// MatchCastCompatibleTypesOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+IREE::transform_dialect::MatchCastCompatibleTypesOp::matchValue(
+ Value current, transform::TransformResults &results,
+ transform::TransformState &state) {
+ Type targetType = getTargetType();
+ if (auto targetTensorType = dyn_cast<RankedTensorType>(targetType)) {
+ if (!isCastableToTensorType(current.getType(), targetTensorType)) {
+ return emitSilenceableError()
+ << "type " << current.getType() << " is not castable to "
+ << targetTensorType;
+ }
+ return DiagnosedSilenceableFailure::success();
+ }
+ if (current.getType() != targetType) {
+ return emitSilenceableError()
+ << "type " << current.getType() << " does not match " << targetType;
+ }
+ return DiagnosedSilenceableFailure::success();
+}
+
+//===----------------------------------------------------------------------===//
+// MatchDimIsMultipleOfOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+IREE::transform_dialect::MatchDimIsMultipleOfOp::matchValue(
+ Value current, transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto shapedType = dyn_cast<ShapedType>(current.getType());
+ if (!shapedType) {
+ return emitSilenceableError()
+ << "type " << current.getType() << " is not a shaped type";
+ }
+ int64_t dim = getDim();
+ if (dim > shapedType.getRank()) {
+ return emitSilenceableError()
+ << "dim " << dim << " out of range for shaped type " << shapedType;
+ }
+ int64_t size = getSize();
+ if (shapedType.getShape()[dim] % size != 0) {
+ return emitSilenceableError()
+ << "dim " << dim << " of shaped type " << shapedType
+ << " is not a multiple of " << size;
+ }
+ return DiagnosedSilenceableFailure::success();
+}
+
+//===----------------------------------------------------------------------===//
+// MatchRegionsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+IREE::transform_dialect::MatchRegionsOp::matchOperation(
+ Operation *current, transform::TransformResults &results,
+ transform::TransformState &state) {
+ Operation *comparisonTarget = &getRegion().front().front();
+ return compareOperationRegions(*this, comparisonTarget, current);
+}
+
+LogicalResult IREE::transform_dialect::MatchRegionsOp::verify() {
+ auto &body = getRegion().front();
+ if (llvm::range_size(body.getOperations()) != 2) {
+ return emitOpError() << "match region must contain exactly one operation";
+ }
+ Operation *target = &body.front();
+ if (target->getNumRegions() == 0) {
+ return emitOpError() << "contained operation for comparison must have at "
+ "least one region";
+ }
+ return success();
+}
+
+} // namespace mlir::iree_compiler
+
+#define GET_OP_CLASSES
+#include "iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensionsOps.cpp.inc"
diff --git a/compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.h b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.h
new file mode 100644
index 0000000..dbc9579
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.h
@@ -0,0 +1,40 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_CODEGEN_PREPROCESSING_TRANSFORMEXTENSIONS_PREPROCESSINGEXTENSIONS_H_
+#define IREE_COMPILER_CODEGEN_PREPROCESSING_TRANSFORMEXTENSIONS_PREPROCESSINGEXTENSIONS_H_
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
+
+namespace mlir {
+class DialectRegistry;
+} // namespace mlir
+
+#define GET_OP_CLASSES
+#include "iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensionsOps.h.inc"
+
+namespace mlir::iree_compiler {
+
+/// Registers Preprocessing transformations that require IREE-specific
+/// information into the transform dialect.
+void registerTransformDialectPreprocessingExtension(DialectRegistry ®istry);
+
+namespace IREE::transform_dialect {
+// Hook to register Preprocessing transformations to the transform dialect.
+class PreprocessingExtensions
+ : public transform::TransformDialectExtension<PreprocessingExtensions> {
+public:
+ PreprocessingExtensions();
+};
+} // namespace IREE::transform_dialect
+
+} // namespace mlir::iree_compiler
+
+#endif // IREE_COMPILER_CODEGEN_PREPROCESSING_TRANSFORMEXTENSIONS_PREPROCESSINGEXTENSIONS_H_
diff --git a/compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensionsOps.td b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensionsOps.td
new file mode 100644
index 0000000..de0232f
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensionsOps.td
@@ -0,0 +1,213 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_PREPROCESSING_TRANSFORMEXTENSIONS_PREPROCESSINGEXTENSIONS
+#define IREE_COMPILER_DIALECT_PREPROCESSING_TRANSFORMEXTENSIONS_PREPROCESSINGEXTENSIONS
+
+include "mlir/Dialect/Transform/IR/MatchInterfaces.td"
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
+include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpBase.td"
+
+def GetNearestSymbolTableOp : Op<Transform_Dialect, "iree.get_nearest_symbol_table",
+ [FunctionalStyleTransformOpTrait,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ TransformOpInterface,
+ TransformEachOpTrait,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ Returns the nearest symbol table op for each op in the payload, inclusive.
+
+ #### Return modes
+
+ This operation reads the `target` handle and produces the `result`
+ handle. This operation emits a definite failure if the nearest symbol table
+ is unknown.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs TransformHandleTypeInterface:$result);
+
+ let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
+ let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::Operation* target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
+def ImportSymbolOp : Op<Transform_Dialect, "iree.import_symbol",
+ [FunctionalStyleTransformOpTrait,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ DeclareOpInterfaceMethods<TransformOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ Clones the op defined by the given symbol into the given symbol table and
+ returns the cloned symbol. If `force_import` is set, this will (unsafely)
+ overwrite any pre-existing definitions of the same symbol. If
+ `if_undefined` is set, this will return a handle to the pre-existing symbol
+ in the payload if found instead of failing.
+
+ #### Return modes
+
+ This operation reads the `symbol_table` handle and produces the
+ `cloned_symbol` handle. This operation emits a definite failure if the if
+ the `symbol_table` op does not define a symbol table.
+
+ This will emit a definite failure if the symbol already exists in the
+ symbol table and neither `force_import` and `if_undefined` are set.
+ }];
+
+ let arguments = (ins SymbolRefAttr:$symbol,
+ UnitAttr:$if_undefined,
+ UnitAttr:$force_import,
+ TransformHandleTypeInterface:$symbol_table);
+ let results = (outs TransformHandleTypeInterface:$cloned_symbol);
+
+ let assemblyFormat = [{
+ (`force` $force_import^)? $symbol `into` $symbol_table
+ (`if` `undefined` $if_undefined^)? attr-dict
+ `:` functional-type(operands, results)
+ }];
+ let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+
+ let hasVerifier = 1;
+}
+
+def MatchCastCompatibleDagFromRootOp : Op<Transform_Dialect, "iree.match.cast_compatible_dag_from_root",
+ [IsolatedFromAbove,
+ MatchOpInterface,
+ SingleOpMatcher,
+ SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">,
+ MemoryEffectsOpInterface]> {
+ let summary =
+ "Checks if the body of the target op matches an operation dag starting "
+ "at the given root";
+ let description = [{
+ Checks whether the given root op matches an operation dag specified in the
+ body of this op. Enforces cast compatibilty between types rather than a
+ strict equality, similar to `iree.match.cast_compatible_type`.
+
+ Note: This operation is experimental and subject to change. General subgraph
+ matching is difficult and can spawn various DSLs and a slew of transforms.
+ This op tries to keep it relatively simple an inflexible, reflecting the
+ expected use case of splicing in hand written kernels that can be equally
+ inflexible.
+
+ #### Return modes
+
+ Succeeds if the root operation matches the dag given by this op, and
+ produces a silenceable failure otherwise. Produces a definite failure
+ if the operand is not associated with a single payload value.
+
+ On success, this operation produces a handle to the inputs and outputs
+ of the operation dag based on the outputs of the root op and the block
+ arguments of this operations body.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$operand_handle);
+ let results = (outs TransformValueHandleTypeInterface:$inputs,
+ TransformValueHandleTypeInterface:$outputs);
+ let regions = (region SizedRegion<1>:$region);
+ let assemblyFormat = "$operand_handle attr-dict-with-keyword regions `:` functional-type(operands, results)";
+ let extraClassDeclaration = SingleOpMatcher.extraDeclaration;
+
+ let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+ let hasVerifier = 1;
+}
+
+def MatchCastCompatibleTypesOp : Op<Transform_Dialect, "iree.match.cast_compatible_type",
+ [IsolatedFromAbove,
+ MatchOpInterface,
+ SingleValueMatcher,
+ MemoryEffectsOpInterface]> {
+ let summary =
+ "Checks if the body of the target op matches the body of the single contained op";
+ let description = [{
+ Checks whether the given value is cast-compatible with the given target
+ type attribute.
+
+ Currently this operation only allows casting of tensor types. Other types
+ must match exactly.
+
+ #### Return modes
+
+ Succeeds if the value's type is compatible with the target type, and
+ produces a silenceable failure otherwise. Produces a definite failure
+ if the operand is not associated with a single payload value.
+ }];
+
+ let arguments = (ins TransformValueHandleTypeInterface:$operand_handle,
+ TypeAttr:$target_type);
+ let assemblyFormat = "$operand_handle `=` $target_type attr-dict `:` type($operand_handle)";
+ let extraClassDeclaration = SingleValueMatcher.extraDeclaration;
+
+ let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+}
+
+def MatchDimIsMultipleOfOp : Op<Transform_Dialect, "iree.match.dim_is_multiple_of",
+ [IsolatedFromAbove,
+ MatchOpInterface,
+ SingleValueMatcher,
+ MemoryEffectsOpInterface]> {
+ let summary =
+ "Checks if the body of the target op matches the body of the single contained op";
+ let description = [{
+ Checks whether the given dimension given shaped value is a multiple of the
+ given size.
+
+ #### Return modes
+
+ Succeeds if the value's type is compatible with the target type, and
+ produces a silenceable failure otherwise. Produces a definite failure
+ if the operand is not associated with a single payload value.
+ }];
+
+ let arguments = (ins TransformValueHandleTypeInterface:$operand_handle,
+ I64Attr:$dim,
+ I64Attr:$size);
+ let assemblyFormat = "$operand_handle `[` $dim `]` `,` $size "
+ "attr-dict `:` type($operand_handle)";
+ let extraClassDeclaration = SingleValueMatcher.extraDeclaration;
+
+ let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+}
+
+def MatchRegionsOp : Op<Transform_Dialect, "iree.match.regions",
+ [IsolatedFromAbove,
+ MatchOpInterface,
+ SingleOpMatcher,
+ SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">,
+ MemoryEffectsOpInterface]> {
+ let summary =
+ "Checks if the body of the target op matches the body of the single contained op";
+ let description = [{
+ Does a structural comparison of the regions of the single op contained
+ within the region of this op against the regions of the target operation.
+
+ #### Return modes
+
+ Succeeds if the operation body satisfies the specified criteria, produces a
+ silenceable failure otherwise. Produces a definite failure if the operand is
+ not associated with a single payload op.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$operand_handle);
+ let regions = (region SizedRegion<1>:$region);
+ let assemblyFormat = "$operand_handle attr-dict `:` type($operand_handle) regions";
+ let extraClassDeclaration = SingleOpMatcher.extraDeclaration;
+
+ let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+ let hasVerifier = 1;
+}
+
+#endif // IREE_COMPILER_DIALECT_PREPROCESSING_TRANSFORMEXTENSIONS_PREPROCESSINGEXTENSIONS
diff --git a/compiler/src/iree/compiler/Tools/BUILD.bazel b/compiler/src/iree/compiler/Tools/BUILD.bazel
index 5d87d65..637a66b 100644
--- a/compiler/src/iree/compiler/Tools/BUILD.bazel
+++ b/compiler/src/iree/compiler/Tools/BUILD.bazel
@@ -63,6 +63,7 @@
"//compiler/src/iree/compiler/Modules/IO/Parameters/Transforms",
"//compiler/src/iree/compiler/Pipelines",
"//compiler/src/iree/compiler/Preprocessing:Passes",
+ "//compiler/src/iree/compiler/Preprocessing/TransformExtensions:PreprocessingExtensions",
"//llvm-external-projects/iree-dialects:IREEInputDialect",
"//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
"//llvm-external-projects/iree-dialects:IREELinalgExtPasses",
@@ -112,8 +113,21 @@
"@llvm-project//mlir:SPIRVTransforms",
"@llvm-project//mlir:ShapeDialect",
"@llvm-project//mlir:TensorInferTypeOpInterfaceImpl",
+ "@llvm-project//mlir:TransformDialect",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:VectorDialect",
+
+ # TransformExtensions
+ "@llvm-project//mlir:AffineTransformOps",
+ "@llvm-project//mlir:BufferizationTransformOps",
+ "@llvm-project//mlir:FuncTransformOps",
+ "@llvm-project//mlir:GPUTransformOps",
+ "@llvm-project//mlir:LinalgTransformOps",
+ "@llvm-project//mlir:MemRefTransformOps",
+ "@llvm-project//mlir:SCFTransformOps",
+ "@llvm-project//mlir:TensorTransformOps",
+ "@llvm-project//mlir:TransformLoopExtension",
+ "@llvm-project//mlir:VectorTransformOps",
],
)
diff --git a/compiler/src/iree/compiler/Tools/CMakeLists.txt b/compiler/src/iree/compiler/Tools/CMakeLists.txt
index 97a3a92..03673ff 100644
--- a/compiler/src/iree/compiler/Tools/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Tools/CMakeLists.txt
@@ -70,6 +70,7 @@
iree::compiler::Modules::IO::Parameters::Transforms
iree::compiler::Pipelines
iree::compiler::Preprocessing::Passes
+ iree::compiler::Preprocessing::TransformExtensions::PreprocessingExtensions
PUBLIC
)
@@ -105,9 +106,21 @@
MLIRFuncDialect
MLIRFuncToSPIRV
MLIRTensorInferTypeOpInterfaceImpl
+ MLIRTransformDialect
MLIRTransforms
MLIRVectorDialect
iree::compiler::Dialect::VM::Target::init_targets
+
+ MLIRAffineTransformOps
+ MLIRBufferizationTransformOps
+ MLIRFuncTransformOps
+ MLIRGPUTransformOps
+ MLIRLinalgTransformOps
+ MLIRMemRefTransformOps
+ MLIRSCFTransformOps
+ MLIRTensorTransformOps
+ MLIRVectorTransformOps
+ MLIRTransformLoopExtension
PUBLIC
)
diff --git a/compiler/src/iree/compiler/Tools/init_iree_dialects.h b/compiler/src/iree/compiler/Tools/init_iree_dialects.h
index 5fcdca4..c8aca0e 100644
--- a/compiler/src/iree/compiler/Tools/init_iree_dialects.h
+++ b/compiler/src/iree/compiler/Tools/init_iree_dialects.h
@@ -31,6 +31,7 @@
#include "iree/compiler/Modules/HAL/Inline/IR/HALInlineDialect.h"
#include "iree/compiler/Modules/HAL/Loader/IR/HALLoaderDialect.h"
#include "iree/compiler/Modules/IO/Parameters/IR/IOParametersDialect.h"
+#include "iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.h"
#include "mlir/IR/Dialect.h"
namespace mlir::iree_compiler {
@@ -59,6 +60,9 @@
registerCodegenInterfaces(registry);
registerGlobalOptimizationInterfaces(registry);
registerUKernelBufferizationInterface(registry);
+
+ // Register transform dialect extensions.
+ registerTransformDialectPreprocessingExtension(registry);
}
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Tools/init_mlir_dialects.h b/compiler/src/iree/compiler/Tools/init_mlir_dialects.h
index 80d22d8..ddc310d 100644
--- a/compiler/src/iree/compiler/Tools/init_mlir_dialects.h
+++ b/compiler/src/iree/compiler/Tools/init_mlir_dialects.h
@@ -13,30 +13,40 @@
#define IREE_COMPILER_TOOLS_INIT_MLIR_DIALECTS_H_
#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h"
#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
#include "mlir/IR/Dialect.h"
#ifdef IREE_HAVE_C_OUTPUT_FORMAT
@@ -76,6 +86,18 @@
tensor::registerInferTypeOpInterfaceExternalModels(registry);
tensor::registerTilingInterfaceExternalModels(registry);
+ // Register all transform dialect extensions.
+ affine::registerTransformDialectExtension(registry);
+ bufferization::registerTransformDialectExtension(registry);
+ func::registerTransformDialectExtension(registry);
+ gpu::registerTransformDialectExtension(registry);
+ linalg::registerTransformDialectExtension(registry);
+ memref::registerTransformDialectExtension(registry);
+ scf::registerTransformDialectExtension(registry);
+ tensor::registerTransformDialectExtension(registry);
+ transform::registerLoopExtension(registry);
+ vector::registerTransformDialectExtension(registry);
+
#ifdef IREE_HAVE_C_OUTPUT_FORMAT
registry.insert<emitc::EmitCDialect>();
#endif // IREE_HAVE_C_OUTPUT_FORMAT
diff --git a/compiler/src/iree/compiler/Tools/init_mlir_passes.h b/compiler/src/iree/compiler/Tools/init_mlir_passes.h
index 611797d..b3e1de8 100644
--- a/compiler/src/iree/compiler/Tools/init_mlir_passes.h
+++ b/compiler/src/iree/compiler/Tools/init_mlir_passes.h
@@ -23,6 +23,7 @@
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
+#include "mlir/Dialect/Transform/Transforms/Passes.h"
#include "mlir/Transforms/Passes.h"
namespace mlir {
@@ -57,6 +58,9 @@
affine::registerAffinePipelineDataTransferPass();
registerConvertAffineToStandardPass();
+ // Arm SME
+ arm_sme::registerArmSMEPasses();
+
// Linalg
registerLinalgPasses();
@@ -80,8 +84,8 @@
registerConvertControlFlowToSPIRVPass();
registerConvertFuncToSPIRVPass();
- // Arm SME
- arm_sme::registerArmSMEPasses();
+ // Transform Dialect
+ transform::registerTransformPasses();
}
} // namespace mlir
diff --git a/compiler/src/iree/compiler/Utils/EquivalenceUtils.cpp b/compiler/src/iree/compiler/Utils/EquivalenceUtils.cpp
index 2f0b122..b4a938a 100644
--- a/compiler/src/iree/compiler/Utils/EquivalenceUtils.cpp
+++ b/compiler/src/iree/compiler/Utils/EquivalenceUtils.cpp
@@ -34,7 +34,7 @@
static bool isStructurallyEquivalentTo(Operation &lhs, Operation &rhs,
IRMapping &parentMapping);
-static bool isStructurallyEquivalentTo(Region &lhs, Region &rhs) {
+bool isStructurallyEquivalentTo(Region &lhs, Region &rhs) {
IRMapping mapping;
return isStructurallyEquivalentTo(lhs, rhs, mapping);
}
diff --git a/compiler/src/iree/compiler/Utils/EquivalenceUtils.h b/compiler/src/iree/compiler/Utils/EquivalenceUtils.h
index c533b04..1b51717 100644
--- a/compiler/src/iree/compiler/Utils/EquivalenceUtils.h
+++ b/compiler/src/iree/compiler/Utils/EquivalenceUtils.h
@@ -11,6 +11,12 @@
namespace mlir::iree_compiler {
+// Recursively compares two regions for structural equivalence.
+//
+// Structural equivalence ensures that operations in both regions
+// |lhs| and |rhs| have the same attributes and same use-def structure.
+bool isStructurallyEquivalentTo(Region &lhs, Region &rhs);
+
// Recursively compares two operations for structural equivalence.
//
// Structural equivalence ensures that operations in the regions of both the
diff --git a/samples/custom_dispatch/README.md b/samples/custom_dispatch/README.md
index a5579f7..63d5748 100644
--- a/samples/custom_dispatch/README.md
+++ b/samples/custom_dispatch/README.md
@@ -257,6 +257,49 @@
substitution approaches should be used instead and in many cases that is
sufficient for most workloads not involving other libraries.
+### Compile Time Inlining Custom Function Calls
+
+**Overview**: user defines functions with MLIR dialects IREE is able to ingest
+paired with a matcher and replacement pattern. The matcher runs as preprocessing
+and calls into the replacement pattern for all successful matches. The
+replacement pattern imports a function from the externally
+ABI, wires them up and links them in their runtime binary, declares the
+externally available functions in IR, and emits calls to the functions in IR
+interleaved with other IR.
+
+**Workflows**:
+
+Statically matched and imported external functions
+```
+ +--------------+
+ | example.mlir |
++--------------+ +--------------+ +--------------+
+| (one of the | | functions + | v
+| above static | ----> | matchers + | ----> iree-compile
+| workflows) | | replace.mlir | v
++--------------+ +--------------+ +--------------+
+ | example.vmfb |
+ +--------------+
+```
+
+**Samples**:
+
+* CPU: [custom_dispatch/cpu/embedded/](./cpu/embedded/) (.c -> .o)
+ * [custom_dispatch/cpu/embedded/](./cpu/embedded/example_transform_spec.mlir) (.mlir)
+* Vulkan/SPIR-V: [custom_dispatch/vulkan/shaders/](./vulkan/shaders/) (.glsl -> .spv)
+ * [custom_dispatch/vulkan/shaders/](./vulkan/shaders/example_transform_spec.mlir) (.mlir)
+
+The above two samples build on top of a couple of the static workflows shown
+above, but should work with any of the other approaches. The idea is to separate
+the custom kernel from the target module to be compiled, allowing integration of
+custom dispatches with default IREE codegen without the need to build a custom
+set of compiler tools around IREE to generate the necessary IR.
+
+There are a number of possible points at which the match and replace can happen;
+the above shows it after import + input conversion. Other plugin points are
+possible (e.g. before input conversion or after global optimization), but
+currently are missing some ergonomics on the available matchers.
+
### Others
Most other situations are covered by [custom modules](/samples/custom_module/).
diff --git a/samples/custom_dispatch/cpu/embedded/CMakeLists.txt b/samples/custom_dispatch/cpu/embedded/CMakeLists.txt
index 95d32f9..2f428dc 100644
--- a/samples/custom_dispatch/cpu/embedded/CMakeLists.txt
+++ b/samples/custom_dispatch/cpu/embedded/CMakeLists.txt
@@ -66,9 +66,11 @@
SRCS
"example_hal.mlir"
"example_stream.mlir"
+ "example_transform.mlir"
DATA
functions_arm_64.o
functions_x86_64.o
+ example_transform_spec.mlir
TOOLS
FileCheck
iree-compile
diff --git a/samples/custom_dispatch/cpu/embedded/README.md b/samples/custom_dispatch/cpu/embedded/README.md
index dc94a99..ede94a6 100644
--- a/samples/custom_dispatch/cpu/embedded/README.md
+++ b/samples/custom_dispatch/cpu/embedded/README.md
@@ -148,3 +148,29 @@
--input=8xf32=4 \
--module=/tmp/example.vmfb
```
+
+## Custom Kernel Match and Replace Scripting Instructions
+
+Follow the first two steps above to build the samples, and then compile with one
+additional flag to include the path to the kernel matcher and replacer.
+
+ ```
+ iree-compile \
+ --iree-hal-executable-object-search-path=../iree-build/ \
+ --iree-preprocessing-transform-spec-filename=samples/custom_dispatch/cpu/embedded/example_transform_spec.mlir \
+ samples/custom_dispatch/cpu/embedded/example_transform.mlir \
+ -o=/tmp/example.vmfb
+ ```
+
+And then run the example the same way.
+
+ ```
+ iree-run-module \
+ --device=local-sync \
+ --function=mixed_invocation \
+ --input=5xf32=7 \
+ --input=5xf32=4 \
+ --input=10xf32=-4 \
+ --input=10xf32=3 \
+ --module=/tmp/example.vmfb
+ ```
diff --git a/samples/custom_dispatch/cpu/embedded/example_transform.mlir b/samples/custom_dispatch/cpu/embedded/example_transform.mlir
new file mode 100644
index 0000000..0735189
--- /dev/null
+++ b/samples/custom_dispatch/cpu/embedded/example_transform.mlir
@@ -0,0 +1,111 @@
+// RUN: iree-compile %s \
+// RUN: --iree-hal-executable-object-search-path=$IREE_BINARY_DIR \
+// RUN: --iree-preprocessing-transform-spec-filename=%p/example_transform_spec.mlir | \
+// RUN: iree-run-module \
+// RUN: --device=local-sync \
+// RUN: --module=- \
+// RUN: --function=mixed_invocation \
+// RUN: --input=5xf32=7 \
+// RUN: --input=5xf32=4 \
+// RUN: --input=10xf32=-4 \
+// RUN: --input=10xf32=3 | \
+// RUN: FileCheck %s
+
+// The configuration used for executable compilation.
+// This lets the compiler and runtime know the format and requirements of the
+// executable binaries produced and multiple variants with differing formats
+// and compilation options (architectures, etc) can be embedded for runtime
+// selection.
+#x86_64_target = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {
+ data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+ native_vector_size = 32 : index,
+ target_triple = "x86_64-none-elf"
+}>
+
+// The target devices that the program will run on. We can compile and run with
+// multiple targets, but this example is maintaining an implicit requirement
+// that the custom kernel being spliced in is supported by the target device,
+// hence we only support llvm-cpu here.
+#cpu_target = #hal.device.target<"llvm-cpu", {
+ executable_targets = [
+ #x86_64_target
+ ]
+}>
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0)>
+module @example attributes {hal.device.targets = [#cpu_target]} {
+
+ // CHECK-LABEL: EXEC @mixed_invocation
+ func.func @mixed_invocation(%lhs: tensor<?xf32>,
+ %rhs: tensor<?xf32>,
+ %lhs_static: tensor<10xf32>,
+ %rhs_static: tensor<10xf32>) -> (tensor<?xf32>, tensor<10xf32>) {
+ %c0 = arith.constant 0 : index
+ %dim = tensor.dim %lhs, %c0 : tensor<?xf32>
+ %empty = tensor.empty(%dim) : tensor<?xf32>
+ %max = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]}
+ ins(%lhs, %rhs : tensor<?xf32>, tensor<?xf32>)
+ outs(%empty : tensor<?xf32>) {
+ ^bb0(%in: f32, %in0: f32, %out: f32):
+ %m = arith.mulf %in, %in0 : f32
+ linalg.yield %m : f32
+ } -> tensor<?xf32>
+ %abs = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]}
+ ins(%max : tensor<?xf32>)
+ outs(%empty : tensor<?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %a = math.absf %in : f32
+ linalg.yield %a : f32
+ } -> tensor<?xf32>
+ %neg = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]}
+ ins(%abs : tensor<?xf32>)
+ outs(%empty : tensor<?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %n = arith.negf %in : f32
+ linalg.yield %n : f32
+ } -> tensor<?xf32>
+
+ %empty_static = tensor.empty() : tensor<10xf32>
+ %max_static = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]}
+ ins(%lhs_static, %rhs_static : tensor<10xf32>, tensor<10xf32>)
+ outs(%empty_static : tensor<10xf32>) {
+ ^bb0(%in: f32, %in0: f32, %out: f32):
+ %m = arith.mulf %in, %in0 : f32
+ linalg.yield %m : f32
+ } -> tensor<10xf32>
+ %abs_static = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]}
+ ins(%max_static : tensor<10xf32>)
+ outs(%empty_static : tensor<10xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %a = math.absf %in : f32
+ linalg.yield %a : f32
+ } -> tensor<10xf32>
+ %neg_static = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]}
+ ins(%abs_static : tensor<10xf32>)
+ outs(%empty_static : tensor<10xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %n = arith.negf %in : f32
+ linalg.yield %n : f32
+ } -> tensor<10xf32>
+
+ // Add 1 to show that it actually runs the custom kernel.
+ // CHECK: 5xf32=-27 -27 -27 -27 -27
+ // CHECK: 10xf32=-11 -11 -11 -11 -11 -11 -11 -11 -11 -11
+ return %neg, %neg_static : tensor<?xf32>, tensor<10xf32>
+ }
+} // module
diff --git a/samples/custom_dispatch/cpu/embedded/example_transform_spec.mlir b/samples/custom_dispatch/cpu/embedded/example_transform_spec.mlir
new file mode 100644
index 0000000..1867e5e
--- /dev/null
+++ b/samples/custom_dispatch/cpu/embedded/example_transform_spec.mlir
@@ -0,0 +1,175 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+// The configuration used for executable compilation.
+// This specifies the device configurations that support this custom kernel.
+#x86_64_target = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {
+ data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+ native_vector_size = 32 : index,
+ target_triple = "x86_64-none-elf"
+}>
+
+// The target devices that the program will run on.
+// These can come from compiler flags and multiple targets can be supported
+// It's possible, for example, to support targeting multiple devices in the same
+// compiled binary (CPU + Vulkan, etc).
+#cpu_target = #hal.device.target<"llvm-cpu", {
+ executable_targets = [
+ #x86_64_target
+ ]
+}>
+
+module attributes {transform.with_named_sequence} {
+
+ // Executable containing exported shims and calls to external functions.
+ // See the other examples in this directory for in-depth explanations of
+ // the IR structure of this executable.
+ hal.executable private @executable {
+ hal.executable.variant public @x86_64 target(#x86_64_target) objects([
+ #hal.executable.object<{
+ path = "samples/custom_dispatch/cpu/embedded/functions_x86_64.o"
+ }>
+ ]) {
+ hal.executable.export public @simple_mul_abs_negate ordinal(0)
+ layout(#hal.pipeline.layout<push_constants = 1, sets = [
+ <0, bindings = [
+ <0, storage_buffer, ReadOnly>,
+ <1, storage_buffer, ReadOnly>,
+ <2, storage_buffer>
+ ]>
+ ]>) {
+ ^bb0(%device: !hal.device, %workload: index):
+ %x = affine.apply affine_map<()[s0] -> (s0 ceildiv 64)>()[%workload]
+ %c1 = arith.constant 1 : index
+ hal.return %x, %c1, %c1 : index, index, index
+ }
+ builtin.module {
+ func.func private @simple_mul_abs_negate_workgroup(%binding0: memref<?xf32>, %binding1: memref<?xf32>, %binding2: memref<?xf32>, %dim: index, %tid: index) attributes {
+ hal.import.static
+ }
+ func.func @simple_mul_abs_negate() {
+ %c0 = arith.constant 0 : index
+ %dim_i32 = hal.interface.constant.load[0] : i32
+ %dim = arith.index_castui %dim_i32 : i32 to index
+ %workgroup_id_x = hal.interface.workgroup.id[0] : index
+ %tid = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x]
+
+ %binding0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : memref<?xf32>{%dim}
+ %binding1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<?xf32>{%dim}
+ %binding2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<?xf32>{%dim}
+
+ func.call @simple_mul_abs_negate_workgroup(%binding0, %binding1, %binding2, %dim, %tid) : (memref<?xf32>, memref<?xf32>, memref<?xf32>, index, index) -> ()
+ return
+ }
+ }
+ } // hal.executable.variant
+ } // hal.executable
+
+ func.func @call_mul_abs_negate(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
+ %c0 = arith.constant 0 : index
+ %dim = tensor.dim %arg0, %c0 : tensor<?xf32>
+ %dim_i32 = arith.index_cast %dim : index to i32
+
+ // Dispatch a basic `ret = -|lhs * rhs|` using an external function.
+ %0 = flow.dispatch @executable::@x86_64::@simple_mul_abs_negate[%dim](%dim_i32, %arg0, %arg1) {
+ hal.interface.bindings = [
+ #hal.interface.binding<0, 0>,
+ #hal.interface.binding<0, 1>,
+ #hal.interface.binding<0, 2>
+ ],
+ // HACK: keep the executable live through DCE. Only required when
+ // using the automatic variant selection.
+ hal.executable.ref = [@executable]
+ } : (i32, tensor<?xf32>{%dim}, tensor<?xf32>{%dim}) -> tensor<?xf32>{%dim}
+ return %0 : tensor<?xf32>
+ }
+
+ transform.named_sequence @match_mul_abs_negate(%root: !transform.any_op {transform.readonly}) -> (!transform.any_value, !transform.any_value) {
+ %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %root {
+ ^bb0(%lhs: tensor<?xf32>, %rhs: tensor<?xf32>):
+ // The matcher does not recurse to the constant index + dim because
+ // their only consumer matches only the operation name.
+ %c0 = arith.constant 0 : index
+ %dim = tensor.dim %lhs, %c0 : tensor<?xf32>
+ // --------------------------------------------------------------------
+ %empty = tensor.empty(%dim) {"match.operation_name_only"} : tensor<?xf32>
+ %mul = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]}
+ ins(%lhs, %rhs : tensor<?xf32>, tensor<?xf32>)
+ outs(%empty : tensor<?xf32>) {
+ ^bb0(%in: f32, %in0: f32, %out: f32):
+ %m = arith.mulf %in, %in0 : f32
+ linalg.yield %m : f32
+ } -> tensor<?xf32>
+ %abs = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]}
+ ins(%mul : tensor<?xf32>)
+ outs(%empty : tensor<?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %a = math.absf %in : f32
+ linalg.yield %a : f32
+ } -> tensor<?xf32>
+ // The payload root is compared starting from here, walking up the chain
+ // of producers
+ %neg = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]}
+ ins(%abs : tensor<?xf32>)
+ outs(%empty : tensor<?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %n = arith.negf %in : f32
+ linalg.yield %n : f32
+ } -> tensor<?xf32>
+ } : (!transform.any_op) -> (!transform.any_value, !transform.any_value)
+ transform.yield %ins, %outs : !transform.any_value, !transform.any_value
+ }
+
+ // Rewrite callback for `transform.foreach_match`. The input signature for
+ // this sequence must match exactly with the outputs of the matcher. In this
+ // case the matcher returns the inputs and outputs to the matched dag directly
+ // so we just insert a call to the hand authored function above.
+ transform.named_sequence @cast_and_call_dag(%ins: !transform.any_value {transform.readonly},
+ %out: !transform.any_value {transform.readonly}) {
+ %root = transform.get_defining_op %out : (!transform.any_value) -> !transform.any_op
+ %module = transform.iree.get_nearest_symbol_table %root : (!transform.any_op) -> !transform.any_op
+ %executable = transform.iree.import_symbol @executable into %module if undefined : (!transform.any_op) -> !transform.any_op
+ %func = transform.iree.import_symbol @call_mul_abs_negate into %module if undefined : (!transform.any_op) -> !transform.any_op
+ transform.func.cast_and_call %func(%ins) -> %out after %root {
+ // This specifies how to resolve type mismatches between the arguments
+ // of the function and the inputs from the matcher. In this example,
+ // the only casts this will generate are same-rank tensor casts that
+ // drop static information.
+ transform.type_conversion.tensor.cast_shape_dynamic_dims
+ } : (!transform.any_op, !transform.any_value, !transform.any_value, !transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+
+ // Entry point for the transform interpreter, nested on the full module. This
+ // is because the rewrites needed for importing the custom kernel needs to
+ // add a new symbol to the module's symbol table.
+ transform.named_sequence @__transform_main(%module: !transform.any_op) {
+ // Gather the set of functions within the module.
+ %funcs = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op
+ // For each function in the module, run the matcher on all contained
+ // operations.
+ transform.foreach %funcs : !transform.any_op {
+ ^bb1(%func: !transform.any_op):
+ transform.foreach_match in %func
+ // <matcher name> -> <rewriter name>
+ // Multiple matcher-action pairs can be specified comma separated,
+ // here we are only doing a single kind of match and replace.
+ @match_mul_abs_negate -> @cast_and_call_dag
+ : (!transform.any_op) -> (!transform.any_op)
+ }
+ // Cleanup leftover dead code; cast_and_call does not do replacement, only
+ // rewires uses.
+ transform.apply_dce to %module : !transform.any_op
+ transform.yield
+ }
+}
diff --git a/samples/custom_dispatch/cpu/embedded/functions.c b/samples/custom_dispatch/cpu/embedded/functions.c
index 0746c4e..a962666 100644
--- a/samples/custom_dispatch/cpu/embedded/functions.c
+++ b/samples/custom_dispatch/cpu/embedded/functions.c
@@ -85,3 +85,33 @@
binding1[i] *= binding0[i];
}
}
+
+// `ret = -|lhs * rhs|`
+//
+// Conforms to ABI:
+// #hal.pipeline.layout<push_constants = 1, sets = [
+// <0, bindings = [
+// <0, storage_buffer, ReadOnly>,
+// <1, storage_buffer, ReadOnly>,
+// <2, storage_buffer>
+// ]>
+// ]>
+// With a workgroup size of 64x1x1.
+void simple_mul_abs_negate_workgroup(
+ // vvvv simplification pending (buffer + offset)
+ const float* restrict binding0, const float* restrict binding0_aligned,
+ size_t binding0_offset, size_t binding0_size, size_t binding0_stride,
+ const float* restrict binding1, const float* restrict binding1_aligned,
+ size_t binding1_offset, size_t binding1_size, size_t binding1_stride,
+ float* restrict binding2, float* restrict binding2_aligned,
+ size_t binding2_offset, size_t binding2_size, size_t binding2_stride,
+ // ^^^^ simplification pending (buffer + offset)
+ size_t dim, size_t tid) {
+ size_t end = tid + 64;
+ if (end > dim) end = dim;
+ for (size_t i = tid; i < end; ++i) {
+ float prod = binding0[i] * binding1[i];
+ if (prod >= 0) prod = -prod;
+ binding2[i] = prod + 1;
+ }
+}
diff --git a/samples/custom_dispatch/vulkan/shaders/CMakeLists.txt b/samples/custom_dispatch/vulkan/shaders/CMakeLists.txt
index 9d96884..13b0436 100644
--- a/samples/custom_dispatch/vulkan/shaders/CMakeLists.txt
+++ b/samples/custom_dispatch/vulkan/shaders/CMakeLists.txt
@@ -44,7 +44,18 @@
${CMAKE_CURRENT_SOURCE_DIR}/simple_mul_inplace.glsl
VERBATIM
)
+add_custom_command(
+ OUTPUT one_workgroup_argmax_subgroup_f32.spv
+ DEPENDS one_workgroup_argmax_subgroup_f32.glsl
+ COMMAND ${GLSLC}
+ -fshader-stage=compute
+ --target-spv=spv1.3
+ -o one_workgroup_argmax_subgroup_f32.spv
+ ${CMAKE_CURRENT_SOURCE_DIR}/one_workgroup_argmax_subgroup_f32.glsl
+ VERBATIM
+)
add_custom_target(iree_samples_custom_dispatch_vulkan_shaders_spv DEPENDS
+ ${CMAKE_CURRENT_BINARY_DIR}/one_workgroup_argmax_subgroup_f32.spv
${CMAKE_CURRENT_BINARY_DIR}/simple_mul.spv
${CMAKE_CURRENT_BINARY_DIR}/simple_mul_inplace.spv
)
@@ -56,8 +67,10 @@
SRCS
"example.mlir"
"example_inline.mlir"
+ "example_transform.mlir"
DATA
${_SPV_TARGET}
+ iree::samples::custom_dispatch::vulkan::shaders::example_transform_spec.mlir
TOOLS
FileCheck
iree-compile
diff --git a/samples/custom_dispatch/vulkan/shaders/README.md b/samples/custom_dispatch/vulkan/shaders/README.md
index 00b71f8..fb00a88 100644
--- a/samples/custom_dispatch/vulkan/shaders/README.md
+++ b/samples/custom_dispatch/vulkan/shaders/README.md
@@ -145,3 +145,19 @@
--input=8xf32=4 \
/tmp/example.vmfb
```
+
+## Custom Kernel Match and Replace Scripting Instructions
+
+This is a flow for authoring custom dispatches externally alongside match and
+replace logic that can be fed directly into a pre-built version of the compiler.
+
+In addition to the above steps, when compiling the module, pass in both the
+target module and the transform library implementing the matcher + kernel.
+
+ ```
+ iree-compile \
+ --iree-hal-executable-object-search-path=../iree-build/ \
+ samples/custom_dispatch/vulkan/shaders/example_transform.mlir \
+ --iree-preprocessing-transform-library=samples/custom_dispatch/vulkan/shaders/example_transform_spec.mlir \
+ -o=/tmp/example.vmfb
+ ```
diff --git a/samples/custom_dispatch/vulkan/shaders/example_transform.mlir b/samples/custom_dispatch/vulkan/shaders/example_transform.mlir
new file mode 100644
index 0000000..82662ac
--- /dev/null
+++ b/samples/custom_dispatch/vulkan/shaders/example_transform.mlir
@@ -0,0 +1,74 @@
+// RUN: iree-compile %s \
+// RUN: --iree-hal-executable-object-search-path=$IREE_BINARY_DIR \
+// RUN: --iree-preprocessing-transform-spec-filename=%p/example_transform_spec.mlir | \
+// RUN: iree-run-module \
+// RUN: --device=vulkan \
+// RUN: --module=- \
+// RUN: --function=mixed_invocation \
+// RUN: --input=1x128xf32=4 \
+// RUN: --input=1x128xf32=3 | \
+// RUN: FileCheck %s
+
+// The configuration used for executable compilation.
+// This lets the compiler and runtime know the format and requirements of the
+// executable binaries produced and multiple variants with differing formats
+// and compilation options (architectures, etc) can be embedded for runtime
+// selection.
+// HACK: Currently this must match EXACTLY with the executable target for the
+// custom kernel. For things to be truly portable, we need to be able to compare
+// executable configurations.
+#spirv_target = #hal.executable.target<"vulkan", "vulkan-spirv-fb", {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.3, [Shader, GroupNonUniform, GroupNonUniformArithmetic, GroupNonUniformBallot],
+ [SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
+ #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64], subgroup_size = 64>
+ >
+}>
+
+// The target devices that the program will run on. We can compile and run with
+// multiple targets, but this example is maintaining an implicit requirement
+// that the custom kernel being spliced in is supported by the target device,
+// hence we only support vulkan here. It is possible to hand author a custom
+// kernel that supports multiple targets by specifying an object per-target, but
+// that requires authoring the kernel for multiple targets.
+#vulkan_target = #hal.device.target<"vulkan", {
+ executable_targets = [#spirv_target],
+ // HACK: Vulkan target currently uses the legacy synchronous execution model.
+ legacy_sync
+}>
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0)>
+module @example attributes {hal.device.targets = [#vulkan_target]} {
+
+ // CHECK-LABEL: EXEC @mixed_invocation
+ func.func @mixed_invocation(%arg0: tensor<1x128xf32>, %arg1: tensor<1x128xf32>) -> tensor<1xi64> {
+ // Code gen some other ops - these will interleave with the matched and
+ // replaced ones but naturally won't be able to fuse with them.
+ %add = arith.addf %arg0, %arg1 : tensor<1x128xf32>
+
+ %c0_i64 = arith.constant 0 : i64
+ %cst = arith.constant 0xFF800000 : f32
+ %1 = tensor.empty() : tensor<1xi64>
+ %2 = linalg.fill ins(%c0_i64 : i64) outs(%1 : tensor<1xi64>) -> tensor<1xi64>
+ %3 = tensor.empty() : tensor<1xf32>
+ %4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<1xf32>) -> tensor<1xf32>
+ // Argmax that is the target for the custom kernel. Note that this operation
+ // only has uses for a single result and takes a single input.
+ %5:2 = linalg.generic {indexing_maps = [#map, #map1, #map1],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%add : tensor<1x128xf32>)
+ outs(%4, %2 : tensor<1xf32>, tensor<1xi64>) {
+ ^bb0(%in: f32, %out: f32, %out_0: i64):
+ %6 = linalg.index 1 : index
+ %7 = arith.index_cast %6 : index to i64
+ %8 = arith.maximumf %in, %out : f32
+ %9 = arith.cmpf ogt, %in, %out : f32
+ %10 = arith.select %9, %7, %out_0 : i64
+ linalg.yield %8, %10 : f32, i64
+ } -> (tensor<1xf32>, tensor<1xi64>)
+
+ // CHECK: 1xi64=0
+ return %5#1 : tensor<1xi64>
+ }
+} // module
diff --git a/samples/custom_dispatch/vulkan/shaders/example_transform_spec.mlir b/samples/custom_dispatch/vulkan/shaders/example_transform_spec.mlir
new file mode 100644
index 0000000..065b795
--- /dev/null
+++ b/samples/custom_dispatch/vulkan/shaders/example_transform_spec.mlir
@@ -0,0 +1,171 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+// The configuration used for executable compilation.
+// This specifies the device configurations that support this custom kernel.
+#spirv_target = #hal.executable.target<"vulkan", "vulkan-spirv-fb", {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.3, [Shader, GroupNonUniform, GroupNonUniformArithmetic, GroupNonUniformBallot],
+ [SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
+ #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64], subgroup_size = 64>
+ >
+}>
+
+module attributes {transform.with_named_sequence} {
+ func.func private @argmax_1d_f32_entry_point(%arg0: tensor<1x?xf32>) -> tensor<1xi64> {
+ %c1 = arith.constant 1 : index
+ %dim = tensor.dim %arg0, %c1 : tensor<1x?xf32>
+ // Note: This is not safe if the dim size exceeds INT32_MAX. To pass a 64
+ // bit value it must be broken down into two 32-bit values for the high and
+ // low bits.
+ %dim_i32 = arith.index_cast %dim : index to i32
+ // Inline external dispatch that conforms to the ABI that the kernel
+ // requires. This is the primary reason for the surrounding function as
+ // details like tensor shape and push constants need to line up after
+ // splicing in the custom dispatch. This allows the kernel author to manage
+ // such details by hand without needing the rewrite patterns to worry about
+ // things like order of push constants.
+ %4 = hal.dispatch.extern "main"[%dim](%dim_i32, %arg0) : (i32, tensor<1x?xf32>{%dim}) -> tensor<1xi64>
+ count(%device: !hal.device, %workload: index) -> (index, index, index) {
+ %c1_0 = arith.constant 1 : index
+ hal.return %c1_0, %c1_0, %c1_0 : index, index, index
+ }
+ layout(#hal.pipeline.layout<push_constants = 1, sets = [
+ <0, bindings = [
+ <0, storage_buffer, ReadOnly>,
+ <1, storage_buffer>
+ ]>
+ ]>)
+ bindings([
+ #hal.interface.binding<0, 0>,
+ #hal.interface.binding<0, 1>
+ ])
+ objects({
+ #spirv_target ordinal(0) = [
+ #hal.executable.object<{
+ path = "samples/custom_dispatch/vulkan/shaders/one_workgroup_argmax_subgroup_f32.spv"
+ }>
+ ]
+ })
+ return %4 : tensor<1xi64>
+ }
+
+ // Custom matcher for argmax operations equivalent to the custom kernel. This
+ // matcher will be run one-by-one on all operations contained within the
+ // target function. On success, it will return the handle to the matched
+ // argmax operation.
+ transform.named_sequence @match_argmax(%generic: !transform.any_op {transform.readonly}) -> (!transform.any_op) {
+ // Fail fast on non-linalg generics.
+ transform.match.operation_name %generic ["linalg.generic"] : !transform.any_op
+ %matched = transform.match.structured failures(propagate) %generic : (!transform.any_op) -> (!transform.any_op) {
+ ^bb1(%argmax: !transform.any_op):
+ // Verify that the rank (i.e. number of loops) of the linalg op is 2,
+ // with one parallel iterator and one reduction iterator.
+ // TODO: Add optionality for the parallel dimensions.
+ %c2 = transform.param.constant 2 : i64 -> !transform.param<i64>
+ %rank = transform.match.structured.rank %argmax : (!transform.any_op) -> !transform.param<i64>
+ transform.match.param.cmpi eq %rank, %c2 : !transform.param<i64>
+ transform.match.structured.dim %argmax[0] {parallel} : !transform.any_op
+ transform.match.structured.dim %argmax[-1] {reduction} : !transform.any_op
+
+ // Verify a single input (target vector to compute the argmax of) and two
+ // outputs, one for the maximum value and one for the index.
+ %c1 = transform.param.constant 1 : i64 -> !transform.param<i64>
+ %n_inputs = transform.match.structured.num_inputs %argmax : (!transform.any_op) -> !transform.param<i64>
+ transform.match.param.cmpi eq %n_inputs, %c1 : !transform.param<i64>
+ %n_outputs = transform.match.structured.num_inits %argmax : (!transform.any_op) -> !transform.param<i64>
+ transform.match.param.cmpi eq %n_outputs, %c2 : !transform.param<i64>
+
+ transform.match.structured.yield %argmax : !transform.any_op
+ }
+
+ // Verify the operand shapes of the linalg op. For example, in the below,
+ // dim 0 must be statically 1, and dim 1 must be statically divisible by 64.
+ %in0 = transform.get_operand %matched[0] : (!transform.any_op) -> !transform.any_value
+ transform.iree.match.cast_compatible_type %in0 = tensor<1x?xf32> : !transform.any_value
+ transform.iree.match.dim_is_multiple_of %in0[1], 64 : !transform.any_value
+ %out0 = transform.get_operand %matched[1] : (!transform.any_op) -> !transform.any_value
+ transform.iree.match.cast_compatible_type %out0 = tensor<1xf32> : !transform.any_value
+ %out1 = transform.get_operand %matched[2] : (!transform.any_op) -> !transform.any_value
+ transform.iree.match.cast_compatible_type %out1 = tensor<1xi64> : !transform.any_value
+
+ // Verify the region of the argmax op. This does a structural comparison of
+ // region(s) of the payload operation against the single operation contained
+ // within the body of this operation. This does no verification of other
+ // input types/attributes. This is because typically for kernel matching,
+ // the most important part to get exactly right is the inner loop. Otherwise
+ // small variations to shape information and iterator counts and such are
+ // better suited for more general matchers.
+ transform.iree.match.regions %matched : !transform.any_op {
+ ^bb0(%target: tensor<1x?xf32>, %empty_max: tensor<1xf32>, %empty_idx: tensor<1xi64>):
+ %5:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0)>,
+ affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%target : tensor<1x?xf32>)
+ outs(%empty_max, %empty_idx : tensor<1xf32>, tensor<1xi64>) {
+ ^bb0(%in: f32, %out: f32, %out_0: i64):
+ %6 = linalg.index 1 : index
+ %7 = arith.index_cast %6 : index to i64
+ %8 = arith.maximumf %in, %out : f32
+ %9 = arith.cmpf ogt, %in, %out : f32
+ %10 = arith.select %9, %7, %out_0 : i64
+ linalg.yield %8, %10 : f32, i64
+ } -> (tensor<1xf32>, tensor<1xi64>)
+ }
+ transform.yield %generic : !transform.any_op
+ }
+
+ // Rewrite callback for `transform.foreach_match`. The input signature for
+ // this sequence must match exactly with the outputs of the matcher. In this
+ // case we just take the argmax as an input, import the entry point for the
+ // custom kernel authored above, and replace the users of the argmax with a
+ // call to the function.
+ transform.named_sequence @cast_and_call_argmax(%argmax: !transform.any_op {transform.readonly}) {
+ %module = transform.iree.get_nearest_symbol_table %argmax : (!transform.any_op) -> !transform.any_op
+ %func = transform.iree.import_symbol @argmax_1d_f32_entry_point into %module : (!transform.any_op) -> !transform.any_op
+ %ins = transform.get_operand %argmax[0] : (!transform.any_op) -> !transform.any_value
+ %outs = transform.get_result %argmax[1] : (!transform.any_op) -> !transform.any_value
+ transform.func.cast_and_call %func(%ins) -> %outs before %argmax {
+ // This specifies how to resolve type mismatches between the arguments
+ // of the function and the inputs to the argmax. In this example, the
+ // only casts this will generate are same-rank tensor casts that drop
+ // static information.
+ transform.type_conversion.tensor.cast_shape_dynamic_dims
+ } : (!transform.any_op, !transform.any_value, !transform.any_value, !transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+
+ // Entry point for the transform interpreter, nested on the full module. This
+ // is because the rewrites needed for importing the custom kernel needs to
+ // add a new symbol to the module's symbol table.
+ transform.named_sequence @__transform_main(%module: !transform.any_op) {
+ // Gather the set of functions within the module.
+ %funcs = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op
+ // For each function in the module, run the matcher on all contained
+ // operations.
+ transform.foreach %funcs : !transform.any_op {
+ ^bb1(%func: !transform.any_op):
+ transform.foreach_match in %func
+ // <matcher name> -> <rewriter name>
+ // Multiple matcher-action pairs can be specified comma separated,
+ // here we are only doing a single kind of match and replace.
+ //
+ // Note that the operations within the module are walked in
+ // post-order, meaning actions must be very careful in their
+ // replacements not to modify successors of operations. Nested
+ // regions and DAG roots will be visited last so it is safest to
+ // do matching + replacement on the root of the DAG rather than
+ // trying to look ahead. The other option is to avoid dce/cse until
+ // after the walk is complete.
+ @match_argmax -> @cast_and_call_argmax
+ : (!transform.any_op) -> (!transform.any_op)
+ }
+ // Cleanup now dead instances of argmax.
+ transform.apply_dce to %module : !transform.any_op
+ transform.yield
+ }
+}
diff --git a/samples/custom_dispatch/vulkan/shaders/one_workgroup_argmax_subgroup_f32.glsl b/samples/custom_dispatch/vulkan/shaders/one_workgroup_argmax_subgroup_f32.glsl
new file mode 100644
index 0000000..c4e1e26
--- /dev/null
+++ b/samples/custom_dispatch/vulkan/shaders/one_workgroup_argmax_subgroup_f32.glsl
@@ -0,0 +1,57 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+// `ret = argmax(in)`
+//
+// Conforms to ABI:
+// #hal.pipeline.layout<push_constants = 1, sets = [
+// <0, bindings = [
+// <0, storage_buffer, ReadOnly>,
+// <1, storage_buffer>
+// ]>
+// ]>
+
+#version 450 core
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_KHR_shader_subgroup_arithmetic : enable
+#extension GL_KHR_shader_subgroup_ballot : enable
+
+layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
+
+layout(set=0, binding=0) buffer InputBuffer { float data[]; } Input;
+layout(set=0, binding=1) buffer OutputBuffer { uvec2 data; } Output;
+
+layout(push_constant) uniform PushConstants { uint totalCount; }; // Total number of scalars
+
+// Each workgroup contains just one subgroup.
+
+void main() {
+ uint laneID = gl_LocalInvocationID.x;
+ uint laneCount = gl_WorkGroupSize.x;
+
+ float laneMax = Input.data[laneID];
+ uint laneResult = 0;
+
+ uint numBatches = totalCount / (laneCount);
+ for (int i = 1; i < numBatches; ++i) {
+ uint idx = laneCount * i + laneID;
+ float new_in = Input.data[idx];
+ laneResult = new_in > laneMax ? idx : laneResult;
+ laneMax = max(laneMax, new_in);
+ }
+
+ // Final reduction with one subgroup
+ float wgMax = subgroupMax(laneMax);
+
+ // Find the smallest thread holding the maximum value.
+ bool eq = wgMax == laneMax;
+ uvec4 ballot = subgroupBallot(eq);
+ uint lsb = subgroupBallotFindLSB(ballot);
+
+ uint upper32bits = 0;
+ if (laneID == lsb) Output.data = uvec2(laneResult, upper32bits);
+}
+