[GlobalOptimization] Drop obsolete transpose propagation pattern (#17203)
The pattern to fuse transposes with linalg generics was changed to
accept linalg ops so this pattern is no longer needed. This also
prevents fusion with named convolution ops, which matches the behavior
of other aggressive propagation patterns.
diff --git a/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp b/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp
index b381174..8e0ccee 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp
@@ -735,47 +735,6 @@
}
};
-// Sinks a transpose to the input of a linalg named op. The conditions for the
-// rewrite are
-// 1) One of the input producers to the named op is a linalg.transpose
-// 2) The named op is generalizable (and is not a transpose)
-// The easiest way to get the rewrite we want then is to just try to generalize
-// all transposed named ops and let the generic pattern handle the actual
-// rewrite.
-class GeneralizeInputTransposedNamedOp
- : public OpInterfaceRewritePattern<linalg::LinalgOp> {
-public:
- using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
-
- LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
- PatternRewriter &rewriter) const override {
- if (!IREE::Flow::isNonNullAndOutsideDispatch(linalgOp)) {
- return failure();
- }
- // Don't generalize transposes.
- if (isa<linalg::TransposeOp>(linalgOp)) {
- return rewriter.notifyMatchFailure(linalgOp,
- "do not generalize transposes");
- }
- bool hasTranspose = false;
- for (Value input : linalgOp.getDpsInputs()) {
- auto definingTranspose = input.getDefiningOp<linalg::TransposeOp>();
- if (definingTranspose && definingTranspose->hasOneUse()) {
- hasTranspose = true;
- break;
- }
- }
- if (!hasTranspose) {
- return rewriter.notifyMatchFailure(linalgOp, "no transpose input");
- }
- if (failed(linalg::generalizeNamedOp(rewriter, linalgOp))) {
- return rewriter.notifyMatchFailure(linalgOp,
- "failed to generalize named op");
- }
- return success();
- }
-};
-
} // namespace
//===----------------------------------------------------------------------===//
@@ -1046,10 +1005,6 @@
populateNamedOpSinkingPatterns(context, sinkingPatterns);
sinkingPatterns.add<SinkTransposeThroughUnaryElementwiseInput>(
context, /*benefit=*/2);
-
- if (enableAggressivePropagation) {
- sinkingPatterns.insert<GeneralizeInputTransposedNamedOp>(context);
- }
if (failed(
applyPatternsAndFoldGreedily(funcOp, std::move(sinkingPatterns)))) {
return signalPassFailure();
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir
index df35095..78b9c0a 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir
@@ -224,6 +224,23 @@
// -----
+util.func public @do_not_propagate_to_conv(%transposed_lhs: tensor<18x2x18x8xf32>,
+ %rhs: tensor<3x3x8x32xf32>) -> tensor<2x16x16x32xf32> {
+ %empty = tensor.empty(): tensor<2x18x18x8xf32>
+ %lhs = linalg.transpose ins(%transposed_lhs : tensor<18x2x18x8xf32>)
+ outs(%empty : tensor<2x18x18x8xf32>) permutation = [1, 0, 2, 3]
+ %out = tensor.empty(): tensor<2x16x16x32xf32>
+ %conv = linalg.conv_2d_nhwc_hwcf {strides = dense<1> : tensor<2xi64>, dilations = dense<1> : tensor<2xi64>}
+ ins(%lhs, %rhs : tensor<2x18x18x8xf32>, tensor<3x3x8x32xf32>)
+ outs(%out : tensor<2x16x16x32xf32>) -> tensor<2x16x16x32xf32>
+ util.return %conv : tensor<2x16x16x32xf32>
+}
+
+// APROP-LABEL: util.func public @do_not_propagate_to_conv
+// APROP: linalg.conv_2d_nhwc_hwcf
+
+// -----
+
util.func public @sink_through_expand_shape(%arg0 : tensor<?x?x?xf32>) -> tensor<32x?x16x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index