[GlobalOpt] Improve reshape/empty cleanup in transpose propagation (#17905)

This adds a few simplifying patterns for reshape/empty/fill folding
patterns now that the pass uses reshape propagation patterns. This helps
cleanup the majority of transposed destinations, which are almost always
either a fill or an empty.
diff --git a/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp b/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp
index 8e0ccee..ac2c356 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp
@@ -23,7 +23,9 @@
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -58,9 +60,17 @@
   return empty;
 }
 
-// Constructs a transpose of the given tensor and permutation.
+// Constructs a transpose of the given tensor and permutation,
+// or produces a transposed version of the producing tensor.empty op.
 static Value createTranspose(OpBuilder &builder, Value source,
                              ArrayRef<int64_t> perm) {
+  if (auto empty = source.getDefiningOp<tensor::EmptyOp>()) {
+    Type elementType = empty.getType().getElementType();
+    SmallVector<OpFoldResult> mixedSizes = empty.getMixedSizes();
+    applyPermutationToVector(mixedSizes, perm);
+    return builder.create<tensor::EmptyOp>(empty.getLoc(), mixedSizes,
+                                           elementType);
+  }
   Value empty = createTransposeInit(builder, source, perm);
   return builder
       .create<linalg::TransposeOp>(source.getLoc(), source, empty, perm)
@@ -861,6 +871,17 @@
                                                  SmallVector<int64_t>{0, 2, 1});
 }
 
+static void
+populateCommonCanonicalizationPatterns(MLIRContext *context,
+                                       RewritePatternSet &patterns) {
+  linalg::FillOp::getCanonicalizationPatterns(patterns, context);
+  tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
+  tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
+  tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
+  tensor::populateFoldTensorEmptyPatterns(patterns,
+                                          /*foldSingleUseOnly=*/false);
+}
+
 void PropagateLinalgTransposePass::runOnOperation() {
   MLIRContext *context = &getContext();
   auto funcOp = getOperation();
@@ -895,6 +916,7 @@
     sinkingPatterns.insert<SinkTransposeThroughExtractSlice>(context);
     sinkingPatterns.insert<SinkTransposeThroughExpandShape>(context);
     populateNamedOpSinkingPatterns(context, sinkingPatterns);
+    populateCommonCanonicalizationPatterns(context, sinkingPatterns);
     sinkingPatterns.add<SinkTransposeThroughUnaryElementwiseInput>(
         context, /*benefit=*/2);
     if (failed(
@@ -952,6 +974,7 @@
     bubblingPatterns.add<BubbleTransposeThroughUnaryElementwiseDpsInit>(
         context, /*benefit=*/2);
     bubblingPatterns.insert<ComposeTransposes>(context);
+    populateCommonCanonicalizationPatterns(context, bubblingPatterns);
     if (failed(applyPatternsAndFoldGreedily(funcOp,
                                             std::move(bubblingPatterns)))) {
       return signalPassFailure();
@@ -1003,6 +1026,7 @@
         context, enableAggressivePropagation);
     sinkingPatterns.insert<ComposeTransposes>(context);
     populateNamedOpSinkingPatterns(context, sinkingPatterns);
+    populateCommonCanonicalizationPatterns(context, sinkingPatterns);
     sinkingPatterns.add<SinkTransposeThroughUnaryElementwiseInput>(
         context, /*benefit=*/2);
     if (failed(
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir
index feb5874..939e650 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir
@@ -43,6 +43,21 @@
 
 // -----
 
+util.func public @fold_transpose_of_fill() -> tensor<32x128xf32> {
+  %cst = arith.constant 1.0 : f32
+  %empty = tensor.empty(): tensor<128x32xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<128x32xf32>) -> tensor<128x32xf32>
+  %empty_t = tensor.empty(): tensor<32x128xf32>
+  %transposed = linalg.transpose ins(%fill : tensor<128x32xf32>)
+      outs(%empty_t : tensor<32x128xf32>) permutation = [1, 0]
+  util.return %transposed : tensor<32x128xf32>
+}
+// CHECK-LABEL: util.func public @fold_transpose_of_fill
+//       CHECK:   %[[FILL:.+]] = linalg.fill
+//       CHECK:   util.return %[[FILL]]
+
+// -----
+
 util.func public @propagate_through_extract_slice(%arg0 : tensor<1x256x128xf32>) -> tensor<1x128x32xf32> {
   %empty = tensor.empty(): tensor<1x128x256xf32>
   %transposed = linalg.transpose ins(%arg0 : tensor<1x256x128xf32>)
@@ -465,6 +480,8 @@
 //   APROP-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
 //   APROP-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 //       APROP: util.func public @bubble_through_matmul
+//       APROP:   %[[EMPTY:.+]] = tensor.empty() : tensor<16x16xf32>
 //       APROP:   %[[MATMUL:.+]] = linalg.generic
-//       APROP:     indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]]
+//  APROP-SAME:     indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]]
+//  APROP-SAME:     outs(%[[EMPTY]] : tensor<16x16xf32>)
 //       APROP:   util.return %[[MATMUL]]