stream: Improve sinking stability further (#12834)
I found other cases that resulted in unstable sinking behavior. Use a
more principled formulation of sinking stability.
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
index e40ba5e..6e44cbe 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
@@ -55,9 +55,18 @@
return std::nullopt;
}
-// Finds the insertion point before |targetOp| and after |earliestOp| that would
-// not oscillate if an op was moved there. Oscillations can occur if there are
-// multiple ops inserted before a single op as insertion order based on
+// Various patterns try to sink ops, and in case of uses in multiple blocks
+// they might be sunk to the end of a block. When multiple such ops are being
+// sunk, they can "fight" over who is at the end of the block, resulting in
+// infinite pattern recursion. To avoid this, we need to collectively know
+// across patterns which ops are liable to be sunk that way.
+static bool isSinkCandidate(Operation *op) {
+ return isa<AsyncSplatOp, AsyncAllocaOp, TimepointAwaitOp>(op);
+}
+
+// Determine if sinking |toBeSunkOp| before |targetOp| won't result in an
+// unstable oscillation across patterns. Oscillations can occur if there
+// are multiple ops inserted before a single op as insertion order based on
// canonicalization is undefined.
//
// Example:
@@ -66,39 +75,73 @@
// %2 = op.c %0, %1
// If %0 and %1 are sunk to %2 the ordering will depend on which sink pattern
// runs first and each of the patterns will fight trying to sink lower than the
-// other.
-static Block::iterator findInsertionPointBefore(Operation *earliestOp,
- Operation *targetOp) {
- // Check if ops between this and the target are all used by the target.
- // If they are, we skip sinking so that we don't get stuck in an infinite loop
- // if there are two splats used by the same op (or another pattern sinking).
- if (earliestOp->getBlock() == targetOp->getBlock()) {
- SmallPtrSet<Operation *, 4> producerOps;
+// other. As long as sinking only happens when this function returns `true`,
+// then the sinking across patterns will reach a fixed-point.
+static bool canStablySinkTo(Operation *toBeSunkOp, Operation *targetOp) {
+ // Stably sinking implies that other sinking won't "fight" with this
+ // sinking. This is obviously not possible in an open pattern ecosystem,
+ // but for the purpose of this function, we assume that all sinking patterns
+ // that we are concerned with are the other patterns in the `stream` dialect.
+ //
+ // In typical usage, this function will result in various patterns sinking
+ // their relevant ops before `targetOp`. This results in a sequence of
+ // sinkable ops before `targetOp`. This is fine, until we start to sink
+ // them again, which can result in "fighting". We detect that scenario
+ // by seeing if all the ops between `toBeSunkOp` and `targetOp` might be sunk
+ // again.
+ //
+ // To prove that this function results in sinking that reaches a fixed-point,
+ // we can design a potential function `f(the_module) -> int`, and show that it
+ // decreases strictly monotonically with each sinking operation (and cannot go
+ // below 0). In particular, we choose the following function: `f(the_module) =
+ // sum(g(op) for op in the_module)`, where `g(op) -> int` gives the distance
+ // between op's current location and the latest it could appear in the program
+ // (infinite, if that location is in another block).
+ assert(isSinkCandidate(toBeSunkOp) && "asking to sink a non-sinkable op");
+
+ // If `targetOp` is a terminator, then it might be chosen as a sink location
+ // purely for control flow reasons, and not due to use-def chains. This means
+ // that if `targetOp` is not a terminator, then we can prune the set of
+ // sinkable ops that might fight with `toBeSunkOp` more aggressively by using
+ // use-def chains.
+ bool allowUseDefPruning = !targetOp->hasTrait<mlir::OpTrait::IsTerminator>();
+
+ // If the sinking operation would be a no-op, then we need to prevent
+ // the sinking operation, to avoid infinite pattern applications.
+ if (Block::iterator(targetOp) == std::next(Block::iterator(toBeSunkOp)))
+ return false;
+
+ // If the sinking is to a different block, then it okay, since for any later
+ // sinkings, this reduces the problem to stable sinking within a single
+ // block (handled below).
+ if (toBeSunkOp->getBlock() != targetOp->getBlock()) return true;
+
+ SmallPtrSet<Operation *, 4> producerOps;
+ if (allowUseDefPruning) {
for (auto operand : targetOp->getOperands()) {
if (operand.getDefiningOp()) {
producerOps.insert(operand.getDefiningOp());
}
}
- bool allUsed = true;
- for (auto it = Block::iterator(earliestOp); it != Block::iterator(targetOp);
- ++it) {
- if (!producerOps.contains(&*it)) {
- allUsed = false;
- break;
- }
- }
- if (allUsed) return Block::iterator(earliestOp);
}
- return Block::iterator(targetOp);
+
+ // If any of the ops between `toBeSunkOp` and `targetOp` are known to not
+ // fight with this op, then it is stable to sink.
+ for (Operation &op : llvm::make_range(Block::iterator(toBeSunkOp),
+ Block::iterator(targetOp))) {
+ // If the intervening op that is not even a sink candidate itself,
+ // then it cannot fight.
+ if (!isSinkCandidate(&op)) return true;
+ // If the op is pruned by use-def chains, then it won't fight.
+ if (allowUseDefPruning && !producerOps.contains(&op)) return true;
+ }
+ return false;
}
// Sinks |op| down to |targetOp|, ensuring that we don't oscillate.
// Returns success if the op was sunk and failure if sinking was not needed.
static LogicalResult sinkOp(Operation *op, Operation *targetOp) {
- auto ip = findInsertionPointBefore(op, targetOp);
- if (ip == Block::iterator(op)) return failure();
- // If the moveBefore would be a no-op, then there is no work to do.
- if (ip == std::next(Block::iterator(op))) return failure();
+ if (!canStablySinkTo(op, targetOp)) return failure();
op->moveBefore(targetOp);
return success();
}
@@ -2569,13 +2612,12 @@
}
}
- // Find the earliest point before |user| that is safe to insert into. If it
- // ends up being where we already are then no-op.
- auto ip = findInsertionPointBefore(op, firstUserInDominator);
- if (ip == Block::iterator(op)) return failure();
+ // If sinking to `firstUserInDominator` could result in patterns
+ // fighting each other, then don't sink.
+ if (!canStablySinkTo(op, firstUserInDominator)) return failure();
rewriter.updateRootInPlace(op,
- [&]() { op->moveBefore(ip->getBlock(), ip); });
+ [&]() { op->moveBefore(firstUserInDominator); });
return success();
}
};
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir b/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir
index 5e29612..6993991 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir
@@ -54,23 +54,25 @@
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c100 = arith.constant 100 : index
+ %c101 = arith.constant 101 : index
%c121_i32 = arith.constant 121 : i32
// The splat is already where we would sink it to -- this used to trigger
- // a bug where we would "move" the splat to its current location, triggering
// infinite pattern recursion.
- // CHECK: %[[SPLAT:.+]] = stream.async.splat %c121_i32 : i32 -> !stream.resource<*>{%c100}
+ // CHECK: %[[SPLAT100:.+]] = stream.async.splat %c121_i32 : i32 -> !stream.resource<*>{%c100}
+ // CHECK-NEXT: %[[SPLAT101:.+]] = stream.async.splat %c121_i32 : i32 -> !stream.resource<*>{%c101}
// CHECK-NEXT: cf.cond_br %arg1, ^bb1, ^bb2
%0 = stream.async.splat %c121_i32 : i32 -> !stream.resource<*>{%c100}
+ %1 = stream.async.splat %c121_i32 : i32 -> !stream.resource<*>{%c101}
cf.cond_br %arg1, ^bb1, ^bb2
// CHECK: ^bb1:
^bb1:
- // CHECK: = stream.async.dispatch @executable::@dispatch0[%c1, %c2, %c3](%[[SPLAT]][%c0 to %c100 for %c100])
- %2 = stream.async.dispatch @executable::@dispatch0[%c1, %c2, %c3](%0[%c0 to %c100 for %c100]) : (!stream.resource<*>{%c100}) -> !stream.resource<*>{%c100}
+ // CHECK: stream.async.dispatch @executable::@dispatch0[%c1, %c2, %c3](%[[SPLAT100]][%c0 to %c100 for %c100], %[[SPLAT101]][%c0 to %c101 for %c101]) : (!stream.resource<*>{%c100}, !stream.resource<*>{%c101}) -> !stream.resource<*>{%c100}
+ %2 = stream.async.dispatch @executable::@dispatch0[%c1, %c2, %c3](%0[%c0 to %c100 for %c100], %1[%c0 to %c101 for %c101]) : (!stream.resource<*>{%c100}, !stream.resource<*>{%c101}) -> !stream.resource<*>{%c100}
cf.br ^bb3(%2 : !stream.resource<*>)
// CHECK: ^bb2:
^bb2:
- // CHECK: = stream.async.dispatch @executable::@dispatch1[%c1, %c2, %c3](%[[SPLAT]][%c0 to %c100 for %c100])
- %3 = stream.async.dispatch @executable::@dispatch1[%c1, %c2, %c3](%0[%c0 to %c100 for %c100]) : (!stream.resource<*>{%c100}) -> !stream.resource<*>{%c100}
+ // CHECK: stream.async.dispatch @executable::@dispatch1[%c1, %c2, %c3](%[[SPLAT100]][%c0 to %c100 for %c100], %[[SPLAT101]][%c0 to %c101 for %c101]) : (!stream.resource<*>{%c100}, !stream.resource<*>{%c101}) -> !stream.resource<*>{%c100}
+ %3 = stream.async.dispatch @executable::@dispatch1[%c1, %c2, %c3](%0[%c0 to %c100 for %c100], %1[%c0 to %c101 for %c101]) : (!stream.resource<*>{%c100}, !stream.resource<*>{%c101}) -> !stream.resource<*>{%c100}
cf.br ^bb3(%3 : !stream.resource<*>)
// CHECK: ^bb3(
^bb3(%arg6: !stream.resource<*>):
@@ -96,12 +98,12 @@
// CHECK-LABEL: @ConvertSplatConstantsIntoSplats
func.func @ConvertSplatConstantsIntoSplats(%arg0: index) -> (!stream.resource<transient>, !stream.resource<transient>) {
- // CHECK-NOT: = stream.async.constant : !stream.resource<transient>{%arg0} = dense<[3]> : tensor<8xi32>
// CHECK: %[[CST:.+]] = arith.constant 3 : i32
- // CHECK: %0 = stream.async.splat %[[CST]] : i32 -> !stream.resource<transient>{%arg0}
- %0 = stream.async.constant : !stream.resource<transient>{%arg0} = dense<3> : tensor<8xi32>
// CHECK: = stream.async.constant : !stream.resource<transient>{%arg0} = dense<[1, 2, 3, 4, 5, 6, 7, 8]> : tensor<8xi32>
- %1 = stream.async.constant : !stream.resource<transient>{%arg0} = dense<[1, 2, 3, 4, 5, 6, 7, 8]> : tensor<8xi32>
+ %0 = stream.async.constant : !stream.resource<transient>{%arg0} = dense<[1, 2, 3, 4, 5, 6, 7, 8]> : tensor<8xi32>
+ // CHECK-NOT: = stream.async.constant : !stream.resource<transient>{%arg0} = dense<[3]> : tensor<8xi32>
+ // CHECK: = stream.async.splat %[[CST]] : i32 -> !stream.resource<transient>{%arg0}
+ %1 = stream.async.constant : !stream.resource<transient>{%arg0} = dense<3> : tensor<8xi32>
return %0, %1 : !stream.resource<transient>, !stream.resource<transient>
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir
index 99dd8d0..42297e5 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir
@@ -271,8 +271,8 @@
%updated = stream.async.store %load, %download[%c0] : i8 -> !stream.resource<staging>{%c1}
// CHECK: stream.async.execute
- // CHECK-NEXT: stream.async.splat
// CHECK-NEXT: stream.async.transfer
+ // CHECK-NEXT: stream.async.splat
%upload = stream.async.transfer %updated : !stream.resource<staging>{%c1} -> !stream.resource<transient>{%c1}
// CHECK-NEXT: stream.async.dispatch
%dispatch1 = stream.async.dispatch @ex::@dispatch1[%c1, %c1, %c1](%upload[%c0 to %c1 for %c1], %splat[%c0 to %c1 for %c1]) : (!stream.resource<transient>{%c1}, !stream.resource<transient>{%c1}) -> !stream.resource<transient>{%c1}