[GlobalOpt] Improve reshape/empty cleanup in transpose propagation (#17905)
This adds a few simplifying patterns for reshape/empty/fill folding
patterns now that the pass uses reshape propagation patterns. This helps
cleanup the majority of transposed destinations, which are almost always
either a fill or an empty.
diff --git a/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp b/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp
index 8e0ccee..ac2c356 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp
@@ -23,7 +23,9 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -58,9 +60,17 @@
return empty;
}
-// Constructs a transpose of the given tensor and permutation.
+// Constructs a transpose of the given tensor and permutation,
+// or produces a transposed version of the producing tensor.empty op.
static Value createTranspose(OpBuilder &builder, Value source,
ArrayRef<int64_t> perm) {
+ if (auto empty = source.getDefiningOp<tensor::EmptyOp>()) {
+ Type elementType = empty.getType().getElementType();
+ SmallVector<OpFoldResult> mixedSizes = empty.getMixedSizes();
+ applyPermutationToVector(mixedSizes, perm);
+ return builder.create<tensor::EmptyOp>(empty.getLoc(), mixedSizes,
+ elementType);
+ }
Value empty = createTransposeInit(builder, source, perm);
return builder
.create<linalg::TransposeOp>(source.getLoc(), source, empty, perm)
@@ -861,6 +871,17 @@
SmallVector<int64_t>{0, 2, 1});
}
+static void
+populateCommonCanonicalizationPatterns(MLIRContext *context,
+ RewritePatternSet &patterns) {
+ linalg::FillOp::getCanonicalizationPatterns(patterns, context);
+ tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
+ tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
+ tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
+ tensor::populateFoldTensorEmptyPatterns(patterns,
+ /*foldSingleUseOnly=*/false);
+}
+
void PropagateLinalgTransposePass::runOnOperation() {
MLIRContext *context = &getContext();
auto funcOp = getOperation();
@@ -895,6 +916,7 @@
sinkingPatterns.insert<SinkTransposeThroughExtractSlice>(context);
sinkingPatterns.insert<SinkTransposeThroughExpandShape>(context);
populateNamedOpSinkingPatterns(context, sinkingPatterns);
+ populateCommonCanonicalizationPatterns(context, sinkingPatterns);
sinkingPatterns.add<SinkTransposeThroughUnaryElementwiseInput>(
context, /*benefit=*/2);
if (failed(
@@ -952,6 +974,7 @@
bubblingPatterns.add<BubbleTransposeThroughUnaryElementwiseDpsInit>(
context, /*benefit=*/2);
bubblingPatterns.insert<ComposeTransposes>(context);
+ populateCommonCanonicalizationPatterns(context, bubblingPatterns);
if (failed(applyPatternsAndFoldGreedily(funcOp,
std::move(bubblingPatterns)))) {
return signalPassFailure();
@@ -1003,6 +1026,7 @@
context, enableAggressivePropagation);
sinkingPatterns.insert<ComposeTransposes>(context);
populateNamedOpSinkingPatterns(context, sinkingPatterns);
+ populateCommonCanonicalizationPatterns(context, sinkingPatterns);
sinkingPatterns.add<SinkTransposeThroughUnaryElementwiseInput>(
context, /*benefit=*/2);
if (failed(
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir
index feb5874..939e650 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir
@@ -43,6 +43,21 @@
// -----
+util.func public @fold_transpose_of_fill() -> tensor<32x128xf32> {
+ %cst = arith.constant 1.0 : f32
+ %empty = tensor.empty(): tensor<128x32xf32>
+ %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<128x32xf32>) -> tensor<128x32xf32>
+ %empty_t = tensor.empty(): tensor<32x128xf32>
+ %transposed = linalg.transpose ins(%fill : tensor<128x32xf32>)
+ outs(%empty_t : tensor<32x128xf32>) permutation = [1, 0]
+ util.return %transposed : tensor<32x128xf32>
+}
+// CHECK-LABEL: util.func public @fold_transpose_of_fill
+// CHECK: %[[FILL:.+]] = linalg.fill
+// CHECK: util.return %[[FILL]]
+
+// -----
+
util.func public @propagate_through_extract_slice(%arg0 : tensor<1x256x128xf32>) -> tensor<1x128x32xf32> {
%empty = tensor.empty(): tensor<1x128x256xf32>
%transposed = linalg.transpose ins(%arg0 : tensor<1x256x128xf32>)
@@ -465,6 +480,8 @@
// APROP-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
// APROP-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// APROP: util.func public @bubble_through_matmul
+// APROP: %[[EMPTY:.+]] = tensor.empty() : tensor<16x16xf32>
// APROP: %[[MATMUL:.+]] = linalg.generic
-// APROP: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]]
+// APROP-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]]
+// APROP-SAME: outs(%[[EMPTY]] : tensor<16x16xf32>)
// APROP: util.return %[[MATMUL]]