stream: Improve sinking stability further (#12834)

I found other cases that resulted in unstable sinking behavior. Use a
more principled formulation of sinking stability.
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
index e40ba5e..6e44cbe 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
@@ -55,9 +55,18 @@
   return std::nullopt;
 }
 
-// Finds the insertion point before |targetOp| and after |earliestOp| that would
-// not oscillate if an op was moved there. Oscillations can occur if there are
-// multiple ops inserted before a single op as insertion order based on
+// Various patterns try to sink ops, and in case of uses in multiple blocks
+// they might be sunk to the end of a block. When multiple such ops are being
+// sunk, they can "fight" over who is at the end of the block, resulting in
+// infinite pattern recursion. To avoid this, we need to collectively know
+// across patterns which ops are liable to be sunk that way.
+static bool isSinkCandidate(Operation *op) {
+  return isa<AsyncSplatOp, AsyncAllocaOp, TimepointAwaitOp>(op);
+}
+
+// Determine if sinking |toBeSunkOp| before |targetOp| won't result in an
+// unstable oscillation across patterns. Oscillations can occur if there
+// are multiple ops inserted before a single op as insertion order based on
 // canonicalization is undefined.
 //
 // Example:
@@ -66,39 +75,73 @@
 //   %2 = op.c %0, %1
 // If %0 and %1 are sunk to %2 the ordering will depend on which sink pattern
 // runs first and each of the patterns will fight trying to sink lower than the
-// other.
-static Block::iterator findInsertionPointBefore(Operation *earliestOp,
-                                                Operation *targetOp) {
-  // Check if ops between this and the target are all used by the target.
-  // If they are, we skip sinking so that we don't get stuck in an infinite loop
-  // if there are two splats used by the same op (or another pattern sinking).
-  if (earliestOp->getBlock() == targetOp->getBlock()) {
-    SmallPtrSet<Operation *, 4> producerOps;
+// other. As long as sinking only happens when this function returns `true`,
+// then the sinking across patterns will reach a fixed-point.
+static bool canStablySinkTo(Operation *toBeSunkOp, Operation *targetOp) {
+  // Stably sinking implies that other sinking won't "fight" with this
+  // sinking. This is obviously not possible in an open pattern ecosystem,
+  // but for the purpose of this function, we assume that all sinking patterns
+  // that we are concerned with are the other patterns in the `stream` dialect.
+  //
+  // In typical usage, this function will result in various patterns sinking
+  // their relevant ops before `targetOp`. This results in a sequence of
+  // sinkable ops before `targetOp`. This is fine, until we start to sink
+  // them again, which can result in "fighting". We detect that scenario
+  // by seeing if all the ops between `toBeSunkOp` and `targetOp` might be sunk
+  // again.
+  //
+  // To prove that this function results in sinking that reaches a fixed-point,
+  // we can design a potential function `f(the_module) -> int`, and show that it
+  // decreases strictly monotonically with each sinking operation (and cannot go
+  // below 0). In particular, we choose the following function: `f(the_module) =
+  // sum(g(op) for op in the_module)`, where `g(op) -> int` gives the distance
+  // between op's current location and the latest it could appear in the program
+  // (infinite, if that location is in another block).
+  assert(isSinkCandidate(toBeSunkOp) && "asking to sink a non-sinkable op");
+
+  // If `targetOp` is a terminator, then it might be chosen as a sink location
+  // purely for control flow reasons, and not due to use-def chains. This means
+  // that if `targetOp` is not a terminator, then we can prune the set of
+  // sinkable ops that might fight with `toBeSunkOp` more aggressively by using
+  // use-def chains.
+  bool allowUseDefPruning = !targetOp->hasTrait<mlir::OpTrait::IsTerminator>();
+
+  // If the sinking operation would be a no-op, then we need to prevent
+  // the sinking operation, to avoid infinite pattern applications.
+  if (Block::iterator(targetOp) == std::next(Block::iterator(toBeSunkOp)))
+    return false;
+
+  // If the sinking is to a different block, then it okay, since for any later
+  // sinkings, this reduces the problem to stable sinking within a single
+  // block (handled below).
+  if (toBeSunkOp->getBlock() != targetOp->getBlock()) return true;
+
+  SmallPtrSet<Operation *, 4> producerOps;
+  if (allowUseDefPruning) {
     for (auto operand : targetOp->getOperands()) {
       if (operand.getDefiningOp()) {
         producerOps.insert(operand.getDefiningOp());
       }
     }
-    bool allUsed = true;
-    for (auto it = Block::iterator(earliestOp); it != Block::iterator(targetOp);
-         ++it) {
-      if (!producerOps.contains(&*it)) {
-        allUsed = false;
-        break;
-      }
-    }
-    if (allUsed) return Block::iterator(earliestOp);
   }
-  return Block::iterator(targetOp);
+
+  // If any of the ops between `toBeSunkOp` and `targetOp` are known to not
+  // fight with this op, then it is stable to sink.
+  for (Operation &op : llvm::make_range(Block::iterator(toBeSunkOp),
+                                        Block::iterator(targetOp))) {
+    // If the intervening op that is not even a sink candidate itself,
+    // then it cannot fight.
+    if (!isSinkCandidate(&op)) return true;
+    // If the op is pruned by use-def chains, then it won't fight.
+    if (allowUseDefPruning && !producerOps.contains(&op)) return true;
+  }
+  return false;
 }
 
 // Sinks |op| down to |targetOp|, ensuring that we don't oscillate.
 // Returns success if the op was sunk and failure if sinking was not needed.
 static LogicalResult sinkOp(Operation *op, Operation *targetOp) {
-  auto ip = findInsertionPointBefore(op, targetOp);
-  if (ip == Block::iterator(op)) return failure();
-  // If the moveBefore would be a no-op, then there is no work to do.
-  if (ip == std::next(Block::iterator(op))) return failure();
+  if (!canStablySinkTo(op, targetOp)) return failure();
   op->moveBefore(targetOp);
   return success();
 }
@@ -2569,13 +2612,12 @@
       }
     }
 
