Adding hal.device.id queries to HAL devices. (#16495)
This allows for the compiler to query whether a particular device ID
(`cuda`, `local-task`, etc) matches a pattern string (`local-*`, etc).
Today all devices just report on whatever ID they were assigned by their
driver but we could support other HAL driver matches in the future (for
things like remote devices that may expose multiple subdevices).
Note that this is not a vendor identifier (like `NVIDIA GeForce ..`).
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index cab95ba..f61a913 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -973,6 +973,17 @@
return success();
}
+// static
+Value DeviceQueryOp::createI1(Location loc, Value device, StringRef category,
+ StringRef key, OpBuilder &builder) {
+ auto i1Type = builder.getI1Type();
+ return builder
+ .create<IREE::HAL::DeviceQueryOp>(
+ loc, i1Type, i1Type, device, builder.getStringAttr(category),
+ builder.getStringAttr(key), builder.getIntegerAttr(i1Type, 0))
+ .getValue();
+}
+
//===----------------------------------------------------------------------===//
// hal.device.queue.*
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
index e030399..0dd54ea 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -1629,7 +1629,10 @@
Well-known keys:
- * hal.executable.format :: {some format}
+ * hal.device.id :: {some id pattern}
+ Returns 1 if the device identifier matches the given pattern string.
+
+ * hal.executable.format :: {some format pattern}
Returns 1 if the given format is supported by the device loader.
* hal.device :: concurrency
@@ -1662,6 +1665,14 @@
attr-dict-with-keyword
}];
+ let extraClassDeclaration = [{
+ // Returns a true i1 if the given query returns a non-zero value.
+ // Returns false if the query fails or returns a zero value.
+ static Value createI1(Location loc, Value device,
+ StringRef category, StringRef key,
+ OpBuilder &builder);
+ }];
+
let hasVerifier = 1;
}
diff --git a/experimental/hip/hip_device.c b/experimental/hip/hip_device.c
index 488313a..4093020 100644
--- a/experimental/hip/hip_device.c
+++ b/experimental/hip/hip_device.c
@@ -384,6 +384,12 @@
iree_string_view_t key, int64_t* out_value) {
*out_value = 0;
+ if (iree_string_view_equal(category, IREE_SV("hal.device.id"))) {
+ *out_value =
+ iree_string_view_match_pattern(device->identifier, key) ? 1 : 0;
+ return iree_ok_status();
+ }
+
if (iree_string_view_equal(category, IREE_SV("hal.executable.format"))) {
*out_value = iree_string_view_equal(key, IREE_SV("rocm-hsaco-fb")) ? 1 : 0;
return iree_ok_status();
diff --git a/experimental/rocm/rocm_device.c b/experimental/rocm/rocm_device.c
index b76f165..86147f5 100644
--- a/experimental/rocm/rocm_device.c
+++ b/experimental/rocm/rocm_device.c
@@ -191,9 +191,15 @@
static iree_status_t iree_hal_rocm_device_query_i64(
iree_hal_device_t* base_device, iree_string_view_t category,
iree_string_view_t key, int64_t* out_value) {
- // iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device);
+ iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device);
*out_value = 0;
+ if (iree_string_view_equal(category, IREE_SV("hal.device.id"))) {
+ *out_value =
+ iree_string_view_match_pattern(device->identifier, key) ? 1 : 0;
+ return iree_ok_status();
+ }
+
if (iree_string_view_equal(category,
iree_make_cstring_view("hal.executable.format"))) {
*out_value =
diff --git a/experimental/webgpu/webgpu_device.c b/experimental/webgpu/webgpu_device.c
index 066261a..8af38c0 100644
--- a/experimental/webgpu/webgpu_device.c
+++ b/experimental/webgpu/webgpu_device.c
@@ -212,10 +212,15 @@
static iree_status_t iree_hal_webgpu_device_query_i64(
iree_hal_device_t* base_device, iree_string_view_t category,
iree_string_view_t key, int64_t* out_value) {
- // iree_hal_webgpu_device_t* device =
- // iree_hal_webgpu_device_cast(base_device);
+ iree_hal_webgpu_device_t* device = iree_hal_webgpu_device_cast(base_device);
*out_value = 0;
+ if (iree_string_view_equal(category, IREE_SV("hal.device.id"))) {
+ *out_value =
+ iree_string_view_match_pattern(device->identifier, key) ? 1 : 0;
+ return iree_ok_status();
+ }
+
if (iree_string_view_equal(category,
iree_make_cstring_view("hal.executable.format"))) {
*out_value =
diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_device.c b/runtime/src/iree/hal/drivers/cuda/cuda_device.c
index 9189b4c..f7d550c 100644
--- a/runtime/src/iree/hal/drivers/cuda/cuda_device.c
+++ b/runtime/src/iree/hal/drivers/cuda/cuda_device.c
@@ -408,6 +408,12 @@
iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device);
*out_value = 0;
+ if (iree_string_view_equal(category, IREE_SV("hal.device.id"))) {
+ *out_value =
+ iree_string_view_match_pattern(device->identifier, key) ? 1 : 0;
+ return iree_ok_status();
+ }
+
if (iree_string_view_equal(category, IREE_SV("hal.executable.format"))) {
*out_value = iree_string_view_equal(key, IREE_SV("cuda-nvptx-fb")) ? 1 : 0;
return iree_ok_status();
diff --git a/runtime/src/iree/hal/drivers/local_sync/sync_device.c b/runtime/src/iree/hal/drivers/local_sync/sync_device.c
index 7a85962..711704c 100644
--- a/runtime/src/iree/hal/drivers/local_sync/sync_device.c
+++ b/runtime/src/iree/hal/drivers/local_sync/sync_device.c
@@ -185,6 +185,12 @@
iree_hal_sync_device_t* device = iree_hal_sync_device_cast(base_device);
*out_value = 0;
+ if (iree_string_view_equal(category, IREE_SV("hal.device.id"))) {
+ *out_value =
+ iree_string_view_match_pattern(device->identifier, key) ? 1 : 0;
+ return iree_ok_status();
+ }
+
if (iree_string_view_equal(category, IREE_SV("hal.executable.format"))) {
*out_value =
iree_hal_query_any_executable_loader_support(
@@ -192,7 +198,9 @@
? 1
: 0;
return iree_ok_status();
- } else if (iree_string_view_equal(category, IREE_SV("hal.device"))) {
+ }
+
+ if (iree_string_view_equal(category, IREE_SV("hal.device"))) {
if (iree_string_view_equal(key, IREE_SV("concurrency"))) {
*out_value = 1;
return iree_ok_status();
diff --git a/runtime/src/iree/hal/drivers/local_task/task_device.c b/runtime/src/iree/hal/drivers/local_task/task_device.c
index 19a5639..7601af5 100644
--- a/runtime/src/iree/hal/drivers/local_task/task_device.c
+++ b/runtime/src/iree/hal/drivers/local_task/task_device.c
@@ -226,6 +226,12 @@
iree_hal_task_device_t* device = iree_hal_task_device_cast(base_device);
*out_value = 0;
+ if (iree_string_view_equal(category, IREE_SV("hal.device.id"))) {
+ *out_value =
+ iree_string_view_match_pattern(device->identifier, key) ? 1 : 0;
+ return iree_ok_status();
+ }
+
if (iree_string_view_equal(category, IREE_SV("hal.executable.format"))) {
*out_value =
iree_hal_query_any_executable_loader_support(
@@ -233,7 +239,9 @@
? 1
: 0;
return iree_ok_status();
- } else if (iree_string_view_equal(category, IREE_SV("hal.device"))) {
+ }
+
+ if (iree_string_view_equal(category, IREE_SV("hal.device"))) {
if (iree_string_view_equal(key, IREE_SV("concurrency"))) {
*out_value = (int64_t)device->queue_count;
return iree_ok_status();
diff --git a/runtime/src/iree/hal/drivers/metal/metal_device.m b/runtime/src/iree/hal/drivers/metal/metal_device.m
index f88747c..05878b0 100644
--- a/runtime/src/iree/hal/drivers/metal/metal_device.m
+++ b/runtime/src/iree/hal/drivers/metal/metal_device.m
@@ -216,8 +216,14 @@
static iree_status_t iree_hal_metal_device_query_i64(iree_hal_device_t* base_device,
iree_string_view_t category,
iree_string_view_t key, int64_t* out_value) {
+ iree_hal_metal_device_t* device = iree_hal_metal_device_cast(base_device);
*out_value = 0;
+ if (iree_string_view_equal(category, IREE_SV("hal.device.id"))) {
+ *out_value = iree_string_view_match_pattern(device->identifier, key) ? 1 : 0;
+ return iree_ok_status();
+ }
+
if (iree_string_view_equal(category, iree_make_cstring_view("hal.executable.format"))) {
*out_value = iree_string_view_equal(key, iree_make_cstring_view("metal-msl-fb")) ? 1 : 0;
return iree_ok_status();
diff --git a/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc b/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc
index aeebe69..0941184 100644
--- a/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc
+++ b/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc
@@ -1420,6 +1420,12 @@
iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device);
*out_value = 0;
+ if (iree_string_view_equal(category, IREE_SV("hal.device.id"))) {
+ *out_value =
+ iree_string_view_match_pattern(device->identifier, key) ? 1 : 0;
+ return iree_ok_status();
+ }
+
if (iree_string_view_equal(category, IREE_SV("hal.executable.format"))) {
if (iree_string_view_equal(key, IREE_SV("vulkan-spirv-fb"))) {
// Base SPIR-V always supported.
diff --git a/runtime/src/iree/modules/hal/module.c b/runtime/src/iree/modules/hal/module.c
index d00c39a..824766b 100644
--- a/runtime/src/iree/modules/hal/module.c
+++ b/runtime/src/iree/modules/hal/module.c
@@ -1155,13 +1155,11 @@
IREE_VM_ABI_EXPORT(iree_hal_module_devices_get, //
iree_hal_module_state_t, //
i, r) {
- if (args->i0 >= state->device_count) {
- return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
- "device index %d out of bounds (%" PRIhsz
- " devices available)",
- args->i0, state->device_count);
+ if (args->i0 < state->device_count) {
+ rets->r0 = iree_hal_device_retain_ref(state->devices[args->i0]);
+ } else {
+ rets->r0 = iree_vm_ref_null();
}
- rets->r0 = iree_hal_device_retain_ref(state->devices[args->i0]);
return iree_ok_status();
}