Replacing hal.ex.shared_device with hal.devices.* ops. (#15916)
This bumps the HAL version as the existing hal.ex.shared_device op is
removed. The runtime HAL module still only reports a single device and
the compiler code is still selecting the "default" device (0). Future
changes will support multiple `--device=` flags to register multiple
available devices in tools and add attributes/lookup for mapping
stream/device affinity to device ordinals.
The plan is to have a `hal.devices.lookup` op that is hoisted into an
initializer to enumerate devices and store the selected ones in globals.
The enumeration code will emit a loop over `hal.devices.count` and use
`hal.devices.get` and the various device query methods to match the
requested devices.
diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp b/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp
index 144a501..2d0e90e 100644
--- a/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp
+++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp
@@ -130,8 +130,12 @@
// HACK: this is relying on the fact that there's only one HAL device.
// We should instead have a way of creating fences on the device that
// is used to produce the tensors we're wrapping.
- auto device =
- entryBuilder.create<IREE::HAL::ExSharedDeviceOp>(importOp.getLoc());
+ //
+ // TODO(multi-device): emit get with derived ordinal or lookup with attr. We
+ // could always say device 0 for now but could instead look for an
+ // iree.abi.affinity/iree.abi.device/etc.
+ Value device =
+ IREE::HAL::DeviceType::resolveAny(importOp.getLoc(), entryBuilder);
// When exporting a fence we need to put a barrier between the rest of the
// program and the tensors consumed by the import.
diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points_coarse_fences.mlir b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points_coarse_fences.mlir
index c89451a..f6ff6d1 100644
--- a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points_coarse_fences.mlir
+++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points_coarse_fences.mlir
@@ -117,7 +117,7 @@
// CHECK: func.func private @_import(%[[ARG0_TENSOR:.+]]: tensor<?x2xi32>, %[[ARG1_TENSOR:.+]]: tensor<?x3xi32>) -> (tensor<2x?xi32>, tensor<3x?xi32>) {
// Prepare fences and put a barrier on input arguments:
-// CHECK: %[[DEVICE:.+]] = hal.ex.shared_device
+// CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
// CHECK: %[[WAIT_FENCE:.+]] = hal.fence.create device(%[[DEVICE]]
// CHECK: %[[ARG_BARRIER:.+]]:2 = hal.tensor.barrier join(%[[ARG0_TENSOR]], %[[ARG1_TENSOR]] : tensor<?x2xi32>, tensor<?x3xi32>) => %[[WAIT_FENCE]] : !hal.fence
// CHECK: %[[SIGNAL_FENCE:.+]] = hal.fence.create device(%[[DEVICE]]
@@ -186,7 +186,7 @@
// CHECK: func.func private @_importI32Effects(%[[ARG0_TENSOR:.+]]: tensor<4xf32>) -> i32 {
// Wait for the inputs to be ready and create the signal fence to wait on.
-// CHECK: %[[DEVICE:.+]] = hal.ex.shared_device
+// CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
// CHECK: %[[WAIT_FENCE:.+]] = hal.fence.create device(%[[DEVICE]]
// CHECK: %[[ARG0_BARRIER:.+]] = hal.tensor.barrier join(%[[ARG0_TENSOR]] : tensor<4xf32>) => %[[WAIT_FENCE]] : !hal.fence
// CHECK: %[[SIGNAL_FENCE:.+]] = hal.fence.create device(%[[DEVICE]]
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir
index 9e08de1..7b407a6 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir
@@ -73,7 +73,8 @@
testing.func.b = @dispatch_0::@spirv,
testing.func.c = @dispatch_0::@spirv::@dispatch_0
} {
- %device = hal.ex.shared_device : !hal.device
+ %c0 = arith.constant 0 : index
+ %device = hal.devices.get %c0 : !hal.device
%cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") : !hal.command_buffer attributes {
testing.op.a = @dispatch_0,
testing.op.b = @dispatch_0::@spirv,
@@ -86,7 +87,8 @@
return
}
util.initializer {
- %device = hal.ex.shared_device : !hal.device
+ %c0 = arith.constant 0 : index
+ %device = hal.devices.get %c0 : !hal.device
%cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") : !hal.command_buffer
%c1 = arith.constant 1 : index
hal.command_buffer.dispatch.symbol<%cmd : !hal.command_buffer> target(@dispatch_0::@spirv::@dispatch_0) workgroups([%c1, %c1, %c1])
@@ -237,7 +239,8 @@
}
}
func.func @two_target_environments_1() -> () {
- %device = hal.ex.shared_device : !hal.device
+ %c0 = arith.constant 0 : index
+ %device = hal.devices.get %c0 : !hal.device
%cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") : !hal.command_buffer
%c1 = arith.constant 1 : index
hal.command_buffer.dispatch.symbol<%cmd : !hal.command_buffer> target(@dispatch_0::@spirv::@dispatch_0) workgroups([%c1, %c1, %c1])
@@ -245,7 +248,8 @@
return
}
func.func @two_target_environments_2() -> () {
- %device = hal.ex.shared_device : !hal.device
+ %c0 = arith.constant 0 : index
+ %device = hal.devices.get %c0 : !hal.device
%cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") : !hal.command_buffer
%c1 = arith.constant 1 : index
hal.command_buffer.dispatch.symbol<%cmd : !hal.command_buffer> target(@dispatch_1::@spirv::@dispatch_1) workgroups([%c1, %c1, %c1])
diff --git a/compiler/src/iree/compiler/Codegen/VMVX/test/link_executables.mlir b/compiler/src/iree/compiler/Codegen/VMVX/test/link_executables.mlir
index 98dcd6c..93425d8 100644
--- a/compiler/src/iree/compiler/Codegen/VMVX/test/link_executables.mlir
+++ b/compiler/src/iree/compiler/Codegen/VMVX/test/link_executables.mlir
@@ -72,7 +72,8 @@
testing.func.b = @dispatch_0::@vmvx,
testing.func.c = @dispatch_0::@vmvx::@dispatch_0
} {
- %device = hal.ex.shared_device : !hal.device
+ %c0 = arith.constant 0 : index
+ %device = hal.devices.get %c0 : !hal.device
%cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") : !hal.command_buffer attributes {
testing.op.a = @dispatch_0,
testing.op.b = @dispatch_0::@vmvx,
@@ -85,7 +86,8 @@
return
}
util.initializer {
- %device = hal.ex.shared_device : !hal.device
+ %c0 = arith.constant 0 : index
+ %device = hal.devices.get %c0 : !hal.device
%cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") : !hal.command_buffer
%c1 = arith.constant 1 : index
hal.command_buffer.dispatch.symbol<%cmd : !hal.command_buffer> target(@dispatch_0::@vmvx::@dispatch_0) workgroups([%c1, %c1, %c1])
diff --git a/compiler/src/iree/compiler/ConstEval/Runtime.cpp b/compiler/src/iree/compiler/ConstEval/Runtime.cpp
index 25d3c43..c5a93cf 100644
--- a/compiler/src/iree/compiler/ConstEval/Runtime.cpp
+++ b/compiler/src/iree/compiler/ConstEval/Runtime.cpp
@@ -440,9 +440,10 @@
iree_hal_driver_release(driver);
// Create hal module.
- IREE_CHECK_OK(iree_hal_module_create(runtime.instance.get(), device.get(),
- IREE_HAL_MODULE_FLAG_NONE,
- iree_allocator_system(), &hal_module));
+ iree_hal_device_t *device_ptr = device.get();
+ IREE_CHECK_OK(iree_hal_module_create(
+ runtime.instance.get(), /*device_count=*/1, &device_ptr,
+ IREE_HAL_MODULE_FLAG_NONE, iree_allocator_system(), &hal_module));
// Bytecode module.
IREE_CHECK_OK(iree_vm_bytecode_module_create(
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/BUILD.bazel
index c630dc4..fcbb192 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/BUILD.bazel
@@ -21,6 +21,7 @@
"ConvertChannelOps.cpp",
"ConvertCommandBufferOps.cpp",
"ConvertDeviceOps.cpp",
+ "ConvertDevicesOps.cpp",
"ConvertExecutableOps.cpp",
"ConvertExperimentalOps.cpp",
"ConvertFenceOps.cpp",
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/CMakeLists.txt
index 5a5cb8a..713982b 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/CMakeLists.txt
@@ -22,6 +22,7 @@
"ConvertChannelOps.cpp"
"ConvertCommandBufferOps.cpp"
"ConvertDeviceOps.cpp"
+ "ConvertDevicesOps.cpp"
"ConvertExecutableOps.cpp"
"ConvertExperimentalOps.cpp"
"ConvertFenceOps.cpp"
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDevicesOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDevicesOps.cpp
new file mode 100644
index 0000000..266233d
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDevicesOps.cpp
@@ -0,0 +1,23 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir::iree_compiler {
+
+void populateHALDevicesToVMPatterns(MLIRContext *context,
+ SymbolTable &importSymbols,
+ TypeConverter &typeConverter,
+ RewritePatternSet &patterns) {
+ patterns.insert<VMImportOpConversion<IREE::HAL::DevicesCountOp>>(
+ context, importSymbols, typeConverter, "hal.devices.count");
+ patterns.insert<VMImportOpConversion<IREE::HAL::DevicesGetOp>>(
+ context, importSymbols, typeConverter, "hal.devices.get");
+}
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExperimentalOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExperimentalOps.cpp
index e724412..b8bc13d 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExperimentalOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExperimentalOps.cpp
@@ -14,8 +14,6 @@
SymbolTable &importSymbols,
TypeConverter &typeConverter,
RewritePatternSet &patterns) {
- patterns.insert<VMImportOpConversion<IREE::HAL::ExSharedDeviceOp>>(
- context, importSymbols, typeConverter, "hal.ex.shared_device");
patterns.insert<VMImportOpConversion<IREE::HAL::ExFileFromMemoryOp>>(
context, importSymbols, typeConverter, "hal.ex.file.from_memory");
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/Patterns.cpp
index 983e2cf..592e78f 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/Patterns.cpp
@@ -51,6 +51,10 @@
SymbolTable &importSymbols,
TypeConverter &typeConverter,
RewritePatternSet &patterns);
+extern void populateHALDevicesToVMPatterns(MLIRContext *context,
+ SymbolTable &importSymbols,
+ TypeConverter &typeConverter,
+ RewritePatternSet &patterns);
extern void populateHALExecutableToVMPatterns(MLIRContext *context,
SymbolTable &importSymbols,
TypeConverter &typeConverter,
@@ -79,6 +83,8 @@
patterns);
populateHALDeviceToVMPatterns(context, importSymbols, typeConverter,
patterns);
+ populateHALDevicesToVMPatterns(context, importSymbols, typeConverter,
+ patterns);
populateHALExecutableToVMPatterns(context, importSymbols, typeConverter,
patterns);
populateHALExperimentalToVMPatterns(context, importSymbols, typeConverter,
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/BUILD.bazel
index 0f80024..a486719 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/BUILD.bazel
@@ -22,6 +22,7 @@
"channel_ops.mlir",
"command_buffer_ops.mlir",
"device_ops.mlir",
+ "devices_ops.mlir",
"executable_ops.mlir",
"fence_ops.mlir",
],
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/CMakeLists.txt
index 9ec645e..b65d9cf 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/CMakeLists.txt
@@ -20,6 +20,7 @@
"channel_ops.mlir"
"command_buffer_ops.mlir"
"device_ops.mlir"
+ "devices_ops.mlir"
"executable_ops.mlir"
"fence_ops.mlir"
TOOLS
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/devices_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/devices_ops.mlir
new file mode 100644
index 0000000..b8423ad
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/devices_ops.mlir
@@ -0,0 +1,18 @@
+// RUN: iree-opt --split-input-file --iree-convert-hal-to-vm --canonicalize --iree-vm-target-index-bits=32 %s | FileCheck %s
+
+// CHECK-LABEL: @devices_count
+func.func @devices_count() -> index {
+ // CHECK: = vm.call @hal.devices.count() {nosideeffects} : () -> i32
+ %device_count = hal.devices.count : index
+ return %device_count : index
+}
+
+// -----
+
+// CHECK-LABEL: @devices_get
+// CHECK-SAME: (%[[INDEX:.+]]: i32)
+func.func @devices_get(%index: index) -> !hal.device {
+ // CHECK: = vm.call @hal.devices.get(%[[INDEX]]) {nosideeffects} : (i32) -> !vm.ref<!hal.device>
+ %device = hal.devices.get %index : !hal.device
+ return %device : !hal.device
+}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
index 55994dd..8b34b92 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
@@ -51,9 +51,9 @@
auto resultTypes = llvm::to_vector(resolveOp.getResultTypes());
assert(!resultTypes.empty() && "must have at least one result");
- // TODO(benvanik): make this do multi-device lookup and other fancy things.
+ // TODO(multi-device): emit get with derived ordinal or lookup with attr.
Value device =
- rewriter.create<IREE::HAL::ExSharedDeviceOp>(resolveOp.getLoc());
+ IREE::HAL::DeviceType::resolveAny(resolveOp.getLoc(), rewriter);
SmallVector<Value> results;
if (resultTypes[0].isa<IREE::HAL::DeviceType>()) {
@@ -678,7 +678,7 @@
// Get the device handle we're executing against in this execution region.
// Note that this is a dynamic value: we have to treat the device as unknown
// here.
- auto deviceValue = rewriter.create<IREE::HAL::CommandBufferDeviceOp>(
+ Value device = rewriter.create<IREE::HAL::CommandBufferDeviceOp>(
loc, rewriter.getType<IREE::HAL::DeviceType>(), commandBuffer);
// Prepare for variant switch table by gathering the conditions selecting
@@ -703,7 +703,7 @@
auto exportOp = caseExportOps[i].second;
auto variantOp =
exportOp->getParentOfType<IREE::HAL::ExecutableVariantOp>();
- return variantOp.buildCondition(deviceValue, rewriter);
+ return variantOp.buildCondition(device, rewriter);
},
rewriter);
@@ -718,12 +718,12 @@
auto caseBuilder = OpBuilder::atBlockBegin(&caseBlock);
// Record push constants and buffer bindings.
- recordParameters(loc, deviceValue, commandBuffer, dispatchOp, adaptor,
+ recordParameters(loc, device, commandBuffer, dispatchOp, adaptor,
exportOp.getLayout(), caseBuilder);
// Dispatch with a target-specific workgroup count.
auto caseWorkgroupCount = exportOp.calculateWorkgroupCount(
- loc, deviceValue, adaptor.getWorkload(), caseBuilder);
+ loc, device, adaptor.getWorkload(), caseBuilder);
caseBuilder.create<IREE::HAL::CommandBufferDispatchSymbolOp>(
loc, commandBuffer, entryPointAttr, caseWorkgroupCount[0],
caseWorkgroupCount[1], caseWorkgroupCount[2]);
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/channel_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/channel_ops.mlir
index 9161467..d11abe8 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/channel_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/channel_ops.mlir
@@ -3,7 +3,7 @@
// CHECK-LABEL: @channel_create
// CHECK-SAME: () -> !hal.channel
func.func @channel_create() -> !stream.channel {
- // CHECK-DAG: %[[DEVICE:.+]] = hal.ex.shared_device : !hal.device
+ // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}} : !hal.device
// CHECK-DAG: %[[AFFINITY:.+]] = arith.constant 3
// CHECK-DAG: %[[ID:.+]] = util.null : !util.buffer
// CHECK-DAG: %[[GROUP:.+]] = util.buffer.constant : !util.buffer = "group"
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir
index 659c11b..dd43362 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir
@@ -2,7 +2,7 @@
// CHECK-LABEL: @contextResolveAllocator
func.func @contextResolveAllocator() -> !hal.allocator {
- // CHECK: %[[DEVICE:.+]] = hal.ex.shared_device
+ // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
// CHECK: %[[ALLOCATOR:.+]] = hal.device.allocator<%[[DEVICE]] : !hal.device> : !hal.allocator
%allocator = stream.context.resolve : !hal.allocator
// CHECK: return %[[ALLOCATOR]]
@@ -13,7 +13,7 @@
// CHECK-LABEL: @contextResolveDevice
func.func @contextResolveDevice() -> !hal.device {
- // CHECK: %[[DEVICE:.+]] = hal.ex.shared_device
+ // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
%device = stream.context.resolve : !hal.device
// CHECK: return %[[DEVICE]]
return %device : !hal.device
@@ -23,7 +23,7 @@
// CHECK-LABEL: @contextResolveDeviceQueueAffinityAny
func.func @contextResolveDeviceQueueAffinityAny() -> (!hal.device, i64) {
- // CHECK-DAG: %[[DEVICE:.+]] = hal.ex.shared_device
+ // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
// CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant -1 : i64
%device, %queue_affinity_any = stream.context.resolve on(#hal.affinity.queue<*>) : !hal.device, i64
// CHECK: return %[[DEVICE]], %[[QUEUE_AFFINITY]]
@@ -34,7 +34,7 @@
// CHECK-LABEL: @contextResolveDeviceQueueAffinity45
func.func @contextResolveDeviceQueueAffinity45() -> (!hal.device, i64) {
- // CHECK: %[[DEVICE:.+]] = hal.ex.shared_device
+ // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
// CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant 48 : i64
%device, %queue_affinity_45 = stream.context.resolve on(#hal.affinity.queue<[4, 5]>) : !hal.device, i64
// CHECK: return %[[DEVICE]], %[[QUEUE_AFFINITY]]
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/file_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/file_ops.mlir
index 2ac1385..9473df7 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/file_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/file_ops.mlir
@@ -5,7 +5,8 @@
func.func @file_constant(%buffer: !util.buffer) {
%c0 = arith.constant 0 : index
%c1088 = arith.constant 1088 : index
- // CHECK: = hal.ex.file.from_memory device(%device : !hal.device) affinity(%c-1_i64) access(Read) buffer(%[[BUFFER]] : !util.buffer)[%c0 for %c1088] flags(%c0_i32) : !hal.file
+ // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
+ // CHECK: = hal.ex.file.from_memory device(%[[DEVICE]] : !hal.device) affinity(%c-1_i64) access(Read) buffer(%[[BUFFER]] : !util.buffer)[%c0 for %c1088] flags(%c0_i32) : !hal.file
%file = stream.file.constant %buffer[%c0 for %c1088] : !util.buffer{%c1088} -> !stream.file
return
}
@@ -18,8 +19,9 @@
%c0 = arith.constant 0 : index
%c0_i64 = arith.constant 0 : i64
%c1088 = arith.constant 1088 : index
+ // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
// CHECK: %[[SIGNAL:.+]] = hal.fence.create
- // CHECK: hal.device.queue.read<%device : !hal.device> affinity(%c-1_i64) wait(%[[WAIT]]) signal(%[[SIGNAL]]) source(%[[FILE]] : !hal.file)[%c0_i64] target(%[[RESOURCE]] : !hal.buffer)[%c0] length(%c1088) flags(0)
+ // CHECK: hal.device.queue.read<%[[DEVICE]] : !hal.device> affinity(%c-1_i64) wait(%[[WAIT]]) signal(%[[SIGNAL]]) source(%[[FILE]] : !hal.file)[%c0_i64] target(%[[RESOURCE]] : !hal.buffer)[%c0] length(%c1088) flags(0)
%signal = stream.file.read await(%wait) => %file[%c0_i64], %resource[%c0], %c1088 : !stream.file -> !stream.resource<variable>{%c1088} => !stream.timepoint
// CHECK: return %[[SIGNAL]]
return %signal : !stream.timepoint
@@ -33,8 +35,9 @@
%c0 = arith.constant 0 : index
%c0_i64 = arith.constant 0 : i64
%c1088 = arith.constant 1088 : index
+ // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
// CHECK: %[[SIGNAL:.+]] = hal.fence.create
- // CHECK: hal.device.queue.write<%device : !hal.device> affinity(%c-1_i64) wait(%[[WAIT]]) signal(%[[SIGNAL]]) source(%[[RESOURCE]] : !hal.buffer)[%c0] target(%[[FILE]] : !hal.file)[%c0_i64] length(%c1088) flags(0)
+ // CHECK: hal.device.queue.write<%[[DEVICE]] : !hal.device> affinity(%c-1_i64) wait(%[[WAIT]]) signal(%[[SIGNAL]]) source(%[[RESOURCE]] : !hal.buffer)[%c0] target(%[[FILE]] : !hal.file)[%c0_i64] length(%c1088) flags(0)
%signal = stream.file.write await(%wait) => %resource[%c0], %file[%c0_i64], %c1088 : !stream.resource<variable>{%c1088} -> !stream.file => !stream.timepoint
// CHECK: return %[[SIGNAL]]
return %signal : !stream.timepoint
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/timepoint_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/timepoint_ops.mlir
index 7d71e1d..cca49b1 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/timepoint_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/timepoint_ops.mlir
@@ -45,7 +45,7 @@
// CHECK-LABEL: @timepointChainExternal
// CHECK-SAME: (%[[TIMEPOINT:.+]]: !hal.fence, %[[SIGNAL:.+]]: !hal.fence)
func.func @timepointChainExternal(%timepoint: !stream.timepoint, %signal: !hal.fence) {
- // CHECK: %[[DEVICE:.+]] = hal.ex.shared_device
+ // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
// CHECK: hal.device.queue.execute<%[[DEVICE]] : !hal.device> affinity(%c-1_i64) wait(%[[TIMEPOINT]]) signal(%[[SIGNAL]])
stream.timepoint.chain_external %timepoint => (%signal : !hal.fence)
return
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index 3711768..2abfd23 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -332,11 +332,6 @@
// hal.ex.*
//===----------------------------------------------------------------------===//
-void ExSharedDeviceOp::getAsmResultNames(
- function_ref<void(Value, StringRef)> setNameFn) {
- setNameFn(getResult(), "device");
-}
-
void ExFileFromMemoryOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "memory_file");
@@ -924,6 +919,27 @@
}
//===----------------------------------------------------------------------===//
+// hal.devices.*
+//===----------------------------------------------------------------------===//
+
+void DevicesCountOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "device_count");
+}
+
+void DevicesGetOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ APInt index;
+ if (matchPattern(getIndex(), m_ConstantInt(&index))) {
+ llvm::SmallString<16> str("device_");
+ index.toStringUnsigned(str);
+ setNameFn(getResult(), str);
+ } else {
+ setNameFn(getResult(), "device_n");
+ }
+}
+
+//===----------------------------------------------------------------------===//
// hal.executable.source
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
index 50414a6..059e451 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -37,24 +37,6 @@
let opDocGroup = OpGroupExperimentalOps in {
-def HAL_ExSharedDeviceOp : HAL_PureOp<"ex.shared_device", [
- DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
- ]> {
- let results = (outs
- HAL_Device:$result
- );
-
- let assemblyFormat = "attr-dict `:` type($result)";
-
- let skipDefaultBuilders = 1;
- let builders = [
- OpBuilder<(ins),
- [{
- $_state.addTypes({DeviceType::get($_builder.getContext())});
- }]>,
- ];
-}
-
def HAL_ExFileFromMemoryOp : HAL_Op<"ex.file.from_memory", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
]> {
@@ -1864,6 +1846,57 @@
} // OpGroupDeviceOps
//===----------------------------------------------------------------------===//
+// !hal.device management
+//===----------------------------------------------------------------------===//
+
+def OpGroupDeviceManagementOps : OpDocGroup {
+ let summary = "Device management ops";
+ let description = "Device availability and selection support.";
+}
+
+let opDocGroup = OpGroupDeviceManagementOps in {
+
+def HAL_DevicesCountOp : HAL_PureOp<"devices.count", [
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+]> {
+ let summary = [{returns the number of available devices}];
+ let description = [{
+ Returns the total number of available devices registered at runtime.
+ }];
+
+ let results = (outs
+ Index:$result
+ );
+
+ let assemblyFormat = [{
+ attr-dict `:` type($result)
+ }];
+}
+
+def HAL_DevicesGetOp : HAL_PureOp<"devices.get", [
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+]> {
+ let summary = [{returns the device with the given index}];
+ let description = [{
+ Returns the device with the given index in the [0, hal.devices.count) range.
+ Devices may be lazily initialized upon first use.
+ }];
+
+ let arguments = (ins
+ Index:$index
+ );
+ let results = (outs
+ HAL_Device:$result
+ );
+
+ let assemblyFormat = [{
+ $index attr-dict `:` type($result)
+ }];
+}
+
+} // OpGroupDeviceManagementOps
+
+//===----------------------------------------------------------------------===//
// !hal.executable / iree_hal_executable_t
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
index bcf3b84..9359e0f 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
@@ -218,6 +218,13 @@
loc, builder.getType<IREE::HAL::BufferType>(), value));
}
+// static
+Value DeviceType::resolveAny(Location loc, OpBuilder &builder) {
+ Value deviceIndex = builder.create<arith::ConstantIndexOp>(loc, 0);
+ return builder.create<IREE::HAL::DevicesGetOp>(
+ loc, builder.getType<IREE::HAL::DeviceType>(), deviceIndex);
+}
+
//===----------------------------------------------------------------------===//
// #hal.device.target
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h
index acdca66..963f404 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h
@@ -125,6 +125,12 @@
using Base::Base;
static constexpr StringLiteral name = "hal.device";
+
+ // Resolves to any device at runtime.
+ // This is unlikely to be what any particular caller wants and should be
+ // avoided outside of testing/debugging passes that don't care about
+ // multi-targeting.
+ static Value resolveAny(Location loc, OpBuilder &builder);
};
struct EventType : public Type::TypeBase<EventType, Type, TypeStorage> {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/IR/test/BUILD.bazel
index 013db09..c8e6437 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/BUILD.bazel
@@ -27,6 +27,7 @@
"descriptor_set_ops.mlir",
"device_folding.mlir",
"device_ops.mlir",
+ "devices_ops.mlir",
"executable_folding.mlir",
"executable_ops.mlir",
"executable_targets.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/IR/test/CMakeLists.txt
index 55bd72b..f1fb349 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/CMakeLists.txt
@@ -25,6 +25,7 @@
"descriptor_set_ops.mlir"
"device_folding.mlir"
"device_ops.mlir"
+ "devices_ops.mlir"
"executable_folding.mlir"
"executable_ops.mlir"
"executable_targets.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir
index 392dc87..3e1b94d 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir
@@ -1,18 +1,17 @@
// RUN: iree-opt --split-input-file --canonicalize %s | iree-opt --split-input-file | FileCheck %s
// CHECK-LABEL: @skip_command_buffer_device
-func.func @skip_command_buffer_device() -> !hal.executable {
- // CHECK: %[[DEVICE:.+]] = hal.ex.shared_device
- %dev = hal.ex.shared_device : !hal.device
- %cmd = hal.command_buffer.create device(%dev : !hal.device)
+// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device)
+func.func @skip_command_buffer_device(%device: !hal.device) -> !hal.executable {
+ %cmd = hal.command_buffer.create device(%device : !hal.device)
mode(OneShot)
- categories("Transfer|Dispatch") : !hal.command_buffer
+ categories("Transfer|Dispatch") : !hal.command_buffer
// CHECK-NOT: hal.command_buffer.device
// CHECK: = hal.executable.lookup device(%[[DEVICE]] : !hal.device)
// CHECK-SAME: executable(@executable_name) : !hal.executable
- %0 = hal.command_buffer.device<%cmd : !hal.command_buffer> : !hal.device
- %exe = hal.executable.lookup device(%dev : !hal.device)
+ %device2 = hal.command_buffer.device<%cmd : !hal.command_buffer> : !hal.device
+ %exe = hal.executable.lookup device(%device2 : !hal.device)
executable(@executable_name) : !hal.executable
return %exe : !hal.executable
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/devices_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/devices_ops.mlir
new file mode 100644
index 0000000..633318b
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/devices_ops.mlir
@@ -0,0 +1,18 @@
+// RUN: iree-opt --split-input-file %s | iree-opt --split-input-file | FileCheck %s
+
+// CHECK-LABEL: @devices_count
+func.func @devices_count() -> index {
+ // CHECK: = hal.devices.count : index
+ %device_count = hal.devices.count : index
+ return %device_count : index
+}
+
+// -----
+
+// CHECK-LABEL: @devices_get
+// CHECK-SAME: (%[[INDEX:.+]]: index)
+func.func @devices_get(%index: index) -> !hal.device {
+ // CHECK: = hal.devices.get %[[INDEX]] : !hal.device
+ %device = hal.devices.get %index : !hal.device
+ return %device : !hal.device
+}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/experimental_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/experimental_ops.mlir
index 3fb6585..3c8fd84 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/experimental_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/experimental_ops.mlir
@@ -1,14 +1,5 @@
// RUN: iree-opt --split-input-file %s | iree-opt --split-input-file | FileCheck %s
-// CHECK-LABEL: @shared_device
-func.func @shared_device() -> !hal.device {
- // CHECK: %device = hal.ex.shared_device : !hal.device
- %device = hal.ex.shared_device : !hal.device
- return %device : !hal.device
-}
-
-// -----
-
// CHECK-LABEL: @file_from_memory
// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[BUFFER:.+]]: !util.buffer)
func.func @file_from_memory(%device: !hal.device, %buffer: !util.buffer) -> !hal.file {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
index 4a6f515..5fc02a5 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
@@ -173,8 +173,8 @@
auto initBuilder = OpBuilder::atBlockBegin(initOp.addEntryBlock());
IndexSet indexSet(loc, initBuilder);
- // TODO(benvanik): real device lookup.
- auto device = initBuilder.create<IREE::HAL::ExSharedDeviceOp>(loc);
+ // TODO(multi-device): support multiple devices in benchmark generation.
+ Value device = IREE::HAL::DeviceType::resolveAny(loc, initBuilder);
auto allocator =
initBuilder.create<IREE::HAL::DeviceAllocatorOp>(loc, device).getResult();
@@ -242,8 +242,8 @@
auto batchSizeArg = funcBuilder.create<arith::IndexCastOp>(
loc, funcBuilder.getIndexType(), entryBlock->getArgument(0));
- // TODO(benvanik): real device lookup.
- auto device = funcBuilder.create<IREE::HAL::ExSharedDeviceOp>(loc);
+ // TODO(multi-device): support multiple devices in benchmark generation.
+ Value device = IREE::HAL::DeviceType::resolveAny(loc, funcBuilder);
// Create and begin command buffer.
// TODO(benvanik): reuse the command buffer (initialize once and store).
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
index 29f8580..3b0bb79 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
@@ -136,10 +136,11 @@
auto initializerOp = moduleBuilder.create<IREE::Util::InitializerOp>(loc);
OpBuilder blockBuilder =
OpBuilder::atBlockEnd(initializerOp.addEntryBlock());
- auto deviceValue = blockBuilder.createOrFold<ExSharedDeviceOp>(loc);
- auto layoutValue = blockBuilder.createOrFold<DescriptorSetLayoutCreateOp>(
- loc, layoutType, deviceValue, flags, bindingAttrs);
- blockBuilder.create<IREE::Util::GlobalStoreOp>(loc, layoutValue,
+ // TODO(multi-device): pass in resolve info to the call and reuse.
+ Value device = IREE::HAL::DeviceType::resolveAny(loc, blockBuilder);
+ Value layout = blockBuilder.createOrFold<DescriptorSetLayoutCreateOp>(
+ loc, layoutType, device, flags, bindingAttrs);
+ blockBuilder.create<IREE::Util::GlobalStoreOp>(loc, layout,
globalOp.getName());
blockBuilder.create<IREE::Util::InitializerReturnOp>(loc);
@@ -188,12 +189,13 @@
setLayoutGlobalOp.getSymName());
setLayoutValues.push_back(setLayoutValue);
}
- auto deviceValue = blockBuilder.createOrFold<ExSharedDeviceOp>(loc);
- auto layoutValue = blockBuilder.createOrFold<PipelineLayoutCreateOp>(
- loc, layoutType, deviceValue,
+ // TODO(multi-device): pass in resolve info to the call and reuse.
+ Value device = IREE::HAL::DeviceType::resolveAny(loc, blockBuilder);
+ Value layout = blockBuilder.createOrFold<PipelineLayoutCreateOp>(
+ loc, layoutType, device,
blockBuilder.getIndexAttr(layoutAttr.getPushConstants()),
setLayoutValues);
- blockBuilder.create<IREE::Util::GlobalStoreOp>(loc, layoutValue,
+ blockBuilder.create<IREE::Util::GlobalStoreOp>(loc, layout,
globalOp.getName());
blockBuilder.create<IREE::Util::InitializerReturnOp>(loc);
@@ -214,7 +216,8 @@
auto initializerOp = moduleBuilder.create<IREE::Util::InitializerOp>(loc);
OpBuilder blockBuilder =
OpBuilder::atBlockEnd(initializerOp.addEntryBlock());
- auto deviceValue = blockBuilder.createOrFold<ExSharedDeviceOp>(loc);
+ // TODO(multi-device): pass in resolve info to the call and reuse.
+ Value device = IREE::HAL::DeviceType::resolveAny(loc, blockBuilder);
// Create a switch statement with a case for each variant.
// Each case should then cache only executables which contain a matching
@@ -232,7 +235,7 @@
Value selectedIndex = buildIfElseTree(
loc, caseVariantOps.size(),
[&](Location loc, size_t i, OpBuilder &builder) {
- return caseVariantOps[i].buildCondition(deviceValue, builder);
+ return caseVariantOps[i].buildCondition(device, builder);
},
blockBuilder);
@@ -261,18 +264,18 @@
SmallVector<Value> constantValues;
for (auto blockOp :
llvm::make_early_inc_range(variantOp.getConstantBlockOps())) {
- constantValues.append(inlineConstantBlockOp(blockOp, moduleBuilder,
- caseBuilder, deviceValue));
+ constantValues.append(
+ inlineConstantBlockOp(blockOp, moduleBuilder, caseBuilder, device));
blockOp.erase();
}
- auto executableValue = caseBuilder.createOrFold<ExecutableCreateOp>(
- loc, executableType, deviceValue,
+ Value executable = caseBuilder.createOrFold<ExecutableCreateOp>(
+ loc, executableType, device,
SymbolRefAttr::get(executableOp.getSymNameAttr(),
{SymbolRefAttr::get(variantOp.getSymNameAttr())}),
pipelineLayoutValues, constantValues);
- caseBuilder.create<scf::YieldOp>(loc, executableValue);
+ caseBuilder.create<scf::YieldOp>(loc, executable);
}
// Fallback for no available variant.
@@ -299,7 +302,7 @@
SmallVector<Value> inlineConstantBlockOp(ExecutableConstantBlockOp blockOp,
OpBuilder &moduleBuilder,
OpBuilder &callerBuilder,
- Value deviceValue) {
+ Value device) {
// Create the function with the region contents of the constant block.
auto funcName = (StringRef("__constant_block_") +
std::to_string(nextUniqueConstantBlockId++))
@@ -320,7 +323,7 @@
// Create the call passing in the device if needed.
SmallVector<Value> callOperands;
if (funcOp.getNumArguments() > 0) {
- callOperands.push_back(deviceValue);
+ callOperands.push_back(device);
}
auto callOp = callerBuilder.create<func::CallOp>(blockOp.getLoc(), funcOp,
callOperands);
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp
index 4e47212..80845d3 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp
@@ -44,6 +44,7 @@
auto fullKey = ArrayAttr::get(
moduleOp.getContext(),
{
+ // TODO(multi-device): add attr key on device resolve source.
StringAttr::get(moduleOp.getContext(),
queryOp.getCategory() + queryOp.getKey()),
queryOp.getDefaultValue().has_value()
@@ -90,8 +91,8 @@
auto initializerOp =
moduleBuilder.create<IREE::Util::InitializerOp>(fusedLoc);
auto funcBuilder = OpBuilder::atBlockBegin(initializerOp.addEntryBlock());
- auto device =
- funcBuilder.createOrFold<IREE::HAL::ExSharedDeviceOp>(fusedLoc);
+ // TODO(multi-device): pass in resolve info to the call and reuse.
+ Value device = IREE::HAL::DeviceType::resolveAny(fusedLoc, funcBuilder);
auto queryOp = funcBuilder.create<IREE::HAL::DeviceQueryOp>(
fusedLoc, funcBuilder.getI1Type(), queryType, device,
anyQueryOp.getCategoryAttr(), anyQueryOp.getKeyAttr(),
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir
index 1b9e70a..d1540ba 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir
@@ -59,7 +59,7 @@
// CHECK: %[[ARG0_BUFFER:.+]] = hal.buffer_view.buffer<%[[ARG0]] : !hal.buffer_view> : !hal.buffer
// (annoyingly out of order)
- // CHECK-DAG: %[[DEVICE:.+]] = hal.ex.shared_device : !hal.device
+ // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
// CHECK-DAG: %[[ALLOCATOR:.+]] = hal.device.allocator<%[[DEVICE]] : !hal.device> : !hal.allocator
// CHECK: hal.buffer.assert<%[[ARG0_BUFFER]] : !hal.buffer>
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir
index 8d468b0..ce67b16 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir
@@ -2,15 +2,15 @@
// CHECK: util.global private @_descriptor_set_layout_0 : !hal.descriptor_set_layout
// CHECK-NEXT: util.initializer {
-// CHECK-NEXT: %device = hal.ex.shared_device : !hal.device
-// CHECK-NEXT: %descriptor_set_layout = hal.descriptor_set_layout.create
-// CHECK-SAME: device(%device : !hal.device)
+// CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
+// CHECK-NEXT: %[[LAYOUT:.+]] = hal.descriptor_set_layout.create
+// CHECK-SAME: device(%[[DEVICE]] : !hal.device)
// CHECK-SAME: flags("None")
// CHECK-SAME: bindings([
// CHECK-SAME: #hal.descriptor_set.binding<0, storage_buffer>,
// CHECK-SAME: #hal.descriptor_set.binding<1, storage_buffer>
// CHECK-SAME: ]) : !hal.descriptor_set_layout
-// CHECK-NEXT: util.global.store %descriptor_set_layout, @_descriptor_set_layout_0 : !hal.descriptor_set_layout
+// CHECK-NEXT: util.global.store %[[LAYOUT]], @_descriptor_set_layout_0 : !hal.descriptor_set_layout
// CHECK-LABEL: @descriptorSetLayoutLookup
func.func @descriptorSetLayoutLookup(%device : !hal.device) -> !hal.descriptor_set_layout {
@@ -31,13 +31,13 @@
// CHECK: util.global private @_pipeline_layout_0 : !hal.pipeline_layout
// CHECK-NEXT: util.initializer {
-// CHECK-NEXT: %[[SET0:.+]] = util.global.load @_descriptor_set_layout_0 : !hal.descriptor_set_layout
-// CHECK-NEXT: %device = hal.ex.shared_device : !hal.device
-// CHECK-NEXT: %pipeline_layout = hal.pipeline_layout.create
-// CHECK-SAME: device(%device : !hal.device)
+// CHECK-DAG: %[[SET0:.+]] = util.global.load @_descriptor_set_layout_0 : !hal.descriptor_set_layout
+// CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
+// CHECK-NEXT: %[[LAYOUT:.+]] = hal.pipeline_layout.create
+// CHECK-SAME: device(%[[DEVICE]] : !hal.device)
// CHECK-SAME: push_constants(1)
// CHECK-SAME: layouts([%[[SET0]]]) : !hal.pipeline_layout
-// CHECK-NEXT: util.global.store %pipeline_layout, @_pipeline_layout_0 : !hal.pipeline_layout
+// CHECK-NEXT: util.global.store %[[LAYOUT]], @_pipeline_layout_0 : !hal.pipeline_layout
// CHECK-LABEL: @exeLayoutLookup
func.func @exeLayoutLookup(%device : !hal.device) -> !hal.pipeline_layout {
@@ -60,14 +60,14 @@
// CHECK: util.global private @_pipeline_layout_0 : !hal.pipeline_layout
// CHECK-NEXT: util.initializer {
-// CHECK-NEXT: %[[SET0:.+]] = util.global.load @_descriptor_set_layout_0 : !hal.descriptor_set_layout
-// CHECK-NEXT: %[[SET1:.+]] = util.global.load @_descriptor_set_layout_1 : !hal.descriptor_set_layout
-// CHECK-NEXT: %device = hal.ex.shared_device : !hal.device
-// CHECK-NEXT: %pipeline_layout = hal.pipeline_layout.create
-// CHECK-SAME: device(%device : !hal.device)
+// CHECK-DAG: %[[SET0:.+]] = util.global.load @_descriptor_set_layout_0 : !hal.descriptor_set_layout
+// CHECK-DAG: %[[SET1:.+]] = util.global.load @_descriptor_set_layout_1 : !hal.descriptor_set_layout
+// CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
+// CHECK-NEXT: %[[LAYOUT:.+]] = hal.pipeline_layout.create
+// CHECK-SAME: device(%[[DEVICE]] : !hal.device)
// CHECK-SAME: push_constants(1)
// CHECK-SAME: layouts([%[[SET0]], %[[SET1]]]) : !hal.pipeline_layout
-// CHECK-NEXT: util.global.store %pipeline_layout, @_pipeline_layout_0 : !hal.pipeline_layout
+// CHECK-NEXT: util.global.store %[[LAYOUT]], @_pipeline_layout_0 : !hal.pipeline_layout
// CHECK-LABEL: @sharedLayoutLookup
func.func @sharedLayoutLookup(%device : !hal.device) -> !hal.pipeline_layout {
@@ -161,14 +161,16 @@
// CHECK-NEXT: util.initializer {
// Switch on the supported formats:
-// CHECK: %[[DEVICE:.+]] = hal.ex.shared_device : !hal.device
+// CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
// CHECK: %{{.+}}, %[[FORMAT_VMVX:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.executable.format" :: "vmvx-bytecode-fb")
// CHECK: %[[VMVX_CONDITION:.+]] = scf.execute_region -> i1 {
// CHECK: %{{.+}}, %[[FEATURE:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("some" :: "feature")
// CHECK: scf.yield %[[FEATURE]]
// CHECK: }
// CHECK: %[[VMVX_VARIANT_SELECTED:.+]] = arith.andi %[[FORMAT_VMVX]], %[[VMVX_CONDITION]]
-// CHECK: %[[VARIANT_INDEX:.+]] = arith.select %[[VMVX_VARIANT_SELECTED]], %c0, %c-1
+// CHECK-DAG: %[[VARIANT_VMVX:.+]] = arith.constant 0
+// CHECK-DAG: %[[VARIANT_DEFAULT:.+]] = arith.constant -1
+// CHECK: %[[VARIANT_INDEX:.+]] = arith.select %[[VMVX_VARIANT_SELECTED]], %[[VARIANT_VMVX]], %[[VARIANT_DEFAULT]]
// CHECK: %[[RET:.+]] = scf.index_switch %[[VARIANT_INDEX]] -> !hal.executable
// CHECK: case 0 {
@@ -244,7 +246,8 @@
util.global private @_descriptor_set_layout_0 : !hal.descriptor_set_layout
util.initializer {
- %device = hal.ex.shared_device : !hal.device
+ %c0 = arith.constant 0 : index
+ %device = hal.devices.get %c0 : !hal.device
%descriptor_set_layout = hal.descriptor_set_layout.create device(%device : !hal.device) flags("None") bindings([#hal.descriptor_set.binding<0, storage_buffer>, #hal.descriptor_set.binding<1, storage_buffer>]) : !hal.descriptor_set_layout
util.global.store %descriptor_set_layout, @_descriptor_set_layout_0 : !hal.descriptor_set_layout
util.initializer.return
@@ -253,7 +256,8 @@
util.global private @_pipeline_layout_0 : !hal.pipeline_layout
util.initializer {
%_descriptor_set_layout_0 = util.global.load @_descriptor_set_layout_0 : !hal.descriptor_set_layout
- %device = hal.ex.shared_device : !hal.device
+ %c0 = arith.constant 0 : index
+ %device = hal.devices.get %c0 : !hal.device
%pipeline_layout = hal.pipeline_layout.create device(%device : !hal.device) push_constants(0) layouts([%_descriptor_set_layout_0]) : !hal.pipeline_layout
util.global.store %pipeline_layout, @_pipeline_layout_0 : !hal.pipeline_layout
util.initializer.return
@@ -261,9 +265,9 @@
util.global private @_executable_exe : !hal.executable
util.initializer {
- %device = hal.ex.shared_device : !hal.device
- %format_ok, %format_supported = hal.device.query<%device : !hal.device> key("hal.executable.format" :: "some-format") : i1, i1
%c0 = arith.constant 0 : index
+ %device = hal.devices.get %c0 : !hal.device
+ %format_ok, %format_supported = hal.device.query<%device : !hal.device> key("hal.executable.format" :: "some-format") : i1, i1
%c-1 = arith.constant -1 : index
%variant = arith.select %format_supported, %c0, %c-1 : index
%selected = scf.index_switch %variant -> !hal.executable
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/memoize_device_queries.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/memoize_device_queries.mlir
index 6884811..1fd8492 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/memoize_device_queries.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/memoize_device_queries.mlir
@@ -3,7 +3,7 @@
// CHECK: util.global private @_device_query_0 : i1
// CHECK-NEXT: util.global private @_device_query_0_ok : i1
// CHECK-NEXT: util.initializer {
-// CHECK-NEXT: %[[DEVICE:.+]] = hal.ex.shared_device : !hal.device
+// CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
// CHECK-NEXT: %[[OK0:.+]], %[[VALUE0:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.device.id" :: "id0*") : i1, i1 = false
// CHECK-NEXT: util.global.store %[[OK0]], @_device_query_0_ok : i1
// CHECK-NEXT: util.global.store %[[VALUE0]], @_device_query_0 : i1
@@ -11,7 +11,7 @@
// CHECK: util.global private @_device_query_1 : i1
// CHECK-NEXT: util.global private @_device_query_1_ok : i1
// CHECK-NEXT: util.initializer {
-// CHECK-NEXT: %[[DEVICE:.+]] = hal.ex.shared_device : !hal.device
+// CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
// CHECK-NEXT: %[[OK1:.+]], %[[VALUE1:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.device.id" :: "id1") : i1, i1 = false
// CHECK-NEXT: util.global.store %[[OK1]], @_device_query_1_ok : i1
// CHECK-NEXT: util.global.store %[[VALUE1]], @_device_query_1 : i1
diff --git a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir
index 5166f5a..445c508 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir
@@ -8,9 +8,6 @@
// Experimental/temporary ops
//===----------------------------------------------------------------------===//
-vm.import private @ex.shared_device() -> !vm.ref<!hal.device>
-attributes {nosideeffects}
-
// Creates a file mapped into a byte range of a host buffer.
// EXPERIMENTAL: may be removed in future versions.
vm.import private @ex.file.from_memory(
@@ -419,6 +416,22 @@
)
//===----------------------------------------------------------------------===//
+// iree_hal_device_t management
+//===----------------------------------------------------------------------===//
+
+vm.import private @devices.count() -> i32
+attributes {
+ minimum_version = 2 : i32,
+ nosideeffects
+}
+
+vm.import private @devices.get(%index : i32) -> !vm.ref<!hal.device>
+attributes {
+ minimum_version = 2 : i32,
+ nosideeffects
+}
+
+//===----------------------------------------------------------------------===//
// iree_hal_executable_t
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp b/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp
index c08f5b6..9cf13da 100644
--- a/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp
+++ b/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp
@@ -68,7 +68,8 @@
state.addAttributes(srcOp->getAttrs());
// Add device argument.
- Value device = rewriter.create<IREE::HAL::ExSharedDeviceOp>(srcOp->getLoc());
+ // TODO(multi-device): support multiple devices in check tests .
+ Value device = IREE::HAL::DeviceType::resolveAny(srcOp->getLoc(), rewriter);
state.addOperands({device});
for (auto [srcOperand, dstOperand] :
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/parameter_ops.mlir b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/parameter_ops.mlir
index eb75719..92035c3 100644
--- a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/parameter_ops.mlir
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/parameter_ops.mlir
@@ -7,7 +7,7 @@
%c51_i64 = arith.constant 51 : i64
%c100 = arith.constant 100 : index
%c101 = arith.constant 101 : index
- // CHECK-DAG: %[[DEVICE:.+]] = hal.ex.shared_device
+ // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
// CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
// CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
// CHECK: %[[BUFFERS:.+]]:2 = io_parameters.load<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
@@ -30,7 +30,7 @@
func.func @parameterLoadNoScope(%wait: !stream.timepoint) -> (!stream.resource<constant>, !stream.timepoint) {
%c50_i64 = arith.constant 50 : i64
%c100 = arith.constant 100 : index
- // CHECK-DAG: %[[DEVICE:.+]] = hal.ex.shared_device
+ // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
// CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
// CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
// CHECK: %[[BUFFER:.+]] = io_parameters.load<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
@@ -53,7 +53,7 @@
%c100 = arith.constant 100 : index
%c200 = arith.constant 200 : index
%c300 = arith.constant 300 : index
- // CHECK-DAG: %[[DEVICE:.+]] = hal.ex.shared_device
+ // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
// CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
// CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
// CHECK: io_parameters.gather<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
@@ -73,7 +73,7 @@
%c100 = arith.constant 100 : index
%c200 = arith.constant 200 : index
%c300 = arith.constant 300 : index
- // CHECK-DAG: %[[DEVICE:.+]] = hal.ex.shared_device
+ // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
// CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
// CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
// CHECK: io_parameters.scatter<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
@@ -99,7 +99,7 @@
%c201 = arith.constant 201 : index
%c202 = arith.constant 202 : index
%c300 = arith.constant 300 : index
- // CHECK-DAG: %[[DEVICE:.+]] = hal.ex.shared_device
+ // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
// CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
// CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
// CHECK: io_parameters.gather<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
@@ -128,7 +128,7 @@
%c200 = arith.constant 200 : index
%c201 = arith.constant 201 : index
%c300 = arith.constant 300 : index
- // CHECK-DAG: %[[DEVICE:.+]] = hal.ex.shared_device
+ // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
// CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
// CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
// CHECK: io_parameters.gather<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
@@ -158,7 +158,7 @@
%c201 = arith.constant 201 : index
%c202 = arith.constant 202 : index
%c300 = arith.constant 300 : index
- // CHECK-DAG: %[[DEVICE:.+]] = hal.ex.shared_device
+ // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}}
// CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1
// CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device)
// CHECK: io_parameters.scatter<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]])
diff --git a/docs/website/docs/developers/general/developer-tips.md b/docs/website/docs/developers/general/developer-tips.md
index c830cfd..f4e2090 100644
--- a/docs/website/docs/developers/general/developer-tips.md
+++ b/docs/website/docs/developers/general/developer-tips.md
@@ -123,8 +123,8 @@
hal, version >= 0, required
Imported Functions:
- [ 0] hal.ex.shared_device() -> (!vm.ref<?>)
- [ 1] hal.allocator.allocate(!vm.ref<?>, i32, i32, i64) -> (!vm.ref<?>)
+ [ 0] hal.allocator.allocate(!vm.ref<?>, i32, i32, i64) -> (!vm.ref<?>)
+ [ 1] hal.devices.get(i32) -> (!vm.ref<?>)
...
Exported Functions:
diff --git a/integrations/pjrt/src/iree_pjrt/common/api_impl.cc b/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
index d6f1dc4..a49faa1 100644
--- a/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
+++ b/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
@@ -1597,8 +1597,8 @@
// HAL module.
modules.push_back({});
IREE_RETURN_IF_ERROR(iree_hal_module_create(
- vm_instance(), hal_device, IREE_HAL_MODULE_FLAG_NONE, host_allocator(),
- &modules.back()));
+ vm_instance(), /*device_count=*/1, &hal_device, IREE_HAL_MODULE_FLAG_NONE,
+ host_allocator(), &modules.back()));
// Main module.
modules.push_back(main_module);
diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc
index 9843c68..2460b36 100644
--- a/runtime/bindings/python/hal.cc
+++ b/runtime/bindings/python/hal.cc
@@ -768,10 +768,12 @@
// HAL module
//------------------------------------------------------------------------------
+// TODO(multi-device): allow for multiple devices to be passed in.
VmModule CreateHalModule(VmInstance* instance, HalDevice* device) {
+ iree_hal_device_t* device_ptr = device->raw_ptr();
iree_vm_module_t* module = NULL;
- CheckApiStatus(iree_hal_module_create(instance->raw_ptr(), device->raw_ptr(),
- IREE_HAL_MODULE_FLAG_NONE,
+ CheckApiStatus(iree_hal_module_create(instance->raw_ptr(), /*device_count=*/1,
+ &device_ptr, IREE_HAL_MODULE_FLAG_NONE,
iree_allocator_system(), &module),
"Error creating hal module");
return VmModule::StealFromRawPtr(module);
diff --git a/runtime/bindings/tflite/interpreter.c b/runtime/bindings/tflite/interpreter.c
index 3dae25d..6f0120b 100644
--- a/runtime/bindings/tflite/interpreter.c
+++ b/runtime/bindings/tflite/interpreter.c
@@ -60,9 +60,10 @@
"failed creating the default device for driver '%.*s'",
(int)driver_name.size, driver_name.data);
- IREE_RETURN_IF_ERROR(iree_hal_module_create(
- interpreter->instance, interpreter->device, IREE_HAL_MODULE_FLAG_NONE,
- interpreter->allocator, &interpreter->hal_module));
+ IREE_RETURN_IF_ERROR(
+ iree_hal_module_create(interpreter->instance, /*device_count=*/1,
+ &interpreter->device, IREE_HAL_MODULE_FLAG_NONE,
+ interpreter->allocator, &interpreter->hal_module));
return iree_ok_status();
}
diff --git a/runtime/src/iree/modules/check/check_test.cc b/runtime/src/iree/modules/check/check_test.cc
index 0966580..e6afaee 100644
--- a/runtime/src/iree/modules/check/check_test.cc
+++ b/runtime/src/iree/modules/check/check_test.cc
@@ -44,9 +44,9 @@
}
IREE_ASSERT_OK(iree_hal_driver_create_default_device(
hal_driver, iree_allocator_system(), &device_));
- IREE_ASSERT_OK(
- iree_hal_module_create(instance_, device_, IREE_HAL_MODULE_FLAG_NONE,
- iree_allocator_system(), &hal_module_));
+ IREE_ASSERT_OK(iree_hal_module_create(
+ instance_, /*device_count=*/1, &device_, IREE_HAL_MODULE_FLAG_NONE,
+ iree_allocator_system(), &hal_module_));
iree_hal_driver_release(hal_driver);
IREE_ASSERT_OK(iree_check_module_create(instance_, iree_allocator_system(),
diff --git a/runtime/src/iree/modules/check/test/success.mlir b/runtime/src/iree/modules/check/test/success.mlir
index 45b7939..7c5012e 100644
--- a/runtime/src/iree/modules/check/test/success.mlir
+++ b/runtime/src/iree/modules/check/test/success.mlir
@@ -13,7 +13,8 @@
}
func.func @expect_all_true() {
- %device = hal.ex.shared_device : !hal.device
+ %c0 = arith.constant 0 : index
+ %device = hal.devices.get %c0 : !hal.device
%all_true = util.unfoldable_constant dense<1> : tensor<2x2xi32>
%all_true_view = hal.tensor.export %all_true : tensor<2x2xi32> -> !hal.buffer_view
check.expect_all_true<%device>(%all_true_view) : !hal.buffer_view
diff --git a/runtime/src/iree/modules/hal/exports.inl b/runtime/src/iree/modules/hal/exports.inl
index c9eae4b..e61638a 100644
--- a/runtime/src/iree/modules/hal/exports.inl
+++ b/runtime/src/iree/modules/hal/exports.inl
@@ -71,8 +71,10 @@
EXPORT_FN("device.queue.read", iree_hal_module_device_queue_read, rIrrrIrIIi, v)
EXPORT_FN("device.queue.write", iree_hal_module_device_queue_write, rIrrrIrIIi, v)
+EXPORT_FN("devices.count", iree_hal_module_devices_count, v, i)
+EXPORT_FN("devices.get", iree_hal_module_devices_get, i, r)
+
EXPORT_FN("ex.file.from_memory", iree_hal_module_ex_file_from_memory, rIirIIi, r)
-EXPORT_FN("ex.shared_device", iree_hal_module_ex_shared_device, v, r)
EXPORT_FN("executable.create", iree_hal_module_executable_create, rrrrCrD, r)
diff --git a/runtime/src/iree/modules/hal/module.c b/runtime/src/iree/modules/hal/module.c
index 57a4cc6..d00c39a 100644
--- a/runtime/src/iree/modules/hal/module.c
+++ b/runtime/src/iree/modules/hal/module.c
@@ -32,18 +32,26 @@
// Module type definitions
//===----------------------------------------------------------------------===//
-#define IREE_HAL_MODULE_VERSION_0_1 0x00000001u
-#define IREE_HAL_MODULE_VERSION_LATEST IREE_HAL_MODULE_VERSION_0_1
+#define IREE_HAL_MODULE_VERSION_0_2 0x00000002u
+#define IREE_HAL_MODULE_VERSION_LATEST IREE_HAL_MODULE_VERSION_0_2
typedef struct iree_hal_module_t {
iree_allocator_t host_allocator;
iree_hal_module_flags_t flags;
- iree_hal_device_t* shared_device;
+ iree_host_size_t device_count;
+ iree_hal_device_t* devices[];
} iree_hal_module_t;
#define IREE_HAL_MODULE_CAST(module) \
(iree_hal_module_t*)((uint8_t*)(module) + iree_vm_native_module_size());
+static void IREE_API_PTR iree_hal_module_destroy(void* base_module) {
+ iree_hal_module_t* module = IREE_HAL_MODULE_CAST(base_module);
+ for (iree_host_size_t i = 0; i < module->device_count; ++i) {
+ iree_hal_device_release(module->devices[i]);
+ }
+}
+
typedef struct iree_hal_module_state_t {
iree_allocator_t host_allocator;
@@ -51,29 +59,26 @@
// application. All instantiations of a module share the same flags.
iree_hal_module_flags_t flags;
- // HACK: today we only support a single device per context - in the future
- // this should be a set of available devices that the module is able to pick
- // from - the module will then hang on to them and use them as native globals
- // instead of storing anything in module state here.
- iree_hal_device_t* shared_device;
+ // Total number of devices available to the module.
+ iree_host_size_t device_count;
+ // Devices referencing the storage in the parent module.
+ // Unretained as the parent module must remain live longer than any module
+ // state allocated from it and we can rely on it to keep the devices retained.
+ iree_hal_device_t** devices;
// TODO(benvanik): add iree_loop_t to module constructor.
// Status of the nested loop we run for executable creation today. We should
// instead be taking a loop upon creation and scheduling work against that.
iree_status_t loop_status;
- // Shared executable cache for all executables created in the context.
- // We could have multiple to allow for modules to create distinct sets of
- // executables like ones for training vs inference in the same model, or just
- // always use this.
- iree_hal_executable_cache_t* executable_cache;
+ // Shared executable cache for each device used to cache all executables
+ // created in the context. We could have multiple to allow for modules to
+ // create distinct sets of executables like ones for training vs inference in
+ // the same model or allow these to be injected so that multiple loaded
+ // contexts share the caches.
+ iree_hal_executable_cache_t* executable_caches[];
} iree_hal_module_state_t;
-static void IREE_API_PTR iree_hal_module_destroy(void* base_module) {
- iree_hal_module_t* module = IREE_HAL_MODULE_CAST(base_module);
- iree_hal_device_release(module->shared_device);
-}
-
static iree_status_t IREE_API_PTR
iree_hal_module_alloc_state(void* self, iree_allocator_t host_allocator,
iree_vm_module_state_t** out_module_state) {
@@ -81,24 +86,36 @@
iree_hal_module_t* module = IREE_HAL_MODULE_CAST(self);
iree_hal_module_state_t* state = NULL;
+ iree_host_size_t total_size =
+ sizeof(*state) +
+ module->device_count * sizeof(state->executable_caches[0]);
IREE_RETURN_AND_END_ZONE_IF_ERROR(
- z0,
- iree_allocator_malloc(host_allocator, sizeof(*state), (void**)&state));
- memset(state, 0, sizeof(*state));
+ z0, iree_allocator_malloc(host_allocator, total_size, (void**)&state));
+ memset(state, 0, total_size);
state->host_allocator = host_allocator;
state->flags = module->flags;
- state->shared_device = module->shared_device;
- iree_hal_device_retain(state->shared_device);
-
+ state->device_count = module->device_count;
+ state->devices = module->devices;
state->loop_status = iree_ok_status();
- IREE_RETURN_AND_END_ZONE_IF_ERROR(
- z0, iree_hal_executable_cache_create(
- state->shared_device, iree_string_view_empty(),
- iree_loop_inline(&state->loop_status), &state->executable_cache));
- *out_module_state = (iree_vm_module_state_t*)state;
+ iree_status_t status = iree_ok_status();
+ for (iree_host_size_t i = 0; i < state->device_count; ++i) {
+ status = iree_hal_executable_cache_create(
+ state->devices[i], iree_string_view_empty(),
+ iree_loop_inline(&state->loop_status), &state->executable_caches[i]);
+ if (!iree_status_is_ok(status)) break;
+ }
+
+ if (iree_status_is_ok(status)) {
+ *out_module_state = (iree_vm_module_state_t*)state;
+ } else {
+ for (iree_host_size_t i = 0; i < state->device_count; ++i) {
+ iree_hal_executable_cache_release(state->executable_caches[i]);
+ }
+ iree_allocator_free(host_allocator, state);
+ }
IREE_TRACE_ZONE_END(z0);
- return iree_ok_status();
+ return status;
}
static void IREE_API_PTR
@@ -106,23 +123,52 @@
IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_module_state_t* state = (iree_hal_module_state_t*)module_state;
- iree_hal_executable_cache_release(state->executable_cache);
+ for (iree_host_size_t i = 0; i < state->device_count; ++i) {
+ iree_hal_executable_cache_release(state->executable_caches[i]);
+ }
iree_status_ignore(state->loop_status);
- iree_hal_device_release(state->shared_device);
iree_allocator_free(state->host_allocator, state);
IREE_TRACE_ZONE_END(z0);
}
+// Returns an unretained reference to the executable cache for the given device.
+// If the same device is registered multiple times the first cache is returned.
+static iree_status_t iree_hal_module_state_lookup_executable_cache(
+ iree_hal_module_state_t* state, iree_hal_device_t* device,
+ iree_hal_executable_cache_t** out_executable_cache) {
+ IREE_ASSERT_ARGUMENT(state);
+ IREE_ASSERT_ARGUMENT(device);
+ IREE_ASSERT_ARGUMENT(out_executable_cache);
+ *out_executable_cache = NULL;
+ for (iree_host_size_t i = 0; i < state->device_count; ++i) {
+ if (state->devices[i] == device) {
+ *out_executable_cache = state->executable_caches[i];
+ return iree_ok_status();
+ }
+ }
+ return iree_make_status(
+ IREE_STATUS_NOT_FOUND,
+ "no executable cache for the given device found; possibly a device not "
+ "registered with the HAL module");
+}
+
static iree_status_t IREE_API_PTR iree_hal_module_notify(
void* self, iree_vm_module_state_t* module_state, iree_vm_signal_t signal) {
iree_hal_module_state_t* state = (iree_hal_module_state_t*)module_state;
switch (signal) {
case IREE_VM_SIGNAL_SUSPEND:
- case IREE_VM_SIGNAL_LOW_MEMORY:
- return iree_hal_device_trim(state->shared_device);
- default:
+ case IREE_VM_SIGNAL_LOW_MEMORY: {
+ for (iree_host_size_t i = 0; i < state->device_count; ++i) {
+ IREE_RETURN_IF_ERROR(iree_hal_device_trim(state->devices[i]));
+ }
return iree_ok_status();
+ }
+ default: {
+ // Ignored today but if we started managing device power down we could
+ // use this to wake them back up again.
+ return iree_ok_status();
+ }
}
}
@@ -150,13 +196,6 @@
// NOTE: Ex* APIs are experimental and likely to be removed soon. Modules
// using these APIs are not forward compatible.
-IREE_VM_ABI_EXPORT(iree_hal_module_ex_shared_device, //
- iree_hal_module_state_t, //
- v, r) {
- rets->r0 = iree_hal_device_retain_ref(state->shared_device);
- return iree_ok_status();
-}
-
static void iree_hal_module_file_buffer_release(
void* user_data, iree_io_file_handle_primitive_t handle_primitive) {
iree_vm_buffer_t* backing_buffer = (iree_vm_buffer_t*)user_data;
@@ -411,9 +450,8 @@
"load length byte count %d exceeds max", length);
}
- IREE_RETURN_IF_ERROR(iree_hal_device_transfer_d2h(
- state->shared_device, source_buffer, source_offset, &target_buffer,
- length, IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout()));
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_map_read(source_buffer, source_offset,
+ &target_buffer, length));
rets->i0 = target_buffer;
return iree_ok_status();
@@ -440,9 +478,8 @@
iree_hal_buffer_byte_length(target_buffer));
}
- return iree_hal_device_transfer_h2d(
- state->shared_device, &value, target_buffer, target_offset, length,
- IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout());
+ return iree_hal_buffer_map_write(target_buffer, target_offset, &value,
+ length);
}
//===----------------------------------------------------------------------===//
@@ -1104,6 +1141,30 @@
return iree_hal_device_queue_flush(device, queue_affinity);
}
+//===----------------------------------------------------------------------===//
+// iree_hal_device_t management
+//===----------------------------------------------------------------------===//
+
+IREE_VM_ABI_EXPORT(iree_hal_module_devices_count, //
+ iree_hal_module_state_t, //
+ v, i) {
+ rets->i0 = (int32_t)state->device_count;
+ return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_devices_get, //
+ iree_hal_module_state_t, //
+ i, r) {
+ if (args->i0 >= state->device_count) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "device index %d out of bounds (%" PRIhsz
+ " devices available)",
+ args->i0, state->device_count);
+ }
+ rets->r0 = iree_hal_device_retain_ref(state->devices[args->i0]);
+ return iree_ok_status();
+}
+
//===--------------------------------------------------------------------===//
// iree_hal_executable_t
//===--------------------------------------------------------------------===//
@@ -1135,6 +1196,11 @@
constant_count = constant_buffer->data.data_length / sizeof(uint32_t);
constants = (const uint32_t*)constant_buffer->data.data;
}
+
+ iree_hal_executable_cache_t* executable_cache = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_module_state_lookup_executable_cache(
+ state, device, &executable_cache));
+
iree_host_size_t pipeline_layout_count = args->a4_count;
iree_hal_pipeline_layout_t** pipeline_layouts = NULL;
IREE_RETURN_IF_ERROR(
@@ -1164,7 +1230,7 @@
executable_params.constant_count = constant_count;
executable_params.constants = constants;
status = iree_hal_executable_cache_prepare_executable(
- state->executable_cache, &executable_params, &executable);
+ executable_cache, &executable_params, &executable);
}
iree_allocator_free(state->host_allocator, pipeline_layouts);
@@ -1573,13 +1639,15 @@
};
IREE_API_EXPORT iree_status_t iree_hal_module_create(
- iree_vm_instance_t* instance, iree_hal_device_t* device,
- iree_hal_module_flags_t flags, iree_allocator_t host_allocator,
- iree_vm_module_t** out_module) {
+ iree_vm_instance_t* instance, iree_host_size_t device_count,
+ iree_hal_device_t** devices, iree_hal_module_flags_t flags,
+ iree_allocator_t host_allocator, iree_vm_module_t** out_module) {
IREE_ASSERT_ARGUMENT(instance);
- IREE_ASSERT_ARGUMENT(device);
+ IREE_ASSERT_ARGUMENT(device_count);
+ IREE_ASSERT_ARGUMENT(devices);
IREE_ASSERT_ARGUMENT(out_module);
*out_module = NULL;
+ IREE_TRACE_ZONE_BEGIN(z0);
// Setup the interface with the functions we implement ourselves. Any function
// we omit will be handled by the base native module.
@@ -1591,8 +1659,9 @@
};
// Allocate shared module state.
- iree_host_size_t total_size =
- iree_vm_native_module_size() + sizeof(iree_hal_module_t);
+ iree_host_size_t total_size = iree_vm_native_module_size() +
+ sizeof(iree_hal_module_t) +
+ device_count * sizeof(iree_hal_device_t*);
iree_vm_module_t* base_module = NULL;
IREE_RETURN_IF_ERROR(
iree_allocator_malloc(host_allocator, total_size, (void**)&base_module));
@@ -1602,6 +1671,7 @@
instance, host_allocator, base_module);
if (!iree_status_is_ok(status)) {
iree_allocator_free(host_allocator, base_module);
+ IREE_TRACE_ZONE_END(z0);
return status;
}
@@ -1609,15 +1679,25 @@
module->host_allocator = host_allocator;
// TODO(benvanik): fix vm yield with result storage.
module->flags = flags | IREE_HAL_MODULE_FLAG_SYNCHRONOUS;
- module->shared_device = device;
- iree_hal_device_retain(module->shared_device);
+ module->device_count = device_count;
+ for (iree_host_size_t i = 0; i < device_count; ++i) {
+ module->devices[i] = devices[i];
+ iree_hal_device_retain(module->devices[i]);
+ }
*out_module = base_module;
+ IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
-IREE_API_EXPORT iree_hal_device_t* iree_hal_module_state_device(
- iree_vm_module_state_t* module_state) {
+IREE_API_EXPORT iree_host_size_t
+iree_hal_module_state_device_count(iree_vm_module_state_t* module_state) {
iree_hal_module_state_t* state = (iree_hal_module_state_t*)module_state;
- return state->shared_device;
+ return state->device_count;
+}
+
+IREE_API_EXPORT iree_hal_device_t* iree_hal_module_state_device_get(
+ iree_vm_module_state_t* module_state, iree_host_size_t index) {
+ iree_hal_module_state_t* state = (iree_hal_module_state_t*)module_state;
+ return index < state->device_count ? state->devices[index] : NULL;
}
diff --git a/runtime/src/iree/modules/hal/module.h b/runtime/src/iree/modules/hal/module.h
index 5ba71af..ee45570 100644
--- a/runtime/src/iree/modules/hal/module.h
+++ b/runtime/src/iree/modules/hal/module.h
@@ -26,18 +26,23 @@
};
typedef uint32_t iree_hal_module_flags_t;
-// Creates the HAL module initialized to use a specific |device|.
-// Each context using this module will share the device and have compatible
+// Creates the HAL module initialized to use one or more |devices|.
+// Each context using this module will share the devices and have compatible
// allocations.
IREE_API_EXPORT iree_status_t iree_hal_module_create(
- iree_vm_instance_t* instance, iree_hal_device_t* device,
- iree_hal_module_flags_t flags, iree_allocator_t host_allocator,
- iree_vm_module_t** out_module);
+ iree_vm_instance_t* instance, iree_host_size_t device_count,
+ iree_hal_device_t** devices, iree_hal_module_flags_t flags,
+ iree_allocator_t host_allocator, iree_vm_module_t** out_module);
-// Returns the device currently in use by the HAL module.
-// Returns NULL if no device has been initialized yet.
-IREE_API_EXPORT iree_hal_device_t* iree_hal_module_state_device(
- iree_vm_module_state_t* module_state);
+// Returns the total number of available devices registered with the HAL module.
+IREE_API_EXPORT iree_host_size_t
+iree_hal_module_state_device_count(iree_vm_module_state_t* module_state);
+
+// Returns the device at |index| currently in use by the HAL module.
+// Returns NULL if no device has been initialized yet or the index is out of
+// bounds.
+IREE_API_EXPORT iree_hal_device_t* iree_hal_module_state_device_get(
+ iree_vm_module_state_t* module_state, iree_host_size_t index);
#ifdef __cplusplus
} // extern "C"
diff --git a/runtime/src/iree/runtime/session.c b/runtime/src/iree/runtime/session.c
index e166569..4286178 100644
--- a/runtime/src/iree/runtime/session.c
+++ b/runtime/src/iree/runtime/session.c
@@ -94,8 +94,9 @@
iree_vm_module_t* hal_module = NULL;
if (iree_status_is_ok(status)) {
status = iree_hal_module_create(iree_runtime_instance_vm_instance(instance),
- device, IREE_HAL_MODULE_FLAG_NONE,
- host_allocator, &hal_module);
+ /*device_count=*/1, &device,
+ IREE_HAL_MODULE_FLAG_NONE, host_allocator,
+ &hal_module);
}
if (iree_status_is_ok(status)) {
status = iree_vm_context_register_modules(
@@ -163,7 +164,8 @@
IREE_API_EXPORT iree_hal_device_t* iree_runtime_session_device(
const iree_runtime_session_t* session) {
IREE_ASSERT_ARGUMENT(session);
- return iree_hal_module_state_device(session->hal_module_state);
+ // NOTE: only one device is supported via this API today.
+ return iree_hal_module_state_device_get(session->hal_module_state, 0);
}
IREE_API_EXPORT iree_hal_allocator_t* iree_runtime_session_device_allocator(
diff --git a/runtime/src/iree/runtime/session.h b/runtime/src/iree/runtime/session.h
index c78d75c..f867da1 100644
--- a/runtime/src/iree/runtime/session.h
+++ b/runtime/src/iree/runtime/session.h
@@ -116,6 +116,8 @@
//
// NOTE: this device will not be available until initialized by a user module
// and will return NULL if queried prior.
+//
+// NOTE: this API does not support multiple devices.
IREE_API_EXPORT iree_hal_device_t* iree_runtime_session_device(
const iree_runtime_session_t* session);
@@ -125,6 +127,8 @@
//
// NOTE: this device allocator will not be available until initialized by a
// user module and will return NULL if queried prior.
+//
+// NOTE: this API does not support multiple devices.
IREE_API_EXPORT iree_hal_allocator_t* iree_runtime_session_device_allocator(
const iree_runtime_session_t* session);
diff --git a/runtime/src/iree/tooling/context_util.c b/runtime/src/iree/tooling/context_util.c
index 135a111..ccb0d9c 100644
--- a/runtime/src/iree/tooling/context_util.c
+++ b/runtime/src/iree/tooling/context_util.c
@@ -196,6 +196,8 @@
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_module_register_all_types(instance));
+ // TODO(multi-device): create multiple devices (maybe with an
+ // iree_hal_device_list_t helper for retaining/managing the dynamic list).
// Create the device to use.
// In the future this will change to a set of available devices instead.
if (iree_string_view_is_empty(default_device_uri)) {
@@ -214,8 +216,8 @@
// Create HAL module wrapping the device created above.
iree_hal_module_flags_t flags = IREE_HAL_MODULE_FLAG_NONE;
iree_vm_module_t* module = NULL;
- iree_status_t status =
- iree_hal_module_create(instance, device, flags, host_allocator, &module);
+ iree_status_t status = iree_hal_module_create(
+ instance, /*device_count=*/1, &device, flags, host_allocator, &module);
if (iree_status_is_ok(status)) {
*out_module = module;
diff --git a/runtime/src/iree/tooling/trace_replay.c b/runtime/src/iree/tooling/trace_replay.c
index 06ed094..5f2350a 100644
--- a/runtime/src/iree/tooling/trace_replay.c
+++ b/runtime/src/iree/tooling/trace_replay.c
@@ -176,8 +176,8 @@
IREE_RETURN_IF_ERROR(iree_trace_replay_create_device(
replay, device_node, replay->host_allocator, &replay->device));
IREE_RETURN_IF_ERROR(iree_hal_module_create(
- replay->instance, replay->device, IREE_HAL_MODULE_FLAG_NONE,
- replay->host_allocator, &module));
+ replay->instance, /*device_count=*/1, &replay->device,
+ IREE_HAL_MODULE_FLAG_NONE, replay->host_allocator, &module));
}
if (!module) {
return iree_make_status(
diff --git a/runtime/src/iree/vm/shims.c b/runtime/src/iree/vm/shims.c
index 83245cb..80f11f5 100644
--- a/runtime/src/iree/vm/shims.c
+++ b/runtime/src/iree/vm/shims.c
@@ -7,6 +7,7 @@
#include "iree/vm/shims.h"
IREE_VM_ABI_DEFINE_SHIM(irIi, v);
+IREE_VM_ABI_DEFINE_SHIM(i, r);
IREE_VM_ABI_DEFINE_SHIM(r, i);
IREE_VM_ABI_DEFINE_SHIM(r, I);
IREE_VM_ABI_DEFINE_SHIM(r, ii);
diff --git a/runtime/src/iree/vm/shims.h b/runtime/src/iree/vm/shims.h
index af75456..860d96f 100644
--- a/runtime/src/iree/vm/shims.h
+++ b/runtime/src/iree/vm/shims.h
@@ -614,6 +614,7 @@
//===----------------------------------------------------------------------===//
IREE_VM_ABI_DECLARE_SHIM(irIi, v);
+IREE_VM_ABI_DECLARE_SHIM(i, r);
IREE_VM_ABI_DECLARE_SHIM(r, i);
IREE_VM_ABI_DECLARE_SHIM(r, I);
IREE_VM_ABI_DECLARE_SHIM(r, ii);
diff --git a/samples/simple_embedding/simple_embedding.c b/samples/simple_embedding/simple_embedding.c
index 57f6761..b5df80b 100644
--- a/samples/simple_embedding/simple_embedding.c
+++ b/samples/simple_embedding/simple_embedding.c
@@ -40,9 +40,9 @@
IREE_RETURN_IF_ERROR(create_sample_device(iree_allocator_system(), &device),
"create device");
iree_vm_module_t* hal_module = NULL;
- IREE_RETURN_IF_ERROR(
- iree_hal_module_create(instance, device, IREE_HAL_MODULE_FLAG_SYNCHRONOUS,
- iree_allocator_system(), &hal_module));
+ IREE_RETURN_IF_ERROR(iree_hal_module_create(
+ instance, /*device_count=*/1, &device, IREE_HAL_MODULE_FLAG_SYNCHRONOUS,
+ iree_allocator_system(), &hal_module));
// Load bytecode module from the embedded data.
const iree_const_byte_span_t module_data = load_bytecode_module_data();