Merging multi-device branch to main. (#17987)
**TLDR**: nothing should break, `--iree-hal-target-backends=` is
deprecated, use `--iree-hal-target-device=` and appropriate
target-specific flags instead.
This reworks the target device concept in the IREE pipeline - in some
cases introducing the concept (flow and HAL) and in others replacing
placeholder mechanisms around stream affinity. This builds upon prior
work that added support for enumerating available devices via the HAL
and providing multiple devices to the runtime tools by adding the
ability to define devices, allowing for execution and storage resources
to be assigned a device, and upgrading passes to support multiple
devices. "Multi-device" here means several things and all are
accomplished with the same mechanism: a single device that may be one of
multiple types (multiple CPU/GPU archs, CPU _or_ GPU, etc), multiple
homogeneous devices (4 of the same exact GPUs accessed through the same
runtime HAL driver), multiple heterogeneous devices (a CPU and a
GPU/NPU/etc), and optional devices (a CPU with some portions offloaded
to a GPU/NPU if it's compatible and available at runtime). In this way
we can provide cross-compilation/targeting, multi-targeting, and
multiple devices with one set of flags, compiler analysis, passes
dealing with the devices, and runtime infrastructure.
Early warning: **it's strongly discouraged to use device information
prior to codegen** - any pass using such information earlier on is a red
flag that will receive pushback. IREE is designed first and foremost as
a cross-compiler with multi-targeting at its core and radically changing
program behavior near the frontend makes it nearly impossible to have
configuration control over the compilation pipeline. Consider
specializing on device prior to codegen tantamount to using C
preprocessor macros based on operating system or architecture: it means
that a problem has not been solved and a workaround has been taken.
There are exceptionally few cases that require device information early,
and those that do can do so in generic ways that do not disturb the
debuggability of the program. For example, far better than preprocessor
macros in C++ are function calls and if statements (as we can do in our
programs), and even better than that are virtual interfaces (ops that
are only lowered to one of multiple implementations later on). That
disclaimer out of the way: it's now possible to query device information
after the input pipeline (global opt/preprocessing/flow). Upstream will
push back against doing so in nearly all cases but it is a useful
mechanism for downstream projects.
The key change here is that the `--iree-hal-target-backends=` compiler
flag has been deprecated. It continues to work for now with the same
behavior as before but usage will shift to the replacement
`--iree-hal-target-device=` flag. A single instance of this flag defines
a single device within the program and repeated uses of it will define
new devices. Devices may be named ("my_device") or anonymous (in which
case they will be assigned an ordinal like 0 or 1), and each device may
be backed by one or more target devices (Vulkan, local host, HIP, etc).
Each target device in the compiler (represented by
`IREE::HAL::TargetDevice`) may have any number of backends with various
configurations (multiple archs, different deployment formats, etc
represented by one or more `IREE::HAL::ExecutableTargetAttr` values).
Example flags:
```sh
# Two devices, one the local host device and the other a Vulkan device:
--iree-hal-target-device=local --iree-hal-target-device=vulkan
# One device selecting between Vulkan if available and otherwise use the local host device:
--iree-hal-target-device=vulkan,local
# Two CUDA devices selected by runtime ordinal; at runtime two --device=
# flags are required to configure both devices:
--iree-hal-target-device=cuda[0] --iree-hal-target-device=cuda[1]
# A fully-defined target specification:
--iree-hal-target-device=#hal.device.target<"cuda", {...}, [#hal.executable.target<...>]>
# Named device for defining a reference by #hal.device.promise<@some_name>:
--iree-hal-target-device=some_name=vulkan
```
The device metadata as specified in the compiler is used to produce
enumeration code that executes at runtime and queries the available
devices to find the appropriate matches. This means that if the program
is compiled to target two CUDA devices then at runtime there must be two
CUDA devices specified - the indirection allows for the compiled
artifact to work with any two CUDA devices targeted by UUID, device
ordinal, etc and not just the first and second CUDA device in the
system. E.g. `iree-compile --iree-hal-target-device=cuda[0]
--iree-hal-target-device=cuda[1]` and `iree-run-module
--device=cuda://UUID_A --device=cuda://UUID_B`. Devices targets in the
compiler can now specify the ordinal of the device in order to
differentiate between multiple devices at runtime (the `cuda[0]` and
`cuda[1]` above indicate the first CUDA device and second CUDA device
provided to the runtime).
Major new attributes:
* `#hal.device.promise<@device>` is a reference to a device that will be
provided at a later stage. Frontends can use this as a placeholder for
devices that are specified on the command line without needing to say
what those devices are when exporting.
* `#hal.device.alias<"name">` specifies an `IREE::HAL::TargetDevice` in
the compiler (`vulkan`, `local`, `hip`, etc) and expands to a full
`#hal.device.target` based on target-specific flags.
* `#hal.device.select<[...]>` controls selection by enumerating each
device in turn and matching the first found.
* `#hal.device.fallback<@other_device>` provides a fallback reference
that the device will match if no other device matches. Note that having
two devices with the same target will create two copies at runtime - if
wanting to use the existing device then the fallback mechanism must be
used.
* `#hal.device.affinity<@device>` (and optional queue mask) is used on
ops to indicate on which device they should execute.
All of the above flags are just syntactic sugar that add the above
attributes to the program IR and it's possible for frontends to insert
these attributes or ops directly depending on use-case. In most cases
leaving placeholders in the IR such that the exact target can be
specified during compilation is ideal: this allows one output from the
frontend to be used with any number of targets and configurations.
Online compilers, though, may want to bake in their exact configuration
and can do so without the need for flags that may lose information. The
general flow of the `buildHALDeviceAssignmentPassPipeline`/`iree-opt
--iree-hal-device-assignment-pipeline` is:
1. `--iree-hal-target-device=` flags are parsed and a
`hal.device.targets` attribute is added to the module.
* `--iree-hal-device-target=cpu_device=local` becomes
`hal.device.targets = [#hal.device.alias<"local"> : !hal.device]`
* `--iree-hal-device-target=cpu_device=local
--iree-hal-device-target=gpu_device=cuda,hip` becomes
```mlir
hal.device.targets = {
cpu_device = #hal.device.alias<"local"> : !hal.device,
gpu_device = #hal.device.select<[#hal.device.alias<"cuda"> :
!hal.device, #hal.device.alias<"hip"> : !hal.device]> :
!hal.device
}
```
2. The `hal.device.targets` attribute (if any) is expanded into
`util.global` ops for each device. These globals are initialized with
one of the supported attributes which are much later turned into
enumeration/selection logic. The above multi-device example becomes:
```mlir
builtin.module attributes {stream.affinity.default =
#hal.device.affinity<@cpu_device>} {
util.global private @cpu_device = #hal.device.alias<"local"> :
!hal.device
util.global private @gpu_device =
#hal.device.select<[#hal.device.alias<"cuda"> : !hal.device,
#hal.device.alias<"hip"> : !hal.device]> :
!hal.device
}
```
3. Any `#hal.device.promise` attributes will be changed to reference the
globals with the same name. This allows for retargeting of inputs by
letting a frontend specify named devices prior to them having been
passed on the command line (or inserted by some other pipeline).
4. Any `#hal.device.alias` attributes are converted to full
`#hal.device.target` attributes using the appropriate
`IREE::HAL::DeviceTarget` implementation.
Upon completion of the pipeline there are globals initialized with
either a specific device target or a selection mechanism to pick between
targets. From that point onward devices are a structural part of the
program and can be referenced by symbol name via attributes like
`#hal.device.affinity`.
Programs are expected to specify the device affinity for all operations
either explicitly or implicitly. By default (as today) the first device
defined will be used but going forward we will want frontends to start
specifying devices. To that end the `flow.tensor.transfer` operation was
added to allow a tensor to have a device affinity assigned to it. A new
analysis is added that allows all tensors (or stream resources) and ops
interacting with them to be queried for which device they should be
placed on. For example, a frontend can specify multiple devices be used
in a computation by transferring the tensors used:
```mlir
util.func private @my_func(%arg0: tensor<4xi32>) -> tensor<4xi32> {
%arg0_device_a = flow.tensor.transfer %arg0 : tensor<4xi32> to #hal.device.promise<@device_a>
%compute_device_a = arith.addi %arg0_device_a, %arg0_device_a : tensor<4xi32>
%transient_device_b = flow.tensor.transfer %compute_device_a : tensor<4xi32> to #hal.device.promise<@device_b>
%compute_device_b = arith.muli %transient_device_b, %transient_device_b : tensor<4xi32>
util.return %compute_device_b : tensor<4xi32>
}
```
To avoid copies there are also ways for frontends to indicate where
argument and result tensors are placed. The best way (in that it's most
general/powerful) is for the frontends to emit `hal.tensor.import`,
`hal.tensor.export`, and `hal.tensor.alias` ops directly as they all now
take affinities. When using the default ABI translation pass it's
possible to add arg/result attrs to public functions, e.g. `util.func
public @my_func(%arg0: tensor<2xi32> {iree.abi.affinity =
#hal.device.promise<@device_a>}) -> (tensor<2xi32> {iree.abi.affinity =
#hal.device.promise<@device_b>})`. Shorthand is provided to allow
specifying an `iree.abi.affinity` on functions themselves for when all
arguments and results are placed on the same device.
After the point devices are specified, materialized in the program as
globals, and referenced either via the magic default attribute, scoped
attributes, or explicit transfer operations most of the mechanics are
implementation details of the stream and HAL dialect lowerings.
Partitioning, allocation, and scheduling in the stream dialect were
always affinity-aware and required only minor tweaks as part of this
work while the HAL TODOs for multi-device were implemented by memoizing
resources per-device and adding the machinery to enumerate and select
devices.
This was reviewed in the following chunks and tested in a roll-up PR
#17482:
* https://github.com/iree-org/iree/pull/17915
* https://github.com/iree-org/iree/pull/17917
* https://github.com/iree-org/iree/pull/17916
* https://github.com/iree-org/iree/pull/17918
* https://github.com/iree-org/iree/pull/17919
* https://github.com/iree-org/iree/pull/17920
diff --git a/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp b/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp
index d9541a7..553dbde 100644
--- a/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp
+++ b/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp
@@ -124,9 +124,17 @@
FENCE,
};
+struct BarrierResult {
+ BlockArgument storage;
+ Type torchType;
+ int returnIndex = -1;
+};
+
struct ConvertedAsyncFunctionInfo {
IREE::Util::FuncOp funcOp;
SmallVector<IREE::Util::ReturnOp> returnOps;
+ SmallVector<DictionaryAttr> torchArgAttrs;
+ SmallVector<DictionaryAttr> torchResultAttrs;
SmallVector<Type> torchInputTypes;
SmallVector<Type> torchResultTypes;
SmallVector<TypeDisposition> inputDispositions;
@@ -136,7 +144,7 @@
// Values that must be captured in the coarse barrier.
SmallVector<Value> barrierInputs;
// Meta data per barrier input: storage, torchType, returnIndex (or -1)
- SmallVector<std::tuple<Value, Type, int>> barrierResultMeta;
+ SmallVector<BarrierResult> barrierResultMeta;
LogicalResult postProcess();
LogicalResult convertImmutableTensorArg(BlockArgument argValue,
@@ -144,10 +152,25 @@
LogicalResult convertMutableTensorArg(BlockArgument argValue, Type torchType,
OpBuilder &builder);
- void addBarrierInput(Value inputTensor, Value storage, Type torchType,
+ void addBarrierInput(Value inputTensor, BlockArgument storage, Type torchType,
int returnIndex) {
barrierInputs.push_back(inputTensor);
- barrierResultMeta.emplace_back(storage, torchType, returnIndex);
+ barrierResultMeta.emplace_back(BarrierResult{
+ storage,
+ torchType,
+ returnIndex,
+ });
+ }
+
+ Attribute getTorchArgAttr(BlockArgument argValue, StringRef attrName) {
+ return torchArgAttrs.empty()
+ ? Attribute{}
+ : torchArgAttrs[argValue.getArgNumber()].get(attrName);
+ }
+ Attribute getTorchResultAttr(int returnIndex, StringRef attrName) {
+ return torchResultAttrs.empty()
+ ? Attribute{}
+ : torchResultAttrs[returnIndex].get(attrName);
}
};
@@ -232,7 +255,8 @@
}
if (needsBarrier) {
Value source = convertToBuiltinTensor(postambleBuilder, returnValue);
- addBarrierInput(source, /*storage=*/Value{}, torchType, returnIndex);
+ addBarrierInput(source, /*storage=*/BlockArgument{}, torchType,
+ returnIndex);
}
break;
}
@@ -276,15 +300,13 @@
SmallVector<Value> aliasedResults;
for (auto [barrierInput, meta] :
llvm::zip_equal(barrierInputs, barrierResultMeta)) {
- Value exportStorage;
- Type torchType;
- int returnIndex;
- std::tie(exportStorage, torchType, returnIndex) = meta;
- if (exportStorage) {
+ if (meta.storage) {
// Use the wait fence indicating when the storage is available for
// mutation. We need to ensure that no writes are made to the storage
// until it indicates it's safe to do so.
- auto waitSignalFences = getEnclosingWaitSignalFences(exportStorage);
+ auto storageAffinityAttr =
+ getTorchArgAttr(meta.storage, "iree.abi.affinity");
+ auto waitSignalFences = getEnclosingWaitSignalFences(meta.storage);
assert(waitSignalFences && "async function missing fences");
Value waitFence = waitSignalFences->first;
auto barrierInputDims = IREE::Util::buildDynamicDimsForValue(
@@ -292,7 +314,8 @@
aliasedResults.push_back(
postambleBuilder.create<IREE::HAL::TensorAliasOp>(
barrierInput.getLoc(), barrierInput.getType(), barrierInput,
- barrierInputDims, exportStorage, waitFence));
+ barrierInputDims, meta.storage, waitFence,
+ storageAffinityAttr));
} else {
aliasedResults.push_back(barrierInput);
}
@@ -301,16 +324,20 @@
funcOp.getLoc(), aliasedResults, coarseSignalFence);
for (auto [barrierResult, meta] :
llvm::zip_equal(barrierOp.getResults(), barrierResultMeta)) {
- Value exportStorage;
- Type torchType;
- int returnIndex;
- std::tie(exportStorage, torchType, returnIndex) = meta;
+ Attribute exportAffinityAttr;
+ if (meta.storage) {
+ exportAffinityAttr = getTorchArgAttr(meta.storage, "iree.abi.affinity");
+ } else if (meta.returnIndex >= 0) {
+ exportAffinityAttr =
+ getTorchResultAttr(meta.returnIndex, "iree.abi.affinity");
+ }
Value exportedValue = postambleBuilder.create<IREE::HAL::TensorExportOp>(
funcOp.getLoc(),
postambleBuilder.getType<IREE::HAL::BufferViewType>(), barrierResult,
- TypeAttr::get(barrierResult.getType()), StringAttr());
- if (returnIndex >= 0) {
- newReturnOperands[returnIndex] = exportedValue;
+ TypeAttr::get(barrierResult.getType()), /*name=*/nullptr,
+ exportAffinityAttr);
+ if (meta.returnIndex >= 0) {
+ newReturnOperands[meta.returnIndex] = exportedValue;
}
}
}
@@ -374,13 +401,16 @@
<< torchType;
}
+ // Propagate explicit affinities to the read.
+ auto affinityAttr = getTorchArgAttr(argValue, "iree.abi.affinity");
+
auto waitSignalFences = getEnclosingWaitSignalFences(argValue);
assert(waitSignalFences && "async function missing fences");
Value waitFence = waitSignalFences->first;
Value importedTensor = builder.create<IREE::HAL::TensorImportOp>(
loc, builtinTensorType, argValue, TypeAttr::get(builtinTensorType),
waitFence,
- /*name=*/StringAttr());
+ /*name=*/nullptr, affinityAttr);
if (builtinTensorType != torchType) {
importedTensor = builder.create<TorchConversion::FromBuiltinTensorOp>(
loc, torchType, importedTensor);
@@ -404,6 +434,9 @@
.toBuiltinTensor();
}
+ // Propagate explicit affinities to the read and write.
+ auto affinityAttr = getTorchArgAttr(argValue, "iree.abi.affinity");
+
// There are only a small set of possible users of a mutable tensor.
// Handle them by operation here.
SmallVector<Operation *> users(argValue.getUsers());
@@ -415,7 +448,7 @@
loc, builtinTensorType, argValue,
/*target_encoding=*/TypeAttr::get(builtinTensorType),
/*wait_fence*/ fences->first,
- /*name=*/StringAttr());
+ /*name=*/nullptr, affinityAttr);
rewriter.replaceOpWithNewOp<TorchConversion::FromBuiltinTensorOp>(
userOp, copyToVtOp.getResult().getType(), imported);
} else if (auto overwriteOp =
@@ -470,6 +503,9 @@
syncFuncOp.setSymVisibilityAttr(asyncFuncOp.getSymVisibilityAttr());
retainFunctionAttributes(asyncFuncOp, syncFuncOp);
syncFuncOp->setAttr("iree.abi.stub", rewriter.getUnitAttr());
+ if (auto affinityAttr = asyncFuncOp->getAttr("iree.abi.affinity")) {
+ syncFuncOp->setAttr("iree.abi.affinity", affinityAttr);
+ }
Block *entryBlock = syncFuncOp.addEntryBlock();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToEnd(entryBlock);
@@ -578,6 +614,10 @@
asyncFunctionName.append("$async");
}
+ // Stash arg/result attrs so they can be referenced during conversion.
+ torchFunc.getAllArgAttrs(convertedFuncInfo.torchArgAttrs);
+ torchFunc.getAllResultAttrs(convertedFuncInfo.torchResultAttrs);
+
// Convert function signature.
Type fenceType = rewriter.getType<IREE::HAL::FenceType>();
FunctionType torchFuncType = torchFunc.getFunctionType();
@@ -638,6 +678,9 @@
asyncFuncOp->setAttr("iree.abi.stub", rewriter.getUnitAttr());
asyncFuncOp->setAttr("iree.abi.model",
rewriter.getStringAttr("coarse-fences"));
+ if (auto affinityAttr = torchFunc->getAttr("iree.abi.affinity")) {
+ asyncFuncOp->setAttr("iree.abi.affinity", affinityAttr);
+ }
rewriter.inlineRegionBefore(
torchFunc.getBody(), asyncFuncOp.getFunctionBody(), asyncFuncOp.end());
diff --git a/compiler/plugins/input/Torch/InputConversion/Passes.cpp b/compiler/plugins/input/Torch/InputConversion/Passes.cpp
index 2939218..8f51c61 100644
--- a/compiler/plugins/input/Torch/InputConversion/Passes.cpp
+++ b/compiler/plugins/input/Torch/InputConversion/Passes.cpp
@@ -56,6 +56,7 @@
TorchInput::createConvertTMTensorToLinalgExtPass());
pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToTensorPass());
pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToLinalgPass());
+ pm.addNestedPass<func::FuncOp>(createCSEPass());
pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToSCFPass());
pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToArithPass());
pm.addPass(torch::createConvertTorchConversionToMLProgramPass());
diff --git a/compiler/plugins/input/Torch/InputConversion/test/func_conversion.mlir b/compiler/plugins/input/Torch/InputConversion/test/func_conversion.mlir
index 3e167ad..3bca01b 100644
--- a/compiler/plugins/input/Torch/InputConversion/test/func_conversion.mlir
+++ b/compiler/plugins/input/Torch/InputConversion/test/func_conversion.mlir
@@ -111,6 +111,37 @@
}
// -----
+// Tests the immutable + mutable argument case with explicit affinities.
+// CHECK-LABEL: @mutable_input_overwrite_no_return
+// CHECK: util.func public @main$async(
+// CHECK-SAME: %arg0: !hal.buffer_view, %arg1: !hal.buffer_view,
+// CHECK-SAME: %arg2: !hal.fence, %arg3: !hal.fence) -> !hal.buffer_view
+// CHECK-DAG: %[[WAIT_ARG0:.+]] = hal.tensor.import on(#hal.device.promise<@dev_a>) wait(%arg2) => %arg0
+// CHECK-DAG: %[[TORCH_ARG0:.+]] = torch_c.from_builtin_tensor %[[WAIT_ARG0]]
+// CHECK-DAG: %[[WAIT_ARG1:.+]] = hal.tensor.import on(#hal.device.promise<@dev_b>) wait(%arg2) => %arg1
+// CHECK-DAG: %[[TORCH_ARG1:.+]] = torch_c.from_builtin_tensor %[[WAIT_ARG1]]
+// CHECK-DAG: %[[TORCH_RESULT0:.+]] = torch.operator "other_calc"(%[[TORCH_ARG0]])
+// CHECK-DAG: %[[TORCH_RESULT1:.+]] = torch.operator "mutate_inplace"(%[[TORCH_ARG1]])
+// CHECK-DAG: %[[TENSOR_ARG0:.+]] = torch_c.to_builtin_tensor %[[TORCH_RESULT0]]
+// CHECK-DAG: %[[TENSOR_ARG1:.+]] = torch_c.to_builtin_tensor %[[TORCH_RESULT1]]
+// CHECK: %[[EXPORT_ALIAS1:.+]] = hal.tensor.alias on(#hal.device.promise<@dev_b>) wait(%arg2) => %[[TENSOR_ARG1]] : tensor<5x4xf32> to %arg1 : !hal.buffer_view
+// CHECK: %[[BARRIER_RESULTS:.+]]:2 = hal.tensor.barrier join(%[[EXPORT_ALIAS1]], %[[TENSOR_ARG0]] : tensor<5x4xf32>, tensor<4x5xi32>) => %arg3 : !hal.fence
+// CHECK-DAG: %[[EXPORT_RESULT0:.+]] = hal.tensor.export on(#hal.device.promise<@dev_b>) %[[BARRIER_RESULTS]]#0
+// CHECK-DAG: %[[EXPORT_RESULT1:.+]] = hal.tensor.export on(#hal.device.promise<@dev_a>) %[[BARRIER_RESULTS]]#1
+// CHECK: util.return %[[EXPORT_RESULT1]]
+builtin.module @mutable_input_overwrite_no_return_affinities {
+func.func @main(%arg0: !torch.vtensor<[4,5],si32> {iree.abi.affinity = #hal.device.promise<@dev_a>},
+ %arg1: !torch.tensor<[5,4],f32> {iree.abi.affinity = #hal.device.promise<@dev_b>})
+ -> (!torch.vtensor<[4,5],si32> {iree.abi.affinity = #hal.device.promise<@dev_a>}) {
+ %0 = torch.copy.to_vtensor %arg1 : !torch.vtensor<[5,4],f32>
+ %1 = torch.operator "mutate_inplace"(%0) : (!torch.vtensor<[5,4],f32>) -> !torch.vtensor<[5,4],f32>
+ %2 = torch.operator "other_calc"(%arg0) : (!torch.vtensor<[4,5],si32>) -> !torch.vtensor<[4,5],si32>
+ torch.overwrite.tensor.contents %1 overwrites %arg1 : !torch.vtensor<[5,4],f32>, !torch.tensor<[5,4],f32>
+ return %2 : !torch.vtensor<[4,5],si32>
+}
+}
+
+// -----
// CHECK-LABEL: @retained_attribute_reflection
// CHECK: util.func public @main$async(
// CHECK-SAME: iree.reflection = {some.attr = 4 : index}
diff --git a/compiler/plugins/target/CUDA/test/smoketest.mlir b/compiler/plugins/target/CUDA/test/smoketest.mlir
index 54606d7..fc7d8fc 100644
--- a/compiler/plugins/target/CUDA/test/smoketest.mlir
+++ b/compiler/plugins/target/CUDA/test/smoketest.mlir
@@ -5,7 +5,9 @@
module attributes {
hal.device.targets = [
- #hal.device.target<"cuda", [#hal.executable.target<"cuda", "cuda-nvptx-fb">]>
+ #hal.device.target<"cuda", [
+ #hal.executable.target<"cuda", "cuda-nvptx-fb">
+ ]> : !hal.device
]
} {
diff --git a/compiler/plugins/target/LLVMCPU/test/BUILD.bazel b/compiler/plugins/target/LLVMCPU/test/BUILD.bazel
index 14b13f9..2332f86 100644
--- a/compiler/plugins/target/LLVMCPU/test/BUILD.bazel
+++ b/compiler/plugins/target/LLVMCPU/test/BUILD.bazel
@@ -16,6 +16,7 @@
name = "lit",
srcs = enforce_glob(
[
+ "materialize_homogeneous_encodings.mlir",
"smoketest_embedded.mlir",
"smoketest_system.mlir",
],
diff --git a/compiler/plugins/target/LLVMCPU/test/CMakeLists.txt b/compiler/plugins/target/LLVMCPU/test/CMakeLists.txt
index dde5618..5eee1f4 100644
--- a/compiler/plugins/target/LLVMCPU/test/CMakeLists.txt
+++ b/compiler/plugins/target/LLVMCPU/test/CMakeLists.txt
@@ -14,6 +14,7 @@
NAME
lit
SRCS
+ "materialize_homogeneous_encodings.mlir"
"smoketest_embedded.mlir"
"smoketest_system.mlir"
TOOLS
diff --git a/compiler/plugins/target/LLVMCPU/test/materialize_homogeneous_encodings.mlir b/compiler/plugins/target/LLVMCPU/test/materialize_homogeneous_encodings.mlir
new file mode 100644
index 0000000..5d5b591
--- /dev/null
+++ b/compiler/plugins/target/LLVMCPU/test/materialize_homogeneous_encodings.mlir
@@ -0,0 +1,30 @@
+// RUN: iree-opt --split-input-file --iree-hal-device-assignment-pipeline --iree-global-opt-materialize-homogeneous-encodings %s | FileCheck %s
+
+#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {target_triple = "x86_64-none-elf", cpu_features = "+avx512f"}>
+#map = affine_map<()[s0, s1] -> (-s1 + (s1 ceildiv s0) * s0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#device_target_llvm_cpu = #hal.device.target<"llvm-cpu", [#executable_target_embedded_elf_x86_64_]> : !hal.device
+module attributes {hal.device.targets = [#device_target_llvm_cpu]} {
+ util.func public @lhs_encoding(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+ %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+ %0:2 = iree_encoding.upper_bound_tile_size tensor<?x?xf32, #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map1, #map2, #map3]>> -> index, index
+ %1 = affine.apply #map()[%0#0, %dim]
+ %2 = affine.apply #map()[%0#1, %dim_0]
+ %padded = tensor.pad %arg0 low[0, 0] high[%1, %2] {
+ ^bb0(%arg1: index, %arg2: index):
+ tensor.yield %cst : f32
+ } : tensor<?x?xf32> to tensor<?x?xf32>
+ %3 = iree_encoding.set_encoding %padded : tensor<?x?xf32> -> tensor<?x?xf32, #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map1, #map2, #map3]>>
+ %4 = iree_encoding.unset_encoding %3 : tensor<?x?xf32, #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map1, #map2, #map3]>> -> tensor<?x?xf32>
+ util.return %4 : tensor<?x?xf32>
+ }
+}
+// CHECK-LABEL: util.func public @lhs_encoding
+// CHECK: tensor.pack
+// CHECK: tensor.unpack
diff --git a/compiler/plugins/target/LLVMCPU/test/smoketest_embedded.mlir b/compiler/plugins/target/LLVMCPU/test/smoketest_embedded.mlir
index 493a3c7..f9e0a4b 100644
--- a/compiler/plugins/target/LLVMCPU/test/smoketest_embedded.mlir
+++ b/compiler/plugins/target/LLVMCPU/test/smoketest_embedded.mlir
@@ -4,8 +4,10 @@
module attributes {
hal.device.targets = [
#hal.device.target<"llvm-cpu", [
- #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", { native_vector_size = 16 : index }>
- ]>
+ #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {
+ native_vector_size = 16 : index
+ }>
+ ]> : !hal.device
]
} {
diff --git a/compiler/plugins/target/LLVMCPU/test/smoketest_system.mlir b/compiler/plugins/target/LLVMCPU/test/smoketest_system.mlir
index bb5c607..d6c6658 100644
--- a/compiler/plugins/target/LLVMCPU/test/smoketest_system.mlir
+++ b/compiler/plugins/target/LLVMCPU/test/smoketest_system.mlir
@@ -6,8 +6,10 @@
module attributes {
hal.device.targets = [
#hal.device.target<"llvm-cpu", [
- #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64",{ native_vector_size = 16 : index } >
- ]>
+ #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {
+ native_vector_size = 16 : index
+ }>
+ ]> : !hal.device
]
} {
diff --git a/compiler/plugins/target/MetalSPIRV/test/smoketest.mlir b/compiler/plugins/target/MetalSPIRV/test/smoketest.mlir
index 720e00b..d32ac8e 100644
--- a/compiler/plugins/target/MetalSPIRV/test/smoketest.mlir
+++ b/compiler/plugins/target/MetalSPIRV/test/smoketest.mlir
@@ -8,7 +8,7 @@
compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [], subgroup_size_choices = [32],
max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
}>
- ]>
+ ]> : !hal.device
]
} {
diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp
index 5f97a97..75e4bbd 100644
--- a/compiler/plugins/target/ROCM/ROCMTarget.cpp
+++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp
@@ -226,7 +226,7 @@
targetRegistry.getTargetBackend("rocm")->getDefaultExecutableTargets(
context, "rocm", configAttr, executableTargetAttrs);
- return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("rocm"),
+ return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("hip"),
configAttr, executableTargetAttrs);
}
@@ -238,7 +238,7 @@
public:
ROCMTargetBackend(const ROCmOptions &options) : options(options) {}
- std::string getLegacyDefaultDeviceID() const override { return "rocm"; }
+ std::string getLegacyDefaultDeviceID() const override { return "hip"; }
void getDefaultExecutableTargets(
MLIRContext *context, StringRef deviceID, DictionaryAttr deviceConfigAttr,
@@ -702,8 +702,8 @@
: PluginSession<ROCMSession, ROCmOptions,
PluginActivationPolicy::DefaultActivated> {
void populateHALTargetDevices(IREE::HAL::TargetDeviceList &targets) {
- // #hal.device.target<"rocm", ...
- targets.add("rocm",
+ // #hal.device.target<"hip", ...
+ targets.add("hip",
[&]() { return std::make_shared<ROCMTargetDevice>(options); });
}
void populateHALTargetBackends(IREE::HAL::TargetBackendList &targets) {
diff --git a/compiler/plugins/target/ROCM/test/smoketest.mlir b/compiler/plugins/target/ROCM/test/smoketest.mlir
index 91c91ba..1afe688 100644
--- a/compiler/plugins/target/ROCM/test/smoketest.mlir
+++ b/compiler/plugins/target/ROCM/test/smoketest.mlir
@@ -2,7 +2,9 @@
module attributes {
hal.device.targets = [
- #hal.device.target<"rocm", [#hal.executable.target<"rocm", "rocm-hsaco-fb">]>
+ #hal.device.target<"hip", [
+ #hal.executable.target<"rocm", "rocm-hsaco-fb">
+ ]> : !hal.device
]
} {
@@ -44,7 +46,9 @@
#loc = loc(unknown)
module attributes {
hal.device.targets = [
- #hal.device.target<"rocm", [#hal.executable.target<"rocm", "rocm-hsaco-fb">]>
+ #hal.device.target<"hip", [
+ #hal.executable.target<"rocm", "rocm-hsaco-fb">
+ ]> : !hal.device
]
} {
diff --git a/compiler/plugins/target/ROCM/test/target_device_features.mlir b/compiler/plugins/target/ROCM/test/target_device_features.mlir
index 5973c05..15240f9 100644
--- a/compiler/plugins/target/ROCM/test/target_device_features.mlir
+++ b/compiler/plugins/target/ROCM/test/target_device_features.mlir
@@ -1,7 +1,7 @@
-// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=mi300x %s | FileCheck %s --check-prefix=GFX942
-// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx940 %s | FileCheck %s --check-prefix=GFX940
-// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=rx7900xtx %s | FileCheck %s --check-prefix=GFX1100
-// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx941 --iree-rocm-target-features=+sramecc,-xnack %s | FileCheck %s --check-prefix=GFX941
+// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=hip},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=mi300x %s | FileCheck %s --check-prefix=GFX942
+// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=hip},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx940 %s | FileCheck %s --check-prefix=GFX940
+// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=hip},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=rx7900xtx %s | FileCheck %s --check-prefix=GFX1100
+// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=hip},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx941 --iree-rocm-target-features=+sramecc,-xnack %s | FileCheck %s --check-prefix=GFX941
// GFX942: target = #iree_gpu.target<arch = "gfx942",
// GFX942-SAME: wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8,
@@ -21,7 +21,6 @@
// GFX941: target = #iree_gpu.target<arch = "gfx941",
// GFX941-SAME: features = "+sramecc,-xnack"
-
stream.executable public @reduce_dispatch {
stream.executable.export @reduce_dispatch workgroups(%arg0: index) -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0
diff --git a/compiler/plugins/target/VMVX/test/smoketest.mlir b/compiler/plugins/target/VMVX/test/smoketest.mlir
index b640d12..44b3208 100644
--- a/compiler/plugins/target/VMVX/test/smoketest.mlir
+++ b/compiler/plugins/target/VMVX/test/smoketest.mlir
@@ -4,7 +4,7 @@
hal.device.targets = [
#hal.device.target<"local", [
#hal.executable.target<"vmvx", "vmvx-bytecode-fb">
- ]>
+ ]> : !hal.device
]
} {
diff --git a/compiler/plugins/target/VulkanSPIRV/test/BUILD.bazel b/compiler/plugins/target/VulkanSPIRV/test/BUILD.bazel
index 32f2485..b839443 100644
--- a/compiler/plugins/target/VulkanSPIRV/test/BUILD.bazel
+++ b/compiler/plugins/target/VulkanSPIRV/test/BUILD.bazel
@@ -16,6 +16,7 @@
name = "lit",
srcs = enforce_glob(
[
+ "materialize_homogeneous_encodings.mlir",
"smoketest.mlir",
],
include = ["*.mlir"],
diff --git a/compiler/plugins/target/VulkanSPIRV/test/CMakeLists.txt b/compiler/plugins/target/VulkanSPIRV/test/CMakeLists.txt
index ec5576e..300499d 100644
--- a/compiler/plugins/target/VulkanSPIRV/test/CMakeLists.txt
+++ b/compiler/plugins/target/VulkanSPIRV/test/CMakeLists.txt
@@ -14,6 +14,7 @@
NAME
lit
SRCS
+ "materialize_homogeneous_encodings.mlir"
"smoketest.mlir"
TOOLS
FileCheck
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/materialize_homogeneous_encodings.mlir b/compiler/plugins/target/VulkanSPIRV/test/materialize_homogeneous_encodings.mlir
similarity index 62%
rename from compiler/src/iree/compiler/GlobalOptimization/test/materialize_homogeneous_encodings.mlir
rename to compiler/plugins/target/VulkanSPIRV/test/materialize_homogeneous_encodings.mlir
index c0d1597..037cda0 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/materialize_homogeneous_encodings.mlir
+++ b/compiler/plugins/target/VulkanSPIRV/test/materialize_homogeneous_encodings.mlir
@@ -1,42 +1,11 @@
-// RUN: iree-opt --split-input-file --iree-global-opt-materialize-homogeneous-encodings %s | FileCheck %s
-
-#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {target_triple = "x86_64-none-elf", cpu_features = "+avx512f"}>
-#map = affine_map<()[s0, s1] -> (-s1 + (s1 ceildiv s0) * s0)>
-#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
-#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
-#device_target_llvm_cpu = #hal.device.target<"llvm-cpu", [#executable_target_embedded_elf_x86_64_]>
-module attributes {hal.device.targets = [#device_target_llvm_cpu]} {
- util.func public @lhs_encoding(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
- %cst = arith.constant 0.000000e+00 : f32
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
- %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
- %0:2 = iree_encoding.upper_bound_tile_size tensor<?x?xf32, #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map1, #map2, #map3]>> -> index, index
- %1 = affine.apply #map()[%0#0, %dim]
- %2 = affine.apply #map()[%0#1, %dim_0]
- %padded = tensor.pad %arg0 low[0, 0] high[%1, %2] {
- ^bb0(%arg1: index, %arg2: index):
- tensor.yield %cst : f32
- } : tensor<?x?xf32> to tensor<?x?xf32>
- %3 = iree_encoding.set_encoding %padded : tensor<?x?xf32> -> tensor<?x?xf32, #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map1, #map2, #map3]>>
- %4 = iree_encoding.unset_encoding %3 : tensor<?x?xf32, #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map1, #map2, #map3]>> -> tensor<?x?xf32>
- util.return %4 : tensor<?x?xf32>
- }
-}
-// CHECK-LABEL: util.func public @lhs_encoding
-// CHECK: tensor.pack
-// CHECK: tensor.unpack
-
-// -----
+// RUN: iree-opt --split-input-file --iree-hal-device-assignment-pipeline --iree-global-opt-materialize-homogeneous-encodings %s | FileCheck %s
#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb">
#map = affine_map<()[s0, s1] -> (-s1 + (s1 ceildiv s0) * s0)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
-#device_target_vulkan = #hal.device.target<"vulkan", [#executable_target_vulkan_spirv_fb]>
+#device_target_vulkan = #hal.device.target<"vulkan", [#executable_target_vulkan_spirv_fb]> : !hal.device
module attributes {hal.device.targets = [#device_target_vulkan]} {
util.func public @lhs_encoding(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
%cst = arith.constant 0.000000e+00 : f32
@@ -57,7 +26,7 @@
}
}
-// vulkan uses default materialization patterns which unsets the encodings.
+// Vulkan uses default materialization patterns which unsets the encodings.
// CHECK-LABEL: util.func public @lhs_encoding
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK: util.return %[[ARG0]]
@@ -69,10 +38,10 @@
#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {target_triple = "x86_64-none-elf", cpu_features = "+avx512f"}>
-#device_target_llvm_cpu = #hal.device.target<"llvm-cpu", [#executable_target_embedded_elf_x86_64_]>
+#device_target_llvm_cpu = #hal.device.target<"llvm-cpu", [#executable_target_embedded_elf_x86_64_]> : !hal.device
#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb">
-#device_target_vulkan = #hal.device.target<"vulkan", [#executable_target_vulkan_spirv_fb]>
-module attributes {hal.device.targets = [#device_target_vulkan, #device_target_llvm_cpu]} {
+#device_target_vulkan = #hal.device.target<"vulkan", [#executable_target_vulkan_spirv_fb]> : !hal.device
+module attributes {hal.device.targets = [#hal.device.select<[#device_target_vulkan, #device_target_llvm_cpu]> : !hal.device]} {
util.func public @lhs_encoding(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
diff --git a/compiler/plugins/target/VulkanSPIRV/test/smoketest.mlir b/compiler/plugins/target/VulkanSPIRV/test/smoketest.mlir
index f8d8159..6ef88a8 100644
--- a/compiler/plugins/target/VulkanSPIRV/test/smoketest.mlir
+++ b/compiler/plugins/target/VulkanSPIRV/test/smoketest.mlir
@@ -8,7 +8,7 @@
compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [], subgroup_size_choices = [32, 32],
max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
}>
- ]>
+ ]> : !hal.device
]
} {
diff --git a/compiler/plugins/target/WebGPUSPIRV/test/smoketest.mlir b/compiler/plugins/target/WebGPUSPIRV/test/smoketest.mlir
index 31f361b..69c5ceb 100644
--- a/compiler/plugins/target/WebGPUSPIRV/test/smoketest.mlir
+++ b/compiler/plugins/target/WebGPUSPIRV/test/smoketest.mlir
@@ -9,7 +9,7 @@
compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [], subgroup_size_choices = [32],
max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
}>
- ]>
+ ]> : !hal.device
]
} {
diff --git a/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp b/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp
index e648f09..488555a 100644
--- a/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp
+++ b/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp
@@ -934,10 +934,12 @@
if (!getCompilationPhase(compileFrom, compileTo)) {
return false;
}
+
+ // TODO: move to someplace centralized; erroring here is not great.
// InlineStatic (currently) only supports the `vmvx-inline` backend.
if (session.schedulingOptions.executionModel ==
SchedulingOptions::ExecutionModel::InlineStatic) {
- for (auto target : session.halTargetOptions.targets) {
+ for (auto target : session.halTargetOptions.legacyTargetBackends) {
if (target != "vmvx-inline") {
parsedModule->emitError() << "InlineStatic execution model is not "
"compatible with hal target '"
diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp b/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp
index a5fd2fe..577cf52 100644
--- a/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp
+++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp
@@ -26,9 +26,9 @@
static IREE::ABI::InvocationModel
getInvocationModel(Operation *op, IREE::ABI::InvocationModel defaultModel) {
auto modelAttr = op->getAttrOfType<StringAttr>("iree.abi.model");
- if (!modelAttr)
+ if (!modelAttr) {
return defaultModel;
- if (modelAttr == "coarse-fences") {
+ } else if (modelAttr == "coarse-fences") {
return IREE::ABI::InvocationModel::CoarseFences;
} else {
return IREE::ABI::InvocationModel::Sync;
@@ -43,15 +43,22 @@
return type;
}
+// Returns true if the given |attr| is a known ABI attribute that is only used
+// by this pass.
+static bool isABIAttr(NamedAttribute attr) {
+ return attr.getName() == "iree.abi.affinity" ||
+ attr.getName() == "iree.abi.encoding" ||
+ attr.getName() == "iree.abi.model" ||
+ attr.getName() == "iree.abi.output";
+}
+
// Removes all ABI attrs handled by this pass from all dictionaries.
static void stripABIAttrs(SmallVectorImpl<DictionaryAttr> &allAttrs) {
for (auto &attrDict : allAttrs) {
SmallVector<NamedAttribute> attrs;
attrs.reserve(attrDict.size());
for (auto attr : attrDict) {
- // TODO(benvanik): faster lookup.
- if (attr.getName() != "iree.abi.output" &&
- attr.getName() != "iree.abi.encoding") {
+ if (!isABIAttr(attr)) {
attrs.push_back(attr);
}
}
@@ -59,6 +66,31 @@
}
}
+// Removes all ABI attrs from the |op| and its args/results.
+static void stripABIAttrs(FunctionOpInterface op) {
+ NamedAttrList attrs;
+ for (auto attr : op->getAttrs()) {
+ if (!isABIAttr(attr)) {
+ attrs.push_back(attr);
+ }
+ }
+ op->setAttrs(attrs);
+
+ SmallVector<DictionaryAttr> argAttrs;
+ op.getAllArgAttrs(argAttrs);
+ stripABIAttrs(argAttrs);
+ op.setAllArgAttrs(argAttrs);
+ SmallVector<DictionaryAttr> resultAttrs;
+ op.getAllResultAttrs(resultAttrs);
+ stripABIAttrs(resultAttrs);
+ op.setAllResultAttrs(resultAttrs);
+}
+
+template <typename T>
+static T fallback(T optionalValue, T defaultValue) {
+ return optionalValue ? optionalValue : defaultValue;
+}
+
// Creates the corresponding wrapper function for the given import function.
static IREE::Util::FuncOp
createImportWrapperFunc(IREE::ABI::InvocationModel invocationModel,
@@ -89,12 +121,7 @@
argAttrDict.push_back(nullptr); // signal
break;
}
-
- // Update the import type and propagate back the attributes we may have
- // modified above.
importOp.setType(newImportType);
- importOp.setAllArgAttrs(argAttrDict);
- importOp.setAllResultAttrs(resultAttrDict);
auto *entryBlock = wrapperOp.addEntryBlock();
auto entryBuilder = OpBuilder::atBlockBegin(entryBlock);
@@ -117,6 +144,12 @@
// users mark their functions 'nosideeffects' to avoid the host wait.
const bool hasSideEffects = !importOp->hasAttr("nosideeffects");
+ // Fetch and normalize any explicitly assigned affinity.
+ auto defaultAffinityAttr = importOp->getAttr("iree.abi.affinity");
+ if (defaultAffinityAttr) {
+ importOp->setAttr("stream.affinity", defaultAffinityAttr);
+ }
+
// When running async we insert a barrier on tensor arguments and attach that
// to the fence we pass to the import for waiting. We'll also allocate the
// signal fence that the import must signal when the returned tensors are
@@ -129,15 +162,24 @@
// No fences.
break;
case IREE::ABI::InvocationModel::CoarseFences: {
- // HACK: this is relying on the fact that there's only one HAL device.
- // We should instead have a way of creating fences on the device that
- // is used to produce the tensors we're wrapping.
- //
- // TODO(multi-device): emit get with derived ordinal or lookup with attr. We
- // could always say device 0 for now but could instead look for an
- // iree.abi.affinity/iree.abi.device/etc.
- Value device =
- IREE::HAL::DeviceType::resolveAny(importOp.getLoc(), entryBuilder);
+ Value device;
+ // TODO(benvanik): support other affinity types.
+ if (auto deviceAffinityAttr =
+ dyn_cast_if_present<IREE::HAL::DeviceAffinityAttr>(
+ defaultAffinityAttr)) {
+ device = entryBuilder
+ .create<IREE::HAL::DeviceResolveOp>(
+ importOp.getLoc(),
+ entryBuilder.getType<IREE::HAL::DeviceType>(),
+ deviceAffinityAttr)
+ .getResult(0);
+ } else {
+ // HACK: if no devices are available we get the first one available at
+ // runtime. This is suboptimal but we expect most usage to have affinities
+ // assigned prior to ABI conversion.
+ device =
+ IREE::HAL::DeviceType::resolveAny(importOp.getLoc(), entryBuilder);
+ }
// When exporting a fence we need to put a barrier between the rest of the
// program and the tensors consumed by the import.
@@ -187,20 +229,25 @@
// NOTE: we insert a barrier on this above if needed so that the wait
// fence will be signaled when the tensor is ready for consumption by the
// import.
- auto encoding =
+ auto encodingAttr =
importOp.getArgAttrOfType<TypeAttr>(argIndex, "iree.abi.encoding");
- auto exportOp = entryBuilder.create<IREE::HAL::TensorExportOp>(
+ auto tensorExportOp = entryBuilder.create<IREE::HAL::TensorExportOp>(
arg.getLoc(), newType, arg,
- encoding ? encoding : TypeAttr::get(oldType), /*name=*/nullptr);
- arguments.push_back(exportOp.getTarget());
+ fallback(encodingAttr, TypeAttr::get(oldType)),
+ /*name=*/nullptr,
+ fallback(importOp.getArgAttr(argIndex, "iree.abi.affinity"),
+ defaultAffinityAttr));
+ arguments.push_back(tensorExportOp.getTarget());
} else {
arguments.push_back(arg);
}
}
- if (waitFence)
+ if (waitFence) {
arguments.push_back(waitFence);
- if (signalFence)
+ }
+ if (signalFence) {
arguments.push_back(signalFence);
+ }
// Make the call with the updated types.
auto callOp = entryBuilder.create<IREE::Util::CallOp>(importOp.getLoc(),
@@ -225,18 +272,24 @@
// NOTE: we set the import pending on the signal fence from the import
// indicating when the returned tensor is ready for consumption by the
// program.
- auto encoding = importOp.getResultAttrOfType<TypeAttr>(
+ auto encodingAttr = importOp.getResultAttrOfType<TypeAttr>(
resultIndex, "iree.abi.encoding");
- results.push_back(entryBuilder.create<IREE::HAL::TensorImportOp>(
+ auto tensorImportOp = entryBuilder.create<IREE::HAL::TensorImportOp>(
importOp.getLoc(), oldType, result,
- encoding ? encoding : TypeAttr::get(oldType), signalFence,
- /*name=*/nullptr));
+ fallback(encodingAttr, TypeAttr::get(oldType)), signalFence,
+ /*name=*/nullptr,
+ fallback(importOp.getResultAttr(resultIndex, "iree.abi.affinity"),
+ defaultAffinityAttr));
+ results.push_back(tensorImportOp);
} else {
results.push_back(result);
}
}
entryBuilder.create<IREE::Util::ReturnOp>(importOp.getLoc(), results);
+
+ stripABIAttrs(importOp);
+
return wrapperOp;
}
@@ -285,8 +338,9 @@
auto wrapperOp = createImportWrapperFunc(
invocationModel, importOp, cast<FunctionType>(importOp.getFunctionType()),
newImportType, privateName);
- if (!wrapperOp)
+ if (!wrapperOp) {
return failure();
+ }
moduleOp.insert(++Block::iterator(importOp), wrapperOp);
// Update the import to the new type and mark it as being converted so we
@@ -302,15 +356,17 @@
static StringAttr inferArgumentName(MLIRContext *context, int index,
DictionaryAttr attrs) {
- if (auto attrName = getNameFromDictAttr(attrs))
+ if (auto attrName = getNameFromDictAttr(attrs)) {
return attrName;
+ }
return StringAttr::get(context, "input" + std::to_string(index));
}
static StringAttr inferResultName(MLIRContext *context, int index,
DictionaryAttr attrs) {
- if (auto attrName = getNameFromDictAttr(attrs))
+ if (auto attrName = getNameFromDictAttr(attrs)) {
return attrName;
+ }
return StringAttr::get(context, "output" + std::to_string(index));
}
@@ -326,8 +382,9 @@
auto shouldIncludeAttr = [](const NamedAttribute &attr) {
return attr.getName().getValue() != "iree.abi.name";
};
- if (!llvm::any_of(attrs, shouldIncludeAttr))
+ if (!llvm::any_of(attrs, shouldIncludeAttr)) {
return;
+ }
os << " {";
llvm::interleaveComma(llvm::make_filter_range(attrs, shouldIncludeAttr), os,
[&](auto argAttr) {
@@ -363,8 +420,9 @@
os << "func @" << publicName;
os << "(";
for (auto arg : exportOp.getArguments()) {
- if (arg.getArgNumber() > 0)
+ if (arg.getArgNumber() > 0) {
os << ", ";
+ }
os << "%";
os << inferArgumentName(exportOp.getContext(), arg.getArgNumber(),
getIOAttr(allArgAttrs, arg.getArgNumber()))
@@ -377,8 +435,9 @@
os << ") -> (";
for (auto [resultNumber, resultType] :
llvm::enumerate(exportOp.getResultTypes())) {
- if (resultNumber > 0)
+ if (resultNumber > 0) {
os << ", ";
+ }
os << "%";
os << inferResultName(exportOp.getContext(), resultNumber,
getIOAttr(allResultAttrs, resultNumber))
@@ -495,6 +554,12 @@
populateReflectionAttrs(invocationModel, exportOp, wrapperOp);
exportOp->removeAttr("iree.reflection");
+ // Fetch and normalize any explicitly assigned affinity.
+ auto defaultAffinityAttr = exportOp->getAttr("iree.abi.affinity");
+ if (defaultAffinityAttr) {
+ exportOp->setAttr("stream.affinity", defaultAffinityAttr);
+ }
+
auto *entryBlock = wrapperOp.addEntryBlock();
auto entryBuilder = OpBuilder::atBlockBegin(entryBlock);
@@ -504,8 +569,9 @@
for (unsigned i = 0; i < exportOp.getNumArguments(); ++i) {
auto outputAttr =
exportOp.getArgAttrOfType<IntegerAttr>(i, "iree.abi.output");
- if (!outputAttr)
+ if (!outputAttr) {
continue;
+ }
// Today all outputs need to be a !hal.buffer - we could change this
// in the future to be something more generalized.
auto storageArg = entryBlock->getArgument(i);
@@ -542,14 +608,16 @@
entryBlock->getArguments().slice(0, oldExportType.getNumInputs()))) {
auto oldType = oldExportType.getInput(argIndex);
if (llvm::isa<TensorType>(oldType)) {
- auto encoding =
+ auto encodingAttr =
exportOp.getArgAttrOfType<TypeAttr>(argIndex, "iree.abi.encoding");
- auto importOp = entryBuilder.create<IREE::HAL::TensorImportOp>(
+ auto argName = inferArgumentName(entryBuilder.getContext(), argIndex,
+ exportOp.getArgAttrDict(argIndex));
+ auto tensorImportOp = entryBuilder.create<IREE::HAL::TensorImportOp>(
arg.getLoc(), oldType, arg,
- encoding ? encoding : TypeAttr::get(oldType), waitFence,
- inferArgumentName(entryBuilder.getContext(), argIndex,
- exportOp.getArgAttrDict(argIndex)));
- arguments.push_back(importOp.getTarget());
+ fallback(encodingAttr, TypeAttr::get(oldType)), waitFence, argName,
+ fallback(exportOp.getArgAttr(argIndex, "iree.abi.affinity"),
+ defaultAffinityAttr));
+ arguments.push_back(tensorImportOp.getTarget());
} else {
arguments.push_back(arg);
}
@@ -563,14 +631,17 @@
// Alias results to storage buffers if provided.
for (unsigned resultIndex = 0; resultIndex < asyncResults.size();
++resultIndex) {
- if (!resultStorages[resultIndex])
+ if (!resultStorages[resultIndex]) {
continue;
+ }
auto source = asyncResults[resultIndex];
auto sourceDims = IREE::Util::buildDynamicDimsForValue(
exportOp.getLoc(), source, entryBuilder);
auto aliasOp = entryBuilder.create<IREE::HAL::TensorAliasOp>(
exportOp.getLoc(), source.getType(), source, sourceDims,
- resultStorages[resultIndex], waitFence);
+ resultStorages[resultIndex], waitFence,
+ fallback(exportOp.getResultAttr(resultIndex, "iree.abi.affinity"),
+ defaultAffinityAttr));
asyncResults[resultIndex] = cast<OpResult>(aliasOp.getResult());
}
@@ -580,8 +651,9 @@
if (signalFence) {
SmallVector<Value> asyncTensors;
for (auto result : asyncResults) {
- if (llvm::isa<TensorType>(result.getType()))
+ if (llvm::isa<TensorType>(result.getType())) {
asyncTensors.push_back(result);
+ }
}
if (asyncTensors.empty()) {
// TODO(benvanik): maybe use a global timeline? global stores may not
@@ -601,20 +673,27 @@
auto oldType = oldExportType.getResult(resultIndex);
auto newType = newExportType.getResult(resultIndex);
if (llvm::isa<TensorType>(oldType)) {
- auto encoding = exportOp.getResultAttrOfType<TypeAttr>(
+ auto encodingAttr = exportOp.getResultAttrOfType<TypeAttr>(
resultIndex, "iree.abi.encoding");
+ auto resultName =
+ inferResultName(entryBuilder.getContext(), resultIndex,
+ exportOp.getResultAttrDict(resultIndex));
auto dynamicDims = IREE::Util::buildDynamicDimsForValue(
result.getLoc(), result, entryBuilder);
- results.push_back(entryBuilder.create<IREE::HAL::TensorExportOp>(
+ auto tensorExportOp = entryBuilder.create<IREE::HAL::TensorExportOp>(
result.getLoc(), newType, result,
- encoding ? encoding : TypeAttr::get(result.getType()), dynamicDims,
- inferResultName(entryBuilder.getContext(), resultIndex,
- exportOp.getResultAttrDict(resultIndex))));
+ fallback(encodingAttr, TypeAttr::get(result.getType())), dynamicDims,
+ resultName,
+ fallback(exportOp.getResultAttr(resultIndex, "iree.abi.affinity"),
+ defaultAffinityAttr));
+ results.push_back(tensorExportOp);
} else {
results.push_back(result);
}
}
+ stripABIAttrs(exportOp);
+
entryBuilder.create<IREE::Util::ReturnOp>(exportOp.getLoc(), results);
return wrapperOp;
}
@@ -641,8 +720,9 @@
// marshals arguments/results to the original function.
auto wrapperOp =
createExportWrapperFunc(invocationModel, exportOp, publicName);
- if (!wrapperOp)
+ if (!wrapperOp) {
return failure();
+ }
symbolTable.insert(wrapperOp, Block::iterator(exportOp));
return success();
@@ -693,12 +773,15 @@
exportOps.push_back(funcOp);
}
}
+ if (importOps.empty() && exportOps.empty()) {
+ return; // no-op
+ }
SymbolTable symbolTable(moduleOp);
// Create a wrapper function for each imported function.
- // This will preserve the internal types (tensors/etc) but change the import
- // to taking the ABI types and rewrite calls.
+ // This will preserve the internal types (tensors/etc) but change the
+ // import to taking the ABI types and rewrite calls.
for (auto importOp : importOps) {
if (failed(wrapImportFunc(getInvocationModel(importOp, invocationModel),
moduleOp, importOp, symbolTable))) {
diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir
index 3780ee2..72a0441 100644
--- a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir
+++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir
@@ -181,6 +181,28 @@
// -----
+// Tests that explicit import affinity specification is carried through to
+// the marshaling ops.
+
+// CHECK-LABEL: util.func private @pinnedImport(%arg0: !hal.buffer_view) -> !hal.buffer_view
+util.func private @pinnedImport(tensor<2xi32> {iree.abi.affinity = #hal.device.promise<@dev_a>}) -> (tensor<2xi32> {iree.abi.affinity = #hal.device.promise<@dev_b>})
+
+// CHECK: util.func private @_pinnedImport(%[[ARG_TENSOR:.+]]: tensor<2xi32>) -> tensor<2xi32> {
+// CHECK: %[[ARG_VIEW:.+]] = hal.tensor.export on(#hal.device.promise<@dev_a>) %[[ARG_TENSOR]] : tensor<2xi32> -> !hal.buffer_view
+// CHECK: %[[RET_VIEW:.+]] = util.call @pinnedImport(%[[ARG_VIEW]]) : (!hal.buffer_view) -> !hal.buffer_view
+// CHECK: %[[RET_TENSOR:.+]] = hal.tensor.import on(#hal.device.promise<@dev_b>) %[[RET_VIEW]] : !hal.buffer_view -> tensor<2xi32>
+// CHECK: util.return %[[RET_TENSOR]]
+// CHECK: }
+
+// CHECK: util.func private @pinnedCaller(%arg0: tensor
+util.func private @pinnedCaller(%arg0: tensor<2xi32>) -> tensor<2xi32> {
+ // CHECK: util.call @_pinnedImport(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
+ %0 = util.call @pinnedImport(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
+ util.return %0 : tensor<2xi32>
+}
+
+// -----
+
// Tests that imports with encodings specified are propagated to the HAL ops.
// CHECK-LABEL: util.func private @importEncodings(%arg0: !hal.buffer_view) -> !hal.buffer_view
@@ -188,10 +210,10 @@
// CHECK: util.func private @_importEncodings(%[[ARG_TENSOR:.+]]: tensor<?x2xi32>) -> tensor<2x?xi32> {
// CHECK: %[[ARG_DIM:.+]] = tensor.dim %[[ARG_TENSOR]], %c0
-// CHECK: %[[ARG_VIEW:.+]] = hal.tensor.export %[[ARG_TENSOR]] : tensor<?x2xi32>{%[[ARG_DIM]]} -> !hal.buffer_view
+// CHECK: %[[ARG_VIEW:.+]] = hal.tensor.export %[[ARG_TENSOR]] : tensor<?x2xf32> as tensor<?x2xi32>{%[[ARG_DIM]]} -> !hal.buffer_view
// CHECK: %[[RET_VIEW:.+]] = util.call @importEncodings(%[[ARG_VIEW]]) : (!hal.buffer_view) -> !hal.buffer_view
// CHECK: %[[RET_DIM:.+]] = hal.buffer_view.dim<%[[RET_VIEW]] : !hal.buffer_view>[1]
-// CHECK: %[[RET_TENSOR:.+]] = hal.tensor.import %[[RET_VIEW]] : !hal.buffer_view -> tensor<2x?xi32>{%[[RET_DIM]]}
+// CHECK: %[[RET_TENSOR:.+]] = hal.tensor.import %[[RET_VIEW]] : !hal.buffer_view -> tensor<2x?xf32> as tensor<2x?xi32>{%[[RET_DIM]]}
// CHECK: util.return %[[RET_TENSOR]]
// CHECK: }
diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points_coarse_fences.mlir b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points_coarse_fences.mlir
index 4505a54..f9af505 100644
--- a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points_coarse_fences.mlir
+++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points_coarse_fences.mlir
@@ -170,6 +170,40 @@
// -----
+// Tests that explicit import affinity specification is carried through to
+// the marshaling ops.
+
+util.global private @dev_a : !hal.device
+util.global private @dev_b : !hal.device
+util.global private @dev_c : !hal.device
+
+// CHECK-LABEL: util.func private @pinnedImport(%arg0: !hal.buffer_view, %arg1: !hal.fence, %arg2: !hal.fence) -> !hal.buffer_view
+util.func private @pinnedImport(tensor<2xi32> {iree.abi.affinity = #hal.device.affinity<@dev_a>}) -> (tensor<2xi32> {iree.abi.affinity = #hal.device.affinity<@dev_b>}) attributes {
+ iree.abi.affinity = #hal.device.affinity<@dev_c>,
+ iree.abi.model = "coarse-fences",
+ nosideeffects
+}
+
+// CHECK: util.func private @_pinnedImport(%[[ARG_TENSOR:.+]]: tensor<2xi32>) -> tensor<2xi32> {
+// CHECK-DAG: %[[DEVICE_C:.+]] = hal.device.resolve on(<@dev_c>) : !hal.device
+// CHECK-DAG: %[[ARG_FENCE:.+]] = hal.fence.create device(%[[DEVICE_C]] : !hal.device) flags("None") : !hal.fence
+// CHECK-DAG: %[[ARG_READY:.+]] = hal.tensor.barrier join(%[[ARG_TENSOR]] : tensor<2xi32>) => %[[ARG_FENCE]] : !hal.fence
+// CHECK-DAG: %[[ARG_VIEW:.+]] = hal.tensor.export on(#hal.device.affinity<@dev_a>) %[[ARG_READY]] : tensor<2xi32> -> !hal.buffer_view
+// CHECK-DAG: %[[RESULT_FENCE:.+]] = hal.fence.create device(%[[DEVICE_C]] : !hal.device) flags("None") : !hal.fence
+// CHECK: %[[RET_VIEW:.+]] = util.call @pinnedImport(%[[ARG_VIEW]], %[[ARG_FENCE]], %[[RESULT_FENCE]]) : (!hal.buffer_view, !hal.fence, !hal.fence) -> !hal.buffer_view
+// CHECK: %[[RET_TENSOR:.+]] = hal.tensor.import on(#hal.device.affinity<@dev_b>) wait(%[[RESULT_FENCE]]) => %[[RET_VIEW]] : !hal.buffer_view -> tensor<2xi32>
+// CHECK: util.return %[[RET_TENSOR]]
+// CHECK: }
+
+// CHECK: util.func private @pinnedCaller(%arg0: tensor
+util.func private @pinnedCaller(%arg0: tensor<2xi32>) -> tensor<2xi32> {
+ // CHECK: util.call @_pinnedImport(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
+ %0 = util.call @pinnedImport(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
+ util.return %0 : tensor<2xi32>
+}
+
+// -----
+
// Tests a side-effect-free import that doesn't take/return reference types.
// CHECK-LABEL: util.func private @importI32(%arg0: i32, %arg1: !hal.fence, %arg2: !hal.fence) -> i32
diff --git a/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp b/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp
index 1ec24e8..d21bfe0 100644
--- a/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp
+++ b/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp
@@ -227,7 +227,8 @@
auto dynamicDims = inputDynamicDims.loadDynamicDims(recalculateBuilder);
auto castOp = recalculateBuilder.create<IREE::HAL::TensorImportOp>(
loc, inputValue.getType(), inputPlaceholder, inputValue.getType(),
- dynamicDims, /*wait_fence=*/Value{}, /*name=*/nullptr);
+ dynamicDims, /*wait_fence=*/Value{}, /*name=*/nullptr,
+ /*affinity=*/nullptr);
inputValue.replaceAllUsesWith(castOp.getTarget());
}
while (entryBlock.getNumArguments() > 0) {
@@ -499,6 +500,8 @@
wrapperFuncOp.setAllResultAttrs(resultAttrDict);
populateReflectionAttrs(entryFuncOp, wrapperFuncOp);
+ if (auto affinityAttr = entryFuncOp->getAttr("stream.affinity"))
+ wrapperFuncOp->setAttr("stream.affinity", affinityAttr);
// Call the entryFuncOp and return the results.
// If we wanted to perform additional work here to invalidate cached shapes
@@ -523,7 +526,8 @@
callOperands.push_back(entryBuilder.create<IREE::HAL::TensorImportOp>(
arg.getLoc(), inputDynamicDims.tensorType, arg,
TypeAttr::get(inputDynamicDims.tensorType), dynamicDims,
- /*wait_fence=*/Value{}, /*name=*/nullptr));
+ /*wait_fence=*/Value{}, /*name=*/nullptr,
+ /*affinity=*/nullptr));
}
auto callOp = entryBuilder.create<IREE::Util::CallOp>(
entryFuncOp.getLoc(), entryFuncOp, callOperands);
@@ -539,7 +543,7 @@
}
callResults.push_back(entryBuilder.create<IREE::HAL::TensorExportOp>(
result.getLoc(), bufferType, result, outputDynamicDims.tensorType,
- dynamicDims, /*name=*/nullptr));
+ dynamicDims, /*name=*/nullptr, /*affinity=*/nullptr));
for (auto [dynamicDim, globalOp] :
llvm::zip_equal(dynamicDims, outputDynamicDims.globalOps)) {
globalOp.createStoreOp(result.getLoc(), dynamicDim, entryBuilder);
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel
index 0efe1f1..fb3b309 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel
@@ -46,7 +46,7 @@
name = "CommonCPUPasses",
srcs = [
"CPULowerToUKernels.cpp",
- "CPUMaterializeEncodingPass.cpp",
+ "CPUMaterializeEncodings.cpp",
"CPUPrepareUkernels.cpp",
"Passes.cpp",
],
@@ -62,7 +62,9 @@
"//compiler/src/iree/compiler/Codegen/Transforms",
"//compiler/src/iree/compiler/Codegen/Utils",
"//compiler/src/iree/compiler/Dialect/Encoding/IR",
+ "//compiler/src/iree/compiler/Dialect/HAL/Analysis",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
+ "//compiler/src/iree/compiler/Dialect/Stream/Analysis",
"//runtime/src/iree/builtins/ukernel:exported_bits",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt
index fbae2ac..0b79e7d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt
@@ -43,7 +43,7 @@
"Passes.h"
SRCS
"CPULowerToUKernels.cpp"
- "CPUMaterializeEncodingPass.cpp"
+ "CPUMaterializeEncodings.cpp"
"CPUPrepareUkernels.cpp"
"Passes.cpp"
DEPS
@@ -84,7 +84,9 @@
iree::compiler::Codegen::Transforms
iree::compiler::Codegen::Utils
iree::compiler::Dialect::Encoding::IR
+ iree::compiler::Dialect::HAL::Analysis
iree::compiler::Dialect::HAL::IR
+ iree::compiler::Dialect::Stream::Analysis
PUBLIC
)
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp
similarity index 72%
rename from compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp
rename to compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp
index e81ed85..4edaffa 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp
@@ -11,7 +11,9 @@
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
+#include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
+#include "iree/compiler/Dialect/Stream/Analysis/Affinity.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/MathExtras.h"
@@ -29,15 +31,12 @@
namespace mlir::iree_compiler {
-using namespace IREE::Encoding;
-using IREE::HAL::ExecutableTargetAttr;
-
// Enumerate tile sizes to choose from when no specific architecture is
// targeted. For narrow-{M,N} cases, this only enumerates on narrow M. The
// narrow-N cases are handled by transposition in chooseMatmulTile.
static SmallVector<TileMxNxK>
enumerateMatmulTilesVMVX(linalg::ContractionDimensions cDims,
- ExecutableTargetAttr target) {
+ IREE::HAL::ExecutableTargetAttr target) {
// TODO(hanchung): The ukernel path does not support 3d
// codegen.query_tile_sizes op, so we disable dynamic tile shapes for
// batch_matmul.
@@ -59,7 +58,7 @@
// For narrow-{M,N} cases, this only enumerates on narrow M. The narrow-N cases
// are handled by transposition in chooseMatmulTile.
static SmallVector<TileMxNxK>
-enumerateMatmulTileRiscv32(ExecutableTargetAttr target) {
+enumerateMatmulTileRiscv32(IREE::HAL::ExecutableTargetAttr target) {
if (hasUkernel(target)) {
return {
TileMxNxK{8, 8, 4}, // Some reasonable tile shape.
@@ -76,7 +75,8 @@
// For narrow-{M,N} cases, this only enumerates on narrow M. The narrow-N cases
// are handled by transposition in chooseMatmulTile.
static SmallVector<TileMxNxK>
-enumerateMatmulTileArm64(TypeRange elementTypes, ExecutableTargetAttr target) {
+enumerateMatmulTileArm64(TypeRange elementTypes,
+ IREE::HAL::ExecutableTargetAttr target) {
// Data-tiling for SVE is not implemented yet.
if (hasFeature(target, "+sve") || hasFeature(target, "+sve2")) {
return {};
@@ -166,7 +166,8 @@
// For narrow-{M,N} cases, this only enumerates on narrow M. The narrow-N cases
// are handled by transposition in chooseMatmulTile.
static SmallVector<TileMxNxK>
-enumerateMatmulTileX86_64(TypeRange elementTypes, ExecutableTargetAttr target) {
+enumerateMatmulTileX86_64(TypeRange elementTypes,
+ IREE::HAL::ExecutableTargetAttr target) {
assert(elementTypes.size() == 3);
Type lhs = elementTypes[0];
Type rhs = elementTypes[1];
@@ -376,9 +377,10 @@
return bestRatedTile;
}
-SmallVector<TileMxNxK>
+static SmallVector<TileMxNxK>
enumerateMatmulTileMxNxK(linalg::ContractionDimensions cDims,
- TypeRange elementTypes, ExecutableTargetAttr target) {
+ TypeRange elementTypes,
+ IREE::HAL::ExecutableTargetAttr target) {
if (isVMVXBackend(target)) {
return enumerateMatmulTilesVMVX(cDims, target);
}
@@ -394,41 +396,10 @@
return {};
}
-struct CPUMaterializeEncodingPass
- : public CPUMaterializeEncodingBase<CPUMaterializeEncodingPass> {
- CPUMaterializeEncodingPass() : targetAttr(nullptr) {}
- explicit CPUMaterializeEncodingPass(IREE::HAL::ExecutableTargetAttr attr)
- : targetAttr(attr) {}
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<arith::ArithDialect, tensor::TensorDialect,
- IREE::Codegen::IREECodegenDialect>();
- }
- void runOnOperation() override;
-
-private:
- IREE::HAL::ExecutableTargetAttr targetAttr;
-};
-
-struct CPUMaterializeUpperBoundTileSizePass
- : public CPUMaterializeUpperBoundTileSizeBase<
- CPUMaterializeUpperBoundTileSizePass> {
- CPUMaterializeUpperBoundTileSizePass() = default;
- explicit CPUMaterializeUpperBoundTileSizePass(
- ArrayRef<IREE::HAL::ExecutableTargetAttr> attrs)
- : targetAttrs(attrs) {}
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<arith::ArithDialect>();
- }
- void runOnOperation() override;
-
-private:
- SmallVector<IREE::HAL::ExecutableTargetAttr, 4> targetAttrs;
-};
-
-FailureOr<MaterializeEncodingInfo>
+static FailureOr<MaterializeEncodingInfo>
materializeEncodingForTarget(RankedTensorType tensorType,
- ExecutableTargetAttr targetAttr) {
- IREE::Encoding::EncodingAttr encoding =
+ IREE::HAL::ExecutableTargetAttr targetAttr) {
+ auto encoding =
dyn_cast_or_null<IREE::Encoding::EncodingAttr>(tensorType.getEncoding());
if (!encoding) {
return failure();
@@ -464,7 +435,7 @@
}
static MaterializeEncodingFn
-getMaterializeEncodingFn(ExecutableTargetAttr targetAttr) {
+getMaterializeEncodingFn(IREE::HAL::ExecutableTargetAttr targetAttr) {
return
[targetAttr](
RankedTensorType tensorType) -> FailureOr<MaterializeEncodingInfo> {
@@ -481,8 +452,8 @@
// executable variant. There, the padding amounts only control the size of
// allocated buffers, so it's OK to over-estimate (only wasting some memory)
// but not under-estimate (would cause buffer overruns) padding amounts.
-static MaterializeEncodingFn
-getUpperBoundMaterializeEncodingFn(ArrayRef<ExecutableTargetAttr> targetAttrs) {
+static MaterializeEncodingFn getUpperBoundMaterializeEncodingFn(
+ ArrayRef<IREE::HAL::ExecutableTargetAttr> targetAttrs) {
return
[targetAttrs](
RankedTensorType tensorType) -> FailureOr<MaterializeEncodingInfo> {
@@ -540,73 +511,220 @@
return {};
}
-void CPUMaterializeEncodingPass::runOnOperation() {
- MLIRContext *context = &getContext();
- auto operation = getOperation();
- RewritePatternSet materializeEncodingPattern(context);
- if (!targetAttr)
- targetAttr = ExecutableTargetAttr::lookup(operation);
- auto materializeEncodingFn = getMaterializeEncodingFn(targetAttr);
+static LogicalResult materializeFuncOpEncodings(
+ FunctionOpInterface funcOp,
+ IREE::HAL::ExecutableTargetAttr executableTargetAttr) {
+ RewritePatternSet materializeEncodingPattern(funcOp.getContext());
+ auto materializeEncodingFn = getMaterializeEncodingFn(executableTargetAttr);
if (!materializeEncodingFn) {
- return signalPassFailure();
+ return failure();
}
MaterializeEncodingTypeConverter typeConverter(materializeEncodingFn);
- MaterializeEncodingConversionTarget target(*context);
- auto materializeEncodingValueFn = getMaterializeEncodingValueFn(targetAttr);
+ MaterializeEncodingConversionTarget target(*funcOp.getContext());
+ auto materializeEncodingValueFn =
+ getMaterializeEncodingValueFn(executableTargetAttr);
populateMaterializeEncodingIntoPackUnPackPatterns(materializeEncodingPattern,
target, typeConverter,
materializeEncodingValueFn);
- if (failed(applyPartialConversion(operation, target,
+ if (failed(applyPartialConversion(funcOp, target,
std::move(materializeEncodingPattern)))) {
- operation.emitOpError("materialization failed");
- return signalPassFailure();
+ funcOp.emitOpError("materialization failed");
+ return failure();
}
- // Add patterns to fold pack/unpack ops with pad/extract_slice ops and resolve
- // dims ops.
+ // Add patterns to fold pack/unpack ops with pad/extract_slice ops and
+ // resolve dims ops.
{
- RewritePatternSet patterns(context);
+ RewritePatternSet patterns(funcOp.getContext());
tensor::populateFoldIntoPackAndUnpackPatterns(patterns);
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
- if (failed(applyPatternsAndFoldGreedily(operation, std::move(patterns)))) {
- operation.emitOpError("folding patterns failed");
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+ funcOp.emitOpError("folding patterns failed");
+ return failure();
+ }
+ }
+
+ return success();
+}
+
+// Returns the executable targets used within |funcOp|.
+//
+// TODO(multi-device): delete this pass and rely on tensor-based analysis to
+// materialize encodings based on where tensors are used. This pass is not able
+// to handle that.
+static std::optional<SetVector<IREE::HAL::ExecutableTargetAttr>>
+getFuncExecutableTargetAttrs(FunctionOpInterface funcOp,
+ IREE::Stream::AffinityAnalysis &affinityAnalysis,
+ IREE::HAL::DeviceAnalysis &deviceAnalysis) {
+ // Get a set of all unique affinities used by resources within the function.
+ SetVector<IREE::Stream::AffinityAttr> uniqueAffinityAttrs;
+ SmallVector<IREE::Stream::AffinityAttr> lookupAffinityAttrs;
+ funcOp.walk([&](Operation *op) {
+ if (affinityAnalysis.tryLookupExecutionAffinity(op, lookupAffinityAttrs)) {
+ uniqueAffinityAttrs.insert(lookupAffinityAttrs.begin(),
+ lookupAffinityAttrs.end());
+ }
+ lookupAffinityAttrs.clear();
+ });
+
+ // Resolve affinities to executable targets.
+ SetVector<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs;
+ for (auto affinityAttr : uniqueAffinityAttrs) {
+ deviceAnalysis.gatherRequiredExecutableTargets(affinityAttr, funcOp,
+ executableTargetAttrs);
+ }
+ return executableTargetAttrs;
+}
+
+struct CPUMaterializeHostEncodingPass
+ : public CPUMaterializeHostEncodingBase<CPUMaterializeHostEncodingPass> {
+ CPUMaterializeHostEncodingPass() = default;
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<arith::ArithDialect, tensor::TensorDialect,
+ IREE::Codegen::IREECodegenDialect>();
+ }
+
+ void runOnOperation() override {
+ auto moduleOp = getOperation();
+
+ // Run required analysis passes.
+ IREE::Stream::AffinityAnalysis affinityAnalysis(moduleOp);
+ if (failed(affinityAnalysis.run())) {
+ return signalPassFailure();
+ }
+ IREE::HAL::DeviceAnalysis deviceAnalysis(moduleOp);
+ if (failed(deviceAnalysis.run())) {
+ return signalPassFailure();
+ }
+
+ for (auto funcOp : moduleOp.getOps<FunctionOpInterface>()) {
+ // Gather the required executable targets for the function. Note that it's
+ // possible there are more required for ops nested within the function but
+ // this pass is a hack and can't handle that :shrug:.
+ auto executableTargets = getFuncExecutableTargetAttrs(
+ funcOp, affinityAnalysis, deviceAnalysis);
+ if (!executableTargets) {
+ funcOp.emitOpError()
+ << "could not determine executable targets for the function";
+ return signalPassFailure();
+ } else if (executableTargets->empty()) {
+ // Probably no tensors.
+ continue;
+ }
+
+ // HACK: this pass is run on the host _but shouldn't be_. Because it's
+ // run on the host and IREE is a compiler capable of multi-targeting there
+ // may be multiple executable targets at any point in the host program.
+ // This pass can't handle that and assumes it's been checked earlier by
+ // spooky action at a distance. This needs to be fixed.
+ if (executableTargets->size() != 1) {
+ funcOp.emitOpError() << "has multiple executable targets and CPU data "
+ "tiling isn't built to support that";
+ return signalPassFailure();
+ }
+
+ // Materialize encodings within the function.
+ if (failed(
+ materializeFuncOpEncodings(funcOp, executableTargets->front()))) {
+ return signalPassFailure();
+ }
+ }
+ }
+};
+
+std::unique_ptr<Pass> createCPUMaterializeHostEncodingPass() {
+ return std::make_unique<CPUMaterializeHostEncodingPass>();
+}
+
+// NOTE: this runs on host modules and executables and has two paths to handle
+// that. It should _not_ be running on both - target-specific codegen passes
+// are not allowed on host programs and it's a big violation of layering that
+// this exists.
+struct CPUMaterializeDeviceEncodingPass
+ : public CPUMaterializeDeviceEncodingBase<
+ CPUMaterializeDeviceEncodingPass> {
+ CPUMaterializeDeviceEncodingPass() = default;
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<arith::ArithDialect, tensor::TensorDialect,
+ IREE::Codegen::IREECodegenDialect>();
+ }
+
+ void runOnOperation() override {
+ auto funcOp = getOperation();
+ auto executableTargetAttr = IREE::HAL::ExecutableTargetAttr::lookup(funcOp);
+ if (failed(materializeFuncOpEncodings(funcOp, executableTargetAttr))) {
return signalPassFailure();
}
}
+};
+
+std::unique_ptr<Pass> createCPUMaterializeDeviceEncodingPass() {
+ return std::make_unique<CPUMaterializeDeviceEncodingPass>();
}
-void CPUMaterializeUpperBoundTileSizePass::runOnOperation() {
- MLIRContext *context = &getContext();
- auto operation = getOperation();
- if (targetAttrs.empty()) {
- targetAttrs =
- IREE::HAL::DeviceTargetAttr::lookupExecutableTargets(operation);
- }
- RewritePatternSet patterns(context);
- MaterializeEncodingFn materializeEncodingFn =
- getUpperBoundMaterializeEncodingFn(targetAttrs);
- if (!materializeEncodingFn) {
- return signalPassFailure();
- }
- populateMaterializeUpperBoundTileSizePatterns(patterns,
- materializeEncodingFn);
- if (failed(applyPatternsAndFoldGreedily(operation, std::move(patterns)))) {
- operation.emitOpError(
- "encoding padding sizes materialization pattern failed");
- return signalPassFailure();
- }
-}
+// NOTE: this runs on host modules.
+struct CPUMaterializeUpperBoundTileSizePass
+ : public CPUMaterializeUpperBoundTileSizeBase<
+ CPUMaterializeUpperBoundTileSizePass> {
+ CPUMaterializeUpperBoundTileSizePass() = default;
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createCPUMaterializeEncodingPass(IREE::HAL::ExecutableTargetAttr targetAttr) {
- return std::make_unique<CPUMaterializeEncodingPass>(targetAttr);
-}
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<arith::ArithDialect>();
+ }
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createCPUMaterializeUpperBoundTileSizePass(
- ArrayRef<IREE::HAL::ExecutableTargetAttr> targetAttrs) {
- return std::make_unique<CPUMaterializeUpperBoundTileSizePass>(targetAttrs);
+ void runOnOperation() override {
+ auto moduleOp = getOperation();
+
+ // Run required analysis passes.
+ IREE::Stream::AffinityAnalysis affinityAnalysis(moduleOp);
+ if (failed(affinityAnalysis.run())) {
+ return signalPassFailure();
+ }
+ IREE::HAL::DeviceAnalysis deviceAnalysis(moduleOp);
+ if (failed(deviceAnalysis.run())) {
+ return signalPassFailure();
+ }
+
+ for (auto funcOp : moduleOp.getOps<FunctionOpInterface>()) {
+ // Gather the required executable targets for the function. Note that it's
+ // possible there are more required for ops nested within the function but
+ // this pass is a hack and can't handle that :shrug:.
+ auto executableTargets = getFuncExecutableTargetAttrs(
+ funcOp, affinityAnalysis, deviceAnalysis);
+ if (!executableTargets) {
+ funcOp.emitOpError()
+ << "could not determine executable targets for the function";
+ return signalPassFailure();
+ } else if (executableTargets->empty()) {
+ // Probably no tensors.
+ continue;
+ }
+
+ // Get patterns specialized for the executable targets used by the
+ // function.
+ RewritePatternSet patterns(&getContext());
+ MaterializeEncodingFn materializeEncodingFn =
+ getUpperBoundMaterializeEncodingFn(executableTargets->getArrayRef());
+ if (!materializeEncodingFn)
+ return signalPassFailure();
+ populateMaterializeUpperBoundTileSizePatterns(patterns,
+ materializeEncodingFn);
+
+ // Run patterns on the function.
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+ funcOp.emitOpError(
+ "encoding padding sizes materialization pattern failed");
+ return signalPassFailure();
+ }
+ }
+ }
+};
+
+std::unique_ptr<Pass> createCPUMaterializeUpperBoundTileSizePass() {
+ return std::make_unique<CPUMaterializeUpperBoundTileSizePass>();
}
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/PassDetail.h b/compiler/src/iree/compiler/Codegen/Common/CPU/PassDetail.h
index 3a782ba..25360a7 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/PassDetail.h
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/PassDetail.h
@@ -7,6 +7,7 @@
#ifndef IREE_COMPILER_CODEGEN_LLVMCPU_PASS_DETAIL_H_
#define IREE_COMPILER_CODEGEN_LLVMCPU_PASS_DETAIL_H_
+#include "mlir/IR/BuiltinOps.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/Passes.h b/compiler/src/iree/compiler/Codegen/Common/CPU/Passes.h
index c1fe3bf..f5f9a31 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/Passes.h
@@ -13,6 +13,7 @@
#define IREE_COMPILER_CODEGEN_COMMON_CPU_PASSES_H_
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
+#include "mlir/IR/BuiltinOps.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
@@ -22,9 +23,8 @@
/// encoding.set_encoding -> tensor.pack
/// encoding.unset_encoding -> tensor.unpack
/// linalg.matmul -> linalg.mmt4d
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createCPUMaterializeEncodingPass(
- IREE::HAL::ExecutableTargetAttr targetAttr = nullptr);
+std::unique_ptr<Pass> createCPUMaterializeHostEncodingPass();
+std::unique_ptr<Pass> createCPUMaterializeDeviceEncodingPass();
/// Like createLLVMCPUMaterializeEncodingPass, but specifically for
/// encoding.upper_bound_tile_size, converting it to constants.
@@ -41,9 +41,7 @@
/// converts upper_bound_tile_size to some specific constant size (currently 16)
/// that is the largest tile size that we can use in VMVX, and can be adjusted
// as needed.
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createCPUMaterializeUpperBoundTileSizePass(
- ArrayRef<IREE::HAL::ExecutableTargetAttr> targetAttrs = {});
+std::unique_ptr<Pass> createCPUMaterializeUpperBoundTileSizePass();
/// Adds CPU bufferization passes to the pipeline.
void addCPUBufferizePasses(OpPassManager &funcPassManager);
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/Passes.td b/compiler/src/iree/compiler/Codegen/Common/CPU/Passes.td
index 8e30049..6329c53 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/Passes.td
@@ -13,14 +13,20 @@
// Common Passes used for CPU-like backends (keep alphabetical)
//===---------------------------------------------------------------------===//
-def CPUMaterializeEncoding :
- InterfacePass<"iree-codegen-cpu-materialize-encoding", "mlir::FunctionOpInterface"> {
- let summary = "Materialize the encoding for tensor as specified by the backend";
- let constructor = "mlir::iree_compiler::createCPUMaterializeEncodingPass()";
+def CPUMaterializeHostEncoding :
+ Pass<"iree-codegen-cpu-materialize-host-encoding", "mlir::ModuleOp"> {
+ let summary = "Materialize the encoding for tensor as specified by the backend.";
+ let constructor = "mlir::iree_compiler::createCPUMaterializeHostEncodingPass()";
+}
+
+def CPUMaterializeDeviceEncoding :
+ InterfacePass<"iree-codegen-cpu-materialize-device-encoding", "mlir::FunctionOpInterface"> {
+ let summary = "Materialize the encoding for tensor as specified by the backend.";
+ let constructor = "mlir::iree_compiler::createCPUMaterializeDeviceEncodingPass()";
}
def CPUMaterializeUpperBoundTileSize :
- InterfacePass<"iree-codegen-cpu-materialize-upper-bound-tile-size", "mlir::FunctionOpInterface"> {
+ Pass<"iree-codegen-cpu-materialize-upper-bound-tile-size", "mlir::ModuleOp"> {
let summary = "Materialize upper_bound_tile_size to constants.";
let constructor = "mlir::iree_compiler::createCPUMaterializeUpperBoundTileSizePass()";
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir b/compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir
index 466bf1f..2a96d39 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-cpu-materialize-encoding),canonicalize,cse)" --split-input-file %s | FileCheck %s
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-cpu-materialize-device-encoding),canonicalize,cse)" --split-input-file %s | FileCheck %s
func.func @set_encoding_with_padding_semantics_bf16_x86_64_avx512f() attributes {
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/test/vmvx_materialize_encoding.mlir b/compiler/src/iree/compiler/Codegen/Common/CPU/test/vmvx_materialize_encoding.mlir
index c1bb279..4c0fd32 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/test/vmvx_materialize_encoding.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/test/vmvx_materialize_encoding.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-cpu-materialize-encoding),canonicalize,cse)" --split-input-file %s | FileCheck %s
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-cpu-materialize-device-encoding),canonicalize,cse)" --split-input-file %s | FileCheck %s
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
@@ -77,7 +77,7 @@
#map3 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map4 = affine_map<(d0, d1, d2) -> (d0, d1)>
func.func @fill_matmul(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index) attributes {
- hal.executable.target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb">
+ hal.executable.target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb">
} {
%c32_i64 = arith.constant 32 : i64
%cst = arith.constant 0.000000e+00 : f32
@@ -123,7 +123,7 @@
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
func.func @set_encoding_dynamic() attributes {
- hal.executable.target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb">
+ hal.executable.target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb">
} {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
@@ -177,7 +177,7 @@
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
func.func @unset_encoding_dynamic() attributes {
- hal.executable.target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb">
+ hal.executable.target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb">
} {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
@@ -225,7 +225,7 @@
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
func.func @matmul_lowering_f32f32f32_generic() attributes {
- hal.executable.target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb">
+ hal.executable.target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb">
} {
%c0 = arith.constant 0 : index
%M = hal.interface.constant.load[0] : index
diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
index c585614..281c398 100644
--- a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
@@ -28,10 +28,6 @@
namespace mlir::iree_compiler {
-using namespace IREE::Encoding;
-using IREE::Encoding::getEncodingAttr;
-using IREE::HAL::ExecutableTargetAttr;
-
//===---------------------------------------------------------------------===//
// Utility methods
//===---------------------------------------------------------------------===//
@@ -214,11 +210,11 @@
/// For now this takes a `paddingValue` as input. The source is also taken
/// as input so that these could be used with `OpConversionPatterns`.
static FailureOr<tensor::PackOp> lowerSetEncodingOpToPackOp(
- RewriterBase &rewriter, SetEncodingOp encodingOp, Value source,
- MaterializeEncodingFn materializeEncodingFn,
+ RewriterBase &rewriter, IREE::Encoding::SetEncodingOp encodingOp,
+ Value source, MaterializeEncodingFn materializeEncodingFn,
MaterializeEncodingValueFn materializeEncodingValueFn) {
RankedTensorType resultType = encodingOp.getResultType();
- auto encoding = getEncodingAttr(resultType);
+ auto encoding = IREE::Encoding::getEncodingAttr(resultType);
if (!encoding) {
return failure();
}
@@ -239,9 +235,6 @@
return rewriter.notifyMatchFailure(
encodingOp, "failed to generate runtime tile size query");
}
- if (!encoding) {
- return failure();
- }
std::optional<Value> paddingValue;
if (encoding.getRoundDimsToArray().empty()) {
paddingValue = getPaddingValue(source);
@@ -266,8 +259,8 @@
/// The source is taken as input so that these could be used with
/// `OpConversionPatterns`.
static FailureOr<tensor::UnPackOp> lowerUnsetEncodingToUnpackOp(
- RewriterBase &rewriter, UnsetEncodingOp encodingOp, Value packedValue,
- MaterializeEncodingFn materializeEncodingFn,
+ RewriterBase &rewriter, IREE::Encoding::UnsetEncodingOp encodingOp,
+ Value packedValue, MaterializeEncodingFn materializeEncodingFn,
MaterializeEncodingValueFn materializeEncodingValueFn) {
RankedTensorType sourceType = encodingOp.getSourceType();
FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
@@ -275,7 +268,7 @@
if (failed(materializeEncodingInfo)) {
return rewriter.notifyMatchFailure(encodingOp, "unhandled source encoding");
}
- if (isNarrowNResult(getEncodingAttr(sourceType))) {
+ if (isNarrowNResult(IREE::Encoding::getEncodingAttr(sourceType))) {
transposeInPlace(*materializeEncodingInfo);
}
// Create an `tensor.empty` for the result of the unpack operation.
@@ -297,7 +290,8 @@
}
static FailureOr<SmallVector<Value>> lowerUpperBoundTileSizeOpToConstants(
- RewriterBase &rewriter, UpperBoundTileSizeOp upperBoundTileSizeOp,
+ RewriterBase &rewriter,
+ IREE::Encoding::UpperBoundTileSizeOp upperBoundTileSizeOp,
MaterializeEncodingFn materializeEncodingFn) {
Location loc = upperBoundTileSizeOp.getLoc();
RankedTensorType tensorType = upperBoundTileSizeOp.getTensorType();
@@ -340,16 +334,17 @@
auto lhsType = cast<RankedTensorType>(inputs[0]->get().getType());
auto rhsType = cast<RankedTensorType>(inputs[1]->get().getType());
auto resultType = cast<RankedTensorType>(outputs[0].getType());
- auto lhsEncoding = getEncodingAttr(lhsType);
- auto rhsEncoding = getEncodingAttr(rhsType);
- auto resultEncoding = getEncodingAttr(resultType);
+ auto lhsEncoding = IREE::Encoding::getEncodingAttr(lhsType);
+ auto rhsEncoding = IREE::Encoding::getEncodingAttr(rhsType);
+ auto resultEncoding = IREE::Encoding::getEncodingAttr(resultType);
if (!lhsEncoding || !rhsEncoding || !resultEncoding) {
return failure();
}
- if (lhsEncoding.getOperandIndex().getValue() != MATMUL_LHS ||
- rhsEncoding.getOperandIndex().getValue() != MATMUL_RHS ||
- resultEncoding.getOperandIndex().getValue() != MATMUL_RESULT) {
+ if (lhsEncoding.getOperandIndex().getValue() != IREE::Encoding::MATMUL_LHS ||
+ rhsEncoding.getOperandIndex().getValue() != IREE::Encoding::MATMUL_RHS ||
+ resultEncoding.getOperandIndex().getValue() !=
+ IREE::Encoding::MATMUL_RESULT) {
return failure();
}
@@ -415,7 +410,7 @@
loc, emptyOp.getMixedSizes(), resultType.getElementType());
return newEmptyOp;
}
- if (isNarrowNResult(getEncodingAttr(emptyType))) {
+ if (isNarrowNResult(IREE::Encoding::getEncodingAttr(emptyType))) {
transposeInPlace(*materializeEncodingInfo);
}
FailureOr<SmallVector<OpFoldResult>> innerTileSizesOfr =
@@ -524,7 +519,7 @@
if (failed(encodingInfo)) {
return failure();
}
- if (isNarrowNResult(getEncodingAttr(boundTensorType))) {
+ if (isNarrowNResult(IREE::Encoding::getEncodingAttr(boundTensorType))) {
transposeInPlace(*encodingInfo);
}
@@ -731,12 +726,12 @@
/// Convert `set_encoding` op to `pack` op.
struct SetEncodingOpToPackOpConversion
- : public OpMaterializeEncodingPattern<SetEncodingOp> {
+ : public OpMaterializeEncodingPattern<IREE::Encoding::SetEncodingOp> {
using OpMaterializeEncodingPattern<
- SetEncodingOp>::OpMaterializeEncodingPattern;
+ IREE::Encoding::SetEncodingOp>::OpMaterializeEncodingPattern;
LogicalResult
- matchAndRewrite(SetEncodingOp encodingOp, OpAdaptor adaptor,
+ matchAndRewrite(IREE::Encoding::SetEncodingOp encodingOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto converter = static_cast<const MaterializeEncodingTypeConverter *>(
getTypeConverter());
@@ -763,12 +758,12 @@
/// Convert `unset_encoding` op to `unpack` op.
struct UnsetEncodingOpToUnPackOpConversion
- : public OpMaterializeEncodingPattern<UnsetEncodingOp> {
+ : public OpMaterializeEncodingPattern<IREE::Encoding::UnsetEncodingOp> {
using OpMaterializeEncodingPattern<
- UnsetEncodingOp>::OpMaterializeEncodingPattern;
+ IREE::Encoding::UnsetEncodingOp>::OpMaterializeEncodingPattern;
LogicalResult
- matchAndRewrite(UnsetEncodingOp encodingOp, OpAdaptor adaptor,
+ matchAndRewrite(IREE::Encoding::UnsetEncodingOp encodingOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto converter = static_cast<const MaterializeEncodingTypeConverter *>(
this->getTypeConverter());
@@ -797,14 +792,15 @@
/// `materializeEncodingFn` returns a failure, the pattern will materialize it
/// to the same shape.
struct UpperBoundTileSizeToConstantOpConversion
- : public OpRewritePattern<UpperBoundTileSizeOp> {
+ : public OpRewritePattern<IREE::Encoding::UpperBoundTileSizeOp> {
UpperBoundTileSizeToConstantOpConversion(
MLIRContext *context, MaterializeEncodingFn materializeEncodingFn)
- : OpRewritePattern<UpperBoundTileSizeOp>(context),
+ : OpRewritePattern<IREE::Encoding::UpperBoundTileSizeOp>(context),
materializeEncodingFn(materializeEncodingFn) {}
- LogicalResult matchAndRewrite(UpperBoundTileSizeOp upperBoundTileSizeOp,
- PatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(IREE::Encoding::UpperBoundTileSizeOp upperBoundTileSizeOp,
+ PatternRewriter &rewriter) const override {
auto constants = lowerUpperBoundTileSizeOpToConstants(
rewriter, upperBoundTileSizeOp, materializeEncodingFn);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index f7db826..6a48eb7 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -750,7 +750,7 @@
// TODO(#13888): This(createExpandF16OpToF32Pass()) pass is being added
// way to late and should insted be be done during lowering to LLVM.
.addPass(createExpandF16OpToF32Pass)
- .addPass([&]() { return createCPUMaterializeEncodingPass(); })
+ .addPass(createCPUMaterializeDeviceEncodingPass)
// TODO: Remove the following pass the plumb support for
// #hal.descriptor_type memory space through the stack.
.addPass(createEraseHALDescriptorTypeFromMemRefPass);
diff --git a/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp b/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp
index e3b19f0..42ede9d 100644
--- a/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp
+++ b/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp
@@ -619,7 +619,8 @@
requestedTargetDevice = resolveTargetDevice(*targetRegistry.value);
hasRequestedTargetDevice =
targetRegistry->getTargetDevice(requestedTargetDevice) != nullptr;
- compileOptions->executableOptions.targets.push_back(requestedTargetDevice);
+ compileOptions->executableOptions.legacyTargetBackends.push_back(
+ requestedTargetDevice);
compileOptions->targetOptions.f32Extension = true;
compileOptions->targetOptions.f64Extension = true;
compileOptions->targetOptions.truncateUnsupportedFloats = false;
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowInterfaces.td b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowInterfaces.td
index 5a1227e..dcb0b0f 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowInterfaces.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowInterfaces.td
@@ -9,8 +9,4 @@
include "iree/compiler/Dialect/Util/IR/UtilBase.td"
-//===----------------------------------------------------------------------===//
-// IREE::Flow::StreamableOpInterface
-//===----------------------------------------------------------------------===//
-
#endif // IREE_DIALECT_FLOW_INTERFACES
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
index a2079c3..1a60d1c 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
@@ -288,7 +288,6 @@
struct ElideRedundantOperandsOfWorkgroupCountFromSliceOp
: OpRewritePattern<DispatchWorkgroupsOp> {
using OpRewritePattern::OpRewritePattern;
-
LogicalResult matchAndRewrite(DispatchWorkgroupsOp op,
PatternRewriter &rewriter) const override {
Region &count = op.getWorkgroupCount();
@@ -369,7 +368,6 @@
// Bubble up the ordinal ops so that all uses go through this operation.
struct BubbleUpOrdinalOp : public OpRewritePattern<DispatchWorkloadOrdinalOp> {
using OpRewritePattern::OpRewritePattern;
-
LogicalResult matchAndRewrite(DispatchWorkloadOrdinalOp ordinalOp,
PatternRewriter &rewriter) const override {
auto blockArg = llvm::dyn_cast<BlockArgument>(ordinalOp.getOperand());
@@ -894,7 +892,6 @@
template <typename CastOpTy>
struct FlattenTensorCastLikeChain : public OpRewritePattern<CastOpTy> {
using OpRewritePattern<CastOpTy>::OpRewritePattern;
-
LogicalResult matchAndRewrite(CastOpTy reshapeOp,
PatternRewriter &rewriter) const override {
// We want the same result value/shape but to source from the ancestor. We
@@ -1157,6 +1154,40 @@
}
//===----------------------------------------------------------------------===//
+// flow.tensor.transfer
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Attempts to identify trivial cases where we locally recognize that a tensor
+// is transferred to the same context it's already on. This does not look across
+// control flow edges or globals and is mostly for simplifying IR that may come
+// in with a transfer on every single tensor.
+struct ElideRedundantTransfer : public OpRewritePattern<TensorTransferOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(TensorTransferOp op,
+ PatternRewriter &rewriter) const override {
+ auto baseValue =
+ IREE::Util::TiedOpInterface::findTiedBaseValue(op.getOperand());
+ if (auto transferOp = dyn_cast_if_present<IREE::Flow::TensorTransferOp>(
+ baseValue.getDefiningOp())) {
+ if (transferOp.getTarget() == op.getTarget()) {
+ rewriter.replaceOp(op, op.getOperand());
+ return success();
+ }
+ }
+ return failure();
+ }
+};
+
+} // namespace
+
+void TensorTransferOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.insert<ElideRedundantTransfer>(context);
+}
+
+//===----------------------------------------------------------------------===//
// flow.tensor.slice
//===----------------------------------------------------------------------===//
@@ -1294,7 +1325,6 @@
// to be updated to use the source of the cast as the target tensor.
struct FoldTensorUpdateOpWithCasts : public OpRewritePattern<TensorUpdateOp> {
using OpRewritePattern<TensorUpdateOp>::OpRewritePattern;
-
LogicalResult matchAndRewrite(TensorUpdateOp updateOp,
PatternRewriter &rewriter) const override {
auto targetCastOp = updateOp.getTarget().getDefiningOp<tensor::CastOp>();
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
index e0be51c..0e47aa5 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
@@ -1786,6 +1786,20 @@
}
//===----------------------------------------------------------------------===//
+// flow.tensor.transfer
+//===----------------------------------------------------------------------===//
+
+LogicalResult TensorTransferOp::verify() {
+ if (failed(verifyOpDynamicDims(getOperation(), {getOperand()},
+ getArgumentDims())) ||
+ failed(verifyOpDynamicDims(getOperation(), {getResult()},
+ getArgumentDims()))) {
+ return failure();
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// flow.tensor.slice
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
index 30b27c8..938e7ad 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
@@ -1500,6 +1500,57 @@
let hasFolder = 1;
}
+def FLOW_TensorTransferOp : FLOW_PureOp<"tensor.transfer", [
+ AllTypesMatch<["operand", "result"]>,
+ DeclareOpInterfaceMethods<Util_HoistableOpInterface>,
+ Util_ShapeAwareOp,
+]> {
+ let summary = [{transfers a tensor to a target by copying if needed}];
+ let description = [{
+ Transfers the tensor from whichever context it may be in to the specified
+ target context. If the contexts are compatible and can access each others
+ memory the operation may be elided and otherwise will become one or more
+ copies to transfer the tensor in cases where staging through an intermediate
+ context is required.
+ }];
+
+ let arguments = (ins
+ FLOW_Tensor:$operand,
+ FLOW_ShapeDynamicDims:$argument_dims,
+ AnyAttr:$target
+ );
+ let results = (outs
+ FLOW_Tensor:$result
+ );
+
+ let assemblyFormat = [{
+ $operand `:` type($result) (`{` $argument_dims^ `}`)?
+ `to` $target
+ attr-dict-with-keyword
+ }];
+
+ let builders = [
+ OpBuilder<(ins "Value":$operand, "Attribute":$target),
+ [{
+ build($_builder, $_state,
+ operand.getType(),
+ operand,
+ IREE::Util::buildDynamicDimsForValue($_state.location, operand, $_builder),
+ target);
+ }]>,
+ ];
+
+ let extraClassDeclaration = [{
+ bool isHoistableLeafOp() { return false; }
+
+ ValueRange getOperandDynamicDims(unsigned idx) { return getArgumentDims(); }
+ ValueRange getResultDynamicDims(unsigned idx) { return getArgumentDims(); }
+ }];
+
+ let hasVerifier = 1;
+ let hasCanonicalizer = 1;
+}
+
def FLOW_TensorSliceOp : FLOW_PureOp<"tensor.slice", [
AllRanksMatch<["source", "result"]>,
AllElementTypesMatch<["source", "result"]>,
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
index bcb1dbb..959e398 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
@@ -351,8 +351,8 @@
// -----
-// CHECK-LABEL: @cloneConst
-util.func public @cloneConst() -> tensor<4xi32> {
+// CHECK-LABEL: @cloneConstant
+util.func public @cloneConstant() -> tensor<4xi32> {
// CHECK-NEXT: %[[C:.+]] = arith.constant dense<[0, 1, 2, 3]> : tensor<4xi32>
%0 = arith.constant dense<[0, 1, 2, 3]> : tensor<4xi32>
%1 = flow.tensor.clone %0 : tensor<4xi32>
@@ -362,8 +362,8 @@
// -----
-// CHECK-LABEL: @cloneConstZeroElements
-util.func public @cloneConstZeroElements() -> tensor<0x2xi32> {
+// CHECK-LABEL: @cloneConstantZeroElements
+util.func public @cloneConstantZeroElements() -> tensor<0x2xi32> {
// CHECK-NEXT: %[[C:.+]] = arith.constant dense<> : tensor<0x2xi32>
%0 = arith.constant dense<> : tensor<0x2xi32>
// CHECK-NOT: flow.tensor.clone
@@ -397,6 +397,21 @@
// -----
+// CHECK-LABEL: @ElideRedundantTransfer
+// CHECK-SAME: (%[[OPERAND:.+]]: tensor<4x?xf32>, %[[DIM:.+]]: index)
+util.func public @ElideRedundantTransfer(%arg0: tensor<4x?xf32>, %dim: index) -> tensor<4x?xi32> {
+ // CHECK: %[[TRANSFER:.+]] = flow.tensor.transfer %arg0
+ %transfer = flow.tensor.transfer %arg0 : tensor<4x?xf32>{%dim} to "target"
+ // CHECK: %[[BITCAST:.+]] = flow.tensor.bitcast %[[TRANSFER]]
+ %bitcast = flow.tensor.bitcast %transfer : tensor<4x?xf32>{%dim} -> tensor<4x?xi32>{%dim}
+ // CHECK-NOT: flow.transfer
+ %redundant = flow.tensor.transfer %bitcast : tensor<4x?xi32>{%dim} to "target"
+ // CHECK-NEXT: %[[BITCAST]]
+ util.return %redundant : tensor<4x?xi32>
+}
+
+// -----
+
// CHECK-LABEL: @sliceConst0D
util.func public @sliceConst0D() -> tensor<i32> {
%0 = arith.constant dense<0> : tensor<i32>
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir
index b0a19ad..62d79a1 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir
@@ -7,6 +7,8 @@
util.return %0 : tensor<16xf32>
}
+// -----
+
// CHECK-LABEL: @tensorReshapeScalar
util.func public @tensorReshapeScalar(%arg0 : tensor<f32>) -> tensor<f32> {
// CHECK-NEXT: %0 = flow.tensor.reshape %arg0 : tensor<f32> -> tensor<f32>
@@ -14,6 +16,8 @@
util.return %0 : tensor<f32>
}
+// -----
+
// CHECK-LABEL: @tensorReshapeDynamic
util.func public @tensorReshapeDynamic(%arg0 : tensor<?x4xf32>) -> tensor<?x2xf32> {
%c4 = arith.constant 4 : index
@@ -23,6 +27,8 @@
util.return %0 : tensor<?x2xf32>
}
+// -----
+
// CHECK-LABEL: @tensorReshapeComplex
util.func public @tensorReshapeComplex(%arg0 : tensor<4x4xcomplex<f32>>) -> tensor<16xcomplex<f32>> {
// CHECK-NEXT: flow.tensor.reshape %arg0 : tensor<4x4xcomplex<f32>> -> tensor<16xcomplex<f32>>
@@ -48,6 +54,8 @@
util.return %0 : f32
}
+// -----
+
// CHECK-LABEL: @tensorLoadScalar
util.func public @tensorLoadScalar(%arg0 : tensor<f32>) -> f32 {
// CHECK-NEXT: %0 = flow.tensor.load %arg0 : tensor<f32>
@@ -55,6 +63,8 @@
util.return %0 : f32
}
+// -----
+
// CHECK-LABEL: @tensorLoadDynamic
util.func public @tensorLoadDynamic(%arg0 : tensor<?x4xf32>, %arg1 : index, %arg2 : index) -> f32 {
%c4 = arith.constant 4 : index
@@ -72,6 +82,8 @@
util.return %0 : tensor<4x4xf32>
}
+// -----
+
// CHECK-LABEL: @tensorStoreScalar
util.func public @tensorStoreScalar(%arg0 : f32, %arg1 : tensor<f32>) -> tensor<f32> {
// CHECK-NEXT: %0 = flow.tensor.store %arg0, %arg1 : tensor<f32>
@@ -79,6 +91,8 @@
util.return %0 : tensor<f32>
}
+// -----
+
// CHECK-LABEL: @tensorStoreDynamic
util.func public @tensorStoreDynamic(%arg0 : tensor<?x4xf32>, %arg1 : index, %arg2 : index, %arg3 : f32) -> tensor<?x4xf32> {
%c4 = arith.constant 4 : index
@@ -114,6 +128,8 @@
util.return %0 : tensor<4x4xf32>
}
+// -----
+
// CHECK-LABEL: @tensorSplatScalar
util.func public @tensorSplatScalar(%arg0 : f32) -> tensor<f32> {
// CHECK-NEXT: %0 = flow.tensor.splat %arg0 : tensor<f32>
@@ -121,6 +137,8 @@
util.return %0 : tensor<f32>
}
+// -----
+
// CHECK-LABEL: @tensorSplatDynamic
util.func public @tensorSplatDynamic(%arg0 : f32) -> tensor<?x4xf32> {
%c4 = arith.constant 4 : index
@@ -138,6 +156,17 @@
util.return %0 : tensor<4x4xf32>
}
+// -----
+
+// CHECK-LABEL: @tensorTransfer
+util.func public @tensorTransfer(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> {
+ // CHECK-NEXT: %0 = flow.tensor.transfer %arg0 : tensor<4x4xf32> to "dummy"
+ %0 = flow.tensor.transfer %arg0 : tensor<4x4xf32> to "dummy"
+ util.return %0 : tensor<4x4xf32>
+}
+
+// -----
+
// CHECK-LABEL: @tensorCloneScalar
util.func public @tensorCloneScalar(%arg0 : tensor<f32>) -> tensor<f32> {
// CHECK-NEXT: %0 = flow.tensor.clone %arg0 : tensor<f32>
@@ -145,6 +174,8 @@
util.return %0 : tensor<f32>
}
+// -----
+
// CHECK-LABEL: @tensorCloneDynamic
util.func public @tensorCloneDynamic(%arg0 : tensor<?x4xf32>) -> tensor<?x4xf32> {
%c4 = arith.constant 4 : index
@@ -162,6 +193,8 @@
util.return %0 : tensor<2x2xf32>
}
+// -----
+
// CHECK-LABEL: @tensorSliceDynamic
util.func public @tensorSliceDynamic(%arg0 : tensor<?x4xf32>, %arg1 : index, %arg2 : index) -> tensor<?x2xf32> {
%c2 = arith.constant 2 : index
@@ -180,6 +213,8 @@
util.return %0 : tensor<4x4xf32>
}
+// -----
+
// CHECK-LABEL: @tensorUpdateDynamic
util.func public @tensorUpdateDynamic(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x4xf32>, %arg2 : index, %arg3 : index) -> tensor<?x4xf32> {
%c1 = arith.constant 1 : index
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp
index ecfb86a..3967d50 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp
@@ -81,7 +81,8 @@
// hal.tensor.export
auto bufferExportOp = initializerBuilder.create<IREE::HAL::TensorExportOp>(
loc, globalOp.getType(), splatOp.getResult(),
- TypeAttr::get(splatOp.getType()), /*name=*/nullptr);
+ TypeAttr::get(splatOp.getType()), /*name=*/nullptr,
+ /*affinity=*/nullptr);
// util.optimization_barrier (try to prevent optimizations across the export)
auto barrierOp = initializerBuilder.create<IREE::Util::OptimizationBarrierOp>(
loc, bufferExportOp.getTarget());
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/InsertDispatchDebugTargets.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/InsertDispatchDebugTargets.cpp
index beb995d..49212b5 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/InsertDispatchDebugTargets.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/InsertDispatchDebugTargets.cpp
@@ -103,7 +103,8 @@
if (llvm::isa<TensorType>(retVal.getType())) {
auto type = IREE::HAL::BufferViewType::get(context);
auto exportOp = builder.create<IREE::HAL::TensorExportOp>(
- loc, type, retVal, TypeAttr::get(retVal.getType()), /*name=*/nullptr);
+ loc, type, retVal, TypeAttr::get(retVal.getType()), /*name=*/nullptr,
+ /*affinity=*/nullptr);
exports.push_back(exportOp.getResult());
newTypes.push_back(type);
} else {
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/convert_region_to_workgroups.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/convert_region_to_workgroups.mlir
index 3caa5b0..92a0208 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/convert_region_to_workgroups.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/convert_region_to_workgroups.mlir
@@ -1,5 +1,7 @@
// RUN: iree-opt %s --pass-pipeline="builtin.module(util.func(iree-flow-convert-dispatch-regions-to-workgroups, iree-flow-canonicalize, cse))" -split-input-file | FileCheck %s
+util.global private @device : !hal.device
+
// CHECK-LABEL: util.func public @foo(
// CHECK: %[[argA:.*]]: tensor<?x?xf32>, %[[argB:.*]]: tensor<5x10xf32>, %[[argC:.*]]: tensor<10x11xf32>
util.func public @foo(%argA: tensor<?x?xf32>, %argB: tensor<5x10xf32>, %argC: tensor<10x11xf32>) -> (tensor<?x?xf32>, tensor<5x11xf32>) {
@@ -21,7 +23,7 @@
flow.return %argA : tensor<?x?xf32>
}
// CHECK: %[[r1:.*]] = flow.dispatch.workgroups(%[[argB]], %[[argC]]) : (tensor<5x10xf32>, tensor<10x11xf32>) -> tensor<5x11xf32>
- // CHECK-SAME: stream.affinity = #hal.affinity.queue<[0]>
+ // CHECK-SAME: stream.affinity = #hal.device.affinity<@device>
// CHECK-NEXT: (%[[arg3:.*]]: !flow.dispatch.tensor<readonly:tensor<5x10xf32>>, %[[arg4:.*]]: !flow.dispatch.tensor<readonly:tensor<10x11xf32>>, %[[arg5:.*]]: !flow.dispatch.tensor<writeonly:tensor<5x11xf32>>)
// CHECK-DAG: %[[loadB:.*]] = flow.dispatch.tensor.load %[[arg3]], offsets = [0, 0], sizes = [5, 10], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<5x10xf32>> -> tensor<5x10xf32>
// CHECK-DAG: %[[loadC:.*]] = flow.dispatch.tensor.load %[[arg4]], offsets = [0, 0], sizes = [10, 11], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<10x11xf32>> -> tensor<10x11xf32>
@@ -31,7 +33,9 @@
// CHECK: flow.dispatch.tensor.store %[[matmul]], %[[arg5]], offsets = [0, 0], sizes = [5, 11], strides = [1, 1] : tensor<5x11xf32> -> !flow.dispatch.tensor<writeonly:tensor<5x11xf32>>
// CHECK: flow.return
// CHECK: }
- %r1 = flow.dispatch.region -> (tensor<5x11xf32>) attributes {stream.affinity = #hal.affinity.queue<[0]>} {
+ %r1 = flow.dispatch.region -> (tensor<5x11xf32>) attributes {
+ stream.affinity = #hal.device.affinity<@device>
+ } {
%zero = arith.constant 0.0 : f32
%0 = tensor.empty() : tensor<5x11xf32>
%1 = linalg.fill ins(%zero : f32) outs(%0 : tensor<5x11xf32>) -> tensor<5x11xf32>
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_constants.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_constants.mlir
index e3db1b6..57304a1 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_constants.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_constants.mlir
@@ -67,11 +67,13 @@
// Tests that any hoistable attrs are propagated to the outlined globals.
+util.global private @device : !hal.device
+
// CHECK: util.global private @__constant_tensor_2xi32
-// CHECK-SAME: stream.affinity = #hal.affinity.queue<[0]>
+// CHECK-SAME: stream.affinity = #hal.device.affinity<@device, [0]>
// CHECK-NEXT: util.func private @set_affinity
util.func private @set_affinity() attributes {
- stream.affinity = #hal.affinity.queue<[0]>
+ stream.affinity = #hal.device.affinity<@device, [0]>
} {
// CHECK-NEXT: = util.global.load immutable @__constant_tensor_2xi32
%cst = arith.constant dense<[0, 1]> : tensor<2xi32>
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir
index 0a0f9e5..dd9d651 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir
@@ -78,6 +78,9 @@
// -----
+util.global private @device_a : !hal.device
+util.global private @device_b : !hal.device
+
// CHECK: flow.executable private @dispatchFn1_dispatch_0
// CHECK-LABEL: util.func public @dispatchFn1
@@ -85,9 +88,9 @@
%x = arith.constant 100 : index
%y = arith.constant 50 : index
// CHECK: flow.dispatch @dispatchFn1_dispatch_0::@dispatchFn1_dispatch_0
- // CHECK-SAME: stream.affinity = #hal.affinity.queue<[0]>
+ // CHECK-SAME: stream.affinity = #hal.device.affinity<@device_a>
%0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<8x4xf32>) -> (tensor<4x8xf32>) attributes {
- stream.affinity = #hal.affinity.queue<[0]>
+ stream.affinity = #hal.device.affinity<@device_a>
} = (
%arg: !flow.dispatch.tensor<readonly:tensor<8x4xf32>>, %ret: !flow.dispatch.tensor<writeonly:tensor<4x8xf32>>
) {
@@ -103,9 +106,9 @@
%x = arith.constant 100 : index
%y = arith.constant 50 : index
// CHECK: flow.dispatch @dispatchFn2_dispatch_0::@dispatchFn2_dispatch_0
- // CHECK-SAME: stream.affinity = #hal.affinity.queue<[1]>
+ // CHECK-SAME: stream.affinity = #hal.device.affinity<@device_b>
%0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<8x4xf32>) -> (tensor<4x8xf32>) attributes {
- stream.affinity = #hal.affinity.queue<[1]>
+ stream.affinity = #hal.device.affinity<@device_b>
} = (
%arg: !flow.dispatch.tensor<readonly:tensor<8x4xf32>>, %ret: !flow.dispatch.tensor<writeonly:tensor<4x8xf32>>
) {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/BUILD.bazel
new file mode 100644
index 0000000..46e9053
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/BUILD.bazel
@@ -0,0 +1,35 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_compiler_cc_library(
+ name = "Attributes",
+ srcs = [
+ "DeviceGlobalPVS.cpp",
+ "DeviceTargetPVS.cpp",
+ ],
+ hdrs = [
+ "DeviceGlobalPVS.h",
+ "DeviceTargetPVS.h",
+ ],
+ deps = [
+ "//compiler/src/iree/compiler/Dialect/HAL/IR",
+ "//compiler/src/iree/compiler/Dialect/Util/Analysis/DFX",
+ "//compiler/src/iree/compiler/Dialect/Util/IR",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:Analysis",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:Support",
+ ],
+)
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/CMakeLists.txt
new file mode 100644
index 0000000..8be479f
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/CMakeLists.txt
@@ -0,0 +1,34 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/BUILD.bazel #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ Attributes
+ HDRS
+ "DeviceGlobalPVS.h"
+ "DeviceTargetPVS.h"
+ SRCS
+ "DeviceGlobalPVS.cpp"
+ "DeviceTargetPVS.cpp"
+ DEPS
+ LLVMSupport
+ MLIRAnalysis
+ MLIRIR
+ MLIRPass
+ MLIRSupport
+ iree::compiler::Dialect::HAL::IR
+ iree::compiler::Dialect::Util::Analysis::DFX
+ iree::compiler::Dialect::Util::IR
+ PUBLIC
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceGlobalPVS.cpp b/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceGlobalPVS.cpp
new file mode 100644
index 0000000..8e96e97
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceGlobalPVS.cpp
@@ -0,0 +1,117 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceGlobalPVS.h"
+
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "iree-hal-device-analysis"
+
+namespace mlir::iree_compiler::IREE::HAL {
+
+//===----------------------------------------------------------------------===//
+// DeviceGlobalValuePVS
+//===----------------------------------------------------------------------===//
+
+const char DeviceGlobalValuePVS::ID = 0;
+
+void DeviceGlobalValuePVS::initializeValue(Value value, DFX::Solver &solver) {
+ assert(isa<IREE::HAL::DeviceType>(value.getType()) &&
+ "only initialize on values of type !hal.device");
+
+ // If the value is a function arg of a public function then we'll never be
+ // able to know (today). We could look for attributes defining device
+ // properties but we can't recover a DeviceTargetAttr from them.
+ if (auto blockArg = dyn_cast<BlockArgument>(value)) {
+ if (auto funcOp =
+ dyn_cast<FunctionOpInterface>(blockArg.getOwner()->getParentOp())) {
+ if (funcOp.isPublic()) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "DeviceGlobalValuePVS: argument to a public function - "
+ "treating as undefined\n");
+ unionAssumedWithUndef();
+ indicatePessimisticFixpoint();
+ return;
+ }
+ }
+ }
+}
+
+ChangeStatus DeviceGlobalValuePVS::updateValue(Value value,
+ DFX::Solver &solver) {
+ StateType newState;
+ auto traversalResult = TraversalResult::COMPLETE;
+
+ // Walk into all producers of the SSA value.
+ // Note that we may end up at multiple global loads of different globals
+ // by walking up through calls/branches/etc.
+ traversalResult |=
+ solver.getExplorer().walkDefiningOps(value, [&](OpResult result) {
+ updateFromDefiningOp(value, result, newState, solver);
+ return WalkResult::advance();
+ });
+
+ if (traversalResult == TraversalResult::INCOMPLETE) {
+ // Incomplete traversal because of external call graph edges or pointers.
+ newState.unionAssumedWithUndef();
+ newState.indicatePessimisticFixpoint();
+ }
+ return DFX::clampStateAndIndicateChange(getState(), newState);
+}
+
+void DeviceGlobalValuePVS::updateFromDefiningOp(Value value, OpResult result,
+ StateType &newState,
+ DFX::Solver &solver) {
+ TypeSwitch<Operation *, void>(result.getOwner())
+ .Case([&](mlir::arith::SelectOp op) {
+ auto &truePVS = solver.getElementFor<DeviceGlobalValuePVS>(
+ *this, Position::forValue(op.getTrueValue()),
+ DFX::Resolution::REQUIRED);
+ auto &falsePVS = solver.getElementFor<DeviceGlobalValuePVS>(
+ *this, Position::forValue(op.getFalseValue()),
+ DFX::Resolution::REQUIRED);
+ newState ^= truePVS.getState();
+ newState ^= falsePVS.getState();
+ })
+ .Case([&](IREE::Util::OptimizationBarrierOp op) {
+ auto &sourcePVS = solver.getElementFor<DeviceGlobalValuePVS>(
+ *this, Position::forValue(op.getOperand(0)),
+ DFX::Resolution::REQUIRED);
+ newState ^= sourcePVS.getState();
+ })
+ .Case([&](IREE::Util::GlobalLoadOpInterface op) {
+ auto *globalInfo =
+ solver.getExplorer().queryGlobalInfoFrom(op.getGlobalName(), op);
+ newState.unionAssumed(globalInfo->op);
+ })
+ .Default([&](Operation *op) {});
+}
+
+const std::string DeviceGlobalValuePVS::getAsStr(AsmState &asmState) const {
+ std::string str;
+ llvm::raw_string_ostream sstream(str);
+ sstream << "pvs: ";
+ if (isValidState()) {
+ sstream << "[";
+ if (isUndefContained()) {
+ sstream << "undef, ";
+ }
+ llvm::interleaveComma(getAssumedSet(), sstream,
+ [&](IREE::Util::GlobalOpInterface value) {
+ value.print(sstream, asmState);
+ });
+ sstream << "]";
+ } else {
+ sstream << "(invalid)";
+ }
+ sstream.flush();
+ return str;
+}
+
+} // namespace mlir::iree_compiler::IREE::HAL
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceGlobalPVS.h b/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceGlobalPVS.h
new file mode 100644
index 0000000..10864d6
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceGlobalPVS.h
@@ -0,0 +1,60 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_HAL_ANALYSIS_ATTRIBUTES_DEVICEGLOBALPVS_H_
+#define IREE_COMPILER_DIALECT_HAL_ANALYSIS_ATTRIBUTES_DEVICEGLOBALPVS_H_
+
+#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
+#include "iree/compiler/Dialect/Util/Analysis/DFX/Solver.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir::iree_compiler::IREE::HAL {
+
+//===----------------------------------------------------------------------===//
+// DeviceGlobalValuePVS (Potential Values State)
+//===----------------------------------------------------------------------===//
+
+// Set of potential globals that provide a !hal.device SSA value.
+// A set size of 1 indicates that the device SSA value is a particular device.
+// Multiple entries indicate that multiple code paths may route to the value
+// with different devices selected.
+class DeviceGlobalValuePVS
+ : public DFX::StateWrapper<
+ DFX::PotentialValuesState<IREE::Util::GlobalOpInterface>,
+ DFX::ValueElement> {
+public:
+ using BaseType = DFX::StateWrapper<
+ DFX::PotentialValuesState<IREE::Util::GlobalOpInterface>,
+ DFX::ValueElement>;
+ using BaseType::BaseType;
+
+ static DeviceGlobalValuePVS &createForPosition(const Position &pos,
+ DFX::Solver &solver) {
+ return *(new (solver.getAllocator()) DeviceGlobalValuePVS(pos));
+ }
+
+ // Identity definitions.
+ const std::string getName() const override { return "DeviceGlobalValuePVS"; }
+ const void *getID() const override { return &ID; }
+ static bool classof(const DFX::AbstractElement *element) {
+ return (element->getID() == &ID);
+ }
+ static const char ID;
+
+ const std::string getAsStr(AsmState &asmState) const override;
+
+private:
+ void initializeValue(Value value, DFX::Solver &solver) override;
+ ChangeStatus updateValue(Value value, DFX::Solver &solver) override;
+ void updateFromDefiningOp(Value value, OpResult result, StateType &newState,
+ DFX::Solver &solver);
+};
+
+} // namespace mlir::iree_compiler::IREE::HAL
+
+#endif // IREE_COMPILER_DIALECT_HAL_ANALYSIS_ATTRIBUTES_DEVICEGLOBALPVS_H_
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceTargetPVS.cpp b/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceTargetPVS.cpp
new file mode 100644
index 0000000..85813f6
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceTargetPVS.cpp
@@ -0,0 +1,250 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceTargetPVS.h"
+
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "iree-hal-device-analysis"
+
+namespace mlir::iree_compiler::IREE::HAL {
+
+//===----------------------------------------------------------------------===//
+// DeviceTargetGlobalPVS
+//===----------------------------------------------------------------------===//
+
+const char DeviceTargetGlobalPVS::ID = 0;
+
+void DeviceTargetGlobalPVS::initializeOperation(IREE::Util::GlobalOp globalOp,
+ DFX::Solver &solver) {
+ assert(isa<IREE::HAL::DeviceType>(globalOp.getType()) &&
+ "only initialize on globals of type !hal.device");
+
+ // We only support immutable initialized device globals.
+ // We could track usage up through stores to handle the mutable case but
+ // the compiler does not generate such programs today.
+ auto *globalInfo = solver.getExplorer().getGlobalInfo(globalOp);
+ if (!globalInfo || globalInfo->isIndirect || globalOp.isGlobalMutable()) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "DeviceTargetGlobalPVS: mutable device globals or those used "
+ "indirectly are not yet implemented\n");
+ unionAssumedWithUndef();
+ indicatePessimisticFixpoint();
+ return;
+ }
+
+ // Use the initial value to populate the potential value set.
+ std::function<bool(Attribute)> unionAttr;
+ unionAttr = [&](Attribute attr) -> bool {
+ return TypeSwitch<Attribute, bool>(attr)
+ .Case<IREE::HAL::DeviceTargetAttr>([&](auto targetAttr) {
+ LLVM_DEBUG({
+ llvm::dbgs() << "DeviceTargetGlobalPVS: unioning with target: ";
+ attr.print(llvm::dbgs());
+ llvm::dbgs() << "\n";
+ });
+ unionAssumed(targetAttr);
+ return true;
+ })
+ .Case<IREE::HAL::DeviceFallbackAttr>([&](auto fallbackAttr) {
+ LLVM_DEBUG({
+ llvm::dbgs() << "DeviceTargetGlobalPVS: unioning with fallback: ";
+ attr.print(llvm::dbgs());
+ llvm::dbgs() << "\n";
+ });
+ auto *fallbackInfo = solver.getExplorer().queryGlobalInfoFrom(
+ fallbackAttr.getName().getValue(), globalOp);
+ if (!fallbackInfo) {
+ LLVM_DEBUG(
+ llvm::dbgs()
+ << "DeviceTargetGlobalPVS: !! failed to find fallback global "
+ << fallbackAttr.getName().getValue() << "\n");
+ return false;
+ }
+ auto fallbackPVS =
+ solver.getOrCreateElementFor<DeviceTargetGlobalPVS>(
+ Position::forOperation(fallbackInfo->op));
+ if (fallbackPVS.isUndefContained()) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "DeviceTargetGlobalPVS: !! fallback is undefined\n");
+ return false;
+ }
+ unionAssumed(fallbackPVS.getState());
+ return true;
+ })
+ .Case<IREE::HAL::DeviceSelectAttr>([&](auto selectAttr) {
+ LLVM_DEBUG({
+ llvm::dbgs() << "DeviceTargetGlobalPVS: unioning with selected "
+ "child devices: ";
+ attr.print(llvm::dbgs());
+ llvm::dbgs() << "\n";
+ });
+ for (auto childAttr : selectAttr.getDevices()) {
+ if (!unionAttr(childAttr)) {
+ return false;
+ }
+ }
+ return true;
+ })
+ .Default(
+ [&](auto attr) {
+ LLVM_DEBUG(
+ llvm::dbgs()
+ << "DeviceTargetGlobalPVS: !! unknown initial value type\n");
+ return false;
+ });
+ };
+ if (auto initialValueAttr = globalOp.getInitialValueAttr()) {
+ if (unionAttr(initialValueAttr)) {
+ indicateOptimisticFixpoint();
+ } else {
+ unionAssumedWithUndef();
+ indicatePessimisticFixpoint();
+ }
+ } else {
+ LLVM_DEBUG(llvm::dbgs()
+ << "DeviceTargetGlobalPVS: no initial value, dynamically "
+ "configure devices not yet implemented\n");
+ unionAssumedWithUndef();
+ indicatePessimisticFixpoint();
+ }
+}
+
+ChangeStatus
+DeviceTargetGlobalPVS::updateOperation(IREE::Util::GlobalOp globalOp,
+ DFX::Solver &solver) {
+ // We only support running on initialized globals today.
+ // We could support walking store/load or other things, though.
+ return ChangeStatus::UNCHANGED;
+}
+
+const std::string DeviceTargetGlobalPVS::getAsStr(AsmState &asmState) const {
+ std::string str;
+ llvm::raw_string_ostream sstream(str);
+ sstream << "pvs: ";
+ if (isValidState()) {
+ sstream << "[";
+ if (isUndefContained()) {
+ sstream << "undef, ";
+ }
+ llvm::interleaveComma(getAssumedSet(), sstream,
+ [&](IREE::HAL::DeviceTargetAttr value) {
+ cast<Attribute>(value).print(sstream);
+ });
+ sstream << "]";
+ } else {
+ sstream << "(invalid)";
+ }
+ sstream.flush();
+ return str;
+}
+
+//===----------------------------------------------------------------------===//
+// DeviceTargetValuePVS
+//===----------------------------------------------------------------------===//
+
+const char DeviceTargetValuePVS::ID = 0;
+
+void DeviceTargetValuePVS::initializeValue(Value value, DFX::Solver &solver) {
+ assert(isa<IREE::HAL::DeviceType>(value.getType()) &&
+ "only initialize on values of type !hal.device");
+
+ // If the value is a function arg of a public function then we'll never be
+ // able to know (today). We could look for attributes defining device
+ // properties but we can't recover a DeviceTargetAttr from them.
+ if (auto blockArg = dyn_cast<BlockArgument>(value)) {
+ if (auto funcOp =
+ dyn_cast<FunctionOpInterface>(blockArg.getOwner()->getParentOp())) {
+ if (funcOp.isPublic()) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "DeviceTargetValuePVS: argument to a public function - "
+ "treating as undefined\n");
+ unionAssumedWithUndef();
+ indicatePessimisticFixpoint();
+ return;
+ }
+ }
+ }
+}
+
+ChangeStatus DeviceTargetValuePVS::updateValue(Value value,
+ DFX::Solver &solver) {
+ StateType newState;
+ auto traversalResult = TraversalResult::COMPLETE;
+
+ // Walk into all producers of the SSA value.
+ // Note that we may end up at multiple global loads of different globals
+ // by walking up through calls/branches/etc.
+ traversalResult |=
+ solver.getExplorer().walkDefiningOps(value, [&](OpResult result) {
+ updateFromDefiningOp(value, result, newState, solver);
+ return WalkResult::advance();
+ });
+
+ if (traversalResult == TraversalResult::INCOMPLETE) {
+ // Incomplete traversal because of external call graph edges or pointers.
+ newState.unionAssumedWithUndef();
+ newState.indicatePessimisticFixpoint();
+ }
+ return DFX::clampStateAndIndicateChange(getState(), newState);
+}
+
+void DeviceTargetValuePVS::updateFromDefiningOp(Value value, OpResult result,
+ StateType &newState,
+ DFX::Solver &solver) {
+ TypeSwitch<Operation *, void>(result.getOwner())
+ .Case([&](mlir::arith::SelectOp op) {
+ auto &truePVS = solver.getElementFor<DeviceTargetValuePVS>(
+ *this, Position::forValue(op.getTrueValue()),
+ DFX::Resolution::REQUIRED);
+ auto &falsePVS = solver.getElementFor<DeviceTargetValuePVS>(
+ *this, Position::forValue(op.getFalseValue()),
+ DFX::Resolution::REQUIRED);
+ newState ^= truePVS.getState();
+ newState ^= falsePVS.getState();
+ })
+ .Case([&](IREE::Util::OptimizationBarrierOp op) {
+ auto &sourcePVS = solver.getElementFor<DeviceTargetValuePVS>(
+ *this, Position::forValue(op.getOperand(0)),
+ DFX::Resolution::REQUIRED);
+ newState ^= sourcePVS.getState();
+ })
+ .Case([&](IREE::Util::GlobalLoadOpInterface op) {
+ auto *globalInfo =
+ solver.getExplorer().queryGlobalInfoFrom(op.getGlobalName(), op);
+ auto &globalPVS = solver.getElementFor<DeviceTargetGlobalPVS>(
+ *this, Position::forOperation(globalInfo->op),
+ DFX::Resolution::REQUIRED);
+ newState ^= globalPVS.getState();
+ })
+ .Default([&](Operation *op) {});
+}
+
+const std::string DeviceTargetValuePVS::getAsStr(AsmState &asmState) const {
+ std::string str;
+ llvm::raw_string_ostream sstream(str);
+ sstream << "pvs: ";
+ if (isValidState()) {
+ sstream << "[";
+ if (isUndefContained()) {
+ sstream << "undef, ";
+ }
+ llvm::interleaveComma(getAssumedSet(), sstream,
+ [&](IREE::HAL::DeviceTargetAttr value) {
+ cast<Attribute>(value).print(sstream);
+ });
+ sstream << "]";
+ } else {
+ sstream << "(invalid)";
+ }
+ sstream.flush();
+ return str;
+}
+
+} // namespace mlir::iree_compiler::IREE::HAL
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceTargetPVS.h b/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceTargetPVS.h
new file mode 100644
index 0000000..f1b220f
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceTargetPVS.h
@@ -0,0 +1,97 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_HAL_ANALYSIS_ATTRIBUTES_DEVICETARGETPVS_H_
+#define IREE_COMPILER_DIALECT_HAL_ANALYSIS_ATTRIBUTES_DEVICETARGETPVS_H_
+
+#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
+#include "iree/compiler/Dialect/Util/Analysis/DFX/Solver.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir::iree_compiler::IREE::HAL {
+
+//===----------------------------------------------------------------------===//
+// DeviceTargetGlobalPVS (Potential Values State)
+//===----------------------------------------------------------------------===//
+
+// Set of potential IREE::HAL::DeviceTargetAttr values for an initialized
+// !hal.device global. When defined the device global may take on the traits of
+// any of the potential values.
+class DeviceTargetGlobalPVS
+ : public DFX::StateWrapper<
+ DFX::PotentialValuesState<IREE::HAL::DeviceTargetAttr>,
+ DFX::TypedOperationElement<IREE::Util::GlobalOp>> {
+public:
+ using BaseType =
+ DFX::StateWrapper<DFX::PotentialValuesState<IREE::HAL::DeviceTargetAttr>,
+ DFX::TypedOperationElement<IREE::Util::GlobalOp>>;
+ using BaseType::BaseType;
+
+ static DeviceTargetGlobalPVS &createForPosition(const Position &pos,
+ DFX::Solver &solver) {
+ return *(new (solver.getAllocator()) DeviceTargetGlobalPVS(pos));
+ }
+
+ // Identity definitions.
+ const std::string getName() const override { return "DeviceTargetGlobalPVS"; }
+ const void *getID() const override { return &ID; }
+ static bool classof(const DFX::AbstractElement *element) {
+ return (element->getID() == &ID);
+ }
+ static const char ID;
+
+ const std::string getAsStr(AsmState &asmState) const override;
+
+private:
+ void initializeOperation(IREE::Util::GlobalOp globalOp,
+ DFX::Solver &solver) override;
+ ChangeStatus updateOperation(IREE::Util::GlobalOp globalOp,
+ DFX::Solver &solver) override;
+};
+
+//===----------------------------------------------------------------------===//
+// DeviceTargetValuePVS
+//===----------------------------------------------------------------------===//
+
+// Set of potential values for a !hal.device SSA value.
+// When defined the value may take on the traits of any of the potential values.
+class DeviceTargetValuePVS
+ : public DFX::StateWrapper<
+ DFX::PotentialValuesState<IREE::HAL::DeviceTargetAttr>,
+ DFX::ValueElement> {
+public:
+ using BaseType =
+ DFX::StateWrapper<DFX::PotentialValuesState<IREE::HAL::DeviceTargetAttr>,
+ DFX::ValueElement>;
+ using BaseType::BaseType;
+
+ static DeviceTargetValuePVS &createForPosition(const Position &pos,
+ DFX::Solver &solver) {
+ return *(new (solver.getAllocator()) DeviceTargetValuePVS(pos));
+ }
+
+ // Identity definitions.
+ const std::string getName() const override { return "DeviceTargetValuePVS"; }
+ const void *getID() const override { return &ID; }
+ static bool classof(const DFX::AbstractElement *element) {
+ return (element->getID() == &ID);
+ }
+ static const char ID;
+
+ const std::string getAsStr(AsmState &asmState) const override;
+
+private:
+ void initializeValue(Value value, DFX::Solver &solver) override;
+ ChangeStatus updateValue(Value value, DFX::Solver &solver) override;
+ void updateFromDefiningOp(Value value, OpResult result, StateType &newState,
+ DFX::Solver &solver);
+};
+
+} // namespace mlir::iree_compiler::IREE::HAL
+
+#endif // IREE_COMPILER_DIALECT_HAL_ANALYSIS_ATTRIBUTES_DEVICETARGETPVS_H_
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BUILD.bazel
index f6e18bf..0e2aa4a 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BUILD.bazel
@@ -16,18 +16,27 @@
name = "Analysis",
srcs = [
"BindingLayout.cpp",
+ "DeviceAnalysis.cpp",
+ "DeviceSet.cpp",
],
hdrs = [
"BindingLayout.h",
+ "DeviceAnalysis.h",
+ "DeviceSet.h",
],
deps = [
+ "//compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Dialect/Stream/IR",
+ "//compiler/src/iree/compiler/Dialect/Util/Analysis",
+ "//compiler/src/iree/compiler/Dialect/Util/Analysis/DFX",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
+ "@llvm-project//mlir:FunctionInterfaces",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Support",
],
)
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp
index f762a1f..fc26729 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp
@@ -211,6 +211,15 @@
}
}
+bool BindingLayoutAnalysis::hasDispatches() const {
+ for (auto &it : exportInfos) {
+ if (!it.second->dispatchOps.empty()) {
+ return true; // found at least one dispatch
+ }
+ }
+ return false;
+}
+
ArrayRef<IREE::Stream::CmdDispatchOp>
BindingLayoutAnalysis::getExportDispatches(Operation *exportOp) const {
auto it = exportInfos.find(exportOp);
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.h b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.h
index 1e8704d..7d08959 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.h
+++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.h
@@ -4,8 +4,8 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#ifndef IREE_COMPILER_DIALECT_HAL_ANALYSIS_BINDINGLAYOUT_
-#define IREE_COMPILER_DIALECT_HAL_ANALYSIS_BINDINGLAYOUT_
+#ifndef IREE_COMPILER_DIALECT_HAL_ANALYSIS_BINDINGLAYOUT_H_
+#define IREE_COMPILER_DIALECT_HAL_ANALYSIS_BINDINGLAYOUT_H_
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
@@ -59,6 +59,9 @@
public:
explicit BindingLayoutAnalysis(Operation *rootOp, SymbolTable &symbolTable);
+ // Whether there are any dispatches in the program.
+ bool hasDispatches() const;
+
// Returns all of the dispatches to the given executable export.
ArrayRef<IREE::Stream::CmdDispatchOp>
getExportDispatches(IREE::Stream::ExecutableExportOp exportOp) const {
@@ -94,4 +97,4 @@
} // namespace mlir::iree_compiler::IREE::HAL
-#endif // IREE_COMPILER_DIALECT_HAL_ANALYSIS_BINDINGLAYOUT_
+#endif // IREE_COMPILER_DIALECT_HAL_ANALYSIS_BINDINGLAYOUT_H_
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Analysis/CMakeLists.txt
index e25ba34..6e733ac 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Analysis/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/CMakeLists.txt
@@ -15,16 +15,25 @@
Analysis
HDRS
"BindingLayout.h"
+ "DeviceAnalysis.h"
+ "DeviceSet.h"
SRCS
"BindingLayout.cpp"
+ "DeviceAnalysis.cpp"
+ "DeviceSet.cpp"
DEPS
LLVMSupport
MLIRAnalysis
+ MLIRFunctionInterfaces
MLIRIR
MLIRPass
+ MLIRSCFDialect
MLIRSupport
+ iree::compiler::Dialect::HAL::Analysis::Attributes
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::Stream::IR
+ iree::compiler::Dialect::Util::Analysis
+ iree::compiler::Dialect::Util::Analysis::DFX
iree::compiler::Dialect::Util::IR
PUBLIC
)
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.cpp b/compiler/src/iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.cpp
new file mode 100644
index 0000000..f144e31
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.cpp
@@ -0,0 +1,234 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h"
+
+#include "iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceGlobalPVS.h"
+#include "iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceTargetPVS.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/Util/Analysis/DFX/Element.h"
+#include "iree/compiler/Dialect/Util/Analysis/DFX/State.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+
+namespace mlir::iree_compiler::IREE::HAL {
+
+//===----------------------------------------------------------------------===//
+// DeviceAnalysis
+//===----------------------------------------------------------------------===//
+
+DeviceAnalysis::DeviceAnalysis(Operation *rootOp)
+ : explorer(rootOp, TraversalAction::SHALLOW), solver(explorer, allocator) {
+ explorer.setOpInterfaceAction<mlir::FunctionOpInterface>(
+ TraversalAction::RECURSE);
+ explorer.setOpAction<mlir::scf::ForOp>(TraversalAction::RECURSE);
+ explorer.setOpAction<mlir::scf::IfOp>(TraversalAction::RECURSE);
+ explorer.setOpAction<mlir::scf::WhileOp>(TraversalAction::RECURSE);
+ // Ignore the contents of executables (linalg goo, etc).
+ explorer.setOpAction<IREE::HAL::ExecutableOp>(TraversalAction::IGNORE);
+ explorer.initialize();
+}
+
+DeviceAnalysis::~DeviceAnalysis() = default;
+
+LogicalResult DeviceAnalysis::run() {
+ // TODO(multi-device): remove this fallback path when device globals are fully
+ // plumbed through. Today we still have inputs with the hal.device.targets
+ // attribute.
+ if (auto targetsAttr = explorer.getRootOp()->getAttrOfType<ArrayAttr>(
+ "hal.device.targets")) {
+ if (!targetsAttr.empty()) {
+ defaultDeviceSet = DeviceSet(targetsAttr);
+ }
+ }
+
+ // Initialize device globals (in declaration order).
+ for (auto globalOp : explorer.getRootOp()
+ ->getRegion(0)
+ .getOps<IREE::Util::GlobalOpInterface>()) {
+ auto globalType = globalOp.getGlobalType();
+ if (isa<IREE::HAL::DeviceType>(globalType)) {
+ solver.getOrCreateElementFor<DeviceTargetGlobalPVS>(
+ Position::forOperation(globalOp));
+ deviceGlobals.push_back(globalOp);
+ }
+ }
+
+ // Initialize all SSA values so we can do just with trivial search.
+ explorer.walkValuesOfType<IREE::HAL::DeviceType>([&](Value value) {
+ solver.getOrCreateElementFor<DeviceGlobalValuePVS>(
+ Position::forValue(value));
+ solver.getOrCreateElementFor<DeviceTargetValuePVS>(
+ Position::forValue(value));
+ return WalkResult::advance();
+ });
+
+ return solver.run();
+}
+
+std::optional<SetVector<IREE::Util::GlobalOpInterface>>
+DeviceAnalysis::lookupDeviceGlobals(Value deviceValue) {
+ auto globalPVS = solver.lookupElementFor<DeviceGlobalValuePVS>(
+ Position::forValue(deviceValue));
+ if (!globalPVS || !globalPVS->isValidState() ||
+ globalPVS->isUndefContained()) {
+ return std::nullopt;
+ }
+ SetVector<IREE::Util::GlobalOpInterface> globalOps;
+ for (auto globalOp : globalPVS->getAssumedSet()) {
+ globalOps.insert(globalOp);
+ }
+ return globalOps;
+}
+
+std::optional<DeviceSet>
+DeviceAnalysis::lookupDeviceTargets(Value deviceValue) {
+ auto valuePVS = solver.lookupElementFor<DeviceTargetValuePVS>(
+ Position::forValue(deviceValue));
+ if (!valuePVS || !valuePVS->isValidState() || valuePVS->isUndefContained()) {
+ return defaultDeviceSet;
+ }
+ return DeviceSet(valuePVS->getAssumedSet());
+}
+
+// Returns a set of target devices that may be active for the given
+// operation. This will recursively walk parent operations until one with
+// the `hal.device.targets` attribute is found.
+//
+// This is a legacy mechanism for performing the search. Newer code should use
+// affinities or !hal.device analysis instead.
+static void gatherLegacyDeviceTargetAttrs(
+ Operation *op, SetVector<IREE::HAL::DeviceTargetAttr> &resultSet) {
+ auto attrId = StringAttr::get(op->getContext(), "hal.device.targets");
+ while (op) {
+ auto targetsAttr = op->getAttrOfType<ArrayAttr>(attrId);
+ if (targetsAttr) {
+ for (auto elementAttr : targetsAttr) {
+ if (auto targetAttr =
+ dyn_cast<IREE::HAL::DeviceTargetAttr>(elementAttr)) {
+ resultSet.insert(targetAttr);
+ } else {
+ // HACK: this legacy approach is deprecated and only preserved for
+ // existing behavior. It's ok to get angry here as users should not be
+ // trying to use this pass prior to device materialization.
+ assert(false &&
+ "legacy hal.device.targets only support hal.device.targets");
+ }
+ }
+ return;
+ }
+ op = op->getParentOp();
+ }
+ // No devices found; let caller decide what to do.
+}
+
+// Recursively resolves the referenced device into targets.
+void DeviceAnalysis::gatherDeviceTargets(
+ Attribute rootAttr, Operation *fromOp,
+ SetVector<IREE::HAL::DeviceTargetAttr> &resultSet) {
+ SetVector<Attribute> worklist;
+ worklist.insert(rootAttr);
+ do {
+ auto attr = worklist.pop_back_val();
+ if (!TypeSwitch<Attribute, bool>(attr)
+ .Case<SymbolRefAttr>([&](auto symRefAttr) {
+ auto globalOp =
+ explorer.getSymbolTables()
+ .lookupNearestSymbolFrom<IREE::Util::GlobalOpInterface>(
+ fromOp, symRefAttr);
+ assert(globalOp && "global reference must be valid");
+ if (auto initialValueAttr = globalOp.getGlobalInitialValue()) {
+ // Global with a device initialization value we can analyze.
+ worklist.insert(initialValueAttr);
+ return true;
+ } else {
+ return false;
+ }
+ })
+ .Case<IREE::HAL::DeviceTargetAttr>([&](auto targetAttr) {
+ resultSet.insert(targetAttr);
+ return true;
+ })
+ .Case<IREE::HAL::DeviceFallbackAttr>([&](auto fallbackAttr) {
+ worklist.insert(fallbackAttr.getName());
+ return true;
+ })
+ .Case<IREE::HAL::DeviceSelectAttr>([&](auto selectAttr) {
+ worklist.insert(selectAttr.getDevices().begin(),
+ selectAttr.getDevices().end());
+ return true;
+ })
+ .Default([](auto attr) { return false; })) {
+ // No initial value means fall back to defaults. We do that by
+ // inserting all knowable targets.
+ gatherLegacyDeviceTargetAttrs(fromOp, resultSet);
+ return;
+ }
+ } while (!worklist.empty());
+}
+
+void DeviceAnalysis::gatherAllDeviceTargets(
+ SetVector<IREE::HAL::DeviceTargetAttr> &resultSet) {
+ for (auto globalOp : deviceGlobals) {
+ gatherDeviceTargets(FlatSymbolRefAttr::get(globalOp), explorer.getRootOp(),
+ resultSet);
+ }
+}
+
+void DeviceAnalysis::gatherDeviceAffinityTargets(
+ IREE::Stream::AffinityAttr affinityAttr, Operation *fromOp,
+ SetVector<IREE::HAL::DeviceTargetAttr> &resultSet) {
+ // We currently only know how to handle HAL device affinities.
+ // We could support other ones via an interface but instead we just fall back
+ // to default logic if no affinity or an unknown one is found.
+ auto deviceAffinityAttr =
+ dyn_cast_if_present<IREE::HAL::DeviceAffinityAttr>(affinityAttr);
+ if (!deviceAffinityAttr) {
+ gatherLegacyDeviceTargetAttrs(fromOp, resultSet);
+ return;
+ }
+
+ // Recursively resolve the referenced device into targets.
+ gatherDeviceTargets(deviceAffinityAttr.getDevice(), fromOp, resultSet);
+}
+
+void DeviceAnalysis::gatherAllExecutableTargets(
+ SetVector<IREE::HAL::ExecutableTargetAttr> &resultSet) {
+ SetVector<IREE::HAL::DeviceTargetAttr> deviceTargetSet;
+ gatherAllDeviceTargets(deviceTargetSet);
+ for (auto deviceTargetAttr : deviceTargetSet) {
+ deviceTargetAttr.getExecutableTargets(resultSet);
+ }
+}
+
+void DeviceAnalysis::gatherRequiredExecutableTargets(
+ Operation *forOp, SetVector<IREE::HAL::ExecutableTargetAttr> &resultSet) {
+ // Get the affinity from the op or an ancestor. Note that there may be no
+ // affinity specified at all.
+ auto affinityAttr = IREE::Stream::AffinityAttr::lookupOrDefault(forOp);
+
+ // Gather the device targets that are referenced by the affinity.
+ SetVector<IREE::HAL::DeviceTargetAttr> deviceTargetSet;
+ gatherDeviceAffinityTargets(affinityAttr, forOp, deviceTargetSet);
+
+ // Add all executable targets on the device targets.
+ for (auto deviceTargetAttr : deviceTargetSet) {
+ resultSet.insert(deviceTargetAttr.getExecutableTargets().begin(),
+ deviceTargetAttr.getExecutableTargets().end());
+ }
+}
+
+void DeviceAnalysis::gatherRequiredExecutableTargets(
+ IREE::Stream::AffinityAttr affinityAttr, Operation *fromOp,
+ SetVector<IREE::HAL::ExecutableTargetAttr> &resultSet) {
+ SetVector<IREE::HAL::DeviceTargetAttr> deviceTargetAttrs;
+ gatherDeviceAffinityTargets(affinityAttr, fromOp, deviceTargetAttrs);
+ for (auto deviceTargetAttr : deviceTargetAttrs) {
+ deviceTargetAttr.getExecutableTargets(resultSet);
+ }
+}
+
+} // namespace mlir::iree_compiler::IREE::HAL
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h b/compiler/src/iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h
new file mode 100644
index 0000000..e4f6245
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h
@@ -0,0 +1,104 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_HAL_ANALYSIS_DEVICETARGET_H_
+#define IREE_COMPILER_DIALECT_HAL_ANALYSIS_DEVICETARGET_H_
+
+#include <optional>
+
+#include "iree/compiler/Dialect/HAL/Analysis/DeviceSet.h"
+#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
+#include "iree/compiler/Dialect/Util/Analysis/DFX/Solver.h"
+#include "iree/compiler/Dialect/Util/Analysis/Explorer.h"
+#include "llvm/ADT/DenseSet.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir::iree_compiler::IREE::HAL {
+
+//===----------------------------------------------------------------------===//
+// DeviceAnalysis
+//===----------------------------------------------------------------------===//
+
+// Performs whole-program analysis of device traits (limits, configuration, etc)
+// and allows for queries against `!hal.device` values for known traits.
+//
+// Though safe to run at any time this may not provide meaningful results until
+// after devices have been materialized and the program has been converted into
+// the HAL dialect.
+class DeviceAnalysis {
+public:
+ explicit DeviceAnalysis(Operation *rootOp);
+ ~DeviceAnalysis();
+
+ Explorer &getExplorer() { return explorer; }
+
+ // Runs analysis and populates the device traits map.
+ // May fail if analysis cannot be completed due to unsupported or unknown IR.
+ LogicalResult run();
+
+ // Returns a set of all !hal.device globals in the analyzed root op in the
+ // order they are declared in the root op.
+ ArrayRef<IREE::Util::GlobalOpInterface> getDeviceGlobals() {
+ return deviceGlobals;
+ }
+
+ // Returns a set of possible device globals of the given `!hal.device` value,
+ // if analyzed.
+ std::optional<SetVector<IREE::Util::GlobalOpInterface>>
+ lookupDeviceGlobals(Value deviceValue);
+
+ // Returns a set of possible targets of the given `!hal.device` value, if
+ // analyzed.
+ std::optional<DeviceSet> lookupDeviceTargets(Value deviceValue);
+
+ // Gathers all possible device targets in the root op.
+ // Ordering is undefined.
+ void
+ gatherAllDeviceTargets(SetVector<IREE::HAL::DeviceTargetAttr> &resultSet);
+
+ // Gathers the set of device targets potentially referenced by the given
+ // affinity. Targets are ordered by most likely to least likely.
+ void gatherDeviceAffinityTargets(
+ IREE::Stream::AffinityAttr affinityAttr, Operation *fromOp,
+ SetVector<IREE::HAL::DeviceTargetAttr> &resultSet);
+
+ // Gathers all executable targets from all devices in the root op.
+ // This should generally be avoided and the scoped
+ // gatherRequiredExecutableTargets gather should be used instead.
+ void gatherAllExecutableTargets(
+ SetVector<IREE::HAL::ExecutableTargetAttr> &resultSet);
+
+ // Gathers all executable targets that may be required by the given host op.
+ // This should be called on the most narrowly scoped op possible as multiple
+ // devices may be used within the same function-like op and have different
+ // requirements. This may return a set with more targets than expected.
+ void gatherRequiredExecutableTargets(
+ Operation *forOp, SetVector<IREE::HAL::ExecutableTargetAttr> &resultSet);
+
+ // Gathers all executable targets that may be required for the given affinity.
+ // This should be called on the most narrowly scoped op possible as multiple
+ // devices may be used within the same function-like op and have different
+ // requirements. This may return a set with more targets than expected.
+ void gatherRequiredExecutableTargets(
+ IREE::Stream::AffinityAttr affinityAttr, Operation *fromOp,
+ SetVector<IREE::HAL::ExecutableTargetAttr> &resultSet);
+
+private:
+ // Recursively resolves the referenced device into targets.
+ void gatherDeviceTargets(Attribute rootAttr, Operation *fromOp,
+ SetVector<IREE::HAL::DeviceTargetAttr> &resultSet);
+
+ Explorer explorer;
+ llvm::BumpPtrAllocator allocator;
+ DFX::Solver solver;
+ std::optional<DeviceSet> defaultDeviceSet;
+ SmallVector<IREE::Util::GlobalOpInterface> deviceGlobals;
+};
+
+} // namespace mlir::iree_compiler::IREE::HAL
+
+#endif // IREE_COMPILER_DIALECT_HAL_ANALYSIS_DEVICETARGET_H_
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/DeviceSet.cpp b/compiler/src/iree/compiler/Dialect/HAL/Analysis/DeviceSet.cpp
new file mode 100644
index 0000000..b60e5e7
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/DeviceSet.cpp
@@ -0,0 +1,139 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/HAL/Analysis/DeviceSet.h"
+
+namespace mlir::iree_compiler::IREE::HAL {
+
+//===----------------------------------------------------------------------===//
+// DeviceSet
+//===----------------------------------------------------------------------===//
+
+DeviceSet::DeviceSet(ArrayAttr targetsAttr) {
+ for (auto targetAttr :
+ targetsAttr.getAsRange<IREE::HAL::DeviceTargetAttr>()) {
+ targetAttrs.insert(targetAttr);
+ }
+}
+
+DeviceSet::DeviceSet(const DenseSet<IREE::HAL::DeviceTargetAttr> &targetAttrs)
+ : targetAttrs(targetAttrs) {}
+
+DeviceSet::~DeviceSet() = default;
+
+std::optional<SmallVector<IREE::HAL::ExecutableTargetAttr>>
+DeviceSet::getExecutableTargets() const {
+ if (targetAttrs.empty()) {
+ return std::nullopt;
+ }
+ SetVector<IREE::HAL::ExecutableTargetAttr> resultAttrs;
+ for (auto targetAttr : targetAttrs) {
+ targetAttr.getExecutableTargets(resultAttrs);
+ }
+ return llvm::to_vector(resultAttrs);
+}
+
+template <typename AttrT>
+static std::optional<typename AttrT::ValueType> joinConfigAttrs(
+ const DenseSet<IREE::HAL::DeviceTargetAttr> &targetAttrs, StringRef name,
+ std::function<typename AttrT::ValueType(typename AttrT::ValueType,
+ typename AttrT::ValueType)>
+ join) {
+ if (targetAttrs.empty()) {
+ return std::nullopt;
+ }
+ std::optional<typename AttrT::ValueType> result;
+ for (auto targetAttr : targetAttrs) {
+ auto configAttr = targetAttr.getConfiguration();
+ if (!configAttr) {
+ return std::nullopt;
+ }
+ auto valueAttr = configAttr.getAs<AttrT>(name);
+ if (!valueAttr) {
+ return std::nullopt;
+ } else if (!result) {
+ result = valueAttr.getValue();
+ } else {
+ result = join(result.value(), valueAttr.getValue());
+ }
+ }
+ return result;
+}
+
+template <typename AttrT>
+static std::optional<StaticRange<typename AttrT::ValueType>>
+joinConfigStaticRanges(const DenseSet<IREE::HAL::DeviceTargetAttr> &targetAttrs,
+ StringRef name,
+ std::function<StaticRange<typename AttrT::ValueType>(
+ StaticRange<typename AttrT::ValueType>,
+ StaticRange<typename AttrT::ValueType>)>
+ join) {
+ if (targetAttrs.empty()) {
+ return std::nullopt;
+ }
+ std::optional<StaticRange<typename AttrT::ValueType>> result;
+ for (auto targetAttr : targetAttrs) {
+ auto configAttr = targetAttr.getConfiguration();
+ if (!configAttr) {
+ return std::nullopt;
+ }
+ auto valueAttr = configAttr.getAs<AttrT>(name);
+ if (!valueAttr) {
+ return std::nullopt;
+ } else if (!result) {
+ result = valueAttr.getValue();
+ } else {
+ result =
+ join(result.value(),
+ StaticRange<typename AttrT::ValueType>{valueAttr.getValue()});
+ }
+ }
+ return result;
+}
+
+bool DeviceSet::hasConfigAttrAny(StringRef name) const {
+ for (auto targetAttr : targetAttrs) {
+ if (auto configAttr = targetAttr.getConfiguration()) {
+ if (configAttr.get(name)) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+bool DeviceSet::hasConfigAttrAll(StringRef name) const {
+ for (auto targetAttr : targetAttrs) {
+ auto configAttr = targetAttr.getConfiguration();
+ if (!configAttr || !configAttr.get(name)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+std::optional<bool> DeviceSet::getConfigAttrAnd(StringRef name) const {
+ return joinConfigAttrs<BoolAttr>(
+ targetAttrs, name, [](bool lhs, bool rhs) { return lhs && rhs; });
+}
+
+std::optional<bool> DeviceSet::getConfigAttrOr(StringRef name) const {
+ return joinConfigAttrs<BoolAttr>(
+ targetAttrs, name, [](bool lhs, bool rhs) { return lhs || rhs; });
+}
+
+std::optional<StaticRange<APInt>>
+DeviceSet::getConfigAttrRange(StringRef name) const {
+ return joinConfigStaticRanges<IntegerAttr>(
+ targetAttrs, name, [](StaticRange<APInt> lhs, StaticRange<APInt> rhs) {
+ return StaticRange<APInt>{
+ llvm::APIntOps::smin(lhs.min, rhs.min),
+ llvm::APIntOps::smax(lhs.max, rhs.max),
+ };
+ });
+}
+
+} // namespace mlir::iree_compiler::IREE::HAL
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/DeviceSet.h b/compiler/src/iree/compiler/Dialect/HAL/Analysis/DeviceSet.h
new file mode 100644
index 0000000..18f7d9f
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/DeviceSet.h
@@ -0,0 +1,57 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_HAL_ANALYSIS_DEVICESET_H_
+#define IREE_COMPILER_DIALECT_HAL_ANALYSIS_DEVICESET_H_
+
+#include <optional>
+
+#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
+#include "llvm/ADT/DenseSet.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir::iree_compiler::IREE::HAL {
+
+// Provides configuration queries over a set of devices.
+class DeviceSet {
+public:
+ DeviceSet() = default;
+ explicit DeviceSet(ArrayAttr targetsAttr);
+ explicit DeviceSet(const DenseSet<IREE::HAL::DeviceTargetAttr> &targetAttrs);
+ ~DeviceSet();
+
+ // Returns zero or more executable targets that may be used by any device.
+ std::optional<SmallVector<IREE::HAL::ExecutableTargetAttr>>
+ getExecutableTargets() const;
+
+ // Returns true if there is any UnitAttr with |name| in any device.
+ bool hasConfigAttrAny(StringRef name) const;
+
+ // Returns true if all device configurations have a UnitAttr with |name|.
+ bool hasConfigAttrAll(StringRef name) const;
+
+ // Returns the AND of boolean attributes of |name| in all devices.
+ // Returns nullopt if any config does not have the key defined indicating
+ // that it's not statically known/runtime dynamic.
+ std::optional<bool> getConfigAttrAnd(StringRef name) const;
+
+ // Returns the OR of boolean attributes of |name| in all devices.
+ // Returns nullopt if any config does not have the key defined indicating
+ // that it's not statically known/runtime dynamic.
+ std::optional<bool> getConfigAttrOr(StringRef name) const;
+
+ // Returns the range of integer attributes of |name| in all devices.
+ // Returns nullopt if any config does not have the key defined indicating
+ // that it's not statically known/runtime dynamic.
+ std::optional<StaticRange<APInt>> getConfigAttrRange(StringRef name) const;
+
+private:
+ DenseSet<IREE::HAL::DeviceTargetAttr> targetAttrs;
+};
+
+} // namespace mlir::iree_compiler::IREE::HAL
+
+#endif // IREE_COMPILER_DIALECT_HAL_ANALYSIS_DEVICESET_H_
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/Patterns.cpp
index a17a100..705292b 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/Patterns.cpp
@@ -15,6 +15,92 @@
namespace {
+struct ConvertDeviceResolveAnyOp
+ : public OpConversionPattern<IREE::HAL::DeviceResolveOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(IREE::HAL::DeviceResolveOp resolveOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (adaptor.getAffinity()) {
+ return rewriter.notifyMatchFailure(
+ resolveOp, "only resolving unspecified affinities to any device");
+ }
+
+ auto deviceType = rewriter.getType<IREE::HAL::DeviceType>();
+ Value device;
+ auto resolveDevice = [&]() {
+ if (!device) {
+ device = rewriter.create<IREE::HAL::DevicesGetOp>(
+ resolveOp.getLoc(), deviceType,
+ rewriter.create<arith::ConstantIndexOp>(resolveOp.getLoc(), 0));
+ }
+ return device;
+ };
+
+ SmallVector<Value> results;
+ for (auto resultType : resolveOp.getResultTypes()) {
+ if (isa<IREE::HAL::DeviceType>(resultType)) {
+ results.push_back(resolveDevice());
+ } else if (isa<IREE::HAL::AllocatorType>(resultType)) {
+ results.push_back(rewriter.create<IREE::HAL::DeviceAllocatorOp>(
+ resolveOp.getLoc(), resolveDevice()));
+ } else if (isa<IntegerType>(resultType)) {
+ results.push_back(rewriter.create<arith::ConstantIntOp>(
+ resolveOp.getLoc(), -1ll, 64));
+ }
+ }
+
+ rewriter.replaceOp(resolveOp, results);
+ return success();
+ }
+};
+
+struct ConvertDeviceResolveAffinityOp
+ : public OpConversionPattern<IREE::HAL::DeviceResolveOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(IREE::HAL::DeviceResolveOp resolveOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto affinityAttr = adaptor.getAffinityAttr();
+ if (!affinityAttr) {
+ return rewriter.notifyMatchFailure(
+ resolveOp, "only resolving fully specified affinities");
+ }
+ auto flatDeviceAttr = dyn_cast<FlatSymbolRefAttr>(affinityAttr.getDevice());
+ if (!flatDeviceAttr) {
+ return rewriter.notifyMatchFailure(
+ resolveOp, "nested device references not yet supported");
+ }
+
+ auto deviceType = rewriter.getType<IREE::HAL::DeviceType>();
+ Value device;
+ auto resolveDevice = [&]() {
+ if (!device) {
+ device = rewriter.create<IREE::Util::GlobalLoadOp>(
+ resolveOp.getLoc(), deviceType, flatDeviceAttr.getValue(),
+ /*is_immutable=*/true);
+ }
+ return device;
+ };
+
+ SmallVector<Value> results;
+ for (auto resultType : resolveOp.getResultTypes()) {
+ if (isa<IREE::HAL::DeviceType>(resultType)) {
+ results.push_back(resolveDevice());
+ } else if (isa<IREE::HAL::AllocatorType>(resultType)) {
+ results.push_back(rewriter.create<IREE::HAL::DeviceAllocatorOp>(
+ resolveOp.getLoc(), resolveDevice()));
+ } else if (isa<IntegerType>(resultType)) {
+ results.push_back(rewriter.create<arith::ConstantIntOp>(
+ resolveOp.getLoc(), affinityAttr.getQueueMask(), 64));
+ }
+ }
+
+ rewriter.replaceOp(resolveOp, results);
+ return success();
+ }
+};
+
struct ConvertExecutableCalculateWorkgroupsOp
: public OpConversionPattern<IREE::HAL::ExecutableCalculateWorkgroupsOp> {
using OpConversionPattern::OpConversionPattern;
@@ -43,6 +129,10 @@
ConversionTarget &conversionTarget,
TypeConverter &typeConverter,
RewritePatternSet &patterns) {
+ conversionTarget.addIllegalOp<IREE::HAL::DeviceResolveOp>();
+ patterns.insert<ConvertDeviceResolveAnyOp>(typeConverter, context);
+ patterns.insert<ConvertDeviceResolveAffinityOp>(typeConverter, context);
+
conversionTarget.addIllegalOp<IREE::HAL::ExecutableCalculateWorkgroupsOp>();
patterns.insert<ConvertExecutableCalculateWorkgroupsOp>(typeConverter,
context);
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD.bazel
index 6056652..b38bbbe 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD.bazel
@@ -15,7 +15,10 @@
iree_lit_test_suite(
name = "lit",
srcs = enforce_glob(
- ["pseudo_ops.mlir"],
+ [
+ "device_ops.mlir",
+ "pseudo_ops.mlir",
+ ],
include = ["*.mlir"],
),
cfg = "//compiler:lit.cfg.py",
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/CMakeLists.txt
index 2a57a80..6757109 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/CMakeLists.txt
@@ -14,6 +14,7 @@
NAME
lit
SRCS
+ "device_ops.mlir"
"pseudo_ops.mlir"
TOOLS
FileCheck
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/device_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/device_ops.mlir
new file mode 100644
index 0000000..44fb128
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/device_ops.mlir
@@ -0,0 +1,75 @@
+// RUN: iree-opt --split-input-file --allow-unregistered-dialect --iree-hal-conversion %s | FileCheck %s
+
+// CHECK-LABEL: @deviceResolveAnyDevice
+util.func public @deviceResolveAnyDevice() -> !hal.device {
+ // CHECK-DAG: %[[ANY_ORDINAL:.+]] = arith.constant 0
+ // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %[[ANY_ORDINAL]] : !hal.device
+ %device = hal.device.resolve : !hal.device
+ // CHECK: util.return %[[DEVICE]]
+ util.return %device : !hal.device
+}
+
+// -----
+
+util.global private @device : !hal.device
+
+// CHECK-LABEL: @deviceResolveDevice
+util.func public @deviceResolveDevice() -> !hal.device {
+ // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device
+ %device = hal.device.resolve on(#hal.device.affinity<@device>) : !hal.device
+ // CHECK: util.return %[[DEVICE]]
+ util.return %device : !hal.device
+}
+
+// -----
+
+util.global private @device : !hal.device
+
+// CHECK-LABEL: @deviceResolveDeviceQueueAffinityAny
+util.func public @deviceResolveDeviceQueueAffinityAny() -> (!hal.device, i64) {
+ // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device
+ // CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant -1 : i64
+ %device, %queue_affinity_any = hal.device.resolve on(#hal.device.affinity<@device>) : !hal.device, i64
+ // CHECK: util.return %[[DEVICE]], %[[QUEUE_AFFINITY]]
+ util.return %device, %queue_affinity_any : !hal.device, i64
+}
+
+// -----
+
+util.global private @device : !hal.device
+
+// CHECK-LABEL: @deviceResolveDeviceQueueAffinity45
+util.func public @deviceResolveDeviceQueueAffinity45() -> (!hal.device, i64) {
+ // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device
+ // CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant 48 : i64
+ %device, %queue_affinity_45 = hal.device.resolve on(#hal.device.affinity<@device, [4, 5]>) : !hal.device, i64
+ // CHECK: util.return %[[DEVICE]], %[[QUEUE_AFFINITY]]
+ util.return %device, %queue_affinity_45 : !hal.device, i64
+}
+
+// -----
+
+util.global private @device : !hal.device
+
+// CHECK-LABEL: @deviceResolveAllocator
+util.func public @deviceResolveAllocator() -> !hal.allocator {
+ // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device
+ // CHECK-DAG: %[[ALLOCATOR:.+]] = hal.device.allocator<%[[DEVICE]] : !hal.device> : !hal.allocator
+ %allocator = hal.device.resolve on(#hal.device.affinity<@device>) : !hal.allocator
+ // CHECK: util.return %[[ALLOCATOR]]
+ util.return %allocator : !hal.allocator
+}
+
+// -----
+
+util.global private @device : !hal.device
+
+// CHECK-LABEL: @deviceResolveAllocatorQueueAffinity45
+util.func public @deviceResolveAllocatorQueueAffinity45() -> (!hal.allocator, i64) {
+ // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device
+ // CHECK-DAG: %[[ALLOCATOR:.+]] = hal.device.allocator<%[[DEVICE]] : !hal.device> : !hal.allocator
+ // CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant 48 : i64
+ %allocator, %queue_affinity_45 = hal.device.resolve on(#hal.device.affinity<@device, [4, 5]>) : !hal.allocator, i64
+ // CHECK: util.return %[[ALLOCATOR]], %[[QUEUE_AFFINITY]]
+ util.return %allocator, %queue_affinity_45 : !hal.allocator, i64
+}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp
index f8efeba..a911de8 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp
@@ -69,32 +69,6 @@
return constantBuffer;
}
-IREE::VM::RodataOp
-createExecutableBinaryRodata(IREE::HAL::ExecutableBinaryOp binaryOp,
- OpBuilder &builder) {
- auto executableOp =
- binaryOp.getOperation()->getParentOfType<IREE::HAL::ExecutableOp>();
- auto insertPoint = builder.saveInsertionPoint();
- builder.setInsertionPoint(builder.getInsertionBlock()->getParentOp());
-
- std::string rodataName = sanitizeSymbolName(
- (executableOp.getName() + "_" + binaryOp.getName()).str());
- auto rodataOp = builder.create<IREE::VM::RodataOp>(
- binaryOp.getLoc(), rodataName, binaryOp.getData());
- rodataOp.setPrivate();
- if (binaryOp.getMimeType().has_value()) {
- rodataOp.setMimeTypeAttr(binaryOp.getMimeTypeAttr());
- }
-
- // TODO(benvanik): should these be page aligned? memcpy fastpath is fine for
- // now.
- rodataOp.setAlignmentAttr(builder.getI64IntegerAttr(16));
-
- builder.restoreInsertionPoint(insertPoint);
-
- return rodataOp;
-}
-
namespace {
class RemoveExecutableOpConversion
@@ -128,9 +102,15 @@
auto executableBinaryOp =
SymbolTable::lookupNearestSymbolFrom<IREE::HAL::ExecutableBinaryOp>(
createOp, createOp.getExecutableTarget());
- auto rodataOp = createExecutableBinaryRodata(executableBinaryOp, rewriter);
- auto executableRodata = rewriter.createOrFold<IREE::VM::ConstRefRodataOp>(
- createOp.getLoc(), rodataOp);
+ auto executableOp = executableBinaryOp.getOperation()
+ ->getParentOfType<IREE::HAL::ExecutableOp>();
+ std::string rodataName = sanitizeSymbolName(
+ (executableOp.getName() + "_" + executableBinaryOp.getName()).str());
+ auto rodataOp = rewriter.create<IREE::VM::RodataInlineOp>(
+ executableBinaryOp.getLoc(),
+ IREE::VM::RefType::get(rewriter.getType<IREE::VM::BufferType>()),
+ rewriter.getStringAttr(rodataName), executableBinaryOp.getData(),
+ rewriter.getI64IntegerAttr(16), executableBinaryOp.getMimeTypeAttr());
// Get format string as a rodata blob.
auto executableFormatStr = rewriter.create<IREE::VM::RodataInlineOp>(
@@ -151,7 +131,7 @@
SmallVector<Value, 8> callOperands = {
adaptor.getDevice(),
executableFormatStr,
- executableRodata,
+ rodataOp,
constantBuffer,
};
callOperands.append(adaptor.getLayouts().begin(),
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/Patterns.h b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/Patterns.h
index 071643a..0596524 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/Patterns.h
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/Patterns.h
@@ -23,11 +23,6 @@
Value createPackedConstantBuffer(Location loc, ValueRange constantValues,
OpBuilder &builder);
-// Creates a vm.rodata containing the contents of a hal.executable.binary.
-IREE::VM::RodataOp
-createExecutableBinaryRodata(IREE::HAL::ExecutableBinaryOp binaryOp,
- OpBuilder &builder);
-
} // namespace mlir::iree_compiler
#endif // IREE_COMPILER_DIALECT_HAL_CONVERSION_HALTOVM_PATTERNS_H_
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/executable_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/executable_ops.mlir
index 9249b56..5dd5341 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/executable_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/executable_ops.mlir
@@ -1,7 +1,5 @@
// RUN: iree-opt --split-input-file --iree-vm-conversion %s | FileCheck %s
-// CHECK: vm.rodata private @exe_binary1 {alignment = 16 : i64} dense<[0, 1, 2, 3]> : vector<4xi8>
-// CHECK: vm.rodata private @exe_binary2 {alignment = 16 : i64} dense<[4, 5, 6, 7]> : vector<4xi8>
hal.executable @exe {
hal.executable.binary @binary1 attributes {
data = dense<[0, 1, 2, 3]> : vector<4xi8>,
@@ -24,7 +22,7 @@
) -> (!hal.executable, !hal.executable) {
// CHECK-DAG: %[[FORMAT1:.+]] = vm.rodata.inline "_utf8_format1_
- // CHECK-DAG: %[[BINARY1:.+]] = vm.const.ref.rodata @exe_binary1 : !vm.buffer
+ // CHECK-DAG: %[[BINARY1:.+]] = vm.rodata.inline "exe_binary1" {alignment = 16 : i64} : !vm.buffer = dense<[0, 1, 2, 3]> : vector<4xi8>
// CHECK-DAG: %[[NULL1:.+]] = vm.const.ref.zero : !vm.buffer
// CHECK: %[[EXE1:.+]] = vm.call.variadic @hal.executable.create(
// CHECK-SAME: %[[DEV]], %[[FORMAT1]], %[[BINARY1]], %[[NULL1]], [%[[LAYOUT0]], %[[LAYOUT1]]]
@@ -32,7 +30,7 @@
%0 = hal.executable.create device(%device : !hal.device) target(@exe::@binary1) layouts([%layout0, %layout1]) : !hal.executable
// CHECK-DAG: %[[FORMAT2:.+]] = vm.rodata.inline "_utf8_format2_
- // CHECK-DAG: %[[BINARY2:.+]] = vm.const.ref.rodata @exe_binary2 : !vm.buffer
+ // CHECK-DAG: %[[BINARY2:.+]] = vm.rodata.inline "exe_binary2" {alignment = 16 : i64} : !vm.buffer = dense<[4, 5, 6, 7]> : vector<4xi8>
// CHECK-DAG: %[[NULL2:.+]] = vm.const.ref.zero : !vm.buffer
// CHECK: %[[EXE2:.+]] = vm.call.variadic @hal.executable.create(
// CHECK-SAME: %[[DEV]], %[[FORMAT2]], %[[BINARY2]], %[[NULL2]], [%[[LAYOUT1]], %[[LAYOUT0]]]
@@ -45,14 +43,12 @@
// -----
-// CHECK: vm.rodata private @exe1_binary1 {alignment = 16 : i64} dense<[0, 1, 2, 3]> : vector<4xi8>
hal.executable @exe1 {
hal.executable.binary @binary1 attributes {
data = dense<[0, 1, 2, 3]> : vector<4xi8>,
format = "format"
}
}
-// CHECK: vm.rodata private @exe2_binary2 {alignment = 16 : i64} dense<[4, 5, 6, 7]> : vector<4xi8>
hal.executable @exe2 {
hal.executable.binary @binary2 attributes {
data = dense<[4, 5, 6, 7]> : vector<4xi8>,
@@ -67,17 +63,16 @@
%layout1: !hal.pipeline_layout
) -> (!hal.executable, !hal.executable) {
// CHECK-DAG: %[[FORMAT1:.+]] = vm.rodata.inline "_utf8_format_
- // CHECK-DAG: %[[BINARY1:.+]] = vm.const.ref.rodata @exe1_binary1 : !vm.buffer
+ // CHECK-DAG: %[[BINARY1:.+]] = vm.rodata.inline "exe1_binary1" {alignment = 16 : i64} : !vm.buffer = dense<[0, 1, 2, 3]> : vector<4xi8>
%0 = hal.executable.create device(%device : !hal.device) target(@exe1::@binary1) layouts([%layout0, %layout1]) : !hal.executable
// CHECK-DAG: %[[FORMAT2:.+]] = vm.rodata.inline "_utf8_format_
- // CHECK-DAG: %[[BINARY2:.+]] = vm.const.ref.rodata @exe2_binary2 : !vm.buffer
+ // CHECK-DAG: %[[BINARY2:.+]] = vm.rodata.inline "exe2_binary2" {alignment = 16 : i64} : !vm.buffer = dense<[4, 5, 6, 7]> : vector<4xi8>
%1 = hal.executable.create device(%device : !hal.device) target(@exe2::@binary2) layouts([%layout1, %layout0]) : !hal.executable
util.return %0, %1 : !hal.executable, !hal.executable
}
// -----
-// CHECK: vm.rodata private @exe_binary {alignment = 16 : i64} dense<[0, 1, 2, 3]> : vector<4xi8>
hal.executable @exe {
hal.executable.binary @binary attributes {
data = dense<[0, 1, 2, 3]> : vector<4xi8>,
@@ -95,7 +90,7 @@
%constant0: i32, %constant1: i32
) -> !hal.executable {
// CHECK-DAG: %[[FORMAT:.+]] = vm.rodata.inline "_utf8_format_
- // CHECK-DAG: %[[BINARY:.+]] = vm.const.ref.rodata @exe_binary : !vm.buffer
+ // CHECK-DAG: %[[BINARY:.+]] = vm.rodata.inline "exe_binary" {alignment = 16 : i64} : !vm.buffer = dense<[0, 1, 2, 3]> : vector<4xi8>
// CHECK: %[[CONSTANTS:.+]] = vm.buffer.alloc %c12, %c16 : !vm.buffer
// CHECK-DAG: %[[INDEX0:.+]] = vm.const.i64 0
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
index 74c0fc9..a58f32a 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
@@ -23,24 +23,6 @@
namespace {
-// Returns the device queue affinity mask indicating which device queues the
-// operations are allowed to execute on.
-static Value buildQueueAffinityMask(Location loc,
- IREE::Stream::AffinityAttr affinityAttr,
- Value device, OpBuilder &builder) {
- // Try to find a specified affinity. This may be on the op provided or one of
- // its parent regions.
- if (auto queueAffinityAttr =
- llvm::dyn_cast_if_present<IREE::HAL::AffinityQueueAttr>(
- affinityAttr)) {
- return builder.create<arith::ConstantIntOp>(
- loc, queueAffinityAttr.getMask(), 64);
- }
-
- // No affinity specified; use default (any) affinity.
- return builder.create<arith::ConstantIntOp>(loc, -1, 64);
-}
-
struct ContextResolveOpPattern
: public StreamConversionPattern<IREE::Stream::ContextResolveOp> {
using StreamConversionPattern::StreamConversionPattern;
@@ -50,33 +32,34 @@
auto resultTypes = llvm::to_vector(resolveOp.getResultTypes());
assert(!resultTypes.empty() && "must have at least one result");
- // TODO(multi-device): emit get with derived ordinal or lookup with attr.
- Value device =
- IREE::HAL::DeviceType::resolveAny(resolveOp.getLoc(), rewriter);
+ // Get the affinity from the op or an ancestor. Note that there may be no
+ // affinity specified at all.
+ auto affinityAttr = IREE::Stream::AffinityAttr::lookupOrDefault(resolveOp);
- SmallVector<Value> results;
- if (isa<IREE::HAL::DeviceType>(resultTypes[0])) {
- results.push_back(device);
- } else if (isa<IREE::HAL::AllocatorType>(resultTypes[0])) {
- results.push_back(rewriter.create<IREE::HAL::DeviceAllocatorOp>(
- resolveOp.getLoc(), device));
- } else {
- return rewriter.notifyMatchFailure(
- resolveOp, "unrecognized context resolve types for a HAL target");
- }
- if (resultTypes.size() > 1) {
- if (isa<IntegerType>(resultTypes[1])) {
- results.push_back(buildQueueAffinityMask(
- resolveOp.getLoc(), resolveOp.getAffinityAttr(), device, rewriter));
- } else {
- return rewriter.notifyMatchFailure(
- resolveOp,
- "unrecognized context resolve types for a HAL target (extended)");
- }
+ // If no affinity was specified then resolve as 'any'.
+ if (!affinityAttr) {
+ rewriter.replaceOpWithNewOp<IREE::HAL::DeviceResolveOp>(
+ resolveOp, resolveOp.getResultTypes(),
+ IREE::HAL::DeviceAffinityAttr{});
+ return success();
}
- rewriter.replaceOp(resolveOp, results);
- return success();
+ // We currently only handle HAL device affinities.
+ // We could make this an interface to select the device and allow users to
+ // provide their own affinities to convert to HAL. In the future users may
+ // also want to provide devices as function arguments post-initialization.
+ // For now we just have one way to specify device globals.
+ if (auto deviceAffinityAttr =
+ dyn_cast_if_present<IREE::HAL::DeviceAffinityAttr>(affinityAttr)) {
+ rewriter.replaceOpWithNewOp<IREE::HAL::DeviceResolveOp>(
+ resolveOp, resolveOp.getResultTypes(), deviceAffinityAttr);
+ return success();
+ }
+
+ resolveOp.emitOpError() << "failed to resolve affinity: only HAL device "
+ "affinities are supported";
+ return rewriter.notifyMatchFailure(
+ resolveOp, "only HAL device affinities are supported");
}
};
@@ -675,7 +658,7 @@
// make this difficult. For now we assume each stream region being lowered
// has a singular affinity that may itself reference multiple devices in the
// future but currently uniquely identifies a device.
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(dispatchOp);
+ auto affinityAttr = IREE::Stream::AffinityAttr::lookupOrDefault(dispatchOp);
// Get the device handle we're executing against in this execution region.
// Note that this is a dynamic value: we have to treat the device as unknown
@@ -698,54 +681,67 @@
caseExportOps.push_back(std::make_pair(entryPointAttr, exportOp));
});
- // Select the variant index.
- Value selectedIndex = buildIfElseTree(
- loc, caseExportOps.size(),
- [&](Location loc, size_t i, OpBuilder &builder) {
- auto exportOp = caseExportOps[i].second;
- auto variantOp =
- exportOp->getParentOfType<IREE::HAL::ExecutableVariantOp>();
- return variantOp.buildCondition(device, rewriter);
- },
- rewriter);
-
- // Allow each variant to define how it is dispatched.
- auto switchOp = rewriter.replaceOpWithNewOp<scf::IndexSwitchOp>(
- dispatchOp, TypeRange{}, selectedIndex, caseIndices,
- caseIndices.size());
- for (size_t i = 0; i < caseExportOps.size(); ++i) {
- auto entryPointAttr = caseExportOps[i].first;
- auto exportOp = caseExportOps[i].second;
- auto &caseBlock = switchOp.getCaseRegions()[i].emplaceBlock();
- auto caseBuilder = OpBuilder::atBlockBegin(&caseBlock);
-
+ auto recordDispatch = [&](SymbolRefAttr entryPointAttr,
+ IREE::HAL::ExecutableExportOp exportOp,
+ OpBuilder &builder) {
// Record push constants and buffer bindings.
recordParameters(loc, affinityAttr, device, commandBuffer, exportOp,
- dispatchOp, adaptor, caseBuilder);
+ dispatchOp, adaptor, builder);
// Dispatch with a target-specific workgroup count.
- auto caseWorkgroupCount = exportOp.calculateWorkgroupCount(
- loc, device, adaptor.getWorkload(), caseBuilder);
- Value executable = caseBuilder.create<IREE::HAL::ExecutableLookupOp>(
- loc, caseBuilder.getType<IREE::HAL::ExecutableType>(), device,
+ auto workgroupCount = exportOp.calculateWorkgroupCount(
+ loc, device, adaptor.getWorkload(), builder);
+ Value executable = builder.create<IREE::HAL::ExecutableLookupOp>(
+ loc, builder.getType<IREE::HAL::ExecutableType>(), device,
entryPointAttr.getRootReference().getValue());
- Value ordinal = caseBuilder.create<IREE::HAL::ExecutableExportOrdinalOp>(
- loc, caseBuilder.getIndexType(), entryPointAttr);
- auto flags = caseBuilder.getAttr<IREE::HAL::DispatchFlagsAttr>(
+ Value ordinal = builder.create<IREE::HAL::ExecutableExportOrdinalOp>(
+ loc, builder.getIndexType(), entryPointAttr);
+ auto flags = builder.getAttr<IREE::HAL::DispatchFlagsAttr>(
IREE::HAL::DispatchFlags::None);
- caseBuilder.create<IREE::HAL::CommandBufferDispatchOp>(
- loc, commandBuffer, executable, ordinal, caseWorkgroupCount[0],
- caseWorkgroupCount[1], caseWorkgroupCount[2], flags);
+ return builder.create<IREE::HAL::CommandBufferDispatchOp>(
+ loc, commandBuffer, executable, ordinal, workgroupCount[0],
+ workgroupCount[1], workgroupCount[2], flags);
+ };
- caseBuilder.create<scf::YieldOp>(loc);
+ // If there is only one variant we can emit that directly without a
+ // conditional check. The same result should occur later on but it saves
+ // a lot of IR during generation if we know we can avoid it.
+ if (caseExportOps.size() == 1) {
+ auto [entryPointAttr, exportOp] = caseExportOps.front();
+ rewriter.replaceOp(dispatchOp,
+ recordDispatch(entryPointAttr, exportOp, rewriter));
+ } else {
+ // Select the variant index.
+ Value selectedIndex = buildIfElseTree(
+ loc, caseExportOps.size(),
+ [&](Location loc, size_t i, OpBuilder &builder) {
+ auto exportOp = caseExportOps[i].second;
+ auto variantOp =
+ exportOp->getParentOfType<IREE::HAL::ExecutableVariantOp>();
+ return variantOp.buildCondition(device, rewriter);
+ },
+ rewriter);
+
+ // Allow each variant to define how it is dispatched.
+ auto switchOp = rewriter.create<scf::IndexSwitchOp>(
+ loc, TypeRange{}, selectedIndex, caseIndices, caseIndices.size());
+ for (size_t i = 0; i < caseExportOps.size(); ++i) {
+ auto [entryPointAttr, exportOp] = caseExportOps[i];
+ auto &caseBlock = switchOp.getCaseRegions()[i].emplaceBlock();
+ auto caseBuilder = OpBuilder::atBlockBegin(&caseBlock);
+ recordDispatch(entryPointAttr, exportOp, caseBuilder);
+ caseBuilder.create<scf::YieldOp>(loc);
+ }
+
+ // Fallback for no available variant. Today we just no-op as executable
+ // loading should have already failed.
+ auto &defaultBlock = switchOp.getDefaultRegion().emplaceBlock();
+ auto defaultBuilder = OpBuilder::atBlockBegin(&defaultBlock);
+ defaultBuilder.create<scf::YieldOp>(loc);
+
+ rewriter.replaceOp(dispatchOp, switchOp);
}
- // Fallback for no available variant. Today we just no-op as executable
- // loading should have already failed.
- auto &defaultBlock = switchOp.getDefaultRegion().emplaceBlock();
- auto defaultBuilder = OpBuilder::atBlockBegin(&defaultBlock);
- defaultBuilder.create<scf::YieldOp>(loc);
-
return success();
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.cpp
index 8b628da..5c685c0 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.cpp
@@ -22,7 +22,7 @@
namespace mlir::iree_compiler {
Value lookupDeviceFor(Operation *op, OpBuilder &builder) {
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
+ auto affinityAttr = IREE::Stream::AffinityAttr::lookupOrDefault(op);
auto resolveOp = builder.create<IREE::Stream::ContextResolveOp>(
op->getLoc(),
TypeRange{
@@ -34,7 +34,7 @@
std::tuple<Value, Value> lookupDeviceAndQueueAffinityFor(Operation *op,
OpBuilder &builder) {
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
+ auto affinityAttr = IREE::Stream::AffinityAttr::lookupOrDefault(op);
auto resolveOp = builder.create<IREE::Stream::ContextResolveOp>(
op->getLoc(),
TypeRange{
@@ -46,7 +46,7 @@
}
Value lookupAllocatorFor(Operation *op, OpBuilder &builder) {
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
+ auto affinityAttr = IREE::Stream::AffinityAttr::lookupOrDefault(op);
auto resolveOp = builder.create<IREE::Stream::ContextResolveOp>(
op->getLoc(),
TypeRange{
@@ -58,7 +58,7 @@
std::tuple<Value, Value>
lookupAllocatorAndQueueAffinityFor(Operation *op, OpBuilder &builder) {
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
+ auto affinityAttr = IREE::Stream::AffinityAttr::lookupOrDefault(op);
auto resolveOp = builder.create<IREE::Stream::ContextResolveOp>(
op->getLoc(),
TypeRange{
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/channel_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/channel_ops.mlir
index 3f88bd1..bb2108f 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/channel_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/channel_ops.mlir
@@ -1,15 +1,17 @@
// RUN: iree-opt --split-input-file --iree-hal-conversion %s | FileCheck %s
+util.global private @device : !hal.device
+
// CHECK-LABEL: @channel_create
// CHECK-SAME: () -> !hal.channel
util.func public @channel_create() -> !stream.channel {
- // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}} : !hal.device
+ // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device
// CHECK-DAG: %[[AFFINITY:.+]] = arith.constant 3
// CHECK-DAG: %[[ID:.+]] = util.null : !util.buffer
// CHECK-DAG: %[[GROUP:.+]] = util.buffer.constant : !util.buffer = "group"
// CHECK-DAG: %[[DEFAULT:.+]] = arith.constant -1
// CHECK: %[[CHANNEL:.+]] = hal.channel.create device(%[[DEVICE]] : !hal.device) affinity(%[[AFFINITY]]) flags(0) id(%[[ID]]) group(%[[GROUP]]) rank(%[[DEFAULT]]) count(%[[DEFAULT]]) : !hal.channel
- %channel = stream.channel.create on(#hal.affinity.queue<[0, 1]>) group("group") : !stream.channel
+ %channel = stream.channel.create on(#hal.device.affinity<@device, [0, 1]>) group("group") : !stream.channel
// CHECK: util.return %[[CHANNEL]]
util.return %channel : !stream.channel
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir
index 7cdd991..941c15b 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir
@@ -3,12 +3,14 @@
// Today all memory control operations are ignored and we're just left with
// the normal sequential execution barriers.
+util.global private @device : !hal.device
+
// CHECK-LABEL: @cmdMemoryControl
util.func public @cmdMemoryControl(%arg0: !stream.resource<transient>, %arg1: index) -> !stream.timepoint {
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
// CHECK: %[[CMD:.+]] = hal.command_buffer.create
- %0 = stream.cmd.execute with(%arg0 as %arg2: !stream.resource<transient>{%arg1}) {
+ %0 = stream.cmd.execute on(#hal.device.affinity<@device>) with(%arg0 as %arg2: !stream.resource<transient>{%arg1}) {
// CHECK-NEXT: hal.command_buffer.execution_barrier<%[[CMD]]
stream.cmd.flush %arg2[%c0 for %c128] : !stream.resource<transient>{%arg1}
// CHECK-NEXT: hal.command_buffer.execution_barrier<%[[CMD]]
@@ -22,13 +24,15 @@
// -----
+util.global private @device : !hal.device
+
// CHECK-LABEL: @cmdFill
util.func public @cmdFill(%arg0: !stream.resource<transient>, %arg1: index) -> !stream.timepoint {
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
%c255_i32 = arith.constant 255 : i32
// CHECK: %[[CMD:.+]] = hal.command_buffer.create
- %0 = stream.cmd.execute with(%arg0 as %arg2: !stream.resource<transient>{%arg1}) {
+ %0 = stream.cmd.execute on(#hal.device.affinity<@device>) with(%arg0 as %arg2: !stream.resource<transient>{%arg1}) {
// CHECK-NEXT: hal.command_buffer.fill_buffer<%[[CMD]] : !hal.command_buffer>
// CHECK-SAME: target(%arg0 : !hal.buffer)[%c0, %c128]
// CHECK-SAME: pattern(%c255_i32 : i32)
@@ -41,12 +45,14 @@
// -----
+util.global private @device : !hal.device
+
// CHECK-LABEL: @cmdCopy
util.func public @cmdCopy(%arg0: !stream.resource<transient>, %arg1: index, %arg2: !stream.resource<staging>, %arg3: index) -> !stream.timepoint {
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
// CHECK: %[[CMD:.+]] = hal.command_buffer.create
- %0 = stream.cmd.execute with(%arg0 as %arg4: !stream.resource<transient>{%arg1}, %arg2 as %arg5: !stream.resource<staging>{%arg3}) {
+ %0 = stream.cmd.execute on(#hal.device.affinity<@device>) with(%arg0 as %arg4: !stream.resource<transient>{%arg1}, %arg2 as %arg5: !stream.resource<staging>{%arg3}) {
// CHECK-NEXT: hal.command_buffer.copy_buffer<%[[CMD]] : !hal.command_buffer>
// CHECK-SAME: source(%arg0 : !hal.buffer)[%c0]
// CHECK-SAME: target(%arg2 : !hal.buffer)[%c0]
@@ -60,12 +66,14 @@
// -----
+util.global private @device : !hal.device
+
// CHECK-LABEL: @cmdCollective
util.func public @cmdCollective(%arg0: !stream.resource<transient>, %arg1: index, %arg2: !stream.resource<transient>, %arg3: index, %arg4: !stream.channel) -> !stream.timepoint {
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
// CHECK: %[[CMD:.+]] = hal.command_buffer.create
- %0 = stream.cmd.execute with(%arg0 as %arg5: !stream.resource<transient>{%arg1}, %arg2 as %arg6: !stream.resource<transient>{%arg3}) {
+ %0 = stream.cmd.execute on(#hal.device.affinity<@device>) with(%arg0 as %arg5: !stream.resource<transient>{%arg1}, %arg2 as %arg6: !stream.resource<transient>{%arg3}) {
// Out-of-place all-reduce:
// CHECK-NEXT: hal.command_buffer.collective
@@ -127,12 +135,14 @@
// than we actually need and guard a lot more work than we otherwise would need
// to.
+util.global private @device : !hal.device
+
// CHECK-LABEL: @cmdExecute
util.func public @cmdExecute(%arg0: !stream.resource<transient>, %arg1: index, %arg2: !stream.resource<staging>, %arg3: index, %arg4: !stream.timepoint) -> !stream.timepoint {
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
// CHECK: %[[CMD:.+]] = hal.command_buffer.create
- %0 = stream.cmd.execute await(%arg4) => with(%arg0 as %arg5: !stream.resource<transient>{%arg1}, %arg2 as %arg6: !stream.resource<staging>{%arg3}) {
+ %0 = stream.cmd.execute on(#hal.device.affinity<@device>) await(%arg4) => with(%arg0 as %arg5: !stream.resource<transient>{%arg1}, %arg2 as %arg6: !stream.resource<staging>{%arg3}) {
stream.cmd.concurrent {
// CHECK-NEXT: hal.command_buffer.copy_buffer<%[[CMD]]
stream.cmd.copy %arg5[%c0], %arg6[%c0], %c128 : !stream.resource<transient>{%arg1} -> !stream.resource<staging>{%arg3}
@@ -166,10 +176,6 @@
#executable_target_aarch64 = #hal.executable.target<"llvm-cpu", "embedded-elf-aarch64">
#executable_target_x86_64 = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64">
-#device_target_cpu = #hal.device.target<"llvm-cpu", [
- #executable_target_aarch64,
- #executable_target_x86_64
-]>
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<4, storage_buffer>
@@ -219,6 +225,8 @@
}
}
+util.global private @device : !hal.device
+
// CHECK-LABEL: @cmdDispatch
util.func public @cmdDispatch(%arg0: !stream.resource<transient>, %arg1: index, %arg2: !stream.resource<external>, %arg3: index) -> !stream.timepoint {
%c0 = arith.constant 0 : index
@@ -229,7 +237,7 @@
%c5_i32 = arith.constant 5 : i32
%c128 = arith.constant 128 : index
// CHECK: %[[CMD:.+]] = hal.command_buffer.create
- %0 = stream.cmd.execute with(%arg0 as %arg4: !stream.resource<transient>{%arg1}, %arg2 as %arg5: !stream.resource<external>{%arg3}) {
+ %0 = stream.cmd.execute on(#hal.device.affinity<@device>) with(%arg0 as %arg4: !stream.resource<transient>{%arg1}, %arg2 as %arg5: !stream.resource<external>{%arg3}) {
// Switch for each executable variant by checking conditions and ranking:
// CHECK: %[[DEVICE:.+]] = hal.command_buffer.device<%[[CMD]] : !hal.command_buffer>
// CHECK-DAG: %{{.+}}, %[[AARCH64_FORMAT:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.executable.format" :: "embedded-elf-aarch64")
@@ -297,6 +305,8 @@
// Tests conversion of streamable calls and function declarations.
// Expect a command buffer and a buffer + offset + length for each resource.
+util.global private @device : !hal.device
+
// CHECK: util.func private @cmdFunc(%arg0: !hal.command_buffer, %arg1: !hal.buffer, %arg2: index, %arg3: index, %arg4: i32, %arg5: !hal.buffer, %arg6: index, %arg7: index, %arg8: !custom.type, %arg9: !hal.buffer, %arg10: index, %arg11: index)
stream.cmd.func private @cmdFunc(%arg0[%arg1 for %arg2]: !stream.resource<*>, %arg3: i32, %arg4[%arg5 for %arg6]: !stream.resource<*>, %arg7: !custom.type, %arg8[%arg9 for %arg10]: !stream.resource<*>)
@@ -310,7 +320,7 @@
// CHECK-DAG: %[[SIZE2:.+]] = arith.constant 102
%size2 = arith.constant 102 : index
// CHECK: %[[COMMAND_BUFFER:.+]] = hal.command_buffer.create
- %timepoint = stream.cmd.execute with(%arg0 as %stream0: !stream.resource<external>{%size0}, %arg2 as %stream1: !stream.resource<external>{%size1}, %arg4 as %stream2: !stream.resource<external>{%size2}) {
+ %timepoint = stream.cmd.execute on(#hal.device.affinity<@device>) with(%arg0 as %stream0: !stream.resource<external>{%size0}, %arg2 as %stream1: !stream.resource<external>{%size1}, %arg4 as %stream2: !stream.resource<external>{%size2}) {
// CHECK: util.call @cmdFunc(%[[COMMAND_BUFFER]], %arg0, %c0, %[[SIZE0]], %arg1, %arg2, %c0, %[[SIZE1]], %arg3, %arg4, %c0, %[[SIZE2]]) :
// CHECK-SAME: (!hal.command_buffer, !hal.buffer, index, index, i32, !hal.buffer, index, index, !custom.type, !hal.buffer, index, index) -> ()
stream.cmd.call @cmdFunc(ro %stream0[%c0 for %size0], %arg1, rw %stream1[%c0 for %size1], %arg3, wo %stream2[%c0 for %size2]) : (!stream.resource<external>{%size0}, i32, !stream.resource<external>{%size1}, !custom.type, !stream.resource<external>{%size2}) -> ()
@@ -324,12 +334,14 @@
// appropriate queue affinity mask. The final affinity is the result of ORing
// the target affinities (0b01 | 0b10 = 0b11 = 3).
+util.global private @device : !hal.device
+
// CHECK-LABEL: @cmdExecuteAffinities
util.func public @cmdExecuteAffinities(%arg0: !stream.resource<transient>, %arg1: index, %arg2: !stream.resource<staging>, %arg3: index, %arg4: !stream.timepoint) -> !stream.timepoint {
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
// CHECK: %[[CMD:.+]] = hal.command_buffer.create
- %0 = stream.cmd.execute on(#hal.affinity.queue<[0, 1]>) await(%arg4) => with(%arg0 as %arg5: !stream.resource<transient>{%arg1}, %arg2 as %arg6: !stream.resource<staging>{%arg3}) {
+ %0 = stream.cmd.execute on(#hal.device.affinity<@device, [0, 1]>) await(%arg4) => with(%arg0 as %arg5: !stream.resource<transient>{%arg1}, %arg2 as %arg6: !stream.resource<staging>{%arg3}) {
stream.cmd.copy %arg5[%c0], %arg6[%c0], %c128 : !stream.resource<transient>{%arg1} -> !stream.resource<staging>{%arg3}
} => !stream.timepoint
// CHECK: hal.device.queue.execute
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir
index 5d73951..c60d0da 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir
@@ -1,19 +1,14 @@
// RUN: iree-opt --split-input-file --allow-unregistered-dialect --iree-hal-conversion %s | FileCheck %s
-// CHECK-LABEL: @contextResolveAllocator
-util.func public @contextResolveAllocator() -> !hal.allocator {
- // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
- // CHECK: %[[ALLOCATOR:.+]] = hal.device.allocator<%[[DEVICE]] : !hal.device> : !hal.allocator
- %allocator = stream.context.resolve : !hal.allocator
- // CHECK: util.return %[[ALLOCATOR]]
- util.return %allocator : !hal.allocator
-}
+// NOTE: the hal.device.resolve lowering in HAL-to-HAL does most of the work.
-// -----
+util.global private @device : !hal.device
-// CHECK-LABEL: @contextResolveDevice
-util.func public @contextResolveDevice() -> !hal.device {
- // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
+// CHECK-LABEL: @contextResolveDefaultDevice
+util.func public @contextResolveDefaultDevice() -> !hal.device attributes {
+ stream.affinity = #hal.device.affinity<@device>
+} {
+ // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device
%device = stream.context.resolve : !hal.device
// CHECK: util.return %[[DEVICE]]
util.return %device : !hal.device
@@ -21,22 +16,14 @@
// -----
-// CHECK-LABEL: @contextResolveDeviceQueueAffinityAny
-util.func public @contextResolveDeviceQueueAffinityAny() -> (!hal.device, i64) {
- // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
- // CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant -1 : i64
- %device, %queue_affinity_any = stream.context.resolve on(#hal.affinity.queue<*>) : !hal.device, i64
- // CHECK: util.return %[[DEVICE]], %[[QUEUE_AFFINITY]]
- util.return %device, %queue_affinity_any : !hal.device, i64
-}
+util.global private @device : !hal.device
-// -----
-
-// CHECK-LABEL: @contextResolveDeviceQueueAffinity45
-util.func public @contextResolveDeviceQueueAffinity45() -> (!hal.device, i64) {
- // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
+// CHECK-LABEL: @contextResolveAllocatorQueueAffinity45
+util.func public @contextResolveAllocatorQueueAffinity45() -> (!hal.device, !hal.allocator, i64) {
+ // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device
+ // CHECK-DAG: %[[ALLOCATOR:.+]] = hal.device.allocator<%[[DEVICE]] : !hal.device> : !hal.allocator
// CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant 48 : i64
- %device, %queue_affinity_45 = stream.context.resolve on(#hal.affinity.queue<[4, 5]>) : !hal.device, i64
- // CHECK: util.return %[[DEVICE]], %[[QUEUE_AFFINITY]]
- util.return %device, %queue_affinity_45 : !hal.device, i64
+ %device, %allocator, %queue_affinity_45 = stream.context.resolve on(#hal.device.affinity<@device, [4, 5]>) : !hal.device, !hal.allocator, i64
+ // CHECK: util.return %[[DEVICE]], %[[ALLOCATOR]], %[[QUEUE_AFFINITY]]
+ util.return %device, %allocator, %queue_affinity_45 : !hal.device, !hal.allocator, i64
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/file_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/file_ops.mlir
index 1182ee4..efa925a 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/file_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/file_ops.mlir
@@ -1,44 +1,50 @@
// RUN: iree-opt --split-input-file --iree-hal-conversion %s | FileCheck %s
+util.global private @device : !hal.device
+
// CHECK-LABEL: @file_constant
// CHECK-SAME: (%[[BUFFER:.+]]: !util.buffer)
util.func public @file_constant(%buffer: !util.buffer) {
%c0 = arith.constant 0 : index
%c1088 = arith.constant 1088 : index
- // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
+ // CHECK: %[[DEVICE:.+]] = util.global.load immutable @device
// CHECK: = hal.ex.file.from_memory device(%[[DEVICE]] : !hal.device) affinity(%c-1_i64) access(Read) buffer(%[[BUFFER]] : !util.buffer)[%c0 for %c1088] flags(%c0_i32) : !hal.file
- %file = stream.file.constant %buffer[%c0 for %c1088] : !util.buffer{%c1088} -> !stream.file
+ %file = stream.file.constant on(#hal.device.affinity<@device>) %buffer[%c0 for %c1088] : !util.buffer{%c1088} -> !stream.file
util.return
}
// -----
+util.global private @device : !hal.device
+
// CHECK-LABEL: @file_read
// CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[FILE:.+]]: !hal.file, %[[RESOURCE:.+]]: !hal.buffer)
util.func public @file_read(%wait: !stream.timepoint, %file: !stream.file, %resource: !stream.resource<variable>) -> !stream.timepoint {
%c0 = arith.constant 0 : index
%c0_i64 = arith.constant 0 : i64
%c1088 = arith.constant 1088 : index
- // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
+ // CHECK: %[[DEVICE:.+]] = util.global.load immutable @device
// CHECK: %[[SIGNAL:.+]] = hal.fence.create
// CHECK: hal.device.queue.read<%[[DEVICE]] : !hal.device> affinity(%c-1_i64) wait(%[[WAIT]]) signal(%[[SIGNAL]]) source(%[[FILE]] : !hal.file)[%c0_i64] target(%[[RESOURCE]] : !hal.buffer)[%c0] length(%c1088) flags(0)
- %signal = stream.file.read await(%wait) => %file[%c0_i64], %resource[%c0], %c1088 : !stream.file -> !stream.resource<variable>{%c1088} => !stream.timepoint
+ %signal = stream.file.read on(#hal.device.affinity<@device>) await(%wait) => %file[%c0_i64], %resource[%c0], %c1088 : !stream.file -> !stream.resource<variable>{%c1088} => !stream.timepoint
// CHECK: util.return %[[SIGNAL]]
util.return %signal : !stream.timepoint
}
// -----
+util.global private @device : !hal.device
+
// CHECK-LABEL: @file_write
// CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[FILE:.+]]: !hal.file, %[[RESOURCE:.+]]: !hal.buffer)
util.func public @file_write(%wait: !stream.timepoint, %file: !stream.file, %resource: !stream.resource<variable>) -> !stream.timepoint {
%c0 = arith.constant 0 : index
%c0_i64 = arith.constant 0 : i64
%c1088 = arith.constant 1088 : index
- // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
+ // CHECK: %[[DEVICE:.+]] = util.global.load immutable @device
// CHECK: %[[SIGNAL:.+]] = hal.fence.create
// CHECK: hal.device.queue.write<%[[DEVICE]] : !hal.device> affinity(%c-1_i64) wait(%[[WAIT]]) signal(%[[SIGNAL]]) source(%[[RESOURCE]] : !hal.buffer)[%c0] target(%[[FILE]] : !hal.file)[%c0_i64] length(%c1088) flags(0)
- %signal = stream.file.write await(%wait) => %resource[%c0], %file[%c0_i64], %c1088 : !stream.resource<variable>{%c1088} -> !stream.file => !stream.timepoint
+ %signal = stream.file.write on(#hal.device.affinity<@device>) await(%wait) => %resource[%c0], %file[%c0_i64], %c1088 : !stream.resource<variable>{%c1088} -> !stream.file => !stream.timepoint
// CHECK: util.return %[[SIGNAL]]
util.return %signal : !stream.timepoint
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/resource_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/resource_ops.mlir
index 6af93ee..09f7046 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/resource_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/resource_ops.mlir
@@ -1,18 +1,22 @@
// RUN: iree-opt --split-input-file --iree-hal-conversion %s | FileCheck %s
+util.global private @device : !hal.device
+
// CHECK-LABEL: @resourceAlloc
util.func public @resourceAlloc(%arg0: index) -> !stream.resource<transient> {
// CHECK: %[[RET0:.+]] = hal.allocator.allocate
// CHECK-SAME: type("DeviceVisible|DeviceLocal")
// CHECK-SAME: usage("{{.+}}Transfer{{.+}}Dispatch{{.+}}")
// CHECK-SAME: : !hal.buffer{%arg0}
- %0 = stream.resource.alloc uninitialized : !stream.resource<transient>{%arg0}
+ %0 = stream.resource.alloc uninitialized on(#hal.device.affinity<@device>) : !stream.resource<transient>{%arg0}
// CHECK: util.return %[[RET0]]
util.return %0 : !stream.resource<transient>
}
// -----
+util.global private @device : !hal.device
+
// CHECK-LABEL: @resourceAlloca
// CHECK-SAME: (%[[SIZE:.+]]: index)
util.func public @resourceAlloca(%size: index) -> (!stream.resource<transient>, !stream.timepoint) {
@@ -26,13 +30,15 @@
// CHECK-SAME: type("DeviceVisible|DeviceLocal")
// CHECK-SAME: usage("{{.+}}Transfer{{.+}}Dispatch{{.+}}")
// CHECK-SAME: : !hal.buffer{%[[SIZE]]}
- %0:2 = stream.resource.alloca uninitialized : !stream.resource<transient>{%size} => !stream.timepoint
+ %0:2 = stream.resource.alloca uninitialized on(#hal.device.affinity<@device>) : !stream.resource<transient>{%size} => !stream.timepoint
// CHECK: util.return %[[RET0]], %[[SIGNAL_FENCE]]
util.return %0#0, %0#1 : !stream.resource<transient>, !stream.timepoint
}
// -----
+util.global private @device : !hal.device
+
// CHECK-LABEL: @resourceAllocaAwait
// CHECK-SAME: (%[[SIZE:.+]]: index, %[[WAIT_FENCE:.+]]: !hal.fence)
util.func public @resourceAllocaAwait(%size: index, %await_timepoint: !stream.timepoint) -> (!stream.resource<transient>, !stream.timepoint) {
@@ -45,13 +51,15 @@
// CHECK-SAME: type("DeviceVisible|DeviceLocal")
// CHECK-SAME: usage("{{.+}}Transfer{{.+}}Dispatch{{.+}}")
// CHECK-SAME: : !hal.buffer{%[[SIZE]]}
- %0:2 = stream.resource.alloca uninitialized await(%await_timepoint) => !stream.resource<transient>{%size} => !stream.timepoint
+ %0:2 = stream.resource.alloca uninitialized on(#hal.device.affinity<@device>) await(%await_timepoint) => !stream.resource<transient>{%size} => !stream.timepoint
// CHECK: util.return %[[RET0]], %[[SIGNAL_FENCE]]
util.return %0#0, %0#1 : !stream.resource<transient>, !stream.timepoint
}
// -----
+util.global private @device : !hal.device
+
// CHECK-LABEL: @resourceDealloca
// CHECK-SAME: (%[[SIZE:.+]]: index, %[[RESOURCE:.+]]: !hal.buffer)
util.func public @resourceDealloca(%size: index, %resource: !stream.resource<transient>) -> !stream.timepoint {
@@ -62,14 +70,14 @@
// CHECK-SAME: wait(%[[WAIT_FENCE]])
// CHECK-SAME: signal(%[[SIGNAL_FENCE]])
// CHECK-SAME: buffer(%[[RESOURCE]] : !hal.buffer)
- %0 = stream.resource.dealloca %resource : !stream.resource<transient>{%size} => !stream.timepoint
+ %0 = stream.resource.dealloca on(#hal.device.affinity<@device>) %resource : !stream.resource<transient>{%size} => !stream.timepoint
// CHECK: util.return %[[SIGNAL_FENCE]]
util.return %0 : !stream.timepoint
}
// -----
-// TODO(#9572): implement stream ordered allocations.
+util.global private @device : !hal.device
// CHECK-LABEL: @resourceDeallocaAwait
// CHECK-SAME: (%[[SIZE:.+]]: index, %[[RESOURCE:.+]]: !hal.buffer, %[[WAIT_FENCE:.+]]: !hal.fence)
@@ -80,7 +88,7 @@
// CHECK-SAME: wait(%[[WAIT_FENCE]])
// CHECK-SAME: signal(%[[SIGNAL_FENCE]])
// CHECK-SAME: buffer(%[[RESOURCE]] : !hal.buffer)
- %0 = stream.resource.dealloca await(%await_timepoint) => %resource : !stream.resource<transient>{%size} => !stream.timepoint
+ %0 = stream.resource.dealloca on(#hal.device.affinity<@device>) await(%await_timepoint) => %resource : !stream.resource<transient>{%size} => !stream.timepoint
// CHECK: util.return %[[SIGNAL_FENCE]]
util.return %0 : !stream.timepoint
}
@@ -97,6 +105,8 @@
// -----
+util.global private @device : !hal.device
+
// CHECK-LABEL: @resourceTryMap
util.func public @resourceTryMap(%arg0: !util.buffer) -> (i1, !stream.resource<constant>) {
%c0 = arith.constant 0 : index
@@ -105,7 +115,7 @@
// CHECK-SAME: source(%arg0 : !util.buffer)[%c0, %c128]
// CHECK-SAME: type("DeviceVisible|DeviceLocal")
// CHECK-SAME: usage("{{.+}}Transfer{{.+}}Dispatch{{.+}}SharingImmutable") : i1, !hal.
- %did_map, %mapping = stream.resource.try_map %arg0[%c0] : !util.buffer -> i1, !stream.resource<constant>{%c128}
+ %did_map, %mapping = stream.resource.try_map on(#hal.device.affinity<@device>) %arg0[%c0] : !util.buffer -> i1, !stream.resource<constant>{%c128}
// CHECK: util.return %[[DID_IMPORT]], %[[IMPORTED]]
util.return %did_map, %mapping : i1, !stream.resource<constant>
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/timepoint_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/timepoint_ops.mlir
index 8a7b691..007f457 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/timepoint_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/timepoint_ops.mlir
@@ -42,12 +42,14 @@
// -----
+util.global private @device : !hal.device
+
// CHECK-LABEL: @timepointChainExternal
// CHECK-SAME: (%[[TIMEPOINT:.+]]: !hal.fence, %[[SIGNAL:.+]]: !hal.fence)
util.func public @timepointChainExternal(%timepoint: !stream.timepoint, %signal: !hal.fence) {
- // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
+ // CHECK: %[[DEVICE:.+]] = util.global.load immutable @device
// CHECK: hal.device.queue.execute<%[[DEVICE]] : !hal.device> affinity(%c-1_i64) wait(%[[TIMEPOINT]]) signal(%[[SIGNAL]])
- stream.timepoint.chain_external %timepoint => (%signal : !hal.fence)
+ stream.timepoint.chain_external on(#hal.device.affinity<@device>) %timepoint => (%signal : !hal.fence)
util.return
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/transfer_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/transfer_ops.mlir
index 1dbcc24..5805f71 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/transfer_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/transfer_ops.mlir
@@ -1,5 +1,7 @@
// RUN: iree-opt --split-input-file --iree-hal-conversion %s | FileCheck %s
+util.global private @device : !hal.device
+
// CHECK-LABEL: @tensorImportBuffer
util.func public @tensorImportBuffer(%arg0: !hal.buffer, %arg1: index) -> !stream.resource<external> {
%c20 = arith.constant 20 : index
@@ -10,7 +12,7 @@
// CHECK-SAME: minimum_length(%c20)
// CHECK-SAME: type(DeviceVisible)
// CHECK-SAME: usage("Transfer{{.+}}Dispatch{{.+}}")
- %0 = stream.tensor.import %arg0 : !hal.buffer -> tensor<?x5xf32>{%arg1} in !stream.resource<external>{%c20}
+ %0 = stream.tensor.import on(#hal.device.affinity<@device>) %arg0 : !hal.buffer -> tensor<?x5xf32>{%arg1} in !stream.resource<external>{%c20}
// CHECK: util.return %arg0
util.return %0 : !stream.resource<external>
}
@@ -21,6 +23,8 @@
// when lowering into the stream dialect; here we only care about the storage
// buffer itself.
+util.global private @device : !hal.device
+
// CHECK-LABEL: @tensorImportBufferView
util.func public @tensorImportBufferView(%arg0: !hal.buffer_view, %arg1: index) -> !stream.resource<external> {
%c20 = arith.constant 20 : index
@@ -32,23 +36,27 @@
// CHECK-SAME: minimum_length(%c20)
// CHECK-SAME: type(DeviceVisible)
// CHECK-SAME: usage("Transfer{{.+}}Dispatch{{.+}}")
- %0 = stream.tensor.import %arg0 : !hal.buffer_view -> tensor<?x5xf32>{%arg1} in !stream.resource<external>{%c20}
+ %0 = stream.tensor.import on(#hal.device.affinity<@device>) %arg0 : !hal.buffer_view -> tensor<?x5xf32>{%arg1} in !stream.resource<external>{%c20}
// CHECK: util.return %[[BUFFER]]
util.return %0 : !stream.resource<external>
}
// -----
+util.global private @device : !hal.device
+
// CHECK-LABEL: @tensorExportBuffer
util.func public @tensorExportBuffer(%arg0: !stream.resource<external>, %arg1: index) -> !hal.buffer {
%c200 = arith.constant 200 : index
- %0 = stream.tensor.export %arg0 : tensor<?x1x10xf32>{%arg1} in !stream.resource<external>{%c200} -> !hal.buffer
+ %0 = stream.tensor.export on(#hal.device.affinity<@device>) %arg0 : tensor<?x1x10xf32>{%arg1} in !stream.resource<external>{%c200} -> !hal.buffer
// CHECK: util.return %arg0 : !hal.buffer
util.return %0 : !hal.buffer
}
// -----
+util.global private @device : !hal.device
+
// CHECK-LABEL: @tensorExportBufferView
util.func public @tensorExportBufferView(%arg0: !stream.resource<external>, %arg1: index) -> !hal.buffer_view {
%c200 = arith.constant 200 : index
@@ -60,7 +68,7 @@
// CHECK-SAME: type(%[[ELEMENT_TYPE]])
// CHECK-SAME: encoding(%[[ENCODING_TYPE]])
// CHECK-SAME: : !hal.buffer_view
- %0 = stream.tensor.export %arg0 : tensor<?x1x10xf32>{%arg1} in !stream.resource<external>{%c200} -> !hal.buffer_view
+ %0 = stream.tensor.export on(#hal.device.affinity<@device>) %arg0 : tensor<?x1x10xf32>{%arg1} in !stream.resource<external>{%c200} -> !hal.buffer_view
// CHECK: util.return %[[VIEW]]
util.return %0 : !hal.buffer_view
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp
index 1ff5f24..fe32e8b 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp
@@ -12,6 +12,7 @@
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/Path.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Parser/Parser.h"
// clang-format off: must be included after all LLVM/MLIR headers.
@@ -23,7 +24,7 @@
namespace mlir::iree_compiler::IREE::HAL {
//===----------------------------------------------------------------------===//
-// Enum utilities
+// Utilities
//===----------------------------------------------------------------------===//
template <typename AttrType>
@@ -85,270 +86,6 @@
}
//===----------------------------------------------------------------------===//
-// #hal.device.target<*>
-//===----------------------------------------------------------------------===//
-
-// static
-DeviceTargetAttr DeviceTargetAttr::get(MLIRContext *context,
- StringRef deviceID) {
- // TODO(benvanik): query default configuration from the target backend.
- return get(context, StringAttr::get(context, deviceID),
- DictionaryAttr::get(context), {});
-}
-
-// static
-Attribute DeviceTargetAttr::parse(AsmParser &p, Type type) {
- StringAttr deviceIDAttr;
- DictionaryAttr configAttr;
- SmallVector<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs;
- // `<"device-id"`
- if (failed(p.parseLess()) || failed(p.parseAttribute(deviceIDAttr))) {
- return {};
- }
- // `, `
- if (succeeded(p.parseOptionalComma())) {
- if (succeeded(p.parseOptionalLSquare())) {
- // `[targets, ...]` (optional)
- do {
- IREE::HAL::ExecutableTargetAttr executableTargetAttr;
- if (failed(p.parseAttribute(executableTargetAttr)))
- return {};
- executableTargetAttrs.push_back(executableTargetAttr);
- } while (succeeded(p.parseOptionalComma()));
- if (failed(p.parseRSquare()))
- return {};
- } else {
- // `{config dict}` (optional)
- if (failed(p.parseAttribute(configAttr)))
- return {};
- // `, [targets, ...]` (optional)
- if (succeeded(p.parseOptionalComma())) {
- if (failed(p.parseLSquare()))
- return {};
- do {
- IREE::HAL::ExecutableTargetAttr executableTargetAttr;
- if (failed(p.parseAttribute(executableTargetAttr)))
- return {};
- executableTargetAttrs.push_back(executableTargetAttr);
- } while (succeeded(p.parseOptionalComma()));
- if (failed(p.parseRSquare()))
- return {};
- }
- }
- }
- // `>`
- if (failed(p.parseGreater())) {
- return {};
- }
- return get(p.getContext(), deviceIDAttr, configAttr, executableTargetAttrs);
-}
-
-void DeviceTargetAttr::print(AsmPrinter &p) const {
- auto &os = p.getStream();
- os << "<";
- p.printAttribute(getDeviceID());
- auto configAttr = getConfiguration();
- if (configAttr && !configAttr.empty()) {
- os << ", ";
- p.printAttribute(configAttr);
- }
- auto executableTargetAttrs = getExecutableTargets();
- if (!executableTargetAttrs.empty()) {
- os << ", [";
- llvm::interleaveComma(executableTargetAttrs, os,
- [&](auto executableTargetAttr) {
- p.printAttribute(executableTargetAttr);
- });
- os << "]";
- }
- os << ">";
-}
-
-std::string DeviceTargetAttr::getSymbolNameFragment() {
- return sanitizeSymbolName(getDeviceID().getValue().lower());
-}
-
-bool DeviceTargetAttr::hasConfigurationAttr(StringRef name) {
- auto configAttr = getConfiguration();
- return configAttr && configAttr.get(name);
-}
-
-// static
-SmallVector<IREE::HAL::DeviceTargetAttr, 4>
-DeviceTargetAttr::lookup(Operation *op) {
- auto attrId = mlir::StringAttr::get(op->getContext(), "hal.device.targets");
- while (op) {
- auto targetsAttr = op->getAttrOfType<ArrayAttr>(attrId);
- if (targetsAttr) {
- SmallVector<IREE::HAL::DeviceTargetAttr, 4> result;
- for (auto targetAttr : targetsAttr) {
- result.push_back(llvm::cast<IREE::HAL::DeviceTargetAttr>(targetAttr));
- }
- return result;
- }
- op = op->getParentOp();
- }
- return {}; // No devices found; let caller decide what to do.
-}
-
-// Returns a set of all configuration attributes from all device targets with
-// a configuration set. Targets with no configuration set are ignored.
-static SmallVector<DictionaryAttr> lookupOptionalConfigAttrs(Operation *op) {
- auto targetAttrs = IREE::HAL::DeviceTargetAttr::lookup(op);
- if (targetAttrs.empty())
- return {};
- SmallVector<DictionaryAttr> configAttrs;
- for (auto targetAttr : targetAttrs) {
- auto configAttr = targetAttr.getConfiguration();
- if (configAttr)
- configAttrs.push_back(configAttr);
- }
- return configAttrs;
-}
-
-void DeviceTargetAttr::getExecutableTargets(
- SetVector<IREE::HAL::ExecutableTargetAttr> &resultAttrs) {
- for (auto attr : getExecutableTargets()) {
- resultAttrs.insert(attr);
- }
-}
-
-// Returns a set of all configuration attributes from all device targets.
-// Returns nullopt if any target is missing a configuration attribute.
-static std::optional<SmallVector<DictionaryAttr>>
-lookupRequiredConfigAttrs(Operation *op) {
- auto targetAttrs = IREE::HAL::DeviceTargetAttr::lookup(op);
- if (targetAttrs.empty())
- return std::nullopt;
- SmallVector<DictionaryAttr> configAttrs;
- for (auto targetAttr : targetAttrs) {
- auto configAttr = targetAttr.getConfiguration();
- if (!configAttr)
- return std::nullopt;
- configAttrs.push_back(configAttr);
- }
- return configAttrs;
-}
-
-template <typename AttrT>
-static std::optional<typename AttrT::ValueType> joinConfigAttrs(
- ArrayRef<DictionaryAttr> configAttrs, StringRef name,
- std::function<typename AttrT::ValueType(typename AttrT::ValueType,
- typename AttrT::ValueType)>
- join) {
- if (configAttrs.empty())
- return std::nullopt;
- auto firstValue = configAttrs.front().getAs<AttrT>(name);
- if (!firstValue)
- return std::nullopt;
- auto result = firstValue.getValue();
- for (auto configAttr : configAttrs.drop_front(1)) {
- auto value = configAttr.getAs<AttrT>(name);
- if (!value)
- return std::nullopt;
- result = join(result, value.getValue());
- }
- return result;
-}
-
-template <typename AttrT>
-static std::optional<StaticRange<typename AttrT::ValueType>>
-joinConfigStaticRanges(ArrayRef<DictionaryAttr> configAttrs, StringRef name,
- std::function<StaticRange<typename AttrT::ValueType>(
- StaticRange<typename AttrT::ValueType>,
- StaticRange<typename AttrT::ValueType>)>
- join) {
- if (configAttrs.empty())
- return std::nullopt;
- auto firstValue = configAttrs.front().getAs<AttrT>(name);
- if (!firstValue)
- return std::nullopt;
- StaticRange<typename AttrT::ValueType> result{firstValue.getValue()};
- for (auto configAttr : configAttrs.drop_front(1)) {
- auto value = configAttr.getAs<AttrT>(name);
- if (!value)
- return std::nullopt;
- result =
- join(result, StaticRange<typename AttrT::ValueType>{value.getValue()});
- }
- return result;
-}
-
-// static
-bool DeviceTargetAttr::lookupConfigAttrAny(Operation *op, StringRef name) {
- auto configAttrs = lookupOptionalConfigAttrs(op);
- if (configAttrs.empty())
- return false;
- for (auto configAttr : configAttrs) {
- if (configAttr.get(name))
- return true;
- }
- return false;
-}
-
-// static
-bool DeviceTargetAttr::lookupConfigAttrAll(Operation *op, StringRef name) {
- auto configAttrs = lookupRequiredConfigAttrs(op);
- if (!configAttrs)
- return false;
- for (auto configAttr : *configAttrs) {
- if (!configAttr.get(name))
- return false;
- }
- return true;
-}
-
-// static
-std::optional<bool> DeviceTargetAttr::lookupConfigAttrAnd(Operation *op,
- StringRef name) {
- auto configAttrs = lookupRequiredConfigAttrs(op);
- if (!configAttrs)
- return std::nullopt;
- return joinConfigAttrs<BoolAttr>(
- configAttrs.value(), name, [](bool lhs, bool rhs) { return lhs && rhs; });
-}
-
-// static
-std::optional<bool> DeviceTargetAttr::lookupConfigAttrOr(Operation *op,
- StringRef name) {
- auto configAttrs = lookupRequiredConfigAttrs(op);
- if (!configAttrs)
- return std::nullopt;
- return joinConfigAttrs<BoolAttr>(
- configAttrs.value(), name, [](bool lhs, bool rhs) { return lhs || rhs; });
-}
-
-// static
-std::optional<StaticRange<APInt>>
-DeviceTargetAttr::lookupConfigAttrRange(Operation *op, StringRef name) {
- auto configAttrs = lookupRequiredConfigAttrs(op);
- if (!configAttrs)
- return std::nullopt;
- return joinConfigStaticRanges<IntegerAttr>(
- configAttrs.value(), name,
- [](StaticRange<APInt> lhs, StaticRange<APInt> rhs) {
- return StaticRange<APInt>{
- llvm::APIntOps::smin(lhs.min, rhs.min),
- llvm::APIntOps::smax(lhs.max, rhs.max),
- };
- });
-}
-
-// static
-SmallVector<ExecutableTargetAttr, 4>
-DeviceTargetAttr::lookupExecutableTargets(Operation *op) {
- SmallVector<ExecutableTargetAttr, 4> resultAttrs;
- for (auto deviceTargetAttr : lookup(op)) {
- for (auto executableTargetAttr : deviceTargetAttr.getExecutableTargets()) {
- if (!llvm::is_contained(resultAttrs, executableTargetAttr)) {
- resultAttrs.push_back(executableTargetAttr);
- }
- }
- }
- return resultAttrs;
-}
-
-//===----------------------------------------------------------------------===//
// #hal.executable.target<*>
//===----------------------------------------------------------------------===//
@@ -674,25 +411,426 @@
}
//===----------------------------------------------------------------------===//
-// #hal.affinity.queue<*>
+// #hal.device.alias<*>
//===----------------------------------------------------------------------===//
// static
-Attribute AffinityQueueAttr::parse(AsmParser &p, Type type) {
- int64_t mask = 0;
- // `<`
- if (failed(p.parseLess()))
+DeviceAliasAttr DeviceAliasAttr::get(MLIRContext *context, StringRef deviceID) {
+ return get(context, IREE::HAL::DeviceType::get(context),
+ StringAttr::get(context, deviceID), std::nullopt,
+ DictionaryAttr::get(context));
+}
+
+//===----------------------------------------------------------------------===//
+// #hal.device.target<*>
+//===----------------------------------------------------------------------===//
+
+// static
+DeviceTargetAttr DeviceTargetAttr::get(MLIRContext *context,
+ StringRef deviceID) {
+ // TODO(benvanik): query default configuration from the target backend.
+ return get(context, StringAttr::get(context, deviceID),
+ DictionaryAttr::get(context), {});
+}
+
+// static
+Attribute DeviceTargetAttr::parse(AsmParser &p, Type type) {
+ StringAttr deviceIDAttr;
+ DictionaryAttr configAttr;
+ SmallVector<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs;
+ // `<"device-id"`
+ if (failed(p.parseLess()) || failed(p.parseAttribute(deviceIDAttr))) {
return {};
- // `*` (any)
- if (succeeded(p.parseOptionalStar())) {
- mask = -1;
+ }
+ // `, `
+ if (succeeded(p.parseOptionalComma())) {
+ if (succeeded(p.parseOptionalLSquare())) {
+ // `[targets, ...]` (optional)
+ do {
+ IREE::HAL::ExecutableTargetAttr executableTargetAttr;
+ if (failed(p.parseAttribute(executableTargetAttr))) {
+ return {};
+ }
+ executableTargetAttrs.push_back(executableTargetAttr);
+ } while (succeeded(p.parseOptionalComma()));
+ if (failed(p.parseRSquare())) {
+ return {};
+ }
+ } else {
+ // `{config dict}` (optional)
+ if (failed(p.parseAttribute(configAttr))) {
+ return {};
+ }
+ // `, [targets, ...]` (optional)
+ if (succeeded(p.parseOptionalComma())) {
+ if (failed(p.parseLSquare())) {
+ return {};
+ }
+ do {
+ IREE::HAL::ExecutableTargetAttr executableTargetAttr;
+ if (failed(p.parseAttribute(executableTargetAttr))) {
+ return {};
+ }
+ executableTargetAttrs.push_back(executableTargetAttr);
+ } while (succeeded(p.parseOptionalComma()));
+ if (failed(p.parseRSquare())) {
+ return {};
+ }
+ }
+ }
+ }
+ // `>`
+ if (failed(p.parseGreater())) {
+ return {};
+ }
+ return get(p.getContext(), deviceIDAttr, configAttr, executableTargetAttrs);
+}
+
+void DeviceTargetAttr::print(AsmPrinter &p) const {
+ auto &os = p.getStream();
+ os << "<";
+ p.printAttribute(getDeviceID());
+ auto configAttr = getConfiguration();
+ if (configAttr && !configAttr.empty()) {
+ os << ", ";
+ p.printAttribute(configAttr);
+ }
+ auto executableTargetAttrs = getExecutableTargets();
+ if (!executableTargetAttrs.empty()) {
+ os << ", [";
+ llvm::interleaveComma(executableTargetAttrs, os,
+ [&](auto executableTargetAttr) {
+ p.printAttribute(executableTargetAttr);
+ });
+ os << "]";
+ }
+ os << ">";
+}
+
+std::string DeviceTargetAttr::getSymbolNameFragment() {
+ std::string name = getDeviceID().getValue().lower();
+ if (auto ordinalAttr =
+ dyn_cast_if_present<IntegerAttr>(getConfigurationAttr("ordinal"))) {
+ name += "_";
+ name += std::to_string(ordinalAttr.getInt());
+ name += "_"; // can't have trailing numbers
+ }
+ return sanitizeSymbolName(name);
+}
+
+bool DeviceTargetAttr::hasConfigurationAttr(StringRef name) {
+ auto configAttr = getConfiguration();
+ return configAttr && configAttr.get(name);
+}
+
+Attribute DeviceTargetAttr::getConfigurationAttr(StringRef name) {
+ if (auto configAttr = getConfiguration()) {
+ return configAttr.get(name);
+ }
+ return {};
+}
+
+void DeviceTargetAttr::getExecutableTargets(
+ SetVector<IREE::HAL::ExecutableTargetAttr> &resultAttrs) {
+ for (auto attr : getExecutableTargets()) {
+ resultAttrs.insert(attr);
+ }
+}
+
+void IREE::HAL::DeviceTargetAttr::printStatusDescription(
+ llvm::raw_ostream &os) const {
+ mlir::cast<Attribute>(this)->print(os, /*elideType=*/true);
+}
+
+// Produces a while-loop that enumerates each device available and tries to
+// match it against the target information. SCF is... not very wieldy, but this
+// is effectively:
+// ```
+// %device_count = hal.devices.count : index
+// %result:3 = scf.while(%i = 0, %match_ordinal = 0, %device = null) {
+// %is_null = util.cmp.eq %device, null : !hal.device
+// %in_bounds = arith.cmpi slt %i, %device_count : index
+// %continue_while = arith.andi %is_null, %in_bounds : i1
+// scf.condition(%continue_while) %i, %match_ordinal %device
+// : index, index, !hal.device
+// } do {
+// %device_i = hal.devices.get %i : !hal.device
+// %device_match = <<buildDeviceMatch>>(%device_i)
+// %ordinal_match = arith.cmpi eq %match_ordinal, %device_ordinal : index
+// %is_match = arith.andi %device_match, %ordinal_match : i1
+// %try_device = arith.select %is_match, %device_i, null : !hal.device
+// %next_i = arith.addi %i, %c1 : index
+// %match_adv = arith.select %device_match, %c1, %c0 : index
+// %next_match_ordinal = arith.addi %match_ordinal, %match_adv : index
+// scf.yield %next_i, %next_match_ordinal, %try_device
+// : index, index !hal.device
+// }
+// ```
+// Upon completion %result#1 contains the device (or null).
+// If the target had an ordinal specified we skip matches until a match with the
+// specified ordinal is reached.
+Value IREE::HAL::DeviceTargetAttr::buildDeviceEnumeration(
+ Location loc, IREE::HAL::BuildDeviceTargetMatchFn buildDeviceTargetMatch,
+ OpBuilder &builder) const {
+ // Device configuration can control selection beyond just the match
+ // expression.
+ auto configAttr = getConfiguration();
+ IntegerAttr deviceOrdinalAttr =
+ configAttr ? configAttr.getAs<IntegerAttr>("ordinal") : IntegerAttr{};
+
+ // Defers to the target backend to build the device match or does a simple
+ // fallback for unregistered backends (usually for testing, but may be used
+ // as a way to bypass validation for out-of-tree experiments).
+ auto buildDeviceMatch = [&](Location loc, Value device,
+ OpBuilder &builder) -> Value {
+ // Ask the target backend to build the match expression. It may opt to
+ // let the default handling take care of things.
+ Value match = buildDeviceTargetMatch(loc, device, *this, builder);
+ if (match)
+ return match;
+ return IREE::HAL::DeviceTargetAttr::buildDeviceIDAndExecutableFormatsMatch(
+ loc, device, getDeviceID(), getExecutableTargets(), builder);
+ };
+
+ // Enumerate all devices and match the first one found (if any).
+ Type indexType = builder.getIndexType();
+ Type deviceType = builder.getType<IREE::HAL::DeviceType>();
+ Value c0 = builder.create<arith::ConstantIndexOp>(loc, 0);
+ Value c1 = builder.create<arith::ConstantIndexOp>(loc, 1);
+ Value nullDevice = builder.create<IREE::Util::NullOp>(loc, deviceType);
+ Value deviceOrdinal = deviceOrdinalAttr
+ ? builder.create<arith::ConstantIndexOp>(
+ loc, deviceOrdinalAttr.getInt())
+ : c0;
+ Value deviceCount = builder.create<IREE::HAL::DevicesCountOp>(loc, indexType);
+ auto whileOp = builder.create<scf::WhileOp>(
+ loc,
+ TypeRange{
+ /*i=*/indexType,
+ /*match_ordinal=*/indexType,
+ /*device=*/deviceType,
+ },
+ ValueRange{
+ /*i=*/c0,
+ /*match_ordinal=*/c0,
+ /*device=*/nullDevice,
+ },
+ [&](OpBuilder &beforeBuilder, Location loc, ValueRange operands) {
+ Value isNull = beforeBuilder.create<IREE::Util::CmpEQOp>(
+ loc, operands[/*device=*/2], nullDevice);
+ Value inBounds = beforeBuilder.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::slt, operands[/*i=*/0], deviceCount);
+ Value continueWhile =
+ beforeBuilder.create<arith::AndIOp>(loc, isNull, inBounds);
+ beforeBuilder.create<scf::ConditionOp>(loc, continueWhile, operands);
+ },
+ [&](OpBuilder &afterBuilder, Location loc, ValueRange operands) {
+ // Check whether the device is a match.
+ Value device = afterBuilder.create<IREE::HAL::DevicesGetOp>(
+ loc, deviceType, operands[/*i=*/0]);
+ Value isDeviceMatch = buildDeviceMatch(loc, device, afterBuilder);
+
+ // Check whether whether this matching device ordinal is the requested
+ // ordinal out of all matching devices.
+ Value isOrdinalMatch = afterBuilder.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::eq, operands[/*match_ordinal=*/1],
+ deviceOrdinal);
+ Value nextMatchOrdinal = afterBuilder.create<arith::AddIOp>(
+ loc, operands[/*match_ordinal=*/1],
+ afterBuilder.create<arith::SelectOp>(loc, isDeviceMatch, c1, c0));
+
+ // Break if the device and ordinal match, otherwise continue with null.
+ Value isMatch = afterBuilder.create<arith::AndIOp>(loc, isDeviceMatch,
+ isOrdinalMatch);
+ Value tryDevice = afterBuilder.create<arith::SelectOp>(
+ loc, isMatch, device, nullDevice);
+
+ Value nextI =
+ afterBuilder.create<arith::AddIOp>(loc, operands[/*i=*/0], c1);
+ afterBuilder.create<scf::YieldOp>(
+ loc, ValueRange{
+ /*i=*/nextI,
+ /*match_ordinal=*/nextMatchOrdinal,
+ /*device=*/tryDevice,
+ });
+ });
+ return whileOp.getResult(/*device=*/2);
+}
+
+// static
+Value DeviceTargetAttr::buildDeviceIDAndExecutableFormatsMatch(
+ Location loc, Value device, StringRef deviceIDPattern,
+ ArrayRef<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs,
+ OpBuilder &builder) {
+ // Match first on the device ID, as that's the top-level filter.
+ Value idMatch = IREE::HAL::DeviceQueryOp::createI1(
+ loc, device, "hal.device.id", deviceIDPattern, builder);
+
+ // If there are executable formats defined we should check at least one of
+ // them is supported.
+ if (executableTargetAttrs.empty()) {
+ return idMatch; // just device ID
} else {
+ auto ifOp = builder.create<scf::IfOp>(loc, builder.getI1Type(), idMatch,
+ true, true);
+ auto thenBuilder = ifOp.getThenBodyBuilder();
+ Value anyFormatMatch = buildExecutableFormatMatch(
+ loc, device, executableTargetAttrs, thenBuilder);
+ thenBuilder.create<scf::YieldOp>(loc, anyFormatMatch);
+ auto elseBuilder = ifOp.getElseBodyBuilder();
+ Value falseValue = elseBuilder.create<arith::ConstantIntOp>(loc, 0, 1);
+ elseBuilder.create<scf::YieldOp>(loc, falseValue);
+ return ifOp.getResult(0);
+ }
+}
+
+// static
+Value DeviceTargetAttr::buildExecutableFormatMatch(
+ Location loc, Value device,
+ ArrayRef<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs,
+ OpBuilder &builder) {
+ if (executableTargetAttrs.empty())
+ return builder.create<arith::ConstantIntOp>(loc, 1, 1);
+ Value anyFormatMatch;
+ for (auto executableTargetAttr : executableTargetAttrs) {
+ Value formatMatch = IREE::HAL::DeviceQueryOp::createI1(
+ loc, device, "hal.executable.format",
+ executableTargetAttr.getFormat().getValue(), builder);
+ if (!anyFormatMatch) {
+ anyFormatMatch = formatMatch;
+ } else {
+ anyFormatMatch =
+ builder.create<arith::OrIOp>(loc, anyFormatMatch, formatMatch);
+ }
+ }
+ return anyFormatMatch;
+}
+
+//===----------------------------------------------------------------------===//
+// #hal.device.ordinal<*>
+//===----------------------------------------------------------------------===//
+
+void IREE::HAL::DeviceOrdinalAttr::printStatusDescription(
+ llvm::raw_ostream &os) const {
+ mlir::cast<Attribute>(this)->print(os, /*elideType=*/true);
+}
+
+Value IREE::HAL::DeviceOrdinalAttr::buildDeviceEnumeration(
+ Location loc, IREE::HAL::BuildDeviceTargetMatchFn buildDeviceTargetMatch,
+ OpBuilder &builder) const {
+ return builder.create<IREE::HAL::DevicesGetOp>(
+ loc, getType(),
+ builder.create<arith::ConstantIndexOp>(loc, getOrdinal()));
+}
+
+//===----------------------------------------------------------------------===//
+// #hal.device.fallback<*>
+//===----------------------------------------------------------------------===//
+
+void IREE::HAL::DeviceFallbackAttr::printStatusDescription(
+ llvm::raw_ostream &os) const {
+ mlir::cast<Attribute>(this)->print(os, /*elideType=*/true);
+}
+
+Value IREE::HAL::DeviceFallbackAttr::buildDeviceEnumeration(
+ Location loc, IREE::HAL::BuildDeviceTargetMatchFn buildDeviceTargetMatch,
+ OpBuilder &builder) const {
+ // TODO(benvanik): hal.device.cast if needed - may need to look up the global
+ // to do it as we don't encode what the device is here in a way that is
+ // guaranteed to be consistent.
+ return builder.create<IREE::Util::GlobalLoadOp>(loc, getType(),
+ getName().getValue());
+}
+
+//===----------------------------------------------------------------------===//
+// #hal.device.select<*>
+//===----------------------------------------------------------------------===//
+
+// static
+DeviceSelectAttr DeviceSelectAttr::get(MLIRContext *context,
+ ArrayRef<Attribute> values) {
+ return DeviceSelectAttr::get(context, IREE::HAL::DeviceType::get(context),
+ ArrayAttr::get(context, values));
+}
+
+// static
+LogicalResult
+DeviceSelectAttr::verify(function_ref<mlir::InFlightDiagnostic()> emitError,
+ Type type, ArrayAttr devicesAttr) {
+ if (devicesAttr.empty())
+ return emitError() << "must have at least one device to select";
+ for (auto deviceAttr : devicesAttr) {
+ if (!mlir::isa<IREE::HAL::DeviceAliasAttr>(deviceAttr) &&
+ !mlir::isa<IREE::HAL::DeviceInitializationAttrInterface>(deviceAttr)) {
+ return emitError() << "can only select between #hal.device.alias, "
+ "#hal.device.target, #hal.device.ordinal, "
+ "#hal.device.fallback, or other device "
+ "initialization attributes";
+ }
+ }
+ // TODO(benvanik): when !hal.device is parameterized we should check that the
+ // type is compatible with the entries.
+ return success();
+}
+
+void IREE::HAL::DeviceSelectAttr::printStatusDescription(
+ llvm::raw_ostream &os) const {
+ // TODO(benvanik): print something easier to read (newline per device, etc).
+ mlir::cast<Attribute>(this)->print(os, /*elideType=*/true);
+}
+
+// Builds a recursive nest of try-else blocks for each device specified.
+Value IREE::HAL::DeviceSelectAttr::buildDeviceEnumeration(
+ Location loc, IREE::HAL::BuildDeviceTargetMatchFn buildDeviceTargetMatch,
+ OpBuilder &builder) const {
+ Type deviceType = builder.getType<IREE::HAL::DeviceType>();
+ Value nullDevice = builder.create<IREE::Util::NullOp>(loc, deviceType);
+ std::function<Value(ArrayRef<IREE::HAL::DeviceInitializationAttrInterface>,
+ OpBuilder &)>
+ buildTry;
+ buildTry =
+ [&](ArrayRef<IREE::HAL::DeviceInitializationAttrInterface> deviceAttrs,
+ OpBuilder &tryBuilder) -> Value {
+ auto deviceAttr = deviceAttrs.front();
+ Value tryDevice = deviceAttr.buildDeviceEnumeration(
+ loc, buildDeviceTargetMatch, tryBuilder);
+ if (deviceAttrs.size() == 1)
+ return tryDevice; // termination case
+ Value isNull =
+ tryBuilder.create<IREE::Util::CmpEQOp>(loc, tryDevice, nullDevice);
+ auto ifOp =
+ tryBuilder.create<scf::IfOp>(loc, deviceType, isNull, true, true);
+ auto thenBuilder = ifOp.getThenBodyBuilder();
+ Value tryChainDevice = buildTry(deviceAttrs.drop_front(1), thenBuilder);
+ thenBuilder.create<scf::YieldOp>(loc, tryChainDevice);
+ auto elseBuilder = ifOp.getElseBodyBuilder();
+ elseBuilder.create<scf::YieldOp>(loc, tryDevice);
+ return ifOp.getResult(0);
+ };
+ SmallVector<IREE::HAL::DeviceInitializationAttrInterface> deviceAttrs(
+ getDevices().getAsRange<IREE::HAL::DeviceInitializationAttrInterface>());
+ return buildTry(deviceAttrs, builder);
+}
+
+//===----------------------------------------------------------------------===//
+// #hal.device.affinity<*>
+//===----------------------------------------------------------------------===//
+
+// static
+Attribute DeviceAffinityAttr::parse(AsmParser &p, Type type) {
+ // `<@device`
+ StringAttr deviceName;
+ int64_t queueMask = -1;
+ if (failed(p.parseLess()) || failed(p.parseSymbolName(deviceName)))
+ return {};
+ if (succeeded(p.parseOptionalComma())) {
// `[`queue_bit[, ...] `]`
+ queueMask = 0;
if (failed(p.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() {
int64_t i = 0;
if (failed(p.parseInteger(i)))
return failure();
- mask |= 1ll << i;
+ queueMask |= 1ll << i;
return success();
}))) {
return {};
@@ -701,19 +839,18 @@
// `>`
if (failed(p.parseGreater()))
return {};
- return get(p.getContext(), mask);
+ return get(p.getContext(), FlatSymbolRefAttr::get(deviceName), queueMask);
}
-void AffinityQueueAttr::print(AsmPrinter &p) const {
+void DeviceAffinityAttr::print(AsmPrinter &p) const {
auto &os = p.getStream();
os << "<";
- int64_t mask = getMask();
- if (mask == -1) {
- os << "*";
- } else {
- os << "[";
- for (int i = 0, j = 0; i < sizeof(mask) * 8; ++i) {
- if (mask & (1ll << i)) {
+ os << getDevice();
+ int64_t queueMask = getQueueMask();
+ if (queueMask != -1) {
+ os << ", [";
+ for (int i = 0, j = 0; i < sizeof(queueMask) * 8; ++i) {
+ if (queueMask & (1ll << i)) {
if (j++ > 0)
os << ", ";
os << i;
@@ -724,45 +861,169 @@
os << ">";
}
-bool AffinityQueueAttr::isExecutableWith(
+bool DeviceAffinityAttr::isExecutableWith(
IREE::Stream::AffinityAttr other) const {
if (!other)
return true;
- // Only compatible with other queue affinities today. When we extend the
- // attributes to specify device targets we'd want to check here.
- auto otherQueueAttr = llvm::dyn_cast_if_present<AffinityQueueAttr>(other);
- if (!otherQueueAttr)
+ // Only compatible with the same exact devices today. We could support a
+ // peering model to allow operations to move across devices in a peered set
+ // but that may be best done at higher levels and avoided once we get to the
+ // "are these the same device" stage.
+ auto otherAffinityAttr = llvm::dyn_cast_if_present<DeviceAffinityAttr>(other);
+ if (!otherAffinityAttr || getDevice() != otherAffinityAttr.getDevice())
return false;
// If this affinity is a subset of the target affinity then it can execute
// with it.
- if ((getMask() & otherQueueAttr.getMask()) == getMask())
+ if ((getQueueMask() & otherAffinityAttr.getQueueMask()) == getQueueMask())
return true;
// Otherwise not compatible.
return false;
}
IREE::Stream::AffinityAttr
-AffinityQueueAttr::joinOR(IREE::Stream::AffinityAttr other) const {
+DeviceAffinityAttr::joinOR(IREE::Stream::AffinityAttr other) const {
if (!other)
return *this;
if (!IREE::Stream::AffinityAttr::canExecuteTogether(*this, other)) {
return nullptr;
}
- auto otherQueueAttr = llvm::dyn_cast_if_present<AffinityQueueAttr>(other);
- return AffinityQueueAttr::get(getContext(),
- getMask() | otherQueueAttr.getMask());
+ auto otherAffinityAttr = llvm::dyn_cast_if_present<DeviceAffinityAttr>(other);
+ return DeviceAffinityAttr::get(getContext(), getDevice(),
+ getQueueMask() |
+ otherAffinityAttr.getQueueMask());
}
IREE::Stream::AffinityAttr
-AffinityQueueAttr::joinAND(IREE::Stream::AffinityAttr other) const {
+DeviceAffinityAttr::joinAND(IREE::Stream::AffinityAttr other) const {
if (!other)
return *this;
if (!IREE::Stream::AffinityAttr::canExecuteTogether(*this, other)) {
return nullptr;
}
- auto otherQueueAttr = llvm::dyn_cast_if_present<AffinityQueueAttr>(other);
- return AffinityQueueAttr::get(getContext(),
- getMask() & otherQueueAttr.getMask());
+ auto otherAffinityAttr = llvm::dyn_cast_if_present<DeviceAffinityAttr>(other);
+ return DeviceAffinityAttr::get(getContext(), getDevice(),
+ getQueueMask() &
+ otherAffinityAttr.getQueueMask());
+}
+
+bool DeviceAffinityAttr::isLegalToInline(Operation *inlineSite,
+ Operation *inlinable) const {
+ // Look up the affinity of the inlining target site and only allow inlining if
+ // it matches exactly. We could make a decision as to whether we allow
+ // inlining when queues are subsets (so if the target site allows any queue
+ // and the inlinable allows queue 2 then allow, etc). In the future we may
+ // want to allow util.scope restrictions within the inline target to keep
+ // queue specification tighter but today most queue masks are wildcarded
+ // anyway.
+ auto targetAffinityAttr = IREE::Stream::AffinityAttr::lookup(inlineSite);
+ return *this == targetAffinityAttr;
+}
+
+//===----------------------------------------------------------------------===//
+// #hal.device.promise<*>
+//===----------------------------------------------------------------------===//
+
+// static
+Attribute DevicePromiseAttr::parse(AsmParser &p, Type type) {
+ // `<@device`
+ StringAttr deviceName;
+ int64_t queueMask = -1;
+ if (failed(p.parseLess()) || failed(p.parseSymbolName(deviceName)))
+ return {};
+ if (succeeded(p.parseOptionalComma())) {
+ // `[`queue_bit[, ...] `]`
+ queueMask = 0;
+ if (failed(p.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() {
+ int64_t i = 0;
+ if (failed(p.parseInteger(i)))
+ return failure();
+ queueMask |= 1ll << i;
+ return success();
+ }))) {
+ return {};
+ }
+ }
+ // `>`
+ if (failed(p.parseGreater()))
+ return {};
+ return get(p.getContext(), deviceName, queueMask);
+}
+
+void DevicePromiseAttr::print(AsmPrinter &p) const {
+ auto &os = p.getStream();
+ os << "<@";
+ os << getDevice().getValue();
+ int64_t queueMask = getQueueMask();
+ if (queueMask != -1) {
+ os << ", [";
+ for (int i = 0, j = 0; i < sizeof(queueMask) * 8; ++i) {
+ if (queueMask & (1ll << i)) {
+ if (j++ > 0)
+ os << ", ";
+ os << i;
+ }
+ }
+ os << "]";
+ }
+ os << ">";
+}
+
+bool DevicePromiseAttr::isExecutableWith(
+ IREE::Stream::AffinityAttr other) const {
+ if (!other)
+ return true;
+ // Only compatible with the same exact devices today. We could support a
+ // peering model to allow operations to move across devices in a peered set
+ // but that may be best done at higher levels and avoided once we get to the
+ // "are these the same device" stage.
+ auto otherPromiseAttr = llvm::dyn_cast_if_present<DevicePromiseAttr>(other);
+ if (!otherPromiseAttr || getDevice() != otherPromiseAttr.getDevice())
+ return false;
+ // If this affinity is a subset of the target affinity then it can execute
+ // with it.
+ if ((getQueueMask() & otherPromiseAttr.getQueueMask()) == getQueueMask())
+ return true;
+ // Otherwise not compatible.
+ return false;
+}
+
+IREE::Stream::AffinityAttr
+DevicePromiseAttr::joinOR(IREE::Stream::AffinityAttr other) const {
+ if (!other)
+ return *this;
+ if (!IREE::Stream::AffinityAttr::canExecuteTogether(*this, other)) {
+ return nullptr;
+ }
+ auto otherPromiseAttr = llvm::dyn_cast_if_present<DevicePromiseAttr>(other);
+ return DevicePromiseAttr::get(getContext(), getDevice(),
+ getQueueMask() |
+ otherPromiseAttr.getQueueMask());
+}
+
+IREE::Stream::AffinityAttr
+DevicePromiseAttr::joinAND(IREE::Stream::AffinityAttr other) const {
+ if (!other)
+ return *this;
+ if (!IREE::Stream::AffinityAttr::canExecuteTogether(*this, other)) {
+ return nullptr;
+ }
+ auto otherPromiseAttr = llvm::dyn_cast_if_present<DevicePromiseAttr>(other);
+ return DevicePromiseAttr::get(getContext(), getDevice(),
+ getQueueMask() &
+ otherPromiseAttr.getQueueMask());
+}
+
+bool DevicePromiseAttr::isLegalToInline(Operation *inlineSite,
+ Operation *inlinable) const {
+ // Look up the affinity of the inlining target site and only allow inlining if
+ // it matches exactly. We could make a decision as to whether we allow
+ // inlining when queues are subsets (so if the target site allows any queue
+ // and the inlinable allows queue 2 then allow, etc). In the future we may
+ // want to allow util.scope restrictions within the inline target to keep
+ // queue specification tighter but today most queue masks are wildcarded
+ // anyway.
+ auto targetAffinityAttr = IREE::Stream::AffinityAttr::lookup(inlineSite);
+ return *this == targetAffinityAttr;
}
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
index 9d85020..9511cf8 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
@@ -477,95 +477,6 @@
"HAL binding array attribute">;
//===----------------------------------------------------------------------===//
-// #hal.device.target<*>
-//===----------------------------------------------------------------------===//
-
-def HAL_DeviceTargetAttr :
- AttrDef<HAL_Dialect, "DeviceTarget"> {
- let mnemonic = "device.target";
- let summary = [{generic device target specification}];
- let description = [{
- Specifies the properties of a target runtime device.
- Target devices are specified with a canonical identifier matching those used
- by the runtime (such as `cpu`, `vulkan`, etc). Target devices may support
- several target executable formats specified with `#hal.executable.target`.
- An optional configuration dictionary allows for overriding backend defaults.
-
- Example:
- ```mlir
- #hal.device.target<"llvm-cpu", {
- device_configuration = ...
- }, [
- #hal.executable.target<"llvm-cpu", "embedded-elf-arm_32">,
- #hal.executable.target<"llvm-cpu", "embedded-elf-arm_64">,
- ]>
- ```
- }];
- let parameters = (ins
- AttrParameter<"StringAttr", "">:$deviceID,
- AttrParameter<"DictionaryAttr", "">:$configuration,
- ArrayRefParameter<"ExecutableTargetAttr", "">:$executable_targets
- );
- let builders = [
- AttrBuilder<(ins "StringRef":$deviceID)>,
- ];
-
- let extraClassDeclaration = [{
- // Returns a symbol-compatible name that pseudo-uniquely identifies this
- // target. Callers must perform deduplication when required.
- std::string getSymbolNameFragment();
-
- // Returns true if there's an attribute with the given name in the
- // configuration dictionary.
- bool hasConfigurationAttr(StringRef name);
-
- // Returns a list of target devices that may be active for the given
- // operation. This will recursively walk parent operations until one with
- // the `hal.device.targets` attribute is found.
- static SmallVector<DeviceTargetAttr, 4> lookup(Operation *op);
-
- // Returns true if there is any UnitAttr with |name| in any device
- // configuration for the given |op|.
- static bool lookupConfigAttrAny(Operation *op, StringRef name);
-
- // Returns true if all device configurations found for the given |op| have
- // a UnitAttr with |name|.
- static bool lookupConfigAttrAll(Operation *op, StringRef name);
-
- // Returns the AND of boolean attributes of |name| in all device
- // configurations found for the given |op|.
- // Returns nullopt if any config does not have the key defined indicating
- // that it's not statically known/runtime dynamic.
- static std::optional<bool>
- lookupConfigAttrAnd(Operation *op, StringRef name);
-
- // Returns the OR of boolean attributes of |name| in all device
- // configurations found for the given |op|.
- // Returns nullopt if any config does not have the key defined indicating
- // that it's not statically known/runtime dynamic.
- static std::optional<bool>
- lookupConfigAttrOr(Operation *op, StringRef name);
-
- // Returns the range of integer attributes of |name| in all device
- // configurations found for the given |op|.
- // Returns nullopt if any config does not have the key defined indicating
- // that it's not statically known/runtime dynamic.
- static std::optional<StaticRange<APInt>>
- lookupConfigAttrRange(Operation *op, StringRef name);
-
- // Returns zero or more executable targets that this device supports.
- void getExecutableTargets(
- SetVector<IREE::HAL::ExecutableTargetAttr> &resultAttrs);
-
- // Returns a list of all target executable configurations that may be
- // required for the given operation.
- static SmallVector<IREE::HAL::ExecutableTargetAttr, 4>
- lookupExecutableTargets(Operation *op);
- }];
- let hasCustomAssemblyFormat = 1;
-}
-
-//===----------------------------------------------------------------------===//
// #hal.executable.target<*>
//===----------------------------------------------------------------------===//
@@ -743,41 +654,318 @@
}
//===----------------------------------------------------------------------===//
-// #hal.affinity.queue<*>
+// #hal.device.alias<*>
//===----------------------------------------------------------------------===//
-def HAL_AffinityQueueAttr : AttrDef<HAL_Dialect, "AffinityQueue", [
+def HAL_DeviceAliasAttr : AttrDef<HAL_Dialect, "DeviceAlias", [
+ TypedAttrInterface,
+]> {
+ let mnemonic = "device.alias";
+ let summary = [{device target named alias}];
+ let description = [{
+ Specifies a device target by named alias whose configuration will be
+ expanded based on compiler configuration and flags. Any configuration
+ provided will override any defaults provided by the configuration.
+
+ Example:
+ ```mlir
+ // Default `vulkan` device:
+ #hal.device.alias<"vulkan"> : !hal.device
+ // Default `vulkan` device with configuration overrides:
+ #hal.device.alias<"vulkan", {
+ device_config = 123 : index
+ }> : !hal.device
+ // The 3rd default `vulkan` device detected at runtime (ordinal = 3):
+ #hal.device.alias<"vulkan"[3]> : !hal.device
+ ```
+ }];
+
+ let parameters = (ins
+ AttributeSelfTypeParameter<"">:$type,
+ AttrParameter<"StringAttr", "">:$deviceID,
+ OptionalParameter<"std::optional<int64_t>", "">:$ordinal,
+ OptionalParameter<"DictionaryAttr", "">:$configuration
+ );
+
+ let builders = [
+ AttrBuilder<(ins "StringRef":$deviceID)>,
+ ];
+
+ let assemblyFormat = [{
+ `<`
+ $deviceID
+ `` (`[` $ordinal^ `]`)?
+ (`,` $configuration^)?
+ `>`
+ }];
+
+ let extraClassDeclaration = [{
+ Type getType() { return IREE::HAL::DeviceType::get(getContext()); }
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// #hal.device.target<*>
+//===----------------------------------------------------------------------===//
+
+def HAL_DeviceTargetAttr : AttrDef<HAL_Dialect, "DeviceTarget", [
+ DeclareAttrInterfaceMethods<HAL_DeviceInitializationAttrInterface>,
+]> {
+ let mnemonic = "device.target";
+ let summary = [{generic device target specification}];
+ let description = [{
+ Specifies the properties of a target runtime device.
+ Target devices are specified with a canonical identifier matching those used
+ by the runtime (such as `cpu`, `vulkan`, etc). Target devices may support
+ several target executable formats specified with `#hal.executable.target`.
+ An optional configuration dictionary allows for overriding backend defaults.
+
+ If used to initialize a device global returns the first device matching the
+ target requirements or null if no devices match. An optional `ordinal`
+ index may be provided that selects the N-th matching device and is used to
+ select between multiple homogeneous devices.
+
+ Example:
+ ```mlir
+ #hal.device.target<"llvm-cpu", {
+ device_configuration = ...
+ }, [
+ #hal.executable.target<"llvm-cpu", "embedded-elf-arm_32">,
+ #hal.executable.target<"llvm-cpu", "embedded-elf-arm_64">,
+ ]> : !hal.device
+ ```
+ }];
+ let parameters = (ins
+ AttrParameter<"StringAttr", "">:$deviceID,
+ AttrParameter<"DictionaryAttr", "">:$configuration,
+ ArrayRefParameter<"ExecutableTargetAttr", "">:$executable_targets
+ );
+ let builders = [
+ AttrBuilder<(ins "StringRef":$deviceID)>,
+ ];
+
+ let extraClassDeclaration = [{
+ Type getType() { return IREE::HAL::DeviceType::get(getContext()); }
+
+ // Returns a symbol-compatible name that pseudo-uniquely identifies this
+ // target. Callers must perform deduplication when required.
+ std::string getSymbolNameFragment();
+
+ // Returns true if there's an attribute with the given name in the
+ // configuration dictionary.
+ bool hasConfigurationAttr(StringRef name);
+ // Returns the configuration attribute with the given name if found.
+ Attribute getConfigurationAttr(StringRef name);
+
+ // Returns zero or more executable targets that this device supports.
+ void getExecutableTargets(
+ SetVector<IREE::HAL::ExecutableTargetAttr> &resultAttrs);
+
+ // Builds an expression that returns an i1 indicating whether the given
+ // |device| matches the device ID string pattern and executable target
+ // requirements.
+ static Value buildDeviceIDAndExecutableFormatsMatch(
+ Location loc, Value device, StringRef deviceIDPattern,
+ ArrayRef<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs,
+ OpBuilder &builder);
+
+ // Builds a match expression that returns an i1 indicating whether the given
+ // |device| supports any one of the |executableTargetAttrs|.
+ static Value buildExecutableFormatMatch(
+ Location loc, Value device,
+ ArrayRef<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs,
+ OpBuilder &builder);
+ }];
+
+ let hasCustomAssemblyFormat = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// #hal.device.ordinal<*>
+//===----------------------------------------------------------------------===//
+
+def HAL_DeviceOrdinalAttr : AttrDef<HAL_Dialect, "DeviceOrdinal", [
+ DeclareAttrInterfaceMethods<HAL_DeviceInitializationAttrInterface>,
+]> {
+ let mnemonic = "device.ordinal";
+ let summary = [{specifies a device by runtime registration ordinal}];
+ let description = [{
+ Represents the device registered with the runtime in the order it was
+ registered with ordinal 0 being the first registered. Returns null during
+ initialization if the device ordinal is out of range.
+ }];
+
+ let parameters = (ins
+ AttributeSelfTypeParameter<"">:$type,
+ AttrParameter<"int64_t", "">:$ordinal
+ );
+
+ let assemblyFormat = [{
+ `<` $ordinal `>`
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// #hal.device.fallback<*>
+//===----------------------------------------------------------------------===//
+
+def HAL_DeviceFallbackAttr : AttrDef<HAL_Dialect, "DeviceFallback", [
+ DeclareAttrInterfaceMethods<HAL_DeviceInitializationAttrInterface>,
+]> {
+ let mnemonic = "device.fallback";
+ let summary = [{specifies a reference to another device}];
+ let description = [{
+ Specifies by symbol a device that has already been initialized.
+ Returns null during initialization if the device specified as a fallback is
+ null.
+ }];
+
+ let parameters = (ins
+ AttributeSelfTypeParameter<"">:$type,
+ AttrParameter<"FlatSymbolRefAttr", "">:$name
+ );
+
+ let assemblyFormat = [{
+ `<` $name `>`
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// #hal.device.select<*>
+//===----------------------------------------------------------------------===//
+
+def HAL_DeviceSelectAttr : AttrDef<HAL_Dialect, "DeviceSelect", [
+ DeclareAttrInterfaceMethods<HAL_DeviceInitializationAttrInterface>,
+]> {
+ let mnemonic = "device.select";
+ let summary = [{selects a device from one or more options}];
+ let description = [{
+ Selects a HAL device at runtime by either enumerating and querying for
+ target support or matching the given existing device by affinity.
+ Devices are selected in the order listed. Fails during initialization if no
+ device can be selected.
+
+ Examples:
+ ```mlir
+ // Selects a single device matching the given target.
+ #hal.device.select<[
+ #hal.device.target<"..."> : !hal.device
+ ]> : !hal.device
+ // Selects a specific device with the given symbol.
+ #hal.device.select<[
+ #hal.device.fallback<@device_0> : !hal.device
+ ]> : !hal.device
+ // Selects a specific device by ordinal as registered at runtime.
+ #hal.device.select<[
+ #hal.device.ordinal<0> : !hal.device
+ ]> : !hal.device
+ // Selects an optional device if available and otherwise @fallback.
+ #hal.device.select<[
+ #hal.device.target<"some_optional_device"> : !hal.device,
+ #hal.device.fallback<@fallback> : !hal.device
+ ]> : !hal.device
+ ```
+ }];
+
+ let parameters = (ins
+ AttributeSelfTypeParameter<"">:$type,
+ AttrParameter<"ArrayAttr", "">:$devices
+ );
+
+ let builders = [
+ AttrBuilder<(ins
+ "ArrayRef<Attribute>":$values
+ )>,
+ ];
+
+ let assemblyFormat = [{
+ `<` $devices `>`
+ }];
+
+ let genVerifyDecl = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// #hal.device.affinity<*>
+//===----------------------------------------------------------------------===//
+
+def HAL_DeviceAffinityAttr : AttrDef<HAL_Dialect, "DeviceAffinity", [
DeclareAttrInterfaceMethods<Stream_AffinityAttr, [
"isExecutableWith",
"joinOR",
"joinAND",
]>,
Util_HoistableAttrInterface,
+ DeclareAttrInterfaceMethods<Util_InliningPolicyAttrInterface, [
+ "isLegalToInline",
+ ]>,
]> {
- let mnemonic = "affinity.queue";
- let summary = [{specifies a set of allowed queues for an operation}];
+ let mnemonic = "device.affinity";
+ let summary = [{specifies a named device and optional queue affinity}];
let description = [{
- WIP; see [#10765](https://github.com/iree-org/iree/issues/10765).
- This may change in the future to either be a nested attribute on a larger
- affinity struct or be defined by an implementation of the affinity attr
- interface. For now this allows higher levels of the stack to specify
- queues such that the stream dialect can understand them and they can be
- lowered into the HAL dialect.
-
Specifies that an annotated operation or scope is only allowed to execute on
- the set of queues (0-64) provided. Operations will not run on other queues.
+ a specific device and optionally a set of queues (0-64) provided.
+ Operations will not run on other queues. If the queue mask is omitted then
+ any queue on the device is allowed to execute the specified operations.
Example:
```mlir
- // any queue
- #hal.affinity.queue<*>
- // queues 4 and 5
- #hal.affinity.queue<[4, 5]>
+ // Any queue on @device_a.
+ #hal.device.affinity<@device_a>
+ // Queues 4 and 5 on @device_b.
+ #hal.device.affinity<@device_b, [4, 5]>
```
}];
let parameters = (ins
- AttrParameter<"int64_t", "">:$mask
+ AttrParameter<"SymbolRefAttr", "">:$device,
+ AttrParameter<"int64_t", "">:$queue_mask
+ );
+
+ let hasCustomAssemblyFormat = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// #hal.device.promise<*>
+//===----------------------------------------------------------------------===//
+
+def HAL_DevicePromiseAttr : AttrDef<HAL_Dialect, "DevicePromise", [
+ DeclareAttrInterfaceMethods<Stream_AffinityAttr, [
+ "isExecutableWith",
+ "joinOR",
+ "joinAND",
+ ]>,
+ DeclareAttrInterfaceMethods<Util_InliningPolicyAttrInterface, [
+ "isLegalToInline",
+ ]>,
+]> {
+ let mnemonic = "device.promise";
+ let summary = [{promises a named device and optional queue affinity}];
+ let description = [{
+ Specifies that an annotated operation or scope is only allowed to execute on
+ a specific device that has not yet been declared and optionally a set of
+ queues (0-64) provided. Operations will not run on other queues. If the
+ queue mask is omitted then any queue on the device is allowed to execute the
+ specified operations.
+
+ This is used in input programs to assign operations to particular devices
+ prior to the devices being declared. This allows device categories to be
+ referenced in the program as produced from the frontend and for those
+ device specifications to be provided later on during compilation.
+ Verification is performed as part of the ResolveDevicePromisesPass.
+
+ Example:
+ ```mlir
+ // Any queue on whatever @device_a will be after declaration.
+ #hal.device.promise<@device_a>
+ // Queues 4 and 5 on whatever @device_b will be after declaration.
+ #hal.device.promise<@device_b, [4, 5]>
+ ```
+ }];
+
+ let parameters = (ins
+ AttrParameter<"StringAttr", "">:$device,
+ AttrParameter<"int64_t", "">:$queue_mask
);
let hasCustomAssemblyFormat = 1;
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALInterfaces.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALInterfaces.td
index b1df6ff..fb38415 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALInterfaces.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALInterfaces.td
@@ -8,24 +8,53 @@
#define IREE_DIALECT_HAL_INTERFACES
include "iree/compiler/Dialect/Util/IR/UtilBase.td"
+include "mlir/IR/BuiltinAttributeInterfaces.td"
-def HAL_MatchAttrInterface :
- AttrInterface<"MatchAttrInterface"> {
+//===----------------------------------------------------------------------===//
+// IREE::HAL::DeviceInitializationAttrInterface
+//===----------------------------------------------------------------------===//
+
+def HAL_DeviceInitializationAttrInterface :
+ AttrInterface<"DeviceInitializationAttrInterface", [
+ TypedAttrInterface,
+ ]> {
let description = [{
- An attribute that can be used in `hal.*.match.*` expressions.
- Each attribute defines some subexpression that can be expanded to one or
- more operations that performs the actual query and matching logic.
+ Interface for attributes controlling device initialization.
}];
let methods = [
InterfaceMethod<
- [{
- Builds a set of operations that evaluate to a boolean (i1) value
- indicating whether the expression tree represented by the match
- attribute is true for the given value.
+ /*desc=*/[{
+ prints a string description of the initialization specification for
+ inclusion in error messages. May include internal newlines but no
+ newline is expected at the end.
}],
- "Value", "buildConditionExpression",
- (ins "Location":$loc, "Value":$device, "OpBuilder":$builder)
+ /*retTy=*/"void",
+ /*methodName=*/"printStatusDescription",
+ /*args=*/(ins "llvm::raw_ostream &":$os),
+ /*methodBody=*/[{}]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Builds a `util.initializer` body responsible for initializing a device
+ global. Returns the device value that should be stored into the global.
+ The name provided is an informal identifier that can be used to produce
+ user-level error messages that reference the device.
+
+ The provided `buildDeviceTargetMatch` function will be called with a
+ `!hal.device` SSA value and a device target specification and should
+ return an `i1` value indicating whether the given device matches the
+ specification. If the device always matches (rare!) a null value may
+ be returned.
+ }],
+ /*retTy=*/"Value",
+ /*methodName=*/"buildDeviceEnumeration",
+ /*args=*/(ins
+ "Location":$loc,
+ "IREE::HAL::BuildDeviceTargetMatchFn":$buildDeviceTargetMatch,
+ "OpBuilder &":$builder
+ ),
+ /*methodBody=*/[{}]
>,
];
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
index d082847..bf596ce 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
@@ -90,12 +90,14 @@
for (auto source : op.getSources()) {
auto it =
uniqueSources.insert(std::make_pair(source, orderedSources.size()));
- if (it.second)
+ if (it.second) {
orderedSources.push_back(source);
+ }
resultMapping.push_back(it.first->second);
}
- if (orderedSources.size() == op.getSources().size())
+ if (orderedSources.size() == op.getSources().size()) {
return failure();
+ }
auto newOp = rewriter.create<TensorBarrierOp>(op.getLoc(), orderedSources,
op.getSignalFence());
SmallVector<Value> newResults;
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index b8c0bab..538b32d 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -431,15 +431,16 @@
void TensorImportOp::build(OpBuilder &builder, OperationState &result,
Type resultType, Value source,
- TypeAttr targetEncoding, StringAttr name) {
+ TypeAttr targetEncoding, StringAttr name,
+ Attribute affinity) {
build(builder, result, resultType, source, targetEncoding,
- /*waitFence=*/Value{}, name);
+ /*waitFence=*/Value{}, name, affinity);
}
void TensorImportOp::build(OpBuilder &builder, OperationState &result,
Type resultType, Value source,
TypeAttr targetEncoding, Value waitFence,
- StringAttr name) {
+ StringAttr name, Attribute affinity) {
auto shapedType = llvm::cast<ShapedType>(resultType);
assert((isa<IREE::HAL::BufferViewType>(source.getType()) ||
shapedType.hasStaticShape()) &&
@@ -454,20 +455,7 @@
builder.getIndexAttr(i)));
}
build(builder, result, resultType, source, targetEncoding, dynamicDims,
- waitFence, name);
-}
-
-Value TensorImportOp::getTiedResult(unsigned resultIndex) {
- return IREE::Util::TiedOpInterface::findTiedBaseValue(getSource());
-}
-
-::std::optional<unsigned>
-TensorImportOp::getTiedResultOperandIndex(unsigned resultIndex) {
- return {0}; // source
-}
-
-SmallVector<int64_t> TensorImportOp::getTiedResultOperandIndices() {
- return {0}; // source
+ waitFence, name, affinity);
}
static LogicalResult verifyTypeStorageCompatibility(Operation *op,
@@ -530,23 +518,12 @@
void TensorExportOp::build(OpBuilder &builder, OperationState &result,
Type resultType, Value source,
- TypeAttr sourceEncoding, StringAttr name) {
+ TypeAttr sourceEncoding, StringAttr name,
+ Attribute affinity) {
auto dynamicDims =
IREE::Util::buildDynamicDimsForValue(result.location, source, builder);
- build(builder, result, resultType, source, sourceEncoding, dynamicDims, name);
-}
-
-Value TensorExportOp::getTiedResult(unsigned resultIndex) {
- return IREE::Util::TiedOpInterface::findTiedBaseValue(getSource());
-}
-
-::std::optional<unsigned>
-TensorExportOp::getTiedResultOperandIndex(unsigned resultIndex) {
- return {0}; // source
-}
-
-SmallVector<int64_t> TensorExportOp::getTiedResultOperandIndices() {
- return {0}; // source
+ build(builder, result, resultType, source, sourceEncoding, dynamicDims, name,
+ affinity);
}
LogicalResult TensorExportOp::verify() {
@@ -1060,6 +1037,23 @@
}
//===----------------------------------------------------------------------===//
+// hal.device.resolve
+//===----------------------------------------------------------------------===//
+
+void DeviceResolveOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ for (auto result : getResults()) {
+ if (isa<IREE::HAL::DeviceType>(result.getType())) {
+ setNameFn(result, "device");
+ } else if (isa<IREE::HAL::AllocatorType>(result.getType())) {
+ setNameFn(result, "allocator");
+ } else if (isa<IntegerType>(result.getType())) {
+ setNameFn(result, "queue_affinity");
+ }
+ }
+}
+
+//===----------------------------------------------------------------------===//
// hal.device.allocator
//===----------------------------------------------------------------------===//
@@ -1480,17 +1474,9 @@
Value ExecutableVariantOp::buildCondition(Value device, OpBuilder &builder) {
// Base case dependent on target information.
- // TODO(multi-device): condition on device target ID and other queries that
- // may be useful for disambiguating two devices that support the same
- // executable targets. Today executable targets are unique per device target
- // but that need not always be the case.
- auto i1Type = builder.getI1Type();
- Value selected = builder
- .create<IREE::HAL::DeviceQueryOp>(
- getLoc(), i1Type, i1Type, device,
- builder.getStringAttr("hal.executable.format"),
- getTarget().getFormat(), builder.getZeroAttr(i1Type))
- .getValue();
+ Value selected = IREE::HAL::DeviceQueryOp::createI1(
+ getLoc(), device, "hal.executable.format",
+ getTarget().getFormat().getValue(), builder);
// Factor in variant condition region, if any.
auto conditionOp = getConditionOp();
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
index 599c1ff..dff5438 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -98,11 +98,6 @@
def HAL_TensorImportOp : HAL_PureOp<"tensor.import", [
AttrSizedOperandSegments,
- DeclareOpInterfaceMethods<Util_TiedOpInterface, [
- "getTiedResult",
- "getTiedResultOperandIndex",
- "getTiedResultOperandIndices",
- ]>,
Util_ShapeAwareOp,
]> {
let summary = [{imports a tensor from a HAL buffer view}];
@@ -125,13 +120,15 @@
TypeAttr:$target_encoding,
HAL_ShapeDynamicDims:$target_dims,
Optional<HAL_Fence>:$wait_fence,
- OptionalAttr<StrAttr>:$name
+ OptionalAttr<StrAttr>:$name,
+ OptionalAttr<AnyAttr>:$affinity
);
let results = (outs
AnyTensor:$target
);
let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
(`wait` `(` $wait_fence^ `)` `=` `` `>`)?
$source
($name^)?
@@ -145,14 +142,16 @@
"Type":$resultType,
"Value":$source,
"TypeAttr":$targetEncoding,
- "StringAttr":$name
+ "StringAttr":$name,
+ "Attribute":$affinity
)>,
OpBuilder<(ins
"Type":$resultType,
"Value":$source,
"TypeAttr":$targetEncoding,
"Value":$waitFence,
- "StringAttr":$name
+ "StringAttr":$name,
+ "Attribute":$affinity
)>,
];
@@ -167,11 +166,6 @@
}
def HAL_TensorExportOp : HAL_PureOp<"tensor.export", [
- DeclareOpInterfaceMethods<Util_TiedOpInterface, [
- "getTiedResult",
- "getTiedResultOperandIndex",
- "getTiedResultOperandIndices",
- ]>,
Util_ShapeAwareOp,
]> {
let summary = [{exports a tensor to a HAL buffer view}];
@@ -190,13 +184,15 @@
AnyTensor:$source,
TypeAttr:$source_encoding,
HAL_ShapeDynamicDims:$source_dims,
- OptionalAttr<StrAttr>:$name
+ OptionalAttr<StrAttr>:$name,
+ OptionalAttr<AnyAttr>:$affinity
);
let results = (outs
AnyTypeOf<[HAL_Buffer, HAL_BufferView]>:$target
);
let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
$source
($name^)?
`:`
@@ -211,7 +207,8 @@
"Type":$resultType,
"Value":$source,
"TypeAttr":$sourceEncoding,
- "StringAttr":$name
+ "StringAttr":$name,
+ "Attribute":$affinity
)>,
];
@@ -273,13 +270,15 @@
AnyTensor:$source,
HAL_ShapeDynamicDims:$source_dims,
AnyTypeOf<[HAL_Buffer, HAL_BufferView]>:$storage,
- Optional<HAL_Fence>:$wait_fence
+ Optional<HAL_Fence>:$wait_fence,
+ OptionalAttr<AnyAttr>:$affinity
);
let results = (outs
AnyTensor:$result
);
let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
(`wait` `(` $wait_fence^ `)` `=` `` `>`)?
$source `:` type($source) (`{` $source_dims^ `}`)?
`to`
@@ -1604,6 +1603,41 @@
let opDocGroup = OpGroupDeviceOps in {
+def HAL_DeviceResolveOp : HAL_PureOp<"device.resolve", [
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+]> {
+ let summary = [{resolves device handles based on affinity}];
+ let description = [{
+ Examples:
+ ```
+ // Returns a HAL device.
+ = hal.device.resolve on(#something) : !hal.device
+ // Returns a HAL device, allocator, and (optional) queue affinity.
+ = hal.device.resolve on(#something) : !hal.device, !hal.allocator, i64
+ // Returns a HAL allocator and (optional) queue affinity.
+ = hal.device.resolve on(#something) : !hal.allocator, i64
+ // Returns "any" device. Should only be used as a fallback.
+ = hal.device.resolve : !hal.device
+ ```
+ }];
+
+ let arguments = (ins
+ OptionalAttr<HAL_DeviceAffinityAttr>:$affinity
+ );
+ let results = (outs
+ Variadic<AnyTypeOf<[
+ HAL_Device,
+ HAL_Allocator,
+ I64,
+ ]>>:$results
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ attr-dict `:` type($results)
+ }];
+}
+
def HAL_DeviceAllocatorOp : HAL_PureOp<"device.allocator", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
]> {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h
index ef67024..ef6417d 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h
@@ -33,6 +33,13 @@
namespace mlir::iree_compiler::IREE::HAL {
+class DeviceTargetAttr;
+class TargetRegistry;
+
+using BuildDeviceTargetMatchFn = std::function<Value(
+ Location loc, Value device, IREE::HAL::DeviceTargetAttr targetAttr,
+ OpBuilder &builder)>;
+
#include "iree/compiler/Dialect/HAL/IR/HALAttrInterfaces.h.inc" // IWYU pragma: export
#include "iree/compiler/Dialect/HAL/IR/HALOpInterfaces.h.inc" // IWYU pragma: export
#include "iree/compiler/Dialect/HAL/IR/HALTypeInterfaces.h.inc" // IWYU pragma: export
@@ -113,7 +120,8 @@
struct DeviceType
: public Type::TypeBase<DeviceType, Type, TypeStorage,
- mlir::OpTrait::IREE::Util::ImplicitlyCaptured> {
+ mlir::OpTrait::IREE::Util::ImplicitlyCaptured,
+ IREE::Util::ReferenceTypeInterface::Trait> {
using Base::Base;
static constexpr StringLiteral name = "hal.device";
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir
index 95ea673..ec0cbdd 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir
@@ -1,4 +1,5 @@
// RUN: iree-opt --allow-unregistered-dialect --split-input-file --mlir-print-local-scope %s | FileCheck %s
+// RUN: iree-opt --inline --allow-unregistered-dialect --split-input-file --mlir-print-local-scope %s | FileCheck %s --check-prefix=CHECK-INLINE
// CHECK-LABEL: descriptor_set_layout_binding.basic
"descriptor_set_layout_binding.basic"() {
@@ -60,25 +61,117 @@
// -----
-// CHECK-LABEL: "device.targets"
-"device.targets"() {
- // CHECK-SAME: target_0 = #hal.device.target<"a">
- target_0 = #hal.device.target<"a">,
- // CHECK-SAME: target_1 = #hal.device.target<"b", {config}>,
- target_1 = #hal.device.target<"b", {config}>,
- // CHECK-SAME: target_2 = #hal.device.target<"c", {config}, [#hal.executable.target<"llvm-cpu", "f">]>,
- target_2 = #hal.device.target<"c", {config}, [#hal.executable.target<"llvm-cpu", "f">]>,
- // CHECK-SAME: target_3 = #hal.device.target<"d", [#hal.executable.target<"llvm-cpu", "f">]>
- target_3 = #hal.device.target<"d", [#hal.executable.target<"llvm-cpu", "f">]>
+// CHECK-LABEL: "device.aliases"
+"device.aliases"() {
+ // CHECK-SAME: alias_0 = #hal.device.alias<"a"> : !hal.device
+ alias_0 = #hal.device.alias<"a"> : !hal.device,
+ // CHECK-SAME: alias_1 = #hal.device.alias<"b", {}> : !hal.device
+ alias_1 = #hal.device.alias<"b", {}> : !hal.device,
+ // CHECK-SAME: alias_2 = #hal.device.alias<"c"[4]> : !hal.device
+ alias_2 = #hal.device.alias<"c"[4]> : !hal.device,
+ // CHECK-SAME: alias_3 = #hal.device.alias<"d", {config = 123 : index}>
+ alias_3 = #hal.device.alias<"d", {config = 123 : index}> : !hal.device
} : () -> ()
// -----
-"affinity.queue"() {
- // CHECK: any = #hal.affinity.queue<*>
- any = #hal.affinity.queue<*>,
- // CHECK: q0 = #hal.affinity.queue<[0]>
- q0 = #hal.affinity.queue<[0]>,
- // CHECK: q123 = #hal.affinity.queue<[1, 2, 3]>
- q123 = #hal.affinity.queue<[1, 2, 3]>
+// CHECK-LABEL: "device.targets"
+"device.targets"() {
+ // CHECK-SAME: target_0 = #hal.device.target<"a"> : !hal.device
+ target_0 = #hal.device.target<"a"> : !hal.device,
+ // CHECK-SAME: target_1 = #hal.device.target<"b", {config}> : !hal.device,
+ target_1 = #hal.device.target<"b", {config}> : !hal.device,
+ // CHECK-SAME: target_2 = #hal.device.target<"c", {config}, [#hal.executable.target<"llvm-cpu", "f">]> : !hal.device,
+ target_2 = #hal.device.target<"c", {config}, [#hal.executable.target<"llvm-cpu", "f">]> : !hal.device,
+ // CHECK-SAME: target_3 = #hal.device.target<"d", [#hal.executable.target<"llvm-cpu", "f">]> : !hal.device
+ target_3 = #hal.device.target<"d", [#hal.executable.target<"llvm-cpu", "f">]> : !hal.device
} : () -> ()
+
+// -----
+
+// CHECK: util.global private @device_a = #hal.device.target<"a"> : !hal.device
+util.global private @device_a = #hal.device.target<"a"> : !hal.device
+// CHECK: util.global private @device_0 = #hal.device.ordinal<0> : !hal.device
+util.global private @device_0 = #hal.device.ordinal<0> : !hal.device
+
+// -----
+
+// CHECK: util.global private @main = #hal.device.select<[
+// CHECK-SAME: #hal.device.target<"a"> : !hal.device
+// CHECK-SAME: ]> : !hal.device
+util.global private @main = #hal.device.select<[
+ #hal.device.target<"a"> : !hal.device
+]> : !hal.device
+// CHECK: util.global private @optional = #hal.device.select<[
+// CHECK-SAME: #hal.device.target<"b"> : !hal.device,
+// CHECK-SAME: #hal.device.ordinal<1> : !hal.device,
+// CHECK-SAME: #hal.device.fallback<@main> : !hal.device
+// CHECK-SAME: ]> : !hal.device
+util.global private @optional = #hal.device.select<[
+ #hal.device.target<"b"> : !hal.device,
+ #hal.device.ordinal<1> : !hal.device,
+ #hal.device.fallback<@main> : !hal.device
+]> : !hal.device
+
+// -----
+
+util.global private @device : !hal.device
+"device.affinity"() {
+ // CHECK: device_any = #hal.device.affinity<@device>
+ device_any = #hal.device.affinity<@device>,
+ // CHECK: device_queue_0 = #hal.device.affinity<@device, [0]>
+ device_queue_0 = #hal.device.affinity<@device, [0]>,
+ // CHECK: device_queue_123 = #hal.device.affinity<@device, [1, 2, 3]>
+ device_queue_123 = #hal.device.affinity<@device, [1, 2, 3]>
+} : () -> ()
+
+// -----
+
+"device.promise"() {
+ // CHECK: device_any = #hal.device.promise<@device>
+ device_any = #hal.device.promise<@device>,
+ // CHECK: device_queue_0 = #hal.device.promise<@device, [0]>
+ device_queue_0 = #hal.device.promise<@device, [0]>,
+ // CHECK: device_queue_123 = #hal.device.promise<@device, [1, 2, 3]>
+ device_queue_123 = #hal.device.promise<@device, [1, 2, 3]>
+} : () -> ()
+
+// -----
+
+// Tests that differing device affinities blocks inlining.
+// Here the @inline_target is using the default affinity specified on the
+// module and only functions also using the default affinity or a matching
+// specified affinity will be inlined. The #hal.device.affinity controls this
+// behavior and in the future we could allow inlining of compatible devices,
+// the same device on differing queues, etc.
+
+builtin.module attributes {
+ stream.affinity = #hal.device.affinity<@device_a>
+} {
+ util.global private @device_a : !hal.device
+ util.global private @device_b : !hal.device
+ // CHECK-INLINE: util.func public @inline_target
+ util.func public @inline_target() -> (i32, i32) {
+ // CHECK-INLINE-NOT: util.call @compat_inlinable
+ // CHECK-INLINE: %[[A:.+]] = arith.constant 0
+ %a = util.call @compat_inlinable() : () -> i32
+ // CHECK-INLINE: %[[B:.+]] = util.call @noncompat_inlinable
+ %b = util.call @noncompat_inlinable() : () -> i32
+ // CHECK-INLINE: util.return %[[A]], %[[B]]
+ util.return %a, %b : i32, i32
+ }
+ // CHECK-INLINE-NOT: util.func private @compat_inlinable
+ util.func private @compat_inlinable() -> i32 attributes {
+ stream.affinity = #hal.device.affinity<@device_a>
+ } {
+ %c0 = arith.constant 0 : i32
+ util.return %c0 : i32
+ }
+ // CHECK-INLINE: util.func private @noncompat_inlinable
+ util.func private @noncompat_inlinable() -> i32 attributes {
+ stream.affinity = #hal.device.affinity<@device_b>
+ } {
+ %c1 = arith.constant 1 : i32
+ util.return %c1 : i32
+ }
+}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir
index 6da4866..7408ad9 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir
@@ -144,9 +144,10 @@
// CHECK-SAME: %[[DEVICE:.+]]: !hal.device,
// CHECK-SAME: %[[LAYOUT0:.+]]: !hal.pipeline_layout,
// CHECK-SAME: %[[LAYOUT1:.+]]: !hal.pipeline_layout
-util.func public @executable_create(%device: !hal.device,
- %layout0: !hal.pipeline_layout,
- %layout1: !hal.pipeline_layout) {
+util.func public @executable_create(
+ %device: !hal.device,
+ %layout0: !hal.pipeline_layout,
+ %layout1: !hal.pipeline_layout) {
// CHECK: = hal.executable.create
// CHECK-SAME: device(%[[DEVICE]] : !hal.device)
// CHECK-SAME: target(@exe::@binary1)
@@ -163,16 +164,17 @@
// CHECK-SAME: %[[DEVICE:.+]]: !hal.device,
// CHECK-SAME: %[[LAYOUT0:.+]]: !hal.descriptor_set_layout,
// CHECK-SAME: %[[LAYOUT1:.+]]: !hal.descriptor_set_layout
-util.func public @pipeline_layout_create(%device: !hal.device,
- %layout0: !hal.descriptor_set_layout,
- %layout1: !hal.descriptor_set_layout) {
+util.func public @pipeline_layout_create(
+ %device: !hal.device,
+ %layout0: !hal.descriptor_set_layout,
+ %layout1: !hal.descriptor_set_layout) {
// CHECK: hal.pipeline_layout.create
// CHECK-SAME: device(%[[DEVICE]] : !hal.device)
// CHECK-SAME: push_constants(1)
// CHECK-SAME: layouts([%[[LAYOUT0]], %[[LAYOUT1]]]) : !hal.pipeline_layout
%0 = hal.pipeline_layout.create device(%device : !hal.device)
- push_constants(1)
- layouts([%layout0, %layout1]) : !hal.pipeline_layout
+ push_constants(1)
+ layouts([%layout0, %layout1]) : !hal.pipeline_layout
util.return
}
@@ -197,8 +199,9 @@
// CHECK-LABEL: @unresolved_workload
// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device,
// CHECK-SAME: %[[WORKLOAD_0:.+]]: index, %[[WORKLOAD_1:.+]]: index)
-util.func public @unresolved_workload(%device: !hal.device,
- %workload_0: index, %workload_1: index) -> (index, index, index) {
+util.func public @unresolved_workload(
+ %device: !hal.device,
+ %workload_0: index, %workload_1: index) -> (index, index, index) {
// CHECK: %[[WORKGROUP_X:.+]], %[[WORKGROUP_Y:.+]], %[[WORKGROUP_Z:.+]] =
// CHECK-SAME: hal.executable.calculate_workgroups
// CHECK-SAME: device(%[[DEVICE]] : !hal.device)
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/Devices/LocalDevice.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/Devices/LocalDevice.cpp
index 302ad62..e482a49 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/Devices/LocalDevice.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/Devices/LocalDevice.cpp
@@ -87,7 +87,7 @@
Value LocalDevice::buildDeviceTargetMatch(
Location loc, Value device, IREE::HAL::DeviceTargetAttr targetAttr,
OpBuilder &builder) const {
- return buildDeviceIDAndExecutableFormatsMatch(
+ return IREE::HAL::DeviceTargetAttr::buildDeviceIDAndExecutableFormatsMatch(
loc, device, "local*", targetAttr.getExecutableTargets(), builder);
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.cpp
index 1695c50..2e51311 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.cpp
@@ -7,8 +7,6 @@
#include "iree/compiler/Dialect/HAL/Target/TargetDevice.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
namespace mlir::iree_compiler::IREE::HAL {
@@ -16,56 +14,9 @@
Value TargetDevice::buildDeviceTargetMatch(
Location loc, Value device, IREE::HAL::DeviceTargetAttr targetAttr,
OpBuilder &builder) const {
- return buildDeviceIDAndExecutableFormatsMatch(
+ return IREE::HAL::DeviceTargetAttr::buildDeviceIDAndExecutableFormatsMatch(
loc, device, targetAttr.getDeviceID(), targetAttr.getExecutableTargets(),
builder);
}
-Value buildDeviceIDAndExecutableFormatsMatch(
- Location loc, Value device, StringRef deviceIDPattern,
- ArrayRef<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs,
- OpBuilder &builder) {
- // Match first on the device ID, as that's the top-level filter.
- Value idMatch = IREE::HAL::DeviceQueryOp::createI1(
- loc, device, "hal.device.id", deviceIDPattern, builder);
-
- // If there are executable formats defined we should check at least one of
- // them is supported.
- if (executableTargetAttrs.empty()) {
- return idMatch; // just device ID
- } else {
- auto ifOp = builder.create<scf::IfOp>(loc, builder.getI1Type(), idMatch,
- true, true);
- auto thenBuilder = ifOp.getThenBodyBuilder();
- Value anyFormatMatch = buildExecutableFormatMatch(
- loc, device, executableTargetAttrs, thenBuilder);
- thenBuilder.create<scf::YieldOp>(loc, anyFormatMatch);
- auto elseBuilder = ifOp.getElseBodyBuilder();
- Value falseValue = elseBuilder.create<arith::ConstantIntOp>(loc, 0, 1);
- elseBuilder.create<scf::YieldOp>(loc, falseValue);
- return ifOp.getResult(0);
- }
-}
-
-Value buildExecutableFormatMatch(
- Location loc, Value device,
- ArrayRef<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs,
- OpBuilder &builder) {
- if (executableTargetAttrs.empty())
- return builder.create<arith::ConstantIntOp>(loc, 1, 1);
- Value anyFormatMatch;
- for (auto executableTargetAttr : executableTargetAttrs) {
- Value formatMatch = IREE::HAL::DeviceQueryOp::createI1(
- loc, device, "hal.executable.format",
- executableTargetAttr.getFormat().getValue(), builder);
- if (!anyFormatMatch) {
- anyFormatMatch = formatMatch;
- } else {
- anyFormatMatch =
- builder.create<arith::OrIOp>(loc, anyFormatMatch, formatMatch);
- }
- }
- return anyFormatMatch;
-}
-
} // namespace mlir::iree_compiler::IREE::HAL
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.h b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.h
index 2ce53f4..f9cb567 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.h
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.h
@@ -48,21 +48,6 @@
// various stages.
};
-// Builds an expression that returns an i1 indicating whether the given
-// |device| matches the device ID string pattern and executable target
-// requirements.
-Value buildDeviceIDAndExecutableFormatsMatch(
- Location loc, Value device, StringRef deviceIDPattern,
- ArrayRef<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs,
- OpBuilder &builder);
-
-// Builds a match expression that returns an i1 indicating whether the given
-// |device| supports any one of the |executableTargetAttrs|.
-Value buildExecutableFormatMatch(
- Location loc, Value device,
- ArrayRef<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs,
- OpBuilder &builder);
-
} // namespace mlir::iree_compiler::IREE::HAL
#endif // IREE_COMPILER_DIALECT_HAL_TARGET_TARGETDEVICE_H_
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetOptions.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetOptions.cpp
index c00cb63..26fcae5 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetOptions.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetOptions.cpp
@@ -22,10 +22,21 @@
// initialized, so targetBackendsFlags needs to be here to be initialized
// first.
binder.list<std::string>(
- "iree-hal-target-backends", targets,
+ "iree-hal-target-backends", legacyTargetBackends,
llvm::cl::desc("Target backends for executable compilation."),
llvm::cl::ZeroOrMore, llvm::cl::cat(halTargetOptionsCategory));
+ binder.list<std::string>("iree-hal-target-device", targetDevices,
+ llvm::cl::desc("Target device specifications."),
+ llvm::cl::ZeroOrMore,
+ llvm::cl::cat(halTargetOptionsCategory));
+ binder.opt<std::string>(
+ "iree-hal-default-device", defaultDevice,
+ llvm::cl::desc("Which device is considered the default when no device "
+ "affinity is specified. Either the device name when names "
+ "are specified or the numeric ordinal of the device."),
+ llvm::cl::cat(halTargetOptionsCategory));
+
binder.opt<int>(
"iree-hal-executable-debug-level", debugLevel,
llvm::cl::desc("Debug level for executable translation (0-3)"),
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetOptions.h b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetOptions.h
index 711e0e1..08601ea 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetOptions.h
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetOptions.h
@@ -17,8 +17,35 @@
// TODO(benvanik): remove this and replace with the pass pipeline options.
// Controls executable translation targets.
struct TargetOptions {
- // TODO(benvanik): multiple targets of the same type, etc.
- std::vector<std::string> targets;
+ // TODO(benvanik): remove the legacy flag once users are switched to devices.
+ std::vector<std::string> legacyTargetBackends;
+
+ // Specifies target devices to assign to the program. May be omitted if the
+ // program already has devices assigned or no devices are required (host
+ // program not using the HAL).
+ //
+ // Two devices, one the local host device and the other a Vulkan device:
+ // `local`, `vulkan`
+ //
+ // One device selecting between Vulkan if available and otherwise use the
+ // local host device:
+ // `vulkan,local`
+ //
+ // Two CUDA devices selected by runtime ordinal; at runtime two --device=
+ // flags are required to configure both devices:
+ // `cuda[0]`, `cuda[1]`
+ //
+ // A fully-defined target specification:
+ // `#hal.device.target<"cuda", {...}, [#hal.executable.target<...>]>`
+ //
+ // Named device for defining a reference by #hal.device.promise<@some_name>:
+ // `some_name=vulkan`
+ std::vector<std::string> targetDevices;
+
+ // Which device is considered the default when no device affinity is specified
+ // on a particular operation. Accepts string names matching those specified
+ // in the target devices list or numeric ordinals if names were omitted.
+ std::string defaultDevice;
// Coarse debug level for executable translation across all targets.
// Each target backend can use this to control its own flags, with values
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/AssignLegacyTargetDevices.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/AssignLegacyTargetDevices.cpp
new file mode 100644
index 0000000..38b0af7
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/AssignLegacyTargetDevices.cpp
@@ -0,0 +1,117 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <memory>
+#include <utility>
+
+#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/HAL/Target/TargetBackend.h"
+#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
+#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/raw_ostream.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir::iree_compiler::IREE::HAL {
+
+#define GEN_PASS_DEF_ASSIGNLEGACYTARGETDEVICESPASS
+#include "iree/compiler/Dialect/HAL/Transforms/Passes.h.inc"
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// --iree-hal-assign-legacy-target-devices
+//===----------------------------------------------------------------------===//
+
+struct AssignLegacyTargetDevicesPass
+ : public IREE::HAL::impl::AssignLegacyTargetDevicesPassBase<
+ AssignLegacyTargetDevicesPass> {
+ using IREE::HAL::impl::AssignLegacyTargetDevicesPassBase<
+ AssignLegacyTargetDevicesPass>::AssignLegacyTargetDevicesPassBase;
+
+ void runOnOperation() override {
+ auto moduleOp = getOperation();
+
+ // If no targets are specified we can't do anything - another pass earlier
+ // in the pipeline will have had to add the targets.
+ if (targetBackends.empty()) {
+ return;
+ }
+
+ // Check to see if targets are already specified and if so then no-op the
+ // pass so that we don't mess with whatever the user intended.
+ auto existingTargetsAttr =
+ moduleOp->getAttrOfType<ArrayAttr>("hal.device.targets");
+ if (existingTargetsAttr) {
+ return;
+ }
+
+ // If there are any device globals declared then bail as it means the user
+ // has already materialized the devices they want.
+ for (auto globalOp : moduleOp.getOps<IREE::Util::GlobalOpInterface>()) {
+ if (isa<IREE::HAL::DeviceType>(globalOp.getGlobalType())) {
+ return;
+ }
+ }
+
+ llvm::SmallDenseSet<Attribute> targetAttrSet;
+ SmallVector<Attribute> targetAttrs;
+ for (const auto &targetBackendName : targetBackends) {
+ auto targetBackend = targetRegistry->getTargetBackend(targetBackendName);
+ if (!targetBackend) {
+ auto diagnostic = emitError(moduleOp.getLoc())
+ << "target backend '" << targetBackendName
+ << "' not registered; registered backends: [";
+ llvm::interleaveComma(targetRegistry->getRegisteredTargetBackends(),
+ diagnostic);
+ diagnostic << "]";
+ return signalPassFailure();
+ }
+ auto targetDeviceName = targetBackend->getLegacyDefaultDeviceID();
+ auto targetDevice = targetRegistry->getTargetDevice(targetDeviceName);
+ if (!targetDevice) {
+ auto diagnostic = emitError(moduleOp.getLoc())
+ << "target device '" << targetDeviceName
+ << "' not registered; registered devices: [";
+ llvm::interleaveComma(targetRegistry->getRegisteredTargetDevices(),
+ diagnostic);
+ diagnostic << "]";
+ return signalPassFailure();
+ }
+
+ // Ask the target backend for its default device specification attribute.
+ auto targetAttr = targetDevice->getDefaultDeviceTarget(
+ moduleOp.getContext(), *targetRegistry.value);
+ if (!targetAttr) {
+ emitError(moduleOp.getLoc()) << "no default device targets available";
+ return signalPassFailure();
+ }
+ if (!targetAttrSet.contains(targetAttr)) {
+ targetAttrSet.insert(targetAttr);
+ targetAttrs.push_back(targetAttr);
+ }
+ }
+
+ Attribute targetsAttr;
+ if (targetAttrs.size() == 1) {
+ targetsAttr = targetAttrs.front();
+ } else {
+ targetsAttr =
+ IREE::HAL::DeviceSelectAttr::get(moduleOp.getContext(), targetAttrs);
+ }
+ moduleOp->setAttr("hal.device.targets",
+ ArrayAttr::get(moduleOp.getContext(), targetsAttr));
+ }
+};
+
+} // namespace
+
+} // namespace mlir::iree_compiler::IREE::HAL
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/AssignTargetDevices.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/AssignTargetDevices.cpp
index 7e0b5d4..4892e44 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/AssignTargetDevices.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/AssignTargetDevices.cpp
@@ -1,4 +1,4 @@
-// Copyright 2021 The IREE Authors
+// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -14,6 +14,7 @@
#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/raw_ostream.h"
+#include "mlir/AsmParser/AsmParser.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -31,79 +32,236 @@
// --iree-hal-assign-target-devices
//===----------------------------------------------------------------------===//
+// Strips leading and trailing whitespace from |value|.
+static StringRef stripWhitespace(StringRef value) {
+ while (!value.empty() && llvm::isSpace(value.front())) {
+ value = value.drop_front(1);
+ }
+ while (!value.empty() && llvm::isSpace(value.back())) {
+ value = value.drop_back(1);
+ }
+ return value;
+}
+
+// Strips leading and trailing double quotes from |value| if both exist.
+static StringRef stripQuotes(StringRef value) {
+ value = stripWhitespace(value);
+ StringRef unquoted = value;
+ if (unquoted.consume_front("\"") && unquoted.consume_back("\"")) {
+ return stripWhitespace(unquoted);
+ }
+ return value;
+}
+
+// Consumes a leading `name=` literal.
+// Returns the `name` and leaves remaining characters after `=` in |value|.
+// Returns an empty string if no name literal is present.
+static StringRef consumeNameLiteral(StringRef &value) {
+ value = stripWhitespace(value);
+ const size_t splitIdx = value.find('=');
+ if (splitIdx == std::string::npos) {
+ return "";
+ }
+ for (size_t i = 0; i < splitIdx; ++i) {
+ const char c = value[i];
+ if (!llvm::isAlnum(c) && c != '_') {
+ return value;
+ }
+ }
+ const StringRef name = value.substr(0, splitIdx);
+ value = stripWhitespace(value.substr(splitIdx + 1));
+ return stripWhitespace(name);
+}
+
+// Consumes the first portion of |value| corresponding to a device alias.
+// Expects: `abc` or `abc[123]` (and allows `"abc"[123]`).
+// Only valid literals will be parsed (a-z0-9_).
+// Returns the device ID and optional ordinal. All other unconsumed characters
+// will remain in |value| upon return.
+static std::pair<StringRef, std::optional<int64_t>>
+consumeAliasLiteral(StringRef &value) {
+ value = stripWhitespace(value);
+ const size_t splitIdx = value.find(',');
+ StringRef part =
+ splitIdx == std::string::npos ? value : value.substr(0, splitIdx);
+
+ StringRef deviceID = part;
+ std::optional<int64_t> ordinal;
+
+ const size_t ordinalIdx = part.find('[');
+ if (ordinalIdx != std::string::npos) {
+ deviceID = part.substr(0, ordinalIdx);
+ StringRef ordinalStr = part.substr(ordinalIdx + 1);
+ APInt ordinalInt;
+ if (!ordinalStr.consumeInteger(10, ordinalInt)) {
+ ordinal = ordinalInt.getSExtValue();
+ }
+ }
+
+ value = stripWhitespace(value.substr(part.size()));
+ return std::make_pair(stripQuotes(deviceID), ordinal);
+}
+
+struct TargetSpec {
+ StringAttr name;
+ TypedAttr attr;
+};
+
+// Parses the user-provided string into a target spec.
+//
+// Supports attributes:
+// #hal.device.alias<...>
+// #hal.device.target<...>
+// #hal.device.select<...>
+// #hal.device.fallback<...>
+// Supports convenience shorthand:
+// ...,... -> #hal.device.select<[...,...]>
+// target -> #hal.device.alias<"target">
+// target[0] -> #hal.device.alias<"target"[0]>
+// "target"[0] -> #hal.device.alias<"target"[0]>
+// Supports name= prefixes:
+// name=... -> ...
+static FailureOr<TargetSpec> parseTargetSpec(Location loc,
+ StringRef targetSpecStr) {
+ auto *context = loc.getContext();
+ targetSpecStr = stripQuotes(targetSpecStr);
+
+ // Check for a name prefix and strip it from the spec.
+ StringRef name = consumeNameLiteral(targetSpecStr);
+ StringAttr nameAttr =
+ name.empty() ? StringAttr{} : StringAttr::get(context, name);
+
+ // Parse the spec attributes.
+ SmallVector<Attribute> attrs;
+ while (!targetSpecStr.empty()) {
+ TypedAttr typedAttr;
+ if (targetSpecStr.starts_with('#')) {
+ // MLIR attribute.
+ size_t numRead = 0;
+ auto parsedAttr = mlir::parseAttribute(targetSpecStr, context,
+ /*type=*/nullptr, &numRead);
+ if (!parsedAttr) {
+ return mlir::emitError(loc) << "failed to parse target spec prefix `"
+ << targetSpecStr << "`";
+ }
+ typedAttr = dyn_cast<TypedAttr>(parsedAttr);
+ if (!typedAttr) {
+ return mlir::emitError(loc) << "unexpected target attribute type: "
+ "expected a `!hal.device` but got `"
+ << parsedAttr << "`";
+ }
+ targetSpecStr = stripWhitespace(targetSpecStr.substr(numRead));
+ } else {
+ // Alias string.
+ auto [deviceID, ordinal] = consumeAliasLiteral(targetSpecStr);
+ typedAttr = IREE::HAL::DeviceAliasAttr::get(
+ context, IREE::HAL::DeviceType::get(context),
+ StringAttr::get(context, deviceID), ordinal, DictionaryAttr{});
+ }
+
+ if (!typedAttr || !isa<IREE::HAL::DeviceType>(typedAttr.getType())) {
+ return mlir::emitError(loc) << "unexpected target attribute type: "
+ "expected a `!hal.device` but got `"
+ << typedAttr.getType() << "`";
+ }
+ attrs.push_back(typedAttr);
+
+ if (targetSpecStr.empty()) {
+ break; // done
+ } else if (!targetSpecStr.starts_with(',')) {
+ return mlir::emitError(loc)
+ << "unexpected additional characters after parsing an element: `"
+ << targetSpecStr << "`";
+ }
+ targetSpecStr = targetSpecStr.substr(1); // strip ,
+ }
+
+ if (attrs.empty()) {
+ return mlir::emitError(loc) << "expected one or more target attributes";
+ } else if (attrs.size() == 1) {
+ return TargetSpec{nameAttr, cast<TypedAttr>(attrs.front())};
+ } else {
+ return TargetSpec{nameAttr,
+ IREE::HAL::DeviceSelectAttr::get(context, attrs)};
+ }
+}
+
struct AssignTargetDevicesPass
: public IREE::HAL::impl::AssignTargetDevicesPassBase<
AssignTargetDevicesPass> {
using IREE::HAL::impl::AssignTargetDevicesPassBase<
AssignTargetDevicesPass>::AssignTargetDevicesPassBase;
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<IREE::HAL::HALDialect>();
- for (auto &targetBackend : targetRegistry->getTargetBackends(
- targetRegistry->getRegisteredTargetBackends())) {
- targetBackend->getDependentDialects(registry);
- }
- }
void runOnOperation() override {
auto moduleOp = getOperation();
- // Check to see if targets are already specified.
- auto existingTargetsAttr =
- moduleOp->getAttrOfType<ArrayAttr>("hal.device.targets");
- if (existingTargetsAttr) {
- // Targets already exist on the module; no-op the pass so that we don't
- // mess with whatever the user intended.
- return;
- }
-
// If no targets are specified we can't do anything - another pass earlier
// in the pipeline will have had to add the targets.
- if (targetBackends.empty()) {
- emitRemark(moduleOp.getLoc())
- << "no target HAL target backends specified during assignment";
+ if (targetDevices.empty()) {
return;
}
- llvm::SmallDenseSet<Attribute> targetAttrSet;
- SmallVector<Attribute> targetAttrs;
- for (const auto &targetBackendName : targetBackends) {
- auto targetBackend = targetRegistry->getTargetBackend(targetBackendName);
- if (!targetBackend) {
- std::string backends;
- llvm::raw_string_ostream os(backends);
- llvm::interleaveComma(targetRegistry->getRegisteredTargetBackends(), os,
- [&os](const std::string &name) { os << name; });
- emitError(moduleOp.getLoc())
- << "target backend '" << targetBackendName
- << "' not registered; registered backends: " << os.str();
- signalPassFailure();
- return;
- }
- auto targetDeviceName = targetBackend->getLegacyDefaultDeviceID();
- auto targetDevice = targetRegistry->getTargetDevice(targetDeviceName);
- if (!targetDevice) {
- std::string devices;
- llvm::raw_string_ostream os(devices);
- llvm::interleaveComma(targetRegistry->getRegisteredTargetDevices(), os,
- [&os](const std::string &name) { os << name; });
- emitError(moduleOp.getLoc())
- << "target device '" << targetDeviceName
- << "' not registered; registered devices: " << os.str();
- signalPassFailure();
- return;
- }
+ // Check to see if targets are already specified and if so then no-op the
+ // pass so that we don't mess with whatever the user intended.
+ if (moduleOp->hasAttr("hal.device.targets")) {
+ return;
+ }
- // Ask the target backend for its default device specification attribute.
- auto targetAttr = targetDevice->getDefaultDeviceTarget(
- moduleOp.getContext(), *targetRegistry.value);
- if (!targetAttrSet.contains(targetAttr)) {
- targetAttrSet.insert(targetAttr);
- targetAttrs.push_back(targetAttr);
+ // If there are any device globals declared then bail as it means the user
+ // has already materialized the devices they want.
+ for (auto globalOp : moduleOp.getOps<IREE::Util::GlobalOpInterface>()) {
+ if (isa<IREE::HAL::DeviceType>(globalOp.getGlobalType())) {
+ return;
}
}
- moduleOp->setAttr("hal.device.targets",
- ArrayAttr::get(moduleOp.getContext(), targetAttrs));
+ // Parse each spec and validate correctness.
+ bool hasAnyNamed = false;
+ bool hasAnyUnnamed = false;
+ SmallVector<TargetSpec> targetSpecs;
+ for (auto &targetDevice : targetDevices) {
+ auto targetSpecOr = parseTargetSpec(moduleOp.getLoc(), targetDevice);
+ if (failed(targetSpecOr)) {
+ return signalPassFailure();
+ }
+ if (targetSpecOr->name) {
+ hasAnyNamed = true;
+ } else {
+ hasAnyUnnamed = true;
+ }
+ targetSpecs.push_back(*targetSpecOr);
+ }
+
+ // If any spec has a name assigned then all must have names assigned.
+ if (hasAnyNamed && hasAnyUnnamed) {
+ emitError(moduleOp.getLoc())
+ << "if any target device spec has a name then all must be named";
+ return signalPassFailure();
+ }
+
+ if (hasAnyNamed) {
+ // NOTE: we allow duplicate names to override assignment.
+ llvm::MapVector<StringAttr, Attribute> deviceAttrMap;
+ for (auto targetSpec : targetSpecs) {
+ assert(targetSpec.name && "all devices must be named");
+ deviceAttrMap[targetSpec.name] = targetSpec.attr;
+ }
+ SmallVector<NamedAttribute> deviceAttrs;
+ for (auto [name, value] : deviceAttrMap) {
+ deviceAttrs.push_back(NamedAttribute(name, value));
+ }
+ moduleOp->setAttr(
+ "hal.device.targets",
+ DictionaryAttr::get(moduleOp.getContext(), deviceAttrs));
+ } else {
+ SmallVector<Attribute> deviceAttrs;
+ for (auto [name, value] : targetSpecs) {
+ assert(!name && "no devices may have names");
+ deviceAttrs.push_back(value);
+ }
+ moduleOp->setAttr("hal.device.targets",
+ ArrayAttr::get(moduleOp.getContext(), deviceAttrs));
+ }
}
};
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel
index 33e9561..e4c7cc5 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel
@@ -15,6 +15,7 @@
iree_compiler_cc_library(
name = "Transforms",
srcs = [
+ "AssignLegacyTargetDevices.cpp",
"AssignTargetDevices.cpp",
"CaptureExecutableSources.cpp",
"ConfigureExecutables.cpp",
@@ -23,22 +24,26 @@
"DumpExecutableSources.cpp",
"ElideRedundantCommands.cpp",
"FixupLegacySync.cpp",
+ "InitializeDevices.cpp",
"LinkExecutables.cpp",
"MaterializeDispatchInstrumentation.cpp",
"MaterializeInterfaces.cpp",
"MaterializeResourceCaches.cpp",
+ "MaterializeTargetDevices.cpp",
"MemoizeDeviceQueries.cpp",
"Passes.cpp",
"Passes.h.inc",
"PreprocessExecutables.cpp",
"PruneExecutables.cpp",
"RepeatDispatches.cpp",
+ "ResolveDeviceAliases.cpp",
+ "ResolveDevicePromises.cpp",
"ResolveExportOrdinals.cpp",
"SerializeExecutables.cpp",
"StripExecutableContents.cpp",
"SubstituteExecutables.cpp",
"TranslateExecutables.cpp",
- "VerifyTargetEnvironment.cpp",
+ "VerifyDevices.cpp",
],
hdrs = [
"Passes.h",
@@ -71,6 +76,7 @@
"@llvm-project//mlir:AffineToStandard",
"@llvm-project//mlir:AffineTransforms",
"@llvm-project//mlir:ArithDialect",
+ "@llvm-project//mlir:AsmParser",
"@llvm-project//mlir:BufferizationDialect",
"@llvm-project//mlir:ControlFlowDialect",
"@llvm-project//mlir:FuncDialect",
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
index 6cce442..9af3134 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
@@ -16,6 +16,7 @@
HDRS
"Passes.h"
SRCS
+ "AssignLegacyTargetDevices.cpp"
"AssignTargetDevices.cpp"
"CaptureExecutableSources.cpp"
"ConfigureExecutables.cpp"
@@ -24,28 +25,33 @@
"DumpExecutableSources.cpp"
"ElideRedundantCommands.cpp"
"FixupLegacySync.cpp"
+ "InitializeDevices.cpp"
"LinkExecutables.cpp"
"MaterializeDispatchInstrumentation.cpp"
"MaterializeInterfaces.cpp"
"MaterializeResourceCaches.cpp"
+ "MaterializeTargetDevices.cpp"
"MemoizeDeviceQueries.cpp"
"Passes.cpp"
"Passes.h.inc"
"PreprocessExecutables.cpp"
"PruneExecutables.cpp"
"RepeatDispatches.cpp"
+ "ResolveDeviceAliases.cpp"
+ "ResolveDevicePromises.cpp"
"ResolveExportOrdinals.cpp"
"SerializeExecutables.cpp"
"StripExecutableContents.cpp"
"SubstituteExecutables.cpp"
"TranslateExecutables.cpp"
- "VerifyTargetEnvironment.cpp"
+ "VerifyDevices.cpp"
DEPS
::PassesIncGen
LLVMSupport
MLIRAffineToStandard
MLIRAffineTransforms
MLIRArithDialect
+ MLIRAsmParser
MLIRBufferizationDialect
MLIRControlFlowDialect
MLIRFuncDialect
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp
index 06d2366..eedb427 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp
@@ -49,6 +49,7 @@
: public IREE::HAL::impl::ConvertToHALPassBase<ConvertToHALPass> {
void runOnOperation() override {
auto *context = &getContext();
+ auto moduleOp = getOperation();
// Gather all interfaces from registered dialects.
// These will perform the tensor->buffer mapping for their ops.
@@ -64,8 +65,7 @@
HALTypeConverter typeConverter(conversionInterfaces);
HALConversionTarget conversionTarget(context, typeConverter);
- RewritePatternSet patterns(&getContext());
-
+ RewritePatternSet patterns(context);
populateHALToHALPatterns(context, conversionTarget, typeConverter,
patterns);
populateUtilToHALPatterns(context, conversionTarget, typeConverter,
@@ -84,13 +84,14 @@
// NOTE: we allow ops that we don't know about to allow custom dialects
// that don't need anything HAL-specific to pass through.
- if (failed(applyPartialConversion(getOperation(), conversionTarget,
+ if (failed(applyPartialConversion(moduleOp, conversionTarget,
std::move(patterns)))) {
return signalPassFailure();
}
// Cleanup conversion attributes used for spooky action at a distance.
- for (auto executableOp : getOperation().getOps<IREE::HAL::ExecutableOp>()) {
+ moduleOp->removeAttr("stream.affinity.default");
+ for (auto executableOp : moduleOp.getOps<IREE::HAL::ExecutableOp>()) {
for (auto variantOp :
executableOp.getOps<IREE::HAL::ExecutableVariantOp>()) {
for (auto exportOp :
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
index c2487b8..7ad1f4a 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
@@ -7,6 +7,7 @@
#include <memory>
#include <utility>
+#include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
@@ -74,7 +75,14 @@
for (auto funcOp : moduleOp.getOps<mlir::FunctionOpInterface>()) {
funcOp.walk([&](IREE::Stream::CmdDispatchOp dispatchOp) {
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(dispatchOp);
+ auto affinityAttr = dyn_cast_if_present<IREE::HAL::DeviceAffinityAttr>(
+ IREE::Stream::AffinityAttr::lookup(dispatchOp));
+ if (!affinityAttr) {
+ LLVM_DEBUG(
+ llvm::dbgs()
+ << "skipping dispatch because it has no affinity specified\n");
+ return;
+ }
auto workloadValues = dispatchOp.getWorkload();
SmallVector<unsigned> workload;
@@ -84,7 +92,7 @@
if (!matchPattern(workloadValue, m_ConstantInt(&workloadConstValue))) {
LLVM_DEBUG({
auto firstEntryPoint = *dispatchOp.getEntryPointRefs().begin();
- llvm::dbgs() << "Skipping dispatch of entry point `"
+ llvm::dbgs() << "skipping dispatch of entry point `"
<< firstEntryPoint << "` (non-constant workload)\n";
});
return;
@@ -123,7 +131,7 @@
APInt resourceLengthInt;
if (!matchPattern(resourceLength,
m_ConstantInt(&resourceLengthInt))) {
- LLVM_DEBUG(llvm::dbgs() << "Skipping dispatch of entry point `"
+ LLVM_DEBUG(llvm::dbgs() << "skipping dispatch of entry point `"
<< entryPointAttr
<< "` (non-constant resource length)\n";);
return;
@@ -157,12 +165,30 @@
return map;
}
+static std::pair<Value, Value>
+getDeviceAndQueueAffinity(Location loc, IREE::Stream::AffinityAttr affinityAttr,
+ OpBuilder &builder) {
+ if (auto deviceAffinityAttr =
+ dyn_cast_if_present<IREE::HAL::DeviceAffinityAttr>(affinityAttr)) {
+ auto resolveOp = builder.create<IREE::HAL::DeviceResolveOp>(
+ loc,
+ TypeRange{
+ builder.getType<IREE::HAL::DeviceType>(),
+ builder.getI64Type(),
+ },
+ deviceAffinityAttr);
+ return std::make_pair(resolveOp.getResult(0), resolveOp.getResult(1));
+ }
+ auto device = IREE::HAL::DeviceType::resolveAny(loc, builder);
+ auto queueAffinity = builder.create<arith::ConstantIntOp>(loc, -1, 64);
+ return std::make_pair(device, queueAffinity);
+}
+
// Appends a global hal.buffer initialized to the size required for all
// of the bindings in |dispatchParams| (plus alignment).
-static IREE::Util::GlobalOp
-appendGlobalBuffer(Location loc, StringRef baseName,
- const DispatchParams &dispatchParams,
- OpBuilder &moduleBuilder) {
+static IREE::Util::GlobalOp appendGlobalBuffer(
+ Location loc, StringRef baseName, const DispatchParams &dispatchParams,
+ IREE::Stream::AffinityAttr affinityAttr, OpBuilder &moduleBuilder) {
// Create a global to hold the HAL buffer.
auto globalOp = moduleBuilder.create<IREE::Util::GlobalOp>(
loc, (baseName + "_buffer").str(),
@@ -183,12 +209,12 @@
auto initBuilder = OpBuilder::atBlockBegin(initOp.addEntryBlock());
IndexSet indexSet(loc, initBuilder);
- // TODO(multi-device): support multiple devices in benchmark generation.
- Value device = IREE::HAL::DeviceType::resolveAny(loc, initBuilder);
+ // Resolve allocator for the benchmark device.
+ auto [device, queueAffinity] =
+ getDeviceAndQueueAffinity(loc, affinityAttr, initBuilder);
auto allocator =
initBuilder.create<IREE::HAL::DeviceAllocatorOp>(loc, device).getResult();
- auto queueAffinity = initBuilder.create<arith::ConstantIntOp>(loc, -1, 64);
auto memoryTypes = IREE::HAL::MemoryTypeBitfield::DeviceLocal;
auto bufferUsage = IREE::HAL::BufferUsageBitfield::Transfer |
IREE::HAL::BufferUsageBitfield::DispatchStorage;
@@ -226,8 +252,8 @@
}
// Add a global variable holding an initialized buffer for the dispatch IO.
- auto bufferGlobalOp =
- appendGlobalBuffer(loc, baseName, dispatchParams, moduleBuilder);
+ auto bufferGlobalOp = appendGlobalBuffer(loc, baseName, dispatchParams,
+ affinityAttr, moduleBuilder);
// Create an exported benchmark function that runs the dispatches.
auto funcType =
@@ -253,10 +279,9 @@
auto batchSizeArg = funcBuilder.create<arith::IndexCastOp>(
loc, funcBuilder.getIndexType(), entryBlock->getArgument(0));
- // TODO(multi-device): support multiple devices in benchmark generation.
- // For now we should just use the affinityAttr to resolve the device.
- Value device = IREE::HAL::DeviceType::resolveAny(loc, funcBuilder);
- Value queueAffinity = funcBuilder.create<arith::ConstantIntOp>(loc, -1, 64);
+ // Resolve device for this particular benchmark.
+ auto [device, queueAffinity] =
+ getDeviceAndQueueAffinity(loc, affinityAttr, funcBuilder);
// Create and begin command buffer.
// TODO(benvanik): reuse the command buffer (initialize once and store).
@@ -402,19 +427,22 @@
static mlir::OwningOpRef<mlir::ModuleOp>
buildBenchmarkModule(IREE::HAL::ExecutableOp sourceExecutableOp,
IREE::HAL::ExecutableVariantOp sourceVariantOp,
- const DispatchParamsMap &dispatchParamsMap) {
+ const DispatchParamsMap &dispatchParamsMap,
+ DeviceAnalysis &deviceAnalysis) {
// Empty module with default name.
// We could use the original module name here to make tracking nicer.
mlir::OwningOpRef<mlir::ModuleOp> moduleOp =
mlir::ModuleOp::create(sourceExecutableOp.getLoc());
auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp->getBody());
- // Copy over the device targets from the original module.
- // TODO(benvanik): filter this by the target of the variant.
- moduleOp->getOperation()->setAttr(
- "hal.device.targets",
- sourceExecutableOp->getParentOfType<mlir::ModuleOp>()->getAttr(
- "hal.device.targets"));
+ // Copy over the devices from the original module. Note that not all of the
+ // devices may be used and we should prune them, but even better than that
+ // would be to generate one module per device dispatches are made on such
+ // that users can isolate to individual devices. For now we just deal with
+ // it.
+ for (auto globalOp : deviceAnalysis.getDeviceGlobals()) {
+ moduleBuilder.clone(*globalOp.getOperation());
+ }
// Clone the executable variant into the new module.
auto executableOp = moduleBuilder.create<IREE::HAL::ExecutableOp>(
@@ -489,6 +517,21 @@
auto moduleName = moduleOp.getName().value_or("module");
SymbolTable symbolTable(moduleOp);
+ DeviceAnalysis deviceAnalysis(moduleOp);
+ if (failed(deviceAnalysis.run()))
+ return signalPassFailure();
+ if (deviceAnalysis.getDeviceGlobals().empty()) {
+ mlir::emitRemark(moduleOp.getLoc())
+ << "Executable benchmarks were requested but no devices were "
+ "declared in the module.\n";
+ return;
+ } else if (deviceAnalysis.getDeviceGlobals().size() != 1) {
+ mlir::emitWarning(moduleOp.getLoc())
+ << "Executable benchmarks were requested but there are multiple "
+ "devices in the module and the pass does not support that yet.\n";
+ return;
+ }
+
// Analyze the module to find dispatch parameters.
// This is a full walk of all stream.cmd.dispatch ops and will handle
// filtering out dispatches that have dynamic parameters we don't
@@ -511,8 +554,8 @@
for (auto executableOp : moduleOp.getOps<IREE::HAL::ExecutableOp>()) {
for (auto variantOp :
executableOp.getOps<IREE::HAL::ExecutableVariantOp>()) {
- auto benchmarkModuleOp =
- buildBenchmarkModule(executableOp, variantOp, dispatchParamsMap);
+ auto benchmarkModuleOp = buildBenchmarkModule(
+ executableOp, variantOp, dispatchParamsMap, deviceAnalysis);
if (!benchmarkModuleOp)
continue;
auto fileName = (moduleName + "_" + executableOp.getName() + "_" +
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/FixupLegacySync.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/FixupLegacySync.cpp
index 45bf830..7f6a18e 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/FixupLegacySync.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/FixupLegacySync.cpp
@@ -4,6 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
@@ -149,12 +150,15 @@
void runOnOperation() override {
auto moduleOp = getOperation();
- // See if any devices are marked as requiring the legacy_sync behavior.
- // If any single device does we must uniformly apply the fixups.
- if (!IREE::HAL::DeviceTargetAttr::lookupConfigAttrAny(moduleOp,
- "legacy_sync")) {
- return;
- }
+ // Analyze the module to determine which devices need the behavior.
+ DeviceAnalysis deviceAnalysis(moduleOp);
+ if (failed(deviceAnalysis.run()))
+ return signalPassFailure();
+ auto isLegacySync = [&](Value deviceValue) {
+ auto deviceSet = deviceAnalysis.lookupDeviceTargets(deviceValue);
+ return deviceSet.has_value() ? deviceSet->hasConfigAttrAny("legacy_sync")
+ : false;
+ };
// This could use an interface but it'd be better to remove the need for
// this pass instead.
@@ -162,19 +166,39 @@
funcOp.walk([&](Operation *op) {
TypeSwitch<Operation *, void>(op)
.Case([&](IREE::HAL::CommandBufferCreateOp op) {
- makeAllowInlineExecution(op);
+ if (isLegacySync(op.getDevice())) {
+ makeAllowInlineExecution(op);
+ }
})
.Case([&](IREE::HAL::DeviceQueueAllocaOp op) {
- insertWaitIfNeeded(op, op.getWaitFenceMutable(),
- op.getSignalFence());
+ if (isLegacySync(op.getDevice())) {
+ insertWaitIfNeeded(op, op.getWaitFenceMutable(),
+ op.getSignalFence());
+ }
})
.Case([&](IREE::HAL::DeviceQueueDeallocaOp op) {
- insertWaitIfNeeded(op, op.getWaitFenceMutable(),
- op.getSignalFence());
+ if (isLegacySync(op.getDevice())) {
+ insertWaitIfNeeded(op, op.getWaitFenceMutable(),
+ op.getSignalFence());
+ }
+ })
+ .Case([&](IREE::HAL::DeviceQueueReadOp op) {
+ if (isLegacySync(op.getDevice())) {
+ insertWaitIfNeeded(op, op.getWaitFenceMutable(),
+ op.getSignalFence());
+ }
+ })
+ .Case([&](IREE::HAL::DeviceQueueWriteOp op) {
+ if (isLegacySync(op.getDevice())) {
+ insertWaitIfNeeded(op, op.getWaitFenceMutable(),
+ op.getSignalFence());
+ }
})
.Case([&](IREE::HAL::DeviceQueueExecuteOp op) {
- insertWaitIfNeeded(op, op.getWaitFenceMutable(),
- op.getSignalFence());
+ if (isLegacySync(op.getDevice())) {
+ insertWaitIfNeeded(op, op.getWaitFenceMutable(),
+ op.getSignalFence());
+ }
});
});
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/InitializeDevices.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/InitializeDevices.cpp
new file mode 100644
index 0000000..fae26c5
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/InitializeDevices.cpp
@@ -0,0 +1,111 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
+#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir::iree_compiler::IREE::HAL {
+
+#define GEN_PASS_DEF_INITIALIZEDEVICESPASS
+#include "iree/compiler/Dialect/HAL/Transforms/Passes.h.inc"
+
+namespace {
+
+// Converts an initialized device global to one with a util.initializer that
+// performs the device initialization. The initializer is added immediately
+// following the global in its parent op.
+static void initializeDeviceGlobal(
+ IREE::Util::GlobalOpInterface globalOp,
+ IREE::HAL::DeviceInitializationAttrInterface initialValue,
+ const IREE::HAL::TargetRegistry &targetRegistry) {
+ auto loc = globalOp.getLoc();
+
+ // Clear the initial value as we'll be initializing from the initializer.
+ globalOp.setGlobalInitialValue({});
+
+ // Build a new util.initializer.
+ OpBuilder moduleBuilder(globalOp);
+ moduleBuilder.setInsertionPointAfter(globalOp);
+ auto initializerOp = moduleBuilder.create<IREE::Util::InitializerOp>(loc);
+ auto *block = moduleBuilder.createBlock(&initializerOp.getBody());
+ auto initializerBuilder = OpBuilder::atBlockBegin(block);
+
+ // Get the device from the attribute builder; note that it may be null.
+ Value enumeratedDevice = initialValue.buildDeviceEnumeration(
+ loc,
+ [&](Location loc, Value device, IREE::HAL::DeviceTargetAttr targetAttr,
+ OpBuilder &builder) {
+ auto targetDevice =
+ targetRegistry.getTargetDevice(targetAttr.getDeviceID());
+ return targetDevice ? targetDevice->buildDeviceTargetMatch(
+ loc, device, targetAttr, builder)
+ : Value{};
+ },
+ initializerBuilder);
+
+ // Check if the device is null and error out. We could support optional
+ // devices that are allowed to be null but don't support that anywhere else in
+ // the compiler today and may never want to. If selecting from multiple
+ // devices queries can be used to detect what the selected device was and
+ // those will be memoized.
+ Value nullDevice = initializerBuilder.create<IREE::Util::NullOp>(
+ loc, enumeratedDevice.getType());
+ Value isNull = initializerBuilder.create<IREE::Util::CmpEQOp>(
+ loc, enumeratedDevice, nullDevice);
+ initializerBuilder.create<scf::IfOp>(
+ loc, isNull, [&](OpBuilder &thenBuilder, Location thenLoc) {
+ Value status = thenBuilder.create<arith::ConstantIntOp>(
+ thenLoc, static_cast<int64_t>(IREE::Util::StatusCode::NotFound),
+ 32);
+ std::string str;
+ {
+ llvm::raw_string_ostream os(str);
+ os << "HAL device `" << globalOp.getGlobalName().getValue()
+ << "` not found or unavailable: ";
+ initialValue.printStatusDescription(os);
+ }
+ thenBuilder.create<IREE::Util::StatusCheckOkOp>(thenLoc, status, str);
+ thenBuilder.create<scf::YieldOp>(thenLoc);
+ });
+
+ // Store the device back to the global to complete initialization.
+ globalOp.createStoreOp(loc, enumeratedDevice, initializerBuilder);
+ initializerBuilder.create<IREE::Util::ReturnOp>(loc);
+}
+
+//===----------------------------------------------------------------------===//
+// --iree-hal-initialize-devices
+//===----------------------------------------------------------------------===//
+
+struct InitializeDevicesPass
+ : public IREE::HAL::impl::InitializeDevicesPassBase<InitializeDevicesPass> {
+ using IREE::HAL::impl::InitializeDevicesPassBase<
+ InitializeDevicesPass>::InitializeDevicesPassBase;
+ void runOnOperation() override {
+ auto moduleOp = getOperation();
+ for (auto globalOp : moduleOp.getOps<IREE::Util::GlobalOpInterface>()) {
+ auto initialValue =
+ dyn_cast_if_present<IREE::HAL::DeviceInitializationAttrInterface>(
+ globalOp.getGlobalInitialValue());
+ if (initialValue) {
+ initializeDeviceGlobal(globalOp, initialValue, *targetRegistry.value);
+ }
+ }
+ }
+};
+
+} // namespace
+
+} // namespace mlir::iree_compiler::IREE::HAL
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
index 1f5eb15..221f754 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
@@ -9,6 +9,7 @@
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Dialect/HAL/Analysis/BindingLayout.h"
+#include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/Target/TargetBackend.h"
@@ -43,64 +44,98 @@
Attribute,
SmallVector<std::pair<Attribute, IREE::HAL::ExecutableTargetAttr>>>;
+// Map of operations (executables, dispatches, etc) to the executable targets
+// required by those operations based on usage. If missing or empty the default
+// set should be used.
+using RequiredExecutableTargets =
+ DenseMap<Operation *, SetVector<IREE::HAL::ExecutableTargetAttr>>;
+
//===----------------------------------------------------------------------===//
// Utilities
//===----------------------------------------------------------------------===//
+static SymbolRefAttr
+makeExportSymbolRefAttr(IREE::HAL::ExecutableOp executableOp,
+ IREE::HAL::ExecutableVariantOp variantOp,
+ IREE::HAL::ExecutableExportOp exportOp) {
+ return SymbolRefAttr::get(executableOp.getNameAttr(),
+ {
+ FlatSymbolRefAttr::get(variantOp.getNameAttr()),
+ FlatSymbolRefAttr::get(exportOp.getNameAttr()),
+ });
+}
+
static void setApplicableObjects(Operation *sourceOp,
IREE::HAL::ExecutableVariantOp targetOp) {
auto objectsAttr = sourceOp->getAttrOfType<IREE::HAL::ExecutableObjectsAttr>(
"hal.executable.objects");
- if (!objectsAttr)
+ if (!objectsAttr) {
return;
+ }
auto objects = objectsAttr.getApplicableObjects(targetOp.getTarget());
- if (!objects)
+ if (!objects) {
return;
+ }
targetOp.setObjectsAttr(*objects);
}
-// Returns a set of executable targets required by any dispatch to the given
-// executable. Not all exports may be dispatched on the targets.
-// If the |executableOp| is public then targets specified on the module will be
-// used in addition to any from the dispatches.
-template <typename OpT>
-static SmallVector<IREE::HAL::ExecutableTargetAttr>
-gatherExecutableTargetAttrs(SymbolOpInterface executableOp,
- llvm::iterator_range<OpT> exportOps,
- const BindingLayoutAnalysis &layoutAnalysis) {
- llvm::SetVector<IREE::HAL::ExecutableTargetAttr,
- SmallVector<IREE::HAL::ExecutableTargetAttr>>
- targetAttrsSet;
- if (executableOp.isPublic()) {
- for (auto targetAttr :
- IREE::HAL::DeviceTargetAttr::lookupExecutableTargets(executableOp)) {
- targetAttrsSet.insert(targetAttr);
- }
- }
- for (auto exportOp : exportOps) {
- for (auto dispatchOp : layoutAnalysis.getExportDispatches(exportOp)) {
- for (auto targetAttr :
- IREE::HAL::DeviceTargetAttr::lookupExecutableTargets(dispatchOp)) {
- targetAttrsSet.insert(targetAttr);
+template <typename ExecutableOpT, typename ExportOpT>
+static void
+buildRequiredExecutableTypeTargetsMap(ModuleOp moduleOp,
+ DeviceAnalysis &deviceAnalysis,
+ BindingLayoutAnalysis &layoutAnalysis,
+ RequiredExecutableTargets &resultMap) {
+ // NOTE: we build the map before we process it so that the addresses are
+ // stable.
+ for (auto executableOp : moduleOp.template getOps<ExecutableOpT>()) {
+ (void)resultMap[executableOp];
+ for (auto exportOp : executableOp.template getOps<ExportOpT>()) {
+ for (auto dispatchOp : layoutAnalysis.getExportDispatches(exportOp)) {
+ (void)resultMap[dispatchOp];
}
}
}
- auto targetAttrs = targetAttrsSet.takeVector();
- llvm::stable_sort(targetAttrs, [](auto lhs, auto rhs) {
- return lhs.getSymbolNameFragment() < rhs.getSymbolNameFragment();
- });
- return targetAttrs;
+ for (auto executableOp : moduleOp.template getOps<ExecutableOpT>()) {
+ auto &executableTargetAttrs = resultMap[executableOp];
+ for (auto exportOp : executableOp.template getOps<ExportOpT>()) {
+ for (auto dispatchOp : layoutAnalysis.getExportDispatches(exportOp)) {
+ auto &dispatchTargetAttrs = resultMap[dispatchOp];
+ deviceAnalysis.gatherRequiredExecutableTargets(dispatchOp,
+ dispatchTargetAttrs);
+ executableTargetAttrs.insert(dispatchTargetAttrs.begin(),
+ dispatchTargetAttrs.end());
+ }
+ }
+ if (executableOp.isPublic()) {
+ // Public executables need all possible targets.
+ deviceAnalysis.gatherAllExecutableTargets(executableTargetAttrs);
+ }
+ }
+}
+
+// Builds a map of executable and dispatch ops to the executable targets that
+// may be required.
+static RequiredExecutableTargets
+buildRequiredExecutableTargetsMap(ModuleOp moduleOp,
+ DeviceAnalysis &deviceAnalysis,
+ BindingLayoutAnalysis &layoutAnalysis) {
+ RequiredExecutableTargets resultMap;
+ buildRequiredExecutableTypeTargetsMap<IREE::HAL::ExecutableSourceOp,
+ IREE::HAL::ExecutableExportOp>(
+ moduleOp, deviceAnalysis, layoutAnalysis, resultMap);
+ buildRequiredExecutableTypeTargetsMap<IREE::Stream::ExecutableOp,
+ IREE::Stream::ExecutableExportOp>(
+ moduleOp, deviceAnalysis, layoutAnalysis, resultMap);
+ return resultMap;
}
// Updates the target entry point symbols of |dispatchOp| to the expanded set of
// variant exports in |exportExpansions|.
-static void updateDispatchTargets(IREE::Stream::CmdDispatchOp dispatchOp,
- const ExportExpansions &exportExpansions) {
- DenseSet<IREE::HAL::ExecutableTargetAttr> requiredTargetAttrs;
- for (auto targetAttr :
- IREE::HAL::DeviceTargetAttr::lookupExecutableTargets(dispatchOp)) {
- requiredTargetAttrs.insert(targetAttr);
- }
+static void
+updateDispatchTargets(IREE::Stream::CmdDispatchOp dispatchOp,
+ const ExportExpansions &exportExpansions,
+ RequiredExecutableTargets &requiredExecutableTargets) {
+ auto &requiredTargetAttrs = requiredExecutableTargets[dispatchOp];
SmallVector<Attribute> newAttrs;
for (auto oldAttr : dispatchOp.getEntryPointRefs()) {
auto it = exportExpansions.find(oldAttr);
@@ -109,9 +144,13 @@
continue;
}
for (auto [newAttr, targetAttr] : it->second) {
- // Filter the new expansions to only those used by the dispatch.
- if (requiredTargetAttrs.contains(targetAttr))
+ // Filter the new expansions to only those used by the dispatch (if we
+ // have a valid filter).
+ if (requiredTargetAttrs.empty()) {
newAttrs.push_back(newAttr);
+ } else if (requiredTargetAttrs.contains(targetAttr)) {
+ newAttrs.push_back(newAttr);
+ }
}
}
dispatchOp.setEntryPointsAttr(
@@ -122,26 +161,20 @@
// hal.executable.source materialization
//===----------------------------------------------------------------------===//
-SymbolRefAttr makeExportSymbolRefAttr(IREE::HAL::ExecutableOp executableOp,
- IREE::HAL::ExecutableVariantOp variantOp,
- IREE::HAL::ExecutableExportOp exportOp) {
- return SymbolRefAttr::get(executableOp.getNameAttr(),
- {
- FlatSymbolRefAttr::get(variantOp.getNameAttr()),
- FlatSymbolRefAttr::get(exportOp.getNameAttr()),
- });
-}
-
-static void
-materializeExecutableFromSourceOp(IREE::HAL::ExecutableSourceOp sourceOp,
- BindingLayoutAnalysis &layoutAnalysis) {
+static void materializeExecutableFromSourceOp(
+ IREE::HAL::ExecutableSourceOp sourceOp,
+ BindingLayoutAnalysis &layoutAnalysis,
+ RequiredExecutableTargets &requiredExecutableTargets) {
// Gather the required executable targets based on the dispatches to exports
// in the source op.
- auto targetAttrs = gatherExecutableTargetAttrs(
- sourceOp, sourceOp.getOps<IREE::HAL::ExecutableExportOp>(),
- layoutAnalysis);
- if (targetAttrs.empty())
+ SmallVector<IREE::HAL::ExecutableTargetAttr> targetAttrs(
+ requiredExecutableTargets[sourceOp].getArrayRef());
+ if (targetAttrs.empty()) {
return;
+ }
+ llvm::stable_sort(targetAttrs, [](auto lhs, auto rhs) {
+ return lhs.getSymbolNameFragment() < rhs.getSymbolNameFragment();
+ });
// Create the op that will contain the translated executable.
OpBuilder moduleBuilder(sourceOp);
@@ -180,8 +213,9 @@
// Clone any target-specific object files specified.
if (auto objectsAttr = sourceOp.getObjectsAttr()) {
auto objects = objectsAttr.getApplicableObjects(targetAttr);
- if (objects)
+ if (objects) {
targetVariantOp.setObjectsAttr(*objects);
+ }
}
// Clone inner module contents.
@@ -194,7 +228,8 @@
// Update all dispatch sites to reference the new expanded variants.
for (auto exportOp : sourceExportOps) {
for (auto dispatchOp : layoutAnalysis.getExportDispatches(exportOp)) {
- updateDispatchTargets(dispatchOp, exportExpansions);
+ updateDispatchTargets(dispatchOp, exportExpansions,
+ requiredExecutableTargets);
}
}
@@ -312,8 +347,9 @@
}
unsigned resourceIdx = 0;
for (auto arg : entryBlock->getArguments()) {
- if (!llvm::isa<IREE::Stream::BindingType>(arg.getType()))
- continue;
+ if (!llvm::isa<IREE::Stream::BindingType>(arg.getType())) {
+ continue; // unhandled arg type (primitive/etc)
+ }
auto setBinding = resourceMap[resourceIdx++];
auto setLayoutAttr = layoutAttr.getSetLayouts()[setBinding.first];
auto bindingAttr = setLayoutAttr.getBindings()[setBinding.second];
@@ -332,7 +368,8 @@
static LogicalResult
declareEntryPointOps(IREE::Stream::ExecutableOp sourceExecutableOp,
IREE::HAL::ExecutableOp targetExecutableOp,
- const BindingLayoutAnalysis &layoutAnalysis) {
+ const BindingLayoutAnalysis &layoutAnalysis,
+ RequiredExecutableTargets &requiredExecutableTargets) {
auto variantOps =
targetExecutableOp.getBlock().getOps<IREE::HAL::ExecutableVariantOp>();
OpBuilder executableBuilder(&targetExecutableOp.getBlock().front());
@@ -348,8 +385,9 @@
if (auto sourceModuleOp = sourceExecutableOp.getInnerModule()) {
sourceFuncOp = sourceModuleOp.lookupSymbol<mlir::func::FuncOp>(
exportOp.getFunctionRef());
- if (failed(verifyEntryPointTypes(sourceFuncOp)))
+ if (failed(verifyEntryPointTypes(sourceFuncOp))) {
return failure();
+ }
}
// Lookup to see if a layout was specified already. If not we'll perform
@@ -441,7 +479,8 @@
// Update all dispatch sites to reference the new expanded variants.
for (auto dispatchOp : layoutAnalysis.getExportDispatches(exportOp)) {
- updateDispatchTargets(dispatchOp, exportExpansions);
+ updateDispatchTargets(dispatchOp, exportExpansions,
+ requiredExecutableTargets);
}
}
@@ -513,8 +552,9 @@
assert(exportOp &&
"must have an entry point corresponding to the parent func");
auto workgroupSizeAttr = exportOp.getWorkgroupSizeAttr();
- if (!workgroupSizeAttr)
+ if (!workgroupSizeAttr) {
return failure();
+ }
uint64_t dimIdx = sizeOp.getDimension().getZExtValue();
auto dimAttr = workgroupSizeAttr[dimIdx];
@@ -554,11 +594,34 @@
SymbolTable symbolTable(moduleOp);
BindingLayoutAnalysis layoutAnalysis(moduleOp, symbolTable);
+ // Run required analysis passes.
+ DeviceAnalysis deviceAnalysis(moduleOp);
+ if (failed(deviceAnalysis.run())) {
+ return signalPassFailure();
+ }
+
+ // If no devices were defined and there are dispatches in the program then
+ // error out. This provides a better error message than if we were to allow
+ // this pass to no-op and then fail during conversion later on.
+ if (layoutAnalysis.hasDispatches() &&
+ deviceAnalysis.getDeviceGlobals().empty()) {
+ mlir::emitError(moduleOp.getLoc())
+ << "no HAL devices defined in the module; use the module-level "
+ "hal.device.targets attribute, the --iree-hal-target-device= "
+ "flag, or provide inputs with global !hal.devices defined";
+ return signalPassFailure();
+ }
+
+ // Gather the required executable targets per executable and dispatch site.
+ auto requiredExecutableTargets = buildRequiredExecutableTargetsMap(
+ moduleOp, deviceAnalysis, layoutAnalysis);
+
// Handle any hand-authored executables; these only need variant expansion
// and no layout analysis as the user specified the layout themselves.
for (auto sourceOp : llvm::make_early_inc_range(
moduleOp.getOps<IREE::HAL::ExecutableSourceOp>())) {
- materializeExecutableFromSourceOp(sourceOp, layoutAnalysis);
+ materializeExecutableFromSourceOp(sourceOp, layoutAnalysis,
+ requiredExecutableTargets);
}
// Processes all executables within the input module and produce the
@@ -568,17 +631,22 @@
for (auto sourceOp : llvm::make_early_inc_range(
moduleOp.getOps<IREE::Stream::ExecutableOp>())) {
auto exportOps = sourceOp.getOps<IREE::Stream::ExecutableExportOp>();
- if (exportOps.empty())
+ if (exportOps.empty()) {
continue;
+ }
// Gather a list of all #hal.executable.targets that we should produce
// variants for based on the dispatches performed. Not all exports may be
// used on any particular target but we let future DCE/pruning passes
// remove them instead of modifying the inner modules here.
- auto targetAttrs =
- gatherExecutableTargetAttrs(sourceOp, exportOps, layoutAnalysis);
- if (targetAttrs.empty())
- continue;
+ SmallVector<IREE::HAL::ExecutableTargetAttr> targetAttrs(
+ requiredExecutableTargets[sourceOp].getArrayRef());
+ if (targetAttrs.empty()) {
+ return;
+ }
+ llvm::stable_sort(targetAttrs, [](auto lhs, auto rhs) {
+ return lhs.getSymbolNameFragment() < rhs.getSymbolNameFragment();
+ });
// Create the op that will contain the translated executable.
OpBuilder builder = OpBuilder::atBlockEnd(moduleOp.getBody());
@@ -605,8 +673,8 @@
}
// Define interfaces for each exported function based on analysis.
- if (failed(
- declareEntryPointOps(sourceOp, executableOp, layoutAnalysis))) {
+ if (failed(declareEntryPointOps(sourceOp, executableOp, layoutAnalysis,
+ requiredExecutableTargets))) {
return signalPassFailure();
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
index 9761580..de22093 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
@@ -7,11 +7,13 @@
#include <memory>
#include <utility>
+#include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Attributes.h"
@@ -20,6 +22,9 @@
#include "mlir/IR/Diagnostics.h"
#include "mlir/Pass/Pass.h"
+#define DEBUG_TYPE "iree-hal-materialize-resource-caches"
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+
namespace mlir::iree_compiler::IREE::HAL {
#define GEN_PASS_DEF_MATERIALIZERESOURCECACHESPASS
@@ -27,315 +32,683 @@
namespace {
-// TODO(multi-device): rewrite this to shard resources per device.
+//===----------------------------------------------------------------------===//
+// --iree-hal-materialize-resource-caches
+//===----------------------------------------------------------------------===//
+
+struct DescriptorSetLayout {
+ // All locations that use the layout.
+ SetVector<Location> locs;
+ // Value within the initializer once materialized.
+ Value initializerValue;
+};
+using DescriptorSetLayoutKey =
+ std::pair<ArrayAttr, IREE::HAL::DescriptorSetLayoutFlags>;
+
+struct PipelineLayout {
+ // All locations that use the layout.
+ SetVector<Location> locs;
+ // Lookup ops for this layout.
+ SmallVector<IREE::HAL::PipelineLayoutLookupOp> lookupOps;
+ // Global once materialized.
+ IREE::Util::GlobalOpInterface globalOp;
+ // Value within the initializer once materialized.
+ Value initializerValue;
+};
+
+struct Executable {
+ // All locations that use the executable.
+ SetVector<Location> locs;
+ // Executable representing the program to load.
+ IREE::HAL::ExecutableOp executableOp;
+ // Lookup ops for this executable.
+ SmallVector<IREE::HAL::ExecutableLookupOp> lookupOps;
+ // Global once materialized.
+ IREE::Util::GlobalOpInterface globalOp;
+};
+
+struct DeviceResources {
+ DeviceResources() = default;
+ explicit DeviceResources(IREE::Util::GlobalOpInterface deviceOp)
+ : deviceOp(deviceOp) {}
+
+ // Global !hal.device.
+ IREE::Util::GlobalOpInterface deviceOp;
+
+ // Fallback devices that should be checked for resources.
+ // These are derived from the transitive set of #hal.device.fallback attrs.
+ SetVector<DeviceResources *> fallbackDeviceResources;
+
+ // Descriptor set layouts used on the device, keyed by [bindingAttrs, flags].
+ llvm::MapVector<DescriptorSetLayoutKey, DescriptorSetLayout>
+ descriptorSetLayouts;
+ // Pipeline layouts used on the device, keyed by layout attr.
+ llvm::MapVector<IREE::HAL::PipelineLayoutAttr, PipelineLayout>
+ pipelineLayouts;
+ // Executables used on the device, keyed by name.
+ llvm::MapVector<StringAttr, Executable> executables;
+};
+
+static std::string getDeviceNamePrefix(IREE::Util::GlobalOpInterface deviceOp) {
+ StringRef deviceName = deviceOp.getGlobalName().getValue();
+ if (deviceName.starts_with("__")) {
+ // Already prefixed.
+ return deviceName.str();
+ }
+ auto prefixedName = "__" + deviceName;
+ return prefixedName.str();
+}
+
+static void declareDevicePipelineLayout(IREE::Util::GlobalOpInterface deviceOp,
+ PipelineLayout &pipelineLayout,
+ size_t pipelineLayoutIndex,
+ OpBuilder &moduleBuilder) {
+ // Create global in the module.
+ auto symbolName = getDeviceNamePrefix(deviceOp) + "_pipeline_layout_" +
+ std::to_string(pipelineLayoutIndex);
+ LLVM_DEBUG(DBGS() << "+ creating device `"
+ << deviceOp.getGlobalName().getValue()
+ << "` pipeline global `" << symbolName << "`\n");
+ auto layoutType = moduleBuilder.getType<PipelineLayoutType>();
+ auto globalOp = moduleBuilder.create<IREE::Util::GlobalOp>(
+ moduleBuilder.getFusedLoc(llvm::to_vector(pipelineLayout.locs)),
+ symbolName,
+ /*isMutable=*/false, layoutType);
+ globalOp.setPrivate();
+ pipelineLayout.globalOp = globalOp;
+
+ // Replace lookups with the global.
+ for (auto lookupOp : pipelineLayout.lookupOps) {
+ LLVM_DEBUG({
+ DBGS() << " - replacing lookup: ";
+ lookupOp.print(llvm::dbgs());
+ llvm::dbgs() << "\n";
+ });
+ OpBuilder lookupBuilder(lookupOp);
+ auto loadedValue =
+ pipelineLayout.globalOp.createLoadOp(lookupOp.getLoc(), lookupBuilder)
+ .getLoadedGlobalValue();
+ lookupOp.replaceAllUsesWith(loadedValue);
+ lookupOp.erase();
+ }
+ pipelineLayout.lookupOps.clear();
+}
+
+static void declareDeviceExecutable(IREE::Util::GlobalOpInterface deviceOp,
+ Executable &executable,
+ size_t executableIndex,
+ OpBuilder &moduleBuilder) {
+ // Create global in the module.
+ auto symbolName = (getDeviceNamePrefix(deviceOp) + "_executable_" +
+ std::to_string(executableIndex) + "_" +
+ executable.executableOp.getName())
+ .str();
+ LLVM_DEBUG(DBGS() << "+ creating device `"
+ << deviceOp.getGlobalName().getValue()
+ << "` executable global `" << symbolName << "`\n");
+ auto executableType = moduleBuilder.getType<IREE::HAL::ExecutableType>();
+ auto globalOp = moduleBuilder.create<IREE::Util::GlobalOp>(
+ moduleBuilder.getFusedLoc(llvm::to_vector(executable.locs)), symbolName,
+ /*isMutable=*/false, executableType);
+ globalOp.setPrivate();
+ executable.globalOp = globalOp;
+
+ // Replace lookups with the global.
+ for (auto lookupOp : executable.lookupOps) {
+ LLVM_DEBUG({
+ DBGS() << " - replacing lookup: ";
+ lookupOp.print(llvm::dbgs());
+ llvm::dbgs() << "\n";
+ });
+ OpBuilder lookupBuilder(lookupOp);
+ auto loadedValue =
+ executable.globalOp.createLoadOp(lookupOp.getLoc(), lookupBuilder)
+ .getLoadedGlobalValue();
+ lookupOp.replaceAllUsesWith(loadedValue);
+ lookupOp.erase();
+ }
+ executable.lookupOps.clear();
+}
+
+static DescriptorSetLayoutKey
+getDescriptorSetLayoutKey(IREE::HAL::DescriptorSetLayoutAttr setLayoutAttr) {
+ auto bindingAttrs =
+ llvm::to_vector_of<Attribute>(setLayoutAttr.getBindings());
+ return DescriptorSetLayoutKey{
+ ArrayAttr::get(setLayoutAttr.getContext(), bindingAttrs),
+ setLayoutAttr.getFlags().value_or(
+ IREE::HAL::DescriptorSetLayoutFlags::None),
+ };
+}
+
+// Inlines a constant block as a function in |moduleBuilder| and then inserts
+// a call to it in |callerBuilder|.
+static SmallVector<Value> inlineConstantBlockOp(
+ StringRef funcName, IREE::HAL::ExecutableConstantBlockOp blockOp,
+ OpBuilder &moduleBuilder, OpBuilder &callerBuilder, Value callerDevice) {
+ LLVM_DEBUG(DBGS() << "- inlining constant block `" << funcName << "`\n");
+
+ // Create the function with the region contents of the constant block.
+ auto funcOp = moduleBuilder.create<IREE::Util::FuncOp>(
+ blockOp.getLoc(), funcName, blockOp.getFunctionType());
+ funcOp.setPrivate();
+ IRMapping mapping;
+ blockOp.getRegion().cloneInto(&funcOp.getRegion(), mapping);
+
+ // Replace the hal.return with a func.return.
+ for (auto returnOp :
+ llvm::make_early_inc_range(funcOp.getOps<IREE::HAL::ReturnOp>())) {
+ OpBuilder(returnOp).create<IREE::Util::ReturnOp>(returnOp.getLoc(),
+ returnOp.getOperands());
+ returnOp.erase();
+ }
+
+ // Create the call passing in the device if needed.
+ SmallVector<Value> callOperands;
+ if (funcOp.getNumArguments() > 0) {
+ callOperands.push_back(callerDevice);
+ }
+ auto callOp = callerBuilder.create<IREE::Util::CallOp>(blockOp.getLoc(),
+ funcOp, callOperands);
+ return llvm::to_vector_of<Value>(callOp.getResults());
+}
+
+static Value initializeExecutable(DeviceResources &deviceResources,
+ Executable &executable,
+ OpBuilder &moduleBuilder,
+ Value initializerDevice,
+ OpBuilder &initializerBuilder) {
+ auto loc = executable.globalOp.getLoc();
+ auto executableType = moduleBuilder.getType<IREE::HAL::ExecutableType>();
+
+ // Create a switch statement with a case for each variant.
+ // Each case should then cache only executables which contain a matching
+ // ExecutableVariantOp.
+ // Afterwards, canonicalization will take care of de-duping/etc.
+ SmallVector<int64_t> caseIndices;
+ SmallVector<IREE::HAL::ExecutableVariantOp> caseVariantOps;
+ for (auto variantOp :
+ executable.executableOp.getOps<IREE::HAL::ExecutableVariantOp>()) {
+ caseIndices.push_back(caseIndices.size());
+ caseVariantOps.push_back(variantOp);
+ }
+
+ // Select the variant index.
+ Value selectedIndex = buildIfElseTree(
+ loc, caseVariantOps.size(),
+ [&](Location loc, size_t i, OpBuilder &builder) {
+ return caseVariantOps[i].buildCondition(initializerDevice, builder);
+ },
+ initializerBuilder);
+
+ // Allow each variant to define how it is loaded and what pipeline it has.
+ auto switchOp = initializerBuilder.create<scf::IndexSwitchOp>(
+ loc, executableType, selectedIndex, caseIndices, caseIndices.size());
+ for (auto [i, variantOp] : llvm::enumerate(caseVariantOps)) {
+ auto &caseBlock = switchOp.getCaseRegions()[i].emplaceBlock();
+ auto caseBuilder = OpBuilder::atBlockBegin(&caseBlock);
+
+ // Gather each of the pipeline layouts needed for each entry point in
+ // the executable.
+ SmallVector<Value> pipelineLayoutValues;
+ for (auto exportOp : variantOp.getExportOps()) {
+ auto &pipelineLayout =
+ deviceResources.pipelineLayouts[exportOp.getLayoutAttr()];
+ pipelineLayoutValues.push_back(pipelineLayout.initializerValue);
+ }
+
+ // Inline constant initializer from the variant.
+ // We want these to all happen inside of this device switch case; they'll
+ // get deduplicated/hoisted if possible in future canonicalization passes.
+ SmallVector<Value> constantValues;
+ for (auto [blockIndex, blockOp] :
+ llvm::enumerate(variantOp.getConstantBlockOps())) {
+ auto blockName = (executable.globalOp.getGlobalName().getValue() +
+ "_constant_block_" + std::to_string(blockIndex))
+ .str();
+ constantValues.append(inlineConstantBlockOp(
+ blockName, blockOp, moduleBuilder, caseBuilder, initializerDevice));
+ }
+
+ Value executableValue =
+ caseBuilder.createOrFold<IREE::HAL::ExecutableCreateOp>(
+ loc, executableType, initializerDevice,
+ SymbolRefAttr::get(
+ executable.executableOp.getSymNameAttr(),
+ {SymbolRefAttr::get(variantOp.getSymNameAttr())}),
+ pipelineLayoutValues, constantValues);
+
+ caseBuilder.create<scf::YieldOp>(loc, executableValue);
+ }
+
+ // Fallback for no available variant.
+ auto &defaultBlock = switchOp.getDefaultRegion().emplaceBlock();
+ auto defaultBuilder = OpBuilder::atBlockBegin(&defaultBlock);
+ Value status = defaultBuilder.create<arith::ConstantIntOp>(
+ loc, static_cast<int>(IREE::Util::StatusCode::Unavailable), 32);
+ {
+ std::string errorStr;
+ llvm::raw_string_ostream errorStream(errorStr);
+ errorStream << "HAL device `"
+ << deviceResources.deviceOp.getGlobalName().getValue()
+ << "` does not support any variant of executable `"
+ << executable.executableOp.getName()
+ << "`; available formats: [";
+ llvm::interleaveComma(caseVariantOps, errorStream, [&](auto variantOp) {
+ errorStream << variantOp.getTargetAttr().getFormat().getValue();
+ });
+ errorStream << "]";
+ defaultBuilder.create<IREE::Util::StatusCheckOkOp>(loc, status, errorStr);
+ }
+ auto nullValue =
+ defaultBuilder.createOrFold<IREE::Util::NullOp>(loc, executableType);
+ defaultBuilder.create<scf::YieldOp>(loc, nullValue);
+
+ return switchOp.getResult(0);
+}
+
+static void initializeDeviceResources(DeviceResources &deviceResources,
+ OpBuilder &moduleBuilder,
+ Value initializerDevice,
+ OpBuilder &initializerBuilder) {
+ // Initialize all descriptor set layouts for use by the pipeline layouts.
+ auto setLayoutType = initializerBuilder.getType<DescriptorSetLayoutType>();
+ for (auto [i, it] : llvm::enumerate(deviceResources.descriptorSetLayouts)) {
+ auto [bindingAttrs, flags] = it.first;
+ auto &descriptorSetLayout = it.second;
+ descriptorSetLayout.initializerValue =
+ initializerBuilder.createOrFold<IREE::HAL::DescriptorSetLayoutCreateOp>(
+ initializerBuilder.getFusedLoc(
+ llvm::to_vector(descriptorSetLayout.locs)),
+ setLayoutType, initializerDevice, flags, bindingAttrs);
+ }
+
+ // Initialize all pipeline layouts required for executable creation.
+ auto pipelineLayoutType = initializerBuilder.getType<PipelineLayoutType>();
+ for (auto [i, it] : llvm::enumerate(deviceResources.pipelineLayouts)) {
+ auto &[layoutAttr, pipelineLayout] = it;
+ SmallVector<Value> setLayoutValues;
+ for (auto setLayoutAttr : layoutAttr.getSetLayouts()) {
+ auto key = getDescriptorSetLayoutKey(setLayoutAttr);
+ setLayoutValues.push_back(
+ deviceResources.descriptorSetLayouts[key].initializerValue);
+ }
+ pipelineLayout.initializerValue =
+ initializerBuilder.createOrFold<IREE::HAL::PipelineLayoutCreateOp>(
+ pipelineLayout.globalOp.getLoc(), pipelineLayoutType,
+ initializerDevice,
+ initializerBuilder.getIndexAttr(layoutAttr.getPushConstants()),
+ setLayoutValues);
+ pipelineLayout.globalOp.createStoreOp(pipelineLayout.globalOp.getLoc(),
+ pipelineLayout.initializerValue,
+ initializerBuilder);
+ }
+
+ // Initialize all executables.
+ for (auto [i, it] : llvm::enumerate(deviceResources.executables)) {
+ auto &[executableName, executable] = it;
+ executable.globalOp.createStoreOp(
+ executable.globalOp.getLoc(),
+ initializeExecutable(deviceResources, executable, moduleBuilder,
+ initializerDevice, initializerBuilder),
+ initializerBuilder);
+ }
+}
+
+static void reuseFallbackDeviceResources(DeviceResources &deviceResources,
+ DeviceResources &fallbackResources,
+ Value initializerDevice,
+ OpBuilder &initializerBuilder) {
+ // Load fallback pipeline layouts for all required by this device.
+ for (auto &[layoutAttr, pipelineLayout] : deviceResources.pipelineLayouts) {
+ auto fallbackGlobalOp =
+ fallbackResources.pipelineLayouts[layoutAttr].globalOp;
+ assert(fallbackGlobalOp && "should have created global");
+ Value fallbackPipelineLayout =
+ fallbackGlobalOp
+ .createLoadOp(pipelineLayout.globalOp.getLoc(), initializerBuilder)
+ .getLoadedGlobalValue();
+ pipelineLayout.globalOp.createStoreOp(pipelineLayout.globalOp.getLoc(),
+ fallbackPipelineLayout,
+ initializerBuilder);
+ }
+
+ // Load fallback executables for all required by this device.
+ for (auto &[executableName, executable] : deviceResources.executables) {
+ auto fallbackGlobalOp =
+ fallbackResources.executables[executable.executableOp.getNameAttr()]
+ .globalOp;
+ assert(fallbackGlobalOp && "should have created global");
+ Value fallbackExecutable =
+ fallbackGlobalOp
+ .createLoadOp(executable.globalOp.getLoc(), initializerBuilder)
+ .getLoadedGlobalValue();
+ executable.globalOp.createStoreOp(executable.globalOp.getLoc(),
+ fallbackExecutable, initializerBuilder);
+ }
+}
+
+static void buildDeviceResourceInitializer(DeviceResources &deviceResources,
+ OpBuilder &moduleBuilder) {
+ auto loc = deviceResources.deviceOp.getLoc();
+ auto initializerOp = moduleBuilder.create<IREE::Util::InitializerOp>(loc);
+ OpBuilder initializerBuilder =
+ OpBuilder::atBlockEnd(initializerOp.addEntryBlock());
+ Value initializerDevice =
+ deviceResources.deviceOp.createLoadOp(loc, initializerBuilder)
+ .getLoadedGlobalValue();
+
+ // If there are any fallbacks then we need to handle referencing their
+ // resources and otherwise will initialize our own.
+ if (deviceResources.fallbackDeviceResources.empty()) {
+ initializeDeviceResources(deviceResources, moduleBuilder, initializerDevice,
+ initializerBuilder);
+ } else {
+ SmallVector<int64_t> caseIndices;
+ Value selectedIndex = buildIfElseTree(
+ loc, deviceResources.fallbackDeviceResources.size(),
+ [&](Location loc, size_t i, OpBuilder &caseBuilder) {
+ caseIndices.push_back(caseIndices.size());
+ auto *fallbackResources = deviceResources.fallbackDeviceResources[i];
+ Value fallbackDevice =
+ fallbackResources->deviceOp.createLoadOp(loc, caseBuilder)
+ .getLoadedGlobalValue();
+ return caseBuilder.create<IREE::Util::CmpEQOp>(loc, initializerDevice,
+ fallbackDevice);
+ },
+ initializerBuilder);
+ auto switchOp = initializerBuilder.create<scf::IndexSwitchOp>(
+ loc, TypeRange{}, selectedIndex, caseIndices, caseIndices.size());
+ for (auto [fallbackResources, caseRegion] :
+ llvm::zip_equal(deviceResources.fallbackDeviceResources,
+ switchOp.getCaseRegions())) {
+ auto &caseBlock = caseRegion.emplaceBlock();
+ auto caseBuilder = OpBuilder::atBlockBegin(&caseBlock);
+ reuseFallbackDeviceResources(deviceResources, *fallbackResources,
+ initializerDevice, caseBuilder);
+ caseBuilder.create<scf::YieldOp>(loc);
+ }
+ auto &defaultBlock = switchOp.getDefaultRegion().emplaceBlock();
+ auto defaultBuilder = OpBuilder::atBlockBegin(&defaultBlock);
+ initializeDeviceResources(deviceResources, moduleBuilder, initializerDevice,
+ defaultBuilder);
+ defaultBuilder.create<scf::YieldOp>(loc);
+ }
+
+ initializerBuilder.create<IREE::Util::ReturnOp>(loc);
+}
+
+// Returns zero or more devices globals that may act as fallbacks for the
+// given device, if analyzed. The result is in selection order.
+static std::optional<SetVector<IREE::Util::GlobalOpInterface>>
+getDeviceFallbackGlobals(IREE::Util::GlobalOpInterface deviceGlobal,
+ SymbolTable &symbolTable) {
+ SetVector<IREE::Util::GlobalOpInterface> resultSet;
+ auto processAttr = [&](Attribute attr) {
+ if (!attr)
+ return true; // ignore uninitialized devices
+ return TypeSwitch<Attribute, bool>(attr)
+ .Case<IREE::HAL::DeviceOrdinalAttr>([](auto attr) { return true; })
+ .Case<IREE::HAL::DeviceTargetAttr>([](auto attr) { return true; })
+ .Case<IREE::HAL::DeviceFallbackAttr>([&](auto fallbackAttr) {
+ resultSet.insert(symbolTable.lookup<IREE::Util::GlobalOpInterface>(
+ fallbackAttr.getName().getValue()));
+ return true;
+ })
+ .Default([](auto attr) { return false; });
+ };
+ auto initialValue = deviceGlobal.getGlobalInitialValue();
+ if (auto selectAttr =
+ dyn_cast_if_present<IREE::HAL::DeviceSelectAttr>(initialValue)) {
+ for (auto deviceAttr : selectAttr.getDevices()) {
+ if (!processAttr(deviceAttr)) {
+ // Fails if unsupported/unhandled device attribute type.
+ return std::nullopt;
+ }
+ }
+ } else {
+ if (!processAttr(initialValue)) {
+ // Fails if unsupported/unhandled device attribute type.
+ return std::nullopt;
+ }
+ }
+ return resultSet;
+}
+
+static LogicalResult gatherDeviceResources(
+ ModuleOp &moduleOp, SymbolTable &symbolTable,
+ DeviceAnalysis &deviceAnalysis,
+ llvm::MapVector<Attribute, DeviceResources> &allDeviceResources) {
+ // Allocate storage for the resource sets.
+ for (auto deviceOp : deviceAnalysis.getDeviceGlobals()) {
+ LLVM_DEBUG(DBGS() << "Gathering device `"
+ << deviceOp.getGlobalName().getValue()
+ << "` resources...\n");
+ allDeviceResources.try_emplace(deviceOp.getGlobalName(),
+ DeviceResources(deviceOp));
+ }
+
+ // Link fallbacks between the resources.
+ for (auto deviceOp : deviceAnalysis.getDeviceGlobals()) {
+ auto fallbackOps = getDeviceFallbackGlobals(deviceOp, symbolTable);
+ if (!fallbackOps) {
+ return deviceOp->emitOpError()
+ << "analysis failed on device; currently analysis must succeed";
+ }
+ auto &deviceResources = allDeviceResources[deviceOp.getGlobalName()];
+ for (auto fallbackOp : *fallbackOps) {
+ LLVM_DEBUG(DBGS() << "* linking to fallback `"
+ << fallbackOp.getGlobalName().getValue() << "`\n");
+ deviceResources.fallbackDeviceResources.insert(
+ &allDeviceResources[fallbackOp.getGlobalName()]);
+ }
+ }
+
+ // Find all relevant ops. If we don't find any we skip the pass as it's
+ // likely it's already been run. We could fix the pass to better support
+ // partial materialization but there's no use cases for that today.
+ auto tryGetDeviceResources = [&](Operation *op,
+ Value device) -> DeviceResources * {
+ auto deviceGlobals = deviceAnalysis.lookupDeviceGlobals(device);
+ if (!deviceGlobals || deviceGlobals->size() != 1) {
+ op->emitOpError() << "analysis failed on device; currently analysis "
+ "must succeed with a single device";
+ return nullptr;
+ }
+ auto deviceOp = deviceGlobals->front();
+ return &allDeviceResources.find(deviceOp.getGlobalName())->second;
+ };
+ for (auto funcOp : moduleOp.getOps<mlir::FunctionOpInterface>()) {
+ for (auto &block : funcOp.getFunctionBody()) {
+ if (block
+ .walk([&](Operation *op) -> WalkResult {
+ if (auto lookupOp = dyn_cast<PipelineLayoutLookupOp>(op)) {
+ auto *deviceResources =
+ tryGetDeviceResources(lookupOp, lookupOp.getDevice());
+ if (!deviceResources) {
+ return WalkResult::interrupt();
+ }
+ auto layoutAttr = lookupOp.getLayoutAttr();
+ LLVM_DEBUG(DBGS()
+ << "+ requiring pipeline layout from lookup: `"
+ << layoutAttr << "`\n");
+ auto &pipelineLayout =
+ deviceResources->pipelineLayouts[layoutAttr];
+ pipelineLayout.locs.insert(lookupOp.getLoc());
+ pipelineLayout.lookupOps.push_back(lookupOp);
+ for (auto setLayoutAttr : layoutAttr.getSetLayouts()) {
+ LLVM_DEBUG(
+ DBGS()
+ << "+ requiring descriptor set layout from lookup: `"
+ << setLayoutAttr << "`\n");
+ auto key = getDescriptorSetLayoutKey(setLayoutAttr);
+ auto &setLayout =
+ deviceResources->descriptorSetLayouts[key];
+ setLayout.locs.insert(lookupOp.getLoc());
+ }
+ } else if (auto lookupOp = dyn_cast<ExecutableLookupOp>(op)) {
+ auto *deviceResources =
+ tryGetDeviceResources(lookupOp, lookupOp.getDevice());
+ if (!deviceResources) {
+ return WalkResult::interrupt();
+ }
+ auto executableAttr = lookupOp.getExecutableAttr().getAttr();
+ LLVM_DEBUG(DBGS() << "+ requiring executable from lookup: `"
+ << executableAttr.getValue() << "`\n");
+ auto &executable =
+ deviceResources->executables[executableAttr];
+ executable.locs.insert(lookupOp.getLoc());
+ executable.lookupOps.push_back(lookupOp);
+ }
+ return WalkResult::advance();
+ })
+ .wasInterrupted()) {
+ return failure();
+ }
+ }
+ }
+
+ // Gather the executables referenced by all lookup ops.
+ for (auto &[deviceName, deviceResources] : allDeviceResources) {
+ for (auto &[executableName, executable] : deviceResources.executables) {
+ executable.executableOp =
+ symbolTable.lookup<IREE::HAL::ExecutableOp>(executableName);
+ for (auto variantOp :
+ executable.executableOp.getOps<IREE::HAL::ExecutableVariantOp>()) {
+ for (auto exportOp : variantOp.getExportOps()) {
+ auto layoutAttr = exportOp.getLayoutAttr();
+ LLVM_DEBUG(DBGS() << "+ requiring pipeline layout from export: `"
+ << layoutAttr << "`\n");
+ auto &pipelineLayout = deviceResources.pipelineLayouts[layoutAttr];
+ pipelineLayout.locs.insert(exportOp.getLoc());
+ for (auto setLayoutAttr : layoutAttr.getSetLayouts()) {
+ LLVM_DEBUG(DBGS()
+ << "+ requiring descriptor set layout from export: `"
+ << setLayoutAttr << "`\n");
+ auto key = getDescriptorSetLayoutKey(setLayoutAttr);
+ auto &setLayout = deviceResources.descriptorSetLayouts[key];
+ setLayout.locs.insert(exportOp.getLoc());
+ }
+ }
+ }
+ }
+ }
+
+ // Merge all resources that may be used by way of fallbacks into each fallback
+ // device. We could make this optional to improve startup performance by
+ // adding these as optional and create them on demand but that's more complex.
+ // For now we just always ensure the resources are available even if they end
+ // up unused.
+ for (auto &[deviceName, deviceResources] :
+ llvm::reverse(allDeviceResources)) {
+ for (auto *fallbackResources : deviceResources.fallbackDeviceResources) {
+ LLVM_DEBUG(
+ DBGS() << "-> requiring fallback resources from device `"
+ << fallbackResources->deviceOp.getGlobalName().getValue()
+ << "`\n");
+ for (auto [setKey, setLayout] : deviceResources.descriptorSetLayouts) {
+ auto &fallbackSetLayout =
+ fallbackResources->descriptorSetLayouts[setKey];
+ fallbackSetLayout.locs.insert(setLayout.locs.begin(),
+ setLayout.locs.end());
+ }
+ for (auto [layoutAttr, pipelineLayout] :
+ deviceResources.pipelineLayouts) {
+ auto &fallbackPipelineLayout =
+ fallbackResources->pipelineLayouts[layoutAttr];
+ fallbackPipelineLayout.locs.insert(pipelineLayout.locs.begin(),
+ pipelineLayout.locs.end());
+ }
+ for (auto [executableName, executable] : deviceResources.executables) {
+ auto &fallbackExecutable =
+ fallbackResources->executables[executableName];
+ fallbackExecutable.locs.insert(executable.locs.begin(),
+ executable.locs.end());
+ fallbackExecutable.executableOp = executable.executableOp;
+ }
+ }
+ }
+
+ return success();
+}
+
struct MaterializeResourceCachesPass
: public IREE::HAL::impl::MaterializeResourceCachesPassBase<
MaterializeResourceCachesPass> {
void runOnOperation() override {
auto moduleOp = getOperation();
- if (moduleOp.getBody()->empty())
- return;
- moduleBuilder = OpBuilder(&moduleOp.getBody()->front());
+ SymbolTable symbolTable(moduleOp);
- // Find all relevant ops. If we don't find any we skip the pass as it's
- // likely it's already been run. We could fix the pass to better support
- // partial materialization but there's no use cases for that today.
- auto executableOps = llvm::to_vector<8>(moduleOp.getOps<ExecutableOp>());
- SmallVector<IREE::HAL::PipelineLayoutLookupOp> pipelineLayoutLookupOps;
- SmallVector<IREE::HAL::ExecutableLookupOp> executableLookupOps;
- for (auto funcOp : moduleOp.getOps<mlir::FunctionOpInterface>()) {
- for (auto &block : funcOp.getFunctionBody()) {
- block.walk([&](Operation *op) {
- if (auto lookupOp = dyn_cast<PipelineLayoutLookupOp>(op)) {
- pipelineLayoutLookupOps.push_back(lookupOp);
- } else if (auto lookupOp = dyn_cast<ExecutableLookupOp>(op)) {
- executableLookupOps.push_back(lookupOp);
- }
- });
+ // Analyze the module to determine which devices are used where.
+ LLVM_DEBUG(DBGS() << "Running device analysis...\n");
+ DeviceAnalysis deviceAnalysis(moduleOp);
+ if (failed(deviceAnalysis.run())) {
+ return signalPassFailure();
+ }
+
+ // Build a table of all resources used by all devices in the program.
+ LLVM_DEBUG(DBGS() << "Gathering device resources...\n");
+ llvm::MapVector<Attribute, DeviceResources> allDeviceResources;
+ if (failed(gatherDeviceResources(moduleOp, symbolTable, deviceAnalysis,
+ allDeviceResources))) {
+ return signalPassFailure();
+ }
+
+ // Materialize resources for each device (if any) and replace lookups.
+ for (auto &[nameAttr, deviceResources] : allDeviceResources) {
+ LLVM_DEBUG(DBGS() << "Materializing device `"
+ << deviceResources.deviceOp.getGlobalName().getValue()
+ << "` resources...\n");
+ // Skip devices with no resources.
+ if (deviceResources.pipelineLayouts.empty() &&
+ deviceResources.executables.empty()) {
+ LLVM_DEBUG(DBGS() << "~ skipping device with no resources\n");
+ continue;
}
- }
- if (pipelineLayoutLookupOps.empty() && executableLookupOps.empty()) {
- return;
+
+ // TODO(benvanik): proper insertion order if devices are initialized via
+ // an initializer. Today this assumes the device hasn't been materialized
+ // yet if there are any lookups to them.
+ if (!deviceResources.deviceOp.getGlobalInitialValue()) {
+ deviceResources.deviceOp.emitOpError()
+ << "is expected to be initialized with an attribute and not yet "
+ "via a util.initializer";
+ return signalPassFailure();
+ }
+
+ // Declare globals for each pipeline layout and executable and replace all
+ // lookup ops to reference them.
+ OpBuilder moduleBuilder(moduleOp);
+ moduleBuilder.setInsertionPointAfter(deviceResources.deviceOp);
+ for (auto [i, it] : llvm::enumerate(deviceResources.pipelineLayouts)) {
+ auto &[layoutAttr, pipelineLayout] = it;
+ declareDevicePipelineLayout(deviceResources.deviceOp, pipelineLayout, i,
+ moduleBuilder);
+ }
+ for (auto [i, it] : llvm::enumerate(deviceResources.executables)) {
+ auto &[executableName, executable] = it;
+ declareDeviceExecutable(deviceResources.deviceOp, executable, i,
+ moduleBuilder);
+ }
+
+ // Create an initializer after the declared globals.
+ buildDeviceResourceInitializer(deviceResources, moduleBuilder);
}
- // Declare all layouts used by the executables. This will ensure that the
- // initialization order is correct as any pipeline layout needed (and its
- // dependencies) will be created prior to the executable cache below. The
- // other nice thing is that we get ordering similar to the executable
- // variables above.
- for (auto executableOp : executableOps) {
+ // Remove ops that are no longer required after materialization.
+ for (auto executableOp : moduleOp.getOps<IREE::HAL::ExecutableOp>()) {
for (auto variantOp :
executableOp.getOps<IREE::HAL::ExecutableVariantOp>()) {
- for (auto exportOp : variantOp.getExportOps()) {
- definePipelineLayoutOp(exportOp.getLoc(), exportOp.getLayout());
+ if (auto conditionOp = variantOp.getConditionOp()) {
+ conditionOp.erase();
+ }
+ for (auto blockOp :
+ llvm::make_early_inc_range(variantOp.getConstantBlockOps())) {
+ blockOp.erase();
}
}
}
-
- // Declare executable variables so that we can reference them during lookup
- // replacement.
- for (auto executableOp : executableOps) {
- defineExecutableOp(executableOp);
- }
-
- // Generate cached resource singletons and replace lookup ops with direct
- // loads from variables.
- for (auto lookupOp : pipelineLayoutLookupOps) {
- replacePipelineLayoutLookupOp(lookupOp);
- }
- for (auto lookupOp : executableLookupOps) {
- replaceExecutableLookupOp(lookupOp);
- }
}
-
-private:
- IREE::Util::GlobalOp
- defineDescriptorSetLayoutOp(Location loc, ArrayAttr bindingAttrs,
- IREE::HAL::DescriptorSetLayoutFlags flags) {
- std::pair<Attribute, IREE::HAL::DescriptorSetLayoutFlags> key = {
- bindingAttrs, flags};
- auto existingIt = descriptorSetLayoutCache_.find(key);
- if (existingIt != descriptorSetLayoutCache_.end()) {
- return existingIt->second;
- }
-
- auto symbolName = (StringRef("_descriptor_set_layout_") +
- std::to_string(nextUniqueDescriptorSetLayoutId++))
- .str();
-
- auto layoutType = DescriptorSetLayoutType::get(loc.getContext());
- auto globalOp = moduleBuilder.create<IREE::Util::GlobalOp>(
- loc, symbolName,
- /*isMutable=*/false, layoutType);
- globalOp.setPrivate();
- descriptorSetLayoutCache_.try_emplace(key, globalOp);
-
- auto initializerOp = moduleBuilder.create<IREE::Util::InitializerOp>(loc);
- OpBuilder blockBuilder =
- OpBuilder::atBlockEnd(initializerOp.addEntryBlock());
- // TODO(multi-device): pass in resolve info to the call and reuse.
- Value device = IREE::HAL::DeviceType::resolveAny(loc, blockBuilder);
- Value layout = blockBuilder.createOrFold<DescriptorSetLayoutCreateOp>(
- loc, layoutType, device, flags, bindingAttrs);
- globalOp.createStoreOp(loc, layout, blockBuilder);
- blockBuilder.create<IREE::Util::ReturnOp>(loc);
-
- return globalOp;
- }
-
- IREE::Util::GlobalOp
- definePipelineLayoutOp(Location loc,
- IREE::HAL::PipelineLayoutAttr layoutAttr) {
- auto existingIt = pipelineLayoutCache_.find(layoutAttr);
- if (existingIt != pipelineLayoutCache_.end()) {
- return existingIt->second;
- }
-
- // First lookup (or create) all the required descriptor sets. This ensures
- // they end up in the proper initialization order.
- SmallVector<IREE::Util::GlobalOp> setLayoutGlobalOps;
- for (auto setLayoutAttr : layoutAttr.getSetLayouts()) {
- SmallVector<Attribute> bindingAttrs;
- for (auto bindingAttr : setLayoutAttr.getBindings()) {
- bindingAttrs.push_back(bindingAttr);
- }
- setLayoutGlobalOps.push_back(defineDescriptorSetLayoutOp(
- loc, ArrayAttr::get(loc.getContext(), bindingAttrs),
- setLayoutAttr.getFlags().value_or(
- IREE::HAL::DescriptorSetLayoutFlags::None)));
- }
-
- auto symbolName = (StringRef("_pipeline_layout_") +
- std::to_string(nextUniquePipelineLayoutId++))
- .str();
-
- auto layoutType = PipelineLayoutType::get(loc.getContext());
- auto globalOp = moduleBuilder.create<IREE::Util::GlobalOp>(
- loc, symbolName, /*isMutable=*/false, layoutType);
- globalOp.setPrivate();
- pipelineLayoutCache_.try_emplace(layoutAttr, globalOp);
-
- auto initializerOp = moduleBuilder.create<IREE::Util::InitializerOp>(loc);
- OpBuilder blockBuilder =
- OpBuilder::atBlockEnd(initializerOp.addEntryBlock());
- SmallVector<Value> setLayoutValues;
- for (auto setLayoutGlobalOp : setLayoutGlobalOps) {
- setLayoutValues.push_back(
- setLayoutGlobalOp.createLoadOp(loc, blockBuilder)
- .getLoadedGlobalValue());
- }
- // TODO(multi-device): pass in resolve info to the call and reuse.
- Value device = IREE::HAL::DeviceType::resolveAny(loc, blockBuilder);
- Value layout = blockBuilder.createOrFold<PipelineLayoutCreateOp>(
- loc, layoutType, device,
- blockBuilder.getIndexAttr(layoutAttr.getPushConstants()),
- setLayoutValues);
- globalOp.createStoreOp(loc, layout, blockBuilder);
- blockBuilder.create<IREE::Util::ReturnOp>(loc);
-
- return globalOp;
- }
-
- void defineExecutableOp(ExecutableOp executableOp) {
- auto loc = executableOp.getLoc();
- auto symbolName =
- (StringRef("_executable_") + executableOp.getSymName()).str();
-
- auto executableType = ExecutableType::get(executableOp.getContext());
- auto globalOp = moduleBuilder.create<IREE::Util::GlobalOp>(
- loc, symbolName, /*isMutable=*/false, executableType);
- globalOp.setPrivate();
- executableCache_.try_emplace(executableOp.getSymName(), globalOp);
-
- auto initializerOp = moduleBuilder.create<IREE::Util::InitializerOp>(loc);
- OpBuilder blockBuilder =
- OpBuilder::atBlockEnd(initializerOp.addEntryBlock());
- // TODO(multi-device): pass in resolve info to the call and reuse.
- Value device = IREE::HAL::DeviceType::resolveAny(loc, blockBuilder);
-
- // Create a switch statement with a case for each variant.
- // Each case should then cache only executables which contain a matching
- // ExecutableVariantOp.
- // Afterwards, canonicalization will take care of de-duping/etc.
- SmallVector<int64_t> caseIndices;
- SmallVector<IREE::HAL::ExecutableVariantOp> caseVariantOps;
- for (auto variantOp :
- executableOp.getOps<IREE::HAL::ExecutableVariantOp>()) {
- caseIndices.push_back(caseIndices.size());
- caseVariantOps.push_back(variantOp);
- }
-
- // Select the variant index.
- Value selectedIndex = buildIfElseTree(
- loc, caseVariantOps.size(),
- [&](Location loc, size_t i, OpBuilder &builder) {
- return caseVariantOps[i].buildCondition(device, builder);
- },
- blockBuilder);
-
- // Allow each variant to define how it is loaded and what pipeline it has.
- auto switchOp = blockBuilder.create<scf::IndexSwitchOp>(
- loc, executableType, selectedIndex, caseIndices, caseIndices.size());
- for (auto [i, variantOp] : llvm::enumerate(caseVariantOps)) {
- auto &caseBlock = switchOp.getCaseRegions()[i].emplaceBlock();
- auto caseBuilder = OpBuilder::atBlockBegin(&caseBlock);
-
- // Gather each of the pipeline layouts needed for each entry point in
- // the executable.
- SmallVector<Value, 8> pipelineLayoutValues;
- for (auto exportOp : variantOp.getExportOps()) {
- auto pipelineLayoutGlobalOp =
- definePipelineLayoutOp(executableOp.getLoc(), exportOp.getLayout());
- pipelineLayoutValues.push_back(
- pipelineLayoutGlobalOp.createLoadOp(loc, caseBuilder)
- .getLoadedGlobalValue());
- }
-
- // Inline constant initializer from the variant.
- // We want these to all happen inside of this device switch case; they'll
- // get deduplicated/hoisted if possible in future canonicalization passes.
- SmallVector<Value> constantValues;
- for (auto blockOp :
- llvm::make_early_inc_range(variantOp.getConstantBlockOps())) {
- constantValues.append(
- inlineConstantBlockOp(blockOp, moduleBuilder, caseBuilder, device));
- blockOp.erase();
- }
-
- Value executable = caseBuilder.createOrFold<ExecutableCreateOp>(
- loc, executableType, device,
- SymbolRefAttr::get(executableOp.getSymNameAttr(),
- {SymbolRefAttr::get(variantOp.getSymNameAttr())}),
- pipelineLayoutValues, constantValues);
-
- caseBuilder.create<scf::YieldOp>(loc, executable);
- }
-
- // Fallback for no available variant.
- auto &defaultBlock = switchOp.getDefaultRegion().emplaceBlock();
- auto defaultBuilder = OpBuilder::atBlockBegin(&defaultBlock);
- Value status = defaultBuilder.create<arith::ConstantIntOp>(
- loc, static_cast<int>(IREE::Util::StatusCode::Unavailable), 32);
- defaultBuilder.create<IREE::Util::StatusCheckOkOp>(
- loc, status,
- "none of the executable binaries in the module are supported by the "
- "runtime");
- auto nullValue =
- defaultBuilder.createOrFold<IREE::Util::NullOp>(loc, executableType);
- defaultBuilder.create<scf::YieldOp>(loc, nullValue);
-
- auto executableValue = switchOp.getResult(0);
- globalOp.createStoreOp(loc, executableValue, blockBuilder);
- blockBuilder.create<IREE::Util::ReturnOp>(loc);
- }
-
- // Inlines a constant block as a function in |moduleBuilder| and then inserts
- // a call to it in |callerBuilder|.
- SmallVector<Value> inlineConstantBlockOp(ExecutableConstantBlockOp blockOp,
- OpBuilder &moduleBuilder,
- OpBuilder &callerBuilder,
- Value device) {
- // Create the function with the region contents of the constant block.
- auto funcName = (StringRef("__constant_block_") +
- std::to_string(nextUniqueConstantBlockId++))
- .str();
- auto funcOp = moduleBuilder.create<IREE::Util::FuncOp>(
- blockOp.getLoc(), funcName, blockOp.getFunctionType());
- funcOp.setPrivate();
- funcOp.getRegion().takeBody(blockOp.getRegion());
-
- // Replace the hal.return with a func.return.
- for (auto returnOp :
- llvm::make_early_inc_range(funcOp.getOps<IREE::HAL::ReturnOp>())) {
- OpBuilder(returnOp).create<IREE::Util::ReturnOp>(returnOp.getLoc(),
- returnOp.getOperands());
- returnOp.erase();
- }
-
- // Create the call passing in the device if needed.
- SmallVector<Value> callOperands;
- if (funcOp.getNumArguments() > 0) {
- callOperands.push_back(device);
- }
- auto callOp = callerBuilder.create<IREE::Util::CallOp>(
- blockOp.getLoc(), funcOp, callOperands);
-
- return llvm::map_to_vector(callOp.getResults(),
- [](OpResult result) -> Value { return result; });
- }
-
- void replacePipelineLayoutLookupOp(PipelineLayoutLookupOp &lookupOp) {
- OpBuilder builder(lookupOp);
- auto globalOp =
- definePipelineLayoutOp(lookupOp.getLoc(), lookupOp.getLayout());
- auto loadedValue = globalOp.createLoadOp(lookupOp.getLoc(), builder)
- .getLoadedGlobalValue();
- lookupOp.replaceAllUsesWith(loadedValue);
- lookupOp.erase();
- }
-
- void replaceExecutableLookupOp(ExecutableLookupOp &lookupOp) {
- OpBuilder builder(lookupOp);
- auto executableIt = executableCache_.find(lookupOp.getExecutable());
- assert(executableIt != executableCache_.end() &&
- "executable must have been cached");
- auto globalOp = executableIt->second;
- auto loadedValue = globalOp.createLoadOp(lookupOp.getLoc(), builder)
- .getLoadedGlobalValue();
- lookupOp.replaceAllUsesWith(loadedValue);
- lookupOp.erase();
- }
-
- OpBuilder moduleBuilder{static_cast<MLIRContext *>(nullptr)};
- DenseMap<std::pair<Attribute, IREE::HAL::DescriptorSetLayoutFlags>,
- IREE::Util::GlobalOp>
- descriptorSetLayoutCache_;
- DenseMap<Attribute, IREE::Util::GlobalOp> pipelineLayoutCache_;
- DenseMap<StringRef, IREE::Util::GlobalOp> executableCache_;
-
- int nextUniqueConstantBlockId = 0;
- int nextUniquePipelineLayoutId = 0;
- int nextUniqueDescriptorSetLayoutId = 0;
};
} // namespace
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTargetDevices.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTargetDevices.cpp
new file mode 100644
index 0000000..5d70f1b
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTargetDevices.cpp
@@ -0,0 +1,234 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <memory>
+#include <utility>
+
+#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
+#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "llvm/ADT/STLExtras.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir::iree_compiler::IREE::HAL {
+
+#define GEN_PASS_DEF_MATERIALIZETARGETDEVICESPASS
+#include "iree/compiler/Dialect/HAL/Transforms/Passes.h.inc"
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// --iree-hal-materialize-target-devices
+//===----------------------------------------------------------------------===//
+
+// Returns the canonical name for a device by ordinal:
+// device ordinal `N` -> `@__device_N`
+static FlatSymbolRefAttr makeDefaultDeviceOrdinalRef(MLIRContext *context,
+ int64_t ordinal) {
+ return FlatSymbolRefAttr::get(
+ context, (StringRef("__device_") + std::to_string(ordinal)).str());
+}
+
+// Returns the canonical name for a device by name:
+// device name `NAME` -> `@NAME`
+static FlatSymbolRefAttr makeDefaultDeviceNameRef(MLIRContext *context,
+ StringRef name) {
+ return FlatSymbolRefAttr::get(context, name);
+}
+
+// Returns a symbol ref constructed to reference the specified device.
+// Supports:
+// integer attrs: device ordinal `N` -> `@__device_N`
+// string attrs: device name `NAME` -> `@NAME`
+static FailureOr<FlatSymbolRefAttr>
+makeDefaultDeviceAttrRef(Attribute defaultDeviceAttr) {
+ if (auto stringAttr = dyn_cast<StringAttr>(defaultDeviceAttr)) {
+ return makeDefaultDeviceNameRef(stringAttr.getContext(), stringAttr);
+ } else if (auto integerAttr = dyn_cast<IntegerAttr>(defaultDeviceAttr)) {
+ return makeDefaultDeviceOrdinalRef(integerAttr.getContext(),
+ integerAttr.getInt());
+ }
+ return failure();
+}
+
+// Creates a named device global with the given attribute.
+static FailureOr<FlatSymbolRefAttr>
+createDeviceGlobal(Location loc, StringAttr name, Attribute targetAttr,
+ OpBuilder &moduleBuilder) {
+ auto deviceType = moduleBuilder.getType<IREE::HAL::DeviceType>();
+ auto globalOp = moduleBuilder.create<IREE::Util::GlobalOp>(
+ loc, name, /*isMutable=*/false, deviceType);
+ globalOp.setPrivate();
+
+ TypedAttr attrValue;
+ if (auto arrayAttr = dyn_cast<ArrayAttr>(targetAttr)) {
+ if (arrayAttr.size() == 1) {
+ auto typedAttr = dyn_cast<TypedAttr>(arrayAttr.getValue().front());
+ if (typedAttr && isa<IREE::HAL::DeviceType>(typedAttr.getType())) {
+ // Don't care exactly what the attribute is, only that it's a device.
+ attrValue = typedAttr;
+ }
+ } else {
+ // Expand arrays to selects.
+ attrValue = moduleBuilder.getAttr<IREE::HAL::DeviceSelectAttr>(deviceType,
+ arrayAttr);
+ }
+ } else if (auto typedAttr = dyn_cast<TypedAttr>(targetAttr)) {
+ if (isa<IREE::HAL::DeviceType>(typedAttr.getType())) {
+ // Don't care exactly what the attribute is, only that it's a device.
+ attrValue = typedAttr;
+ }
+ }
+ if (!attrValue) {
+ return mlir::emitError(loc)
+ << "module has invalid device targets specified; "
+ "expected hal.device.targets to be an array of !hal.device "
+ "initialization attributes or a dictionary with named values";
+ }
+
+ globalOp.setInitialValueAttr(attrValue);
+ return FlatSymbolRefAttr::get(globalOp);
+}
+
+// Creates one or more device globals based on the specified targets and returns
+// the "default" device (usually just the first one specified).
+static FailureOr<FlatSymbolRefAttr> createDeviceGlobals(mlir::ModuleOp moduleOp,
+ Attribute targetsAttr) {
+ auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp.getBody());
+
+ FlatSymbolRefAttr firstDeviceRef;
+ if (auto dictAttr = dyn_cast<DictionaryAttr>(targetsAttr)) {
+ for (auto namedTargetsAttr : dictAttr.getValue()) {
+ auto deviceRefOr =
+ createDeviceGlobal(moduleOp.getLoc(), namedTargetsAttr.getName(),
+ namedTargetsAttr.getValue(), moduleBuilder);
+ if (failed(deviceRefOr)) {
+ return failure();
+ } else if (!firstDeviceRef) {
+ firstDeviceRef = *deviceRefOr;
+ }
+ }
+ } else if (auto arrayAttr = dyn_cast<ArrayAttr>(targetsAttr)) {
+ for (auto [i, ordinalTargetsAttr] : llvm::enumerate(arrayAttr.getValue())) {
+ auto deviceRefOr =
+ createDeviceGlobal(moduleOp.getLoc(),
+ moduleBuilder.getStringAttr(
+ StringRef("__device_") + std::to_string(i)),
+ ordinalTargetsAttr, moduleBuilder);
+ if (failed(deviceRefOr)) {
+ return failure();
+ } else if (!firstDeviceRef) {
+ firstDeviceRef = *deviceRefOr;
+ }
+ }
+ } else {
+ return moduleOp.emitError()
+ << "unexpected `hal.device.targets` attribute; must be a dictionary "
+ "of named devices or an array of devices to use by ordinal";
+ }
+
+ return firstDeviceRef;
+}
+
+// Assigns the default device affinity to all top level ops that don't already
+// have one set.
+static void assignDefaultDeviceAffinity(mlir::ModuleOp moduleOp,
+ FlatSymbolRefAttr defaultDeviceRef) {
+ auto affinityAttr = IREE::HAL::DeviceAffinityAttr::get(
+ moduleOp.getContext(), defaultDeviceRef, /*queue_mask=*/-1ll);
+
+ // Default on the module that applies to any ops that don't otherwise have a
+ // placement. Ideally we never need this but some programs may take/return no
+ // tensors or have tensors come from unattributed containers (lists/dicts).
+ moduleOp->setAttr("stream.affinity.default", affinityAttr);
+
+ // Set all arg/results to route through the default device unless they've
+ // already been assigned.
+ auto affinityName = StringAttr::get(moduleOp.getContext(), "stream.affinity");
+ for (auto funcOp : moduleOp.getOps<FunctionOpInterface>()) {
+ if (funcOp.isPublic()) {
+ for (auto arg : funcOp.getArguments()) {
+ if (isa<IREE::Stream::AffinityTypeInterface>(arg.getType())) {
+ if (!funcOp.getArgAttr(arg.getArgNumber(), affinityName)) {
+ funcOp.setArgAttr(arg.getArgNumber(), affinityName, affinityAttr);
+ }
+ }
+ }
+ for (auto result : llvm::enumerate(funcOp.getResultTypes())) {
+ if (isa<IREE::Stream::AffinityTypeInterface>(result.value())) {
+ if (!funcOp.getResultAttr(result.index(), affinityName)) {
+ funcOp.setResultAttr(result.index(), affinityName, affinityAttr);
+ }
+ }
+ }
+ }
+ }
+}
+
+struct MaterializeTargetDevicesPass
+ : public IREE::HAL::impl::MaterializeTargetDevicesPassBase<
+ MaterializeTargetDevicesPass> {
+ using IREE::HAL::impl::MaterializeTargetDevicesPassBase<
+ MaterializeTargetDevicesPass>::MaterializeTargetDevicesPassBase;
+
+ void runOnOperation() override {
+ auto moduleOp = getOperation();
+
+ // Only materialize devices if there's a module-level attribute specified.
+ FlatSymbolRefAttr defaultDeviceRef;
+ auto deviceTargetAttrs = moduleOp->getAttr("hal.device.targets");
+ if (deviceTargetAttrs) {
+ moduleOp->removeAttr("hal.device.targets");
+
+ // Create the globals and get the default device.
+ auto firstDeviceOr = createDeviceGlobals(moduleOp, deviceTargetAttrs);
+ if (failed(firstDeviceOr)) {
+ // Fails if invalid attributes.
+ return signalPassFailure();
+ }
+ defaultDeviceRef = *firstDeviceOr;
+ }
+
+ // Select the default device from what the user specified or from the first
+ // created.
+ auto defaultDeviceAttr = moduleOp->getAttr("hal.device.default");
+ if (defaultDeviceAttr) {
+ // Always prefer the explicitly specified default device.
+ moduleOp->removeAttr("hal.device.default");
+ auto defaultDeviceRefOr = makeDefaultDeviceAttrRef(defaultDeviceAttr);
+ if (failed(defaultDeviceRefOr)) {
+ moduleOp.emitError() << "invalid `hal.device.default` value, must be "
+ "an ordinal or a name";
+ return signalPassFailure();
+ }
+ defaultDeviceRef = *defaultDeviceRefOr;
+ } else if (!defaultDevice.empty()) {
+ // Fallback to the option specified, if any provided.
+ long long defaultDeviceOrdinal = 0;
+ if (!llvm::getAsSignedInteger(defaultDevice, 10, defaultDeviceOrdinal)) {
+ defaultDeviceRef =
+ makeDefaultDeviceOrdinalRef(&getContext(), defaultDeviceOrdinal);
+ } else {
+ defaultDeviceRef =
+ makeDefaultDeviceNameRef(&getContext(), defaultDevice);
+ }
+ }
+
+ // Assign affinities to all top level ops that don't already have one set.
+ if (defaultDeviceRef) {
+ assignDefaultDeviceAffinity(moduleOp, defaultDeviceRef);
+ }
+ }
+};
+
+} // namespace
+
+} // namespace mlir::iree_compiler::IREE::HAL
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp
index 9857a33..096b7bf 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp
@@ -6,10 +6,12 @@
#include <utility>
+#include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "iree/compiler/Utils/StringUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -27,6 +29,36 @@
// --iree-hal-memoize-device-queries
//===----------------------------------------------------------------------===//
+// All queries for a particular !hal.device global.
+struct DeviceQueries {
+ // Global !hal.device.
+ IREE::Util::GlobalOpInterface deviceOp;
+ // [category, key, default] used for lookup/indexing.
+ SmallVector<Attribute> queryKeys;
+ // Ops performing queries against the device by [category, key, default].
+ DenseMap<Attribute, SmallVector<IREE::HAL::DeviceQueryOp>> queryOps;
+};
+
+// A query being replaced by global lookups.
+struct Query {
+ Query(Location loc) : loc(loc) {}
+ Location loc;
+ IREE::Util::GlobalOp okGlobalOp;
+ IREE::Util::GlobalOp valueGlobalOp;
+ StringAttr categoryAttr;
+ StringAttr keyAttr;
+ TypedAttr defaultValueAttr;
+};
+
+static std::string getDeviceNamePrefix(IREE::Util::GlobalOpInterface deviceOp) {
+ StringRef deviceName = deviceOp.getGlobalName().getValue();
+ if (deviceName.starts_with("__")) {
+ return deviceName.str();
+ }
+ auto prefixedName = "__" + deviceName;
+ return prefixedName.str();
+}
+
// NOTE: this implementation is just for a single active device. As we start to
// support multiple devices we'll need to change this to be per-device.
struct MemoizeDeviceQueriesPass
@@ -35,28 +67,50 @@
void runOnOperation() override {
auto moduleOp = getOperation();
+ // Analyze the module to determine which devices are used where.
+ DeviceAnalysis deviceAnalysis(moduleOp);
+ if (failed(deviceAnalysis.run())) {
+ return signalPassFailure();
+ }
+
+ // Prepare device table indexed by symbol name.
+ DenseMap<Attribute, DeviceQueries> allDeviceQueries;
+ for (auto deviceOp : deviceAnalysis.getDeviceGlobals()) {
+ allDeviceQueries[deviceOp.getGlobalName()].deviceOp = deviceOp;
+ }
+
// Find all query ops we want to memoize and group them together.
// This lets us easily replace all usages of a match with a single variable.
- SmallVector<Attribute> deviceQueryKeys;
- DenseMap<Attribute, std::vector<IREE::HAL::DeviceQueryOp>> deviceQueryOps;
for (auto callableOp : moduleOp.getOps<mlir::CallableOpInterface>()) {
callableOp.walk([&](IREE::HAL::DeviceQueryOp queryOp) {
+ // Try to find the device this query is made on. If analysis failed then
+ // we can't memoize the query today.
+ auto deviceGlobals =
+ deviceAnalysis.lookupDeviceGlobals(queryOp.getDevice());
+ if (!deviceGlobals || deviceGlobals->size() != 1)
+ return WalkResult::advance();
+ IREE::Util::GlobalOpInterface deviceGlobalOp = deviceGlobals->front();
+
+ // Construct key used to dedupe/lookup the query.
auto fullKey = ArrayAttr::get(
moduleOp.getContext(),
{
- // TODO(multi-device): add attr key on device resolve source.
StringAttr::get(moduleOp.getContext(),
queryOp.getCategory() + queryOp.getKey()),
queryOp.getDefaultValue().has_value()
? queryOp.getDefaultValueAttr()
: Attribute{},
});
- auto lookup = deviceQueryOps.try_emplace(
- fullKey, std::vector<IREE::HAL::DeviceQueryOp>{});
+
+ // Track the query on the device.
+ auto &deviceQueries = allDeviceQueries[deviceGlobalOp.getGlobalName()];
+ auto lookup = deviceQueries.queryOps.try_emplace(
+ fullKey, SmallVector<IREE::HAL::DeviceQueryOp>{});
if (lookup.second) {
- deviceQueryKeys.push_back(std::move(fullKey));
+ deviceQueries.queryKeys.push_back(std::move(fullKey));
}
lookup.first->second.push_back(queryOp);
+
return WalkResult::advance();
});
}
@@ -64,54 +118,83 @@
// Create each query variable and replace the uses with loads.
SymbolTable symbolTable(moduleOp);
auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp.getBody());
- for (auto queryKey : llvm::enumerate(deviceQueryKeys)) {
- auto queryOps = deviceQueryOps[queryKey.value()];
- auto anyQueryOp = queryOps.front();
- auto queryType = anyQueryOp.getValue().getType();
+ for (auto deviceOp : deviceAnalysis.getDeviceGlobals()) {
+ auto &deviceQueries = allDeviceQueries[deviceOp.getGlobalName()];
+ if (deviceQueries.queryKeys.empty()) {
+ // No queries against this device.
+ continue;
+ }
- // Merge all the locs as we are deduping the original query ops.
- auto fusedLoc = moduleBuilder.getFusedLoc(llvm::map_to_vector(
- queryOps, [&](Operation *op) { return op->getLoc(); }));
+ // Create one global per unique query key against the device.
+ SmallVector<Query> queries;
+ moduleBuilder.setInsertionPointAfter(deviceOp);
+ for (auto [i, queryKey] : llvm::enumerate(deviceQueries.queryKeys)) {
+ auto &queryOps = deviceQueries.queryOps[queryKey];
+ auto queryLoc = moduleBuilder.getFusedLoc(llvm::map_to_vector(
+ queryOps, [&](auto queryOp) { return queryOp.getLoc(); }));
- // The initializer will perform the query once and store it in the
- // variable.
- std::string variableName =
- "_device_query_" + std::to_string(queryKey.index());
- auto valueGlobalOp = moduleBuilder.create<IREE::Util::GlobalOp>(
- fusedLoc, variableName,
- /*isMutable=*/false, queryType);
- symbolTable.insert(valueGlobalOp);
- valueGlobalOp.setPrivate();
- auto okGlobalOp = moduleBuilder.create<IREE::Util::GlobalOp>(
- fusedLoc, variableName + "_ok",
- /*isMutable=*/false, moduleBuilder.getI1Type());
- symbolTable.insert(okGlobalOp);
- okGlobalOp.setPrivate();
+ // Create a global for the ok flag and the queried value.
+ // TODO(benvanik): create a better name based on the key.
+ auto anyQueryOp = queryOps.front();
+ auto queryType = anyQueryOp.getValue().getType();
+ std::string variableName =
+ getDeviceNamePrefix(deviceOp) + "_query_" + std::to_string(i) +
+ "_" + sanitizeSymbolName(anyQueryOp.getCategory()) + "_" +
+ sanitizeSymbolName(anyQueryOp.getKey());
+ auto okGlobalOp = moduleBuilder.create<IREE::Util::GlobalOp>(
+ queryLoc, variableName + "_ok",
+ /*isMutable=*/false, moduleBuilder.getI1Type());
+ symbolTable.insert(okGlobalOp);
+ okGlobalOp.setPrivate();
+ auto valueGlobalOp = moduleBuilder.create<IREE::Util::GlobalOp>(
+ queryLoc, variableName,
+ /*isMutable=*/false, queryType);
+ symbolTable.insert(valueGlobalOp);
+ valueGlobalOp.setPrivate();
+ // Stash the globals for initialization.
+ Query query(queryLoc);
+ query.okGlobalOp = okGlobalOp;
+ query.valueGlobalOp = valueGlobalOp;
+ query.categoryAttr = anyQueryOp.getCategoryAttr();
+ query.keyAttr = anyQueryOp.getKeyAttr();
+ query.defaultValueAttr = anyQueryOp.getDefaultValueAttr();
+ queries.push_back(query);
+
+ // Replace all queries with loads of the global values.
+ for (auto queryOp : queryOps) {
+ OpBuilder replaceBuilder(queryOp);
+ auto okLoadOp =
+ okGlobalOp.createLoadOp(queryOp.getLoc(), replaceBuilder);
+ auto resultLoadOp =
+ valueGlobalOp.createLoadOp(queryOp.getLoc(), replaceBuilder);
+ queryOp.replaceAllUsesWith(ValueRange{
+ okLoadOp.getLoadedGlobalValue(),
+ resultLoadOp.getLoadedGlobalValue(),
+ });
+ queryOp.erase();
+ }
+ }
+
+ // Create an initializer for the device where we will perform all queries.
+ auto fusedLoc = moduleBuilder.getFusedLoc(
+ llvm::map_to_vector(queries, [&](auto &query) { return query.loc; }));
auto initializerOp =
moduleBuilder.create<IREE::Util::InitializerOp>(fusedLoc);
auto funcBuilder = OpBuilder::atBlockBegin(initializerOp.addEntryBlock());
- // TODO(multi-device): pass in resolve info to the call and reuse.
- Value device = IREE::HAL::DeviceType::resolveAny(fusedLoc, funcBuilder);
- auto queryOp = funcBuilder.create<IREE::HAL::DeviceQueryOp>(
- fusedLoc, funcBuilder.getI1Type(), queryType, device,
- anyQueryOp.getCategoryAttr(), anyQueryOp.getKeyAttr(),
- anyQueryOp.getDefaultValueAttr());
- okGlobalOp.createStoreOp(fusedLoc, queryOp.getOk(), funcBuilder);
- valueGlobalOp.createStoreOp(fusedLoc, queryOp.getValue(), funcBuilder);
- funcBuilder.create<IREE::Util::ReturnOp>(fusedLoc);
-
- for (auto queryOp : queryOps) {
- OpBuilder replaceBuilder(queryOp);
- auto okLoadOp = okGlobalOp.createLoadOp(fusedLoc, replaceBuilder);
- auto resultLoadOp =
- valueGlobalOp.createLoadOp(fusedLoc, replaceBuilder);
- queryOp.replaceAllUsesWith(ValueRange{
- okLoadOp.getLoadedGlobalValue(),
- resultLoadOp.getLoadedGlobalValue(),
- });
- queryOp.erase();
+ Value device =
+ deviceOp.createLoadOp(fusedLoc, funcBuilder).getLoadedGlobalValue();
+ for (auto [i, queryKey] : llvm::enumerate(deviceQueries.queryKeys)) {
+ auto &query = queries[i];
+ auto queryOp = funcBuilder.create<IREE::HAL::DeviceQueryOp>(
+ fusedLoc, funcBuilder.getI1Type(),
+ query.valueGlobalOp.getGlobalType(), device, query.categoryAttr,
+ query.keyAttr, query.defaultValueAttr);
+ query.okGlobalOp.createStoreOp(fusedLoc, queryOp.getOk(), funcBuilder);
+ query.valueGlobalOp.createStoreOp(fusedLoc, queryOp.getValue(),
+ funcBuilder);
}
+ funcBuilder.create<IREE::Util::ReturnOp>(fusedLoc);
}
}
};
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
index 5fe1fd6..54a87b0 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
@@ -27,10 +27,6 @@
namespace {
struct TransformOptions : public PassPipelineOptions<TransformOptions> {
- // TODO(benvanik): replace the global iree-hal-target-backends flag with this.
- // ListOption<std::string> targets{
- // *this, "targets", llvm::cl::desc("One or more HAL devices to target."),
- // llvm::cl::ZeroOrMore};
Option<bool> serializeExecutables{
*this,
"serialize-executables",
@@ -181,6 +177,42 @@
}
//===----------------------------------------------------------------------===//
+// --iree-hal-device-assignment-pipeline
+//===----------------------------------------------------------------------===//
+
+void buildHALDeviceAssignmentPassPipeline(
+ OpPassManager &passManager, const TargetRegistry &targetRegistry,
+ const AssignmentOptions &assignmentOptions) {
+ // The HAL must know its targets early on in the process. This pass discovers/
+ // derives/specifies the target devices and annotates the module with that
+ // information. This allows subsequent passes to lookup which devices they are
+ // targeting.
+ if (!assignmentOptions.legacyTargetBackends.empty()) {
+ // Today we just assign devices from parameters but we should instead be
+ // performing analysis at the flow level and then doing magic device
+ // database lookups here.
+ passManager.addPass(IREE::HAL::createAssignLegacyTargetDevicesPass(
+ {&targetRegistry, assignmentOptions.legacyTargetBackends}));
+ }
+ if (!assignmentOptions.targetDevices.empty()) {
+ passManager.addPass(IREE::HAL::createAssignTargetDevicesPass(
+ {assignmentOptions.targetDevices}));
+ }
+
+ // Create globals for each device (if needed).
+ passManager.addPass(IREE::HAL::createMaterializeTargetDevicesPass(
+ {assignmentOptions.defaultDevice}));
+
+ // Resolve #hal.device.promise and #hal.device.alias attributes.
+ passManager.addPass(IREE::HAL::createResolveDevicePromisesPass());
+ passManager.addPass(
+ IREE::HAL::createResolveDeviceAliasesPass({&targetRegistry}));
+
+ // Verify devices are valid.
+ passManager.addPass(IREE::HAL::createVerifyDevicesPass({&targetRegistry}));
+}
+
+//===----------------------------------------------------------------------===//
// --iree-hal-configuration-pipeline
//===----------------------------------------------------------------------===//
@@ -196,23 +228,12 @@
// and initial interface analysis (we rely on CSE and such having been run).
addCleanupPatterns(passManager);
- //----------------------------------------------------------------------------
- // Device assignment and interface materialization
- //----------------------------------------------------------------------------
+ // Verify devices are valid.
+ passManager.addPass(IREE::HAL::createVerifyDevicesPass({&targetRegistry}));
- // The HAL must know its targets early on in the process. This pass discovers/
- // derives/specifies the target devices and annotates the module with that
- // information. This allows subsequent passes to lookup which devices they are
- // targeting.
- if (!targetOptions.targets.empty()) {
- // Today we just assign devices from parameters but we should instead be
- // performing analysis at the flow level and then doing magic device
- // database lookups here.
- passManager.addPass(IREE::HAL::createAssignTargetDevicesPass(
- {&targetRegistry, targetOptions.targets}));
- }
- passManager.addPass(
- IREE::HAL::createVerifyTargetEnvironmentPass({targetRegistry}));
+ //----------------------------------------------------------------------------
+ // Device-specific interface materialization
+ //----------------------------------------------------------------------------
// Add dispatch instrumentation prior to materializing interfaces so we can
// more easily mutate the stream dispatch ops and exports.
@@ -271,16 +292,25 @@
// Device assignment and interface materialization
//----------------------------------------------------------------------------
- if (hooks.beforePhase)
+ if (hooks.beforePhase) {
hooks.beforePhase(PipelinePhase::ExecutableSources, passManager);
+ }
if (compileFrom < PipelinePhase::ExecutableSources) {
+ AssignmentOptions assignmentOptions;
+ assignmentOptions.legacyTargetBackends = targetOptions.legacyTargetBackends;
+ assignmentOptions.targetDevices = targetOptions.targetDevices;
+ assignmentOptions.defaultDevice = targetOptions.defaultDevice;
+ buildHALDeviceAssignmentPassPipeline(passManager, targetRegistry,
+ assignmentOptions);
buildHALConfigurationPassPipeline(passManager, targetRegistry,
targetOptions, hooks);
- FunctionLikeNest(passManager).addPass([]() {
- return createCPUMaterializeUpperBoundTileSizePass();
- });
+ // HACK: this should not be here and will be going away. It exists for
+ // lowering iree_linalg_ext.upper_bound_tile_size ops that exist on the
+ // host. We should be using stream ops for performing such calculations that
+ // we can attach affinities to and understand what devices are being used.
+ passManager.addPass(createCPUMaterializeUpperBoundTileSizePass());
// Preprocess executables using an external tool. The tool may mutate one or
// more variants and even insert or remove variants.
@@ -290,17 +320,20 @@
}
}
- if (hooks.afterPhase)
+ if (hooks.afterPhase) {
hooks.afterPhase(PipelinePhase::ExecutableSources, passManager);
- if (compileTo == PipelinePhase::ExecutableSources)
+ }
+ if (compileTo == PipelinePhase::ExecutableSources) {
return;
+ }
//----------------------------------------------------------------------------
// Executable translation
//----------------------------------------------------------------------------
- if (hooks.beforePhase)
+ if (hooks.beforePhase) {
hooks.beforePhase(PipelinePhase::ExecutableConfigurations, passManager);
+ }
if (compileFrom < PipelinePhase::ExecutableConfigurations) {
// Select a translation strategy for each hal.executable.variant and
@@ -343,10 +376,12 @@
}
}
- if (hooks.afterPhase)
+ if (hooks.afterPhase) {
hooks.afterPhase(PipelinePhase::ExecutableConfigurations, passManager);
- if (compileTo == PipelinePhase::ExecutableConfigurations)
+ }
+ if (compileTo == PipelinePhase::ExecutableConfigurations) {
return;
+ }
// TODO(benvanik): move translation after conversion; today translation
// inserts the workgroup count logic we need to convert but we could instead
@@ -359,8 +394,9 @@
// After this point the executables are opaque blobs and we cannot change
// their interfaces.
- if (hooks.beforePhase)
+ if (hooks.beforePhase) {
hooks.beforePhase(PipelinePhase::ExecutableTargets, passManager);
+ }
if (compileFrom < PipelinePhase::ExecutableTargets) {
passManager.addNestedPass<IREE::HAL::ExecutableOp>(
@@ -376,10 +412,12 @@
IREE::HAL::createCaptureExecutableSourcesPass({"2.translated"}));
}
- if (hooks.afterPhase)
+ if (hooks.afterPhase) {
hooks.afterPhase(PipelinePhase::ExecutableTargets, passManager);
- if (compileTo == PipelinePhase::ExecutableTargets)
+ }
+ if (compileTo == PipelinePhase::ExecutableTargets) {
return;
+ }
// Substitute hal.executables we've translated with those specified on the
// command line. This developer feature allows for splicing in hand-authored
@@ -456,6 +494,14 @@
FunctionLikeNest(passManager)
.addPass(IREE::HAL::createElideRedundantCommandsPass);
+ // Initialize device globals now that we've done the analysis that is easier
+ // with them in their original target specification.
+ passManager.addPass(IREE::HAL::createInitializeDevicesPass({targetRegistry}));
+
+ // Combine the initializers we emitted during resource cache
+ // materialization.
+ passManager.addPass(IREE::Util::createCombineInitializersPass());
+
// TODO: Maybe this should be a part of Affine lowering pass.
// Remove if it is added there.
// https://github.com/llvm/llvm-project/issues/78458
@@ -468,10 +514,6 @@
// SimplifyGlobalAccesses are currently broken with scf present.
FunctionLikeNest(passManager).addPass(mlir::createConvertSCFToCFPass);
- // Combine the initializers we emitted during resource cache
- // materialization.
- passManager.addPass(IREE::Util::createCombineInitializersPass());
-
//----------------------------------------------------------------------------
// Executable serialization
//----------------------------------------------------------------------------
@@ -558,6 +600,14 @@
registerPasses();
// Pipelines.
+ PassPipelineRegistration<AssignmentOptions>(
+ "iree-hal-device-assignment-pipeline",
+ "Runs HAL target device assignment pipeline.",
+ [](OpPassManager &passManager,
+ const AssignmentOptions &assignmentOptions) {
+ buildHALDeviceAssignmentPassPipeline(
+ passManager, TargetRegistry::getGlobal(), assignmentOptions);
+ });
PassPipelineRegistration<>("iree-hal-configuration-pipeline",
"Runs HAL target configuration pipeline.",
[](OpPassManager &passManager) {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h
index ec5c9c5..e3231c3 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h
@@ -46,9 +46,41 @@
std::function<void(PipelinePhase phase, OpPassManager &)> afterPhase;
};
+struct AssignmentOptions : public PassPipelineOptions<AssignmentOptions> {
+ // TODO(benvanik): remove the legacy flag once users are switched to devices.
+ ListOption<std::string> legacyTargetBackends{
+ *this,
+ "legacy-target-backends",
+ llvm::cl::desc("DEPRECATED: Target backend names."),
+ llvm::cl::ZeroOrMore,
+ };
+ ListOption<std::string> targetDevices{
+ *this,
+ "target-devices",
+ llvm::cl::desc("Target device specifications."),
+ llvm::cl::ZeroOrMore,
+ };
+ Option<std::string> defaultDevice{
+ *this,
+ "default-device",
+ llvm::cl::desc("Which device is considered the default when no device "
+ "affinity is specified. Either the device name when names "
+ "are specified or the numeric ordinal of the device."),
+ llvm::cl::init(""),
+ };
+};
+
+// Assigns devices from flags and coarse module-level specification.
+// Frontends are encouraged to create and assign devices themselves in order to
+// support more complex configurations (multiple devices, fallbacks, etc).
+void buildHALDeviceAssignmentPassPipeline(
+ OpPassManager &passManager, const TargetRegistry &targetRegistry,
+ const AssignmentOptions &assignmentOptions);
+
// Adds a set of passes to the given pass manager that run the head of the HAL
-// pipeline to assign devices, materialize interfaces, and translate
-// executables. The host portion of the program is annotated but not modified.
+// pipeline to materialize interfaces, import externally specified executables,
+// and translate executables. The host portion of the program is annotated but
+// not modified.
void buildHALConfigurationPassPipeline(OpPassManager &passManager,
const TargetRegistry &targetRegistry,
const TargetOptions &targetOptions,
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.td
index 188340b..f5e345f 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.td
@@ -42,26 +42,8 @@
// Device management
//===----------------------------------------------------------------------===//
-def VerifyTargetEnvironmentPass :
- Pass<"iree-hal-verify-target-environment", "mlir::ModuleOp"> {
- let summary = "Verifies that the target execution environment is valid.";
- let description = [{
- Verifies that the target execution environment is valid.
- `#hal.device.target` and `#hal.executable.target` attribute placement and
- definition will be checked that they reference the available target backends
- and that they are structurally valid.
- }];
- let options = [
- Option<
- "targetRegistry", "target-registry",
- "llvm::cl::TargetRegistryRef", "",
- "Target backend registry containing the list of available backends."
- >,
- ];
-}
-
-def AssignTargetDevicesPass :
- Pass<"iree-hal-assign-target-devices", "mlir::ModuleOp"> {
+def AssignLegacyTargetDevicesPass :
+ Pass<"iree-hal-assign-legacy-target-devices", "mlir::ModuleOp"> {
let summary = "Assigns the HAL devices the module will target to the given list of targets.";
let description = [{
Assigns target HAL devices to the module based on the given list.
@@ -70,7 +52,7 @@
Option<
"targetRegistry", "target-registry",
"llvm::cl::TargetRegistryRef", "",
- "Target backend registry containing the list of available backends."
+ "Target registry containing the list of available devices and backends."
>,
ListOption<
"targetBackends", "targetBackends",
@@ -78,6 +60,131 @@
"List of target backends to assign as device targets."
>,
];
+ let dependentDialects = [
+ "IREE::HAL::HALDialect",
+ ];
+}
+
+def AssignTargetDevicesPass :
+ Pass<"iree-hal-assign-target-devices", "mlir::ModuleOp"> {
+ let summary = "Assigns the HAL devices the module will target to the given list of target specifications.";
+ let description = [{
+ Assigns target HAL devices to the module based on the given list of target
+ specifications.
+
+ Targets can be specified in several ways depending on whether there are
+ multiple devices, named devices, or devices imported from external files.
+ Human-friendly device aliases can be used as shorthand for
+ `IREE::HAL::TargetDevice` implementations providing their own configuration.
+ The aliases are identical to those used by `#hal.device.alias<>`.
+
+ If multiple targets are specified they will be available as multiple
+ distinct devices. A single device may select from one or more targets such
+ that the first enumerated that matches at runtime will be selected. For
+ example a `gpu` device may select between CUDA, HIP, or Vulkan at runtime
+ based on what kind of device the user has and what HAL implementations were
+ compiled into the runtime.
+
+ Examples using the canonical flag:
+ ```mlir
+ // Two devices, one the local host device and the other a Vulkan device:
+ --iree-hal-target-device=local
+ --iree-hal-target-device=vulkan
+
+ // One device selecting between Vulkan if available and otherwise use the
+ // local host device:
+ --iree-hal-target-device=vulkan,local
+
+ // Two CUDA devices selected by runtime ordinal; at runtime two --device=
+ // flags are required to configure both devices:
+ --iree-hal-target-device=cuda[0]
+ --iree-hal-target-device=cuda[1]
+
+ // A fully-defined target specification:
+ --iree-hal-target-device=#hal.device.target<"cuda", {...}, [#hal.executable.target<...>]>
+
+ // Named device for defining a reference by #hal.device.promise<@some_name>:
+ --iree-hal-target-device=some_name=vulkan
+ ```
+ }];
+ let options = [
+ ListOption<
+ "targetDevices", "targetDevices",
+ "std::string",
+ "List of target device specifications."
+ >,
+ ];
+ let dependentDialects = [
+ "IREE::HAL::HALDialect",
+ ];
+}
+
+def MaterializeTargetDevicesPass :
+ Pass<"iree-hal-materialize-target-devices", "mlir::ModuleOp"> {
+ let summary = "Materializes global device handles based on a `hal.device.targets` spec.";
+ let description = [{
+ Materializes global `!hal.device` ops for the devices specified by the
+ `hal.device.targets` attribute on the module. An optional default device can
+ be specified to assign to ops that do not have a default device specified.
+ }];
+ let options = [
+ Option<
+ "defaultDevice", "defaultDevice",
+ "std::string", "",
+ "Which device is considered the default when no device affinity is specified."
+ >,
+ ];
+ let dependentDialects = [
+ "IREE::HAL::HALDialect",
+ "IREE::Util::UtilDialect",
+ ];
+}
+
+def ResolveDevicePromisesPass :
+ Pass<"iree-hal-resolve-device-promises", "mlir::ModuleOp"> {
+ let summary = "Resolves `#hal.device.promise` attributes to their devices.";
+ let description = [{
+ Resolves promised device affinities to the materialized device globals that
+ were promised. Verifies that all promises are resolved.
+ }];
+ let dependentDialects = [
+ "IREE::HAL::HALDialect",
+ ];
+}
+
+def ResolveDeviceAliasesPass :
+ Pass<"iree-hal-resolve-device-aliases", "mlir::ModuleOp"> {
+ let summary = "Resolves `#hal.device.alias` attributes to their expanded configurations.";
+ let description = [{
+ Resolves device aliases to the concrete targets using defaults, flags, and
+ registered device configurations.
+ }];
+ let options = [
+ Option<
+ "targetRegistry", "target-registry",
+ "llvm::cl::TargetRegistryRef", "",
+ "Target registry containing the list of available devices and backends."
+ >,
+ ];
+ let dependentDialects = [
+ "IREE::HAL::HALDialect",
+ ];
+}
+
+def VerifyDevicesPass :
+ Pass<"iree-hal-verify-devices", "mlir::ModuleOp"> {
+ let summary = "Verifies that all devices can be targeted with the available compiler plugins.";
+ let description = [{
+ Verifies that `#hal.device.target` and `#hal.executable.target` attributes
+ reference targets that are registered with the compiler.
+ }];
+ let options = [
+ Option<
+ "targetRegistry", "target-registry",
+ "llvm::cl::TargetRegistryRef", "",
+ "Target registry containing the list of available devices and backends."
+ >,
+ ];
}
def FixupLegacySyncPass :
@@ -213,7 +320,7 @@
Option<
"targetRegistry", "target-registry",
"llvm::cl::TargetRegistryRef", "",
- "Target backend registry containing the list of available backends."
+ "Target registry containing the list of available devices and backends."
>,
];
}
@@ -229,7 +336,7 @@
Option<
"targetRegistry", "target-registry",
"llvm::cl::TargetRegistryRef", "",
- "Target backend registry containing the list of available backends."
+ "Target registry containing the list of available devices and backends."
>,
Option<
"target", "target",
@@ -251,7 +358,7 @@
Option<
"targetRegistry", "target-registry",
"llvm::cl::TargetRegistryRef", "",
- "Target backend registry containing the list of available backends."
+ "Target registry containing the list of available devices and backends."
>,
];
}
@@ -268,7 +375,7 @@
Option<
"targetRegistry", "target-registry",
"llvm::cl::TargetRegistryRef", "",
- "Target backend registry containing the list of available backends."
+ "Target registry containing the list of available devices and backends."
>,
Option<
"target", "target",
@@ -300,7 +407,7 @@
Option<
"targetRegistry", "target-registry",
"llvm::cl::TargetRegistryRef", "",
- "Target backend registry containing the list of available backends."
+ "Target registry containing the list of available devices and backends."
>,
];
}
@@ -318,7 +425,7 @@
Option<
"targetRegistry", "target-registry",
"llvm::cl::TargetRegistryRef", "",
- "Target backend registry containing the list of available backends."
+ "Target registry containing the list of available devices and backends."
>,
Option<
"target", "target",
@@ -354,7 +461,7 @@
Option<
"targetRegistry", "target-registry",
"llvm::cl::TargetRegistryRef", "",
- "Target backend registry containing the list of available backends."
+ "Target registry containing the list of available devices and backends."
>,
Option<
"debugLevel", "debug-level",
@@ -386,7 +493,7 @@
Option<
"targetRegistry", "target-registry",
"llvm::cl::TargetRegistryRef", "",
- "Target backend registry containing the list of available backends."
+ "Target registry containing the list of available devices and backends."
>,
Option<
"target", "target",
@@ -439,6 +546,28 @@
];
}
+def InitializeDevicesPass :
+ Pass<"iree-hal-initialize-devices", "mlir::ModuleOp"> {
+ let summary = "Initializes global device handles based on their specification.";
+ let description = [{
+ Initializes each global `!hal.device` based on the specification attribute
+ by building initializers that enumerate and select the appropriate device.
+ }];
+ let options = [
+ Option<
+ "targetRegistry", "target-registry",
+ "llvm::cl::TargetRegistryRef", "",
+ "Target registry containing the list of available devices and backends."
+ >,
+ ];
+ let dependentDialects = [
+ "mlir::arith::ArithDialect",
+ "mlir::scf::SCFDialect",
+ "IREE::HAL::HALDialect",
+ "IREE::Util::UtilDialect",
+ ];
+}
+
def MaterializeResourceCachesPass :
Pass<"iree-hal-materialize-resource-caches", "mlir::ModuleOp"> {
let summary = "Materializes cached globals for device resources.";
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/ResolveDeviceAliases.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/ResolveDeviceAliases.cpp
new file mode 100644
index 0000000..0108aee
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/ResolveDeviceAliases.cpp
@@ -0,0 +1,134 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <memory>
+#include <utility>
+
+#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
+#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir::iree_compiler::IREE::HAL {
+
+#define GEN_PASS_DEF_RESOLVEDEVICEALIASESPASS
+#include "iree/compiler/Dialect/HAL/Transforms/Passes.h.inc"
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// --iree-hal-resolve-device-aliases
+//===----------------------------------------------------------------------===//
+
+static FailureOr<Attribute>
+resolveAliasAttr(Operation *forOp, IREE::HAL::DeviceAliasAttr aliasAttr,
+ const TargetRegistry &targetRegistry) {
+ // Lookup device in the registry.
+ auto targetDevice =
+ targetRegistry.getTargetDevice(aliasAttr.getDeviceID().getValue());
+ if (!targetDevice) {
+ auto diagnostic = forOp->emitError();
+ diagnostic << "unregistered device alias " << aliasAttr.getDeviceID()
+ << "; ensure it is linked into the compiler (available = [ ";
+ for (const auto &targetName : targetRegistry.getRegisteredTargetDevices()) {
+ diagnostic << "'" << targetName << "' ";
+ }
+ diagnostic << "])";
+ return diagnostic;
+ }
+
+ // Query the default device target.
+ auto defaultAttr =
+ targetDevice->getDefaultDeviceTarget(forOp->getContext(), targetRegistry);
+ assert(defaultAttr && "expected a default device target attr");
+
+ // Merge in any additional configuration from the alias attr.
+ if (aliasAttr.getOrdinal().has_value() ||
+ (aliasAttr.getConfiguration() && !aliasAttr.getConfiguration().empty())) {
+ NamedAttrList configAttrs;
+ if (auto defaultConfigAttr = defaultAttr.getConfiguration()) {
+ for (auto existingAttr : defaultConfigAttr) {
+ configAttrs.push_back(existingAttr);
+ }
+ }
+ if (auto overrideConfigAttr = aliasAttr.getConfiguration()) {
+ for (auto overrideAttr : overrideConfigAttr) {
+ configAttrs.set(overrideAttr.getName(), overrideAttr.getValue());
+ }
+ }
+ if (aliasAttr.getOrdinal().has_value()) {
+ configAttrs.set("ordinal",
+ IntegerAttr::get(IndexType::get(forOp->getContext()),
+ aliasAttr.getOrdinal().value()));
+ }
+ defaultAttr = IREE::HAL::DeviceTargetAttr::get(
+ forOp->getContext(), defaultAttr.getDeviceID(),
+ DictionaryAttr::get(forOp->getContext(), configAttrs),
+ defaultAttr.getExecutableTargets());
+ }
+
+ return defaultAttr;
+}
+
+static FailureOr<Attribute>
+resolveNestedAliasAttrs(Operation *forOp, Attribute attr,
+ const TargetRegistry &targetRegistry) {
+ if (auto aliasAttr = dyn_cast<IREE::HAL::DeviceAliasAttr>(attr)) {
+ return resolveAliasAttr(forOp, aliasAttr, targetRegistry);
+ } else if (auto selectAttr = dyn_cast<IREE::HAL::DeviceSelectAttr>(attr)) {
+ SmallVector<Attribute> resolvedAttrs;
+ bool didChange = false;
+ for (auto deviceAttr : selectAttr.getDevices()) {
+ auto resolvedAttr =
+ resolveNestedAliasAttrs(forOp, deviceAttr, targetRegistry);
+ if (failed(resolvedAttr)) {
+ return failure();
+ }
+ didChange = didChange || *resolvedAttr != deviceAttr;
+ resolvedAttrs.push_back(*resolvedAttr);
+ }
+ return didChange ? IREE::HAL::DeviceSelectAttr::get(attr.getContext(),
+ resolvedAttrs)
+ : attr;
+ } else {
+ return attr; // pass-through
+ }
+}
+
+struct ResolveDeviceAliasesPass
+ : public IREE::HAL::impl::ResolveDeviceAliasesPassBase<
+ ResolveDeviceAliasesPass> {
+ using IREE::HAL::impl::ResolveDeviceAliasesPassBase<
+ ResolveDeviceAliasesPass>::ResolveDeviceAliasesPassBase;
+ void runOnOperation() override {
+ // Walks all device globals and resolve any aliases found.
+ auto moduleOp = getOperation();
+ for (auto globalOp : moduleOp.getOps<IREE::Util::GlobalOpInterface>()) {
+ if (!isa<IREE::HAL::DeviceType>(globalOp.getGlobalType())) {
+ continue;
+ }
+ auto initialValue = globalOp.getGlobalInitialValue();
+ if (!initialValue) {
+ continue;
+ }
+ auto resolvedValue = resolveNestedAliasAttrs(globalOp, initialValue,
+ *targetRegistry.value);
+ if (failed(resolvedValue)) {
+ return signalPassFailure();
+ }
+ globalOp.setGlobalInitialValue(*resolvedValue);
+ }
+ }
+};
+
+} // namespace
+
+} // namespace mlir::iree_compiler::IREE::HAL
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/ResolveDevicePromises.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/ResolveDevicePromises.cpp
new file mode 100644
index 0000000..413c107
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/ResolveDevicePromises.cpp
@@ -0,0 +1,155 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <memory>
+#include <utility>
+
+#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir::iree_compiler::IREE::HAL {
+
+#define GEN_PASS_DEF_RESOLVEDEVICEPROMISESPASS
+#include "iree/compiler/Dialect/HAL/Transforms/Passes.h.inc"
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// --iree-hal-resolve-device-promises
+//===----------------------------------------------------------------------===//
+
+struct ResolveDevicePromisesPass
+ : public IREE::HAL::impl::ResolveDevicePromisesPassBase<
+ ResolveDevicePromisesPass> {
+ using IREE::HAL::impl::ResolveDevicePromisesPassBase<
+ ResolveDevicePromisesPass>::ResolveDevicePromisesPassBase;
+
+ void runOnOperation() override {
+ auto moduleOp = getOperation();
+
+ // Resolves a #hal.device.promise attr to a #hal.device.affinity. Fails if
+ // the referenced device is not found.
+ SymbolTable symbolTable(moduleOp);
+ auto resolvePromise = [&](Operation *fromOp,
+ IREE::HAL::DevicePromiseAttr promiseAttr)
+ -> FailureOr<IREE::Stream::AffinityAttr> {
+ auto deviceOp =
+ symbolTable.lookupNearestSymbolFrom<IREE::Util::GlobalOpInterface>(
+ fromOp, promiseAttr.getDevice());
+ if (!deviceOp) {
+ return fromOp->emitOpError()
+ << "references a promised device that was not declared: "
+ << promiseAttr;
+ }
+ return cast<IREE::Stream::AffinityAttr>(
+ IREE::HAL::DeviceAffinityAttr::get(&getContext(),
+ FlatSymbolRefAttr::get(deviceOp),
+ promiseAttr.getQueueMask()));
+ };
+
+ // Resolves any #hal.device.promise attr on the op.
+ auto resolvePromiseAttrs = [&](Operation *op, DictionaryAttr attrDict)
+ -> std::optional<std::pair<DictionaryAttr, WalkResult>> {
+ bool didReplaceAny = false;
+ auto newDict = dyn_cast_if_present<DictionaryAttr>(attrDict.replace(
+ [&](Attribute attr)
+ -> std::optional<std::pair<Attribute, WalkResult>> {
+ if (auto promiseAttr =
+ dyn_cast_if_present<IREE::HAL::DevicePromiseAttr>(attr)) {
+ auto resolvedAttrOr = resolvePromise(op, promiseAttr);
+ if (failed(resolvedAttrOr)) {
+ return std::make_pair(attr, WalkResult::interrupt());
+ }
+ didReplaceAny = true;
+ return std::make_pair(resolvedAttrOr.value(),
+ WalkResult::advance());
+ }
+ return std::nullopt;
+ }));
+ if (newDict) {
+ return std::make_pair(newDict, didReplaceAny ? WalkResult::advance()
+ : WalkResult::skip());
+ } else {
+ return std::make_pair(attrDict, WalkResult::interrupt());
+ }
+ };
+ auto resolveAllPromiseAttrs =
+ [&](Operation *op,
+ MutableArrayRef<DictionaryAttr> attrDicts) -> WalkResult {
+ bool didReplaceAny = false;
+ for (auto &attrDict : attrDicts) {
+ auto resolveState = resolvePromiseAttrs(op, attrDict);
+ if (!resolveState) {
+ // Failed to resolve while recursively replacing.
+ return WalkResult::interrupt();
+ } else if (!resolveState->second.wasSkipped()) {
+ // Performed a replacement.
+ attrDict = resolveState->first;
+ didReplaceAny = true;
+ }
+ }
+ return didReplaceAny ? WalkResult::advance() : WalkResult::skip();
+ };
+ auto resolvePromisesOnOp = [&](Operation *op) -> WalkResult {
+ auto opAttrs = op->getAttrDictionary();
+ if (opAttrs) {
+ auto resolveState = resolvePromiseAttrs(op, opAttrs);
+ if (!resolveState) {
+ // Failed to resolve while recursively replacing.
+ return WalkResult::interrupt();
+ } else if (!resolveState->second.wasSkipped()) {
+ // Performed a replacement.
+ op->setAttrs(resolveState->first);
+ }
+ }
+ if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
+ SmallVector<DictionaryAttr> argAttrs;
+ funcOp.getAllArgAttrs(argAttrs);
+ auto argStatus = resolveAllPromiseAttrs(op, argAttrs);
+ if (argStatus.wasInterrupted()) {
+ return argStatus;
+ } else if (!argStatus.wasSkipped()) {
+ funcOp.setAllArgAttrs(argAttrs);
+ }
+ SmallVector<DictionaryAttr> resultAttrs;
+ funcOp.getAllResultAttrs(resultAttrs);
+ auto resultStatus = resolveAllPromiseAttrs(op, resultAttrs);
+ if (resultStatus.wasInterrupted()) {
+ return resultStatus;
+ } else if (!resultStatus.wasSkipped()) {
+ funcOp.setAllResultAttrs(resultAttrs);
+ }
+ }
+ return WalkResult::advance();
+ };
+
+ // Walk the entire module and replace promises.
+ // We skip any symbol table op as all devices are top-level only.
+ if (resolvePromisesOnOp(moduleOp).wasInterrupted()) {
+ return signalPassFailure();
+ }
+ if (moduleOp
+ .walk([&](Operation *op) {
+ if (op->hasTrait<OpTrait::SymbolTable>()) {
+ return WalkResult::skip(); // ignore isolated ops
+ }
+ return resolvePromisesOnOp(op);
+ })
+ .wasInterrupted()) {
+ return signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+
+} // namespace mlir::iree_compiler::IREE::HAL
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/VerifyDevices.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/VerifyDevices.cpp
new file mode 100644
index 0000000..6ac9cc8
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/VerifyDevices.cpp
@@ -0,0 +1,179 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <memory>
+#include <utility>
+
+#include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h"
+#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
+#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir::iree_compiler::IREE::HAL {
+
+#define GEN_PASS_DEF_VERIFYDEVICESPASS
+#include "iree/compiler/Dialect/HAL/Transforms/Passes.h.inc"
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// --iree-hal-verify-devices
+//===----------------------------------------------------------------------===//
+
+static void printAvailable(InFlightDiagnostic &diagnostic,
+ const TargetRegistry &targetRegistry) {
+ diagnostic << "available devices: [";
+ llvm::interleaveComma(targetRegistry.getRegisteredTargetDevices(),
+ diagnostic);
+ diagnostic << "], available backends = [";
+ llvm::interleaveComma(targetRegistry.getRegisteredTargetBackends(),
+ diagnostic);
+ diagnostic << "]";
+}
+
+static LogicalResult
+verifyDeviceTargetAttr(Operation *deviceOp,
+ IREE::HAL::DeviceTargetAttr deviceTargetAttr,
+ const TargetRegistry &targetRegistry) {
+ auto targetDevice =
+ targetRegistry.getTargetDevice(deviceTargetAttr.getDeviceID().getValue());
+ if (!targetDevice) {
+ auto diagnostic = deviceOp->emitError();
+ diagnostic << "unregistered target device "
+ << deviceTargetAttr.getDeviceID()
+ << "; ensure it is linked into the compiler (available = [ ";
+ for (const auto &targetName : targetRegistry.getRegisteredTargetDevices()) {
+ diagnostic << "'" << targetName << "' ";
+ }
+ diagnostic << "])";
+ return diagnostic;
+ }
+
+ for (auto executableTargetAttr : deviceTargetAttr.getExecutableTargets()) {
+ auto targetBackend = targetRegistry.getTargetBackend(
+ executableTargetAttr.getBackend().getValue());
+ if (!targetBackend) {
+ auto diagnostic = deviceOp->emitError();
+ diagnostic << "unregistered target backend "
+ << executableTargetAttr.getBackend()
+ << "; ensure it is linked into the compiler (available = [ ";
+ for (const auto &targetName :
+ targetRegistry.getRegisteredTargetBackends()) {
+ diagnostic << "'" << targetName << "' ";
+ }
+ diagnostic << "])";
+ return diagnostic;
+ }
+ }
+
+ return success();
+}
+
+static LogicalResult verifyAttr(Operation *deviceOp, Attribute attr,
+ const TargetRegistry &targetRegistry) {
+ return TypeSwitch<Attribute, LogicalResult>(attr)
+ .Case<IREE::HAL::DeviceTargetAttr>([&](auto deviceTargetAttr) {
+ return verifyDeviceTargetAttr(deviceOp, deviceTargetAttr,
+ targetRegistry);
+ })
+ .Case<IREE::HAL::DeviceSelectAttr>([&](auto deviceSelectAttr) {
+ for (auto attr : deviceSelectAttr.getDevices().getValue()) {
+ if (failed(verifyAttr(deviceOp, attr, targetRegistry))) {
+ return failure();
+ }
+ }
+ return success();
+ })
+ .Default([&](auto attr) {
+ return success(); // probably fallback/ordinal/etc - can't verify
+ });
+}
+
+struct VerifyDevicesPass
+ : public IREE::HAL::impl::VerifyDevicesPassBase<VerifyDevicesPass> {
+ using IREE::HAL::impl::VerifyDevicesPassBase<
+ VerifyDevicesPass>::VerifyDevicesPassBase;
+ void runOnOperation() override {
+ auto moduleOp = getOperation();
+
+ // Devices are required if we need to convert host code or executables.
+ // If we only have hal.executables as input then we can bypass this.
+ // We could extend this check to be a bit smarter at the risk of false
+ // negatives - today this is just handling the standalone hal.executable
+ // compilation workflow.
+ bool anyNonExecutableOps = false;
+ for (auto &op : moduleOp.getOps()) {
+ if (!isa<IREE::HAL::ExecutableOp>(op)) {
+ anyNonExecutableOps = true;
+ break;
+ }
+ }
+ if (!anyNonExecutableOps) {
+ return;
+ }
+
+ // Analyze the module to find all devices.
+ DeviceAnalysis deviceAnalysis(moduleOp);
+ if (failed(deviceAnalysis.run())) {
+ return signalPassFailure();
+ }
+
+ // Devices are only required if we have dialects we may lower into device
+ // code. For now checking for tensor types is probably sufficient though we
+ // may want a pluggable way to decide this (e.g. dialect/type/op
+ // interfaces).
+ auto isTensor = [](Type type) { return isa<TensorType>(type); };
+ bool anyTensors = false;
+ for (auto &op : moduleOp.getOps()) {
+ if (op.hasTrait<OpTrait::IREE::Util::ObjectLike>()) {
+ continue; // ignore executables
+ }
+ op.walk([&](Operation *childOp) {
+ if (llvm::any_of(childOp->getOperandTypes(), isTensor) ||
+ llvm::any_of(childOp->getResultTypes(), isTensor)) {
+ anyTensors = true;
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
+ }
+ // TODO(multi-device): the logic above is insufficient; we only need devices
+ // if the program will end up requiring them but we don't know that here.
+ // We have to wait until we've lowered to the point where we do require a
+ // device _and_ we actually want one (aren't compiling a non-HAL program).
+ // We could probably have an op interface, better output from the pass that
+ // requires the devices, etc. For now we error out in HAL conversion when we
+ // try to resolve devices.
+ if (false && anyTensors && deviceAnalysis.getDeviceGlobals().empty()) {
+ auto diagnostic = moduleOp.emitError();
+ diagnostic
+ << "no HAL devices defined in the module; use the module-level "
+ "hal.device.targets attribute, the --iree-hal-target-device= "
+ "flag, or provide inputs with global !hal.devices defined; ";
+ printAvailable(diagnostic, *targetRegistry.value);
+ return signalPassFailure();
+ }
+
+ // Walk all devices and verify them.
+ for (auto deviceOp : deviceAnalysis.getDeviceGlobals()) {
+ if (auto initialValue = deviceOp.getGlobalInitialValue()) {
+ if (failed(verifyAttr(deviceOp, initialValue, *targetRegistry.value))) {
+ return signalPassFailure();
+ }
+ }
+ }
+ }
+};
+
+} // namespace
+
+} // namespace mlir::iree_compiler::IREE::HAL
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/VerifyTargetEnvironment.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/VerifyTargetEnvironment.cpp
deleted file mode 100644
index 7362b92..0000000
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/VerifyTargetEnvironment.cpp
+++ /dev/null
@@ -1,120 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include <memory>
-#include <utility>
-
-#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
-#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "iree/compiler/Dialect/HAL/Target/TargetBackend.h"
-#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
-#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Diagnostics.h"
-#include "mlir/Pass/Pass.h"
-
-namespace mlir::iree_compiler::IREE::HAL {
-
-#define GEN_PASS_DEF_VERIFYTARGETENVIRONMENTPASS
-#include "iree/compiler/Dialect/HAL/Transforms/Passes.h.inc"
-
-namespace {
-
-//===----------------------------------------------------------------------===//
-// --iree-hal-verify-target-environment
-//===----------------------------------------------------------------------===//
-
-struct VerifyTargetEnvironmentPass
- : public IREE::HAL::impl::VerifyTargetEnvironmentPassBase<
- VerifyTargetEnvironmentPass> {
- using IREE::HAL::impl::VerifyTargetEnvironmentPassBase<
- VerifyTargetEnvironmentPass>::VerifyTargetEnvironmentPassBase;
- void runOnOperation() override {
- auto moduleOp = getOperation();
-
- // Targets are required if we need to convert host code or executables.
- // If we only have hal.executables as input then we can bypass this.
- // We could extend this check to be a bit smarter at the risk of false
- // negatives - today this is just handling the standalone hal.executable
- // compilation workflow.
- bool anyNonExecutableOps = false;
- for (auto &op : moduleOp.getOps()) {
- if (!isa<IREE::HAL::ExecutableOp>(op)) {
- anyNonExecutableOps = true;
- break;
- }
- }
- if (!anyNonExecutableOps)
- return;
-
- // Must have targets specified.
- auto targetsAttr = moduleOp->getAttrOfType<ArrayAttr>("hal.device.targets");
- if (!targetsAttr || targetsAttr.empty()) {
- auto diagnostic = moduleOp.emitError();
- diagnostic
- << "no HAL target devices specified on the module (available = [ ";
- for (const auto &targetName :
- targetRegistry->getRegisteredTargetBackends()) {
- diagnostic << "'" << targetName << "' ";
- }
- diagnostic << "])";
- signalPassFailure();
- return;
- }
-
- // Verify each target is registered.
- for (auto attr : targetsAttr) {
- auto deviceTargetAttr = llvm::dyn_cast<IREE::HAL::DeviceTargetAttr>(attr);
- if (!deviceTargetAttr) {
- moduleOp.emitError() << "invalid target attr type: " << attr;
- signalPassFailure();
- return;
- }
-
- auto targetDevice = targetRegistry->getTargetDevice(
- deviceTargetAttr.getDeviceID().getValue());
- if (!targetDevice) {
- auto diagnostic = moduleOp.emitError();
- diagnostic
- << "unregistered target device " << deviceTargetAttr.getDeviceID()
- << "; ensure it is linked in to the compiler (available = [ ";
- for (const auto &targetName :
- targetRegistry->getRegisteredTargetDevices()) {
- diagnostic << "'" << targetName << "' ";
- }
- diagnostic << "])";
- signalPassFailure();
- return;
- }
-
- for (auto executableTargetAttr :
- deviceTargetAttr.getExecutableTargets()) {
- auto targetBackend = targetRegistry->getTargetBackend(
- executableTargetAttr.getBackend().getValue());
- if (!targetBackend) {
- auto diagnostic = moduleOp.emitError();
- diagnostic
- << "unregistered target backend "
- << executableTargetAttr.getBackend()
- << "; ensure it is linked in to the compiler (available = [ ";
- for (const auto &targetName :
- targetRegistry->getRegisteredTargetBackends()) {
- diagnostic << "'" << targetName << "' ";
- }
- diagnostic << "])";
- signalPassFailure();
- return;
- }
- }
- }
- }
-};
-
-} // namespace
-
-} // namespace mlir::iree_compiler::IREE::HAL
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel
index db9a2d49..a2a704d 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel
@@ -16,6 +16,7 @@
name = "lit",
srcs = enforce_glob(
[
+ "assign_legacy_target_devices.mlir",
"assign_target_devices.mlir",
"capture_executable_sources.mlir",
"convert_to_hal.mlir",
@@ -23,17 +24,21 @@
"dump_executable_sources.mlir",
"elide_redundant_commands.mlir",
"fixup_legacy_sync.mlir",
+ "initialize_devices.mlir",
"materialize_dispatch_instrumentation.mlir",
"materialize_interfaces.mlir",
"materialize_resource_caches.mlir",
+ "materialize_target_devices.mlir",
"memoize_device_queries.mlir",
"preprocess_executables.mlir",
"prune_executables.mlir",
"repeat_dispatches.mlir",
+ "resolve_device_aliases.mlir",
+ "resolve_device_promises.mlir",
"resolve_export_ordinals.mlir",
"strip_executable_contents.mlir",
"substitute_executables.mlir",
- "verify_target_environment.mlir",
+ "verify_devices.mlir",
],
include = ["*.mlir"],
exclude = ["substitute_executables_replacement.mlir"],
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt
index ae4322d..28c81ad 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt
@@ -14,6 +14,7 @@
NAME
lit
SRCS
+ "assign_legacy_target_devices.mlir"
"assign_target_devices.mlir"
"capture_executable_sources.mlir"
"convert_to_hal.mlir"
@@ -21,17 +22,21 @@
"dump_executable_sources.mlir"
"elide_redundant_commands.mlir"
"fixup_legacy_sync.mlir"
+ "initialize_devices.mlir"
"materialize_dispatch_instrumentation.mlir"
"materialize_interfaces.mlir"
"materialize_resource_caches.mlir"
+ "materialize_target_devices.mlir"
"memoize_device_queries.mlir"
"preprocess_executables.mlir"
"prune_executables.mlir"
"repeat_dispatches.mlir"
+ "resolve_device_aliases.mlir"
+ "resolve_device_promises.mlir"
"resolve_export_ordinals.mlir"
"strip_executable_contents.mlir"
"substitute_executables.mlir"
- "verify_target_environment.mlir"
+ "verify_devices.mlir"
TOOLS
FileCheck
iree-opt
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_legacy_target_devices.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_legacy_target_devices.mlir
new file mode 100644
index 0000000..1bdb7af
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_legacy_target_devices.mlir
@@ -0,0 +1,42 @@
+// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-hal-assign-legacy-target-devices)' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-0
+// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-hal-assign-legacy-target-devices{targetBackends=vmvx})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-1
+// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-hal-assign-legacy-target-devices{targetBackends=vmvx,vmvx-inline})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-2
+// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-hal-assign-legacy-target-devices{targetBackends=vmvx,vmvx})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-EQ
+
+// TARGET-1: #device_target_local = #hal.device.target<"local"
+
+// TARGET-2: #device_target_local = #hal.device.target<"local"
+// TARGET-2: #device_target_vmvx_inline = #hal.device.target<"vmvx-inline"
+
+// TARGET-EQ: #device_target_local = #hal.device.target<"local"
+
+// CHECK: module
+// TARGET-0: @module {
+// TARGET-1: @module attributes {
+// TARGET-1-SAME: hal.device.targets = [#device_target_local]
+// TARGET-2: @module attributes {
+// TARGET-2-SAME: hal.device.targets = [#hal.device.select<[#device_target_local, #device_target_vmvx_inline]> : !hal.device]
+// TARGET-EQ: @module attributes {
+// TARGET-EQ-SAME: hal.device.targets = [#device_target_local]}
+module @module {}
+
+// -----
+
+// The pass is a no-op when targets are already specified.
+
+// CHECK: #device_target_foo = #hal.device.target<"foo"
+// CHECK: module @module attributes {hal.device.targets = [#device_target_foo]}
+module @module attributes {
+ hal.device.targets = [#hal.device.target<"foo">]
+} {}
+
+// -----
+
+// The pass does nothing when one or more devices has already been defined.
+
+// CHECK: module @module
+// CHECK-NOT: hal.device.targets
+module @module {
+ // CHECK: @existing_device
+ util.global private @existing_device : !hal.device
+}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_target_devices.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_target_devices.mlir
index 46c0a5e..adce11f 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_target_devices.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_target_devices.mlir
@@ -1,31 +1,45 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-hal-assign-target-devices)' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-0
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=vmvx})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-1
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=vmvx,vmvx-inline})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-2
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=vmvx,vmvx})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-EQ
-
-// TARGET-1: #device_target_local = #hal.device.target<"local"
-
-// TARGET-2: #device_target_local = #hal.device.target<"local"
-// TARGET-2: #device_target_vmvx_inline = #hal.device.target<"vmvx-inline"
-
-// TARGET-EQ: #device_target_local = #hal.device.target<"local"
+// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(iree-hal-assign-target-devices)' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-0
+// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=device})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-1
+// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=device_a,device_b})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-2
+// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=device_a[0],device_a[1]})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-ORDINALS
+// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=#hal.device.target<"local">})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-ATTR
+// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=#hal.device.alias<"device_a">})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-ALIAS
+// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices={"device_a,#hal.device.alias<"device_b">"}})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-SELECT
+// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=device_a=#hal.device.alias<"device_a">,"device_bc=device_b,#hal.device.alias<"device_c">"})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-SELECT-MULTI
// CHECK: module
-// TARGET-0: @module {
-// TARGET-1: @module attributes {
-// TARGET-1-SAME: hal.device.targets = [#device_target_local]
-// TARGET-2: @module attributes {
-// TARGET-2-SAME: hal.device.targets = [#device_target_local, #device_target_vmvx_inline]}
-// TARGET-EQ: @module attributes {
-// TARGET-EQ-SAME: hal.device.targets = [#device_target_local]}
-module @module {}
+// TARGET-0-NOT: hal.device.targets
+// TARGET-1: hal.device.targets = [#hal.device.alias<"device"> : !hal.device]
+// TARGET-2: hal.device.targets = [#hal.device.alias<"device_a"> : !hal.device, #hal.device.alias<"device_b"> : !hal.device]}
+// TARGET-ORDINALS: hal.device.targets = [#hal.device.alias<"device_a"[0]> : !hal.device, #hal.device.alias<"device_a"[1]> : !hal.device]}
+// TARGET-ATTR: hal.device.targets = [#hal.device.target<"local"> : !hal.device]
+// TARGET-ALIAS: hal.device.targets = [#hal.device.alias<"device_a"> : !hal.device]
+// TARGET-SELECT: hal.device.targets = [#hal.device.select<[#hal.device.alias<"device_a"> : !hal.device, #hal.device.alias<"device_b"> : !hal.device]> : !hal.device]
+// TARGET-SELECT-MULTI: hal.device.targets = {
+// TARGET-SELECT-MULTI-SAME: device_a = #hal.device.alias<"device_a"> : !hal.device,
+// TARGET-SELECT-MULTI-SAME: device_bc = #hal.device.select<[#hal.device.alias<"device_b"> : !hal.device, #hal.device.alias<"device_c"> : !hal.device]> : !hal.device
+// TARGET-SELECT-MULTI-SAME: }
+module @module {
+ util.global private @tensor_global : tensor<4xf32>
+}
// -----
// The pass is a no-op when targets are already specified.
-// CHECK: #device_target_foo = #hal.device.target<"foo"
-// CHECK: module @module attributes {hal.device.targets = [#device_target_foo]}
+// CHECK: module @module attributes {
+// CHECK-SAME: hal.device.targets = [#hal.device.target<"foo"> : !hal.device]
module @module attributes {
- hal.device.targets = [#hal.device.target<"foo">]
+ hal.device.targets = [#hal.device.target<"foo"> : !hal.device]
} {}
+
+// -----
+
+// The pass does nothing when one or more devices has already been defined.
+
+// CHECK: module @module
+// CHECK-NOT: hal.device.targets
+module @module {
+ // CHECK: @existing_device
+ util.global private @existing_device : !hal.device
+}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir
index d504c5e..1de9b90 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir
@@ -2,6 +2,8 @@
// Tests an end-to-end simple single-dispatch `dispatch(arg0, arg1) -> result`.
+util.global private @device : !hal.device
+
#executable_target_embedded_elf_aarch64 = #hal.executable.target<"llvm-cpu", "embedded-elf-aarch64">
#executable_target_embedded_elf_x86_64 = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64">
@@ -33,9 +35,7 @@
// CHECK: hal.executable private @ex
hal.executable private @ex {
hal.executable.variant public @embedded_elf_aarch64 target(#executable_target_embedded_elf_aarch64) {
- hal.executable.export public @dispatch ordinal(0) layout(#pipeline_layout_0) attributes {
- translation_info = #iree_codegen.translation_info<CPUDefault>
- } {
+ hal.executable.export public @dispatch ordinal(0) layout(#pipeline_layout_0) {
^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index): // no predecessors
%c1 = arith.constant 1 : index
%0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 4)>()[%arg0]
@@ -53,8 +53,7 @@
#hal.interface.binding<0, 4>,
#hal.interface.binding<1, 5>,
#hal.interface.binding<1, 6>
- ],
- translation_info = #iree_codegen.translation_info<CPUDefault>
+ ]
} {
^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index): // no predecessors
%c1 = arith.constant 1 : index
@@ -69,7 +68,9 @@
// CHECK: util.func public @simpleDispatch
// CHECK-SAME: (%[[ARG0:.+]]: !hal.buffer_view, %[[ARG1:.+]]: !hal.buffer_view) -> !hal.buffer_view
-util.func public @simpleDispatch(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
+util.func public @simpleDispatch(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {
+ stream.affinity = #hal.device.affinity<@device>
+} {
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c16 = arith.constant 16 : index
@@ -79,8 +80,7 @@
// CHECK: %[[ARG0_BUFFER:.+]] = hal.buffer_view.buffer<%[[ARG0]] : !hal.buffer_view> : !hal.buffer
- // (annoyingly out of order)
- // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
+ // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device : !hal.device
// CHECK-DAG: %[[ALLOCATOR:.+]] = hal.device.allocator<%[[DEVICE]] : !hal.device> : !hal.allocator
// CHECK: hal.buffer.assert<%[[ARG0_BUFFER]] : !hal.buffer>
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_benchmarks.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_benchmarks.mlir
index b42b38f..bd6d630 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_benchmarks.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_benchmarks.mlir
@@ -1,12 +1,15 @@
-// RUN: iree-opt --split-input-file --iree-hal-dump-executable-benchmarks %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-hal-dump-executable-benchmarks %s --verify-diagnostics | FileCheck %s
// Tests dumping executable benchmarks to stdout - it's more common to use files
// but this is much easier to test with lit.
+// Ensure devices are copied and made available:
#executable_target_embedded_elf_x86_64 = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64">
-#device_target_cpu = #hal.device.target<"llvm-cpu", [
+// CHECK: util.global private @device
+util.global private @device = #hal.device.target<"llvm-cpu", [
#executable_target_embedded_elf_x86_64
-]>
+]> : !hal.device
+
#pipeline_layout_0 = #hal.pipeline.layout<push_constants = 2, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
@@ -21,144 +24,268 @@
]>
]>
-module attributes {hal.device.targets = [#device_target_cpu]} {
+// Executable should be dumped:
+// CHECK: hal.executable private @ex0
+hal.executable private @ex0 {
+ hal.executable.variant public @embedded_elf_x86_64 target(#executable_target_embedded_elf_x86_64) {
+ hal.executable.export public @dispatch0 ordinal(0) layout(#pipeline_layout_0) attributes {
+ translation_info = #iree_codegen.translation_info<CPUDefault>
+ } {
+ ^bb0(%device: !hal.device, %arg0: index):
+ %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @dispatch0() {
+ func.return
+ }
+ }
- // Executable should be dumped:
- // CHECK: hal.executable private @ex0
- hal.executable private @ex0 {
- hal.executable.variant public @embedded_elf_x86_64 target(#executable_target_embedded_elf_x86_64) {
- hal.executable.export public @dispatch0 ordinal(0) layout(#pipeline_layout_0) attributes {
- translation_info = #iree_codegen.translation_info<CPUDefault>
- } {
- ^bb0(%device: !hal.device, %arg0: index):
- %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0
- hal.return %x, %y, %z : index, index, index
- }
- builtin.module {
- func.func @dispatch0() {
- func.return
- }
- }
-
- hal.executable.export public @dispatch1 ordinal(1) layout(#pipeline_layout_1) attributes {
- translation_info = #iree_codegen.translation_info<CPUDefault>
- } {
- ^bb0(%device: !hal.device, %arg0: index, %arg1: index):
- %c1 = arith.constant 1 : index
- %0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 4)>()[%arg0]
- %1 = arith.addi %0, %arg1 : index
- hal.return %1, %c1, %c1 : index, index, index
- }
- builtin.module {
- func.func @dispatch1() {
- func.return
- }
+ hal.executable.export public @dispatch1 ordinal(1) layout(#pipeline_layout_1) attributes {
+ translation_info = #iree_codegen.translation_info<CPUDefault>
+ } {
+ ^bb0(%device: !hal.device, %arg0: index, %arg1: index):
+ %c1 = arith.constant 1 : index
+ %0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 4)>()[%arg0]
+ %1 = arith.addi %0, %arg1 : index
+ hal.return %1, %c1, %c1 : index, index, index
+ }
+ builtin.module {
+ func.func @dispatch1() {
+ func.return
}
}
}
+}
- // ===========================================================================
- // @dispatch0 benchmark logic:
- // ===========================================================================
+// ===========================================================================
+// @dispatch0 benchmark logic:
+// ===========================================================================
- // CHECK: util.global private mutable @ex0_embedded_elf_x86_64_dispatch0_512_buffer : !hal.buffer
- // CHECK-NEXT: util.initializer {
- // CHECK: %[[BUFFER:.+]] = hal.allocator.allocate<%{{.+}} : !hal.allocator> affinity(%{{.+}}) type("DeviceVisible|DeviceLocal") usage("{{.+}}Dispatch{{.+}}") : !hal.buffer{%c768}
- // CHECK-NEXT: util.global.store %[[BUFFER]], @ex0_embedded_elf_x86_64_dispatch0_512_buffer : !hal.buffer
+// CHECK: util.global private mutable @ex0_embedded_elf_x86_64_dispatch0_512_buffer : !hal.buffer
+// CHECK-NEXT: util.initializer {
+// CHECK: %[[BUFFER:.+]] = hal.allocator.allocate<%{{.+}} : !hal.allocator> affinity(%{{.+}}) type("DeviceVisible|DeviceLocal") usage("{{.+}}Dispatch{{.+}}") : !hal.buffer{%c768}
+// CHECK-NEXT: util.global.store %[[BUFFER]], @ex0_embedded_elf_x86_64_dispatch0_512_buffer : !hal.buffer
- // CHECK: util.func public @ex0_embedded_elf_x86_64_dispatch0_512(%arg0: i32)
- // CHECK-SAME: attributes {iree.abi.stub, iree.reflection = {iree.benchmark = "dispatch"}} {
- // CHECK: %[[BATCH_SIZE:.+]] = arith.index_cast %arg0 : i32 to index
+// CHECK: util.func public @ex0_embedded_elf_x86_64_dispatch0_512(%arg0: i32)
+// CHECK-SAME: attributes {iree.abi.stub, iree.reflection = {iree.benchmark = "dispatch"}} {
+// CHECK: %[[BATCH_SIZE:.+]] = arith.index_cast %arg0 : i32 to index
- // Create command buffer:
- // CHECK: %[[CMD:.+]] = hal.command_buffer.create
+// Create command buffer:
+// CHECK: %[[CMD:.+]] = hal.command_buffer.create
- // Setup dispatch constants and bindings:
- // CHECK: hal.command_buffer.push_constants<%[[CMD]] : !hal.command_buffer> layout(%{{.+}} : !hal.pipeline_layout) offset(0) values([%c100_i32, %c200_i32]) : i32, i32
- // CHECK: %[[BUFFER:.+]] = util.global.load @ex0_embedded_elf_x86_64_dispatch0_512_buffer
- // CHECK: hal.command_buffer.push_descriptor_set<%[[CMD]] : !hal.command_buffer> layout(%{{.+}} : !hal.pipeline_layout)[%c0] bindings([
- // CHECK-NEXT: %c0 = (%[[BUFFER]] : !hal.buffer)[%c0, %c32],
- // CHECK-NEXT: %c1 = (%[[BUFFER]] : !hal.buffer)[%c256, %c32],
- // CHECK-NEXT: %c2 = (%[[BUFFER]] : !hal.buffer)[%c512, %c32]
- // CHECK-NEXT: ])
+// Setup dispatch constants and bindings:
+// CHECK: hal.command_buffer.push_constants<%[[CMD]] : !hal.command_buffer> layout(%{{.+}} : !hal.pipeline_layout) offset(0) values([%c100_i32, %c200_i32]) : i32, i32
+// CHECK: %[[BUFFER:.+]] = util.global.load @ex0_embedded_elf_x86_64_dispatch0_512_buffer
+// CHECK: hal.command_buffer.push_descriptor_set<%[[CMD]] : !hal.command_buffer> layout(%{{.+}} : !hal.pipeline_layout)[%c0] bindings([
+// CHECK-NEXT: %c0 = (%[[BUFFER]] : !hal.buffer)[%c0, %c32],
+// CHECK-NEXT: %c1 = (%[[BUFFER]] : !hal.buffer)[%c256, %c32],
+// CHECK-NEXT: %c2 = (%[[BUFFER]] : !hal.buffer)[%c512, %c32]
+// CHECK-NEXT: ])
- // Calculate the workgroup count, which we leave symbolic until after
- // translation:
- // CHECK: %[[WORKGROUP_X:.+]], %[[WORKGROUP_Y:.+]], %[[WORKGROUP_Z:.+]] =
- // CHECK-SAME: hal.executable.calculate_workgroups
- // CHECK-SAME: target(@ex0::@embedded_elf_x86_64::@dispatch0)
- // CHECK-SAME: workload([%c512])
+// Calculate the workgroup count, which we leave symbolic until after
+// translation:
+// CHECK: %[[WORKGROUP_X:.+]], %[[WORKGROUP_Y:.+]], %[[WORKGROUP_Z:.+]] =
+// CHECK-SAME: hal.executable.calculate_workgroups
+// CHECK-SAME: target(@ex0::@embedded_elf_x86_64::@dispatch0)
+// CHECK-SAME: workload([%c512])
- // Get executable and target ordinal (outside of the loop).
- // CHECK-DAG: %[[EXECUTABLE:.+]] = hal.executable.lookup device({{.+}}) executable(@ex0) : !hal.executable
- // CHECK-DAG: %[[ORDINAL_0:.+]] = hal.executable.export.ordinal target(@ex0::@embedded_elf_x86_64::@dispatch0) : index
+// Get executable and target ordinal (outside of the loop).
+// CHECK-DAG: %[[EXECUTABLE:.+]] = hal.executable.lookup device({{.+}}) executable(@ex0) : !hal.executable
+// CHECK-DAG: %[[ORDINAL_0:.+]] = hal.executable.export.ordinal target(@ex0::@embedded_elf_x86_64::@dispatch0) : index
- // Dispatch up to batch size dispatches:
- // CHECK: scf.for %{{.+}} = %c0 to %[[BATCH_SIZE]] step %c1 {
- // CHECK-NEXT: hal.command_buffer.dispatch<%[[CMD]] : !hal.command_buffer> target(%[[EXECUTABLE:.+]] : !hal.executable)[%[[ORDINAL_0]]] workgroups([%[[WORKGROUP_X]], %[[WORKGROUP_Y]], %[[WORKGROUP_Z]]])
- // CHECK-NEXT: hal.command_buffer.execution_barrier
- // CHECK-NEXT: }
+// Dispatch up to batch size dispatches:
+// CHECK: scf.for %{{.+}} = %c0 to %[[BATCH_SIZE]] step %c1 {
+// CHECK-NEXT: hal.command_buffer.dispatch<%[[CMD]] : !hal.command_buffer> target(%[[EXECUTABLE:.+]] : !hal.executable)[%[[ORDINAL_0]]] workgroups([%[[WORKGROUP_X]], %[[WORKGROUP_Y]], %[[WORKGROUP_Z]]])
+// CHECK-NEXT: hal.command_buffer.execution_barrier
+// CHECK-NEXT: }
- // Submit and wait for dispatches to complete:
- // CHECK: hal.command_buffer.finalize<%[[CMD]] : !hal.command_buffer>
- // CHECK: hal.fence.await
+// Submit and wait for dispatches to complete:
+// CHECK: hal.command_buffer.finalize<%[[CMD]] : !hal.command_buffer>
+// CHECK: hal.fence.await
- // ===========================================================================
- // @dispatch1 benchmark logic (note two deduplicated dispatches):
- // ===========================================================================
+// ===========================================================================
+// @dispatch1 benchmark logic (note two deduplicated dispatches):
+// ===========================================================================
- // CHECK: util.global private mutable @ex0_embedded_elf_x86_64_dispatch1_512x1_buffer : !hal.buffer
- // CHECK: util.func public @ex0_embedded_elf_x86_64_dispatch1_512x1(%arg0: i32)
- // CHECK: %[[ORDINAL_1A:.+]] = hal.executable.export.ordinal target(@ex0::@embedded_elf_x86_64::@dispatch1) : index
- // CHECK: hal.command_buffer.dispatch<%{{.+}} : !hal.command_buffer> target({{.+}})[%[[ORDINAL_1A]]]
+// CHECK: util.global private mutable @ex0_embedded_elf_x86_64_dispatch1_512x1_buffer : !hal.buffer
+// CHECK: util.func public @ex0_embedded_elf_x86_64_dispatch1_512x1(%arg0: i32)
+// CHECK: %[[ORDINAL_1A:.+]] = hal.executable.export.ordinal target(@ex0::@embedded_elf_x86_64::@dispatch1) : index
+// CHECK: hal.command_buffer.dispatch<%{{.+}} : !hal.command_buffer> target({{.+}})[%[[ORDINAL_1A]]]
- // CHECK: util.global private mutable @ex0_embedded_elf_x86_64_dispatch1_128x32_buffer : !hal.buffer
- // CHECK: util.func public @ex0_embedded_elf_x86_64_dispatch1_128x32(%arg0: i32)
- // CHECK: %[[ORDINAL_1B:.+]] = hal.executable.export.ordinal target(@ex0::@embedded_elf_x86_64::@dispatch1) : index
- // CHECK: hal.command_buffer.dispatch<%{{.+}} : !hal.command_buffer> target({{.+}})[%[[ORDINAL_1B]]]
+// CHECK: util.global private mutable @ex0_embedded_elf_x86_64_dispatch1_128x32_buffer : !hal.buffer
+// CHECK: util.func public @ex0_embedded_elf_x86_64_dispatch1_128x32(%arg0: i32)
+// CHECK: %[[ORDINAL_1B:.+]] = hal.executable.export.ordinal target(@ex0::@embedded_elf_x86_64::@dispatch1) : index
+// CHECK: hal.command_buffer.dispatch<%{{.+}} : !hal.command_buffer> target({{.+}})[%[[ORDINAL_1B]]]
- util.func public @main(%dynamic_arg: i32) -> !stream.timepoint {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c32 = arith.constant 32 : index
- %c64 = arith.constant 64 : index
- %c128 = arith.constant 128 : index
- %c512 = arith.constant 512 : index
- %c100_i32 = arith.constant 100 : i32
- %c200_i32 = arith.constant 200 : i32
- %c300_i32 = arith.constant 300 : i32
- %result, %result_timepoint = stream.resource.alloca uninitialized : !stream.resource<transient>{%c128} => !stream.timepoint
- %6 = stream.cmd.execute await(%result_timepoint) => with(%result as %result_capture: !stream.resource<transient>{%c128}) {
- // Dispatches with static and dynamic args.
- stream.cmd.dispatch @ex0::@embedded_elf_x86_64::@dispatch0[%c512](%c100_i32, %c200_i32 : i32, i32) {
- ro %result_capture[%c0 for %c32] : !stream.resource<transient>{%c128},
- rw %result_capture[%c32 for %c32] : !stream.resource<transient>{%c128},
- rw %result_capture[%c64 for %c32] : !stream.resource<transient>{%c128}
+util.func public @main(%dynamic_arg: i32) -> !stream.timepoint attributes {
+ stream.affinity = #hal.device.affinity<@device>
+} {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c32 = arith.constant 32 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c512 = arith.constant 512 : index
+ %c100_i32 = arith.constant 100 : i32
+ %c200_i32 = arith.constant 200 : i32
+ %c300_i32 = arith.constant 300 : i32
+ %result, %result_timepoint = stream.resource.alloca uninitialized : !stream.resource<transient>{%c128} => !stream.timepoint
+ %6 = stream.cmd.execute await(%result_timepoint) => with(%result as %result_capture: !stream.resource<transient>{%c128}) {
+ // Dispatches with static and dynamic args.
+ stream.cmd.dispatch @ex0::@embedded_elf_x86_64::@dispatch0[%c512](%c100_i32, %c200_i32 : i32, i32) {
+ ro %result_capture[%c0 for %c32] : !stream.resource<transient>{%c128},
+ rw %result_capture[%c32 for %c32] : !stream.resource<transient>{%c128},
+ rw %result_capture[%c64 for %c32] : !stream.resource<transient>{%c128}
+ }
+ // NOTE: today the dynamic args will prevent us from generating
+ // benchmarks. We could handle this better by tracking alignment and such.
+ stream.cmd.dispatch @ex0::@embedded_elf_x86_64::@dispatch0[%c512](%c300_i32, %dynamic_arg : i32, i32) {
+ ro %result_capture[%c0 for %c32] : !stream.resource<transient>{%c128},
+ rw %result_capture[%c32 for %c32] : !stream.resource<transient>{%c128},
+ rw %result_capture[%c64 for %c32] : !stream.resource<transient>{%c128}
+ }
+
+ // Multiple dispatches to a single entry point.
+ // Dispatches are deduplicated and the two 128x32x1 should combine.
+ stream.cmd.dispatch @ex0::@embedded_elf_x86_64::@dispatch1[%c512, %c1] {
+ ro %result_capture[%c0 for %c64] : !stream.resource<transient>{%c128},
+ rw %result_capture[%c64 for %c32] : !stream.resource<transient>{%c128}
+ }
+ stream.cmd.dispatch @ex0::@embedded_elf_x86_64::@dispatch1[%c128, %c32] {
+ ro %result_capture[%c0 for %c64] : !stream.resource<transient>{%c128},
+ rw %result_capture[%c64 for %c32] : !stream.resource<transient>{%c128}
+ }
+ stream.cmd.dispatch @ex0::@embedded_elf_x86_64::@dispatch1[%c128, %c32] {
+ ro %result_capture[%c0 for %c64] : !stream.resource<transient>{%c128},
+ rw %result_capture[%c64 for %c32] : !stream.resource<transient>{%c128}
+ }
+ } => !stream.timepoint
+ %39 = stream.resource.dealloca await(%6) => %result : !stream.resource<transient>{%c128} => !stream.timepoint
+ util.return %39 : !stream.timepoint
+}
+
+// -----
+// expected-warning@-2 {{multiple devices in the module}}
+
+// Tests that multiple devices fail today.
+// We should be creating one benchmark per executable with only the dispatches
+// used by that executable.
+
+#executable_target_embedded_elf_aarch64 = #hal.executable.target<"llvm-cpu", "embedded-elf-aarch64">
+#executable_target_embedded_elf_x86_64 = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64">
+util.global private @device_a = #hal.device.target<"llvm-cpu", [
+ #executable_target_embedded_elf_aarch64
+]> : !hal.device
+util.global private @device_b = #hal.device.target<"llvm-cpu", [
+ #executable_target_embedded_elf_x86_64
+]> : !hal.device
+
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>
+ ]>
+]>
+
+hal.executable private @ex_0 {
+ hal.executable.variant public @variant_a target(#executable_target_embedded_elf_aarch64) {
+ hal.executable.export public @dispatch0 ordinal(0) layout(#pipeline_layout) attributes {
+ translation_info = #iree_codegen.translation_info<CPUDefault>
+ } {
+ ^bb0(%device: !hal.device, %arg0: index):
+ %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @dispatch0() {
+ func.return
}
- // NOTE: today the dynamic args will prevent us from generating
- // benchmarks. We could handle this better by tracking alignment and such.
- stream.cmd.dispatch @ex0::@embedded_elf_x86_64::@dispatch0[%c512](%c300_i32, %dynamic_arg : i32, i32) {
- ro %result_capture[%c0 for %c32] : !stream.resource<transient>{%c128},
- rw %result_capture[%c32 for %c32] : !stream.resource<transient>{%c128},
- rw %result_capture[%c64 for %c32] : !stream.resource<transient>{%c128}
+ }
+ hal.executable.export public @dispatch1 ordinal(1) layout(#pipeline_layout) attributes {
+ translation_info = #iree_codegen.translation_info<CPUDefault>
+ } {
+ ^bb0(%device: !hal.device, %arg0: index, %arg1: index):
+ %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @dispatch1() {
+ func.return
}
-
- // Multiple dispatches to a single entry point.
- // Dispatches are deduplicated and the two 128x32x1 should combine.
- stream.cmd.dispatch @ex0::@embedded_elf_x86_64::@dispatch1[%c512, %c1] {
- ro %result_capture[%c0 for %c64] : !stream.resource<transient>{%c128},
- rw %result_capture[%c64 for %c32] : !stream.resource<transient>{%c128}
- }
- stream.cmd.dispatch @ex0::@embedded_elf_x86_64::@dispatch1[%c128, %c32] {
- ro %result_capture[%c0 for %c64] : !stream.resource<transient>{%c128},
- rw %result_capture[%c64 for %c32] : !stream.resource<transient>{%c128}
- }
- stream.cmd.dispatch @ex0::@embedded_elf_x86_64::@dispatch1[%c128, %c32] {
- ro %result_capture[%c0 for %c64] : !stream.resource<transient>{%c128},
- rw %result_capture[%c64 for %c32] : !stream.resource<transient>{%c128}
- }
- } => !stream.timepoint
- %39 = stream.resource.dealloca await(%6) => %result : !stream.resource<transient>{%c128} => !stream.timepoint
- util.return %39 : !stream.timepoint
+ }
}
+ hal.executable.variant public @variant_b target(#executable_target_embedded_elf_x86_64) {
+ hal.executable.export public @dispatch0 ordinal(0) layout(#pipeline_layout) attributes {
+ translation_info = #iree_codegen.translation_info<CPUDefault>
+ } {
+ ^bb0(%device: !hal.device, %arg0: index):
+ %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @dispatch0() {
+ func.return
+ }
+ }
+ hal.executable.export public @dispatch1 ordinal(1) layout(#pipeline_layout) attributes {
+ translation_info = #iree_codegen.translation_info<CPUDefault>
+ } {
+ ^bb0(%device: !hal.device, %arg0: index, %arg1: index):
+ %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @dispatch1() {
+ func.return
+ }
+ }
+ }
+}
+hal.executable private @ex_1 {
+ hal.executable.variant public @variant_b target(#executable_target_embedded_elf_x86_64) {
+ hal.executable.export public @dispatch0 ordinal(0) layout(#pipeline_layout) attributes {
+ translation_info = #iree_codegen.translation_info<CPUDefault>
+ } {
+ ^bb0(%device: !hal.device, %arg0: index):
+ %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @dispatch0() {
+ func.return
+ }
+ }
+ }
+}
+
+util.func public @main(%resource_a_arg: !stream.resource<transient>, %resource_b_arg: !stream.resource<transient>) -> (!stream.timepoint, !stream.timepoint) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c32 = arith.constant 32 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c512 = arith.constant 512 : index
+ %tp_a = stream.cmd.execute on(#hal.device.affinity<@device_a>) with(%resource_a_arg as %resource_a: !stream.resource<transient>{%c128}) {
+ stream.cmd.dispatch @ex_0::@variant_a::@dispatch0[%c512] {
+ rw %resource_a[%c0 for %c32] : !stream.resource<transient>{%c128}
+ }
+ stream.cmd.dispatch @ex_0::@variant_a::@dispatch1[%c512] {
+ rw %resource_a[%c0 for %c64] : !stream.resource<transient>{%c128}
+ }
+ stream.cmd.dispatch @ex_0::@variant_a::@dispatch1[%c128] {
+ rw %resource_a[%c0 for %c64] : !stream.resource<transient>{%c128}
+ }
+ } => !stream.timepoint
+ %tp_b = stream.cmd.execute on(#hal.device.affinity<@device_b>) with(%resource_b_arg as %resource_b: !stream.resource<transient>{%c128}) {
+ stream.cmd.dispatch @ex_0::@variant_a::@dispatch0[%c512] {
+ rw %resource_b[%c0 for %c32] : !stream.resource<transient>{%c128}
+ }
+ stream.cmd.dispatch @ex_0::@variant_a::@dispatch1[%c512] {
+ rw %resource_b[%c0 for %c64] : !stream.resource<transient>{%c128}
+ }
+ stream.cmd.dispatch @ex_0::@variant_b::@dispatch0[%c128] {
+ rw %resource_b[%c0 for %c64] : !stream.resource<transient>{%c128}
+ }
+ } => !stream.timepoint
+ util.return %tp_a, %tp_b : !stream.timepoint, !stream.timepoint
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_sources.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_sources.mlir
index 458e841..5de1c9c 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_sources.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_sources.mlir
@@ -4,9 +4,6 @@
// but this is much easier to test with lit.
#executable_target_embedded_elf_x86_64 = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64">
-#device_target_cpu = #hal.device.target<"llvm-cpu", [
- #executable_target_embedded_elf_x86_64
-]>
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
@@ -15,46 +12,42 @@
]>
]>
-module attributes {hal.device.targets = [#device_target_cpu]} {
-
- // CHECK: hal.executable public @ex0
- hal.executable private @ex0 {
- // We expect local outputs with attributes inlined:
- // CHECK-NEXT: hal.executable.variant {{.+}} target(<"llvm-cpu"
- hal.executable.variant public @embedded_elf_x86_64 target(#executable_target_embedded_elf_x86_64) {
- hal.executable.export public @dispatch0 ordinal(0) layout(#pipeline_layout) attributes {
- translation_info = #iree_codegen.translation_info<CPUDefault>
- } {
- ^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index): // no predecessors
- %c1 = arith.constant 1 : index
- %0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 4)>()[%arg0]
- hal.return %0, %c1, %c1 : index, index, index
- }
- builtin.module {
- func.func @dispatch0() {
- func.return
- }
+// CHECK: hal.executable public @ex0
+hal.executable private @ex0 {
+ // We expect local outputs with attributes inlined:
+ // CHECK-NEXT: hal.executable.variant {{.+}} target(<"llvm-cpu"
+ hal.executable.variant public @embedded_elf_x86_64 target(#executable_target_embedded_elf_x86_64) {
+ hal.executable.export public @dispatch0 ordinal(0) layout(#pipeline_layout) attributes {
+ translation_info = #iree_codegen.translation_info<CPUDefault>
+ } {
+ ^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index): // no predecessors
+ %c1 = arith.constant 1 : index
+ %0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 4)>()[%arg0]
+ hal.return %0, %c1, %c1 : index, index, index
+ }
+ builtin.module {
+ func.func @dispatch0() {
+ func.return
}
}
}
+}
- // CHECK: hal.executable private @ex1
- hal.executable private @ex1 {
- hal.executable.variant public @embedded_elf_x86_64 target(#executable_target_embedded_elf_x86_64) {
- hal.executable.export public @dispatch1 ordinal(0) layout(#pipeline_layout) attributes {
- translation_info = #iree_codegen.translation_info<CPUDefault>
- } {
- ^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index): // no predecessors
- %c1 = arith.constant 1 : index
- %0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 4)>()[%arg0]
- hal.return %0, %c1, %c1 : index, index, index
- }
- builtin.module {
- func.func @dispatch1() {
- func.return
- }
+// CHECK: hal.executable private @ex1
+hal.executable private @ex1 {
+ hal.executable.variant public @embedded_elf_x86_64 target(#executable_target_embedded_elf_x86_64) {
+ hal.executable.export public @dispatch1 ordinal(0) layout(#pipeline_layout) attributes {
+ translation_info = #iree_codegen.translation_info<CPUDefault>
+ } {
+ ^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index): // no predecessors
+ %c1 = arith.constant 1 : index
+ %0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 4)>()[%arg0]
+ hal.return %0, %c1, %c1 : index, index, index
+ }
+ builtin.module {
+ func.func @dispatch1() {
+ func.return
}
}
}
-
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/fixup_legacy_sync.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/fixup_legacy_sync.mlir
index 29de091..d217b47 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/fixup_legacy_sync.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/fixup_legacy_sync.mlir
@@ -1,12 +1,16 @@
// RUN: iree-opt --split-input-file --iree-hal-fixup-legacy-sync %s | FileCheck %s
-// Tests that command buffers that are reusable don't execute inline.
-// Reusable + inline is not a valid combination.
-
-module attributes {hal.device.targets = [#hal.device.target<"vulkan", {legacy_sync}>]} {
-// CHECK-LABEL: @command_buffer_reusable
-util.func public @command_buffer_reusable(%device: !hal.device, %affinity: i64) {
- // CHECK: hal.command_buffer.create device(%{{.+}} : !hal.device) mode("None")
+// TODO(multi-device): remove once device globals are used. This is a fallback
+// path during the transition.
+module attributes {
+ hal.device.targets = [
+ #hal.device.target<"vulkan", {legacy_sync}> : !hal.device
+ ]
+} {
+// CHECK-LABEL: @default_device_targets
+// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[AFFINITY:.+]]: i64)
+util.func public @default_device_targets(%device: !hal.device, %affinity: i64) {
+ // CHECK: hal.command_buffer.create device(%[[DEVICE]] : !hal.device) mode("None")
%cmd = hal.command_buffer.create device(%device : !hal.device) mode("None") categories("Transfer|Dispatch") affinity(%affinity) : !hal.command_buffer
util.return
}
@@ -14,43 +18,78 @@
// -----
+// Tests that unknown devices (here passed as an arg on a public function)
+// don't trigger the pass, as we default to non-legacy behavior.
+
+// CHECK-LABEL: @unknown_device
+// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[AFFINITY:.+]]: i64)
+util.func public @unknown_device(%device: !hal.device, %affinity: i64) {
+ // CHECK: hal.command_buffer.create device(%[[DEVICE]] : !hal.device) mode("None")
+ %cmd = hal.command_buffer.create device(%device : !hal.device) mode("None") categories("Transfer|Dispatch") affinity(%affinity) : !hal.command_buffer
+ util.return
+}
+
+// -----
+
+// Tests that command buffers that are reusable don't execute inline.
+// Reusable + inline is not a valid combination.
+
+util.global private @device = #hal.device.target<"vulkan", {legacy_sync}> : !hal.device
+
+// CHECK-LABEL: @command_buffer_reusable
+util.func public @command_buffer_reusable(%affinity: i64) {
+ // CHECK: %[[DEVICE:.+]] = util.global.load @device
+ %device = util.global.load @device : !hal.device
+ // CHECK: hal.command_buffer.create device(%[[DEVICE]] : !hal.device) mode("None")
+ %cmd = hal.command_buffer.create device(%device : !hal.device) mode("None") categories("Transfer|Dispatch") affinity(%affinity) : !hal.command_buffer
+ util.return
+}
+
+// -----
+
// Tests that one-shot command buffers are allowed to execute inline.
-module attributes {hal.device.targets = [#hal.device.target<"vulkan", {legacy_sync}>]} {
+util.global private @device = #hal.device.target<"vulkan", {legacy_sync}> : !hal.device
+
// CHECK-LABEL: @command_buffer_oneshot
-util.func public @command_buffer_oneshot(%device: !hal.device, %affinity: i64) {
- // CHECK: hal.command_buffer.create device(%{{.+}} : !hal.device) mode("OneShot|AllowInlineExecution")
+util.func public @command_buffer_oneshot(%affinity: i64) {
+ // CHECK: %[[DEVICE:.+]] = util.global.load @device
+ %device = util.global.load @device : !hal.device
+ // CHECK: hal.command_buffer.create device(%[[DEVICE]] : !hal.device) mode("OneShot|AllowInlineExecution")
%cmd = hal.command_buffer.create device(%device : !hal.device) mode(OneShot) categories("Transfer|Dispatch") affinity(%affinity) : !hal.command_buffer
util.return
}
-} // module
// -----
// Tests for a no-op if there are no devices requiring legacy mode.
-module attributes {hal.device.targets = [
+util.global private @device = #hal.device.select<[
#hal.device.target<"local", {}>,
#hal.device.target<"vulkan", {}>
-]} {
+]> : !hal.device
+
// CHECK-LABEL: @legacy_mode_not_required
-util.func public @legacy_mode_not_required(%device: !hal.device, %affinity: i64) {
- // CHECK: hal.command_buffer.create device(%{{.+}} : !hal.device) mode(OneShot)
+util.func public @legacy_mode_not_required(%affinity: i64) {
+ // CHECK: %[[DEVICE:.+]] = util.global.load @device
+ %device = util.global.load @device : !hal.device
+ // CHECK: hal.command_buffer.create device(%[[DEVICE]] : !hal.device) mode(OneShot)
%cmd = hal.command_buffer.create device(%device : !hal.device) mode(OneShot) categories("Transfer|Dispatch") affinity(%affinity) : !hal.command_buffer
util.return
}
-} // module
// -----
-// Tests that any device requiring legacy_sync will trigger the pass.
+// Tests that any device requiring legacy_sync in a set will trigger the pass.
-module attributes {hal.device.targets = [
+util.global private @device = #hal.device.select<[
#hal.device.target<"local", {}>,
#hal.device.target<"vulkan", {legacy_sync}>
-]} {
+]> : !hal.device
+
// CHECK-LABEL: @mixed_legacy_mode_required
-util.func public @mixed_legacy_mode_required(%device: !hal.device, %wait: !hal.fence, %cmd: !hal.command_buffer, %signal: !hal.fence) {
+util.func public @mixed_legacy_mode_required(%wait: !hal.fence, %cmd: !hal.command_buffer, %signal: !hal.fence) {
+ %device = util.global.load @device : !hal.device
%affinity = arith.constant 1 : i64
// CHECK: hal.fence.await
// CHECK: hal.device.queue.execute
@@ -61,17 +100,50 @@
commands([%cmd])
util.return
}
-} // module
+
+// -----
+
+// Tests that only devices with legacy_sync trigger the pass.
+
+util.global private @device_async = #hal.device.target<"local", {}> : !hal.device
+util.global private @device_sync = #hal.device.target<"vulkan", {legacy_sync}> : !hal.device
+
+// CHECK-LABEL: @mixed_legacy_mode_scoped
+util.func public @mixed_legacy_mode_scoped(%wait: !hal.fence, %cmd: !hal.command_buffer, %signal: !hal.fence) {
+ // CHECK-DAG: %[[DEVICE_ASYNC:.+]] = util.global.load @device_async
+ %device_async = util.global.load @device_async : !hal.device
+ // CHECK-DAG: %[[DEVICE_SYNC:.+]] = util.global.load @device_sync
+ %device_sync = util.global.load @device_sync : !hal.device
+ %affinity = arith.constant 1 : i64
+ // CHECK-NOT: hal.fence.await
+ // CHECK: hal.device.queue.execute<%[[DEVICE_ASYNC]]
+ // CHECK-NOT: hal.fence.await
+ hal.device.queue.execute<%device_async : !hal.device>
+ affinity(%affinity)
+ wait(%wait) signal(%signal)
+ commands([%cmd])
+ // CHECK: hal.fence.await
+ // CHECK: hal.device.queue.execute<%[[DEVICE_SYNC]]
+ // CHECK: hal.fence.await
+ hal.device.queue.execute<%device_sync : !hal.device>
+ affinity(%affinity)
+ wait(%wait) signal(%signal)
+ commands([%cmd])
+ util.return
+}
// -----
// Tests that queued operations get the appropriate waits before/after.
-module attributes {hal.device.targets = [#hal.device.target<"vulkan", {legacy_sync}>]} {
+util.global private @device = #hal.device.target<"vulkan", {legacy_sync}> : !hal.device
+
// CHECK-LABEL: @blocking_execute
-// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[WAIT:.+]]: !hal.fence, %[[CMD:.+]]: !hal.command_buffer, %[[SIGNAL:.+]]: !hal.fence)
-util.func public @blocking_execute(%device: !hal.device, %wait: !hal.fence, %cmd: !hal.command_buffer, %signal: !hal.fence) {
+// CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[CMD:.+]]: !hal.command_buffer, %[[SIGNAL:.+]]: !hal.fence)
+util.func public @blocking_execute(%wait: !hal.fence, %cmd: !hal.command_buffer, %signal: !hal.fence) {
%affinity = arith.constant 1 : i64
+ // CHECK: %[[DEVICE:.+]] = util.global.load @device
+ %device = util.global.load @device : !hal.device
// CHECK-DAG: %[[NULL:.+]] = util.null : !hal.fence
// CHECK-DAG: hal.fence.await until([%[[WAIT]]])
// CHECK-NEXT: hal.device.queue.execute<%[[DEVICE]] : !hal.device>
@@ -84,16 +156,18 @@
commands([%cmd])
util.return
}
-} // module
// -----
// Tests that waits are not inserted if they already exist.
-module attributes {hal.device.targets = [#hal.device.target<"vulkan", {legacy_sync}>]} {
+util.global private @device = #hal.device.target<"vulkan", {legacy_sync}> : !hal.device
+
// CHECK-LABEL: @blocking_execute
-// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[WAIT:.+]]: !hal.fence, %[[CMD:.+]]: !hal.command_buffer, %[[SIGNAL:.+]]: !hal.fence)
-util.func public @blocking_execute(%device: !hal.device, %wait: !hal.fence, %cmd: !hal.command_buffer, %signal: !hal.fence) {
+// CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[CMD:.+]]: !hal.command_buffer, %[[SIGNAL:.+]]: !hal.fence)
+util.func public @blocking_execute(%wait: !hal.fence, %cmd: !hal.command_buffer, %signal: !hal.fence) {
+ // CHECK: %[[DEVICE:.+]] = util.global.load @device
+ %device = util.global.load @device : !hal.device
// CHECK-NEXT: %[[TIMEOUT:.+]] = arith.constant 100
%timeout = arith.constant 100 : i32
// CHECK-NEXT: hal.fence.await until([%[[WAIT]]]) timeout_millis(%[[TIMEOUT]])
@@ -114,4 +188,3 @@
// CHECK-NEXT: util.return
util.return
}
-} // module
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/initialize_devices.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/initialize_devices.mlir
new file mode 100644
index 0000000..3920a4d
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/initialize_devices.mlir
@@ -0,0 +1,119 @@
+// RUN: iree-opt --split-input-file --iree-hal-initialize-devices --cse %s | FileCheck %s
+
+// Tests that #hal.device.ordinal<*> gets the device with the given ordinal.
+
+// CHECK: util.global private @device_123 : !hal.device
+util.global private @device_123 = #hal.device.ordinal<123> : !hal.device
+
+// CHECK-NEXT: util.initializer
+// CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %c123
+// CHECK-DAG: %[[NULL_DEVICE:.+]] = util.null : !hal.device
+// CHECK-DAG: %[[IS_NULL:.+]] = util.cmp.eq %[[DEVICE]], %[[NULL_DEVICE]]
+// CHECK-NEXT: scf.if %[[IS_NULL]] {
+// CHECK: util.status.check_ok %c5_i32, "HAL device `device_123` not found or unavailable: #hal.device.ordinal<123>"
+// CHECK: util.global.store %[[DEVICE]], @device_123
+
+// -----
+
+// Tests that #hal.device.fallback<*> references the specified device global.
+
+util.global private @device_base : !hal.device
+
+// CHECK: util.global private @device_fallback : !hal.device
+util.global private @device_fallback = #hal.device.fallback<@device_base> : !hal.device
+
+// CHECK-NEXT: util.initializer
+// CHECK-DAG: %[[DEVICE:.+]] = util.global.load @device_base : !hal.device
+// CHECK-DAG: %[[IS_NULL:.+]] = util.cmp.eq %[[DEVICE]], %{{.+}}
+// CHECK-NEXT: scf.if %[[IS_NULL]] {
+// CHECK: util.status.check_ok %c5_i32, "HAL device `device_fallback` not found or unavailable: #hal.device.fallback<@device_base>"
+// CHECK: util.global.store %[[DEVICE]], @device_fallback
+
+// -----
+
+// Tests that #hal.device.target<*> enumerates all devices and tries to match
+// a particular target with the given ordinal. The ordinal allows for multiple
+// devices of the same type to be differentiated.
+
+// CHECK: util.global private @device_a : !hal.device
+util.global private @device_a = #hal.device.target<"a", {
+ ordinal = 2 : index
+}, [
+ #hal.executable.target<"backend0", "format0">,
+ #hal.executable.target<"backend1", "format1">
+]> : !hal.device
+
+// CHECK-NEXT: util.initializer
+// CHECK-DAG: %[[NULL_DEVICE:.+]] = util.null : !hal.device
+// CHECK-DAG: %[[DEVICE_COUNT:.+]] = hal.devices.count
+// CHECK: %[[WHILE:.+]]:3 = scf.while (%arg0 = %c0, %arg1 = %c0, %arg2 = %[[NULL_DEVICE]])
+// CHECK-DAG: %[[IS_DEVICE_NULL:.+]] = util.cmp.eq %arg2, %[[NULL_DEVICE]]
+// CHECK-DAG: %[[IS_END:.+]] = arith.cmpi slt, %arg0, %[[DEVICE_COUNT]]
+// CHECK-DAG: %[[CONTINUE:.+]] = arith.andi %[[IS_DEVICE_NULL]], %[[IS_END]]
+// CHECK-NEXT: scf.condition(%[[CONTINUE]]) %arg0, %arg1, %arg2
+// CHECK-NEXT: } do {
+// CHECK-NEXT: ^bb0(%arg0: index, %arg1: index, %arg2: !hal.device)
+// CHECK-DAG: %[[DEVICE_N:.+]] = hal.devices.get %arg0 : !hal.device
+
+// NOTE: this is the fallback path for device matching unregistered targets.
+// Real targets can have much more complex logic if they so choose.
+// CHECK-DAG: %{{.+}}, %[[ID_MATCH:.+]] = hal.device.query<%[[DEVICE_N]] : !hal.device> key("hal.device.id" :: "a")
+// CHECK-NEXT: %[[IS_DEVICE_MATCH:.+]] = scf.if %[[ID_MATCH]] -> (i1) {
+// CHECK-DAG: %{{.+}}, %[[FORMAT0_MATCH:.+]] = hal.device.query<%[[DEVICE_N]] : !hal.device> key("hal.executable.format" :: "format0")
+// CHECK-DAG: %{{.+}}, %[[FORMAT1_MATCH:.+]] = hal.device.query<%[[DEVICE_N]] : !hal.device> key("hal.executable.format" :: "format1")
+// CHECK-DAG: %[[FORMAT_MATCH_OR:.+]] = arith.ori %[[FORMAT0_MATCH]], %[[FORMAT1_MATCH]]
+// CHECK-DAG: scf.yield %[[FORMAT_MATCH_OR]]
+// CHECK-NEXT: } else {
+// CHECK-DAG: scf.yield %false
+
+// Check that if the device matches this is the ordinal selected. If not the
+// correct ordinal we'll skip it and continue to look for the next.
+// CHECK-DAG: %[[IS_ORDINAL_MATCH:.+]] = arith.cmpi eq, %arg1, %c2
+// CHECK-DAG: %[[NEXT_MATCH_ADVANCE:.+]] = arith.select %[[IS_DEVICE_MATCH]], %c1, %c0
+// CHECK-DAG: %[[NEXT_MATCH_ORDINAL:.+]] = arith.addi %arg1, %[[NEXT_MATCH_ADVANCE]]
+
+// CHECK-DAG: %[[IS_MATCH:.+]] = arith.andi %[[IS_DEVICE_MATCH]], %[[IS_ORDINAL_MATCH]]
+// CHECK-DAG: %[[YIELD_DEVICE:.+]] = arith.select %[[IS_MATCH]], %[[DEVICE_N]], %[[NULL_DEVICE]]
+// CHECK-DAG: %[[NEXT_I:.+]] = arith.addi %arg0, %c1
+// CHECK-NEXT: scf.yield %[[NEXT_I]], %[[NEXT_MATCH_ORDINAL]], %[[YIELD_DEVICE]]
+
+// Error out if no device was found because at least one match is required.
+// CHECK-DAG: %[[IS_NULL:.+]] = util.cmp.eq %[[WHILE]]#2, %[[NULL_DEVICE]]
+// CHECK-NEXT: scf.if %[[IS_NULL]] {
+// CHECK: util.status.check_ok %c5_i32, "HAL device `device_a` not found or unavailable: #hal.device.target<{{.+}}>"
+// CHECK: util.global.store %[[WHILE]]#2, @device_a
+
+// -----
+
+// Tests that #hal.device.select<*> expands to a chain of ifs.
+
+util.global private @fallback : !hal.device
+
+// CHECK: util.global private @selected : !hal.device
+util.global private @selected = #hal.device.select<[
+ #hal.device.ordinal<2> : !hal.device,
+ #hal.device.ordinal<1> : !hal.device,
+ #hal.device.fallback<@fallback> : !hal.device
+]> : !hal.device
+
+// CHECK-NEXT: util.initializer
+// CHECK-DAG: %[[NULL_DEVICE:.+]] = util.null : !hal.device
+// CHECK-DAG: %[[DEVICE_2:.+]] = hal.devices.get %c2
+// CHECK-DAG: %[[NOT_DEVICE_2:.+]] = util.cmp.eq %[[DEVICE_2]], %[[NULL_DEVICE]]
+// CHECK-NEXT: %[[IF_0:.+]] = scf.if %[[NOT_DEVICE_2]]
+// CHECK-DAG: %[[DEVICE_1:.+]] = hal.devices.get %c1
+// CHECK-DAG: %[[NOT_DEVICE_1:.+]] = util.cmp.eq %[[DEVICE_1]], %[[NULL_DEVICE]]
+// CHECK-NEXT: %[[IF_1:.+]] = scf.if %[[NOT_DEVICE_1]]
+// CHECK-DAG: %[[DEVICE_FALLBACK:.+]] = util.global.load @fallback
+// CHECK-NEXT: scf.yield %[[DEVICE_FALLBACK]]
+// CHECK-NEXT: } else {
+// CHECK-NEXT: scf.yield %[[DEVICE_1]]
+// CHECK-NEXT: }
+// CHECK-NEXT: scf.yield %[[IF_1]]
+// CHECK-NEXT: } else {
+// CHECK-NEXT: scf.yield %[[DEVICE_2]]
+// CHECK-NEXT: }
+// CHECK-DAG: %[[IS_NULL:.+]] = util.cmp.eq %[[IF_0]], %[[NULL_DEVICE]]
+// CHECK-NEXT: scf.if %[[IS_NULL]] {
+// CHECK: util.status.check_ok %c5_i32, "HAL device `selected` not found or unavailable: #hal.device.select<{{.+}}>"
+// CHECK: util.global.store %[[IF_0]], @selected
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_dispatch_instrumentation.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_dispatch_instrumentation.mlir
index 607c876..e86f141 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_dispatch_instrumentation.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_dispatch_instrumentation.mlir
@@ -4,7 +4,7 @@
#hal.device.target<"llvm-cpu", [
#hal.executable.target<"llvm-cpu", "embedded-elf-arm_64">,
#hal.executable.target<"llvm-cpu", "embedded-elf-x86_64">
- ]>
+ ]> : !hal.device
]} {
// Instrumentation storage buffer allocated at startup (defaults to 64MB + footer):
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir
index a688bdd..d350e0e 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir
@@ -2,69 +2,71 @@
// Tests an executable with a workgroup count region specified.
-module attributes {hal.device.targets = [
- #hal.device.target<"llvm-cpu", [
- #hal.executable.target<"llvm-cpu", "arm_64">,
- #hal.executable.target<"llvm-cpu", "x86_64">
- ]>
-]} {
- // CHECK: #pipeline_layout = #hal.pipeline.layout<
- // CHECK-SAME: push_constants = 1
- // CHECK-SAME: sets = [
- // CHECK-SAME: <0, bindings = [
- // CHECK-SAME: <0, storage_buffer, ReadOnly>
- // CHECK-SAME: <1, storage_buffer, ReadOnly>
- // CHECK-SAME: <2, storage_buffer>
+// The default device when none is specified.
+// Functions and scopes can override the target device.
+util.global private @default_device = #hal.device.target<"cpu", [
+ #hal.executable.target<"llvm-cpu", "arm_64">,
+ #hal.executable.target<"llvm-cpu", "x86_64">
+]> : !hal.device
- // CHECK: hal.executable private @ex
- // CHECK: hal.executable.variant public @arm_64 target(#executable_target_arm_64
- // CHECK: hal.executable.export public @entry ordinal(0) layout(#pipeline_layout)
- // CHECK-SAME: hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>]
- // CHECK-NEXT: ^bb0(%[[DEVICE:.+]]: !hal.device, %[[ARG0:.+]]: index, %[[ARG1:.+]]: index):
- // CHECK-NEXT: hal.return %[[ARG0]], %[[ARG1]], %[[ARG0]] : index, index, index
- // CHECK-NEXT: }
- // CHECK: builtin.module
- // CHECK-NEXT: func.func private @extern_func()
- // CHECK-NEXT: func.func @entry
- // CHECK: hal.executable.variant public @x86_64 target(#executable_target_x86_64
- // CHECK: hal.executable.export public @entry ordinal(0) layout(#pipeline_layout)
- // CHECK-SAME: hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>]
- // CHECK-NEXT: ^bb0(%[[DEVICE:.+]]: !hal.device, %[[ARG0:.+]]: index, %[[ARG1:.+]]: index):
- // CHECK-NEXT: hal.return %[[ARG0]], %[[ARG1]], %[[ARG0]] : index, index, index
- // CHECK-NEXT: }
- // CHECK: builtin.module
- // CHECK-NEXT: func.func private @extern_func()
+// CHECK: #pipeline_layout = #hal.pipeline.layout<
+// CHECK-SAME: push_constants = 1
+// CHECK-SAME: sets = [
+// CHECK-SAME: <0, bindings = [
+// CHECK-SAME: <0, storage_buffer, ReadOnly>
+// CHECK-SAME: <1, storage_buffer, ReadOnly>
+// CHECK-SAME: <2, storage_buffer>
- // CHECK-NEXT: func.func @entry
- stream.executable private @ex {
- stream.executable.export public @entry workgroups(%arg0: index, %arg1: index) -> (index, index, index) {
- stream.return %arg0, %arg1, %arg0 : index, index, index
- }
- builtin.module {
- func.func private @extern_func()
- func.func @entry(%operand: i32, %arg0: !stream.binding {stream.alignment = 64 : index}, %arg1: !stream.binding {stream.alignment = 64 : index}, %arg2: !stream.binding {stream.alignment = 64 : index}) {
- return
- }
+// CHECK: hal.executable private @ex
+// CHECK: hal.executable.variant public @arm_64 target(#executable_target_arm_64
+// CHECK: hal.executable.export public @entry ordinal(0) layout(#pipeline_layout)
+// CHECK-SAME: hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>]
+// CHECK-NEXT: ^bb0(%[[DEVICE:.+]]: !hal.device, %[[ARG0:.+]]: index, %[[ARG1:.+]]: index):
+// CHECK-NEXT: hal.return %[[ARG0]], %[[ARG1]], %[[ARG0]] : index, index, index
+// CHECK-NEXT: }
+// CHECK: builtin.module
+// CHECK-NEXT: func.func private @extern_func()
+// CHECK-NEXT: func.func @entry
+// CHECK: hal.executable.variant public @x86_64 target(#executable_target_x86_64
+// CHECK: hal.executable.export public @entry ordinal(0) layout(#pipeline_layout)
+// CHECK-SAME: hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>]
+// CHECK-NEXT: ^bb0(%[[DEVICE:.+]]: !hal.device, %[[ARG0:.+]]: index, %[[ARG1:.+]]: index):
+// CHECK-NEXT: hal.return %[[ARG0]], %[[ARG1]], %[[ARG0]] : index, index, index
+// CHECK-NEXT: }
+// CHECK: builtin.module
+// CHECK-NEXT: func.func private @extern_func()
+
+// CHECK-NEXT: func.func @entry
+stream.executable private @ex {
+ stream.executable.export public @entry workgroups(%arg0: index, %arg1: index) -> (index, index, index) {
+ stream.return %arg0, %arg1, %arg0 : index, index, index
+ }
+ builtin.module {
+ func.func private @extern_func()
+ func.func @entry(%operand: i32, %arg0: !stream.binding {stream.alignment = 64 : index}, %arg1: !stream.binding {stream.alignment = 64 : index}, %arg2: !stream.binding {stream.alignment = 64 : index}) {
+ return
}
}
- util.func public @main(%arg0: !stream.resource<constant>, %arg1: !stream.resource<transient>, %arg2: index, %arg3: i32) -> !stream.resource<transient> {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c2 = arith.constant 2 : index
- %0 = stream.resource.alloc uninitialized : !stream.resource<transient>{%arg2}
- %1 = stream.cmd.execute with(%arg0 as %arg4: !stream.resource<constant>{%arg2}, %arg1 as %arg5: !stream.resource<transient>{%arg2}, %0 as %arg6: !stream.resource<transient>{%arg2}) {
- // CHECK: stream.cmd.dispatch
- // CHECK-SAME: @ex::@arm_64::@entry
- // CHECK-SAME: @ex::@x86_64::@entry
- stream.cmd.dispatch @ex::@entry[%c1, %c2](%arg3 : i32) {
- ro %arg4[%c0 for %arg2] : !stream.resource<constant>{%arg2},
- ro %arg5[%c0 for %arg2] : !stream.resource<transient>{%arg2},
- wo %arg6[%c0 for %arg2] : !stream.resource<transient>{%arg2}
- }
- } => !stream.timepoint
- %2 = stream.timepoint.await %1 => %0 : !stream.resource<transient>{%arg2}
- util.return %2 : !stream.resource<transient>
- }
+}
+util.func public @main(%arg0: !stream.resource<constant>, %arg1: !stream.resource<transient>, %arg2: index, %arg3: i32) -> !stream.resource<transient> attributes {
+ stream.affinity = #hal.device.affinity<@default_device>
+} {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %0 = stream.resource.alloc uninitialized : !stream.resource<transient>{%arg2}
+ %1 = stream.cmd.execute with(%arg0 as %arg4: !stream.resource<constant>{%arg2}, %arg1 as %arg5: !stream.resource<transient>{%arg2}, %0 as %arg6: !stream.resource<transient>{%arg2}) {
+ // CHECK: stream.cmd.dispatch
+ // CHECK-SAME: @ex::@arm_64::@entry
+ // CHECK-SAME: @ex::@x86_64::@entry
+ stream.cmd.dispatch @ex::@entry[%c1, %c2](%arg3 : i32) {
+ ro %arg4[%c0 for %arg2] : !stream.resource<constant>{%arg2},
+ ro %arg5[%c0 for %arg2] : !stream.resource<transient>{%arg2},
+ wo %arg6[%c0 for %arg2] : !stream.resource<transient>{%arg2}
+ }
+ } => !stream.timepoint
+ %2 = stream.timepoint.await %1 => %0 : !stream.resource<transient>{%arg2}
+ util.return %2 : !stream.resource<transient>
}
// -----
@@ -72,68 +74,67 @@
// Tests that executable variants are expanded based on what devices they are
// dispatched on.
-module attributes {
- // The default device when none is specified.
- // Functions and scopes can override the target device.
- hal.device.targets = [
- #hal.device.target<"cpu", [
- #hal.executable.target<"llvm-cpu", "arm_64">,
- #hal.executable.target<"llvm-cpu", "x86_64">
- ]>
- ]
+// The default device when none is specified.
+// Functions and scopes can override the target device.
+util.global private @default_device = #hal.device.target<"cpu", [
+ #hal.executable.target<"llvm-cpu", "arm_64">,
+ #hal.executable.target<"llvm-cpu", "x86_64">
+]> : !hal.device
+util.global private @riscv_device = #hal.device.target<"cpu", [
+ #hal.executable.target<"llvm-cpu", "riscv_32">
+]> : !hal.device
+
+// CHECK: hal.executable private @ex
+// CHECK: hal.executable.variant public @arm_64
+// CHECK: hal.executable.variant public @riscv_32
+// CHECK: hal.executable.variant public @x86_64
+stream.executable private @ex {
+ stream.executable.export public @entry workgroups() -> (index, index, index) {
+ %c1 = arith.constant 1 : index
+ stream.return %c1, %c1, %c1 : index, index, index
+ }
+ builtin.module {
+ func.func @entry(%arg0: !stream.binding {stream.alignment = 64 : index}) {
+ return
+ }
+ }
+}
+
+// This function uses the default HAL device targeting arm_64 and x86_64.
+// CHECK-LABEL: @using_default
+util.func public @using_default(%arg0: !stream.resource<transient>, %arg1: index) -> !stream.timepoint attributes {
+ stream.affinity = #hal.device.affinity<@default_device>
} {
- // CHECK: hal.executable private @ex
- // CHECK: hal.executable.variant public @arm_64
- // CHECK: hal.executable.variant public @riscv_32
- // CHECK: hal.executable.variant public @x86_64
- stream.executable private @ex {
- stream.executable.export public @entry workgroups() -> (index, index, index) {
- %c1 = arith.constant 1 : index
- stream.return %c1, %c1, %c1 : index, index, index
+ %c0 = arith.constant 0 : index
+ %0 = stream.cmd.execute with(%arg0 as %arg2: !stream.resource<transient>{%arg1}) {
+ // CHECK: stream.cmd.dispatch
+ // CHECK-SAME: @ex::@arm_64::@entry
+ // CHECK-NOT: @ex::@riscv_32::@entry
+ // CHECK-SAME: @ex::@x86_64::@entry
+ stream.cmd.dispatch @ex::@entry {
+ rw %arg2[%c0 for %arg1] : !stream.resource<transient>{%arg1}
}
- builtin.module {
- func.func @entry(%arg0: !stream.binding {stream.alignment = 64 : index}) {
- return
- }
+ } => !stream.timepoint
+ util.return %0 : !stream.timepoint
+}
+
+// This function is specialized to only run on only riscv_32 and should
+// not get assigned the arm_64/x86_64 variant entry points.
+// CHECK-LABEL: @using_specialized
+util.func public @using_specialized(%arg0: !stream.resource<transient>, %arg1: index) -> !stream.timepoint attributes {
+ stream.affinity = #hal.device.affinity<@riscv_device>
+} {
+ %c0 = arith.constant 0 : index
+ %0 = stream.cmd.execute with(%arg0 as %arg2: !stream.resource<transient>{%arg1}) {
+ // CHECK: stream.cmd.dispatch
+ // CHECK-NOT: @ex::@arm_64::@entry
+ // CHECK-SAME: @ex::@riscv_32::@entry
+ // CHECK-NOT: @ex::@x86_64::@entry
+ stream.cmd.dispatch @ex::@entry {
+ rw %arg2[%c0 for %arg1] : !stream.resource<transient>{%arg1}
}
- }
- // This function uses the default HAL device targeting arm_64 and x86_64.
- // CHECK-LABEL: @using_default
- util.func public @using_default(%arg0: !stream.resource<transient>, %arg1: index) -> !stream.timepoint {
- %c0 = arith.constant 0 : index
- %0 = stream.cmd.execute with(%arg0 as %arg2: !stream.resource<transient>{%arg1}) {
- // CHECK: stream.cmd.dispatch
- // CHECK-SAME: @ex::@arm_64::@entry
- // CHECK-NOT: @ex::@riscv_32::@entry
- // CHECK-SAME: @ex::@x86_64::@entry
- stream.cmd.dispatch @ex::@entry {
- rw %arg2[%c0 for %arg1] : !stream.resource<transient>{%arg1}
- }
- } => !stream.timepoint
- util.return %0 : !stream.timepoint
- }
- // This function is specialized to only run on only riscv_32 and should
- // not get assigned the arm_64/x86_64 variant entry points.
- // CHECK-LABEL: @using_specialized
- util.func public @using_specialized(%arg0: !stream.resource<transient>, %arg1: index) -> !stream.timepoint attributes {
- hal.device.targets = [
- #hal.device.target<"cpu", [
- #hal.executable.target<"llvm-cpu", "riscv_32">
- ]>
- ]
- } {
- %c0 = arith.constant 0 : index
- %0 = stream.cmd.execute with(%arg0 as %arg2: !stream.resource<transient>{%arg1}) {
- // CHECK: stream.cmd.dispatch
- // CHECK-NOT: @ex::@arm_64::@entry
- // CHECK-SAME: @ex::@riscv_32::@entry
- // CHECK-NOT: @ex::@x86_64::@entry
- stream.cmd.dispatch @ex::@entry {
- rw %arg2[%c0 for %arg1] : !stream.resource<transient>{%arg1}
- }
- } => !stream.timepoint
- util.return %0 : !stream.timepoint
- }
+ } => !stream.timepoint
+ util.return %0 : !stream.timepoint
}
// -----
@@ -143,69 +144,68 @@
// hand-authored code or other dialects that perform interface assignment
// themselves.
-module attributes {
- // The default device when none is specified.
- // Functions and scopes can override the target device.
- hal.device.targets = [
- #hal.device.target<"cpu", [
- #hal.executable.target<"llvm-cpu", "arm_64">,
- #hal.executable.target<"llvm-cpu", "x86_64">
+// The default device when none is specified.
+// Functions and scopes can override the target device.
+util.global private @default_device = #hal.device.target<"cpu", [
+ #hal.executable.target<"llvm-cpu", "arm_64">,
+ #hal.executable.target<"llvm-cpu", "x86_64">
+]> : !hal.device
+util.global private @riscv_device = #hal.device.target<"cpu", [
+ #hal.executable.target<"llvm-cpu", "riscv_32">
+]> : !hal.device
+
+// CHECK: hal.executable private @ex
+// CHECK: hal.executable.variant public @arm_64
+// CHECK: hal.executable.variant public @riscv_32
+// CHECK: hal.executable.variant public @x86_64
+hal.executable.source private @ex {
+ hal.executable.export public @entry layout(#hal.pipeline.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>
]>
- ]
-} {
- // CHECK: hal.executable private @ex
- // CHECK: hal.executable.variant public @arm_64
- // CHECK: hal.executable.variant public @riscv_32
- // CHECK: hal.executable.variant public @x86_64
- hal.executable.source private @ex {
- hal.executable.export public @entry layout(#hal.pipeline.layout<push_constants = 0, sets = [
- #hal.descriptor_set.layout<0, bindings = [
- #hal.descriptor_set.binding<0, storage_buffer>
- ]>
- ]>)
- builtin.module {
- func.func @entry() {
- return
- }
+ ]>)
+ builtin.module {
+ func.func @entry() {
+ return
}
}
- // This function uses the default HAL device targeting arm_64 and x86_64.
- // CHECK-LABEL: @using_default
- util.func public @using_default(%arg0: !stream.resource<transient>, %arg1: index) -> !stream.timepoint {
- %c0 = arith.constant 0 : index
- %0 = stream.cmd.execute with(%arg0 as %arg2: !stream.resource<transient>{%arg1}) {
- // CHECK: stream.cmd.dispatch
- // CHECK-SAME: @ex::@arm_64::@entry
- // CHECK-NOT: @ex::@riscv_32::@entry
- // CHECK-SAME: @ex::@x86_64::@entry
- stream.cmd.dispatch @ex::@entry {
- rw %arg2[%c0 for %arg1] : !stream.resource<transient>{%arg1}
- }
- } => !stream.timepoint
- util.return %0 : !stream.timepoint
- }
- // This function is specialized to only run on only riscv_32 and should
- // not get assigned the arm_64/x86_64 variant entry points.
- // CHECK-LABEL: @using_specialized
- util.func public @using_specialized(%arg0: !stream.resource<transient>, %arg1: index) -> !stream.timepoint attributes {
- hal.device.targets = [
- #hal.device.target<"cpu", [
- #hal.executable.target<"llvm-cpu", "riscv_32">
- ]>
- ]
- } {
- %c0 = arith.constant 0 : index
- %0 = stream.cmd.execute with(%arg0 as %arg2: !stream.resource<transient>{%arg1}) {
- // CHECK: stream.cmd.dispatch
- // CHECK-NOT: @ex::@arm_64::@entry
- // CHECK-SAME: @ex::@riscv_32::@entry
- // CHECK-NOT: @ex::@x86_64::@entry
- stream.cmd.dispatch @ex::@entry {
- rw %arg2[%c0 for %arg1] : !stream.resource<transient>{%arg1}
- }
- } => !stream.timepoint
- util.return %0 : !stream.timepoint
- }
+}
+
+// This function uses the default HAL device targeting arm_64 and x86_64.
+// CHECK-LABEL: @using_default
+util.func public @using_default(%arg0: !stream.resource<transient>, %arg1: index) -> !stream.timepoint attributes {
+ stream.affinity = #hal.device.affinity<@default_device>
+} {
+ %c0 = arith.constant 0 : index
+ %0 = stream.cmd.execute with(%arg0 as %arg2: !stream.resource<transient>{%arg1}) {
+ // CHECK: stream.cmd.dispatch
+ // CHECK-SAME: @ex::@arm_64::@entry
+ // CHECK-NOT: @ex::@riscv_32::@entry
+ // CHECK-SAME: @ex::@x86_64::@entry
+ stream.cmd.dispatch @ex::@entry {
+ rw %arg2[%c0 for %arg1] : !stream.resource<transient>{%arg1}
+ }
+ } => !stream.timepoint
+ util.return %0 : !stream.timepoint
+}
+
+// This function is specialized to only run on only riscv_32 and should
+// not get assigned the arm_64/x86_64 variant entry points.
+// CHECK-LABEL: @using_specialized
+util.func public @using_specialized(%arg0: !stream.resource<transient>, %arg1: index) -> !stream.timepoint attributes {
+ stream.affinity = #hal.device.affinity<@riscv_device>
+} {
+ %c0 = arith.constant 0 : index
+ %0 = stream.cmd.execute with(%arg0 as %arg2: !stream.resource<transient>{%arg1}) {
+ // CHECK: stream.cmd.dispatch
+ // CHECK-NOT: @ex::@arm_64::@entry
+ // CHECK-SAME: @ex::@riscv_32::@entry
+ // CHECK-NOT: @ex::@x86_64::@entry
+ stream.cmd.dispatch @ex::@entry {
+ rw %arg2[%c0 for %arg1] : !stream.resource<transient>{%arg1}
+ }
+ } => !stream.timepoint
+ util.return %0 : !stream.timepoint
}
// -----
@@ -213,14 +213,15 @@
// Tests that a hal.executable.source op gets expanded to all default targets
// when it's public in addition to any ones from dispatch sites.
-module attributes {
- hal.device.targets = [
- #hal.device.target<"cpu", [
- #hal.executable.target<"llvm-cpu", "arm_64">,
- #hal.executable.target<"llvm-cpu", "x86_64">
- ]>
- ]
-} {
+module {
+ util.global private @primary_device = #hal.device.target<"cpu", [
+ #hal.executable.target<"llvm-cpu", "arm_64">,
+ #hal.executable.target<"llvm-cpu", "x86_64">
+ ]> : !hal.device
+ util.global private @riscv_device = #hal.device.target<"cpu", [
+ #hal.executable.target<"llvm-cpu", "riscv_32">
+ ]> : !hal.device
+
// CHECK: hal.executable public @ex
// CHECK: hal.executable.variant public @arm_64
// CHECK: hal.executable.variant public @riscv_32
@@ -239,11 +240,7 @@
}
// CHECK-LABEL: @using_specialized
util.func public @using_specialized(%arg0: !stream.resource<transient>, %arg1: index) -> !stream.timepoint attributes {
- hal.device.targets = [
- #hal.device.target<"cpu", [
- #hal.executable.target<"llvm-cpu", "riscv_32">
- ]>
- ]
+ stream.affinity = #hal.device.affinity<@riscv_device>
} {
%c0 = arith.constant 0 : index
%0 = stream.cmd.execute with(%arg0 as %arg2: !stream.resource<transient>{%arg1}) {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir
index 287822a..4e562f6 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir
@@ -1,28 +1,34 @@
// RUN: iree-opt --split-input-file --iree-hal-materialize-resource-caches %s | FileCheck %s
-// CHECK: util.global private @_descriptor_set_layout_0 : !hal.descriptor_set_layout
-
-// CHECK: util.global private @_pipeline_layout_0 : !hal.pipeline_layout
+// CHECK: util.global private @device = #hal.device.ordinal<0>
+util.global private @device = #hal.device.ordinal<0> : !hal.device
+// CHECK: util.global private @__device_pipeline_layout_0 : !hal.pipeline_layout
// CHECK-NEXT: util.initializer {
-// CHECK-DAG: %[[SET0:.+]] = util.global.load @_descriptor_set_layout_0 : !hal.descriptor_set_layout
-// CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
-// CHECK-NEXT: %[[LAYOUT:.+]] = hal.pipeline_layout.create
-// CHECK-SAME: device(%[[DEVICE]] : !hal.device)
-// CHECK-SAME: push_constants(1)
-// CHECK-SAME: layouts([%[[SET0]]]) : !hal.pipeline_layout
-// CHECK-NEXT: util.global.store %[[LAYOUT]], @_pipeline_layout_0 : !hal.pipeline_layout
+// CHECK-DAG: %[[DEVICE:.+]] = util.global.load @device
+// CHECK-DAG: %[[SET_LAYOUT_0:.+]] = hal.descriptor_set_layout.create
+// CHECK-SAME: device(%[[DEVICE]] : !hal.device)
+// CHECK-SAME: flags("None")
+// CHECK-SAME: bindings([
+// CHECK-SAME: #hal.descriptor_set.binding<0, storage_buffer>,
+// CHECK-SAME: #hal.descriptor_set.binding<1, storage_buffer>
+// CHECK-SAME: ]) : !hal.descriptor_set_layout
+// CHECK-NEXT: %[[PIPELINE_LAYOUT:.+]] = hal.pipeline_layout.create
+// CHECK-SAME: device(%[[DEVICE]] : !hal.device)
+// CHECK-SAME: push_constants(1)
+// CHECK-SAME: layouts([%[[SET_LAYOUT_0]]]) : !hal.pipeline_layout
+// CHECK-NEXT: util.global.store %[[PIPELINE_LAYOUT]], @__device_pipeline_layout_0 : !hal.pipeline_layout
// CHECK-LABEL: @exeLayoutLookup
-util.func public @exeLayoutLookup(%device : !hal.device) -> !hal.pipeline_layout {
- // CHECK: %[[LAYOUT:.+]] = util.global.load @_pipeline_layout_0 : !hal.pipeline_layout
- %0 = hal.pipeline_layout.lookup device(%device : !hal.device)
- layout(#hal.pipeline.layout<push_constants = 1, sets = [
+util.func public @exeLayoutLookup() -> !hal.pipeline_layout {
+ %device = util.global.load @device : !hal.device
+ // CHECK: %[[LOADED_LAYOUT:.+]] = util.global.load @__device_pipeline_layout_0 : !hal.pipeline_layout
+ %0 = hal.pipeline_layout.lookup device(%device : !hal.device) layout(#hal.pipeline.layout<push_constants = 1, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
#hal.descriptor_set.binding<1, storage_buffer>
]>
]>) : !hal.pipeline_layout
- // CHECK-NEXT: util.return %[[LAYOUT]]
+ // CHECK-NEXT: util.return %[[LOADED_LAYOUT]]
util.return %0 : !hal.pipeline_layout
}
@@ -42,28 +48,25 @@
]>
]>
-// TODO(scotttodd): Test without depending on a specific HAL target? Or move to HAL/Target/*/test/?
-// - If there is no matching hal.executable.variant then the executable will not be cached
-hal.executable @exe {
+// CHECK: hal.executable private @exe
+hal.executable private @exe {
+ // CHECK: hal.executable.variant public @vmvx
hal.executable.variant @vmvx target(<"vmvx", "vmvx-bytecode-fb">) {
+ // CHECK-NOT: hal.executable.condition
hal.executable.condition(%device: !hal.device) -> i1 {
%ok, %selected = hal.device.query<%device : !hal.device> key("some" :: "feature") : i1, i1
hal.return %selected : i1
}
- hal.executable.export @entry0 ordinal(0) layout(#pipeline_layout_0) attributes {
- workgroup_size = [32 : index, 1 : index, 1 : index]
- }
- hal.executable.export @entry0_alias ordinal(0) layout(#pipeline_layout_0) attributes {
- workgroup_size = [32 : index, 1 : index, 1 : index]
- }
- hal.executable.export @entry1 ordinal(1) layout(#pipeline_layout_1) attributes {
- workgroup_size = [32 : index, 1 : index, 1 : index]
- }
+ hal.executable.export @entry0 ordinal(0) layout(#pipeline_layout_0)
+ hal.executable.export @entry0_alias ordinal(0) layout(#pipeline_layout_0)
+ hal.executable.export @entry1 ordinal(1) layout(#pipeline_layout_1)
+ // CHECK-NOT: hal.executable.constant.block
hal.executable.constant.block() -> (i32, i32) as ("foo", "bar") {
%c123 = arith.constant 123 : i32
%c456 = arith.constant 456 : i32
hal.return %c123, %c456 : i32, i32
}
+ // CHECK-NOT: hal.executable.constant.block
hal.executable.constant.block(%device: !hal.device) -> i32 as "baz" {
%ok, %query = hal.device.query<%device : !hal.device> key("sys" :: "baz") : i1, i32
cf.cond_br %ok, ^bb_ok, ^bb_fail
@@ -76,16 +79,27 @@
}
}
-// CHECK-DAG: util.global private @_descriptor_set_layout_0
-// CHECK-DAG: util.global private @_pipeline_layout_0
-// CHECK-DAG: util.global private @_descriptor_set_layout_1
-// CHECK-DAG: util.global private @_pipeline_layout_1
+// CHECK: util.global private @device = #hal.device.ordinal<0>
+util.global private @device = #hal.device.ordinal<0> : !hal.device
-// CHECK: util.global private @_executable_exe : !hal.executable
-// CHECK-NEXT: util.initializer {
+// Cached resources for the device.
+// CHECK: util.global private @__device_pipeline_layout_0 : !hal.pipeline_layout
+// CHECK: util.global private @__device_pipeline_layout_1 : !hal.pipeline_layout
+// CHECK: util.global private @__device_executable_0_exe : !hal.executable
+
+// Device initializer for all resources used with the device:
+// CHECK: util.initializer
+// CHECK: %[[DEVICE:.+]] = util.global.load @device
+
+// Create pipeline layouts (and required descriptor set layouts):
+// CHECK: %[[SET_LAYOUT_0:.+]] = hal.descriptor_set_layout.create device(%[[DEVICE]] : !hal.device)
+// CHECK: %[[SET_LAYOUT_1:.+]] = hal.descriptor_set_layout.create device(%[[DEVICE]] : !hal.device)
+// CHECK: %[[PIPELINE_LAYOUT_0:.+]] = hal.pipeline_layout.create device(%[[DEVICE]] : !hal.device) push_constants(0) layouts([%[[SET_LAYOUT_0]]]) : !hal.pipeline_layout
+// CHECK: util.global.store %[[PIPELINE_LAYOUT_0]], @__device_pipeline_layout_0
+// CHECK: %[[PIPELINE_LAYOUT_1:.+]] = hal.pipeline_layout.create device(%device : !hal.device) push_constants(0) layouts([%[[SET_LAYOUT_1]]]) : !hal.pipeline_layout
+// CHECK: util.global.store %[[PIPELINE_LAYOUT_1]], @__device_pipeline_layout_1
// Switch on the supported formats:
-// CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
// CHECK: %{{.+}}, %[[FORMAT_VMVX:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.executable.format" :: "vmvx-bytecode-fb")
// CHECK: %[[VMVX_CONDITION:.+]] = scf.execute_region -> i1 {
// CHECK: %{{.+}}, %[[FEATURE:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("some" :: "feature")
@@ -98,20 +112,15 @@
// CHECK: %[[RET:.+]] = scf.index_switch %[[VARIANT_INDEX]] -> !hal.executable
// CHECK: case 0 {
-// Dependent layouts:
-// CHECK: %[[LAYOUT0:.+]] = util.global.load @_pipeline_layout_0 : !hal.pipeline_layout
-// CHECK: %[[LAYOUT0_2:.+]] = util.global.load @_pipeline_layout_0 : !hal.pipeline_layout
-// CHECK: %[[LAYOUT1:.+]] = util.global.load @_pipeline_layout_1 : !hal.pipeline_layout
-
// Constant block initializers:
-// CHECK: %[[CONST_01:.+]]:2 = util.call @__constant_block_0()
-// CHECK: %[[CONST_2:.+]] = util.call @__constant_block_1(%[[DEVICE]])
+// CHECK: %[[CONST_01:.+]]:2 = util.call @__device_executable_0_exe_constant_block_0()
+// CHECK: %[[CONST_2:.+]] = util.call @__device_executable_0_exe_constant_block_1(%[[DEVICE]])
// Executable creation:
// CHECK: %[[EXE:.+]] = hal.executable.create
// CHECK-SAME: device(%[[DEVICE]] : !hal.device)
// CHECK-SAME: target(@exe::@vmvx)
-// CHECK-SAME: layouts([%[[LAYOUT0]], %[[LAYOUT0_2]], %[[LAYOUT1]]])
+// CHECK-SAME: layouts([%[[PIPELINE_LAYOUT_0]], %[[PIPELINE_LAYOUT_0]], %[[PIPELINE_LAYOUT_1]]])
// CHECK-SAME: constants([%[[CONST_01]]#0, %[[CONST_01]]#1, %[[CONST_2]]])
// CHECK-SAME: : !hal.executable
@@ -119,18 +128,18 @@
// CHECK: }
// CHECK: default {
// CHECK: %[[C14:.+]] = arith.constant 14 : i32
-// CHECK: util.status.check_ok %[[C14]], "none of the executable binaries in the module are supported by the runtime"
+// CHECK: util.status.check_ok %[[C14]], "HAL device `device` does not support any variant of executable `exe`; available formats: [vmvx-bytecode-fb]"
// CHECK: %[[NULL:.+]] = util.null : !hal.executable
// CHECK: scf.yield %[[NULL]] : !hal.executable
// CHECK: }
-// CHECK: util.global.store %[[RET]], @_executable_exe : !hal.executable
+// CHECK: util.global.store %[[RET]], @__device_executable_0_exe : !hal.executable
-// Inlined constant block functions (here we ensure all blocks are cloned):
-// CHECK: util.func private @__constant_block_0() -> (i32, i32)
+// Constant block functions (here we ensure all blocks are cloned):
+// CHECK: util.func private @__device_executable_0_exe_constant_block_0() -> (i32, i32)
// CHECK-DAG: %[[C0:.+]] = arith.constant 123
// CHECK-DAG: %[[C1:.+]] = arith.constant 456
// CHECK: util.return %[[C0]], %[[C1]]
-// CHECK: util.func private @__constant_block_1(%[[BLOCK_DEVICE:.+]]: !hal.device) -> i32
+// CHECK: util.func private @__device_executable_0_exe_constant_block_1(%[[BLOCK_DEVICE:.+]]: !hal.device) -> i32
// CHECK: %[[OK:.+]], %[[VALUE:.+]] = hal.device.query<%[[BLOCK_DEVICE]] : !hal.device> key("sys" :: "baz")
// CHECK: cf.cond_br %[[OK]], ^bb1, ^bb2
// CHECK: ^bb1:
@@ -140,16 +149,172 @@
// CHECK: util.return %[[DUMMY]]
// CHECK-LABEL: @exeLookup
-util.func public @exeLookup(%device : !hal.device) -> !hal.executable {
- // CHECK: %[[EXE:.+]] = util.global.load @_executable_exe : !hal.executable
+util.func public @exeLookup() -> !hal.executable {
+ %device = util.global.load @device : !hal.device
+ // CHECK: %[[EXE:.+]] = util.global.load @__device_executable_0_exe : !hal.executable
%0 = hal.executable.lookup device(%device : !hal.device)
- executable(@exe) : !hal.executable
+ executable(@exe) : !hal.executable
// CHECK-NEXT: util.return %[[EXE]]
util.return %0 : !hal.executable
}
// -----
+// Tests that fallback resources are reused instead of being created again
+// when a device selects a fallback.
+
+// CHECK: hal.executable private @exe
+hal.executable private @exe {
+ // CHECK: hal.executable.variant public @vmvx
+ hal.executable.variant @vmvx target(<"vmvx", "vmvx-bytecode-fb">) {
+ // CHECK-NOT: hal.executable.condition
+ hal.executable.condition(%device: !hal.device) -> i1 {
+ %ok, %selected = hal.device.query<%device : !hal.device> key("some" :: "feature") : i1, i1
+ hal.return %selected : i1
+ }
+ hal.executable.export @entry0 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>
+ ]>
+ ]>)
+ // CHECK-NOT: hal.executable.constant.block
+ hal.executable.constant.block() -> (i32, i32) as ("foo", "bar") {
+ %c123 = arith.constant 123 : i32
+ %c456 = arith.constant 456 : i32
+ hal.return %c123, %c456 : i32, i32
+ }
+ }
+}
+
+// CHECK: util.global private @primary_device
+util.global private @primary_device = #hal.device.ordinal<0> : !hal.device
+// CHECK-NEXT: util.global private @__primary_device_pipeline_layout_0
+// CHECK-NEXT: util.global private @__primary_device_executable_0_exe
+// CHECK-NEXT: util.initializer
+// CHECK: util.global.load @primary_device
+// CHECK: hal.descriptor_set_layout.create
+// CHECK: hal.pipeline_layout.create
+// CHECK: util.global.store {{.+}}, @__primary_device_pipeline_layout_0
+// CHECK: hal.executable.create
+// CHECK: util.global.store {{.+}}, @__primary_device_executable_0_exe
+// CHECK: util.func private @__primary_device_executable_0_exe_constant_block_0
+
+// CHECK: util.global private @optional_device
+util.global private @optional_device = #hal.device.select<[
+ #hal.device.ordinal<1> : !hal.device,
+ #hal.device.fallback<@primary_device> : !hal.device
+]> : !hal.device
+// CHECK-NEXT: util.global private @__optional_device_pipeline_layout_0
+// CHECK-NEXT: util.global private @__optional_device_executable_0_exe
+// CHECK-NEXT: util.initializer
+// CHECK-DAG: %[[OPTIONAL_DEVICE:.+]] = util.global.load @optional_device
+// CHECK-DAG: %[[PRIMARY_DEVICE:.+]] = util.global.load @primary_device
+// CHECK-DAG: %[[DEVICE_EQ:.+]] = util.cmp.eq %[[OPTIONAL_DEVICE]], %[[PRIMARY_DEVICE]]
+// CHECK-DAG: %[[INDEX:.+]] = arith.select %[[DEVICE_EQ]]
+// CHECK-DAG: scf.index_switch %[[INDEX]]
+// CHECK: case 0
+// CHECK: %[[PRIMARY_LAYOUT:.+]] = util.global.load @__primary_device_pipeline_layout_0
+// CHECK: util.global.store %[[PRIMARY_LAYOUT]], @__optional_device_pipeline_layout_0
+// CHECK: %[[PRIMARY_EXE:.+]] = util.global.load @__primary_device_executable_0_exe
+// CHECK: util.global.store %[[PRIMARY_EXE]], @__optional_device_executable_0_exe
+// CHECK: default
+// CHECK: hal.descriptor_set_layout.create
+// CHECK: hal.pipeline_layout.create
+// CHECK: util.global.store {{.+}}, @__optional_device_pipeline_layout_0
+// CHECK: hal.executable.create
+// CHECK: util.global.store {{.+}}, @__optional_device_executable_0_exe
+// CHECK: util.func private @__optional_device_executable_0_exe_constant_block_0
+
+// CHECK-LABEL: @fallbackLookup
+util.func public @fallbackLookup() -> (!hal.executable, !hal.executable) {
+ %primary_device = util.global.load @primary_device : !hal.device
+ // CHECK: %[[PRIMARY_EXE_LOOKUP:.+]] = util.global.load @__primary_device_executable_0_exe
+ %0 = hal.executable.lookup device(%primary_device : !hal.device)
+ executable(@exe) : !hal.executable
+ %optional_device = util.global.load @optional_device : !hal.device
+ // CHECK: %[[OPTIONAL_EXE_LOOKUP:.+]] = util.global.load @__optional_device_executable_0_exe
+ %1 = hal.executable.lookup device(%optional_device : !hal.device)
+ executable(@exe) : !hal.executable
+ util.return %0, %1 : !hal.executable, !hal.executable
+}
+
+// -----
+
+// Tests that resources only used by optional devices force the resources to
+// be created on fallbacks. This isn't optimal as we should really only be
+// creating them if the fallback is selected but that's more complex than it's
+// worth today given the limited usage of fallbacks.
+
+hal.executable private @exe {
+ hal.executable.variant @vmvx target(<"vmvx", "vmvx-bytecode-fb">) {
+ hal.executable.export @entry0 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>
+ ]>
+ ]>)
+ }
+}
+
+// CHECK-LABEL: util.global private @primary_device
+util.global private @primary_device = #hal.device.ordinal<0> : !hal.device
+// CHECK-NEXT: util.global private @__primary_device_pipeline_layout_0
+// CHECK-NEXT: util.global private @__primary_device_executable_0_exe
+// CHECK-NEXT: util.initializer
+// CHECK: util.global.load @primary_device
+// CHECK: hal.descriptor_set_layout.create
+// CHECK: hal.pipeline_layout.create
+// CHECK: util.global.store {{.+}}, @__primary_device_pipeline_layout_0
+// CHECK: hal.executable.create
+// CHECK: util.global.store {{.+}}, @__primary_device_executable_0_exe
+
+// CHECK-LABEL: util.global private @optional_device_0
+util.global private @optional_device_0 = #hal.device.select<[
+ #hal.device.ordinal<1> : !hal.device,
+ #hal.device.fallback<@primary_device> : !hal.device
+]> : !hal.device
+// CHECK-NEXT: util.global private @__optional_device_0_pipeline_layout_0
+// CHECK-NEXT: util.global private @__optional_device_0_executable_0_exe
+// CHECK-NEXT: util.initializer
+// CHECK-DAG: %[[OPTIONAL_DEVICE_0:.+]] = util.global.load @optional_device_0
+// CHECK-DAG: %[[PRIMARY_DEVICE:.+]] = util.global.load @primary_device
+// CHECK-DAG: %[[DEVICE_EQ:.+]] = util.cmp.eq %[[OPTIONAL_DEVICE_0]], %[[PRIMARY_DEVICE]]
+// CHECK-DAG: %[[INDEX:.+]] = arith.select %[[DEVICE_EQ]]
+// CHECK-DAG: scf.index_switch %[[INDEX]]
+// CHECK: util.global.load @__primary_device_pipeline_layout_0
+// CHECK: util.global.store {{.+}}, @__optional_device_0_pipeline_layout_0
+// CHECK: util.global.load @__primary_device_executable_0_exe
+// CHECK: util.global.store {{.+}}, @__optional_device_0_executable_0_exe
+
+// CHECK-LABEL: util.global private @optional_device_1
+util.global private @optional_device_1 = #hal.device.select<[
+ #hal.device.ordinal<2> : !hal.device,
+ #hal.device.fallback<@optional_device_0> : !hal.device
+]> : !hal.device
+// CHECK-NEXT: util.global private @__optional_device_1_pipeline_layout_0
+// CHECK-NEXT: util.global private @__optional_device_1_executable_0_exe
+// CHECK-NEXT: util.initializer
+// CHECK-DAG: %[[OPTIONAL_DEVICE_1:.+]] = util.global.load @optional_device_1
+// CHECK-DAG: %[[OPTIONAL_DEVICE_0:.+]] = util.global.load @optional_device_0
+// CHECK-DAG: %[[DEVICE_EQ:.+]] = util.cmp.eq %[[OPTIONAL_DEVICE_1]], %[[OPTIONAL_DEVICE_0]]
+// CHECK-DAG: %[[INDEX:.+]] = arith.select %[[DEVICE_EQ]]
+// CHECK-DAG: scf.index_switch %[[INDEX]]
+// CHECK: util.global.load @__optional_device_0_pipeline_layout_0
+// CHECK: util.global.store {{.+}}, @__optional_device_1_pipeline_layout_0
+// CHECK: util.global.load @__optional_device_0_executable_0_exe
+// CHECK: util.global.store {{.+}}, @__optional_device_1_executable_0_exe
+
+// CHECK-LABEL: @fallbackOnlyLookup
+util.func public @fallbackOnlyLookup() -> !hal.executable {
+ %optional_device_1 = util.global.load @optional_device_1 : !hal.device
+ // CHECK: util.global.load @__optional_device_1_executable_0_exe
+ %0 = hal.executable.lookup device(%optional_device_1 : !hal.device)
+ executable(@exe) : !hal.executable
+ util.return %0 : !hal.executable
+}
+
+// -----
+
// Tests that materialization no-ops when resource caches have already been
// materialized. Today this is rather simplistic and just bails if the names
// match with the expectation being that users are mostly just running through
@@ -164,6 +329,8 @@
]>
]>
+util.global private @device : !hal.device
+
util.global private @_descriptor_set_layout_0 : !hal.descriptor_set_layout
util.initializer {
%c0 = arith.constant 0 : index
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_target_devices.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_target_devices.mlir
new file mode 100644
index 0000000..6360abe
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_target_devices.mlir
@@ -0,0 +1,116 @@
+// RUN: iree-opt --split-input-file --iree-hal-materialize-target-devices %s --verify-diagnostics | FileCheck %s
+
+// expected-error@+1 {{invalid device targets specified}}
+module @module attributes {
+ hal.device.targets = [
+ "wrong_type"
+ ]
+} {
+ util.func private @func() -> ()
+}
+
+// -----
+
+// Modules without anything that needs an environment are OK as-is.
+
+// CHECK: module @module
+module @module {
+ // CHECK-NEXT: hal.executable private @exe
+ hal.executable private @exe {
+ // CHECK-NEXT: hal.executable.variant public @embedded_elf_arm_64
+ hal.executable.variant public @embedded_elf_arm_64 target(#hal.executable.target<"backend", "format", {}>) {}
+ }
+}
+
+// -----
+
+// Valid input with proper attributes for a single device.
+
+// CHECK: #[[DEVICE_A:.+]] = #hal.device.target<"device_a"
+#device_a = #hal.device.target<"device_a", [#hal.executable.target<"backend_a", "format_a">]>
+// CHECK: #[[DEVICE_B:.+]] = #hal.device.target<"device_b"
+#device_b = #hal.device.target<"device_b", [#hal.executable.target<"backend_b", "format_b">]>
+
+// CHECK: module @module
+// CHECK-NOT: hal.device.targets
+// CHECK-SAME: stream.affinity.default = #hal.device.affinity<@__device_0>
+module @module attributes {
+ hal.device.targets = [
+ #hal.device.select<[#device_a, #device_b]> : !hal.device
+ ]
+} {
+ // CHECK: util.global private @__device_0 = #hal.device.select<[
+ // CHECK-SAME: #[[DEVICE_A]],
+ // CHECK-SAME: #[[DEVICE_B]]
+ // CHECK-SAME: ]> : !hal.device
+}
+
+// -----
+
+// Multiple devices using device names.
+
+// CHECK: #[[DEVICE_A:.+]] = #hal.device.target<"device_a"
+#device_a = #hal.device.target<"device_a", [#hal.executable.target<"backend_a", "format_a">]>
+// CHECK: #[[DEVICE_B:.+]] = #hal.device.target<"device_b"
+#device_b = #hal.device.target<"device_b", [#hal.executable.target<"backend_b", "format_b">]>
+// CHECK: #[[DEVICE_C:.+]] = #hal.device.target<"device_c"
+#device_c = #hal.device.target<"device_c", [#hal.executable.target<"backend_c", "format_c">]>
+
+// CHECK: module @module
+// CHECK-NOT: hal.device.targets
+// CHECK-SAME: stream.affinity.default = #hal.device.affinity<@device_a>
+module @module attributes {
+ hal.device.targets = {
+ device_a = #device_a,
+ device_bc = [#device_b, #device_c]
+ }
+} {
+ // CHECK: util.global private @device_a = #[[DEVICE_A]]
+ // CHECK: util.global private @device_bc = #hal.device.select<[#[[DEVICE_B]], #[[DEVICE_C]]]>
+}
+
+// -----
+
+// Default device selection by name.
+
+// CHECK: #[[DEVICE_A:.+]] = #hal.device.target<"device_a"
+#device_a = #hal.device.target<"device_a", [#hal.executable.target<"backend_a", "format_a">]>
+// CHECK: #[[DEVICE_B:.+]] = #hal.device.target<"device_b"
+#device_b = #hal.device.target<"device_b", [#hal.executable.target<"backend_b", "format_b">]>
+
+// CHECK: module @module
+// CHECK-NOT: hal.device.targets
+// CHECK-SAME: stream.affinity.default = #hal.device.affinity<@device_b>
+module @module attributes {
+ hal.device.targets = {
+ device_a = #device_a,
+ device_b = #device_b
+ },
+ hal.device.default = "device_b"
+} {
+ // CHECK: util.global private @device_a
+ // CHECK: util.global private @device_b
+}
+
+// -----
+
+// Default device selection by ordinal.
+
+// CHECK: #[[DEVICE_A:.+]] = #hal.device.target<"device_a"
+#device_a = #hal.device.target<"device_a", [#hal.executable.target<"backend_a", "format_a">]>
+// CHECK: #[[DEVICE_B:.+]] = #hal.device.target<"device_b"
+#device_b = #hal.device.target<"device_b", [#hal.executable.target<"backend_b", "format_b">]>
+
+// CHECK: module @module
+// CHECK-NOT: hal.device.targets
+// CHECK-SAME: stream.affinity.default = #hal.device.affinity<@__device_1>
+module @module attributes {
+ hal.device.targets = [
+ #device_a,
+ #device_b
+ ],
+ hal.device.default = 1 : index
+} {
+ // CHECK: util.global private @__device_0
+ // CHECK: util.global private @__device_1
+}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/memoize_device_queries.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/memoize_device_queries.mlir
index 5211bd9..16dfd5c 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/memoize_device_queries.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/memoize_device_queries.mlir
@@ -1,38 +1,71 @@
// RUN: iree-opt --split-input-file --iree-hal-memoize-device-queries --canonicalize %s | FileCheck %s
-// CHECK: util.global private @_device_query_0 : i1
-// CHECK-NEXT: util.global private @_device_query_0_ok : i1
-// CHECK-NEXT: util.initializer {
-// CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
-// CHECK-NEXT: %[[OK0:.+]], %[[VALUE0:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.device.id" :: "id0*") : i1, i1 = false
-// CHECK-NEXT: util.global.store %[[OK0]], @_device_query_0_ok : i1
-// CHECK-NEXT: util.global.store %[[VALUE0]], @_device_query_0 : i1
+// Tests that unknown devices (here passed as an arg on a public function) don't
+// get memoized.
-// CHECK: util.global private @_device_query_1 : i1
-// CHECK-NEXT: util.global private @_device_query_1_ok : i1
-// CHECK-NEXT: util.initializer {
-// CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
-// CHECK-NEXT: %[[OK1:.+]], %[[VALUE1:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.device.id" :: "id1") : i1, i1 = false
-// CHECK-NEXT: util.global.store %[[OK1]], @_device_query_1_ok : i1
-// CHECK-NEXT: util.global.store %[[VALUE1]], @_device_query_1 : i1
+// CHECK-LABEL: @unknown_device
+// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device)
+util.func public @unknown_device(%device: !hal.device) -> i1 {
+ // CHECK-NEXT: hal.device.query<%[[DEVICE]]
+ %id0_ok, %id0 = hal.device.query<%device : !hal.device> key("hal.device.id" :: "id0") : i1, i1 = false
+ util.return %id0 : i1
+}
-// CHECK: util.global private @_device_query_2
+// -----
-// CHECK-LABEL: util.func public @device_matchers
-util.func public @device_matchers(%device : !hal.device) -> (i1, i1, i1, i1, i1, i1) {
- // Same queries (same variables):
- // CHECK-NEXT: = util.global.load @_device_query_0_ok : i1
- // CHECK-NEXT: = util.global.load @_device_query_0 : i1
- %id0_a_ok, %id0_a = hal.device.query<%device : !hal.device> key("hal.device.id" :: "id0*") : i1, i1 = false
- // CHECK-NEXT: = util.global.load @_device_query_0_ok : i1
- // CHECK-NEXT: = util.global.load @_device_query_0 : i1
- %id0_b_ok, %id0_b = hal.device.query<%device : !hal.device> key("hal.device.id" :: "id0*") : i1, i1 = false
+// Tests that multiple possible devices disable memoization.
+// TODO(multi-device): enable propagation of queried values across the program.
+// We should be able to track back to each global, memoize there, then pass
+// through the value as a normal SSA value.
- // Same query but with different defaults (different variables):
- // CHECK-NEXT: = util.global.load @_device_query_1 : i1
- %id1_a_ok, %id1_a = hal.device.query<%device : !hal.device> key("hal.device.id" :: "id1") : i1, i1 = false
- // CHECK-NEXT: = util.global.load @_device_query_2 : i1
- %id1_b_ok, %id1_b = hal.device.query<%device : !hal.device> key("hal.device.id" :: "id1") : i1, i1 = true
+util.global private @device_a : !hal.device
+util.global private @device_b : !hal.device
- util.return %id0_a_ok, %id0_a, %id0_b_ok, %id0_b, %id1_a, %id1_b : i1, i1, i1, i1, i1, i1
+// CHECK-LABEL: @multi_device_not_memoized
+util.func public @multi_device_not_memoized(%cond: i1) -> i1 {
+ // CHECK-DAG: %[[DEVICE_A:.+]] = util.global.load @device_a
+ %device_a = util.global.load @device_a : !hal.device
+ // CHECK-DAG: %[[DEVICE_B:.+]] = util.global.load @device_b
+ %device_b = util.global.load @device_b : !hal.device
+ // CHECK: %[[DEVICE_AB:.+]] = arith.select %{{.+}}, %[[DEVICE_A]], %[[DEVICE_B]]
+ %device_ab = arith.select %cond, %device_a, %device_b : !hal.device
+ // CHECK-NEXT: hal.device.query<%[[DEVICE_AB]]
+ %id0_ok, %id0 = hal.device.query<%device_ab : !hal.device> key("hal.device.id" :: "id0") : i1, i1 = false
+ util.return %id0 : i1
+}
+
+// -----
+
+// Tests basic hoisting of device queries up to an initializer per device.
+
+// CHECK: util.global private @device
+util.global private @device : !hal.device
+// CHECK-NEXT: util.global private @__device_query_0_hal_device_id_id0_ok : i1
+// CHECK-NEXT: util.global private @__device_query_0_hal_device_id_id0 : i1
+// CHECK-NEXT: util.global private @__device_query_1_hal_device_id_id1_ok : i1
+// CHECK-NEXT: util.global private @__device_query_1_hal_device_id_id1 : i1
+// CHECK-NEXT: util.initializer
+// CHECK: %[[DEVICE:.+]] = util.global.load @device : !hal.device
+// CHECK: %[[OK0:.+]], %[[VALUE0:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.device.id" :: "id0") : i1, i1 = false
+// CHECK: util.global.store %[[OK0]], @__device_query_0_hal_device_id_id0_ok : i1
+// CHECK: util.global.store %[[VALUE0]], @__device_query_0_hal_device_id_id0 : i1
+// CHECK: %[[OK1:.+]], %[[VALUE1:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.device.id" :: "id1") : i1, i1 = false
+// CHECK: util.global.store %[[OK1]], @__device_query_1_hal_device_id_id1_ok : i1
+// CHECK: util.global.store %[[VALUE1]], @__device_query_1_hal_device_id_id1 : i1
+
+// CHECK: @single_device_memoized_0
+util.func public @single_device_memoized_0() -> (i1, i1) {
+ %device = util.global.load @device : !hal.device
+ // CHECK-NEXT: = util.global.load @__device_query_0_hal_device_id_id0_ok : i1
+ // CHECK-NEXT: = util.global.load @__device_query_0_hal_device_id_id0 : i1
+ %id0_ok, %id0 = hal.device.query<%device : !hal.device> key("hal.device.id" :: "id0") : i1, i1 = false
+ util.return %id0_ok, %id0 : i1, i1
+}
+// CHECK: @single_device_memoized_1
+util.func public @single_device_memoized_1() -> (i1, i1) {
+ %device = util.global.load @device : !hal.device
+ // CHECK-NEXT: = util.global.load @__device_query_1_hal_device_id_id1_ok : i1
+ // CHECK-NEXT: = util.global.load @__device_query_1_hal_device_id_id1 : i1
+ %id1_ok, %id1 = hal.device.query<%device : !hal.device> key("hal.device.id" :: "id1") : i1, i1 = false
+ util.return %id1_ok, %id1 : i1, i1
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/resolve_device_aliases.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/resolve_device_aliases.mlir
new file mode 100644
index 0000000..82a45cc
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/resolve_device_aliases.mlir
@@ -0,0 +1,41 @@
+// RUN: iree-opt --split-input-file --iree-hal-resolve-device-aliases %s --mlir-print-local-scope --verify-diagnostics | FileCheck %s
+
+// CHECK: util.global private @device
+// CHECK-SAME: #hal.device.target<"local"
+// CHECK-SAME: extra_config = 4 : index
+// CHECK-SAME: #hal.executable.target<"vmvx"
+util.global private @device = #hal.device.alias<"vmvx", {
+ extra_config = 4 : index
+}> : !hal.device
+
+// -----
+
+// CHECK: util.global private @device_ordinal
+// CHECK-SAME: #hal.device.target<"local"
+// CHECK-SAME: ordinal = 123 : index
+// CHECK-SAME: #hal.executable.target<"vmvx"
+util.global private @device_ordinal = #hal.device.alias<"vmvx"[123]> : !hal.device
+
+// -----
+
+// CHECK: util.global private @device_select
+// CHECK-SAME: #hal.device.select<[
+// CHECK-SAME: #hal.device.target<"local", {ordinal = 0 : index}
+// CHECK-SAME: #hal.device.target<"local", {ordinal = 1 : index}
+util.global private @device_select = #hal.device.select<[
+ #hal.device.alias<"vmvx"[0]> : !hal.device,
+ #hal.device.alias<"vmvx"[1]> : !hal.device
+]> : !hal.device
+
+// -----
+
+// expected-error@+1 {{unregistered device alias "__unregistered__"}}
+util.global private @device_unregistered = #hal.device.alias<"__unregistered__"> : !hal.device
+
+// -----
+
+// expected-error@+1 {{unregistered device alias "__unregistered__"}}
+util.global private @device_select_unregistered = #hal.device.select<[
+ #hal.device.alias<"vmvx"> : !hal.device,
+ #hal.device.alias<"__unregistered__"> : !hal.device
+]> : !hal.device
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/resolve_device_promises.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/resolve_device_promises.mlir
new file mode 100644
index 0000000..2b5985f
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/resolve_device_promises.mlir
@@ -0,0 +1,43 @@
+// RUN: iree-opt --split-input-file --iree-hal-resolve-device-promises %s --mlir-print-local-scope --verify-diagnostics | FileCheck %s
+
+// Resolves device promises.
+
+// CHECK: module @module
+module @module attributes {
+ // CHECK-SAME: stream.affinity = #hal.device.affinity<@device0, [1, 2, 3]>
+ stream.affinity = #hal.device.promise<@device0, [1, 2, 3]>
+} {
+ util.global private @device0 = #hal.device.target<"vmvx"> : !hal.device
+ util.global private @device1 = #hal.device.target<"vmvx"> : !hal.device
+ // CHECK: util.func private @func
+ util.func private @func(%arg0: tensor<i32> {
+ // CHECK-SAME: arg.affinity = #hal.device.affinity<@device1>
+ arg.affinity = #hal.device.promise<@device1>
+ }) -> (tensor<i32> {
+ // CHECK-SAME: result.affinity = #hal.device.affinity<@device1>
+ result.affinity = #hal.device.promise<@device1>
+ }) attributes {
+ // CHECK-SAME: func.affinity = #hal.device.affinity<@device1>
+ func.affinity = #hal.device.promise<@device1>
+ } {
+ // CHECK: util.return
+ util.return {
+ // CHECK-SAME: some.affinities = [#hal.device.affinity<@device0>, #hal.device.affinity<@device1>]
+ some.affinities = [#hal.device.promise<@device0>, #hal.device.promise<@device1>]
+ } %arg0 : tensor<i32>
+ }
+}
+
+// -----
+
+// Verifies that promised devices exist.
+
+module @module {
+ util.global private @device = #hal.device.target<"vmvx"> : !hal.device
+ // expected-error@+1 {{op references a promised device that was not declared}}
+ util.func private @func() -> () attributes {
+ stream.affinity = #hal.device.promise<@unknown_device>
+ } {
+ util.return
+ }
+}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/verify_devices.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/verify_devices.mlir
new file mode 100644
index 0000000..b4e2264
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/verify_devices.mlir
@@ -0,0 +1,73 @@
+// RUN: iree-opt --split-input-file --iree-hal-verify-devices %s --mlir-print-local-scope --verify-diagnostics | FileCheck %s
+
+// Tests that modules without tensors don't need devices.
+
+module @module {
+ // CHECK: util.func private @func
+ util.func private @func() -> ()
+}
+
+// -----
+
+// TODO(multi-device): find a way to verify that devices exist if they need to.
+// Currently the check is disabled as it's difficult to tell if a device will be
+// needed by the time we get to the HAL layer: plugins may absorb things, etc.
+// NO-expected-errorx@+1 {{no HAL devices defined in the module}}
+module @module {
+ util.func private @func() -> () {
+ arith.constant dense<1.0> : tensor<4xf32>
+ util.return
+ }
+}
+
+// -----
+
+module @module {
+ // expected-error@+1 {{unregistered target device "__unregistered__"}}
+ util.global private @device = #hal.device.target<"__unregistered__"> : !hal.device
+ util.func private @func() -> () attributes {
+ stream.affinity = #hal.device.affinity<@device>
+ }
+}
+
+// -----
+
+module @module {
+ // expected-error@+1 {{unregistered target device "__unregistered__"}}
+ util.global private @device = #hal.device.select<[
+ #hal.device.target<"vmvx"> : !hal.device,
+ #hal.device.target<"__unregistered__"> : !hal.device
+ ]> : !hal.device
+ util.func private @func() -> () attributes {
+ stream.affinity = #hal.device.affinity<@device>
+ }
+}
+
+// -----
+
+// Valid input with proper attributes.
+
+// CHECK: module @module
+module @module {
+ util.global private @device = #hal.device.target<"vmvx"> : !hal.device
+ util.global private @optional = #hal.device.fallback<@device> : !hal.device
+ util.global private @ordinal = #hal.device.ordinal<0> : !hal.device
+ util.global private @selected = #hal.device.select<[
+ #hal.device.target<"llvm-cpu"> : !hal.device,
+ #hal.device.target<"vmvx"> : !hal.device
+ ]> : !hal.device
+ util.func private @func() -> () attributes {
+ stream.affinity = #hal.device.affinity<@device>
+ }
+}
+
+// -----
+
+// Modules without anything that needs an environment are OK.
+
+// CHECK: module @module
+module @module {
+ hal.executable private @exe {
+ hal.executable.variant public @embedded_elf_arm_64 target(#hal.executable.target<"llvm-cpu", "embedded-elf-arm_64", {}>) {}
+ }
+}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/verify_target_environment.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/verify_target_environment.mlir
deleted file mode 100644
index 81f1e1c..0000000
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/verify_target_environment.mlir
+++ /dev/null
@@ -1,54 +0,0 @@
-// RUN: iree-opt --split-input-file --iree-hal-verify-target-environment %s --verify-diagnostics | FileCheck %s
-
-// expected-error@+1 {{no HAL target devices specified}}
-module @module {
- util.func private @func() -> ()
-}
-
-// -----
-
-// expected-error@+1 {{no HAL target devices specified}}
-module @module attributes {hal.device.targets = []} {
- util.func private @func() -> ()
-}
-
-// -----
-
-// expected-error@+1 {{invalid target attr type}}
-module @module attributes {hal.device.targets = ["wrong_type"]} {
- util.func private @func() -> ()
-}
-
-// -----
-
-// expected-error@+1 {{unregistered target device "foo"}}
-module @module attributes {hal.device.targets = [#hal.device.target<"foo">]} {
- util.func private @func() -> ()
-}
-
-// -----
-
-// Valid input with proper attributes.
-
-// CHECK: #device_target_vmvx = #hal.device.target<"vmvx">
-#device_target_vmvx = #hal.device.target<"vmvx">
-
-// CHECK: module @module attributes {hal.device.targets = [#device_target_vmvx]}
-module @module attributes {hal.device.targets = [#device_target_vmvx]} {
- util.func private @func() -> ()
-}
-
-// -----
-
-// Modules without anything that needs an environment are OK.
-
-#executable_target = #hal.executable.target<"llvm-cpu", "embedded-elf-arm_64", {}>
-
-// CHECK: module @module
-module @module {
- // CHECK-NEXT: hal.executable private @exe
- hal.executable private @exe {
- // CHECK-NEXT: hal.executable.variant public @embedded_elf_arm_64
- hal.executable.variant public @embedded_elf_arm_64 target(#executable_target) {}
- }
-}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Affinity.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Affinity.cpp
new file mode 100644
index 0000000..8351c91
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Affinity.cpp
@@ -0,0 +1,1051 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Stream/Analysis/Affinity.h"
+
+#include <utility>
+
+#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
+#include "iree/compiler/Dialect/Util/Analysis/DFX/Element.h"
+#include "iree/compiler/Dialect/Util/Analysis/DFX/State.h"
+#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+
+#define DEBUG_TYPE "iree-util-dfx"
+
+namespace mlir::iree_compiler::IREE::Stream {
+
+//===----------------------------------------------------------------------===//
+// Utilities
+//===----------------------------------------------------------------------===//
+
+static const std::string getAffinitySetAsStr(
+ const DFX::PotentialValuesState<IREE::Stream::AffinityAttr> &state,
+ AsmState &asmState) {
+ std::string str;
+ llvm::raw_string_ostream sstream(str);
+ sstream << "pvs: ";
+ if (state.isValidState()) {
+ sstream << "[";
+ if (state.isUndefContained()) {
+ sstream << "undef, ";
+ }
+ llvm::interleaveComma(state.getAssumedSet(), sstream,
+ [&](IREE::Stream::AffinityAttr value) {
+ cast<Attribute>(value).print(sstream);
+ });
+ sstream << "]";
+ } else {
+ sstream << "(invalid)";
+ }
+ sstream.flush();
+ return str;
+}
+
+//===----------------------------------------------------------------------===//
+// Analysis elements
+//===----------------------------------------------------------------------===//
+
+class ValueProducerAffinityPVS
+ : public DFX::StateWrapper<
+ DFX::PotentialValuesState<IREE::Stream::AffinityAttr>,
+ DFX::ValueElement> {
+public:
+ using BaseType =
+ DFX::StateWrapper<DFX::PotentialValuesState<IREE::Stream::AffinityAttr>,
+ DFX::ValueElement>;
+ using BaseType::BaseType;
+
+ static ValueProducerAffinityPVS &createForPosition(const Position &pos,
+ DFX::Solver &solver) {
+ return *(new (solver.getAllocator()) ValueProducerAffinityPVS(pos));
+ }
+
+ // Identity definitions.
+ const std::string getName() const override {
+ return "ValueProducerAffinityPVS";
+ }
+ const void *getID() const override { return &ID; }
+ static bool classof(const DFX::AbstractElement *element) {
+ return (element->getID() == &ID);
+ }
+ static const char ID;
+
+ const std::string getAsStr(AsmState &asmState) const override {
+ return getAffinitySetAsStr(getState(), asmState);
+ }
+
+private:
+ void initializeValue(Value value, DFX::Solver &solver) override;
+ ChangeStatus updateValue(Value value, DFX::Solver &solver) override;
+ void updateFromUse(Value value, OpOperand &operand, StateType &newState,
+ DFX::Solver &solver);
+
+ // Operations that the value is pinned to.
+ SetVector<Operation *> pinnedOps;
+};
+const char ValueProducerAffinityPVS::ID = 0;
+
+class GlobalAffinityPVS
+ : public DFX::StateWrapper<
+ DFX::PotentialValuesState<IREE::Stream::AffinityAttr>,
+ DFX::TypedOperationElement<IREE::Util::GlobalOpInterface>> {
+public:
+ using BaseType = DFX::StateWrapper<
+ DFX::PotentialValuesState<IREE::Stream::AffinityAttr>,
+ DFX::TypedOperationElement<IREE::Util::GlobalOpInterface>>;
+ using BaseType::BaseType;
+
+ static GlobalAffinityPVS &createForPosition(const Position &pos,
+ DFX::Solver &solver) {
+ return *(new (solver.getAllocator()) GlobalAffinityPVS(pos));
+ }
+
+ // Identity definitions.
+ const std::string getName() const override { return "GlobalAffinityPVS"; }
+ const void *getID() const override { return &ID; }
+ static bool classof(const DFX::AbstractElement *element) {
+ return (element->getID() == &ID);
+ }
+ static const char ID;
+
+ const std::string getAsStr(AsmState &asmState) const override {
+ return getAffinitySetAsStr(getState(), asmState);
+ }
+
+private:
+ void initializeOperation(IREE::Util::GlobalOpInterface globalOp,
+ DFX::Solver &solver) override;
+ ChangeStatus updateOperation(IREE::Util::GlobalOpInterface globalOp,
+ DFX::Solver &solver) override;
+};
+const char GlobalAffinityPVS::ID = 0;
+
+class OpAffinityPVS : public DFX::StateWrapper<
+ DFX::PotentialValuesState<IREE::Stream::AffinityAttr>,
+ DFX::OperationElement> {
+public:
+ using BaseType =
+ DFX::StateWrapper<DFX::PotentialValuesState<IREE::Stream::AffinityAttr>,
+ DFX::OperationElement>;
+ using BaseType::BaseType;
+
+ static OpAffinityPVS &createForPosition(const Position &pos,
+ DFX::Solver &solver) {
+ return *(new (solver.getAllocator()) OpAffinityPVS(pos));
+ }
+
+ // Identity definitions.
+ const std::string getName() const override { return "OpAffinityPVS"; }
+ const void *getID() const override { return &ID; }
+ static bool classof(const DFX::AbstractElement *element) {
+ return (element->getID() == &ID);
+ }
+ static const char ID;
+
+ const std::string getAsStr(AsmState &asmState) const override {
+ return getAffinitySetAsStr(getState(), asmState);
+ }
+
+private:
+ void initializeOperation(Operation *op, DFX::Solver &solver) override;
+ ChangeStatus updateOperation(Operation *op, DFX::Solver &solver) override;
+};
+const char OpAffinityPVS::ID = 0;
+
+//===----------------------------------------------------------------------===//
+// ValueConsumerAffinityPVS
+//===----------------------------------------------------------------------===//
+
+class ValueConsumerAffinityPVS
+ : public DFX::StateWrapper<
+ DFX::PotentialValuesState<IREE::Stream::AffinityAttr>,
+ DFX::ValueElement> {
+public:
+ using BaseType =
+ DFX::StateWrapper<DFX::PotentialValuesState<IREE::Stream::AffinityAttr>,
+ DFX::ValueElement>;
+ using BaseType::BaseType;
+
+ static ValueConsumerAffinityPVS &createForPosition(const Position &pos,
+ DFX::Solver &solver) {
+ return *(new (solver.getAllocator()) ValueConsumerAffinityPVS(pos));
+ }
+
+ // Identity definitions.
+ const std::string getName() const override {
+ return "ValueConsumerAffinityPVS";
+ }
+ const void *getID() const override { return &ID; }
+ static bool classof(const DFX::AbstractElement *element) {
+ return (element->getID() == &ID);
+ }
+ static const char ID;
+
+ const std::string getAsStr(AsmState &asmState) const override {
+ return getAffinitySetAsStr(getState(), asmState);
+ }
+
+private:
+ void initializeValue(Value value, DFX::Solver &solver) override;
+ ChangeStatus updateValue(Value value, DFX::Solver &solver) override;
+ TraversalResult updateFromUse(Value value, OpOperand &operand,
+ StateType &newState, DFX::Solver &solver);
+};
+const char ValueConsumerAffinityPVS::ID = 0;
+
+void ValueConsumerAffinityPVS::initializeValue(Value value,
+ DFX::Solver &solver) {}
+
+ChangeStatus ValueConsumerAffinityPVS::updateValue(Value value,
+ DFX::Solver &solver) {
+ StateType newState;
+ auto traversalResult = TraversalResult::COMPLETE;
+
+ // Walk into all consumers of the SSA value.
+ // Note that we may end up at multiple global stores of different globals
+ // by walking down through calls/branches/etc.
+ traversalResult |= solver.getExplorer().walkTransitiveUses(
+ value,
+ [&](OpOperand &operand) {
+ traversalResult |= updateFromUse(value, operand, newState, solver);
+ return WalkResult::advance();
+ },
+ (TraversalBehavior::DEFAULT | TraversalBehavior::DONT_WALK_TIED_VALUES));
+
+ if (traversalResult == TraversalResult::INCOMPLETE) {
+ // Incomplete traversal because of external call graph edges or pointers.
+ newState.unionAssumedWithUndef();
+ newState.indicatePessimisticFixpoint();
+ }
+ return DFX::clampStateAndIndicateChange(getState(), newState);
+}
+
+TraversalResult ValueConsumerAffinityPVS::updateFromUse(Value value,
+ OpOperand &operand,
+ StateType &newState,
+ DFX::Solver &solver) {
+ // If the value is consumed by an affinity-aware op then we can directly use
+ // the affinity specified on the op. A majority of the values we care about at
+ // the stream level are consumed by affinity-aware ops and earlier in the
+ // pipeline dialects may have transfer ops that define affinities we can
+ // anchor on.
+ if (auto affinityOp =
+ dyn_cast<IREE::Stream::AffinityOpInterface>(operand.getOwner())) {
+ auto opPVS = solver.getElementFor<OpAffinityPVS>(
+ *this, Position::forOperation(operand.getOwner()),
+ DFX::Resolution::REQUIRED);
+ LLVM_DEBUG({
+ llvm::dbgs() << "[ValueConsumerAffinityPVS] value ";
+ value.printAsOperand(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << " affinity using consumer affinity from ";
+ operand.get().printAsOperand(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << " as ";
+ opPVS.print(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << "\n";
+ });
+ newState ^= opPVS;
+ }
+
+ // If the consumer op has the operand tied to one or more results then we walk
+ // through to track the transitive consumers. When this analysis runs we are
+ // usually still prior to baking out copy-on-write behavior so it's possible
+ // that the results of the tied operation end up in different places.
+ if (auto tiedOp = dyn_cast<IREE::Util::TiedOpInterface>(operand.getOwner())) {
+ auto tiedResults = tiedOp.getOperandTiedResults(operand.getOperandNumber());
+ for (auto tiedResult : tiedResults) {
+ auto resultPVS = solver.getElementFor<ValueConsumerAffinityPVS>(
+ *this, Position::forValue(tiedResult), DFX::Resolution::REQUIRED);
+ LLVM_DEBUG({
+ llvm::dbgs() << "[ValueConsumerAffinityPVS] value ";
+ value.printAsOperand(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << " affinity referencing tied operand ";
+ operand.get().printAsOperand(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << " result ";
+ tiedResult.printAsOperand(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << " as ";
+ resultPVS.print(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << "\n";
+ });
+ newState ^= resultPVS;
+ }
+ }
+
+ // Handle consumers that are not affinity aware - this should have any control
+ // flow ops so that we can track values that flow through the program.
+ return TypeSwitch<Operation *, TraversalResult>(operand.getOwner())
+ .Case([&](mlir::arith::SelectOp op) {
+ auto &resultPVS = solver.getElementFor<ValueConsumerAffinityPVS>(
+ *this, Position::forValue(op.getResult()),
+ DFX::Resolution::REQUIRED);
+ newState ^= resultPVS.getState();
+ return TraversalResult::COMPLETE;
+ })
+ .Case([&](mlir::BranchOpInterface op) {
+ return solver.getExplorer().walkOutgoingBranchOperandArguments(
+ op, operand.getOperandNumber(),
+ [&](Block *targetBlock, BlockArgument arg) {
+ auto &argUsage = solver.getElementFor<ValueConsumerAffinityPVS>(
+ *this, Position::forValue(arg), DFX::Resolution::OPTIONAL);
+ newState ^= argUsage;
+ return WalkResult::advance();
+ });
+ })
+ .Case([&](mlir::scf::ForOp op) {
+ if (operand.getOperandNumber() >= op.getNumControlOperands()) {
+ int64_t blockIdx =
+ operand.getOperandNumber() - op.getNumControlOperands();
+ auto &beforeUsage = solver.getElementFor<ValueConsumerAffinityPVS>(
+ *this, Position::forValue(op.getRegionIterArg(blockIdx)),
+ DFX::Resolution::REQUIRED);
+ newState ^= beforeUsage.getState();
+ }
+ return TraversalResult::COMPLETE;
+ })
+ .Case([&](mlir::scf::WhileOp op) {
+ auto &beforeUsage = solver.getElementFor<ValueConsumerAffinityPVS>(
+ *this,
+ Position::forValue(
+ op.getBeforeBody()->getArgument(operand.getOperandNumber())),
+ DFX::Resolution::REQUIRED);
+ newState ^= beforeUsage.getState();
+ return TraversalResult::COMPLETE;
+ })
+ .Case([&](mlir::scf::ConditionOp op) {
+ auto &parentUsage = solver.getElementFor<ValueConsumerAffinityPVS>(
+ *this,
+ Position::forValue(
+ op->getParentOp()->getResult(operand.getOperandNumber() - 1)),
+ DFX::Resolution::REQUIRED);
+ newState ^= parentUsage.getState();
+ if (auto whileOp =
+ dyn_cast_or_null<mlir::scf::WhileOp>(op->getParentOp())) {
+ auto value = Position::forValue(
+ whileOp.getAfter().getArgument(operand.getOperandNumber() - 1));
+ auto &valueUsage = solver.getElementFor<ValueConsumerAffinityPVS>(
+ *this, value, DFX::Resolution::REQUIRED);
+ newState ^= valueUsage.getState();
+ }
+ return TraversalResult::COMPLETE;
+ })
+ .Case([&](mlir::scf::YieldOp op) {
+ if (isa<mlir::scf::IfOp>(op->getParentOp())) {
+ auto &operandUsage = solver.getElementFor<ValueConsumerAffinityPVS>(
+ *this,
+ Position::forValue(op->getOperand(operand.getOperandNumber())),
+ DFX::Resolution::REQUIRED);
+ newState ^= operandUsage.getState();
+ auto &parentUsage = solver.getElementFor<ValueConsumerAffinityPVS>(
+ *this,
+ Position::forValue(
+ op->getParentOp()->getResult(operand.getOperandNumber())),
+ DFX::Resolution::REQUIRED);
+ newState ^= parentUsage.getState();
+ return TraversalResult::COMPLETE;
+ } else if (auto whileOp =
+ dyn_cast<mlir::scf::WhileOp>(op->getParentOp())) {
+ auto value = Position::forValue(
+ whileOp.getBefore().getArgument(operand.getOperandNumber()));
+ auto &valueUsage = solver.getElementFor<ValueConsumerAffinityPVS>(
+ *this, value, DFX::Resolution::REQUIRED);
+ newState ^= valueUsage.getState();
+ auto &parentUsage = solver.getElementFor<ValueConsumerAffinityPVS>(
+ *this,
+ Position::forValue(
+ whileOp->getResult(operand.getOperandNumber())),
+ DFX::Resolution::REQUIRED);
+ newState ^= parentUsage.getState();
+ return TraversalResult::COMPLETE;
+ } else if (auto forOp = dyn_cast<mlir::scf::ForOp>(op->getParentOp())) {
+ auto value = Position::forValue(
+ forOp.getRegionIterArg(operand.getOperandNumber()));
+ auto &valueUsage = solver.getElementFor<ValueConsumerAffinityPVS>(
+ *this, value, DFX::Resolution::REQUIRED);
+ newState ^= valueUsage.getState();
+ auto &parentUsage = solver.getElementFor<ValueConsumerAffinityPVS>(
+ *this,
+ Position::forValue(forOp->getResult(operand.getOperandNumber())),
+ DFX::Resolution::REQUIRED);
+ newState ^= parentUsage.getState();
+ return TraversalResult::COMPLETE;
+ } else {
+ assert(false && "unhandled scf yield parent");
+ return TraversalResult::INCOMPLETE;
+ }
+ })
+ .Case([&](IREE::Util::ReturnOp op) {
+ return solver.getExplorer().walkIncomingCalls(
+ op->getParentOfType<mlir::CallableOpInterface>(),
+ [&](mlir::CallOpInterface callOp) {
+ auto &argUsage = solver.getElementFor<ValueConsumerAffinityPVS>(
+ *this,
+ Position::forValue(
+ callOp->getResult(operand.getOperandNumber())),
+ DFX::Resolution::OPTIONAL);
+ getState() ^= argUsage;
+ return WalkResult::advance();
+ });
+ })
+ .Case([&](IREE::Util::OptimizationBarrierOp op) {
+ auto &resultPVS = solver.getElementFor<ValueConsumerAffinityPVS>(
+ *this, Position::forValue(op.getResult(operand.getOperandNumber())),
+ DFX::Resolution::REQUIRED);
+ newState ^= resultPVS.getState();
+ return TraversalResult::COMPLETE;
+ })
+ .Case([&](IREE::Util::GlobalStoreOpInterface op) {
+ auto *globalInfo =
+ solver.getExplorer().queryGlobalInfoFrom(op.getGlobalName(), op);
+ auto &globalPVS = solver.getElementFor<GlobalAffinityPVS>(
+ *this, Position::forOperation(globalInfo->op),
+ DFX::Resolution::REQUIRED);
+ newState ^= globalPVS.getState();
+ return TraversalResult::COMPLETE;
+ })
+ .Default([&](Operation *op) { return TraversalResult::COMPLETE; });
+}
+
+//===----------------------------------------------------------------------===//
+// ValueProducerAffinityPVS
+//===----------------------------------------------------------------------===//
+
+void ValueProducerAffinityPVS::initializeValue(Value value,
+ DFX::Solver &solver) {
+ solver.getExplorer().walkDefiningOps(value, [&](OpResult result) {
+ if (!isa<IREE::Stream::AffinityTypeInterface>(result.getType())) {
+ return WalkResult::skip();
+ }
+ if (auto affinityOp =
+ dyn_cast_if_present<IREE::Stream::AffinityOpInterface>(
+ result.getOwner())) {
+ if (affinityOp.pinsValueAffinity()) {
+ pinnedOps.insert(result.getOwner());
+ }
+ }
+ return WalkResult::advance();
+ });
+ solver.getExplorer().walkTransitiveUses(value, [&](OpOperand &operand) {
+ if (!isa<IREE::Stream::AffinityTypeInterface>(operand.get().getType())) {
+ return WalkResult::skip();
+ }
+ if (auto affinityOp =
+ dyn_cast_if_present<IREE::Stream::AffinityOpInterface>(
+ operand.getOwner())) {
+ if (affinityOp.pinsValueAffinity()) {
+ pinnedOps.insert(operand.getOwner());
+ }
+ }
+ return WalkResult::advance();
+ });
+}
+
+ChangeStatus ValueProducerAffinityPVS::updateValue(Value value,
+ DFX::Solver &solver) {
+ StateType newState;
+
+ // If there are any ops that produce the value and pin to a specific affinity
+ // then we take those directly and ignore all others.
+ if (!pinnedOps.empty()) {
+ for (auto pinnedOp : pinnedOps) {
+ auto &opPVS = solver.getElementFor<OpAffinityPVS>(
+ *this, Position::forOperation(pinnedOp), DFX::Resolution::REQUIRED);
+ newState ^= opPVS;
+ }
+ return DFX::clampStateAndIndicateChange(getState(), newState);
+ }
+
+ // We special case some ops that act as barriers in the program. This prevents
+ // us from walking past boundaries that are not profitable to do so with; for
+ // example, globals are usually stored in independent contexts from where they
+ // are consumed.
+ if (auto barrierOp = dyn_cast_if_present<IREE::Util::OptimizationBarrierOp>(
+ value.getDefiningOp())) {
+ auto operand =
+ barrierOp.getOperand(cast<OpResult>(value).getResultNumber());
+ auto operandPVS = solver.getElementFor<ValueProducerAffinityPVS>(
+ *this, Position::forValue(operand), DFX::Resolution::REQUIRED);
+ LLVM_DEBUG({
+ llvm::dbgs() << "[ValueProducerAffinityPVS] value ";
+ value.printAsOperand(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << " affinity using barrier op operand as ";
+ operandPVS.print(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << "\n";
+ });
+ newState ^= operandPVS;
+ return DFX::clampStateAndIndicateChange(getState(), newState);
+ } else if (auto loadOp =
+ dyn_cast_if_present<IREE::Util::GlobalLoadOpInterface>(
+ value.getDefiningOp())) {
+ auto *globalInfo = solver.getExplorer().queryGlobalInfoFrom(
+ loadOp.getGlobalName(), loadOp);
+ auto &globalPVS = solver.getElementFor<GlobalAffinityPVS>(
+ *this, Position::forOperation(globalInfo->op),
+ DFX::Resolution::REQUIRED);
+ LLVM_DEBUG({
+ llvm::dbgs() << "[ValueProducerAffinityPVS] value ";
+ value.printAsOperand(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << " affinity using global op affinity from "
+ << loadOp.getGlobalName() << " as ";
+ globalPVS.print(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << "\n";
+ });
+ newState ^= globalPVS.getState();
+ return DFX::clampStateAndIndicateChange(getState(), newState);
+ }
+
+ // Walk the program up into any possible producers of the value.
+ auto traversalResult = TraversalResult::COMPLETE;
+ traversalResult |= solver.getExplorer().walkDefiningOps(
+ value,
+ [&](OpResult result) {
+ if (isa<CallOpInterface>(result.getOwner())) {
+ return WalkResult::advance();
+ }
+
+ // If coming from an affinity-aware op that pins the value storage to a
+ // particular affinity that overrides all other logic.
+ if (auto affinityOp =
+ dyn_cast_if_present<IREE::Stream::AffinityOpInterface>(
+ result.getDefiningOp())) {
+ if (affinityOp.pinsValueAffinity()) {
+ auto &opPVS = solver.getElementFor<OpAffinityPVS>(
+ *this, Position::forOperation(affinityOp),
+ DFX::Resolution::REQUIRED);
+ LLVM_DEBUG({
+ llvm::dbgs() << "[ValueProducerAffinityPVS] value ";
+ value.printAsOperand(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << " affinity using assuming pinned affinity from ";
+ result.printAsOperand(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << " as ";
+ opPVS.print(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << "\n";
+ });
+ newState ^= opPVS;
+ newState.indicateOptimisticFixpoint();
+ return WalkResult::advance();
+ }
+ }
+
+ // If the result value is tied to an operand of the defining op then
+ // inherit the operand affinity.
+ if (auto tiedOp = dyn_cast_if_present<IREE::Util::TiedOpInterface>(
+ result.getDefiningOp())) {
+ auto operand = tiedOp.getTiedResultOperand(result);
+ if (operand) {
+ auto &valuePVS = solver.getElementFor<ValueProducerAffinityPVS>(
+ *this, Position::forValue(operand), DFX::Resolution::OPTIONAL);
+ LLVM_DEBUG({
+ llvm::dbgs() << "[ValueProducerAffinityPVS] value ";
+ value.printAsOperand(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << " affinity referencing tied operand ";
+ operand.printAsOperand(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << " as ";
+ valuePVS.print(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << "\n";
+ });
+ newState ^= valuePVS;
+ return WalkResult::advance();
+ }
+ }
+
+ // If the value is produced by the defining op then assume that the
+ // execution affinity dictates the result affinity.
+ if (auto affinityOp =
+ dyn_cast_if_present<IREE::Stream::AffinityOpInterface>(
+ result.getDefiningOp())) {
+ auto &opPVS = solver.getOrCreateElementFor<OpAffinityPVS>(
+ Position::forOperation(result.getOwner()), *this,
+ DFX::Resolution::OPTIONAL, /*forceUpdate=*/false,
+ /*updateAfterInit=*/false);
+ LLVM_DEBUG({
+ llvm::dbgs() << "[ValueProducerAffinityPVS] value ";
+ value.printAsOperand(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << " affinity using op affinity from result ";
+ result.printAsOperand(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << " as ";
+ opPVS.print(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << "\n";
+ });
+ newState ^= opPVS;
+ return WalkResult::advance();
+ }
+
+ // Special handling for specific ops.
+ TypeSwitch<Operation *>(result.getOwner())
+ .Case<IREE::Util::GlobalLoadOpInterface>([&](auto loadOp) {
+ auto *globalInfo = solver.getExplorer().queryGlobalInfoFrom(
+ loadOp.getGlobalName(), loadOp);
+ auto &globalPVS = solver.getElementFor<GlobalAffinityPVS>(
+ *this, Position::forOperation(globalInfo->op),
+ DFX::Resolution::REQUIRED);
+ LLVM_DEBUG({
+ llvm::dbgs() << "[ValueProducerAffinityPVS] value ";
+ value.printAsOperand(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs()
+ << " affinity using global op affinity from result ";
+ result.printAsOperand(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << " as ";
+ globalPVS.print(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << "\n";
+ });
+ newState ^= globalPVS.getState();
+ })
+ .Case<mlir::arith::SelectOp>([&](auto op) {
+ auto &truePVS = solver.getElementFor<ValueProducerAffinityPVS>(
+ *this, Position::forValue(op.getTrueValue()),
+ DFX::Resolution::REQUIRED);
+ newState ^= truePVS.getState();
+ auto &falsePVS = solver.getElementFor<ValueProducerAffinityPVS>(
+ *this, Position::forValue(op.getFalseValue()),
+ DFX::Resolution::REQUIRED);
+ newState ^= falsePVS.getState();
+ })
+ .Default([&](auto op) {
+ auto valuePVS = solver.getElementFor<ValueProducerAffinityPVS>(
+ *this, Position::forValue(result), DFX::Resolution::OPTIONAL);
+ newState ^= valuePVS;
+ });
+ return WalkResult::advance();
+ },
+ (TraversalBehavior::DEFAULT | TraversalBehavior::DONT_WALK_TIED_VALUES));
+
+ if (traversalResult == TraversalResult::INCOMPLETE) {
+ // Incomplete traversal because of external call graph edges or pointers.
+ newState.unionAssumedWithUndef();
+ newState.indicatePessimisticFixpoint();
+ }
+ return DFX::clampStateAndIndicateChange(getState(), newState);
+}
+
+//===----------------------------------------------------------------------===//
+// GlobalAffinityPVS
+//===----------------------------------------------------------------------===//
+
+void GlobalAffinityPVS::initializeOperation(
+ IREE::Util::GlobalOpInterface globalOp, DFX::Solver &solver) {
+ // If an affinity is explicitly specified we take that over all analysis.
+ if (auto affinityAttr = IREE::Stream::AffinityAttr::lookup(globalOp)) {
+ LLVM_DEBUG({
+ llvm::dbgs() << "[GlobalAffinityPVS] global @"
+ << globalOp.getGlobalName().getValue()
+ << " affinity explicitly specified as ";
+ affinityAttr.print(llvm::dbgs());
+ llvm::dbgs() << "\n";
+ });
+ unionAssumed(affinityAttr);
+ indicateOptimisticFixpoint();
+ return;
+ }
+}
+
+ChangeStatus
+GlobalAffinityPVS::updateOperation(IREE::Util::GlobalOpInterface globalOp,
+ DFX::Solver &solver) {
+ StateType newState;
+ auto traversalResult = TraversalResult::COMPLETE;
+
+ const auto *globalInfo = solver.getExplorer().getGlobalInfo(globalOp);
+ if (globalInfo->isIndirect) {
+ traversalResult = TraversalResult::INCOMPLETE;
+ }
+
+ // Traverse all transitive uses of the global.
+ // We try to place globals where they are used as the common case is weights
+ // or parameters that are read more frequently than they are written.
+ // The reasoning is that if there are more writes than reads there's unneeded
+ // work being done and otherwise there's always at least one read per write
+ // or more reads than writes.
+ bool anyLoads = false;
+ for (auto loadOp : globalInfo->getLoads()) {
+ anyLoads = true;
+ auto &valuePVS = solver.getElementFor<ValueConsumerAffinityPVS>(
+ *this, Position::forValue(loadOp.getLoadedGlobalValue()),
+ DFX::Resolution::OPTIONAL);
+ if (valuePVS.isValidState()) {
+ newState ^= valuePVS;
+ }
+ }
+
+ // If there were no loads then take the affinity from stores.
+ // This is not common but can arise in tests or where the globals may be used
+ // to model side-effecting behavior.
+ if (!anyLoads) {
+ for (auto storeOp : globalInfo->getStores()) {
+ auto &valuePVS = solver.getElementFor<ValueProducerAffinityPVS>(
+ *this, Position::forValue(storeOp.getStoredGlobalValue()),
+ DFX::Resolution::OPTIONAL);
+ if (valuePVS.isValidState()) {
+ newState ^= valuePVS;
+ }
+ }
+ }
+
+ if (traversalResult == TraversalResult::INCOMPLETE) {
+ // Incomplete traversal because of external call graph edges or pointers.
+ newState.unionAssumedWithUndef();
+ newState.indicatePessimisticFixpoint();
+ }
+ return DFX::clampStateAndIndicateChange(getState(), newState);
+}
+
+//===----------------------------------------------------------------------===//
+// OpAffinityPVS
+//===----------------------------------------------------------------------===//
+
+void OpAffinityPVS::initializeOperation(Operation *op, DFX::Solver &solver) {
+ // If an affinity is explicitly specified we take that over all analysis.
+ if (auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op)) {
+ LLVM_DEBUG({
+ llvm::dbgs() << "[OpAffinityPVS] op ";
+ op->getName().print(llvm::dbgs());
+ llvm::dbgs() << " affinity explicitly specified as ";
+ affinityAttr.print(llvm::dbgs());
+ llvm::dbgs() << "\n";
+ });
+ unionAssumed(affinityAttr);
+ indicateOptimisticFixpoint();
+ return;
+ }
+}
+
+ChangeStatus OpAffinityPVS::updateOperation(Operation *op,
+ DFX::Solver &solver) {
+ StateType newState;
+
+ const bool consumesAny = llvm::any_of(
+ op->getOperandTypes(), +[](Type type) {
+ return isa<IREE::Stream::AffinityTypeInterface>(type);
+ });
+ if (consumesAny) {
+ for (auto operand : op->getOperands()) {
+ if (isa<IREE::Stream::AffinityTypeInterface>(operand.getType())) {
+ auto valuePVS = solver.getElementFor<ValueProducerAffinityPVS>(
+ *this, Position::forValue(operand), DFX::Resolution::REQUIRED);
+ newState ^= valuePVS;
+ }
+ }
+ } else {
+ for (auto result : op->getResults()) {
+ if (isa<IREE::Stream::AffinityTypeInterface>(result.getType())) {
+ auto valuePVS = solver.getElementFor<ValueConsumerAffinityPVS>(
+ *this, Position::forValue(result), DFX::Resolution::REQUIRED);
+ newState ^= valuePVS;
+ }
+ }
+ }
+
+ return DFX::clampStateAndIndicateChange(getState(), newState);
+}
+
+//===----------------------------------------------------------------------===//
+// AffinityAnalysis
+//===----------------------------------------------------------------------===//
+
+// Tries to find a default affinity specified on an ancestor of |fromOp| and
+// adds it to |affinities|. Returns true if an affinity was found.
+static bool tryLookupDefaultAffinity(
+ Operation *fromOp,
+ SmallVectorImpl<IREE::Stream::AffinityAttr> &affinities) {
+ while (fromOp) {
+ auto affinityAttr = fromOp->getAttrOfType<IREE::Stream::AffinityAttr>(
+ "stream.affinity.default");
+ if (affinityAttr) {
+ affinities.push_back(affinityAttr);
+ return true;
+ }
+ fromOp = fromOp->getParentOp();
+ }
+ return false;
+}
+
+// Returns the first affinity if all affinities are compatible and otherwise
+// returns nullptr.
+static IREE::Stream::AffinityAttr
+trySelectLeadAffinity(ArrayRef<IREE::Stream::AffinityAttr> affinities) {
+ if (affinities.empty()) {
+ return {};
+ }
+ auto leadAffinityAttr = affinities.front();
+ for (size_t i = 1; i < affinities.size(); ++i) {
+ if (!IREE::Stream::AffinityAttr::areCompatible(affinities[i],
+ leadAffinityAttr)) {
+ return {};
+ }
+ }
+ return leadAffinityAttr;
+}
+
+// Sorts |affinities| in the natural affinity sort order.
+// We unfortunately have to do this as the PVS elements we source from are
+// unsorted.
+static void
+sortAffinities(SmallVectorImpl<IREE::Stream::AffinityAttr> &affinities) {
+ // HACK: this should probably do a type id ordering followed by a
+ // type-specific ordering (interface compare method?). We just need this to be
+ // stable as the affinities come from multiple DenseSets that have run-to-run
+ // ordering variance. This is very inefficient but is only used when there are
+ // multiple possible affinities and we try to avoid that anyway.
+ if (affinities.size() <= 1) {
+ return;
+ }
+ llvm::stable_sort(affinities, [](IREE::Stream::AffinityAttr lhs,
+ IREE::Stream::AffinityAttr rhs) {
+ std::string lhsStr;
+ llvm::raw_string_ostream lhsStream(lhsStr);
+ lhs.print(lhsStream);
+ std::string rhsStr;
+ llvm::raw_string_ostream rhsStream(rhsStr);
+ rhs.print(rhsStream);
+ return lhsStr < rhsStr;
+ });
+}
+
+AffinityAnalysis::AffinityAnalysis(Operation *rootOp)
+ : explorer(rootOp, TraversalAction::RECURSE), solver(explorer, allocator) {
+ explorer.setOpInterfaceAction<mlir::FunctionOpInterface>(
+ TraversalAction::RECURSE);
+
+ explorer.setDialectAction<mlir::scf::SCFDialect>(TraversalAction::RECURSE);
+
+ explorer.setDialectAction<IREE::Stream::StreamDialect>(
+ TraversalAction::RECURSE);
+ explorer.setOpAction<IREE::Stream::ExecutableOp>(TraversalAction::IGNORE);
+
+ explorer.initialize();
+}
+
+AffinityAnalysis::~AffinityAnalysis() = default;
+
+IREE::Stream::AffinityAttr
+AffinityAnalysis::lookupGlobalAffinity(Operation *op) {
+ SmallVector<IREE::Stream::AffinityAttr> affinities;
+ if (!tryLookupGlobalAffinity(op, affinities) || affinities.empty()) {
+ return {};
+ }
+ if (affinities.size() == 1) {
+ return affinities.front();
+ }
+ return trySelectLeadAffinity(affinities);
+}
+
+bool AffinityAnalysis::tryLookupGlobalAffinity(
+ Operation *op, SmallVectorImpl<IREE::Stream::AffinityAttr> &affinities) {
+ auto globalPVS =
+ solver.lookupElementFor<GlobalAffinityPVS>(Position::forOperation(op));
+ if (!globalPVS || !globalPVS->isValidState() ||
+ globalPVS->isUndefContained()) {
+ // Analysis failed.
+ return false;
+ }
+ if (globalPVS->getAssumedSet().empty()) {
+ // Analysis completed but no affinity was specified; try to find a default.
+ return tryLookupDefaultAffinity(op, affinities);
+ }
+ for (auto affinityAttr : globalPVS->getAssumedSet()) {
+ affinities.push_back(affinityAttr);
+ }
+ sortAffinities(affinities);
+ return true;
+}
+
+IREE::Stream::AffinityAttr
+AffinityAnalysis::lookupExecutionAffinity(Operation *op) {
+ SmallVector<IREE::Stream::AffinityAttr> affinities;
+ if (!tryLookupExecutionAffinity(op, affinities) || affinities.empty()) {
+ return {};
+ }
+ if (affinities.size() == 1) {
+ return affinities.front();
+ }
+ return trySelectLeadAffinity(affinities);
+}
+
+bool AffinityAnalysis::tryLookupExecutionAffinity(
+ Operation *op, SmallVectorImpl<IREE::Stream::AffinityAttr> &affinities) {
+ auto opPVS =
+ solver.lookupElementFor<OpAffinityPVS>(Position::forOperation(op));
+ if (!opPVS || !opPVS->isValidState() || opPVS->isUndefContained()) {
+ // Analysis failed.
+ return false;
+ }
+ if (opPVS->getAssumedSet().empty()) {
+ // Analysis completed but no affinity was specified; try to find a default.
+ return tryLookupDefaultAffinity(op, affinities);
+ }
+ for (auto affinityAttr : opPVS->getAssumedSet()) {
+ affinities.push_back(affinityAttr);
+ }
+ sortAffinities(affinities);
+ return true;
+}
+
+IREE::Stream::AffinityAttr
+AffinityAnalysis::inferExecutionAffinity(Operation *op) {
+ SmallVector<IREE::Stream::AffinityAttr> affinities;
+ if (!tryInferExecutionAffinity(op, affinities) || affinities.empty()) {
+ return {};
+ }
+ if (affinities.size() == 1) {
+ return affinities.front();
+ }
+ return trySelectLeadAffinity(affinities);
+}
+
+bool AffinityAnalysis::tryInferExecutionAffinity(
+ Operation *op, SmallVectorImpl<IREE::Stream::AffinityAttr> &affinities) {
+ if (auto affinityOp = dyn_cast<IREE::Stream::AffinityOpInterface>(op)) {
+ return tryLookupExecutionAffinity(op, affinities);
+ }
+ DFX::PotentialValuesState<IREE::Stream::AffinityAttr> opPVS;
+ const bool consumesAny = llvm::any_of(
+ op->getOperandTypes(), +[](Type type) {
+ return isa<IREE::Stream::AffinityTypeInterface>(type);
+ });
+ if (consumesAny) {
+ for (auto operand : op->getOperands()) {
+ if (isa<IREE::Stream::AffinityTypeInterface>(operand.getType())) {
+ auto valuePVS = solver.lookupElementFor<ValueProducerAffinityPVS>(
+ Position::forValue(operand), nullptr, DFX::Resolution::REQUIRED);
+ if (valuePVS && valuePVS->isValidState()) {
+ opPVS.unionAssumed(valuePVS->getState());
+ } else {
+ return false;
+ }
+ }
+ }
+ } else {
+ for (auto result : op->getResults()) {
+ if (isa<IREE::Stream::AffinityTypeInterface>(result.getType())) {
+ auto valuePVS = solver.lookupElementFor<ValueConsumerAffinityPVS>(
+ Position::forValue(result), nullptr, DFX::Resolution::REQUIRED);
+ if (valuePVS && valuePVS->isValidState()) {
+ opPVS.unionAssumed(valuePVS->getState());
+ } else {
+ return false;
+ }
+ }
+ }
+ }
+ if (!opPVS.isValidState() || opPVS.isUndefContained()) {
+ // Analysis failed.
+ return false;
+ }
+ if (opPVS.getAssumedSet().empty()) {
+ // Analysis completed but no affinity was specified; try to find a default.
+ return tryLookupDefaultAffinity(op, affinities);
+ }
+ for (auto affinityAttr : opPVS.getAssumedSet()) {
+ affinities.push_back(affinityAttr);
+ }
+ sortAffinities(affinities);
+ return true;
+}
+
+IREE::Stream::AffinityAttr
+AffinityAnalysis::lookupResourceAffinity(Value value) {
+ SmallVector<IREE::Stream::AffinityAttr> affinities;
+ if (!tryLookupResourceAffinity(value, affinities) || affinities.empty()) {
+ return {};
+ }
+ if (affinities.size() == 1) {
+ return affinities.front();
+ }
+ return trySelectLeadAffinity(affinities);
+}
+
+bool AffinityAnalysis::tryLookupResourceAffinity(
+ Value value, SmallVectorImpl<IREE::Stream::AffinityAttr> &affinities) {
+ auto valuePVS = solver.lookupElementFor<ValueProducerAffinityPVS>(
+ Position::forValue(value));
+ if (!valuePVS || !valuePVS->isValidState() || valuePVS->isUndefContained()) {
+ // Analysis failed.
+ return false;
+ }
+ if (valuePVS->getAssumedSet().empty()) {
+ // Analysis completed but no affinity was specified; try to find a default.
+ return tryLookupDefaultAffinity(value.getParentBlock()->getParentOp(),
+ affinities);
+ }
+ for (auto affinityAttr : valuePVS->getAssumedSet()) {
+ affinities.push_back(affinityAttr);
+ }
+ sortAffinities(affinities);
+ return true;
+}
+
+LogicalResult AffinityAnalysis::run() {
+ // Initialize globals so that we can assign them affinity.
+ explorer.forEachGlobal([&](const auto *globalInfo) {
+ if (isa<IREE::Stream::AffinityTypeInterface>(
+ globalInfo->op.getGlobalType())) {
+ solver.getOrCreateElementFor<GlobalAffinityPVS>(
+ Position::forOperation(globalInfo->op));
+ }
+ });
+
+ // Initialize op execution affinities for any ops that use tracked types.
+ //
+ // TODO(benvanik): avoid doing this initialization for the entire module and
+ // instead rely on DFX to automatically populate the required abstract values.
+ // There's some missing logic in the element initialization, though, and by
+ // initializing all values we side-step that and work with test programs that
+ // may not have I/O edges that we could easily latch on to here.
+ explorer.forEachFunctionLikeOp([&](FunctionOpInterface funcOp) {
+ for (auto &block : funcOp.getBlocks()) {
+ for (auto arg : block.getArguments()) {
+ if (isa<IREE::Stream::AffinityTypeInterface>(arg.getType())) {
+ solver.getOrCreateElementFor<ValueProducerAffinityPVS>(
+ Position::forValue(arg));
+ }
+ }
+ }
+ funcOp.walk([&](Operation *op) {
+ if (auto regionOp = dyn_cast<RegionBranchOpInterface>(op)) {
+ for (auto ®ion : regionOp->getRegions()) {
+ for (auto arg : region.getArguments()) {
+ if (isa<IREE::Stream::AffinityTypeInterface>(arg.getType())) {
+ solver.getOrCreateElementFor<ValueProducerAffinityPVS>(
+ Position::forValue(arg));
+ }
+ }
+ }
+ }
+ if (auto affinityOp = dyn_cast<IREE::Stream::AffinityOpInterface>(op)) {
+ solver.getOrCreateElementFor<OpAffinityPVS>(Position::forOperation(op));
+ }
+ for (auto result : op->getResults()) {
+ if (isa<IREE::Stream::AffinityTypeInterface>(result.getType())) {
+ solver.getOrCreateElementFor<ValueProducerAffinityPVS>(
+ Position::forValue(result));
+ }
+ }
+ });
+ });
+
+ if (failed(solver.run())) {
+ return failure(); // did not converge
+ }
+
+ LLVM_DEBUG({
+ llvm::dbgs()
+ << "\n\n[Analysis] affinity analysis results for the whole module:\n";
+ solver.print(llvm::dbgs());
+ llvm::dbgs() << "\n";
+ });
+
+ return success();
+}
+
+} // namespace mlir::iree_compiler::IREE::Stream
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Affinity.h b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Affinity.h
new file mode 100644
index 0000000..3642a53
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Affinity.h
@@ -0,0 +1,102 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_STREAM_ANALYSIS_AFFINITY_H_
+#define IREE_COMPILER_DIALECT_STREAM_ANALYSIS_AFFINITY_H_
+
+#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
+#include "iree/compiler/Dialect/Util/Analysis/DFX/Solver.h"
+#include "iree/compiler/Dialect/Util/Analysis/Explorer.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Diagnostics.h"
+
+namespace mlir::iree_compiler::IREE::Stream {
+
+//===----------------------------------------------------------------------===//
+// Affinity analysis
+//===----------------------------------------------------------------------===//
+
+// Performs whole-program analysis of resource and tensor value affinity.
+// All `!stream.resource` and `tensor` SSA values will be analyzed and their
+// affinities where used will be available for querying via the lookup
+// functions.
+class AffinityAnalysis {
+public:
+ explicit AffinityAnalysis(Operation *rootOp);
+ ~AffinityAnalysis();
+
+ // Runs analysis and populates the resource usage map.
+ // May fail if analysis cannot be completed due to unsupported or unknown IR.
+ LogicalResult run();
+
+ // Returns the affinity of the global |op| based on its loads.
+ // The global storage should be allocated with this affinity and available for
+ // fast access from any compatible affinity.
+ //
+ // If an explicit affinity is provided via a stream.affinity attribute then
+ // that will be used in place of analysis. If there are more than one consumer
+ // (such as multiple loads) with differing affinities or analysis fails then
+ // no affinity is returned. If all affinities are compatible one will be
+ // chosen in an unspecified way.
+ IREE::Stream::AffinityAttr lookupGlobalAffinity(Operation *op);
+
+ // Populates all potential affinities of the global |op| in |affinities|.
+ // Returns false if analysis failed and the set of affinities is unknown.
+ bool tryLookupGlobalAffinity(
+ Operation *op, SmallVectorImpl<IREE::Stream::AffinityAttr> &affinities);
+
+ // Returns the affinity of the executable |op| based on the op-specific rules
+ // as to whether its operands or results control placement. The operation
+ // should be scheduled to execute with this affinity and efficiently consume
+ // or produce resources that share a compatible affinity.
+ //
+ // If an explicit affinity is provided via stream.affinity attrs or the
+ // affinity op interface then that will be used in place of analysis. If there
+ // are multiple possible affinities or analysis fails no affinity is returned.
+ // If all affinities are compatible one will be chosen in an unspecified way.
+ IREE::Stream::AffinityAttr lookupExecutionAffinity(Operation *op);
+
+ // Populates all potential execution affinities of |op| in |affinities|.
+ // Returns false if analysis failed and the set of affinities is unknown.
+ bool tryLookupExecutionAffinity(
+ Operation *op, SmallVectorImpl<IREE::Stream::AffinityAttr> &affinities);
+
+ // Returns the affinity of |op| as if it were executable even if it is not.
+ // This relies on analysis of operands and results having resolved and
+ // otherwise returns nullptr indicating the op has no assumed affinity.
+ IREE::Stream::AffinityAttr inferExecutionAffinity(Operation *op);
+
+ // Populates all inferred potential execution affinities of |op| in
+ // |affinities|. This relies on analysis of operands and results having
+ // resolved and otherwise returns nullptr indicating the op has no assumed
+ // affinity.
+ // Returns false if analysis failed and the set of affinities is unknown.
+ bool tryInferExecutionAffinity(
+ Operation *op, SmallVectorImpl<IREE::Stream::AffinityAttr> &affinities);
+
+ // Returns the affinity of |value| based on its producers.
+ // The resource should be allocated with this affinity and be usable by any
+ // compatible affinity.
+ //
+ // If there are more than one producer of the value (such as multiple callers)
+ // with differing affinities or analysis fails then no affinity is returned.
+ // If all affinities are compatible one will be chosen in an unspecified way.
+ IREE::Stream::AffinityAttr lookupResourceAffinity(Value value);
+
+ // Populates all potential affinities of |value| in |affinities|.
+ // Returns false if analysis failed and the set of affinities is unknown.
+ bool tryLookupResourceAffinity(
+ Value value, SmallVectorImpl<IREE::Stream::AffinityAttr> &affinities);
+
+private:
+ Explorer explorer;
+ llvm::BumpPtrAllocator allocator;
+ DFX::Solver solver;
+};
+
+} // namespace mlir::iree_compiler::IREE::Stream
+
+#endif // IREE_COMPILER_DIALECT_STREAM_ANALYSIS_AFFINITY_H_
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Analysis/BUILD.bazel
index 4e1421b..3cbb5b5 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/BUILD.bazel
@@ -15,12 +15,14 @@
iree_compiler_cc_library(
name = "Analysis",
srcs = [
+ "Affinity.cpp",
"Partitioning.cpp",
"Partitioning/ReferencePartitioning.cpp",
"ResourceHazards.cpp",
"ResourceUsage.cpp",
],
hdrs = [
+ "Affinity.h",
"Partitioning.h",
"ResourceHazards.h",
"ResourceUsage.h",
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Analysis/CMakeLists.txt
index f1b0fc8..c2dd74c 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/CMakeLists.txt
@@ -14,10 +14,12 @@
NAME
Analysis
HDRS
+ "Affinity.h"
"Partitioning.h"
"ResourceHazards.h"
"ResourceUsage.h"
SRCS
+ "Affinity.cpp"
"Partitioning.cpp"
"Partitioning/ReferencePartitioning.cpp"
"ResourceHazards.cpp"
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning.cpp
index 5ed2ff8..93fcd37 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning.cpp
@@ -58,9 +58,9 @@
for (auto *op : ops) {
if (auto affinityOp = dyn_cast<IREE::Stream::AffinityOpInterface>(op)) {
if (!IREE::Stream::AffinityAttr::areCompatible(
- affinity, affinityOp.getAffinity())) {
+ affinity, affinityOp.getAffinityAttr())) {
return op->emitError("op affinity ")
- << affinityOp.getAffinity()
+ << affinityOp.getAffinityAttr()
<< " is not compatible with the partition affinity " << affinity;
}
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp
index b86ff61..a4fff96 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp
@@ -54,8 +54,8 @@
DenseSet<Operation *> clonedOps;
void insert(Operation *op) {
if (auto affinityOp = dyn_cast<IREE::Stream::AffinityOpInterface>(op)) {
- affinity = affinity ? affinity.joinAND(affinityOp.getAffinity())
- : affinityOp.getAffinity();
+ affinity = affinity ? affinity.joinAND(affinityOp.getAffinityAttr())
+ : affinityOp.getAffinityAttr();
}
ops.insert(op);
}
@@ -109,7 +109,7 @@
IREE::Stream::AffinityAttr affinityAttr;
if (auto affinityOp = dyn_cast<IREE::Stream::AffinityOpInterface>(op)) {
- affinityAttr = affinityOp.getAffinity();
+ affinityAttr = affinityOp.getAffinityAttr();
}
LLVM_DEBUG({
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp
index 7728ce8..4ff656c 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp
@@ -11,15 +11,12 @@
#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
#include "iree/compiler/Dialect/Util/Analysis/DFX/Element.h"
-#include "iree/compiler/Dialect/Util/Analysis/DFX/Solver.h"
#include "iree/compiler/Dialect/Util/Analysis/DFX/State.h"
-#include "iree/compiler/Dialect/Util/Analysis/Explorer.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@@ -418,10 +415,14 @@
// TODO(benvanik): remove kFavorTransients.
bool isSourceExternal = !sourceUsage.isAssumed(NOT_EXTERNAL);
bool isTargetInternal = isAssumed(NOT_EXTERNAL);
- if (kFavorTransients && isSourceExternal && isTargetInternal) {
+ bool deviceChange =
+ op.getSourceAffinityAttr() != op.getResultAffinityAttr();
+ if ((kFavorTransients || deviceChange) && isSourceExternal &&
+ isTargetInternal) {
LLVM_DEBUG({
- llvm::dbgs() << "[ValueResourceUsage] skipping forward prop of "
- "external into internal: ";
+ llvm::dbgs()
+ << "[ValueResourceUsage] skipping forward prop of external "
+ "into internal due to kFavorTransients/device-change: ";
op.print(llvm::dbgs(), solver.getAsmState());
llvm::dbgs() << "\n";
});
@@ -531,7 +532,6 @@
*this,
Position::forValue(op.getBeforeBody()->getArgument(operandIdx)),
DFX::Resolution::REQUIRED);
-
getState() ^= beforeUsage.getState();
})
.Case([&](mlir::scf::ConditionOp op) {
@@ -564,29 +564,30 @@
Position::forValue(op->getParentOp()->getResult(operandIdx)),
DFX::Resolution::REQUIRED);
getState() ^= parentUsage.getState();
- } else if (auto whileOp =
- dyn_cast_or_null<scf::WhileOp>(op->getParentOp())) {
+ } else if (auto whileOp = dyn_cast<scf::WhileOp>(op->getParentOp())) {
auto value =
Position::forValue(whileOp.getBefore().getArgument(operandIdx));
auto &valueUsage = solver.getElementFor<ValueResourceUsage>(
*this, value, DFX::Resolution::REQUIRED);
getState() ^= valueUsage.getState();
- } else if (auto forOp =
- dyn_cast_or_null<scf::ForOp>(op->getParentOp())) {
+ auto &parentUsage = solver.getElementFor<ValueResourceUsage>(
+ *this, Position::forValue(whileOp->getResult(operandIdx)),
+ DFX::Resolution::REQUIRED);
+ getState() ^= parentUsage.getState();
+ } else if (auto forOp = dyn_cast<scf::ForOp>(op->getParentOp())) {
auto value = Position::forValue(forOp.getRegionIterArg(operandIdx));
auto &valueUsage = solver.getElementFor<ValueResourceUsage>(
*this, value, DFX::Resolution::REQUIRED);
getState() ^= valueUsage.getState();
-
auto &parentUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(forOp->getResult(operandIdx)),
DFX::Resolution::REQUIRED);
getState() ^= parentUsage.getState();
} else {
- assert(false && "Unsupported test case");
+ assert(false && "unhandled scf yield parent");
}
})
- .Case([&](mlir::func::ReturnOp op) {
+ .Case([&](IREE::Util::ReturnOp op) {
auto &operandUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getOperand(operandIdx)),
DFX::Resolution::REQUIRED);
@@ -736,11 +737,14 @@
// TODO(benvanik): remove kFavorTransients.
bool isSourceInternal = isAssumed(NOT_EXTERNAL);
bool isTargetExternal = !resultUsage.isAssumed(NOT_EXTERNAL);
- if (kFavorTransients && isSourceInternal && isTargetExternal) {
+ bool deviceChange =
+ op.getSourceAffinityAttr() != op.getResultAffinityAttr();
+ if ((kFavorTransients || deviceChange) && isSourceInternal &&
+ isTargetExternal) {
LLVM_DEBUG({
llvm::dbgs()
<< "[ValueResourceUsage] skipping back prop of external into "
- "internal due to kFavorTransients: ";
+ "internal due to kFavorTransients/device-change: ";
op.print(llvm::dbgs(), solver.getAsmState());
llvm::dbgs() << "\n";
});
@@ -867,11 +871,8 @@
// });
// Initialize all SSA values we can do just with trivial search.
- explorer.walkValues([&](Value value) {
- if (llvm::isa<IREE::Stream::ResourceType>(value.getType())) {
- solver.getOrCreateElementFor<ValueResourceUsage>(
- Position::forValue(value));
- }
+ explorer.walkValuesOfType<IREE::Stream::ResourceType>([&](Value value) {
+ solver.getOrCreateElementFor<ValueResourceUsage>(Position::forValue(value));
return WalkResult::advance();
});
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Conversion/BUILD.bazel
index fbc0e51..10d1456 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/BUILD.bazel
@@ -22,6 +22,7 @@
],
deps = [
"//compiler/src/iree/compiler/Dialect/Flow/IR",
+ "//compiler/src/iree/compiler/Dialect/Stream/Analysis",
"//compiler/src/iree/compiler/Dialect/Stream/IR",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FunctionInterfaces",
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Conversion/CMakeLists.txt
index 05bbb79..cc472aa 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/CMakeLists.txt
@@ -24,6 +24,7 @@
MLIRTransformUtils
MLIRTransforms
iree::compiler::Dialect::Flow::IR
+ iree::compiler::Dialect::Stream::Analysis
iree::compiler::Dialect::Stream::IR
PUBLIC
)
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
index 060aeb8..31d6151 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
@@ -25,51 +25,52 @@
// size of operands must be queried from the input resource.
static Value buildResultSizeOf(Location loc, Value tensorValue,
ValueRange dynamicDims,
+ IREE::Stream::AffinityAttr affinityAttr,
ConversionPatternRewriter &rewriter) {
// TODO(benvanik): see if we can stash this on the side to avoid expensive
// materialization of a bunch of redundant IR.
- return rewriter.createOrFold<IREE::Stream::TensorSizeOfOp>(
+ return rewriter.create<IREE::Stream::TensorSizeOfOp>(
loc, rewriter.getIndexType(), TypeAttr::get(tensorValue.getType()),
- dynamicDims,
- IREE::Stream::AffinityAttr::lookup(tensorValue.getDefiningOp()));
+ dynamicDims, affinityAttr);
}
struct ConvertTensorConstantOp
- : public OpConversionPattern<IREE::Flow::TensorConstantOp> {
+ : public AffinityOpConversionPattern<IREE::Flow::TensorConstantOp> {
public:
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Flow::TensorConstantOp constantOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::Flow::TensorConstantOp constantOp, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
// Capture the tensor constant strongly typed with constant lifetime.
- Type constantType = IREE::Stream::ResourceType::get(
- getContext(), IREE::Stream::Lifetime::Constant);
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(constantOp);
+ auto constantType = rewriter.getType<IREE::Stream::ResourceType>(
+ IREE::Stream::Lifetime::Constant);
auto newOp = rewriter.create<IREE::Stream::TensorConstantOp>(
constantOp.getLoc(), constantType,
convertAttributeToStream(constantOp.getValue()),
- TypeAttr::get(constantOp.getType()), ValueRange{}, affinityAttr);
+ TypeAttr::get(constantOp.getType()), ValueRange{},
+ executionAffinityAttr);
// Transfer to unknown lifetime.
- Type unknownType = IREE::Stream::ResourceType::get(getContext());
+ auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
auto constantSize = rewriter.createOrFold<IREE::Stream::ResourceSizeOp>(
constantOp.getLoc(), rewriter.getIndexType(), newOp.getResult());
rewriter.replaceOpWithNewOp<IREE::Stream::AsyncTransferOp>(
constantOp, unknownType, newOp.getResult(), constantSize, constantSize,
- /*source_affinity=*/affinityAttr,
- /*result_affinity=*/affinityAttr);
+ /*source_affinity=*/executionAffinityAttr,
+ /*result_affinity=*/executionAffinityAttr);
return success();
}
};
struct ConvertTensorDynamicConstantOp
- : public OpConversionPattern<IREE::Flow::TensorDynamicConstantOp> {
+ : public AffinityOpConversionPattern<IREE::Flow::TensorDynamicConstantOp> {
public:
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Flow::TensorDynamicConstantOp constantOp,
- OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::Flow::TensorDynamicConstantOp constantOp, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
auto attrType = dyn_cast<RankedTensorType>(constantOp.getValue().getType());
if (!attrType)
return failure();
@@ -91,22 +92,21 @@
}
// Capture the tensor constant strongly typed with constant lifetime.
- Type constantType = IREE::Stream::ResourceType::get(
- getContext(), IREE::Stream::Lifetime::Constant);
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(constantOp);
+ auto constantType = rewriter.getType<IREE::Stream::ResourceType>(
+ IREE::Stream::Lifetime::Constant);
auto newOp = rewriter.create<IREE::Stream::TensorConstantOp>(
constantOp.getLoc(), constantType,
convertAttributeToStream(constantOp.getValue()),
- TypeAttr::get(resultType), dynamicDims, affinityAttr);
+ TypeAttr::get(resultType), dynamicDims, executionAffinityAttr);
// Transfer to unknown lifetime.
- Type unknownType = IREE::Stream::ResourceType::get(getContext());
+ auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
auto constantSize = rewriter.createOrFold<IREE::Stream::ResourceSizeOp>(
constantOp.getLoc(), rewriter.getIndexType(), newOp.getResult());
rewriter.replaceOpWithNewOp<IREE::Stream::AsyncTransferOp>(
constantOp, unknownType, newOp.getResult(), constantSize, constantSize,
- /*source_affinity=*/affinityAttr,
- /*result_affinity=*/affinityAttr);
+ /*source_affinity=*/executionAffinityAttr,
+ /*result_affinity=*/executionAffinityAttr);
return success();
}
};
@@ -114,202 +114,300 @@
// Reshapes and bitcasts become clones here to preserve shape/element type
// information (which may become actual transfers depending on source/target
// shape) - they'll be elided if not needed.
+//
+// NOTE: we transfer to the target before cloning. This may not be optimal
+// as the clone may otherwise have been able to be elided on the producer
+// side but we leave that for future copy elision to determine.
template <typename CastOpTy>
-struct ConvertTensorCastLikeOp : public OpConversionPattern<CastOpTy> {
- using OpConversionPattern<CastOpTy>::OpConversionPattern;
+struct ConvertTensorCastLikeOp
+ : public AffinityAwareConversionPattern<CastOpTy> {
+ using AffinityAwareConversionPattern<
+ CastOpTy>::AffinityAwareConversionPattern;
LogicalResult
matchAndRewrite(CastOpTy op, typename CastOpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ auto resultAffinityAttr = this->lookupResultAffinity(op.getResult());
+ auto source = this->transferTensorOperand(op.getLoc(), op.getSource(),
+ adaptor.getSource(),
+ resultAffinityAttr, rewriter);
+ auto resultSize =
+ buildResultSizeOf(op.getLoc(), op.getResult(), op.getResultDims(),
+ resultAffinityAttr, rewriter);
auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
- auto source =
- consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter);
- auto resultSize = buildResultSizeOf(op.getLoc(), op.getResult(),
- op.getResultDims(), rewriter);
rewriter.replaceOpWithNewOp<IREE::Stream::TensorCloneOp>(
op, unknownType, source.resource, op.getSource().getType(),
op.getSourceDims(), source.resourceSize, op.getResult().getType(),
- adaptor.getResultDims(), resultSize,
- IREE::Stream::AffinityAttr::lookup(op));
+ adaptor.getResultDims(), resultSize, resultAffinityAttr);
return success();
}
};
struct ConvertTensorAllocaOp
- : public OpConversionPattern<IREE::Flow::TensorAllocaOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Flow::TensorAllocaOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Type unknownType = IREE::Stream::ResourceType::get(getContext());
- auto resultSize = buildResultSizeOf(op.getLoc(), op.getResult(),
- op.getResultDims(), rewriter);
+ : public AffinityOpConversionPattern<IREE::Flow::TensorAllocaOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::Flow::TensorAllocaOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
+ auto resultSize =
+ buildResultSizeOf(op.getLoc(), op.getResult(), op.getResultDims(),
+ executionAffinityAttr, rewriter);
+ auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
rewriter.replaceOpWithNewOp<IREE::Stream::AsyncAllocaOp>(
- op, unknownType, resultSize, IREE::Stream::AffinityAttr::lookup(op));
+ op, unknownType, resultSize, executionAffinityAttr);
return success();
}
};
struct ConvertTensorEmptyOp
- : public OpConversionPattern<IREE::Flow::TensorEmptyOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Flow::TensorEmptyOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Type unknownType = IREE::Stream::ResourceType::get(getContext());
- auto resultSize = buildResultSizeOf(op.getLoc(), op.getResult(),
- op.getResultDims(), rewriter);
+ : public AffinityOpConversionPattern<IREE::Flow::TensorEmptyOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::Flow::TensorEmptyOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
+ auto resultSize =
+ buildResultSizeOf(op.getLoc(), op.getResult(), op.getResultDims(),
+ executionAffinityAttr, rewriter);
+ auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
rewriter.replaceOpWithNewOp<IREE::Stream::TensorEmptyOp>(
op, unknownType, op.getResult().getType(), adaptor.getResultDims(),
- resultSize, IREE::Stream::AffinityAttr::lookup(op));
+ resultSize, executionAffinityAttr);
return success();
}
};
struct ConvertTensorSplatOp
- : public OpConversionPattern<IREE::Flow::TensorSplatOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Flow::TensorSplatOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ : public AffinityOpConversionPattern<IREE::Flow::TensorSplatOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::Flow::TensorSplatOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
+ auto resultSize =
+ buildResultSizeOf(op.getLoc(), op.getResult(), op.getResultDims(),
+ executionAffinityAttr, rewriter);
auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
- auto resultSize = buildResultSizeOf(op.getLoc(), op.getResult(),
- op.getResultDims(), rewriter);
rewriter.replaceOpWithNewOp<IREE::Stream::TensorSplatOp>(
op, unknownType, adaptor.getValue(), op.getResult().getType(),
- adaptor.getResultDims(), resultSize,
- IREE::Stream::AffinityAttr::lookup(op));
+ adaptor.getResultDims(), resultSize, executionAffinityAttr);
return success();
}
};
struct ConvertTensorCloneOp
- : public OpConversionPattern<IREE::Flow::TensorCloneOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Flow::TensorCloneOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ : public AffinityOpConversionPattern<IREE::Flow::TensorCloneOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::Flow::TensorCloneOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
+ auto operand = transferTensorOperand(op.getLoc(), op.getOperand(),
+ adaptor.getOperand(),
+ executionAffinityAttr, rewriter);
auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
- auto operand =
- consumeTensorOperand(op.getLoc(), adaptor.getOperand(), rewriter);
rewriter.replaceOpWithNewOp<IREE::Stream::TensorCloneOp>(
op, unknownType, operand.resource, op.getOperand().getType(),
op.getArgumentDims(), operand.resourceSize, op.getResult().getType(),
- adaptor.getArgumentDims(), operand.resourceSize,
- IREE::Stream::AffinityAttr::lookup(op));
+ adaptor.getArgumentDims(), operand.resourceSize, executionAffinityAttr);
+ return success();
+ }
+};
+
+struct ConvertTensorTransferOp
+ : public AffinityOpConversionPattern<IREE::Flow::TensorTransferOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::Flow::TensorTransferOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!executionAffinityAttr) {
+ return rewriter.notifyMatchFailure(op, "invalid stream affinity attr");
+ }
+ auto operand = resolveTensorOperand(op.getLoc(), op.getOperand(),
+ adaptor.getOperand(), rewriter);
+ auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
+ rewriter.replaceOpWithNewOp<IREE::Stream::AsyncTransferOp>(
+ op, unknownType, operand.resource, operand.resourceSize,
+ operand.resourceSize,
+ /*source_affinity=*/operand.affinity,
+ /*result_affinity=*/executionAffinityAttr);
return success();
}
};
struct ConvertTensorSliceOp
- : public OpConversionPattern<IREE::Flow::TensorSliceOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Flow::TensorSliceOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
+ : public AffinityOpConversionPattern<IREE::Flow::TensorSliceOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::Flow::TensorSliceOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
auto source =
- consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter);
- auto resultSize = buildResultSizeOf(op.getLoc(), op.getResult(),
- op.getResultDims(), rewriter);
+ transferTensorOperand(op.getLoc(), op.getSource(), adaptor.getSource(),
+ executionAffinityAttr, rewriter);
+ auto resultSize =
+ buildResultSizeOf(op.getLoc(), op.getResult(), op.getResultDims(),
+ executionAffinityAttr, rewriter);
+ auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
rewriter.replaceOpWithNewOp<IREE::Stream::TensorSliceOp>(
op, unknownType, source.resource, op.getSource().getType(),
op.getSourceDims(), source.resourceSize, adaptor.getStartIndices(),
adaptor.getLengths(), op.getResult().getType(), adaptor.getResultDims(),
- resultSize, IREE::Stream::AffinityAttr::lookup(op));
+ resultSize, executionAffinityAttr);
return success();
}
};
struct ConvertTensorUpdateOp
- : public OpConversionPattern<IREE::Flow::TensorUpdateOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Flow::TensorUpdateOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto update =
- consumeTensorOperand(op.getLoc(), adaptor.getUpdate(), rewriter);
+ : public AffinityOpConversionPattern<IREE::Flow::TensorUpdateOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::Flow::TensorUpdateOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
auto target =
- consumeTensorOperand(op.getLoc(), adaptor.getTarget(), rewriter);
+ transferTensorOperand(op.getLoc(), op.getTarget(), adaptor.getTarget(),
+ executionAffinityAttr, rewriter);
+ auto update =
+ transferTensorOperand(op.getLoc(), op.getUpdate(), adaptor.getUpdate(),
+ executionAffinityAttr, rewriter);
rewriter.replaceOpWithNewOp<IREE::Stream::TensorUpdateOp>(
op, target.resource.getType(), target.resource,
op.getTarget().getType(), adaptor.getTargetDims(), target.resourceSize,
adaptor.getStartIndices(), update.resource, op.getUpdate().getType(),
- op.getUpdateDims(), update.resourceSize,
- IREE::Stream::AffinityAttr::lookup(op));
+ op.getUpdateDims(), update.resourceSize, executionAffinityAttr);
return success();
}
};
+static bool isScalarTensor(RankedTensorType type) {
+ if (type.getRank() == 0)
+ return true; // tensor<i32>
+ if (!type.hasStaticShape())
+ return false; // tensor<...?...xi32>
+ int64_t elementCount = 1;
+ for (int64_t dim : type.getShape())
+ elementCount *= dim;
+ return elementCount == 1; // tensor<1xi32> or tensor<1x1x1xi32>
+}
+
struct ConvertTensorLoadOp
- : public OpConversionPattern<IREE::Flow::TensorLoadOp> {
- using OpConversionPattern::OpConversionPattern;
+ : public AffinityAwareConversionPattern<IREE::Flow::TensorLoadOp> {
+ using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
LogicalResult
matchAndRewrite(IREE::Flow::TensorLoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto resultType = getTypeConverter()->convertType(op.getResult().getType());
- auto source =
- consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter);
+ auto source = resolveTensorOperand(op.getLoc(), op.getSource(),
+ adaptor.getSource(), rewriter);
+ // If the source is not a staging resource then we need to transfer it to
+ // a staging resource. We slice out just what is being loaded so that we
+ // don't transfer the entire tensor. If loading multiple values from the
+ // same tensor we'll either want to have batched that before this point
+ // by loading an entire buffer or after by coalescing the slices.
+ //
+ // If already a staging resource then we can fast-path load the value.
auto stagingType = rewriter.getType<IREE::Stream::ResourceType>(
IREE::Stream::Lifetime::Staging);
- auto loadSource = source.resource;
- if (source.resource.getType() != stagingType) {
- loadSource = rewriter.createOrFold<IREE::Stream::AsyncTransferOp>(
- op.getLoc(), stagingType, source.resource, source.resourceSize,
- source.resourceSize,
- /*source_affinity=*/IREE::Stream::AffinityAttr::lookup(op),
- /*result_affinity=*/nullptr);
+ auto resultType = getTypeConverter()->convertType(op.getResult().getType());
+ if (source.resource.getType() == stagingType) {
+ rewriter.replaceOpWithNewOp<IREE::Stream::TensorLoadOp>(
+ op, resultType, source.resource, op.getSource().getType(),
+ adaptor.getSourceDims(), source.resourceSize, adaptor.getIndices());
+ return success();
}
+ // Scalar tensors get transferred without slicing.
+ auto sourceEncoding = op.getSource().getType();
+ if (isScalarTensor(sourceEncoding)) {
+ auto transferOp = rewriter.create<IREE::Stream::AsyncTransferOp>(
+ op.getLoc(), stagingType, source.resource, source.resourceSize,
+ source.resourceSize,
+ /*source_affinity=*/source.affinity,
+ /*result_affinity=*/source.affinity);
+ rewriter.replaceOpWithNewOp<IREE::Stream::TensorLoadOp>(
+ op, resultType, transferOp.getResult(), sourceEncoding,
+ adaptor.getSourceDims(), transferOp.getResultSize(),
+ adaptor.getIndices());
+ return success();
+ }
+
+ // Slice out the individual element value.
+ IndexSet indexSet(op.getLoc(), rewriter);
+ indexSet.populate(adaptor.getIndices());
+ SmallVector<Value> sliceIndices;
+ SmallVector<Value> sliceLengths;
+ SmallVector<Value> loadIndices;
+ SmallVector<int64_t> resultDims;
+ for (auto index : adaptor.getIndices()) {
+ // TODO(benvanik): support larger buffer slices.
+ sliceIndices.push_back(index);
+ sliceLengths.push_back(indexSet.get(1));
+ loadIndices.push_back(indexSet.get(0));
+ resultDims.push_back(1);
+ }
+ auto resultEncoding =
+ RankedTensorType::get(resultDims, sourceEncoding.getElementType(),
+ sourceEncoding.getEncoding());
+ Value resultSize = rewriter.create<IREE::Stream::TensorSizeOfOp>(
+ op.getLoc(), resultEncoding, ValueRange{}, source.affinity);
+ auto sliceOp = rewriter.create<IREE::Stream::TensorSliceOp>(
+ op.getLoc(), source.resource.getType(), source.resource, sourceEncoding,
+ adaptor.getSourceDims(), source.resourceSize, sliceIndices,
+ sliceLengths, resultEncoding, ValueRange{}, resultSize,
+ source.affinity);
+ auto transferOp = rewriter.create<IREE::Stream::AsyncTransferOp>(
+ op.getLoc(), stagingType, sliceOp.getResult(), sliceOp.getResultSize(),
+ sliceOp.getResultSize(),
+ /*source_affinity=*/source.affinity,
+ /*result_affinity=*/source.affinity);
rewriter.replaceOpWithNewOp<IREE::Stream::TensorLoadOp>(
- op, resultType, loadSource, op.getSource().getType(),
- op.getSourceDims(), source.resourceSize, adaptor.getIndices());
+ op, resultType, transferOp.getResult(), sliceOp.getResultEncoding(),
+ sliceOp.getResultEncodingDims(), transferOp.getResultSize(),
+ loadIndices);
return success();
}
};
struct ConvertTensorStoreOp
- : public OpConversionPattern<IREE::Flow::TensorStoreOp> {
- using OpConversionPattern::OpConversionPattern;
+ : public AffinityAwareConversionPattern<IREE::Flow::TensorStoreOp> {
+ using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
LogicalResult
matchAndRewrite(IREE::Flow::TensorStoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto target =
- consumeTensorOperand(op.getLoc(), adaptor.getTarget(), rewriter);
+ auto target = resolveTensorOperand(op.getLoc(), op.getTarget(),
+ adaptor.getTarget(), rewriter);
+ // If the target is a staging resource then we can directly store into it
+ // with a fast-path. Otherwise we need to stage an upload.
auto stagingType = rewriter.getType<IREE::Stream::ResourceType>(
IREE::Stream::Lifetime::Staging);
- auto storeTarget = target.resource;
- if (target.resource.getType() != stagingType) {
- storeTarget = rewriter.createOrFold<IREE::Stream::AsyncTransferOp>(
- op.getLoc(), stagingType, storeTarget, target.resourceSize,
- target.resourceSize,
- /*source_affinity=*/IREE::Stream::AffinityAttr::lookup(op),
- /*result_affinity=*/nullptr);
+ if (target.resource.getType() == stagingType) {
+ rewriter.replaceOpWithNewOp<IREE::Stream::TensorStoreOp>(
+ op, target.resource.getType(), target.resource,
+ op.getTarget().getType(), adaptor.getTargetDims(),
+ target.resourceSize, adaptor.getIndices(), adaptor.getValue());
+ return success();
}
- auto newOp = rewriter.create<IREE::Stream::TensorStoreOp>(
- op.getLoc(), storeTarget.getType(), storeTarget,
- op.getTarget().getType(), adaptor.getTargetDims(), target.resourceSize,
- adaptor.getIndices(), adaptor.getValue());
-
- Value newResult = newOp.getResult();
- if (target.resource.getType() != stagingType) {
- newResult = rewriter.createOrFold<IREE::Stream::AsyncTransferOp>(
- op.getLoc(), target.resource.getType(), newResult,
- target.resourceSize, target.resourceSize,
- /*source_affinity=*/nullptr,
- /*result_affinity=*/IREE::Stream::AffinityAttr::lookup(op));
- }
- rewriter.replaceOp(op, {newResult});
-
+ // Use fill to store the value.
+ // TODO(benvanik): support larger buffer slices (stage + update).
+ IndexSet indexSet(op.getLoc(), rewriter);
+ indexSet.populate(adaptor.getIndices());
+ SmallVector<Value> lengths(adaptor.getIndices().size(), indexSet.get(1));
+ auto targetEncoding = op.getTarget().getType();
+ rewriter.replaceOpWithNewOp<IREE::Stream::TensorFillOp>(
+ op, target.resource, targetEncoding, adaptor.getTargetDims(),
+ target.resourceSize, adaptor.getIndices(), lengths, adaptor.getValue(),
+ target.affinity);
return success();
}
};
struct ConvertTensorTraceOp
- : public OpConversionPattern<IREE::Flow::TensorTraceOp> {
- using OpConversionPattern::OpConversionPattern;
+ : public AffinityAwareConversionPattern<IREE::Flow::TensorTraceOp> {
+ using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
LogicalResult
matchAndRewrite(IREE::Flow::TensorTraceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -318,8 +416,8 @@
SmallVector<Attribute> resourceEncodings;
for (auto [tensorOperand, resourceOperand] :
llvm::zip_equal(op.getValues(), adaptor.getValues())) {
- auto source =
- consumeTensorOperand(op.getLoc(), resourceOperand, rewriter);
+ auto source = resolveTensorOperand(op.getLoc(), tensorOperand,
+ resourceOperand, rewriter);
auto stagingType = rewriter.getType<IREE::Stream::ResourceType>(
IREE::Stream::Lifetime::Staging);
auto traceSource = source.resource;
@@ -327,13 +425,14 @@
traceSource = rewriter.create<IREE::Stream::AsyncTransferOp>(
op.getLoc(), stagingType, source.resource, source.resourceSize,
source.resourceSize,
- /*source_affinity=*/IREE::Stream::AffinityAttr::lookup(op),
- /*result_affinity=*/nullptr);
+ /*source_affinity=*/source.affinity,
+ /*result_affinity=*/source.affinity);
}
resources.push_back(traceSource);
resourceSizes.push_back(source.resourceSize);
resourceEncodings.push_back(TypeAttr::get(tensorOperand.getType()));
}
+
rewriter.replaceOpWithNewOp<IREE::Stream::TensorTraceOp>(
op, adaptor.getKey(), resources, resourceSizes,
rewriter.getArrayAttr(resourceEncodings), adaptor.getValueDims());
@@ -342,16 +441,18 @@
};
struct ConvertChannelDefaultOp
- : public OpConversionPattern<IREE::Flow::ChannelDefaultOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Flow::ChannelDefaultOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ : public AffinityOpConversionPattern<IREE::Flow::ChannelDefaultOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::Flow::ChannelDefaultOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<IREE::Stream::ChannelCreateOp>(
- op, /*id=*/Value{},
+ op,
+ /*id=*/Value{},
/*group=*/adaptor.getGroupAttr(),
/*rank=*/Value{},
- /*count=*/Value{}, IREE::Stream::AffinityAttr::lookup(op));
+ /*count=*/Value{}, executionAffinityAttr);
return success();
}
};
@@ -393,164 +494,190 @@
};
struct ConvertAllGatherOp
- : public OpConversionPattern<IREE::Flow::CollectiveAllGatherOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Flow::CollectiveAllGatherOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto shape = llvm::cast<ShapedType>(op.getSource().getType());
- auto collectiveAttr = IREE::Stream::CollectiveAttr::get(
- op.getContext(), IREE::Stream::CollectiveKind::AllGather,
+ : public AffinityOpConversionPattern<IREE::Flow::CollectiveAllGatherOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::Flow::CollectiveAllGatherOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
+ auto collectiveAttr = rewriter.getAttr<IREE::Stream::CollectiveAttr>(
+ IREE::Stream::CollectiveKind::AllGather,
/*reduction=*/std::nullopt,
static_cast<IREE::Stream::CollectiveElementType>(op.getElementType()));
auto zeroOffset = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
auto elementCount = rewriter.create<arith::ConstantIndexOp>(
- op.getLoc(), shape.getNumElements());
+ op.getLoc(), op.getType().getNumElements());
auto newTargetCast =
- consumeTensorOperand(op.getLoc(), adaptor.getTarget(), rewriter);
+ transferTensorOperand(op.getLoc(), op.getTarget(), adaptor.getTarget(),
+ executionAffinityAttr, rewriter);
auto newSourceCast =
- consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter);
+ transferTensorOperand(op.getLoc(), op.getSource(), adaptor.getSource(),
+ executionAffinityAttr, rewriter);
rewriter.replaceOpWithNewOp<IREE::Stream::AsyncCollectiveOp>(
- op, collectiveAttr, newTargetCast.resource,
+ op, collectiveAttr,
+ /*target=*/newTargetCast.resource,
/*target_size=*/newTargetCast.resourceSize,
/*target_offset=*/zeroOffset,
/*target_end=*/newTargetCast.resourceSize,
- /*target_length=*/newTargetCast.resourceSize, newSourceCast.resource,
+ /*target_length=*/newTargetCast.resourceSize,
+ /*source=*/newSourceCast.resource,
/*source_size=*/newSourceCast.resourceSize,
- /*source_offset=*/zeroOffset, /*source_end=*/newSourceCast.resourceSize,
- /*source_length=*/newSourceCast.resourceSize, elementCount,
- adaptor.getChannel(),
- /*param=*/mlir::Value(), IREE::Stream::AffinityAttr::lookup(op));
+ /*source_offset=*/zeroOffset,
+ /*source_end=*/newSourceCast.resourceSize,
+ /*source_length=*/newSourceCast.resourceSize,
+ /*element_count=*/elementCount,
+ /*channel=*/adaptor.getChannel(),
+ /*param=*/mlir::Value(), executionAffinityAttr);
return success();
}
};
struct ConvertAllReduceOp
- : public OpConversionPattern<IREE::Flow::CollectiveAllReduceOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Flow::CollectiveAllReduceOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto shape = llvm::cast<ShapedType>(op.getType());
- auto collectiveAttr = IREE::Stream::CollectiveAttr::get(
- op.getContext(), IREE::Stream::CollectiveKind::AllReduce,
+ : public AffinityOpConversionPattern<IREE::Flow::CollectiveAllReduceOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::Flow::CollectiveAllReduceOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
+ auto collectiveAttr = rewriter.getAttr<IREE::Stream::CollectiveAttr>(
+ IREE::Stream::CollectiveKind::AllReduce,
static_cast<IREE::Stream::CollectiveReductionOp>(op.getReductionOp()),
static_cast<IREE::Stream::CollectiveElementType>(op.getElementType()));
auto zeroOffset = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
auto elementCount = rewriter.create<arith::ConstantIndexOp>(
- op.getLoc(), shape.getNumElements());
+ op.getLoc(), op.getType().getNumElements());
auto newTargetCast =
- consumeTensorOperand(op.getLoc(), adaptor.getTarget(), rewriter);
+ transferTensorOperand(op.getLoc(), op.getTarget(), adaptor.getTarget(),
+ executionAffinityAttr, rewriter);
auto newSourceCast =
- consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter);
+ transferTensorOperand(op.getLoc(), op.getSource(), adaptor.getSource(),
+ executionAffinityAttr, rewriter);
rewriter.replaceOpWithNewOp<IREE::Stream::AsyncCollectiveOp>(
- op, collectiveAttr, newTargetCast.resource,
+ op, collectiveAttr,
+ /*target=*/newTargetCast.resource,
/*target_size=*/newTargetCast.resourceSize,
/*target_offset=*/zeroOffset,
/*target_end=*/newTargetCast.resourceSize,
- /*target_length=*/newTargetCast.resourceSize, newSourceCast.resource,
+ /*target_length=*/newTargetCast.resourceSize,
+ /*source=*/newSourceCast.resource,
/*source_size=*/newSourceCast.resourceSize,
- /*source_offset=*/zeroOffset, /*source_end=*/newSourceCast.resourceSize,
- /*source_length=*/newSourceCast.resourceSize, elementCount,
- adaptor.getChannel(),
- /*param=*/mlir::Value(), IREE::Stream::AffinityAttr::lookup(op));
+ /*source_offset=*/zeroOffset,
+ /*source_end=*/newSourceCast.resourceSize,
+ /*source_length=*/newSourceCast.resourceSize,
+ /*element_count=*/elementCount,
+ /*channel=*/adaptor.getChannel(),
+ /*param=*/mlir::Value(), executionAffinityAttr);
return success();
}
};
struct ConvertAllToAllOp
- : public OpConversionPattern<IREE::Flow::CollectiveAllToAllOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Flow::CollectiveAllToAllOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto shape = llvm::cast<ShapedType>(op.getSource().getType());
- auto collectiveAttr = IREE::Stream::CollectiveAttr::get(
- op.getContext(), IREE::Stream::CollectiveKind::AllToAll,
+ : public AffinityOpConversionPattern<IREE::Flow::CollectiveAllToAllOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::Flow::CollectiveAllToAllOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
+ auto collectiveAttr = rewriter.getAttr<IREE::Stream::CollectiveAttr>(
+ IREE::Stream::CollectiveKind::AllToAll,
/*reduction=*/std::nullopt,
static_cast<IREE::Stream::CollectiveElementType>(op.getElementType()));
auto zeroOffset = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
auto elementCount = rewriter.create<arith::ConstantIndexOp>(
- op.getLoc(), shape.getNumElements());
+ op.getLoc(), op.getType().getNumElements());
auto newTargetCast =
- consumeTensorOperand(op.getLoc(), adaptor.getTarget(), rewriter);
+ transferTensorOperand(op.getLoc(), op.getTarget(), adaptor.getTarget(),
+ executionAffinityAttr, rewriter);
auto newSourceCast =
- consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter);
+ transferTensorOperand(op.getLoc(), op.getSource(), adaptor.getSource(),
+ executionAffinityAttr, rewriter);
rewriter.replaceOpWithNewOp<IREE::Stream::AsyncCollectiveOp>(
- op, collectiveAttr, newTargetCast.resource,
+ op, collectiveAttr,
+ /*target=*/newTargetCast.resource,
/*target_size=*/newTargetCast.resourceSize,
/*target_offset=*/zeroOffset,
/*target_end=*/newTargetCast.resourceSize,
- /*target_length=*/newTargetCast.resourceSize, newSourceCast.resource,
+ /*target_length=*/newTargetCast.resourceSize,
+ /*source=*/newSourceCast.resource,
/*source_size=*/newSourceCast.resourceSize,
- /*source_offset=*/zeroOffset, /*source_end=*/newSourceCast.resourceSize,
- /*source_length=*/newSourceCast.resourceSize, elementCount,
- adaptor.getChannel(),
- /*param=*/mlir::Value(), IREE::Stream::AffinityAttr::lookup(op));
+ /*source_offset=*/zeroOffset,
+ /*source_end=*/newSourceCast.resourceSize,
+ /*source_length=*/newSourceCast.resourceSize,
+ /*element_count=*/elementCount,
+ /*channel=*/adaptor.getChannel(),
+ /*param=*/mlir::Value(), executionAffinityAttr);
return success();
}
};
-struct ConvertReduceScatterOp
- : public OpConversionPattern<IREE::Flow::CollectiveReduceScatterOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Flow::CollectiveReduceScatterOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto shape = llvm::cast<ShapedType>(op.getType());
- auto collectiveAttr = IREE::Stream::CollectiveAttr::get(
- op.getContext(), IREE::Stream::CollectiveKind::ReduceScatter,
+struct ConvertReduceScatterOp : public AffinityOpConversionPattern<
+ IREE::Flow::CollectiveReduceScatterOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::Flow::CollectiveReduceScatterOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
+ auto collectiveAttr = rewriter.getAttr<IREE::Stream::CollectiveAttr>(
+ IREE::Stream::CollectiveKind::ReduceScatter,
static_cast<IREE::Stream::CollectiveReductionOp>(op.getReductionOp()),
static_cast<IREE::Stream::CollectiveElementType>(op.getElementType()));
auto zeroOffset = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
auto elementCount = rewriter.create<arith::ConstantIndexOp>(
- op.getLoc(), shape.getNumElements());
+ op.getLoc(), op.getType().getNumElements());
auto newTargetCast =
- consumeTensorOperand(op.getLoc(), adaptor.getTarget(), rewriter);
+ transferTensorOperand(op.getLoc(), op.getTarget(), adaptor.getTarget(),
+ executionAffinityAttr, rewriter);
auto newSourceCast =
- consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter);
+ transferTensorOperand(op.getLoc(), op.getSource(), adaptor.getSource(),
+ executionAffinityAttr, rewriter);
rewriter.replaceOpWithNewOp<IREE::Stream::AsyncCollectiveOp>(
- op, collectiveAttr, newTargetCast.resource,
+ op, collectiveAttr,
+ /*target=*/newTargetCast.resource,
/*target_size=*/newTargetCast.resourceSize,
/*target_offset=*/zeroOffset,
/*target_end=*/newTargetCast.resourceSize,
- /*target_length=*/newTargetCast.resourceSize, newSourceCast.resource,
+ /*target_length=*/newTargetCast.resourceSize,
+ /*source=*/newSourceCast.resource,
/*source_size=*/newSourceCast.resourceSize,
- /*source_offset=*/zeroOffset, /*source_end=*/newSourceCast.resourceSize,
- /*source_length=*/newSourceCast.resourceSize, elementCount,
- adaptor.getChannel(),
- /*param=*/mlir::Value(), IREE::Stream::AffinityAttr::lookup(op));
+ /*source_offset=*/zeroOffset,
+ /*source_end=*/newSourceCast.resourceSize,
+ /*source_length=*/newSourceCast.resourceSize,
+ /*element_count=*/elementCount,
+ /*channel=*/adaptor.getChannel(),
+ /*param=*/mlir::Value(), executionAffinityAttr);
return success();
}
};
struct ConvertCollectiveSendRecvOp
- : public OpConversionPattern<IREE::Flow::CollectiveSendRecvOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Flow::CollectiveSendRecvOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto shape = llvm::cast<ShapedType>(op.getType());
- auto collectiveAttr = IREE::Stream::CollectiveAttr::get(
- op.getContext(), IREE::Stream::CollectiveKind::SendRecv,
+ : public AffinityOpConversionPattern<IREE::Flow::CollectiveSendRecvOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::Flow::CollectiveSendRecvOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
+ auto collectiveAttr = rewriter.getAttr<IREE::Stream::CollectiveAttr>(
+ IREE::Stream::CollectiveKind::SendRecv,
/*reduction=*/std::nullopt,
static_cast<IREE::Stream::CollectiveElementType>(op.getElementType()));
auto zeroOffset = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
auto elementCount = rewriter.create<arith::ConstantIndexOp>(
- op.getLoc(), shape.getNumElements());
+ op.getLoc(), op.getType().getNumElements());
auto newTargetCast =
- consumeTensorOperand(op.getLoc(), adaptor.getTarget(), rewriter);
+ transferTensorOperand(op.getLoc(), op.getTarget(), adaptor.getTarget(),
+ executionAffinityAttr, rewriter);
auto newSourceCast =
- consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter);
+ transferTensorOperand(op.getLoc(), op.getSource(), adaptor.getSource(),
+ executionAffinityAttr, rewriter);
// Pack send, recv into param. The values are checked to be within the
// 16-bit range during lowering to Flow dialect.
@@ -567,25 +694,31 @@
auto param = rewriter.create<arith::OrIOp>(op.getLoc(), hi, lo);
rewriter.replaceOpWithNewOp<IREE::Stream::AsyncCollectiveOp>(
- op, collectiveAttr, newTargetCast.resource,
+ op, collectiveAttr,
+ /*target=*/newTargetCast.resource,
/*target_size=*/newTargetCast.resourceSize,
/*target_offset=*/zeroOffset,
/*target_end=*/newTargetCast.resourceSize,
- /*target_length=*/newTargetCast.resourceSize, newSourceCast.resource,
+ /*target_length=*/newTargetCast.resourceSize,
+ /*source=*/newSourceCast.resource,
/*source_size=*/newSourceCast.resourceSize,
- /*source_offset=*/zeroOffset, /*source_end=*/newSourceCast.resourceSize,
- /*source_length=*/newSourceCast.resourceSize, elementCount,
- adaptor.getChannel(),
- /*param=*/param, IREE::Stream::AffinityAttr::lookup(op));
+ /*source_offset=*/zeroOffset,
+ /*source_end=*/newSourceCast.resourceSize,
+ /*source_length=*/newSourceCast.resourceSize,
+ /*element_count=*/elementCount,
+ /*channel=*/adaptor.getChannel(),
+ /*param=*/param, executionAffinityAttr);
return success();
}
};
-struct ConvertDispatchOp : public OpConversionPattern<IREE::Flow::DispatchOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Flow::DispatchOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+struct ConvertDispatchOp
+ : public AffinityOpConversionPattern<IREE::Flow::DispatchOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::Flow::DispatchOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
// Zero is going to be used for each operand to start.
auto zeroOffset = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
@@ -600,7 +733,8 @@
llvm::zip_equal(op.getArguments(), adaptor.getArguments())) {
if (llvm::isa<ShapedType>(oldOperand.getType())) {
auto newOperandCast =
- consumeTensorOperand(op.getLoc(), newOperand, rewriter);
+ transferTensorOperand(op.getLoc(), oldOperand, newOperand,
+ executionAffinityAttr, rewriter);
newOperand = newOperandCast.resource;
dispatchOperandSizes.push_back(newOperandCast.resourceSize);
operandSizes.push_back(newOperandCast.resourceSize);
@@ -632,8 +766,9 @@
} else {
auto resultDynamicDims = IREE::Util::buildDynamicDimsForValue(
op.getLoc(), result.value(), rewriter);
- resultSizes.push_back(buildResultSizeOf(op.getLoc(), result.value(),
- resultDynamicDims, rewriter));
+ resultSizes.push_back(
+ buildResultSizeOf(op.getLoc(), result.value(), resultDynamicDims,
+ executionAffinityAttr, rewriter));
resultTypes.push_back(unknownType);
}
}
@@ -642,7 +777,7 @@
op, resultTypes, adaptor.getWorkload(), adaptor.getEntryPointsAttr(),
dispatchOperands, dispatchOperandSizes, dispatchOperandOffsets,
dispatchOperandEnds, dispatchOperandLengths, resultSizes,
- adaptor.getTiedOperandsAttr(), IREE::Stream::AffinityAttr::lookup(op));
+ adaptor.getTiedOperandsAttr(), executionAffinityAttr);
newOp->setDialectAttrs(op->getDialectAttrs());
return success();
}
@@ -658,8 +793,8 @@
// Tensors become resources without sizes. The default type converter
// adds the size so we bypass that here. We may want to allow the user
// to override the lifetime with attributes, too.
- return IREE::Stream::ResourceType::get(type.getContext(),
- IREE::Stream::Lifetime::Unknown);
+ return rewriter.getType<IREE::Stream::ResourceType>(
+ IREE::Stream::Lifetime::Unknown);
}
return getTypeConverter()->convertType(type);
};
@@ -683,11 +818,12 @@
}
};
-struct ConvertCallOp : public OpConversionPattern<IREE::Flow::CallOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Flow::CallOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+struct ConvertCallOp : public AffinityOpConversionPattern<IREE::Flow::CallOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::Flow::CallOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
// Zero is going to be used for each operand to start.
auto zeroOffset = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
@@ -702,7 +838,8 @@
llvm::zip_equal(op.getArguments(), adaptor.getArguments())) {
if (llvm::isa<ShapedType>(oldOperand.getType())) {
auto newOperandCast =
- consumeTensorOperand(op.getLoc(), newOperand, rewriter);
+ transferTensorOperand(op.getLoc(), oldOperand, newOperand,
+ executionAffinityAttr, rewriter);
newOperand = newOperandCast.resource;
callOperandSizes.push_back(newOperandCast.resourceSize);
operandSizes.push_back(newOperandCast.resourceSize);
@@ -734,8 +871,9 @@
} else {
auto resultDynamicDims = IREE::Util::buildDynamicDimsForValue(
op.getLoc(), result.value(), rewriter);
- resultSizes.push_back(buildResultSizeOf(op.getLoc(), result.value(),
- resultDynamicDims, rewriter));
+ resultSizes.push_back(
+ buildResultSizeOf(op.getLoc(), result.value(), resultDynamicDims,
+ executionAffinityAttr, rewriter));
resultTypes.push_back(unknownType);
}
}
@@ -744,7 +882,7 @@
op, resultTypes, adaptor.getCalleeAttr(), callOperands,
callOperandSizes, callOperandOffsets, callOperandEnds,
callOperandLengths, resultSizes, adaptor.getTiedOperandsAttr(),
- IREE::Stream::AffinityAttr::lookup(op));
+ executionAffinityAttr);
newOp->setDialectAttrs(op->getDialectAttrs());
return success();
}
@@ -961,26 +1099,30 @@
} // namespace
-void populateFlowToStreamConversionPatterns(MLIRContext *context,
- TypeConverter &typeConverter,
- RewritePatternSet &patterns) {
+void populateFlowToStreamConversionPatterns(
+ MLIRContext *context, TypeConverter &typeConverter,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis,
+ RewritePatternSet &patterns) {
patterns
.insert<ConvertTensorConstantOp, ConvertTensorDynamicConstantOp,
ConvertTensorCastLikeOp<IREE::Flow::TensorReshapeOp>,
ConvertTensorCastLikeOp<IREE::Flow::TensorBitCastOp>,
ConvertTensorAllocaOp, ConvertTensorEmptyOp, ConvertTensorSplatOp,
- ConvertTensorCloneOp, ConvertTensorSliceOp, ConvertTensorUpdateOp,
- ConvertTensorLoadOp, ConvertTensorStoreOp, ConvertTensorTraceOp>(
- typeConverter, context);
- patterns.insert<ConvertChannelDefaultOp, ConvertChannelSplitOp,
- ConvertChannelRankOp, ConvertChannelCountOp>(typeConverter,
- context);
+ ConvertTensorCloneOp, ConvertTensorTransferOp,
+ ConvertTensorSliceOp, ConvertTensorUpdateOp, ConvertTensorLoadOp,
+ ConvertTensorStoreOp, ConvertTensorTraceOp>(
+ typeConverter, context, affinityAnalysis);
+ patterns.insert<ConvertChannelDefaultOp>(typeConverter, context,
+ affinityAnalysis);
+ patterns.insert<ConvertChannelSplitOp, ConvertChannelRankOp,
+ ConvertChannelCountOp>(typeConverter, context);
patterns
.insert<ConvertAllGatherOp, ConvertAllReduceOp, ConvertReduceScatterOp,
- ConvertAllToAllOp, ConvertCollectiveSendRecvOp>(typeConverter,
- context);
- patterns.insert<ConvertDispatchOp>(typeConverter, context);
- patterns.insert<ConvertFuncOp, ConvertCallOp>(typeConverter, context);
+ ConvertAllToAllOp, ConvertCollectiveSendRecvOp>(
+ typeConverter, context, affinityAnalysis);
+ patterns.insert<ConvertDispatchOp>(typeConverter, context, affinityAnalysis);
+ patterns.insert<ConvertFuncOp>(typeConverter, context);
+ patterns.insert<ConvertCallOp>(typeConverter, context, affinityAnalysis);
patterns.insert<ConvertExecutableOp>(typeConverter, context);
patterns.insert<
ConvertDispatchWorkgroupInfoOp<IREE::Flow::DispatchWorkgroupIDOp,
@@ -993,10 +1135,11 @@
patterns.insert<ConvertReturnOp>(typeConverter, context);
}
-void populateFlowToStreamConversionPatterns(MLIRContext *context,
- ConversionTarget &conversionTarget,
- TypeConverter &typeConverter,
- RewritePatternSet &patterns) {
+void populateFlowToStreamConversionPatterns(
+ MLIRContext *context, ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis,
+ RewritePatternSet &patterns) {
// Disallow all flow ops besides the ones we pass through (today).
// We don't have a stream-equivalent of several of the dispatch-level flow
// ops as the codegen backends directly touch them and so long as we have both
@@ -1006,7 +1149,8 @@
conversionTarget.addLegalOp<IREE::Stream::ExecutableOp>();
conversionTarget.markOpRecursivelyLegal<IREE::Stream::ExecutableOp>();
- populateFlowToStreamConversionPatterns(context, typeConverter, patterns);
+ populateFlowToStreamConversionPatterns(context, typeConverter,
+ affinityAnalysis, patterns);
}
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.h b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.h
index ad2c95a..0379be2 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.h
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.h
@@ -11,18 +11,24 @@
#include "mlir/IR/OperationSupport.h"
#include "mlir/Transforms/DialectConversion.h"
+namespace mlir::iree_compiler::IREE::Stream {
+class AffinityAnalysis;
+} // namespace mlir::iree_compiler::IREE::Stream
+
namespace mlir::iree_compiler {
// Populates conversion patterns that perform flow->stream conversion.
// These patterns ensure that nested types are run through the provided
// |typeConverter|.
-void populateFlowToStreamConversionPatterns(MLIRContext *context,
- TypeConverter &typeConverter,
- RewritePatternSet &patterns);
-void populateFlowToStreamConversionPatterns(MLIRContext *context,
- ConversionTarget &conversionTarget,
- TypeConverter &typeConverter,
- RewritePatternSet &patterns);
+void populateFlowToStreamConversionPatterns(
+ MLIRContext *context, TypeConverter &typeConverter,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis,
+ RewritePatternSet &patterns);
+void populateFlowToStreamConversionPatterns(
+ MLIRContext *context, ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis,
+ RewritePatternSet &patterns);
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/dispatch_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/dispatch_ops.mlir
index da75704..063389f 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/dispatch_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/dispatch_ops.mlir
@@ -46,18 +46,23 @@
// -----
+util.global private @device_a : !hal.device
+util.global private @device_b : !hal.device
+
// CHECK-LABEL: @dispatchAffinity
// CHECK-SAME: (%[[INPUT:.+]]: !stream.resource<*>, %[[INPUT_SIZE:.+]]: index, %[[DIM1:.+]]: index, %[[DIM3:.+]]: index)
util.func public @dispatchAffinity(%input: tensor<7x?x24x?xf32>, %dim1: index, %dim3: index) -> (tensor<?x?x1024xf32>, tensor<?x?x1024xf32>) {
- // CHECK: %[[RESULT0_SIZE:.+]] = stream.tensor.sizeof on(#hal.affinity.queue<[0]>) tensor<?x?x1024xf32>{%[[DIM1]], %[[DIM3]]}
- // CHECK: %[[RESULT0:.+]] = stream.async.dispatch on(#hal.affinity.queue<[0]>) @ex::@entry0(%[[INPUT]][%c0 to %[[INPUT_SIZE]] for %[[INPUT_SIZE]]])
+ // CHECK: %[[INPUT_A:.+]] = stream.async.transfer %[[INPUT]] : !stream.resource<*>{%[[INPUT_SIZE]]} -> to(#hal.device.affinity<@device_a>) !stream.resource<*>{%[[INPUT_SIZE]]}
+ // CHECK: %[[RESULT0_SIZE:.+]] = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<?x?x1024xf32>{%[[DIM1]], %[[DIM3]]}
+ // CHECK: %[[RESULT0:.+]] = stream.async.dispatch on(#hal.device.affinity<@device_a>) @ex::@entry0(%[[INPUT_A]][%c0 to %[[INPUT_SIZE]] for %[[INPUT_SIZE]]])
%0 = flow.dispatch @ex::@entry0(%input) {
- stream.affinity = #hal.affinity.queue<[0]>
+ stream.affinity = #hal.device.affinity<@device_a>
} : (tensor<7x?x24x?xf32>{%dim1, %dim3}) -> tensor<?x?x1024xf32>{%dim1, %dim3}
- // CHECK: %[[RESULT1_SIZE:.+]] = stream.tensor.sizeof on(#hal.affinity.queue<[1]>) tensor<?x?x1024xf32>{%[[DIM3]], %[[DIM1]]}
- // CHECK: %[[RESULT1:.+]] = stream.async.dispatch on(#hal.affinity.queue<[1]>) @ex::@entry1(%[[INPUT]][%c0 to %[[INPUT_SIZE]] for %[[INPUT_SIZE]]])
+ // CHECK: %[[INPUT_B:.+]] = stream.async.transfer %[[INPUT]] : !stream.resource<*>{%[[INPUT_SIZE]]} -> to(#hal.device.affinity<@device_b>) !stream.resource<*>{%[[INPUT_SIZE]]}
+ // CHECK: %[[RESULT1_SIZE:.+]] = stream.tensor.sizeof on(#hal.device.affinity<@device_b>) tensor<?x?x1024xf32>{%[[DIM3]], %[[DIM1]]}
+ // CHECK: %[[RESULT1:.+]] = stream.async.dispatch on(#hal.device.affinity<@device_b>) @ex::@entry1(%[[INPUT_B]][%c0 to %[[INPUT_SIZE]] for %[[INPUT_SIZE]]])
%1 = flow.dispatch @ex::@entry1(%input) {
- stream.affinity = #hal.affinity.queue<[1]>
+ stream.affinity = #hal.device.affinity<@device_b>
} : (tensor<7x?x24x?xf32>{%dim1, %dim3}) -> tensor<?x?x1024xf32>{%dim3, %dim1}
// return %[[RESULT0]], %[[RESULT0_SIZE]], %[[RESULT1]], %[[RESULT1_SIZE]]
util.return %0, %1 : tensor<?x?x1024xf32>, tensor<?x?x1024xf32>
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir
index 9a1272f..ee68211 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir
@@ -136,6 +136,19 @@
// -----
+util.global private @device : !hal.device
+
+// CHECK-LABEL: @tensorTransfer
+// CHECK-SAME: (%[[INPUT:.+]]: !stream.resource<*>, %[[INPUT_SIZE:.+]]: index, %[[DIM0:.+]]: index)
+util.func public @tensorTransfer(%input: tensor<?x128xi8>, %dim0: index) -> tensor<?x128xi8> {
+ // CHECK: %[[TRANSFER:.+]] = stream.async.transfer %[[INPUT]] : !stream.resource<*>{%[[INPUT_SIZE]]} -> to(#hal.device.affinity<@device>) !stream.resource<*>{%[[INPUT_SIZE]]}
+ %transfer = flow.tensor.transfer %input : tensor<?x128xi8>{%dim0} to #hal.device.affinity<@device>
+ // CHECK: util.return %[[TRANSFER]], %[[INPUT_SIZE]]
+ util.return %transfer : tensor<?x128xi8>
+}
+
+// -----
+
// CHECK-LABEL: @tensorSlice
// CHECK-SAME: (%[[INPUT:.+]]: !stream.resource<*>, %[[INPUT_SIZE:.+]]: index)
util.func public @tensorSlice(%input : tensor<5x24x48xf32>) -> tensor<3x24x48xf32> {
@@ -171,14 +184,28 @@
util.func public @tensorLoad(%source : tensor<2x3xi32>) -> i32 {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
- // CHECK: %[[T0:.+]] = stream.async.transfer
- // CHECK-SAME: %[[SOURCE]] : !stream.resource<*>{%[[SOURCE_SIZE]]}
- // CHECK-SAME: from(#hal.affinity.queue<[0, 1]>) -> !stream.resource<staging>{%[[SOURCE_SIZE]]}
- // CHECK: %[[T1:.+]] = stream.tensor.load %[[T0]][%c0, %c1] : tensor<2x3xi32> in !stream.resource<staging>{%[[SOURCE_SIZE]]} -> i32
- %0 = flow.tensor.load %source[%c0, %c1] : tensor<2x3xi32> attributes {
- stream.affinity = #hal.affinity.queue<[0, 1]>
- }
- // CHECK: util.return %[[T1]]
+ // CHECK: %[[SLICE_SIZE:.+]] = stream.tensor.sizeof tensor<1x1xi32>
+ // CHECK: %[[SLICE:.+]] = stream.tensor.slice %[[SOURCE]][%c0, %c1 for %c1, %c1] : tensor<2x3xi32> in !stream.resource<*>{%[[SOURCE_SIZE]]} -> tensor<1x1xi32> in !stream.resource<*>{%[[SLICE_SIZE]]}
+ // CHECK: %[[STAGING:.+]] = stream.async.transfer
+ // CHECK-SAME: %[[SLICE]] : !stream.resource<*>{%[[SLICE_SIZE]]}
+ // CHECK-SAME: !stream.resource<staging>{%[[SLICE_SIZE]]}
+ // CHECK: %[[VALUE:.+]] = stream.tensor.load %[[STAGING]][%c0, %c0] : tensor<1x1xi32> in !stream.resource<staging>{%[[SLICE_SIZE]]} -> i32
+ %0 = flow.tensor.load %source[%c0, %c1] : tensor<2x3xi32>
+ // CHECK: util.return %[[VALUE]]
+ util.return %0 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @tensorLoadScalar
+// CHECK-SAME: (%[[SOURCE:.+]]: !stream.resource<*>, %[[SOURCE_SIZE:.+]]: index)
+util.func public @tensorLoadScalar(%source : tensor<i32>) -> i32 {
+ // CHECK: %[[STAGING:.+]] = stream.async.transfer
+ // CHECK-SAME: %[[SOURCE]] : !stream.resource<*>{%[[SOURCE_SIZE]]}
+ // CHECK-SAME: !stream.resource<staging>{%[[SOURCE_SIZE]]}
+ // CHECK: %[[VALUE:.+]] = stream.tensor.load %[[STAGING]] : tensor<i32> in !stream.resource<staging>{%[[SOURCE_SIZE]]} -> i32
+ %0 = flow.tensor.load %source : tensor<i32>
+ // CHECK: util.return %[[VALUE]]
util.return %0 : i32
}
@@ -189,22 +216,29 @@
util.func public @tensorStore(%target : tensor<2x3xi32>) -> tensor<2x3xi32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
- %c9 = arith.constant 9 : i32
- // CHECK: %[[T0:.+]] = stream.async.transfer %[[TARGET]] : !stream.resource<*>{%[[TARGET_SIZE]]}
- // CHECK-SAME: from(#hal.affinity.queue<[0, 1]>) -> !stream.resource<staging>{%[[TARGET_SIZE]]}
- // CHECK: %[[T1:.+]] = stream.tensor.store %c9_i32, %[[T0]][%c0, %c1] :
- // CHECK-SAME: i32 -> tensor<2x3xi32> in %[[T0]] as !stream.resource<staging>{%[[TARGET_SIZE]]}
- // CHECK: %[[T2:.+]] = stream.async.transfer %[[T1]] : !stream.resource<staging>{%[[TARGET_SIZE]]} ->
- // CHECK-SAME: to(#hal.affinity.queue<[0, 1]>) !stream.resource<*>{%[[TARGET_SIZE]]}
- %0 = flow.tensor.store %c9, %target[%c0, %c1] : tensor<2x3xi32> attributes {
- stream.affinity = #hal.affinity.queue<[0, 1]>
- }
- // CHECK: util.return %[[T2]]
+ // CHECK: %[[VALUE:.+]] = arith.constant 9
+ %value = arith.constant 9 : i32
+ // CHECK: %[[FILL:.+]] = stream.tensor.fill %[[VALUE]], %[[TARGET]][%c0, %c1 for %c1, %c1] : i32 -> tensor<2x3xi32> in %[[TARGET]] as !stream.resource<*>{%[[TARGET_SIZE]]}
+ %0 = flow.tensor.store %value, %target[%c0, %c1] : tensor<2x3xi32>
+ // CHECK: util.return %[[FILL]]
util.return %0 : tensor<2x3xi32>
}
// -----
+// CHECK-LABEL: @tensorStoreScalar
+// CHECK-SAME: (%[[TARGET:.+]]: !stream.resource<*>, %[[TARGET_SIZE:.+]]: index)
+util.func public @tensorStoreScalar(%target : tensor<i32>) -> tensor<i32> {
+ // CHECK: %[[VALUE:.+]] = arith.constant 9
+ %value = arith.constant 9 : i32
+ // CHECK: %[[FILL:.+]] = stream.tensor.fill %[[VALUE]], %[[TARGET]] : i32 -> tensor<i32> in %[[TARGET]] as !stream.resource<*>{%[[TARGET_SIZE]]}
+ %0 = flow.tensor.store %value, %target : tensor<i32>
+ // CHECK: util.return %[[FILL]]
+ util.return %0 : tensor<i32>
+}
+
+// -----
+
// CHECK-LABEL: @tensorTrace
// CHECK-SAME: (%[[TENSOR0:.+]]: !stream.resource<*>, %[[TENSOR0_SIZE:.+]]: index, %[[TENSOR1:.+]]: !stream.resource<*>, %[[TENSOR1_SIZE:.+]]: index, %[[TENSOR1_DIM0:.+]]: index, %[[TENSOR1_DIM2:.+]]: index)
util.func public @tensorTrace(%tensor0: tensor<5xf32>, %tensor1: tensor<?x3x?xi32>, %tensor1_dim0: index, %tensor1_dim2: index) {
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp
index 2acd3ba..76eef8b 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp
@@ -21,11 +21,12 @@
// %1 = stream.tensor.import %0 : !hal.buffer_view ->
// tensor<4xf32> in !stream.resource<*>
struct ConvertTensorImportOp
- : public OpConversionPattern<IREE::HAL::TensorImportOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::HAL::TensorImportOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ : public AffinityOpConversionPattern<IREE::HAL::TensorImportOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::HAL::TensorImportOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
auto sourceType = op.getSource().getType();
auto targetType = op.getTargetEncoding();
if (!llvm::isa<IREE::HAL::BufferType>(sourceType) &&
@@ -50,23 +51,22 @@
}
// Import (buffer view to stream resource).
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
auto resultType = rewriter.getType<IREE::Stream::ResourceType>(
IREE::Stream::Lifetime::External);
- auto resultSize = rewriter.createOrFold<IREE::Stream::TensorSizeOfOp>(
+ Value resultSize = rewriter.create<IREE::Stream::TensorSizeOfOp>(
op.getLoc(), rewriter.getIndexType(),
TypeAttr::get(op.getTarget().getType()), adaptor.getTargetDims(),
- affinityAttr);
+ executionAffinityAttr);
Value resource = rewriter.create<IREE::Stream::TensorImportOp>(
op.getLoc(), resultType, adaptor.getSource(), TypeAttr::get(targetType),
- adaptor.getTargetDims(), resultSize, affinityAttr);
+ adaptor.getTargetDims(), resultSize, executionAffinityAttr);
// Await the fence, if needed. When not specified the resource is assumed to
// be immediately available.
if (auto waitFence = op.getWaitFence()) {
Value waitTimepoint = rewriter.create<IREE::Stream::TimepointImportOp>(
op.getLoc(), rewriter.getType<IREE::Stream::TimepointType>(),
- ValueRange{waitFence}, affinityAttr);
+ ValueRange{waitFence}, executionAffinityAttr);
resource = rewriter
.create<IREE::Stream::TimepointAwaitOp>(
op.getLoc(), ValueRange{resource},
@@ -76,8 +76,9 @@
auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
rewriter.replaceOpWithNewOp<IREE::Stream::AsyncTransferOp>(
- op, unknownType, resource, resultSize, resultSize, affinityAttr,
- affinityAttr);
+ op, unknownType, resource, resultSize, resultSize,
+ /*source_affinity=*/executionAffinityAttr,
+ /*target_affinity=*/executionAffinityAttr);
return success();
}
@@ -121,11 +122,12 @@
// %1 = stream.tensor.export %0 : tensor<4xf32> in !stream.resource<*> ->
// !hal.buffer_view
struct ConvertTensorExportOp
- : public OpConversionPattern<IREE::HAL::TensorExportOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::HAL::TensorExportOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ : public AffinityOpConversionPattern<IREE::HAL::TensorExportOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::HAL::TensorExportOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
auto sourceType = op.getSourceEncoding();
auto targetType = op.getTarget().getType();
if (!llvm::isa<IREE::HAL::BufferType>(targetType) &&
@@ -133,9 +135,9 @@
return rewriter.notifyMatchFailure(op, "unsupported HAL cast conversion");
}
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
auto source =
- consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter);
+ transferTensorOperand(op.getLoc(), op.getSource(), adaptor.getSource(),
+ executionAffinityAttr, rewriter);
// Exporting a produced value - transfer our source value to an externally
// usable resource and directly export it. This will cause an allocation.
@@ -145,13 +147,14 @@
if (source.resource.getType() != externalType) {
exportSource = rewriter.create<IREE::Stream::AsyncTransferOp>(
op.getLoc(), externalType, source.resource, source.resourceSize,
- source.resourceSize, affinityAttr, affinityAttr);
+ source.resourceSize, /*source_affinity=*/source.affinity,
+ /*target_affinity=*/executionAffinityAttr);
}
// Export (stream resource to buffer view).
rewriter.replaceOpWithNewOp<IREE::Stream::TensorExportOp>(
op, targetType, exportSource, TypeAttr::get(sourceType),
- adaptor.getSourceDims(), source.resourceSize, affinityAttr);
+ adaptor.getSourceDims(), source.resourceSize, executionAffinityAttr);
return success();
}
};
@@ -168,25 +171,23 @@
// %update = stream.async.update %0, %storage[...]
// %2 = stream.async.slice %update[...]
struct ConvertTensorAliasOp
- : public OpConversionPattern<IREE::HAL::TensorAliasOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::HAL::TensorAliasOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ : public AffinityOpConversionPattern<IREE::HAL::TensorAliasOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ IREE::HAL::TensorAliasOp op, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
auto sourceType = op.getSource().getType();
auto source =
- consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter);
-
- // All operations (if any) will happen on the device specified by the alias
- // as that indicates the affinity of the storage.
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
+ transferTensorOperand(op.getLoc(), op.getSource(), adaptor.getSource(),
+ executionAffinityAttr, rewriter);
// Query the target storage buffer length; we will only populate up to
// what is required for the output.
- auto storageSize = rewriter.createOrFold<IREE::Stream::TensorSizeOfOp>(
+ Value storageSize = rewriter.create<IREE::Stream::TensorSizeOfOp>(
op.getLoc(), rewriter.getIndexType(),
TypeAttr::get(op.getSource().getType()), adaptor.getSourceDims(),
- affinityAttr);
+ executionAffinityAttr);
// Import the target storage as a resource that we can use as an update
// target. We overwrite the contents and just cast the storage to the
@@ -196,7 +197,7 @@
auto importOp = rewriter.create<IREE::Stream::TensorImportOp>(
op.getLoc(), externalType, adaptor.getStorage(),
TypeAttr::get(sourceType), adaptor.getSourceDims(), storageSize,
- affinityAttr);
+ executionAffinityAttr);
// Await the fence, if needed. When not specified the storage is assumed to
// be immediately available.
@@ -204,7 +205,7 @@
if (auto waitFence = op.getWaitFence()) {
Value waitTimepoint = rewriter.create<IREE::Stream::TimepointImportOp>(
op.getLoc(), rewriter.getType<IREE::Stream::TimepointType>(),
- ValueRange{waitFence}, affinityAttr);
+ ValueRange{waitFence}, executionAffinityAttr);
storage = rewriter
.create<IREE::Stream::TimepointAwaitOp>(
op.getLoc(), ValueRange{storage},
@@ -217,7 +218,7 @@
auto updateOp = rewriter.create<IREE::Stream::AsyncUpdateOp>(
op.getLoc(), externalType, storage, storageSize, zeroOffset,
source.resourceSize, source.resource, source.resourceSize,
- affinityAttr);
+ executionAffinityAttr);
// Slice out the value from the updated tensor.
// This preserves the use-def chain but is almost always elided by aliasing
@@ -225,14 +226,14 @@
auto sliceOp = rewriter.create<IREE::Stream::AsyncSliceOp>(
op.getLoc(), externalType, updateOp.getResult(),
updateOp.getTargetSize(), zeroOffset, source.resourceSize,
- source.resourceSize, affinityAttr);
+ source.resourceSize, executionAffinityAttr);
// Transfer to match original lifetime (if needed).
Value result = sliceOp.getResult();
if (source.resource.getType() != result.getType()) {
result = rewriter.create<IREE::Stream::AsyncTransferOp>(
op.getLoc(), source.resource.getType(), result, source.resourceSize,
- source.resourceSize, affinityAttr, affinityAttr);
+ source.resourceSize, executionAffinityAttr, executionAffinityAttr);
}
rewriter.replaceOp(op, result);
@@ -250,28 +251,38 @@
// %t01 = stream.timepoint.join max(%t0, %t1)
// stream.timepoint.export %t01 => %fence
struct ConvertTensorBarrierOp
- : public OpConversionPattern<IREE::HAL::TensorBarrierOp> {
- using OpConversionPattern::OpConversionPattern;
+ : public AffinityAwareConversionPattern<IREE::HAL::TensorBarrierOp> {
+ using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
LogicalResult
matchAndRewrite(IREE::HAL::TensorBarrierOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
auto timepointType = rewriter.getType<IREE::Stream::TimepointType>();
+ IREE::Stream::AffinityAttr anyAffinityAttr;
SmallVector<Value> signaledResources;
SmallVector<Value> signaledTimepoints;
- for (auto sourceResource : adaptor.getSources()) {
- auto source = consumeTensorOperand(op.getLoc(), sourceResource, rewriter);
+ for (auto [sourceTensor, sourceResource] :
+ llvm::zip_equal(op.getSources(), adaptor.getSources())) {
+ auto source = resolveTensorOperand(op.getLoc(), sourceTensor,
+ sourceResource, rewriter);
auto barrierOp = rewriter.create<IREE::Stream::TimepointBarrierOp>(
sourceResource.getLoc(), source.resource.getType(), timepointType,
- source.resource, source.resourceSize, affinityAttr);
+ source.resource, source.resourceSize, source.affinity);
signaledResources.push_back(barrierOp.getResult());
signaledTimepoints.push_back(barrierOp.getResultTimepoint());
+
+ // When joining from multiple affinities we need to pick one to perform
+ // the chain. For now we do the affinity of the last tensor with the hope
+ // that we can perform the final signal on the affinity that is running.
+ // We should instead probably change this to be set after timepoint
+ // propagation such that we ensure it happens on the final signal when not
+ // acting as a join.
+ anyAffinityAttr = source.affinity;
}
Value joinedTimepoint = IREE::Stream::TimepointJoinOp::join(
op.getLoc(), signaledTimepoints, rewriter);
rewriter.create<IREE::Stream::TimepointChainExternalOp>(
op.getLoc(), joinedTimepoint, ValueRange{adaptor.getSignalFence()},
- affinityAttr);
+ anyAffinityAttr);
rewriter.replaceOp(op, signaledResources);
return success();
}
@@ -279,21 +290,27 @@
} // namespace
-void populateHALToStreamConversionPatterns(MLIRContext *context,
- TypeConverter &typeConverter,
- RewritePatternSet &patterns) {
+void populateHALToStreamConversionPatterns(
+ MLIRContext *context, TypeConverter &typeConverter,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis,
+ RewritePatternSet &patterns) {
typeConverter.addConversion(
[](IREE::HAL::BufferViewType type) { return type; });
- patterns.insert<ConvertTensorImportOp>(typeConverter, context);
- patterns.insert<ConvertTensorExportOp>(typeConverter, context);
- patterns.insert<ConvertTensorAliasOp>(typeConverter, context);
- patterns.insert<ConvertTensorBarrierOp>(typeConverter, context);
+ patterns.insert<ConvertTensorImportOp>(typeConverter, context,
+ affinityAnalysis);
+ patterns.insert<ConvertTensorExportOp>(typeConverter, context,
+ affinityAnalysis);
+ patterns.insert<ConvertTensorAliasOp>(typeConverter, context,
+ affinityAnalysis);
+ patterns.insert<ConvertTensorBarrierOp>(typeConverter, context,
+ affinityAnalysis);
}
-void populateHALToStreamConversionPatterns(MLIRContext *context,
- ConversionTarget &conversionTarget,
- TypeConverter &typeConverter,
- RewritePatternSet &patterns) {
+void populateHALToStreamConversionPatterns(
+ MLIRContext *context, ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis,
+ RewritePatternSet &patterns) {
// Allow executables through without modification.
conversionTarget.addLegalOp<IREE::HAL::ExecutableOp>();
conversionTarget.markOpRecursivelyLegal<IREE::HAL::ExecutableOp>();
@@ -309,7 +326,8 @@
typeConverter.isLegal(op.getTarget().getType());
});
- populateHALToStreamConversionPatterns(context, typeConverter, patterns);
+ populateHALToStreamConversionPatterns(context, typeConverter,
+ affinityAnalysis, patterns);
}
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.h b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.h
index ed2a3c0..f3e955d 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.h
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.h
@@ -11,18 +11,24 @@
#include "mlir/IR/OperationSupport.h"
#include "mlir/Transforms/DialectConversion.h"
+namespace mlir::iree_compiler::IREE::Stream {
+class AffinityAnalysis;
+} // namespace mlir::iree_compiler::IREE::Stream
+
namespace mlir::iree_compiler {
// Populates conversion patterns that perform hal->stream conversion.
// These patterns ensure that nested types are run through the provided
// |typeConverter|.
-void populateHALToStreamConversionPatterns(MLIRContext *context,
- TypeConverter &typeConverter,
- RewritePatternSet &patterns);
-void populateHALToStreamConversionPatterns(MLIRContext *context,
- ConversionTarget &conversionTarget,
- TypeConverter &typeConverter,
- RewritePatternSet &patterns);
+void populateHALToStreamConversionPatterns(
+ MLIRContext *context, TypeConverter &typeConverter,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis,
+ RewritePatternSet &patterns);
+void populateHALToStreamConversionPatterns(
+ MLIRContext *context, ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis,
+ RewritePatternSet &patterns);
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp
index 6bb26f1..fee06f2 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp
@@ -7,6 +7,7 @@
#include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h"
#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
+#include "iree/compiler/Dialect/Stream/Analysis/Affinity.h"
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
@@ -23,13 +24,59 @@
return attr;
}
+IREE::Stream::AffinityAttr
+tryLookupGlobalAffinity(Operation *op,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis) {
+ return affinityAnalysis->lookupGlobalAffinity(op);
+}
+
+IREE::Stream::AffinityAttr
+tryLookupExecutionAffinity(Operation *op,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis) {
+ assert(llvm::isa<IREE::Stream::AffinityOpInterface>(op) &&
+ "must be an affinity op");
+ return affinityAnalysis->lookupExecutionAffinity(op);
+}
+
+IREE::Stream::AffinityAttr
+tryLookupResultAffinity(Value value,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis) {
+ return affinityAnalysis->lookupResourceAffinity(value);
+}
+
+static std::pair<Value, Value>
+resolveTensorOperand(Location loc, Value convertedOperand, OpBuilder &builder) {
+ auto operandType = convertedOperand.getType();
+ if (llvm::isa<IREE::Stream::ResourceType>(operandType)) {
+ // Prior to https://reviews.llvm.org/D111620 this is the path we'd take;
+ // the tensor operands would be remapped into their new resource types.
+ // This is still possible during rewriting if we ourselves produce a new
+ // resource type, but the automatic materialization will go down the
+ // unrealized_conversion_cast path below.
+ return std::make_pair(convertedOperand,
+ builder.createOrFold<IREE::Stream::ResourceSizeOp>(
+ loc, builder.getIndexType(), convertedOperand));
+ } else if (auto castOp =
+ convertedOperand
+ .getDefiningOp<mlir::UnrealizedConversionCastOp>()) {
+ // We only have a single tensor type conversion and it expands to (resource,
+ // size) so that's all we look for here.
+ assert(castOp.getNumOperands() == 2 && "expected (resource, size)");
+ return std::make_pair(castOp.getOperand(0), castOp.getOperand(1));
+ }
+ assert(false &&
+ "unexpected operand; expected either a IREE::Stream::ResourceType or "
+ "the result of a mlir::UnrealizedConversionCastOp");
+ return std::make_pair(Value{}, Value{});
+}
+
void expandResourceOperand(Location loc, Value operand,
SmallVectorImpl<Value> &newOperands,
OpBuilder &builder) {
if (llvm::isa<TensorType>(operand.getType())) {
- auto value = consumeTensorOperand(loc, operand, builder);
- newOperands.push_back(value.resource);
- newOperands.push_back(value.resourceSize);
+ auto [resource, resourceSize] = resolveTensorOperand(loc, operand, builder);
+ newOperands.push_back(resource);
+ newOperands.push_back(resourceSize);
} else if (llvm::isa<IREE::Stream::ResourceType>(operand.getType())) {
newOperands.push_back(operand);
newOperands.push_back(
@@ -49,34 +96,28 @@
return expandedOperands;
}
-ConvertedTensor consumeTensorOperand(Location loc, Value operand,
- OpBuilder &builder) {
- auto operandType = operand.getType();
- if (llvm::isa<IREE::Stream::ResourceType>(operandType)) {
- // Prior to https://reviews.llvm.org/D111620 this is the path we'd take;
- // the tensor operands would be remapped into their new resource types.
- // This is still possible during rewriting if we ourselves produce a new
- // resource type, but the automatic materialization will go down the
- // unrealized_conversion_cast path below.
- return {
- operand,
- builder.createOrFold<IREE::Stream::ResourceSizeOp>(
- loc, builder.getIndexType(), operand),
- };
- } else if (auto castOp =
- operand.getDefiningOp<mlir::UnrealizedConversionCastOp>()) {
- // We only have a single tensor type conversion and it expands to (resource,
- // size) so that's all we look for here.
- assert(castOp.getNumOperands() == 2 && "expected (resource, size)");
- return {
- castOp.getOperand(0),
- castOp.getOperand(1),
- };
+ConvertedTensor resolveTensorOperand(
+ Location loc, Value originalOperand, Value convertedOperand,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis, OpBuilder &builder) {
+ auto [resource, resourceSize] =
+ resolveTensorOperand(loc, convertedOperand, builder);
+ auto affinityAttr = affinityAnalysis->lookupResourceAffinity(originalOperand);
+ return {affinityAttr, resource, resourceSize};
+}
+
+ConvertedTensor transferTensorOperand(
+ Location loc, Value originalOperand, Value convertedOperand,
+ IREE::Stream::AffinityAttr requiredAffinityAttr,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis, OpBuilder &builder) {
+ auto [resource, resourceSize] =
+ resolveTensorOperand(loc, convertedOperand, builder);
+ auto affinityAttr = affinityAnalysis->lookupResourceAffinity(originalOperand);
+ if (affinityAttr != requiredAffinityAttr) {
+ resource = builder.create<IREE::Stream::AsyncTransferOp>(
+ loc, resource.getType(), resource, resourceSize, resourceSize,
+ affinityAttr, requiredAffinityAttr);
}
- assert(false &&
- "unexpected operand; expected either a IREE::Stream::ResourceType or "
- "the result of a mlir::UnrealizedConversionCastOp");
- return ConvertedTensor();
+ return {requiredAffinityAttr, resource, resourceSize};
}
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.h b/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.h
index fd9249e..43cfbb0 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.h
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.h
@@ -7,37 +7,123 @@
#ifndef IREE_COMPILER_DIALECT_STREAM_CONVERSION_PATTERN_UTILS_H_
#define IREE_COMPILER_DIALECT_STREAM_CONVERSION_PATTERN_UTILS_H_
+#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
+namespace mlir::iree_compiler::IREE::Stream {
+class AffinityAnalysis;
+} // namespace mlir::iree_compiler::IREE::Stream
+
namespace mlir::iree_compiler {
// Converts a supported attribute type to the corresponding stream dialect
// value. Returns the provided value if it is natively supported.
TypedAttr convertAttributeToStream(TypedAttr attr);
-void expandResourceOperand(Location loc, Value operand,
- SmallVectorImpl<Value> &newOperands,
- OpBuilder &builder);
+IREE::Stream::AffinityAttr
+tryLookupGlobalAffinity(Operation *op,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis);
+IREE::Stream::AffinityAttr
+tryLookupExecutionAffinity(Operation *op,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis);
+IREE::Stream::AffinityAttr
+tryLookupResultAffinity(Value value,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis);
-SmallVector<Value> expandResourceOperands(Location loc, ValueRange operands,
- ConversionPatternRewriter &rewriter);
-
-// https://reviews.llvm.org/D111620 broke 1->N type expansion during dialect
-// conversion. It inserts unrealized_conversion_casts but then passes the
-// illegal source dialect types for pattern operands, meaning that even though
-// we say tensors are illegal the patterns get the new remapped values as
-// tensors. This, naturally, breaks everything. To work around this we have this
-// helper that tries to peek through the unrealized_conversion_casts and get out
-// the actual values we expected to see from the conversion (and did before that
-// change).
struct ConvertedTensor {
+ // Optional affinity of the resource at the time it is consumed.
+ // May be nullptr if the affinity could not be determined.
+ IREE::Stream::AffinityAttr affinity;
+ // Resource storing the tensor.
Value resource;
+ // Size of the resource in bytes.
Value resourceSize;
};
-ConvertedTensor consumeTensorOperand(Location loc, Value operand,
- OpBuilder &builder);
+
+void expandResourceOperand(Location loc, Value convertedOperand,
+ SmallVectorImpl<Value> &newOperands,
+ OpBuilder &builder);
+SmallVector<Value> expandResourceOperands(Location loc,
+ ValueRange convertedOperands,
+ ConversionPatternRewriter &rewriter);
+
+ConvertedTensor resolveTensorOperand(
+ Location loc, Value originalOperand, Value convertedOperand,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis, OpBuilder &builder);
+ConvertedTensor transferTensorOperand(
+ Location loc, Value originalOperand, Value convertedOperand,
+ IREE::Stream::AffinityAttr requiredAffinityAttr,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis, OpBuilder &builder);
+
+template <typename OpT>
+struct AffinityAwareConversionPattern : public OpConversionPattern<OpT> {
+public:
+ AffinityAwareConversionPattern(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern<OpT>(typeConverter, context, benefit),
+ affinityAnalysis(affinityAnalysis) {}
+
+ IREE::Stream::AffinityAnalysis *getAffinityAnalysis() const {
+ return affinityAnalysis;
+ }
+
+protected:
+ ConvertedTensor resolveTensorOperand(Location loc, Value originalOperand,
+ Value convertedOperand,
+ OpBuilder &builder) const {
+ return mlir::iree_compiler::resolveTensorOperand(
+ loc, originalOperand, convertedOperand, affinityAnalysis, builder);
+ }
+
+ ConvertedTensor
+ transferTensorOperand(Location loc, Value originalOperand,
+ Value convertedOperand,
+ IREE::Stream::AffinityAttr requiredAffinityAttr,
+ OpBuilder &builder) const {
+ return mlir::iree_compiler::transferTensorOperand(
+ loc, originalOperand, convertedOperand, requiredAffinityAttr,
+ affinityAnalysis, builder);
+ }
+
+ IREE::Stream::AffinityAttr lookupResultAffinity(Value originalResult) const {
+ return mlir::iree_compiler::tryLookupResultAffinity(originalResult,
+ affinityAnalysis);
+ }
+
+ IREE::Stream::AffinityAnalysis *affinityAnalysis = nullptr;
+};
+
+template <typename OpT>
+struct AffinityOpConversionPattern
+ : public AffinityAwareConversionPattern<OpT> {
+public:
+ AffinityOpConversionPattern(const TypeConverter &typeConverter,
+ MLIRContext *context,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis,
+ PatternBenefit benefit = 1)
+ : AffinityAwareConversionPattern<OpT>(typeConverter, context,
+ affinityAnalysis, benefit) {}
+
+protected:
+ virtual LogicalResult matchAndRewriteOnAffinity(
+ OpT op, typename OpConversionPattern<OpT>::OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const = 0;
+
+private:
+ LogicalResult
+ matchAndRewrite(OpT op, typename OpConversionPattern<OpT>::OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override final {
+ auto executionAffinityAttr =
+ tryLookupExecutionAffinity(op, this->getAffinityAnalysis());
+ return matchAndRewriteOnAffinity(op, adaptor, executionAffinityAttr,
+ rewriter);
+ }
+};
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/BUILD.bazel
index 646d55e..38ad9ee 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/BUILD.bazel
@@ -15,8 +15,6 @@
iree_compiler_cc_library(
name = "StandardToStream",
srcs = [
- "ConvertConstantOps.cpp",
- "ConvertStructuralOps.cpp",
"Patterns.cpp",
],
hdrs = [
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/CMakeLists.txt
index b910c60..3def716 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/CMakeLists.txt
@@ -16,8 +16,6 @@
HDRS
"Patterns.h"
SRCS
- "ConvertConstantOps.cpp"
- "ConvertStructuralOps.cpp"
"Patterns.cpp"
DEPS
LLVMSupport
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertConstantOps.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertConstantOps.cpp
deleted file mode 100644
index 5ff99f7..0000000
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertConstantOps.cpp
+++ /dev/null
@@ -1,66 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h"
-#include "iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.h"
-#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
-#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-namespace mlir::iree_compiler {
-
-namespace {
-
-struct ConvertTensorConstantOp : public OpConversionPattern<arith::ConstantOp> {
-public:
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- // Only handle tensor types - other arith.constant types (like i32) are
- // ignored.
- if (!llvm::isa<TensorType>(constantOp.getType()))
- return failure();
-
- Type constantType = IREE::Stream::ResourceType::get(
- getContext(), IREE::Stream::Lifetime::Constant);
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(constantOp);
- auto newOp = rewriter.create<IREE::Stream::TensorConstantOp>(
- constantOp.getLoc(), constantType,
- convertAttributeToStream(constantOp.getValue()),
- TypeAttr::get(constantOp.getType()),
- /*result_encoding_dims=*/ValueRange{}, affinityAttr);
-
- Type unknownType = IREE::Stream::ResourceType::get(getContext());
- auto constantSize = rewriter.createOrFold<IREE::Stream::ResourceSizeOp>(
- constantOp.getLoc(), rewriter.getIndexType(), newOp.getResult());
- rewriter.replaceOpWithNewOp<IREE::Stream::AsyncTransferOp>(
- constantOp, unknownType, newOp.getResult(), constantSize, constantSize,
- /*source_affinity=*/affinityAttr,
- /*result_affinity=*/affinityAttr);
- return success();
- }
-};
-
-} // namespace
-
-void populateStandardConstantToStreamPatterns(
- MLIRContext *context, ConversionTarget &conversionTarget,
- TypeConverter &typeConverter, RewritePatternSet &patterns) {
- conversionTarget.addDynamicallyLegalOp<arith::ConstantOp>(
- [](arith::ConstantOp op) {
- return !llvm::isa<TensorType>(op.getType());
- });
-
- patterns.insert<ConvertTensorConstantOp>(typeConverter, context);
-}
-
-} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertStructuralOps.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertStructuralOps.cpp
deleted file mode 100644
index 5b29504..0000000
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertStructuralOps.cpp
+++ /dev/null
@@ -1,406 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h"
-#include "iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.h"
-#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
-#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
-#include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVector.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/IRMapping.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-namespace mlir::iree_compiler {
-
-namespace {
-
-struct BranchOpConversion : public OpConversionPattern<mlir::cf::BranchOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(mlir::cf::BranchOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- // Expand any resource operands to resource + size.
- auto expandedOperands = expandResourceOperands(
- op.getLoc(), adaptor.getDestOperands(), rewriter);
- rewriter.replaceOpWithNewOp<mlir::cf::BranchOp>(op, op.getDest(),
- expandedOperands);
- return success();
- }
-};
-
-struct CondBranchOpConversion
- : public OpConversionPattern<mlir::cf::CondBranchOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(mlir::cf::CondBranchOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- // Expand any resource operands to resource + size.
- auto trueDestOperands = expandResourceOperands(
- op.getLoc(), adaptor.getTrueDestOperands(), rewriter);
- auto falseDestOperands = expandResourceOperands(
- op.getLoc(), adaptor.getFalseDestOperands(), rewriter);
- rewriter.replaceOpWithNewOp<mlir::cf::CondBranchOp>(
- op, adaptor.getCondition(), op.getTrueDest(), trueDestOperands,
- op.getFalseDest(), falseDestOperands);
- return success();
- }
-};
-
-static ValueRange asValueRange(ArrayRef<Value> values) { return values; }
-
-struct SwitchOpConversion : public OpConversionPattern<mlir::cf::SwitchOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(mlir::cf::SwitchOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- // Expand any resource operands to resource + size.
- auto defaultOperands = expandResourceOperands(
- op.getLoc(), adaptor.getDefaultOperands(), rewriter);
- auto caseOperands = llvm::to_vector(
- llvm::map_range(adaptor.getCaseOperands(), [&](ValueRange operands) {
- return expandResourceOperands(op.getLoc(), operands, rewriter);
- }));
- rewriter.replaceOpWithNewOp<mlir::cf::SwitchOp>(
- op, adaptor.getFlag(), op.getDefaultDestination(), defaultOperands,
- op.getCaseValuesAttr(), op.getCaseDestinations(),
- llvm::to_vector(llvm::map_range(caseOperands, asValueRange)));
- return success();
- }
-};
-
-struct SelectOpConversion : public OpConversionPattern<mlir::arith::SelectOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(mlir::arith::SelectOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- // Only handle selects where the operands are tensors (resources).
- if (!llvm::isa<TensorType>(op.getTrueValue().getType()))
- return failure();
- auto trueOperand =
- consumeTensorOperand(op.getLoc(), adaptor.getTrueValue(), rewriter);
- auto falseOperand =
- consumeTensorOperand(op.getLoc(), adaptor.getFalseValue(), rewriter);
- auto resourceSelectOp = rewriter.create<mlir::arith::SelectOp>(
- op.getLoc(), adaptor.getCondition(), trueOperand.resource,
- falseOperand.resource);
- auto sizeSelectOp = rewriter.create<mlir::arith::SelectOp>(
- op.getLoc(), adaptor.getCondition(), trueOperand.resourceSize,
- falseOperand.resourceSize);
- rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
- op, adaptor.getTrueValue().getType(),
- ValueRange{resourceSelectOp.getResult(), sizeSelectOp.getResult()});
- return success();
- }
-};
-
-struct ScfIfOpConversion : public OpConversionPattern<mlir::scf::IfOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(mlir::scf::IfOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- // Expand any resource operands to resource + size.
- auto expandedOperands =
- expandResourceOperands(op.getLoc(), adaptor.getOperands(), rewriter);
-
- // Expand any resource results to resource + size.
- SmallVector<Type> expandedTypes;
- struct Result {
- size_t originalIndex;
- size_t newIndex;
- Type newType;
- };
- SmallVector<Result> resultMap;
- for (auto originalType : llvm::enumerate(op.getResultTypes())) {
- SmallVector<Type> newTypes;
- if (failed(getTypeConverter()->convertType(originalType.value(),
- newTypes))) {
- return rewriter.notifyMatchFailure(op,
- "unable to convert result types");
- }
- resultMap.push_back(
- Result{originalType.index(), expandedTypes.size(), newTypes.front()});
- expandedTypes.append(newTypes);
- }
-
- // Create a new call that takes the expanded input operands and returns the
- // expanded output results. We can't directly replace the original call as
- // the result counts differ.
- auto ifOp = rewriter.create<mlir::scf::IfOp>(op.getLoc(), expandedTypes,
- op.getCondition());
-
- ifOp.getThenRegion().getBlocks().clear();
- rewriter.inlineRegionBefore(op.getThenRegion(), ifOp.getThenRegion(),
- ifOp.getThenRegion().end());
-
- ifOp.getElseRegion().getBlocks().clear();
- rewriter.inlineRegionBefore(op.getElseRegion(), ifOp.getElseRegion(),
- ifOp.getElseRegion().end());
-
- // Tie all resource results together so we end up with 1:1 results with the
- // original op.
- SmallVector<Value> results;
- for (auto result : resultMap) {
- if (llvm::isa<IREE::Stream::ResourceType>(result.newType)) {
- auto oldType = op.getResult(result.originalIndex).getType();
- auto resource = ifOp.getResult(result.newIndex + 0);
- auto resourceSize = ifOp.getResult(result.newIndex + 1);
- results.push_back(rewriter
- .create<mlir::UnrealizedConversionCastOp>(
- op.getLoc(), TypeRange{oldType},
- ValueRange{resource, resourceSize})
- .getResult(0));
- } else {
- results.push_back(ifOp.getResult(result.newIndex));
- }
- }
- rewriter.replaceOp(op, results);
- return success();
- }
-};
-
-struct ScfForOpConversion : public OpConversionPattern<mlir::scf::ForOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(mlir::scf::ForOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto &typeConverter = *getTypeConverter();
- // Expand any resource operands to resource + size.
- auto expandedOperands =
- expandResourceOperands(op.getLoc(), adaptor.getInitArgs(), rewriter);
-
- // Expand any resource results to resource + size.
- SmallVector<Type> expandedTypes;
- struct Result {
- size_t originalIndex;
- size_t newIndex;
- Type newType;
- };
- SmallVector<Result> resultMap;
- for (auto originalType : llvm::enumerate(op.getResultTypes())) {
- SmallVector<Type> newTypes;
- if (failed(getTypeConverter()->convertType(originalType.value(),
- newTypes))) {
- return rewriter.notifyMatchFailure(op,
- "unable to convert result types");
- }
- resultMap.push_back(
- Result{originalType.index(), expandedTypes.size(), newTypes.front()});
- expandedTypes.append(newTypes);
- }
-
- auto &block = op.getRegion().front();
- TypeConverter::SignatureConversion newSignature(block.getNumArguments());
- for (auto arg : llvm::enumerate(block.getArgumentTypes())) {
- if (failed(typeConverter.convertSignatureArg(arg.index(), arg.value(),
- newSignature))) {
- return failure();
- }
- }
-
- // Create a new loop that takes the expanded input operands and returns the
- // expanded output results. We can't directly replace the original loop as
- // the result counts differ.
- auto forOp = rewriter.create<mlir::scf::ForOp>(
- op.getLoc(), adaptor.getLowerBound(), adaptor.getUpperBound(),
- adaptor.getStep(), expandedOperands);
-
- // Inline the block and update the block arguments.
- rewriter.eraseBlock(forOp.getBody());
- rewriter.inlineRegionBefore(op.getRegion(), forOp.getRegion(),
- forOp.getRegion().end());
- if (failed(rewriter.convertRegionTypes(&forOp.getRegion(), typeConverter,
- &newSignature))) {
- return failure();
- }
-
- // Tie all resource results together so we end up with 1:1 results with the
- // original op.
- SmallVector<Value> results;
- for (auto result : resultMap) {
- if (llvm::isa<IREE::Stream::ResourceType>(result.newType)) {
- auto oldType = op.getResult(result.originalIndex).getType();
- auto resource = forOp.getResult(result.newIndex + 0);
- auto resourceSize = forOp.getResult(result.newIndex + 1);
- results.push_back(rewriter
- .create<mlir::UnrealizedConversionCastOp>(
- op.getLoc(), TypeRange{oldType},
- ValueRange{resource, resourceSize})
- .getResult(0));
- } else {
- results.push_back(forOp.getResult(result.newIndex));
- }
- }
- rewriter.replaceOp(op, results);
- return success();
- }
-};
-
-struct ScfWhileOpConversion : public OpConversionPattern<mlir::scf::WhileOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(mlir::scf::WhileOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto &typeConverter = *getTypeConverter();
- // Expand any resource operands to resource + size.
- auto expandedOperands =
- expandResourceOperands(op.getLoc(), adaptor.getOperands(), rewriter);
-
- // Expand any resource results to resource + size.
- SmallVector<Type> expandedTypes;
- struct Result {
- size_t originalIndex;
- size_t newIndex;
- Type newType;
- };
- SmallVector<Result> resultMap;
- for (auto originalType : llvm::enumerate(op.getResultTypes())) {
- SmallVector<Type> newTypes;
- if (failed(getTypeConverter()->convertType(originalType.value(),
- newTypes))) {
- return rewriter.notifyMatchFailure(op,
- "unable to convert result types");
- }
- resultMap.push_back(
- Result{originalType.index(), expandedTypes.size(), newTypes.front()});
- expandedTypes.append(newTypes);
- }
-
- TypeConverter::SignatureConversion newSignature(op.getNumOperands());
- for (auto argType : llvm::enumerate(op.getOperandTypes())) {
- if (failed(typeConverter.convertSignatureArg(
- argType.index(), argType.value(), newSignature))) {
- return failure();
- }
- }
-
- // Create a new call that takes the expanded input operands and returns the
- // expanded output results. We can't directly replace the original call as
- // the result counts differ.
- auto whileOp = rewriter.create<mlir::scf::WhileOp>(
- op.getLoc(), expandedTypes, expandedOperands);
-
- // Inline the `before` block and update the block arguments.
- whileOp.getBefore().getBlocks().clear();
- rewriter.inlineRegionBefore(op.getBefore(), whileOp.getBefore(),
- whileOp.getBefore().end());
- if (failed(rewriter.convertRegionTypes(&whileOp.getBefore(), typeConverter,
- &newSignature))) {
- return failure();
- }
-
- // Inline the `after` block and update the block arguments.
- whileOp.getAfter().getBlocks().clear();
- rewriter.inlineRegionBefore(op.getAfter(), whileOp.getAfter(),
- whileOp.getAfter().end());
- if (failed(rewriter.convertRegionTypes(&whileOp.getAfter(), typeConverter,
- &newSignature))) {
- return failure();
- }
-
- // Tie all resource results together so we end up with 1:1 results with the
- // original op.
- SmallVector<Value> results;
- for (auto result : resultMap) {
- if (llvm::isa<IREE::Stream::ResourceType>(result.newType)) {
- auto oldType = op.getResult(result.originalIndex).getType();
- auto resource = whileOp.getResult(result.newIndex + 0);
- auto resourceSize = whileOp.getResult(result.newIndex + 1);
- results.push_back(rewriter
- .create<mlir::UnrealizedConversionCastOp>(
- op.getLoc(), TypeRange{oldType},
- ValueRange{resource, resourceSize})
- .getResult(0));
- } else {
- results.push_back(whileOp.getResult(result.newIndex));
- }
- }
- rewriter.replaceOp(op, results);
- return success();
- }
-};
-
-struct ScfConditionOpConversion
- : public OpConversionPattern<mlir::scf::ConditionOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(mlir::scf::ConditionOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- // Expand any resource operands to resource + size.
- auto expandedOperands =
- expandResourceOperands(op.getLoc(), adaptor.getArgs(), rewriter);
- rewriter.replaceOpWithNewOp<mlir::scf::ConditionOp>(
- op, adaptor.getCondition(), expandedOperands);
- return success();
- }
-};
-
-struct ScfYieldOpConversion : public OpConversionPattern<mlir::scf::YieldOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(mlir::scf::YieldOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- // Expand any resource operands to resource + size.
- auto expandedOperands =
- expandResourceOperands(op.getLoc(), adaptor.getOperands(), rewriter);
- rewriter.replaceOpWithNewOp<mlir::scf::YieldOp>(op, expandedOperands);
- return success();
- }
-};
-
-} // namespace
-
-template <typename OpT>
-static inline void addGenericLegalOp(ConversionTarget &conversionTarget,
- TypeConverter &typeConverter) {
- conversionTarget.addDynamicallyLegalOp<OpT>([&](OpT op) {
- return llvm::all_of(
- op->getOperandTypes(),
- [&typeConverter](Type t) { return typeConverter.isLegal(t); }) &&
- llvm::all_of(op->getResultTypes(), [&typeConverter](Type t) {
- return typeConverter.isLegal(t);
- });
- });
-}
-
-void populateStandardStructuralToStreamPatterns(
- MLIRContext *context, ConversionTarget &conversionTarget,
- TypeConverter &typeConverter, RewritePatternSet &patterns) {
- conversionTarget.addLegalOp<mlir::ModuleOp>();
-
- // We need to rewrite certain types on operands/results so use the default
- // dynamic legality checker to force any ops using such types to run through
- // our patterns.
-
- addGenericLegalOp<mlir::cf::BranchOp>(conversionTarget, typeConverter);
- addGenericLegalOp<mlir::cf::CondBranchOp>(conversionTarget, typeConverter);
- addGenericLegalOp<mlir::cf::SwitchOp>(conversionTarget, typeConverter);
- patterns
- .insert<BranchOpConversion, CondBranchOpConversion, SwitchOpConversion>(
- typeConverter, context);
-
- addGenericLegalOp<mlir::arith::SelectOp>(conversionTarget, typeConverter);
- patterns.insert<SelectOpConversion>(typeConverter, context);
-
- addGenericLegalOp<mlir::scf::IfOp>(conversionTarget, typeConverter);
- addGenericLegalOp<mlir::scf::ForOp>(conversionTarget, typeConverter);
- addGenericLegalOp<mlir::scf::WhileOp>(conversionTarget, typeConverter);
- addGenericLegalOp<mlir::scf::ConditionOp>(conversionTarget, typeConverter);
- addGenericLegalOp<mlir::scf::YieldOp>(conversionTarget, typeConverter);
- patterns
- .insert<ScfConditionOpConversion, ScfIfOpConversion, ScfForOpConversion,
- ScfWhileOpConversion, ScfYieldOpConversion>(typeConverter,
- context);
-}
-
-} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.cpp
index 1725fb0..9924fd2 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.cpp
@@ -6,26 +6,419 @@
#include "iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.h"
+#include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h"
#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir::iree_compiler {
-void populateStandardConstantToStreamPatterns(
- MLIRContext *context, ConversionTarget &conversionTarget,
- TypeConverter &typeConverter, RewritePatternSet &patterns);
+namespace {
-void populateStandardStructuralToStreamPatterns(
- MLIRContext *context, ConversionTarget &conversionTarget,
- TypeConverter &typeConverter, RewritePatternSet &patterns);
+struct ConvertTensorConstantOp
+ : public AffinityOpConversionPattern<arith::ConstantOp> {
+ using AffinityOpConversionPattern::AffinityOpConversionPattern;
+ LogicalResult matchAndRewriteOnAffinity(
+ arith::ConstantOp constantOp, OpAdaptor adaptor,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ ConversionPatternRewriter &rewriter) const override {
+ // Only handle tensor types - other arith.constant types (like i32) are
+ // ignored.
+ if (!llvm::isa<TensorType>(constantOp.getType())) {
+ return failure();
+ }
+
+ auto constantType = rewriter.getType<IREE::Stream::ResourceType>(
+ IREE::Stream::Lifetime::Constant);
+ auto newOp = rewriter.create<IREE::Stream::TensorConstantOp>(
+ constantOp.getLoc(), constantType,
+ convertAttributeToStream(constantOp.getValue()),
+ TypeAttr::get(constantOp.getType()),
+ /*result_encoding_dims=*/ValueRange{}, executionAffinityAttr);
+
+ auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
+ auto constantSize = rewriter.createOrFold<IREE::Stream::ResourceSizeOp>(
+ constantOp.getLoc(), rewriter.getIndexType(), newOp.getResult());
+ rewriter.replaceOpWithNewOp<IREE::Stream::AsyncTransferOp>(
+ constantOp, unknownType, newOp.getResult(), constantSize, constantSize,
+ /*source_affinity=*/executionAffinityAttr,
+ /*result_affinity=*/executionAffinityAttr);
+ return success();
+ }
+};
+
+struct BranchOpConversion
+ : public AffinityAwareConversionPattern<mlir::cf::BranchOp> {
+ using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
+ LogicalResult
+ matchAndRewrite(mlir::cf::BranchOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Expand any resource operands to resource + size.
+ auto expandedOperands = expandResourceOperands(
+ op.getLoc(), adaptor.getDestOperands(), rewriter);
+ rewriter.replaceOpWithNewOp<mlir::cf::BranchOp>(op, op.getDest(),
+ expandedOperands);
+ return success();
+ }
+};
+
+struct CondBranchOpConversion
+ : public AffinityAwareConversionPattern<mlir::cf::CondBranchOp> {
+ using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
+ LogicalResult
+ matchAndRewrite(mlir::cf::CondBranchOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Expand any resource operands to resource + size.
+ auto trueDestOperands = expandResourceOperands(
+ op.getLoc(), adaptor.getTrueDestOperands(), rewriter);
+ auto falseDestOperands = expandResourceOperands(
+ op.getLoc(), adaptor.getFalseDestOperands(), rewriter);
+ rewriter.replaceOpWithNewOp<mlir::cf::CondBranchOp>(
+ op, adaptor.getCondition(), op.getTrueDest(), trueDestOperands,
+ op.getFalseDest(), falseDestOperands);
+ return success();
+ }
+};
+
+static ValueRange asValueRange(ArrayRef<Value> values) { return values; }
+
+struct SwitchOpConversion
+ : public AffinityAwareConversionPattern<mlir::cf::SwitchOp> {
+ using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
+ LogicalResult
+ matchAndRewrite(mlir::cf::SwitchOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Expand any resource operands to resource + size.
+ auto defaultOperands = expandResourceOperands(
+ op.getLoc(), adaptor.getDefaultOperands(), rewriter);
+ auto caseOperands = llvm::to_vector(
+ llvm::map_range(adaptor.getCaseOperands(), [&](ValueRange operands) {
+ return expandResourceOperands(op.getLoc(), operands, rewriter);
+ }));
+ rewriter.replaceOpWithNewOp<mlir::cf::SwitchOp>(
+ op, adaptor.getFlag(), op.getDefaultDestination(), defaultOperands,
+ op.getCaseValuesAttr(), op.getCaseDestinations(),
+ llvm::to_vector(llvm::map_range(caseOperands, asValueRange)));
+ return success();
+ }
+};
+
+struct SelectOpConversion
+ : public AffinityAwareConversionPattern<mlir::arith::SelectOp> {
+ using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
+ LogicalResult
+ matchAndRewrite(mlir::arith::SelectOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Only handle selects where the operands are tensors (resources).
+ if (!llvm::isa<TensorType>(op.getTrueValue().getType()))
+ return failure();
+ auto trueOperand = resolveTensorOperand(op.getLoc(), op.getTrueValue(),
+ adaptor.getTrueValue(), rewriter);
+ auto falseOperand = resolveTensorOperand(op.getLoc(), op.getFalseValue(),
+ adaptor.getFalseValue(), rewriter);
+ auto resourceSelectOp = rewriter.create<mlir::arith::SelectOp>(
+ op.getLoc(), adaptor.getCondition(), trueOperand.resource,
+ falseOperand.resource);
+ auto sizeSelectOp = rewriter.create<mlir::arith::SelectOp>(
+ op.getLoc(), adaptor.getCondition(), trueOperand.resourceSize,
+ falseOperand.resourceSize);
+ rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
+ op, adaptor.getTrueValue().getType(),
+ ValueRange{resourceSelectOp.getResult(), sizeSelectOp.getResult()});
+ return success();
+ }
+};
+
+struct ScfIfOpConversion
+ : public AffinityAwareConversionPattern<mlir::scf::IfOp> {
+ using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
+ LogicalResult
+ matchAndRewrite(mlir::scf::IfOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Expand any resource results to resource + size.
+ SmallVector<Type> expandedTypes;
+ struct Result {
+ size_t originalIndex;
+ size_t newIndex;
+ Type newType;
+ };
+ SmallVector<Result> resultMap;
+ for (auto originalType : llvm::enumerate(op.getResultTypes())) {
+ SmallVector<Type> newTypes;
+ if (failed(getTypeConverter()->convertType(originalType.value(),
+ newTypes))) {
+ return rewriter.notifyMatchFailure(op,
+ "unable to convert result types");
+ }
+ resultMap.push_back(
+ Result{originalType.index(), expandedTypes.size(), newTypes.front()});
+ expandedTypes.append(newTypes);
+ }
+
+ // Create a new call that takes the expanded input operands and returns the
+ // expanded output results. We can't directly replace the original call as
+ // the result counts differ.
+ auto ifOp = rewriter.create<mlir::scf::IfOp>(op.getLoc(), expandedTypes,
+ op.getCondition());
+
+ ifOp.getThenRegion().getBlocks().clear();
+ rewriter.inlineRegionBefore(op.getThenRegion(), ifOp.getThenRegion(),
+ ifOp.getThenRegion().end());
+
+ ifOp.getElseRegion().getBlocks().clear();
+ rewriter.inlineRegionBefore(op.getElseRegion(), ifOp.getElseRegion(),
+ ifOp.getElseRegion().end());
+
+ // Tie all resource results together so we end up with 1:1 results with the
+ // original op.
+ SmallVector<Value> results;
+ for (auto result : resultMap) {
+ if (llvm::isa<IREE::Stream::ResourceType>(result.newType)) {
+ auto oldType = op.getResult(result.originalIndex).getType();
+ auto resource = ifOp.getResult(result.newIndex + 0);
+ auto resourceSize = ifOp.getResult(result.newIndex + 1);
+ results.push_back(rewriter
+ .create<mlir::UnrealizedConversionCastOp>(
+ op.getLoc(), TypeRange{oldType},
+ ValueRange{resource, resourceSize})
+ .getResult(0));
+ } else {
+ results.push_back(ifOp.getResult(result.newIndex));
+ }
+ }
+ rewriter.replaceOp(op, results);
+ return success();
+ }
+};
+
+struct ScfForOpConversion
+ : public AffinityAwareConversionPattern<mlir::scf::ForOp> {
+ using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
+ LogicalResult
+ matchAndRewrite(mlir::scf::ForOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto &typeConverter = *getTypeConverter();
+
+ // Expand any resource operands to resource + size.
+ auto expandedOperands =
+ expandResourceOperands(op.getLoc(), adaptor.getInitArgs(), rewriter);
+
+ // Expand any resource results to resource + size.
+ SmallVector<Type> expandedTypes;
+ struct Result {
+ size_t originalIndex;
+ size_t newIndex;
+ Type newType;
+ };
+ SmallVector<Result> resultMap;
+ for (auto originalType : llvm::enumerate(op.getResultTypes())) {
+ SmallVector<Type> newTypes;
+ if (failed(getTypeConverter()->convertType(originalType.value(),
+ newTypes))) {
+ return rewriter.notifyMatchFailure(op,
+ "unable to convert result types");
+ }
+ resultMap.push_back(
+ Result{originalType.index(), expandedTypes.size(), newTypes.front()});
+ expandedTypes.append(newTypes);
+ }
+
+ auto &block = op.getRegion().front();
+ TypeConverter::SignatureConversion newSignature(block.getNumArguments());
+ for (auto arg : llvm::enumerate(block.getArgumentTypes())) {
+ if (failed(typeConverter.convertSignatureArg(arg.index(), arg.value(),
+ newSignature))) {
+ return failure();
+ }
+ }
+
+ // Create a new loop that takes the expanded input operands and returns the
+ // expanded output results. We can't directly replace the original loop as
+ // the result counts differ.
+ auto forOp = rewriter.create<mlir::scf::ForOp>(
+ op.getLoc(), adaptor.getLowerBound(), adaptor.getUpperBound(),
+ adaptor.getStep(), expandedOperands);
+
+ // Inline the block and update the block arguments.
+ rewriter.eraseBlock(forOp.getBody());
+ rewriter.inlineRegionBefore(op.getRegion(), forOp.getRegion(),
+ forOp.getRegion().end());
+ if (failed(rewriter.convertRegionTypes(&forOp.getRegion(), typeConverter,
+ &newSignature))) {
+ return failure();
+ }
+
+ // Tie all resource results together so we end up with 1:1 results with the
+ // original op.
+ SmallVector<Value> results;
+ for (auto result : resultMap) {
+ if (llvm::isa<IREE::Stream::ResourceType>(result.newType)) {
+ auto oldType = op.getResult(result.originalIndex).getType();
+ auto resource = forOp.getResult(result.newIndex + 0);
+ auto resourceSize = forOp.getResult(result.newIndex + 1);
+ results.push_back(rewriter
+ .create<mlir::UnrealizedConversionCastOp>(
+ op.getLoc(), TypeRange{oldType},
+ ValueRange{resource, resourceSize})
+ .getResult(0));
+ } else {
+ results.push_back(forOp.getResult(result.newIndex));
+ }
+ }
+ rewriter.replaceOp(op, results);
+ return success();
+ }
+};
+
+struct ScfWhileOpConversion
+ : public AffinityAwareConversionPattern<mlir::scf::WhileOp> {
+ using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
+ LogicalResult
+ matchAndRewrite(mlir::scf::WhileOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto &typeConverter = *getTypeConverter();
+
+ // Expand any resource operands to resource + size.
+ auto expandedOperands =
+ expandResourceOperands(op.getLoc(), adaptor.getOperands(), rewriter);
+
+ // Expand any resource results to resource + size.
+ SmallVector<Type> expandedTypes;
+ struct Result {
+ size_t originalIndex;
+ size_t newIndex;
+ Type newType;
+ };
+ SmallVector<Result> resultMap;
+ for (auto originalType : llvm::enumerate(op.getResultTypes())) {
+ SmallVector<Type> newTypes;
+ if (failed(getTypeConverter()->convertType(originalType.value(),
+ newTypes))) {
+ return rewriter.notifyMatchFailure(op,
+ "unable to convert result types");
+ }
+ resultMap.push_back(
+ Result{originalType.index(), expandedTypes.size(), newTypes.front()});
+ expandedTypes.append(newTypes);
+ }
+
+ TypeConverter::SignatureConversion newSignature(op.getNumOperands());
+ for (auto argType : llvm::enumerate(op.getOperandTypes())) {
+ if (failed(typeConverter.convertSignatureArg(
+ argType.index(), argType.value(), newSignature))) {
+ return failure();
+ }
+ }
+
+ // Create a new call that takes the expanded input operands and returns the
+ // expanded output results. We can't directly replace the original call as
+ // the result counts differ.
+ auto whileOp = rewriter.create<mlir::scf::WhileOp>(
+ op.getLoc(), expandedTypes, expandedOperands);
+
+ // Inline the `before` block and update the block arguments.
+ whileOp.getBefore().getBlocks().clear();
+ rewriter.inlineRegionBefore(op.getBefore(), whileOp.getBefore(),
+ whileOp.getBefore().end());
+ if (failed(rewriter.convertRegionTypes(&whileOp.getBefore(), typeConverter,
+ &newSignature))) {
+ return failure();
+ }
+
+ // Inline the `after` block and update the block arguments.
+ whileOp.getAfter().getBlocks().clear();
+ rewriter.inlineRegionBefore(op.getAfter(), whileOp.getAfter(),
+ whileOp.getAfter().end());
+ if (failed(rewriter.convertRegionTypes(&whileOp.getAfter(), typeConverter,
+ &newSignature))) {
+ return failure();
+ }
+
+ // Tie all resource results together so we end up with 1:1 results with the
+ // original op.
+ SmallVector<Value> results;
+ for (auto result : resultMap) {
+ if (llvm::isa<IREE::Stream::ResourceType>(result.newType)) {
+ auto oldType = op.getResult(result.originalIndex).getType();
+ auto resource = whileOp.getResult(result.newIndex + 0);
+ auto resourceSize = whileOp.getResult(result.newIndex + 1);
+ results.push_back(rewriter
+ .create<mlir::UnrealizedConversionCastOp>(
+ op.getLoc(), TypeRange{oldType},
+ ValueRange{resource, resourceSize})
+ .getResult(0));
+ } else {
+ results.push_back(whileOp.getResult(result.newIndex));
+ }
+ }
+ rewriter.replaceOp(op, results);
+ return success();
+ }
+};
+
+struct ScfConditionOpConversion
+ : public AffinityAwareConversionPattern<mlir::scf::ConditionOp> {
+ using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
+ LogicalResult
+ matchAndRewrite(mlir::scf::ConditionOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Expand any resource operands to resource + size.
+ auto expandedOperands =
+ expandResourceOperands(op.getLoc(), adaptor.getArgs(), rewriter);
+ rewriter.replaceOpWithNewOp<mlir::scf::ConditionOp>(
+ op, adaptor.getCondition(), expandedOperands);
+ return success();
+ }
+};
+
+struct ScfYieldOpConversion
+ : public AffinityAwareConversionPattern<mlir::scf::YieldOp> {
+ using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
+ LogicalResult
+ matchAndRewrite(mlir::scf::YieldOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Expand any resource operands to resource + size.
+ auto expandedOperands =
+ expandResourceOperands(op.getLoc(), adaptor.getOperands(), rewriter);
+ rewriter.replaceOpWithNewOp<mlir::scf::YieldOp>(op, expandedOperands);
+ return success();
+ }
+};
+
+template <typename OpT>
+static inline void addGenericLegalOp(ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter) {
+ conversionTarget.addDynamicallyLegalOp<OpT>([&](OpT op) {
+ return llvm::all_of(
+ op->getOperandTypes(),
+ [&typeConverter](Type t) { return typeConverter.isLegal(t); }) &&
+ llvm::all_of(op->getResultTypes(), [&typeConverter](Type t) {
+ return typeConverter.isLegal(t);
+ });
+ });
+}
+
+} // namespace
void populateStandardToStreamConversionPatterns(
MLIRContext *context, ConversionTarget &conversionTarget,
- TypeConverter &typeConverter, RewritePatternSet &patterns) {
+ TypeConverter &typeConverter,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis,
+ RewritePatternSet &patterns) {
typeConverter.addConversion([](IndexType type) { return type; });
typeConverter.addConversion([](IntegerType type) { return type; });
typeConverter.addConversion([](FloatType type) { return type; });
@@ -35,10 +428,38 @@
conversionTarget.addIllegalOp<memref::DimOp, memref::RankOp, tensor::DimOp,
tensor::RankOp>();
- populateStandardConstantToStreamPatterns(context, conversionTarget,
- typeConverter, patterns);
- populateStandardStructuralToStreamPatterns(context, conversionTarget,
- typeConverter, patterns);
+ conversionTarget.addDynamicallyLegalOp<arith::ConstantOp>(
+ [](arith::ConstantOp op) {
+ return !llvm::isa<TensorType>(op.getType());
+ });
+ patterns.insert<ConvertTensorConstantOp>(typeConverter, context,
+ affinityAnalysis);
+
+ conversionTarget.addLegalOp<mlir::ModuleOp>();
+
+ // We need to rewrite certain types on operands/results so use the default
+ // dynamic legality checker to force any ops using such types to run through
+ // our patterns.
+
+ addGenericLegalOp<mlir::cf::BranchOp>(conversionTarget, typeConverter);
+ addGenericLegalOp<mlir::cf::CondBranchOp>(conversionTarget, typeConverter);
+ addGenericLegalOp<mlir::cf::SwitchOp>(conversionTarget, typeConverter);
+ patterns
+ .insert<BranchOpConversion, CondBranchOpConversion, SwitchOpConversion>(
+ typeConverter, context, affinityAnalysis);
+
+ addGenericLegalOp<mlir::arith::SelectOp>(conversionTarget, typeConverter);
+ patterns.insert<SelectOpConversion>(typeConverter, context, affinityAnalysis);
+
+ addGenericLegalOp<mlir::scf::IfOp>(conversionTarget, typeConverter);
+ addGenericLegalOp<mlir::scf::ForOp>(conversionTarget, typeConverter);
+ addGenericLegalOp<mlir::scf::WhileOp>(conversionTarget, typeConverter);
+ addGenericLegalOp<mlir::scf::ConditionOp>(conversionTarget, typeConverter);
+ addGenericLegalOp<mlir::scf::YieldOp>(conversionTarget, typeConverter);
+ patterns
+ .insert<ScfConditionOpConversion, ScfIfOpConversion, ScfForOpConversion,
+ ScfWhileOpConversion, ScfYieldOpConversion>(
+ typeConverter, context, affinityAnalysis);
}
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.h b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.h
index 3314dcf..112e602 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.h
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.h
@@ -10,6 +10,10 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
+namespace mlir::iree_compiler::IREE::Stream {
+class AffinityAnalysis;
+} // namespace mlir::iree_compiler::IREE::Stream
+
namespace mlir::iree_compiler {
// Populates conversion patterns that perform standard/builtin->stream
@@ -17,7 +21,9 @@
// provided |typeConverter|.
void populateStandardToStreamConversionPatterns(
MLIRContext *context, ConversionTarget &conversionTarget,
- TypeConverter &typeConverter, RewritePatternSet &patterns);
+ TypeConverter &typeConverter,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis,
+ RewritePatternSet &patterns);
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp
index 4d1fa5f..35e1ca8 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp
@@ -67,8 +67,9 @@
}
};
-struct CallOpConversion : public OpConversionPattern<IREE::Util::CallOp> {
- using OpConversionPattern::OpConversionPattern;
+struct CallOpConversion
+ : public AffinityAwareConversionPattern<IREE::Util::CallOp> {
+ using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
LogicalResult
matchAndRewrite(IREE::Util::CallOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -122,8 +123,9 @@
}
};
-struct ReturnOpConversion : public OpConversionPattern<IREE::Util::ReturnOp> {
- using OpConversionPattern::OpConversionPattern;
+struct ReturnOpConversion
+ : public AffinityAwareConversionPattern<IREE::Util::ReturnOp> {
+ using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
LogicalResult
matchAndRewrite(IREE::Util::ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -142,6 +144,7 @@
struct ExpandedGlobalResource {
IREE::Util::GlobalOp resourceOp;
IREE::Util::GlobalOp resourceSizeOp;
+ IREE::Stream::AffinityAttr affinityAttr;
};
struct GlobalExpansionState {
@@ -163,13 +166,16 @@
public:
BaseGlobalConversionPattern(
std::shared_ptr<GlobalExpansionState> expansionState,
- TypeConverter &typeConverter, MLIRContext *context,
+ TypeConverter &typeConverter,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis, MLIRContext *context,
PatternBenefit benefit = 1)
: OpConversionPattern<T>(typeConverter, context, benefit),
- expansionState(std::move(expansionState)) {}
+ expansionState(std::move(expansionState)),
+ affinityAnalysis(affinityAnalysis) {}
protected:
mutable std::shared_ptr<GlobalExpansionState> expansionState;
+ IREE::Stream::AffinityAnalysis *affinityAnalysis;
};
struct GlobalOpExpansion
@@ -230,9 +236,13 @@
globalOp.getIsMutable(), indexType, std::optional<TypedAttr>{});
resourceSizeOp.setVisibility(globalOp.getVisibility());
+ // Resolve the affinity of the global.
+ // We require this to be a single value today that is usually chosen from
+ // consumers (we take the hit on transfer from producers if needed).
+ auto affinityAttr = tryLookupGlobalAffinity(globalOp, affinityAnalysis);
+
// Materialize the initializer if we need to setup a tensor-like constant.
if (tensorInitializerRequired) {
- auto affinityAttr = IREE::Stream::AffinityAttr::lookup(globalOp);
auto initializerOp =
rewriter.create<IREE::Util::InitializerOp>(globalOp.getLoc());
auto *entryBlock = rewriter.createBlock(&initializerOp.getBody());
@@ -265,6 +275,7 @@
expansionState->globalMap[globalOp.getSymName()] = ExpandedGlobalResource{
resourceOp,
resourceSizeOp,
+ affinityAttr,
};
return success();
@@ -289,7 +300,7 @@
auto &expandedGlobal = expandedGlobalIt->getSecond();
// Insert a load/transfer to the unknown resource lifetime.
- auto unknownType = IREE::Stream::ResourceType::get(rewriter.getContext());
+ auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
auto resource =
rewriter
.create<IREE::Util::GlobalLoadOp>(
@@ -303,8 +314,8 @@
.getResult();
rewriter.replaceOpWithNewOp<IREE::Stream::AsyncTransferOp>(
loadOp, unknownType, resource, resourceSize, resourceSize,
- /*source_affinity=*/nullptr,
- /*result_affinity=*/nullptr);
+ /*source_affinity=*/expandedGlobal.affinityAttr,
+ /*result_affinity=*/expandedGlobal.affinityAttr);
return success();
}
@@ -330,12 +341,14 @@
// Insert a transfer/store to the global with unknown lifetime. Lifetime
// refinement will make this go away if possible.
auto value =
- consumeTensorOperand(storeOp.getLoc(), adaptor.getValue(), rewriter);
+ resolveTensorOperand(storeOp.getLoc(), storeOp.getValue(),
+ adaptor.getValue(), affinityAnalysis, rewriter);
assert(expandedGlobal.resourceOp && "Missing resource op");
auto transferOp = rewriter.create<IREE::Stream::AsyncTransferOp>(
storeOp.getLoc(), expandedGlobal.resourceOp.getType(), value.resource,
- value.resourceSize, value.resourceSize, /*source_affinity=*/nullptr,
- /*result_affinity=*/nullptr);
+ value.resourceSize, value.resourceSize,
+ /*source_affinity=*/value.affinity,
+ /*result_affinity=*/expandedGlobal.affinityAttr);
rewriter.replaceOpWithNewOp<IREE::Util::GlobalStoreOp>(
storeOp, transferOp.getResult(),
expandedGlobal.resourceOp.getSymName());
@@ -347,30 +360,59 @@
}
};
+struct OptimizationBarrierOpConversion
+ : public AffinityAwareConversionPattern<IREE::Util::OptimizationBarrierOp> {
+ using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
+ LogicalResult
+ matchAndRewrite(IREE::Util::OptimizationBarrierOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ SmallVector<Value> newOperands;
+ for (auto [originalOperand, convertedOperand] :
+ llvm::zip_equal(op.getOperands(), adaptor.getOperands())) {
+ if (isa<TensorType>(convertedOperand.getType())) {
+ newOperands.push_back(resolveTensorOperand(op.getLoc(), originalOperand,
+ convertedOperand, rewriter)
+ .resource);
+ } else {
+ newOperands.push_back(convertedOperand);
+ }
+ }
+ rewriter.replaceOpWithNewOp<IREE::Util::OptimizationBarrierOp>(op,
+ newOperands);
+ return success();
+ }
+};
+
} // namespace
-void populateUtilToStreamConversionPatterns(MLIRContext *context,
- TypeConverter &typeConverter,
- RewritePatternSet &patterns) {
- patterns
- .insert<FuncOpSignatureConversion, CallOpConversion, ReturnOpConversion>(
- typeConverter, context);
+void populateUtilToStreamConversionPatterns(
+ MLIRContext *context, TypeConverter &typeConverter,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis,
+ RewritePatternSet &patterns) {
+ patterns.insert<FuncOpSignatureConversion>(typeConverter, context);
+ patterns.insert<CallOpConversion, ReturnOpConversion>(typeConverter, context,
+ affinityAnalysis);
auto expansionState = std::make_shared<GlobalExpansionState>();
// TODO(#7432): add indirect global expansion support to streams.
patterns
.insert<GlobalOpExpansion, GlobalLoadOpExpansion, GlobalStoreOpExpansion>(
- expansionState, typeConverter, context);
+ expansionState, typeConverter, affinityAnalysis, context);
patterns.add<GenericConvertTypesPattern<IREE::Util::GlobalOp>,
GenericConvertTypesPattern<IREE::Util::GlobalLoadOp>,
GenericConvertTypesPattern<IREE::Util::GlobalStoreOp>>(
typeConverter, context);
+
+ patterns.insert<OptimizationBarrierOpConversion>(typeConverter, context,
+ affinityAnalysis,
+ /*benefit=*/2);
}
-void populateUtilToStreamConversionPatterns(MLIRContext *context,
- ConversionTarget &conversionTarget,
- TypeConverter &typeConverter,
- RewritePatternSet &patterns) {
+void populateUtilToStreamConversionPatterns(
+ MLIRContext *context, ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis,
+ RewritePatternSet &patterns) {
typeConverter.addConversion([=](IREE::Util::PtrType type,
SmallVectorImpl<Type> &resultTypes) {
// Expand pointers to tensors to [resource, sizeof resource] pointers.
@@ -432,7 +474,8 @@
return typeConverter.isLegal(op.getResultTypes());
});
- populateUtilToStreamConversionPatterns(context, typeConverter, patterns);
+ populateUtilToStreamConversionPatterns(context, typeConverter,
+ affinityAnalysis, patterns);
}
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.h b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.h
index 5673c74..56fcca2 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.h
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.h
@@ -11,18 +11,24 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
+namespace mlir::iree_compiler::IREE::Stream {
+class AffinityAnalysis;
+} // namespace mlir::iree_compiler::IREE::Stream
+
namespace mlir::iree_compiler {
// Populates conversion patterns that perform util->stream conversion.
// These patterns ensure that nested types are run through the provided
// |typeConverter|.
-void populateUtilToStreamConversionPatterns(MLIRContext *context,
- TypeConverter &typeConverter,
- RewritePatternSet &patterns);
-void populateUtilToStreamConversionPatterns(MLIRContext *context,
- ConversionTarget &conversionTarget,
- TypeConverter &typeConverter,
- RewritePatternSet &patterns);
+void populateUtilToStreamConversionPatterns(
+ MLIRContext *context, TypeConverter &typeConverter,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis,
+ RewritePatternSet &patterns);
+void populateUtilToStreamConversionPatterns(
+ MLIRContext *context, ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis,
+ RewritePatternSet &patterns);
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamBase.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamBase.td
index bfcca44..4c8fb8d 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamBase.td
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamBase.td
@@ -504,6 +504,7 @@
}
def Stream_Resource : TypeDef<Stream_Dialect, "Resource", [
+ Stream_AffinityType,
Util_ReferenceType,
Util_SizeAwareType,
DeclareTypeInterfaceMethods<Util_GlobalType, [
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.td
index f34003d..2b686b7 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.td
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.td
@@ -97,7 +97,13 @@
// Returns an affinity active for the given operation.
// This will recursively walk parent operations until one with the
// `stream.affinity` attribute is found.
- static AffinityAttr lookup(Operation *op);
+ static AffinityAttr lookup(Operation *fromOp);
+
+ // Returns an affinity active for the given operation or the fallback
+ // default if none is specified.
+ // This will recursively walk parent operations until one with the
+ // `stream.affinity` attribute is found.
+ static AffinityAttr lookupOrDefault(Operation *fromOp);
// TODO(benvanik): replace with more fine-grained compatibility checks.
// "Compatible" can mean a lot of things: are they cache-coherent, are they
@@ -116,10 +122,24 @@
}
//===----------------------------------------------------------------------===//
+// IREE::Stream::AffinityTypeInterface
+//===----------------------------------------------------------------------===//
+
+def Stream_AffinityType : TypeInterface<"AffinityTypeInterface"> {
+ let cppNamespace = "::mlir::iree_compiler::IREE::Stream";
+
+ let description = [{
+ Indicates a type represents a resource that has its affinity tracked.
+ }];
+}
+
+//===----------------------------------------------------------------------===//
// IREE::Stream::AffinityOpInterface
//===----------------------------------------------------------------------===//
def Stream_AffinityOp : OpInterface<"AffinityOpInterface"> {
+ let cppNamespace = "::mlir::iree_compiler::IREE::Stream";
+
let description = [{
TBD. Used to denote a stream affinity for ops and specify the kind of
environment the ops are expected run in.
@@ -142,13 +162,27 @@
>,
InterfaceMethod<
/*desc=*/[{
+ Returns true if the operands and results should be pinned to the
+ affinity of the op. This overrides all automatic placement logic.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"pinsValueAffinity",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return false;
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
Returns the stream affinity for the op, indicating where it should run.
}],
/*retTy=*/"IREE::Stream::AffinityAttr",
- /*methodName=*/"getAffinity",
+ /*methodName=*/"getAffinityAttr",
/*args=*/(ins),
- /*methodBody=*/[{
- return dyn_cast_or_null<IREE::Stream::AffinityAttr>($_self->getAttr("affinity"));
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return dyn_cast_or_null<IREE::Stream::AffinityAttr>($_op->getAttr("affinity"));
}]
>,
InterfaceMethod<
@@ -156,11 +190,12 @@
Sets the stream affinity for the op, indicating where it should run.
}],
/*retTy=*/"void",
- /*methodName=*/"setAffinity",
+ /*methodName=*/"setAffinityAttr",
/*args=*/(ins "IREE::Stream::AffinityAttr":$value),
- /*methodBody=*/[{
- if (value) $_self->setAttr("affinity", value);
- else $_self->removeAttr("affinity");
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ if (value) $_op->setAttr("affinity", value);
+ else $_op->removeAttr("affinity");
}]
>,
];
@@ -171,6 +206,8 @@
//===----------------------------------------------------------------------===//
def Stream_StreamableOp : OpInterface<"StreamableOpInterface"> {
+ let cppNamespace = "::mlir::iree_compiler::IREE::Stream";
+
let description = [{
Interface for ops that can be asynchronous executed in a streaming context.
}];
@@ -210,6 +247,8 @@
//===----------------------------------------------------------------------===//
def Stream_AsyncAccessOp : OpInterface<"AsyncAccessOpInterface"> {
+ let cppNamespace = "::mlir::iree_compiler::IREE::Stream";
+
let description = [{
Interface for stream.async.* ops that access subviews of resources.
This allows for some basic analysis and is only valid prior to allocation.
@@ -238,6 +277,8 @@
//===----------------------------------------------------------------------===//
def Stream_SubviewEffectOp : OpInterface<"SubviewEffectOpInterface"> {
+ let cppNamespace = "::mlir::iree_compiler::IREE::Stream";
+
let description = [{
Interface for ops that operate on subviews of resources used to query the
memory effects for subviews on operands.
@@ -256,6 +297,8 @@
//===----------------------------------------------------------------------===//
def Stream_TimelineOp : OpInterface<"TimelineOpInterface"> {
+ let cppNamespace = "::mlir::iree_compiler::IREE::Stream";
+
let description = [{
Interface for ops that operate in an ordered sequence defined by timepoints.
}];
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
index dbe9abc..9360eea 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
@@ -1183,7 +1183,7 @@
return failure();
// Definitely empty if here.
- auto resultSize = rewriter.createOrFold<IREE::Stream::TensorSizeOfOp>(
+ Value resultSize = rewriter.create<IREE::Stream::TensorSizeOfOp>(
constantOp.getLoc(), rewriter.getIndexType(),
TypeAttr::get(constantOp.getResultEncoding()),
constantOp.getResultEncodingDims(), constantOp.getAffinityAttr());
@@ -1219,7 +1219,7 @@
}
auto resultType = IREE::Stream::ResourceType::get(constantOp.getContext());
- auto resultSize = rewriter.createOrFold<IREE::Stream::TensorSizeOfOp>(
+ Value resultSize = rewriter.create<IREE::Stream::TensorSizeOfOp>(
constantOp.getLoc(), rewriter.getIndexType(),
TypeAttr::get(constantOp.getResultEncoding()),
constantOp.getResultEncodingDims(), constantOp.getAffinityAttr());
@@ -1231,7 +1231,7 @@
constantOp, constantOp.getResult().getType(), splatOp.getResult(),
resultSize, resultSize,
/*source_affinity=*/constantOp.getAffinityAttr(),
- /*result_affinity=*/nullptr);
+ /*result_affinity=*/constantOp.getAffinityAttr());
return success();
}
};
@@ -1452,9 +1452,9 @@
LogicalResult matchAndRewrite(AsyncConstantOp constantOp,
PatternRewriter &rewriter) const override {
auto value = dyn_cast<ElementsAttr>(constantOp.getValue());
- if (!value || !value.isSplat())
+ if (!value || !value.isSplat()) {
return failure();
-
+ }
auto splatElementAttr =
llvm::dyn_cast<SplatElementsAttr>(value).getSplatValue<TypedAttr>();
auto splatValue = rewriter.create<arith::ConstantOp>(
@@ -3263,16 +3263,8 @@
if (dominanceInfo.dominates(use.getOwner(), op))
continue;
auto awaitOp = dyn_cast<TimepointAwaitOp>(use.getOwner());
- if (!awaitOp ||
- !AffinityAttr::areCompatible(
- llvm::dyn_cast_if_present<AffinityAttr>(op.getAffinityAttr()),
- llvm::dyn_cast_if_present<AffinityAttr>(
- awaitOp.getAffinityAttr()))) {
- // Can't combine if the affinities differ as the wait semantics are
- // load-bearing. Probably. They really shouldn't be.
- // TODO(benvanik): remove affinity from stream.timepoint.await.
+ if (!awaitOp)
continue;
- }
// Ensure all dependencies of the await op are available.
if (!areAllOperandsDefinedBy(awaitOp, op, dominanceInfo)) {
// One or more operands is defined after op so we can't merge.
@@ -3299,9 +3291,6 @@
}
auto newOp = rewriter.create<TimepointAwaitOp>(
op.getLoc(), newOperands, newOperandSizes, op.getAwaitTimepoint());
- if (op.getAffinity().has_value()) {
- newOp.setAffinityAttr(op.getAffinityAttr());
- }
// Replace covered ops with the new results.
unsigned resultIdx = 0;
@@ -3349,9 +3338,6 @@
// Create replacement op with deduped operands/results.
auto newOp = rewriter.create<IREE::Stream::TimepointAwaitOp>(
op.getLoc(), newOperands, newOperandSizes, op.getAwaitTimepoint());
- if (op.getAffinity().has_value()) {
- newOp.setAffinityAttr(op.getAffinityAttr());
- }
// Replace all duplicate results with the base results.
for (auto &replacement : replacements) {
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
index d9da282..358d1fd 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
@@ -2019,6 +2019,60 @@
return success();
}
+IREE::Stream::AffinityAttr AsyncTransferOp::getAffinityAttr() {
+ auto sourceType = cast<IREE::Stream::ResourceType>(getSource().getType());
+ auto resultType = cast<IREE::Stream::ResourceType>(getResult().getType());
+ if (sourceType.getLifetime() == IREE::Stream::Lifetime::Staging &&
+ resultType.getLifetime() == IREE::Stream::Lifetime::Staging) {
+ // TODO(multi-device): figure out how to model staging->staging transfers.
+ return getSourceAffinityAttr();
+ } else if (sourceType.getLifetime() == IREE::Stream::Lifetime::Staging) {
+ // If source is staging then the op should execute on the consumer.
+ return getResultAffinityAttr();
+ } else if (resultType.getLifetime() == IREE::Stream::Lifetime::Staging) {
+ // If result is staging then the op should execute on the producer.
+ return getSourceAffinityAttr();
+ } else {
+ // Default to result affinity.
+ return getResultAffinityAttr();
+ }
+}
+
+void AsyncTransferOp::setAffinityAttr(IREE::Stream::AffinityAttr value) {
+ auto sourceType = cast<IREE::Stream::ResourceType>(getSource().getType());
+ auto resultType = cast<IREE::Stream::ResourceType>(getResult().getType());
+ if (sourceType.getLifetime() == IREE::Stream::Lifetime::Staging &&
+ resultType.getLifetime() == IREE::Stream::Lifetime::Staging) {
+ // TODO(multi-device): figure out how to model staging->staging transfers.
+ if (value) {
+ setSourceAffinityAttr(value);
+ } else {
+ removeSourceAffinityAttr();
+ }
+ } else if (sourceType.getLifetime() == IREE::Stream::Lifetime::Staging) {
+ // If source is staging then the op should execute on the consumer.
+ if (value) {
+ setResultAffinityAttr(value);
+ } else {
+ removeResultAffinityAttr();
+ }
+ } else if (resultType.getLifetime() == IREE::Stream::Lifetime::Staging) {
+ // If result is staging then the op should execute on the producer.
+ if (value) {
+ setSourceAffinityAttr(value);
+ } else {
+ removeSourceAffinityAttr();
+ }
+ } else {
+ // Default to result affinity.
+ if (value) {
+ setResultAffinityAttr(value);
+ } else {
+ removeResultAffinityAttr();
+ }
+ }
+}
+
void AsyncTransferOp::getAsyncAccessRanges(
SmallVectorImpl<AsyncAccessRange> &ranges) {
ranges.push_back({ResourceAccessBitfield::Read, getSource(), Value{},
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
index c994e65..871e3bb 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
@@ -86,10 +86,7 @@
let opDocGroup = OpGroupResourceOps in {
def Stream_ResourceAllocOp : Stream_Op<"resource.alloc", [
- DeclareOpInterfaceMethods<Stream_AffinityOp, [
- "getAffinity",
- "setAffinity",
- ]>,
+ Stream_AffinityOp,
Util_SizeAwareOp,
AlwaysSpeculatable,
MemoryEffects<[MemAlloc]>,
@@ -123,8 +120,8 @@
);
let assemblyFormat = [{
- (`on` `(` $affinity^ `)`)?
(`uninitialized` $uninitialized^)?
+ (`on` `(` $affinity^ `)`)?
attr-dict `:`
type($result) `{` $storage_size `}`
}];
@@ -148,10 +145,7 @@
}
def Stream_ResourceAllocaOp : Stream_Op<"resource.alloca", [
- DeclareOpInterfaceMethods<Stream_AffinityOp, [
- "getAffinity",
- "setAffinity",
- ]>,
+ Stream_AffinityOp,
Stream_TimelineOp,
Util_SizeAwareOp,
AlwaysSpeculatable,
@@ -209,10 +203,7 @@
}
def Stream_ResourceDeallocaOp : Stream_Op<"resource.dealloca", [
- DeclareOpInterfaceMethods<Stream_AffinityOp, [
- "getAffinity",
- "setAffinity",
- ]>,
+ Stream_AffinityOp,
Stream_TimelineOp,
Util_SizeAwareOp,
MemoryEffects<[MemFree]>,
@@ -645,10 +636,7 @@
def Stream_ParameterLoadOp : Stream_PureOp<"parameter.load", [
AttrSizedOperandSegments,
AllTypesMatch<["results"]>,
- DeclareOpInterfaceMethods<Stream_AffinityOp, [
- "getAffinity",
- "setAffinity",
- ]>,
+ Stream_AffinityOp,
Stream_CmdPhaseOp,
Stream_TimelineOp,
Util_SizeAwareOp,
@@ -702,10 +690,7 @@
}
def Stream_ParameterReadOp : Stream_Op<"parameter.read", [
- DeclareOpInterfaceMethods<Stream_AffinityOp, [
- "getAffinity",
- "setAffinity",
- ]>,
+ Stream_AffinityOp,
Stream_CmdPhaseOp,
Stream_TimelineOp,
Util_SizeAwareOp,
@@ -757,10 +742,7 @@
}
def Stream_ParameterWriteOp : Stream_Op<"parameter.write", [
- DeclareOpInterfaceMethods<Stream_AffinityOp, [
- "getAffinity",
- "setAffinity",
- ]>,
+ Stream_AffinityOp,
Stream_CmdPhaseOp,
Stream_TimelineOp,
Util_SizeAwareOp,
@@ -813,10 +795,7 @@
def Stream_ParameterGatherOp : Stream_Op<"parameter.gather", [
AttrSizedOperandSegments,
- DeclareOpInterfaceMethods<Stream_AffinityOp, [
- "getAffinity",
- "setAffinity",
- ]>,
+ Stream_AffinityOp,
Stream_CmdPhaseOp,
Stream_TimelineOp,
Util_SizeAwareOp,
@@ -872,10 +851,7 @@
def Stream_ParameterScatterOp : Stream_Op<"parameter.scatter", [
AttrSizedOperandSegments,
- DeclareOpInterfaceMethods<Stream_AffinityOp, [
- "getAffinity",
- "setAffinity",
- ]>,
+ Stream_AffinityOp,
Stream_CmdPhaseOp,
Stream_TimelineOp,
Util_SizeAwareOp,
@@ -982,10 +958,7 @@
}
def Stream_FileReadOp : Stream_Op<"file.read", [
- DeclareOpInterfaceMethods<Stream_AffinityOp, [
- "getAffinity",
- "setAffinity",
- ]>,
+ Stream_AffinityOp,
Stream_CmdPhaseOp,
Stream_TimelineOp,
Util_SizeAwareOp,
@@ -1040,10 +1013,7 @@
}
def Stream_FileWriteOp : Stream_Op<"file.write", [
- DeclareOpInterfaceMethods<Stream_AffinityOp, [
- "getAffinity",
- "setAffinity",
- ]>,
+ Stream_AffinityOp,
Stream_CmdPhaseOp,
Stream_TimelineOp,
Util_SizeAwareOp,
@@ -1550,7 +1520,7 @@
let assemblyFormat = [{
(`on` `(` $affinity^ `)`)?
- $value `,` $target `[` $start_indices `for` $lengths `]` `:`
+ $value `,` $target (`[` $start_indices `for` $lengths^ `]`)? `:`
type($value)
`->`
$target_encoding (`` `{` $target_encoding_dims^ `}`)?
@@ -1772,21 +1742,18 @@
} // OpGroupTensorOps
//===----------------------------------------------------------------------===//
-// Resource transfer ops
+// Async (stream.async*) ops
//===----------------------------------------------------------------------===//
-def OpGroupResourceTransferOps : OpDocGroup {
- let summary = "Resource transfer ops";
+def OpGroupAsyncOps : OpDocGroup {
+ let summary = "Async ops";
let description = "";
}
-let opDocGroup = OpGroupResourceTransferOps in {
+let opDocGroup = OpGroupAsyncOps in {
def Stream_AsyncAllocaOp : Stream_Op<"async.alloca", [
- DeclareOpInterfaceMethods<Stream_AffinityOp, [
- "getAffinity",
- "setAffinity",
- ]>,
+ Stream_AffinityOp,
Stream_AsyncPhaseOp,
DeclareOpInterfaceMethods<Stream_StreamableOp, [
"isMetadata",
@@ -2250,7 +2217,10 @@
}
def Stream_AsyncTransferOp : Stream_Op<"async.transfer", [
- Stream_AffinityOp,
+ DeclareOpInterfaceMethods<Stream_AffinityOp, [
+ "getAffinityAttr",
+ "setAffinityAttr",
+ ]>,
Stream_AsyncPhaseOp,
Stream_StreamableOp,
DeclareOpInterfaceMethods<Stream_AsyncAccessOp, [
@@ -2460,7 +2430,7 @@
let hasCanonicalizer = 1;
}
-} // OpGroupResourceTransferOps
+} // OpGroupAsyncOps
//===----------------------------------------------------------------------===//
// Async control flow ops
@@ -2855,15 +2825,15 @@
} // OpGroupAsyncControlFlowOps
//===----------------------------------------------------------------------===//
-// Explicit command ops
+// Explicit command (stream.cmd.*) ops
//===----------------------------------------------------------------------===//
-def OpGroupExplicitCommandOps : OpDocGroup {
+def OpGroupCmdOps : OpDocGroup {
let summary = "Explicit command ops";
let description = "";
}
-let opDocGroup = OpGroupExplicitCommandOps in {
+let opDocGroup = OpGroupCmdOps in {
def Stream_CmdFlushOp : Stream_Op<"cmd.flush", [
Stream_CmdPhaseOp,
@@ -3531,7 +3501,7 @@
let hasCanonicalizer = 1;
}
-} // OpGroupExplicitCommandOps
+} // OpGroupCmdOps
//===----------------------------------------------------------------------===//
// Synchronization ops
@@ -3753,7 +3723,6 @@
def Stream_TimepointAwaitOp : Stream_PureOp<"timepoint.await", [
AttrSizedOperandSegments,
- Stream_AffinityOp,
Stream_TimelineOp,
Util_SizeAwareOp,
DeclareOpInterfaceMethods<Util_TiedOpInterface, [
@@ -3777,8 +3746,7 @@
Stream_StagingResource,
]>>:$resource_operands,
Variadic<Stream_Size>:$resource_operand_sizes,
- Stream_Timepoint:$await_timepoint,
- OptionalAttr<Stream_AffinityAttr>:$affinity
+ Stream_Timepoint:$await_timepoint
);
let results = (outs
Variadic<AnyTypeOf<[
@@ -3788,7 +3756,6 @@
);
let assemblyFormat = [{
- (`on` `(` $affinity^ `)`)?
$await_timepoint `=` `` `>`
$resource_operands `:`
custom<ShapedTypeList>(type($resource_operands),
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp
index 7ce79c9..82b8609 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp
@@ -274,7 +274,7 @@
return attr;
// See if the affinity specified provides a resource configuration.
if (auto affinityOp = llvm::dyn_cast<AffinityOpInterface>(op)) {
- auto affinityAttr = affinityOp.getAffinity();
+ auto affinityAttr = affinityOp.getAffinityAttr();
if (affinityAttr) {
auto attr = affinityAttr.getResourceConfigAttr();
if (attr)
@@ -335,20 +335,40 @@
// #stream.affinity
//===----------------------------------------------------------------------===//
-AffinityAttr AffinityAttr::lookup(Operation *op) {
- auto attrId = StringAttr::get(op->getContext(), "stream.affinity");
- while (op) {
- if (auto affinityOp = llvm::dyn_cast<AffinityOpInterface>(op)) {
- auto affinity = affinityOp.getAffinity();
- if (affinity)
+// static
+AffinityAttr AffinityAttr::lookup(Operation *fromOp) {
+ auto attrId = StringAttr::get(fromOp->getContext(), "stream.affinity");
+ while (fromOp) {
+ if (auto affinityOp = llvm::dyn_cast<AffinityOpInterface>(fromOp)) {
+ if (auto affinity = affinityOp.getAffinityAttr()) {
return affinity;
+ }
}
- auto attr = op->getAttrOfType<AffinityAttr>(attrId);
- if (attr)
+ if (auto attr = fromOp->getAttrOfType<AffinityAttr>(attrId)) {
return attr;
- op = op->getParentOp();
+ }
+ fromOp = fromOp->getParentOp();
}
- return {}; // No affinity found; let caller decide what to do.
+ // No affinity found; let caller decide what to do.
+ return {};
+}
+
+// static
+AffinityAttr AffinityAttr::lookupOrDefault(Operation *fromOp) {
+ if (auto affinityAttr = AffinityAttr::lookup(fromOp)) {
+ return affinityAttr; // found a specified affinity
+ }
+ auto attrId =
+ StringAttr::get(fromOp->getContext(), "stream.affinity.default");
+ while (fromOp) {
+ if (auto affinityAttr =
+ fromOp->getAttrOfType<IREE::Stream::AffinityAttr>(attrId)) {
+ return affinityAttr;
+ }
+ fromOp = fromOp->getParentOp();
+ }
+ // No affinity or default found; let caller decide what to do.
+ return {};
}
// static
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.h b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.h
index 42b8424..d69e226 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.h
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.h
@@ -69,9 +69,7 @@
#include "iree/compiler/Dialect/Stream/IR/StreamAttrInterfaces.h.inc" // IWYU pragma: export
-namespace mlir::iree_compiler::IREE::Stream {
#include "iree/compiler/Dialect/Stream/IR/StreamTypeInterfaces.h.inc" // IWYU pragma: export
-} // namespace mlir::iree_compiler::IREE::Stream
// clang-format off: must be included after all LLVM/MLIR headers.
#define GET_TYPEDEF_CLASSES
@@ -99,8 +97,12 @@
const AsyncAccessRange &rhs);
};
+} // namespace mlir::iree_compiler::IREE::Stream
+
#include "iree/compiler/Dialect/Stream/IR/StreamOpInterfaces.h.inc" // IWYU pragma: export
+namespace mlir::iree_compiler::IREE::Stream {
+
//===----------------------------------------------------------------------===//
// custom<ParameterReference>($scope, $key)
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_ops.mlir
index 50c7c26..2d33e5e 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_ops.mlir
@@ -82,6 +82,8 @@
// This covers all_gather, all_reduce, and reduce_scatter variants.
+util.global private @device : !hal.device
+
// CHECK-LABEL: @asyncCollectiveAllGather
util.func private @asyncCollectiveAllGather(
// CHECK-SAME: %[[CHANNEL:.+]]: !stream.channel,
@@ -95,8 +97,8 @@
%recv = stream.async.alloca : !stream.resource<*>{%recv_size}
// CHECK: = stream.async.collective<all_gather : f32>[%[[COUNT]]]
%0 = stream.async.collective<all_gather : f32>[%count]
- // CHECK-SAME: on(#hal.affinity.queue<[0]>) channel(%[[CHANNEL]])
- on(#hal.affinity.queue<[0]>) channel(%channel)
+ // CHECK-SAME: on(#hal.device.affinity<@device>) channel(%[[CHANNEL]])
+ on(#hal.device.affinity<@device>) channel(%channel)
// CHECK-SAME: %[[SEND]][%c0 to %[[SEND_SIZE]] for %[[SEND_SIZE]]],
%send[%c0 to %send_size for %send_size],
// CHECK-SAME: %[[RECV]][%c0 to %[[RECV_SIZE]] for %[[RECV_SIZE]]] :
@@ -110,6 +112,8 @@
// This covers broadcast and reduce variants.
+util.global private @device : !hal.device
+
// CHECK-LABEL: @asyncCollectiveBroadcast
util.func private @asyncCollectiveBroadcast(
// CHECK-SAME: %[[CHANNEL:.+]]: !stream.channel,
@@ -125,8 +129,8 @@
%recv = stream.async.alloca : !stream.resource<*>{%recv_size}
// CHECK: = stream.async.collective<broadcast : f32>[%[[COUNT]]]
%0 = stream.async.collective<broadcast : f32>[%count]
- // CHECK-SAME: on(#hal.affinity.queue<[0]>) channel(%[[CHANNEL]]) source(%[[RANK]])
- on(#hal.affinity.queue<[0]>) channel(%channel) source(%rank)
+ // CHECK-SAME: on(#hal.device.affinity<@device>) channel(%[[CHANNEL]]) source(%[[RANK]])
+ on(#hal.device.affinity<@device>) channel(%channel) source(%rank)
// CHECK-SAME: %[[SEND]][%c0 to %[[SEND_SIZE]] for %[[SEND_SIZE]]],
%send[%c0 to %send_size for %send_size],
// CHECK-SAME: %[[RECV]][%c0 to %[[RECV_SIZE]] for %[[RECV_SIZE]]] :
@@ -147,10 +151,12 @@
// -----
+util.global private @device : !hal.device
+
// CHECK-LABEL: @asyncTransferAffinities
util.func private @asyncTransferAffinities(%arg0: !stream.resource<constant>, %arg1: index) -> !stream.resource<constant> {
- // CHECK: = stream.async.transfer %arg0 : !stream.resource<constant>{%arg1} from(#hal.affinity.queue<[0]>) -> to(#hal.affinity.queue<[1]>) !stream.resource<constant>{%arg1}
- %0 = stream.async.transfer %arg0 : !stream.resource<constant>{%arg1} from(#hal.affinity.queue<[0]>) -> to(#hal.affinity.queue<[1]>) !stream.resource<constant>{%arg1}
+ // CHECK: = stream.async.transfer %arg0 : !stream.resource<constant>{%arg1} from(#hal.device.affinity<@device, [0]>) -> to(#hal.device.affinity<@device, [1]>) !stream.resource<constant>{%arg1}
+ %0 = stream.async.transfer %arg0 : !stream.resource<constant>{%arg1} from(#hal.device.affinity<@device, [0]>) -> to(#hal.device.affinity<@device, [1]>) !stream.resource<constant>{%arg1}
util.return %0 : !stream.resource<constant>
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/channel_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/IR/test/channel_ops.mlir
index 486a03f..a465546 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/channel_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/channel_ops.mlir
@@ -1,10 +1,12 @@
// RUN: iree-opt --split-input-file %s | iree-opt --split-input-file | FileCheck %s
+util.global private @device : !hal.device
+
// CHECK-LABEL: @channel_create
// CHECK-SAME: (%[[RANK:.+]]: index, %[[COUNT:.+]]: index)
util.func private @channel_create(%rank: index, %count: index) {
- // CHECK: %channel = stream.channel.create on(#hal.affinity.queue<[0, 1]>) rank(%[[RANK]]) count(%[[COUNT]]) : !stream.channel
- %channel = stream.channel.create on(#hal.affinity.queue<[0, 1]>) rank(%rank) count(%count) : !stream.channel
+ // CHECK: %channel = stream.channel.create on(#hal.device.affinity<@device>) rank(%[[RANK]]) count(%[[COUNT]]) : !stream.channel
+ %channel = stream.channel.create on(#hal.device.affinity<@device>) rank(%rank) count(%count) : !stream.channel
util.return
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/context_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/IR/test/context_ops.mlir
index ab523ec..950643a 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/context_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/context_ops.mlir
@@ -1,12 +1,14 @@
// RUN: iree-opt --split-input-file %s | iree-opt --split-input-file | FileCheck %s
+util.global private @device : !hal.device
+
// CHECK-LABEL: @context_resolve
util.func private @context_resolve() {
// CHECK: = stream.context.resolve : !hal.allocator
%allocator = stream.context.resolve : !hal.allocator
- // CHECK: = stream.context.resolve on(#hal.affinity.queue<*>) : !hal.device, i64
- %device1, %queue_affinity_any = stream.context.resolve on(#hal.affinity.queue<*>) : !hal.device, i64
- // CHECK: = stream.context.resolve on(#hal.affinity.queue<[4, 5]>) : !hal.device, i64
- %device0, %queue_affinity_45 = stream.context.resolve on(#hal.affinity.queue<[4, 5]>) : !hal.device, i64
+ // CHECK: = stream.context.resolve on(#hal.device.affinity<@device>) : !hal.device, i64
+ %device1, %queue_affinity_any = stream.context.resolve on(#hal.device.affinity<@device>) : !hal.device, i64
+ // CHECK: = stream.context.resolve on(#hal.device.affinity<@device, [4, 5]>) : !hal.device, i64
+ %device0, %queue_affinity_45 = stream.context.resolve on(#hal.device.affinity<@device, [4, 5]>) : !hal.device, i64
util.return
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateAffinities.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateAffinities.cpp
new file mode 100644
index 0000000..62b9db2
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateAffinities.cpp
@@ -0,0 +1,127 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Stream/Analysis/Affinity.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
+#include "iree/compiler/Dialect/Stream/Transforms/Passes.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir::iree_compiler::IREE::Stream {
+
+#define GEN_PASS_DEF_ANNOTATEAFFINITIESPASS
+#include "iree/compiler/Dialect/Stream/Transforms/Passes.h.inc"
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// --iree-stream-annotate-affinities
+//===----------------------------------------------------------------------===//
+
+static void annotateOp(Operation *op,
+ ArrayRef<IREE::Stream::AffinityAttr> affinities) {
+ auto affinityOp = dyn_cast<IREE::Stream::AffinityOpInterface>(op);
+ if (!affinityOp || !affinityOp.requiresAffinity()) {
+ return;
+ }
+ if (!affinities.empty()) {
+ op->setAttr("stream.affinities",
+ ArrayAttr::get(op->getContext(),
+ llvm::to_vector_of<Attribute>(affinities)));
+ }
+}
+
+static void annotateGlobalOp(IREE::Util::GlobalOpInterface globalOp,
+ AffinityAnalysis &affinityAnalysis) {
+ if (!isa<IREE::Stream::AffinityTypeInterface>(globalOp.getGlobalType())) {
+ return;
+ }
+ SmallVector<IREE::Stream::AffinityAttr> affinities;
+ if (affinityAnalysis.tryLookupGlobalAffinity(globalOp, affinities)) {
+ annotateOp(globalOp, affinities);
+ }
+}
+
+static void annotateOperandsAndResults(Operation *op,
+ AffinityAnalysis &affinityAnalysis) {
+ auto emptyArray = ArrayAttr::get(op->getContext(), {});
+ SmallVector<Attribute> operandAttrs;
+ for (auto operand : op->getOperands()) {
+ if (isa<IREE::Stream::AffinityTypeInterface>(operand.getType())) {
+ SmallVector<IREE::Stream::AffinityAttr> affinities;
+ if (affinityAnalysis.tryLookupResourceAffinity(operand, affinities)) {
+ operandAttrs.push_back(ArrayAttr::get(
+ op->getContext(), llvm::to_vector_of<Attribute>(affinities)));
+ } else {
+ operandAttrs.push_back(emptyArray);
+ }
+ }
+ }
+ SmallVector<Attribute> resultAttrs;
+ for (auto result : op->getResults()) {
+ if (isa<IREE::Stream::AffinityTypeInterface>(result.getType())) {
+ SmallVector<IREE::Stream::AffinityAttr> affinities;
+ if (affinityAnalysis.tryLookupResourceAffinity(result, affinities)) {
+ resultAttrs.push_back(ArrayAttr::get(
+ op->getContext(), llvm::to_vector_of<Attribute>(affinities)));
+ } else {
+ resultAttrs.push_back(emptyArray);
+ }
+ }
+ }
+ if (!operandAttrs.empty()) {
+ op->setAttr("stream.affinities.operands",
+ ArrayAttr::get(op->getContext(), operandAttrs));
+ }
+ if (!resultAttrs.empty()) {
+ op->setAttr("stream.affinities.results",
+ ArrayAttr::get(op->getContext(), resultAttrs));
+ }
+}
+
+static void annotateFuncOp(FunctionOpInterface funcOp,
+ AffinityAnalysis &affinityAnalysis) {
+ funcOp.walk([&](Operation *op) {
+ SmallVector<IREE::Stream::AffinityAttr> affinities;
+ if (affinityAnalysis.tryLookupExecutionAffinity(op, affinities)) {
+ annotateOp(op, affinities);
+ }
+ annotateOperandsAndResults(op, affinityAnalysis);
+ });
+}
+
+struct AnnotateAffinitiesPass
+ : public IREE::Stream::impl::AnnotateAffinitiesPassBase<
+ AnnotateAffinitiesPass> {
+ void runOnOperation() override {
+ // Run affinity analysis on the whole module.
+ AffinityAnalysis affinityAnalysis(getOperation());
+ if (failed(affinityAnalysis.run())) {
+ return signalPassFailure();
+ }
+
+ // Annotate all ops with derived affinities.
+ for (auto &op : getOperation().getOps()) {
+ if (op.hasTrait<OpTrait::IREE::Util::ObjectLike>())
+ continue;
+ if (auto globalOp = dyn_cast<IREE::Util::GlobalOpInterface>(op)) {
+ annotateGlobalOp(globalOp, affinityAnalysis);
+ } else if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
+ annotateFuncOp(funcOp, affinityAnalysis);
+ }
+ }
+ }
+};
+
+} // namespace
+
+} // namespace mlir::iree_compiler::IREE::Stream
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel
index 6471943..d2f326c 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel
@@ -15,6 +15,7 @@
iree_compiler_cc_library(
name = "Transforms",
srcs = [
+ "AnnotateAffinities.cpp",
"AnnotateDispatchArguments.cpp",
"ConvertToStream.cpp",
"DumpStatistics.cpp",
@@ -37,6 +38,7 @@
"ScheduleConcurrency.cpp",
"ScheduleExecution.cpp",
"SpecializeDispatches.cpp",
+ "VerifyAffinities.cpp",
"VerifyAsyncAccessRanges.cpp",
"VerifyLowerings.cpp",
],
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt
index 4f1a114..5eb3d27 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt
@@ -16,6 +16,7 @@
HDRS
"Passes.h"
SRCS
+ "AnnotateAffinities.cpp"
"AnnotateDispatchArguments.cpp"
"ConvertToStream.cpp"
"DumpStatistics.cpp"
@@ -38,6 +39,7 @@
"ScheduleConcurrency.cpp"
"ScheduleExecution.cpp"
"SpecializeDispatches.cpp"
+ "VerifyAffinities.cpp"
"VerifyAsyncAccessRanges.cpp"
"VerifyLowerings.cpp"
DEPS
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp
index 11873a2..91f5c0f 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp
@@ -6,6 +6,7 @@
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
+#include "iree/compiler/Dialect/Stream/Analysis/Affinity.h"
#include "iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.h"
#include "iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.h"
#include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h"
@@ -39,79 +40,6 @@
namespace {
-// Builds a stream.tensor.import op that imports an external tensor value into
-// a stream resource. |consumingOps| will be populated with all ops that consume
-// the original |sourceTensor| and that should not be replaced with the returned
-// value.
-static Value buildTensorImportOp(Location loc, Value sourceTensor,
- Type targetType,
- SmallPtrSetImpl<Operation *> &consumingOps,
- OpBuilder &builder) {
- // Gather dynamic dimensions from the input value.
- auto dynamicDims =
- IREE::Util::buildDynamicDimsForValue(loc, sourceTensor, builder);
-
- // Compute the size of the tensor once in the stream resource.
- // This may differ from the external encoding of the tensor as imports are
- // a transfer operation that may need to reformat the tensor.
- auto encodingAttr = TypeAttr::get(sourceTensor.getType());
- auto resultSize = builder.createOrFold<IREE::Stream::TensorSizeOfOp>(
- loc, builder.getIndexType(), encodingAttr, dynamicDims,
- /*affinity=*/nullptr);
-
- // Associate the external SSA value, encoding, and shape information with the
- // stream resource. When lowering we'll then have all the metadata required
- // even after erasing it all on the resource.
- auto externalType = builder.getType<IREE::Stream::ResourceType>(
- IREE::Stream::Lifetime::External);
- auto importOp = builder.create<IREE::Stream::TensorImportOp>(
- loc, externalType, sourceTensor, encodingAttr, dynamicDims, resultSize,
- /*affinity=*/nullptr);
- consumingOps.insert(importOp);
-
- // If needed insert a transfer to the target lifetime.
- Value result = importOp.getResult();
- if (targetType != externalType) {
- result = builder
- .create<IREE::Stream::AsyncTransferOp>(
- loc, targetType, result, resultSize, resultSize,
- /*source_affinity=*/nullptr,
- /*result_affinity=*/nullptr)
- .getResult();
- }
-
- auto castOp = builder.create<mlir::UnrealizedConversionCastOp>(
- loc, sourceTensor.getType(), ValueRange{result, resultSize});
- consumingOps.insert(castOp);
- return castOp.getResult(0);
-}
-
-// Builds a stream.tensor.export op that exports a stream resource into an
-// external tensor value.
-static Value buildTensorExportOp(Location loc, Value sourceValue,
- TensorType targetType, ValueRange dynamicDims,
- OpBuilder &builder) {
- auto source = consumeTensorOperand(loc, sourceValue, builder);
-
- // If needed insert a transfer to external resource lifetime.
- auto externalType = builder.getType<IREE::Stream::ResourceType>(
- IREE::Stream::Lifetime::External);
- if (source.resource.getType() != externalType) {
- source.resource = builder.create<IREE::Stream::AsyncTransferOp>(
- loc, externalType, source.resource, source.resourceSize,
- source.resourceSize,
- /*source_affinity=*/nullptr,
- /*result_affinity=*/nullptr);
- }
-
- // Associate the stream resource and external encoding and shape information.
- auto newOp = builder.create<IREE::Stream::TensorExportOp>(
- loc, targetType, source.resource, TypeAttr::get(targetType), dynamicDims,
- source.resourceSize,
- /*affinity=*/nullptr);
- return newOp.getResult();
-}
-
// Returns true if |op| has tensor I/O that is not yet imported/exported using
// the stream ops that capture encodings and shapes.
static bool doesOperationNeedWrapping(Operation *op) {
@@ -123,8 +51,9 @@
operand.getDefiningOp());
}) ||
llvm::any_of(op->getResults(), [](Value result) {
- if (!isa<TensorType>(result.getType()))
+ if (!isa<TensorType>(result.getType())) {
return false;
+ }
return !llvm::all_of(result.getUsers(),
llvm::IsaPred<TensorImportOp>);
});
@@ -133,13 +62,19 @@
// Fallback handler for unknown ops taking/returning tensors that need to be
// marshaled into/outof stream resource types.
struct GenericResourcePattern : public ConversionPattern {
- GenericResourcePattern(MLIRContext *context, TypeConverter &converter)
- : ConversionPattern(converter, MatchAnyOpTypeTag(), 0, context) {}
+ GenericResourcePattern(MLIRContext *context, TypeConverter &converter,
+ IREE::Stream::AffinityAnalysis *affinityAnalysis)
+ : ConversionPattern(converter, MatchAnyOpTypeTag(), 0, context),
+ affinityAnalysis(affinityAnalysis) {}
+
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- if (!doesOperationNeedWrapping(op))
+ if (!doesOperationNeedWrapping(op)) {
return failure();
+ }
+
+ auto executionAffinityAttr = affinityAnalysis->inferExecutionAffinity(op);
// Export resources into tensor operands for the op to consume.
SmallVector<Value> newOperands;
@@ -154,10 +89,14 @@
auto tensorType = dyn_cast<TensorType>(oldOperand.getType());
assert(tensorType && "must have a tensor type to map to a resource");
+ auto exportAffinityAttr =
+ affinityAnalysis->lookupResourceAffinity(oldOperand);
auto dynamicDims = IREE::Util::buildDynamicDimsForValue(
op->getLoc(), oldOperand, rewriter);
newOperands.push_back(buildTensorExportOp(
- op->getLoc(), newOperand, tensorType, dynamicDims, rewriter));
+ op->getLoc(), oldOperand, newOperand, tensorType, dynamicDims,
+ exportAffinityAttr ? exportAffinityAttr : executionAffinityAttr,
+ rewriter));
}
rewriter.modifyOpInPlace(op, [&]() { op->setOperands(newOperands); });
@@ -165,46 +104,112 @@
rewriter.setInsertionPointAfter(op);
for (auto result : op->getResults()) {
auto tensorType = dyn_cast<TensorType>(result.getType());
- if (!tensorType)
+ if (!tensorType) {
continue;
+ }
+ auto importAffinityAttr =
+ affinityAnalysis->lookupResourceAffinity(result);
auto dynamicDims =
IREE::Util::buildDynamicDimsForValue(op->getLoc(), result, rewriter);
SmallPtrSet<Operation *, 4> consumingOps;
auto importedValue = buildTensorImportOp(
op->getLoc(), result, rewriter.getType<IREE::Stream::ResourceType>(),
- consumingOps, rewriter);
+ consumingOps,
+ importAffinityAttr ? importAffinityAttr : executionAffinityAttr,
+ rewriter);
result.replaceAllUsesExcept(importedValue, consumingOps);
}
return success();
}
-};
-namespace {
-struct OptimizationBarrierOpConversion
- : public OpConversionPattern<IREE::Util::OptimizationBarrierOp> {
- using OpConversionPattern<
- IREE::Util::OptimizationBarrierOp>::OpConversionPattern;
+ // Builds a stream.tensor.export op that exports a stream resource into an
+ // external tensor value.
+ Value buildTensorExportOp(Location loc, Value originalValue,
+ Value convertedValue, TensorType targetType,
+ ValueRange dynamicDims,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ OpBuilder &builder) const {
+ auto source =
+ transferTensorOperand(loc, originalValue, convertedValue,
+ executionAffinityAttr, affinityAnalysis, builder);
- LogicalResult
- matchAndRewrite(IREE::Util::OptimizationBarrierOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- SmallVector<Value> newOperands;
- for (Value v : adaptor.getOperands()) {
- if (isa<TensorType>(v.getType())) {
- newOperands.push_back(
- consumeTensorOperand(op.getLoc(), v, rewriter).resource);
- } else {
- newOperands.push_back(v);
- }
+ // If needed insert a transfer to external resource lifetime.
+ auto externalType = builder.getType<IREE::Stream::ResourceType>(
+ IREE::Stream::Lifetime::External);
+ if (source.resource.getType() != externalType) {
+ source.resource = builder.create<IREE::Stream::AsyncTransferOp>(
+ loc, externalType, source.resource, source.resourceSize,
+ source.resourceSize,
+ /*source_affinity=*/source.affinity,
+ /*result_affinity=*/executionAffinityAttr);
}
- rewriter.replaceOpWithNewOp<IREE::Util::OptimizationBarrierOp>(op,
- newOperands);
- return success();
+
+ // Associate the stream resource and external encoding and shape
+ // information.
+ auto newOp = builder.create<IREE::Stream::TensorExportOp>(
+ loc, targetType, source.resource, TypeAttr::get(targetType),
+ dynamicDims, source.resourceSize, executionAffinityAttr);
+ return newOp.getResult();
}
+
+ // Builds a stream.tensor.import op that imports an external tensor value into
+ // a stream resource. |consumingOps| will be populated with all ops that
+ // consume the original |sourceTensor| and that should not be replaced with
+ // the returned value.
+ Value buildTensorImportOp(Location loc, Value sourceTensor, Type targetType,
+ SmallPtrSetImpl<Operation *> &consumingOps,
+ IREE::Stream::AffinityAttr executionAffinityAttr,
+ OpBuilder &builder) const {
+ // Gather dynamic dimensions from the input value.
+ auto dynamicDims =
+ IREE::Util::buildDynamicDimsForValue(loc, sourceTensor, builder);
+
+ // Compute the size of the tensor once in the stream resource.
+ // This may differ from the external encoding of the tensor as imports are
+ // a transfer operation that may need to reformat the tensor.
+ auto encodingAttr = TypeAttr::get(sourceTensor.getType());
+ Value resultSize = builder.create<IREE::Stream::TensorSizeOfOp>(
+ loc, builder.getIndexType(), encodingAttr, dynamicDims,
+ executionAffinityAttr);
+
+ // Associate the external SSA value, encoding, and shape information with
+ // the stream resource. When lowering we'll then have all the metadata
+ // required even after erasing it all on the resource.
+ auto externalType = builder.getType<IREE::Stream::ResourceType>(
+ IREE::Stream::Lifetime::External);
+ auto importOp = builder.create<IREE::Stream::TensorImportOp>(
+ loc, externalType, sourceTensor, encodingAttr, dynamicDims, resultSize,
+ executionAffinityAttr);
+ consumingOps.insert(importOp);
+
+ // If needed insert a transfer to the target lifetime.
+ Value result = importOp.getResult();
+ if (targetType != externalType) {
+ result = builder
+ .create<IREE::Stream::AsyncTransferOp>(
+ loc, targetType, result, resultSize, resultSize,
+ /*source_affinity=*/executionAffinityAttr,
+ /*result_affinity=*/executionAffinityAttr)
+ .getResult();
+ }
+
+ auto castOp = builder.create<mlir::UnrealizedConversionCastOp>(
+ loc, sourceTensor.getType(), ValueRange{result, resultSize});
+ consumingOps.insert(castOp);
+ return castOp.getResult(0);
+ }
+
+ IREE::Stream::AffinityAnalysis *affinityAnalysis = nullptr;
};
-} // namespace
+
+static void stripAffinityAttrs(ModuleOp moduleOp) {
+ auto affinityName = StringAttr::get(moduleOp.getContext(), "stream.affinity");
+ for (auto &op : moduleOp.getOps()) {
+ op.removeDiscardableAttr(affinityName);
+ }
+}
//===----------------------------------------------------------------------===//
// --iree-stream-conversion
@@ -215,6 +220,13 @@
void runOnOperation() override {
auto *context = &getContext();
+ // Run affinity analysis so that the required producer/consumer affinities
+ // for all SSA values we'll use during conversion are available.
+ AffinityAnalysis affinityAnalysis(getOperation());
+ if (failed(affinityAnalysis.run())) {
+ return signalPassFailure();
+ }
+
TypeConverter typeConverter;
ConversionTarget conversionTarget(getContext());
RewritePatternSet patterns(&getContext());
@@ -227,10 +239,9 @@
// Allow unknown types to pass through; these come from custom dialects that
// may be mixed into the IR we are converting.
typeConverter.addConversion([=](Type type) -> Type {
- // convert flow.channel into stream.channel
- if (llvm::isa<IREE::Flow::ChannelType>(type))
+ if (llvm::isa<IREE::Flow::ChannelType>(type)) {
return IREE::Stream::ChannelType::get(context);
-
+ }
return !llvm::isa<TensorType>(type) ? type : Type{};
});
@@ -267,21 +278,20 @@
populateUtilConversionPatterns(context, conversionTarget, typeConverter,
patterns);
- populateUtilToStreamConversionPatterns(context, conversionTarget,
- typeConverter, patterns);
+ populateUtilToStreamConversionPatterns(
+ context, conversionTarget, typeConverter, &affinityAnalysis, patterns);
- populateStandardToStreamConversionPatterns(context, conversionTarget,
- typeConverter, patterns);
- populateFlowToStreamConversionPatterns(context, conversionTarget,
- typeConverter, patterns);
- populateHALToStreamConversionPatterns(context, conversionTarget,
- typeConverter, patterns);
+ populateStandardToStreamConversionPatterns(
+ context, conversionTarget, typeConverter, &affinityAnalysis, patterns);
+ populateFlowToStreamConversionPatterns(
+ context, conversionTarget, typeConverter, &affinityAnalysis, patterns);
+ populateHALToStreamConversionPatterns(
+ context, conversionTarget, typeConverter, &affinityAnalysis, patterns);
conversionTarget.markUnknownOpDynamicallyLegal(
[&](Operation *op) -> bool { return !doesOperationNeedWrapping(op); });
- patterns.insert<GenericResourcePattern>(context, typeConverter);
- patterns.insert<OptimizationBarrierOpConversion>(typeConverter, context,
- /*benefit=*/2);
+ patterns.insert<GenericResourcePattern>(context, typeConverter,
+ &affinityAnalysis);
// NOTE: we allow ops that we don't know about to allow custom dialects
// that don't need anything Stream-specific to pass through.
@@ -290,6 +300,9 @@
std::move(patterns)))) {
return signalPassFailure();
}
+
+ // Strip affinity ops as they are no longer required.
+ stripAffinityAttrs(getOperation());
}
};
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideAsyncCopies.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideAsyncCopies.cpp
index 7e9c231..6732b97 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideAsyncCopies.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideAsyncCopies.cpp
@@ -399,7 +399,8 @@
if (sourceType != targetType &&
sourceType.getLifetime() == IREE::Stream::Lifetime::Constant) {
LLVM_DEBUG(llvm::dbgs()
- << " - clone source is a constant; cannot elide\n");
+ << " - clone is a resource lifetime cast (" << sourceType
+ << " to " << targetType << "); cannot elide\n");
return false;
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/EmplaceAllocations.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/EmplaceAllocations.cpp
index f3cd827..2db2cd6 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/EmplaceAllocations.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/EmplaceAllocations.cpp
@@ -33,7 +33,9 @@
// Emplacement
//===----------------------------------------------------------------------===//
-static void replaceUsesAndTransfer(Value oldValue, Value newValue) {
+static void
+replaceUsesAndTransfer(Value oldValue, Value newValue,
+ IREE::Stream::AffinityAttr usageAffinityAttr) {
assert(isa<IREE::Stream::ResourceType>(oldValue.getType()));
assert(isa<IREE::Stream::ResourceType>(newValue.getType()));
if (oldValue.getType() == newValue.getType()) {
@@ -44,8 +46,8 @@
builder.setInsertionPointAfterValue(newValue);
Value newValueSize = IREE::Util::SizeAwareTypeInterface::queryValueSize(
newValue.getLoc(), newValue, builder);
- IREE::Stream::AffinityAttr sourceAffinity;
- IREE::Stream::AffinityAttr resultAffinity;
+ IREE::Stream::AffinityAttr sourceAffinity = usageAffinityAttr;
+ IREE::Stream::AffinityAttr resultAffinity = usageAffinityAttr;
Value transferValue = builder.create<IREE::Stream::AsyncTransferOp>(
newValue.getLoc(), oldValue.getType(), newValue, newValueSize,
newValueSize, sourceAffinity, resultAffinity);
@@ -74,7 +76,7 @@
break;
}
- // Find potential.
+ // Find potential update to place the dispatch result into.
Value targetResource;
Value targetResourceSize;
Value targetOffset;
@@ -82,12 +84,22 @@
Value targetLength;
Value targetResult;
Value targetResultSize;
+ Attribute targetAffinityAttr;
Operation *userOp = *result.user_begin();
if (auto updateOp = dyn_cast<IREE::Stream::AsyncUpdateOp>(userOp)) {
if (updateOp.getUpdate() != result) {
// TODO(#14566): continue if sparse emplacement on multiple results.
break;
}
+
+ // Currently only allow exactly matching affinities.
+ // TODO(multi-device): memory compatibility - if compatible then allow.
+ if (updateOp.getAffinityAttr() != dispatchOp.getAffinityAttr()) {
+ continue;
+ }
+
+ // Try to move all SSA values required into the appropriate place.
+ // TODO(benvanik): undo this if there's a failure (or record/roll-back).
if (!IREE::Util::tryMoveProducerBefore(updateOp.getUpdateSize(),
dispatchOp) ||
!IREE::Util::tryMoveProducerBefore(updateOp.getTargetSize(),
@@ -102,6 +114,7 @@
// TODO(#14566): continue if sparse emplacement on multiple results.
break;
}
+
targetResource = updateOp.getTarget();
if (targetResource.getDefiningOp() == dispatchOp) {
// NOTE: we may have already replaced the update target with one of our
@@ -115,6 +128,7 @@
targetLength = updateOp.getUpdateSize();
targetResult = updateOp.getResult();
targetResultSize = updateOp.getTargetSize();
+ targetAffinityAttr = updateOp.getAffinityAttr();
}
if (!targetResource) {
// TODO(#14566): continue if sparse emplacement on multiple results.
@@ -136,7 +150,7 @@
dispatchOp.getResultSizesMutable().assign(resultSizes);
// Replace users with the result of the dispatch op.
- replaceUsesAndTransfer(targetResult, result);
+ replaceUsesAndTransfer(targetResult, result, dispatchOp.getAffinityAttr());
userOp->erase();
didChange = true;
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeCopyOnWrite.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeCopyOnWrite.cpp
index 34c0ef8..5b7d3a9 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeCopyOnWrite.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeCopyOnWrite.cpp
@@ -103,7 +103,7 @@
IREE::Stream::AffinityAttr affinity;
if (auto affinityOp =
dyn_cast<IREE::Stream::AffinityOpInterface>(tiedOp.getOperation())) {
- affinity = affinityOp.getAffinity();
+ affinity = affinityOp.getAffinityAttr();
}
// Clones each operand that is tied to a result and it may be required.
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
index a0861e7..31a5bb6 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
@@ -16,6 +16,12 @@
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/Passes.h"
+static llvm::cl::opt<bool> clAnnotateInputAffinities(
+ "iree-stream-annotate-input-affinities",
+ llvm::cl::desc("Annotates all tensor/resource affinities on the input to "
+ "the pipeline for debugging."),
+ llvm::cl::init(false));
+
namespace mlir::iree_compiler::IREE::Stream {
using FunctionLikeNest =
@@ -68,6 +74,13 @@
// Conversion
//----------------------------------------------------------------------------
+ // Annotate all ops/resources with the analyzed affinities.
+ // This should have no behavioral changes during conversion but allows for
+ // debugging of analysis errors in end-user tooling.
+ if (clAnnotateInputAffinities) {
+ passManager.addPass(IREE::Stream::createAnnotateAffinitiesPass());
+ }
+
// Converts from all input dialects into various levels of the stream dialect.
// Tensor-like things go to stream.tensor.* ops while lower level buffer-like
// things will go to stream.async.* ops.
@@ -81,6 +94,9 @@
// Constant/variable optimization
//----------------------------------------------------------------------------
+ // Run inlining after having baked out affinities.
+ passManager.addPass(mlir::createInlinerPass());
+
// Cleanup globals that were created during conversion.
addCleanupPatterns(passManager);
@@ -95,6 +111,16 @@
// TODO(benvanik): compute affinities for executables.
// TODO(benvanik): annotate all dispatches with preferred executable affinity.
// TODO(benvanik): DFA to specify all value affinities and pin dispatches.
+
+ // TODO(multi-device): it's really nice to be able to verify here but it
+ // prevents compiling to stream without devices specified or continuation at
+ // various phases. It'd be nice to find a way to enable this when the user
+ // expects it to work and otherwise not.
+ //
+ // Verify that all ops that may require affinities have them assigned or
+ // available (on a parent scope, etc). This allows subsequent passes to trust
+ // that an affinity lookup will always return a valid affinity.
+ // passManager.addPass(IREE::Stream::createVerifyAffinitiesPass());
}
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td
index 83f1d0f..f5ee39f 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td
@@ -457,6 +457,11 @@
// Diagnostics
//===----------------------------------------------------------------------===//
+def AnnotateAffinitiesPass :
+ Pass<"iree-stream-annotate-affinities", "mlir::ModuleOp"> {
+ let summary = "Annotates affinities on all ops for debugging.";
+}
+
def DumpStatisticsPass :
Pass<"iree-stream-dump-statistics", "mlir::ModuleOp"> {
let summary = "Dumps stream dialect usage information to a file.";
@@ -486,6 +491,11 @@
let summary = "Verifies that input dialects are supported by the streams dialect.";
}
+def VerifyAffinitiesPass :
+ Pass<"iree-stream-verify-affinities", "mlir::ModuleOp"> {
+ let summary = "Verifies that all operations have affinities assigned (directly or indirectly).";
+}
+
def VerifyLoweringToTensorsPass :
Pass<"iree-stream-verify-lowering-to-tensors", "mlir::ModuleOp"> {
let summary = "Verifies that input dialects are converted to stream.tensor.* ops.";
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp
index bc73616..02c2bb0 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp
@@ -65,7 +65,7 @@
// Returns either the affinity of |op| or nullptr.
static IREE::Stream::AffinityAttr getOpAffinity(Operation *op) {
if (auto affinityOp = dyn_cast<IREE::Stream::AffinityOpInterface>(op)) {
- return affinityOp.getAffinity();
+ return affinityOp.getAffinityAttr();
}
return {};
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp
index 1bec564..c8510a6 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp
@@ -668,7 +668,8 @@
return llvm::cast<IREE::Stream::ResourceType>(value.getType())
.getLifetime() == IREE::Stream::Lifetime::Staging;
};
- auto currentAffinityAttr = IREE::Stream::AffinityAttr::lookup(asyncOp);
+ auto currentAffinityAttr =
+ IREE::Stream::AffinityAttr::lookupOrDefault(asyncOp);
bool transferIn = asyncOp.getSourceAffinityAttr() != currentAffinityAttr ||
isStaging(asyncOp.getSource());
bool transferOut = asyncOp.getResultAffinityAttr() != currentAffinityAttr ||
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp
index f8278ff..c850c3b 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp
@@ -152,8 +152,8 @@
// want to preserve those as long as possible.
if (auto affinityOp =
dyn_cast<IREE::Stream::AffinityOpInterface>(clonedOp)) {
- if (affinityOp.getAffinity() == partition->affinity) {
- affinityOp.setAffinity(nullptr);
+ if (affinityOp.getAffinityAttr() == partition->affinity) {
+ affinityOp.setAffinityAttr(nullptr);
}
}
@@ -275,9 +275,6 @@
auto awaitOp = builder.create<IREE::Stream::TimepointAwaitOp>(
executeOp.getLoc(), newResult, newResultSize,
executeOp.getResultTimepoint());
- if (executeOp.getAffinity().has_value()) {
- awaitOp.setAffinityAttr(executeOp.getAffinityAttr());
- }
// Explicitly copy the Value since it is marked as const.
Value toBeDeleted = oldResult;
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAffinities.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAffinities.cpp
new file mode 100644
index 0000000..7579244
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAffinities.cpp
@@ -0,0 +1,70 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
+#include "iree/compiler/Dialect/Stream/Transforms/Passes.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir::iree_compiler::IREE::Stream {
+
+#define GEN_PASS_DEF_VERIFYAFFINITIESPASS
+#include "iree/compiler/Dialect/Stream/Transforms/Passes.h.inc"
+
+namespace {
+
+// Verifies that |op| has an affinity assigned on itself or a parent.
+static LogicalResult
+verifyAffinityAssigned(IREE::Stream::AffinityOpInterface op) {
+ if (!op.requiresAffinity()) {
+ return success(); // does not require an affinity
+ } else if (IREE::Stream::AffinityAttr::lookupOrDefault(op)) {
+ return success(); // has an affinity
+ }
+ return op->emitOpError()
+ << "does not have an affinity assigned; ensure that the op or some "
+ "ancestor of it has a valid execution affinity assigned";
+}
+
+//===----------------------------------------------------------------------===//
+// --iree-stream-verify-affinities
+//===----------------------------------------------------------------------===//
+
+struct VerifyAffinitiesPass
+ : public IREE::Stream::impl::VerifyAffinitiesPassBase<
+ VerifyAffinitiesPass> {
+ void runOnOperation() override {
+ auto moduleOp = getOperation();
+ if (moduleOp
+ .walk<WalkOrder::PreOrder>([&](Operation *op) {
+ if (isa<mlir::ModuleOp>(op)) {
+ return WalkResult::advance();
+ }
+ if (auto affinityOp =
+ dyn_cast<IREE::Stream::AffinityOpInterface>(op)) {
+ if (failed(verifyAffinityAssigned(affinityOp))) {
+ return WalkResult::interrupt();
+ }
+ }
+ return (op->hasTrait<OpTrait::IREE::Util::ObjectLike>() ||
+ op->hasTrait<OpTrait::SymbolTable>())
+ ? WalkResult::skip()
+ : WalkResult::advance();
+ })
+ .wasInterrupted())
+ return signalPassFailure();
+ }
+};
+
+} // namespace
+
+} // namespace mlir::iree_compiler::IREE::Stream
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel
index 1f2104a..362d672 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel
@@ -16,6 +16,7 @@
name = "lit",
srcs = enforce_glob(
[
+ "annotate_affinities.mlir",
"annotate_dispatch_arguments.mlir",
"convert_to_stream.mlir",
"dump_statistics.mlir",
@@ -43,6 +44,7 @@
"schedule_concurrency.mlir",
"schedule_execution.mlir",
"specialize_dispatches.mlir",
+ "verify_affinities.mlir",
"verify_async_access_ranges.mlir",
],
include = ["*.mlir"],
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt
index 2e2294a..fe83ee6 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt
@@ -14,6 +14,7 @@
NAME
lit
SRCS
+ "annotate_affinities.mlir"
"annotate_dispatch_arguments.mlir"
"convert_to_stream.mlir"
"dump_statistics.mlir"
@@ -41,6 +42,7 @@
"schedule_concurrency.mlir"
"schedule_execution.mlir"
"specialize_dispatches.mlir"
+ "verify_affinities.mlir"
"verify_async_access_ranges.mlir"
TOOLS
FileCheck
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/annotate_affinities.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/annotate_affinities.mlir
new file mode 100644
index 0000000..c3e1f1e
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/annotate_affinities.mlir
@@ -0,0 +1,1549 @@
+// RUN: iree-opt --split-input-file --iree-stream-annotate-affinities %s | FileCheck %s
+
+// Tests that we can track affinity through optimization barriers. They're meant
+// to block optimization but we really can't do much if we don't track affinity.
+// We could change this in the future but tests would be harder to write and
+// there's not a lot that can be done with an unassigned resource.
+
+// CHECK-LABEL: @optimization_barrier_consumer
+util.func private @optimization_barrier_consumer() -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: util.optimization_barrier
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_dno = util.optimization_barrier %cst : tensor<1xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.transfer %cst_dno : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %cst_a : tensor<1xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @optimization_barrier_producer
+util.func private @optimization_barrier_producer() -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32>
+ // CHECK: util.optimization_barrier
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a_dno = util.optimization_barrier %cst_a : tensor<1xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %cst_a_dno : tensor<1xi32>
+}
+
+// -----
+
+// Tests that constant-like ops get placed with their consumer(s).
+// We want to replicate constants where they are consumed instead of performing
+// transfers at runtime to move them around and by placing with consumers we
+// can know when we need to do that early on.
+
+// CHECK-LABEL: @constant_op
+util.func private @constant_op() -> (tensor<1xi32>, tensor<1xi32>) {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ %cst = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.transfer %cst : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %cst_b = flow.tensor.transfer %cst : tensor<1xi32> to #hal.device.promise<@dev_b>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_b>]]
+ util.return %cst_a, %cst_b : tensor<1xi32>, tensor<1xi32>
+}
+
+// -----
+
+// Tests that splats (not constant-like but no consumed values) are placed with
+// their consumer(s). These are always best to rematerialize where they are
+// consumed to avoid allocating/transfering a bunch of repeated values.
+
+// CHECK-LABEL: @splat_op
+util.func private @splat_op() -> tensor<1xi32> {
+ %splat_value = arith.constant 123 : i32
+ // CHECK: flow.tensor.splat
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %splat = flow.tensor.splat %splat_value : tensor<1xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %splat_a = flow.tensor.transfer %splat : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %splat_a : tensor<1xi32>
+}
+
+// -----
+
+// Tests that imported tensor placement is inherited.
+// Frontends can use this to declare where they expect their arguments to
+// be living at the time the functions are invoked. Imports do not perform
+// transfers so we must use whatever is declared.
+
+// CHECK-LABEL: @imported_tensor
+util.func public @imported_tensor(%buffer_view: !hal.buffer_view, %fence: !hal.fence) -> tensor<1xi32> {
+ // CHECK: hal.tensor.import
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %tensor = hal.tensor.import on(#hal.device.promise<@dev_a>) wait(%fence) => %buffer_view "input" : !hal.buffer_view -> tensor<1xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %tensor : tensor<1xi32>
+}
+
+// -----
+
+// Tests that consumer-placed ops exported to buffers are properly placed.
+// Frontends can use this to explicitly define where exported tensors must live.
+// With consumer-placed ops like constants or splats we place them directly on
+// the export target.
+
+// CHECK-LABEL: @exported_constant
+util.func public @exported_constant(%fence: !hal.fence) -> !hal.buffer_view {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: hal.tensor.barrier
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_ready = hal.tensor.barrier join(%cst : tensor<1xi32>) => %fence : !hal.fence
+ // CHECK: hal.tensor.export
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ %buffer_view = hal.tensor.export on(#hal.device.promise<@dev_a>) %cst_ready "output" : tensor<1xi32> -> !hal.buffer_view
+ util.return %buffer_view : !hal.buffer_view
+}
+
+// -----
+
+// Tests that producer-placed ops exported to buffers get the appropriate
+// affinity on both devices. Frontends can use this to explicitly define where
+// exported tensors must live. Transfers may need to be inserted in order to
+// respect the required affinities. Note here that the operand to the export
+// is on @dev_a instead of the requested @dev_b.
+
+// CHECK-LABEL: @exported_producer
+util.func public @exported_producer(%fence: !hal.fence) -> !hal.buffer_view {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.transfer %cst : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: flow.tensor.clone
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %clone_a = flow.tensor.clone %cst_a : tensor<1xi32>
+ // CHECK: hal.tensor.barrier
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %clone_ready_a = hal.tensor.barrier join(%clone_a : tensor<1xi32>) => %fence : !hal.fence
+ // CHECK: hal.tensor.export
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ %buffer_view = hal.tensor.export on(#hal.device.promise<@dev_b>) %clone_ready_a "output" : tensor<1xi32> -> !hal.buffer_view
+ // CHECK: util.return
+ util.return %buffer_view : !hal.buffer_view
+}
+
+// -----
+
+// Test in-place aliased storage for results.
+// Frontends require that the storage be placed as indicated even if that means
+// introducing transfers such that the operation is not in-place.
+
+// CHECK-LABEL: @aliased_storage
+util.func public @aliased_storage(%view: !hal.buffer_view, %storage: !hal.buffer, %fence: !hal.fence) {
+ // CHECK: hal.tensor.import
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %arg_a = hal.tensor.import on(#hal.device.promise<@dev_a>) %view : !hal.buffer_view -> tensor<4xi32>
+ // CHECK: flow.dispatch @dispatch
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %ret_b = flow.dispatch @dispatch(%arg_a) : (tensor<4xi32>) -> tensor<4xi32>
+ // CHECK: hal.tensor.alias
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %alias_b = hal.tensor.alias on(#hal.device.promise<@dev_b>) %ret_b : tensor<4xi32> to %storage : !hal.buffer
+ // CHECK: hal.tensor.barrier
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ hal.tensor.barrier join(%alias_b : tensor<4xi32>) => %fence : !hal.fence
+ util.return
+}
+
+// -----
+
+// Tests aliased storage through tied dispatches.
+
+// CHECK-LABEL: @tied_aliased_storage
+util.func public @tied_aliased_storage(%view: !hal.buffer_view, %storage: !hal.buffer, %fence: !hal.fence) {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst = flow.tensor.constant dense<123> : tensor<4xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.transfer %cst : tensor<4xi32> to #hal.device.promise<@dev_a>
+ // CHECK: flow.dispatch @dispatch0
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %t0 = flow.dispatch @dispatch0(%cst) : (tensor<4xi32>) -> tensor<4xi32>
+ // CHECK: flow.dispatch @dispatch1
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %t1 = flow.dispatch @dispatch1(%t0) : (tensor<4xi32>) -> %t0
+ // CHECK: hal.tensor.alias
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %alias = hal.tensor.alias on(#hal.device.promise<@dev_b>) %t1 : tensor<4xi32> to %storage : !hal.buffer
+ // CHECK: hal.tensor.barrier
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ hal.tensor.barrier join(%alias : tensor<4xi32>) => %fence : !hal.fence
+ util.return
+}
+
+// -----
+
+// Tests that consumer-placed ops that pass through tied ops get attributed to
+// a single consumer.
+
+// CHECK-LABEL: @tied_constant
+util.func private @tied_constant() -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: flow.dispatch @a
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %tied = flow.dispatch @a(%cst) : (tensor<1xi32>) -> %cst
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %tied_a = flow.tensor.transfer %tied : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %tied_a : tensor<1xi32>
+}
+
+// -----
+
+// Tests that consumer-placed ops that pass through tied ops get attributed to
+// transitive consumers. This is not ideal but allows the application of
+// replication policies.
+
+// CHECK-LABEL: @tied_constant_multi_consumer
+util.func private @tied_constant_multi_consumer() -> (tensor<1xi32>, tensor<1xi32>) {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ %cst = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: flow.dispatch @a
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ %tied_0 = flow.dispatch @a(%cst) : (tensor<1xi32>) -> %cst
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %tied_0_a = flow.tensor.transfer %tied_0 : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: flow.dispatch @b
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ %tied_1 = flow.dispatch @b(%cst) : (tensor<1xi32>) -> %cst
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %tied_1_b = flow.tensor.transfer %tied_1 : tensor<1xi32> to #hal.device.promise<@dev_b>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_b>]]
+ util.return %tied_0_a, %tied_1_b : tensor<1xi32>, tensor<1xi32>
+}
+
+// -----
+
+// Tests the proper transfer of consumer-placed values prior to multiple tied
+// uses don't pollute the execution affinity of ops after transfers. Note that
+// the constant will still have multiple affinities to allow for policies that
+// replicate the constant.
+
+// CHECK-LABEL: @tied_transfer_constant_multi_consumer
+util.func private @tied_transfer_constant_multi_consumer() -> (tensor<1xi32>, tensor<1xi32>) {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ %cst = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.transfer %cst : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: flow.dispatch @a
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %tied_0 = flow.dispatch @a(%cst_a) : (tensor<1xi32>) -> %cst_a
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %tied_0_a = flow.tensor.transfer %tied_0 : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %cst_b = flow.tensor.transfer %cst : tensor<1xi32> to #hal.device.promise<@dev_b>
+ // CHECK: flow.dispatch @b
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %tied_1 = flow.dispatch @b(%cst_b) : (tensor<1xi32>) -> %cst_b
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %tied_1_b = flow.tensor.transfer %tied_1 : tensor<1xi32> to #hal.device.promise<@dev_b>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_b>]]
+ util.return %tied_0_a, %tied_1_b : tensor<1xi32>, tensor<1xi32>
+}
+
+// -----
+
+// Tests that implicitly placed consumers use their transfer execution affinity.
+
+// CHECK-LABEL: @transfer_execution_affinity
+util.func private @transfer_execution_affinity() -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %cst_b = flow.tensor.transfer %cst_a : tensor<1xi32> to #hal.device.promise<@dev_b>
+ // CHECK: flow.dispatch @dispatch
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %dispatch_b = flow.dispatch @dispatch(%cst_b) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ util.return %dispatch_b : tensor<1xi32>
+}
+
+// -----
+
+// Tests that explicitly placed consumers use their explicit execution affinity.
+
+// CHECK-LABEL: @explicit_execution_affinity
+util.func private @explicit_execution_affinity() -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32>
+ // CHECK: flow.dispatch @dispatch
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %dispatch_b = flow.dispatch @dispatch(%cst_a) {stream.affinity = #hal.device.promise<@dev_b>} : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ util.return %dispatch_b : tensor<1xi32>
+}
+
+// -----
+
+// Tests that consumers of operands with multiple affinities inherit those
+// affinities for execution. This allows policies to determine where they want
+// to execute out of the resources they may be consuming.
+
+// CHECK-LABEL: @consume_multi_affinities
+util.func private @consume_multi_affinities() -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32>
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %cst_b = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_b>} dense<456> : tensor<1xi32>
+ // CHECK: flow.dispatch @dispatch
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ %dispatch_ab = flow.dispatch @dispatch(%cst_a, %cst_b) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ util.return %dispatch_ab : tensor<1xi32>
+}
+
+// -----
+
+// Tests that globals are placed where they are loaded.
+
+// CHECK: util.global private @consumed_global_a
+// CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+util.global private @consumed_global_a : tensor<1xi32>
+util.func private @consumer_fn() -> tensor<1xi32> {
+ // CHECK: util.global.load @consumed_global_a
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %load = util.global.load @consumed_global_a : tensor<1xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %load_a = flow.tensor.transfer %load : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %load_a : tensor<1xi32>
+}
+
+// -----
+
+// Tests that a global loaded from two locations is attributed to both
+// affinities. This allows policies to decide whether to replicate the global.
+
+// CHECK: util.global private @consumed_global_ab
+// CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]
+util.global private @consumed_global_ab : tensor<1xi32>
+util.func private @consumer_fn_a() -> tensor<1xi32> {
+ // CHECK: util.global.load @consumed_global_ab
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ %load = util.global.load @consumed_global_ab : tensor<1xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %load_a = flow.tensor.transfer %load : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %load_a : tensor<1xi32>
+}
+util.func private @consumer_fn_b() -> tensor<1xi32> {
+ // CHECK: util.global.load @consumed_global_ab
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ %load = util.global.load @consumed_global_ab : tensor<1xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %load_b = flow.tensor.transfer %load : tensor<1xi32> to #hal.device.promise<@dev_b>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ util.return %load_b : tensor<1xi32>
+}
+
+// -----
+
+// Tests that consumer-placed ops track through global loads.
+
+// CHECK: util.global private mutable @global_b
+// CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>]
+util.global private mutable @global_b : tensor<1xi32>
+util.func private @producer_fn() {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32>
+ // CHECK: util.global.store
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.global.store %cst_a, @global_b : tensor<1xi32>
+ util.return
+}
+util.func private @consumer_fn() -> tensor<1xi32> {
+ // CHECK: util.global.load
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %load = util.global.load @global_b : tensor<1xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %load_b = flow.tensor.transfer %load : tensor<1xi32> to #hal.device.promise<@dev_b>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ util.return %load_b : tensor<1xi32>
+}
+
+// -----
+
+// Tests that globals that are only stored take the fallback placement of
+// their producer. This is silly but can arise prior to global optimization
+// passes that may elide them.
+
+// CHECK: util.global private mutable @global_a
+// CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+util.global private mutable @global_a : tensor<1xi32>
+util.func private @producer_fn() {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32>
+ // CHECK: util.global.store
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.global.store %cst_a, @global_a : tensor<1xi32>
+ util.return
+}
+
+// -----
+
+// Tests that global consumers that take on consumed affinity track the global.
+
+// CHECK: util.global private @global_a
+// CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+util.global private @global_a {stream.affinity = #hal.device.promise<@dev_a>} : tensor<1xi32>
+// CHECK: util.global private @global_b
+// CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>]
+util.global private @global_b {stream.affinity = #hal.device.promise<@dev_b>} : tensor<1xi32>
+util.func private @consumer_fn() -> tensor<1xi32> {
+ // CHECK: util.global.load @global_a
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %load_a = util.global.load @global_a : tensor<1xi32>
+ // CHECK: util.global.load @global_b
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %load_b = util.global.load @global_b : tensor<1xi32>
+ // CHECK: flow.dispatch @dispatch
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ %result_ab = flow.dispatch @dispatch(%load_a, %load_b) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ util.return %result_ab : tensor<1xi32>
+}
+
+// -----
+
+// Tests a global update tick that operates on the global from multiple
+// affinities.
+
+// CHECK: util.global private mutable @global_a
+// CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+util.global private mutable @global_a {stream.affinity = #hal.device.promise<@dev_a>} = dense<123> : tensor<1xi32>
+util.func private @step(%arg0: tensor<2xi32>) -> tensor<2xi32> {
+ // CHECK: util.global.load @global_a
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %load_a = util.global.load @global_a : tensor<1xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %arg0_b = flow.tensor.transfer %arg0 : tensor<2xi32> to #hal.device.promise<@dev_b>
+ // CHECK: flow.dispatch @dispatch
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>], [#hal.device.promise<@dev_b>]]
+ %result_b:2 = flow.dispatch @dispatch(%load_a, %arg0_b) {stream.affinity = #hal.device.promise<@dev_b>} : (tensor<1xi32>, tensor<2xi32>) -> (tensor<1xi32>, tensor<2xi32>)
+ // CHECK: util.global.store
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ util.global.store %result_b#0, @global_a : tensor<1xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ util.return %result_b#1 : tensor<2xi32>
+}
+
+// -----
+
+// Tests that constants passed through selects are placed on the consumer.
+
+// CHECK-LABEL: @select_constants_consumed
+util.func private @select_constants_consumed(%cond: i1) -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_123 = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_456 = flow.tensor.constant dense<456> : tensor<1xi32>
+ // CHECK: arith.select
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst = arith.select %cond, %cst_123, %cst_456 : tensor<1xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.transfer %cst : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %cst_a : tensor<1xi32>
+}
+
+// -----
+
+// Tests that placed operands passed through selects are tracked on consumers.
+
+// CHECK-LABEL: @select_constants_placed
+util.func private @select_constants_placed(%cond: i1) -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32>
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %cst_b = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_b>} dense<456> : tensor<1xi32>
+ // CHECK: arith.select
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ %cst_ab = arith.select %cond, %cst_a, %cst_b : tensor<1xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ util.return %cst_ab : tensor<1xi32>
+}
+
+// -----
+
+// Tests that a callee that does not touch an argument still tracks the
+// affinity through it.
+
+// CHECK-LABEL: @passthrough_caller
+util.func private @passthrough_caller() -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32>
+ // CHECK: util.call @passthrough_callee
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %result_a = util.call @passthrough_callee(%cst_a) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %result_a : tensor<1xi32>
+}
+// CHECK: util.func private @passthrough_callee
+util.func private @passthrough_callee(%arg0: tensor<1xi32>) -> tensor<1xi32> {
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %arg0 : tensor<1xi32>
+}
+
+// -----
+
+// Tests that callees that consumer-placed arguments that are passed to callees
+// get placed based on callee usage.
+
+// CHECK-LABEL: @consumer_placement_caller
+util.func private @consumer_placement_caller() -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: util.call @consumer_placement_callee
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %result_a = util.call @consumer_placement_callee(%cst) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %result_a : tensor<1xi32>
+}
+// CHECK: util.func private @consumer_placement_callee
+util.func private @consumer_placement_callee(%arg0: tensor<1xi32>) -> tensor<1xi32> {
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %arg0_a = flow.tensor.transfer %arg0 : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %arg0_a : tensor<1xi32>
+}
+
+// -----
+
+// Tests that multiple potential affinities are propagated across call edges.
+
+// CHECK-LABEL: @select_caller
+util.func private @select_caller(%arg0: tensor<1xi32>, %cond: i1) -> tensor<1xi32> {
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %arg0_a = flow.tensor.transfer %arg0 : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: util.call @select_callee
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ %result_ab = util.call @select_callee(%arg0_a, %cond) : (tensor<1xi32>, i1) -> tensor<1xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ util.return %result_ab : tensor<1xi32>
+}
+// CHECK: util.func private @select_callee
+util.func private @select_callee(%arg0_a: tensor<1xi32>, %cond: i1) -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %cst_b = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_b>} dense<123> : tensor<1xi32>
+ // CHECK: arith.select
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ %select_ab = arith.select %cond, %arg0_a, %cst_b : tensor<1xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ util.return %select_ab : tensor<1xi32>
+}
+
+// -----
+
+// Tests that consumer-placed ops are propagated across call edges.
+
+// CHECK-LABEL: @consumer_multi_placement_caller
+util.func private @consumer_multi_placement_caller() -> (tensor<1xi32>, tensor<1xi32>) {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_c>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_c>]]
+ %cst = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: util.call @consumer_multi_placement_callee
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_c>]]
+ %result_0_c = util.call @consumer_multi_placement_callee(%cst) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %result_0_a = flow.tensor.transfer %result_0_c : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: util.call @consumer_multi_placement_callee
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_c>]]
+ %result_1_c = util.call @consumer_multi_placement_callee(%cst) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %result_1_b = flow.tensor.transfer %result_1_c : tensor<1xi32> to #hal.device.promise<@dev_b>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_b>]]
+ util.return %result_0_a, %result_1_b : tensor<1xi32>, tensor<1xi32>
+}
+// CHECK: util.func private @consumer_multi_placement_callee
+util.func private @consumer_multi_placement_callee(%arg0: tensor<1xi32>) -> tensor<1xi32> {
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_c>]]
+ %arg0_c = flow.tensor.transfer %arg0 : tensor<1xi32> to #hal.device.promise<@dev_c>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]]
+ util.return %arg0_c : tensor<1xi32>
+}
+
+// -----
+
+// Tests that operand/result affinities are tracked across call edges.
+
+// CHECK-LABEL: @dispatch_fn_a
+util.func private @dispatch_fn_a() -> tensor<4xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %0 = flow.tensor.constant dense<123> : tensor<4xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %1 = flow.tensor.transfer %0 : tensor<4xi32> to #hal.device.promise<@dev_a>
+ // CHECK: flow.dispatch @dispatch_a_0
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %2 = flow.dispatch @dispatch_a_0(%1) : (tensor<4xi32>) -> tensor<4xi32>
+ // CHECK: util.call @dispatch_fn_b
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %3 = util.call @dispatch_fn_b(%2) : (tensor<4xi32>) -> tensor<4xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %4 = flow.tensor.transfer %3 : tensor<4xi32> to #hal.device.promise<@dev_a>
+ // CHECK: flow.dispatch @dispatch_a_1
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %5 = flow.dispatch @dispatch_a_1(%4) : (tensor<4xi32>) -> tensor<4xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %5 : tensor<4xi32>
+}
+// CHECK: util.func private @dispatch_fn_b
+util.func private @dispatch_fn_b(%arg0: tensor<4xi32>) -> tensor<4xi32> {
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %0 = flow.tensor.transfer %arg0 : tensor<4xi32> to #hal.device.promise<@dev_b>
+ // CHECK: flow.dispatch @dispatch_b
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %1 = flow.dispatch @dispatch_b(%0) : (tensor<4xi32>) -> tensor<4xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ util.return %1 : tensor<4xi32>
+}
+
+// -----
+
+// Tests a realistic call graph with explicit transfers.
+
+// CHECK-LABEL: @dispatch_fn_a
+util.func private @dispatch_fn_a() -> tensor<4xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %0 = flow.tensor.constant dense<123> : tensor<4xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %1 = flow.tensor.transfer %0 : tensor<4xi32> to #hal.device.promise<@dev_a>
+ // CHECK: util.call @dispatch_fn_b
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %2 = util.call @dispatch_fn_b(%1) : (tensor<4xi32>) -> tensor<4xi32>
+ // CHECK: util.call @dispatch_fn_c
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_c>]]
+ %3 = util.call @dispatch_fn_c(%1) : (tensor<4xi32>) -> tensor<4xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %4 = flow.tensor.transfer %2 : tensor<4xi32> to #hal.device.promise<@dev_a>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %5 = flow.tensor.transfer %3 : tensor<4xi32> to #hal.device.promise<@dev_a>
+ // CHECK: flow.dispatch @dispatch_a
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %6 = flow.dispatch @dispatch_a(%4, %5) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %5 : tensor<4xi32>
+}
+// CHECK: util.func private @dispatch_fn_b
+util.func private @dispatch_fn_b(%arg0: tensor<4xi32>) -> tensor<4xi32> {
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %0 = flow.tensor.transfer %arg0 : tensor<4xi32> to #hal.device.promise<@dev_b>
+ // CHECK: flow.dispatch @dispatch_b
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %1 = flow.dispatch @dispatch_b(%0) : (tensor<4xi32>) -> tensor<4xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ util.return %1 : tensor<4xi32>
+}
+// CHECK: util.func private @dispatch_fn_c
+util.func private @dispatch_fn_c(%arg0: tensor<4xi32>) -> tensor<4xi32> {
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_c>]]
+ %0 = flow.tensor.transfer %arg0 : tensor<4xi32> to #hal.device.promise<@dev_c>
+ // CHECK: flow.dispatch @dispatch_c
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_c>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_c>]]
+ %1 = flow.dispatch @dispatch_c(%0) : (tensor<4xi32>) -> tensor<4xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]]
+ util.return %1 : tensor<4xi32>
+}
+
+// -----
+
+// Tests that consumer-placed ops are tracked across branch edges.
+
+// CHECK-LABEL: @cfg_branch_constant_consumed
+util.func private @cfg_branch_constant_consumed() -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: cf.br ^bb1
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ cf.br ^bb1(%cst : tensor<1xi32>)
+^bb1(%bb1_arg0: tensor<1xi32>):
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.transfer %bb1_arg0 : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %cst_a : tensor<1xi32>
+}
+
+// -----
+
+// Tests that producer-placed ops are tracked across branch edges.
+
+// CHECK-LABEL: @cfg_branch_dispatch_produced
+util.func private @cfg_branch_dispatch_produced() -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32>
+ // CHECK: cf.br ^bb1
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ cf.br ^bb1(%cst_a : tensor<1xi32>)
+^bb1(%bb1_arg0: tensor<1xi32>):
+ // CHECK: flow.dispatch @dispatch
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %dispatch_a = flow.dispatch @dispatch(%bb1_arg0) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %dispatch_a : tensor<1xi32>
+}
+
+// -----
+
+// Tests that back edges on loops track affinity changes.
+
+// CHECK-LABEL: @cfg_loop_back_edge
+util.func private @cfg_loop_back_edge() -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32>
+ // CHECK: cf.br ^bb1
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ cf.br ^bb1(%cst_a : tensor<1xi32>)
+^bb1(%bb1_arg0: tensor<1xi32>):
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %bb1_arg0_b = flow.tensor.transfer %bb1_arg0 : tensor<1xi32> to #hal.device.promise<@dev_b>
+ // CHECK: util.call @step
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ %cond = util.call @step(%bb1_arg0_b) : (tensor<1xi32>) -> i1
+ // CHECK: cf.cond_br
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_b>]]
+ cf.cond_br %cond, ^bb1(%bb1_arg0 : tensor<1xi32>), ^bb2(%bb1_arg0_b : tensor<1xi32>)
+^bb2(%bb2_arg0: tensor<1xi32>):
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_c>]]
+ %bb2_arg0_c = flow.tensor.transfer %bb2_arg0 : tensor<1xi32> to #hal.device.promise<@dev_c>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]]
+ util.return %bb2_arg0_c : tensor<1xi32>
+}
+util.func private @step(tensor<1xi32>) -> i1
+
+// -----
+
+// Tests that conditional branches acting as selects propagate both affinities.
+
+// CHECK-LABEL: @cfg_cond_branch_select
+util.func private @cfg_cond_branch_select(%cond: i1) -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32>
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %cst_b = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_b>} dense<456> : tensor<1xi32>
+ // CHECK: cf.cond_br
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_b>]]
+ cf.cond_br %cond, ^bb1(%cst_a : tensor<1xi32>), ^bb1(%cst_b : tensor<1xi32>)
+^bb1(%bb1_arg0: tensor<1xi32>):
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ util.return %bb1_arg0 : tensor<1xi32>
+}
+
+// -----
+
+// Tests that consumer-placed ops through conditional branches acting as selects
+// get placed on all targets.
+
+// CHECK-LABEL: @cfg_cond_branch_select_consumer
+util.func private @cfg_cond_branch_select_consumer(%cond: i1) -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ %cst = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: cf.cond_br
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>], [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ cf.cond_br %cond, ^bb1(%cst : tensor<1xi32>), ^bb2(%cst : tensor<1xi32>)
+^bb1(%bb1_arg0: tensor<1xi32>):
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.transfer %bb1_arg0 : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %cst_a : tensor<1xi32>
+^bb2(%bb2_arg0: tensor<1xi32>):
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %cst_b = flow.tensor.transfer %bb2_arg0 : tensor<1xi32> to #hal.device.promise<@dev_b>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ util.return %cst_b : tensor<1xi32>
+}
+
+// -----
+
+// Tests scf.if capturing consumer-placed ops tracks the affinity into nested
+// regions.
+
+// CHECK-LABEL: @scf_if_capture_consumer
+util.func private @scf_if_capture_consumer(%cond: i1) -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ %cst = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: scf.if
+ %cst_ab = scf.if %cond -> tensor<1xi32> {
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.transfer %cst : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: scf.yield
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ scf.yield %cst_a : tensor<1xi32>
+ // CHECK: else
+ } else {
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %cst_b = flow.tensor.transfer %cst : tensor<1xi32> to #hal.device.promise<@dev_b>
+ // CHECK: scf.yield
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ scf.yield %cst_b : tensor<1xi32>
+ // CHECK{LITERAL}: } {stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ }
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ util.return %cst_ab : tensor<1xi32>
+}
+
+// -----
+
+// Tests scf.if capturing explicitly placed ops tracks the affinity of their
+// produced results into consumers.
+
+// CHECK-LABEL: @scf_if_capture_producer
+util.func private @scf_if_capture_producer(%cond: i1) -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32>
+ // CHECK: scf.if
+ %cst_bc = scf.if %cond -> tensor<1xi32> {
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %cst_b = flow.tensor.transfer %cst_a : tensor<1xi32> to #hal.device.promise<@dev_b>
+ // CHECK: scf.yield
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ scf.yield %cst_b : tensor<1xi32>
+ // CHECK: else
+ } else {
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_c>]]
+ %cst_c = flow.tensor.transfer %cst_a : tensor<1xi32> to #hal.device.promise<@dev_c>
+ // CHECK: scf.yield
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]]
+ scf.yield %cst_c : tensor<1xi32>
+ // CHECK{LITERAL}: } {stream.affinities.results = [[#hal.device.promise<@dev_b>, #hal.device.promise<@dev_c>]]
+ }
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>, #hal.device.promise<@dev_c>]]
+ util.return %cst_bc : tensor<1xi32>
+}
+
+// -----
+
+// Tests scf.if returning unassigned consumer-placed operations has the affinity
+// tracked across scf.yields and assigned based on the consumer.
+
+// CHECK-LABEL: @scf_if_consumer_yield
+util.func private @scf_if_consumer_yield(%cond: i1) -> tensor<1xi32> {
+ // CHECK: scf.if
+ %cst = scf.if %cond -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_0 = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: scf.yield
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ scf.yield %cst_0 : tensor<1xi32>
+ // CHECK: else
+ } else {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_1 = flow.tensor.constant dense<456> : tensor<1xi32>
+ // CHECK: scf.yield
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ scf.yield %cst_1 : tensor<1xi32>
+ // CHECK{LITERAL}: } {stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ }
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.transfer %cst : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %cst_a : tensor<1xi32>
+}
+
+// -----
+
+// Tests that consumer-placed ops get placed based on their use in the body.
+
+// CHECK-LABEL: @scf_for_consumer_body_transfer
+util.func private @scf_for_consumer_body_transfer() -> tensor<1xi32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: scf.for
+ %for = scf.for %i = %c0 to %c2 step %c1 iter_args(%arg0 = %cst) -> tensor<1xi32> {
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %arg0_a = flow.tensor.transfer %arg0 : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: flow.dispatch @dispatch
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %t = flow.dispatch @dispatch(%arg0_a) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: scf.yield
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ scf.yield %t : tensor<1xi32>
+ // CHECK{LITERAL}: } {stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ }
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %for : tensor<1xi32>
+}
+
+// -----
+
+// Tests that scf.for ops with transfers/explicit affinities on the edges get
+// the
+
+// CHECK-LABEL: @scf_for_boundary_transfer
+util.func private @scf_for_boundary_transfer() -> (tensor<1xi32>, tensor<1xi32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ %cst = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: scf.for
+ %for:2 = scf.for %i = %c0 to %c2 step %c1 iter_args(%arg0 = %cst, %arg1 = %cst) -> (tensor<1xi32>, tensor<1xi32>) {
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %arg0_a = flow.tensor.transfer %arg0 : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: flow.dispatch @dispatch
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %t = flow.dispatch @dispatch(%arg0_a) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: scf.yield
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ scf.yield %t, %arg1 : tensor<1xi32>, tensor<1xi32>
+ // CHECK{LITERAL}: } {stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>], [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ }
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %for_0_b = flow.tensor.transfer %for#0 : tensor<1xi32> to #hal.device.promise<@dev_b>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %for_1_b = flow.tensor.transfer %for#1 : tensor<1xi32> to #hal.device.promise<@dev_b>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>], [#hal.device.promise<@dev_b>]]
+ util.return %for_0_b, %for_1_b : tensor<1xi32>, tensor<1xi32>
+}
+
+// -----
+
+// Tests that transfers track through iter_args.
+
+// CHECK-LABEL: @scf_for_body_transfer
+util.func private @scf_for_body_transfer() -> tensor<1xi32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32>
+ // CHECK: scf.for
+ %for = scf.for %i = %c0 to %c2 step %c1 iter_args(%arg0 = %cst_a) -> tensor<1xi32> {
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %arg0_b = flow.tensor.transfer %arg0 : tensor<1xi32> to #hal.device.promise<@dev_b>
+ // CHECK: flow.dispatch @dispatch
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %t = flow.dispatch @dispatch(%arg0_b) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: scf.yield
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ scf.yield %t : tensor<1xi32>
+ // CHECK{LITERAL}: } {stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ }
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_c>]]
+ %for_c = flow.tensor.transfer %for : tensor<1xi32> to #hal.device.promise<@dev_c>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]]
+ util.return %for_c : tensor<1xi32>
+}
+
+// -----
+
+// Tests that placed values track through iter_args to consumers in scf.for
+// bodies.
+
+// CHECK-LABEL: @scf_for_capture_producer
+util.func private @scf_for_capture_producer() -> tensor<1xi32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32>
+ // CHECK: scf.for
+ %for = scf.for %i = %c0 to %c2 step %c1 iter_args(%arg0 = %cst_a) -> tensor<1xi32> {
+ // CHECK: flow.dispatch @dispatch
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %t = flow.dispatch @dispatch(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: scf.yield
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ scf.yield %t : tensor<1xi32>
+ // CHECK{LITERAL}: } {stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ }
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %for : tensor<1xi32>
+}
+
+// -----
+
+// Tests that consumer-placed ops get placed based on their use in the body.
+
+// CHECK-LABEL: @scf_while_consumer_body_transfer
+util.func private @scf_while_consumer_body_transfer() -> tensor<1xi32> {
+ %c0 = arith.constant 0 : index
+ %c2_i32 = arith.constant 2 : i32
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: scf.while
+ %while = scf.while(%arg0 = %cst) : (tensor<1xi32>) -> tensor<1xi32> {
+ // CHECK: flow.tensor.load
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ %cond_i32 = flow.tensor.load %arg0[%c0] : tensor<1xi32>
+ %cond = arith.cmpi slt, %cond_i32, %c2_i32 : i32
+ // CHECK: scf.condition
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ scf.condition(%cond) %arg0 : tensor<1xi32>
+ } do {
+ ^bb0(%arg0: tensor<1xi32>):
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %arg0_a = flow.tensor.transfer %arg0 : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: flow.dispatch @dispatch
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %t = flow.dispatch @dispatch(%arg0_a) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: scf.yield
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ scf.yield %t : tensor<1xi32>
+ // CHECK: } attributes {
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ }
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %while : tensor<1xi32>
+}
+
+// -----
+
+// Tests that consumer-placed ops get placed based on their use as the result
+// of an scf.while body.
+
+// CHECK-LABEL: @scf_while_consumer_result_transfer
+util.func private @scf_while_consumer_result_transfer() -> tensor<1xi32> {
+ %c0 = arith.constant 0 : index
+ %c2_i32 = arith.constant 2 : i32
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: scf.while
+ %while = scf.while(%arg0 = %cst) : (tensor<1xi32>) -> tensor<1xi32> {
+ // CHECK: flow.tensor.load
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ %cond_i32 = flow.tensor.load %arg0[%c0] : tensor<1xi32>
+ %cond = arith.cmpi slt, %cond_i32, %c2_i32 : i32
+ // CHECK: scf.condition
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ scf.condition(%cond) %arg0 : tensor<1xi32>
+ } do {
+ ^bb0(%arg0: tensor<1xi32>):
+ // CHECK: flow.dispatch @dispatch
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %t = flow.dispatch @dispatch(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: scf.yield
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ scf.yield %t : tensor<1xi32>
+ // CHECK: } attributes {
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ }
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %while_a = flow.tensor.transfer %while : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %while_a : tensor<1xi32>
+}
+
+// -----
+
+// Tests that transfers track through scf.while bodies.
+
+// CHECK-LABEL: @scf_while_body_transfer
+util.func private @scf_while_body_transfer() -> tensor<1xi32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2_i32 = arith.constant 2 : i32
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32>
+ // CHECK: scf.while
+ %while = scf.while(%arg0 = %cst_a) : (tensor<1xi32>) -> tensor<1xi32> {
+ // CHECK: flow.tensor.load
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ %cond_i32 = flow.tensor.load %arg0[%c0] : tensor<1xi32>
+ %cond = arith.cmpi slt, %cond_i32, %c2_i32 : i32
+ // CHECK: scf.condition
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ scf.condition(%cond) %arg0 : tensor<1xi32>
+ } do {
+ ^bb0(%arg0: tensor<1xi32>):
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %arg0_b = flow.tensor.transfer %arg0 : tensor<1xi32> to #hal.device.promise<@dev_b>
+ // CHECK: flow.dispatch @dispatch
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %t = flow.dispatch @dispatch(%arg0_b) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: scf.yield
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ scf.yield %t : tensor<1xi32>
+ // CHECK: } attributes {
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ }
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_c>]]
+ %while_c = flow.tensor.transfer %while : tensor<1xi32> to #hal.device.promise<@dev_c>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]]
+ util.return %while_c : tensor<1xi32>
+}
+
+// -----
+
+// Tests that placed values track through to consumers in scf.while conditions.
+
+// CHECK-LABEL: @scf_while_capture_producer_condition
+util.func private @scf_while_capture_producer_condition() -> tensor<1xi32> {
+ %c0 = arith.constant 0 : index
+ %c2_i32 = arith.constant 2 : i32
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32>
+ // CHECK: scf.while
+ %while = scf.while(%arg0 = %cst_a) : (tensor<1xi32>) -> tensor<1xi32> {
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %arg0_a = flow.tensor.transfer %arg0 : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: flow.tensor.load
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ %cond_i32 = flow.tensor.load %arg0_a[%c0] : tensor<1xi32>
+ %cond = arith.cmpi slt, %cond_i32, %c2_i32 : i32
+ // CHECK: scf.condition
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ scf.condition(%cond) %arg0 : tensor<1xi32>
+ } do {
+ ^bb0(%arg0: tensor<1xi32>):
+ // CHECK: flow.dispatch @dispatch
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %t = flow.dispatch @dispatch(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: scf.yield
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ scf.yield %t : tensor<1xi32>
+ // CHECK: } attributes {
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ }
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %while : tensor<1xi32>
+}
+
+// -----
+
+// Tests that placed values track through to consumers in scf.while bodies.
+
+// CHECK-LABEL: @scf_while_capture_producer_body
+util.func private @scf_while_capture_producer_body() -> tensor<1xi32> {
+ %c0 = arith.constant 0 : index
+ %c2_i32 = arith.constant 2 : i32
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32>
+ // CHECK: scf.while
+ %while = scf.while(%arg0 = %cst_a) : (tensor<1xi32>) -> tensor<1xi32> {
+ // CHECK: flow.tensor.load
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ %cond_i32 = flow.tensor.load %arg0[%c0] : tensor<1xi32>
+ %cond = arith.cmpi slt, %cond_i32, %c2_i32 : i32
+ // CHECK: scf.condition
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ scf.condition(%cond) %arg0 : tensor<1xi32>
+ } do {
+ ^bb0(%arg0: tensor<1xi32>):
+ // CHECK: flow.dispatch @dispatch
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %t = flow.dispatch @dispatch(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: scf.yield
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ scf.yield %t : tensor<1xi32>
+ // CHECK: } attributes {
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ }
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %while : tensor<1xi32>
+}
+
+// -----
+
+// Tests a realistic program with ABI ops.
+
+// CHECK-LABEL: @simple_program
+util.func public @simple_program(%arg0: !hal.buffer_view, %arg1: !hal.fence, %arg2: !hal.fence) -> !hal.buffer_view {
+ // CHECK: hal.tensor.import
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %0 = hal.tensor.import on(#hal.device.promise<@dev_a>) wait(%arg1) => %arg0 "input0" : !hal.buffer_view -> tensor<1xi32>
+ // CHECK: util.call @_simple_program
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %1 = util.call @_simple_program(%0) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %2 = flow.tensor.transfer %1 : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: hal.tensor.barrier
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %3 = hal.tensor.barrier join(%2 : tensor<1xi32>) => %arg2 : !hal.fence
+ // CHECK: hal.tensor.export
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ %4 = hal.tensor.export on(#hal.device.promise<@dev_a>) %3 "output0" : tensor<1xi32> -> !hal.buffer_view
+ util.return %4 : !hal.buffer_view
+}
+// CHECK: util.func private @_simple_program
+util.func private @_simple_program(%arg0: tensor<1xi32>) -> tensor<1xi32> {
+ // CHECK: util.call @dispatch_a
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %0 = util.call @dispatch_a(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %1 = flow.tensor.transfer %0 : tensor<1xi32> to #hal.device.promise<@dev_b>
+ // CHECK: util.call @dispatch_b
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %2 = util.call @dispatch_b(%1) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ util.return %2 : tensor<1xi32>
+}
+// CHECK: util.func private @dispatch_a
+util.func private @dispatch_a(%arg0: tensor<1xi32>) -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst = flow.tensor.constant dense<[1]> : tensor<1xi32>
+ // CHECK: flow.dispatch @dispatch_a
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %0 = flow.dispatch @dispatch_a(%arg0, %cst) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %0 : tensor<1xi32>
+}
+// CHECK: util.func private @dispatch_b
+util.func private @dispatch_b(%arg0: tensor<1xi32>) -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %cst = flow.tensor.constant dense<[2]> : tensor<1xi32>
+ // CHECK: flow.dispatch @dispatch_b
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>], [#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %0 = flow.dispatch @dispatch_b(%arg0, %cst) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ util.return %0 : tensor<1xi32>
+}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/fuse_dispatch_bindings.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/fuse_dispatch_bindings.mlir
index 14e8fb2..ed1f338 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/fuse_dispatch_bindings.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/fuse_dispatch_bindings.mlir
@@ -16,8 +16,8 @@
stream.executable private @rebaseBindingsEx {
stream.executable.export public @dispatch attributes {stream.resources = #aliasConfig}
builtin.module {
- // CHECK: util.func public @dispatch(%[[BINDING_A:.+]]: !stream.binding, %[[BINDING_B:.+]]: !stream.binding,
- // CHECK-SAME: %[[OFFSET_A:.+]]: index, %[[OFFSET_B:.+]]: index, %[[OPERAND:.+]]: index)
+ // CHECK: util.func public @dispatch(%[[BINDING_A:.+]]: !stream.binding, %[[BINDING_B:.+]]: !stream.binding,
+ // CHECK-SAME: %[[OFFSET_A:.+]]: index, %[[OFFSET_B:.+]]: index, %[[OPERAND:.+]]: index)
util.func public @dispatch(%binding_a: !stream.binding, %binding_b: !stream.binding, %operand: index) {
%c0 = arith.constant 0 : index
%c20 = arith.constant 20 : index
@@ -39,7 +39,7 @@
}
}
}
-// CHECK: util.func public @rebaseBindings(%[[OPERAND:.+]]: index)
+// CHECK: util.func public @rebaseBindings(%[[OPERAND:.+]]: index)
util.func public @rebaseBindings(%operand: index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
@@ -97,8 +97,8 @@
stream.executable private @deduplicateBindingsEx {
stream.executable.export public @dispatch attributes {stream.resources = #aliasConfig}
builtin.module {
- // CHECK: util.func public @dispatch(%[[BINDING_A:.+]]: !stream.binding, %[[BINDING_B:.+]]: !stream.binding,
- // CHECK-SAME: %[[OFFSET_A:.+]]: index, %[[OFFSET_C:.+]]: index, %[[OFFSET_B:.+]]: index, %[[OPERAND:.+]]: index)
+ // CHECK: util.func public @dispatch(%[[BINDING_A:.+]]: !stream.binding, %[[BINDING_B:.+]]: !stream.binding,
+ // CHECK-SAME: %[[OFFSET_A:.+]]: index, %[[OFFSET_C:.+]]: index, %[[OFFSET_B:.+]]: index, %[[OPERAND:.+]]: index)
util.func public @dispatch(%binding_a: !stream.binding, %binding_b: !stream.binding, %binding_c: !stream.binding, %operand: index) {
%c0 = arith.constant 0 : index
%c20 = arith.constant 20 : index
@@ -127,7 +127,7 @@
}
}
}
-// CHECK: util.func public @deduplicateBindings(%[[OPERAND:.+]]: index)
+// CHECK: util.func public @deduplicateBindings(%[[OPERAND:.+]]: index)
util.func public @deduplicateBindings(%operand: index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/materialize_copy_on_write.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/materialize_copy_on_write.mlir
index a3b4ef6..a1509e3 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/materialize_copy_on_write.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/materialize_copy_on_write.mlir
@@ -110,13 +110,15 @@
// TODO(#11249): support in-place collectives - when supported this will become
// a negative test as we'd expect %send_recv to be used for both operands.
+util.global private @device : !hal.device
+
// CHECK-LABEL: @tiedCollectivesTODO
// CHECK-SAME: (%[[CHANNEL:.+]]: !stream.channel, %[[SEND_RECV:.+]]: !stream.resource<*>, %[[SEND_SIZE:.+]]: index, %[[RECV_SIZE:.+]]: index, %[[COUNT:.+]]: index)
util.func private @tiedCollectivesTODO(%channel: !stream.channel, %send_recv: !stream.resource<*>, %send_size: index, %recv_size: index, %count: index) -> !stream.resource<*> {
%c0 = arith.constant 0 : index
- // CHECK: %[[RECV_CLONE:.+]] = stream.async.clone on(#hal.affinity.queue<[0]>) %[[SEND_RECV]]
+ // CHECK: %[[RECV_CLONE:.+]] = stream.async.clone on(#hal.device.affinity<@device>) %[[SEND_RECV]]
// CHECK: %[[ALL_GATHER:.+]] = stream.async.collective<all_gather : f32>[%[[COUNT]]]
- %0 = stream.async.collective<all_gather : f32>[%count] on(#hal.affinity.queue<[0]>) channel(%channel)
+ %0 = stream.async.collective<all_gather : f32>[%count] on(#hal.device.affinity<@device>) channel(%channel)
// CHECK-SAME: %[[SEND_RECV]][%c0 to %[[SEND_SIZE]] for %[[SEND_SIZE]]],
%send_recv[%c0 to %send_size for %send_size],
// CHECK-SAME: %[[RECV_CLONE]][%c0 to %[[RECV_SIZE]] for %[[RECV_SIZE]]] :
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_allocation.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_allocation.mlir
index 00f5c32..8266ca5 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_allocation.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_allocation.mlir
@@ -223,27 +223,29 @@
// execution region. We expect them to be placed into packed slices and
// allocated with the async stream-ordered alloca/dealloca ops.
+util.global private @device : !hal.device
+
// CHECK-LABEL: @locals
// CHECK-SAME: (%[[SIZE0:.+]]: index, %[[SIZE1:.+]]: index, %[[AWAIT_TIMEPOINT:.+]]: !stream.timepoint)
util.func public @locals(%size0: index, %size1: index, %await_timepoint: !stream.timepoint) -> !stream.timepoint {
%c254_i32 = arith.constant 254 : i32
%c255_i32 = arith.constant 255 : i32
- // CHECK: %[[SLICES:.+]]:3 = stream.resource.pack on(#hal.affinity.queue<[0]>) slices({
+ // CHECK: %[[SLICES:.+]]:3 = stream.resource.pack on(#hal.device.affinity<@device>) slices({
// CHECK-NEXT: [0, 0] = %[[SIZE0]],
// CHECK-NEXT: [1, 1] = %[[SIZE1]]
// CHECK-NEXT: })
- // CHECK-NEXT: %[[ALLOCA:.+]], %[[ALLOCA_TIMEPOINT:.+]] = stream.resource.alloca uninitialized on(#hal.affinity.queue<[0]>) await(%[[AWAIT_TIMEPOINT]]) => !stream.resource<transient>{%[[SLICES]]#0} => !stream.timepoint
+ // CHECK-NEXT: %[[ALLOCA:.+]], %[[ALLOCA_TIMEPOINT:.+]] = stream.resource.alloca uninitialized on(#hal.device.affinity<@device>) await(%[[AWAIT_TIMEPOINT]]) => !stream.resource<transient>{%[[SLICES]]#0} => !stream.timepoint
// CHECK-NEXT: %[[AWAIT_JOIN:.+]] = stream.timepoint.join max(%[[AWAIT_TIMEPOINT]], %[[ALLOCA_TIMEPOINT]])
- // CHECK: %[[EXEC_TIMEPOINT:.+]] = stream.cmd.execute on(#hal.affinity.queue<[0]>) await(%[[AWAIT_JOIN]])
+ // CHECK: %[[EXEC_TIMEPOINT:.+]] = stream.cmd.execute on(#hal.device.affinity<@device>) await(%[[AWAIT_JOIN]])
// CHECK-SAME: with(%[[ALLOCA]] as %[[CAPTURE:.+]]: !stream.resource<transient>{%[[SLICES]]#0})
- %result_timepoint = stream.async.execute on(#hal.affinity.queue<[0]>) await(%await_timepoint) => with() {
+ %result_timepoint = stream.async.execute on(#hal.device.affinity<@device>) await(%await_timepoint) => with() {
// CHECK: stream.cmd.fill %c254_i32, %[[CAPTURE]][%[[SLICES]]#1 for %[[SIZE0]]] : i32 -> !stream.resource<transient>{%[[SLICES]]#0}
%0 = stream.async.splat %c254_i32 : i32 -> !stream.resource<transient>{%size0}
// CHECK: stream.cmd.fill %c255_i32, %[[CAPTURE]][%[[SLICES]]#2 for %[[SIZE1]]] : i32 -> !stream.resource<transient>{%[[SLICES]]#0}
%1 = stream.async.splat %c255_i32 : i32 -> !stream.resource<transient>{%size1}
stream.yield
} => !stream.timepoint
- // CHECK: %[[DEALLOCA_TIMEPOINT:.+]] = stream.resource.dealloca on(#hal.affinity.queue<[0]>) await(%[[EXEC_TIMEPOINT]]) => %[[ALLOCA]] : !stream.resource<transient>{%[[SLICES]]#0} => !stream.timepoint
+ // CHECK: %[[DEALLOCA_TIMEPOINT:.+]] = stream.resource.dealloca on(#hal.device.affinity<@device>) await(%[[EXEC_TIMEPOINT]]) => %[[ALLOCA]] : !stream.resource<transient>{%[[SLICES]]#0} => !stream.timepoint
// CHECK: %[[JOIN:.+]] = stream.timepoint.join max(%[[DEALLOCA_TIMEPOINT]], %[[EXEC_TIMEPOINT]]) => !stream.timepoint
// CHECK: util.return %[[JOIN]]
util.return %result_timepoint : !stream.timepoint
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir
index 3ccd781..dcdc586 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir
@@ -38,6 +38,9 @@
// Dispatches with the same affinities should be placed into the same execution
// regions.
+util.global private @device_a : !hal.device
+util.global private @device_b : !hal.device
+
// CHECK-LABEL: @partitioningWithAffinities
// CHECK-SAME: (%[[ARG0:.+]]: !stream.resource<external>)
util.func public @partitioningWithAffinities(%arg0: !stream.resource<external>) -> !stream.resource<external> {
@@ -48,31 +51,30 @@
%c255_i32 = arith.constant 255 : i32
// CHECK: %[[TRANSIENTS:.+]]:2, %[[TIMEPOINT0:.+]] = stream.async.execute
- // CHECK-SAME: on(#hal.affinity.queue<[0]>)
+ // CHECK-SAME: on(#hal.device.affinity<@device_a>)
// CHECK-SAME: with(%[[ARG0]] as %[[ARG0_CAPTURE:.+]]: !stream.resource<external>{%c20})
// CHECK-SAME: -> (!stream.resource<transient>{%c1280}, !stream.resource<transient>{%c20}) {
// CHECK-NEXT: %[[SPLAT:.+]] = stream.async.splat
%splat = stream.async.splat %c255_i32 : i32 -> !stream.resource<transient>{%c1280}
// CHECK-NEXT: %[[DISPATCH0:.+]] = stream.async.dispatch @ex::@dispatch_0[%c1](%[[ARG0_CAPTURE]][{{.+}}], %[[SPLAT]][{{.+}}])
- %dispatch0 = stream.async.dispatch on(#hal.affinity.queue<[0]>) @ex::@dispatch_0[%c1](%arg0[%c0 to %c20 for %c20], %splat[%c0 to %c20 for %c20]) : (!stream.resource<external>{%c20}, !stream.resource<transient>{%c20}) -> !stream.resource<transient>{%c1280}
+ %dispatch0 = stream.async.dispatch on(#hal.device.affinity<@device_a>) @ex::@dispatch_0[%c1](%arg0[%c0 to %c20 for %c20], %splat[%c0 to %c20 for %c20]) : (!stream.resource<external>{%c20}, !stream.resource<transient>{%c20}) -> !stream.resource<transient>{%c1280}
// CHECK-NEXT: %[[DISPATCH1:.+]] = stream.async.dispatch @ex::@dispatch_1[%c1](%[[ARG0_CAPTURE]][{{.+}}], %[[SPLAT]][{{.+}}])
- %dispatch1 = stream.async.dispatch on(#hal.affinity.queue<[0]>) @ex::@dispatch_1[%c1](%arg0[%c0 to %c20 for %c20], %splat[%c0 to %c20 for %c20]) : (!stream.resource<external>{%c20}, !stream.resource<transient>{%c20}) -> !stream.resource<transient>{%c20}
+ %dispatch1 = stream.async.dispatch on(#hal.device.affinity<@device_a>) @ex::@dispatch_1[%c1](%arg0[%c0 to %c20 for %c20], %splat[%c0 to %c20 for %c20]) : (!stream.resource<external>{%c20}, !stream.resource<transient>{%c20}) -> !stream.resource<transient>{%c20}
// CHECK-NEXT: stream.yield %[[DISPATCH0]], %[[DISPATCH1]]
// CHECK-NEXT: } => !stream.timepoint
// CHECK: %[[RESULT:.+]], %[[TIMEPOINT1:.+]] = stream.async.execute
- // CHECK-SAME: on(#hal.affinity.queue<[1]>)
+ // CHECK-SAME: on(#hal.device.affinity<@device_b>)
// CHECK-SAME: await(%[[TIMEPOINT0]])
// CHECK-SAME: with(%[[TRANSIENTS]]#0 as %[[TRANSIENT0_CAPTURE:.+]]: !stream.resource<transient>{%c1280},
// CHECK-SAME: %[[TRANSIENTS]]#1 as %[[TRANSIENT1_CAPTURE:.+]]: !stream.resource<transient>{%c20})
// CHECK-SAME: -> !stream.resource<external>{%c20}
// CHECK-NEXT: %[[DISPATCH2:.+]] = stream.async.dispatch @ex::@dispatch_2[%c1](%[[TRANSIENT0_CAPTURE]][{{.+}}], %[[TRANSIENT1_CAPTURE]][{{.+}}])
- %dispatch2 = stream.async.dispatch on(#hal.affinity.queue<[1]>) @ex::@dispatch_2[%c1](%dispatch0[%c0 to %c1280 for %c1280], %dispatch1[%c0 to %c20 for %c20]) : (!stream.resource<transient>{%c1280}, !stream.resource<transient>{%c20}) -> !stream.resource<external>{%c20}
+ %dispatch2 = stream.async.dispatch on(#hal.device.affinity<@device_b>) @ex::@dispatch_2[%c1](%dispatch0[%c0 to %c1280 for %c1280], %dispatch1[%c0 to %c20 for %c20]) : (!stream.resource<transient>{%c1280}, !stream.resource<transient>{%c20}) -> !stream.resource<external>{%c20}
// CHECK-NEXT: stream.yield %[[DISPATCH2]]
// CHECK-NEXT: } => !stream.timepoint
// CHECK-NEXT: %[[READY:.+]] = stream.timepoint.await
- // CHECK-SAME: on(#hal.affinity.queue<[1]>)
// CHECK-SAME: %[[TIMEPOINT1]] => %[[RESULT]] : !stream.resource<external>{%c20}
// CHECK-NEXT: util.return %[[READY]]
util.return %dispatch2 : !stream.resource<external>
@@ -84,6 +86,10 @@
// dependencies. Unrelated dispatches with differing affinities should end up
// in concurrently executable regions.
+util.global private @device_a : !hal.device
+util.global private @device_b : !hal.device
+util.global private @device_c : !hal.device
+
// CHECK-LABEL: @partitioningWithConcurrentAffinities
// CHECK-SAME: (%[[ARG0:.+]]: !stream.resource<external>)
util.func public @partitioningWithConcurrentAffinities(%arg0: !stream.resource<external>) -> !stream.resource<external> {
@@ -94,23 +100,23 @@
%c255_i32 = arith.constant 255 : i32
// CHECK: %[[TRANSIENT0:.+]], %[[TIMEPOINT0:.+]] = stream.async.execute
- // CHECK-SAME: on(#hal.affinity.queue<[0]>)
+ // CHECK-SAME: on(#hal.device.affinity<@device_a>)
// CHECK-SAME: with(%[[ARG0]] as %[[ARG0_CAPTURE0:.+]]: !stream.resource<external>{%c20})
// CHECK-SAME: !stream.resource<transient>{%c1280}
// CHECK-NEXT: %[[SPLAT0:.+]] = stream.async.splat
%splat = stream.async.splat %c255_i32 : i32 -> !stream.resource<transient>{%c1280}
// CHECK-NEXT: %[[DISPATCH0:.+]] = stream.async.dispatch @ex::@dispatch_0[%c1](%[[ARG0_CAPTURE0]][{{.+}}], %[[SPLAT0]][{{.+}}])
- %dispatch0 = stream.async.dispatch on(#hal.affinity.queue<[0]>) @ex::@dispatch_0[%c1](%arg0[%c0 to %c20 for %c20], %splat[%c0 to %c20 for %c20]) : (!stream.resource<external>{%c20}, !stream.resource<transient>{%c20}) -> !stream.resource<transient>{%c1280}
+ %dispatch0 = stream.async.dispatch on(#hal.device.affinity<@device_a>) @ex::@dispatch_0[%c1](%arg0[%c0 to %c20 for %c20], %splat[%c0 to %c20 for %c20]) : (!stream.resource<external>{%c20}, !stream.resource<transient>{%c20}) -> !stream.resource<transient>{%c1280}
// CHECK-NEXT: stream.yield %[[DISPATCH0]]
// CHECK-NEXT: } => !stream.timepoint
// CHECK: %[[TRANSIENT1:.+]], %[[TIMEPOINT1:.+]] = stream.async.execute
- // CHECK-SAME: on(#hal.affinity.queue<[1]>)
+ // CHECK-SAME: on(#hal.device.affinity<@device_b>)
// CHECK-SAME: with(%[[ARG0]] as %[[ARG0_CAPTURE1:.+]]: !stream.resource<external>{%c20})
// CHECK-SAME: -> !stream.resource<transient>{%c20} {
// CHECK-NEXT: %[[SPLAT1:.+]] = stream.async.splat
// CHECK-NEXT: %[[DISPATCH1:.+]] = stream.async.dispatch @ex::@dispatch_1[%c1](%[[ARG0_CAPTURE1]][{{.+}}], %[[SPLAT1]][{{.+}}])
- %dispatch1 = stream.async.dispatch on(#hal.affinity.queue<[1]>) @ex::@dispatch_1[%c1](%arg0[%c0 to %c20 for %c20], %splat[%c0 to %c20 for %c20]) : (!stream.resource<external>{%c20}, !stream.resource<transient>{%c20}) -> !stream.resource<transient>{%c20}
+ %dispatch1 = stream.async.dispatch on(#hal.device.affinity<@device_b>) @ex::@dispatch_1[%c1](%arg0[%c0 to %c20 for %c20], %splat[%c0 to %c20 for %c20]) : (!stream.resource<external>{%c20}, !stream.resource<transient>{%c20}) -> !stream.resource<transient>{%c20}
// CHECK-NEXT: stream.yield %[[DISPATCH1]]
// CHECK-NEXT: } => !stream.timepoint
@@ -121,12 +127,11 @@
// CHECK-SAME: with(%[[TRANSIENT0]] as %[[TRANSIENT0_CAPTURE:.+]]: !stream.resource<transient>{%c1280},
// CHECK-SAME: %[[TRANSIENT1]] as %[[TRANSIENT1_CAPTURE:.+]]: !stream.resource<transient>{%c20})
// CHECK-NEXT: %[[DISPATCH2:.+]] = stream.async.dispatch @ex::@dispatch_2[%c1](%[[TRANSIENT0_CAPTURE]][{{.+}}], %[[TRANSIENT1_CAPTURE]][{{.+}}])
- %dispatch2 = stream.async.dispatch on(#hal.affinity.queue<[2]>) @ex::@dispatch_2[%c1](%dispatch0[%c0 to %c1280 for %c1280], %dispatch1[%c0 to %c20 for %c20]) : (!stream.resource<transient>{%c1280}, !stream.resource<transient>{%c20}) -> !stream.resource<external>{%c20}
+ %dispatch2 = stream.async.dispatch on(#hal.device.affinity<@device_c>) @ex::@dispatch_2[%c1](%dispatch0[%c0 to %c1280 for %c1280], %dispatch1[%c0 to %c20 for %c20]) : (!stream.resource<transient>{%c1280}, !stream.resource<transient>{%c20}) -> !stream.resource<external>{%c20}
// CHECK-NEXT: stream.yield %[[DISPATCH2]]
// CHECK-NEXT: } => !stream.timepoint
// CHECK-NEXT: %[[READY:.+]] = stream.timepoint.await
- // CHECK-SAME: on(#hal.affinity.queue<[2]>)
// CHECK-SAME: %[[TIMEPOINT2]] => %[[RESULT]] : !stream.resource<external>{%c20}
// CHECK-NEXT: util.return %[[READY]]
util.return %dispatch2 : !stream.resource<external>
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/verify_affinities.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/verify_affinities.mlir
new file mode 100644
index 0000000..ee29810
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/verify_affinities.mlir
@@ -0,0 +1,33 @@
+// RUN: iree-opt --iree-stream-verify-affinities --split-input-file %s --verify-diagnostics | FileCheck %s
+
+// Tests that affinities on ops are checked.
+
+// CHECK-LABEL: @affinityOnOp
+util.func public @affinityOnOp(%size: index) {
+ // CHECK: stream.async.alloca
+ %0 = stream.async.alloca on(#hal.device.promise<@device>) : !stream.resource<transient>{%size}
+ util.return
+}
+
+// -----
+
+// Tests that affinities on ancestor ops are allowed.
+
+// CHECK-LABEL: @affinityOnAncestorOp
+util.func public @affinityOnAncestorOp(%size: index) attributes {
+ stream.affinity = #hal.device.promise<@device>
+} {
+ // CHECK: stream.async.alloca
+ %0 = stream.async.alloca : !stream.resource<transient>{%size}
+ util.return
+}
+
+// -----
+
+// Tests that ops with no affinities fail.
+
+util.func public @missingAffinity(%size: index) {
+ // expected-error @+1 {{does not have an affinity assigned}}
+ %0 = stream.async.alloca : !stream.resource<transient>{%size}
+ util.return
+}
diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp
index c25458c..bb46c83 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp
@@ -320,7 +320,8 @@
return WalkResult::advance();
}
-TraversalResult Explorer::walkValues(ValueWalkFn fn) {
+TraversalResult Explorer::walkAllValues(ValueWalkFn fn,
+ std::optional<TypeID> typeID) {
LLVM_DEBUG(llvm::dbgs() << "[[ Explorer::walkValues ]]\n");
TraversalResult result = TraversalResult::COMPLETE;
@@ -357,7 +358,8 @@
LLVM_DEBUG(llvm::dbgs() << " + entering callable region @"
<< getRegionName(callableRegion) << "\n");
- auto emitResult = recursiveWalkValues(callableOp, visitedValues, fn);
+ auto emitResult =
+ recursiveWalkValues(callableOp, visitedValues, fn, typeID);
if (emitResult.wasInterrupted())
break;
if (emitResult.wasSkipped())
@@ -384,7 +386,8 @@
WalkResult Explorer::recursiveWalkValues(Operation *parentOp,
DenseSet<Value> &visitedValues,
- const ValueWalkFn &fn) {
+ const ValueWalkFn &fn,
+ std::optional<TypeID> typeID) {
auto parentAction = getTraversalAction(parentOp);
if (parentAction == TraversalAction::IGNORE) {
LLVM_DEBUG(llvm::dbgs()
@@ -396,6 +399,8 @@
LLVM_DEBUG(llvm::dbgs()
<< " + processing op results " << getOpName(parentOp) << "\n");
for (auto result : parentOp->getResults()) {
+ if (typeID.has_value() && result.getType().getTypeID() != *typeID)
+ continue;
if (visitedValues.insert(result).second) {
LLVM_DEBUG({
llvm::dbgs() << " == emitting value ";
@@ -425,6 +430,8 @@
llvm::dbgs() << " arguments\n";
});
for (auto arg : block.getArguments()) {
+ if (typeID.has_value() && arg.getType().getTypeID() != *typeID)
+ continue;
if (visitedValues.insert(arg).second) {
LLVM_DEBUG({
llvm::dbgs() << " == emitting block arg ";
@@ -437,7 +444,7 @@
}
}
for (auto &op : block) {
- auto opResult = recursiveWalkValues(&op, visitedValues, fn);
+ auto opResult = recursiveWalkValues(&op, visitedValues, fn, typeID);
if (opResult.wasInterrupted())
return WalkResult::interrupt();
}
@@ -672,7 +679,8 @@
// traversal algorithm separated from the policy here. This would let us
// reuse the traversal for other kinds of walks that are more specific (like
// only getting the ops or values instead of both, etc).
-TraversalResult Explorer::walkDefiningOps(Value value, ResultWalkFn fn) {
+TraversalResult Explorer::walkDefiningOps(Value value, ResultWalkFn fn,
+ TraversalBehavior options) {
// Fast-path short-circuit for constants, which are like 25% of all IR.
if (value.getDefiningOp() &&
value.getDefiningOp()->hasTrait<OpTrait::ConstantLike>()) {
@@ -849,15 +857,17 @@
// If the op is tied we may need to walk up to the operand the result is
// tied to.
- if (auto tiedOp = dyn_cast<IREE::Util::TiedOpInterface>(definingOp)) {
- auto tiedOperand = tiedOp.getTiedResultOperand(resultValue);
- if (tiedOperand) {
- LLVM_DEBUG({
- llvm::dbgs() << " + queuing tied operand ";
- tiedOperand.printAsOperand(llvm::dbgs(), asmState);
- llvm::dbgs() << "\n";
- });
- worklist.insert(tiedOperand);
+ if (!bitEnumContains(options, TraversalBehavior::DONT_WALK_TIED_VALUES)) {
+ if (auto tiedOp = dyn_cast<IREE::Util::TiedOpInterface>(definingOp)) {
+ auto tiedOperand = tiedOp.getTiedResultOperand(resultValue);
+ if (tiedOperand) {
+ LLVM_DEBUG({
+ llvm::dbgs() << " + queuing tied operand ";
+ tiedOperand.printAsOperand(llvm::dbgs(), asmState);
+ llvm::dbgs() << "\n";
+ });
+ worklist.insert(tiedOperand);
+ }
}
}
@@ -884,7 +894,8 @@
return result;
}
-TraversalResult Explorer::walkTransitiveUses(Value value, UseWalkFn fn) {
+TraversalResult Explorer::walkTransitiveUses(Value value, UseWalkFn fn,
+ TraversalBehavior options) {
LLVM_DEBUG(llvm::dbgs() << "[[ Explorer::walkTransitiveUses ]]\n");
TraversalResult result = TraversalResult::COMPLETE;
@@ -1083,15 +1094,17 @@
// If the op is tied we may need to walk down to the results the operand
// is tied to (multiple results can tie the same operand).
- if (auto tiedOp = dyn_cast<IREE::Util::TiedOpInterface>(ownerOp)) {
- for (auto tiedResult :
- tiedOp.getOperandTiedResults(use.getOperandNumber())) {
- LLVM_DEBUG({
- llvm::dbgs() << " + queuing tied result ";
- tiedResult.printAsOperand(llvm::dbgs(), asmState);
- llvm::dbgs() << "\n";
- });
- worklist.insert(tiedResult);
+ if (!bitEnumContains(options, TraversalBehavior::DONT_WALK_TIED_VALUES)) {
+ if (auto tiedOp = dyn_cast<IREE::Util::TiedOpInterface>(ownerOp)) {
+ for (auto tiedResult :
+ tiedOp.getOperandTiedResults(use.getOperandNumber())) {
+ LLVM_DEBUG({
+ llvm::dbgs() << " + queuing tied result ";
+ tiedResult.printAsOperand(llvm::dbgs(), asmState);
+ llvm::dbgs() << "\n";
+ });
+ worklist.insert(tiedResult);
+ }
}
}
@@ -1142,14 +1155,18 @@
return result;
}
-TraversalResult Explorer::walkTransitiveUsers(Value value, OperationWalkFn fn) {
+TraversalResult Explorer::walkTransitiveUsers(Value value, OperationWalkFn fn,
+ TraversalBehavior options) {
DenseSet<Operation *> visitedOwners;
- return walkTransitiveUses(value, [&](OpOperand &use) {
- if (visitedOwners.insert(use.getOwner()).second) {
- return fn(use.getOwner());
- }
- return WalkResult::advance();
- });
+ return walkTransitiveUses(
+ value,
+ [&](OpOperand &use) {
+ if (visitedOwners.insert(use.getOwner()).second) {
+ return fn(use.getOwner());
+ }
+ return WalkResult::advance();
+ },
+ options);
}
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.h b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.h
index 000aa2f..35ee12a 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.h
+++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.h
@@ -37,6 +37,31 @@
IGNORE,
};
+enum class TraversalBehavior : uint32_t {
+ // When traversing defining ops any tied result will move through its tied
+ // operand. When traversing uses any tied operand will move through its tied
+ // results (as many as are tied to the operand).
+ DEFAULT = 0u,
+ // Don't traverse through tied operands or results.
+ DONT_WALK_TIED_VALUES = 1 << 0u,
+};
+inline TraversalBehavior operator~(TraversalBehavior value) {
+ return static_cast<TraversalBehavior>(~static_cast<uint32_t>(value));
+}
+inline TraversalBehavior operator|(TraversalBehavior lhs,
+ TraversalBehavior rhs) {
+ return static_cast<TraversalBehavior>(static_cast<uint32_t>(lhs) |
+ static_cast<uint32_t>(rhs));
+}
+inline TraversalBehavior operator&(TraversalBehavior lhs,
+ TraversalBehavior rhs) {
+ return static_cast<TraversalBehavior>(static_cast<uint32_t>(lhs) &
+ static_cast<uint32_t>(rhs));
+}
+inline bool bitEnumContains(TraversalBehavior bits, TraversalBehavior bit) {
+ return (static_cast<uint32_t>(bits) & static_cast<uint32_t>(bit)) != 0;
+}
+
// Boolean operations on TraversalResult behave as though `INCOMPLETE` is
// truthy to allow for |='ing results.
enum class TraversalResult {
@@ -229,7 +254,15 @@
TraversalResult walk(OperationWalkFn fn);
// Walks all unique SSA values nested within the root op.
- TraversalResult walkValues(ValueWalkFn fn);
+ TraversalResult walkValues(ValueWalkFn fn) {
+ return walkAllValues(fn, std::nullopt);
+ }
+ // Walks all unique SSA values nested within the root op that have the given
+ // type.
+ template <typename OpT>
+ TraversalResult walkValuesOfType(ValueWalkFn fn) {
+ return walkAllValues(fn, OpT::getTypeID());
+ }
// Walks all unique SSA values used/defined by |op| and all nested regions.
TraversalResult walkValues(Operation *op, ValueWalkFn fn);
@@ -305,7 +338,9 @@
// Walk %2: [%2 of producer.b]
// Walk @some_user::%arg0: [%0 of producer.a]
// Walk @some_user::ret0: [%2 of producer.b]
- TraversalResult walkDefiningOps(Value value, ResultWalkFn fn);
+ TraversalResult
+ walkDefiningOps(Value value, ResultWalkFn fn,
+ TraversalBehavior options = TraversalBehavior::DEFAULT);
// Randomly walks uses of |value| and any transitive alias of |value|.
// The uses may come from any part of the program.
@@ -326,13 +361,17 @@
// Walk %arg0: [%arg0 of producer.a]
// Walk %0: [%0 of call @some_user, %arg0 of producer.b]
// Walk %2: [%2 of return, %1 of return]
- TraversalResult walkTransitiveUses(Value value, UseWalkFn fn);
+ TraversalResult
+ walkTransitiveUses(Value value, UseWalkFn fn,
+ TraversalBehavior options = TraversalBehavior::DEFAULT);
// Randomly walks uses of |value| and any transitive alias of |value| and
// returns each owner operation once. As a value may be used multiple times
// by a single operation this is equivalent to a walkTransitiveUses with
// deduplication on the owner of the use.
- TraversalResult walkTransitiveUsers(Value value, OperationWalkFn fn);
+ TraversalResult
+ walkTransitiveUsers(Value value, OperationWalkFn fn,
+ TraversalBehavior options = TraversalBehavior::DEFAULT);
private:
// Maps callee callable region -> call sites.
@@ -341,10 +380,13 @@
void initializeGlobalInfos();
void initializeInverseCallGraph();
+ TraversalResult walkAllValues(ValueWalkFn fn, std::optional<TypeID> typeID);
+
WalkResult recursiveWalk(Operation *parentOp, const OperationWalkFn &fn);
WalkResult recursiveWalkValues(Operation *parentOp,
DenseSet<Value> &visitedValues,
- const ValueWalkFn &fn);
+ const ValueWalkFn &fn,
+ std::optional<TypeID> typeID = std::nullopt);
Operation *rootOp = nullptr;
AsmState asmState;
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp
index b4b708a..318160f 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp
@@ -228,23 +228,23 @@
// op order issues.
SmallVector<std::map<StringRef, SmallVector<Operation *>>> sequencedBuckets;
sequencedBuckets.push_back({}); // Start in a sequence.
- block.walk([&](Operation *op) {
+ for (auto &op : block) {
auto &buckets = sequencedBuckets.back();
if (auto loadOp = dyn_cast<IREE::Util::GlobalLoadOpInterface>(op)) {
if (!immutableGlobals.contains(loadOp.getGlobalName())) {
- buckets[loadOp.getGlobalName()].push_back(op);
+ buckets[loadOp.getGlobalName()].push_back(&op);
}
} else if (auto storeOp =
dyn_cast<IREE::Util::GlobalStoreOpInterface>(op)) {
- buckets[storeOp.getGlobalName()].push_back(op);
- } else if (doesOpBlockMotion(op)) {
+ buckets[storeOp.getGlobalName()].push_back(&op);
+ } else if (doesOpBlockMotion(&op)) {
// Split point - all accesses after this point must not assume anything
// about accesses before it.
if (!buckets.empty()) {
sequencedBuckets.push_back({});
}
}
- });
+ }
bool didRemoveAny = false;
for (auto &buckets : sequencedBuckets) {
didRemoveAny = optimizeBuckets(block, buckets) || didRemoveAny;
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
index 5eef577..9df179f 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
@@ -1063,7 +1063,6 @@
if (typeName[0] == '!') {
typeName = typeName.substr(1);
}
- typeName = std::string("\"") + typeName + std::string("\"");
Value stringView =
emitc_builders::ireeMakeCstringView(builder, loc, typeName);
@@ -2947,6 +2946,107 @@
}
};
+class SelectRefOpConversion
+ : public EmitCConversionPattern<IREE::VM::SelectRefOp> {
+ using Adaptor = typename IREE::VM::SelectRefOp::Adaptor;
+ using EmitCConversionPattern<IREE::VM::SelectRefOp>::EmitCConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(IREE::VM::SelectRefOp selectOp, Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto ctx = selectOp.getContext();
+ auto loc = selectOp.getLoc();
+
+ auto moduleOp =
+ selectOp.getOperation()->template getParentOfType<IREE::VM::ModuleOp>();
+ auto funcOp = selectOp.getOperation()
+ ->template getParentOfType<mlir::emitc::FuncOp>();
+ auto &funcAnalysis = getModuleAnalysis().lookupFunction(funcOp);
+
+ const BlockArgument moduleArg = funcOp.getArgument(CCONV_ARGUMENT_MODULE);
+ auto resultTypePtr =
+ createVmTypeDefPtr(rewriter, loc, this->getModuleAnalysis(), moduleOp,
+ moduleArg, selectOp.getType());
+ if (!resultTypePtr.has_value()) {
+ return selectOp->emitError() << "generating iree_vm_type_def_t* failed";
+ }
+ auto resultTypeAsRef =
+ rewriter
+ .create<emitc::CallOpaqueOp>(
+ /*location=*/loc,
+ /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_type_t"),
+ /*callee=*/StringAttr::get(ctx, "iree_vm_type_def_as_ref"),
+ /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{},
+ /*operands=*/ArrayRef<Value>{resultTypePtr.value()})
+ .getResult(0);
+
+ bool moveTrue =
+ funcAnalysis.isMove(selectOp.getTrueValue(), selectOp.getOperation());
+ bool moveFalse =
+ funcAnalysis.isMove(selectOp.getFalseValue(), selectOp.getOperation());
+
+ Value refTrue =
+ this->getModuleAnalysis().lookupRef(selectOp.getTrueValue());
+ Value refFalse =
+ this->getModuleAnalysis().lookupRef(selectOp.getFalseValue());
+ Value refResult = this->getModuleAnalysis().lookupRef(selectOp.getResult());
+
+ Type boolType = rewriter.getI1Type();
+ auto condition = rewriter.create<IREE::VM::CmpNZI32Op>(
+ loc, rewriter.getI32Type(), selectOp.getCondition());
+ auto conditionI1 = rewriter.create<emitc::CastOp>(
+ /*location=*/loc,
+ /*type=*/boolType,
+ /*operand=*/condition.getResult());
+
+ auto *continueBlock =
+ rewriter.splitBlock(selectOp->getBlock(), Block::iterator(selectOp));
+
+ Block *trueBlock = nullptr;
+ {
+ OpBuilder::InsertionGuard guard(rewriter);
+ trueBlock = rewriter.createBlock(continueBlock);
+ returnIfError(
+ /*rewriter=*/rewriter,
+ /*location=*/loc,
+ /*callee=*/StringAttr::get(ctx, "iree_vm_ref_retain_or_move_checked"),
+ /*args=*/
+ ArrayAttr::get(
+ ctx, {rewriter.getBoolAttr(moveTrue), rewriter.getIndexAttr(0),
+ rewriter.getIndexAttr(1), rewriter.getIndexAttr(2)}),
+ /*operands=*/
+ ArrayRef<Value>{refTrue, resultTypeAsRef, refResult},
+ this->getModuleAnalysis());
+ rewriter.create<IREE::VM::BranchOp>(loc, continueBlock);
+ }
+
+ Block *falseBlock = nullptr;
+ {
+ OpBuilder::InsertionGuard guard(rewriter);
+ falseBlock = rewriter.createBlock(continueBlock);
+ returnIfError(
+ /*rewriter=*/rewriter,
+ /*location=*/loc,
+ /*callee=*/StringAttr::get(ctx, "iree_vm_ref_retain_or_move_checked"),
+ /*args=*/
+ ArrayAttr::get(
+ ctx, {rewriter.getBoolAttr(moveFalse), rewriter.getIndexAttr(0),
+ rewriter.getIndexAttr(1), rewriter.getIndexAttr(2)}),
+ /*operands=*/
+ ArrayRef<Value>{refFalse, resultTypeAsRef, refResult},
+ this->getModuleAnalysis());
+ rewriter.create<IREE::VM::BranchOp>(loc, continueBlock);
+ }
+
+ rewriter.setInsertionPointAfterValue(conditionI1);
+ rewriter.create<mlir::cf::CondBranchOp>(loc, conditionI1.getResult(),
+ trueBlock, falseBlock);
+ rewriter.replaceOp(selectOp, refResult);
+
+ return success();
+ }
+};
+
template <typename OpTy>
class ConstOpConversion : public EmitCConversionPattern<OpTy> {
using Adaptor = typename OpTy::Adaptor;
@@ -3429,12 +3529,8 @@
releaseRefs(rewriter, loc, funcOp, getModuleAnalysis());
- std::string messageStr = std::string("\"") +
- op.getMessage().value_or("").str() +
- std::string("\"");
-
- Value message =
- emitc_builders::ireeMakeCstringView(rewriter, loc, messageStr);
+ Value message = emitc_builders::ireeMakeCstringView(
+ rewriter, loc, op.getMessage().value_or("").str());
auto messageSizeOp = emitc_builders::structMember(
rewriter, loc,
@@ -4430,6 +4526,7 @@
CallOpConversion<IREE::VM::CallOp>,
CallOpConversion<IREE::VM::CallVariadicOp>,
CompareRefNotZeroOpConversion,
+ SelectRefOpConversion,
CondBranchOpConversion,
BranchTableOpConversion,
ConstOpConversion<IREE::VM::ConstF32Op>,
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCBuilders.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCBuilders.cpp
index e817f9e..3076f2d 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCBuilders.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCBuilders.cpp
@@ -299,6 +299,11 @@
Value ireeMakeCstringView(OpBuilder builder, Location location,
std::string str) {
+ std::string escapedStr;
+ llvm::raw_string_ostream os(escapedStr);
+ os.write_escaped(str);
+ auto quotedStr = std::string("\"") + escapedStr + std::string("\"");
+
auto ctx = builder.getContext();
return builder
.create<emitc::CallOpaqueOp>(
@@ -306,7 +311,7 @@
/*type=*/emitc::OpaqueType::get(ctx, "iree_string_view_t"),
/*callee=*/StringAttr::get(ctx, "iree_make_cstring_view"),
/*args=*/
- ArrayAttr::get(ctx, {emitc::OpaqueAttr::get(ctx, str)}),
+ ArrayAttr::get(ctx, {emitc::OpaqueAttr::get(ctx, quotedStr)}),
/*templateArgs=*/ArrayAttr{},
/*operands=*/ArrayRef<Value>{})
.getResult(0);
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp
index 65ad8c9..871087c 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp
@@ -97,8 +97,18 @@
// No uses - erase the global entirely.
deadOps.push_back(globalInfo->op);
} else {
- // If there are stores mark the global as mutable.
- globalInfo->op.setGlobalMutable(!globalInfo->getStores().empty());
+ // TODO(benvanik): verify we want this behavior - we likely want to change
+ // this to be mutable only if stores exist outside of initializers.
+ //
+ // If there are stores mark the global as mutable. We need to update all
+ // of the loads if this changes anything.
+ bool hasStores = !globalInfo->getStores().empty();
+ bool didChange = globalInfo->op.isGlobalMutable() != hasStores;
+ globalInfo->op.setGlobalMutable(hasStores);
+ if (didChange) {
+ for (auto loadOp : globalInfo->getLoads())
+ loadOp.setGlobalImmutable(!hasStores);
+ }
}
for (auto loadOp : globalInfo->getLoads())
loadOp.setGlobalImmutable(!globalInfo->op.isGlobalMutable());
diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/Passes.cpp
index 7d23aaa..fc25a70 100644
--- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/Passes.cpp
@@ -32,8 +32,8 @@
// Variant configuration
// ---------------------------------------------------------------------------
-static void
-buildVMVXConfigurationPassPipelineImpl(OpPassManager &modulePassManager) {
+void buildVMVXConfigurationPassPipeline(OpPassManager &variantPassManager) {
+ OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
{
FunctionLikeNest funcPassManager(modulePassManager);
// ---------------------------------------------------------------------------
@@ -43,25 +43,19 @@
}
modulePassManager.addPass(createMaterializeUserConfigsPass());
FunctionLikeNest(modulePassManager)
- .addPass([&]() { return createCPUMaterializeEncodingPass(); })
+ .addPass(createCPUMaterializeDeviceEncodingPass)
// TODO: Remove the following pass the plumb support for
// #hal.descriptor_type memory space through the stack.
.addPass(createEraseHALDescriptorTypeFromMemRefPass);
modulePassManager.addPass(createVMVXSelectLoweringStrategyPass());
}
-void buildVMVXConfigurationPassPipeline(OpPassManager &variantPassManager) {
- OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
- buildVMVXConfigurationPassPipelineImpl(modulePassManager);
-}
-
// ---------------------------------------------------------------------------
// Variant Translation
// ---------------------------------------------------------------------------
static void
buildVectorVMVXTransformPassPipeline(OpPassManager &variantPassManager) {
-
OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
// ---------------------------------------------------------------------------
// Tensor-level optimization, kernel dispatch and lower to buffers.
diff --git a/compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel b/compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel
index 7bbd7f5..66643e6 100644
--- a/compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel
+++ b/compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel
@@ -29,6 +29,7 @@
deps = [
"//compiler/src/iree/compiler/Dialect/Encoding/IR",
"//compiler/src/iree/compiler/Dialect/Flow/IR",
+ "//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
"//compiler/src/iree/compiler/Dialect/Stream/IR",
"//compiler/src/iree/compiler/Dialect/Util/IR",
diff --git a/compiler/src/iree/compiler/ExternalInterfaces/CMakeLists.txt b/compiler/src/iree/compiler/ExternalInterfaces/CMakeLists.txt
index 4e2f29a..a63fca3 100644
--- a/compiler/src/iree/compiler/ExternalInterfaces/CMakeLists.txt
+++ b/compiler/src/iree/compiler/ExternalInterfaces/CMakeLists.txt
@@ -34,6 +34,7 @@
MLIRValueBoundsOpInterface
iree::compiler::Dialect::Encoding::IR
iree::compiler::Dialect::Flow::IR
+ iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::LinalgExt::IR
iree::compiler::Dialect::Stream::IR
iree::compiler::Dialect::Util::IR
diff --git a/compiler/src/iree/compiler/ExternalInterfaces/StreamExternalModels.cpp b/compiler/src/iree/compiler/ExternalInterfaces/StreamExternalModels.cpp
index e3ba257..ab1adf0 100644
--- a/compiler/src/iree/compiler/ExternalInterfaces/StreamExternalModels.cpp
+++ b/compiler/src/iree/compiler/ExternalInterfaces/StreamExternalModels.cpp
@@ -6,6 +6,10 @@
#include "iree/compiler/ExternalInterfaces/StreamExternalModels.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
@@ -15,26 +19,77 @@
namespace {
template <typename OpT>
-struct AffinityOpAttrExternalModel
+struct OptionalOpAffinityAttrExternalModel
: public IREE::Stream::AffinityOpInterface::ExternalModel<
- AffinityOpAttrExternalModel<OpT>, OpT> {
+ OptionalOpAffinityAttrExternalModel<OpT>, OpT> {
static void add(MLIRContext *context) {
- OpT::template attachInterface<AffinityOpAttrExternalModel<OpT>>(*context);
+ OpT::template attachInterface<OptionalOpAffinityAttrExternalModel<OpT>>(
+ *context);
}
- // Most structural ops don't require affinities and after placement we don't
- // use the affinities even if the ops still exist.
- bool requiresAffinity(Operation *op) const { return false; }
+ // Affinity only required for results that hold resources that
+ // require placement.
+ bool requiresAffinity(Operation *op) const {
+ auto resultType = cast<OpT>(op).getResult().getType();
+ return isa<TensorType>(resultType);
+ }
- IREE::Stream::AffinityAttr getAffinity(Operation *op) const {
+ IREE::Stream::AffinityAttr getAffinityAttr(Operation *op) const {
return op->getAttrOfType<IREE::Stream::AffinityAttr>("stream.affinity");
}
- void setAffinity(Operation *op, IREE::Stream::AffinityAttr value) const {
- if (value)
+ void setAffinityAttr(Operation *op, IREE::Stream::AffinityAttr value) const {
+ if (value) {
op->setAttr("stream.affinity", value);
- else
+ } else {
op->removeAttr("stream.affinity");
+ }
+ }
+};
+
+struct FlowTransferTargetAffinityAttrExternalModel
+ : public IREE::Stream::AffinityOpInterface::ExternalModel<
+ FlowTransferTargetAffinityAttrExternalModel,
+ IREE::Flow::TensorTransferOp> {
+ static void add(MLIRContext *context) {
+ IREE::Flow::TensorTransferOp::attachInterface<
+ FlowTransferTargetAffinityAttrExternalModel>(*context);
+ }
+
+ bool requiresAffinity(Operation *op) const { return true; }
+
+ IREE::Stream::AffinityAttr getAffinityAttr(Operation *op) const {
+ return op->getAttrOfType<IREE::Stream::AffinityAttr>("target");
+ }
+
+ void setAffinityAttr(Operation *op, IREE::Stream::AffinityAttr value) const {
+ op->setAttr("target", value);
+ }
+};
+
+template <typename OpT>
+struct HALTensorAffinityAttrExternalModel
+ : public IREE::Stream::AffinityOpInterface::ExternalModel<
+ HALTensorAffinityAttrExternalModel<OpT>, OpT> {
+ static void add(MLIRContext *context) {
+ OpT::template attachInterface<HALTensorAffinityAttrExternalModel<OpT>>(
+ *context);
+ }
+
+ bool requiresAffinity(Operation *op) const { return false; }
+
+ bool pinsValueAffinity(Operation *op) const { return true; }
+
+ IREE::Stream::AffinityAttr getAffinityAttr(Operation *op) const {
+ return op->getAttrOfType<IREE::Stream::AffinityAttr>("affinity");
+ }
+
+ void setAffinityAttr(Operation *op, IREE::Stream::AffinityAttr value) const {
+ if (value) {
+ op->setAttr("affinity", value);
+ } else {
+ op->removeAttr("affinity");
+ }
}
};
@@ -54,30 +109,110 @@
return isa<TensorType>(globalType);
}
- IREE::Stream::AffinityAttr getAffinity(Operation *op) const {
+ IREE::Stream::AffinityAttr getAffinityAttr(Operation *op) const {
return op->getAttrOfType<IREE::Stream::AffinityAttr>("stream.affinity");
}
- void setAffinity(Operation *op, IREE::Stream::AffinityAttr value) const {
- if (value)
+ void setAffinityAttr(Operation *op, IREE::Stream::AffinityAttr value) const {
+ if (value) {
op->setAttr("stream.affinity", value);
- else
+ } else {
op->removeAttr("stream.affinity");
+ }
+ }
+};
+
+template <typename OpT, bool kRequiresAffinity = true>
+struct AffinityOpAttrExternalModel
+ : public IREE::Stream::AffinityOpInterface::ExternalModel<
+ AffinityOpAttrExternalModel<OpT, kRequiresAffinity>, OpT> {
+ static void add(MLIRContext *context) {
+ OpT::template attachInterface<
+ AffinityOpAttrExternalModel<OpT, kRequiresAffinity>>(*context);
+ }
+
+ // Most structural ops don't require affinities and after placement we don't
+ // use the affinities even if the ops still exist.
+ bool requiresAffinity(Operation *op) const { return kRequiresAffinity; }
+
+ IREE::Stream::AffinityAttr getAffinityAttr(Operation *op) const {
+ return op->getAttrOfType<IREE::Stream::AffinityAttr>("stream.affinity");
+ }
+
+ void setAffinityAttr(Operation *op, IREE::Stream::AffinityAttr value) const {
+ if (value) {
+ op->setAttr("stream.affinity", value);
+ } else {
+ op->removeAttr("stream.affinity");
+ }
+ }
+};
+
+struct TensorAffinityTypeExternalModel
+ : public IREE::Stream::AffinityTypeInterface::ExternalModel<
+ TensorAffinityTypeExternalModel, RankedTensorType> {
+ static void add(MLIRContext *context) {
+ RankedTensorType::attachInterface<TensorAffinityTypeExternalModel>(
+ *context);
}
};
} // namespace
void registerStreamExternalModels(DialectRegistry ®istry) {
- // Must ensure that any dependent dialects are registered.
- registry.insert<IREE::Util::UtilDialect>();
+ registry.addExtension(+[](MLIRContext *context) {
+ TensorAffinityTypeExternalModel::add(context);
+ });
+ registry.insert<arith::ArithDialect>();
registry.addExtension(
- +[](MLIRContext *context, IREE::Util::UtilDialect *dialect) {
- GlobalOpAffinityAttrExternalModel<IREE::Util::GlobalOp>::add(context);
- AffinityOpAttrExternalModel<IREE::Util::InitializerOp>::add(context);
- AffinityOpAttrExternalModel<IREE::Util::FuncOp>::add(context);
+ +[](MLIRContext *context, arith::ArithDialect *dialect) {
+ OptionalOpAffinityAttrExternalModel<arith::ConstantOp>::add(context);
});
+
+ registry.insert<IREE::Flow::FlowDialect>();
+ registry.addExtension(+[](MLIRContext *context,
+ IREE::Flow::FlowDialect *dialect) {
+ FlowTransferTargetAffinityAttrExternalModel::add(context);
+ AffinityOpAttrExternalModel<IREE::Flow::DispatchRegionOp>::add(context);
+ AffinityOpAttrExternalModel<IREE::Flow::DispatchWorkgroupsOp>::add(context);
+ AffinityOpAttrExternalModel<IREE::Flow::DispatchOp>::add(context);
+ AffinityOpAttrExternalModel<IREE::Flow::CallOp>::add(context);
+ AffinityOpAttrExternalModel<IREE::Flow::TensorConstantOp>::add(context);
+ AffinityOpAttrExternalModel<IREE::Flow::TensorDynamicConstantOp>::add(
+ context);
+ AffinityOpAttrExternalModel<IREE::Flow::TensorAllocaOp>::add(context);
+ AffinityOpAttrExternalModel<IREE::Flow::TensorEmptyOp>::add(context);
+ AffinityOpAttrExternalModel<IREE::Flow::TensorSplatOp>::add(context);
+ AffinityOpAttrExternalModel<IREE::Flow::TensorCloneOp>::add(context);
+ AffinityOpAttrExternalModel<IREE::Flow::TensorSliceOp>::add(context);
+ AffinityOpAttrExternalModel<IREE::Flow::TensorUpdateOp>::add(context);
+ AffinityOpAttrExternalModel<IREE::Flow::ChannelDefaultOp>::add(context);
+ AffinityOpAttrExternalModel<IREE::Flow::CollectiveAllGatherOp>::add(
+ context);
+ AffinityOpAttrExternalModel<IREE::Flow::CollectiveAllReduceOp>::add(
+ context);
+ AffinityOpAttrExternalModel<IREE::Flow::CollectiveAllToAllOp>::add(context);
+ AffinityOpAttrExternalModel<IREE::Flow::CollectiveReduceScatterOp>::add(
+ context);
+ AffinityOpAttrExternalModel<IREE::Flow::CollectiveSendRecvOp>::add(context);
+ });
+
+ registry.insert<IREE::HAL::HALDialect>();
+ registry.addExtension(+[](MLIRContext *context,
+ IREE::HAL::HALDialect *dialect) {
+ HALTensorAffinityAttrExternalModel<IREE::HAL::TensorImportOp>::add(context);
+ HALTensorAffinityAttrExternalModel<IREE::HAL::TensorExportOp>::add(context);
+ HALTensorAffinityAttrExternalModel<IREE::HAL::TensorAliasOp>::add(context);
+ });
+
+ registry.insert<IREE::Util::UtilDialect>();
+ registry.addExtension(+[](MLIRContext *context,
+ IREE::Util::UtilDialect *dialect) {
+ GlobalOpAffinityAttrExternalModel<IREE::Util::GlobalOp>::add(context);
+ AffinityOpAttrExternalModel<IREE::Util::InitializerOp, false>::add(context);
+ AffinityOpAttrExternalModel<IREE::Util::FuncOp, false>::add(context);
+ });
}
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel
index 8b4d450..46c4988 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel
+++ b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel
@@ -83,6 +83,7 @@
"//compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow",
"//compiler/src/iree/compiler/Dialect/Flow/IR",
"//compiler/src/iree/compiler/Dialect/Flow/Transforms",
+ "//compiler/src/iree/compiler/Dialect/HAL/Analysis",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Dialect/HAL/IR:HALDialect",
"//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
diff --git a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt
index 4e29821..9410270 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt
+++ b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt
@@ -98,6 +98,7 @@
iree::compiler::Dialect::Flow::Conversion::TensorToFlow
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::Flow::Transforms
+ iree::compiler::Dialect::HAL::Analysis
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::HAL::IR::HALDialect
iree::compiler::Dialect::LinalgExt::IR
diff --git a/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp b/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp
index bf62869..143b969 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp
@@ -6,6 +6,7 @@
#include "iree/compiler/Codegen/Common/CPU/Passes.h"
#include "iree/compiler/Codegen/Common/Passes.h"
+#include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/GlobalOptimization/PassDetail.h"
@@ -46,11 +47,16 @@
void runOnOperation() override {
auto moduleOp = getOperation();
- auto executableTargets =
- IREE::HAL::DeviceTargetAttr::lookupExecutableTargets(moduleOp);
+ IREE::HAL::DeviceAnalysis deviceAnalysis(moduleOp);
+ if (failed(deviceAnalysis.run()))
+ return signalPassFailure();
+
+ SetVector<IREE::HAL::ExecutableTargetAttr> executableTargets;
+ deviceAnalysis.gatherAllExecutableTargets(executableTargets);
if (executableTargets.size() != 1) {
return runNopPipeline(moduleOp);
}
+
// TODO: vmvx has its own logic about supporting dynamic tile
// sizes. It is not fully integrated into the pipeline, so we remain the
// materialization to the end.
@@ -65,12 +71,8 @@
}
OpPassManager passManager(moduleOp.getOperationName());
- FunctionLikeNest(passManager).addPass([&]() {
- return createCPUMaterializeUpperBoundTileSizePass(executableTargets);
- });
- FunctionLikeNest(passManager).addPass([&]() {
- return createCPUMaterializeEncodingPass(executableTarget);
- });
+ passManager.addPass(createCPUMaterializeUpperBoundTileSizePass());
+ passManager.addPass(createCPUMaterializeHostEncodingPass());
if (failed(runPipeline(passManager, moduleOp))) {
return signalPassFailure();
}
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/BUILD.bazel b/compiler/src/iree/compiler/GlobalOptimization/test/BUILD.bazel
index bbb7867..027c626 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/BUILD.bazel
@@ -29,7 +29,6 @@
"global_loop_invariant_code_motion.mlir",
"hoist_into_globals.mlir",
"infer_numeric_narrowing.mlir",
- "materialize_homogeneous_encodings.mlir",
"optimize_numerics.mlir",
"propagate_linalg_transpose.mlir",
"raise_special_ops.mlir",
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/CMakeLists.txt b/compiler/src/iree/compiler/GlobalOptimization/test/CMakeLists.txt
index 79c75b3..b6823fc 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/CMakeLists.txt
@@ -27,7 +27,6 @@
"global_loop_invariant_code_motion.mlir"
"hoist_into_globals.mlir"
"infer_numeric_narrowing.mlir"
- "materialize_homogeneous_encodings.mlir"
"optimize_numerics.mlir"
"propagate_linalg_transpose.mlir"
"raise_special_ops.mlir"
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/hoist_into_globals.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/hoist_into_globals.mlir
index 4082bbf..e289f07 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/hoist_into_globals.mlir
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/hoist_into_globals.mlir
@@ -142,12 +142,14 @@
// CHECK-LABEL: @hoist_dialect_attrs
module @hoist_dialect_attrs {
+ // CHECK: util.global private @device
+ util.global private @device : !hal.device
// CHECK: util.global private @[[HOISTED:[a-z0-9_]+]]
- // CHECK-SAME: hal.affinity = #hal.affinity.queue<[0, 1]>
+ // CHECK-SAME: stream.affinity = #hal.device.affinity<@device>
// CHECK: util.initializer
- // CHECK-SAME: hal.affinity = #hal.affinity.queue<[0, 1]>
+ // CHECK-SAME: stream.affinity = #hal.device.affinity<@device>
util.func public @main() -> tensor<i32> attributes {
- hal.affinity = #hal.affinity.queue<[0, 1]>
+ stream.affinity = #hal.device.affinity<@device>
} {
%0 = arith.constant dense<3> : tensor<i32>
%1 = "iree_unregistered.const_expr"(%0) : (tensor<i32>) -> tensor<i32>
diff --git a/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp b/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp
index f6162eb..6e8ecc0 100644
--- a/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp
+++ b/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp
@@ -209,13 +209,15 @@
// the work.
rewriter.replaceOpWithNewOp<IREE::HAL::TensorImportOp>(
srcOp, resultType, adaptor.getSource(), TypeAttr::get(resultType),
- /*name=*/nullptr);
+ /*name=*/nullptr,
+ /*affinity=*/nullptr);
} else {
// Dynamic dims explicitly provided (or wrong, in which case the verifier
// will get it).
rewriter.replaceOpWithNewOp<IREE::HAL::TensorImportOp>(
srcOp, resultType, adaptor.getSource(), TypeAttr::get(resultType),
- adaptor.getTargetDims(), /*wait_fence=*/Value{}, /*name=*/nullptr);
+ adaptor.getTargetDims(), /*wait_fence=*/Value{}, /*name=*/nullptr,
+ /*affinity=*/nullptr);
}
return success();
}
@@ -237,14 +239,16 @@
// the work.
rewriter.replaceOpWithNewOp<IREE::HAL::TensorExportOp>(
srcOp, resultType, adaptor.getSource(),
- TypeAttr::get(adaptor.getSource().getType()), /*name=*/nullptr);
+ TypeAttr::get(adaptor.getSource().getType()), /*name=*/nullptr,
+ /*affinity=*/nullptr);
} else {
// Dynamic dims explicitly provided (or wrong, in which case the verifier
// will get it).
rewriter.replaceOpWithNewOp<IREE::HAL::TensorExportOp>(
srcOp, resultType, adaptor.getSource(),
TypeAttr::get(adaptor.getSource().getType()), adaptor.getSourceDims(),
- /*name=*/nullptr);
+ /*name=*/nullptr,
+ /*affinity=*/nullptr);
}
return success();
}
@@ -349,10 +353,8 @@
// Allowlist of function attributes to retain when importing funcs.
constexpr const char *kRetainedAttributes[] = {
- "iree.reflection",
- "vm.fallback",
- "vm.signature",
- "vm.version",
+ "iree.reflection", "stream.affinity", "vm.fallback",
+ "vm.signature", "vm.version",
};
auto retainedAttributes = ArrayRef<const char *>(
kRetainedAttributes,
diff --git a/compiler/src/iree/compiler/Modules/Check/Conversion/BUILD.bazel b/compiler/src/iree/compiler/Modules/Check/Conversion/BUILD.bazel
index 42d582b..a3255ea 100644
--- a/compiler/src/iree/compiler/Modules/Check/Conversion/BUILD.bazel
+++ b/compiler/src/iree/compiler/Modules/Check/Conversion/BUILD.bazel
@@ -22,6 +22,7 @@
],
deps = [
"//compiler/src/iree/compiler/Dialect/HAL/Conversion",
+ "//compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL:Utils",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Dialect/VM/Conversion",
"//compiler/src/iree/compiler/Modules/Check/IR",
diff --git a/compiler/src/iree/compiler/Modules/Check/Conversion/CMakeLists.txt b/compiler/src/iree/compiler/Modules/Check/Conversion/CMakeLists.txt
index 161a143..3c20a5b 100644
--- a/compiler/src/iree/compiler/Modules/Check/Conversion/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Modules/Check/Conversion/CMakeLists.txt
@@ -22,6 +22,7 @@
MLIRTransformUtils
MLIRTransforms
iree::compiler::Dialect::HAL::Conversion
+ iree::compiler::Dialect::HAL::Conversion::StreamToHAL::Utils
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::VM::Conversion
iree::compiler::Modules::Check::IR
diff --git a/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp b/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp
index 3dab905..d9db0b3 100644
--- a/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp
+++ b/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp
@@ -7,6 +7,7 @@
#include "iree/compiler/Modules/Check/Conversion/ConversionPatterns.h"
#include "iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h"
+#include "iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.h"
#include "iree/compiler/Dialect/HAL/Conversion/TypeConverter.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h"
@@ -68,8 +69,7 @@
state.addAttributes(srcOp->getAttrs());
// Add device argument.
- // TODO(multi-device): support multiple devices in check tests .
- Value device = IREE::HAL::DeviceType::resolveAny(srcOp->getLoc(), rewriter);
+ Value device = lookupDeviceFor(srcOp, rewriter);
state.addOperands({device});
for (auto [srcOperand, dstOperand] :
diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/Passes.cpp b/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/Passes.cpp
index 2c11a33..0d68029c 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/Passes.cpp
@@ -53,6 +53,12 @@
// Device assignment and interface materialization
//----------------------------------------------------------------------------
+ IREE::HAL::AssignmentOptions assignmentOptions;
+ assignmentOptions.legacyTargetBackends = targetOptions.legacyTargetBackends;
+ assignmentOptions.targetDevices = targetOptions.targetDevices;
+ assignmentOptions.defaultDevice = targetOptions.defaultDevice;
+ IREE::HAL::buildHALDeviceAssignmentPassPipeline(passManager, targetRegistry,
+ assignmentOptions);
IREE::HAL::buildHALConfigurationPassPipeline(passManager, targetRegistry,
targetOptions);
diff --git a/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/Passes.cpp b/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/Passes.cpp
index 47f1bcd..96c7eb8 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/Passes.cpp
@@ -53,6 +53,12 @@
// Device assignment and interface materialization
//----------------------------------------------------------------------------
+ IREE::HAL::AssignmentOptions assignmentOptions;
+ assignmentOptions.legacyTargetBackends = targetOptions.legacyTargetBackends;
+ assignmentOptions.targetDevices = targetOptions.targetDevices;
+ assignmentOptions.defaultDevice = targetOptions.defaultDevice;
+ IREE::HAL::buildHALDeviceAssignmentPassPipeline(passManager, targetRegistry,
+ assignmentOptions);
IREE::HAL::buildHALConfigurationPassPipeline(passManager, targetRegistry,
targetOptions);
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/parameter_ops.mlir b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/parameter_ops.mlir
index 4842e51..5354b82 100644
--- a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/parameter_ops.mlir
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/parameter_ops.mlir
@@ -1,5 +1,7 @@
// RUN: iree-opt --split-input-file --iree-hal-conversion --canonicalize %s | FileCheck %s
+util.global private @device : !hal.device
+
// CHECK-LABEL: @parameterLoad
// CHECK-SAME: (%[[WAIT:.+]]: !hal.fence) -> (!hal.buffer, !hal.buffer, !hal.fence)
util.func public @parameterLoad(%wait: !stream.timepoint) -> (!stream.resource<constant>, !stream.resource<constant>, !stream.timepoint) {
@@ -7,7 +9,7 @@
%c51_i64 = arith.constant 51 : i64
%c100 = arith.constant 100 : index
%c101 = arith.constant 101 : index
- // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
+ // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device
// CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
// CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
// CHECK: %[[BUFFERS:.+]]:2 = io_parameters.load<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
@@ -15,7 +17,7 @@
// CHECK-SAME: type("DeviceVisible|DeviceLocal") usage("TransferSource|TransferTarget|Transfer|DispatchStorageRead|DispatchStorageWrite|DispatchStorage|SharingImmutable")
// CHECK-NEXT: "scope"::"key0"[%c50_i64] : !hal.buffer{%c100}
// CHECK-NEXT: "scope"::"key1"[%c51_i64] : !hal.buffer{%c101}
- %results:2, %result_timepoint = stream.parameter.load await(%wait) => {
+ %results:2, %result_timepoint = stream.parameter.load on(#hal.device.affinity<@device>) await(%wait) => {
"scope"::"key0"[%c50_i64] : !stream.resource<constant>{%c100},
"scope"::"key1"[%c51_i64] : !stream.resource<constant>{%c101}
} => !stream.timepoint
@@ -25,19 +27,21 @@
// -----
+util.global private @device : !hal.device
+
// CHECK-LABEL: @parameterLoadNoScope
// CHECK-SAME: (%[[WAIT:.+]]: !hal.fence) -> (!hal.buffer, !hal.fence)
util.func public @parameterLoadNoScope(%wait: !stream.timepoint) -> (!stream.resource<constant>, !stream.timepoint) {
%c50_i64 = arith.constant 50 : i64
%c100 = arith.constant 100 : index
- // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
+ // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device
// CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
// CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
// CHECK: %[[BUFFER:.+]] = io_parameters.load<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
// CHECK-SAME: wait(%[[WAIT]]) signal(%[[SIGNAL]])
// CHECK-SAME: type("DeviceVisible|DeviceLocal") usage("TransferSource|TransferTarget|Transfer|DispatchStorageRead|DispatchStorageWrite|DispatchStorage|SharingImmutable")
// CHECK-NEXT: "key"[%c50_i64] : !hal.buffer{%c100}
- %result, %result_timepoint = stream.parameter.load await(%wait) => {
+ %result, %result_timepoint = stream.parameter.load on(#hal.device.affinity<@device>) await(%wait) => {
"key"[%c50_i64] : !stream.resource<constant>{%c100}
} => !stream.timepoint
// CHECK: return %[[BUFFER]], %[[SIGNAL]]
@@ -46,6 +50,8 @@
// -----
+util.global private @device : !hal.device
+
// CHECK-LABEL: @parameterRead
// CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[TARGET:.+]]: !hal.buffer) -> !hal.fence
util.func public @parameterRead(%wait: !stream.timepoint, %target: !stream.resource<transient>) -> !stream.timepoint {
@@ -53,19 +59,21 @@
%c100 = arith.constant 100 : index
%c200 = arith.constant 200 : index
%c300 = arith.constant 300 : index
- // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
+ // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device
// CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
// CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
// CHECK: io_parameters.gather<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
// CHECK-SAME: wait(%[[WAIT]]) signal(%[[SIGNAL]])
// CHECK-NEXT: "scope"::"key"[%c50_i64] -> %[[TARGET]][%c100 for %c200] : !hal.buffer
- %timepoint = stream.parameter.read await(%wait) => "scope"::"key"[%c50_i64] -> %target[%c100 for %c200] : !stream.resource<transient>{%c300} => !stream.timepoint
+ %timepoint = stream.parameter.read on(#hal.device.affinity<@device>) await(%wait) => "scope"::"key"[%c50_i64] -> %target[%c100 for %c200] : !stream.resource<transient>{%c300} => !stream.timepoint
// CHECK: return %[[SIGNAL]]
util.return %timepoint : !stream.timepoint
}
// -----
+util.global private @device : !hal.device
+
// CHECK-LABEL: @parameterWrite
// CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[SOURCE:.+]]: !hal.buffer) -> !hal.fence
util.func public @parameterWrite(%wait: !stream.timepoint, %source: !stream.resource<transient>) -> !stream.timepoint {
@@ -73,19 +81,21 @@
%c100 = arith.constant 100 : index
%c200 = arith.constant 200 : index
%c300 = arith.constant 300 : index
- // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
+ // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device
// CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
// CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
// CHECK: io_parameters.scatter<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
// CHECK-SAME: wait(%[[WAIT]]) signal(%[[SIGNAL]])
// CHECK-NEXT: %[[SOURCE]][%c100 for %c200] : !hal.buffer -> "scope"::"key"[%c50_i64]
- %timepoint = stream.parameter.write await(%wait) => %source[%c100 for %c200] : !stream.resource<transient>{%c300} -> "scope"::"key"[%c50_i64] => !stream.timepoint
+ %timepoint = stream.parameter.write on(#hal.device.affinity<@device>) await(%wait) => %source[%c100 for %c200] : !stream.resource<transient>{%c300} -> "scope"::"key"[%c50_i64] => !stream.timepoint
// CHECK: return %[[SIGNAL]]
util.return %timepoint : !stream.timepoint
}
// -----
+util.global private @device : !hal.device
+
// CHECK-LABEL: @parameterGather
// CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[TARGET:.+]]: !hal.buffer) -> !hal.fence
util.func public @parameterGather(%wait: !stream.timepoint, %target: !stream.resource<transient>) -> !stream.timepoint {
@@ -99,7 +109,7 @@
%c201 = arith.constant 201 : index
%c202 = arith.constant 202 : index
%c300 = arith.constant 300 : index
- // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
+ // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device
// CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
// CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
// CHECK: io_parameters.gather<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
@@ -107,7 +117,7 @@
// CHECK-NEXT: "scope"::"key0"[%c50_i64] -> %[[TARGET]][%c100 for %c200] : !hal.buffer,
// CHECK-NEXT: "scope"::"key1"[%c51_i64] -> %[[TARGET]][%c101 for %c201] : !hal.buffer,
// CHECK-NEXT: "scope"::"key2"[%c52_i64] -> %[[TARGET]][%c102 for %c202] : !hal.buffer
- %timepoint = stream.parameter.gather await(%wait) => {
+ %timepoint = stream.parameter.gather on(#hal.device.affinity<@device>) await(%wait) => {
"scope"::"key0"[%c50_i64] -> %target[%c100 for %c200] : !stream.resource<transient>{%c300},
"scope"::"key1"[%c51_i64] -> %target[%c101 for %c201] : !stream.resource<transient>{%c300},
"scope"::"key2"[%c52_i64] -> %target[%c102 for %c202] : !stream.resource<transient>{%c300}
@@ -118,6 +128,8 @@
// -----
+util.global private @device : !hal.device
+
// CHECK-LABEL: @parameterGatherNoScope
// CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[TARGET:.+]]: !hal.buffer) -> !hal.fence
util.func public @parameterGatherNoScope(%wait: !stream.timepoint, %target: !stream.resource<transient>) -> !stream.timepoint {
@@ -128,14 +140,14 @@
%c200 = arith.constant 200 : index
%c201 = arith.constant 201 : index
%c300 = arith.constant 300 : index
- // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
+ // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device
// CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
// CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
// CHECK: io_parameters.gather<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
// CHECK-SAME: wait(%[[WAIT]]) signal(%[[SIGNAL]])
// CHECK-NEXT: "key0"[%c50_i64] -> %[[TARGET]][%c100 for %c200] : !hal.buffer,
// CHECK-NEXT: "key1"[%c51_i64] -> %[[TARGET]][%c101 for %c201] : !hal.buffer
- %timepoint = stream.parameter.gather await(%wait) => {
+ %timepoint = stream.parameter.gather on(#hal.device.affinity<@device>) await(%wait) => {
"key0"[%c50_i64] -> %target[%c100 for %c200] : !stream.resource<transient>{%c300},
"key1"[%c51_i64] -> %target[%c101 for %c201] : !stream.resource<transient>{%c300}
} => !stream.timepoint
@@ -145,6 +157,8 @@
// -----
+util.global private @device : !hal.device
+
// CHECK-LABEL: @parameterScatter
// CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[SOURCE:.+]]: !hal.buffer) -> !hal.fence
util.func public @parameterScatter(%wait: !stream.timepoint, %source: !stream.resource<transient>) -> !stream.timepoint {
@@ -158,7 +172,7 @@
%c201 = arith.constant 201 : index
%c202 = arith.constant 202 : index
%c300 = arith.constant 300 : index
- // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
+ // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device
// CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
// CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
// CHECK: io_parameters.scatter<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
@@ -167,7 +181,7 @@
// CHECK-NEXT: %[[SOURCE]][%c101 for %c201] : !hal.buffer -> "scope"::"key1"[%c51_i64],
// CHECK-NEXT: %[[SOURCE]][%c102 for %c202] : !hal.buffer -> "scope"::"key2"[%c52_i64]
// CHECK-NEXT: }
- %timepoint = stream.parameter.scatter await(%wait) => {
+ %timepoint = stream.parameter.scatter on(#hal.device.affinity<@device>) await(%wait) => {
%source[%c100 for %c200] : !stream.resource<transient>{%c300} -> "scope"::"key0"[%c50_i64],
%source[%c101 for %c201] : !stream.resource<transient>{%c300} -> "scope"::"key1"[%c51_i64],
%source[%c102 for %c202] : !stream.resource<transient>{%c300} -> "scope"::"key2"[%c52_i64]
diff --git a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp
index d7a155b..bd56dc8 100644
--- a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp
+++ b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp
@@ -77,18 +77,9 @@
PreprocessingOptions preprocessingOptions,
GlobalOptimizationOptions globalOptimizationOptions,
SchedulingOptions schedulingOptions,
- IREE::HAL::TargetOptions executableOptions, IREEVMPipelineHooks &hooks,
+ IREE::HAL::TargetOptions halTargetOptions, IREEVMPipelineHooks &hooks,
OpPassManager &passManager, IREEVMPipelinePhase compileFrom,
IREEVMPipelinePhase compileTo) {
- // If the user specified a set of target devices we attach them to the module
- // IR so that they are available for all passes that may want to use this
- // information. If trying to compile in a generic mode the user should omit
- // specifying targets.
- if (!executableOptions.targets.empty()) {
- passManager.addPass(IREE::HAL::createAssignTargetDevicesPass(
- {&targetRegistry, executableOptions.targets}));
- }
-
// Input pipelines can result in changes to the exported functions and types
// and must run before generating bindings.
// After input processing, there should only be IREE legal types in
@@ -169,6 +160,18 @@
if (compileTo == IREEVMPipelinePhase::ABI)
return; // early-exit
+ // If the user specified a set of target devices we attach them to the module
+ // IR so that they are available for all passes that may want to use this
+ // information. If trying to compile in a generic mode the user should omit
+ // specifying targets.
+ IREE::HAL::AssignmentOptions halAssignmentOptions;
+ halAssignmentOptions.legacyTargetBackends =
+ halTargetOptions.legacyTargetBackends;
+ halAssignmentOptions.targetDevices = halTargetOptions.targetDevices;
+ halAssignmentOptions.defaultDevice = halTargetOptions.defaultDevice;
+ IREE::HAL::buildHALDeviceAssignmentPassPipeline(passManager, targetRegistry,
+ halAssignmentOptions);
+
GlobalOptimization::TransformOptions globalTransformOptions;
globalTransformOptions.options = globalOptimizationOptions;
@@ -244,13 +247,13 @@
PreprocessingOptions preprocessingOptions,
GlobalOptimizationOptions globalOptimizationOptions,
SchedulingOptions schedulingOptions,
- IREE::HAL::TargetOptions executableOptions,
- IREE::VM::TargetOptions targetOptions, IREEVMPipelineHooks &hooks,
+ IREE::HAL::TargetOptions halTargetOptions,
+ IREE::VM::TargetOptions vmTargetOptions, IREEVMPipelineHooks &hooks,
OpPassManager &passManager, IREEVMPipelinePhase compileFrom,
IREEVMPipelinePhase compileTo) {
buildIREEPrecompileTransformPassPipeline(
targetRegistry, bindingOptions, inputOptions, preprocessingOptions,
- globalOptimizationOptions, schedulingOptions, executableOptions, hooks,
+ globalOptimizationOptions, schedulingOptions, halTargetOptions, hooks,
passManager, compileFrom, compileTo);
if (compileTo <= IREEVMPipelinePhase::GlobalOptimization)
@@ -313,16 +316,16 @@
case SchedulingOptions::ExecutionModel::AsyncInternal:
case SchedulingOptions::ExecutionModel::AsyncExternal:
IREE::HAL::buildHALTransformPassPipeline(passManager, targetRegistry,
- executableOptions, hooks,
+ halTargetOptions, hooks,
halCompileFrom, halCompileTo);
break;
case SchedulingOptions::ExecutionModel::InlineStatic:
IREE::HAL::Inline::buildHALInlineStaticTransformPassPipeline(
- passManager, targetRegistry, executableOptions);
+ passManager, targetRegistry, halTargetOptions);
break;
case SchedulingOptions::ExecutionModel::InlineDynamic:
IREE::HAL::Loader::buildHALInlineDynamicTransformPassPipeline(
- passManager, targetRegistry, executableOptions);
+ passManager, targetRegistry, halTargetOptions);
break;
}
if (hooks.afterPhase)
@@ -338,7 +341,7 @@
IREE_TRACE_ADD_BEGIN_FRAME_PASS(passManager, "VM");
if (hooks.beforePhase)
hooks.beforePhase(IREEVMPipelinePhase::VM, passManager);
- IREE::VM::buildVMTransformPassPipeline(passManager, targetOptions);
+ IREE::VM::buildVMTransformPassPipeline(passManager, vmTargetOptions);
passManager.addPass(IREE::Util::createDropCompilerHintsPass());
if (hooks.afterPhase)
hooks.afterPhase(IREEVMPipelinePhase::VM, passManager);
diff --git a/compiler/src/iree/compiler/Pipelines/Pipelines.h b/compiler/src/iree/compiler/Pipelines/Pipelines.h
index 1fba747..cdc754f 100644
--- a/compiler/src/iree/compiler/Pipelines/Pipelines.h
+++ b/compiler/src/iree/compiler/Pipelines/Pipelines.h
@@ -102,7 +102,7 @@
PreprocessingOptions preprocessingOptions,
GlobalOptimizationOptions highLevelOptimizationOptions,
SchedulingOptions schedulingOptions,
- IREE::HAL::TargetOptions executableOptions, IREEVMPipelineHooks &hooks,
+ IREE::HAL::TargetOptions halTargetOptions, IREEVMPipelineHooks &hooks,
OpPassManager &passManager,
IREEVMPipelinePhase compileFrom = IREEVMPipelinePhase::Start,
IREEVMPipelinePhase compileTo = IREEVMPipelinePhase::GlobalOptimization);
@@ -118,8 +118,8 @@
PreprocessingOptions preprocessingOptions,
GlobalOptimizationOptions highLevelOptimizationOptions,
SchedulingOptions schedulingOptions,
- IREE::HAL::TargetOptions executableOptions,
- IREE::VM::TargetOptions targetOptions, IREEVMPipelineHooks &hooks,
+ IREE::HAL::TargetOptions halTargetOptions,
+ IREE::VM::TargetOptions vmTargetOptions, IREEVMPipelineHooks &hooks,
OpPassManager &passManager,
IREEVMPipelinePhase compileFrom = IREEVMPipelinePhase::Start,
IREEVMPipelinePhase compileTo = IREEVMPipelinePhase::End);
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
index a481888..13bc6bf 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
@@ -54,8 +54,10 @@
"//compiler/src/iree/compiler/Codegen/Utils",
"//compiler/src/iree/compiler/Dialect/Flow/IR",
"//compiler/src/iree/compiler/Dialect/Flow/Transforms",
+ "//compiler/src/iree/compiler/Dialect/HAL/Analysis",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
+ "//compiler/src/iree/compiler/Dialect/Stream/Analysis",
"//compiler/src/iree/compiler/Dialect/Stream/IR",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"@llvm-project//llvm:Support",
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
index 1bc4c1e..3764d49 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
@@ -65,8 +65,10 @@
iree::compiler::Codegen::Utils
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::Flow::Transforms
+ iree::compiler::Dialect::HAL::Analysis
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::LinalgExt::IR
+ iree::compiler::Dialect::Stream::Analysis
iree::compiler::Dialect::Stream::IR
iree::compiler::Dialect::Util::IR
PUBLIC
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp b/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp
index 721522a..ba415b3 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp
+++ b/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp
@@ -5,9 +5,14 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include <cstdint>
+#include <limits>
+
#include "iree/compiler/Codegen/Common/GPU/GPUHeuristics.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
+#include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h"
+#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
+#include "iree/compiler/Dialect/Stream/Analysis/Affinity.h"
#include "iree/compiler/Preprocessing/Common/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
@@ -21,7 +26,7 @@
namespace mlir::iree_compiler::Preprocessing {
-#define GEN_PASS_DEF_PADTOINTRINSICS
+#define GEN_PASS_DEF_PADTOINTRINSICSPASS
#include "iree/compiler/Preprocessing/Common/Passes.h.inc" // IWYU pragma: export
namespace {
@@ -141,10 +146,8 @@
}
static SmallVector<GPUMatmulShapeType>
-getIntrinsics(linalg::LinalgOp linalgOp) {
- SmallVector<IREE::HAL::ExecutableTargetAttr, 4> executableTargets =
- IREE::HAL::DeviceTargetAttr::lookupExecutableTargets(linalgOp);
-
+getIntrinsics(linalg::LinalgOp linalgOp,
+ ArrayRef<IREE::HAL::ExecutableTargetAttr> executableTargets) {
IREE::GPU::TargetAttr target;
if (executableTargets.size() == 1) {
auto targetAttr = executableTargets.front();
@@ -165,7 +168,9 @@
});
}
-static void padConvOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp) {
+static void
+padConvOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
+ ArrayRef<IREE::HAL::ExecutableTargetAttr> executableTargets) {
if (!isa<linalg::ConvolutionOpInterface>(*linalgOp)) {
return;
}
@@ -174,7 +179,8 @@
return;
// Early exit if cannot find intrinsics or if multiple executable targets.
- SmallVector<GPUMatmulShapeType> intrinsics = getIntrinsics(linalgOp);
+ SmallVector<GPUMatmulShapeType> intrinsics =
+ getIntrinsics(linalgOp, executableTargets);
if (intrinsics.empty())
return;
@@ -304,8 +310,9 @@
rewriter.replaceOp(linalgOp, extracted);
}
-static void padContractionLikeOp(RewriterBase &rewriter,
- linalg::LinalgOp linalgOp) {
+static void padContractionLikeOp(
+ RewriterBase &rewriter, linalg::LinalgOp linalgOp,
+ ArrayRef<IREE::HAL::ExecutableTargetAttr> executableTargets) {
FailureOr<mlir::linalg::ContractionDimensions> contractionDims =
mlir::linalg::inferContractionDims(linalgOp);
@@ -319,7 +326,8 @@
}
// Early exit if cannot find intrinsics or if multiple executable targets.
- SmallVector<GPUMatmulShapeType> intrinsics = getIntrinsics(linalgOp);
+ SmallVector<GPUMatmulShapeType> intrinsics =
+ getIntrinsics(linalgOp, executableTargets);
if (intrinsics.empty())
return;
@@ -526,7 +534,7 @@
}
struct PadToIntrinsicsPass
- : public impl::PadToIntrinsicsBase<PadToIntrinsicsPass> {
+ : public impl::PadToIntrinsicsPassBase<PadToIntrinsicsPass> {
using Base::Base;
void runOnOperation() override;
};
@@ -536,39 +544,63 @@
void PadToIntrinsicsPass::runOnOperation() {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
- auto funcOp = getOperation();
+
+ auto moduleOp = getOperation();
+ IREE::Stream::AffinityAnalysis affinityAnalysis(moduleOp);
+ if (failed(affinityAnalysis.run())) {
+ return signalPassFailure();
+ }
+ IREE::HAL::DeviceAnalysis deviceAnalysis(moduleOp);
+ if (failed(deviceAnalysis.run())) {
+ return signalPassFailure();
+ }
+
bool padConvOps = padTargetType == PadTargetType::ConvOp ||
padTargetType == PadTargetType::All;
bool padContractionOps = padTargetType == PadTargetType::ContractionOp ||
padTargetType == PadTargetType::All;
SmallVector<linalg::LinalgOp> targetConvOps;
SmallVector<linalg::LinalgOp> targetContractOps;
- funcOp.walk([&](linalg::LinalgOp linalgOp) {
- if (isa<linalg::Conv2DNhwcHwcfOp>(linalgOp.getOperation()) && padConvOps) {
- // Add convOps into worklist.
- targetConvOps.push_back(linalgOp);
- } else if (isa<linalg::BatchMatmulOp, linalg::MatmulOp,
- linalg::MatmulTransposeBOp>(linalgOp.getOperation()) &&
- padContractionOps) {
- // Add named contractionOps into worklist.
- targetContractOps.push_back(linalgOp);
- } else if (isa<linalg::GenericOp>(linalgOp.getOperation()) &&
- linalg::isaContractionOpInterface(linalgOp) &&
- padContractionOps) {
- // Add named generic contractionOps into worklist.
- targetContractOps.push_back(linalgOp);
- }
- });
+ for (auto funcOp : moduleOp.getOps<FunctionOpInterface>()) {
+ funcOp.walk([&](linalg::LinalgOp linalgOp) {
+ if (isa<linalg::Conv2DNhwcHwcfOp>(linalgOp.getOperation()) &&
+ padConvOps) {
+ targetConvOps.push_back(linalgOp);
+ } else if (isa<linalg::BatchMatmulOp, linalg::MatmulOp,
+ linalg::MatmulTransposeBOp>(linalgOp.getOperation()) &&
+ padContractionOps) {
+ targetContractOps.push_back(linalgOp);
+ } else if (isa<linalg::GenericOp>(linalgOp.getOperation()) &&
+ linalg::isaContractionOpInterface(linalgOp) &&
+ padContractionOps) {
+ targetContractOps.push_back(linalgOp);
+ }
+ });
+ }
// Iterate through and pad ops in the worklists.
+ auto getRequiredExecutableTargetAttrs = [&](Operation *op) {
+ SetVector<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs;
+ SmallVector<IREE::Stream::AffinityAttr> affinityAttrs;
+ if (affinityAnalysis.tryInferExecutionAffinity(op, affinityAttrs)) {
+ for (auto affinityAttr : affinityAttrs) {
+ deviceAnalysis.gatherRequiredExecutableTargets(affinityAttr, op,
+ executableTargetAttrs);
+ }
+ }
+ return executableTargetAttrs;
+ };
IRRewriter rewriter(context);
for (auto convOp : targetConvOps) {
rewriter.setInsertionPoint(convOp);
- padConvOp(rewriter, convOp);
+ auto executableTargetAttrs = getRequiredExecutableTargetAttrs(convOp);
+ padConvOp(rewriter, convOp, executableTargetAttrs.getArrayRef());
}
for (auto contractOp : targetContractOps) {
rewriter.setInsertionPoint(contractOp);
- padContractionLikeOp(rewriter, contractOp);
+ auto executableTargetAttrs = getRequiredExecutableTargetAttrs(contractOp);
+ padContractionLikeOp(rewriter, contractOp,
+ executableTargetAttrs.getArrayRef());
}
}
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
index edc1705..ca29a52 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
+++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
@@ -84,8 +84,8 @@
];
}
-def PadToIntrinsics :
- InterfacePass<"iree-preprocessing-pad-to-intrinsics", "mlir::FunctionOpInterface"> {
+def PadToIntrinsicsPass :
+ Pass<"iree-preprocessing-pad-to-intrinsics", "ModuleOp"> {
let summary = "Pad linalg ops such that we can use target's intrinsics.";
let dependentDialects = [
"mlir::linalg::LinalgDialect",
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_mfma.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_mfma.mlir
index 5761741..7d9da45 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_mfma.mlir
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_mfma.mlir
@@ -1,6 +1,6 @@
-// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(func.func(iree-preprocessing-pad-to-intrinsics,canonicalize))" | FileCheck %s
-// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(func.func(iree-preprocessing-pad-to-intrinsics{pad-target-type=conv},canonicalize))" | FileCheck %s -check-prefix=CONVOLUTION
-// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(func.func(iree-preprocessing-pad-to-intrinsics{pad-target-type=contraction},canonicalize))" | FileCheck %s -check-prefix=CONTRACT
+// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(iree-preprocessing-pad-to-intrinsics,func.func(canonicalize))" | FileCheck %s
+// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(iree-preprocessing-pad-to-intrinsics{pad-target-type=conv},func.func(canonicalize))" | FileCheck %s -check-prefix=CONVOLUTION
+// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(iree-preprocessing-pad-to-intrinsics{pad-target-type=contraction},func.func(canonicalize))" | FileCheck %s -check-prefix=CONTRACT
// CHECK-LABEL: func.func @main0(
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_wmma.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_wmma.mlir
index aba35c8..ece0283 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_wmma.mlir
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_wmma.mlir
@@ -1,7 +1,6 @@
-// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(func.func(iree-preprocessing-pad-to-intrinsics,canonicalize))" | FileCheck %s
-// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(func.func(iree-preprocessing-pad-to-intrinsics{pad-target-type=conv},canonicalize))" | FileCheck %s -check-prefix=CONVOLUTION
-// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(func.func(iree-preprocessing-pad-to-intrinsics{pad-target-type=contraction},canonicalize))" | FileCheck %s -check-prefix=CONTRACT
-
+// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(iree-preprocessing-pad-to-intrinsics,func.func(canonicalize))" | FileCheck %s
+// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(iree-preprocessing-pad-to-intrinsics{pad-target-type=conv},func.func(canonicalize))" | FileCheck %s -check-prefix=CONVOLUTION
+// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(iree-preprocessing-pad-to-intrinsics{pad-target-type=contraction},func.func(canonicalize))" | FileCheck %s -check-prefix=CONTRACT
// CHECK: func.func @matmul_static(
// CHECK-SAME: %[[ARG0:.+]]: tensor<10x20xf16>,
diff --git a/compiler/src/iree/compiler/Utils/IntegerSet.h b/compiler/src/iree/compiler/Utils/IntegerSet.h
index 594eecd..aed0376 100644
--- a/compiler/src/iree/compiler/Utils/IntegerSet.h
+++ b/compiler/src/iree/compiler/Utils/IntegerSet.h
@@ -33,6 +33,15 @@
return memoizedValue;
}
+ Value add(StorageT lhs, StorageT rhs) { return get(lhs + rhs); }
+ Value add(Value lhs, StorageT rhs) {
+ APInt lhsValue;
+ if (matchPattern(lhs, m_ConstantInt(&lhsValue))) {
+ return add(lhsValue.getSExtValue(), rhs);
+ }
+ return builder.create<arith::AddIOp>(loc, lhs, get(rhs));
+ }
+
void populate(ValueRange values) {
for (auto value : values) {
APInt intValue;
@@ -66,6 +75,15 @@
}
Value get(APInt value) { return get(value.getSExtValue()); }
+ Value add(int64_t lhs, int64_t rhs) { return get(lhs + rhs); }
+ Value add(Value lhs, int64_t rhs) {
+ APInt lhsValue;
+ if (matchPattern(lhs, m_ConstantInt(&lhsValue))) {
+ return add(lhsValue.getSExtValue(), rhs);
+ }
+ return builder.create<arith::AddIOp>(loc, lhs, get(rhs));
+ }
+
void populate(ValueRange values) {
for (auto value : values) {
APInt intValue;
diff --git a/docs/website/docs/community/blog/posts/microkernels.md b/docs/website/docs/community/blog/posts/microkernels.md
index f56b713..62c8e6a 100644
--- a/docs/website/docs/community/blog/posts/microkernels.md
+++ b/docs/website/docs/community/blog/posts/microkernels.md
@@ -338,7 +338,7 @@
[...]
// -----// IR Dump After Inliner (inline) //----- //
#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu = "znver4", cpu_features = "+mmx,+popcnt,+sse,+sse2,+sse3,+ssse3,+sse4.1,+sse4.2,+avx,+avx2,+sse4a,+fma,+avx512f,+bmi,+bmi2,+aes,+pclmul,+avx512vl,+avx512bw,+avx512dq,+avx512cd,+avx512vbmi,+avx512ifma,+avx512vpopcntdq,+avx512vbmi2,+gfni,+vpclmulqdq,+avx512vnni,+avx512bitalg,+avx512bf16,+adx,+clflushopt,+clwb,+clzero,+cx16,+cx8,+crc32,+f16c,+fsgsbase,+fxsr,+invpcid,+lzcnt,+movbe,+mwaitx,+pku,+prfchw,+rdpid,+rdpru,+rdrnd,+rdseed,+sahf,+sha,+shstk,+vaes,+wbnoinvd,+x87,+xsave,+xsavec,+xsaveopt,+xsaves,+evex512", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 64 : index, target_triple = "x86_64-unknown-unknown-eabi-elf", ukernels = "all"}>
-#device_target_llvm_cpu = #hal.device.target<"llvm-cpu", {executable_targets = [#executable_target_embedded_elf_x86_64_]}>
+#device_target_llvm_cpu = #hal.device.target<"llvm-cpu", {executable_targets = [#executable_target_embedded_elf_x86_64_]}> : !hal.device
module attributes {hal.device.targets = [#device_target_llvm_cpu]} {
func.func @matmul_dynamic(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "sync func @matmul_dynamic(%input0: tensor<?x?xf32>, %input1: tensor<?x?xf32>, %input2: tensor<?x?xf32>) -> (%output0: tensor<?x?xf32>)"}} {
%0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
@@ -357,17 +357,17 @@
}
```
-### IR Dump After CPUMaterializeEncoding
+### IR Dump After CPUMaterializeHostEncoding
```mlir
-// -----// IR Dump After CPUMaterializeEncoding (iree-codegen-cpu-materialize-encoding) //----- //
+// -----// IR Dump After CPUMaterializeHostEncoding (iree-codegen-cpu-materialize-host-encoding) //----- //
[...]
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
[...]
// -----// IR Dump After CSE (cse) //----- //
#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu = "znver4", cpu_features = "+mmx,+popcnt,+sse,+sse2,+sse3,+ssse3,+sse4.1,+sse4.2,+avx,+avx2,+sse4a,+fma,+avx512f,+bmi,+bmi2,+aes,+pclmul,+avx512vl,+avx512bw,+avx512dq,+avx512cd,+avx512vbmi,+avx512ifma,+avx512vpopcntdq,+avx512vbmi2,+gfni,+vpclmulqdq,+avx512vnni,+avx512bitalg,+avx512bf16,+adx,+clflushopt,+clwb,+clzero,+cx16,+cx8,+crc32,+f16c,+fsgsbase,+fxsr,+invpcid,+lzcnt,+movbe,+mwaitx,+pku,+prfchw,+rdpid,+rdpru,+rdrnd,+rdseed,+sahf,+sha,+shstk,+vaes,+wbnoinvd,+x87,+xsave,+xsavec,+xsaveopt,+xsaves,+evex512", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 64 : index, target_triple = "x86_64-unknown-unknown-eabi-elf", ukernels = "all"}>
#map = affine_map<()[s0] -> (s0 ceildiv 16)>
-#device_target_llvm_cpu = #hal.device.target<"llvm-cpu", {executable_targets = [#executable_target_embedded_elf_x86_64_]}>
+#device_target_llvm_cpu = #hal.device.target<"llvm-cpu", {executable_targets = [#executable_target_embedded_elf_x86_64_]}> : !hal.device
module attributes {hal.device.targets = [#device_target_llvm_cpu]} {
func.func @matmul_dynamic(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "sync func @matmul_dynamic(%input0: tensor<?x?xf32>, %input1: tensor<?x?xf32>, %input2: tensor<?x?xf32>) -> (%output0: tensor<?x?xf32>)"}} {
%cst = arith.constant 0.000000e+00 : f32
diff --git a/experimental/regression_suite/shark-test-suite-models/sd3/test_clip.py b/experimental/regression_suite/shark-test-suite-models/sd3/test_clip.py
index 2725322..c08b434 100644
--- a/experimental/regression_suite/shark-test-suite-models/sd3/test_clip.py
+++ b/experimental/regression_suite/shark-test-suite-models/sd3/test_clip.py
@@ -113,7 +113,7 @@
"--iree-opt-aggressively-propagate-transposes=true",
"--iree-codegen-llvmgpu-use-vector-distribution=true",
"--iree-execution-model=async-external",
- "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics{pad-target-type=conv}))",
+ "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics{pad-target-type=conv})",
]
###############################################################################
diff --git a/experimental/regression_suite/shark-test-suite-models/sd3/test_mmdit.py b/experimental/regression_suite/shark-test-suite-models/sd3/test_mmdit.py
index f328211..2e5b189 100644
--- a/experimental/regression_suite/shark-test-suite-models/sd3/test_mmdit.py
+++ b/experimental/regression_suite/shark-test-suite-models/sd3/test_mmdit.py
@@ -97,7 +97,7 @@
"--iree-codegen-llvmgpu-use-vector-distribution",
"--iree-rocm-waves-per-eu=2",
"--iree-execution-model=async-external",
- "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))",
+ "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics)",
]
###############################################################################
diff --git a/experimental/regression_suite/shark-test-suite-models/sd3/test_vae.py b/experimental/regression_suite/shark-test-suite-models/sd3/test_vae.py
index 6d9ab66..881d93d 100644
--- a/experimental/regression_suite/shark-test-suite-models/sd3/test_vae.py
+++ b/experimental/regression_suite/shark-test-suite-models/sd3/test_vae.py
@@ -68,7 +68,7 @@
"--iree-flow-enable-aggressive-fusion=true",
"--iree-codegen-llvmgpu-use-vector-distribution=true",
"--iree-execution-model=async-external",
- "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))",
+ "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics)",
]
###############################################################################
diff --git a/experimental/regression_suite/shark-test-suite-models/sdxl/test_clip.py b/experimental/regression_suite/shark-test-suite-models/sdxl/test_clip.py
index 41b2e61..207ddaf 100644
--- a/experimental/regression_suite/shark-test-suite-models/sdxl/test_clip.py
+++ b/experimental/regression_suite/shark-test-suite-models/sdxl/test_clip.py
@@ -99,7 +99,7 @@
"--iree-opt-aggressively-propagate-transposes=true",
"--iree-codegen-llvmgpu-use-vector-distribution=true",
"--iree-execution-model=async-external",
- "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics{pad-target-type=conv}))",
+ "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics{pad-target-type=conv})",
"--iree-scheduling-dump-statistics-format=json",
"--iree-scheduling-dump-statistics-file=compilation_info.json",
]
diff --git a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py
index 4e1bc70..9d7f942 100644
--- a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py
+++ b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py
@@ -103,7 +103,7 @@
"--iree-codegen-llvmgpu-use-vector-distribution",
"--iree-rocm-waves-per-eu=2",
"--iree-execution-model=async-external",
- "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))",
+ "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics)",
"--iree-scheduling-dump-statistics-format=json",
"--iree-scheduling-dump-statistics-file=compilation_info.json",
]
diff --git a/experimental/regression_suite/shark-test-suite-models/sdxl/test_vae.py b/experimental/regression_suite/shark-test-suite-models/sdxl/test_vae.py
index 49e49d3..5b9ab15 100644
--- a/experimental/regression_suite/shark-test-suite-models/sdxl/test_vae.py
+++ b/experimental/regression_suite/shark-test-suite-models/sdxl/test_vae.py
@@ -68,7 +68,7 @@
"--iree-flow-enable-aggressive-fusion=true",
"--iree-codegen-llvmgpu-use-vector-distribution=true",
"--iree-execution-model=async-external",
- "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))",
+ "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics)",
"--iree-scheduling-dump-statistics-format=json",
"--iree-scheduling-dump-statistics-file=compilation_info.json",
]
diff --git a/runtime/src/iree/hal/device.c b/runtime/src/iree/hal/device.c
index 07bd660..40f2610 100644
--- a/runtime/src/iree/hal/device.c
+++ b/runtime/src/iree/hal/device.c
@@ -65,14 +65,6 @@
iree_string_view_t key, int64_t* out_value) {
IREE_ASSERT_ARGUMENT(device);
IREE_ASSERT_ARGUMENT(out_value);
-
- if (iree_string_view_equal(category,
- iree_make_cstring_view("hal.device.id"))) {
- *out_value =
- iree_string_view_match_pattern(iree_hal_device_id(device), key) ? 1 : 0;
- return iree_ok_status();
- }
-
return _VTABLE_DISPATCH(device, query_i64)(device, category, key, out_value);
}
diff --git a/runtime/src/iree/modules/check/test/success.mlir b/runtime/src/iree/modules/check/test/success.mlir
index 7c5012e..c2a310e 100644
--- a/runtime/src/iree/modules/check/test/success.mlir
+++ b/runtime/src/iree/modules/check/test/success.mlir
@@ -73,7 +73,6 @@
%p8 = arith.addf %p7, %cp1 : tensor<f32>
%p9 = arith.addf %p8, %cp1 : tensor<f32>
%approximately_1 = arith.addf %p9, %cp1 : tensor<f32>
-
check.expect_almost_eq(%approximately_1, %c1) : tensor<f32>
return
}
diff --git a/runtime/src/iree/vm/bytecode/disassembler.c b/runtime/src/iree/vm/bytecode/disassembler.c
index 853c36f..ed6684c 100644
--- a/runtime/src/iree/vm/bytecode/disassembler.c
+++ b/runtime/src/iree/vm/bytecode/disassembler.c
@@ -1326,6 +1326,7 @@
IREE_RETURN_IF_ERROR(iree_string_builder_append_cstring(b, " : "));
EMIT_REF_REG_NAME(false_value_reg);
EMIT_OPTIONAL_VALUE_REF(®s->ref[false_value_reg]);
+ IREE_RETURN_IF_ERROR(iree_string_builder_append_cstring(b, " -> !"));
EMIT_TYPE_NAME(type_def);
break;
}
diff --git a/runtime/src/iree/vm/test/assignment_ops.mlir b/runtime/src/iree/vm/test/assignment_ops.mlir
index 1388c1e..891165d 100644
--- a/runtime/src/iree/vm/test/assignment_ops.mlir
+++ b/runtime/src/iree/vm/test/assignment_ops.mlir
@@ -17,7 +17,7 @@
vm.return
}
- vm.export @test_select_ref attributes {emitc.exclude}
+ vm.export @test_select_ref
vm.func private @test_select_ref() {
%c0 = vm.const.i32 0
%list0 = vm.list.alloc %c0 : (i32) -> !vm.list<i8>
diff --git a/samples/custom_dispatch/cpu/embedded/example_hal.mlir b/samples/custom_dispatch/cpu/embedded/example_hal.mlir
index e9edfd5..91a87ad 100644
--- a/samples/custom_dispatch/cpu/embedded/example_hal.mlir
+++ b/samples/custom_dispatch/cpu/embedded/example_hal.mlir
@@ -43,7 +43,7 @@
// compiled binary (CPU + Vulkan, etc).
#cpu_target = #hal.device.target<"llvm-cpu", [
#x86_64_target
-]>
+]> : !hal.device
module @example attributes {hal.device.targets = [#cpu_target]} {
diff --git a/samples/custom_dispatch/cpu/embedded/example_stream.mlir b/samples/custom_dispatch/cpu/embedded/example_stream.mlir
index a8b6861..910a007 100644
--- a/samples/custom_dispatch/cpu/embedded/example_stream.mlir
+++ b/samples/custom_dispatch/cpu/embedded/example_stream.mlir
@@ -48,7 +48,7 @@
#cpu_target = #hal.device.target<"llvm-cpu", [
#arm_64_target,
#x86_64_target
-]>
+]> : !hal.device
module @example attributes {hal.device.targets = [#cpu_target]} {
diff --git a/samples/custom_dispatch/cpu/embedded/example_transform.mlir b/samples/custom_dispatch/cpu/embedded/example_transform.mlir
index 709a016..858052c 100644
--- a/samples/custom_dispatch/cpu/embedded/example_transform.mlir
+++ b/samples/custom_dispatch/cpu/embedded/example_transform.mlir
@@ -28,7 +28,7 @@
// hence we only support llvm-cpu here.
#cpu_target = #hal.device.target<"llvm-cpu", [
#x86_64_target
-]>
+]> : !hal.device
#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0)>
diff --git a/samples/custom_dispatch/cpu/mlp_plugin/mlp.mlir b/samples/custom_dispatch/cpu/mlp_plugin/mlp.mlir
index 599ed8a..2aa5943 100644
--- a/samples/custom_dispatch/cpu/mlp_plugin/mlp.mlir
+++ b/samples/custom_dispatch/cpu/mlp_plugin/mlp.mlir
@@ -21,7 +21,7 @@
// hence we only support llvm-cpu here.
#cpu_target = #hal.device.target<"llvm-cpu", [
#x86_64_target
-]>
+]> : !hal.device
#map = affine_map<(d0, d1) -> (d0, d1)>
module @example attributes {hal.device.targets = [#cpu_target]} {
diff --git a/samples/custom_dispatch/cpu/mlp_plugin/mlp_linalg.mlir b/samples/custom_dispatch/cpu/mlp_plugin/mlp_linalg.mlir
index 3bc9f12..c725daf 100644
--- a/samples/custom_dispatch/cpu/mlp_plugin/mlp_linalg.mlir
+++ b/samples/custom_dispatch/cpu/mlp_plugin/mlp_linalg.mlir
@@ -72,7 +72,7 @@
// hence we only support llvm-cpu here.
#cpu_target = #hal.device.target<"llvm-cpu", [
#x86_64_target
-]>
+]> : !hal.device
#map = affine_map<(d0, d1) -> (d0, d1)>
module @example attributes {hal.device.targets = [#cpu_target]} {
diff --git a/samples/custom_dispatch/cpu/mlp_plugin/mlp_linalg_two_matmul.mlir b/samples/custom_dispatch/cpu/mlp_plugin/mlp_linalg_two_matmul.mlir
index 8636c94..0dcc358 100644
--- a/samples/custom_dispatch/cpu/mlp_plugin/mlp_linalg_two_matmul.mlir
+++ b/samples/custom_dispatch/cpu/mlp_plugin/mlp_linalg_two_matmul.mlir
@@ -29,7 +29,7 @@
// hence we only support llvm-cpu here.
#cpu_target = #hal.device.target<"llvm-cpu", [
#x86_64_target
-]>
+]> : !hal.device
#map = affine_map<(d0, d1) -> (d0, d1)>
module @example attributes {hal.device.targets = [#cpu_target]} {
diff --git a/samples/custom_dispatch/cpu/mlp_plugin/mlp_torch.mlir b/samples/custom_dispatch/cpu/mlp_plugin/mlp_torch.mlir
index 7876085..6b6fbf1 100644
--- a/samples/custom_dispatch/cpu/mlp_plugin/mlp_torch.mlir
+++ b/samples/custom_dispatch/cpu/mlp_plugin/mlp_torch.mlir
@@ -41,7 +41,7 @@
// hence we only support llvm-cpu here.
#cpu_target = #hal.device.target<"llvm-cpu", [
#x86_64_target
-]>
+]> : !hal.device
#map = affine_map<(d0, d1) -> (d0, d1)>
module @example attributes {hal.device.targets = [#cpu_target]} {
diff --git a/samples/custom_dispatch/cpu/mlp_plugin/mlp_tosa.mlir b/samples/custom_dispatch/cpu/mlp_plugin/mlp_tosa.mlir
index 0b27ae2..4bb0593 100644
--- a/samples/custom_dispatch/cpu/mlp_plugin/mlp_tosa.mlir
+++ b/samples/custom_dispatch/cpu/mlp_plugin/mlp_tosa.mlir
@@ -41,7 +41,7 @@
// hence we only support llvm-cpu here.
#cpu_target = #hal.device.target<"llvm-cpu", [
#x86_64_target
-]>
+]> : !hal.device
module @example attributes {hal.device.targets = [#cpu_target]} {
func.func @mlp_invocation(%lhs: tensor<2x4xf32>, %rhs : tensor<4x8xf32>) -> tensor<2x8xf32> {
diff --git a/samples/custom_dispatch/cuda/kernels/example.mlir b/samples/custom_dispatch/cuda/kernels/example.mlir
index 15a3bb43..62e49c6 100644
--- a/samples/custom_dispatch/cuda/kernels/example.mlir
+++ b/samples/custom_dispatch/cuda/kernels/example.mlir
@@ -27,7 +27,7 @@
#cuda_target = #hal.device.target<"cuda", [
#nvptx_sm_52_target,
#nvptx_sm_80_target
-]>
+]> : !hal.device
module @example attributes {hal.device.targets = [#cuda_target]} {
diff --git a/samples/custom_dispatch/hip/kernels/example.mlir b/samples/custom_dispatch/hip/kernels/example.mlir
index 8819d86..3ca1bad 100644
--- a/samples/custom_dispatch/hip/kernels/example.mlir
+++ b/samples/custom_dispatch/hip/kernels/example.mlir
@@ -23,7 +23,7 @@
// compiled binary.
#rocm_target = #hal.device.target<"rocm", [
#rocm_gfx1100_target
-]>
+]> : !hal.device
module @example attributes {hal.device.targets = [#rocm_target]} {
diff --git a/samples/custom_dispatch/vulkan/shaders/example.mlir b/samples/custom_dispatch/vulkan/shaders/example.mlir
index ef10fb7..69843aa 100644
--- a/samples/custom_dispatch/vulkan/shaders/example.mlir
+++ b/samples/custom_dispatch/vulkan/shaders/example.mlir
@@ -29,7 +29,7 @@
// compiled binary.
#vulkan_target = #hal.device.target<"vulkan", [
#spirv_target
-]>
+]> : !hal.device
module @example attributes {hal.device.targets = [#vulkan_target]} {
diff --git a/samples/custom_dispatch/vulkan/shaders/example_inline.mlir b/samples/custom_dispatch/vulkan/shaders/example_inline.mlir
index 36912bb..4157651 100644
--- a/samples/custom_dispatch/vulkan/shaders/example_inline.mlir
+++ b/samples/custom_dispatch/vulkan/shaders/example_inline.mlir
@@ -27,7 +27,7 @@
// These can come from compiler flags and multiple targets can be supported
// It's possible, for example, to support targeting multiple devices in the same
// compiled binary.
-#vulkan_target = #hal.device.target<"vulkan", [#spirv_target]>
+#vulkan_target = #hal.device.target<"vulkan", [#spirv_target]> : !hal.device
module @example attributes {hal.device.targets = [#vulkan_target]} {
diff --git a/samples/custom_dispatch/vulkan/shaders/example_transform.mlir b/samples/custom_dispatch/vulkan/shaders/example_transform.mlir
index b4885a0..4bea02d 100644
--- a/samples/custom_dispatch/vulkan/shaders/example_transform.mlir
+++ b/samples/custom_dispatch/vulkan/shaders/example_transform.mlir
@@ -33,7 +33,7 @@
// hence we only support vulkan here. It is possible to hand author a custom
// kernel that supports multiple targets by specifying an object per-target, but
// that requires authoring the kernel for multiple targets.
-#vulkan_target = #hal.device.target<"vulkan", [#spirv_target]>
+#vulkan_target = #hal.device.target<"vulkan", [#spirv_target]> : !hal.device
#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0)>
diff --git a/samples/multiple_modules/pipeline_async.mlir b/samples/multiple_modules/pipeline_async.mlir
index 46ad83c..3676359 100644
--- a/samples/multiple_modules/pipeline_async.mlir
+++ b/samples/multiple_modules/pipeline_async.mlir
@@ -1,7 +1,8 @@
// RUN: (iree-compile --iree-execution-model=async-external --iree-hal-target-backends=vmvx %p/module_a.mlir -o=%t.module_a.vmfb && \
// RUN: iree-compile --iree-execution-model=async-external --iree-hal-target-backends=vmvx %p/module_b.mlir -o=%t.module_b.vmfb && \
// RUN: iree-compile --iree-execution-model=async-external --iree-hal-target-backends=vmvx %s | \
-// RUN: iree-run-module --device=local-task \
+// RUN: iree-run-module \
+// RUN: --device=local-task \
// RUN: --module=%t.module_a.vmfb \
// RUN: --module=%t.module_b.vmfb \
// RUN: --module=- --function=run \
diff --git a/samples/multiple_modules/pipeline_sync.mlir b/samples/multiple_modules/pipeline_sync.mlir
index 3f9a6e0..b9f8d15 100644
--- a/samples/multiple_modules/pipeline_sync.mlir
+++ b/samples/multiple_modules/pipeline_sync.mlir
@@ -1,7 +1,8 @@
// RUN: (iree-compile --iree-hal-target-backends=vmvx %p/module_a.mlir -o=%t.module_a.vmfb && \
// RUN: iree-compile --iree-hal-target-backends=vmvx %p/module_b.mlir -o=%t.module_b.vmfb && \
// RUN: iree-compile --iree-hal-target-backends=vmvx %s | \
-// RUN: iree-run-module --device=local-sync \
+// RUN: iree-run-module \
+// RUN: --device=local-sync \
// RUN: --module=%t.module_a.vmfb \
// RUN: --module=%t.module_b.vmfb \
// RUN: --module=- --function=run \
diff --git a/samples/simple_embedding/device_vmvx_sync.c b/samples/simple_embedding/device_vmvx_sync.c
index fa5981c..f1f633f 100644
--- a/samples/simple_embedding/device_vmvx_sync.c
+++ b/samples/simple_embedding/device_vmvx_sync.c
@@ -34,7 +34,7 @@
iree_vm_instance_release(instance);
// Use the default host allocator for buffer allocations.
- iree_string_view_t identifier = iree_make_cstring_view("vmvx");
+ iree_string_view_t identifier = iree_make_cstring_view("local-sync");
iree_hal_allocator_t* device_allocator = NULL;
if (iree_status_is_ok(status)) {
status = iree_hal_allocator_create_heap(identifier, host_allocator,
diff --git a/samples/static_library/static_library_demo.c b/samples/static_library/static_library_demo.c
index 76a0b6c..e8670c5 100644
--- a/samples/static_library/static_library_demo.c
+++ b/samples/static_library/static_library_demo.c
@@ -42,7 +42,7 @@
&library_loader);
// Use the default host allocator for buffer allocations.
- iree_string_view_t identifier = iree_make_cstring_view("sync");
+ iree_string_view_t identifier = iree_make_cstring_view("local-sync");
iree_hal_allocator_t* device_allocator = NULL;
if (iree_status_is_ok(status)) {
status = iree_hal_allocator_create_heap(identifier, host_allocator,
diff --git a/samples/transform_dialect/example_module.mlir b/samples/transform_dialect/example_module.mlir
index 585bb25..017f393 100644
--- a/samples/transform_dialect/example_module.mlir
+++ b/samples/transform_dialect/example_module.mlir
@@ -35,7 +35,7 @@
#hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
iree.gpu.target = #target
}>
- ]>
+ ]> : !hal.device
]
} {
hal.executable private @example_module_dispatch_0 {
diff --git a/tests/compiler_driver/precompile.mlir b/tests/compiler_driver/precompile.mlir
index 5cdd117..b25cb34 100644
--- a/tests/compiler_driver/precompile.mlir
+++ b/tests/compiler_driver/precompile.mlir
@@ -7,4 +7,6 @@
}
// Just check that we have the right target and executable targets.
-// CHECK: module attributes {hal.device.targets = [#hal.device.target<"local", [#hal.executable.target<"vmvx"
+// CHECK: module
+// CHECK-SAME: stream.affinity.default = #hal.device.affinity<@[[DEVICE:.+]]>
+// CHECK: util.global private @[[DEVICE]] = #hal.device.target<"local", [#hal.executable.target<"vmvx"
diff --git a/tests/compiler_driver/preprocessing_flags.mlir b/tests/compiler_driver/preprocessing_flags.mlir
index 3313988..f6e8adc 100644
--- a/tests/compiler_driver/preprocessing_flags.mlir
+++ b/tests/compiler_driver/preprocessing_flags.mlir
@@ -13,7 +13,7 @@
// CHECK: ConvertConv2DToImg2ColPass (iree-preprocessing-convert-conv2d-to-img2col)
// CHECK: PadLinalgOpsPass (iree-preprocessing-pad-linalg-ops)
// CHECK-LABEL: module
-// CHECK-NEXT: util.func public @test(
+// CHECK: util.func public @test(
// CHECK-DAG: %[[ARG0:.+]] = hal.tensor.import %{{[a-zA-Z0-9]+}} "input0" : !hal.buffer_view -> tensor<10x20xf32>
// CHECK-DAG: %[[ARG1:.+]] = hal.tensor.import %{{[a-zA-Z0-9]+}} "input1" : !hal.buffer_view -> tensor<20x30xf32>
// CHECK-DAG: %[[ARG2:.+]] = hal.tensor.import %{{[a-zA-Z0-9]+}} "input2" : !hal.buffer_view -> tensor<10x30xf32>
diff --git a/tests/e2e/regression/libm_linking.mlir b/tests/e2e/regression/libm_linking.mlir
index e63e593..5cbeff0 100644
--- a/tests/e2e/regression/libm_linking.mlir
+++ b/tests/e2e/regression/libm_linking.mlir
@@ -1,5 +1,5 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=llvm-cpu},iree-transformation-pipeline)' %s | FileCheck %s
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=llvm-cpu},iree-transformation-pipeline)' --iree-llvmcpu-link-embedded=false %s | FileCheck %s
+// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=llvm-cpu},iree-transformation-pipeline)' %s | FileCheck %s
+// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=llvm-cpu},iree-transformation-pipeline)' --iree-llvmcpu-link-embedded=false %s | FileCheck %s
// When lowering to CPU code through LLVM, certain LLVM intrinsics require
// linking against libm (the standard C library of math functions, `-lm`).
diff --git a/tools/test/BUILD.bazel b/tools/test/BUILD.bazel
index cf5878c..46709aa 100644
--- a/tools/test/BUILD.bazel
+++ b/tools/test/BUILD.bazel
@@ -31,6 +31,7 @@
"iree-run-mlir.mlir",
"iree-run-module-expected.mlir",
"iree-run-module-inputs.mlir",
+ "iree-run-module-multi.mlir",
"iree-run-module-outputs.mlir",
"iree-run-module.mlir",
"multiple_args.mlir",
diff --git a/tools/test/CMakeLists.txt b/tools/test/CMakeLists.txt
index 75dde66..a866548 100644
--- a/tools/test/CMakeLists.txt
+++ b/tools/test/CMakeLists.txt
@@ -27,6 +27,7 @@
"iree-run-mlir.mlir"
"iree-run-module-expected.mlir"
"iree-run-module-inputs.mlir"
+ "iree-run-module-multi.mlir"
"iree-run-module-outputs.mlir"
"iree-run-module.mlir"
"multiple_args.mlir"
diff --git a/tools/test/compile_pipelines.mlir b/tools/test/compile_pipelines.mlir
index 2fd4a6c..fb6dbbe 100644
--- a/tools/test/compile_pipelines.mlir
+++ b/tools/test/compile_pipelines.mlir
@@ -1,10 +1,10 @@
// RUN: iree-opt --iree-common-input-transformation-pipeline %s | \
// RUN: iree-opt --iree-abi-transformation-pipeline - | \
-// RUN: iree-opt --iree-common-input-transformation-pipeline - | \
+// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-device-assignment-pipeline{target-devices=local})' --iree-hal-local-target-device-backends=vmvx - | \
// RUN: iree-opt --iree-global-optimization-transformation-pipeline - | \
// RUN: iree-opt --iree-flow-transformation-pipeline - | \
// RUN: iree-opt --iree-stream-transformation-pipeline - | \
-// RUN: iree-opt --iree-hal-transformation-pipeline --iree-hal-target-backends=vmvx - | \
+// RUN: iree-opt --iree-hal-transformation-pipeline - | \
// RUN: iree-opt --iree-vm-transformation-pipeline - | \
// RUN: FileCheck %s
diff --git a/tools/test/compile_to_continuation.mlir b/tools/test/compile_to_continuation.mlir
index 9c78153..5476462 100644
--- a/tools/test/compile_to_continuation.mlir
+++ b/tools/test/compile_to_continuation.mlir
@@ -1,79 +1,89 @@
// RUN: iree-compile --compile-to=input %s | \
-// RUN: iree-compile --output-format=vm-asm --iree-hal-target-backends=vmvx - | \
+// RUN: iree-compile --output-format=vm-asm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx - | \
// RUN: FileCheck %s --check-prefix=INPUT-PHASE
// INPUT-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
// RUN: iree-compile --compile-to=abi %s | \
-// RUN: iree-compile --output-format=vm-asm --iree-hal-target-backends=vmvx - | \
+// RUN: iree-compile --output-format=vm-asm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx - | \
// RUN: FileCheck %s --check-prefix=ABI-PHASE
// ABI-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
-// RUN: iree-compile --compile-to=flow %s | \
-// RUN: iree-compile --output-format=vm-asm --iree-hal-target-backends=vmvx - | \
+// RUN: iree-compile --compile-to=flow --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \
+// RUN: iree-compile --output-format=vm-asm - | \
// RUN: FileCheck %s --check-prefix=FLOW-PHASE
// FLOW-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
-// RUN: iree-compile --compile-to=stream %s | \
-// RUN: iree-compile --output-format=vm-asm --iree-hal-target-backends=vmvx - | \
+// RUN: iree-compile --compile-to=flow %s | \
+// RUN: iree-compile --output-format=vm-asm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx - | \
+// RUN: FileCheck %s --check-prefix=FLOW-PHASE-NO-DEVICE
+// FLOW-PHASE-NO-DEVICE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
+
+// RUN: iree-compile --compile-to=stream --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \
+// RUN: iree-compile --output-format=vm-asm - | \
// RUN: FileCheck %s --check-prefix=STREAM-PHASE
// STREAM-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
-// RUN: iree-compile --compile-to=executable-sources --iree-hal-target-backends=vmvx %s | \
+// RUN: iree-compile --compile-to=stream %s | \
+// RUN: iree-compile --output-format=vm-asm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx - | \
+// RUN: FileCheck %s --check-prefix=STREAM-PHASE-NO-DEVICE
+// STREAM-PHASE-NO-DEVICE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
+
+// RUN: iree-compile --compile-to=executable-sources --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \
// RUN: iree-compile --output-format=vm-asm - | \
// RUN: FileCheck %s --check-prefix=EXECUTABLE-SOURCES-PHASE
// EXECUTABLE-SOURCES-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
-// RUN: iree-compile --compile-to=executable-targets --iree-hal-target-backends=vmvx %s | \
+// RUN: iree-compile --compile-to=executable-targets --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \
// RUN: iree-compile --output-format=vm-asm - | \
// RUN: FileCheck %s --check-prefix=EXECUTABLE-TARGETS-PHASE
// EXECUTABLE-TARGETS-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
-// RUN: iree-compile --compile-to=hal --iree-hal-target-backends=vmvx %s | \
+// RUN: iree-compile --compile-to=hal --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \
// RUN: iree-compile --output-format=vm-asm - | \
// RUN: FileCheck %s --check-prefix=HAL-PHASE
// HAL-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
-// RUN: iree-compile --compile-to=vm --iree-hal-target-backends=vmvx %s | \
+// RUN: iree-compile --compile-to=vm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \
// RUN: iree-compile --output-format=vm-asm - | \
// RUN: FileCheck %s --check-prefix=VM-PHASE
// VM-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
// RUN: iree-compile --compile-to=input %s | \
-// RUN: iree-compile --compile-from=input --output-format=vm-asm --iree-hal-target-backends=vmvx - | \
+// RUN: iree-compile --compile-from=input --output-format=vm-asm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx - | \
// RUN: FileCheck %s --check-prefix=FROM-ABI-PHASE
// FROM-INPUT-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
// RUN: iree-compile --compile-to=abi %s | \
-// RUN: iree-compile --compile-from=abi --output-format=vm-asm --iree-hal-target-backends=vmvx - | \
+// RUN: iree-compile --compile-from=abi --output-format=vm-asm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx - | \
// RUN: FileCheck %s --check-prefix=FROM-ABI-PHASE
// FROM-ABI-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
-// RUN: iree-compile --compile-to=flow %s | \
-// RUN: iree-compile --compile-from=flow --output-format=vm-asm --iree-hal-target-backends=vmvx - | \
+// RUN: iree-compile --compile-to=flow --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \
+// RUN: iree-compile --compile-from=flow --output-format=vm-asm - | \
// RUN: FileCheck %s --check-prefix=FROM-FLOW-PHASE
// FROM-FLOW-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
-// RUN: iree-compile --compile-to=stream %s | \
-// RUN: iree-compile --compile-from=stream --output-format=vm-asm --iree-hal-target-backends=vmvx - | \
+// RUN: iree-compile --compile-to=stream --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \
+// RUN: iree-compile --compile-from=stream --output-format=vm-asm - | \
// RUN: FileCheck %s --check-prefix=FROM-STREAM-PHASE
// FROM-STREAM-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
-// RUN: iree-compile --compile-to=executable-sources --iree-hal-target-backends=vmvx %s | \
+// RUN: iree-compile --compile-to=executable-sources --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \
// RUN: iree-compile --compile-from=executable-sources --output-format=vm-asm - | \
// RUN: FileCheck %s --check-prefix=FROM-EXECUTABLE-SOURCES-PHASE
// FROM-EXECUTABLE-SOURCES-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
-// RUN: iree-compile --compile-to=executable-targets --iree-hal-target-backends=vmvx %s | \
+// RUN: iree-compile --compile-to=executable-targets --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \
// RUN: iree-compile --compile-from=executable-targets --output-format=vm-asm - | \
// RUN: FileCheck %s --check-prefix=FROM-EXECUTABLE-TARGETS-PHASE
// FROM-EXECUTABLE-TARGETS-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
-// RUN: iree-compile --compile-to=hal --iree-hal-target-backends=vmvx %s | \
+// RUN: iree-compile --compile-to=hal --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \
// RUN: iree-compile --compile-from=hal --output-format=vm-asm - | \
// RUN: FileCheck %s --check-prefix=FROM-HAL-PHASE
// FROM-HAL-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
-// RUN: iree-compile --compile-to=vm --iree-hal-target-backends=vmvx %s | \
+// RUN: iree-compile --compile-to=vm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \
// RUN: iree-compile --compile-from=vm --output-format=vm-asm - | \
// RUN: FileCheck %s --check-prefix=FROM-VM-PHASE
// FROM-VM-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
diff --git a/tools/test/compile_to_phase.mlir b/tools/test/compile_to_phase.mlir
index 0390564..f1861a0 100644
--- a/tools/test/compile_to_phase.mlir
+++ b/tools/test/compile_to_phase.mlir
@@ -7,36 +7,44 @@
// ABI-PHASE: %[[INPUT:.+]] = hal.tensor.import %[[ARG0]] "input0" : !hal.buffer_view -> tensor<f32>
// ABI-PHASE: math.absf %[[INPUT]] : tensor<f32>
-// RUN: iree-compile --compile-to=flow %s | FileCheck %s --check-prefix=FLOW-PHASE
+// RUN: iree-compile --compile-to=flow %s --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx | FileCheck %s --check-prefix=FLOW-PHASE
// FLOW-PHASE: flow.executable.export public @abs_dispatch_0
// FLOW-PHASE: flow.dispatch @abs_dispatch_0
-// RUN: iree-compile --compile-to=stream %s | FileCheck %s --check-prefix=STREAM-PHASE
+// RUN: iree-compile --compile-to=flow %s | FileCheck %s --check-prefix=FLOW-PHASE-NO-DEVICE
+// FLOW-PHASE-NO-DEVICE: flow.executable.export public @abs_dispatch_0
+// FLOW-PHASE-NO-DEVICE: flow.dispatch @abs_dispatch_0
+
+// RUN: iree-compile --compile-to=stream --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | FileCheck %s --check-prefix=STREAM-PHASE
// STREAM-PHASE: stream.executable.export public @abs_dispatch_0
// STREAM-PHASE: stream.cmd.dispatch @abs_dispatch_0
-// RUN: iree-compile --compile-to=executable-sources --iree-hal-target-backends=vmvx %s | FileCheck %s --check-prefix=EXECUTABLE-SOURCES-PHASE
+// RUN: iree-compile --compile-to=stream %s | FileCheck %s --check-prefix=STREAM-PHASE-NO-DEVICE
+// STREAM-PHASE-NO-DEVICE: stream.executable.export public @abs_dispatch_0
+// STREAM-PHASE-NO-DEVICE: stream.cmd.dispatch @abs_dispatch_0
+
+// RUN: iree-compile --compile-to=executable-sources --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | FileCheck %s --check-prefix=EXECUTABLE-SOURCES-PHASE
// EXECUTABLE-SOURCES-PHASE: hal.executable private @abs_dispatch_0
// EXECUTABLE-SOURCES-PHASE: hal.executable.variant
// EXECUTABLE-SOURCES-PHASE: linalg.generic
// EXECUTABLE-SOURCES-PHASE: math.absf
-// RUN: iree-compile --compile-to=executable-targets --iree-hal-target-backends=vmvx %s | FileCheck %s --check-prefix=EXECUTABLE-TARGETS-PHASE
+// RUN: iree-compile --compile-to=executable-targets --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | FileCheck %s --check-prefix=EXECUTABLE-TARGETS-PHASE
// EXECUTABLE-TARGETS-PHASE: hal.executable private @abs_dispatch_0
// EXECUTABLE-TARGETS-PHASE: hal.executable.variant
// EXECUTABLE-TARGETS-PHASE: vm.abs.f32
-// RUN: iree-compile --compile-to=hal --iree-hal-target-backends=vmvx %s | FileCheck %s --check-prefix=HAL-PHASE
+// RUN: iree-compile --compile-to=hal --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | FileCheck %s --check-prefix=HAL-PHASE
// HAL-PHASE: hal.executable private @abs_dispatch_0
// HAL-PHASE: hal.executable.binary
// HAL-PHASE: hal.command_buffer.dispatch
-// RUN: iree-compile --compile-to=vm --iree-hal-target-backends=vmvx %s | FileCheck %s --check-prefix=VM-PHASE
+// RUN: iree-compile --compile-to=vm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | FileCheck %s --check-prefix=VM-PHASE
// VM-PHASE: vm.rodata private @abs_dispatch_0
// VM-PHASE: vm.call @hal.command_buffer.dispatch
-// RUN: iree-compile --output-format=vm-asm --compile-to=end --iree-hal-target-backends=vmvx %s | FileCheck %s --check-prefix=END-PHASE
-// RUN: iree-compile --output-format=vm-asm --iree-hal-target-backends=vmvx %s | FileCheck %s --check-prefix=END-PHASE
+// RUN: iree-compile --output-format=vm-asm --compile-to=end --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | FileCheck %s --check-prefix=END-PHASE
+// RUN: iree-compile --output-format=vm-asm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | FileCheck %s --check-prefix=END-PHASE
// END-PHASE: vm.rodata private @abs_dispatch_0
// END-PHASE: vm.call @hal.command_buffer.dispatch
diff --git a/tools/test/iree-run-module-multi.mlir b/tools/test/iree-run-module-multi.mlir
new file mode 100644
index 0000000..3412596
--- /dev/null
+++ b/tools/test/iree-run-module-multi.mlir
@@ -0,0 +1,43 @@
+// Tests that multiple devices are supported through iree-run-module by
+// providing two local thread pools. This is not optimal and not an intended
+// route for multi-device CPU workloads but requires no additional hardware
+// resources for the test and still verifies the compiler/runtime tooling
+// rendezvous of devices as specified on the command line.
+
+// RUN: (iree-compile %s \
+// RUN: --iree-execution-model=async-external \
+// RUN: --iree-hal-target-device=device_a=local[0] \
+// RUN: --iree-hal-target-device=device_b=local[1] \
+// RUN: --iree-hal-local-target-device-backends=vmvx | \
+// RUN: iree-run-module \
+// RUN: --module=- \
+// RUN: --function=mutli_device_mul \
+// RUN: --input=4xf32=10,11,12,13 \
+// RUN: --device=local-task \
+// RUN: --device=local-task \
+// RUN: --task_topology_group_count=1) | \
+// RUN: FileCheck %s
+
+// CHECK: EXEC @mutli_device_mul
+// CHECK-NEXT: result[0]: hal.buffer_view
+// CHECK-NEXT: 4xf32=0 55 144 273
+func.func public @mutli_device_mul(
+ // Input argument is resident on device_a (tooling default to first device).
+ %input_a: tensor<4xf32> {iree.abi.affinity = #hal.device.promise<@device_a>}
+) -> (
+ // Output result is expected to be on device_a (though not required).
+ tensor<4xf32> {iree.abi.affinity = #hal.device.promise<@device_a>}
+) {
+ // Compute on device_a (input is there).
+ %constant_a = arith.constant dense<[0.0, 1.0, 2.0, 3.0]> : tensor<4xf32>
+ %transient_a = arith.mulf %input_a, %constant_a : tensor<4xf32>
+ // Transfer the result from device_a -> device_b.
+ %transient_b = flow.tensor.transfer %transient_a : tensor<4xf32> to #hal.device.promise<@device_b>
+ // Compute on device_b.
+ %constant_b = arith.constant dense<[4.0, 5.0, 6.0, 7.0]> : tensor<4xf32>
+ %result_b = arith.mulf %transient_b, %constant_b : tensor<4xf32>
+ // Transfer the result from device_b -> device_a.
+ %result_a = flow.tensor.transfer %result_b : tensor<4xf32> to #hal.device.promise<@device_a>
+ // Return the result on device_a (as required by ABI attr).
+ func.return %result_a : tensor<4xf32>
+}
diff --git a/tools/testing/e2e/iree-e2e-conv2d-test.cc b/tools/testing/e2e/iree-e2e-conv2d-test.cc
index 31d02e9..c4158fd 100644
--- a/tools/testing/e2e/iree-e2e-conv2d-test.cc
+++ b/tools/testing/e2e/iree-e2e-conv2d-test.cc
@@ -549,14 +549,17 @@
return EXIT_FAILURE;
}
+ // Run the tests. Note that some modules may be compiled for other platforms
+ // and not have the required architectures for execution within them - to keep
+ // the test runner dumber we gracefully fail those cases by returning success.
iree_status_t status = iree_test_utils_load_and_run_e2e_tests(
iree_allocator_system(), conv2d_test_module_create);
int exit_code = EXIT_SUCCESS;
if (!iree_status_is_ok(status)) {
iree_status_fprint(stderr, status);
- bool is_unavailable = iree_status_is_unavailable(status);
+ bool is_device_unavailable = iree_status_is_not_found(status);
iree_status_free(status);
- exit_code = is_unavailable ? EXIT_SUCCESS : EXIT_FAILURE;
+ exit_code = is_device_unavailable ? EXIT_SUCCESS : EXIT_FAILURE;
}
IREE_TRACE_APP_EXIT(exit_code);
diff --git a/tools/testing/e2e/iree-e2e-matmul-test.cc b/tools/testing/e2e/iree-e2e-matmul-test.cc
index c9c82f9..f2773f0 100644
--- a/tools/testing/e2e/iree-e2e-matmul-test.cc
+++ b/tools/testing/e2e/iree-e2e-matmul-test.cc
@@ -725,14 +725,17 @@
return EXIT_FAILURE;
}
+ // Run the tests. Note that some modules may be compiled for other platforms
+ // and not have the required architectures for execution within them - to keep
+ // the test runner dumber we gracefully fail those cases by returning success.
iree_status_t status = iree_test_utils_load_and_run_e2e_tests(
iree_allocator_system(), matmul_test_module_create);
int exit_code = EXIT_SUCCESS;
if (!iree_status_is_ok(status)) {
iree_status_fprint(stderr, status);
- bool is_unavailable = iree_status_is_unavailable(status);
+ bool is_device_unavailable = iree_status_is_not_found(status);
iree_status_free(status);
- exit_code = is_unavailable ? EXIT_SUCCESS : EXIT_FAILURE;
+ exit_code = is_device_unavailable ? EXIT_SUCCESS : EXIT_FAILURE;
}
IREE_TRACE_APP_EXIT(exit_code);
diff --git a/tools/testing/e2e/test_utils.c b/tools/testing/e2e/test_utils.c
index 926b0ea..2981148 100644
--- a/tools/testing/e2e/test_utils.c
+++ b/tools/testing/e2e/test_utils.c
@@ -413,7 +413,7 @@
return iree_make_status(
// The error status matters. We distinguish "feature not supported"
// which is a normal thing to happen from actual errors.
- IREE_STATUS_UNAVAILABLE,
+ IREE_STATUS_NOT_FOUND,
"target device does not have the required feature '%.*s'",
(int)required_feature.size, required_feature.data);
}
diff --git a/tools/testing/e2e/test_utils.h b/tools/testing/e2e/test_utils.h
index f3a18d2..f095537 100644
--- a/tools/testing/e2e/test_utils.h
+++ b/tools/testing/e2e/test_utils.h
@@ -133,7 +133,7 @@
iree_allocator_t host_allocator);
// Returns OK if there are declared requirements on |module| and they are all
-// met and otherwise UNAVAILABLE indicating that the module should not be run.
+// met and otherwise NOT_FOUND indicating that the module should not be run.
iree_status_t iree_test_utils_check_module_requirements(
iree_vm_module_t* module);