Making FlattenFullFillToSplat more conservative. (#17079)
Full analysis is required to do this in all cases as we need to know
that the target storage isn't required to be the same. The pattern now
does a local check to see if it can be proven to be producing into a
non-tied value it knows the providence of and bails otherwise.
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
index 4cb1080..0294b40 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
@@ -1560,7 +1560,8 @@
// Turns fills that cover an entire target resource into splats.
// This acts as a discard as it indicates we don't care about the previous
-// resource contents.
+// resource contents. Note that we only do this when we can locally prove that
+// it's safe to disassociate the result storage.
//
// Example:
// %0 = stream.async.fill %cst, %dst[%c0 to %dstsz for %dstsz] ... {%dstsz}
@@ -1570,13 +1571,20 @@
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AsyncFillOp fillOp,
PatternRewriter &rewriter) const override {
- if (fillOp.getTargetLength() == fillOp.getTargetSize()) {
- rewriter.replaceOpWithNewOp<IREE::Stream::AsyncSplatOp>(
- fillOp, fillOp.getResult().getType(), fillOp.getValue(),
- fillOp.getTargetSize(), fillOp.getAffinityAttr());
- return success();
+ if (fillOp.getTargetLength() != fillOp.getTargetSize())
+ return failure();
+
+ auto targetOp = fillOp.getTarget().getDefiningOp();
+ if (!targetOp || IREE::Util::TiedOpInterface::findTiedBaseValue(
+ fillOp.getTarget()) != fillOp.getTarget()) {
+ return rewriter.notifyMatchFailure(
+ fillOp, "unable to locally determine safety of eliding the target");
}
- return failure();
+
+ rewriter.replaceOpWithNewOp<IREE::Stream::AsyncSplatOp>(
+ fillOp, fillOp.getResult().getType(), fillOp.getValue(),
+ fillOp.getTargetSize(), fillOp.getAffinityAttr());
+ return success();
}
};
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir b/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir
index 1438588..8b86429 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir
@@ -179,17 +179,35 @@
// -----
+// Allow pattern because we can verify the target is safe to elide.
+
// CHECK-LABEL: @FlattenFullFillToSplat
-util.func private @FlattenFullFillToSplat(%arg0: !stream.resource<*>, %arg1: index, %arg2: i32) -> !stream.resource<*> {
+util.func private @FlattenFullFillToSplat(%arg0: index, %arg1: i32) -> !stream.resource<*> {
%c0 = arith.constant 0 : index
- // CHECK: %[[T:.+]] = stream.async.splat %arg2 : i32 -> !stream.resource<*>{%arg1}
- %0 = stream.async.fill %arg2, %arg0[%c0 to %arg1 for %arg1] : i32 -> %arg0 as !stream.resource<*>{%arg1}
+ %c123_i32 = arith.constant 123 : i32
+ %target = stream.async.splat %c123_i32 : i32 -> !stream.resource<*>{%arg0}
+ // CHECK: %[[T:.+]] = stream.async.splat %arg1 : i32 -> !stream.resource<*>{%arg0}
+ %0 = stream.async.fill %arg1, %target[%c0 to %arg0 for %arg0] : i32 -> %target as !stream.resource<*>{%arg0}
// CHECK: util.return %[[T]]
util.return %0 : !stream.resource<*>
}
// -----
+// The target is tied and we cannot avoid the fill.
+
+// CHECK-LABEL: @FlattenFullFillToSplatUnsafe
+util.func private @FlattenFullFillToSplatUnsafe(%arg0: index, %arg1: i32, %arg2: !hal.buffer_view) -> !stream.resource<*> {
+ %c0 = arith.constant 0 : index
+ // CHECK: stream.tensor.import
+ %target = stream.tensor.import %arg2 : !hal.buffer_view -> tensor<8xi32> in !stream.resource<*>{%arg0}
+ // CHECK: stream.async.fill
+ %0 = stream.async.fill %arg1, %target[%c0 to %arg0 for %arg0] : i32 -> %target as !stream.resource<*>{%arg0}
+ util.return %0 : !stream.resource<*>
+}
+
+// -----
+
// CHECK-LABEL: @ElideRedundantFill
util.func private @ElideRedundantFill(%arg0: !stream.resource<*>, %arg1: index, %arg2: i32) -> !stream.resource<*> {
%c0 = arith.constant 0 : index