Add support for using PDL to replicate the functionality in MLP sample that uses Transform dialect. (#16453)
This PR adds a sample that uses PDL to match a subgraph corresponding to
MLP and replaces with a `flow.dispatch`, that invokes an external
function which is provided by a system plugin.
To enable this an new pass `--iree-preprocessing-apply-pdl-patterns` is
added that has an option to read in the PDL pattern file and applies it
to the input program.
To support this a custom rewrite function `rewriteAsFlowDispatch` is
added that takes as arguments
- the root of the matched DAG (this is replaced by the matcher)
- A list of values that represent the dynamic dimensions of the results
of the root
- The name of the external function provided by the plugin
- The operands to the external function.
What is missing is the support to specify the workload and number of
workgroups to use while invoking the external function. This could be
solved by having a custom PDL operation (if possible) that accepts the
workload and a region that computes the number of workgroups based on
the workload. For now that is not handled, and the nubmer of workgroups
is set to `{1, 1, 1}`. This is still a useful thing to
prototype/checkpoint, but for any reasonable deployment this needs to be
fixed.
This PR adds a sample that matches the input in TOSA dialect. Due to the
TOSA dialect definition, the matmul now has a batch dimension as well.
To be possible to use the same plugin implementation, the `llvm.bareptr`
calling convention is used for the external function so that the inputs
(outputs) are passed (passed by reference) using pointer, offset only,
and `memref.extract_strided_metadata` is used to extract this
information from the multi-dimensional memrefs within the dispatch.
diff --git a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/BUILD.bazel b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/BUILD.bazel
index 8a57538..2e692a5 100644
--- a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/BUILD.bazel
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/BUILD.bazel
@@ -16,6 +16,7 @@
name = "lit",
srcs = enforce_glob(
[
+ "apply_pdl_patterns_tosa.mlir",
"auto_input_conversion.mlir",
"convert_i48_to_i64.mlir",
"strip_signedness.mlir",
@@ -23,8 +24,14 @@
"verify_compiler_tosa_input_legality.mlir",
],
include = ["*.mlir"],
+ exclude = [
+ "tosa.pdl.mlir",
+ ],
),
cfg = "//compiler:lit.cfg.py",
+ data = [
+ "tosa.pdl.mlir",
+ ],
tools = [
"//tools:iree-compile",
"//tools:iree-opt",
diff --git a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/CMakeLists.txt b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/CMakeLists.txt
index d45111e..c61f067 100644
--- a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/CMakeLists.txt
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/CMakeLists.txt
@@ -14,6 +14,7 @@
NAME
lit
SRCS
+ "apply_pdl_patterns_tosa.mlir"
"auto_input_conversion.mlir"
"convert_i48_to_i64.mlir"
"strip_signedness.mlir"
@@ -23,6 +24,8 @@
FileCheck
iree-compile
iree-opt
+ DATA
+ tosa.pdl.mlir
)
### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/apply_pdl_patterns_tosa.mlir b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/apply_pdl_patterns_tosa.mlir
new file mode 100644
index 0000000..c889013
--- /dev/null
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/apply_pdl_patterns_tosa.mlir
@@ -0,0 +1,52 @@
+// RUN: iree-opt --pass-pipeline="builtin.module(iree-preprocessing-apply-pdl-patterns{patterns-file=%p/tosa.pdl.mlir})" %s | FileCheck %s
+
+// CHECK-LABEL: stream.executable private @mlp_external_executable
+// CHECK: stream.executable.export public @mlp_external_entry_point
+// CHECK: builtin.module
+// CHECK: func.func private @mlp_external
+// CHECK-SAME: (memref<f32>, index, memref<f32>, index, memref<f32>, index, i32, i32, i32)
+// CHECK-SAME: attributes {llvm.bareptr = [true]}
+// CHECK: func.func @mlp_external_entry_point
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !stream.binding
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !stream.binding
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: !stream.binding
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: i32
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i32
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: i32
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[STREAM0:.+]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<1x2x4xf32>
+// CHECK-NEXT: %[[STREAM0_BASE:[a-zA-Z0-9_]+]],
+// CHECK-SAME: = memref.extract_strided_metadata %[[STREAM0]]
+// CHECK: %[[STREAM1:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<1x4x8xf32>
+// CHECK-NEXT: %[[STREAM1_BASE:[a-zA-Z0-9_]+]],
+// CHECK-SAME: = memref.extract_strided_metadata %[[STREAM1]]
+// CHECK: %[[STREAM2:.+]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<1x2x8xf32>
+// CHECK-NEXT: %[[STREAM2_BASE:[a-zA-Z0-9_]+]],
+// CHECK-SAME: = memref.extract_strided_metadata %[[STREAM2]]
+// CHECK: call @mlp_external
+// CHECK-SAME: %[[STREAM0_BASE]], %[[C0]], %[[STREAM1_BASE]], %[[C0]], %[[STREAM2_BASE]], %[[C0]], %[[ARG3]], %[[ARG4]], %[[ARG5]]
+
+// CHECK: func.func @mlp_invocation
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<2x4xf32>, %[[ARG1:.+]]: tensor<4x8xf32>)
+// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : i32
+// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : i32
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : i32
+// CHECK-DAG: %[[LHS:.+]] = tosa.reshape %[[ARG0]]
+// CHECK-DAG: %[[RHS:.+]] = tosa.reshape %[[ARG1]]
+// CHECK: %[[RESULT:.+]] = flow.dispatch
+// CHECK-SAME: @mlp_external_executable::@mlp_external_entry_point
+// CHECK-SAME: (%[[LHS]], %[[RHS]], %[[C2]], %[[C8]], %[[C4]])
+// CHECK: tosa.negate %[[RESULT]]
+
+func.func @mlp_invocation(%lhs: tensor<2x4xf32>, %rhs : tensor<4x8xf32>) -> tensor<2x8xf32> {
+ %lhs_3D = tosa.reshape %lhs {new_shape = array<i64 : 1, 2, 2>} : (tensor<2x4xf32>) -> tensor<1x2x4xf32>
+ %rhs_3D = tosa.reshape %rhs {new_shape = array<i64 : 1, 2, 2>} : (tensor<4x8xf32>) -> tensor<1x4x8xf32>
+ %0 = tosa.matmul %lhs_3D, %rhs_3D : (tensor<1x2x4xf32>, tensor<1x4x8xf32>) -> tensor<1x2x8xf32>
+ %1 = tosa.clamp %0 {
+ min_int = 0 : i64, max_int = 9223372036854775807 : i64,
+ min_fp = 0.0 : f32, max_fp = 3.4028235e+38 : f32}
+ : (tensor<1x2x8xf32>) -> tensor<1x2x8xf32>
+ %2 = tosa.negate %1 : (tensor<1x2x8xf32>) -> tensor<1x2x8xf32>
+ %3 = tosa.reshape %2 {new_shape = array<i64 : 2, 2>} : (tensor<1x2x8xf32>) -> tensor<2x8xf32>
+ return %3 : tensor<2x8xf32>
+}
diff --git a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/tosa.pdl.mlir b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/tosa.pdl.mlir
new file mode 100644
index 0000000..e4f02da
--- /dev/null
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/tosa.pdl.mlir
@@ -0,0 +1,125 @@
+// PDL pattern spec to match an MLP and offload to an external function
+//
+// ```
+// void mlp_external(void *params, void *context, void *reserved)
+// ```
+//
+// which is the expected signature of an external function implemented
+// provided by a system plugin. See
+// samples/custom_dispatch/cpu/plugin/system_plugin.c for an example.
+//
+// The `params` is the following struct
+//
+// ```
+// struct mlp_params_t {
+// const float *restrict lhs;
+// size_t lhs_offset;
+// const float *restrict rhs;
+// size_t rhs_offset;
+// int32_t M;
+// int32_t N;
+// int32_t K;
+// float *restrict result;
+// size_t result_offset;
+// };
+// ```
+//
+// In MLIR this corresponds to the function
+//
+// ```
+// func.func @mlp_external(%lhs : memref<..xf32>, %rhs : memref<..xf32>,
+// %M: i32, %N : i32, %K : i32, %result : memref<..xf32>)
+// ```
+//
+// Note: In the above struct a `pointer, offset` pair represents a buffer
+// passed into the external function. So any access to `lhs`, `rhs` and
+// `result` is valid only if accessed as `lhs[lhs_offset + ...]`,
+// `rhs[rhs_offset + ]` and `result[result_offset + ...]`.
+pdl.pattern @mlp : benefit(1) {
+
+ // PDL matcher to match the MLP computation. This pattern is expected to
+ // match
+ //
+ // ```
+ // %result = func.call @mlp_external(%lhs : tensor<...xf32>,
+ // %rhs : tensor<..xf32>, %M : i32, %N : i32, %K : i32) -> tensor<..xf32>
+ // ```
+ %lhs_type = pdl.type
+ %lhs = pdl.operand : %lhs_type
+ %rhs_type = pdl.type
+ %rhs = pdl.operand : %rhs_type
+ %matmul_type = pdl.type
+ %min_int = pdl.attribute = 0 : i64
+ %max_int = pdl.attribute
+ %min_fp = pdl.attribute = 0.0 : f32
+ %max_fp = pdl.attribute
+ %matmul = pdl.operation "tosa.matmul"(%lhs, %rhs : !pdl.value, !pdl.value)
+ -> (%matmul_type : !pdl.type)
+ %element_type = pdl.type : f32
+ pdl.apply_native_constraint "checkTensorElementType"(%lhs_type, %element_type : !pdl.type, !pdl.type)
+ pdl.apply_native_constraint "checkTensorElementType"(%rhs_type, %element_type : !pdl.type, !pdl.type)
+ pdl.apply_native_constraint "checkTensorElementType"(%matmul_type, %element_type : !pdl.type, !pdl.type)
+
+ %matmul_result = pdl.result 0 of %matmul
+ %relu_type = pdl.type
+ %relu = pdl.operation "tosa.clamp"(%matmul_result : !pdl.value) {
+ "min_int" = %min_int, "max_int" = %max_int,
+ "min_fp" = %min_fp, "max_fp" = %max_fp}
+ -> (%relu_type : !pdl.type)
+
+ pdl.rewrite %matmul {
+ // The pattern above matched `%result`, `%lhs`, `%rhs` needed for the
+ // external function call. The values of `%M`, `%N` and `%K` need to
+ // be generated.
+ %one_val = pdl.attribute = 1 : index
+ %two_val = pdl.attribute = 2 : index
+ %index_type = pdl.type : index
+ %one_op = pdl.operation "arith.constant" {"value" = %one_val} -> (%index_type : !pdl.type)
+ %one = pdl.result 0 of %one_op
+ %two_op = pdl.operation "arith.constant" {"value" = %two_val} -> (%index_type : !pdl.type)
+ %two = pdl.result 0 of %two_op
+ %i32_type = pdl.type : i32
+ %m_op = pdl.operation "tensor.dim"(%lhs, %one : !pdl.value, !pdl.value)
+ %m = pdl.result 0 of %m_op
+ %n_op = pdl.operation "tensor.dim"(%rhs, %two : !pdl.value, !pdl.value)
+ %n = pdl.result 0 of %n_op
+ %k_op = pdl.operation "tensor.dim"(%lhs, %two : !pdl.value, !pdl.value)
+ %k = pdl.result 0 of %k_op
+ %m_i32_op = pdl.operation "arith.index_cast"(%m : !pdl.value) -> (%i32_type : !pdl.type)
+ %m_i32 = pdl.result 0 of %m_i32_op
+ %n_i32_op = pdl.operation "arith.index_cast"(%n : !pdl.value) -> (%i32_type : !pdl.type)
+ %n_i32 = pdl.result 0 of %n_i32_op
+ %k_i32_op = pdl.operation "arith.index_cast"(%k : !pdl.value) -> (%i32_type : !pdl.type)
+ %k_i32 = pdl.result 0 of %k_i32_op
+
+ %replaced_values_dims = pdl.range : !pdl.range<value>
+ %input_values = pdl.range %lhs, %rhs : !pdl.value, !pdl.value
+ %replaced_value = pdl.result 0 of %relu
+ %replaced_values = pdl.range %replaced_value : !pdl.value
+ %other_operands = pdl.range %m_i32, %n_i32, %k_i32 : !pdl.value, !pdl.value, !pdl.value
+
+ // The `rewriteAsFlowDispatch` is a rewrite function that allows
+ // converting the matched dag into a call to the external function call
+ // provided by a system plugin. The rewrite method expects the following
+ // arguments
+ // - the root of the matched DAG. This op will be erased after the call.
+ // - `fn_name` the name of the function that is provided externally
+ // (using a plugin).
+ // - `input_values` are values that are captures as the part of the match
+ // and are inputs to the match.
+ // - `replaced_values` are the values that are captured as part of the
+ // match and are replaced by the `flow.dispatch`. The `flow.dispatch`
+ // returns as many values as `replaced_values` (and of same type).
+ // - `replaced_values_dims` are the values for the dynamic dimensions of
+ // all the `tensor` values in `replaced_values`. For matches that could
+ // be static or dynamic, it should be assumed that the shape is dynamic
+ // and the value needs to be passed to the rewrite function.
+ // - `other_operands` same as `input_values`, but kept separate to allow
+ // flexibility of where the results are passed through the ABI boundary.
+ %fn_name = pdl.attribute = "mlp_external"
+ pdl.apply_native_rewrite "rewriteAsFlowDispatch"(
+ %relu, %fn_name, %input_values, %replaced_values, %replaced_values_dims, %other_operands
+ : !pdl.operation, !pdl.attribute, !pdl.range<value>, !pdl.range<value>, !pdl.range<value>, !pdl.range<value>)
+ }
+}
+
diff --git a/compiler/plugins/input/Torch/torch-iree/InputConversion/test/CMakeLists.txt b/compiler/plugins/input/Torch/torch-iree/InputConversion/test/CMakeLists.txt
index 47dad20..6f86276 100644
--- a/compiler/plugins/input/Torch/torch-iree/InputConversion/test/CMakeLists.txt
+++ b/compiler/plugins/input/Torch/torch-iree/InputConversion/test/CMakeLists.txt
@@ -2,6 +2,7 @@
NAME
lit
SRCS
+ "apply_pdl_patterns_torch.mlir"
"assume_strict_symbols.mlir"
"auto_input_conversion.mlir"
"attention.mlir"
@@ -12,6 +13,8 @@
"scatter.mlir"
"sort.mlir"
"torch_to_iree.mlir"
+ DATA
+ "torch.pdl.mlir"
TOOLS
FileCheck
iree-compile
diff --git a/compiler/plugins/input/Torch/torch-iree/InputConversion/test/apply_pdl_patterns_torch.mlir b/compiler/plugins/input/Torch/torch-iree/InputConversion/test/apply_pdl_patterns_torch.mlir
new file mode 100644
index 0000000..3275308
--- /dev/null
+++ b/compiler/plugins/input/Torch/torch-iree/InputConversion/test/apply_pdl_patterns_torch.mlir
@@ -0,0 +1,75 @@
+// RUN: iree-opt --pass-pipeline="builtin.module(iree-preprocessing-apply-pdl-patterns{patterns-file=%p/torch.pdl.mlir}, cse)" %s | FileCheck %s
+
+// CHECK-LABEL: stream.executable private @mlp_external_executable
+// CHECK: stream.executable.export public @mlp_external_entry_point
+// CHECK: builtin.module
+// CHECK: func.func private @mlp_external
+// CHECK-SAME: (memref<f32>, index, memref<f32>, index, memref<f32>, index, i32, i32, i32)
+// CHECK-SAME: attributes {llvm.bareptr = [true]}
+// CHECK: func.func @mlp_external_entry_point
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !stream.binding
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !stream.binding
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: !stream.binding
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: i32
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i32
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: i32
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG9:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG10:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG11:[a-zA-Z0-9]+]]: index
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[STREAM0:.+]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<?x?xf32>{%[[ARG6]], %[[ARG7]]}
+// CHECK-NEXT: %[[STREAM0_BASE:[a-zA-Z0-9_]+]],
+// CHECK-SAME: = memref.extract_strided_metadata %[[STREAM0]]
+// CHECK: %[[STREAM1:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<?x?xf32>{%[[ARG8]], %[[ARG9]]}
+// CHECK-NEXT: %[[STREAM1_BASE:[a-zA-Z0-9_]+]],
+// CHECK-SAME: = memref.extract_strided_metadata %[[STREAM1]]
+// CHECK: %[[STREAM2:.+]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<?x?xf32>{%[[ARG10]], %[[ARG11]]}
+// CHECK-NEXT: %[[STREAM2_BASE:[a-zA-Z0-9_]+]],
+// CHECK-SAME: = memref.extract_strided_metadata %[[STREAM2]]
+// CHECK: call @mlp_external
+// CHECK-SAME: %[[STREAM0_BASE]], %[[C0]], %[[STREAM1_BASE]], %[[C0]], %[[STREAM2_BASE]], %[[C0]], %[[ARG3]], %[[ARG4]], %[[ARG5]]
+
+// CHECK: func.func @mlp_invocation
+// CHECK-SAME: (%[[LHS:[a-zA-Z0-9]+]]: tensor<?x?xf32>, %[[RHS:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1
+// CHECK: %[[M:.+]] = tensor.dim %[[LHS]], %[[C0]]
+// CHECK: %[[N:.+]] = tensor.dim %[[RHS]], %[[C1]]
+// CHECK: %[[K:.+]] = tensor.dim %[[LHS]], %[[C1]]
+// CHECK: %[[M_I32:.+]] = arith.index_cast %[[M]] : index to i32
+// CHECK: %[[N_I32:.+]] = arith.index_cast %[[N]] : index to i32
+// CHECK: %[[K_I32:.+]] = arith.index_cast %[[K]] : index to i32
+// CHECK: %[[K_0:.+]] = tensor.dim %[[RHS]], %[[C0]]
+// CHECK: %[[RESULT:.+]] = flow.dispatch
+// CHECK-SAME: @mlp_external_executable::@mlp_external_entry_point
+// CHECK-SAME: (%[[LHS]], %[[RHS]], %[[M_I32]], %[[N_I32]], %[[K_I32]], %[[M]], %[[K]], %[[K_0]], %[[N]], %[[M]], %[[N]])
+// CHECK: linalg.generic
+// CHECK-SAME: ins(%[[RESULT]] :
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @mlp_invocation(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %cst = arith.constant 0.0 : f32
+ %dim0 = tensor.dim %lhs, %c0 : tensor<?x?xf32>
+ %dim1 = tensor.dim %rhs, %c1 : tensor<?x?xf32>
+ %empty = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
+ %torch_lhs = torch_c.from_builtin_tensor %lhs : tensor<?x?xf32> -> !torch.vtensor<[?, ?], f32>
+ %torch_rhs = torch_c.from_builtin_tensor %rhs : tensor<?x?xf32> -> !torch.vtensor<[?, ?], f32>
+ %mm = torch.aten.mm %torch_lhs, %torch_rhs
+ : !torch.vtensor<[?, ?], f32>, !torch.vtensor<[?, ?], f32> -> !torch.vtensor<[?, ?], f32>
+ %relu = torch.aten.relu %mm : !torch.vtensor<[?, ?], f32> -> !torch.vtensor<[?, ?], f32>
+ %cast= torch_c.to_builtin_tensor %relu : !torch.vtensor<[?, ?], f32> -> tensor<?x?xf32>
+ %negf = linalg.generic {
+ indexing_maps = [#map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%cast : tensor<?x?xf32>) outs(%empty : tensor<?x?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32):
+ %0 = arith.negf %b0 : f32
+ linalg.yield %0 : f32
+ } -> tensor<?x?xf32>
+ return %negf : tensor<?x?xf32>
+}
diff --git a/compiler/plugins/input/Torch/torch-iree/InputConversion/test/torch.pdl.mlir b/compiler/plugins/input/Torch/torch-iree/InputConversion/test/torch.pdl.mlir
new file mode 100644
index 0000000..2ae7532
--- /dev/null
+++ b/compiler/plugins/input/Torch/torch-iree/InputConversion/test/torch.pdl.mlir
@@ -0,0 +1,118 @@
+// PDL pattern spec to match an MLP and offload to an external function
+//
+// ```
+// void mlp_external(void *params, void *context, void *reserved)
+// ```
+//
+// which is the expected signature of an external function implemented
+// provided by a system plugin. See
+// samples/custom_dispatch/cpu/plugin/system_plugin.c for an example.
+//
+// The `params` is the following struct
+//
+// ```
+// struct mlp_params_t {
+// const float *restrict lhs;
+// size_t lhs_offset;
+// const float *restrict rhs;
+// size_t rhs_offset;
+// int32_t M;
+// int32_t N;
+// int32_t K;
+// float *restrict result;
+// size_t result_offset;
+// };
+// ```
+//
+// In MLIR this corresponds to the function
+//
+// ```
+// func.func @mlp_external(%lhs : memref<..xf32>, %rhs : memref<..xf32>,
+// %M: i32, %N : i32, %K : i32, %result : memref<..xf32>)
+// ```
+//
+// Note: In the above struct a `pointer, offset` pair represents a buffer
+// passed into the external function. So any access to `lhs`, `rhs` and
+// `result` is valid only if accessed as `lhs[lhs_offset + ...]`,
+// `rhs[rhs_offset + ]` and `result[result_offset + ...]`.
+pdl.pattern @mlp : benefit(1) {
+
+ // PDL matcher to match the MLP computation. This pattern is expected to
+ // match
+ //
+ // ```
+ // %result = func.call @mlp_external(%lhs : tensor<...xf32>,
+ // %rhs : tensor<..xf32>, %M : i32, %N : i32, %K : i32) -> tensor<..xf32>
+ // ```
+ %lhs = pdl.operand
+ %rhs = pdl.operand
+ %lhs_type = pdl.type : !torch.vtensor<[?,?],f32>
+ %rhs_type = pdl.type : !torch.vtensor<[?,?],f32>
+ %lhs_torch = pdl.operation "torch_c.from_builtin_tensor"(%lhs : !pdl.value) -> (%lhs_type : !pdl.type)
+ %lhs_val = pdl.result 0 of %lhs_torch
+ %rhs_torch = pdl.operation "torch_c.from_builtin_tensor"(%rhs : !pdl.value) -> (%rhs_type : !pdl.type)
+ %rhs_val = pdl.result 0 of %rhs_torch
+ %matmul_type = pdl.type : !torch.vtensor<[?,?],f32>
+ %matmul = pdl.operation "torch.aten.mm"(%lhs_val, %rhs_val : !pdl.value, !pdl.value) -> (%matmul_type : !pdl.type)
+ %matmul_result = pdl.result 0 of %matmul
+ %relu = pdl.operation "torch.aten.relu"(%matmul_result : !pdl.value) -> (%matmul_type : !pdl.type)
+ %result_type = pdl.type
+ %relu_val = pdl.result 0 of %relu
+ %cast = pdl.operation "torch_c.to_builtin_tensor"(%relu_val : !pdl.value) -> (%result_type : !pdl.type)
+
+ pdl.rewrite %matmul {
+ // The pattern above matched `%result`, `%lhs`, `%rhs` needed for the
+ // external function call. The values of `%M`, `%N` and `%K` need to
+ // be generated.
+ %zero_val = pdl.attribute = 0 : index
+ %one_val = pdl.attribute = 1 : index
+ %index_type = pdl.type : index
+ %zero_op = pdl.operation "arith.constant" {"value" = %zero_val} -> (%index_type : !pdl.type)
+ %zero = pdl.result 0 of %zero_op
+ %one_op = pdl.operation "arith.constant" {"value" = %one_val} -> (%index_type : !pdl.type)
+ %one = pdl.result 0 of %one_op
+ %i32_type = pdl.type : i32
+ %m_op = pdl.operation "tensor.dim"(%lhs, %zero : !pdl.value, !pdl.value)
+ %m = pdl.result 0 of %m_op
+ %n_op = pdl.operation "tensor.dim"(%rhs, %one : !pdl.value, !pdl.value)
+ %n = pdl.result 0 of %n_op
+ %k_op = pdl.operation "tensor.dim"(%lhs, %one : !pdl.value, !pdl.value)
+ %k = pdl.result 0 of %k_op
+ %m_i32_op = pdl.operation "arith.index_cast"(%m : !pdl.value) -> (%i32_type : !pdl.type)
+ %m_i32 = pdl.result 0 of %m_i32_op
+ %n_i32_op = pdl.operation "arith.index_cast"(%n : !pdl.value) -> (%i32_type : !pdl.type)
+ %n_i32 = pdl.result 0 of %n_i32_op
+ %k_i32_op = pdl.operation "arith.index_cast"(%k : !pdl.value) -> (%i32_type : !pdl.type)
+ %k_i32 = pdl.result 0 of %k_i32_op
+
+ %replaced_values_dims = pdl.range %m, %n : !pdl.value, !pdl.value
+ %input_values = pdl.range %lhs, %rhs : !pdl.value, !pdl.value
+ %replaced_value = pdl.result 0 of %cast
+ %replaced_values = pdl.range %replaced_value : !pdl.value
+ %other_operands = pdl.range %m_i32, %n_i32, %k_i32 : !pdl.value, !pdl.value, !pdl.value
+
+ // The `rewriteAsFlowDispatch` is a rewrite function that allows
+ // converting the matched dag into a call to the external function call
+ // provided by a system plugin. The rewrite method expects the following
+ // arguments
+ // - the root of the matched DAG. This op will be erased after the call.
+ // - `fn_name` the name of the function that is provided externally
+ // (using a plugin).
+ // - `input_values` are values that are captures as the part of the match
+ // and are inputs to the match.
+ // - `replaced_values` are the values that are captured as part of the
+ // match and are replaced by the `flow.dispatch`. The `flow.dispatch`
+ // returns as many values as `replaced_values` (and of same type).
+ // - `replaced_values_dims` are the values for the dynamic dimensions of
+ // all the `tensor` values in `replaced_values`. For matches that could
+ // be static or dynamic, it should be assumed that the shape is dynamic
+ // and the value needs to be passed to the rewrite function.
+ // - `other_operands` same as `input_values`, but kept separate to allow
+ // flexibility of where the results are passed through the ABI boundary.
+ %fn_name = pdl.attribute = "mlp_external"
+ pdl.apply_native_rewrite "rewriteAsFlowDispatch"(
+ %cast, %fn_name, %input_values, %replaced_values, %replaced_values_dims, %other_operands
+ : !pdl.operation, !pdl.attribute, !pdl.range<value>, !pdl.range<value>, !pdl.range<value>, !pdl.range<value>)
+ }
+}
+
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
index 6ad3b9c..7d00257 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
@@ -1400,8 +1400,10 @@
state.addOperands(resultDims);
state.addAttributes(attributes);
state.attributes.erase(IREE::Util::TiedOpInterface::getStorageAttrName());
- state.addAttribute(IREE::Util::TiedOpInterface::getStorageAttrName(),
- tiedOperands);
+ if (tiedOperands) {
+ state.addAttribute(IREE::Util::TiedOpInterface::getStorageAttrName(),
+ tiedOperands);
+ }
state.attributes.erase(getOperandSegmentSizeAttr());
state.addAttribute(getOperandSegmentSizeAttr(),
builder.getDenseI32ArrayAttr({
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/ApplyPDLPatterns.cpp b/compiler/src/iree/compiler/Preprocessing/Common/ApplyPDLPatterns.cpp
new file mode 100644
index 0000000..b1a3e0c
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/Common/ApplyPDLPatterns.cpp
@@ -0,0 +1,501 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Preprocessing/Common/Passes.h"
+
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
+#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/SourceMgr.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/PDL/IR/PDL.h"
+#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-preprocessing-apply-pdl-patterns"
+
+using namespace mlir;
+using namespace mlir::iree_compiler;
+
+namespace mlir::iree_compiler::Preprocessing {
+
+#define GEN_PASS_DEF_APPLYPDLPATTERNS
+#include "iree/compiler/Preprocessing/Common/Passes.h.inc" // IWYU pragma: export
+
+} // namespace mlir::iree_compiler::Preprocessing
+
+// Get the `memref` type for a `tensor` type.
+static MemRefType getMemRefTypeFor(RankedTensorType tensorType) {
+ return MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+}
+
+// Generates the external function call type that corresponds to the
+// matched list.
+static FunctionType getExternalFunctionCallType(MLIRContext *context,
+ Location loc,
+ TypeRange inputTypes,
+ TypeRange resultTypes,
+ TypeRange otherOperandTypes) {
+ SmallVector<Type> externalCallArgTypes;
+ // Conversion from tensor types to call arg types.
+ auto convertTensorTypeToCallArgTypes = [&](RankedTensorType tensorType) {
+ auto memRefType = getMemRefTypeFor(tensorType);
+ externalCallArgTypes.push_back(
+ MemRefType::get(ArrayRef<int64_t>{}, memRefType.getElementType()));
+ externalCallArgTypes.push_back(IndexType::get(context));
+ };
+
+ // Conversion from input type to call arg types.
+ auto convertInputTypeToCallArgTypes = [&](Type inputType) {
+ if (inputType.isIntOrFloat()) {
+ externalCallArgTypes.push_back(inputType);
+ return;
+ }
+
+ auto tensorType = inputType.cast<RankedTensorType>();
+ convertTensorTypeToCallArgTypes(tensorType);
+ return;
+ };
+
+ llvm::for_each(inputTypes, convertInputTypeToCallArgTypes);
+ llvm::for_each(resultTypes, [&](Type t) {
+ convertTensorTypeToCallArgTypes(t.cast<RankedTensorType>());
+ });
+ llvm::for_each(otherOperandTypes, convertInputTypeToCallArgTypes);
+
+ return FunctionType::get(context, externalCallArgTypes,
+ /*results=*/TypeRange{});
+}
+
+// Returns the base pointer and offset from the given binding.
+std::pair<Value, Value>
+getBasePtrAndOffsetForTensor(PatternRewriter &rewriter, Location loc,
+ RankedTensorType tensorType, Value value,
+ Value bindingOffset, ValueRange dynamicDims) {
+ auto memrefType = getMemRefTypeFor(tensorType);
+ Value memrefVal = rewriter.create<IREE::Stream::BindingSubspanOp>(
+ loc, memrefType, value, bindingOffset, dynamicDims);
+ auto extractMetadataOp =
+ rewriter.create<memref::ExtractStridedMetadataOp>(loc, memrefVal);
+ return std::make_pair<Value, Value>(extractMetadataOp.getResult(0),
+ extractMetadataOp.getResult(1));
+}
+
+// Create the entry point function to marshal IREEs ABI and call the external
+// function.
+static func::FuncOp
+createEntryPointFn(PatternRewriter &rewriter, Operation *rootOp,
+ StringRef entryPointFnName, func::FuncOp externalFn,
+ TypeRange inputTypes, TypeRange resultTypes,
+ TypeRange otherOperandTypes) {
+ MLIRContext *context = rewriter.getContext();
+ Location loc = rootOp->getLoc();
+
+ // The ABI is
+ // - !stream.binding for all tensor type operands
+ // - !stream.binding for all tensor type results
+ // - all scalar operands
+ // - values of dynamic dimensions of all tensor operands and results.
+ SmallVector<Type> entryPointInputTypes;
+ SmallVector<Type> entryPointScalarInputTypes;
+ int64_t totalNumDynamicDims = 0;
+ auto bindingType = IREE::Stream::BindingType::get(context);
+
+ // Method to process tensor types.
+ auto processTensorType = [&](RankedTensorType tensorType) {
+ entryPointInputTypes.push_back(bindingType);
+ totalNumDynamicDims += tensorType.getNumDynamicDims();
+ };
+ // Method to process input types.
+ auto processInputType = [&](Type type) {
+ if (type.isIntOrFloat()) {
+ entryPointScalarInputTypes.push_back(type);
+ return;
+ }
+ auto tensorType = type.cast<RankedTensorType>();
+ processTensorType(tensorType);
+ };
+
+ llvm::for_each(inputTypes, processInputType);
+ llvm::for_each(resultTypes, [&](Type t) {
+ processTensorType(t.cast<RankedTensorType>());
+ });
+ llvm::for_each(otherOperandTypes, processInputType);
+
+ int64_t numTensorOperands = (int64_t)entryPointInputTypes.size();
+ int64_t numScalarOperands = (int64_t)entryPointScalarInputTypes.size();
+ entryPointInputTypes.append(entryPointScalarInputTypes);
+ entryPointInputTypes.append(totalNumDynamicDims, rewriter.getIndexType());
+
+ auto entryPointFnType = FunctionType::get(context, entryPointInputTypes,
+ /*results=*/TypeRange{});
+ auto entryPointFn =
+ rewriter.create<func::FuncOp>(loc, entryPointFnName, entryPointFnType);
+ Region &body = entryPointFn.getBody();
+ SmallVector<Location> locs(entryPointInputTypes.size(), loc);
+ rewriter.createBlock(&body, body.begin(), entryPointInputTypes, locs);
+
+ auto entryPointArgs = entryPointFn.getArguments();
+ auto tensorArgs = entryPointArgs.take_front(numTensorOperands);
+ auto scalarArgs = entryPointArgs.slice(numTensorOperands, numScalarOperands);
+ auto dynamicDimArgs = entryPointArgs.take_back(totalNumDynamicDims);
+ SmallVector<Value> callOperands;
+ Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+
+ // Method to marshal tensor types into call operands.
+ auto marshalTensorTypes = [&](RankedTensorType tensorType) {
+ int64_t numDynamicDims = tensorType.getNumDynamicDims();
+ auto dynamicDims = dynamicDimArgs.take_front(numDynamicDims);
+ auto [basePtr, offset] = getBasePtrAndOffsetForTensor(
+ rewriter, loc, tensorType, tensorArgs.front(), zero, dynamicDims);
+ callOperands.push_back(basePtr);
+ callOperands.push_back(offset);
+ tensorArgs = tensorArgs.drop_front();
+ dynamicDimArgs = dynamicDimArgs.drop_front(numDynamicDims);
+ };
+ // Method to marshal input types into call operands.
+ auto marshalInputTypes = [&](Type type) {
+ if (type.isIntOrFloat()) {
+ callOperands.push_back(scalarArgs.front());
+ scalarArgs = scalarArgs.drop_front();
+ return;
+ }
+ marshalTensorTypes(type.cast<RankedTensorType>());
+ };
+
+ llvm::for_each(inputTypes, marshalInputTypes);
+ llvm::for_each(resultTypes, [&](Type t) {
+ marshalTensorTypes(t.cast<RankedTensorType>());
+ });
+ llvm::for_each(otherOperandTypes, marshalInputTypes);
+
+ rewriter.create<func::CallOp>(loc, externalFn, callOperands);
+ rewriter.create<func::ReturnOp>(loc, /*operands=*/ValueRange{});
+ return entryPointFn;
+}
+
+// Generate the `hal.executable` that calls into the external function.
+// Return the nested symbol reference to the entry point function generated.
+static SymbolRefAttr
+createStreamExecutableOp(PatternRewriter &rewriter, Operation *rootOp,
+ StringRef externalFnName, TypeRange inputTypes,
+ TypeRange resultTypes, TypeRange otherOperandTypes) {
+ auto moduleOp = rootOp->getParentOfType<ModuleOp>();
+ assert(moduleOp && "found op without surrounding module");
+
+ Block *body = moduleOp.getBody();
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(body);
+
+ // Create the hal.executable to marshal calling the external function.
+ Location loc = rootOp->getLoc();
+ std::string executableOpName = externalFnName.str() + "_executable";
+ auto executableOp =
+ rewriter.create<IREE::Stream::ExecutableOp>(loc, executableOpName);
+ executableOp.setPrivate();
+ Block &executableOpBody = executableOp.getBlock();
+ rewriter.setInsertionPointToStart(&executableOpBody);
+
+ // Create the dispatch inner module.
+ auto innerModule = rewriter.create<ModuleOp>(loc);
+ Block *moduleBody = innerModule.getBody();
+ rewriter.setInsertionPointToStart(moduleBody);
+
+ // Create a private function call which is the external function call.
+ MLIRContext *context = rewriter.getContext();
+ FunctionType externalFnCallType = getExternalFunctionCallType(
+ context, loc, inputTypes, resultTypes, otherOperandTypes);
+ func::FuncOp externalFnCall =
+ rewriter.create<func::FuncOp>(loc, externalFnName, externalFnCallType);
+ externalFnCall.setPrivate();
+ externalFnCall->setAttr("llvm.bareptr", rewriter.getBoolArrayAttr(true));
+
+ // Create the executable entry point function.
+ std::string entryPointName = externalFnName.str() + "_entry_point";
+ func::FuncOp entryFn =
+ createEntryPointFn(rewriter, rootOp, entryPointName, externalFnCall,
+ inputTypes, resultTypes, otherOperandTypes);
+
+ // Create the export operation.
+ rewriter.setInsertionPoint(innerModule);
+ auto exportOp = rewriter.create<IREE::Stream::ExecutableExportOp>(
+ loc, entryPointName, FlatSymbolRefAttr::get(context, entryPointName));
+
+ // Create the body of the export operation.
+ // TODO(MaheshRavishankar): This represents the number of workgroups to use.
+ // Ideally this is somehow exposed to the rewrite mechanism to get the
+ // workload and the number of workgroups.
+ Region &exportOpRegion = exportOp.getRegion();
+ Block *exportOpBody =
+ rewriter.createBlock(&exportOpRegion, exportOpRegion.begin());
+ rewriter.setInsertionPointToStart(exportOpBody);
+ Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ rewriter.create<IREE::Stream::ReturnOp>(loc, ValueRange{one, one, one});
+ return SymbolRefAttr::get(rewriter.getStringAttr(executableOpName),
+ SymbolRefAttr::get(entryFn));
+}
+
+// Create the `flow.dispatch` op calling into the executable.
+static IREE::Flow::DispatchOp
+createFlowDispatchOp(PatternRewriter &rewriter, SymbolRefAttr exportOp,
+ Operation *rootOp, TypeRange resultTypes,
+ ValueRange resultDynamicDims, ValueRange operands) {
+ Location loc = rootOp->getLoc();
+ SmallVector<Value> operandsVec = llvm::to_vector(operands);
+ SmallVector<Value> operandDynamicDims;
+
+ // Get the dynamic dims for the operands.
+ for (auto operand : operands) {
+ auto tensorType = operand.getType().dyn_cast<RankedTensorType>();
+ if (!tensorType)
+ continue;
+
+ for (auto [index, shape] : llvm::enumerate(tensorType.getShape())) {
+ if (!ShapedType::isDynamic(shape))
+ continue;
+
+ Value dim = rewriter.create<tensor::DimOp>(loc, operand, index);
+ operandDynamicDims.push_back(dim);
+ }
+ }
+
+ // Append all the dynamic dims to the operands.
+ operandsVec.append(operandDynamicDims);
+ operandsVec.append(resultDynamicDims.begin(), resultDynamicDims.end());
+
+ // Insert the `flow.dispatch`.
+ auto dispatchOp = rewriter.create<IREE::Flow::DispatchOp>(
+ loc, exportOp,
+ /*workload=*/ValueRange{}, resultTypes, resultDynamicDims, operandsVec,
+ operandDynamicDims, /*tiedOperands=*/nullptr);
+ return dispatchOp;
+}
+
+// Get the values for dynamic shape of results of `rootOp`.
+static FailureOr<SmallVector<Value>>
+getDynamicResultDims(PatternRewriter &rewriter, ValueRange givenResultDims) {
+ // Prune the given dimensions to get just the dynamic dims.
+ SmallVector<Value> dynamicResultDims;
+ SmallVector<OpFoldResult> mixedValues = getAsOpFoldResult(givenResultDims);
+ for (auto ofr : mixedValues) {
+ auto value = ofr.dyn_cast<Value>();
+ if (!value)
+ continue;
+ dynamicResultDims.push_back(value);
+ }
+ return dynamicResultDims;
+}
+
+// Check that the operand types and result type satisfy some constants
+// - All operands must be scalar type or tensor type.
+// - All results must be tensor type.
+static LogicalResult checkOperandAndResultTypes(Operation *rootOp,
+ TypeRange inputTypes,
+ TypeRange resultTypes,
+ TypeRange otherOperandTypes) {
+ if (llvm::any_of(inputTypes, [](Type type) {
+ return !type.isIntOrFloat() && !type.isa<RankedTensorType>();
+ })) {
+ return rootOp->emitOpError("operand types of external function can be "
+ "`int*`, `float*` or `tensor`");
+ }
+
+ if (llvm::any_of(resultTypes,
+ [](Type type) { return !type.isa<RankedTensorType>(); })) {
+ return rootOp->emitOpError("result types of external function can only be "
+ "`int*`, `float*` or `tensor`s");
+ }
+
+ if (llvm::any_of(otherOperandTypes, [](Type type) {
+ return !type.isIntOrFloat() && !type.isa<RankedTensorType>();
+ })) {
+ return rootOp->emitOpError("operand types of external function can be "
+ "`int*`, `float*` or `tensor`");
+ }
+ return success();
+}
+
+// Constraint function to check that a tensor has a given element type
+static LogicalResult checkTensorElementType(PatternRewriter &rewriter,
+ Type operandType,
+ Type elementType) {
+ auto tensorType = operandType.dyn_cast<RankedTensorType>();
+ return success(tensorType && tensorType.getElementType() == elementType);
+}
+
+// Rewrite function to rewrite a matched DAG into a flow.dispatch. Conceptually,
+// the matched DAG at the tensor level gets replaced by a function
+//
+// ```
+// <results> = <external fn>(<input operands>, <initial value of results>,
+// <other operands>)
+// ```
+//
+// `<other operands>` is handled same as `<input operands>`. The split is to
+// allow freedom for where the result buffers are passed in through the ABI.
+// `<results>` and `<initial values of result>` get tied to the same `memref`.
+// So conceptually, at a `memref` level the DAG gets replaced by
+//
+// ```
+// <external fn>(<input operands>, <result operands in-out>, <other operands>)
+// ```
+//
+// Each buffer object (input or output) is passed as a `pointer, offset` pair
+// and value at location `index` is expected to be accessed as `pointer[offset +
+// index]` (note: `offset` is number of elements)
+//
+//
+// The operands to this are
+// - `rootOp` is the root of the matched DAG. This op will be erased after the
+// call.
+// - `externalFnName` the name of the function that is provided externally
+// (using a plugin).
+// - `inputOperands` are values that are captures as the part of the match
+// and are inputs to the match.
+// - `replacedValues` are the values that are captured as part of the match
+// and are replaced by the `flow.dispatch`. The `flow.dispatch` returns
+// as many values as `replacedValues` (and of same type).
+// - `replacedValuesShape` are the values for the dynamic dimensions of all the
+// `tensor` values in `replacedValues`.
+// For matches that could be static or dynamic, it should be assumed that the
+// shape is dynamic and the value needs to be passed to the rewrite function.
+// - `otherOperands` same as `inputOperands`, but kept separate to allow
+// flexibility of where the
+// results are passed through the ABI boundary.
+static LogicalResult rewriteAsFlowDispatch(
+ PatternRewriter &rewriter, Operation *rootOp, Attribute externalFnName,
+ ValueRange inputOperands, ValueRange replacedValues,
+ ValueRange replacedValuesShapes, ValueRange otherOperands) {
+ auto getType = [](Value v) { return v.getType(); };
+ auto inputTypes = llvm::map_to_vector(inputOperands, getType);
+ SmallVector<Type> resultTypes = llvm::map_to_vector(replacedValues, getType);
+ auto otherOperandTypes = llvm::map_to_vector(otherOperands, getType);
+
+ if (failed(checkOperandAndResultTypes(rootOp, inputTypes, resultTypes,
+ otherOperandTypes))) {
+ return rewriter.notifyMatchFailure(rootOp,
+ "unhandled operand/result types");
+ }
+ StringAttr externalFnNameAttr = dyn_cast<StringAttr>(externalFnName);
+ if (!externalFnNameAttr) {
+ return rewriter.notifyMatchFailure(
+ rootOp, "expected string attribute for external fn name");
+ }
+
+ // Get the dynamic result dimensions.
+ FailureOr<SmallVector<Value>> dynamicResultDims =
+ getDynamicResultDims(rewriter, replacedValuesShapes);
+ if (failed(dynamicResultDims)) {
+ return rewriter.notifyMatchFailure(
+ rootOp, "failed to get dynamic result dimensions");
+ }
+
+ SymbolRefAttr entryPointFnRef =
+ createStreamExecutableOp(rewriter, rootOp, externalFnNameAttr.getValue(),
+ inputTypes, resultTypes, otherOperandTypes);
+
+ SmallVector<Value> operands = llvm::to_vector(inputOperands);
+ operands.append(otherOperands.begin(), otherOperands.end());
+ IREE::Flow::DispatchOp dispatchOp =
+ createFlowDispatchOp(rewriter, entryPointFnRef, rootOp, resultTypes,
+ dynamicResultDims.value(), operands);
+ assert(
+ dispatchOp.getNumResults() == replacedValues.size() &&
+ "expected dispatch op to return replacements for all specified values");
+
+ for (auto [origValue, replacement] :
+ llvm::zip_equal(replacedValues, dispatchOp->getResults())) {
+ rewriter.replaceAllUsesWith(origValue, replacement);
+ }
+ rewriter.eraseOp(rootOp);
+
+ return success();
+}
+
+// Populate patterns from files.
+static LogicalResult
+populatePDLModuleFromFileName(MLIRContext *context, RewritePatternSet &patterns,
+ llvm::StringRef pdlModuleFileName) {
+ std::string errorMessage;
+ auto memoryBuffer = mlir::openInputFile(pdlModuleFileName, &errorMessage);
+ if (!memoryBuffer) {
+ return emitError(FileLineColLoc::get(
+ StringAttr::get(context, pdlModuleFileName), 0, 0))
+ << "failed to open pattern module file: " << errorMessage;
+ }
+ // Tell sourceMgr about this buffer, the parser will pick it up.
+ llvm::SourceMgr sourceMgr;
+ sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc());
+ PDLPatternModule pdlModule =
+ OwningOpRef<ModuleOp>(parseSourceFile<ModuleOp>(sourceMgr, context));
+ pdlModule.registerRewriteFunction("rewriteAsFlowDispatch",
+ rewriteAsFlowDispatch);
+ pdlModule.registerConstraintFunction("checkTensorElementType",
+ checkTensorElementType);
+ patterns.insert(std::move(pdlModule));
+ return success();
+}
+
+namespace {
+
+class ApplyPDLPatternsPass
+ : public iree_compiler::Preprocessing::impl::ApplyPDLPatternsBase<
+ ApplyPDLPatternsPass> {
+
+public:
+ using Base::Base;
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<arith::ArithDialect, iree_compiler::IREE::Flow::FlowDialect,
+ iree_compiler::IREE::Stream::StreamDialect,
+ iree_compiler::IREE::Util::UtilDialect,
+ memref::MemRefDialect, pdl::PDLDialect,
+ pdl_interp::PDLInterpDialect, tensor::TensorDialect>();
+ }
+
+ LogicalResult initialize(MLIRContext *context) override {
+ if (patternsFile.empty()) {
+ return success();
+ }
+ RewritePatternSet tmpPatterns(context);
+ if (failed(populatePDLModuleFromFileName(context, tmpPatterns,
+ patternsFile))) {
+ return failure();
+ }
+ patterns = std::move(tmpPatterns);
+ return success();
+ }
+
+ void runOnOperation() override {
+ // If there is nothing to do then return.
+ if (!patterns.getPDLByteCode()) {
+ return;
+ }
+
+ // Apply the patterns.
+ auto operation = getOperation();
+ if (failed(applyPatternsAndFoldGreedily(operation, patterns))) {
+ operation->emitOpError("failed to apply patterns specified in ")
+ << patternsFile;
+ return signalPassFailure();
+ }
+ }
+
+private:
+ /// Loaded PDL patterns
+ FrozenRewritePatternSet patterns;
+};
+
+} // namespace
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
index 7b8c052..ebbc221 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
@@ -30,6 +30,7 @@
iree_compiler_cc_library(
name = "Transforms",
srcs = [
+ "ApplyPDLPatterns.cpp",
"ConvertConv2DToImg2Col.cpp",
"ConvertConvToChannelsLast.cpp",
"InterpreterPass.cpp",
@@ -46,6 +47,8 @@
":PassesIncGen",
"//compiler/src/iree/compiler/Dialect/Flow/IR",
"//compiler/src/iree/compiler/Dialect/Flow/Transforms",
+ "//compiler/src/iree/compiler/Dialect/Stream/IR",
+ "//compiler/src/iree/compiler/Dialect/Util/IR",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:DialectUtils",
@@ -54,7 +57,12 @@
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgDialect",
"@llvm-project//mlir:LinalgTransforms",
+ "@llvm-project//mlir:MemRefDialect",
+ "@llvm-project//mlir:PDLDialect",
+ "@llvm-project//mlir:PDLInterpDialect",
+ "@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TensorUtils",
"@llvm-project//mlir:TransformDialect",
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
index 4df5ddb..9ed7bf4 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
@@ -26,6 +26,7 @@
"Passes.h"
"Passes.h.inc"
SRCS
+ "ApplyPDLPatterns.cpp"
"ConvertConv2DToImg2Col.cpp"
"ConvertConvToChannelsLast.cpp"
"InterpreterPass.cpp"
@@ -42,7 +43,12 @@
MLIRIR
MLIRLinalgDialect
MLIRLinalgTransforms
+ MLIRMemRefDialect
+ MLIRPDLDialect
+ MLIRPDLInterpDialect
+ MLIRParser
MLIRPass
+ MLIRSupport
MLIRTensorDialect
MLIRTensorUtils
MLIRTransformDialect
@@ -50,6 +56,8 @@
MLIRTransforms
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::Flow::Transforms
+ iree::compiler::Dialect::Stream::IR
+ iree::compiler::Dialect::Util::IR
PUBLIC
)
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
index 3eeeb9b..708d4c7 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
+++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
@@ -9,6 +9,14 @@
include "mlir/Pass/PassBase.td"
+def ApplyPDLPatterns : Pass<"iree-preprocessing-apply-pdl-patterns", "ModuleOp"> {
+ let summary = "Parse an input file containing PDL patterns and apply them as patterns";
+ let options = [
+ Option<"patternsFile", "patterns-file", "std::string",
+ /*default=*/"", "File path to file containing PDL patterns to use.">,
+ ];
+}
+
def ConvertConv2DToImg2Col :
Pass<"iree-preprocessing-convert-conv2d-to-img2col", ""> {
let summary = "Convert linalg convolution ops to matmul img2col based implementation";
diff --git a/compiler/src/iree/compiler/Preprocessing/Passes.cpp b/compiler/src/iree/compiler/Preprocessing/Passes.cpp
index 40638e0..3d96476 100644
--- a/compiler/src/iree/compiler/Preprocessing/Passes.cpp
+++ b/compiler/src/iree/compiler/Preprocessing/Passes.cpp
@@ -68,9 +68,8 @@
const PreprocessingOptions &preprocessingOptions,
PipelineExtensions *pipelineExtensions) {
auto pipelineStr = preprocessingOptions.preprocessingPassPipeline;
- if (!preprocessingOptions.preprocessingPassPipeline.empty()) {
- extendWithTextPipeline(passManager,
- preprocessingOptions.preprocessingPassPipeline);
+ if (!pipelineStr.empty()) {
+ extendWithTextPipeline(passManager, pipelineStr);
}
if (pipelineExtensions) {
diff --git a/samples/custom_dispatch/cpu/mlp_plugin/CMakeLists.txt b/samples/custom_dispatch/cpu/mlp_plugin/CMakeLists.txt
index 721717e..2d4221d 100644
--- a/samples/custom_dispatch/cpu/mlp_plugin/CMakeLists.txt
+++ b/samples/custom_dispatch/cpu/mlp_plugin/CMakeLists.txt
@@ -37,9 +37,7 @@
)
add_dependencies(iree-sample-deps
- iree_samples_custom_dispatch_cpu_mlp_plugin
- iree_samples_custom_dispatch_cpu_system_plugin)
-
+ iree_samples_custom_dispatch_cpu_mlp_plugin)
iree_lit_test_suite(
NAME
@@ -51,6 +49,42 @@
iree-compile
iree-run-module
iree_samples_custom_dispatch_cpu_mlp_plugin
+ DATA
+ "mlp_spec.mlir"
+ LABELS
+ "driver=local-sync"
+ "hostonly"
+)
+
+iree_lit_test_suite(
+ NAME
+ mlp_tosa_example
+ SRCS
+ "mlp_tosa.mlir"
+ TOOLS
+ FileCheck
+ iree-compile
+ iree-run-module
+ iree_samples_custom_dispatch_cpu_mlp_plugin
+ DATA
+ "mlp_tosa_spec.pdl.mlir"
+ LABELS
+ "driver=local-sync"
+ "hostonly"
+)
+
+iree_lit_test_suite(
+ NAME
+ mlp_torch_example
+ SRCS
+ "mlp_torch.mlir"
+ TOOLS
+ FileCheck
+ iree-compile
+ iree-run-module
+ iree_samples_custom_dispatch_cpu_mlp_plugin
+ DATA
+ "mlp_torch_spec.pdl.mlir"
LABELS
"driver=local-sync"
"hostonly"
diff --git a/samples/custom_dispatch/cpu/mlp_plugin/mlp.mlir b/samples/custom_dispatch/cpu/mlp_plugin/mlp.mlir
index fbe68f2..3d99d26 100644
--- a/samples/custom_dispatch/cpu/mlp_plugin/mlp.mlir
+++ b/samples/custom_dispatch/cpu/mlp_plugin/mlp.mlir
@@ -4,7 +4,7 @@
// RUN: --module=- \
// RUN: --function=mlp_invocation \
// RUN: --input="2x2xf32=[[2.0, 2.0], [-2.0, -2.0]]" \
-// RUN: --input="2x2xf32=[[3.0 -3.0], [3.0, -3.0]]"
+// RUN: --input="2x2xf32=[[3.0, -3.0], [3.0, -3.0]]"
// The implementation of MLP is matched using a transform dialect script and is forwarded to a system plugin.
diff --git a/samples/custom_dispatch/cpu/mlp_plugin/mlp_plugin.c b/samples/custom_dispatch/cpu/mlp_plugin/mlp_plugin.c
index 35b22e4..c895a9c 100644
--- a/samples/custom_dispatch/cpu/mlp_plugin/mlp_plugin.c
+++ b/samples/custom_dispatch/cpu/mlp_plugin/mlp_plugin.c
@@ -25,9 +25,8 @@
} mlp_plugin_t;
// Helper function to resolve index [i][j] into location for given strides.
-size_t get_index(size_t i, size_t j, size_t offset, size_t stride0,
- size_t stride1) {
- return offset + i * stride0 + j * stride1;
+static size_t get_index(size_t i, size_t j, size_t offset, size_t stride) {
+ return offset + i * stride + j;
}
// `ret = mlp(lhs, rhs)`
@@ -56,26 +55,11 @@
mlp_plugin_t* plugin = (mlp_plugin_t*)context;
typedef struct {
const float* restrict lhs;
- const float* restrict lhs_aligned;
size_t lhs_offset;
- size_t lhs_size0;
- size_t lhs_size1;
- size_t lhs_stride0;
- size_t lhs_stride1;
const float* restrict rhs;
- const float* restrict rhs_aligned;
size_t rhs_offset;
- size_t rhs_size0;
- size_t rhs_size1;
- size_t rhs_stride0;
- size_t rhs_stride1;
float* restrict result;
- float* restrict result_aligned;
size_t result_offset;
- size_t result_size0;
- size_t result_size1;
- size_t result_stride0;
- size_t result_stride1;
int32_t M;
int32_t N;
int32_t K;
@@ -87,16 +71,14 @@
for (int32_t j = 0; j < params->N; j++) {
float curr_result = 0.0;
for (int32_t k = 0; k < params->K; k++) {
- size_t lhs_index = get_index(i, k, params->lhs_offset,
- params->lhs_stride0, params->lhs_stride1);
- size_t rhs_index = get_index(k, j, params->rhs_offset,
- params->rhs_stride0, params->rhs_stride1);
+ size_t lhs_index =
+ get_index(i, k, params->lhs_offset, (size_t)params->K);
+ size_t rhs_index =
+ get_index(k, j, params->rhs_offset, (size_t)params->N);
curr_result += params->lhs[lhs_index] * params->rhs[rhs_index];
}
curr_result = curr_result < 0.0 ? 0.0 : curr_result;
- size_t result_index =
- get_index(i, j, params->result_offset, params->result_stride0,
- params->result_stride1);
+ size_t result_index = get_index(i, j, params->result_offset, params->N);
params->result[result_index] = curr_result;
}
}
diff --git a/samples/custom_dispatch/cpu/mlp_plugin/mlp_spec.mlir b/samples/custom_dispatch/cpu/mlp_plugin/mlp_spec.mlir
index 58f3f15..24a1046 100644
--- a/samples/custom_dispatch/cpu/mlp_plugin/mlp_spec.mlir
+++ b/samples/custom_dispatch/cpu/mlp_plugin/mlp_spec.mlir
@@ -2,51 +2,29 @@
// an implementation implemented by a system plugin.
// Is used along with samples/custom_dispatch/cpu/plugin/mlp.mlir
-#x86_64_target = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {
- data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
- native_vector_size = 32 : index,
- target_triple = "x86_64-none-elf"
-}>
-
-#cpu_target = #hal.device.target<"llvm-cpu", {
- executable_targets = [
- #x86_64_target
- ]
-}>
-
module attributes {transform.with_named_sequence} {
// Executable that stages call to the external functions.
- hal.executable private @executable {
- hal.executable.variant public @x86_64 target(#x86_64_target) {
- hal.executable.export public @mlp ordinal(0)
- layout(#hal.pipeline.layout<push_constants = 3, sets = [
- <0, bindings = [
- <0, storage_buffer, ReadOnly>,
- <1, storage_buffer, ReadOnly>,
- <2, storage_buffer>
- ]>
- ]>) {
- ^bb0(%device : !hal.device):
- %c1 = arith.constant 1 : index
- hal.return %c1, %c1, %c1 : index, index, index
- }
- builtin.module {
- func.func private @mlp_external(%lhs : memref<?x?xf32>, %rhs : memref<?x?xf32>, %result : memref<?x?xf32>, %m : i32, %n : i32, %k : i32)
- func.func @mlp() {
- %m_i32 = hal.interface.constant.load[0] : i32
- %n_i32 = hal.interface.constant.load[1] : i32
- %k_i32 = hal.interface.constant.load[2] : i32
- %c0 = arith.constant 0 : index
- %m = arith.index_cast %m_i32 : i32 to index
- %n = arith.index_cast %n_i32 : i32 to index
- %k = arith.index_cast %k_i32 : i32 to index
- %lhs = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : memref<?x?xf32>{%m, %k}
- %rhs = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<?x?xf32>{%k, %n}
- %result = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<?x?xf32>{%m, %n}
- func.call @mlp_external(%lhs, %rhs, %result, %m_i32, %n_i32, %k_i32) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>, i32, i32, i32) -> ()
- return
- }
+ stream.executable private @executable {
+ stream.executable.export public @mlp workgroups() -> (index, index, index) {
+ %c1 = arith.constant 1 : index
+ hal.return %c1, %c1, %c1 : index, index, index
+ }
+ builtin.module {
+ func.func private @mlp_external(%lhs : memref<f32>, %lhs_offset : index, %rhs : memref<f32>, %rhs_offset : index, %result : memref<f32>, %result_offset : index, %m : i32, %n : i32, %k : i32) attributes {llvm.bareptr}
+ func.func @mlp(%arg0: !stream.binding, %arg1: !stream.binding, %arg2: !stream.binding, %arg3: i32, %arg4: i32, %arg5 : i32) {
+ %c0 = arith.constant 0 : index
+ %m = arith.index_cast %arg3 : i32 to index
+ %n = arith.index_cast %arg4 : i32 to index
+ %k = arith.index_cast %arg5 : i32 to index
+ %lhs = stream.binding.subspan %arg0[%c0] : !stream.binding -> memref<?x?xf32>{%m, %k}
+ %rhs = stream.binding.subspan %arg1[%c0] : !stream.binding -> memref<?x?xf32>{%k, %n}
+ %result = stream.binding.subspan %arg2[%c0] : !stream.binding -> memref<?x?xf32>{%m, %n}
+ %p0, %o0, %s00, %s01, %t00, %t01 = memref.extract_strided_metadata %lhs : memref<?x?xf32> -> memref<f32>, index, index, index, index, index
+ %p1, %o1, %s10, %s11, %t10, %t11 = memref.extract_strided_metadata %rhs : memref<?x?xf32> -> memref<f32>, index, index, index, index, index
+ %p2, %o2, %s20, %s21, %t20, %t21 = memref.extract_strided_metadata %result : memref<?x?xf32> -> memref<f32>, index, index, index, index, index
+ func.call @mlp_external(%p0, %o0, %p1, %o1, %p2, %o2, %arg3, %arg4, %arg5) : (memref<f32>, index, memref<f32>, index, memref<f32>, index, i32, i32, i32) -> ()
+ return
}
}
}
@@ -61,7 +39,8 @@
%n_i32 = arith.index_cast %n : index to i32
%k_i32 = arith.index_cast %k : index to i32
- %mlp_result = flow.dispatch @executable::@x86_64::@mlp(%lhs, %rhs, %m_i32, %n_i32, %k_i32) : (tensor<?x?xf32>{%m, %k}, tensor<?x?xf32>{%k, %n}, i32, i32, i32) -> tensor<?x?xf32>{%m, %n}
+ %mlp_result = flow.dispatch @executable::@mlp(%lhs, %rhs, %m_i32, %n_i32, %k_i32)
+ : (tensor<?x?xf32>{%m, %k}, tensor<?x?xf32>{%k, %n}, i32, i32, i32) -> tensor<?x?xf32>{%m, %n}
util.return %mlp_result : tensor<?x?xf32>
}
diff --git a/samples/custom_dispatch/cpu/mlp_plugin/mlp_torch.mlir b/samples/custom_dispatch/cpu/mlp_plugin/mlp_torch.mlir
new file mode 100644
index 0000000..42a9d31
--- /dev/null
+++ b/samples/custom_dispatch/cpu/mlp_plugin/mlp_torch.mlir
@@ -0,0 +1,81 @@
+// RUN: iree-opt --pass-pipeline="builtin.module(iree-preprocessing-apply-pdl-patterns{patterns-file=%p/mlp_torch_spec.pdl.mlir})" %s | \
+// RUN: iree-compile - | \
+// RUN: iree-run-module --device=local-sync \
+// RUN: --executable_plugin=$IREE_BINARY_DIR/samples/custom_dispatch/cpu/mlp_plugin/mlp_plugin$IREE_DYLIB_EXT \
+// RUN: --module=- \
+// RUN: --function=mlp_invocation \
+// RUN: --input="2x4xf32=[[2.0, 2.0, 2.0, 2.0], [-2.0, -2.0, -2.0, -2.0]]" \
+// RUN: --input="4x8xf32=[[3.0, -3.0, 3.0, -3.0], [3.0, -3.0, 3.0, -3.0], [3.0, -3.0, 3.0, -3.0], [3.0, -3.0, 3.0, -3.0], [3.0, -3.0, 3.0, -3.0], [3.0, -3.0, 3.0, -3.0], [3.0, -3.0, 3.0, -3.0], [3.0, -3.0, 3.0, -3.0]]"
+
+// Rewrite function to rewrite a matched DAG into a flow.dispatch. Conceptually,
+// the matched DAG at the tensor level gets replaced by a function
+//
+// ```
+// <results> = <external fn>(<input operands>, <initial value of results>,
+// <other operands>)
+// ```
+//
+// `<other operands>` is handled same as `<input operands>`. The split is to
+// allow freedom for where the result buffers are passed in through the ABI.
+// `<results>` and `<initial values of result>` get tied to the same `memref`.
+// So conceptually, at a `memref` level the DAG gets replaced by
+//
+// ```
+// <external fn>(<input operands>, <result operands in-out>, <other operands>)
+// ```
+//
+// Each buffer object (input or output) is passed as a `pointer, offset` pair
+// and value at location `index` is expected to be accessed as `pointer[offset +
+// index]` (note: `offset` is number of elements)
+
+#x86_64_target = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {
+ data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+ native_vector_size = 32 : index,
+ target_triple = "x86_64-none-elf"
+}>
+
+// The target devices that the program will run on. We can compile and run with
+// multiple targets, but this example is maintaining an implicit requirement
+// that the custom kernel being spliced in is supported by the target device,
+// hence we only support llvm-cpu here.
+#cpu_target = #hal.device.target<"llvm-cpu", {
+ executable_targets = [
+ #x86_64_target
+ ]
+}>
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module @example attributes {hal.device.targets = [#cpu_target]} {
+
+ // CHECK-LABEL: EXEC @mlp_invocation
+ // CHECK: [Plugin]: M = 2, N = 8, K = 4
+ // CHECK: 2x8xf32=[-24 -0 -24 -0 -24 -0 -24 -0][-0 -24 -0 -24 -0 -24 -0 -24]
+ func.func @mlp_invocation(%lhs: tensor<?x?xf32>,
+ %rhs: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %cst = arith.constant 0.0 : f32
+ %dim0 = tensor.dim %lhs, %c0 : tensor<?x?xf32>
+ %dim1 = tensor.dim %rhs, %c1 : tensor<?x?xf32>
+ %empty = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
+ %torch_lhs = torch_c.from_builtin_tensor %lhs : tensor<?x?xf32> -> !torch.vtensor<[?, ?], f32>
+ %torch_rhs = torch_c.from_builtin_tensor %rhs : tensor<?x?xf32> -> !torch.vtensor<[?, ?], f32>
+ %mm = torch.aten.mm %torch_lhs, %torch_rhs
+ : !torch.vtensor<[?, ?], f32>, !torch.vtensor<[?, ?], f32> -> !torch.vtensor<[?, ?], f32>
+ %relu = torch.aten.relu %mm : !torch.vtensor<[?, ?], f32> -> !torch.vtensor<[?, ?], f32>
+ %cast= torch_c.to_builtin_tensor %relu : !torch.vtensor<[?, ?], f32> -> tensor<?x?xf32>
+ %negf = linalg.generic {
+ indexing_maps = [#map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%cast : tensor<?x?xf32>) outs(%empty : tensor<?x?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32):
+ %0 = arith.negf %b0 : f32
+ linalg.yield %0 : f32
+ } -> tensor<?x?xf32>
+ return %negf : tensor<?x?xf32>
+ }
+} // module
+
+// CHECK-LABEL: EXEC @mlp_invocation
+// CHECK: [Plugin]: M = 2, N = 8, K = 4
+// CHECK: 2x8xf32=[-24 -0 -24 -0 -24 -0 -24 -0][-0 -24 -0 -24 -0 -24 -0 -24]
diff --git a/samples/custom_dispatch/cpu/mlp_plugin/mlp_torch_spec.pdl.mlir b/samples/custom_dispatch/cpu/mlp_plugin/mlp_torch_spec.pdl.mlir
new file mode 100644
index 0000000..2ae7532
--- /dev/null
+++ b/samples/custom_dispatch/cpu/mlp_plugin/mlp_torch_spec.pdl.mlir
@@ -0,0 +1,118 @@
+// PDL pattern spec to match an MLP and offload to an external function
+//
+// ```
+// void mlp_external(void *params, void *context, void *reserved)
+// ```
+//
+// which is the expected signature of an external function implemented
+// provided by a system plugin. See
+// samples/custom_dispatch/cpu/plugin/system_plugin.c for an example.
+//
+// The `params` is the following struct
+//
+// ```
+// struct mlp_params_t {
+// const float *restrict lhs;
+// size_t lhs_offset;
+// const float *restrict rhs;
+// size_t rhs_offset;
+// int32_t M;
+// int32_t N;
+// int32_t K;
+// float *restrict result;
+// size_t result_offset;
+// };
+// ```
+//
+// In MLIR this corresponds to the function
+//
+// ```
+// func.func @mlp_external(%lhs : memref<..xf32>, %rhs : memref<..xf32>,
+// %M: i32, %N : i32, %K : i32, %result : memref<..xf32>)
+// ```
+//
+// Note: In the above struct a `pointer, offset` pair represents a buffer
+// passed into the external function. So any access to `lhs`, `rhs` and
+// `result` is valid only if accessed as `lhs[lhs_offset + ...]`,
+// `rhs[rhs_offset + ]` and `result[result_offset + ...]`.
+pdl.pattern @mlp : benefit(1) {
+
+ // PDL matcher to match the MLP computation. This pattern is expected to
+ // match
+ //
+ // ```
+ // %result = func.call @mlp_external(%lhs : tensor<...xf32>,
+ // %rhs : tensor<..xf32>, %M : i32, %N : i32, %K : i32) -> tensor<..xf32>
+ // ```
+ %lhs = pdl.operand
+ %rhs = pdl.operand
+ %lhs_type = pdl.type : !torch.vtensor<[?,?],f32>
+ %rhs_type = pdl.type : !torch.vtensor<[?,?],f32>
+ %lhs_torch = pdl.operation "torch_c.from_builtin_tensor"(%lhs : !pdl.value) -> (%lhs_type : !pdl.type)
+ %lhs_val = pdl.result 0 of %lhs_torch
+ %rhs_torch = pdl.operation "torch_c.from_builtin_tensor"(%rhs : !pdl.value) -> (%rhs_type : !pdl.type)
+ %rhs_val = pdl.result 0 of %rhs_torch
+ %matmul_type = pdl.type : !torch.vtensor<[?,?],f32>
+ %matmul = pdl.operation "torch.aten.mm"(%lhs_val, %rhs_val : !pdl.value, !pdl.value) -> (%matmul_type : !pdl.type)
+ %matmul_result = pdl.result 0 of %matmul
+ %relu = pdl.operation "torch.aten.relu"(%matmul_result : !pdl.value) -> (%matmul_type : !pdl.type)
+ %result_type = pdl.type
+ %relu_val = pdl.result 0 of %relu
+ %cast = pdl.operation "torch_c.to_builtin_tensor"(%relu_val : !pdl.value) -> (%result_type : !pdl.type)
+
+ pdl.rewrite %matmul {
+ // The pattern above matched `%result`, `%lhs`, `%rhs` needed for the
+ // external function call. The values of `%M`, `%N` and `%K` need to
+ // be generated.
+ %zero_val = pdl.attribute = 0 : index
+ %one_val = pdl.attribute = 1 : index
+ %index_type = pdl.type : index
+ %zero_op = pdl.operation "arith.constant" {"value" = %zero_val} -> (%index_type : !pdl.type)
+ %zero = pdl.result 0 of %zero_op
+ %one_op = pdl.operation "arith.constant" {"value" = %one_val} -> (%index_type : !pdl.type)
+ %one = pdl.result 0 of %one_op
+ %i32_type = pdl.type : i32
+ %m_op = pdl.operation "tensor.dim"(%lhs, %zero : !pdl.value, !pdl.value)
+ %m = pdl.result 0 of %m_op
+ %n_op = pdl.operation "tensor.dim"(%rhs, %one : !pdl.value, !pdl.value)
+ %n = pdl.result 0 of %n_op
+ %k_op = pdl.operation "tensor.dim"(%lhs, %one : !pdl.value, !pdl.value)
+ %k = pdl.result 0 of %k_op
+ %m_i32_op = pdl.operation "arith.index_cast"(%m : !pdl.value) -> (%i32_type : !pdl.type)
+ %m_i32 = pdl.result 0 of %m_i32_op
+ %n_i32_op = pdl.operation "arith.index_cast"(%n : !pdl.value) -> (%i32_type : !pdl.type)
+ %n_i32 = pdl.result 0 of %n_i32_op
+ %k_i32_op = pdl.operation "arith.index_cast"(%k : !pdl.value) -> (%i32_type : !pdl.type)
+ %k_i32 = pdl.result 0 of %k_i32_op
+
+ %replaced_values_dims = pdl.range %m, %n : !pdl.value, !pdl.value
+ %input_values = pdl.range %lhs, %rhs : !pdl.value, !pdl.value
+ %replaced_value = pdl.result 0 of %cast
+ %replaced_values = pdl.range %replaced_value : !pdl.value
+ %other_operands = pdl.range %m_i32, %n_i32, %k_i32 : !pdl.value, !pdl.value, !pdl.value
+
+ // The `rewriteAsFlowDispatch` is a rewrite function that allows
+ // converting the matched dag into a call to the external function call
+ // provided by a system plugin. The rewrite method expects the following
+ // arguments
+ // - the root of the matched DAG. This op will be erased after the call.
+ // - `fn_name` the name of the function that is provided externally
+ // (using a plugin).
+ // - `input_values` are values that are captures as the part of the match
+ // and are inputs to the match.
+ // - `replaced_values` are the values that are captured as part of the
+ // match and are replaced by the `flow.dispatch`. The `flow.dispatch`
+ // returns as many values as `replaced_values` (and of same type).
+ // - `replaced_values_dims` are the values for the dynamic dimensions of
+ // all the `tensor` values in `replaced_values`. For matches that could
+ // be static or dynamic, it should be assumed that the shape is dynamic
+ // and the value needs to be passed to the rewrite function.
+ // - `other_operands` same as `input_values`, but kept separate to allow
+ // flexibility of where the results are passed through the ABI boundary.
+ %fn_name = pdl.attribute = "mlp_external"
+ pdl.apply_native_rewrite "rewriteAsFlowDispatch"(
+ %cast, %fn_name, %input_values, %replaced_values, %replaced_values_dims, %other_operands
+ : !pdl.operation, !pdl.attribute, !pdl.range<value>, !pdl.range<value>, !pdl.range<value>, !pdl.range<value>)
+ }
+}
+
diff --git a/samples/custom_dispatch/cpu/mlp_plugin/mlp_tosa.mlir b/samples/custom_dispatch/cpu/mlp_plugin/mlp_tosa.mlir
new file mode 100644
index 0000000..10d418e
--- /dev/null
+++ b/samples/custom_dispatch/cpu/mlp_plugin/mlp_tosa.mlir
@@ -0,0 +1,63 @@
+// RUN: iree-opt --pass-pipeline="builtin.module(iree-preprocessing-apply-pdl-patterns{patterns-file=%p/mlp_tosa_spec.pdl.mlir})" %s | \
+// RUN: iree-compile - | \
+// RUN: iree-run-module --device=local-sync \
+// RUN: --executable_plugin=$IREE_BINARY_DIR/samples/custom_dispatch/cpu/mlp_plugin/mlp_plugin$IREE_DYLIB_EXT \
+// RUN: --module=- \
+// RUN: --function=mlp_invocation \
+// RUN: --input="2x4xf32=[[2.0, 2.0, 2.0, 2.0], [-2.0, -2.0, -2.0, -2.0]]" \
+// RUN: --input="4x8xf32=[[3.0, -3.0, 3.0, -3.0], [3.0, -3.0, 3.0, -3.0], [3.0, -3.0, 3.0, -3.0], [3.0, -3.0, 3.0, -3.0], [3.0, -3.0, 3.0, -3.0], [3.0, -3.0, 3.0, -3.0], [3.0, -3.0, 3.0, -3.0], [3.0, -3.0, 3.0, -3.0]]"
+
+// Rewrite function to rewrite a matched DAG into a flow.dispatch. Conceptually,
+// the matched DAG at the tensor level gets replaced by a function
+//
+// ```
+// <results> = <external fn>(<input operands>, <initial value of results>,
+// <other operands>)
+// ```
+//
+// `<other operands>` is handled same as `<input operands>`. The split is to
+// allow freedom for where the result buffers are passed in through the ABI.
+// `<results>` and `<initial values of result>` get tied to the same `memref`.
+// So conceptually, at a `memref` level the DAG gets replaced by
+//
+// ```
+// <external fn>(<input operands>, <result operands in-out>, <other operands>)
+// ```
+//
+// Each buffer object (input or output) is passed as a `pointer, offset` pair
+// and value at location `index` is expected to be accessed as `pointer[offset +
+// index]` (note: `offset` is number of elements)
+
+#x86_64_target = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {
+ data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+ native_vector_size = 32 : index,
+ target_triple = "x86_64-none-elf"
+}>
+
+// The target devices that the program will run on. We can compile and run with
+// multiple targets, but this example is maintaining an implicit requirement
+// that the custom kernel being spliced in is supported by the target device,
+// hence we only support llvm-cpu here.
+#cpu_target = #hal.device.target<"llvm-cpu", {
+ executable_targets = [
+ #x86_64_target
+ ]
+}>
+
+module @example attributes {hal.device.targets = [#cpu_target]} {
+ func.func @mlp_invocation(%lhs: tensor<2x4xf32>, %rhs : tensor<4x8xf32>) -> tensor<2x8xf32> {
+ %lhs_3D = tosa.reshape %lhs {new_shape = array<i64 : 1, 2, 2>} : (tensor<2x4xf32>) -> tensor<1x2x4xf32>
+ %rhs_3D = tosa.reshape %rhs {new_shape = array<i64 : 1, 2, 2>} : (tensor<4x8xf32>) -> tensor<1x4x8xf32>
+ %0 = tosa.matmul %lhs_3D, %rhs_3D : (tensor<1x2x4xf32>, tensor<1x4x8xf32>) -> tensor<1x2x8xf32>
+ %1 = tosa.clamp %0 {
+ min_int = 0 : i64, max_int = 9223372036854775807 : i64,
+ min_fp = 0.0 : f32, max_fp = 3.4028235e+38 : f32}
+ : (tensor<1x2x8xf32>) -> tensor<1x2x8xf32>
+ %2 = tosa.negate %1 : (tensor<1x2x8xf32>) -> tensor<1x2x8xf32>
+ %3 = tosa.reshape %2 {new_shape = array<i64 : 2, 2>} : (tensor<1x2x8xf32>) -> tensor<2x8xf32>
+ return %3 : tensor<2x8xf32>
+ }
+}
+// CHECK-LABEL: EXEC @mlp_invocation
+// CHECK: [Plugin]: M = 2, N = 8, K = 4
+// CHECK: 2x8xf32=[-24 -0 -24 -0 -24 -0 -24 -0][-0 -24 -0 -24 -0 -24 -0 -24]
diff --git a/samples/custom_dispatch/cpu/mlp_plugin/mlp_tosa_spec.pdl.mlir b/samples/custom_dispatch/cpu/mlp_plugin/mlp_tosa_spec.pdl.mlir
new file mode 100644
index 0000000..cf9afb8
--- /dev/null
+++ b/samples/custom_dispatch/cpu/mlp_plugin/mlp_tosa_spec.pdl.mlir
@@ -0,0 +1,125 @@
+// PDL pattern spec to match an MLP and offload to an external function
+//
+// ```
+// void mlp_external(void *params, void *context, void *reserved)
+// ```
+//
+// which is the expected signature of an external function implemented
+// provided by a system plugin. See
+// samples/custom_dispatch/cpu/plugin/system_plugin.c for an example.
+//
+// The `params` is the following struct
+//
+// ```
+// struct mlp_params_t {
+// const float *restrict lhs;
+// size_t lhs_offset;
+// const float *restrict rhs;
+// size_t rhs_offset;
+// int32_t M;
+// int32_t N;
+// int32_t K;
+// float *restrict result;
+// size_t result_offset;
+// };
+// ```
+//
+// In MLIR this corresponds to the function
+//
+// ```
+// func.func @mlp_external(%lhs : memref<..xf32>, %rhs : memref<..xf32>,
+// %M: i32, %N : i32, %K : i32, %result : memref<..xf32>)
+// ```
+//
+// Note: In the above struct a `pointer, offset` pair represents a buffer
+// passed into the external function. So any access to `lhs`, `rhs` and
+// `result` is valid only if accessed as `lhs[lhs_offset + ...]`,
+// `rhs[rhs_offset + ]` and `result[result_offset + ...]`.
+pdl.pattern @mlp : benefit(1) {
+
+ // PDL matcher to match the MLP computation. This pattern is expected to
+ // match
+ //
+ // ```
+ // %result = func.call @mlp_external(%lhs : tensor<...xf32>,
+ // %rhs : tensor<..xf32>, %M : i32, %N : i32, %K : i32) -> tensor<..xf32>
+ // ```
+
+ %lhs_type = pdl.type
+ %rhs_type = pdl.type
+ %lhs = pdl.operand : %lhs_type
+ %rhs = pdl.operand : %rhs_type
+ %matmul_type = pdl.type
+ %min_int = pdl.attribute = 0 : i64
+ %max_int = pdl.attribute
+ %min_fp = pdl.attribute = 0.0 : f32
+ %max_fp = pdl.attribute
+ %matmul = pdl.operation "tosa.matmul"(%lhs, %rhs : !pdl.value, !pdl.value)
+ -> (%matmul_type : !pdl.type)
+ %element_type = pdl.type : f32
+ pdl.apply_native_constraint "checkTensorElementType"(%lhs_type, %element_type : !pdl.type, !pdl.type)
+ pdl.apply_native_constraint "checkTensorElementType"(%rhs_type, %element_type : !pdl.type, !pdl.type)
+ pdl.apply_native_constraint "checkTensorElementType"(%matmul_type, %element_type : !pdl.type, !pdl.type)
+ %matmul_result = pdl.result 0 of %matmul
+ %relu_type = pdl.type
+ %relu = pdl.operation "tosa.clamp"(%matmul_result : !pdl.value) {
+ "min_int" = %min_int, "max_int" = %max_int,
+ "min_fp" = %min_fp, "max_fp" = %max_fp}
+ -> (%relu_type : !pdl.type)
+
+ pdl.rewrite %matmul {
+ // The pattern above matched `%result`, `%lhs`, `%rhs` needed for the
+ // external function call. The values of `%M`, `%N` and `%K` need to
+ // be generated.
+ %one_val = pdl.attribute = 1 : index
+ %two_val = pdl.attribute = 2 : index
+ %index_type = pdl.type : index
+ %one_op = pdl.operation "arith.constant" {"value" = %one_val} -> (%index_type : !pdl.type)
+ %one = pdl.result 0 of %one_op
+ %two_op = pdl.operation "arith.constant" {"value" = %two_val} -> (%index_type : !pdl.type)
+ %two = pdl.result 0 of %two_op
+ %i32_type = pdl.type : i32
+ %m_op = pdl.operation "tensor.dim"(%lhs, %one : !pdl.value, !pdl.value)
+ %m = pdl.result 0 of %m_op
+ %n_op = pdl.operation "tensor.dim"(%rhs, %two : !pdl.value, !pdl.value)
+ %n = pdl.result 0 of %n_op
+ %k_op = pdl.operation "tensor.dim"(%lhs, %two : !pdl.value, !pdl.value)
+ %k = pdl.result 0 of %k_op
+ %m_i32_op = pdl.operation "arith.index_cast"(%m : !pdl.value) -> (%i32_type : !pdl.type)
+ %m_i32 = pdl.result 0 of %m_i32_op
+ %n_i32_op = pdl.operation "arith.index_cast"(%n : !pdl.value) -> (%i32_type : !pdl.type)
+ %n_i32 = pdl.result 0 of %n_i32_op
+ %k_i32_op = pdl.operation "arith.index_cast"(%k : !pdl.value) -> (%i32_type : !pdl.type)
+ %k_i32 = pdl.result 0 of %k_i32_op
+
+ %replaced_values_dims = pdl.range : !pdl.range<value>
+ %input_values = pdl.range %lhs, %rhs : !pdl.value, !pdl.value
+ %replaced_value = pdl.result 0 of %relu
+ %replaced_values = pdl.range %replaced_value : !pdl.value
+ %other_operands = pdl.range %m_i32, %n_i32, %k_i32 : !pdl.value, !pdl.value, !pdl.value
+
+ // The `rewriteAsFlowDispatch` is a rewrite function that allows
+ // converting the matched dag into a call to the external function call
+ // provided by a system plugin. The rewrite method expects the following
+ // arguments
+ // - the root of the matched DAG. This op will be erased after the call.
+ // - `fn_name` the name of the function that is provided externally
+ // (using a plugin).
+ // - `input_values` are values that are captures as the part of the match
+ // and are inputs to the match.
+ // - `replaced_values` are the values that are captured as part of the
+ // match and are replaced by the `flow.dispatch`. The `flow.dispatch`
+ // returns as many values as `replaced_values` (and of same type).
+ // - `replaced_values_dims` are the values for the dynamic dimensions of
+ // all the `tensor` values in `replaced_values`. For matches that could
+ // be static or dynamic, it should be assumed that the shape is dynamic
+ // and the value needs to be passed to the rewrite function.
+ // - `other_operands` same as `input_values`, but kept separate to allow
+ // flexibility of where the results are passed through the ABI boundary.
+ %fn_name = pdl.attribute = "mlp_external"
+ pdl.apply_native_rewrite "rewriteAsFlowDispatch"(
+ %relu, %fn_name, %input_values, %replaced_values, %replaced_values_dims, %other_operands
+ : !pdl.operation, !pdl.attribute, !pdl.range<value>, !pdl.range<value>, !pdl.range<value>, !pdl.range<value>)
+ }
+}
+