[Dispatch Creation] Bubble up ExtractSliceOp with FillOp when the latter has multiple consumers (#18896)

Signed-off-by: nithinsubbiah <nithinsubbiah@gmail.com>
diff --git a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExtractSlices.cpp b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExtractSlices.cpp
index 4b03b60..59ddb57 100644
--- a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExtractSlices.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExtractSlices.cpp
@@ -115,6 +115,30 @@
   }
 };
 
+/// Swaps tensor.extract_slice(linalg.fill(%cst, %init)) into linalg.fill(%cst,
+/// tensor.extract_slice(%init)) even when the linalg.fill has multiple users.
+/// Bubbles up tensor.extract_slice when encountered with linalg.fill and the
+/// former can be folded away.
+struct SwapExtractSliceOfFill final
+    : public OpRewritePattern<tensor::ExtractSliceOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    auto fillOp = extractOp.getSource().getDefiningOp<linalg::FillOp>();
+    if (!fillOp)
+      return failure();
+
+    auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
+        extractOp.getLoc(), extractOp.getType(), fillOp.getOutputs()[0],
+        extractOp.getMixedOffsets(), extractOp.getMixedSizes(),
+        extractOp.getMixedStrides());
+    rewriter.replaceOpWithNewOp<linalg::FillOp>(
+        extractOp, fillOp.getInputs(), ValueRange{newExtractOp.getResult()});
+    return success();
+  }
+};
+
 struct BubbleUpExtractSlicesPass
     : impl::BubbleUpExtractSlicesPassBase<BubbleUpExtractSlicesPass> {
   void runOnOperation() override {
@@ -122,6 +146,8 @@
     {
       RewritePatternSet patterns(context);
       patterns.insert<BubbleUpExtract>(context);
+      patterns.insert<SwapExtractSliceOfFill>(context);
+      tensor::populateFoldTensorEmptyPatterns(patterns, false);
       if (failed(applyPatternsAndFoldGreedily(getOperation(),
                                               std::move(patterns)))) {
         return signalPassFailure();
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_extract_slice.mlir b/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_extract_slice.mlir
index a5b7ea1..56fa91d 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_extract_slice.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_extract_slice.mlir
@@ -94,3 +94,24 @@
 //   CHECK-DAG:    %[[GENERIC1:.+]] = linalg.generic
 //  CHECK-SAME:      ins(%[[EXTRACT0]] : tensor<1024x7x7xi8>)
 //       CHECK:    util.return %[[GENERIC1]], %[[GENERIC0]]
+
+util.func public @bubble_up_extract_fill_multi_use() -> tensor<2x320x130x130xf8E4M3FNUZ> {
+  %cst_1 = arith.constant 1.000000e+00 : f8E4M3FNUZ
+  %cst_2 = arith.constant 2.000000e+00 : f8E4M3FNUZ
+  %1 = tensor.empty() : tensor<2x320x128x128xf8E4M3FNUZ>
+  %2 = linalg.fill ins(%cst_2 : f8E4M3FNUZ) outs(%1 : tensor<2x320x128x128xf8E4M3FNUZ>) -> tensor<2x320x128x128xf8E4M3FNUZ>
+  %3 = tensor.empty() : tensor<2x320x130x130xf8E4M3FNUZ>
+  %4 = linalg.fill ins(%cst_1 : f8E4M3FNUZ) outs(%3 : tensor<2x320x130x130xf8E4M3FNUZ>) -> tensor<2x320x130x130xf8E4M3FNUZ>
+  %extracted_slice_1 = tensor.extract_slice %4[0, 0, 1, 0] [2, 320, 128, 130] [1, 1, 1, 1] : tensor<2x320x130x130xf8E4M3FNUZ> to tensor<2x320x128x130xf8E4M3FNUZ>
+  %inserted_slice_1 = tensor.insert_slice %2 into %extracted_slice_1[0, 0, 0, 1] [2, 320, 128, 128] [1, 1, 1, 1] : tensor<2x320x128x128xf8E4M3FNUZ> into tensor<2x320x128x130xf8E4M3FNUZ>
+  %inserted_slice_2 = tensor.insert_slice %inserted_slice_1 into %4[0, 0, 1, 0] [2, 320, 128, 130] [1, 1, 1, 1] : tensor<2x320x128x130xf8E4M3FNUZ> into tensor<2x320x130x130xf8E4M3FNUZ>
+  util.return %inserted_slice_2 : tensor<2x320x130x130xf8E4M3FNUZ>
+}
+
+// CHECK-LABEL:  @bubble_up_extract_fill_multi_use
+//       CHECK:    %[[FILL1:.+]] = linalg.fill
+//       CHECK:    %[[EMPTY1:.+]] = tensor.empty
+//       CHECK:    %[[FILL2:.+]] = linalg.fill
+//   CHECK-NOT:    %[[SLICE:.+]] = tensor.extract_slice
+//       CHECK:    %[[EMPTY2:.+]] = tensor.empty
+//       CHECK:    %[[FILL3:.+]] = linalg.fill