-    // Find the earliest point before |user| that is safe to insert into. If it
-    // ends up being where we already are then no-op.
-    auto ip = findInsertionPointBefore(op, firstUserInDominator);
-    if (ip == Block::iterator(op)) return failure();
+    // If sinking to `firstUserInDominator` could result in patterns
+    // fighting each other, then don't sink.
+    if (!canStablySinkTo(op, firstUserInDominator)) return failure();
 
     rewriter.updateRootInPlace(op,
-                               [&]() { op->moveBefore(ip->getBlock(), ip); });
+                               [&]() { op->moveBefore(firstUserInDominator); });
     return success();
   }
 };
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir b/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir
index 5e29612..6993991 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir
@@ -54,23 +54,25 @@
   %c2 = arith.constant 2 : index
   %c3 = arith.constant 3 : index
   %c100 = arith.constant 100 : index
+  %c101 = arith.constant 101 : index
   %c121_i32 = arith.constant 121 : i32
   // The splat is already where we would sink it to -- this used to trigger
-  // a bug where we would "move" the splat to its current location, triggering
   // infinite pattern recursion.
-  // CHECK: %[[SPLAT:.+]] = stream.async.splat %c121_i32 : i32 -> !stream.resource<*>{%c100}
+  // CHECK: %[[SPLAT100:.+]] = stream.async.splat %c121_i32 : i32 -> !stream.resource<*>{%c100}
+  // CHECK-NEXT: %[[SPLAT101:.+]] = stream.async.splat %c121_i32 : i32 -> !stream.resource<*>{%c101}
   // CHECK-NEXT: cf.cond_br %arg1, ^bb1, ^bb2
   %0 = stream.async.splat %c121_i32 : i32 -> !stream.resource<*>{%c100}
+  %1 = stream.async.splat %c121_i32 : i32 -> !stream.resource<*>{%c101}
   cf.cond_br %arg1, ^bb1, ^bb2
 // CHECK: ^bb1:
 ^bb1:
-  // CHECK: = stream.async.dispatch @executable::@dispatch0[%c1, %c2, %c3](%[[SPLAT]][%c0 to %c100 for %c100])
-  %2 = stream.async.dispatch @executable::@dispatch0[%c1, %c2, %c3](%0[%c0 to %c100 for %c100]) : (!stream.resource<*>{%c100}) -> !stream.resource<*>{%c100}
+  // CHECK: stream.async.dispatch @executable::@dispatch0[%c1, %c2, %c3](%[[SPLAT100]][%c0 to %c100 for %c100], %[[SPLAT101]][%c0 to %c101 for %c101]) : (!stream.resource<*>{%c100}, !stream.resource<*>{%c101}) -> !stream.resource<*>{%c100}
+  %2 = stream.async.dispatch @executable::@dispatch0[%c1, %c2, %c3](%0[%c0 to %c100 for %c100], %1[%c0 to %c101 for %c101]) : (!stream.resource<*>{%c100}, !stream.resource<*>{%c101}) -> !stream.resource<*>{%c100}
   cf.br ^bb3(%2 : !stream.resource<*>)
 // CHECK: ^bb2:
 ^bb2:
