Fixing lifetime transitions. (#7678)

This was showing up in some internal tests - the repro added here covers the core case hit by those.
diff --git a/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp b/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp
index fef5d08..f3a2db1 100644
--- a/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp
+++ b/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp
@@ -85,6 +85,37 @@
   static_assert(BEST_STATE == BaseType::getBestState(),
                 "unexpected BEST_STATE value");
 
+  static bool isValidState(uint16_t bits) {
+    // bool isIndirect = (bits & NOT_INDIRECT) != NOT_INDIRECT;
+    // bool isExternal = (bits & NOT_EXTERNAL) != NOT_EXTERNAL;
+    bool isMutated = (bits & NOT_MUTATED) != NOT_MUTATED;
+    bool isConstant = (bits & NOT_CONSTANT) != NOT_CONSTANT;
+    // bool isTransferRead = (bits & NOT_TRANSFER_READ) != NOT_TRANSFER_READ;
+    // bool isTransferWrite = (bits & NOT_TRANSFER_WRITE) != NOT_TRANSFER_WRITE;
+    bool isStagingRead = (bits & NOT_STAGING_READ) != NOT_STAGING_READ;
+    bool isStagingWrite = (bits & NOT_STAGING_WRITE) != NOT_STAGING_WRITE;
+    bool isDispatchRead = (bits & NOT_DISPATCH_READ) != NOT_DISPATCH_READ;
+    bool isDispatchWrite = (bits & NOT_DISPATCH_WRITE) != NOT_DISPATCH_WRITE;
+    // bool isGlobalRead = (bits & NOT_GLOBAL_READ) != NOT_GLOBAL_READ;
+    // bool isGlobalWrite = (bits & NOT_GLOBAL_WRITE) != NOT_GLOBAL_WRITE;
+
+    // Cannot be both staging and dispatch.
+    if ((isStagingRead || isStagingWrite) &&
+        (isDispatchRead || isDispatchWrite)) {
+      return false;
+    }
+
+    // Cannot be both constant and mutated.
+    // TODO(benvanik): maybe allow this for initializers that perform dispatches
+    // to initialize the resources. This introduces copies of those that are
+    // annoying to elide later on.
+    if (isConstant && isMutated) {
+      return false;
+    }
+
+    return true;
+  }
+
   ResourceUsageBitfield convertBitsToResourceUsage(uint16_t bits) const {
     return static_cast<ResourceUsageBitfield>(~bits & BEST_STATE);
   }
@@ -232,9 +263,17 @@
             default:
               break;
           }
+          auto resultUsage = solver.getElementFor<ValueResourceUsage>(
+              *this, Position::forValue(op.result()),
+              DFX::Resolution::REQUIRED);
+          getState() ^= resultUsage.getState();
         })
         .Case([&](IREE::Util::GlobalLoadIndirectOp op) {
           removeAssumedBits(NOT_INDIRECT | NOT_GLOBAL_READ);
+          auto resultUsage = solver.getElementFor<ValueResourceUsage>(
+              *this, Position::forValue(op.result()),
+              DFX::Resolution::REQUIRED);
+          getState() ^= resultUsage.getState();
         })
         .Case([&](IREE::Stream::ResourceStoreOp op) {
           removeAssumedBits(NOT_STAGING_WRITE);
@@ -245,12 +284,24 @@
         })
         .Case([&](IREE::Stream::TensorImportOp op) {
           removeAssumedBits(NOT_MUTATED | NOT_EXTERNAL);
+          auto resultUsage = solver.getElementFor<ValueResourceUsage>(
+              *this, Position::forValue(op.result()),
+              DFX::Resolution::REQUIRED);
+          getState() ^= resultUsage.getState();
         })
         .Case([&](IREE::Stream::AsyncConstantOp op) {
           removeAssumedBits(NOT_CONSTANT | NOT_TRANSFER_WRITE);
+          auto resultUsage = solver.getElementFor<ValueResourceUsage>(
+              *this, Position::forValue(op.result()),
+              DFX::Resolution::REQUIRED);
+          getState() ^= resultUsage.getState();
         })
         .Case([&](IREE::Stream::AsyncSplatOp op) {
           removeAssumedBits(NOT_TRANSFER_WRITE);
+          auto resultUsage = solver.getElementFor<ValueResourceUsage>(
+              *this, Position::forValue(op.result()),
+              DFX::Resolution::REQUIRED);
+          getState() ^= resultUsage.getState();
         })
         .Case([&](IREE::Stream::AsyncCloneOp op) {
           removeAssumedBits(NOT_TRANSFER_WRITE);
@@ -334,6 +385,10 @@
                 *this, Position::forValue(tiedOperand),
                 DFX::Resolution::REQUIRED);
             getState() ^= tiedUsage.getState();
+          } else {
+            auto resultUsage = solver.getElementFor<ValueResourceUsage>(
+                *this, Position::forValue(result), DFX::Resolution::REQUIRED);
+            getState() ^= resultUsage.getState();
           }
         })
         .Default([&](Operation *op) {});
