blob: e4f02dac80d944b6254c0a0a7f48f48944f2e174 [file] [log] [blame]
// PDL pattern spec to match an MLP and offload to an external function
//
// ```
// void mlp_external(void *params, void *context, void *reserved)
// ```
//
// which is the expected signature of an external function implemented
// provided by a system plugin. See
// samples/custom_dispatch/cpu/plugin/system_plugin.c for an example.
//
// The `params` is the following struct
//
// ```
// struct mlp_params_t {
// const float *restrict lhs;
// size_t lhs_offset;
// const float *restrict rhs;
// size_t rhs_offset;
// int32_t M;
// int32_t N;
// int32_t K;
// float *restrict result;
// size_t result_offset;
// };
// ```
//
// In MLIR this corresponds to the function
//
// ```
// func.func @mlp_external(%lhs : memref<..xf32>, %rhs : memref<..xf32>,
// %M: i32, %N : i32, %K : i32, %result : memref<..xf32>)
// ```
//
// Note: In the above struct a `pointer, offset` pair represents a buffer
// passed into the external function. So any access to `lhs`, `rhs` and
// `result` is valid only if accessed as `lhs[lhs_offset + ...]`,
// `rhs[rhs_offset + ]` and `result[result_offset + ...]`.
pdl.pattern @mlp : benefit(1) {
// PDL matcher to match the MLP computation. This pattern is expected to
// match
//
// ```
// %result = func.call @mlp_external(%lhs : tensor<...xf32>,
// %rhs : tensor<..xf32>, %M : i32, %N : i32, %K : i32) -> tensor<..xf32>
// ```
%lhs_type = pdl.type
%lhs = pdl.operand : %lhs_type
%rhs_type = pdl.type
%rhs = pdl.operand : %rhs_type
%matmul_type = pdl.type
%min_int = pdl.attribute = 0 : i64
%max_int = pdl.attribute
%min_fp = pdl.attribute = 0.0 : f32
%max_fp = pdl.attribute
%matmul = pdl.operation "tosa.matmul"(%lhs, %rhs : !pdl.value, !pdl.value)
-> (%matmul_type : !pdl.type)
%element_type = pdl.type : f32
pdl.apply_native_constraint "checkTensorElementType"(%lhs_type, %element_type : !pdl.type, !pdl.type)
pdl.apply_native_constraint "checkTensorElementType"(%rhs_type, %element_type : !pdl.type, !pdl.type)
pdl.apply_native_constraint "checkTensorElementType"(%matmul_type, %element_type : !pdl.type, !pdl.type)
%matmul_result = pdl.result 0 of %matmul
%relu_type = pdl.type
%relu = pdl.operation "tosa.clamp"(%matmul_result : !pdl.value) {
"min_int" = %min_int, "max_int" = %max_int,
"min_fp" = %min_fp, "max_fp" = %max_fp}
-> (%relu_type : !pdl.type)
pdl.rewrite %matmul {
// The pattern above matched `%result`, `%lhs`, `%rhs` needed for the
// external function call. The values of `%M`, `%N` and `%K` need to
// be generated.
%one_val = pdl.attribute = 1 : index
%two_val = pdl.attribute = 2 : index
%index_type = pdl.type : index
%one_op = pdl.operation "arith.constant" {"value" = %one_val} -> (%index_type : !pdl.type)
%one = pdl.result 0 of %one_op
%two_op = pdl.operation "arith.constant" {"value" = %two_val} -> (%index_type : !pdl.type)
%two = pdl.result 0 of %two_op
%i32_type = pdl.type : i32
%m_op = pdl.operation "tensor.dim"(%lhs, %one : !pdl.value, !pdl.value)
%m = pdl.result 0 of %m_op
%n_op = pdl.operation "tensor.dim"(%rhs, %two : !pdl.value, !pdl.value)
%n = pdl.result 0 of %n_op
%k_op = pdl.operation "tensor.dim"(%lhs, %two : !pdl.value, !pdl.value)
%k = pdl.result 0 of %k_op
%m_i32_op = pdl.operation "arith.index_cast"(%m : !pdl.value) -> (%i32_type : !pdl.type)
%m_i32 = pdl.result 0 of %m_i32_op
%n_i32_op = pdl.operation "arith.index_cast"(%n : !pdl.value) -> (%i32_type : !pdl.type)
%n_i32 = pdl.result 0 of %n_i32_op
%k_i32_op = pdl.operation "arith.index_cast"(%k : !pdl.value) -> (%i32_type : !pdl.type)
%k_i32 = pdl.result 0 of %k_i32_op
%replaced_values_dims = pdl.range : !pdl.range<value>
%input_values = pdl.range %lhs, %rhs : !pdl.value, !pdl.value
%replaced_value = pdl.result 0 of %relu
%replaced_values = pdl.range %replaced_value : !pdl.value
%other_operands = pdl.range %m_i32, %n_i32, %k_i32 : !pdl.value, !pdl.value, !pdl.value
// The `rewriteAsFlowDispatch` is a rewrite function that allows
// converting the matched dag into a call to the external function call
// provided by a system plugin. The rewrite method expects the following
// arguments
// - the root of the matched DAG. This op will be erased after the call.
// - `fn_name` the name of the function that is provided externally
// (using a plugin).
// - `input_values` are values that are captures as the part of the match
// and are inputs to the match.
// - `replaced_values` are the values that are captured as part of the
// match and are replaced by the `flow.dispatch`. The `flow.dispatch`
// returns as many values as `replaced_values` (and of same type).
// - `replaced_values_dims` are the values for the dynamic dimensions of
// all the `tensor` values in `replaced_values`. For matches that could
// be static or dynamic, it should be assumed that the shape is dynamic
// and the value needs to be passed to the rewrite function.
// - `other_operands` same as `input_values`, but kept separate to allow
// flexibility of where the results are passed through the ABI boundary.
%fn_name = pdl.attribute = "mlp_external"
pdl.apply_native_rewrite "rewriteAsFlowDispatch"(
%relu, %fn_name, %input_values, %replaced_values, %replaced_values_dims, %other_operands
: !pdl.operation, !pdl.attribute, !pdl.range<value>, !pdl.range<value>, !pdl.range<value>, !pdl.range<value>)
}
}