blob: 804ce2dfac3504bd59d3c92b3ab338b246c56d3b [file] [log] [blame]
// RUN: iree-opt %s
// Dispatch softmax.
transform.sequence failures(propagate){
^bb1(%variant_op: !pdl.operation):
%ops = transform.structured.match ops{["linalg.fill", "linalg.generic"]}
in %variant_op : (!pdl.operation) -> !pdl.operation
%input_max_fill, %input_max, %exps_sum_fill, %exps, %exps_sum, %div =
transform.split_handles %ops in [6]
: (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation,
!pdl.operation, !pdl.operation, !pdl.operation)
/// This must be used with the custom dispatch region formation
/// because IREE's does not fuse the 6 ops softmax version even with
/// --iree-flow-fuse-multi-use.
%region_op = transform.iree.wrap_in_dispatch_region %div { generateWorkload = false }
%non_div = transform.merge_handles %input_max_fill, %input_max, %exps_sum_fill, %exps, %exps_sum
: !pdl.operation
%region_op_2 = transform.iree.move_preceding_op_into_dispatch_region %non_div into %region_op
%empty = transform.structured.match ops{["tensor.empty"]} in %variant_op : (!pdl.operation) -> !pdl.operation
%region_op_3 = transform.iree.move_preceding_op_into_dispatch_region %empty into %region_op_2
transform.iree.region_to_workgroups %region_op_3
}