Fixing lifetime transitions. (#7678)
This was showing up in some internal tests - the repro added here covers the core case hit by those.
diff --git a/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp b/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp
index fef5d08..f3a2db1 100644
--- a/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp
+++ b/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp
@@ -85,6 +85,37 @@
static_assert(BEST_STATE == BaseType::getBestState(),
"unexpected BEST_STATE value");
+ static bool isValidState(uint16_t bits) {
+ // bool isIndirect = (bits & NOT_INDIRECT) != NOT_INDIRECT;
+ // bool isExternal = (bits & NOT_EXTERNAL) != NOT_EXTERNAL;
+ bool isMutated = (bits & NOT_MUTATED) != NOT_MUTATED;
+ bool isConstant = (bits & NOT_CONSTANT) != NOT_CONSTANT;
+ // bool isTransferRead = (bits & NOT_TRANSFER_READ) != NOT_TRANSFER_READ;
+ // bool isTransferWrite = (bits & NOT_TRANSFER_WRITE) != NOT_TRANSFER_WRITE;
+ bool isStagingRead = (bits & NOT_STAGING_READ) != NOT_STAGING_READ;
+ bool isStagingWrite = (bits & NOT_STAGING_WRITE) != NOT_STAGING_WRITE;
+ bool isDispatchRead = (bits & NOT_DISPATCH_READ) != NOT_DISPATCH_READ;
+ bool isDispatchWrite = (bits & NOT_DISPATCH_WRITE) != NOT_DISPATCH_WRITE;
+ // bool isGlobalRead = (bits & NOT_GLOBAL_READ) != NOT_GLOBAL_READ;
+ // bool isGlobalWrite = (bits & NOT_GLOBAL_WRITE) != NOT_GLOBAL_WRITE;
+
+ // Cannot be both staging and dispatch.
+ if ((isStagingRead || isStagingWrite) &&
+ (isDispatchRead || isDispatchWrite)) {
+ return false;
+ }
+
+ // Cannot be both constant and mutated.
+ // TODO(benvanik): maybe allow this for initializers that perform dispatches
+ // to initialize the resources. This introduces copies of those that are
+ // annoying to elide later on.
+ if (isConstant && isMutated) {
+ return false;
+ }
+
+ return true;
+ }
+
ResourceUsageBitfield convertBitsToResourceUsage(uint16_t bits) const {
return static_cast<ResourceUsageBitfield>(~bits & BEST_STATE);
}
@@ -232,9 +263,17 @@
default:
break;
}
+ auto resultUsage = solver.getElementFor<ValueResourceUsage>(
+ *this, Position::forValue(op.result()),
+ DFX::Resolution::REQUIRED);
+ getState() ^= resultUsage.getState();
})
.Case([&](IREE::Util::GlobalLoadIndirectOp op) {
removeAssumedBits(NOT_INDIRECT | NOT_GLOBAL_READ);
+ auto resultUsage = solver.getElementFor<ValueResourceUsage>(
+ *this, Position::forValue(op.result()),
+ DFX::Resolution::REQUIRED);
+ getState() ^= resultUsage.getState();
})
.Case([&](IREE::Stream::ResourceStoreOp op) {
removeAssumedBits(NOT_STAGING_WRITE);
@@ -245,12 +284,24 @@
})
.Case([&](IREE::Stream::TensorImportOp op) {
removeAssumedBits(NOT_MUTATED | NOT_EXTERNAL);
+ auto resultUsage = solver.getElementFor<ValueResourceUsage>(
+ *this, Position::forValue(op.result()),
+ DFX::Resolution::REQUIRED);
+ getState() ^= resultUsage.getState();
})
.Case([&](IREE::Stream::AsyncConstantOp op) {
removeAssumedBits(NOT_CONSTANT | NOT_TRANSFER_WRITE);
+ auto resultUsage = solver.getElementFor<ValueResourceUsage>(
+ *this, Position::forValue(op.result()),
+ DFX::Resolution::REQUIRED);
+ getState() ^= resultUsage.getState();
})
.Case([&](IREE::Stream::AsyncSplatOp op) {
removeAssumedBits(NOT_TRANSFER_WRITE);
+ auto resultUsage = solver.getElementFor<ValueResourceUsage>(
+ *this, Position::forValue(op.result()),
+ DFX::Resolution::REQUIRED);
+ getState() ^= resultUsage.getState();
})
.Case([&](IREE::Stream::AsyncCloneOp op) {
removeAssumedBits(NOT_TRANSFER_WRITE);
@@ -334,6 +385,10 @@
*this, Position::forValue(tiedOperand),
DFX::Resolution::REQUIRED);
getState() ^= tiedUsage.getState();
+ } else {
+ auto resultUsage = solver.getElementFor<ValueResourceUsage>(
+ *this, Position::forValue(result), DFX::Resolution::REQUIRED);
+ getState() ^= resultUsage.getState();
}
})
.Default([&](Operation *op) {});
@@ -354,22 +409,9 @@
DFX::Resolution::REQUIRED);
getState() ^= resultUsage.getState();
})
- .Case([&](mlir::BranchOp op) {
+ .Case([&](mlir::BranchOpInterface op) {
auto operandUsage = solver.getElementFor<ValueResourceUsage>(
- *this, Position::forValue(op.getOperand(operandIdx)),
- DFX::Resolution::OPTIONAL);
- getState() ^= operandUsage.getState();
- solver.getExplorer().walkOutgoingBranchOperandArguments(
- op, operandIdx, [&](Block *targetBlock, BlockArgument arg) {
- auto argUsage = solver.getElementFor<ValueResourceUsage>(
- *this, Position::forValue(arg), DFX::Resolution::OPTIONAL);
- getState() ^= argUsage;
- return WalkResult::advance();
- });
- })
- .Case([&](mlir::CondBranchOp op) {
- auto operandUsage = solver.getElementFor<ValueResourceUsage>(
- *this, Position::forValue(op.getOperand(operandIdx)),
+ *this, Position::forValue(op->getOperand(operandIdx)),
DFX::Resolution::REQUIRED);
getState() ^= operandUsage.getState();
solver.getExplorer().walkOutgoingBranchOperandArguments(
@@ -383,14 +425,13 @@
.Case([&](mlir::ReturnOp op) {
auto operandUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getOperand(operandIdx)),
- DFX::Resolution::OPTIONAL);
+ DFX::Resolution::REQUIRED);
getState() ^= operandUsage.getState();
solver.getExplorer().walkIncomingCalls(
op->getParentOfType<mlir::CallableOpInterface>(),
[&](mlir::CallOpInterface callOp) {
auto argUsage = solver.getElementFor<ValueResourceUsage>(
- *this,
- Position::forValue(callOp.getArgOperands()[operandIdx]),
+ *this, Position::forValue(callOp->getResult(operandIdx)),
DFX::Resolution::OPTIONAL);
getState() ^= argUsage;
return WalkResult::advance();
@@ -546,6 +587,13 @@
if (traversalResult == TraversalResult::INCOMPLETE) {
removeAssumedBits(NOT_EXTERNAL);
}
+
+ // Filter out impossible states by marking the state invalid.
+ // The fixpoint framework will try again.
+ if (!isValidState(assumedBits)) {
+ return indicatePessimisticFixpoint();
+ }
+
return assumedBits == getAssumed() ? ChangeStatus::UNCHANGED
: ChangeStatus::CHANGED;
}
diff --git a/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp b/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp
index d35a7d2..fca1be7 100644
--- a/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp
+++ b/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp
@@ -61,6 +61,14 @@
}
}
+// Returns either the affinity of |op| or nullptr.
+static IREE::Stream::AffinityAttr getOpAffinity(Operation *op) {
+ if (auto affinityOp = dyn_cast<IREE::Stream::AffinityOpInterface>(op)) {
+ return affinityOp.getAffinity();
+ }
+ return {};
+}
+
// Base pattern type for resource usage refinement.
// The results of the usage analysis are available for use by subclasses.
template <typename OpT>
@@ -70,6 +78,27 @@
ResourceUsageAnalysis &analysis;
+ // Updates the |arg| type to the lifetime derived by analysis, if needed.
+ // Returns true if a change was made.
+ bool applyArgTransition(BlockArgument arg, PatternRewriter &rewriter) const {
+ auto oldType = arg.getType().dyn_cast<IREE::Stream::ResourceType>();
+ if (!oldType) return false;
+ auto newUsage = analysis.lookupResourceUsage(arg);
+ auto newLifetime = convertUsageToLifetime(newUsage);
+ auto newType = rewriter.getType<IREE::Stream::ResourceType>(newLifetime);
+ if (oldType == newType) {
+ // Old and new lifetimes match; no need to apply a transition.
+ return false;
+ } else if (oldType.getLifetime() != IREE::Stream::Lifetime::Unknown) {
+ // Transitioning lifetimes; rely on users to insert the transitions.
+ return false;
+ } else {
+ // Directly overwrite the existing lifetime.
+ arg.setType(newType);
+ return true;
+ }
+ }
+
// Updates the |result| type to the lifetime derived by analysis, if needed.
// Returns true if a change was made.
bool applyResultTransition(Operation *op, Value result,
@@ -78,25 +107,62 @@
if (!oldType) return false;
auto newUsage = analysis.lookupResourceUsage(result);
auto newLifetime = convertUsageToLifetime(newUsage);
- if (oldType.getLifetime() == newLifetime) return false;
auto newType = rewriter.getType<IREE::Stream::ResourceType>(newLifetime);
- result.setType(newType);
- return true;
+ if (oldType == newType) {
+ // Old and new lifetimes match; no need to apply a transition.
+ return false;
+ } else if (oldType.getLifetime() != IREE::Stream::Lifetime::Unknown) {
+ // Transitioning from one lifetime to another; insert a transfer
+ // placeholder (as we may later decide it's ok to transition on a
+ // particular device).
+ auto resultSize = rewriter.createOrFold<IREE::Stream::ResourceSizeOp>(
+ op->getLoc(), result);
+ auto affinityAttr = getOpAffinity(op);
+ auto transferOp = rewriter.create<IREE::Stream::AsyncTransferOp>(
+ op->getLoc(), newType, result, resultSize, resultSize,
+ /*source_affinity=*/affinityAttr,
+ /*target_affinity=*/affinityAttr);
+ result.replaceUsesWithIf(transferOp.result(), [&](OpOperand &operand) {
+ return operand.getOwner() != transferOp &&
+ operand.getOwner() != resultSize.getDefiningOp();
+ });
+ return true;
+ } else {
+ // Directly overwrite the existing lifetime.
+ result.setType(newType);
+ return true;
+ }
}
// Updates the |result| type to the lifetime derived by analysis, if needed.
- // Returns true if a change was made.
- bool applyResultTransition(Operation *op, Value result, Value resultSize,
- Attribute affinityAttr,
+ // Returns true if a change was made. Same as above but for when we have the
+ // information available and don't need to insert the queries.
+ bool applyResultTransition(Value result, Value resultSize,
+ IREE::Stream::AffinityAttr affinityAttr,
PatternRewriter &rewriter) const {
auto oldType = result.getType().dyn_cast<IREE::Stream::ResourceType>();
if (!oldType) return false;
auto newUsage = analysis.lookupResourceUsage(result);
auto newLifetime = convertUsageToLifetime(newUsage);
- if (oldType.getLifetime() == newLifetime) return false;
auto newType = rewriter.getType<IREE::Stream::ResourceType>(newLifetime);
- result.setType(newType);
- return true;
+ if (oldType == newType) {
+ // Old and new lifetimes match; no need to apply a transition.
+ return false;
+ } else if (oldType.getLifetime() != IREE::Stream::Lifetime::Unknown) {
+ // Transitioning from one lifetime to another; insert a transfer
+ // placeholder (as we may later decide it's ok to transition on a
+ // particular device).
+ auto transferOp = rewriter.create<IREE::Stream::AsyncTransferOp>(
+ result.getLoc(), newType, result, resultSize, resultSize,
+ /*source_affinity=*/affinityAttr,
+ /*target_affinity=*/affinityAttr);
+ result.replaceAllUsesExcept(transferOp.result(), transferOp);
+ return true;
+ } else {
+ // Directly overwrite the existing lifetime.
+ result.setType(newType);
+ return true;
+ }
}
// Updates all blocks argument lifetimes within the regions of |op|.
@@ -108,16 +174,9 @@
for (auto &block : region) {
rewriter.setInsertionPoint(&block, block.begin());
for (auto &blockArg : block.getArguments()) {
- auto oldType =
- blockArg.getType().dyn_cast<IREE::Stream::ResourceType>();
- if (!oldType) continue;
- auto newUsage = analysis.lookupResourceUsage(blockArg);
- auto newLifetime = convertUsageToLifetime(newUsage);
- if (oldType.getLifetime() == newLifetime) return false;
- auto newType =
- rewriter.getType<IREE::Stream::ResourceType>(newLifetime);
- blockArg.setType(newType);
- didChange = true;
+ if (applyArgTransition(blockArg, rewriter)) {
+ didChange = true;
+ }
}
}
}
@@ -158,13 +217,16 @@
auto oldType = inputType.value().dyn_cast<IREE::Stream::ResourceType>();
if (!oldType) {
newInputs.push_back(inputType.value());
- continue;
+ } else if (oldType.getLifetime() == IREE::Stream::Lifetime::Unknown) {
+ auto blockArg = op.getArgument(inputType.index());
+ auto newUsage = analysis.lookupResourceUsage(blockArg);
+ auto newLifetime = convertUsageToLifetime(newUsage);
+ auto newType =
+ rewriter.getType<IREE::Stream::ResourceType>(newLifetime);
+ newInputs.push_back(newType);
+ } else {
+ newInputs.push_back(oldType);
}
- auto blockArg = op.getArgument(inputType.index());
- auto newUsage = analysis.lookupResourceUsage(blockArg);
- auto newLifetime = convertUsageToLifetime(newUsage);
- auto newType = rewriter.getType<IREE::Stream::ResourceType>(newLifetime);
- newInputs.push_back(newType);
}
// Results:
@@ -174,13 +236,16 @@
auto oldType = outputType.value().dyn_cast<IREE::Stream::ResourceType>();
if (!oldType) {
newOutputs.push_back(outputType.value());
- continue;
+ } else if (oldType.getLifetime() == IREE::Stream::Lifetime::Unknown) {
+ auto returnValue = anyReturnOp.getOperand(outputType.index());
+ auto newUsage = analysis.lookupResourceUsage(returnValue);
+ auto newLifetime = convertUsageToLifetime(newUsage);
+ auto newType =
+ rewriter.getType<IREE::Stream::ResourceType>(newLifetime);
+ newOutputs.push_back(newType);
+ } else {
+ newOutputs.push_back(oldType);
}
- auto returnValue = anyReturnOp.getOperand(outputType.index());
- auto newUsage = analysis.lookupResourceUsage(returnValue);
- auto newLifetime = convertUsageToLifetime(newUsage);
- auto newType = rewriter.getType<IREE::Stream::ResourceType>(newLifetime);
- newOutputs.push_back(newType);
}
auto newFuncType = rewriter.getFunctionType(newInputs, newOutputs);
if (op.getType() != newFuncType) {
@@ -231,11 +296,7 @@
// Walk into nested regions first so we have the final result types returned
// by the regions.
bool didChange = this->applyRegionTransitions(op, rewriter);
- Attribute affinityAttr;
- if (auto affinityOp =
- dyn_cast<IREE::Stream::AffinityOpInterface>(op.getOperation())) {
- affinityAttr = affinityOp.getAffinity();
- }
+ auto affinityAttr = getOpAffinity(op);
rewriter.startRootUpdate(op);
@@ -247,7 +308,7 @@
continue;
}
auto resultSize = sizeAwareOp.getResultSize(i);
- if (this->applyResultTransition(op, result, resultSize, affinityAttr,
+ if (this->applyResultTransition(result, resultSize, affinityAttr,
rewriter)) {
didChange = true;
}
diff --git a/iree/compiler/Dialect/Stream/Transforms/test/refine_usage.mlir b/iree/compiler/Dialect/Stream/Transforms/test/refine_usage.mlir
index 4b0e6a2..b94dccc 100644
--- a/iree/compiler/Dialect/Stream/Transforms/test/refine_usage.mlir
+++ b/iree/compiler/Dialect/Stream/Transforms/test/refine_usage.mlir
@@ -114,6 +114,26 @@
// -----
+// Tests invalid transfer conflict resolution.
+// Constants cannot be mutated even though it is tied. This survives after
+// copy-on-write materialization because of the transfer and we need to preserve
+// it such that the copy is performed as epxected.
+
+// CHECK-LABEL: @transferResolution
+// CHECK-SAME: (%[[ARG0:.+]]: !stream.resource<constant>, %[[SIZE:.+]]: index)
+// CHECK-SAME: -> !stream.resource<external>
+func @transferResolution(%arg0: !stream.resource<constant>, %size: index) -> !stream.resource<*> {
+ %c1 = arith.constant 1 : index
+ // CHECK: %[[ARG0_EXT:.+]] = stream.async.transfer %[[ARG0]] : !stream.resource<constant>{%[[SIZE]]} -> !stream.resource<external>{%[[SIZE]]}
+ %arg0_any = stream.async.transfer %arg0 : !stream.resource<constant>{%size} -> !stream.resource<*>{%size}
+ // CHECK: %[[RET0:.+]] = stream.async.dispatch @ex::@dispatch[%c1, %c1, %c1](%[[ARG0_EXT]]) : (!stream.resource<external>{%[[SIZE]]}) -> %[[ARG0_EXT]]{%[[SIZE]]}
+ %ret0_any = stream.async.dispatch @ex::@dispatch[%c1, %c1, %c1](%arg0_any) : (!stream.resource<*>{%size}) -> %arg0_any{%size}
+ // return %[[RET0]] : !stream.resource<external>
+ return %ret0_any : !stream.resource<*>
+}
+
+// -----
+
// Tests that global usage propagates through loads/stores.
util.global private mutable @variable : !stream.resource<variable>