Changing stream conversion to use a value/op affinity analysis.
This reworks some of the prior stack to support transfer ops and analysis
to determine the placement of ops for execution and resource control.
diff --git a/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp b/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp
index b8a630d..553dbde 100644
--- a/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp
+++ b/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp
@@ -124,9 +124,17 @@
FENCE,
};
+struct BarrierResult {
+ BlockArgument storage;
+ Type torchType;
+ int returnIndex = -1;
+};
+
struct ConvertedAsyncFunctionInfo {
IREE::Util::FuncOp funcOp;
SmallVector<IREE::Util::ReturnOp> returnOps;
+ SmallVector<DictionaryAttr> torchArgAttrs;
+ SmallVector<DictionaryAttr> torchResultAttrs;
SmallVector<Type> torchInputTypes;
SmallVector<Type> torchResultTypes;
SmallVector<TypeDisposition> inputDispositions;
@@ -136,7 +144,7 @@
// Values that must be captured in the coarse barrier.
SmallVector<Value> barrierInputs;
// Meta data per barrier input: storage, torchType, returnIndex (or -1)
- SmallVector<std::tuple<Value, Type, int>> barrierResultMeta;
+ SmallVector<BarrierResult> barrierResultMeta;
LogicalResult postProcess();
LogicalResult convertImmutableTensorArg(BlockArgument argValue,
@@ -144,10 +152,25 @@
LogicalResult convertMutableTensorArg(BlockArgument argValue, Type torchType,
OpBuilder &builder);
- void addBarrierInput(Value inputTensor, Value storage, Type torchType,
+ void addBarrierInput(Value inputTensor, BlockArgument storage, Type torchType,
int returnIndex) {
barrierInputs.push_back(inputTensor);
- barrierResultMeta.emplace_back(storage, torchType, returnIndex);
+ barrierResultMeta.emplace_back(BarrierResult{
+ storage,
+ torchType,
+ returnIndex,
+ });
+ }
+
+ Attribute getTorchArgAttr(BlockArgument argValue, StringRef attrName) {
+ return torchArgAttrs.empty()
+ ? Attribute{}
+ : torchArgAttrs[argValue.getArgNumber()].get(attrName);
+ }
+ Attribute getTorchResultAttr(int returnIndex, StringRef attrName) {
+ return torchResultAttrs.empty()
+ ? Attribute{}
+ : torchResultAttrs[returnIndex].get(attrName);
}
};
@@ -232,7 +255,8 @@
}
if (needsBarrier) {
Value source = convertToBuiltinTensor(postambleBuilder, returnValue);
- addBarrierInput(source, /*storage=*/Value{}, torchType, returnIndex);
+ addBarrierInput(source, /*storage=*/BlockArgument{}, torchType,
+ returnIndex);
}
break;
}
@@ -276,15 +300,13 @@
SmallVector<Value> aliasedResults;
for (auto [barrierInput, meta] :
llvm::zip_equal(barrierInputs, barrierResultMeta)) {
- Value exportStorage;
- Type torchType;
- int returnIndex;
- std::tie(exportStorage, torchType, returnIndex) = meta;
- if (exportStorage) {
+ if (meta.storage) {
// Use the wait fence indicating when the storage is available for
// mutation. We need to ensure that no writes are made to the storage
// until it indicates it's safe to do so.
- auto waitSignalFences = getEnclosingWaitSignalFences(exportStorage);
+ auto storageAffinityAttr =
+ getTorchArgAttr(meta.storage, "iree.abi.affinity");
+ auto waitSignalFences = getEnclosingWaitSignalFences(meta.storage);
assert(waitSignalFences && "async function missing fences");
Value waitFence = waitSignalFences->first;
auto barrierInputDims = IREE::Util::buildDynamicDimsForValue(
@@ -292,28 +314,30 @@
aliasedResults.push_back(
postambleBuilder.create<IREE::HAL::TensorAliasOp>(
barrierInput.getLoc(), barrierInput.getType(), barrierInput,
- barrierInputDims, exportStorage, waitFence,
- /*affinity=*/nullptr));
+ barrierInputDims, meta.storage, waitFence,
+ storageAffinityAttr));
} else {
aliasedResults.push_back(barrierInput);
}
}
auto barrierOp = postambleBuilder.create<IREE::HAL::TensorBarrierOp>(
- funcOp.getLoc(), aliasedResults, coarseSignalFence,
- /*affinity=*/nullptr);
+ funcOp.getLoc(), aliasedResults, coarseSignalFence);
for (auto [barrierResult, meta] :
llvm::zip_equal(barrierOp.getResults(), barrierResultMeta)) {
- Value exportStorage;
- Type torchType;
- int returnIndex;
- std::tie(exportStorage, torchType, returnIndex) = meta;
+ Attribute exportAffinityAttr;
+ if (meta.storage) {
+ exportAffinityAttr = getTorchArgAttr(meta.storage, "iree.abi.affinity");
+ } else if (meta.returnIndex >= 0) {
+ exportAffinityAttr =
+ getTorchResultAttr(meta.returnIndex, "iree.abi.affinity");
+ }
Value exportedValue = postambleBuilder.create<IREE::HAL::TensorExportOp>(
funcOp.getLoc(),
postambleBuilder.getType<IREE::HAL::BufferViewType>(), barrierResult,
TypeAttr::get(barrierResult.getType()), /*name=*/nullptr,
- /*affinity=*/nullptr);
- if (returnIndex >= 0) {
- newReturnOperands[returnIndex] = exportedValue;
+ exportAffinityAttr);
+ if (meta.returnIndex >= 0) {
+ newReturnOperands[meta.returnIndex] = exportedValue;
}
}
}
@@ -377,14 +401,16 @@
<< torchType;
}
+ // Propagate explicit affinities to the read.
+ auto affinityAttr = getTorchArgAttr(argValue, "iree.abi.affinity");
+
auto waitSignalFences = getEnclosingWaitSignalFences(argValue);
assert(waitSignalFences && "async function missing fences");
Value waitFence = waitSignalFences->first;
Value importedTensor = builder.create<IREE::HAL::TensorImportOp>(
loc, builtinTensorType, argValue, TypeAttr::get(builtinTensorType),
waitFence,
- /*name=*/nullptr,
- /*affinity=*/nullptr);
+ /*name=*/nullptr, affinityAttr);
if (builtinTensorType != torchType) {
importedTensor = builder.create<TorchConversion::FromBuiltinTensorOp>(
loc, torchType, importedTensor);
@@ -408,6 +434,9 @@
.toBuiltinTensor();
}
+ // Propagate explicit affinities to the read and write.
+ auto affinityAttr = getTorchArgAttr(argValue, "iree.abi.affinity");
+
// There are only a small set of possible users of a mutable tensor.
// Handle them by operation here.
SmallVector<Operation *> users(argValue.getUsers());
@@ -419,8 +448,7 @@
loc, builtinTensorType, argValue,
/*target_encoding=*/TypeAttr::get(builtinTensorType),
/*wait_fence*/ fences->first,
- /*name=*/nullptr,
- /*affinity=*/nullptr);
+ /*name=*/nullptr, affinityAttr);
rewriter.replaceOpWithNewOp<TorchConversion::FromBuiltinTensorOp>(
userOp, copyToVtOp.getResult().getType(), imported);
} else if (auto overwriteOp =
@@ -444,7 +472,6 @@
// Allowlist of function attributes to retain when importing funcs.
constexpr const char *kRetainedAttributes[] = {
"iree.reflection",
- "stream.affinity",
};
auto retainedAttributes = ArrayRef<const char *>(
kRetainedAttributes,
@@ -476,6 +503,9 @@
syncFuncOp.setSymVisibilityAttr(asyncFuncOp.getSymVisibilityAttr());
retainFunctionAttributes(asyncFuncOp, syncFuncOp);
syncFuncOp->setAttr("iree.abi.stub", rewriter.getUnitAttr());
+ if (auto affinityAttr = asyncFuncOp->getAttr("iree.abi.affinity")) {
+ syncFuncOp->setAttr("iree.abi.affinity", affinityAttr);
+ }
Block *entryBlock = syncFuncOp.addEntryBlock();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToEnd(entryBlock);
@@ -584,6 +614,10 @@
asyncFunctionName.append("$async");
}
+ // Stash arg/result attrs so they can be referenced during conversion.
+ torchFunc.getAllArgAttrs(convertedFuncInfo.torchArgAttrs);
+ torchFunc.getAllResultAttrs(convertedFuncInfo.torchResultAttrs);
+
// Convert function signature.
Type fenceType = rewriter.getType<IREE::HAL::FenceType>();
FunctionType torchFuncType = torchFunc.getFunctionType();
@@ -644,6 +678,9 @@
asyncFuncOp->setAttr("iree.abi.stub", rewriter.getUnitAttr());
asyncFuncOp->setAttr("iree.abi.model",
rewriter.getStringAttr("coarse-fences"));
+ if (auto affinityAttr = torchFunc->getAttr("iree.abi.affinity")) {
+ asyncFuncOp->setAttr("iree.abi.affinity", affinityAttr);
+ }
rewriter.inlineRegionBefore(
torchFunc.getBody(), asyncFuncOp.getFunctionBody(), asyncFuncOp.end());
diff --git a/compiler/plugins/input/Torch/InputConversion/test/func_conversion.mlir b/compiler/plugins/input/Torch/InputConversion/test/func_conversion.mlir
index 3e167ad..3bca01b 100644
--- a/compiler/plugins/input/Torch/InputConversion/test/func_conversion.mlir
+++ b/compiler/plugins/input/Torch/InputConversion/test/func_conversion.mlir
@@ -111,6 +111,37 @@
}
// -----
+// Tests the immutable + mutable argument case with explicit affinities.
+// CHECK-LABEL: @mutable_input_overwrite_no_return
+// CHECK: util.func public @main$async(
+// CHECK-SAME: %arg0: !hal.buffer_view, %arg1: !hal.buffer_view,
+// CHECK-SAME: %arg2: !hal.fence, %arg3: !hal.fence) -> !hal.buffer_view
+// CHECK-DAG: %[[WAIT_ARG0:.+]] = hal.tensor.import on(#hal.device.promise<@dev_a>) wait(%arg2) => %arg0
+// CHECK-DAG: %[[TORCH_ARG0:.+]] = torch_c.from_builtin_tensor %[[WAIT_ARG0]]
+// CHECK-DAG: %[[WAIT_ARG1:.+]] = hal.tensor.import on(#hal.device.promise<@dev_b>) wait(%arg2) => %arg1
+// CHECK-DAG: %[[TORCH_ARG1:.+]] = torch_c.from_builtin_tensor %[[WAIT_ARG1]]
+// CHECK-DAG: %[[TORCH_RESULT0:.+]] = torch.operator "other_calc"(%[[TORCH_ARG0]])
+// CHECK-DAG: %[[TORCH_RESULT1:.+]] = torch.operator "mutate_inplace"(%[[TORCH_ARG1]])
+// CHECK-DAG: %[[TENSOR_ARG0:.+]] = torch_c.to_builtin_tensor %[[TORCH_RESULT0]]
+// CHECK-DAG: %[[TENSOR_ARG1:.+]] = torch_c.to_builtin_tensor %[[TORCH_RESULT1]]
+// CHECK: %[[EXPORT_ALIAS1:.+]] = hal.tensor.alias on(#hal.device.promise<@dev_b>) wait(%arg2) => %[[TENSOR_ARG1]] : tensor<5x4xf32> to %arg1 : !hal.buffer_view
+// CHECK: %[[BARRIER_RESULTS:.+]]:2 = hal.tensor.barrier join(%[[EXPORT_ALIAS1]], %[[TENSOR_ARG0]] : tensor<5x4xf32>, tensor<4x5xi32>) => %arg3 : !hal.fence
+// CHECK-DAG: %[[EXPORT_RESULT0:.+]] = hal.tensor.export on(#hal.device.promise<@dev_b>) %[[BARRIER_RESULTS]]#0
+// CHECK-DAG: %[[EXPORT_RESULT1:.+]] = hal.tensor.export on(#hal.device.promise<@dev_a>) %[[BARRIER_RESULTS]]#1
+// CHECK: util.return %[[EXPORT_RESULT1]]
+builtin.module @mutable_input_overwrite_no_return_affinities {
+func.func @main(%arg0: !torch.vtensor<[4,5],si32> {iree.abi.affinity = #hal.device.promise<@dev_a>},
+ %arg1: !torch.tensor<[5,4],f32> {iree.abi.affinity = #hal.device.promise<@dev_b>})
+ -> (!torch.vtensor<[4,5],si32> {iree.abi.affinity = #hal.device.promise<@dev_a>}) {
+ %0 = torch.copy.to_vtensor %arg1 : !torch.vtensor<[5,4],f32>
+ %1 = torch.operator "mutate_inplace"(%0) : (!torch.vtensor<[5,4],f32>) -> !torch.vtensor<[5,4],f32>
+ %2 = torch.operator "other_calc"(%arg0) : (!torch.vtensor<[4,5],si32>) -> !torch.vtensor<[4,5],si32>
+ torch.overwrite.tensor.contents %1 overwrites %arg1 : !torch.vtensor<[5,4],f32>, !torch.tensor<[5,4],f32>
+ return %2 : !torch.vtensor<[4,5],si32>
+}
+}
+
+// -----
// CHECK-LABEL: @retained_attribute_reflection
// CHECK: util.func public @main$async(
// CHECK-SAME: iree.reflection = {some.attr = 4 : index}
diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp b/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp
index e91bc54..577cf52 100644
--- a/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp
+++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp
@@ -43,16 +43,22 @@
return type;
}
+// Returns true if the given |attr| is a known ABI attribute that is only used
+// by this pass.
+static bool isABIAttr(NamedAttribute attr) {
+ return attr.getName() == "iree.abi.affinity" ||
+ attr.getName() == "iree.abi.encoding" ||
+ attr.getName() == "iree.abi.model" ||
+ attr.getName() == "iree.abi.output";
+}
+
// Removes all ABI attrs handled by this pass from all dictionaries.
static void stripABIAttrs(SmallVectorImpl<DictionaryAttr> &allAttrs) {
for (auto &attrDict : allAttrs) {
SmallVector<NamedAttribute> attrs;
attrs.reserve(attrDict.size());
for (auto attr : attrDict) {
- // TODO(benvanik): faster lookup.
- if (attr.getName() != "iree.abi.output" &&
- attr.getName() != "iree.abi.encoding" &&
- attr.getName() != "iree.abi.affinity") {
+ if (!isABIAttr(attr)) {
attrs.push_back(attr);
}
}
@@ -60,7 +66,16 @@
}
}
+// Removes all ABI attrs from the |op| and its args/results.
static void stripABIAttrs(FunctionOpInterface op) {
+ NamedAttrList attrs;
+ for (auto attr : op->getAttrs()) {
+ if (!isABIAttr(attr)) {
+ attrs.push_back(attr);
+ }
+ }
+ op->setAttrs(attrs);
+
SmallVector<DictionaryAttr> argAttrs;
op.getAllArgAttrs(argAttrs);
stripABIAttrs(argAttrs);
@@ -71,6 +86,11 @@
op.setAllResultAttrs(resultAttrs);
}
+template <typename T>
+static T fallback(T optionalValue, T defaultValue) {
+ return optionalValue ? optionalValue : defaultValue;
+}
+
// Creates the corresponding wrapper function for the given import function.
static IREE::Util::FuncOp
createImportWrapperFunc(IREE::ABI::InvocationModel invocationModel,
@@ -101,12 +121,7 @@
argAttrDict.push_back(nullptr); // signal
break;
}
-
- // Update the import type and propagate back the attributes we may have
- // modified above.
importOp.setType(newImportType);
- importOp.setAllArgAttrs(argAttrDict);
- importOp.setAllResultAttrs(resultAttrDict);
auto *entryBlock = wrapperOp.addEntryBlock();
auto entryBuilder = OpBuilder::atBlockBegin(entryBlock);
@@ -129,6 +144,12 @@
// users mark their functions 'nosideeffects' to avoid the host wait.
const bool hasSideEffects = !importOp->hasAttr("nosideeffects");
+ // Fetch and normalize any explicitly assigned affinity.
+ auto defaultAffinityAttr = importOp->getAttr("iree.abi.affinity");
+ if (defaultAffinityAttr) {
+ importOp->setAttr("stream.affinity", defaultAffinityAttr);
+ }
+
// When running async we insert a barrier on tensor arguments and attach that
// to the fence we pass to the import for waiting. We'll also allocate the
// signal fence that the import must signal when the returned tensors are
@@ -141,15 +162,24 @@
// No fences.
break;
case IREE::ABI::InvocationModel::CoarseFences: {
- // HACK: this is relying on the fact that there's only one HAL device.
- // We should instead have a way of creating fences on the device that
- // is used to produce the tensors we're wrapping.
- //
- // TODO(multi-device): emit get with derived ordinal or lookup with attr. We
- // could always say device 0 for now but could instead look for an
- // iree.abi.affinity/iree.abi.device/etc.
- Value device =
- IREE::HAL::DeviceType::resolveAny(importOp.getLoc(), entryBuilder);
+ Value device;
+ // TODO(benvanik): support other affinity types.
+ if (auto deviceAffinityAttr =
+ dyn_cast_if_present<IREE::HAL::DeviceAffinityAttr>(
+ defaultAffinityAttr)) {
+ device = entryBuilder
+ .create<IREE::HAL::DeviceResolveOp>(
+ importOp.getLoc(),
+ entryBuilder.getType<IREE::HAL::DeviceType>(),
+ deviceAffinityAttr)
+ .getResult(0);
+ } else {
+ // HACK: if no devices are available we get the first one available at
+ // runtime. This is suboptimal but we expect most usage to have affinities
+ // assigned prior to ABI conversion.
+ device =
+ IREE::HAL::DeviceType::resolveAny(importOp.getLoc(), entryBuilder);
+ }
// When exporting a fence we need to put a barrier between the rest of the
// program and the tensors consumed by the import.
@@ -162,7 +192,7 @@
importOp.getLoc(), entryBuilder.getType<IREE::HAL::FenceType>(),
device, IREE::HAL::FenceFlagBitfield::None);
auto barrierOp = entryBuilder.create<IREE::HAL::TensorBarrierOp>(
- importOp.getLoc(), tensorArgs, waitFence, /*affinity=*/nullptr);
+ importOp.getLoc(), tensorArgs, waitFence);
for (auto [argIndex, readyArg] :
llvm::zip_equal(tensorArgIndices, barrierOp.getResults())) {
entryArgs[argIndex] = readyArg;
@@ -203,9 +233,10 @@
importOp.getArgAttrOfType<TypeAttr>(argIndex, "iree.abi.encoding");
auto tensorExportOp = entryBuilder.create<IREE::HAL::TensorExportOp>(
arg.getLoc(), newType, arg,
- encodingAttr ? encodingAttr : TypeAttr::get(oldType),
+ fallback(encodingAttr, TypeAttr::get(oldType)),
/*name=*/nullptr,
- /*affinity=*/nullptr);
+ fallback(importOp.getArgAttr(argIndex, "iree.abi.affinity"),
+ defaultAffinityAttr));
arguments.push_back(tensorExportOp.getTarget());
} else {
arguments.push_back(arg);
@@ -245,9 +276,10 @@
resultIndex, "iree.abi.encoding");
auto tensorImportOp = entryBuilder.create<IREE::HAL::TensorImportOp>(
importOp.getLoc(), oldType, result,
- encodingAttr ? encodingAttr : TypeAttr::get(oldType), signalFence,
+ fallback(encodingAttr, TypeAttr::get(oldType)), signalFence,
/*name=*/nullptr,
- /*affinity=*/nullptr);
+ fallback(importOp.getResultAttr(resultIndex, "iree.abi.affinity"),
+ defaultAffinityAttr));
results.push_back(tensorImportOp);
} else {
results.push_back(result);
@@ -255,6 +287,9 @@
}
entryBuilder.create<IREE::Util::ReturnOp>(importOp.getLoc(), results);
+
+ stripABIAttrs(importOp);
+
return wrapperOp;
}
@@ -518,8 +553,11 @@
// Populate the reflection attrs based on the original types.
populateReflectionAttrs(invocationModel, exportOp, wrapperOp);
exportOp->removeAttr("iree.reflection");
- if (auto affinityAttr = exportOp->getAttr("stream.affinity")) {
- wrapperOp->setAttr("stream.affinity", affinityAttr);
+
+ // Fetch and normalize any explicitly assigned affinity.
+ auto defaultAffinityAttr = exportOp->getAttr("iree.abi.affinity");
+ if (defaultAffinityAttr) {
+ exportOp->setAttr("stream.affinity", defaultAffinityAttr);
}
auto *entryBlock = wrapperOp.addEntryBlock();
@@ -572,12 +610,13 @@
if (llvm::isa<TensorType>(oldType)) {
auto encodingAttr =
exportOp.getArgAttrOfType<TypeAttr>(argIndex, "iree.abi.encoding");
+ auto argName = inferArgumentName(entryBuilder.getContext(), argIndex,
+ exportOp.getArgAttrDict(argIndex));
auto tensorImportOp = entryBuilder.create<IREE::HAL::TensorImportOp>(
arg.getLoc(), oldType, arg,
- encodingAttr ? encodingAttr : TypeAttr::get(oldType), waitFence,
- inferArgumentName(entryBuilder.getContext(), argIndex,
- exportOp.getArgAttrDict(argIndex)),
- exportOp.getArgAttr(argIndex, "iree.abi.affinity"));
+ fallback(encodingAttr, TypeAttr::get(oldType)), waitFence, argName,
+ fallback(exportOp.getArgAttr(argIndex, "iree.abi.affinity"),
+ defaultAffinityAttr));
arguments.push_back(tensorImportOp.getTarget());
} else {
arguments.push_back(arg);
@@ -601,7 +640,8 @@
auto aliasOp = entryBuilder.create<IREE::HAL::TensorAliasOp>(
exportOp.getLoc(), source.getType(), source, sourceDims,
resultStorages[resultIndex], waitFence,
- exportOp.getResultAttr(resultIndex, "iree.abi.affinity"));
+ fallback(exportOp.getResultAttr(resultIndex, "iree.abi.affinity"),
+ defaultAffinityAttr));
asyncResults[resultIndex] = cast<OpResult>(aliasOp.getResult());
}
@@ -622,7 +662,7 @@
signalFence);
} else {
auto barrierOp = entryBuilder.create<IREE::HAL::TensorBarrierOp>(
- exportOp.getLoc(), asyncTensors, signalFence, /*affinity=*/nullptr);
+ exportOp.getLoc(), asyncTensors, signalFence);
asyncResults = llvm::to_vector(barrierOp.getResults());
}
}
@@ -635,15 +675,17 @@
if (llvm::isa<TensorType>(oldType)) {
auto encodingAttr = exportOp.getResultAttrOfType<TypeAttr>(
resultIndex, "iree.abi.encoding");
+ auto resultName =
+ inferResultName(entryBuilder.getContext(), resultIndex,
+ exportOp.getResultAttrDict(resultIndex));
auto dynamicDims = IREE::Util::buildDynamicDimsForValue(
result.getLoc(), result, entryBuilder);
auto tensorExportOp = entryBuilder.create<IREE::HAL::TensorExportOp>(
result.getLoc(), newType, result,
- encodingAttr ? encodingAttr : TypeAttr::get(result.getType()),
- dynamicDims,
- inferResultName(entryBuilder.getContext(), resultIndex,
- exportOp.getResultAttrDict(resultIndex)),
- exportOp.getResultAttr(resultIndex, "iree.abi.affinity"));
+ fallback(encodingAttr, TypeAttr::get(result.getType())), dynamicDims,
+ resultName,
+ fallback(exportOp.getResultAttr(resultIndex, "iree.abi.affinity"),
+ defaultAffinityAttr));
results.push_back(tensorExportOp);
} else {
results.push_back(result);
@@ -731,12 +773,15 @@
exportOps.push_back(funcOp);
}
}
+ if (importOps.empty() && exportOps.empty()) {
+ return; // no-op
+ }
SymbolTable symbolTable(moduleOp);
// Create a wrapper function for each imported function.
- // This will preserve the internal types (tensors/etc) but change the import
- // to taking the ABI types and rewrite calls.
+ // This will preserve the internal types (tensors/etc) but change the
+ // import to taking the ABI types and rewrite calls.
for (auto importOp : importOps) {
if (failed(wrapImportFunc(getInvocationModel(importOp, invocationModel),
moduleOp, importOp, symbolTable))) {
diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir
index 3780ee2..72a0441 100644
--- a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir
+++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir
@@ -181,6 +181,28 @@
// -----
+// Tests that explicit import affinity specification is carried through to
+// the marshaling ops.
+
+// CHECK-LABEL: util.func private @pinnedImport(%arg0: !hal.buffer_view) -> !hal.buffer_view
+util.func private @pinnedImport(tensor<2xi32> {iree.abi.affinity = #hal.device.promise<@dev_a>}) -> (tensor<2xi32> {iree.abi.affinity = #hal.device.promise<@dev_b>})
+
+// CHECK: util.func private @_pinnedImport(%[[ARG_TENSOR:.+]]: tensor<2xi32>) -> tensor<2xi32> {
+// CHECK: %[[ARG_VIEW:.+]] = hal.tensor.export on(#hal.device.promise<@dev_a>) %[[ARG_TENSOR]] : tensor<2xi32> -> !hal.buffer_view
+// CHECK: %[[RET_VIEW:.+]] = util.call @pinnedImport(%[[ARG_VIEW]]) : (!hal.buffer_view) -> !hal.buffer_view
+// CHECK: %[[RET_TENSOR:.+]] = hal.tensor.import on(#hal.device.promise<@dev_b>) %[[RET_VIEW]] : !hal.buffer_view -> tensor<2xi32>
+// CHECK: util.return %[[RET_TENSOR]]
+// CHECK: }
+
+// CHECK: util.func private @pinnedCaller(%arg0: tensor
+util.func private @pinnedCaller(%arg0: tensor<2xi32>) -> tensor<2xi32> {
+ // CHECK: util.call @_pinnedImport(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
+ %0 = util.call @pinnedImport(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
+ util.return %0 : tensor<2xi32>
+}
+
+// -----
+
// Tests that imports with encodings specified are propagated to the HAL ops.
// CHECK-LABEL: util.func private @importEncodings(%arg0: !hal.buffer_view) -> !hal.buffer_view
@@ -188,10 +210,10 @@
// CHECK: util.func private @_importEncodings(%[[ARG_TENSOR:.+]]: tensor<?x2xi32>) -> tensor<2x?xi32> {
// CHECK: %[[ARG_DIM:.+]] = tensor.dim %[[ARG_TENSOR]], %c0
-// CHECK: %[[ARG_VIEW:.+]] = hal.tensor.export %[[ARG_TENSOR]] : tensor<?x2xi32>{%[[ARG_DIM]]} -> !hal.buffer_view
+// CHECK: %[[ARG_VIEW:.+]] = hal.tensor.export %[[ARG_TENSOR]] : tensor<?x2xf32> as tensor<?x2xi32>{%[[ARG_DIM]]} -> !hal.buffer_view
// CHECK: %[[RET_VIEW:.+]] = util.call @importEncodings(%[[ARG_VIEW]]) : (!hal.buffer_view) -> !hal.buffer_view
// CHECK: %[[RET_DIM:.+]] = hal.buffer_view.dim<%[[RET_VIEW]] : !hal.buffer_view>[1]
-// CHECK: %[[RET_TENSOR:.+]] = hal.tensor.import %[[RET_VIEW]] : !hal.buffer_view -> tensor<2x?xi32>{%[[RET_DIM]]}
+// CHECK: %[[RET_TENSOR:.+]] = hal.tensor.import %[[RET_VIEW]] : !hal.buffer_view -> tensor<2x?xf32> as tensor<2x?xi32>{%[[RET_DIM]]}
// CHECK: util.return %[[RET_TENSOR]]
// CHECK: }
diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points_coarse_fences.mlir b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points_coarse_fences.mlir
index 4505a54..f9af505 100644
--- a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points_coarse_fences.mlir
+++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points_coarse_fences.mlir
@@ -170,6 +170,40 @@
// -----
+// Tests that explicit import affinity specification is carried through to
+// the marshaling ops.
+
+util.global private @dev_a : !hal.device
+util.global private @dev_b : !hal.device
+util.global private @dev_c : !hal.device
+
+// CHECK-LABEL: util.func private @pinnedImport(%arg0: !hal.buffer_view, %arg1: !hal.fence, %arg2: !hal.fence) -> !hal.buffer_view
+util.func private @pinnedImport(tensor<2xi32> {iree.abi.affinity = #hal.device.affinity<@dev_a>}) -> (tensor<2xi32> {iree.abi.affinity = #hal.device.affinity<@dev_b>}) attributes {
+ iree.abi.affinity = #hal.device.affinity<@dev_c>,
+ iree.abi.model = "coarse-fences",
+ nosideeffects
+}
+
+// CHECK: util.func private @_pinnedImport(%[[ARG_TENSOR:.+]]: tensor<2xi32>) -> tensor<2xi32> {
+// CHECK-DAG: %[[DEVICE_C:.+]] = hal.device.resolve on(<@dev_c>) : !hal.device
+// CHECK-DAG: %[[ARG_FENCE:.+]] = hal.fence.create device(%[[DEVICE_C]] : !hal.device) flags("None") : !hal.fence
+// CHECK-DAG: %[[ARG_READY:.+]] = hal.tensor.barrier join(%[[ARG_TENSOR]] : tensor<2xi32>) => %[[ARG_FENCE]] : !hal.fence
+// CHECK-DAG: %[[ARG_VIEW:.+]] = hal.tensor.export on(#hal.device.affinity<@dev_a>) %[[ARG_READY]] : tensor<2xi32> -> !hal.buffer_view
+// CHECK-DAG: %[[RESULT_FENCE:.+]] = hal.fence.create device(%[[DEVICE_C]] : !hal.device) flags("None") : !hal.fence
+// CHECK: %[[RET_VIEW:.+]] = util.call @pinnedImport(%[[ARG_VIEW]], %[[ARG_FENCE]], %[[RESULT_FENCE]]) : (!hal.buffer_view, !hal.fence, !hal.fence) -> !hal.buffer_view
+// CHECK: %[[RET_TENSOR:.+]] = hal.tensor.import on(#hal.device.affinity<@dev_b>) wait(%[[RESULT_FENCE]]) => %[[RET_VIEW]] : !hal.buffer_view -> tensor<2xi32>
+// CHECK: util.return %[[RET_TENSOR]]
+// CHECK: }
+
+// CHECK: util.func private @pinnedCaller(%arg0: tensor
+util.func private @pinnedCaller(%arg0: tensor<2xi32>) -> tensor<2xi32> {
+ // CHECK: util.call @_pinnedImport(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
+ %0 = util.call @pinnedImport(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
+ util.return %0 : tensor<2xi32>
+}
+
+// -----
+
// Tests a side-effect-free import that doesn't take/return reference types.
// CHECK-LABEL: util.func private @importI32(%arg0: i32, %arg1: !hal.fence, %arg2: !hal.fence) -> i32
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel
index 98da9d9..fb3b309 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel
@@ -64,6 +64,7 @@
"//compiler/src/iree/compiler/Dialect/Encoding/IR",
"//compiler/src/iree/compiler/Dialect/HAL/Analysis",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
+ "//compiler/src/iree/compiler/Dialect/Stream/Analysis",
"//runtime/src/iree/builtins/ukernel:exported_bits",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt
index 7236121..0b79e7d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt
@@ -86,6 +86,7 @@
iree::compiler::Dialect::Encoding::IR
iree::compiler::Dialect::HAL::Analysis
iree::compiler::Dialect::HAL::IR
+ iree::compiler::Dialect::Stream::Analysis
PUBLIC
)
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp
index c107bc7..4edaffa 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp
@@ -13,6 +13,7 @@
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
#include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
+#include "iree/compiler/Dialect/Stream/Analysis/Affinity.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/MathExtras.h"
@@ -547,6 +548,35 @@
return success();
}
+// Returns the executable targets used within |funcOp|.
+//
+// TODO(multi-device): delete this pass and rely on tensor-based analysis to
+// materialize encodings based on where tensors are used. This pass is not able
+// to handle that.
+static std::optional<SetVector<IREE::HAL::ExecutableTargetAttr>>
+getFuncExecutableTargetAttrs(FunctionOpInterface funcOp,
+ IREE::Stream::AffinityAnalysis &affinityAnalysis,
+ IREE::HAL::DeviceAnalysis &deviceAnalysis) {
+ // Get a set of all unique affinities used by resources within the function.
+ SetVector<IREE::Stream::AffinityAttr> uniqueAffinityAttrs;
+ SmallVector<IREE::Stream::AffinityAttr> lookupAffinityAttrs;
+ funcOp.walk([&](Operation *op) {
+ if (affinityAnalysis.tryLookupExecutionAffinity(op, lookupAffinityAttrs)) {
+ uniqueAffinityAttrs.insert(lookupAffinityAttrs.begin(),
+ lookupAffinityAttrs.end());
+ }
+ lookupAffinityAttrs.clear();
+ });
+
+ // Resolve affinities to executable targets.
+ SetVector<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs;
+ for (auto affinityAttr : uniqueAffinityAttrs) {
+ deviceAnalysis.gatherRequiredExecutableTargets(affinityAttr, funcOp,
+ executableTargetAttrs);
+ }
+ return executableTargetAttrs;
+}
+
struct CPUMaterializeHostEncodingPass
: public CPUMaterializeHostEncodingBase<CPUMaterializeHostEncodingPass> {
CPUMaterializeHostEncodingPass() = default;
@@ -560,23 +590,36 @@
auto moduleOp = getOperation();
// Run required analysis passes.
- IREE::HAL::DeviceAnalysis deviceAnalysis(moduleOp);
- if (failed(deviceAnalysis.run()))
+ IREE::Stream::AffinityAnalysis affinityAnalysis(moduleOp);
+ if (failed(affinityAnalysis.run())) {
return signalPassFailure();
+ }
+ IREE::HAL::DeviceAnalysis deviceAnalysis(moduleOp);
+ if (failed(deviceAnalysis.run())) {
+ return signalPassFailure();
+ }
for (auto funcOp : moduleOp.getOps<FunctionOpInterface>()) {
// Gather the required executable targets for the function. Note that it's
// possible there are more required for ops nested within the function but
// this pass is a hack and can't handle that :shrug:.
- SetVector<IREE::HAL::ExecutableTargetAttr> executableTargets;
- deviceAnalysis.gatherRequiredExecutableTargets(funcOp, executableTargets);
+ auto executableTargets = getFuncExecutableTargetAttrs(
+ funcOp, affinityAnalysis, deviceAnalysis);
+ if (!executableTargets) {
+ funcOp.emitOpError()
+ << "could not determine executable targets for the function";
+ return signalPassFailure();
+ } else if (executableTargets->empty()) {
+ // Probably no tensors.
+ continue;
+ }
// HACK: this pass is run on the host _but shouldn't be_. Because it's
// run on the host and IREE is a compiler capable of multi-targeting there
// may be multiple executable targets at any point in the host program.
// This pass can't handle that and assumes it's been checked earlier by
// spooky action at a distance. This needs to be fixed.
- if (executableTargets.size() != 1) {
+ if (executableTargets->size() != 1) {
funcOp.emitOpError() << "has multiple executable targets and CPU data "
"tiling isn't built to support that";
return signalPassFailure();
@@ -584,7 +627,7 @@
// Materialize encodings within the function.
if (failed(
- materializeFuncOpEncodings(funcOp, executableTargets.front()))) {
+ materializeFuncOpEncodings(funcOp, executableTargets->front()))) {
return signalPassFailure();
}
}
@@ -636,22 +679,35 @@
auto moduleOp = getOperation();
// Run required analysis passes.
- IREE::HAL::DeviceAnalysis deviceAnalysis(moduleOp);
- if (failed(deviceAnalysis.run()))
+ IREE::Stream::AffinityAnalysis affinityAnalysis(moduleOp);
+ if (failed(affinityAnalysis.run())) {
return signalPassFailure();
+ }
+ IREE::HAL::DeviceAnalysis deviceAnalysis(moduleOp);
+ if (failed(deviceAnalysis.run())) {
+ return signalPassFailure();
+ }
for (auto funcOp : moduleOp.getOps<FunctionOpInterface>()) {
// Gather the required executable targets for the function. Note that it's
// possible there are more required for ops nested within the function but
// this pass is a hack and can't handle that :shrug:.
- SetVector<IREE::HAL::ExecutableTargetAttr> executableTargets;
- deviceAnalysis.gatherRequiredExecutableTargets(funcOp, executableTargets);
+ auto executableTargets = getFuncExecutableTargetAttrs(
+ funcOp, affinityAnalysis, deviceAnalysis);
+ if (!executableTargets) {
+ funcOp.emitOpError()
+ << "could not determine executable targets for the function";
+ return signalPassFailure();
+ } else if (executableTargets->empty()) {
+ // Probably no tensors.
+ continue;
+ }
// Get patterns specialized for the executable targets used by the
// function.
RewritePatternSet patterns(&getContext());
MaterializeEncodingFn materializeEncodingFn =
- getUpperBoundMaterializeEncodingFn(executableTargets.getArrayRef());
+ getUpperBoundMaterializeEncodingFn(executableTargets->getArrayRef());
if (!materializeEncodingFn)
return signalPassFailure();
populateMaterializeUpperBoundTileSizePatterns(patterns,
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp
index f762a1f..fc26729 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp
@@ -211,6 +211,15 @@
}
}
+bool BindingLayoutAnalysis::hasDispatches() const {
+ for (auto &it : exportInfos) {
+ if (!it.second->dispatchOps.empty()) {
+ return true; // found at least one dispatch
+ }
+ }
+ return false;
+}
+
ArrayRef<IREE::Stream::CmdDispatchOp>
BindingLayoutAnalysis::getExportDispatches(Operation *exportOp) const {
auto it = exportInfos.find(exportOp);
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.h b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.h
index 050e18e..7d08959 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.h
+++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.h
@@ -59,6 +59,9 @@
public:
explicit BindingLayoutAnalysis(Operation *rootOp, SymbolTable &symbolTable);
+ // Whether there are any dispatches in the program.
+ bool hasDispatches() const;
+
// Returns all of the dispatches to the given executable export.
ArrayRef<IREE::Stream::CmdDispatchOp>
getExportDispatches(IREE::Stream::ExecutableExportOp exportOp) const {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/Patterns.cpp
index a17a100..705292b 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/Patterns.cpp
@@ -15,6 +15,92 @@
namespace {
+struct ConvertDeviceResolveAnyOp
+ : public OpConversionPattern<IREE::HAL::DeviceResolveOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(IREE::HAL::DeviceResolveOp resolveOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (adaptor.getAffinity()) {
+ return rewriter.notifyMatchFailure(
+ resolveOp, "only resolving unspecified affinities to any device");
+ }
+
+ auto deviceType = rewriter.getType<IREE::HAL::DeviceType>();
+ Value device;
+ auto resolveDevice = [&]() {
+ if (!device) {
+ device = rewriter.create<IREE::HAL::DevicesGetOp>(
+ resolveOp.getLoc(), deviceType,
+ rewriter.create<arith::ConstantIndexOp>(resolveOp.getLoc(), 0));
+ }
+ return device;
+ };
+
+ SmallVector<Value> results;
+ for (auto resultType : resolveOp.getResultTypes()) {
+ if (isa<IREE::HAL::DeviceType>(resultType)) {
+ results.push_back(resolveDevice());
+ } else if (isa<IREE::HAL::AllocatorType>(resultType)) {
+ results.push_back(rewriter.create<IREE::HAL::DeviceAllocatorOp>(
+ resolveOp.getLoc(), resolveDevice()));
+ } else if (isa<IntegerType>(resultType)) {
+ results.push_back(rewriter.create<arith::ConstantIntOp>(
+ resolveOp.getLoc(), -1ll, 64));
+ }
+ }
+
+ rewriter.replaceOp(resolveOp, results);
+ return success();
+ }
+};
+
+struct ConvertDeviceResolveAffinityOp
+ : public OpConversionPattern<IREE::HAL::DeviceResolveOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(IREE::HAL::DeviceResolveOp resolveOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto affinityAttr = adaptor.getAffinityAttr();
+ if (!affinityAttr) {
+ return rewriter.notifyMatchFailure(
+ resolveOp, "only resolving fully specified affinities");
+ }
+ auto flatDeviceAttr = dyn_cast<FlatSymbolRefAttr>(affinityAttr.getDevice());
+ if (!flatDeviceAttr) {
+ return rewriter.notifyMatchFailure(
+ resolveOp, "nested device references not yet supported");
+ }
+
+ auto deviceType = rewriter.getType<IREE::HAL::DeviceType>();
+ Value device;
+ auto resolveDevice = [&]() {
+ if (!device) {
+ device = rewriter.create<IREE::Util::GlobalLoadOp>(
+ resolveOp.getLoc(), deviceType, flatDeviceAttr.getValue(),
+ /*is_immutable=*/true);
+ }
+ return device;
+ };
+
+ SmallVector<Value> results;
+ for (auto resultType : resolveOp.getResultTypes()) {
+ if (isa<IREE::HAL::DeviceType>(resultType)) {
+ results.push_back(resolveDevice());
+ } else if (isa<IREE::HAL::AllocatorType>(resultType)) {
+ results.push_back(rewriter.create<IREE::HAL::DeviceAllocatorOp>(
+ resolveOp.getLoc(), resolveDevice()));
+ } else if (isa<IntegerType>(resultType)) {
+ results.push_back(rewriter.create<arith::ConstantIntOp>(
+ resolveOp.getLoc(), affinityAttr.getQueueMask(), 64));
+ }
+ }
+
+ rewriter.replaceOp(resolveOp, results);
+ return success();
+ }
+};
+
struct ConvertExecutableCalculateWorkgroupsOp
: public OpConversionPattern<IREE::HAL::ExecutableCalculateWorkgroupsOp> {
using OpConversionPattern::OpConversionPattern;
@@ -43,6 +129,10 @@
ConversionTarget &conversionTarget,
TypeConverter &typeConverter,
RewritePatternSet &patterns) {
+ conversionTarget.addIllegalOp<IREE::HAL::DeviceResolveOp>();
+ patterns.insert<ConvertDeviceResolveAnyOp>(typeConverter, context);
+ patterns.insert<ConvertDeviceResolveAffinityOp>(typeConverter, context);
+
conversionTarget.addIllegalOp<IREE::HAL::ExecutableCalculateWorkgroupsOp>();
patterns.insert<ConvertExecutableCalculateWorkgroupsOp>(typeConverter,
context);
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD.bazel
index 6056652..b38bbbe 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD.bazel
@@ -15,7 +15,10 @@
iree_lit_test_suite(
name = "lit",
srcs = enforce_glob(
- ["pseudo_ops.mlir"],
+ [
+ "device_ops.mlir",
+ "pseudo_ops.mlir",
+ ],
include = ["*.mlir"],
),
cfg = "//compiler:lit.cfg.py",
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/CMakeLists.txt
index 2a57a80..6757109 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/CMakeLists.txt
@@ -14,6 +14,7 @@
NAME
lit
SRCS
+ "device_ops.mlir"
"pseudo_ops.mlir"
TOOLS
FileCheck
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/device_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/device_ops.mlir
new file mode 100644
index 0000000..44fb128
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/device_ops.mlir
@@ -0,0 +1,75 @@
+// RUN: iree-opt --split-input-file --allow-unregistered-dialect --iree-hal-conversion %s | FileCheck %s
+
+// CHECK-LABEL: @deviceResolveAnyDevice
+util.func public @deviceResolveAnyDevice() -> !hal.device {
+ // CHECK-DAG: %[[ANY_ORDINAL:.+]] = arith.constant 0
+ // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %[[ANY_ORDINAL]] : !hal.device
+ %device = hal.device.resolve : !hal.device
+ // CHECK: util.return %[[DEVICE]]
+ util.return %device : !hal.device
+}
+
+// -----
+
+util.global private @device : !hal.device
+
+// CHECK-LABEL: @deviceResolveDevice
+util.func public @deviceResolveDevice() -> !hal.device {
+ // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device
+ %device = hal.device.resolve on(#hal.device.affinity<@device>) : !hal.device
+ // CHECK: util.return %[[DEVICE]]
+ util.return %device : !hal.device
+}
+
+// -----
+
+util.global private @device : !hal.device
+
+// CHECK-LABEL: @deviceResolveDeviceQueueAffinityAny
+util.func public @deviceResolveDeviceQueueAffinityAny() -> (!hal.device, i64) {
+ // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device
+ // CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant -1 : i64
+ %device, %queue_affinity_any = hal.device.resolve on(#hal.device.affinity<@device>) : !hal.device, i64
+ // CHECK: util.return %[[DEVICE]], %[[QUEUE_AFFINITY]]
+ util.return %device, %queue_affinity_any : !hal.device, i64
+}
+
+// -----
+
+util.global private @device : !hal.device
+
+// CHECK-LABEL: @deviceResolveDeviceQueueAffinity45
+util.func public @deviceResolveDeviceQueueAffinity45() -> (!hal.device, i64) {
+ // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device
+ // CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant 48 : i64
+ %device, %queue_affinity_45 = hal.device.resolve on(#hal.device.affinity<@device, [4, 5]>) : !hal.device, i64
+ // CHECK: util.return %[[DEVICE]], %[[QUEUE_AFFINITY]]
+ util.return %device, %queue_affinity_45 : !hal.device, i64
+}
+
+// -----
+
+util.global private @device : !hal.device
+
+// CHECK-LABEL: @deviceResolveAllocator
+util.func public @deviceResolveAllocator() -> !hal.allocator {
+ // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device
+ // CHECK-DAG: %[[ALLOCATOR:.+]] = hal.device.allocator<%[[DEVICE]] : !hal.device> : !hal.allocator
+ %allocator = hal.device.resolve on(#hal.device.affinity<@device>) : !hal.allocator
+ // CHECK: util.return %[[ALLOCATOR]]
+ util.return %allocator : !hal.allocator
+}
+
+// -----
+
+util.global private @device : !hal.device
+
+// CHECK-LABEL: @deviceResolveAllocatorQueueAffinity45
+util.func public @deviceResolveAllocatorQueueAffinity45() -> (!hal.allocator, i64) {
+ // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device
+ // CHECK-DAG: %[[ALLOCATOR:.+]] = hal.device.allocator<%[[DEVICE]] : !hal.device> : !hal.allocator
+ // CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant 48 : i64
+ %allocator, %queue_affinity_45 = hal.device.resolve on(#hal.device.affinity<@device, [4, 5]>) : !hal.allocator, i64
+ // CHECK: util.return %[[ALLOCATOR]], %[[QUEUE_AFFINITY]]
+ util.return %allocator, %queue_affinity_45 : !hal.allocator, i64
+}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp
index f8efeba..a911de8 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp
@@ -69,32 +69,6 @@
return constantBuffer;
}
-IREE::VM::RodataOp
-createExecutableBinaryRodata(IREE::HAL::ExecutableBinaryOp binaryOp,
- OpBuilder &builder) {
- auto executableOp =
- binaryOp.getOperation()->getParentOfType<IREE::HAL::ExecutableOp>();
- auto insertPoint = builder.saveInsertionPoint();
- builder.setInsertionPoint(builder.getInsertionBlock()->getParentOp());
-
- std::string rodataName = sanitizeSymbolName(
- (executableOp.getName() + "_" + binaryOp.getName()).str());
- auto rodataOp = builder.create<IREE::VM::RodataOp>(
- binaryOp.getLoc(), rodataName, binaryOp.getData());
- rodataOp.setPrivate();
- if (binaryOp.getMimeType().has_value()) {
- rodataOp.setMimeTypeAttr(binaryOp.getMimeTypeAttr());
- }
-
- // TODO(benvanik): should these be page aligned? memcpy fastpath is fine for
- // now.
- rodataOp.setAlignmentAttr(builder.getI64IntegerAttr(16));
-
- builder.restoreInsertionPoint(insertPoint);
-
- return rodataOp;
-}
-
namespace {
class RemoveExecutableOpConversion
@@ -128,9 +102,15 @@
auto executableBinaryOp =
SymbolTable::lookupNearestSymbolFrom<IREE::HAL::ExecutableBinaryOp>(
createOp, createOp.getExecutableTarget());
- auto rodataOp = createExecutableBinaryRodata(executableBinaryOp, rewriter);
- auto executableRodata = rewriter.createOrFold<IREE::VM::ConstRefRodataOp>(
- createOp.getLoc(), rodataOp);
+ auto executableOp = executableBinaryOp.getOperation()
+ ->getParentOfType<IREE::HAL::ExecutableOp>();
+ std::string rodataName = sanitizeSymbolName(
+ (executableOp.getName() + "_" + executableBinaryOp.getName()).str());
+ auto rodataOp = rewriter.create<IREE::VM::RodataInlineOp>(
+ executableBinaryOp.getLoc(),
+ IREE::VM::RefType::get(rewriter.getType<IREE::VM::BufferType>()),
+ rewriter.getStringAttr(rodataName), executableBinaryOp.getData(),
+ rewriter.getI64IntegerAttr(16), executableBinaryOp.getMimeTypeAttr());
// Get format string as a rodata blob.
auto executableFormatStr = rewriter.create<IREE::VM::RodataInlineOp>(
@@ -151,7 +131,7 @@
SmallVector<Value, 8> callOperands = {
adaptor.getDevice(),
executableFormatStr,
- executableRodata,
+ rodataOp,
constantBuffer,
};
callOperands.append(adaptor.getLayouts().begin(),
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/Patterns.h b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/Patterns.h
index 071643a..0596524 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/Patterns.h
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/Patterns.h
@@ -23,11 +23,6 @@
Value createPackedConstantBuffer(Location loc, ValueRange constantValues,
OpBuilder &builder);
-// Creates a vm.rodata containing the contents of a hal.executable.binary.
-IREE::VM::RodataOp
-createExecutableBinaryRodata(IREE::HAL::ExecutableBinaryOp binaryOp,
- OpBuilder &builder);
-
} // namespace mlir::iree_compiler
#endif // IREE_COMPILER_DIALECT_HAL_CONVERSION_HALTOVM_PATTERNS_H_
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/executable_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/executable_ops.mlir
index 9249b56..5dd5341 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/executable_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/executable_ops.mlir
@@ -1,7 +1,5 @@
// RUN: iree-opt --split-input-file --iree-vm-conversion %s | FileCheck %s
-// CHECK: vm.rodata private @exe_binary1 {alignment = 16 : i64} dense<[0, 1, 2, 3]> : vector<4xi8>
-// CHECK: vm.rodata private @exe_binary2 {alignment = 16 : i64} dense<[4, 5, 6, 7]> : vector<4xi8>
hal.executable @exe {
hal.executable.binary @binary1 attributes {
data = dense<[0, 1, 2, 3]> : vector<4xi8>,
@@ -24,7 +22,7 @@
) -> (!hal.executable, !hal.executable) {
// CHECK-DAG: %[[FORMAT1:.+]] = vm.rodata.inline "_utf8_format1_
- // CHECK-DAG: %[[BINARY1:.+]] = vm.const.ref.rodata @exe_binary1 : !vm.buffer
+ // CHECK-DAG: %[[BINARY1:.+]] = vm.rodata.inline "exe_binary1" {alignment = 16 : i64} : !vm.buffer = dense<[0, 1, 2, 3]> : vector<4xi8>
// CHECK-DAG: %[[NULL1:.+]] = vm.const.ref.zero : !vm.buffer
// CHECK: %[[EXE1:.+]] = vm.call.variadic @hal.executable.create(
// CHECK-SAME: %[[DEV]], %[[FORMAT1]], %[[BINARY1]], %[[NULL1]], [%[[LAYOUT0]], %[[LAYOUT1]]]
@@ -32,7 +30,7 @@
%0 = hal.executable.create device(%device : !hal.device) target(@exe::@binary1) layouts([%layout0, %layout1]) : !hal.executable
// CHECK-DAG: %[[FORMAT2:.+]] = vm.rodata.inline "_utf8_format2_
- // CHECK-DAG: %[[BINARY2:.+]] = vm.const.ref.rodata @exe_binary2 : !vm.buffer
+ // CHECK-DAG: %[[BINARY2:.+]] = vm.rodata.inline "exe_binary2" {alignment = 16 : i64} : !vm.buffer = dense<[4, 5, 6, 7]> : vector<4xi8>
// CHECK-DAG: %[[NULL2:.+]] = vm.const.ref.zero : !vm.buffer
// CHECK: %[[EXE2:.+]] = vm.call.variadic @hal.executable.create(
// CHECK-SAME: %[[DEV]], %[[FORMAT2]], %[[BINARY2]], %[[NULL2]], [%[[LAYOUT1]], %[[LAYOUT0]]]
@@ -45,14 +43,12 @@
// -----
-// CHECK: vm.rodata private @exe1_binary1 {alignment = 16 : i64} dense<[0, 1, 2, 3]> : vector<4xi8>
hal.executable @exe1 {
hal.executable.binary @binary1 attributes {
data = dense<[0, 1, 2, 3]> : vector<4xi8>,
format = "format"
}
}
-// CHECK: vm.rodata private @exe2_binary2 {alignment = 16 : i64} dense<[4, 5, 6, 7]> : vector<4xi8>
hal.executable @exe2 {
hal.executable.binary @binary2 attributes {
data = dense<[4, 5, 6, 7]> : vector<4xi8>,
@@ -67,17 +63,16 @@
%layout1: !hal.pipeline_layout
) -> (!hal.executable, !hal.executable) {
// CHECK-DAG: %[[FORMAT1:.+]] = vm.rodata.inline "_utf8_format_
- // CHECK-DAG: %[[BINARY1:.+]] = vm.const.ref.rodata @exe1_binary1 : !vm.buffer
+ // CHECK-DAG: %[[BINARY1:.+]] = vm.rodata.inline "exe1_binary1" {alignment = 16 : i64} : !vm.buffer = dense<[0, 1, 2, 3]> : vector<4xi8>
%0 = hal.executable.create device(%device : !hal.device) target(@exe1::@binary1) layouts([%layout0, %layout1]) : !hal.executable
// CHECK-DAG: %[[FORMAT2:.+]] = vm.rodata.inline "_utf8_format_
- // CHECK-DAG: %[[BINARY2:.+]] = vm.const.ref.rodata @exe2_binary2 : !vm.buffer
+ // CHECK-DAG: %[[BINARY2:.+]] = vm.rodata.inline "exe2_binary2" {alignment = 16 : i64} : !vm.buffer = dense<[4, 5, 6, 7]> : vector<4xi8>
%1 = hal.executable.create device(%device : !hal.device) target(@exe2::@binary2) layouts([%layout1, %layout0]) : !hal.executable
util.return %0, %1 : !hal.executable, !hal.executable
}
// -----
-// CHECK: vm.rodata private @exe_binary {alignment = 16 : i64} dense<[0, 1, 2, 3]> : vector<4xi8>
hal.executable @exe {
hal.executable.binary @binary attributes {
data = dense<[0, 1, 2, 3]> : vector<4xi8>,
@@ -95,7 +90,7 @@
%constant0: i32, %constant1: i32
) -> !hal.executable {
// CHECK-DAG: %[[FORMAT:.+]] = vm.rodata.inline "_utf8_format_
- // CHECK-DAG: %[[BINARY:.+]] = vm.const.ref.rodata @exe_binary : !vm.buffer
+ // CHECK-DAG: %[[BINARY:.+]] = vm.rodata.inline "exe_binary" {alignment = 16 : i64} : !vm.buffer = dense<[0, 1, 2, 3]> : vector<4xi8>
// CHECK: %[[CONSTANTS:.+]] = vm.buffer.alloc %c12, %c16 : !vm.buffer
// CHECK-DAG: %[[INDEX0:.+]] = vm.const.i64 0
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
index de9bbb4..a58f32a 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
@@ -34,58 +34,32 @@
// Get the affinity from the op or an ancestor. Note that there may be no
// affinity specified at all.
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(resolveOp);
+ auto affinityAttr = IREE::Stream::AffinityAttr::lookupOrDefault(resolveOp);
+
+ // If no affinity was specified then resolve as 'any'.
+ if (!affinityAttr) {
+ rewriter.replaceOpWithNewOp<IREE::HAL::DeviceResolveOp>(
+ resolveOp, resolveOp.getResultTypes(),
+ IREE::HAL::DeviceAffinityAttr{});
+ return success();
+ }
// We currently only handle HAL device affinities.
// We could make this an interface to select the device and allow users to
// provide their own affinities to convert to HAL. In the future users may
// also want to provide devices as function arguments post-initialization.
// For now we just have one way to specify device globals.
- auto deviceAffinityAttr =
- dyn_cast_if_present<IREE::HAL::DeviceAffinityAttr>(affinityAttr);
- if (!deviceAffinityAttr) {
- resolveOp.emitOpError() << "failed to resolve affinity: only HAL device "
- "affinities are supported";
- return rewriter.notifyMatchFailure(
- resolveOp, "only HAL device affinities are supported");
+ if (auto deviceAffinityAttr =
+ dyn_cast_if_present<IREE::HAL::DeviceAffinityAttr>(affinityAttr)) {
+ rewriter.replaceOpWithNewOp<IREE::HAL::DeviceResolveOp>(
+ resolveOp, resolveOp.getResultTypes(), deviceAffinityAttr);
+ return success();
}
- // Get the device handle and queue.
- //
- // TODO(multi-device): specialized types; may need analysis we don't have
- // or at least a symbol lookup. An alternative would be an optional type
- // on the affinity in cases where we've evaluated it early but for now
- // we assume all device types are unspecialized.
- auto deviceType = rewriter.getType<IREE::HAL::DeviceType>();
- Value device = rewriter.create<IREE::Util::GlobalLoadOp>(
- resolveOp.getLoc(), deviceType,
- deviceAffinityAttr.getDevice().getValue(),
- /*is_immutable=*/true);
- int64_t queueMask = deviceAffinityAttr.getQueueMask();
-
- SmallVector<Value> results;
- if (isa<IREE::HAL::DeviceType>(resultTypes[0])) {
- results.push_back(device);
- } else if (isa<IREE::HAL::AllocatorType>(resultTypes[0])) {
- results.push_back(rewriter.create<IREE::HAL::DeviceAllocatorOp>(
- resolveOp.getLoc(), device));
- } else {
- return rewriter.notifyMatchFailure(
- resolveOp, "unrecognized context resolve types for a HAL target");
- }
- if (resultTypes.size() > 1) {
- if (isa<IntegerType>(resultTypes[1])) {
- results.push_back(rewriter.create<arith::ConstantIntOp>(
- resolveOp.getLoc(), queueMask, 64));
- } else {
- return rewriter.notifyMatchFailure(
- resolveOp,
- "unrecognized context resolve types for a HAL target (extended)");
- }
- }
-
- rewriter.replaceOp(resolveOp, results);
- return success();
+ resolveOp.emitOpError() << "failed to resolve affinity: only HAL device "
+ "affinities are supported";
+ return rewriter.notifyMatchFailure(
+ resolveOp, "only HAL device affinities are supported");
}
};
@@ -684,7 +658,7 @@
// make this difficult. For now we assume each stream region being lowered
// has a singular affinity that may itself reference multiple devices in the
// future but currently uniquely identifies a device.
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(dispatchOp);
+ auto affinityAttr = IREE::Stream::AffinityAttr::lookupOrDefault(dispatchOp);
// Get the device handle we're executing against in this execution region.
// Note that this is a dynamic value: we have to treat the device as unknown
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.cpp
index 8b628da..5c685c0 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.cpp
@@ -22,7 +22,7 @@
namespace mlir::iree_compiler {
Value lookupDeviceFor(Operation *op, OpBuilder &builder) {
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
+ auto affinityAttr = IREE::Stream::AffinityAttr::lookupOrDefault(op);
auto resolveOp = builder.create<IREE::Stream::ContextResolveOp>(
op->getLoc(),
TypeRange{
@@ -34,7 +34,7 @@
std::tuple<Value, Value> lookupDeviceAndQueueAffinityFor(Operation *op,
OpBuilder &builder) {
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
+ auto affinityAttr = IREE::Stream::AffinityAttr::lookupOrDefault(op);
auto resolveOp = builder.create<IREE::Stream::ContextResolveOp>(
op->getLoc(),
TypeRange{
@@ -46,7 +46,7 @@
}
Value lookupAllocatorFor(Operation *op, OpBuilder &builder) {
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
+ auto affinityAttr = IREE::Stream::AffinityAttr::lookupOrDefault(op);
auto resolveOp = builder.create<IREE::Stream::ContextResolveOp>(
op->getLoc(),
TypeRange{
@@ -58,7 +58,7 @@
std::tuple<Value, Value>
lookupAllocatorAndQueueAffinityFor(Operation *op, OpBuilder &builder) {
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
+ auto affinityAttr = IREE::Stream::AffinityAttr::lookupOrDefault(op);
auto resolveOp = builder.create<IREE::Stream::ContextResolveOp>(
op->getLoc(),
TypeRange{
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir
index 20a9c59..c60d0da 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir
@@ -1,5 +1,7 @@
// RUN: iree-opt --split-input-file --allow-unregistered-dialect --iree-hal-conversion %s | FileCheck %s
+// NOTE: the hal.device.resolve lowering in HAL-to-HAL does most of the work.
+
util.global private @device : !hal.device
// CHECK-LABEL: @contextResolveDefaultDevice
@@ -16,63 +18,12 @@
util.global private @device : !hal.device
-// CHECK-LABEL: @contextResolveDevice
-util.func public @contextResolveDevice() -> !hal.device {
- // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device
- %device = stream.context.resolve on(#hal.device.affinity<@device>) : !hal.device
- // CHECK: util.return %[[DEVICE]]
- util.return %device : !hal.device
-}
-
-// -----
-
-util.global private @device : !hal.device
-
-// CHECK-LABEL: @contextResolveDeviceQueueAffinityAny
-util.func public @contextResolveDeviceQueueAffinityAny() -> (!hal.device, i64) {
- // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device
- // CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant -1 : i64
- %device, %queue_affinity_any = stream.context.resolve on(#hal.device.affinity<@device>) : !hal.device, i64
- // CHECK: util.return %[[DEVICE]], %[[QUEUE_AFFINITY]]
- util.return %device, %queue_affinity_any : !hal.device, i64
-}
-
-// -----
-
-util.global private @device : !hal.device
-
-// CHECK-LABEL: @contextResolveDeviceQueueAffinity45
-util.func public @contextResolveDeviceQueueAffinity45() -> (!hal.device, i64) {
- // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device
- // CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant 48 : i64
- %device, %queue_affinity_45 = stream.context.resolve on(#hal.device.affinity<@device, [4, 5]>) : !hal.device, i64
- // CHECK: util.return %[[DEVICE]], %[[QUEUE_AFFINITY]]
- util.return %device, %queue_affinity_45 : !hal.device, i64
-}
-
-// -----
-
-util.global private @device : !hal.device
-
-// CHECK-LABEL: @contextResolveAllocator
-util.func public @contextResolveAllocator() -> !hal.allocator {
- // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device
- // CHECK-DAG: %[[ALLOCATOR:.+]] = hal.device.allocator<%[[DEVICE]] : !hal.device> : !hal.allocator
- %allocator = stream.context.resolve on(#hal.device.affinity<@device>) : !hal.allocator
- // CHECK: util.return %[[ALLOCATOR]]
- util.return %allocator : !hal.allocator
-}
-
-// -----
-
-util.global private @device : !hal.device
-
// CHECK-LABEL: @contextResolveAllocatorQueueAffinity45
-util.func public @contextResolveAllocatorQueueAffinity45() -> (!hal.allocator, i64) {
+util.func public @contextResolveAllocatorQueueAffinity45() -> (!hal.device, !hal.allocator, i64) {
// CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device
// CHECK-DAG: %[[ALLOCATOR:.+]] = hal.device.allocator<%[[DEVICE]] : !hal.device> : !hal.allocator
// CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant 48 : i64
- %allocator, %queue_affinity_45 = stream.context.resolve on(#hal.device.affinity<@device, [4, 5]>) : !hal.allocator, i64
- // CHECK: util.return %[[ALLOCATOR]], %[[QUEUE_AFFINITY]]
- util.return %allocator, %queue_affinity_45 : !hal.allocator, i64
+ %device, %allocator, %queue_affinity_45 = stream.context.resolve on(#hal.device.affinity<@device, [4, 5]>) : !hal.device, !hal.allocator, i64
+ // CHECK: util.return %[[DEVICE]], %[[ALLOCATOR]], %[[QUEUE_AFFINITY]]
+ util.return %device, %allocator, %queue_affinity_45 : !hal.device, !hal.allocator, i64
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp
index 6d3c4a2..fe32e8b 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp
@@ -448,28 +448,34 @@
// `[targets, ...]` (optional)
do {
IREE::HAL::ExecutableTargetAttr executableTargetAttr;
- if (failed(p.parseAttribute(executableTargetAttr)))
+ if (failed(p.parseAttribute(executableTargetAttr))) {
return {};
+ }
executableTargetAttrs.push_back(executableTargetAttr);
} while (succeeded(p.parseOptionalComma()));
- if (failed(p.parseRSquare()))
+ if (failed(p.parseRSquare())) {
return {};
+ }
} else {
// `{config dict}` (optional)
- if (failed(p.parseAttribute(configAttr)))
+ if (failed(p.parseAttribute(configAttr))) {
return {};
+ }
// `, [targets, ...]` (optional)
if (succeeded(p.parseOptionalComma())) {
- if (failed(p.parseLSquare()))
+ if (failed(p.parseLSquare())) {
return {};
+ }
do {
IREE::HAL::ExecutableTargetAttr executableTargetAttr;
- if (failed(p.parseAttribute(executableTargetAttr)))
+ if (failed(p.parseAttribute(executableTargetAttr))) {
return {};
+ }
executableTargetAttrs.push_back(executableTargetAttr);
} while (succeeded(p.parseOptionalComma()));
- if (failed(p.parseRSquare()))
+ if (failed(p.parseRSquare())) {
return {};
+ }
}
}
}
@@ -502,7 +508,14 @@
}
std::string DeviceTargetAttr::getSymbolNameFragment() {
- return sanitizeSymbolName(getDeviceID().getValue().lower());
+ std::string name = getDeviceID().getValue().lower();
+ if (auto ordinalAttr =
+ dyn_cast_if_present<IntegerAttr>(getConfigurationAttr("ordinal"))) {
+ name += "_";
+ name += std::to_string(ordinalAttr.getInt());
+ name += "_"; // can't have trailing numbers
+ }
+ return sanitizeSymbolName(name);
}
bool DeviceTargetAttr::hasConfigurationAttr(StringRef name) {
@@ -510,6 +523,13 @@
return configAttr && configAttr.get(name);
}
+Attribute DeviceTargetAttr::getConfigurationAttr(StringRef name) {
+ if (auto configAttr = getConfiguration()) {
+ return configAttr.get(name);
+ }
+ return {};
+}
+
void DeviceTargetAttr::getExecutableTargets(
SetVector<IREE::HAL::ExecutableTargetAttr> &resultAttrs) {
for (auto attr : getExecutableTargets()) {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
index 2d10dc3..9511cf8 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
@@ -754,6 +754,8 @@
// Returns true if there's an attribute with the given name in the
// configuration dictionary.
bool hasConfigurationAttr(StringRef name);
+ // Returns the configuration attribute with the given name if found.
+ Attribute getConfigurationAttr(StringRef name);
// Returns zero or more executable targets that this device supports.
void getExecutableTargets(
@@ -916,7 +918,7 @@
}];
let parameters = (ins
- AttrParameter<"FlatSymbolRefAttr", "">:$device,
+ AttrParameter<"SymbolRefAttr", "">:$device,
AttrParameter<"int64_t", "">:$queue_mask
);
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
index 23dab24..bf596ce 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
@@ -90,14 +90,16 @@
for (auto source : op.getSources()) {
auto it =
uniqueSources.insert(std::make_pair(source, orderedSources.size()));
- if (it.second)
+ if (it.second) {
orderedSources.push_back(source);
+ }
resultMapping.push_back(it.first->second);
}
- if (orderedSources.size() == op.getSources().size())
+ if (orderedSources.size() == op.getSources().size()) {
return failure();
- auto newOp = rewriter.create<TensorBarrierOp>(
- op.getLoc(), orderedSources, op.getSignalFence(), op.getAffinityAttr());
+ }
+ auto newOp = rewriter.create<TensorBarrierOp>(op.getLoc(), orderedSources,
+ op.getSignalFence());
SmallVector<Value> newResults;
newResults.reserve(newOp.getNumResults());
for (unsigned newIndex : resultMapping) {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index e28025e..538b32d 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -458,19 +458,6 @@
waitFence, name, affinity);
}
-Value TensorImportOp::getTiedResult(unsigned resultIndex) {
- return IREE::Util::TiedOpInterface::findTiedBaseValue(getSource());
-}
-
-::std::optional<unsigned>
-TensorImportOp::getTiedResultOperandIndex(unsigned resultIndex) {
- return {0}; // source
-}
-
-SmallVector<int64_t> TensorImportOp::getTiedResultOperandIndices() {
- return {0}; // source
-}
-
static LogicalResult verifyTypeStorageCompatibility(Operation *op,
Type encodingType,
Type storageType) {
@@ -539,19 +526,6 @@
affinity);
}
-Value TensorExportOp::getTiedResult(unsigned resultIndex) {
- return IREE::Util::TiedOpInterface::findTiedBaseValue(getSource());
-}
-
-::std::optional<unsigned>
-TensorExportOp::getTiedResultOperandIndex(unsigned resultIndex) {
- return {0}; // source
-}
-
-SmallVector<int64_t> TensorExportOp::getTiedResultOperandIndices() {
- return {0}; // source
-}
-
LogicalResult TensorExportOp::verify() {
TensorExportOp op = *this;
auto sourceType = llvm::cast<TensorType>(op.getSource().getType());
@@ -595,11 +569,10 @@
//===----------------------------------------------------------------------===//
void TensorBarrierOp::build(OpBuilder &builder, OperationState &result,
- ValueRange sources, Value signalFence,
- Attribute affinity) {
+ ValueRange sources, Value signalFence) {
auto resultTypes = llvm::map_to_vector(
sources, [](Value source) { return source.getType(); });
- build(builder, result, resultTypes, sources, signalFence, affinity);
+ build(builder, result, resultTypes, sources, signalFence);
}
Value TensorBarrierOp::getTiedResult(unsigned resultIndex) {
@@ -1064,6 +1037,23 @@
}
//===----------------------------------------------------------------------===//
+// hal.device.resolve
+//===----------------------------------------------------------------------===//
+
+void DeviceResolveOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ for (auto result : getResults()) {
+ if (isa<IREE::HAL::DeviceType>(result.getType())) {
+ setNameFn(result, "device");
+ } else if (isa<IREE::HAL::AllocatorType>(result.getType())) {
+ setNameFn(result, "allocator");
+ } else if (isa<IntegerType>(result.getType())) {
+ setNameFn(result, "queue_affinity");
+ }
+ }
+}
+
+//===----------------------------------------------------------------------===//
// hal.device.allocator
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
index d35a0cb..dff5438 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -98,11 +98,6 @@
def HAL_TensorImportOp : HAL_PureOp<"tensor.import", [
AttrSizedOperandSegments,
- DeclareOpInterfaceMethods<Util_TiedOpInterface, [
- "getTiedResult",
- "getTiedResultOperandIndex",
- "getTiedResultOperandIndices",
- ]>,
Util_ShapeAwareOp,
]> {
let summary = [{imports a tensor from a HAL buffer view}];
@@ -171,11 +166,6 @@
}
def HAL_TensorExportOp : HAL_PureOp<"tensor.export", [
- DeclareOpInterfaceMethods<Util_TiedOpInterface, [
- "getTiedResult",
- "getTiedResultOperandIndex",
- "getTiedResultOperandIndices",
- ]>,
Util_ShapeAwareOp,
]> {
let summary = [{exports a tensor to a HAL buffer view}];
@@ -320,15 +310,13 @@
let arguments = (ins
Variadic<AnyTensor>:$sources,
- HAL_Fence:$signal_fence,
- OptionalAttr<AnyAttr>:$affinity
+ HAL_Fence:$signal_fence
);
let results = (outs
Variadic<AnyTensor>:$results
);
let assemblyFormat = [{
- (`on` `(` $affinity^ `)`)?
`join` `` `(` $sources `:` type($sources) `)`
`=` `` `>`
$signal_fence `:` type($signal_fence)
@@ -338,8 +326,7 @@
let builders = [
OpBuilder<(ins
"ValueRange":$sources,
- "Value":$signalFence,
- "Attribute":$affinity
+ "Value":$signalFence
)>,
];
@@ -1616,6 +1603,41 @@
let opDocGroup = OpGroupDeviceOps in {
+def HAL_DeviceResolveOp : HAL_PureOp<"device.resolve", [
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+]> {
+ let summary = [{resolves device handles based on affinity}];
+ let description = [{
+ Examples:
+ ```
+ // Returns a HAL device.
+ = hal.device.resolve on(#something) : !hal.device
+ // Returns a HAL device, allocator, and (optional) queue affinity.
+ = hal.device.resolve on(#something) : !hal.device, !hal.allocator, i64
+ // Returns a HAL allocator and (optional) queue affinity.
+ = hal.device.resolve on(#something) : !hal.allocator, i64
+ // Returns "any" device. Should only be used as a fallback.
+ = hal.device.resolve : !hal.device
+ ```
+ }];
+
+ let arguments = (ins
+ OptionalAttr<HAL_DeviceAffinityAttr>:$affinity
+ );
+ let results = (outs
+ Variadic<AnyTypeOf<[
+ HAL_Device,
+ HAL_Allocator,
+ I64,
+ ]>>:$results
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ attr-dict `:` type($results)
+ }];
+}
+
def HAL_DeviceAllocatorOp : HAL_PureOp<"device.allocator", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
]> {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp
index 06d2366..eedb427 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp
@@ -49,6 +49,7 @@
: public IREE::HAL::impl::ConvertToHALPassBase<ConvertToHALPass> {
void runOnOperation() override {
auto *context = &getContext();
+ auto moduleOp = getOperation();
// Gather all interfaces from registered dialects.
// These will perform the tensor->buffer mapping for their ops.
@@ -64,8 +65,7 @@
HALTypeConverter typeConverter(conversionInterfaces);
HALConversionTarget conversionTarget(context, typeConverter);
- RewritePatternSet patterns(&getContext());
-
+ RewritePatternSet patterns(context);
populateHALToHALPatterns(context, conversionTarget, typeConverter,
patterns);
populateUtilToHALPatterns(context, conversionTarget, typeConverter,
@@ -84,13 +84,14 @@
// NOTE: we allow ops that we don't know about to allow custom dialects
// that don't need anything HAL-specific to pass through.
- if (failed(applyPartialConversion(getOperation(), conversionTarget,
+ if (failed(applyPartialConversion(moduleOp, conversionTarget,
std::move(patterns)))) {
return signalPassFailure();
}
// Cleanup conversion attributes used for spooky action at a distance.
- for (auto executableOp : getOperation().getOps<IREE::HAL::ExecutableOp>()) {
+ moduleOp->removeAttr("stream.affinity.default");
+ for (auto executableOp : moduleOp.getOps<IREE::HAL::ExecutableOp>()) {
for (auto variantOp :
executableOp.getOps<IREE::HAL::ExecutableVariantOp>()) {
for (auto exportOp :
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
index 6e5f110..7ad1f4a 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
@@ -165,12 +165,30 @@
return map;
}
+static std::pair<Value, Value>
+getDeviceAndQueueAffinity(Location loc, IREE::Stream::AffinityAttr affinityAttr,
+ OpBuilder &builder) {
+ if (auto deviceAffinityAttr =
+ dyn_cast_if_present<IREE::HAL::DeviceAffinityAttr>(affinityAttr)) {
+ auto resolveOp = builder.create<IREE::HAL::DeviceResolveOp>(
+ loc,
+ TypeRange{
+ builder.getType<IREE::HAL::DeviceType>(),
+ builder.getI64Type(),
+ },
+ deviceAffinityAttr);
+ return std::make_pair(resolveOp.getResult(0), resolveOp.getResult(1));
+ }
+ auto device = IREE::HAL::DeviceType::resolveAny(loc, builder);
+ auto queueAffinity = builder.create<arith::ConstantIntOp>(loc, -1, 64);
+ return std::make_pair(device, queueAffinity);
+}
+
// Appends a global hal.buffer initialized to the size required for all
// of the bindings in |dispatchParams| (plus alignment).
-static IREE::Util::GlobalOp
-appendGlobalBuffer(Location loc, StringRef baseName,
- const DispatchParams &dispatchParams,
- OpBuilder &moduleBuilder) {
+static IREE::Util::GlobalOp appendGlobalBuffer(
+ Location loc, StringRef baseName, const DispatchParams &dispatchParams,
+ IREE::Stream::AffinityAttr affinityAttr, OpBuilder &moduleBuilder) {
// Create a global to hold the HAL buffer.
auto globalOp = moduleBuilder.create<IREE::Util::GlobalOp>(
loc, (baseName + "_buffer").str(),
@@ -191,12 +209,12 @@
auto initBuilder = OpBuilder::atBlockBegin(initOp.addEntryBlock());
IndexSet indexSet(loc, initBuilder);
- // TODO(multi-device): support multiple devices in benchmark generation.
- Value device = IREE::HAL::DeviceType::resolveAny(loc, initBuilder);
+ // Resolve allocator for the benchmark device.
+ auto [device, queueAffinity] =
+ getDeviceAndQueueAffinity(loc, affinityAttr, initBuilder);
auto allocator =
initBuilder.create<IREE::HAL::DeviceAllocatorOp>(loc, device).getResult();
- auto queueAffinity = initBuilder.create<arith::ConstantIntOp>(loc, -1, 64);
auto memoryTypes = IREE::HAL::MemoryTypeBitfield::DeviceLocal;
auto bufferUsage = IREE::HAL::BufferUsageBitfield::Transfer |
IREE::HAL::BufferUsageBitfield::DispatchStorage;
@@ -234,8 +252,8 @@
}
// Add a global variable holding an initialized buffer for the dispatch IO.
- auto bufferGlobalOp =
- appendGlobalBuffer(loc, baseName, dispatchParams, moduleBuilder);
+ auto bufferGlobalOp = appendGlobalBuffer(loc, baseName, dispatchParams,
+ affinityAttr, moduleBuilder);
// Create an exported benchmark function that runs the dispatches.
auto funcType =
@@ -261,10 +279,9 @@
auto batchSizeArg = funcBuilder.create<arith::IndexCastOp>(
loc, funcBuilder.getIndexType(), entryBlock->getArgument(0));
- // TODO(multi-device): support multiple devices in benchmark generation.
- // For now we should just use the affinityAttr to resolve the device.
- Value device = IREE::HAL::DeviceType::resolveAny(loc, funcBuilder);
- Value queueAffinity = funcBuilder.create<arith::ConstantIntOp>(loc, -1, 64);
+ // Resolve device for this particular benchmark.
+ auto [device, queueAffinity] =
+ getDeviceAndQueueAffinity(loc, affinityAttr, funcBuilder);
// Create and begin command buffer.
// TODO(benvanik): reuse the command buffer (initialize once and store).
@@ -423,8 +440,9 @@
// would be to generate one module per device dispatches are made on such
// that users can isolate to individual devices. For now we just deal with
// it.
- for (auto globalOp : deviceAnalysis.getDeviceGlobals())
+ for (auto globalOp : deviceAnalysis.getDeviceGlobals()) {
moduleBuilder.clone(*globalOp.getOperation());
+ }
// Clone the executable variant into the new module.
auto executableOp = moduleBuilder.create<IREE::HAL::ExecutableOp>(
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
index e1437b2..221f754 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
@@ -600,6 +600,18 @@
return signalPassFailure();
}
+ // If no devices were defined and there are dispatches in the program then
+ // error out. This provides a better error message than if we were to allow
+ // this pass to no-op and then fail during conversion later on.
+ if (layoutAnalysis.hasDispatches() &&
+ deviceAnalysis.getDeviceGlobals().empty()) {
+ mlir::emitError(moduleOp.getLoc())
+ << "no HAL devices defined in the module; use the module-level "
+ "hal.device.targets attribute, the --iree-hal-target-device= "
+ "flag, or provide inputs with global !hal.devices defined";
+ return signalPassFailure();
+ }
+
// Gather the required executable targets per executable and dispatch site.
auto requiredExecutableTargets = buildRequiredExecutableTargetsMap(
moduleOp, deviceAnalysis, layoutAnalysis);
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTargetDevices.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTargetDevices.cpp
index 67bf383..5d70f1b 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTargetDevices.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTargetDevices.cpp
@@ -142,36 +142,32 @@
// have one set.
static void assignDefaultDeviceAffinity(mlir::ModuleOp moduleOp,
FlatSymbolRefAttr defaultDeviceRef) {
- Builder builder(moduleOp);
- auto affinityName = builder.getStringAttr("stream.affinity");
- auto affinityAttr = builder.getAttr<IREE::HAL::DeviceAffinityAttr>(
- defaultDeviceRef, /*queue_mask=*/-1ll);
+ auto affinityAttr = IREE::HAL::DeviceAffinityAttr::get(
+ moduleOp.getContext(), defaultDeviceRef, /*queue_mask=*/-1ll);
- // TODO(benvanik): make this an interface that can be registered on types.
- auto isAnnotatableType = [](Type type) {
- return isa<TensorType>(type) || isa<IREE::Stream::ResourceType>(type);
- };
- for (auto &op : moduleOp.getOps()) {
- bool shouldAnnotate = true;
- if (auto globalOp = dyn_cast<IREE::Util::GlobalOpInterface>(op)) {
- if (!isAnnotatableType(globalOp.getGlobalType())) {
- shouldAnnotate = false;
- }
- } else if (op.hasTrait<OpTrait::SymbolTable>()) {
- // Symbol table ops can't reference parent symbols properly.
- shouldAnnotate = false;
- }
- if (!shouldAnnotate) {
- continue; // skip op
- }
+ // Default on the module that applies to any ops that don't otherwise have a
+ // placement. Ideally we never need this but some programs may take/return no
+ // tensors or have tensors come from unattributed containers (lists/dicts).
+ moduleOp->setAttr("stream.affinity.default", affinityAttr);
- if (auto affinityOp = dyn_cast<IREE::Stream::AffinityOpInterface>(op)) {
- if (!affinityOp.getAffinityAttr()) {
- affinityOp.setAffinityAttr(affinityAttr);
+ // Set all arg/results to route through the default device unless they've
+ // already been assigned.
+ auto affinityName = StringAttr::get(moduleOp.getContext(), "stream.affinity");
+ for (auto funcOp : moduleOp.getOps<FunctionOpInterface>()) {
+ if (funcOp.isPublic()) {
+ for (auto arg : funcOp.getArguments()) {
+ if (isa<IREE::Stream::AffinityTypeInterface>(arg.getType())) {
+ if (!funcOp.getArgAttr(arg.getArgNumber(), affinityName)) {
+ funcOp.setArgAttr(arg.getArgNumber(), affinityName, affinityAttr);
+ }
+ }
}
- } else {
- if (!op.hasAttr(affinityName)) {
- op.setAttr(affinityName, affinityAttr);
+ for (auto result : llvm::enumerate(funcOp.getResultTypes())) {
+ if (isa<IREE::Stream::AffinityTypeInterface>(result.value())) {
+ if (!funcOp.getResultAttr(result.index(), affinityName)) {
+ funcOp.setResultAttr(result.index(), affinityName, affinityAttr);
+ }
+ }
}
}
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
index 773ed92..54a87b0 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
@@ -198,11 +198,17 @@
passManager.addPass(IREE::HAL::createAssignTargetDevicesPass(
{assignmentOptions.targetDevices}));
}
+
+ // Create globals for each device (if needed).
passManager.addPass(IREE::HAL::createMaterializeTargetDevicesPass(
{assignmentOptions.defaultDevice}));
+
+ // Resolve #hal.device.promise and #hal.device.alias attributes.
passManager.addPass(IREE::HAL::createResolveDevicePromisesPass());
passManager.addPass(
IREE::HAL::createResolveDeviceAliasesPass({&targetRegistry}));
+
+ // Verify devices are valid.
passManager.addPass(IREE::HAL::createVerifyDevicesPass({&targetRegistry}));
}
@@ -222,6 +228,9 @@
// and initial interface analysis (we rely on CSE and such having been run).
addCleanupPatterns(passManager);
+ // Verify devices are valid.
+ passManager.addPass(IREE::HAL::createVerifyDevicesPass({&targetRegistry}));
+
//----------------------------------------------------------------------------
// Device-specific interface materialization
//----------------------------------------------------------------------------
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/VerifyDevices.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/VerifyDevices.cpp
index e1ca624..6ac9cc8 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/VerifyDevices.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/VerifyDevices.cpp
@@ -127,13 +127,38 @@
return signalPassFailure();
}
- // Must have at least one device specified.
- if (deviceAnalysis.getDeviceGlobals().empty()) {
+ // Devices are only required if we have dialects we may lower into device
+ // code. For now checking for tensor types is probably sufficient though we
+ // may want a pluggable way to decide this (e.g. dialect/type/op
+ // interfaces).
+ auto isTensor = [](Type type) { return isa<TensorType>(type); };
+ bool anyTensors = false;
+ for (auto &op : moduleOp.getOps()) {
+ if (op.hasTrait<OpTrait::IREE::Util::ObjectLike>()) {
+ continue; // ignore executables
+ }
+ op.walk([&](Operation *childOp) {
+ if (llvm::any_of(childOp->getOperandTypes(), isTensor) ||
+ llvm::any_of(childOp->getResultTypes(), isTensor)) {
+ anyTensors = true;
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
+ }
+ // TODO(multi-device): the logic above is insufficient; we only need devices
+ // if the program will end up requiring them but we don't know that here.
+ // We have to wait until we've lowered to the point where we do require a
+ // device _and_ we actually want one (aren't compiling a non-HAL program).
+ // We could probably have an op interface, better output from the pass that
+ // requires the devices, etc. For now we error out in HAL conversion when we
+ // try to resolve devices.
+ if (false && anyTensors && deviceAnalysis.getDeviceGlobals().empty()) {
auto diagnostic = moduleOp.emitError();
diagnostic
<< "no HAL devices defined in the module; use the module-level "
"hal.device.targets attribute, the --iree-hal-target-device= "
- "flags, or provide inputs with global !hal.devices defined; ";
+ "flag, or provide inputs with global !hal.devices defined; ";
printAvailable(diagnostic, *targetRegistry.value);
return signalPassFailure();
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_target_devices.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_target_devices.mlir
index 61926d3..adce11f 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_target_devices.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_target_devices.mlir
@@ -4,7 +4,7 @@
// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=device_a[0],device_a[1]})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-ORDINALS
// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=#hal.device.target<"local">})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-ATTR
// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=#hal.device.alias<"device_a">})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-ALIAS
-// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices="device_a,#hal.device.alias<"device_b">"})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-SELECT
+// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices={"device_a,#hal.device.alias<"device_b">"}})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-SELECT
// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=device_a=#hal.device.alias<"device_a">,"device_bc=device_b,#hal.device.alias<"device_c">"})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-SELECT-MULTI
// CHECK: module
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_target_devices.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_target_devices.mlir
index 11b3918..6360abe 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_target_devices.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_target_devices.mlir
@@ -33,6 +33,7 @@
// CHECK: module @module
// CHECK-NOT: hal.device.targets
+// CHECK-SAME: stream.affinity.default = #hal.device.affinity<@__device_0>
module @module attributes {
hal.device.targets = [
#hal.device.select<[#device_a, #device_b]> : !hal.device
@@ -42,18 +43,6 @@
// CHECK-SAME: #[[DEVICE_A]],
// CHECK-SAME: #[[DEVICE_B]]
// CHECK-SAME: ]> : !hal.device
-
- // CHECK: util.global private @tensor_global
- // CHECK-SAME: stream.affinity = #hal.device.affinity<@__device_0>
- util.global private @tensor_global : tensor<4xf32>
-
- // CHECK: util.global private @primitive_global
- // CHECK-NOT: stream.affinity
- util.global private @primitive_global : i32
-
- // CHECK: util.func private @func
- // CHECK-SAME: stream.affinity = #hal.device.affinity<@__device_0>
- util.func private @func() -> ()
}
// -----
@@ -69,6 +58,7 @@
// CHECK: module @module
// CHECK-NOT: hal.device.targets
+// CHECK-SAME: stream.affinity.default = #hal.device.affinity<@device_a>
module @module attributes {
hal.device.targets = {
device_a = #device_a,
@@ -77,10 +67,6 @@
} {
// CHECK: util.global private @device_a = #[[DEVICE_A]]
// CHECK: util.global private @device_bc = #hal.device.select<[#[[DEVICE_B]], #[[DEVICE_C]]]>
-
- // CHECK: util.global private @tensor_global
- // CHECK-SAME: stream.affinity = #hal.device.affinity<@device_a>
- util.global private @tensor_global : tensor<4xf32>
}
// -----
@@ -94,6 +80,7 @@
// CHECK: module @module
// CHECK-NOT: hal.device.targets
+// CHECK-SAME: stream.affinity.default = #hal.device.affinity<@device_b>
module @module attributes {
hal.device.targets = {
device_a = #device_a,
@@ -103,10 +90,6 @@
} {
// CHECK: util.global private @device_a
// CHECK: util.global private @device_b
-
- // CHECK: util.global private @tensor_global
- // CHECK-SAME: stream.affinity = #hal.device.affinity<@device_b>
- util.global private @tensor_global : tensor<4xf32>
}
// -----
@@ -120,6 +103,7 @@
// CHECK: module @module
// CHECK-NOT: hal.device.targets
+// CHECK-SAME: stream.affinity.default = #hal.device.affinity<@__device_1>
module @module attributes {
hal.device.targets = [
#device_a,
@@ -129,9 +113,4 @@
} {
// CHECK: util.global private @__device_0
// CHECK: util.global private @__device_1
-
- // CHECK: util.global private @tensor_global
- // CHECK-SAME: stream.affinity = #hal.device.affinity<@__device_1>
- util.global private @tensor_global : tensor<4xf32>
}
-
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/verify_devices.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/verify_devices.mlir
index 4511a0b..b4e2264 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/verify_devices.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/verify_devices.mlir
@@ -1,12 +1,27 @@
// RUN: iree-opt --split-input-file --iree-hal-verify-devices %s --mlir-print-local-scope --verify-diagnostics | FileCheck %s
-// expected-error@+1 {{no HAL devices defined in the module}}
+// Tests that modules without tensors don't need devices.
+
module @module {
+ // CHECK: util.func private @func
util.func private @func() -> ()
}
// -----
+// TODO(multi-device): find a way to verify that devices exist if they need to.
+// Currently the check is disabled as it's difficult to tell if a device will be
+// needed by the time we get to the HAL layer: plugins may absorb things, etc.
+// NO-expected-errorx@+1 {{no HAL devices defined in the module}}
+module @module {
+ util.func private @func() -> () {
+ arith.constant dense<1.0> : tensor<4xf32>
+ util.return
+ }
+}
+
+// -----
+
module @module {
// expected-error@+1 {{unregistered target device "__unregistered__"}}
util.global private @device = #hal.device.target<"__unregistered__"> : !hal.device
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Conversion/BUILD.bazel
index fbc0e51..10d1456 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/BUILD.bazel
@@ -22,6 +22,7 @@
],
deps = [
"//compiler/src/iree/compiler/Dialect/Flow/IR",
+ "//compiler/src/iree/compiler/Dialect/Stream/Analysis",
"//compiler/src/iree/compiler/Dialect/Stream/IR",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FunctionInterfaces",
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Conversion/CMakeLists.txt
index 05bbb79..cc472aa 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/CMakeLists.txt
@@ -24,6 +24,7 @@
MLIRTransformUtils
MLIRTransforms
iree::compiler::Dialect::Flow::IR
+ iree::compiler::Dialect::Stream::Analysis
iree::compiler::Dialect::Stream::IR
PUBLIC
)
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
index bdc5aaf..31d6151 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
@@ -35,41 +35,42 @@
}
struct ConvertTensorConstantOp
- : public OpConversionPattern<IREE::Flow::TensorConstantOp> {
+ : public AffinityOpConversionPattern<IREE::Flow::TensorConstantOp> {
public:
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Flow::TensorConstantOp constantOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::Flow::TensorConstantOp constantOp, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
// Capture the tensor constant strongly typed with constant lifetime.
- Type constantType = IREE::Stream::ResourceType::get(
- getContext(), IREE::Stream::Lifetime::Constant);
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(constantOp);
+ auto constantType = rewriter.getType<IREE::Stream::ResourceType>(
+ IREE::Stream::Lifetime::Constant);
auto newOp = rewriter.create<IREE::Stream::TensorConstantOp>(
constantOp.getLoc(), constantType,
convertAttributeToStream(constantOp.getValue()),
- TypeAttr::get(constantOp.getType()), ValueRange{}, affinityAttr);
+ TypeAttr::get(constantOp.getType()), ValueRange{},
+ executionAffinityAttr);
// Transfer to unknown lifetime.
- Type unknownType = IREE::Stream::ResourceType::get(getContext());
+ auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
auto constantSize = rewriter.createOrFold<IREE::Stream::ResourceSizeOp>(
constantOp.getLoc(), rewriter.getIndexType(), newOp.getResult());
rewriter.replaceOpWithNewOp<IREE::Stream::AsyncTransferOp>(
constantOp, unknownType, newOp.getResult(), constantSize, constantSize,
- /*source_affinity=*/affinityAttr,
- /*result_affinity=*/affinityAttr);
+ /*source_affinity=*/executionAffinityAttr,
+ /*result_affinity=*/executionAffinityAttr);
return success();
}
};
struct ConvertTensorDynamicConstantOp
- : public OpConversionPattern<IREE::Flow::TensorDynamicConstantOp> {
+ : public AffinityOpConversionPattern<IREE::Flow::TensorDynamicConstantOp> {
public:
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Flow::TensorDynamicConstantOp constantOp,
- OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::Flow::TensorDynamicConstantOp constantOp, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
auto attrType = dyn_cast<RankedTensorType>(constantOp.getValue().getType());
if (!attrType)
return failure();
@@ -91,22 +92,21 @@
}
// Capture the tensor constant strongly typed with constant lifetime.
- Type constantType = IREE::Stream::ResourceType::get(
- getContext(), IREE::Stream::Lifetime::Constant);
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(constantOp);
+ auto constantType = rewriter.getType<IREE::Stream::ResourceType>(
+ IREE::Stream::Lifetime::Constant);
auto newOp = rewriter.create<IREE::Stream::TensorConstantOp>(
constantOp.getLoc(), constantType,
convertAttributeToStream(constantOp.getValue()),
- TypeAttr::get(resultType), dynamicDims, affinityAttr);
+ TypeAttr::get(resultType), dynamicDims, executionAffinityAttr);
// Transfer to unknown lifetime.
- Type unknownType = IREE::Stream::ResourceType::get(getContext());
+ auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
auto constantSize = rewriter.createOrFold<IREE::Stream::ResourceSizeOp>(
constantOp.getLoc(), rewriter.getIndexType(), newOp.getResult());
rewriter.replaceOpWithNewOp<IREE::Stream::AsyncTransferOp>(
constantOp, unknownType, newOp.getResult(), constantSize, constantSize,
- /*source_affinity=*/affinityAttr,
- /*result_affinity=*/affinityAttr);
+ /*source_affinity=*/executionAffinityAttr,
+ /*result_affinity=*/executionAffinityAttr);
return success();
}
};
@@ -114,157 +114,169 @@
// Reshapes and bitcasts become clones here to preserve shape/element type
// information (which may become actual transfers depending on source/target
// shape) - they'll be elided if not needed.
+//
+// NOTE: we transfer to the target before cloning. This may not be optimal
+// as the clone may otherwise have been able to be elided on the producer
+// side but we leave that for future copy elision to determine.
template <typename CastOpTy>
-struct ConvertTensorCastLikeOp : public OpConversionPattern<CastOpTy> {
- using OpConversionPattern<CastOpTy>::OpConversionPattern;
+struct ConvertTensorCastLikeOp
+ : public AffinityAwareConversionPattern<CastOpTy> {
+ using AffinityAwareConversionPattern<
+ CastOpTy>::AffinityAwareConversionPattern;
LogicalResult
matchAndRewrite(CastOpTy op, typename CastOpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
- auto source =
- consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter);
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
+ auto resultAffinityAttr = this->lookupResultAffinity(op.getResult());
+ auto source = this->transferTensorOperand(op.getLoc(), op.getSource(),
+ adaptor.getSource(),
+ resultAffinityAttr, rewriter);
auto resultSize =
buildResultSizeOf(op.getLoc(), op.getResult(), op.getResultDims(),
- affinityAttr, rewriter);
+ resultAffinityAttr, rewriter);
+ auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
rewriter.replaceOpWithNewOp<IREE::Stream::TensorCloneOp>(
op, unknownType, source.resource, op.getSource().getType(),
op.getSourceDims(), source.resourceSize, op.getResult().getType(),
- adaptor.getResultDims(), resultSize, affinityAttr);
+ adaptor.getResultDims(), resultSize, resultAffinityAttr);
return success();
}
};
struct ConvertTensorAllocaOp
- : public OpConversionPattern<IREE::Flow::TensorAllocaOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Flow::TensorAllocaOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Type unknownType = IREE::Stream::ResourceType::get(getContext());
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
+ : public AffinityOpConversionPattern<IREE::Flow::TensorAllocaOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::Flow::TensorAllocaOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
auto resultSize =
buildResultSizeOf(op.getLoc(), op.getResult(), op.getResultDims(),
- affinityAttr, rewriter);
+ executionAffinityAttr, rewriter);
+ auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
rewriter.replaceOpWithNewOp<IREE::Stream::AsyncAllocaOp>(
- op, unknownType, resultSize, affinityAttr);
+ op, unknownType, resultSize, executionAffinityAttr);
return success();
}
};
struct ConvertTensorEmptyOp
- : public OpConversionPattern<IREE::Flow::TensorEmptyOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Flow::TensorEmptyOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Type unknownType = IREE::Stream::ResourceType::get(getContext());
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
+ : public AffinityOpConversionPattern<IREE::Flow::TensorEmptyOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::Flow::TensorEmptyOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
auto resultSize =
buildResultSizeOf(op.getLoc(), op.getResult(), op.getResultDims(),
- affinityAttr, rewriter);
+ executionAffinityAttr, rewriter);
+ auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
rewriter.replaceOpWithNewOp<IREE::Stream::TensorEmptyOp>(
op, unknownType, op.getResult().getType(), adaptor.getResultDims(),
- resultSize, affinityAttr);
+ resultSize, executionAffinityAttr);
return success();
}
};
struct ConvertTensorSplatOp
- : public OpConversionPattern<IREE::Flow::TensorSplatOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Flow::TensorSplatOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
+ : public AffinityOpConversionPattern<IREE::Flow::TensorSplatOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::Flow::TensorSplatOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
auto resultSize =
buildResultSizeOf(op.getLoc(), op.getResult(), op.getResultDims(),
- affinityAttr, rewriter);
+ executionAffinityAttr, rewriter);
+ auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
rewriter.replaceOpWithNewOp<IREE::Stream::TensorSplatOp>(
op, unknownType, adaptor.getValue(), op.getResult().getType(),
- adaptor.getResultDims(), resultSize, affinityAttr);
+ adaptor.getResultDims(), resultSize, executionAffinityAttr);
return success();
}
};
struct ConvertTensorCloneOp
- : public OpConversionPattern<IREE::Flow::TensorCloneOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Flow::TensorCloneOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ : public AffinityOpConversionPattern<IREE::Flow::TensorCloneOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::Flow::TensorCloneOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
+ auto operand = transferTensorOperand(op.getLoc(), op.getOperand(),
+ adaptor.getOperand(),
+ executionAffinityAttr, rewriter);
auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
- auto operand =
- consumeTensorOperand(op.getLoc(), adaptor.getOperand(), rewriter);
rewriter.replaceOpWithNewOp<IREE::Stream::TensorCloneOp>(
op, unknownType, operand.resource, op.getOperand().getType(),
op.getArgumentDims(), operand.resourceSize, op.getResult().getType(),
- adaptor.getArgumentDims(), operand.resourceSize,
- IREE::Stream::AffinityAttr::lookup(op));
+ adaptor.getArgumentDims(), operand.resourceSize, executionAffinityAttr);
return success();
}
};
struct ConvertTensorTransferOp
- : public OpConversionPattern<IREE::Flow::TensorTransferOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Flow::TensorTransferOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto targetAffinityAttr =
- dyn_cast<IREE::Stream::AffinityAttr>(adaptor.getTarget());
- if (!targetAffinityAttr)
+ : public AffinityOpConversionPattern<IREE::Flow::TensorTransferOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::Flow::TensorTransferOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!executionAffinityAttr) {
return rewriter.notifyMatchFailure(op, "invalid stream affinity attr");
+ }
+ auto operand = resolveTensorOperand(op.getLoc(), op.getOperand(),
+ adaptor.getOperand(), rewriter);
auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
- auto operand =
- consumeTensorOperand(op.getLoc(), adaptor.getOperand(), rewriter);
rewriter.replaceOpWithNewOp<IREE::Stream::AsyncTransferOp>(
op, unknownType, operand.resource, operand.resourceSize,
operand.resourceSize,
- /*source_affinity=*/IREE::Stream::AffinityAttr{}, targetAffinityAttr);
+ /*source_affinity=*/operand.affinity,
+ /*result_affinity=*/executionAffinityAttr);
return success();
}
};
struct ConvertTensorSliceOp
- : public OpConversionPattern<IREE::Flow::TensorSliceOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Flow::TensorSliceOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
+ : public AffinityOpConversionPattern<IREE::Flow::TensorSliceOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::Flow::TensorSliceOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
auto source =
- consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter);
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
+ transferTensorOperand(op.getLoc(), op.getSource(), adaptor.getSource(),
+ executionAffinityAttr, rewriter);
auto resultSize =
buildResultSizeOf(op.getLoc(), op.getResult(), op.getResultDims(),
- affinityAttr, rewriter);
+ executionAffinityAttr, rewriter);
+ auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
rewriter.replaceOpWithNewOp<IREE::Stream::TensorSliceOp>(
op, unknownType, source.resource, op.getSource().getType(),
op.getSourceDims(), source.resourceSize, adaptor.getStartIndices(),
adaptor.getLengths(), op.getResult().getType(), adaptor.getResultDims(),
- resultSize, affinityAttr);
+ resultSize, executionAffinityAttr);
return success();
}
};
struct ConvertTensorUpdateOp
- : public OpConversionPattern<IREE::Flow::TensorUpdateOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Flow::TensorUpdateOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto update =
- consumeTensorOperand(op.getLoc(), adaptor.getUpdate(), rewriter);
+ : public AffinityOpConversionPattern<IREE::Flow::TensorUpdateOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::Flow::TensorUpdateOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
auto target =
- consumeTensorOperand(op.getLoc(), adaptor.getTarget(), rewriter);
+ transferTensorOperand(op.getLoc(), op.getTarget(), adaptor.getTarget(),
+ executionAffinityAttr, rewriter);
+ auto update =
+ transferTensorOperand(op.getLoc(), op.getUpdate(), adaptor.getUpdate(),
+ executionAffinityAttr, rewriter);
rewriter.replaceOpWithNewOp<IREE::Stream::TensorUpdateOp>(
op, target.resource.getType(), target.resource,
op.getTarget().getType(), adaptor.getTargetDims(), target.resourceSize,
adaptor.getStartIndices(), update.resource, op.getUpdate().getType(),
- op.getUpdateDims(), update.resourceSize,
- IREE::Stream::AffinityAttr::lookup(op));
+ op.getUpdateDims(), update.resourceSize, executionAffinityAttr);
return success();
}
};
@@ -281,14 +293,13 @@
}
struct ConvertTensorLoadOp
- : public OpConversionPattern<IREE::Flow::TensorLoadOp> {
- using OpConversionPattern::OpConversionPattern;
+ : public AffinityAwareConversionPattern<IREE::Flow::TensorLoadOp> {
+ using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
LogicalResult
matchAndRewrite(IREE::Flow::TensorLoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto resultType = getTypeConverter()->convertType(op.getResult().getType());
- auto source =
- consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter);
+ auto source = resolveTensorOperand(op.getLoc(), op.getSource(),
+ adaptor.getSource(), rewriter);
// If the source is not a staging resource then we need to transfer it to
// a staging resource. We slice out just what is being loaded so that we
@@ -299,6 +310,7 @@
// If already a staging resource then we can fast-path load the value.
auto stagingType = rewriter.getType<IREE::Stream::ResourceType>(
IREE::Stream::Lifetime::Staging);
+ auto resultType = getTypeConverter()->convertType(op.getResult().getType());
if (source.resource.getType() == stagingType) {
rewriter.replaceOpWithNewOp<IREE::Stream::TensorLoadOp>(
op, resultType, source.resource, op.getSource().getType(),
@@ -306,16 +318,14 @@
return success();
}
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
-
// Scalar tensors get transferred without slicing.
auto sourceEncoding = op.getSource().getType();
if (isScalarTensor(sourceEncoding)) {
auto transferOp = rewriter.create<IREE::Stream::AsyncTransferOp>(
op.getLoc(), stagingType, source.resource, source.resourceSize,
source.resourceSize,
- /*source_affinity=*/IREE::Stream::AffinityAttr::lookup(op),
- /*result_affinity=*/IREE::Stream::AffinityAttr::lookup(op));
+ /*source_affinity=*/source.affinity,
+ /*result_affinity=*/source.affinity);
rewriter.replaceOpWithNewOp<IREE::Stream::TensorLoadOp>(
op, resultType, transferOp.getResult(), sourceEncoding,
adaptor.getSourceDims(), transferOp.getResultSize(),
@@ -341,16 +351,17 @@
RankedTensorType::get(resultDims, sourceEncoding.getElementType(),
sourceEncoding.getEncoding());
Value resultSize = rewriter.create<IREE::Stream::TensorSizeOfOp>(
- op.getLoc(), resultEncoding, ValueRange{}, affinityAttr);
+ op.getLoc(), resultEncoding, ValueRange{}, source.affinity);
auto sliceOp = rewriter.create<IREE::Stream::TensorSliceOp>(
op.getLoc(), source.resource.getType(), source.resource, sourceEncoding,
adaptor.getSourceDims(), source.resourceSize, sliceIndices,
- sliceLengths, resultEncoding, ValueRange{}, resultSize, affinityAttr);
+ sliceLengths, resultEncoding, ValueRange{}, resultSize,
+ source.affinity);
auto transferOp = rewriter.create<IREE::Stream::AsyncTransferOp>(
op.getLoc(), stagingType, sliceOp.getResult(), sliceOp.getResultSize(),
sliceOp.getResultSize(),
- /*source_affinity=*/IREE::Stream::AffinityAttr::lookup(op),
- /*result_affinity=*/IREE::Stream::AffinityAttr::lookup(op));
+ /*source_affinity=*/source.affinity,
+ /*result_affinity=*/source.affinity);
rewriter.replaceOpWithNewOp<IREE::Stream::TensorLoadOp>(
op, resultType, transferOp.getResult(), sliceOp.getResultEncoding(),
sliceOp.getResultEncodingDims(), transferOp.getResultSize(),
@@ -360,13 +371,13 @@
};
struct ConvertTensorStoreOp
- : public OpConversionPattern<IREE::Flow::TensorStoreOp> {
- using OpConversionPattern::OpConversionPattern;
+ : public AffinityAwareConversionPattern<IREE::Flow::TensorStoreOp> {
+ using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
LogicalResult
matchAndRewrite(IREE::Flow::TensorStoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto target =
- consumeTensorOperand(op.getLoc(), adaptor.getTarget(), rewriter);
+ auto target = resolveTensorOperand(op.getLoc(), op.getTarget(),
+ adaptor.getTarget(), rewriter);
// If the target is a staging resource then we can directly store into it
// with a fast-path. Otherwise we need to stage an upload.
@@ -380,34 +391,23 @@
return success();
}
- // Scalar tensors disconnect from the original target.
- auto targetEncoding = op.getTarget().getType();
- if (isScalarTensor(targetEncoding)) {
- rewriter.replaceOpWithNewOp<IREE::Stream::TensorSplatOp>(
- op, target.resource.getType(), adaptor.getValue(), targetEncoding,
- adaptor.getTargetDims(), target.resourceSize,
- IREE::Stream::AffinityAttr::lookup(op));
- return success();
- }
-
// Use fill to store the value.
// TODO(benvanik): support larger buffer slices (stage + update).
IndexSet indexSet(op.getLoc(), rewriter);
indexSet.populate(adaptor.getIndices());
- SmallVector<Value> lengths;
- for (auto index : adaptor.getIndices())
- lengths.push_back(indexSet.get(1));
+ SmallVector<Value> lengths(adaptor.getIndices().size(), indexSet.get(1));
+ auto targetEncoding = op.getTarget().getType();
rewriter.replaceOpWithNewOp<IREE::Stream::TensorFillOp>(
op, target.resource, targetEncoding, adaptor.getTargetDims(),
target.resourceSize, adaptor.getIndices(), lengths, adaptor.getValue(),
- IREE::Stream::AffinityAttr::lookup(op));
+ target.affinity);
return success();
}
};
struct ConvertTensorTraceOp
- : public OpConversionPattern<IREE::Flow::TensorTraceOp> {
- using OpConversionPattern::OpConversionPattern;
+ : public AffinityAwareConversionPattern<IREE::Flow::TensorTraceOp> {
+ using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
LogicalResult
matchAndRewrite(IREE::Flow::TensorTraceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -416,8 +416,8 @@
SmallVector<Attribute> resourceEncodings;
for (auto [tensorOperand, resourceOperand] :
llvm::zip_equal(op.getValues(), adaptor.getValues())) {
- auto source =
- consumeTensorOperand(op.getLoc(), resourceOperand, rewriter);
+ auto source = resolveTensorOperand(op.getLoc(), tensorOperand,
+ resourceOperand, rewriter);
auto stagingType = rewriter.getType<IREE::Stream::ResourceType>(
IREE::Stream::Lifetime::Staging);
auto traceSource = source.resource;
@@ -425,13 +425,14 @@
traceSource = rewriter.create<IREE::Stream::AsyncTransferOp>(
op.getLoc(), stagingType, source.resource, source.resourceSize,
source.resourceSize,
- /*source_affinity=*/IREE::Stream::AffinityAttr::lookup(op),
- /*result_affinity=*/nullptr);
+ /*source_affinity=*/source.affinity,
+ /*result_affinity=*/source.affinity);
}
resources.push_back(traceSource);
resourceSizes.push_back(source.resourceSize);
resourceEncodings.push_back(TypeAttr::get(tensorOperand.getType()));
}
+
rewriter.replaceOpWithNewOp<IREE::Stream::TensorTraceOp>(
op, adaptor.getKey(), resources, resourceSizes,
rewriter.getArrayAttr(resourceEncodings), adaptor.getValueDims());
@@ -440,16 +441,18 @@
};
struct ConvertChannelDefaultOp
- : public OpConversionPattern<IREE::Flow::ChannelDefaultOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Flow::ChannelDefaultOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ : public AffinityOpConversionPattern<IREE::Flow::ChannelDefaultOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::Flow::ChannelDefaultOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<IREE::Stream::ChannelCreateOp>(
- op, /*id=*/Value{},
+ op,
+ /*id=*/Value{},
/*group=*/adaptor.getGroupAttr(),
/*rank=*/Value{},
- /*count=*/Value{}, IREE::Stream::AffinityAttr::lookup(op));
+ /*count=*/Value{}, executionAffinityAttr);
return success();
}
};
@@ -491,164 +494,190 @@
};
struct ConvertAllGatherOp
- : public OpConversionPattern<IREE::Flow::CollectiveAllGatherOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Flow::CollectiveAllGatherOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto shape = llvm::cast<ShapedType>(op.getSource().getType());
- auto collectiveAttr = IREE::Stream::CollectiveAttr::get(
- op.getContext(), IREE::Stream::CollectiveKind::AllGather,
+ : public AffinityOpConversionPattern<IREE::Flow::CollectiveAllGatherOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::Flow::CollectiveAllGatherOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
+ auto collectiveAttr = rewriter.getAttr<IREE::Stream::CollectiveAttr>(
+ IREE::Stream::CollectiveKind::AllGather,
/*reduction=*/std::nullopt,
static_cast<IREE::Stream::CollectiveElementType>(op.getElementType()));
auto zeroOffset = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
auto elementCount = rewriter.create<arith::ConstantIndexOp>(
- op.getLoc(), shape.getNumElements());
+ op.getLoc(), op.getType().getNumElements());
auto newTargetCast =
- consumeTensorOperand(op.getLoc(), adaptor.getTarget(), rewriter);
+ transferTensorOperand(op.getLoc(), op.getTarget(), adaptor.getTarget(),
+ executionAffinityAttr, rewriter);
auto newSourceCast =
- consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter);
+ transferTensorOperand(op.getLoc(), op.getSource(), adaptor.getSource(),
+ executionAffinityAttr, rewriter);
rewriter.replaceOpWithNewOp<IREE::Stream::AsyncCollectiveOp>(
- op, collectiveAttr, newTargetCast.resource,
+ op, collectiveAttr,
+ /*target=*/newTargetCast.resource,
/*target_size=*/newTargetCast.resourceSize,
/*target_offset=*/zeroOffset,
/*target_end=*/newTargetCast.resourceSize,
- /*target_length=*/newTargetCast.resourceSize, newSourceCast.resource,
+ /*target_length=*/newTargetCast.resourceSize,
+ /*source=*/newSourceCast.resource,
/*source_size=*/newSourceCast.resourceSize,
- /*source_offset=*/zeroOffset, /*source_end=*/newSourceCast.resourceSize,
- /*source_length=*/newSourceCast.resourceSize, elementCount,
- adaptor.getChannel(),
- /*param=*/mlir::Value(), IREE::Stream::AffinityAttr::lookup(op));
+ /*source_offset=*/zeroOffset,
+ /*source_end=*/newSourceCast.resourceSize,
+ /*source_length=*/newSourceCast.resourceSize,
+ /*element_count=*/elementCount,
+ /*channel=*/adaptor.getChannel(),
+ /*param=*/mlir::Value(), executionAffinityAttr);
return success();
}
};
struct ConvertAllReduceOp
- : public OpConversionPattern<IREE::Flow::CollectiveAllReduceOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Flow::CollectiveAllReduceOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto shape = llvm::cast<ShapedType>(op.getType());
- auto collectiveAttr = IREE::Stream::CollectiveAttr::get(
- op.getContext(), IREE::Stream::CollectiveKind::AllReduce,
+ : public AffinityOpConversionPattern<IREE::Flow::CollectiveAllReduceOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::Flow::CollectiveAllReduceOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
+ auto collectiveAttr = rewriter.getAttr<IREE::Stream::CollectiveAttr>(
+ IREE::Stream::CollectiveKind::AllReduce,
static_cast<IREE::Stream::CollectiveReductionOp>(op.getReductionOp()),
static_cast<IREE::Stream::CollectiveElementType>(op.getElementType()));
auto zeroOffset = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
auto elementCount = rewriter.create<arith::ConstantIndexOp>(
- op.getLoc(), shape.getNumElements());
+ op.getLoc(), op.getType().getNumElements());
auto newTargetCast =
- consumeTensorOperand(op.getLoc(), adaptor.getTarget(), rewriter);
+ transferTensorOperand(op.getLoc(), op.getTarget(), adaptor.getTarget(),
+ executionAffinityAttr, rewriter);
auto newSourceCast =
- consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter);
+ transferTensorOperand(op.getLoc(), op.getSource(), adaptor.getSource(),
+ executionAffinityAttr, rewriter);
rewriter.replaceOpWithNewOp<IREE::Stream::AsyncCollectiveOp>(
- op, collectiveAttr, newTargetCast.resource,
+ op, collectiveAttr,
+ /*target=*/newTargetCast.resource,
/*target_size=*/newTargetCast.resourceSize,
/*target_offset=*/zeroOffset,
/*target_end=*/newTargetCast.resourceSize,
- /*target_length=*/newTargetCast.resourceSize, newSourceCast.resource,
+ /*target_length=*/newTargetCast.resourceSize,
+ /*source=*/newSourceCast.resource,
/*source_size=*/newSourceCast.resourceSize,
- /*source_offset=*/zeroOffset, /*source_end=*/newSourceCast.resourceSize,
- /*source_length=*/newSourceCast.resourceSize, elementCount,
- adaptor.getChannel(),
- /*param=*/mlir::Value(), IREE::Stream::AffinityAttr::lookup(op));
+ /*source_offset=*/zeroOffset,
+ /*source_end=*/newSourceCast.resourceSize,
+ /*source_length=*/newSourceCast.resourceSize,
+ /*element_count=*/elementCount,
+ /*channel=*/adaptor.getChannel(),
+ /*param=*/mlir::Value(), executionAffinityAttr);
return success();
}
};
struct ConvertAllToAllOp
- : public OpConversionPattern<IREE::Flow::CollectiveAllToAllOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Flow::CollectiveAllToAllOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto shape = llvm::cast<ShapedType>(op.getSource().getType());
- auto collectiveAttr = IREE::Stream::CollectiveAttr::get(
- op.getContext(), IREE::Stream::CollectiveKind::AllToAll,
+ : public AffinityOpConversionPattern<IREE::Flow::CollectiveAllToAllOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::Flow::CollectiveAllToAllOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
+ auto collectiveAttr = rewriter.getAttr<IREE::Stream::CollectiveAttr>(
+ IREE::Stream::CollectiveKind::AllToAll,
/*reduction=*/std::nullopt,
static_cast<IREE::Stream::CollectiveElementType>(op.getElementType()));
auto zeroOffset = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
auto elementCount = rewriter.create<arith::ConstantIndexOp>(
- op.getLoc(), shape.getNumElements());
+ op.getLoc(), op.getType().getNumElements());
auto newTargetCast =
- consumeTensorOperand(op.getLoc(), adaptor.getTarget(), rewriter);
+ transferTensorOperand(op.getLoc(), op.getTarget(), adaptor.getTarget(),
+ executionAffinityAttr, rewriter);
auto newSourceCast =
- consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter);
+ transferTensorOperand(op.getLoc(), op.getSource(), adaptor.getSource(),
+ executionAffinityAttr, rewriter);
rewriter.replaceOpWithNewOp<IREE::Stream::AsyncCollectiveOp>(
- op, collectiveAttr, newTargetCast.resource,
+ op, collectiveAttr,
+ /*target=*/newTargetCast.resource,
/*target_size=*/newTargetCast.resourceSize,
/*target_offset=*/zeroOffset,
/*target_end=*/newTargetCast.resourceSize,
- /*target_length=*/newTargetCast.resourceSize, newSourceCast.resource,
+ /*target_length=*/newTargetCast.resourceSize,
+ /*source=*/newSourceCast.resource,
/*source_size=*/newSourceCast.resourceSize,
- /*source_offset=*/zeroOffset, /*source_end=*/newSourceCast.resourceSize,
- /*source_length=*/newSourceCast.resourceSize, elementCount,
- adaptor.getChannel(),
- /*param=*/mlir::Value(), IREE::Stream::AffinityAttr::lookup(op));
+ /*source_offset=*/zeroOffset,
+ /*source_end=*/newSourceCast.resourceSize,
+ /*source_length=*/newSourceCast.resourceSize,
+ /*element_count=*/elementCount,
+ /*channel=*/adaptor.getChannel(),
+ /*param=*/mlir::Value(), executionAffinityAttr);
return success();
}
};
-struct ConvertReduceScatterOp
- : public OpConversionPattern<IREE::Flow::CollectiveReduceScatterOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Flow::CollectiveReduceScatterOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto shape = llvm::cast<ShapedType>(op.getType());
- auto collectiveAttr = IREE::Stream::CollectiveAttr::get(
- op.getContext(), IREE::Stream::CollectiveKind::ReduceScatter,
+struct ConvertReduceScatterOp : public AffinityOpConversionPattern<
+ IREE::Flow::CollectiveReduceScatterOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::Flow::CollectiveReduceScatterOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
+ auto collectiveAttr = rewriter.getAttr<IREE::Stream::CollectiveAttr>(
+ IREE::Stream::CollectiveKind::ReduceScatter,
static_cast<IREE::Stream::CollectiveReductionOp>(op.getReductionOp()),
static_cast<IREE::Stream::CollectiveElementType>(op.getElementType()));
auto zeroOffset = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
auto elementCount = rewriter.create<arith::ConstantIndexOp>(
- op.getLoc(), shape.getNumElements());
+ op.getLoc(), op.getType().getNumElements());
auto newTargetCast =
- consumeTensorOperand(op.getLoc(), adaptor.getTarget(), rewriter);
+ transferTensorOperand(op.getLoc(), op.getTarget(), adaptor.getTarget(),
+ executionAffinityAttr, rewriter);
auto newSourceCast =
- consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter);
+ transferTensorOperand(op.getLoc(), op.getSource(), adaptor.getSource(),
+ executionAffinityAttr, rewriter);
rewriter.replaceOpWithNewOp<IREE::Stream::AsyncCollectiveOp>(
- op, collectiveAttr, newTargetCast.resource,
+ op, collectiveAttr,
+ /*target=*/newTargetCast.resource,
/*target_size=*/newTargetCast.resourceSize,
/*target_offset=*/zeroOffset,
/*target_end=*/newTargetCast.resourceSize,
- /*target_length=*/newTargetCast.resourceSize, newSourceCast.resource,
+ /*target_length=*/newTargetCast.resourceSize,
+ /*source=*/newSourceCast.resource,
/*source_size=*/newSourceCast.resourceSize,
- /*source_offset=*/zeroOffset, /*source_end=*/newSourceCast.resourceSize,
- /*source_length=*/newSourceCast.resourceSize, elementCount,
- adaptor.getChannel(),
- /*param=*/mlir::Value(), IREE::Stream::AffinityAttr::lookup(op));
+ /*source_offset=*/zeroOffset,
+ /*source_end=*/newSourceCast.resourceSize,
+ /*source_length=*/newSourceCast.resourceSize,
+ /*element_count=*/elementCount,
+ /*channel=*/adaptor.getChannel(),
+ /*param=*/mlir::Value(), executionAffinityAttr);
return success();
}
};
struct ConvertCollectiveSendRecvOp
- : public OpConversionPattern<IREE::Flow::CollectiveSendRecvOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Flow::CollectiveSendRecvOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto shape = llvm::cast<ShapedType>(op.getType());
- auto collectiveAttr = IREE::Stream::CollectiveAttr::get(
- op.getContext(), IREE::Stream::CollectiveKind::SendRecv,
+ : public AffinityOpConversionPattern<IREE::Flow::CollectiveSendRecvOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::Flow::CollectiveSendRecvOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
+ auto collectiveAttr = rewriter.getAttr<IREE::Stream::CollectiveAttr>(
+ IREE::Stream::CollectiveKind::SendRecv,
/*reduction=*/std::nullopt,
static_cast<IREE::Stream::CollectiveElementType>(op.getElementType()));
auto zeroOffset = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
auto elementCount = rewriter.create<arith::ConstantIndexOp>(
- op.getLoc(), shape.getNumElements());
+ op.getLoc(), op.getType().getNumElements());
auto newTargetCast =
- consumeTensorOperand(op.getLoc(), adaptor.getTarget(), rewriter);
+ transferTensorOperand(op.getLoc(), op.getTarget(), adaptor.getTarget(),
+ executionAffinityAttr, rewriter);
auto newSourceCast =
- consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter);
+ transferTensorOperand(op.getLoc(), op.getSource(), adaptor.getSource(),
+ executionAffinityAttr, rewriter);
// Pack send, recv into param. The values are checked to be within the
// 16-bit range during lowering to Flow dialect.
@@ -665,27 +694,31 @@
auto param = rewriter.create<arith::OrIOp>(op.getLoc(), hi, lo);
rewriter.replaceOpWithNewOp<IREE::Stream::AsyncCollectiveOp>(
- op, collectiveAttr, newTargetCast.resource,
+ op, collectiveAttr,
+ /*target=*/newTargetCast.resource,
/*target_size=*/newTargetCast.resourceSize,
/*target_offset=*/zeroOffset,
/*target_end=*/newTargetCast.resourceSize,
- /*target_length=*/newTargetCast.resourceSize, newSourceCast.resource,
+ /*target_length=*/newTargetCast.resourceSize,
+ /*source=*/newSourceCast.resource,
/*source_size=*/newSourceCast.resourceSize,
- /*source_offset=*/zeroOffset, /*source_end=*/newSourceCast.resourceSize,
- /*source_length=*/newSourceCast.resourceSize, elementCount,
- adaptor.getChannel(),
- /*param=*/param, IREE::Stream::AffinityAttr::lookup(op));
+ /*source_offset=*/zeroOffset,
+ /*source_end=*/newSourceCast.resourceSize,
+ /*source_length=*/newSourceCast.resourceSize,
+ /*element_count=*/elementCount,
+ /*channel=*/adaptor.getChannel(),
+ /*param=*/param, executionAffinityAttr);
return success();
}
};
-struct ConvertDispatchOp : public OpConversionPattern<IREE::Flow::DispatchOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Flow::DispatchOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
-
+struct ConvertDispatchOp
+ : public AffinityOpConversionPattern<IREE::Flow::DispatchOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::Flow::DispatchOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
// Zero is going to be used for each operand to start.
auto zeroOffset = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
@@ -700,7 +733,8 @@
llvm::zip_equal(op.getArguments(), adaptor.getArguments())) {
if (llvm::isa<ShapedType>(oldOperand.getType())) {
auto newOperandCast =
- consumeTensorOperand(op.getLoc(), newOperand, rewriter);
+ transferTensorOperand(op.getLoc(), oldOperand, newOperand,
+ executionAffinityAttr, rewriter);
newOperand = newOperandCast.resource;
dispatchOperandSizes.push_back(newOperandCast.resourceSize);
operandSizes.push_back(newOperandCast.resourceSize);
@@ -732,9 +766,9 @@
} else {
auto resultDynamicDims = IREE::Util::buildDynamicDimsForValue(
op.getLoc(), result.value(), rewriter);
- resultSizes.push_back(buildResultSizeOf(op.getLoc(), result.value(),
- resultDynamicDims, affinityAttr,
- rewriter));
+ resultSizes.push_back(
+ buildResultSizeOf(op.getLoc(), result.value(), resultDynamicDims,
+ executionAffinityAttr, rewriter));
resultTypes.push_back(unknownType);
}
}
@@ -743,7 +777,7 @@
op, resultTypes, adaptor.getWorkload(), adaptor.getEntryPointsAttr(),
dispatchOperands, dispatchOperandSizes, dispatchOperandOffsets,
dispatchOperandEnds, dispatchOperandLengths, resultSizes,
- adaptor.getTiedOperandsAttr(), affinityAttr);
+ adaptor.getTiedOperandsAttr(), executionAffinityAttr);
newOp->setDialectAttrs(op->getDialectAttrs());
return success();
}
@@ -759,8 +793,8 @@
// Tensors become resources without sizes. The default type converter
// adds the size so we bypass that here. We may want to allow the user
// to override the lifetime with attributes, too.
- return IREE::Stream::ResourceType::get(type.getContext(),
- IREE::Stream::Lifetime::Unknown);
+ return rewriter.getType<IREE::Stream::ResourceType>(
+ IREE::Stream::Lifetime::Unknown);
}
return getTypeConverter()->convertType(type);
};
@@ -784,13 +818,12 @@
}
};
-struct ConvertCallOp : public OpConversionPattern<IREE::Flow::CallOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Flow::CallOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
-
+struct ConvertCallOp : public AffinityOpConversionPattern<IREE::Flow::CallOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::Flow::CallOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
// Zero is going to be used for each operand to start.
auto zeroOffset = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
@@ -805,7 +838,8 @@
llvm::zip_equal(op.getArguments(), adaptor.getArguments())) {
if (llvm::isa<ShapedType>(oldOperand.getType())) {
auto newOperandCast =
- consumeTensorOperand(op.getLoc(), newOperand, rewriter);
+ transferTensorOperand(op.getLoc(), oldOperand, newOperand,
+ executionAffinityAttr, rewriter);
newOperand = newOperandCast.resource;
callOperandSizes.push_back(newOperandCast.resourceSize);
operandSizes.push_back(newOperandCast.resourceSize);
@@ -837,9 +871,9 @@
} else {
auto resultDynamicDims = IREE::Util::buildDynamicDimsForValue(
op.getLoc(), result.value(), rewriter);
- resultSizes.push_back(buildResultSizeOf(op.getLoc(), result.value(),
- resultDynamicDims, affinityAttr,
- rewriter));
+ resultSizes.push_back(
+ buildResultSizeOf(op.getLoc(), result.value(), resultDynamicDims,
+ executionAffinityAttr, rewriter));
resultTypes.push_back(unknownType);
}
}
@@ -848,7 +882,7 @@
op, resultTypes, adaptor.getCalleeAttr(), callOperands,
callOperandSizes, callOperandOffsets, callOperandEnds,
callOperandLengths, resultSizes, adaptor.getTiedOperandsAttr(),
- affinityAttr);
+ executionAffinityAttr);
newOp->setDialectAttrs(op->getDialectAttrs());
return success();
}
@@ -1065,9 +1099,10 @@
} // namespace
-void populateFlowToStreamConversionPatterns(MLIRContext *context,
- TypeConverter &typeConverter,
- RewritePatternSet &patterns) {
+void populateFlowToStreamConversionPatterns(
+ MLIRContext *context, TypeConverter &typeConverter,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis,
+ RewritePatternSet &patterns) {
patterns
.insert<ConvertTensorConstantOp, ConvertTensorDynamicConstantOp,
ConvertTensorCastLikeOp<IREE::Flow::TensorReshapeOp>,
@@ -1075,17 +1110,19 @@
ConvertTensorAllocaOp, ConvertTensorEmptyOp, ConvertTensorSplatOp,
ConvertTensorCloneOp, ConvertTensorTransferOp,
ConvertTensorSliceOp, ConvertTensorUpdateOp, ConvertTensorLoadOp,
- ConvertTensorStoreOp, ConvertTensorTraceOp>(typeConverter,
- context);
- patterns.insert<ConvertChannelDefaultOp, ConvertChannelSplitOp,
- ConvertChannelRankOp, ConvertChannelCountOp>(typeConverter,
- context);
+ ConvertTensorStoreOp, ConvertTensorTraceOp>(
+ typeConverter, context, affinityAnalysis);
+ patterns.insert<ConvertChannelDefaultOp>(typeConverter, context,
+ affinityAnalysis);
+ patterns.insert<ConvertChannelSplitOp, ConvertChannelRankOp,
+ ConvertChannelCountOp>(typeConverter, context);
patterns
.insert<ConvertAllGatherOp, ConvertAllReduceOp, ConvertReduceScatterOp,
- ConvertAllToAllOp, ConvertCollectiveSendRecvOp>(typeConverter,
- context);
- patterns.insert<ConvertDispatchOp>(typeConverter, context);
- patterns.insert<ConvertFuncOp, ConvertCallOp>(typeConverter, context);
+ ConvertAllToAllOp, ConvertCollectiveSendRecvOp>(
+ typeConverter, context, affinityAnalysis);
+ patterns.insert<ConvertDispatchOp>(typeConverter, context, affinityAnalysis);
+ patterns.insert<ConvertFuncOp>(typeConverter, context);
+ patterns.insert<ConvertCallOp>(typeConverter, context, affinityAnalysis);
patterns.insert<ConvertExecutableOp>(typeConverter, context);
patterns.insert<
ConvertDispatchWorkgroupInfoOp<IREE::Flow::DispatchWorkgroupIDOp,
@@ -1098,10 +1135,11 @@
patterns.insert<ConvertReturnOp>(typeConverter, context);
}
-void populateFlowToStreamConversionPatterns(MLIRContext *context,
- ConversionTarget &conversionTarget,
- TypeConverter &typeConverter,
- RewritePatternSet &patterns) {
+void populateFlowToStreamConversionPatterns(
+ MLIRContext *context, ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis,
+ RewritePatternSet &patterns) {
// Disallow all flow ops besides the ones we pass through (today).
// We don't have a stream-equivalent of several of the dispatch-level flow
// ops as the codegen backends directly touch them and so long as we have both
@@ -1111,7 +1149,8 @@
conversionTarget.addLegalOp<IREE::Stream::ExecutableOp>();
conversionTarget.markOpRecursivelyLegal<IREE::Stream::ExecutableOp>();
- populateFlowToStreamConversionPatterns(context, typeConverter, patterns);
+ populateFlowToStreamConversionPatterns(context, typeConverter,
+ affinityAnalysis, patterns);
}
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.h b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.h
index ad2c95a..0379be2 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.h
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.h
@@ -11,18 +11,24 @@
#include "mlir/IR/OperationSupport.h"
#include "mlir/Transforms/DialectConversion.h"
+namespace mlir::iree_compiler::IREE::Stream {
+class AffinityAnalysis;
+} // namespace mlir::iree_compiler::IREE::Stream
+
namespace mlir::iree_compiler {
// Populates conversion patterns that perform flow->stream conversion.
// These patterns ensure that nested types are run through the provided
// |typeConverter|.
-void populateFlowToStreamConversionPatterns(MLIRContext *context,
- TypeConverter &typeConverter,
- RewritePatternSet &patterns);
-void populateFlowToStreamConversionPatterns(MLIRContext *context,
- ConversionTarget &conversionTarget,
- TypeConverter &typeConverter,
- RewritePatternSet &patterns);
+void populateFlowToStreamConversionPatterns(
+ MLIRContext *context, TypeConverter &typeConverter,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis,
+ RewritePatternSet &patterns);
+void populateFlowToStreamConversionPatterns(
+ MLIRContext *context, ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis,
+ RewritePatternSet &patterns);
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/dispatch_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/dispatch_ops.mlir
index 4106307..063389f 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/dispatch_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/dispatch_ops.mlir
@@ -52,13 +52,15 @@
// CHECK-LABEL: @dispatchAffinity
// CHECK-SAME: (%[[INPUT:.+]]: !stream.resource<*>, %[[INPUT_SIZE:.+]]: index, %[[DIM1:.+]]: index, %[[DIM3:.+]]: index)
util.func public @dispatchAffinity(%input: tensor<7x?x24x?xf32>, %dim1: index, %dim3: index) -> (tensor<?x?x1024xf32>, tensor<?x?x1024xf32>) {
+ // CHECK: %[[INPUT_A:.+]] = stream.async.transfer %[[INPUT]] : !stream.resource<*>{%[[INPUT_SIZE]]} -> to(#hal.device.affinity<@device_a>) !stream.resource<*>{%[[INPUT_SIZE]]}
// CHECK: %[[RESULT0_SIZE:.+]] = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<?x?x1024xf32>{%[[DIM1]], %[[DIM3]]}
- // CHECK: %[[RESULT0:.+]] = stream.async.dispatch on(#hal.device.affinity<@device_a>) @ex::@entry0(%[[INPUT]][%c0 to %[[INPUT_SIZE]] for %[[INPUT_SIZE]]])
+ // CHECK: %[[RESULT0:.+]] = stream.async.dispatch on(#hal.device.affinity<@device_a>) @ex::@entry0(%[[INPUT_A]][%c0 to %[[INPUT_SIZE]] for %[[INPUT_SIZE]]])
%0 = flow.dispatch @ex::@entry0(%input) {
stream.affinity = #hal.device.affinity<@device_a>
} : (tensor<7x?x24x?xf32>{%dim1, %dim3}) -> tensor<?x?x1024xf32>{%dim1, %dim3}
+ // CHECK: %[[INPUT_B:.+]] = stream.async.transfer %[[INPUT]] : !stream.resource<*>{%[[INPUT_SIZE]]} -> to(#hal.device.affinity<@device_b>) !stream.resource<*>{%[[INPUT_SIZE]]}
// CHECK: %[[RESULT1_SIZE:.+]] = stream.tensor.sizeof on(#hal.device.affinity<@device_b>) tensor<?x?x1024xf32>{%[[DIM3]], %[[DIM1]]}
- // CHECK: %[[RESULT1:.+]] = stream.async.dispatch on(#hal.device.affinity<@device_b>) @ex::@entry1(%[[INPUT]][%c0 to %[[INPUT_SIZE]] for %[[INPUT_SIZE]]])
+ // CHECK: %[[RESULT1:.+]] = stream.async.dispatch on(#hal.device.affinity<@device_b>) @ex::@entry1(%[[INPUT_B]][%c0 to %[[INPUT_SIZE]] for %[[INPUT_SIZE]]])
%1 = flow.dispatch @ex::@entry1(%input) {
stream.affinity = #hal.device.affinity<@device_b>
} : (tensor<7x?x24x?xf32>{%dim1, %dim3}) -> tensor<?x?x1024xf32>{%dim3, %dim1}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir
index a755d44..ee68211 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir
@@ -231,9 +231,9 @@
util.func public @tensorStoreScalar(%target : tensor<i32>) -> tensor<i32> {
// CHECK: %[[VALUE:.+]] = arith.constant 9
%value = arith.constant 9 : i32
- // CHECK: %[[SPLAT:.+]] = stream.tensor.splat %[[VALUE]] : i32 -> tensor<i32> in !stream.resource<*>{%[[TARGET_SIZE]]}
+ // CHECK: %[[FILL:.+]] = stream.tensor.fill %[[VALUE]], %[[TARGET]] : i32 -> tensor<i32> in %[[TARGET]] as !stream.resource<*>{%[[TARGET_SIZE]]}
%0 = flow.tensor.store %value, %target : tensor<i32>
- // CHECK: util.return %[[SPLAT]]
+ // CHECK: util.return %[[FILL]]
util.return %0 : tensor<i32>
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp
index 4323473..76eef8b 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp
@@ -21,11 +21,12 @@
// %1 = stream.tensor.import %0 : !hal.buffer_view ->
// tensor<4xf32> in !stream.resource<*>
struct ConvertTensorImportOp
- : public OpConversionPattern<IREE::HAL::TensorImportOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::HAL::TensorImportOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ : public AffinityOpConversionPattern<IREE::HAL::TensorImportOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::HAL::TensorImportOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
auto sourceType = op.getSource().getType();
auto targetType = op.getTargetEncoding();
if (!llvm::isa<IREE::HAL::BufferType>(sourceType) &&
@@ -49,25 +50,23 @@
}
}
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
-
// Import (buffer view to stream resource).
auto resultType = rewriter.getType<IREE::Stream::ResourceType>(
IREE::Stream::Lifetime::External);
Value resultSize = rewriter.create<IREE::Stream::TensorSizeOfOp>(
op.getLoc(), rewriter.getIndexType(),
TypeAttr::get(op.getTarget().getType()), adaptor.getTargetDims(),
- affinityAttr);
+ executionAffinityAttr);
Value resource = rewriter.create<IREE::Stream::TensorImportOp>(
op.getLoc(), resultType, adaptor.getSource(), TypeAttr::get(targetType),
- adaptor.getTargetDims(), resultSize, affinityAttr);
+ adaptor.getTargetDims(), resultSize, executionAffinityAttr);
// Await the fence, if needed. When not specified the resource is assumed to
// be immediately available.
if (auto waitFence = op.getWaitFence()) {
Value waitTimepoint = rewriter.create<IREE::Stream::TimepointImportOp>(
op.getLoc(), rewriter.getType<IREE::Stream::TimepointType>(),
- ValueRange{waitFence}, affinityAttr);
+ ValueRange{waitFence}, executionAffinityAttr);
resource = rewriter
.create<IREE::Stream::TimepointAwaitOp>(
op.getLoc(), ValueRange{resource},
@@ -77,8 +76,9 @@
auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
rewriter.replaceOpWithNewOp<IREE::Stream::AsyncTransferOp>(
- op, unknownType, resource, resultSize, resultSize, affinityAttr,
- /*target_affinity=*/IREE::Stream::AffinityAttr{});
+ op, unknownType, resource, resultSize, resultSize,
+ /*source_affinity=*/executionAffinityAttr,
+ /*target_affinity=*/executionAffinityAttr);
return success();
}
@@ -122,11 +122,12 @@
// %1 = stream.tensor.export %0 : tensor<4xf32> in !stream.resource<*> ->
// !hal.buffer_view
struct ConvertTensorExportOp
- : public OpConversionPattern<IREE::HAL::TensorExportOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::HAL::TensorExportOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ : public AffinityOpConversionPattern<IREE::HAL::TensorExportOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::HAL::TensorExportOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
auto sourceType = op.getSourceEncoding();
auto targetType = op.getTarget().getType();
if (!llvm::isa<IREE::HAL::BufferType>(targetType) &&
@@ -134,9 +135,9 @@
return rewriter.notifyMatchFailure(op, "unsupported HAL cast conversion");
}
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
auto source =
- consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter);
+ transferTensorOperand(op.getLoc(), op.getSource(), adaptor.getSource(),
+ executionAffinityAttr, rewriter);
// Exporting a produced value - transfer our source value to an externally
// usable resource and directly export it. This will cause an allocation.
@@ -146,14 +147,14 @@
if (source.resource.getType() != externalType) {
exportSource = rewriter.create<IREE::Stream::AsyncTransferOp>(
op.getLoc(), externalType, source.resource, source.resourceSize,
- source.resourceSize, /*source_affinity=*/IREE::Stream::AffinityAttr{},
- affinityAttr);
+ source.resourceSize, /*source_affinity=*/source.affinity,
+ /*target_affinity=*/executionAffinityAttr);
}
// Export (stream resource to buffer view).
rewriter.replaceOpWithNewOp<IREE::Stream::TensorExportOp>(
op, targetType, exportSource, TypeAttr::get(sourceType),
- adaptor.getSourceDims(), source.resourceSize, affinityAttr);
+ adaptor.getSourceDims(), source.resourceSize, executionAffinityAttr);
return success();
}
};
@@ -170,29 +171,23 @@
// %update = stream.async.update %0, %storage[...]
// %2 = stream.async.slice %update[...]
struct ConvertTensorAliasOp
- : public OpConversionPattern<IREE::HAL::TensorAliasOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::HAL::TensorAliasOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ : public AffinityOpConversionPattern<IREE::HAL::TensorAliasOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::HAL::TensorAliasOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
auto sourceType = op.getSource().getType();
auto source =
- consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter);
-
- // All operations (if any) will happen on the device specified by the alias
- // as that indicates the affinity of the storage.
- auto affinityAttr =
- dyn_cast_if_present<IREE::Stream::AffinityAttr>(op.getAffinityAttr());
- if (!affinityAttr) {
- affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
- }
+ transferTensorOperand(op.getLoc(), op.getSource(), adaptor.getSource(),
+ executionAffinityAttr, rewriter);
// Query the target storage buffer length; we will only populate up to
// what is required for the output.
Value storageSize = rewriter.create<IREE::Stream::TensorSizeOfOp>(
op.getLoc(), rewriter.getIndexType(),
TypeAttr::get(op.getSource().getType()), adaptor.getSourceDims(),
- affinityAttr);
+ executionAffinityAttr);
// Import the target storage as a resource that we can use as an update
// target. We overwrite the contents and just cast the storage to the
@@ -202,7 +197,7 @@
auto importOp = rewriter.create<IREE::Stream::TensorImportOp>(
op.getLoc(), externalType, adaptor.getStorage(),
TypeAttr::get(sourceType), adaptor.getSourceDims(), storageSize,
- affinityAttr);
+ executionAffinityAttr);
// Await the fence, if needed. When not specified the storage is assumed to
// be immediately available.
@@ -210,7 +205,7 @@
if (auto waitFence = op.getWaitFence()) {
Value waitTimepoint = rewriter.create<IREE::Stream::TimepointImportOp>(
op.getLoc(), rewriter.getType<IREE::Stream::TimepointType>(),
- ValueRange{waitFence}, affinityAttr);
+ ValueRange{waitFence}, executionAffinityAttr);
storage = rewriter
.create<IREE::Stream::TimepointAwaitOp>(
op.getLoc(), ValueRange{storage},
@@ -223,7 +218,7 @@
auto updateOp = rewriter.create<IREE::Stream::AsyncUpdateOp>(
op.getLoc(), externalType, storage, storageSize, zeroOffset,
source.resourceSize, source.resource, source.resourceSize,
- affinityAttr);
+ executionAffinityAttr);
// Slice out the value from the updated tensor.
// This preserves the use-def chain but is almost always elided by aliasing
@@ -231,14 +226,14 @@
auto sliceOp = rewriter.create<IREE::Stream::AsyncSliceOp>(
op.getLoc(), externalType, updateOp.getResult(),
updateOp.getTargetSize(), zeroOffset, source.resourceSize,
- source.resourceSize, affinityAttr);
+ source.resourceSize, executionAffinityAttr);
// Transfer to match original lifetime (if needed).
Value result = sliceOp.getResult();
if (source.resource.getType() != result.getType()) {
result = rewriter.create<IREE::Stream::AsyncTransferOp>(
op.getLoc(), source.resource.getType(), result, source.resourceSize,
- source.resourceSize, affinityAttr, affinityAttr);
+ source.resourceSize, executionAffinityAttr, executionAffinityAttr);
}
rewriter.replaceOp(op, result);
@@ -256,28 +251,38 @@
// %t01 = stream.timepoint.join max(%t0, %t1)
// stream.timepoint.export %t01 => %fence
struct ConvertTensorBarrierOp
- : public OpConversionPattern<IREE::HAL::TensorBarrierOp> {
- using OpConversionPattern::OpConversionPattern;
+ : public AffinityAwareConversionPattern<IREE::HAL::TensorBarrierOp> {
+ using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
LogicalResult
matchAndRewrite(IREE::HAL::TensorBarrierOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
auto timepointType = rewriter.getType<IREE::Stream::TimepointType>();
+ IREE::Stream::AffinityAttr anyAffinityAttr;
SmallVector<Value> signaledResources;
SmallVector<Value> signaledTimepoints;
- for (auto sourceResource : adaptor.getSources()) {
- auto source = consumeTensorOperand(op.getLoc(), sourceResource, rewriter);
+ for (auto [sourceTensor, sourceResource] :
+ llvm::zip_equal(op.getSources(), adaptor.getSources())) {
+ auto source = resolveTensorOperand(op.getLoc(), sourceTensor,
+ sourceResource, rewriter);
auto barrierOp = rewriter.create<IREE::Stream::TimepointBarrierOp>(
sourceResource.getLoc(), source.resource.getType(), timepointType,
- source.resource, source.resourceSize, affinityAttr);
+ source.resource, source.resourceSize, source.affinity);
signaledResources.push_back(barrierOp.getResult());
signaledTimepoints.push_back(barrierOp.getResultTimepoint());
+
+ // When joining from multiple affinities we need to pick one to perform
+ // the chain. For now we do the affinity of the last tensor with the hope
+ // that we can perform the final signal on the affinity that is running.
+ // We should instead probably change this to be set after timepoint
+ // propagation such that we ensure it happens on the final signal when not
+ // acting as a join.
+ anyAffinityAttr = source.affinity;
}
Value joinedTimepoint = IREE::Stream::TimepointJoinOp::join(
op.getLoc(), signaledTimepoints, rewriter);
rewriter.create<IREE::Stream::TimepointChainExternalOp>(
op.getLoc(), joinedTimepoint, ValueRange{adaptor.getSignalFence()},
- affinityAttr);
+ anyAffinityAttr);
rewriter.replaceOp(op, signaledResources);
return success();
}
@@ -285,21 +290,27 @@
} // namespace
-void populateHALToStreamConversionPatterns(MLIRContext *context,
- TypeConverter &typeConverter,
- RewritePatternSet &patterns) {
+void populateHALToStreamConversionPatterns(
+ MLIRContext *context, TypeConverter &typeConverter,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis,
+ RewritePatternSet &patterns) {
typeConverter.addConversion(
[](IREE::HAL::BufferViewType type) { return type; });
- patterns.insert<ConvertTensorImportOp>(typeConverter, context);
- patterns.insert<ConvertTensorExportOp>(typeConverter, context);
- patterns.insert<ConvertTensorAliasOp>(typeConverter, context);
- patterns.insert<ConvertTensorBarrierOp>(typeConverter, context);
+ patterns.insert<ConvertTensorImportOp>(typeConverter, context,
+ affinityAnalysis);
+ patterns.insert<ConvertTensorExportOp>(typeConverter, context,
+ affinityAnalysis);
+ patterns.insert<ConvertTensorAliasOp>(typeConverter, context,
+ affinityAnalysis);
+ patterns.insert<ConvertTensorBarrierOp>(typeConverter, context,
+ affinityAnalysis);
}
-void populateHALToStreamConversionPatterns(MLIRContext *context,
- ConversionTarget &conversionTarget,
- TypeConverter &typeConverter,
- RewritePatternSet &patterns) {
+void populateHALToStreamConversionPatterns(
+ MLIRContext *context, ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis,
+ RewritePatternSet &patterns) {
// Allow executables through without modification.
conversionTarget.addLegalOp<IREE::HAL::ExecutableOp>();
conversionTarget.markOpRecursivelyLegal<IREE::HAL::ExecutableOp>();
@@ -315,7 +326,8 @@
typeConverter.isLegal(op.getTarget().getType());
});
- populateHALToStreamConversionPatterns(context, typeConverter, patterns);
+ populateHALToStreamConversionPatterns(context, typeConverter,
+ affinityAnalysis, patterns);
}
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.h b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.h
index ed2a3c0..f3e955d 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.h
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.h
@@ -11,18 +11,24 @@
#include "mlir/IR/OperationSupport.h"
#include "mlir/Transforms/DialectConversion.h"
+namespace mlir::iree_compiler::IREE::Stream {
+class AffinityAnalysis;
+} // namespace mlir::iree_compiler::IREE::Stream
+
namespace mlir::iree_compiler {
// Populates conversion patterns that perform hal->stream conversion.
// These patterns ensure that nested types are run through the provided
// |typeConverter|.
-void populateHALToStreamConversionPatterns(MLIRContext *context,
- TypeConverter &typeConverter,
- RewritePatternSet &patterns);
-void populateHALToStreamConversionPatterns(MLIRContext *context,
- ConversionTarget &conversionTarget,
- TypeConverter &typeConverter,
- RewritePatternSet &patterns);
+void populateHALToStreamConversionPatterns(
+ MLIRContext *context, TypeConverter &typeConverter,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis,
+ RewritePatternSet &patterns);
+void populateHALToStreamConversionPatterns(
+ MLIRContext *context, ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis,
+ RewritePatternSet &patterns);
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp
index 6bb26f1..fee06f2 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp
@@ -7,6 +7,7 @@
#include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h"
#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
+#include "iree/compiler/Dialect/Stream/Analysis/Affinity.h"
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
@@ -23,13 +24,59 @@
return attr;
}
+IREE::Stream::AffinityAttr
+tryLookupGlobalAffinity(Operation *op,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis) {
+ return affinityAnalysis->lookupGlobalAffinity(op);
+}
+
+IREE::Stream::AffinityAttr
+tryLookupExecutionAffinity(Operation *op,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis) {
+ assert(llvm::isa<IREE::Stream::AffinityOpInterface>(op) &&
+ "must be an affinity op");
+ return affinityAnalysis->lookupExecutionAffinity(op);
+}
+
+IREE::Stream::AffinityAttr
+tryLookupResultAffinity(Value value,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis) {
+ return affinityAnalysis->lookupResourceAffinity(value);
+}
+
+static std::pair<Value, Value>
+resolveTensorOperand(Location loc, Value convertedOperand, OpBuilder &builder) {
+ auto operandType = convertedOperand.getType();
+ if (llvm::isa<IREE::Stream::ResourceType>(operandType)) {
+ // Prior to https://reviews.llvm.org/D111620 this is the path we'd take;
+ // the tensor operands would be remapped into their new resource types.
+ // This is still possible during rewriting if we ourselves produce a new
+ // resource type, but the automatic materialization will go down the
+ // unrealized_conversion_cast path below.
+ return std::make_pair(convertedOperand,
+ builder.createOrFold<IREE::Stream::ResourceSizeOp>(
+ loc, builder.getIndexType(), convertedOperand));
+ } else if (auto castOp =
+ convertedOperand
+ .getDefiningOp<mlir::UnrealizedConversionCastOp>()) {
+ // We only have a single tensor type conversion and it expands to (resource,
+ // size) so that's all we look for here.
+ assert(castOp.getNumOperands() == 2 && "expected (resource, size)");
+ return std::make_pair(castOp.getOperand(0), castOp.getOperand(1));
+ }
+ assert(false &&
+ "unexpected operand; expected either a IREE::Stream::ResourceType or "
+ "the result of a mlir::UnrealizedConversionCastOp");
+ return std::make_pair(Value{}, Value{});
+}
+
void expandResourceOperand(Location loc, Value operand,
SmallVectorImpl<Value> &newOperands,
OpBuilder &builder) {
if (llvm::isa<TensorType>(operand.getType())) {
- auto value = consumeTensorOperand(loc, operand, builder);
- newOperands.push_back(value.resource);
- newOperands.push_back(value.resourceSize);
+ auto [resource, resourceSize] = resolveTensorOperand(loc, operand, builder);
+ newOperands.push_back(resource);
+ newOperands.push_back(resourceSize);
} else if (llvm::isa<IREE::Stream::ResourceType>(operand.getType())) {
newOperands.push_back(operand);
newOperands.push_back(
@@ -49,34 +96,28 @@
return expandedOperands;
}
-ConvertedTensor consumeTensorOperand(Location loc, Value operand,
- OpBuilder &builder) {
- auto operandType = operand.getType();
- if (llvm::isa<IREE::Stream::ResourceType>(operandType)) {
- // Prior to https://reviews.llvm.org/D111620 this is the path we'd take;
- // the tensor operands would be remapped into their new resource types.
- // This is still possible during rewriting if we ourselves produce a new
- // resource type, but the automatic materialization will go down the
- // unrealized_conversion_cast path below.
- return {
- operand,
- builder.createOrFold<IREE::Stream::ResourceSizeOp>(
- loc, builder.getIndexType(), operand),
- };
- } else if (auto castOp =
- operand.getDefiningOp<mlir::UnrealizedConversionCastOp>()) {
- // We only have a single tensor type conversion and it expands to (resource,
- // size) so that's all we look for here.
- assert(castOp.getNumOperands() == 2 && "expected (resource, size)");
- return {
- castOp.getOperand(0),
- castOp.getOperand(1),
- };
+ConvertedTensor resolveTensorOperand(
+ Location loc, Value originalOperand, Value convertedOperand,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis, OpBuilder &builder) {
+ auto [resource, resourceSize] =
+ resolveTensorOperand(loc, convertedOperand, builder);
+ auto affinityAttr = affinityAnalysis->lookupResourceAffinity(originalOperand);
+ return {affinityAttr, resource, resourceSize};
+}
+
+ConvertedTensor transferTensorOperand(
+ Location loc, Value originalOperand, Value convertedOperand,
+ IREE::Stream::AffinityAttr requiredAffinityAttr,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis, OpBuilder &builder) {
+ auto [resource, resourceSize] =
+ resolveTensorOperand(loc, convertedOperand, builder);
+ auto affinityAttr = affinityAnalysis->lookupResourceAffinity(originalOperand);
+ if (affinityAttr != requiredAffinityAttr) {
+ resource = builder.create<IREE::Stream::AsyncTransferOp>(
+ loc, resource.getType(), resource, resourceSize, resourceSize,
+ affinityAttr, requiredAffinityAttr);
}
- assert(false &&
- "unexpected operand; expected either a IREE::Stream::ResourceType or "
- "the result of a mlir::UnrealizedConversionCastOp");
- return ConvertedTensor();
+ return {requiredAffinityAttr, resource, resourceSize};
}
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.h b/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.h
index fd9249e..43cfbb0 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.h
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.h
@@ -7,37 +7,123 @@
#ifndef IREE_COMPILER_DIALECT_STREAM_CONVERSION_PATTERN_UTILS_H_
#define IREE_COMPILER_DIALECT_STREAM_CONVERSION_PATTERN_UTILS_H_
+#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
+namespace mlir::iree_compiler::IREE::Stream {
+class AffinityAnalysis;
+} // namespace mlir::iree_compiler::IREE::Stream
+
namespace mlir::iree_compiler {
// Converts a supported attribute type to the corresponding stream dialect
// value. Returns the provided value if it is natively supported.
TypedAttr convertAttributeToStream(TypedAttr attr);
-void expandResourceOperand(Location loc, Value operand,
- SmallVectorImpl<Value> &newOperands,
- OpBuilder &builder);
+IREE::Stream::AffinityAttr
+tryLookupGlobalAffinity(Operation *op,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis);
+IREE::Stream::AffinityAttr
+tryLookupExecutionAffinity(Operation *op,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis);
+IREE::Stream::AffinityAttr
+tryLookupResultAffinity(Value value,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis);
-SmallVector<Value> expandResourceOperands(Location loc, ValueRange operands,
- ConversionPatternRewriter &rewriter);
-
-// https://reviews.llvm.org/D111620 broke 1->N type expansion during dialect
-// conversion. It inserts unrealized_conversion_casts but then passes the
-// illegal source dialect types for pattern operands, meaning that even though
-// we say tensors are illegal the patterns get the new remapped values as
-// tensors. This, naturally, breaks everything. To work around this we have this
-// helper that tries to peek through the unrealized_conversion_casts and get out
-// the actual values we expected to see from the conversion (and did before that
-// change).
struct ConvertedTensor {
+ // Optional affinity of the resource at the time it is consumed.
+ // May be nullptr if the affinity could not be determined.
+ IREE::Stream::AffinityAttr affinity;
+ // Resource storing the tensor.
Value resource;
+ // Size of the resource in bytes.
Value resourceSize;
};
-ConvertedTensor consumeTensorOperand(Location loc, Value operand,
- OpBuilder &builder);
+
+void expandResourceOperand(Location loc, Value convertedOperand,
+ SmallVectorImpl<Value> &newOperands,
+ OpBuilder &builder);
+SmallVector<Value> expandResourceOperands(Location loc,
+ ValueRange convertedOperands,
+ ConversionPatternRewriter &rewriter);
+
+ConvertedTensor resolveTensorOperand(
+ Location loc, Value originalOperand, Value convertedOperand,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis, OpBuilder &builder);
+ConvertedTensor transferTensorOperand(
+ Location loc, Value originalOperand, Value convertedOperand,
+ IREE::Stream::AffinityAttr requiredAffinityAttr,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis, OpBuilder &builder);
+
+template <typename OpT>
+struct AffinityAwareConversionPattern : public OpConversionPattern<OpT> {
+public:
+ AffinityAwareConversionPattern(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern<OpT>(typeConverter, context, benefit),
+ affinityAnalysis(affinityAnalysis) {}
+
+ IREE::Stream::AffinityAnalysis *getAffinityAnalysis() const {
+ return affinityAnalysis;
+ }
+
+protected:
+ ConvertedTensor resolveTensorOperand(Location loc, Value originalOperand,
+ Value convertedOperand,
+ OpBuilder &builder) const {
+ return mlir::iree_compiler::resolveTensorOperand(
+ loc, originalOperand, convertedOperand, affinityAnalysis, builder);
+ }
+
+ ConvertedTensor
+ transferTensorOperand(Location loc, Value originalOperand,
+ Value convertedOperand,
+ IREE::Stream::AffinityAttr requiredAffinityAttr,
+ OpBuilder &builder) const {
+ return mlir::iree_compiler::transferTensorOperand(
+ loc, originalOperand, convertedOperand, requiredAffinityAttr,
+ affinityAnalysis, builder);
+ }
+
+ IREE::Stream::AffinityAttr lookupResultAffinity(Value originalResult) const {
+ return mlir::iree_compiler::tryLookupResultAffinity(originalResult,
+ affinityAnalysis);
+ }
+
+ IREE::Stream::AffinityAnalysis *affinityAnalysis = nullptr;
+};
+
+template <typename OpT>
+struct AffinityOpConversionPattern
+ : public AffinityAwareConversionPattern<OpT> {
+public:
+ AffinityOpConversionPattern(const TypeConverter &typeConverter,
+ MLIRContext *context,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis,
+ PatternBenefit benefit = 1)
+ : AffinityAwareConversionPattern<OpT>(typeConverter, context,
+ affinityAnalysis, benefit) {}
+
+protected:
+ virtual LogicalResult matchAndRewriteOnAffinity(
+ OpT op, typename OpConversionPattern<OpT>::OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const = 0;
+
+private:
+ LogicalResult
+ matchAndRewrite(OpT op, typename OpConversionPattern<OpT>::OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override final {
+ auto executionAffinityAttr =
+ tryLookupExecutionAffinity(op, this->getAffinityAnalysis());
+ return matchAndRewriteOnAffinity(op, adaptor, executionAffinityAttr,
+ rewriter);
+ }
+};
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/BUILD.bazel
index 646d55e..38ad9ee 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/BUILD.bazel
@@ -15,8 +15,6 @@
iree_compiler_cc_library(
name = "StandardToStream",
srcs = [
- "ConvertConstantOps.cpp",
- "ConvertStructuralOps.cpp",
"Patterns.cpp",
],
hdrs = [
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/CMakeLists.txt
index b910c60..3def716 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/CMakeLists.txt
@@ -16,8 +16,6 @@
HDRS
"Patterns.h"
SRCS
- "ConvertConstantOps.cpp"
- "ConvertStructuralOps.cpp"
"Patterns.cpp"
DEPS
LLVMSupport
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertConstantOps.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertConstantOps.cpp
deleted file mode 100644
index 5ff99f7..0000000
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertConstantOps.cpp
+++ /dev/null
@@ -1,66 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h"
-#include "iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.h"
-#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
-#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-namespace mlir::iree_compiler {
-
-namespace {
-
-struct ConvertTensorConstantOp : public OpConversionPattern<arith::ConstantOp> {
-public:
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- // Only handle tensor types - other arith.constant types (like i32) are
- // ignored.
- if (!llvm::isa<TensorType>(constantOp.getType()))
- return failure();
-
- Type constantType = IREE::Stream::ResourceType::get(
- getContext(), IREE::Stream::Lifetime::Constant);
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(constantOp);
- auto newOp = rewriter.create<IREE::Stream::TensorConstantOp>(
- constantOp.getLoc(), constantType,
- convertAttributeToStream(constantOp.getValue()),
- TypeAttr::get(constantOp.getType()),
- /*result_encoding_dims=*/ValueRange{}, affinityAttr);
-
- Type unknownType = IREE::Stream::ResourceType::get(getContext());
- auto constantSize = rewriter.createOrFold<IREE::Stream::ResourceSizeOp>(
- constantOp.getLoc(), rewriter.getIndexType(), newOp.getResult());
- rewriter.replaceOpWithNewOp<IREE::Stream::AsyncTransferOp>(
- constantOp, unknownType, newOp.getResult(), constantSize, constantSize,
- /*source_affinity=*/affinityAttr,
- /*result_affinity=*/affinityAttr);
- return success();
- }
-};
-
-} // namespace
-
-void populateStandardConstantToStreamPatterns(
- MLIRContext *context, ConversionTarget &conversionTarget,
- TypeConverter &typeConverter, RewritePatternSet &patterns) {
- conversionTarget.addDynamicallyLegalOp<arith::ConstantOp>(
- [](arith::ConstantOp op) {
- return !llvm::isa<TensorType>(op.getType());
- });
-
- patterns.insert<ConvertTensorConstantOp>(typeConverter, context);
-}
-
-} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertStructuralOps.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertStructuralOps.cpp
deleted file mode 100644
index 5b29504..0000000
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertStructuralOps.cpp
+++ /dev/null
@@ -1,406 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h"
-#include "iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.h"
-#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
-#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
-#include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVector.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/IRMapping.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-namespace mlir::iree_compiler {
-
-namespace {
-
-struct BranchOpConversion : public OpConversionPattern<mlir::cf::BranchOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(mlir::cf::BranchOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- // Expand any resource operands to resource + size.
- auto expandedOperands = expandResourceOperands(
- op.getLoc(), adaptor.getDestOperands(), rewriter);
- rewriter.replaceOpWithNewOp<mlir::cf::BranchOp>(op, op.getDest(),
- expandedOperands);
- return success();
- }
-};
-
-struct CondBranchOpConversion
- : public OpConversionPattern<mlir::cf::CondBranchOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(mlir::cf::CondBranchOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- // Expand any resource operands to resource + size.
- auto trueDestOperands = expandResourceOperands(
- op.getLoc(), adaptor.getTrueDestOperands(), rewriter);
- auto falseDestOperands = expandResourceOperands(
- op.getLoc(), adaptor.getFalseDestOperands(), rewriter);
- rewriter.replaceOpWithNewOp<mlir::cf::CondBranchOp>(
- op, adaptor.getCondition(), op.getTrueDest(), trueDestOperands,
- op.getFalseDest(), falseDestOperands);
- return success();
- }
-};
-
-static ValueRange asValueRange(ArrayRef<Value> values) { return values; }
-
-struct SwitchOpConversion : public OpConversionPattern<mlir::cf::SwitchOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(mlir::cf::SwitchOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- // Expand any resource operands to resource + size.
- auto defaultOperands = expandResourceOperands(
- op.getLoc(), adaptor.getDefaultOperands(), rewriter);
- auto caseOperands = llvm::to_vector(
- llvm::map_range(adaptor.getCaseOperands(), [&](ValueRange operands) {
- return expandResourceOperands(op.getLoc(), operands, rewriter);
- }));
- rewriter.replaceOpWithNewOp<mlir::cf::SwitchOp>(
- op, adaptor.getFlag(), op.getDefaultDestination(), defaultOperands,
- op.getCaseValuesAttr(), op.getCaseDestinations(),
- llvm::to_vector(llvm::map_range(caseOperands, asValueRange)));
- return success();
- }
-};
-
-struct SelectOpConversion : public OpConversionPattern<mlir::arith::SelectOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(mlir::arith::SelectOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- // Only handle selects where the operands are tensors (resources).
- if (!llvm::isa<TensorType>(op.getTrueValue().getType()))
- return failure();
- auto trueOperand =
- consumeTensorOperand(op.getLoc(), adaptor.getTrueValue(), rewriter);
- auto falseOperand =
- consumeTensorOperand(op.getLoc(), adaptor.getFalseValue(), rewriter);
- auto resourceSelectOp = rewriter.create<mlir::arith::SelectOp>(
- op.getLoc(), adaptor.getCondition(), trueOperand.resource,
- falseOperand.resource);
- auto sizeSelectOp = rewriter.create<mlir::arith::SelectOp>(
- op.getLoc(), adaptor.getCondition(), trueOperand.resourceSize,
- falseOperand.resourceSize);
- rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
- op, adaptor.getTrueValue().getType(),
- ValueRange{resourceSelectOp.getResult(), sizeSelectOp.getResult()});
- return success();
- }
-};
-
-struct ScfIfOpConversion : public OpConversionPattern<mlir::scf::IfOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(mlir::scf::IfOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- // Expand any resource operands to resource + size.
- auto expandedOperands =
- expandResourceOperands(op.getLoc(), adaptor.getOperands(), rewriter);
-
- // Expand any resource results to resource + size.
- SmallVector<Type> expandedTypes;
- struct Result {
- size_t originalIndex;
- size_t newIndex;
- Type newType;
- };
- SmallVector<Result> resultMap;
- for (auto originalType : llvm::enumerate(op.getResultTypes())) {
- SmallVector<Type> newTypes;
- if (failed(getTypeConverter()->convertType(originalType.value(),
- newTypes))) {
- return rewriter.notifyMatchFailure(op,
- "unable to convert result types");
- }
- resultMap.push_back(
- Result{originalType.index(), expandedTypes.size(), newTypes.front()});
- expandedTypes.append(newTypes);
- }
-
- // Create a new call that takes the expanded input operands and returns the
- // expanded output results. We can't directly replace the original call as
- // the result counts differ.
- auto ifOp = rewriter.create<mlir::scf::IfOp>(op.getLoc(), expandedTypes,
- op.getCondition());
-
- ifOp.getThenRegion().getBlocks().clear();
- rewriter.inlineRegionBefore(op.getThenRegion(), ifOp.getThenRegion(),
- ifOp.getThenRegion().end());
-
- ifOp.getElseRegion().getBlocks().clear();
- rewriter.inlineRegionBefore(op.getElseRegion(), ifOp.getElseRegion(),
- ifOp.getElseRegion().end());
-
- // Tie all resource results together so we end up with 1:1 results with the
- // original op.
- SmallVector<Value> results;
- for (auto result : resultMap) {
- if (llvm::isa<IREE::Stream::ResourceType>(result.newType)) {
- auto oldType = op.getResult(result.originalIndex).getType();
- auto resource = ifOp.getResult(result.newIndex + 0);
- auto resourceSize = ifOp.getResult(result.newIndex + 1);
- results.push_back(rewriter
- .create<mlir::UnrealizedConversionCastOp>(
- op.getLoc(), TypeRange{oldType},
- ValueRange{resource, resourceSize})
- .getResult(0));
- } else {
- results.push_back(ifOp.getResult(result.newIndex));
- }
- }
- rewriter.replaceOp(op, results);
- return success();
- }
-};
-
-struct ScfForOpConversion : public OpConversionPattern<mlir::scf::ForOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(mlir::scf::ForOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto &typeConverter = *getTypeConverter();
- // Expand any resource operands to resource + size.
- auto expandedOperands =
- expandResourceOperands(op.getLoc(), adaptor.getInitArgs(), rewriter);
-
- // Expand any resource results to resource + size.
- SmallVector<Type> expandedTypes;
- struct Result {
- size_t originalIndex;
- size_t newIndex;
- Type newType;
- };
- SmallVector<Result> resultMap;
- for (auto originalType : llvm::enumerate(op.getResultTypes())) {
- SmallVector<Type> newTypes;
- if (failed(getTypeConverter()->convertType(originalType.value(),
- newTypes))) {
- return rewriter.notifyMatchFailure(op,
- "unable to convert result types");
- }
- resultMap.push_back(
- Result{originalType.index(), expandedTypes.size(), newTypes.front()});
- expandedTypes.append(newTypes);
- }
-
- auto &block = op.getRegion().front();
- TypeConverter::SignatureConversion newSignature(block.getNumArguments());
- for (auto arg : llvm::enumerate(block.getArgumentTypes())) {
- if (failed(typeConverter.convertSignatureArg(arg.index(), arg.value(),
- newSignature))) {
- return failure();
- }
- }
-
- // Create a new loop that takes the expanded input operands and returns the
- // expanded output results. We can't directly replace the original loop as
- // the result counts differ.
- auto forOp = rewriter.create<mlir::scf::ForOp>(
- op.getLoc(), adaptor.getLowerBound(), adaptor.getUpperBound(),
- adaptor.getStep(), expandedOperands);
-
- // Inline the block and update the block arguments.
- rewriter.eraseBlock(forOp.getBody());
- rewriter.inlineRegionBefore(op.getRegion(), forOp.getRegion(),
- forOp.getRegion().end());
- if (failed(rewriter.convertRegionTypes(&forOp.getRegion(), typeConverter,
- &newSignature))) {
- return failure();
- }
-
- // Tie all resource results together so we end up with 1:1 results with the
- // original op.
- SmallVector<Value> results;
- for (auto result : resultMap) {
- if (llvm::isa<IREE::Stream::ResourceType>(result.newType)) {
- auto oldType = op.getResult(result.originalIndex).getType();
- auto resource = forOp.getResult(result.newIndex + 0);
- auto resourceSize = forOp.getResult(result.newIndex + 1);
- results.push_back(rewriter
- .create<mlir::UnrealizedConversionCastOp>(
- op.getLoc(), TypeRange{oldType},
- ValueRange{resource, resourceSize})
- .getResult(0));
- } else {
- results.push_back(forOp.getResult(result.newIndex));
- }
- }
- rewriter.replaceOp(op, results);
- return success();
- }
-};
-
-struct ScfWhileOpConversion : public OpConversionPattern<mlir::scf::WhileOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(mlir::scf::WhileOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto &typeConverter = *getTypeConverter();
- // Expand any resource operands to resource + size.
- auto expandedOperands =
- expandResourceOperands(op.getLoc(), adaptor.getOperands(), rewriter);
-
- // Expand any resource results to resource + size.
- SmallVector<Type> expandedTypes;
- struct Result {
- size_t originalIndex;
- size_t newIndex;
- Type newType;
- };
- SmallVector<Result> resultMap;
- for (auto originalType : llvm::enumerate(op.getResultTypes())) {
- SmallVector<Type> newTypes;
- if (failed(getTypeConverter()->convertType(originalType.value(),
- newTypes))) {
- return rewriter.notifyMatchFailure(op,
- "unable to convert result types");
- }
- resultMap.push_back(
- Result{originalType.index(), expandedTypes.size(), newTypes.front()});
- expandedTypes.append(newTypes);
- }
-
- TypeConverter::SignatureConversion newSignature(op.getNumOperands());
- for (auto argType : llvm::enumerate(op.getOperandTypes())) {
- if (failed(typeConverter.convertSignatureArg(
- argType.index(), argType.value(), newSignature))) {
- return failure();
- }
- }
-
- // Create a new call that takes the expanded input operands and returns the
- // expanded output results. We can't directly replace the original call as
- // the result counts differ.
- auto whileOp = rewriter.create<mlir::scf::WhileOp>(
- op.getLoc(), expandedTypes, expandedOperands);
-
- // Inline the `before` block and update the block arguments.
- whileOp.getBefore().getBlocks().clear();
- rewriter.inlineRegionBefore(op.getBefore(), whileOp.getBefore(),
- whileOp.getBefore().end());
- if (failed(rewriter.convertRegionTypes(&whileOp.getBefore(), typeConverter,
- &newSignature))) {
- return failure();
- }
-
- // Inline the `after` block and update the block arguments.
- whileOp.getAfter().getBlocks().clear();
- rewriter.inlineRegionBefore(op.getAfter(), whileOp.getAfter(),
- whileOp.getAfter().end());
- if (failed(rewriter.convertRegionTypes(&whileOp.getAfter(), typeConverter,
- &newSignature))) {
- return failure();
- }
-
- // Tie all resource results together so we end up with 1:1 results with the
- // original op.
- SmallVector<Value> results;
- for (auto result : resultMap) {
- if (llvm::isa<IREE::Stream::ResourceType>(result.newType)) {
- auto oldType = op.getResult(result.originalIndex).getType();
- auto resource = whileOp.getResult(result.newIndex + 0);
- auto resourceSize = whileOp.getResult(result.newIndex + 1);
- results.push_back(rewriter
- .create<mlir::UnrealizedConversionCastOp>(
- op.getLoc(), TypeRange{oldType},
- ValueRange{resource, resourceSize})
- .getResult(0));
- } else {
- results.push_back(whileOp.getResult(result.newIndex));
- }
- }
- rewriter.replaceOp(op, results);
- return success();
- }
-};
-
-struct ScfConditionOpConversion
- : public OpConversionPattern<mlir::scf::ConditionOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(mlir::scf::ConditionOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- // Expand any resource operands to resource + size.
- auto expandedOperands =
- expandResourceOperands(op.getLoc(), adaptor.getArgs(), rewriter);
- rewriter.replaceOpWithNewOp<mlir::scf::ConditionOp>(
- op, adaptor.getCondition(), expandedOperands);
- return success();
- }
-};
-
-struct ScfYieldOpConversion : public OpConversionPattern<mlir::scf::YieldOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(mlir::scf::YieldOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- // Expand any resource operands to resource + size.
- auto expandedOperands =
- expandResourceOperands(op.getLoc(), adaptor.getOperands(), rewriter);
- rewriter.replaceOpWithNewOp<mlir::scf::YieldOp>(op, expandedOperands);
- return success();
- }
-};
-
-} // namespace
-
-template <typename OpT>
-static inline void addGenericLegalOp(ConversionTarget &conversionTarget,
- TypeConverter &typeConverter) {
- conversionTarget.addDynamicallyLegalOp<OpT>([&](OpT op) {
- return llvm::all_of(
- op->getOperandTypes(),
- [&typeConverter](Type t) { return typeConverter.isLegal(t); }) &&
- llvm::all_of(op->getResultTypes(), [&typeConverter](Type t) {
- return typeConverter.isLegal(t);
- });
- });
-}
-
-void populateStandardStructuralToStreamPatterns(
- MLIRContext *context, ConversionTarget &conversionTarget,
- TypeConverter &typeConverter, RewritePatternSet &patterns) {
- conversionTarget.addLegalOp<mlir::ModuleOp>();
-
- // We need to rewrite certain types on operands/results so use the default
- // dynamic legality checker to force any ops using such types to run through
- // our patterns.
-
- addGenericLegalOp<mlir::cf::BranchOp>(conversionTarget, typeConverter);
- addGenericLegalOp<mlir::cf::CondBranchOp>(conversionTarget, typeConverter);
- addGenericLegalOp<mlir::cf::SwitchOp>(conversionTarget, typeConverter);
- patterns
- .insert<BranchOpConversion, CondBranchOpConversion, SwitchOpConversion>(
- typeConverter, context);
-
- addGenericLegalOp<mlir::arith::SelectOp>(conversionTarget, typeConverter);
- patterns.insert<SelectOpConversion>(typeConverter, context);
-
- addGenericLegalOp<mlir::scf::IfOp>(conversionTarget, typeConverter);
- addGenericLegalOp<mlir::scf::ForOp>(conversionTarget, typeConverter);
- addGenericLegalOp<mlir::scf::WhileOp>(conversionTarget, typeConverter);
- addGenericLegalOp<mlir::scf::ConditionOp>(conversionTarget, typeConverter);
- addGenericLegalOp<mlir::scf::YieldOp>(conversionTarget, typeConverter);
- patterns
- .insert<ScfConditionOpConversion, ScfIfOpConversion, ScfForOpConversion,
- ScfWhileOpConversion, ScfYieldOpConversion>(typeConverter,
- context);
-}
-
-} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.cpp
index 1725fb0..9924fd2 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.cpp
@@ -6,26 +6,419 @@
#include "iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.h"
+#include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h"
#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir::iree_compiler {
-void populateStandardConstantToStreamPatterns(
- MLIRContext *context, ConversionTarget &conversionTarget,
- TypeConverter &typeConverter, RewritePatternSet &patterns);
+namespace {
-void populateStandardStructuralToStreamPatterns(
- MLIRContext *context, ConversionTarget &conversionTarget,
- TypeConverter &typeConverter, RewritePatternSet &patterns);
+struct ConvertTensorConstantOp
+ : public AffinityOpConversionPattern<arith::ConstantOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ arith::ConstantOp constantOp, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
+ // Only handle tensor types - other arith.constant types (like i32) are
+ // ignored.
+ if (!llvm::isa<TensorType>(constantOp.getType())) {
+ return failure();
+ }
+
+ auto constantType = rewriter.getType<IREE::Stream::ResourceType>(
+ IREE::Stream::Lifetime::Constant);
+ auto newOp = rewriter.create<IREE::Stream::TensorConstantOp>(
+ constantOp.getLoc(), constantType,
+ convertAttributeToStream(constantOp.getValue()),
+ TypeAttr::get(constantOp.getType()),
+ /*result_encoding_dims=*/ValueRange{}, executionAffinityAttr);
+
+ auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
+ auto constantSize = rewriter.createOrFold<IREE::Stream::ResourceSizeOp>(
+ constantOp.getLoc(), rewriter.getIndexType(), newOp.getResult());
+ rewriter.replaceOpWithNewOp<IREE::Stream::AsyncTransferOp>(
+ constantOp, unknownType, newOp.getResult(), constantSize, constantSize,
+ /*source_affinity=*/executionAffinityAttr,
+ /*result_affinity=*/executionAffinityAttr);
+ return success();
+ }
+};
+
+struct BranchOpConversion
+ : public AffinityAwareConversionPattern<mlir::cf::BranchOp> {
+ using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
+ LogicalResult
+ matchAndRewrite(mlir::cf::BranchOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Expand any resource operands to resource + size.
+ auto expandedOperands = expandResourceOperands(
+ op.getLoc(), adaptor.getDestOperands(), rewriter);
+ rewriter.replaceOpWithNewOp<mlir::cf::BranchOp>(op, op.getDest(),
+ expandedOperands);
+ return success();
+ }
+};
+
+struct CondBranchOpConversion
+ : public AffinityAwareConversionPattern<mlir::cf::CondBranchOp> {
+ using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
+ LogicalResult
+ matchAndRewrite(mlir::cf::CondBranchOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Expand any resource operands to resource + size.
+ auto trueDestOperands = expandResourceOperands(
+ op.getLoc(), adaptor.getTrueDestOperands(), rewriter);
+ auto falseDestOperands = expandResourceOperands(
+ op.getLoc(), adaptor.getFalseDestOperands(), rewriter);
+ rewriter.replaceOpWithNewOp<mlir::cf::CondBranchOp>(
+ op, adaptor.getCondition(), op.getTrueDest(), trueDestOperands,
+ op.getFalseDest(), falseDestOperands);
+ return success();
+ }
+};
+
+static ValueRange asValueRange(ArrayRef<Value> values) { return values; }
+
+struct SwitchOpConversion
+ : public AffinityAwareConversionPattern<mlir::cf::SwitchOp> {
+ using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
+ LogicalResult
+ matchAndRewrite(mlir::cf::SwitchOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Expand any resource operands to resource + size.
+ auto defaultOperands = expandResourceOperands(
+ op.getLoc(), adaptor.getDefaultOperands(), rewriter);
+ auto caseOperands = llvm::to_vector(
+ llvm::map_range(adaptor.getCaseOperands(), [&](ValueRange operands) {
+ return expandResourceOperands(op.getLoc(), operands, rewriter);
+ }));
+ rewriter.replaceOpWithNewOp<mlir::cf::SwitchOp>(
+ op, adaptor.getFlag(), op.getDefaultDestination(), defaultOperands,
+ op.getCaseValuesAttr(), op.getCaseDestinations(),
+ llvm::to_vector(llvm::map_range(caseOperands, asValueRange)));
+ return success();
+ }
+};
+
+struct SelectOpConversion
+ : public AffinityAwareConversionPattern<mlir::arith::SelectOp> {
+ using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
+ LogicalResult
+ matchAndRewrite(mlir::arith::SelectOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Only handle selects where the operands are tensors (resources).
+ if (!llvm::isa<TensorType>(op.getTrueValue().getType()))
+ return failure();
+ auto trueOperand = resolveTensorOperand(op.getLoc(), op.getTrueValue(),
+ adaptor.getTrueValue(), rewriter);
+ auto falseOperand = resolveTensorOperand(op.getLoc(), op.getFalseValue(),
+ adaptor.getFalseValue(), rewriter);
+ auto resourceSelectOp = rewriter.create<mlir::arith::SelectOp>(
+ op.getLoc(), adaptor.getCondition(), trueOperand.resource,
+ falseOperand.resource);
+ auto sizeSelectOp = rewriter.create<mlir::arith::SelectOp>(
+ op.getLoc(), adaptor.getCondition(), trueOperand.resourceSize,
+ falseOperand.resourceSize);
+ rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
+ op, adaptor.getTrueValue().getType(),
+ ValueRange{resourceSelectOp.getResult(), sizeSelectOp.getResult()});
+ return success();
+ }
+};
+
+struct ScfIfOpConversion
+ : public AffinityAwareConversionPattern<mlir::scf::IfOp> {
+ using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
+ LogicalResult
+ matchAndRewrite(mlir::scf::IfOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Expand any resource results to resource + size.
+ SmallVector<Type> expandedTypes;
+ struct Result {
+ size_t originalIndex;
+ size_t newIndex;
+ Type newType;
+ };
+ SmallVector<Result> resultMap;
+ for (auto originalType : llvm::enumerate(op.getResultTypes())) {
+ SmallVector<Type> newTypes;
+ if (failed(getTypeConverter()->convertType(originalType.value(),
+ newTypes))) {
+ return rewriter.notifyMatchFailure(op,
+ "unable to convert result types");
+ }
+ resultMap.push_back(
+ Result{originalType.index(), expandedTypes.size(), newTypes.front()});
+ expandedTypes.append(newTypes);
+ }
+
+ // Create a new call that takes the expanded input operands and returns the
+ // expanded output results. We can't directly replace the original call as
+ // the result counts differ.
+ auto ifOp = rewriter.create<mlir::scf::IfOp>(op.getLoc(), expandedTypes,
+ op.getCondition());
+
+ ifOp.getThenRegion().getBlocks().clear();
+ rewriter.inlineRegionBefore(op.getThenRegion(), ifOp.getThenRegion(),
+ ifOp.getThenRegion().end());
+
+ ifOp.getElseRegion().getBlocks().clear();
+ rewriter.inlineRegionBefore(op.getElseRegion(), ifOp.getElseRegion(),
+ ifOp.getElseRegion().end());
+
+ // Tie all resource results together so we end up with 1:1 results with the
+ // original op.
+ SmallVector<Value> results;
+ for (auto result : resultMap) {
+ if (llvm::isa<IREE::Stream::ResourceType>(result.newType)) {
+ auto oldType = op.getResult(result.originalIndex).getType();
+ auto resource = ifOp.getResult(result.newIndex + 0);
+ auto resourceSize = ifOp.getResult(result.newIndex + 1);
+ results.push_back(rewriter
+ .create<mlir::UnrealizedConversionCastOp>(
+ op.getLoc(), TypeRange{oldType},
+ ValueRange{resource, resourceSize})
+ .getResult(0));
+ } else {
+ results.push_back(ifOp.getResult(result.newIndex));
+ }
+ }
+ rewriter.replaceOp(op, results);
+ return success();
+ }
+};
+
+struct ScfForOpConversion
+ : public AffinityAwareConversionPattern<mlir::scf::ForOp> {
+ using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
+ LogicalResult
+ matchAndRewrite(mlir::scf::ForOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto &typeConverter = *getTypeConverter();
+
+ // Expand any resource operands to resource + size.
+ auto expandedOperands =
+ expandResourceOperands(op.getLoc(), adaptor.getInitArgs(), rewriter);
+
+ // Expand any resource results to resource + size.
+ SmallVector<Type> expandedTypes;
+ struct Result {
+ size_t originalIndex;
+ size_t newIndex;
+ Type newType;
+ };
+ SmallVector<Result> resultMap;
+ for (auto originalType : llvm::enumerate(op.getResultTypes())) {
+ SmallVector<Type> newTypes;
+ if (failed(getTypeConverter()->convertType(originalType.value(),
+ newTypes))) {
+ return rewriter.notifyMatchFailure(op,
+ "unable to convert result types");
+ }
+ resultMap.push_back(
+ Result{originalType.index(), expandedTypes.size(), newTypes.front()});
+ expandedTypes.append(newTypes);
+ }
+
+ auto &block = op.getRegion().front();
+ TypeConverter::SignatureConversion newSignature(block.getNumArguments());
+ for (auto arg : llvm::enumerate(block.getArgumentTypes())) {
+ if (failed(typeConverter.convertSignatureArg(arg.index(), arg.value(),
+ newSignature))) {
+ return failure();
+ }
+ }
+
+ // Create a new loop that takes the expanded input operands and returns the
+ // expanded output results. We can't directly replace the original loop as
+ // the result counts differ.
+ auto forOp = rewriter.create<mlir::scf::ForOp>(
+ op.getLoc(), adaptor.getLowerBound(), adaptor.getUpperBound(),
+ adaptor.getStep(), expandedOperands);
+
+ // Inline the block and update the block arguments.
+ rewriter.eraseBlock(forOp.getBody());
+ rewriter.inlineRegionBefore(op.getRegion(), forOp.getRegion(),
+ forOp.getRegion().end());
+ if (failed(rewriter.convertRegionTypes(&forOp.getRegion(), typeConverter,
+ &newSignature))) {
+ return failure();
+ }
+
+ // Tie all resource results together so we end up with 1:1 results with the
+ // original op.
+ SmallVector<Value> results;
+ for (auto result : resultMap) {
+ if (llvm::isa<IREE::Stream::ResourceType>(result.newType)) {
+ auto oldType = op.getResult(result.originalIndex).getType();
+ auto resource = forOp.getResult(result.newIndex + 0);
+ auto resourceSize = forOp.getResult(result.newIndex + 1);
+ results.push_back(rewriter
+ .create<mlir::UnrealizedConversionCastOp>(
+ op.getLoc(), TypeRange{oldType},
+ ValueRange{resource, resourceSize})
+ .getResult(0));
+ } else {
+ results.push_back(forOp.getResult(result.newIndex));
+ }
+ }
+ rewriter.replaceOp(op, results);
+ return success();
+ }
+};
+
+struct ScfWhileOpConversion
+ : public AffinityAwareConversionPattern<mlir::scf::WhileOp> {
+ using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
+ LogicalResult
+ matchAndRewrite(mlir::scf::WhileOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto &typeConverter = *getTypeConverter();
+
+ // Expand any resource operands to resource + size.
+ auto expandedOperands =
+ expandResourceOperands(op.getLoc(), adaptor.getOperands(), rewriter);
+
+ // Expand any resource results to resource + size.
+ SmallVector<Type> expandedTypes;
+ struct Result {
+ size_t originalIndex;
+ size_t newIndex;
+ Type newType;
+ };
+ SmallVector<Result> resultMap;
+ for (auto originalType : llvm::enumerate(op.getResultTypes())) {
+ SmallVector<Type> newTypes;
+ if (failed(getTypeConverter()->convertType(originalType.value(),
+ newTypes))) {
+ return rewriter.notifyMatchFailure(op,
+ "unable to convert result types");
+ }
+ resultMap.push_back(
+ Result{originalType.index(), expandedTypes.size(), newTypes.front()});
+ expandedTypes.append(newTypes);
+ }
+
+ TypeConverter::SignatureConversion newSignature(op.getNumOperands());
+ for (auto argType : llvm::enumerate(op.getOperandTypes())) {
+ if (failed(typeConverter.convertSignatureArg(
+ argType.index(), argType.value(), newSignature))) {
+ return failure();
+ }
+ }
+
+ // Create a new call that takes the expanded input operands and returns the
+ // expanded output results. We can't directly replace the original call as
+ // the result counts differ.
+ auto whileOp = rewriter.create<mlir::scf::WhileOp>(
+ op.getLoc(), expandedTypes, expandedOperands);
+
+ // Inline the `before` block and update the block arguments.
+ whileOp.getBefore().getBlocks().clear();
+ rewriter.inlineRegionBefore(op.getBefore(), whileOp.getBefore(),
+ whileOp.getBefore().end());
+ if (failed(rewriter.convertRegionTypes(&whileOp.getBefore(), typeConverter,
+ &newSignature))) {
+ return failure();
+ }
+
+ // Inline the `after` block and update the block arguments.
+ whileOp.getAfter().getBlocks().clear();
+ rewriter.inlineRegionBefore(op.getAfter(), whileOp.getAfter(),
+ whileOp.getAfter().end());
+ if (failed(rewriter.convertRegionTypes(&whileOp.getAfter(), typeConverter,
+ &newSignature))) {
+ return failure();
+ }
+
+ // Tie all resource results together so we end up with 1:1 results with the
+ // original op.
+ SmallVector<Value> results;
+ for (auto result : resultMap) {
+ if (llvm::isa<IREE::Stream::ResourceType>(result.newType)) {
+ auto oldType = op.getResult(result.originalIndex).getType();
+ auto resource = whileOp.getResult(result.newIndex + 0);
+ auto resourceSize = whileOp.getResult(result.newIndex + 1);
+ results.push_back(rewriter
+ .create<mlir::UnrealizedConversionCastOp>(
+ op.getLoc(), TypeRange{oldType},
+ ValueRange{resource, resourceSize})
+ .getResult(0));
+ } else {
+ results.push_back(whileOp.getResult(result.newIndex));
+ }
+ }
+ rewriter.replaceOp(op, results);
+ return success();
+ }
+};
+
+struct ScfConditionOpConversion
+ : public AffinityAwareConversionPattern<mlir::scf::ConditionOp> {
+ using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
+ LogicalResult
+ matchAndRewrite(mlir::scf::ConditionOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Expand any resource operands to resource + size.
+ auto expandedOperands =
+ expandResourceOperands(op.getLoc(), adaptor.getArgs(), rewriter);
+ rewriter.replaceOpWithNewOp<mlir::scf::ConditionOp>(
+ op, adaptor.getCondition(), expandedOperands);
+ return success();
+ }
+};
+
+struct ScfYieldOpConversion
+ : public AffinityAwareConversionPattern<mlir::scf::YieldOp> {
+ using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
+ LogicalResult
+ matchAndRewrite(mlir::scf::YieldOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Expand any resource operands to resource + size.
+ auto expandedOperands =
+ expandResourceOperands(op.getLoc(), adaptor.getOperands(), rewriter);
+ rewriter.replaceOpWithNewOp<mlir::scf::YieldOp>(op, expandedOperands);
+ return success();
+ }
+};
+
+template <typename OpT>
+static inline void addGenericLegalOp(ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter) {
+ conversionTarget.addDynamicallyLegalOp<OpT>([&](OpT op) {
+ return llvm::all_of(
+ op->getOperandTypes(),
+ [&typeConverter](Type t) { return typeConverter.isLegal(t); }) &&
+ llvm::all_of(op->getResultTypes(), [&typeConverter](Type t) {
+ return typeConverter.isLegal(t);
+ });
+ });
+}
+
+} // namespace
void populateStandardToStreamConversionPatterns(
MLIRContext *context, ConversionTarget &conversionTarget,
- TypeConverter &typeConverter, RewritePatternSet &patterns) {
+ TypeConverter &typeConverter,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis,
+ RewritePatternSet &patterns) {
typeConverter.addConversion([](IndexType type) { return type; });
typeConverter.addConversion([](IntegerType type) { return type; });
typeConverter.addConversion([](FloatType type) { return type; });
@@ -35,10 +428,38 @@
conversionTarget.addIllegalOp<memref::DimOp, memref::RankOp, tensor::DimOp,
tensor::RankOp>();
- populateStandardConstantToStreamPatterns(context, conversionTarget,
- typeConverter, patterns);
- populateStandardStructuralToStreamPatterns(context, conversionTarget,
- typeConverter, patterns);
+ conversionTarget.addDynamicallyLegalOp<arith::ConstantOp>(
+ [](arith::ConstantOp op) {
+ return !llvm::isa<TensorType>(op.getType());
+ });
+ patterns.insert<ConvertTensorConstantOp>(typeConverter, context,
+ affinityAnalysis);
+
+ conversionTarget.addLegalOp<mlir::ModuleOp>();
+
+ // We need to rewrite certain types on operands/results so use the default
+ // dynamic legality checker to force any ops using such types to run through
+ // our patterns.
+
+ addGenericLegalOp<mlir::cf::BranchOp>(conversionTarget, typeConverter);
+ addGenericLegalOp<mlir::cf::CondBranchOp>(conversionTarget, typeConverter);
+ addGenericLegalOp<mlir::cf::SwitchOp>(conversionTarget, typeConverter);
+ patterns
+ .insert<BranchOpConversion, CondBranchOpConversion, SwitchOpConversion>(
+ typeConverter, context, affinityAnalysis);
+
+ addGenericLegalOp<mlir::arith::SelectOp>(conversionTarget, typeConverter);
+ patterns.insert<SelectOpConversion>(typeConverter, context, affinityAnalysis);
+
+ addGenericLegalOp<mlir::scf::IfOp>(conversionTarget, typeConverter);
+ addGenericLegalOp<mlir::scf::ForOp>(conversionTarget, typeConverter);
+ addGenericLegalOp<mlir::scf::WhileOp>(conversionTarget, typeConverter);
+ addGenericLegalOp<mlir::scf::ConditionOp>(conversionTarget, typeConverter);
+ addGenericLegalOp<mlir::scf::YieldOp>(conversionTarget, typeConverter);
+ patterns
+ .insert<ScfConditionOpConversion, ScfIfOpConversion, ScfForOpConversion,
+ ScfWhileOpConversion, ScfYieldOpConversion>(
+ typeConverter, context, affinityAnalysis);
}
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.h b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.h
index 3314dcf..112e602 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.h
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.h
@@ -10,6 +10,10 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
+namespace mlir::iree_compiler::IREE::Stream {
+class AffinityAnalysis;
+} // namespace mlir::iree_compiler::IREE::Stream
+
namespace mlir::iree_compiler {
// Populates conversion patterns that perform standard/builtin->stream
@@ -17,7 +21,9 @@
// provided |typeConverter|.
void populateStandardToStreamConversionPatterns(
MLIRContext *context, ConversionTarget &conversionTarget,
- TypeConverter &typeConverter, RewritePatternSet &patterns);
+ TypeConverter &typeConverter,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis,
+ RewritePatternSet &patterns);
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp
index 4d1fa5f..35e1ca8 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp
@@ -67,8 +67,9 @@
}
};
-struct CallOpConversion : public OpConversionPattern<IREE::Util::CallOp> {
- using OpConversionPattern::OpConversionPattern;
+struct CallOpConversion
+ : public AffinityAwareConversionPattern<IREE::Util::CallOp> {
+ using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
LogicalResult
matchAndRewrite(IREE::Util::CallOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -122,8 +123,9 @@
}
};
-struct ReturnOpConversion : public OpConversionPattern<IREE::Util::ReturnOp> {
- using OpConversionPattern::OpConversionPattern;
+struct ReturnOpConversion
+ : public AffinityAwareConversionPattern<IREE::Util::ReturnOp> {
+ using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
LogicalResult
matchAndRewrite(IREE::Util::ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -142,6 +144,7 @@
struct ExpandedGlobalResource {
IREE::Util::GlobalOp resourceOp;
IREE::Util::GlobalOp resourceSizeOp;
+ IREE::Stream::AffinityAttr affinityAttr;
};
struct GlobalExpansionState {
@@ -163,13 +166,16 @@
public:
BaseGlobalConversionPattern(
std::shared_ptr<GlobalExpansionState> expansionState,
- TypeConverter &typeConverter, MLIRContext *context,
+ TypeConverter &typeConverter,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis, MLIRContext *context,
PatternBenefit benefit = 1)
: OpConversionPattern<T>(typeConverter, context, benefit),
- expansionState(std::move(expansionState)) {}
+ expansionState(std::move(expansionState)),
+ affinityAnalysis(affinityAnalysis) {}
protected:
mutable std::shared_ptr<GlobalExpansionState> expansionState;
+ IREE::Stream::AffinityAnalysis *affinityAnalysis;
};
struct GlobalOpExpansion
@@ -230,9 +236,13 @@
globalOp.getIsMutable(), indexType, std::optional<TypedAttr>{});
resourceSizeOp.setVisibility(globalOp.getVisibility());
+ // Resolve the affinity of the global.
+ // We require this to be a single value today that is usually chosen from
+ // consumers (we take the hit on transfer from producers if needed).
+ auto affinityAttr = tryLookupGlobalAffinity(globalOp, affinityAnalysis);
+
// Materialize the initializer if we need to setup a tensor-like constant.
if (tensorInitializerRequired) {
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(globalOp);
auto initializerOp =
rewriter.create<IREE::Util::InitializerOp>(globalOp.getLoc());
auto *entryBlock = rewriter.createBlock(&initializerOp.getBody());
@@ -265,6 +275,7 @@
expansionState->globalMap[globalOp.getSymName()] = ExpandedGlobalResource{
resourceOp,
resourceSizeOp,
+ affinityAttr,
};
return success();
@@ -289,7 +300,7 @@
auto &expandedGlobal = expandedGlobalIt->getSecond();
// Insert a load/transfer to the unknown resource lifetime.
- auto unknownType = IREE::Stream::ResourceType::get(rewriter.getContext());
+ auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
auto resource =
rewriter
.create<IREE::Util::GlobalLoadOp>(
@@ -303,8 +314,8 @@
.getResult();
rewriter.replaceOpWithNewOp<IREE::Stream::AsyncTransferOp>(
loadOp, unknownType, resource, resourceSize, resourceSize,
- /*source_affinity=*/nullptr,
- /*result_affinity=*/nullptr);
+ /*source_affinity=*/expandedGlobal.affinityAttr,
+ /*result_affinity=*/expandedGlobal.affinityAttr);
return success();
}
@@ -330,12 +341,14 @@
// Insert a transfer/store to the global with unknown lifetime. Lifetime
// refinement will make this go away if possible.
auto value =
- consumeTensorOperand(storeOp.getLoc(), adaptor.getValue(), rewriter);
+ resolveTensorOperand(storeOp.getLoc(), storeOp.getValue(),
+ adaptor.getValue(), affinityAnalysis, rewriter);
assert(expandedGlobal.resourceOp && "Missing resource op");
auto transferOp = rewriter.create<IREE::Stream::AsyncTransferOp>(
storeOp.getLoc(), expandedGlobal.resourceOp.getType(), value.resource,
- value.resourceSize, value.resourceSize, /*source_affinity=*/nullptr,
- /*result_affinity=*/nullptr);
+ value.resourceSize, value.resourceSize,
+ /*source_affinity=*/value.affinity,
+ /*result_affinity=*/expandedGlobal.affinityAttr);
rewriter.replaceOpWithNewOp<IREE::Util::GlobalStoreOp>(
storeOp, transferOp.getResult(),
expandedGlobal.resourceOp.getSymName());
@@ -347,30 +360,59 @@
}
};
+struct OptimizationBarrierOpConversion
+ : public AffinityAwareConversionPattern<IREE::Util::OptimizationBarrierOp> {
+ using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
+ LogicalResult
+ matchAndRewrite(IREE::Util::OptimizationBarrierOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ SmallVector<Value> newOperands;
+ for (auto [originalOperand, convertedOperand] :
+ llvm::zip_equal(op.getOperands(), adaptor.getOperands())) {
+ if (isa<TensorType>(convertedOperand.getType())) {
+ newOperands.push_back(resolveTensorOperand(op.getLoc(), originalOperand,
+ convertedOperand, rewriter)
+ .resource);
+ } else {
+ newOperands.push_back(convertedOperand);
+ }
+ }
+ rewriter.replaceOpWithNewOp<IREE::Util::OptimizationBarrierOp>(op,
+ newOperands);
+ return success();
+ }
+};
+
} // namespace
-void populateUtilToStreamConversionPatterns(MLIRContext *context,
- TypeConverter &typeConverter,
- RewritePatternSet &patterns) {
- patterns
- .insert<FuncOpSignatureConversion, CallOpConversion, ReturnOpConversion>(
- typeConverter, context);
+void populateUtilToStreamConversionPatterns(
+ MLIRContext *context, TypeConverter &typeConverter,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis,
+ RewritePatternSet &patterns) {
+ patterns.insert<FuncOpSignatureConversion>(typeConverter, context);
+ patterns.insert<CallOpConversion, ReturnOpConversion>(typeConverter, context,
+ affinityAnalysis);
auto expansionState = std::make_shared<GlobalExpansionState>();
// TODO(#7432): add indirect global expansion support to streams.
patterns
.insert<GlobalOpExpansion, GlobalLoadOpExpansion, GlobalStoreOpExpansion>(
- expansionState, typeConverter, context);
+ expansionState, typeConverter, affinityAnalysis, context);
patterns.add<GenericConvertTypesPattern<IREE::Util::GlobalOp>,
GenericConvertTypesPattern<IREE::Util::GlobalLoadOp>,
GenericConvertTypesPattern<IREE::Util::GlobalStoreOp>>(
typeConverter, context);
+
+ patterns.insert<OptimizationBarrierOpConversion>(typeConverter, context,
+ affinityAnalysis,
+ /*benefit=*/2);
}
-void populateUtilToStreamConversionPatterns(MLIRContext *context,
- ConversionTarget &conversionTarget,
- TypeConverter &typeConverter,
- RewritePatternSet &patterns) {
+void populateUtilToStreamConversionPatterns(
+ MLIRContext *context, ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis,
+ RewritePatternSet &patterns) {
typeConverter.addConversion([=](IREE::Util::PtrType type,
SmallVectorImpl<Type> &resultTypes) {
// Expand pointers to tensors to [resource, sizeof resource] pointers.
@@ -432,7 +474,8 @@
return typeConverter.isLegal(op.getResultTypes());
});
- populateUtilToStreamConversionPatterns(context, typeConverter, patterns);
+ populateUtilToStreamConversionPatterns(context, typeConverter,
+ affinityAnalysis, patterns);
}
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.h b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.h
index 5673c74..56fcca2 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.h
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.h
@@ -11,18 +11,24 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
+namespace mlir::iree_compiler::IREE::Stream {
+class AffinityAnalysis;
+} // namespace mlir::iree_compiler::IREE::Stream
+
namespace mlir::iree_compiler {
// Populates conversion patterns that perform util->stream conversion.
// These patterns ensure that nested types are run through the provided
// |typeConverter|.
-void populateUtilToStreamConversionPatterns(MLIRContext *context,
- TypeConverter &typeConverter,
- RewritePatternSet &patterns);
-void populateUtilToStreamConversionPatterns(MLIRContext *context,
- ConversionTarget &conversionTarget,
- TypeConverter &typeConverter,
- RewritePatternSet &patterns);
+void populateUtilToStreamConversionPatterns(
+ MLIRContext *context, TypeConverter &typeConverter,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis,
+ RewritePatternSet &patterns);
+void populateUtilToStreamConversionPatterns(
+ MLIRContext *context, ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis,
+ RewritePatternSet &patterns);
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp
index b0b66ac..91f5c0f 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp
@@ -6,6 +6,7 @@
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
+#include "iree/compiler/Dialect/Stream/Analysis/Affinity.h"
#include "iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.h"
#include "iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.h"
#include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h"
@@ -39,79 +40,6 @@
namespace {
-// Builds a stream.tensor.import op that imports an external tensor value into
-// a stream resource. |consumingOps| will be populated with all ops that consume
-// the original |sourceTensor| and that should not be replaced with the returned
-// value.
-static Value buildTensorImportOp(Location loc, Value sourceTensor,
- Type targetType,
- SmallPtrSetImpl<Operation *> &consumingOps,
- IREE::Stream::AffinityAttr affinityAttr,
- OpBuilder &builder) {
- // Gather dynamic dimensions from the input value.
- auto dynamicDims =
- IREE::Util::buildDynamicDimsForValue(loc, sourceTensor, builder);
-
- // Compute the size of the tensor once in the stream resource.
- // This may differ from the external encoding of the tensor as imports are
- // a transfer operation that may need to reformat the tensor.
- auto encodingAttr = TypeAttr::get(sourceTensor.getType());
- Value resultSize = builder.create<IREE::Stream::TensorSizeOfOp>(
- loc, builder.getIndexType(), encodingAttr, dynamicDims, affinityAttr);
-
- // Associate the external SSA value, encoding, and shape information with the
- // stream resource. When lowering we'll then have all the metadata required
- // even after erasing it all on the resource.
- auto externalType = builder.getType<IREE::Stream::ResourceType>(
- IREE::Stream::Lifetime::External);
- auto importOp = builder.create<IREE::Stream::TensorImportOp>(
- loc, externalType, sourceTensor, encodingAttr, dynamicDims, resultSize,
- affinityAttr);
- consumingOps.insert(importOp);
-
- // If needed insert a transfer to the target lifetime.
- Value result = importOp.getResult();
- if (targetType != externalType) {
- result = builder
- .create<IREE::Stream::AsyncTransferOp>(
- loc, targetType, result, resultSize, resultSize,
- /*source_affinity=*/affinityAttr,
- /*result_affinity=*/affinityAttr)
- .getResult();
- }
-
- auto castOp = builder.create<mlir::UnrealizedConversionCastOp>(
- loc, sourceTensor.getType(), ValueRange{result, resultSize});
- consumingOps.insert(castOp);
- return castOp.getResult(0);
-}
-
-// Builds a stream.tensor.export op that exports a stream resource into an
-// external tensor value.
-static Value buildTensorExportOp(Location loc, Value sourceValue,
- TensorType targetType, ValueRange dynamicDims,
- IREE::Stream::AffinityAttr affinityAttr,
- OpBuilder &builder) {
- auto source = consumeTensorOperand(loc, sourceValue, builder);
-
- // If needed insert a transfer to external resource lifetime.
- auto externalType = builder.getType<IREE::Stream::ResourceType>(
- IREE::Stream::Lifetime::External);
- if (source.resource.getType() != externalType) {
- source.resource = builder.create<IREE::Stream::AsyncTransferOp>(
- loc, externalType, source.resource, source.resourceSize,
- source.resourceSize,
- /*source_affinity=*/nullptr,
- /*result_affinity=*/affinityAttr);
- }
-
- // Associate the stream resource and external encoding and shape information.
- auto newOp = builder.create<IREE::Stream::TensorExportOp>(
- loc, targetType, source.resource, TypeAttr::get(targetType), dynamicDims,
- source.resourceSize, affinityAttr);
- return newOp.getResult();
-}
-
// Returns true if |op| has tensor I/O that is not yet imported/exported using
// the stream ops that capture encodings and shapes.
static bool doesOperationNeedWrapping(Operation *op) {
@@ -123,8 +51,9 @@
operand.getDefiningOp());
}) ||
llvm::any_of(op->getResults(), [](Value result) {
- if (!isa<TensorType>(result.getType()))
+ if (!isa<TensorType>(result.getType())) {
return false;
+ }
return !llvm::all_of(result.getUsers(),
llvm::IsaPred<TensorImportOp>);
});
@@ -133,15 +62,19 @@
// Fallback handler for unknown ops taking/returning tensors that need to be
// marshaled into/outof stream resource types.
struct GenericResourcePattern : public ConversionPattern {
- GenericResourcePattern(MLIRContext *context, TypeConverter &converter)
- : ConversionPattern(converter, MatchAnyOpTypeTag(), 0, context) {}
+ GenericResourcePattern(MLIRContext *context, TypeConverter &converter,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis)
+ : ConversionPattern(converter, MatchAnyOpTypeTag(), 0, context),
+ affinityAnalysis(affinityAnalysis) {}
+
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- if (!doesOperationNeedWrapping(op))
+ if (!doesOperationNeedWrapping(op)) {
return failure();
+ }
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
+ auto executionAffinityAttr = affinityAnalysis->inferExecutionAffinity(op);
// Export resources into tensor operands for the op to consume.
SmallVector<Value> newOperands;
@@ -156,11 +89,14 @@
auto tensorType = dyn_cast<TensorType>(oldOperand.getType());
assert(tensorType && "must have a tensor type to map to a resource");
+ auto exportAffinityAttr =
+ affinityAnalysis->lookupResourceAffinity(oldOperand);
auto dynamicDims = IREE::Util::buildDynamicDimsForValue(
op->getLoc(), oldOperand, rewriter);
- newOperands.push_back(buildTensorExportOp(op->getLoc(), newOperand,
- tensorType, dynamicDims,
- affinityAttr, rewriter));
+ newOperands.push_back(buildTensorExportOp(
+ op->getLoc(), oldOperand, newOperand, tensorType, dynamicDims,
+ exportAffinityAttr ? exportAffinityAttr : executionAffinityAttr,
+ rewriter));
}
rewriter.modifyOpInPlace(op, [&]() { op->setOperands(newOperands); });
@@ -168,46 +104,107 @@
rewriter.setInsertionPointAfter(op);
for (auto result : op->getResults()) {
auto tensorType = dyn_cast<TensorType>(result.getType());
- if (!tensorType)
+ if (!tensorType) {
continue;
+ }
+ auto importAffinityAttr =
+ affinityAnalysis->lookupResourceAffinity(result);
auto dynamicDims =
IREE::Util::buildDynamicDimsForValue(op->getLoc(), result, rewriter);
SmallPtrSet<Operation *, 4> consumingOps;
auto importedValue = buildTensorImportOp(
op->getLoc(), result, rewriter.getType<IREE::Stream::ResourceType>(),
- consumingOps, affinityAttr, rewriter);
+ consumingOps,
+ importAffinityAttr ? importAffinityAttr : executionAffinityAttr,
+ rewriter);
result.replaceAllUsesExcept(importedValue, consumingOps);
}
return success();
}
-};
-struct OptimizationBarrierOpConversion
- : public OpConversionPattern<IREE::Util::OptimizationBarrierOp> {
- using OpConversionPattern<
- IREE::Util::OptimizationBarrierOp>::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Util::OptimizationBarrierOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- SmallVector<Value> newOperands;
- for (Value v : adaptor.getOperands()) {
- if (isa<TensorType>(v.getType())) {
- newOperands.push_back(
- consumeTensorOperand(op.getLoc(), v, rewriter).resource);
- } else {
- newOperands.push_back(v);
- }
+ // Builds a stream.tensor.export op that exports a stream resource into an
+ // external tensor value.
+ Value buildTensorExportOp(Location loc, Value originalValue,
+ Value convertedValue, TensorType targetType,
+ ValueRange dynamicDims,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ OpBuilder &builder) const {
+ auto source =
+ transferTensorOperand(loc, originalValue, convertedValue,
+ executionAffinityAttr, affinityAnalysis, builder);
+
+ // If needed insert a transfer to external resource lifetime.
+ auto externalType = builder.getType<IREE::Stream::ResourceType>(
+ IREE::Stream::Lifetime::External);
+ if (source.resource.getType() != externalType) {
+ source.resource = builder.create<IREE::Stream::AsyncTransferOp>(
+ loc, externalType, source.resource, source.resourceSize,
+ source.resourceSize,
+ /*source_affinity=*/source.affinity,
+ /*result_affinity=*/executionAffinityAttr);
}
- rewriter.replaceOpWithNewOp<IREE::Util::OptimizationBarrierOp>(op,
- newOperands);
- return success();
+
+ // Associate the stream resource and external encoding and shape
+ // information.
+ auto newOp = builder.create<IREE::Stream::TensorExportOp>(
+ loc, targetType, source.resource, TypeAttr::get(targetType),
+ dynamicDims, source.resourceSize, executionAffinityAttr);
+ return newOp.getResult();
}
+
+ // Builds a stream.tensor.import op that imports an external tensor value into
+ // a stream resource. |consumingOps| will be populated with all ops that
+ // consume the original |sourceTensor| and that should not be replaced with
+ // the returned value.
+ Value buildTensorImportOp(Location loc, Value sourceTensor, Type targetType,
+ SmallPtrSetImpl<Operation *> &consumingOps,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ OpBuilder &builder) const {
+ // Gather dynamic dimensions from the input value.
+ auto dynamicDims =
+ IREE::Util::buildDynamicDimsForValue(loc, sourceTensor, builder);
+
+ // Compute the size of the tensor once in the stream resource.
+ // This may differ from the external encoding of the tensor as imports are
+ // a transfer operation that may need to reformat the tensor.
+ auto encodingAttr = TypeAttr::get(sourceTensor.getType());
+ Value resultSize = builder.create<IREE::Stream::TensorSizeOfOp>(
+ loc, builder.getIndexType(), encodingAttr, dynamicDims,
+ executionAffinityAttr);
+
+ // Associate the external SSA value, encoding, and shape information with
+ // the stream resource. When lowering we'll then have all the metadata
+ // required even after erasing it all on the resource.
+ auto externalType = builder.getType<IREE::Stream::ResourceType>(
+ IREE::Stream::Lifetime::External);
+ auto importOp = builder.create<IREE::Stream::TensorImportOp>(
+ loc, externalType, sourceTensor, encodingAttr, dynamicDims, resultSize,
+ executionAffinityAttr);
+ consumingOps.insert(importOp);
+
+ // If needed insert a transfer to the target lifetime.
+ Value result = importOp.getResult();
+ if (targetType != externalType) {
+ result = builder
+ .create<IREE::Stream::AsyncTransferOp>(
+ loc, targetType, result, resultSize, resultSize,
+ /*source_affinity=*/executionAffinityAttr,
+ /*result_affinity=*/executionAffinityAttr)
+ .getResult();
+ }
+
+ auto castOp = builder.create<mlir::UnrealizedConversionCastOp>(
+ loc, sourceTensor.getType(), ValueRange{result, resultSize});
+ consumingOps.insert(castOp);
+ return castOp.getResult(0);
+ }
+
+ IREE::Stream::AffinityAnalysis *affinityAnalysis = nullptr;
};
static void stripAffinityAttrs(ModuleOp moduleOp) {
- moduleOp->removeAttr("stream.affinity.default");
auto affinityName = StringAttr::get(moduleOp.getContext(), "stream.affinity");
for (auto &op : moduleOp.getOps()) {
op.removeDiscardableAttr(affinityName);
@@ -223,6 +220,13 @@
void runOnOperation() override {
auto *context = &getContext();
+ // Run affinity analysis so that the required producer/consumer affinities
+ // for all SSA values we'll use during conversion are available.
+ AffinityAnalysis affinityAnalysis(getOperation());
+ if (failed(affinityAnalysis.run())) {
+ return signalPassFailure();
+ }
+
TypeConverter typeConverter;
ConversionTarget conversionTarget(getContext());
RewritePatternSet patterns(&getContext());
@@ -235,10 +239,9 @@
// Allow unknown types to pass through; these come from custom dialects that
// may be mixed into the IR we are converting.
typeConverter.addConversion([=](Type type) -> Type {
- // convert flow.channel into stream.channel
- if (llvm::isa<IREE::Flow::ChannelType>(type))
+ if (llvm::isa<IREE::Flow::ChannelType>(type)) {
return IREE::Stream::ChannelType::get(context);
-
+ }
return !llvm::isa<TensorType>(type) ? type : Type{};
});
@@ -275,21 +278,20 @@
populateUtilConversionPatterns(context, conversionTarget, typeConverter,
patterns);
- populateUtilToStreamConversionPatterns(context, conversionTarget,
- typeConverter, patterns);
+ populateUtilToStreamConversionPatterns(
+ context, conversionTarget, typeConverter, &affinityAnalysis, patterns);
- populateStandardToStreamConversionPatterns(context, conversionTarget,
- typeConverter, patterns);
- populateFlowToStreamConversionPatterns(context, conversionTarget,
- typeConverter, patterns);
- populateHALToStreamConversionPatterns(context, conversionTarget,
- typeConverter, patterns);
+ populateStandardToStreamConversionPatterns(
+ context, conversionTarget, typeConverter, &affinityAnalysis, patterns);
+ populateFlowToStreamConversionPatterns(
+ context, conversionTarget, typeConverter, &affinityAnalysis, patterns);
+ populateHALToStreamConversionPatterns(
+ context, conversionTarget, typeConverter, &affinityAnalysis, patterns);
conversionTarget.markUnknownOpDynamicallyLegal(
[&](Operation *op) -> bool { return !doesOperationNeedWrapping(op); });
- patterns.insert<GenericResourcePattern>(context, typeConverter);
- patterns.insert<OptimizationBarrierOpConversion>(typeConverter, context,
- /*benefit=*/2);
+ patterns.insert<GenericResourcePattern>(context, typeConverter,
+ &affinityAnalysis);
// NOTE: we allow ops that we don't know about to allow custom dialects
// that don't need anything Stream-specific to pass through.
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideAsyncCopies.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideAsyncCopies.cpp
index 7e9c231..6732b97 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideAsyncCopies.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideAsyncCopies.cpp
@@ -399,7 +399,8 @@
if (sourceType != targetType &&
sourceType.getLifetime() == IREE::Stream::Lifetime::Constant) {
LLVM_DEBUG(llvm::dbgs()
- << " - clone source is a constant; cannot elide\n");
+ << " - clone is a resource lifetime cast (" << sourceType
+ << " to " << targetType << "); cannot elide\n");
return false;
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/EmplaceAllocations.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/EmplaceAllocations.cpp
index f3cd827..2db2cd6 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/EmplaceAllocations.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/EmplaceAllocations.cpp
@@ -33,7 +33,9 @@
// Emplacement
//===----------------------------------------------------------------------===//
-static void replaceUsesAndTransfer(Value oldValue, Value newValue) {
+static void
+replaceUsesAndTransfer(Value oldValue, Value newValue,
+ IREE::Stream::AffinityAttr usageAffinityAttr) {
assert(isa<IREE::Stream::ResourceType>(oldValue.getType()));
assert(isa<IREE::Stream::ResourceType>(newValue.getType()));
if (oldValue.getType() == newValue.getType()) {
@@ -44,8 +46,8 @@
builder.setInsertionPointAfterValue(newValue);
Value newValueSize = IREE::Util::SizeAwareTypeInterface::queryValueSize(
newValue.getLoc(), newValue, builder);
- IREE::Stream::AffinityAttr sourceAffinity;
- IREE::Stream::AffinityAttr resultAffinity;
+ IREE::Stream::AffinityAttr sourceAffinity = usageAffinityAttr;
+ IREE::Stream::AffinityAttr resultAffinity = usageAffinityAttr;
Value transferValue = builder.create<IREE::Stream::AsyncTransferOp>(
newValue.getLoc(), oldValue.getType(), newValue, newValueSize,
newValueSize, sourceAffinity, resultAffinity);
@@ -74,7 +76,7 @@
break;
}
- // Find potential.
+ // Find potential update to place the dispatch result into.
Value targetResource;
Value targetResourceSize;
Value targetOffset;
@@ -82,12 +84,22 @@
Value targetLength;
Value targetResult;
Value targetResultSize;
+ Attribute targetAffinityAttr;
Operation *userOp = *result.user_begin();
if (auto updateOp = dyn_cast<IREE::Stream::AsyncUpdateOp>(userOp)) {
if (updateOp.getUpdate() != result) {
// TODO(#14566): continue if sparse emplacement on multiple results.
break;
}
+
+ // Currently only allow exactly matching affinities.
+ // TODO(multi-device): memory compatibility - if compatible then allow.
+ if (updateOp.getAffinityAttr() != dispatchOp.getAffinityAttr()) {
+ continue;
+ }
+
+ // Try to move all SSA values required into the appropriate place.
+ // TODO(benvanik): undo this if there's a failure (or record/roll-back).
if (!IREE::Util::tryMoveProducerBefore(updateOp.getUpdateSize(),
dispatchOp) ||
!IREE::Util::tryMoveProducerBefore(updateOp.getTargetSize(),
@@ -102,6 +114,7 @@
// TODO(#14566): continue if sparse emplacement on multiple results.
break;
}
+
targetResource = updateOp.getTarget();
if (targetResource.getDefiningOp() == dispatchOp) {
// NOTE: we may have already replaced the update target with one of our
@@ -115,6 +128,7 @@
targetLength = updateOp.getUpdateSize();
targetResult = updateOp.getResult();
targetResultSize = updateOp.getTargetSize();
+ targetAffinityAttr = updateOp.getAffinityAttr();
}
if (!targetResource) {
// TODO(#14566): continue if sparse emplacement on multiple results.
@@ -136,7 +150,7 @@
dispatchOp.getResultSizesMutable().assign(resultSizes);
// Replace users with the result of the dispatch op.
- replaceUsesAndTransfer(targetResult, result);
+ replaceUsesAndTransfer(targetResult, result, dispatchOp.getAffinityAttr());
userOp->erase();
didChange = true;
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/materialize_homogeneous_encodings.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/materialize_homogeneous_encodings.mlir
index a3eae92..b11839b 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/materialize_homogeneous_encodings.mlir
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/materialize_homogeneous_encodings.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --iree-global-opt-materialize-homogeneous-encodings %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-hal-device-assignment-pipeline --iree-global-opt-materialize-homogeneous-encodings %s | FileCheck %s
#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {target_triple = "x86_64-none-elf", cpu_features = "+avx512f"}>
#map = affine_map<()[s0, s1] -> (-s1 + (s1 ceildiv s0) * s0)>
@@ -57,7 +57,7 @@
}
}
-// vulkan uses default materialization patterns which unsets the encodings.
+// Vulkan uses default materialization patterns which unsets the encodings.
// CHECK-LABEL: util.func public @lhs_encoding
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK: util.return %[[ARG0]]
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/parameter_ops.mlir b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/parameter_ops.mlir
index 4842e51..5354b82 100644
--- a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/parameter_ops.mlir
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/parameter_ops.mlir
@@ -1,5 +1,7 @@
// RUN: iree-opt --split-input-file --iree-hal-conversion --canonicalize %s | FileCheck %s
+util.global private @device : !hal.device
+
// CHECK-LABEL: @parameterLoad
// CHECK-SAME: (%[[WAIT:.+]]: !hal.fence) -> (!hal.buffer, !hal.buffer, !hal.fence)
util.func public @parameterLoad(%wait: !stream.timepoint) -> (!stream.resource<constant>, !stream.resource<constant>, !stream.timepoint) {
@@ -7,7 +9,7 @@
%c51_i64 = arith.constant 51 : i64
%c100 = arith.constant 100 : index
%c101 = arith.constant 101 : index
- // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
+ // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device
// CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
// CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
// CHECK: %[[BUFFERS:.+]]:2 = io_parameters.load<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
@@ -15,7 +17,7 @@
// CHECK-SAME: type("DeviceVisible|DeviceLocal") usage("TransferSource|TransferTarget|Transfer|DispatchStorageRead|DispatchStorageWrite|DispatchStorage|SharingImmutable")
// CHECK-NEXT: "scope"::"key0"[%c50_i64] : !hal.buffer{%c100}
// CHECK-NEXT: "scope"::"key1"[%c51_i64] : !hal.buffer{%c101}
- %results:2, %result_timepoint = stream.parameter.load await(%wait) => {
+ %results:2, %result_timepoint = stream.parameter.load on(#hal.device.affinity<@device>) await(%wait) => {
"scope"::"key0"[%c50_i64] : !stream.resource<constant>{%c100},
"scope"::"key1"[%c51_i64] : !stream.resource<constant>{%c101}
} => !stream.timepoint
@@ -25,19 +27,21 @@
// -----
+util.global private @device : !hal.device
+
// CHECK-LABEL: @parameterLoadNoScope
// CHECK-SAME: (%[[WAIT:.+]]: !hal.fence) -> (!hal.buffer, !hal.fence)
util.func public @parameterLoadNoScope(%wait: !stream.timepoint) -> (!stream.resource<constant>, !stream.timepoint) {
%c50_i64 = arith.constant 50 : i64
%c100 = arith.constant 100 : index
- // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
+ // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device
// CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
// CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
// CHECK: %[[BUFFER:.+]] = io_parameters.load<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
// CHECK-SAME: wait(%[[WAIT]]) signal(%[[SIGNAL]])
// CHECK-SAME: type("DeviceVisible|DeviceLocal") usage("TransferSource|TransferTarget|Transfer|DispatchStorageRead|DispatchStorageWrite|DispatchStorage|SharingImmutable")
// CHECK-NEXT: "key"[%c50_i64] : !hal.buffer{%c100}
- %result, %result_timepoint = stream.parameter.load await(%wait) => {
+ %result, %result_timepoint = stream.parameter.load on(#hal.device.affinity<@device>) await(%wait) => {
"key"[%c50_i64] : !stream.resource<constant>{%c100}
} => !stream.timepoint
// CHECK: return %[[BUFFER]], %[[SIGNAL]]
@@ -46,6 +50,8 @@
// -----
+util.global private @device : !hal.device
+
// CHECK-LABEL: @parameterRead
// CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[TARGET:.+]]: !hal.buffer) -> !hal.fence
util.func public @parameterRead(%wait: !stream.timepoint, %target: !stream.resource<transient>) -> !stream.timepoint {
@@ -53,19 +59,21 @@
%c100 = arith.constant 100 : index
%c200 = arith.constant 200 : index
%c300 = arith.constant 300 : index
- // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
+ // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device
// CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
// CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
// CHECK: io_parameters.gather<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
// CHECK-SAME: wait(%[[WAIT]]) signal(%[[SIGNAL]])
// CHECK-NEXT: "scope"::"key"[%c50_i64] -> %[[TARGET]][%c100 for %c200] : !hal.buffer
- %timepoint = stream.parameter.read await(%wait) => "scope"::"key"[%c50_i64] -> %target[%c100 for %c200] : !stream.resource<transient>{%c300} => !stream.timepoint
+ %timepoint = stream.parameter.read on(#hal.device.affinity<@device>) await(%wait) => "scope"::"key"[%c50_i64] -> %target[%c100 for %c200] : !stream.resource<transient>{%c300} => !stream.timepoint
// CHECK: return %[[SIGNAL]]
util.return %timepoint : !stream.timepoint
}
// -----
+util.global private @device : !hal.device
+
// CHECK-LABEL: @parameterWrite
// CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[SOURCE:.+]]: !hal.buffer) -> !hal.fence
util.func public @parameterWrite(%wait: !stream.timepoint, %source: !stream.resource<transient>) -> !stream.timepoint {
@@ -73,19 +81,21 @@
%c100 = arith.constant 100 : index
%c200 = arith.constant 200 : index
%c300 = arith.constant 300 : index
- // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
+ // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device
// CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
// CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
// CHECK: io_parameters.scatter<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
// CHECK-SAME: wait(%[[WAIT]]) signal(%[[SIGNAL]])
// CHECK-NEXT: %[[SOURCE]][%c100 for %c200] : !hal.buffer -> "scope"::"key"[%c50_i64]
- %timepoint = stream.parameter.write await(%wait) => %source[%c100 for %c200] : !stream.resource<transient>{%c300} -> "scope"::"key"[%c50_i64] => !stream.timepoint
+ %timepoint = stream.parameter.write on(#hal.device.affinity<@device>) await(%wait) => %source[%c100 for %c200] : !stream.resource<transient>{%c300} -> "scope"::"key"[%c50_i64] => !stream.timepoint
// CHECK: return %[[SIGNAL]]
util.return %timepoint : !stream.timepoint
}
// -----
+util.global private @device : !hal.device
+
// CHECK-LABEL: @parameterGather
// CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[TARGET:.+]]: !hal.buffer) -> !hal.fence
util.func public @parameterGather(%wait: !stream.timepoint, %target: !stream.resource<transient>) -> !stream.timepoint {
@@ -99,7 +109,7 @@
%c201 = arith.constant 201 : index
%c202 = arith.constant 202 : index
%c300 = arith.constant 300 : index
- // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
+ // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device
// CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
// CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
// CHECK: io_parameters.gather<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
@@ -107,7 +117,7 @@
// CHECK-NEXT: "scope"::"key0"[%c50_i64] -> %[[TARGET]][%c100 for %c200] : !hal.buffer,
// CHECK-NEXT: "scope"::"key1"[%c51_i64] -> %[[TARGET]][%c101 for %c201] : !hal.buffer,
// CHECK-NEXT: "scope"::"key2"[%c52_i64] -> %[[TARGET]][%c102 for %c202] : !hal.buffer
- %timepoint = stream.parameter.gather await(%wait) => {
+ %timepoint = stream.parameter.gather on(#hal.device.affinity<@device>) await(%wait) => {
"scope"::"key0"[%c50_i64] -> %target[%c100 for %c200] : !stream.resource<transient>{%c300},
"scope"::"key1"[%c51_i64] -> %target[%c101 for %c201] : !stream.resource<transient>{%c300},
"scope"::"key2"[%c52_i64] -> %target[%c102 for %c202] : !stream.resource<transient>{%c300}
@@ -118,6 +128,8 @@
// -----
+util.global private @device : !hal.device
+
// CHECK-LABEL: @parameterGatherNoScope
// CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[TARGET:.+]]: !hal.buffer) -> !hal.fence
util.func public @parameterGatherNoScope(%wait: !stream.timepoint, %target: !stream.resource<transient>) -> !stream.timepoint {
@@ -128,14 +140,14 @@
%c200 = arith.constant 200 : index
%c201 = arith.constant 201 : index
%c300 = arith.constant 300 : index
- // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
+ // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device
// CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
// CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
// CHECK: io_parameters.gather<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
// CHECK-SAME: wait(%[[WAIT]]) signal(%[[SIGNAL]])
// CHECK-NEXT: "key0"[%c50_i64] -> %[[TARGET]][%c100 for %c200] : !hal.buffer,
// CHECK-NEXT: "key1"[%c51_i64] -> %[[TARGET]][%c101 for %c201] : !hal.buffer
- %timepoint = stream.parameter.gather await(%wait) => {
+ %timepoint = stream.parameter.gather on(#hal.device.affinity<@device>) await(%wait) => {
"key0"[%c50_i64] -> %target[%c100 for %c200] : !stream.resource<transient>{%c300},
"key1"[%c51_i64] -> %target[%c101 for %c201] : !stream.resource<transient>{%c300}
} => !stream.timepoint
@@ -145,6 +157,8 @@
// -----
+util.global private @device : !hal.device
+
// CHECK-LABEL: @parameterScatter
// CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[SOURCE:.+]]: !hal.buffer) -> !hal.fence
util.func public @parameterScatter(%wait: !stream.timepoint, %source: !stream.resource<transient>) -> !stream.timepoint {
@@ -158,7 +172,7 @@
%c201 = arith.constant 201 : index
%c202 = arith.constant 202 : index
%c300 = arith.constant 300 : index
- // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
+ // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device
// CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
// CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
// CHECK: io_parameters.scatter<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
@@ -167,7 +181,7 @@
// CHECK-NEXT: %[[SOURCE]][%c101 for %c201] : !hal.buffer -> "scope"::"key1"[%c51_i64],
// CHECK-NEXT: %[[SOURCE]][%c102 for %c202] : !hal.buffer -> "scope"::"key2"[%c52_i64]
// CHECK-NEXT: }
- %timepoint = stream.parameter.scatter await(%wait) => {
+ %timepoint = stream.parameter.scatter on(#hal.device.affinity<@device>) await(%wait) => {
%source[%c100 for %c200] : !stream.resource<transient>{%c300} -> "scope"::"key0"[%c50_i64],
%source[%c101 for %c201] : !stream.resource<transient>{%c300} -> "scope"::"key1"[%c51_i64],
%source[%c102 for %c202] : !stream.resource<transient>{%c300} -> "scope"::"key2"[%c52_i64]
diff --git a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp
index 0d730cd..bd56dc8 100644
--- a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp
+++ b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp
@@ -136,18 +136,6 @@
if (compileTo == IREEVMPipelinePhase::Input)
return; // early-exit
- // If the user specified a set of target devices we attach them to the module
- // IR so that they are available for all passes that may want to use this
- // information. If trying to compile in a generic mode the user should omit
- // specifying targets.
- IREE::HAL::AssignmentOptions halAssignmentOptions;
- halAssignmentOptions.legacyTargetBackends =
- halTargetOptions.legacyTargetBackends;
- halAssignmentOptions.targetDevices = halTargetOptions.targetDevices;
- halAssignmentOptions.defaultDevice = halTargetOptions.defaultDevice;
- IREE::HAL::buildHALDeviceAssignmentPassPipeline(passManager, targetRegistry,
- halAssignmentOptions);
-
// Now that inputs are legalized, generate wrapper for entry functions.
if (compileFrom < IREEVMPipelinePhase::ABI) { // late-entry
IREE_TRACE_ADD_BEGIN_FRAME_PASS(passManager, "ABI");
@@ -172,6 +160,18 @@
if (compileTo == IREEVMPipelinePhase::ABI)
return; // early-exit
+ // If the user specified a set of target devices we attach them to the module
+ // IR so that they are available for all passes that may want to use this
+ // information. If trying to compile in a generic mode the user should omit
+ // specifying targets.
+ IREE::HAL::AssignmentOptions halAssignmentOptions;
+ halAssignmentOptions.legacyTargetBackends =
+ halTargetOptions.legacyTargetBackends;
+ halAssignmentOptions.targetDevices = halTargetOptions.targetDevices;
+ halAssignmentOptions.defaultDevice = halTargetOptions.defaultDevice;
+ IREE::HAL::buildHALDeviceAssignmentPassPipeline(passManager, targetRegistry,
+ halAssignmentOptions);
+
GlobalOptimization::TransformOptions globalTransformOptions;
globalTransformOptions.options = globalOptimizationOptions;
diff --git a/tools/test/compile_pipelines.mlir b/tools/test/compile_pipelines.mlir
index 2fd4a6c..fb6dbbe 100644
--- a/tools/test/compile_pipelines.mlir
+++ b/tools/test/compile_pipelines.mlir
@@ -1,10 +1,10 @@
// RUN: iree-opt --iree-common-input-transformation-pipeline %s | \
// RUN: iree-opt --iree-abi-transformation-pipeline - | \
-// RUN: iree-opt --iree-common-input-transformation-pipeline - | \
+// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-device-assignment-pipeline{target-devices=local})' --iree-hal-local-target-device-backends=vmvx - | \
// RUN: iree-opt --iree-global-optimization-transformation-pipeline - | \
// RUN: iree-opt --iree-flow-transformation-pipeline - | \
// RUN: iree-opt --iree-stream-transformation-pipeline - | \
-// RUN: iree-opt --iree-hal-transformation-pipeline --iree-hal-target-backends=vmvx - | \
+// RUN: iree-opt --iree-hal-transformation-pipeline - | \
// RUN: iree-opt --iree-vm-transformation-pipeline - | \
// RUN: FileCheck %s
diff --git a/tools/test/compile_to_continuation.mlir b/tools/test/compile_to_continuation.mlir
index 9c78153..5476462 100644
--- a/tools/test/compile_to_continuation.mlir
+++ b/tools/test/compile_to_continuation.mlir
@@ -1,79 +1,89 @@
// RUN: iree-compile --compile-to=input %s | \
-// RUN: iree-compile --output-format=vm-asm --iree-hal-target-backends=vmvx - | \
+// RUN: iree-compile --output-format=vm-asm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx - | \
// RUN: FileCheck %s --check-prefix=INPUT-PHASE
// INPUT-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
// RUN: iree-compile --compile-to=abi %s | \
-// RUN: iree-compile --output-format=vm-asm --iree-hal-target-backends=vmvx - | \
+// RUN: iree-compile --output-format=vm-asm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx - | \
// RUN: FileCheck %s --check-prefix=ABI-PHASE
// ABI-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
-// RUN: iree-compile --compile-to=flow %s | \
-// RUN: iree-compile --output-format=vm-asm --iree-hal-target-backends=vmvx - | \
+// RUN: iree-compile --compile-to=flow --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \
+// RUN: iree-compile --output-format=vm-asm - | \
// RUN: FileCheck %s --check-prefix=FLOW-PHASE
// FLOW-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
-// RUN: iree-compile --compile-to=stream %s | \
-// RUN: iree-compile --output-format=vm-asm --iree-hal-target-backends=vmvx - | \
+// RUN: iree-compile --compile-to=flow %s | \
+// RUN: iree-compile --output-format=vm-asm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx - | \
+// RUN: FileCheck %s --check-prefix=FLOW-PHASE-NO-DEVICE
+// FLOW-PHASE-NO-DEVICE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
+
+// RUN: iree-compile --compile-to=stream --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \
+// RUN: iree-compile --output-format=vm-asm - | \
// RUN: FileCheck %s --check-prefix=STREAM-PHASE
// STREAM-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
-// RUN: iree-compile --compile-to=executable-sources --iree-hal-target-backends=vmvx %s | \
+// RUN: iree-compile --compile-to=stream %s | \
+// RUN: iree-compile --output-format=vm-asm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx - | \
+// RUN: FileCheck %s --check-prefix=STREAM-PHASE-NO-DEVICE
+// STREAM-PHASE-NO-DEVICE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
+
+// RUN: iree-compile --compile-to=executable-sources --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \
// RUN: iree-compile --output-format=vm-asm - | \
// RUN: FileCheck %s --check-prefix=EXECUTABLE-SOURCES-PHASE
// EXECUTABLE-SOURCES-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
-// RUN: iree-compile --compile-to=executable-targets --iree-hal-target-backends=vmvx %s | \
+// RUN: iree-compile --compile-to=executable-targets --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \
// RUN: iree-compile --output-format=vm-asm - | \
// RUN: FileCheck %s --check-prefix=EXECUTABLE-TARGETS-PHASE
// EXECUTABLE-TARGETS-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
-// RUN: iree-compile --compile-to=hal --iree-hal-target-backends=vmvx %s | \
+// RUN: iree-compile --compile-to=hal --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \
// RUN: iree-compile --output-format=vm-asm - | \
// RUN: FileCheck %s --check-prefix=HAL-PHASE
// HAL-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
-// RUN: iree-compile --compile-to=vm --iree-hal-target-backends=vmvx %s | \
+// RUN: iree-compile --compile-to=vm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \
// RUN: iree-compile --output-format=vm-asm - | \
// RUN: FileCheck %s --check-prefix=VM-PHASE
// VM-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
// RUN: iree-compile --compile-to=input %s | \
-// RUN: iree-compile --compile-from=input --output-format=vm-asm --iree-hal-target-backends=vmvx - | \
+// RUN: iree-compile --compile-from=input --output-format=vm-asm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx - | \
// RUN: FileCheck %s --check-prefix=FROM-ABI-PHASE
// FROM-INPUT-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
// RUN: iree-compile --compile-to=abi %s | \
-// RUN: iree-compile --compile-from=abi --output-format=vm-asm --iree-hal-target-backends=vmvx - | \
+// RUN: iree-compile --compile-from=abi --output-format=vm-asm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx - | \
// RUN: FileCheck %s --check-prefix=FROM-ABI-PHASE
// FROM-ABI-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
-// RUN: iree-compile --compile-to=flow %s | \
-// RUN: iree-compile --compile-from=flow --output-format=vm-asm --iree-hal-target-backends=vmvx - | \
+// RUN: iree-compile --compile-to=flow --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \
+// RUN: iree-compile --compile-from=flow --output-format=vm-asm - | \
// RUN: FileCheck %s --check-prefix=FROM-FLOW-PHASE
// FROM-FLOW-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
-// RUN: iree-compile --compile-to=stream %s | \
-// RUN: iree-compile --compile-from=stream --output-format=vm-asm --iree-hal-target-backends=vmvx - | \
+// RUN: iree-compile --compile-to=stream --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \
+// RUN: iree-compile --compile-from=stream --output-format=vm-asm - | \
// RUN: FileCheck %s --check-prefix=FROM-STREAM-PHASE
// FROM-STREAM-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
-// RUN: iree-compile --compile-to=executable-sources --iree-hal-target-backends=vmvx %s | \
+// RUN: iree-compile --compile-to=executable-sources --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \
// RUN: iree-compile --compile-from=executable-sources --output-format=vm-asm - | \
// RUN: FileCheck %s --check-prefix=FROM-EXECUTABLE-SOURCES-PHASE
// FROM-EXECUTABLE-SOURCES-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
-// RUN: iree-compile --compile-to=executable-targets --iree-hal-target-backends=vmvx %s | \
+// RUN: iree-compile --compile-to=executable-targets --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \
// RUN: iree-compile --compile-from=executable-targets --output-format=vm-asm - | \
// RUN: FileCheck %s --check-prefix=FROM-EXECUTABLE-TARGETS-PHASE
// FROM-EXECUTABLE-TARGETS-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
-// RUN: iree-compile --compile-to=hal --iree-hal-target-backends=vmvx %s | \
+// RUN: iree-compile --compile-to=hal --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \
// RUN: iree-compile --compile-from=hal --output-format=vm-asm - | \
// RUN: FileCheck %s --check-prefix=FROM-HAL-PHASE
// FROM-HAL-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
-// RUN: iree-compile --compile-to=vm --iree-hal-target-backends=vmvx %s | \
+// RUN: iree-compile --compile-to=vm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \
// RUN: iree-compile --compile-from=vm --output-format=vm-asm - | \
// RUN: FileCheck %s --check-prefix=FROM-VM-PHASE
// FROM-VM-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
diff --git a/tools/test/compile_to_phase.mlir b/tools/test/compile_to_phase.mlir
index 0390564..f1861a0 100644
--- a/tools/test/compile_to_phase.mlir
+++ b/tools/test/compile_to_phase.mlir
@@ -7,36 +7,44 @@
// ABI-PHASE: %[[INPUT:.+]] = hal.tensor.import %[[ARG0]] "input0" : !hal.buffer_view -> tensor<f32>
// ABI-PHASE: math.absf %[[INPUT]] : tensor<f32>
-// RUN: iree-compile --compile-to=flow %s | FileCheck %s --check-prefix=FLOW-PHASE
+// RUN: iree-compile --compile-to=flow %s --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx | FileCheck %s --check-prefix=FLOW-PHASE
// FLOW-PHASE: flow.executable.export public @abs_dispatch_0
// FLOW-PHASE: flow.dispatch @abs_dispatch_0
-// RUN: iree-compile --compile-to=stream %s | FileCheck %s --check-prefix=STREAM-PHASE
+// RUN: iree-compile --compile-to=flow %s | FileCheck %s --check-prefix=FLOW-PHASE-NO-DEVICE
+// FLOW-PHASE-NO-DEVICE: flow.executable.export public @abs_dispatch_0
+// FLOW-PHASE-NO-DEVICE: flow.dispatch @abs_dispatch_0
+
+// RUN: iree-compile --compile-to=stream --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | FileCheck %s --check-prefix=STREAM-PHASE
// STREAM-PHASE: stream.executable.export public @abs_dispatch_0
// STREAM-PHASE: stream.cmd.dispatch @abs_dispatch_0
-// RUN: iree-compile --compile-to=executable-sources --iree-hal-target-backends=vmvx %s | FileCheck %s --check-prefix=EXECUTABLE-SOURCES-PHASE
+// RUN: iree-compile --compile-to=stream %s | FileCheck %s --check-prefix=STREAM-PHASE-NO-DEVICE
+// STREAM-PHASE-NO-DEVICE: stream.executable.export public @abs_dispatch_0
+// STREAM-PHASE-NO-DEVICE: stream.cmd.dispatch @abs_dispatch_0
+
+// RUN: iree-compile --compile-to=executable-sources --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | FileCheck %s --check-prefix=EXECUTABLE-SOURCES-PHASE
// EXECUTABLE-SOURCES-PHASE: hal.executable private @abs_dispatch_0
// EXECUTABLE-SOURCES-PHASE: hal.executable.variant
// EXECUTABLE-SOURCES-PHASE: linalg.generic
// EXECUTABLE-SOURCES-PHASE: math.absf
-// RUN: iree-compile --compile-to=executable-targets --iree-hal-target-backends=vmvx %s | FileCheck %s --check-prefix=EXECUTABLE-TARGETS-PHASE
+// RUN: iree-compile --compile-to=executable-targets --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | FileCheck %s --check-prefix=EXECUTABLE-TARGETS-PHASE
// EXECUTABLE-TARGETS-PHASE: hal.executable private @abs_dispatch_0
// EXECUTABLE-TARGETS-PHASE: hal.executable.variant
// EXECUTABLE-TARGETS-PHASE: vm.abs.f32
-// RUN: iree-compile --compile-to=hal --iree-hal-target-backends=vmvx %s | FileCheck %s --check-prefix=HAL-PHASE
+// RUN: iree-compile --compile-to=hal --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | FileCheck %s --check-prefix=HAL-PHASE
// HAL-PHASE: hal.executable private @abs_dispatch_0
// HAL-PHASE: hal.executable.binary
// HAL-PHASE: hal.command_buffer.dispatch
-// RUN: iree-compile --compile-to=vm --iree-hal-target-backends=vmvx %s | FileCheck %s --check-prefix=VM-PHASE
+// RUN: iree-compile --compile-to=vm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | FileCheck %s --check-prefix=VM-PHASE
// VM-PHASE: vm.rodata private @abs_dispatch_0
// VM-PHASE: vm.call @hal.command_buffer.dispatch
-// RUN: iree-compile --output-format=vm-asm --compile-to=end --iree-hal-target-backends=vmvx %s | FileCheck %s --check-prefix=END-PHASE
-// RUN: iree-compile --output-format=vm-asm --iree-hal-target-backends=vmvx %s | FileCheck %s --check-prefix=END-PHASE
+// RUN: iree-compile --output-format=vm-asm --compile-to=end --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | FileCheck %s --check-prefix=END-PHASE
+// RUN: iree-compile --output-format=vm-asm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | FileCheck %s --check-prefix=END-PHASE
// END-PHASE: vm.rodata private @abs_dispatch_0
// END-PHASE: vm.call @hal.command_buffer.dispatch