blob: 3eafa2ea56a3f31f3c2143ff897edcfb87aeeb9b [file] [log] [blame]
// RUN: iree-opt %s
// Dispatch softmax.
transform.sequence failures(propagate){
^bb1(%variant_op: !transform.any_op):
%ops = transform.structured.match ops{["linalg.fill", "linalg.generic"]}
in %variant_op : (!transform.any_op) -> !transform.any_op
%input_max_fill, %input_max, %exps_sum_fill, %exps, %exps_sum, %div =
transform.split_handle %ops
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op,
!transform.any_op, !transform.any_op, !transform.any_op)
/// This must be used with the custom dispatch region formation
/// because IREE's does not fuse the 6 ops softmax version even with
/// --iree-flow-fuse-multi-use.
%region_op = transform.iree.wrap_in_dispatch_region %div { generateWorkload = false } : (!transform.any_op) -> !transform.any_op
%non_div = transform.merge_handles %input_max_fill, %input_max, %exps_sum_fill, %exps, %exps_sum
: !transform.any_op
%region_op_2 = transform.iree.move_preceding_op_into_dispatch_region %non_div into %region_op : (!transform.any_op, !transform.any_op) -> !transform.any_op
%empty = transform.structured.match ops{["tensor.empty"]} in %variant_op : (!transform.any_op) -> !transform.any_op
%region_op_3 = transform.iree.move_preceding_op_into_dispatch_region %empty into %region_op_2 : (!transform.any_op, !transform.any_op) -> !transform.any_op
transform.iree.region_to_workgroups %region_op_3 : (!transform.any_op) -> !transform.any_op
}