Adding iree.abi.affinity arg/result attrs on the native ABI.
These map to an opaque affinity on the tensor import/export ops and
act as a seed to placement when lowering into stream.
diff --git a/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp b/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp
index 4bdd5c4..b8a630d 100644
--- a/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp
+++ b/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp
@@ -292,13 +292,15 @@
aliasedResults.push_back(
postambleBuilder.create<IREE::HAL::TensorAliasOp>(
barrierInput.getLoc(), barrierInput.getType(), barrierInput,
- barrierInputDims, exportStorage, waitFence));
+ barrierInputDims, exportStorage, waitFence,
+ /*affinity=*/nullptr));
} else {
aliasedResults.push_back(barrierInput);
}
}
auto barrierOp = postambleBuilder.create<IREE::HAL::TensorBarrierOp>(
- funcOp.getLoc(), aliasedResults, coarseSignalFence);
+ funcOp.getLoc(), aliasedResults, coarseSignalFence,
+ /*affinity=*/nullptr);
for (auto [barrierResult, meta] :
llvm::zip_equal(barrierOp.getResults(), barrierResultMeta)) {
Value exportStorage;
@@ -308,7 +310,8 @@
Value exportedValue = postambleBuilder.create<IREE::HAL::TensorExportOp>(
funcOp.getLoc(),
postambleBuilder.getType<IREE::HAL::BufferViewType>(), barrierResult,
- TypeAttr::get(barrierResult.getType()), StringAttr());
+ TypeAttr::get(barrierResult.getType()), /*name=*/nullptr,
+ /*affinity=*/nullptr);
if (returnIndex >= 0) {
newReturnOperands[returnIndex] = exportedValue;
}
@@ -380,7 +383,8 @@
Value importedTensor = builder.create<IREE::HAL::TensorImportOp>(
loc, builtinTensorType, argValue, TypeAttr::get(builtinTensorType),
waitFence,
- /*name=*/StringAttr());
+ /*name=*/nullptr,
+ /*affinity=*/nullptr);
if (builtinTensorType != torchType) {
importedTensor = builder.create<TorchConversion::FromBuiltinTensorOp>(
loc, torchType, importedTensor);
@@ -415,7 +419,8 @@
loc, builtinTensorType, argValue,
/*target_encoding=*/TypeAttr::get(builtinTensorType),
/*wait_fence*/ fences->first,
- /*name=*/StringAttr());
+ /*name=*/nullptr,
+ /*affinity=*/nullptr);
rewriter.replaceOpWithNewOp<TorchConversion::FromBuiltinTensorOp>(
userOp, copyToVtOp.getResult().getType(), imported);
} else if (auto overwriteOp =
diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp b/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp
index bd2702f..e91bc54 100644
--- a/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp
+++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp
@@ -26,9 +26,9 @@
static IREE::ABI::InvocationModel
getInvocationModel(Operation *op, IREE::ABI::InvocationModel defaultModel) {
auto modelAttr = op->getAttrOfType<StringAttr>("iree.abi.model");
- if (!modelAttr)
+ if (!modelAttr) {
return defaultModel;
- if (modelAttr == "coarse-fences") {
+ } else if (modelAttr == "coarse-fences") {
return IREE::ABI::InvocationModel::CoarseFences;
} else {
return IREE::ABI::InvocationModel::Sync;
@@ -51,7 +51,8 @@
for (auto attr : attrDict) {
// TODO(benvanik): faster lookup.
if (attr.getName() != "iree.abi.output" &&
- attr.getName() != "iree.abi.encoding") {
+ attr.getName() != "iree.abi.encoding" &&
+ attr.getName() != "iree.abi.affinity") {
attrs.push_back(attr);
}
}
@@ -59,6 +60,17 @@
}
}
+static void stripABIAttrs(FunctionOpInterface op) {
+ SmallVector<DictionaryAttr> argAttrs;
+ op.getAllArgAttrs(argAttrs);
+ stripABIAttrs(argAttrs);
+ op.setAllArgAttrs(argAttrs);
+ SmallVector<DictionaryAttr> resultAttrs;
+ op.getAllResultAttrs(resultAttrs);
+ stripABIAttrs(resultAttrs);
+ op.setAllResultAttrs(resultAttrs);
+}
+
// Creates the corresponding wrapper function for the given import function.
static IREE::Util::FuncOp
createImportWrapperFunc(IREE::ABI::InvocationModel invocationModel,
@@ -150,7 +162,7 @@
importOp.getLoc(), entryBuilder.getType<IREE::HAL::FenceType>(),
device, IREE::HAL::FenceFlagBitfield::None);
auto barrierOp = entryBuilder.create<IREE::HAL::TensorBarrierOp>(
- importOp.getLoc(), tensorArgs, waitFence);
+ importOp.getLoc(), tensorArgs, waitFence, /*affinity=*/nullptr);
for (auto [argIndex, readyArg] :
llvm::zip_equal(tensorArgIndices, barrierOp.getResults())) {
entryArgs[argIndex] = readyArg;
@@ -187,20 +199,24 @@
// NOTE: we insert a barrier on this above if needed so that the wait
// fence will be signaled when the tensor is ready for consumption by the
// import.
- auto encoding =
+ auto encodingAttr =
importOp.getArgAttrOfType<TypeAttr>(argIndex, "iree.abi.encoding");
- auto exportOp = entryBuilder.create<IREE::HAL::TensorExportOp>(
+ auto tensorExportOp = entryBuilder.create<IREE::HAL::TensorExportOp>(
arg.getLoc(), newType, arg,
- encoding ? encoding : TypeAttr::get(oldType), /*name=*/nullptr);
- arguments.push_back(exportOp.getTarget());
+ encodingAttr ? encodingAttr : TypeAttr::get(oldType),
+ /*name=*/nullptr,
+ /*affinity=*/nullptr);
+ arguments.push_back(tensorExportOp.getTarget());
} else {
arguments.push_back(arg);
}
}
- if (waitFence)
+ if (waitFence) {
arguments.push_back(waitFence);
- if (signalFence)
+ }
+ if (signalFence) {
arguments.push_back(signalFence);
+ }
// Make the call with the updated types.
auto callOp = entryBuilder.create<IREE::Util::CallOp>(importOp.getLoc(),
@@ -225,12 +241,14 @@
// NOTE: we set the import pending on the signal fence from the import
// indicating when the returned tensor is ready for consumption by the
// program.
- auto encoding = importOp.getResultAttrOfType<TypeAttr>(
+ auto encodingAttr = importOp.getResultAttrOfType<TypeAttr>(
resultIndex, "iree.abi.encoding");
- results.push_back(entryBuilder.create<IREE::HAL::TensorImportOp>(
+ auto tensorImportOp = entryBuilder.create<IREE::HAL::TensorImportOp>(
importOp.getLoc(), oldType, result,
- encoding ? encoding : TypeAttr::get(oldType), signalFence,
- /*name=*/nullptr));
+ encodingAttr ? encodingAttr : TypeAttr::get(oldType), signalFence,
+ /*name=*/nullptr,
+ /*affinity=*/nullptr);
+ results.push_back(tensorImportOp);
} else {
results.push_back(result);
}
@@ -285,8 +303,9 @@
auto wrapperOp = createImportWrapperFunc(
invocationModel, importOp, cast<FunctionType>(importOp.getFunctionType()),
newImportType, privateName);
- if (!wrapperOp)
+ if (!wrapperOp) {
return failure();
+ }
moduleOp.insert(++Block::iterator(importOp), wrapperOp);
// Update the import to the new type and mark it as being converted so we
@@ -302,15 +321,17 @@
static StringAttr inferArgumentName(MLIRContext *context, int index,
DictionaryAttr attrs) {
- if (auto attrName = getNameFromDictAttr(attrs))
+ if (auto attrName = getNameFromDictAttr(attrs)) {
return attrName;
+ }
return StringAttr::get(context, "input" + std::to_string(index));
}
static StringAttr inferResultName(MLIRContext *context, int index,
DictionaryAttr attrs) {
- if (auto attrName = getNameFromDictAttr(attrs))
+ if (auto attrName = getNameFromDictAttr(attrs)) {
return attrName;
+ }
return StringAttr::get(context, "output" + std::to_string(index));
}
@@ -326,8 +347,9 @@
auto shouldIncludeAttr = [](const NamedAttribute &attr) {
return attr.getName().getValue() != "iree.abi.name";
};
- if (!llvm::any_of(attrs, shouldIncludeAttr))
+ if (!llvm::any_of(attrs, shouldIncludeAttr)) {
return;
+ }
os << " {";
llvm::interleaveComma(llvm::make_filter_range(attrs, shouldIncludeAttr), os,
[&](auto argAttr) {
@@ -363,8 +385,9 @@
os << "func @" << publicName;
os << "(";
for (auto arg : exportOp.getArguments()) {
- if (arg.getArgNumber() > 0)
+ if (arg.getArgNumber() > 0) {
os << ", ";
+ }
os << "%";
os << inferArgumentName(exportOp.getContext(), arg.getArgNumber(),
getIOAttr(allArgAttrs, arg.getArgNumber()))
@@ -377,8 +400,9 @@
os << ") -> (";
for (auto [resultNumber, resultType] :
llvm::enumerate(exportOp.getResultTypes())) {
- if (resultNumber > 0)
+ if (resultNumber > 0) {
os << ", ";
+ }
os << "%";
os << inferResultName(exportOp.getContext(), resultNumber,
getIOAttr(allResultAttrs, resultNumber))
@@ -494,8 +518,9 @@
// Populate the reflection attrs based on the original types.
populateReflectionAttrs(invocationModel, exportOp, wrapperOp);
exportOp->removeAttr("iree.reflection");
- if (auto affinityAttr = exportOp->getAttr("stream.affinity"))
+ if (auto affinityAttr = exportOp->getAttr("stream.affinity")) {
wrapperOp->setAttr("stream.affinity", affinityAttr);
+ }
auto *entryBlock = wrapperOp.addEntryBlock();
auto entryBuilder = OpBuilder::atBlockBegin(entryBlock);
@@ -506,8 +531,9 @@
for (unsigned i = 0; i < exportOp.getNumArguments(); ++i) {
auto outputAttr =
exportOp.getArgAttrOfType<IntegerAttr>(i, "iree.abi.output");
- if (!outputAttr)
+ if (!outputAttr) {
continue;
+ }
// Today all outputs need to be a !hal.buffer - we could change this
// in the future to be something more generalized.
auto storageArg = entryBlock->getArgument(i);
@@ -544,14 +570,15 @@
entryBlock->getArguments().slice(0, oldExportType.getNumInputs()))) {
auto oldType = oldExportType.getInput(argIndex);
if (llvm::isa<TensorType>(oldType)) {
- auto encoding =
+ auto encodingAttr =
exportOp.getArgAttrOfType<TypeAttr>(argIndex, "iree.abi.encoding");
- auto importOp = entryBuilder.create<IREE::HAL::TensorImportOp>(
+ auto tensorImportOp = entryBuilder.create<IREE::HAL::TensorImportOp>(
arg.getLoc(), oldType, arg,
- encoding ? encoding : TypeAttr::get(oldType), waitFence,
+ encodingAttr ? encodingAttr : TypeAttr::get(oldType), waitFence,
inferArgumentName(entryBuilder.getContext(), argIndex,
- exportOp.getArgAttrDict(argIndex)));
- arguments.push_back(importOp.getTarget());
+ exportOp.getArgAttrDict(argIndex)),
+ exportOp.getArgAttr(argIndex, "iree.abi.affinity"));
+ arguments.push_back(tensorImportOp.getTarget());
} else {
arguments.push_back(arg);
}
@@ -565,14 +592,16 @@
// Alias results to storage buffers if provided.
for (unsigned resultIndex = 0; resultIndex < asyncResults.size();
++resultIndex) {
- if (!resultStorages[resultIndex])
+ if (!resultStorages[resultIndex]) {
continue;
+ }
auto source = asyncResults[resultIndex];
auto sourceDims = IREE::Util::buildDynamicDimsForValue(
exportOp.getLoc(), source, entryBuilder);
auto aliasOp = entryBuilder.create<IREE::HAL::TensorAliasOp>(
exportOp.getLoc(), source.getType(), source, sourceDims,
- resultStorages[resultIndex], waitFence);
+ resultStorages[resultIndex], waitFence,
+ exportOp.getResultAttr(resultIndex, "iree.abi.affinity"));
asyncResults[resultIndex] = cast<OpResult>(aliasOp.getResult());
}
@@ -582,8 +611,9 @@
if (signalFence) {
SmallVector<Value> asyncTensors;
for (auto result : asyncResults) {
- if (llvm::isa<TensorType>(result.getType()))
+ if (llvm::isa<TensorType>(result.getType())) {
asyncTensors.push_back(result);
+ }
}
if (asyncTensors.empty()) {
// TODO(benvanik): maybe use a global timeline? global stores may not
@@ -592,7 +622,7 @@
signalFence);
} else {
auto barrierOp = entryBuilder.create<IREE::HAL::TensorBarrierOp>(
- exportOp.getLoc(), asyncTensors, signalFence);
+ exportOp.getLoc(), asyncTensors, signalFence, /*affinity=*/nullptr);
asyncResults = llvm::to_vector(barrierOp.getResults());
}
}
@@ -603,20 +633,25 @@
auto oldType = oldExportType.getResult(resultIndex);
auto newType = newExportType.getResult(resultIndex);
if (llvm::isa<TensorType>(oldType)) {
- auto encoding = exportOp.getResultAttrOfType<TypeAttr>(
+ auto encodingAttr = exportOp.getResultAttrOfType<TypeAttr>(
resultIndex, "iree.abi.encoding");
auto dynamicDims = IREE::Util::buildDynamicDimsForValue(
result.getLoc(), result, entryBuilder);
- results.push_back(entryBuilder.create<IREE::HAL::TensorExportOp>(
+ auto tensorExportOp = entryBuilder.create<IREE::HAL::TensorExportOp>(
result.getLoc(), newType, result,
- encoding ? encoding : TypeAttr::get(result.getType()), dynamicDims,
+ encodingAttr ? encodingAttr : TypeAttr::get(result.getType()),
+ dynamicDims,
inferResultName(entryBuilder.getContext(), resultIndex,
- exportOp.getResultAttrDict(resultIndex))));
+ exportOp.getResultAttrDict(resultIndex)),
+ exportOp.getResultAttr(resultIndex, "iree.abi.affinity"));
+ results.push_back(tensorExportOp);
} else {
results.push_back(result);
}
}
+ stripABIAttrs(exportOp);
+
entryBuilder.create<IREE::Util::ReturnOp>(exportOp.getLoc(), results);
return wrapperOp;
}
@@ -643,8 +678,9 @@
// marshals arguments/results to the original function.
auto wrapperOp =
createExportWrapperFunc(invocationModel, exportOp, publicName);
- if (!wrapperOp)
+ if (!wrapperOp) {
return failure();
+ }
symbolTable.insert(wrapperOp, Block::iterator(exportOp));
return success();
diff --git a/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp b/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp
index 3e8c0fd..d21bfe0 100644
--- a/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp
+++ b/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp
@@ -227,7 +227,8 @@
auto dynamicDims = inputDynamicDims.loadDynamicDims(recalculateBuilder);
auto castOp = recalculateBuilder.create<IREE::HAL::TensorImportOp>(
loc, inputValue.getType(), inputPlaceholder, inputValue.getType(),
- dynamicDims, /*wait_fence=*/Value{}, /*name=*/nullptr);
+ dynamicDims, /*wait_fence=*/Value{}, /*name=*/nullptr,
+ /*affinity=*/nullptr);
inputValue.replaceAllUsesWith(castOp.getTarget());
}
while (entryBlock.getNumArguments() > 0) {
@@ -525,7 +526,8 @@
callOperands.push_back(entryBuilder.create<IREE::HAL::TensorImportOp>(
arg.getLoc(), inputDynamicDims.tensorType, arg,
TypeAttr::get(inputDynamicDims.tensorType), dynamicDims,
- /*wait_fence=*/Value{}, /*name=*/nullptr));
+ /*wait_fence=*/Value{}, /*name=*/nullptr,
+ /*affinity=*/nullptr));
}
auto callOp = entryBuilder.create<IREE::Util::CallOp>(
entryFuncOp.getLoc(), entryFuncOp, callOperands);
@@ -541,7 +543,7 @@
}
callResults.push_back(entryBuilder.create<IREE::HAL::TensorExportOp>(
result.getLoc(), bufferType, result, outputDynamicDims.tensorType,
- dynamicDims, /*name=*/nullptr));
+ dynamicDims, /*name=*/nullptr, /*affinity=*/nullptr));
for (auto [dynamicDim, globalOp] :
llvm::zip_equal(dynamicDims, outputDynamicDims.globalOps)) {
globalOp.createStoreOp(result.getLoc(), dynamicDim, entryBuilder);
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
index a2079c3..d9a6a0f 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
@@ -288,7 +288,6 @@
struct ElideRedundantOperandsOfWorkgroupCountFromSliceOp
: OpRewritePattern<DispatchWorkgroupsOp> {
using OpRewritePattern::OpRewritePattern;
-
LogicalResult matchAndRewrite(DispatchWorkgroupsOp op,
PatternRewriter &rewriter) const override {
Region &count = op.getWorkgroupCount();
@@ -369,7 +368,6 @@
// Bubble up the ordinal ops so that all uses go through this operation.
struct BubbleUpOrdinalOp : public OpRewritePattern<DispatchWorkloadOrdinalOp> {
using OpRewritePattern::OpRewritePattern;
-
LogicalResult matchAndRewrite(DispatchWorkloadOrdinalOp ordinalOp,
PatternRewriter &rewriter) const override {
auto blockArg = llvm::dyn_cast<BlockArgument>(ordinalOp.getOperand());
@@ -894,7 +892,6 @@
template <typename CastOpTy>
struct FlattenTensorCastLikeChain : public OpRewritePattern<CastOpTy> {
using OpRewritePattern<CastOpTy>::OpRewritePattern;
-
LogicalResult matchAndRewrite(CastOpTy reshapeOp,
PatternRewriter &rewriter) const override {
// We want the same result value/shape but to source from the ancestor. We
@@ -1294,7 +1291,6 @@
// to be updated to use the source of the cast as the target tensor.
struct FoldTensorUpdateOpWithCasts : public OpRewritePattern<TensorUpdateOp> {
using OpRewritePattern<TensorUpdateOp>::OpRewritePattern;
-
LogicalResult matchAndRewrite(TensorUpdateOp updateOp,
PatternRewriter &rewriter) const override {
auto targetCastOp = updateOp.getTarget().getDefiningOp<tensor::CastOp>();
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
index bcb1dbb..cf1e341 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
@@ -351,8 +351,8 @@
// -----
-// CHECK-LABEL: @cloneConst
-util.func public @cloneConst() -> tensor<4xi32> {
+// CHECK-LABEL: @cloneConstant
+util.func public @cloneConstant() -> tensor<4xi32> {
// CHECK-NEXT: %[[C:.+]] = arith.constant dense<[0, 1, 2, 3]> : tensor<4xi32>
%0 = arith.constant dense<[0, 1, 2, 3]> : tensor<4xi32>
%1 = flow.tensor.clone %0 : tensor<4xi32>
@@ -362,8 +362,8 @@
// -----
-// CHECK-LABEL: @cloneConstZeroElements
-util.func public @cloneConstZeroElements() -> tensor<0x2xi32> {
+// CHECK-LABEL: @cloneConstantZeroElements
+util.func public @cloneConstantZeroElements() -> tensor<0x2xi32> {
// CHECK-NEXT: %[[C:.+]] = arith.constant dense<> : tensor<0x2xi32>
%0 = arith.constant dense<> : tensor<0x2xi32>
// CHECK-NOT: flow.tensor.clone
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir
index b0a19ad..8b01c00 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir
@@ -7,6 +7,8 @@
util.return %0 : tensor<16xf32>
}
+// -----
+
// CHECK-LABEL: @tensorReshapeScalar
util.func public @tensorReshapeScalar(%arg0 : tensor<f32>) -> tensor<f32> {
// CHECK-NEXT: %0 = flow.tensor.reshape %arg0 : tensor<f32> -> tensor<f32>
@@ -14,6 +16,8 @@
util.return %0 : tensor<f32>
}
+// -----
+
// CHECK-LABEL: @tensorReshapeDynamic
util.func public @tensorReshapeDynamic(%arg0 : tensor<?x4xf32>) -> tensor<?x2xf32> {
%c4 = arith.constant 4 : index
@@ -23,6 +27,8 @@
util.return %0 : tensor<?x2xf32>
}
+// -----
+
// CHECK-LABEL: @tensorReshapeComplex
util.func public @tensorReshapeComplex(%arg0 : tensor<4x4xcomplex<f32>>) -> tensor<16xcomplex<f32>> {
// CHECK-NEXT: flow.tensor.reshape %arg0 : tensor<4x4xcomplex<f32>> -> tensor<16xcomplex<f32>>
@@ -48,6 +54,8 @@
util.return %0 : f32
}
+// -----
+
// CHECK-LABEL: @tensorLoadScalar
util.func public @tensorLoadScalar(%arg0 : tensor<f32>) -> f32 {
// CHECK-NEXT: %0 = flow.tensor.load %arg0 : tensor<f32>
@@ -55,6 +63,8 @@
util.return %0 : f32
}
+// -----
+
// CHECK-LABEL: @tensorLoadDynamic
util.func public @tensorLoadDynamic(%arg0 : tensor<?x4xf32>, %arg1 : index, %arg2 : index) -> f32 {
%c4 = arith.constant 4 : index
@@ -72,6 +82,8 @@
util.return %0 : tensor<4x4xf32>
}
+// -----
+
// CHECK-LABEL: @tensorStoreScalar
util.func public @tensorStoreScalar(%arg0 : f32, %arg1 : tensor<f32>) -> tensor<f32> {
// CHECK-NEXT: %0 = flow.tensor.store %arg0, %arg1 : tensor<f32>
@@ -79,6 +91,8 @@
util.return %0 : tensor<f32>
}
+// -----
+
// CHECK-LABEL: @tensorStoreDynamic
util.func public @tensorStoreDynamic(%arg0 : tensor<?x4xf32>, %arg1 : index, %arg2 : index, %arg3 : f32) -> tensor<?x4xf32> {
%c4 = arith.constant 4 : index
@@ -114,6 +128,8 @@
util.return %0 : tensor<4x4xf32>
}
+// -----
+
// CHECK-LABEL: @tensorSplatScalar
util.func public @tensorSplatScalar(%arg0 : f32) -> tensor<f32> {
// CHECK-NEXT: %0 = flow.tensor.splat %arg0 : tensor<f32>
@@ -121,6 +137,8 @@
util.return %0 : tensor<f32>
}
+// -----
+
// CHECK-LABEL: @tensorSplatDynamic
util.func public @tensorSplatDynamic(%arg0 : f32) -> tensor<?x4xf32> {
%c4 = arith.constant 4 : index
@@ -138,6 +156,8 @@
util.return %0 : tensor<4x4xf32>
}
+// -----
+
// CHECK-LABEL: @tensorCloneScalar
util.func public @tensorCloneScalar(%arg0 : tensor<f32>) -> tensor<f32> {
// CHECK-NEXT: %0 = flow.tensor.clone %arg0 : tensor<f32>
@@ -145,6 +165,8 @@
util.return %0 : tensor<f32>
}
+// -----
+
// CHECK-LABEL: @tensorCloneDynamic
util.func public @tensorCloneDynamic(%arg0 : tensor<?x4xf32>) -> tensor<?x4xf32> {
%c4 = arith.constant 4 : index
@@ -162,6 +184,8 @@
util.return %0 : tensor<2x2xf32>
}
+// -----
+
// CHECK-LABEL: @tensorSliceDynamic
util.func public @tensorSliceDynamic(%arg0 : tensor<?x4xf32>, %arg1 : index, %arg2 : index) -> tensor<?x2xf32> {
%c2 = arith.constant 2 : index
@@ -180,6 +204,8 @@
util.return %0 : tensor<4x4xf32>
}
+// -----
+
// CHECK-LABEL: @tensorUpdateDynamic
util.func public @tensorUpdateDynamic(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x4xf32>, %arg2 : index, %arg3 : index) -> tensor<?x4xf32> {
%c1 = arith.constant 1 : index
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp
index ecfb86a..3967d50 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp
@@ -81,7 +81,8 @@
// hal.tensor.export
auto bufferExportOp = initializerBuilder.create<IREE::HAL::TensorExportOp>(
loc, globalOp.getType(), splatOp.getResult(),
- TypeAttr::get(splatOp.getType()), /*name=*/nullptr);
+ TypeAttr::get(splatOp.getType()), /*name=*/nullptr,
+ /*affinity=*/nullptr);
// util.optimization_barrier (try to prevent optimizations across the export)
auto barrierOp = initializerBuilder.create<IREE::Util::OptimizationBarrierOp>(
loc, bufferExportOp.getTarget());
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/InsertDispatchDebugTargets.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/InsertDispatchDebugTargets.cpp
index beb995d..49212b5 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/InsertDispatchDebugTargets.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/InsertDispatchDebugTargets.cpp
@@ -103,7 +103,8 @@
if (llvm::isa<TensorType>(retVal.getType())) {
auto type = IREE::HAL::BufferViewType::get(context);
auto exportOp = builder.create<IREE::HAL::TensorExportOp>(
- loc, type, retVal, TypeAttr::get(retVal.getType()), /*name=*/nullptr);
+ loc, type, retVal, TypeAttr::get(retVal.getType()), /*name=*/nullptr,
+ /*affinity=*/nullptr);
exports.push_back(exportOp.getResult());
newTypes.push_back(type);
} else {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
index d082847..23dab24 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
@@ -96,8 +96,8 @@
}
if (orderedSources.size() == op.getSources().size())
return failure();
- auto newOp = rewriter.create<TensorBarrierOp>(op.getLoc(), orderedSources,
- op.getSignalFence());
+ auto newOp = rewriter.create<TensorBarrierOp>(
+ op.getLoc(), orderedSources, op.getSignalFence(), op.getAffinityAttr());
SmallVector<Value> newResults;
newResults.reserve(newOp.getNumResults());
for (unsigned newIndex : resultMapping) {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index 07ff982..e28025e 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -431,15 +431,16 @@
void TensorImportOp::build(OpBuilder &builder, OperationState &result,
Type resultType, Value source,
- TypeAttr targetEncoding, StringAttr name) {
+ TypeAttr targetEncoding, StringAttr name,
+ Attribute affinity) {
build(builder, result, resultType, source, targetEncoding,
- /*waitFence=*/Value{}, name);
+ /*waitFence=*/Value{}, name, affinity);
}
void TensorImportOp::build(OpBuilder &builder, OperationState &result,
Type resultType, Value source,
TypeAttr targetEncoding, Value waitFence,
- StringAttr name) {
+ StringAttr name, Attribute affinity) {
auto shapedType = llvm::cast<ShapedType>(resultType);
assert((isa<IREE::HAL::BufferViewType>(source.getType()) ||
shapedType.hasStaticShape()) &&
@@ -454,7 +455,7 @@
builder.getIndexAttr(i)));
}
build(builder, result, resultType, source, targetEncoding, dynamicDims,
- waitFence, name);
+ waitFence, name, affinity);
}
Value TensorImportOp::getTiedResult(unsigned resultIndex) {
@@ -530,10 +531,12 @@
void TensorExportOp::build(OpBuilder &builder, OperationState &result,
Type resultType, Value source,
- TypeAttr sourceEncoding, StringAttr name) {
+ TypeAttr sourceEncoding, StringAttr name,
+ Attribute affinity) {
auto dynamicDims =
IREE::Util::buildDynamicDimsForValue(result.location, source, builder);
- build(builder, result, resultType, source, sourceEncoding, dynamicDims, name);
+ build(builder, result, resultType, source, sourceEncoding, dynamicDims, name,
+ affinity);
}
Value TensorExportOp::getTiedResult(unsigned resultIndex) {
@@ -592,10 +595,11 @@
//===----------------------------------------------------------------------===//
void TensorBarrierOp::build(OpBuilder &builder, OperationState &result,
- ValueRange sources, Value signalFence) {
+ ValueRange sources, Value signalFence,
+ Attribute affinity) {
auto resultTypes = llvm::map_to_vector(
sources, [](Value source) { return source.getType(); });
- build(builder, result, resultTypes, sources, signalFence);
+ build(builder, result, resultTypes, sources, signalFence, affinity);
}
Value TensorBarrierOp::getTiedResult(unsigned resultIndex) {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
index 599c1ff..d35a0cb 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -125,13 +125,15 @@
TypeAttr:$target_encoding,
HAL_ShapeDynamicDims:$target_dims,
Optional<HAL_Fence>:$wait_fence,
- OptionalAttr<StrAttr>:$name
+ OptionalAttr<StrAttr>:$name,
+ OptionalAttr<AnyAttr>:$affinity
);
let results = (outs
AnyTensor:$target
);
let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
(`wait` `(` $wait_fence^ `)` `=` `` `>`)?
$source
($name^)?
@@ -145,14 +147,16 @@
"Type":$resultType,
"Value":$source,
"TypeAttr":$targetEncoding,
- "StringAttr":$name
+ "StringAttr":$name,
+ "Attribute":$affinity
)>,
OpBuilder<(ins
"Type":$resultType,
"Value":$source,
"TypeAttr":$targetEncoding,
"Value":$waitFence,
- "StringAttr":$name
+ "StringAttr":$name,
+ "Attribute":$affinity
)>,
];
@@ -190,13 +194,15 @@
AnyTensor:$source,
TypeAttr:$source_encoding,
HAL_ShapeDynamicDims:$source_dims,
- OptionalAttr<StrAttr>:$name
+ OptionalAttr<StrAttr>:$name,
+ OptionalAttr<AnyAttr>:$affinity
);
let results = (outs
AnyTypeOf<[HAL_Buffer, HAL_BufferView]>:$target
);
let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
$source
($name^)?
`:`
@@ -211,7 +217,8 @@
"Type":$resultType,
"Value":$source,
"TypeAttr":$sourceEncoding,
- "StringAttr":$name
+ "StringAttr":$name,
+ "Attribute":$affinity
)>,
];
@@ -273,13 +280,15 @@
AnyTensor:$source,
HAL_ShapeDynamicDims:$source_dims,
AnyTypeOf<[HAL_Buffer, HAL_BufferView]>:$storage,
- Optional<HAL_Fence>:$wait_fence
+ Optional<HAL_Fence>:$wait_fence,
+ OptionalAttr<AnyAttr>:$affinity
);
let results = (outs
AnyTensor:$result
);
let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
(`wait` `(` $wait_fence^ `)` `=` `` `>`)?
$source `:` type($source) (`{` $source_dims^ `}`)?
`to`
@@ -311,13 +320,15 @@
let arguments = (ins
Variadic<AnyTensor>:$sources,
- HAL_Fence:$signal_fence
+ HAL_Fence:$signal_fence,
+ OptionalAttr<AnyAttr>:$affinity
);
let results = (outs
Variadic<AnyTensor>:$results
);
let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
`join` `` `(` $sources `:` type($sources) `)`
`=` `` `>`
$signal_fence `:` type($signal_fence)
@@ -327,7 +338,8 @@
let builders = [
OpBuilder<(ins
"ValueRange":$sources,
- "Value":$signalFence
+ "Value":$signalFence,
+ "Attribute":$affinity
)>,
];
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
index 060aeb8..457f09f 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
@@ -28,7 +28,7 @@
ConversionPatternRewriter &rewriter) {
// TODO(benvanik): see if we can stash this on the side to avoid expensive
// materialization of a bunch of redundant IR.
- return rewriter.createOrFold<IREE::Stream::TensorSizeOfOp>(
+ return rewriter.create<IREE::Stream::TensorSizeOfOp>(
loc, rewriter.getIndexType(), TypeAttr::get(tensorValue.getType()),
dynamicDims,
IREE::Stream::AffinityAttr::lookup(tensorValue.getDefiningOp()));
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp
index 2acd3ba..35eb31f 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp
@@ -49,11 +49,16 @@
}
}
+ auto affinityAttr =
+ dyn_cast_if_present<IREE::Stream::AffinityAttr>(op.getAffinityAttr());
+ if (!affinityAttr) {
+ affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
+ }
+
// Import (buffer view to stream resource).
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
auto resultType = rewriter.getType<IREE::Stream::ResourceType>(
IREE::Stream::Lifetime::External);
- auto resultSize = rewriter.createOrFold<IREE::Stream::TensorSizeOfOp>(
+ Value resultSize = rewriter.create<IREE::Stream::TensorSizeOfOp>(
op.getLoc(), rewriter.getIndexType(),
TypeAttr::get(op.getTarget().getType()), adaptor.getTargetDims(),
affinityAttr);
@@ -77,7 +82,7 @@
auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
rewriter.replaceOpWithNewOp<IREE::Stream::AsyncTransferOp>(
op, unknownType, resource, resultSize, resultSize, affinityAttr,
- affinityAttr);
+ /*target_affinity=*/IREE::Stream::AffinityAttr{});
return success();
}
@@ -133,7 +138,11 @@
return rewriter.notifyMatchFailure(op, "unsupported HAL cast conversion");
}
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
+ auto affinityAttr =
+ dyn_cast_if_present<IREE::Stream::AffinityAttr>(op.getAffinityAttr());
+ if (!affinityAttr) {
+ affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
+ }
auto source =
consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter);
@@ -145,7 +154,8 @@
if (source.resource.getType() != externalType) {
exportSource = rewriter.create<IREE::Stream::AsyncTransferOp>(
op.getLoc(), externalType, source.resource, source.resourceSize,
- source.resourceSize, affinityAttr, affinityAttr);
+ source.resourceSize, /*source_affinity=*/IREE::Stream::AffinityAttr{},
+ affinityAttr);
}
// Export (stream resource to buffer view).
@@ -179,11 +189,15 @@
// All operations (if any) will happen on the device specified by the alias
// as that indicates the affinity of the storage.
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
+ auto affinityAttr =
+ dyn_cast_if_present<IREE::Stream::AffinityAttr>(op.getAffinityAttr());
+ if (!affinityAttr) {
+ affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
+ }
// Query the target storage buffer length; we will only populate up to
// what is required for the output.
- auto storageSize = rewriter.createOrFold<IREE::Stream::TensorSizeOfOp>(
+ Value storageSize = rewriter.create<IREE::Stream::TensorSizeOfOp>(
op.getLoc(), rewriter.getIndexType(),
TypeAttr::get(op.getSource().getType()), adaptor.getSourceDims(),
affinityAttr);
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
index dbe9abc..24d939a 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
@@ -1183,7 +1183,7 @@
return failure();
// Definitely empty if here.
- auto resultSize = rewriter.createOrFold<IREE::Stream::TensorSizeOfOp>(
+ Value resultSize = rewriter.create<IREE::Stream::TensorSizeOfOp>(
constantOp.getLoc(), rewriter.getIndexType(),
TypeAttr::get(constantOp.getResultEncoding()),
constantOp.getResultEncodingDims(), constantOp.getAffinityAttr());
@@ -1219,7 +1219,7 @@
}
auto resultType = IREE::Stream::ResourceType::get(constantOp.getContext());
- auto resultSize = rewriter.createOrFold<IREE::Stream::TensorSizeOfOp>(
+ Value resultSize = rewriter.create<IREE::Stream::TensorSizeOfOp>(
constantOp.getLoc(), rewriter.getIndexType(),
TypeAttr::get(constantOp.getResultEncoding()),
constantOp.getResultEncodingDims(), constantOp.getAffinityAttr());
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp
index 11873a2..6e26e24 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp
@@ -55,7 +55,7 @@
// This may differ from the external encoding of the tensor as imports are
// a transfer operation that may need to reformat the tensor.
auto encodingAttr = TypeAttr::get(sourceTensor.getType());
- auto resultSize = builder.createOrFold<IREE::Stream::TensorSizeOfOp>(
+ Value resultSize = builder.create<IREE::Stream::TensorSizeOfOp>(
loc, builder.getIndexType(), encodingAttr, dynamicDims,
/*affinity=*/nullptr);
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp
index b4b708a..318160f 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp
@@ -228,23 +228,23 @@
// op order issues.
SmallVector<std::map<StringRef, SmallVector<Operation *>>> sequencedBuckets;
sequencedBuckets.push_back({}); // Start in a sequence.
- block.walk([&](Operation *op) {
+ for (auto &op : block) {
auto &buckets = sequencedBuckets.back();
if (auto loadOp = dyn_cast<IREE::Util::GlobalLoadOpInterface>(op)) {
if (!immutableGlobals.contains(loadOp.getGlobalName())) {
- buckets[loadOp.getGlobalName()].push_back(op);
+ buckets[loadOp.getGlobalName()].push_back(&op);
}
} else if (auto storeOp =
dyn_cast<IREE::Util::GlobalStoreOpInterface>(op)) {
- buckets[storeOp.getGlobalName()].push_back(op);
- } else if (doesOpBlockMotion(op)) {
+ buckets[storeOp.getGlobalName()].push_back(&op);
+ } else if (doesOpBlockMotion(&op)) {
// Split point - all accesses after this point must not assume anything
// about accesses before it.
if (!buckets.empty()) {
sequencedBuckets.push_back({});
}
}
- });
+ }
bool didRemoveAny = false;
for (auto &buckets : sequencedBuckets) {
didRemoveAny = optimizeBuckets(block, buckets) || didRemoveAny;
diff --git a/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp b/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp
index 3200829..6e8ecc0 100644
--- a/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp
+++ b/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp
@@ -209,13 +209,15 @@
// the work.
rewriter.replaceOpWithNewOp<IREE::HAL::TensorImportOp>(
srcOp, resultType, adaptor.getSource(), TypeAttr::get(resultType),
- /*name=*/nullptr);
+ /*name=*/nullptr,
+ /*affinity=*/nullptr);
} else {
// Dynamic dims explicitly provided (or wrong, in which case the verifier
// will get it).
rewriter.replaceOpWithNewOp<IREE::HAL::TensorImportOp>(
srcOp, resultType, adaptor.getSource(), TypeAttr::get(resultType),
- adaptor.getTargetDims(), /*wait_fence=*/Value{}, /*name=*/nullptr);
+ adaptor.getTargetDims(), /*wait_fence=*/Value{}, /*name=*/nullptr,
+ /*affinity=*/nullptr);
}
return success();
}
@@ -237,14 +239,16 @@
// the work.
rewriter.replaceOpWithNewOp<IREE::HAL::TensorExportOp>(
srcOp, resultType, adaptor.getSource(),
- TypeAttr::get(adaptor.getSource().getType()), /*name=*/nullptr);
+ TypeAttr::get(adaptor.getSource().getType()), /*name=*/nullptr,
+ /*affinity=*/nullptr);
} else {
// Dynamic dims explicitly provided (or wrong, in which case the verifier
// will get it).
rewriter.replaceOpWithNewOp<IREE::HAL::TensorExportOp>(
srcOp, resultType, adaptor.getSource(),
TypeAttr::get(adaptor.getSource().getType()), adaptor.getSourceDims(),
- /*name=*/nullptr);
+ /*name=*/nullptr,
+ /*affinity=*/nullptr);
}
return success();
}