Refine fusion heuristic (#7165)
Only merge duplicate broacast ops to avoid duplicating an op containing
a broacast along with expensive work
diff --git a/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp b/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
index 02b3b25..2912248 100644
--- a/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
@@ -87,7 +87,10 @@
return attr.cast<StringAttr>().getValue() ==
getParallelIteratorTypeName();
});
- if (parallelOp) {
+ // Detect op that only broadcast input as fusing them makes the new
+ // op cheaper.
+ if (parallelOp &&
+ isa<linalg::YieldOp>(genericOp.getBody()->front())) {
for (OpOperand *opOperand : genericOp.getInputOperands()) {
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
if (indexingMap.isProjectedPermutation() &&