Handling AffinityOpInterface on stream.async.transfer.
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTargetDevices.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTargetDevices.cpp
index 395673c..67bf383 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTargetDevices.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTargetDevices.cpp
@@ -166,8 +166,9 @@
}
if (auto affinityOp = dyn_cast<IREE::Stream::AffinityOpInterface>(op)) {
- if (!affinityOp.getAffinity())
- affinityOp.setAffinity(affinityAttr);
+ if (!affinityOp.getAffinityAttr()) {
+ affinityOp.setAffinityAttr(affinityAttr);
+ }
} else {
if (!op.hasAttr(affinityName)) {
op.setAttr(affinityName, affinityAttr);
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning.cpp
index 5ed2ff8..93fcd37 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning.cpp
@@ -58,9 +58,9 @@
for (auto *op : ops) {
if (auto affinityOp = dyn_cast<IREE::Stream::AffinityOpInterface>(op)) {
if (!IREE::Stream::AffinityAttr::areCompatible(
- affinity, affinityOp.getAffinity())) {
+ affinity, affinityOp.getAffinityAttr())) {
return op->emitError("op affinity ")
- << affinityOp.getAffinity()
+ << affinityOp.getAffinityAttr()
<< " is not compatible with the partition affinity " << affinity;
}
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp
index b86ff61..a4fff96 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp
@@ -54,8 +54,8 @@
DenseSet<Operation *> clonedOps;
void insert(Operation *op) {
if (auto affinityOp = dyn_cast<IREE::Stream::AffinityOpInterface>(op)) {
- affinity = affinity ? affinity.joinAND(affinityOp.getAffinity())
- : affinityOp.getAffinity();
+ affinity = affinity ? affinity.joinAND(affinityOp.getAffinityAttr())
+ : affinityOp.getAffinityAttr();
}
ops.insert(op);
}
@@ -109,7 +109,7 @@
IREE::Stream::AffinityAttr affinityAttr;
if (auto affinityOp = dyn_cast<IREE::Stream::AffinityOpInterface>(op)) {
- affinityAttr = affinityOp.getAffinity();
+ affinityAttr = affinityOp.getAffinityAttr();
}
LLVM_DEBUG({
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
index 8e4d854..bdc5aaf 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
@@ -25,13 +25,13 @@
// size of operands must be queried from the input resource.
static Value buildResultSizeOf(Location loc, Value tensorValue,
ValueRange dynamicDims,
+ IREE::Stream::AffinityAttr affinityAttr,
ConversionPatternRewriter &rewriter) {
// TODO(benvanik): see if we can stash this on the side to avoid expensive
// materialization of a bunch of redundant IR.
return rewriter.create<IREE::Stream::TensorSizeOfOp>(
loc, rewriter.getIndexType(), TypeAttr::get(tensorValue.getType()),
- dynamicDims,
- IREE::Stream::AffinityAttr::lookup(tensorValue.getDefiningOp()));
+ dynamicDims, affinityAttr);
}
struct ConvertTensorConstantOp
@@ -123,13 +123,14 @@
auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
auto source =
consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter);
- auto resultSize = buildResultSizeOf(op.getLoc(), op.getResult(),
- op.getResultDims(), rewriter);
+ auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
+ auto resultSize =
+ buildResultSizeOf(op.getLoc(), op.getResult(), op.getResultDims(),
+ affinityAttr, rewriter);
rewriter.replaceOpWithNewOp<IREE::Stream::TensorCloneOp>(
op, unknownType, source.resource, op.getSource().getType(),
op.getSourceDims(), source.resourceSize, op.getResult().getType(),
- adaptor.getResultDims(), resultSize,
- IREE::Stream::AffinityAttr::lookup(op));
+ adaptor.getResultDims(), resultSize, affinityAttr);
return success();
}
};
@@ -141,10 +142,12 @@
matchAndRewrite(IREE::Flow::TensorAllocaOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type unknownType = IREE::Stream::ResourceType::get(getContext());
- auto resultSize = buildResultSizeOf(op.getLoc(), op.getResult(),
- op.getResultDims(), rewriter);
+ auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
+ auto resultSize =
+ buildResultSizeOf(op.getLoc(), op.getResult(), op.getResultDims(),
+ affinityAttr, rewriter);
rewriter.replaceOpWithNewOp<IREE::Stream::AsyncAllocaOp>(
- op, unknownType, resultSize, IREE::Stream::AffinityAttr::lookup(op));
+ op, unknownType, resultSize, affinityAttr);
return success();
}
};
@@ -156,11 +159,13 @@
matchAndRewrite(IREE::Flow::TensorEmptyOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type unknownType = IREE::Stream::ResourceType::get(getContext());
- auto resultSize = buildResultSizeOf(op.getLoc(), op.getResult(),
- op.getResultDims(), rewriter);
+ auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
+ auto resultSize =
+ buildResultSizeOf(op.getLoc(), op.getResult(), op.getResultDims(),
+ affinityAttr, rewriter);
rewriter.replaceOpWithNewOp<IREE::Stream::TensorEmptyOp>(
op, unknownType, op.getResult().getType(), adaptor.getResultDims(),
- resultSize, IREE::Stream::AffinityAttr::lookup(op));
+ resultSize, affinityAttr);
return success();
}
};
@@ -172,12 +177,13 @@
matchAndRewrite(IREE::Flow::TensorSplatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
- auto resultSize = buildResultSizeOf(op.getLoc(), op.getResult(),
- op.getResultDims(), rewriter);
+ auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
+ auto resultSize =
+ buildResultSizeOf(op.getLoc(), op.getResult(), op.getResultDims(),
+ affinityAttr, rewriter);
rewriter.replaceOpWithNewOp<IREE::Stream::TensorSplatOp>(
op, unknownType, adaptor.getValue(), op.getResult().getType(),
- adaptor.getResultDims(), resultSize,
- IREE::Stream::AffinityAttr::lookup(op));
+ adaptor.getResultDims(), resultSize, affinityAttr);
return success();
}
};
@@ -230,13 +236,15 @@
auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
auto source =
consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter);
- auto resultSize = buildResultSizeOf(op.getLoc(), op.getResult(),
- op.getResultDims(), rewriter);
+ auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
+ auto resultSize =
+ buildResultSizeOf(op.getLoc(), op.getResult(), op.getResultDims(),
+ affinityAttr, rewriter);
rewriter.replaceOpWithNewOp<IREE::Stream::TensorSliceOp>(
op, unknownType, source.resource, op.getSource().getType(),
op.getSourceDims(), source.resourceSize, adaptor.getStartIndices(),
adaptor.getLengths(), op.getResult().getType(), adaptor.getResultDims(),
- resultSize, IREE::Stream::AffinityAttr::lookup(op));
+ resultSize, affinityAttr);
return success();
}
};
@@ -676,6 +684,8 @@
LogicalResult
matchAndRewrite(IREE::Flow::DispatchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
+
// Zero is going to be used for each operand to start.
auto zeroOffset = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
@@ -723,7 +733,8 @@
auto resultDynamicDims = IREE::Util::buildDynamicDimsForValue(
op.getLoc(), result.value(), rewriter);
resultSizes.push_back(buildResultSizeOf(op.getLoc(), result.value(),
- resultDynamicDims, rewriter));
+ resultDynamicDims, affinityAttr,
+ rewriter));
resultTypes.push_back(unknownType);
}
}
@@ -732,7 +743,7 @@
op, resultTypes, adaptor.getWorkload(), adaptor.getEntryPointsAttr(),
dispatchOperands, dispatchOperandSizes, dispatchOperandOffsets,
dispatchOperandEnds, dispatchOperandLengths, resultSizes,
- adaptor.getTiedOperandsAttr(), IREE::Stream::AffinityAttr::lookup(op));
+ adaptor.getTiedOperandsAttr(), affinityAttr);
newOp->setDialectAttrs(op->getDialectAttrs());
return success();
}
@@ -778,6 +789,8 @@
LogicalResult
matchAndRewrite(IREE::Flow::CallOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
+
// Zero is going to be used for each operand to start.
auto zeroOffset = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
@@ -825,7 +838,8 @@
auto resultDynamicDims = IREE::Util::buildDynamicDimsForValue(
op.getLoc(), result.value(), rewriter);
resultSizes.push_back(buildResultSizeOf(op.getLoc(), result.value(),
- resultDynamicDims, rewriter));
+ resultDynamicDims, affinityAttr,
+ rewriter));
resultTypes.push_back(unknownType);
}
}
@@ -834,7 +848,7 @@
op, resultTypes, adaptor.getCalleeAttr(), callOperands,
callOperandSizes, callOperandOffsets, callOperandEnds,
callOperandLengths, resultSizes, adaptor.getTiedOperandsAttr(),
- IREE::Stream::AffinityAttr::lookup(op));
+ affinityAttr);
newOp->setDialectAttrs(op->getDialectAttrs());
return success();
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp
index 35eb31f..4323473 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp
@@ -49,11 +49,7 @@
}
}
- auto affinityAttr =
- dyn_cast_if_present<IREE::Stream::AffinityAttr>(op.getAffinityAttr());
- if (!affinityAttr) {
- affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
- }
+ auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
// Import (buffer view to stream resource).
auto resultType = rewriter.getType<IREE::Stream::ResourceType>(
@@ -138,11 +134,7 @@
return rewriter.notifyMatchFailure(op, "unsupported HAL cast conversion");
}
- auto affinityAttr =
- dyn_cast_if_present<IREE::Stream::AffinityAttr>(op.getAffinityAttr());
- if (!affinityAttr) {
- affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
- }
+ auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
auto source =
consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter);
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.td
index f34003d..f17b753 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.td
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.td
@@ -145,9 +145,10 @@
Returns the stream affinity for the op, indicating where it should run.
}],
/*retTy=*/"IREE::Stream::AffinityAttr",
- /*methodName=*/"getAffinity",
+ /*methodName=*/"getAffinityAttr",
/*args=*/(ins),
- /*methodBody=*/[{
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
return dyn_cast_or_null<IREE::Stream::AffinityAttr>($_self->getAttr("affinity"));
}]
>,
@@ -156,9 +157,10 @@
Sets the stream affinity for the op, indicating where it should run.
}],
/*retTy=*/"void",
- /*methodName=*/"setAffinity",
+ /*methodName=*/"setAffinityAttr",
/*args=*/(ins "IREE::Stream::AffinityAttr":$value),
- /*methodBody=*/[{
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
if (value) $_self->setAttr("affinity", value);
else $_self->removeAttr("affinity");
}]
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
index d9da282..698c7b9 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
@@ -2019,6 +2019,17 @@
return success();
}
+IREE::Stream::AffinityAttr AsyncTransferOp::getAffinityAttr() {
+ return getResultAffinityAttr();
+}
+
+void AsyncTransferOp::setAffinityAttr(IREE::Stream::AffinityAttr value) {
+ if (value)
+ setResultAffinityAttr(value);
+ else
+ removeResultAffinityAttr();
+}
+
void AsyncTransferOp::getAsyncAccessRanges(
SmallVectorImpl<AsyncAccessRange> &ranges) {
ranges.push_back({ResourceAccessBitfield::Read, getSource(), Value{},
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
index 99e793a..dbe5207 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
@@ -86,10 +86,7 @@
let opDocGroup = OpGroupResourceOps in {
def Stream_ResourceAllocOp : Stream_Op<"resource.alloc", [
- DeclareOpInterfaceMethods<Stream_AffinityOp, [
- "getAffinity",
- "setAffinity",
- ]>,
+ Stream_AffinityOp,
Util_SizeAwareOp,
AlwaysSpeculatable,
MemoryEffects<[MemAlloc]>,
@@ -148,10 +145,7 @@
}
def Stream_ResourceAllocaOp : Stream_Op<"resource.alloca", [
- DeclareOpInterfaceMethods<Stream_AffinityOp, [
- "getAffinity",
- "setAffinity",
- ]>,
+ Stream_AffinityOp,
Stream_TimelineOp,
Util_SizeAwareOp,
AlwaysSpeculatable,
@@ -209,10 +203,7 @@
}
def Stream_ResourceDeallocaOp : Stream_Op<"resource.dealloca", [
- DeclareOpInterfaceMethods<Stream_AffinityOp, [
- "getAffinity",
- "setAffinity",
- ]>,
+ Stream_AffinityOp,
Stream_TimelineOp,
Util_SizeAwareOp,
MemoryEffects<[MemFree]>,
@@ -645,10 +636,7 @@
def Stream_ParameterLoadOp : Stream_PureOp<"parameter.load", [
AttrSizedOperandSegments,
AllTypesMatch<["results"]>,
- DeclareOpInterfaceMethods<Stream_AffinityOp, [
- "getAffinity",
- "setAffinity",
- ]>,
+ Stream_AffinityOp,
Stream_CmdPhaseOp,
Stream_TimelineOp,
Util_SizeAwareOp,
@@ -702,10 +690,7 @@
}
def Stream_ParameterReadOp : Stream_Op<"parameter.read", [
- DeclareOpInterfaceMethods<Stream_AffinityOp, [
- "getAffinity",
- "setAffinity",
- ]>,
+ Stream_AffinityOp,
Stream_CmdPhaseOp,
Stream_TimelineOp,
Util_SizeAwareOp,
@@ -757,10 +742,7 @@
}
def Stream_ParameterWriteOp : Stream_Op<"parameter.write", [
- DeclareOpInterfaceMethods<Stream_AffinityOp, [
- "getAffinity",
- "setAffinity",
- ]>,
+ Stream_AffinityOp,
Stream_CmdPhaseOp,
Stream_TimelineOp,
Util_SizeAwareOp,
@@ -813,10 +795,7 @@
def Stream_ParameterGatherOp : Stream_Op<"parameter.gather", [
AttrSizedOperandSegments,
- DeclareOpInterfaceMethods<Stream_AffinityOp, [
- "getAffinity",
- "setAffinity",
- ]>,
+ Stream_AffinityOp,
Stream_CmdPhaseOp,
Stream_TimelineOp,
Util_SizeAwareOp,
@@ -872,10 +851,7 @@
def Stream_ParameterScatterOp : Stream_Op<"parameter.scatter", [
AttrSizedOperandSegments,
- DeclareOpInterfaceMethods<Stream_AffinityOp, [
- "getAffinity",
- "setAffinity",
- ]>,
+ Stream_AffinityOp,
Stream_CmdPhaseOp,
Stream_TimelineOp,
Util_SizeAwareOp,
@@ -982,10 +958,7 @@
}
def Stream_FileReadOp : Stream_Op<"file.read", [
- DeclareOpInterfaceMethods<Stream_AffinityOp, [
- "getAffinity",
- "setAffinity",
- ]>,
+ Stream_AffinityOp,
Stream_CmdPhaseOp,
Stream_TimelineOp,
Util_SizeAwareOp,
@@ -1040,10 +1013,7 @@
}
def Stream_FileWriteOp : Stream_Op<"file.write", [
- DeclareOpInterfaceMethods<Stream_AffinityOp, [
- "getAffinity",
- "setAffinity",
- ]>,
+ Stream_AffinityOp,
Stream_CmdPhaseOp,
Stream_TimelineOp,
Util_SizeAwareOp,
@@ -1783,10 +1753,7 @@
let opDocGroup = OpGroupAsyncOps in {
def Stream_AsyncAllocaOp : Stream_Op<"async.alloca", [
- DeclareOpInterfaceMethods<Stream_AffinityOp, [
- "getAffinity",
- "setAffinity",
- ]>,
+ Stream_AffinityOp,
Stream_AsyncPhaseOp,
DeclareOpInterfaceMethods<Stream_StreamableOp, [
"isMetadata",
@@ -2250,7 +2217,10 @@
}
def Stream_AsyncTransferOp : Stream_Op<"async.transfer", [
- Stream_AffinityOp,
+ DeclareOpInterfaceMethods<Stream_AffinityOp, [
+ "getAffinityAttr",
+ "setAffinityAttr",
+ ]>,
Stream_AsyncPhaseOp,
Stream_StreamableOp,
DeclareOpInterfaceMethods<Stream_AsyncAccessOp, [
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp
index 7ce79c9..19c2410 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp
@@ -274,7 +274,7 @@
return attr;
// See if the affinity specified provides a resource configuration.
if (auto affinityOp = llvm::dyn_cast<AffinityOpInterface>(op)) {
- auto affinityAttr = affinityOp.getAffinity();
+ auto affinityAttr = affinityOp.getAffinityAttr();
if (affinityAttr) {
auto attr = affinityAttr.getResourceConfigAttr();
if (attr)
@@ -339,7 +339,7 @@
auto attrId = StringAttr::get(op->getContext(), "stream.affinity");
while (op) {
if (auto affinityOp = llvm::dyn_cast<AffinityOpInterface>(op)) {
- auto affinity = affinityOp.getAffinity();
+ auto affinity = affinityOp.getAffinityAttr();
if (affinity)
return affinity;
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp
index aa5cb25..b0b66ac 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp
@@ -46,6 +46,7 @@
static Value buildTensorImportOp(Location loc, Value sourceTensor,
Type targetType,
SmallPtrSetImpl<Operation *> &consumingOps,
+ IREE::Stream::AffinityAttr affinityAttr,
OpBuilder &builder) {
// Gather dynamic dimensions from the input value.
auto dynamicDims =
@@ -56,8 +57,7 @@
// a transfer operation that may need to reformat the tensor.
auto encodingAttr = TypeAttr::get(sourceTensor.getType());
Value resultSize = builder.create<IREE::Stream::TensorSizeOfOp>(
- loc, builder.getIndexType(), encodingAttr, dynamicDims,
- /*affinity=*/nullptr);
+ loc, builder.getIndexType(), encodingAttr, dynamicDims, affinityAttr);
// Associate the external SSA value, encoding, and shape information with the
// stream resource. When lowering we'll then have all the metadata required
@@ -66,7 +66,7 @@
IREE::Stream::Lifetime::External);
auto importOp = builder.create<IREE::Stream::TensorImportOp>(
loc, externalType, sourceTensor, encodingAttr, dynamicDims, resultSize,
- /*affinity=*/nullptr);
+ affinityAttr);
consumingOps.insert(importOp);
// If needed insert a transfer to the target lifetime.
@@ -75,8 +75,8 @@
result = builder
.create<IREE::Stream::AsyncTransferOp>(
loc, targetType, result, resultSize, resultSize,
- /*source_affinity=*/nullptr,
- /*result_affinity=*/nullptr)
+ /*source_affinity=*/affinityAttr,
+ /*result_affinity=*/affinityAttr)
.getResult();
}
@@ -90,6 +90,7 @@
// external tensor value.
static Value buildTensorExportOp(Location loc, Value sourceValue,
TensorType targetType, ValueRange dynamicDims,
+ IREE::Stream::AffinityAttr affinityAttr,
OpBuilder &builder) {
auto source = consumeTensorOperand(loc, sourceValue, builder);
@@ -101,14 +102,13 @@
loc, externalType, source.resource, source.resourceSize,
source.resourceSize,
/*source_affinity=*/nullptr,
- /*result_affinity=*/nullptr);
+ /*result_affinity=*/affinityAttr);
}
// Associate the stream resource and external encoding and shape information.
auto newOp = builder.create<IREE::Stream::TensorExportOp>(
loc, targetType, source.resource, TypeAttr::get(targetType), dynamicDims,
- source.resourceSize,
- /*affinity=*/nullptr);
+ source.resourceSize, affinityAttr);
return newOp.getResult();
}
@@ -141,6 +141,8 @@
if (!doesOperationNeedWrapping(op))
return failure();
+ auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
+
// Export resources into tensor operands for the op to consume.
SmallVector<Value> newOperands;
newOperands.reserve(op->getNumOperands());
@@ -156,8 +158,9 @@
auto dynamicDims = IREE::Util::buildDynamicDimsForValue(
op->getLoc(), oldOperand, rewriter);
- newOperands.push_back(buildTensorExportOp(
- op->getLoc(), newOperand, tensorType, dynamicDims, rewriter));
+ newOperands.push_back(buildTensorExportOp(op->getLoc(), newOperand,
+ tensorType, dynamicDims,
+ affinityAttr, rewriter));
}
rewriter.modifyOpInPlace(op, [&]() { op->setOperands(newOperands); });
@@ -173,7 +176,7 @@
SmallPtrSet<Operation *, 4> consumingOps;
auto importedValue = buildTensorImportOp(
op->getLoc(), result, rewriter.getType<IREE::Stream::ResourceType>(),
- consumingOps, rewriter);
+ consumingOps, affinityAttr, rewriter);
result.replaceAllUsesExcept(importedValue, consumingOps);
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeCopyOnWrite.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeCopyOnWrite.cpp
index 34c0ef8..5b7d3a9 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeCopyOnWrite.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeCopyOnWrite.cpp
@@ -103,7 +103,7 @@
IREE::Stream::AffinityAttr affinity;
if (auto affinityOp =
dyn_cast<IREE::Stream::AffinityOpInterface>(tiedOp.getOperation())) {
- affinity = affinityOp.getAffinity();
+ affinity = affinityOp.getAffinityAttr();
}
// Clones each operand that is tied to a result and it may be required.
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp
index bc73616..02c2bb0 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp
@@ -65,7 +65,7 @@
// Returns either the affinity of |op| or nullptr.
static IREE::Stream::AffinityAttr getOpAffinity(Operation *op) {
if (auto affinityOp = dyn_cast<IREE::Stream::AffinityOpInterface>(op)) {
- return affinityOp.getAffinity();
+ return affinityOp.getAffinityAttr();
}
return {};
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp
index 9c5e3d4..c850c3b 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp
@@ -152,8 +152,8 @@
// want to preserve those as long as possible.
if (auto affinityOp =
dyn_cast<IREE::Stream::AffinityOpInterface>(clonedOp)) {
- if (affinityOp.getAffinity() == partition->affinity) {
- affinityOp.setAffinity(nullptr);
+ if (affinityOp.getAffinityAttr() == partition->affinity) {
+ affinityOp.setAffinityAttr(nullptr);
}
}
diff --git a/compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel b/compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel
index 7bbd7f5..66643e6 100644
--- a/compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel
+++ b/compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel
@@ -29,6 +29,7 @@
deps = [
"//compiler/src/iree/compiler/Dialect/Encoding/IR",
"//compiler/src/iree/compiler/Dialect/Flow/IR",
+ "//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
"//compiler/src/iree/compiler/Dialect/Stream/IR",
"//compiler/src/iree/compiler/Dialect/Util/IR",
diff --git a/compiler/src/iree/compiler/ExternalInterfaces/CMakeLists.txt b/compiler/src/iree/compiler/ExternalInterfaces/CMakeLists.txt
index 4e2f29a..a63fca3 100644
--- a/compiler/src/iree/compiler/ExternalInterfaces/CMakeLists.txt
+++ b/compiler/src/iree/compiler/ExternalInterfaces/CMakeLists.txt
@@ -34,6 +34,7 @@
MLIRValueBoundsOpInterface
iree::compiler::Dialect::Encoding::IR
iree::compiler::Dialect::Flow::IR
+ iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::LinalgExt::IR
iree::compiler::Dialect::Stream::IR
iree::compiler::Dialect::Util::IR
diff --git a/compiler/src/iree/compiler/ExternalInterfaces/StreamExternalModels.cpp b/compiler/src/iree/compiler/ExternalInterfaces/StreamExternalModels.cpp
index e3ba257..b82d599 100644
--- a/compiler/src/iree/compiler/ExternalInterfaces/StreamExternalModels.cpp
+++ b/compiler/src/iree/compiler/ExternalInterfaces/StreamExternalModels.cpp
@@ -6,6 +6,10 @@
#include "iree/compiler/ExternalInterfaces/StreamExternalModels.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
@@ -14,27 +18,47 @@
namespace {
-template <typename OpT>
-struct AffinityOpAttrExternalModel
+struct FlowTransferTargetAffinityAttrExternalModel
: public IREE::Stream::AffinityOpInterface::ExternalModel<
- AffinityOpAttrExternalModel<OpT>, OpT> {
+ FlowTransferTargetAffinityAttrExternalModel,
+ IREE::Flow::TensorTransferOp> {
static void add(MLIRContext *context) {
- OpT::template attachInterface<AffinityOpAttrExternalModel<OpT>>(*context);
+ IREE::Flow::TensorTransferOp::attachInterface<
+ FlowTransferTargetAffinityAttrExternalModel>(*context);
}
- // Most structural ops don't require affinities and after placement we don't
- // use the affinities even if the ops still exist.
+ bool requiresAffinity(Operation *op) const { return true; }
+
+ IREE::Stream::AffinityAttr getAffinity(Operation *op) const {
+ return op->getAttrOfType<IREE::Stream::AffinityAttr>("target");
+ }
+
+ void setAffinity(Operation *op, IREE::Stream::AffinityAttr value) const {
+ op->setAttr("target", value);
+ }
+};
+
+template <typename OpT>
+struct HALTensorAffinityAttrExternalModel
+ : public IREE::Stream::AffinityOpInterface::ExternalModel<
+ HALTensorAffinityAttrExternalModel<OpT>, OpT> {
+ static void add(MLIRContext *context) {
+ OpT::template attachInterface<HALTensorAffinityAttrExternalModel<OpT>>(
+ *context);
+ }
+
bool requiresAffinity(Operation *op) const { return false; }
IREE::Stream::AffinityAttr getAffinity(Operation *op) const {
- return op->getAttrOfType<IREE::Stream::AffinityAttr>("stream.affinity");
+ return op->getAttrOfType<IREE::Stream::AffinityAttr>("affinity");
}
void setAffinity(Operation *op, IREE::Stream::AffinityAttr value) const {
if (value)
- op->setAttr("stream.affinity", value);
- else
- op->removeAttr("stream.affinity");
+ op->setAttr("affinity", value);
+ } else {
+ op->removeAttr("affinity");
+ }
}
};
@@ -61,17 +85,58 @@
void setAffinity(Operation *op, IREE::Stream::AffinityAttr value) const {
if (value)
op->setAttr("stream.affinity", value);
- else
+ } else {
op->removeAttr("stream.affinity");
+ }
+ }
+};
+
+template <typename OpT>
+struct AffinityOpAttrExternalModel
+ : public IREE::Stream::AffinityOpInterface::ExternalModel<
+ AffinityOpAttrExternalModel<OpT, kRequiresAffinity>, OpT> {
+ static void add(MLIRContext *context) {
+ OpT::template attachInterface<
+ AffinityOpAttrExternalModel<OpT, kRequiresAffinity>>(*context);
+ }
+
+ // Most structural ops don't require affinities and after placement we don't
+ // use the affinities even if the ops still exist.
+ bool requiresAffinity(Operation *op) const { return false; }
+
+ IREE::Stream::AffinityAttr getAffinity(Operation *op) const {
+ return op->getAttrOfType<IREE::Stream::AffinityAttr>("stream.affinity");
+ }
+
+ void setAffinity(Operation *op, IREE::Stream::AffinityAttr value) const {
+ if (value)
+ op->setAttr("stream.affinity", value);
+ } else {
+ op->removeAttr("stream.affinity");
+ }
}
};
} // namespace
void registerStreamExternalModels(DialectRegistry ®istry) {
- // Must ensure that any dependent dialects are registered.
- registry.insert<IREE::Util::UtilDialect>();
+ registry.insert<IREE::Flow::FlowDialect>();
+ registry.addExtension(
+ +[](MLIRContext *context, IREE::Flow::FlowDialect *dialect) {
+ FlowTransferTargetAffinityAttrExternalModel::add(context);
+ });
+ registry.insert<IREE::HAL::HALDialect>();
+ registry.addExtension(+[](MLIRContext *context,
+ IREE::HAL::HALDialect *dialect) {
+ HALTensorAffinityAttrExternalModel<IREE::HAL::TensorImportOp>::add(context);
+ HALTensorAffinityAttrExternalModel<IREE::HAL::TensorExportOp>::add(context);
+ HALTensorAffinityAttrExternalModel<IREE::HAL::TensorAliasOp>::add(context);
+ HALTensorAffinityAttrExternalModel<IREE::HAL::TensorBarrierOp>::add(
+ context);
+ });
+
+ registry.insert<IREE::Util::UtilDialect>();
registry.addExtension(
+[](MLIRContext *context, IREE::Util::UtilDialect *dialect) {
GlobalOpAffinityAttrExternalModel<IREE::Util::GlobalOp>::add(context);