[Dispatch Creation] Bubble up ExtractSliceOp with FillOp when the latter has multiple consumers (#18896)
Signed-off-by: nithinsubbiah <nithinsubbiah@gmail.com>
diff --git a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExtractSlices.cpp b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExtractSlices.cpp
index 4b03b60..59ddb57 100644
--- a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExtractSlices.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExtractSlices.cpp
@@ -115,6 +115,30 @@
}
};
+/// Swaps tensor.extract_slice(linalg.fill(%cst, %init)) into linalg.fill(%cst,
+/// tensor.extract_slice(%init)) even when the linalg.fill has multiple users.
+/// Bubbles up tensor.extract_slice when encountered with linalg.fill and the
+/// former can be folded away.
+struct SwapExtractSliceOfFill final
+ : public OpRewritePattern<tensor::ExtractSliceOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
+ PatternRewriter &rewriter) const override {
+ auto fillOp = extractOp.getSource().getDefiningOp<linalg::FillOp>();
+ if (!fillOp)
+ return failure();
+
+ auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
+ extractOp.getLoc(), extractOp.getType(), fillOp.getOutputs()[0],
+ extractOp.getMixedOffsets(), extractOp.getMixedSizes(),
+ extractOp.getMixedStrides());
+ rewriter.replaceOpWithNewOp<linalg::FillOp>(
+ extractOp, fillOp.getInputs(), ValueRange{newExtractOp.getResult()});
+ return success();
+ }
+};
+
struct BubbleUpExtractSlicesPass
: impl::BubbleUpExtractSlicesPassBase<BubbleUpExtractSlicesPass> {
void runOnOperation() override {
@@ -122,6 +146,8 @@
{
RewritePatternSet patterns(context);
patterns.insert<BubbleUpExtract>(context);
+ patterns.insert<SwapExtractSliceOfFill>(context);
+ tensor::populateFoldTensorEmptyPatterns(patterns, false);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_extract_slice.mlir b/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_extract_slice.mlir
index a5b7ea1..56fa91d 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_extract_slice.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_extract_slice.mlir
@@ -94,3 +94,24 @@
// CHECK-DAG: %[[GENERIC1:.+]] = linalg.generic
// CHECK-SAME: ins(%[[EXTRACT0]] : tensor<1024x7x7xi8>)
// CHECK: util.return %[[GENERIC1]], %[[GENERIC0]]
+
+util.func public @bubble_up_extract_fill_multi_use() -> tensor<2x320x130x130xf8E4M3FNUZ> {
+ %cst_1 = arith.constant 1.000000e+00 : f8E4M3FNUZ
+ %cst_2 = arith.constant 2.000000e+00 : f8E4M3FNUZ
+ %1 = tensor.empty() : tensor<2x320x128x128xf8E4M3FNUZ>
+ %2 = linalg.fill ins(%cst_2 : f8E4M3FNUZ) outs(%1 : tensor<2x320x128x128xf8E4M3FNUZ>) -> tensor<2x320x128x128xf8E4M3FNUZ>
+ %3 = tensor.empty() : tensor<2x320x130x130xf8E4M3FNUZ>
+ %4 = linalg.fill ins(%cst_1 : f8E4M3FNUZ) outs(%3 : tensor<2x320x130x130xf8E4M3FNUZ>) -> tensor<2x320x130x130xf8E4M3FNUZ>
+ %extracted_slice_1 = tensor.extract_slice %4[0, 0, 1, 0] [2, 320, 128, 130] [1, 1, 1, 1] : tensor<2x320x130x130xf8E4M3FNUZ> to tensor<2x320x128x130xf8E4M3FNUZ>
+ %inserted_slice_1 = tensor.insert_slice %2 into %extracted_slice_1[0, 0, 0, 1] [2, 320, 128, 128] [1, 1, 1, 1] : tensor<2x320x128x128xf8E4M3FNUZ> into tensor<2x320x128x130xf8E4M3FNUZ>
+ %inserted_slice_2 = tensor.insert_slice %inserted_slice_1 into %4[0, 0, 1, 0] [2, 320, 128, 130] [1, 1, 1, 1] : tensor<2x320x128x130xf8E4M3FNUZ> into tensor<2x320x130x130xf8E4M3FNUZ>
+ util.return %inserted_slice_2 : tensor<2x320x130x130xf8E4M3FNUZ>
+}
+
+// CHECK-LABEL: @bubble_up_extract_fill_multi_use
+// CHECK: %[[FILL1:.+]] = linalg.fill
+// CHECK: %[[EMPTY1:.+]] = tensor.empty
+// CHECK: %[[FILL2:.+]] = linalg.fill
+// CHECK-NOT: %[[SLICE:.+]] = tensor.extract_slice
+// CHECK: %[[EMPTY2:.+]] = tensor.empty
+// CHECK: %[[FILL3:.+]] = linalg.fill