blob: 58f3f151672a136221fdc8cc234c6a14e3df7a68 [file] [log] [blame]
// Sample spec that matches an MLP example and forwards to
// an implementation implemented by a system plugin.
// Is used along with samples/custom_dispatch/cpu/plugin/mlp.mlir
#x86_64_target = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {
data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
native_vector_size = 32 : index,
target_triple = "x86_64-none-elf"
}>
#cpu_target = #hal.device.target<"llvm-cpu", {
executable_targets = [
#x86_64_target
]
}>
module attributes {transform.with_named_sequence} {
// Executable that stages call to the external functions.
hal.executable private @executable {
hal.executable.variant public @x86_64 target(#x86_64_target) {
hal.executable.export public @mlp ordinal(0)
layout(#hal.pipeline.layout<push_constants = 3, sets = [
<0, bindings = [
<0, storage_buffer, ReadOnly>,
<1, storage_buffer, ReadOnly>,
<2, storage_buffer>
]>
]>) {
^bb0(%device : !hal.device):
%c1 = arith.constant 1 : index
hal.return %c1, %c1, %c1 : index, index, index
}
builtin.module {
func.func private @mlp_external(%lhs : memref<?x?xf32>, %rhs : memref<?x?xf32>, %result : memref<?x?xf32>, %m : i32, %n : i32, %k : i32)
func.func @mlp() {
%m_i32 = hal.interface.constant.load[0] : i32
%n_i32 = hal.interface.constant.load[1] : i32
%k_i32 = hal.interface.constant.load[2] : i32
%c0 = arith.constant 0 : index
%m = arith.index_cast %m_i32 : i32 to index
%n = arith.index_cast %n_i32 : i32 to index
%k = arith.index_cast %k_i32 : i32 to index
%lhs = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : memref<?x?xf32>{%m, %k}
%rhs = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<?x?xf32>{%k, %n}
%result = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<?x?xf32>{%m, %n}
func.call @mlp_external(%lhs, %rhs, %result, %m_i32, %n_i32, %k_i32) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>, i32, i32, i32) -> ()
return
}
}
}
}
util.func private @call_mlp(%lhs : tensor<?x?xf32>, %rhs : tensor<?x?xf32>, %init1 : tensor<?x?xf32>, %init2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%m = tensor.dim %lhs, %c0 : tensor<?x?xf32>
%n = tensor.dim %rhs, %c1 : tensor<?x?xf32>
%k = tensor.dim %lhs, %c1 : tensor<?x?xf32>
%m_i32 = arith.index_cast %m : index to i32
%n_i32 = arith.index_cast %n : index to i32
%k_i32 = arith.index_cast %k : index to i32
%mlp_result = flow.dispatch @executable::@x86_64::@mlp(%lhs, %rhs, %m_i32, %n_i32, %k_i32) : (tensor<?x?xf32>{%m, %k}, tensor<?x?xf32>{%k, %n}, i32, i32, i32) -> tensor<?x?xf32>{%m, %n}
util.return %mlp_result : tensor<?x?xf32>
}
transform.named_sequence @match_mlp(%root: !transform.any_op {transform.readonly}) -> (!transform.any_value, !transform.any_value) {
%ins, %outs = transform.iree.match.cast_compatible_dag_from_root %root {
^bb0(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>, %init1 : tensor<?x?xf32>, %init2 : tensor<?x?xf32>):
%cst = arith.constant 0.0 : f32
%fill = linalg.fill ins(%cst : f32) outs(%init1 : tensor<?x?xf32>) -> tensor<?x?xf32>
%matmul = linalg.matmul
ins(%lhs, %rhs : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
%relu = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%matmul : tensor<?x?xf32>)
outs(%init2 : tensor<?x?xf32>) {
^bb0(%b0 : f32, %b1 : f32):
%0 = arith.maximumf %b0, %cst : f32
linalg.yield %0 : f32
} -> tensor<?x?xf32>
} : (!transform.any_op) -> (!transform.any_value, !transform.any_value)
transform.yield %ins, %outs : !transform.any_value, !transform.any_value
}
// Rewrite callback for `transform.foreach_match`. The input signature for
// this sequence must match exactly with the outputs of the matcher. In this
// case the matcher returns the inputs and outputs to the matched dag directly
// so we just insert a call to the hand authored function above.
transform.named_sequence @cast_and_call_dag(%ins: !transform.any_value {transform.readonly},
%out: !transform.any_value {transform.readonly}) {
%root = transform.get_defining_op %out : (!transform.any_value) -> !transform.any_op
%module = transform.util.get_nearest_symbol_table %root : (!transform.any_op) -> !transform.any_op
%executable = transform.util.import_symbol @executable into %module if undefined : (!transform.any_op) -> !transform.any_op
%func = transform.util.import_symbol @call_mlp into %module if undefined : (!transform.any_op) -> !transform.any_op
transform.util.cast_and_call %func(%ins) -> %out after %root {
// This specifies how to resolve type mismatches between the arguments
// of the function and the inputs from the matcher. In this example,
// the only casts this will generate are same-rank tensor casts that
// drop static information.
transform.type_conversion.tensor.cast_shape_dynamic_dims
} : (!transform.any_op, !transform.any_value, !transform.any_value, !transform.any_op) -> !transform.any_op
transform.yield
}
// Entry point for the transform interpreter, nested on the full module. This
// is because the rewrites needed for importing the custom kernel needs to
// add a new symbol to the module's symbol table.
transform.named_sequence @__transform_main(%module: !transform.any_op) {
// Gather the set of functions within the module.
%funcs = transform.structured.match ops{["util.func"]} in %module : (!transform.any_op) -> !transform.any_op
// For each function in the module, run the matcher on all contained
// operations.
transform.foreach %funcs : !transform.any_op {
^bb1(%func: !transform.any_op):
transform.foreach_match in %func
// <matcher name> -> <rewriter name>
// Multiple matcher-action pairs can be specified comma separated,
// here we are only doing a single kind of match and replace.
@match_mlp -> @cast_and_call_dag
: (!transform.any_op) -> (!transform.any_op)
}
// Cleanup leftover dead code; cast_and_call does not do replacement, only
// rewires uses.
transform.apply_dce to %module : !transform.any_op
transform.yield
}
}