-  // CHECK: = stream.async.dispatch @executable::@dispatch1[%c1, %c2, %c3](%[[SPLAT]][%c0 to %c100 for %c100])
-  %3 = stream.async.dispatch @executable::@dispatch1[%c1, %c2, %c3](%0[%c0 to %c100 for %c100]) : (!stream.resource<*>{%c100}) -> !stream.resource<*>{%c100}
+  // CHECK: stream.async.dispatch @executable::@dispatch1[%c1, %c2, %c3](%[[SPLAT100]][%c0 to %c100 for %c100], %[[SPLAT101]][%c0 to %c101 for %c101]) : (!stream.resource<*>{%c100}, !stream.resource<*>{%c101}) -> !stream.resource<*>{%c100}
+  %3 = stream.async.dispatch @executable::@dispatch1[%c1, %c2, %c3](%0[%c0 to %c100 for %c100], %1[%c0 to %c101 for %c101]) : (!stream.resource<*>{%c100}, !stream.resource<*>{%c101}) -> !stream.resource<*>{%c100}
   cf.br ^bb3(%3 : !stream.resource<*>)
 // CHECK: ^bb3(
 ^bb3(%arg6: !stream.resource<*>):
@@ -96,12 +98,12 @@
 
 // CHECK-LABEL: @ConvertSplatConstantsIntoSplats
 func.func @ConvertSplatConstantsIntoSplats(%arg0: index) -> (!stream.resource<transient>, !stream.resource<transient>) {
-  // CHECK-NOT: = stream.async.constant : !stream.resource<transient>{%arg0} = dense<[3]> : tensor<8xi32>
   // CHECK: %[[CST:.+]] = arith.constant 3 : i32
-  // CHECK: %0 = stream.async.splat %[[CST]] : i32 -> !stream.resource<transient>{%arg0}
-  %0 = stream.async.constant : !stream.resource<transient>{%arg0} = dense<3> : tensor<8xi32>
   // CHECK: = stream.async.constant : !stream.resource<transient>{%arg0} = dense<[1, 2, 3, 4, 5, 6, 7, 8]> : tensor<8xi32>
-  %1 = stream.async.constant : !stream.resource<transient>{%arg0} = dense<[1, 2, 3, 4, 5, 6, 7, 8]> : tensor<8xi32>
+  %0 = stream.async.constant : !stream.resource<transient>{%arg0} = dense<[1, 2, 3, 4, 5, 6, 7, 8]> : tensor<8xi32>
+  // CHECK-NOT: = stream.async.constant : !stream.resource<transient>{%arg0} = dense<[3]> : tensor<8xi32>
+  // CHECK: = stream.async.splat %[[CST]] : i32 -> !stream.resource<transient>{%arg0}
+  %1 = stream.async.constant : !stream.resource<transient>{%arg0} = dense<3> : tensor<8xi32>
   return %0, %1 : !stream.resource<transient>, !stream.resource<transient>
 }
 
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir
index 99dd8d0..42297e5 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir
@@ -271,8 +271,8 @@
   %updated = stream.async.store %load, %download[%c0] : i8 -> !stream.resource<staging>{%c1}
 
   // CHECK: stream.async.execute
-  // CHECK-NEXT: stream.async.splat
   // CHECK-NEXT: stream.async.transfer
+  // CHECK-NEXT: stream.async.splat
   %upload = stream.async.transfer %updated : !stream.resource<staging>{%c1} -> !stream.resource<transient>{%c1}
   // CHECK-NEXT: stream.async.dispatch
   %dispatch1 = stream.async.dispatch @ex::@dispatch1[%c1, %c1, %c1](%upload[%c0 to %c1 for %c1], %splat[%c0 to %c1 for %c1]) : (!stream.resource<transient>{%c1}, !stream.resource<transient>{%c1}) -> !stream.resource<transient>{%c1}