@@ -354,22 +409,9 @@
               DFX::Resolution::REQUIRED);
           getState() ^= resultUsage.getState();
         })
-        .Case([&](mlir::BranchOp op) {
+        .Case([&](mlir::BranchOpInterface op) {
           auto operandUsage = solver.getElementFor<ValueResourceUsage>(
-              *this, Position::forValue(op.getOperand(operandIdx)),
-              DFX::Resolution::OPTIONAL);
-          getState() ^= operandUsage.getState();
-          solver.getExplorer().walkOutgoingBranchOperandArguments(
-              op, operandIdx, [&](Block *targetBlock, BlockArgument arg) {
-                auto argUsage = solver.getElementFor<ValueResourceUsage>(
-                    *this, Position::forValue(arg), DFX::Resolution::OPTIONAL);
-                getState() ^= argUsage;
-                return WalkResult::advance();
-              });
-        })
-        .Case([&](mlir::CondBranchOp op) {
-          auto operandUsage = solver.getElementFor<ValueResourceUsage>(
-              *this, Position::forValue(op.getOperand(operandIdx)),
+              *this, Position::forValue(op->getOperand(operandIdx)),
               DFX::Resolution::REQUIRED);
           getState() ^= operandUsage.getState();
           solver.getExplorer().walkOutgoingBranchOperandArguments(
@@ -383,14 +425,13 @@
         .Case([&](mlir::ReturnOp op) {
           auto operandUsage = solver.getElementFor<ValueResourceUsage>(
               *this, Position::forValue(op.getOperand(operandIdx)),
-              DFX::Resolution::OPTIONAL);
+              DFX::Resolution::REQUIRED);
           getState() ^= operandUsage.getState();
           solver.getExplorer().walkIncomingCalls(
               op->getParentOfType<mlir::CallableOpInterface>(),
               [&](mlir::CallOpInterface callOp) {
                 auto argUsage = solver.getElementFor<ValueResourceUsage>(
-                    *this,
-                    Position::forValue(callOp.getArgOperands()[operandIdx]),
+                    *this, Position::forValue(callOp->getResult(operandIdx)),
                     DFX::Resolution::OPTIONAL);
                 getState() ^= argUsage;
                 return WalkResult::advance();
@@ -546,6 +587,13 @@
     if (traversalResult == TraversalResult::INCOMPLETE) {
       removeAssumedBits(NOT_EXTERNAL);
     }
+
+    // Filter out impossible states by marking the state invalid.
+    // The fixpoint framework will try again.
+    if (!isValidState(assumedBits)) {
+      return indicatePessimisticFixpoint();
+    }
+
     return assumedBits == getAssumed() ? ChangeStatus::UNCHANGED
                                        : ChangeStatus::CHANGED;
   }
diff --git a/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp b/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp
index d35a7d2..fca1be7 100644
--- a/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp
+++ b/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp
@@ -61,6 +61,14 @@
   }
 }
 
+// Returns either the affinity of |op| or nullptr.
+static IREE::Stream::AffinityAttr getOpAffinity(Operation *op) {
+  if (auto affinityOp = dyn_cast<IREE::Stream::AffinityOpInterface>(op)) {
+    return affinityOp.getAffinity();
+  }
+  return {};
+}
+
 // Base pattern type for resource usage refinement.
 // The results of the usage analysis are available for use by subclasses.
 template <typename OpT>
@@ -70,6 +78,27 @@
 
   ResourceUsageAnalysis &analysis;
 
+  // Updates the |arg| type to the lifetime derived by analysis, if needed.
+  // Returns true if a change was made.
+  bool applyArgTransition(BlockArgument arg, PatternRewriter &rewriter) const {
+    auto oldType = arg.getType().dyn_cast<IREE::Stream::ResourceType>();
+    if (!oldType) return false;
+    auto newUsage = analysis.lookupResourceUsage(arg);
+    auto newLifetime = convertUsageToLifetime(newUsage);
+    auto newType = rewriter.getType<IREE::Stream::ResourceType>(newLifetime);
+    if (oldType == newType) {
+      // Old and new lifetimes match; no need to apply a transition.
+      return false;
+    } else if (oldType.getLifetime() != IREE::Stream::Lifetime::Unknown) {
+      // Transitioning lifetimes; rely on users to insert the transitions.
+      return false;
+    } else {
+      // Directly overwrite the existing lifetime.
+      arg.setType(newType);
+      return true;
+    }
+  }
+
   // Updates the |result| type to the lifetime derived by analysis, if needed.
   // Returns true if a change was made.
   bool applyResultTransition(Operation *op, Value result,
@@ -78,25 +107,62 @@
     if (!oldType) return false;
     auto newUsage = analysis.lookupResourceUsage(result);
     auto newLifetime = convertUsageToLifetime(newUsage);
-    if (oldType.getLifetime() == newLifetime) return false;
     auto newType = rewriter.getType<IREE::Stream::ResourceType>(newLifetime);
-    result.setType(newType);
-    return true;
+    if (oldType == newType) {
+      // Old and new lifetimes match; no need to apply a transition.
+      return false;
+    } else if (oldType.getLifetime() != IREE::Stream::Lifetime::Unknown) {
+      // Transitioning from one lifetime to another; insert a transfer
+      // placeholder (as we may later decide it's ok to transition on a
+      // particular device).
+      auto resultSize = rewriter.createOrFold<IREE::Stream::ResourceSizeOp>(
+          op->getLoc(), result);
+      auto affinityAttr = getOpAffinity(op);
+      auto transferOp = rewriter.create<IREE::Stream::AsyncTransferOp>(
+          op->getLoc(), newType, result, resultSize, resultSize,
+          /*source_affinity=*/affinityAttr,
+          /*target_affinity=*/affinityAttr);
+      result.replaceUsesWithIf(transferOp.result(), [&](OpOperand &operand) {
+        return operand.getOwner() != transferOp &&
+               operand.getOwner() != resultSize.getDefiningOp();
+      });
+      return true;
+    } else {
+      // Directly overwrite the existing lifetime.
+      result.setType(newType);
+      return true;
+    }
   }
 
   // Updates the |result| type to the lifetime derived by analysis, if needed.
-  // Returns true if a change was made.
-  bool applyResultTransition(Operation *op, Value result, Value resultSize,
-                             Attribute affinityAttr,
+  // Returns true if a change was made. Same as above but for when we have the
+  // information available and don't need to insert the queries.
+  bool applyResultTransition(Value result, Value resultSize,
+                             IREE::Stream::AffinityAttr affinityAttr,
                              PatternRewriter &rewriter) const {
     auto oldType = result.getType().dyn_cast<IREE::Stream::ResourceType>();
     if (!oldType) return false;
     auto newUsage = analysis.lookupResourceUsage(result);
     auto newLifetime = convertUsageToLifetime(newUsage);
-    if (oldType.getLifetime() == newLifetime) return false;
     auto newType = rewriter.getType<IREE::Stream::ResourceType>(newLifetime);
-    result.setType(newType);
-    return true;
+    if (oldType == newType) {
+      // Old and new lifetimes match; no need to apply a transition.
+      return false;
+    } else if (oldType.getLifetime() != IREE::Stream::Lifetime::Unknown) {
+      // Transitioning from one lifetime to another; insert a transfer
+      // placeholder (as we may later decide it's ok to transition on a
+      // particular device).
+      auto transferOp = rewriter.create<IREE::Stream::AsyncTransferOp>(
+          result.getLoc(), newType, result, resultSize, resultSize,
+          /*source_affinity=*/affinityAttr,
+          /*target_affinity=*/affinityAttr);
+      result.replaceAllUsesExcept(transferOp.result(), transferOp);
+      return true;
+    } else {
+      // Directly overwrite the existing lifetime.
+      result.setType(newType);
+      return true;
+    }
   }
 
   // Updates all blocks argument lifetimes within the regions of |op|.
@@ -108,16 +174,9 @@
       for (auto &block : region) {
         rewriter.setInsertionPoint(&block, block.begin());
         for (auto &blockArg : block.getArguments()) {
-          auto oldType =
-              blockArg.getType().dyn_cast<IREE::Stream::ResourceType>();
-          if (!oldType) continue;
-          auto newUsage = analysis.lookupResourceUsage(blockArg);
-          auto newLifetime = convertUsageToLifetime(newUsage);
-          if (oldType.getLifetime() == newLifetime) return false;
-          auto newType =
-              rewriter.getType<IREE::Stream::ResourceType>(newLifetime);
-          blockArg.setType(newType);
-          didChange = true;
+          if (applyArgTransition(blockArg, rewriter)) {
+            didChange = true;
+          }
         }
       }
     }
@@ -158,13 +217,16 @@
       auto oldType = inputType.value().dyn_cast<IREE::Stream::ResourceType>();
       if (!oldType) {
         newInputs.push_back(inputType.value());
-        continue;
+      } else if (oldType.getLifetime() == IREE::Stream::Lifetime::Unknown) {
+        auto blockArg = op.getArgument(inputType.index());
+        auto newUsage = analysis.lookupResourceUsage(blockArg);
+        auto newLifetime = convertUsageToLifetime(newUsage);
+        auto newType =
+            rewriter.getType<IREE::Stream::ResourceType>(newLifetime);
+        newInputs.push_back(newType);
+      } else {
+        newInputs.push_back(oldType);
       }
-      auto blockArg = op.getArgument(inputType.index());
-      auto newUsage = analysis.lookupResourceUsage(blockArg);
-      auto newLifetime = convertUsageToLifetime(newUsage);
-      auto newType = rewriter.getType<IREE::Stream::ResourceType>(newLifetime);
-      newInputs.push_back(newType);
     }
 
     // Results:
@@ -174,13 +236,16 @@
       auto oldType = outputType.value().dyn_cast<IREE::Stream::ResourceType>();
       if (!oldType) {
         newOutputs.push_back(outputType.value());
-        continue;
+      } else if (oldType.getLifetime() == IREE::Stream::Lifetime::Unknown) {
+        auto returnValue = anyReturnOp.getOperand(outputType.index());
+        auto newUsage = analysis.lookupResourceUsage(returnValue);
+        auto newLifetime = convertUsageToLifetime(newUsage);
+        auto newType =
+            rewriter.getType<IREE::Stream::ResourceType>(newLifetime);
+        newOutputs.push_back(newType);
+      } else {
+        newOutputs.push_back(oldType);
       }
-      auto returnValue = anyReturnOp.getOperand(outputType.index());
-      auto newUsage = analysis.lookupResourceUsage(returnValue);
-      auto newLifetime = convertUsageToLifetime(newUsage);
-      auto newType = rewriter.getType<IREE::Stream::ResourceType>(newLifetime);
-      newOutputs.push_back(newType);
     }
     auto newFuncType = rewriter.getFunctionType(newInputs, newOutputs);
     if (op.getType() != newFuncType) {
@@ -231,11 +296,7 @@
     // Walk into nested regions first so we have the final result types returned
     // by the regions.
     bool didChange = this->applyRegionTransitions(op, rewriter);
-    Attribute affinityAttr;
-    if (auto affinityOp =
-            dyn_cast<IREE::Stream::AffinityOpInterface>(op.getOperation())) {
-      affinityAttr = affinityOp.getAffinity();
-    }
+    auto affinityAttr = getOpAffinity(op);
 
     rewriter.startRootUpdate(op);
 
@@ -247,7 +308,7 @@
         continue;
       }
       auto resultSize = sizeAwareOp.getResultSize(i);
-      if (this->applyResultTransition(op, result, resultSize, affinityAttr,
+      if (this->applyResultTransition(result, resultSize, affinityAttr,
                                       rewriter)) {
         didChange = true;
       }
diff --git a/iree/compiler/Dialect/Stream/Transforms/test/refine_usage.mlir b/iree/compiler/Dialect/Stream/Transforms/test/refine_usage.mlir
index 4b0e6a2..b94dccc 100644
--- a/iree/compiler/Dialect/Stream/Transforms/test/refine_usage.mlir
+++ b/iree/compiler/Dialect/Stream/Transforms/test/refine_usage.mlir
@@ -114,6 +114,26 @@
 
 // -----
 
+// Tests invalid transfer conflict resolution.
+// Constants cannot be mutated even though it is tied. This survives after
+// copy-on-write materialization because of the transfer and we need to preserve
+// it such that the copy is performed as epxected.
+
+// CHECK-LABEL: @transferResolution
+// CHECK-SAME: (%[[ARG0:.+]]: !stream.resource<constant>, %[[SIZE:.+]]: index)
+// CHECK-SAME: -> !stream.resource<external>
+func @transferResolution(%arg0: !stream.resource<constant>, %size: index) -> !stream.resource<*> {
+  %c1 = arith.constant 1 : index
+  // CHECK: %[[ARG0_EXT:.+]] = stream.async.transfer %[[ARG0]] : !stream.resource<constant>{%[[SIZE]]} -> !stream.resource<external>{%[[SIZE]]}
+  %arg0_any = stream.async.transfer %arg0 : !stream.resource<constant>{%size} -> !stream.resource<*>{%size}
+  // CHECK: %[[RET0:.+]] = stream.async.dispatch @ex::@dispatch[%c1, %c1, %c1](%[[ARG0_EXT]]) : (!stream.resource<external>{%[[SIZE]]}) -> %[[ARG0_EXT]]{%[[SIZE]]}
+  %ret0_any = stream.async.dispatch @ex::@dispatch[%c1, %c1, %c1](%arg0_any) : (!stream.resource<*>{%size}) -> %arg0_any{%size}
+  // return %[[RET0]] : !stream.resource<external>
+  return %ret0_any : !stream.resource<*>
+}
+
+// -----
+
 // Tests that global usage propagates through loads/stores.
 
 util.global private mutable @variable : !stream.resource<variable>