Add simple heuristic to decide if TensorReshape should be folded during fusion (#5797)
This solves a performance regression on MobileBert.
diff --git a/iree/compiler/Conversion/HLOToLinalg/FusionOfTensorOps.cpp b/iree/compiler/Conversion/HLOToLinalg/FusionOfTensorOps.cpp
index f6e3b9e..3034e91 100644
--- a/iree/compiler/Conversion/HLOToLinalg/FusionOfTensorOps.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/FusionOfTensorOps.cpp
@@ -108,14 +108,20 @@
}
return numUsers.empty();
};
- linalg::ControlElementwiseOpsFusionFn foldAllFn =
+ // Simple heuristic to decide if reshaope should be folded in the linalg.
+ // If the source of the reshape is a linalg op fold to potentially allow the
+ // two linalg ops to be fused. Otherwise leave it to avoid adding dimensions
+ // to the consumer linalg op.
+ linalg::ControlElementwiseOpsFusionFn foldReshapeBetweenLinalgFn =
[](const OpResult &producer, const OpOperand &consumer) {
- return true;
+ auto reshapeOp = producer.getDefiningOp<linalg::TensorReshapeOp>();
+ return reshapeOp.src().getDefiningOp<linalg::LinalgOp>() != nullptr;
};
linalg::populateElementwiseOpsFusionPatterns(
- fusionPatterns, linalg::LinalgElementwiseFusionOptions()
- .setControlFoldingReshapes(foldAllFn)
- .setControlElementwiseOpsFusionFn(controlFn));
+ fusionPatterns,
+ linalg::LinalgElementwiseFusionOptions()
+ .setControlFoldingReshapes(foldReshapeBetweenLinalgFn)
+ .setControlElementwiseOpsFusionFn(controlFn));
(void)applyPatternsAndFoldGreedily(op->getRegions(),
std::move(fusionPatterns));