Allow specifying multiple --device= flags in tooling. (#16132)
diff --git a/runtime/src/iree/base/internal/flags.h b/runtime/src/iree/base/internal/flags.h
index 2826efd..a322593 100644
--- a/runtime/src/iree/base/internal/flags.h
+++ b/runtime/src/iree/base/internal/flags.h
@@ -227,13 +227,7 @@
// List flag utilities
//===----------------------------------------------------------------------===//
-// A list of string views referencing flag storage.
-typedef struct iree_flag_string_list_t {
- // Total number of values in the list.
- iree_host_size_t count;
- // Value list or NULL if no values.
- const iree_string_view_t* values;
-} iree_flag_string_list_t;
+typedef struct iree_string_view_list_t iree_flag_string_list_t;
#if IREE_FLAGS_ENABLE_CLI == 1
diff --git a/runtime/src/iree/base/string_view.h b/runtime/src/iree/base/string_view.h
index 313e3db..01e243c 100644
--- a/runtime/src/iree/base/string_view.h
+++ b/runtime/src/iree/base/string_view.h
@@ -98,6 +98,14 @@
// Returns a string view initialized with the given string literal.
#define IREE_SVL(cstr) iree_string_view_literal(cstr)
+// A list of string views.
+typedef struct iree_string_view_list_t {
+ // Total number of values in the list.
+ iree_host_size_t count;
+ // Value list or NULL if no values.
+ const iree_string_view_t* values;
+} iree_string_view_list_t;
+
// Returns true if the two strings are equal (compare == 0).
IREE_API_EXPORT bool iree_string_view_equal(iree_string_view_t lhs,
iree_string_view_t rhs);
diff --git a/runtime/src/iree/hal/device.c b/runtime/src/iree/hal/device.c
index c70fa69..4097a8d 100644
--- a/runtime/src/iree/hal/device.c
+++ b/runtime/src/iree/hal/device.c
@@ -12,6 +12,10 @@
#include "iree/hal/detail.h"
#include "iree/hal/resource.h"
+//===----------------------------------------------------------------------===//
+// iree_hal_device_t
+//===----------------------------------------------------------------------===//
+
#define _VTABLE_DISPATCH(device, method_name) \
IREE_HAL_VTABLE_DISPATCH(device, iree_hal_device, method_name)
@@ -381,3 +385,62 @@
IREE_TRACE_ZONE_END(z0);
return status;
}
+
+//===----------------------------------------------------------------------===//
+// iree_hal_device_list_t
+//===----------------------------------------------------------------------===//
+
+IREE_API_EXPORT iree_status_t iree_hal_device_list_allocate(
+ iree_host_size_t capacity, iree_allocator_t host_allocator,
+ iree_hal_device_list_t** out_list) {
+ IREE_ASSERT_ARGUMENT(out_list);
+ *out_list = NULL;
+ IREE_TRACE_ZONE_BEGIN(z0);
+ iree_hal_device_list_t* list = NULL;
+ iree_host_size_t total_size =
+ sizeof(*list) + capacity * sizeof(list->devices[0]);
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_allocator_malloc(host_allocator, total_size, (void**)&list));
+ list->host_allocator = host_allocator;
+ list->capacity = capacity;
+ list->count = 0;
+ *out_list = list;
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+IREE_API_EXPORT void iree_hal_device_list_free(iree_hal_device_list_t* list) {
+ if (!list) return;
+ IREE_TRACE_ZONE_BEGIN(z0);
+ iree_allocator_t host_allocator = list->host_allocator;
+ for (iree_host_size_t i = 0; i < list->count; ++i) {
+ iree_hal_device_release(list->devices[i]);
+ }
+ iree_allocator_free(host_allocator, list);
+ IREE_TRACE_ZONE_END(z0);
+}
+
+IREE_API_EXPORT iree_status_t iree_hal_device_list_push_back(
+ iree_hal_device_list_t* list, iree_hal_device_t* device) {
+ IREE_ASSERT_ARGUMENT(list);
+ IREE_ASSERT_ARGUMENT(device);
+ IREE_TRACE_ZONE_BEGIN(z0);
+ iree_status_t status = iree_ok_status();
+ if (list->count + 1 <= list->capacity) {
+ iree_hal_device_retain(device);
+ list->devices[list->count++] = device;
+ } else {
+ status = iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
+ "list capacity %" PRIhsz
+ " reached; no more devices can be added",
+ list->capacity);
+ }
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+IREE_API_EXPORT iree_hal_device_t* iree_hal_device_list_at(
+ const iree_hal_device_list_t* list, iree_host_size_t i) {
+ IREE_ASSERT_ARGUMENT(list);
+ return i < list->count ? list->devices[i] : NULL;
+}
diff --git a/runtime/src/iree/hal/device.h b/runtime/src/iree/hal/device.h
index cc06cf8..7a04654 100644
--- a/runtime/src/iree/hal/device.h
+++ b/runtime/src/iree/hal/device.h
@@ -480,6 +480,36 @@
iree_hal_device_profiling_end(iree_hal_device_t* device);
//===----------------------------------------------------------------------===//
+// iree_hal_device_list_t
+//===----------------------------------------------------------------------===//
+
+// A fixed-size list of retained devices.
+typedef struct iree_hal_device_list_t {
+ iree_allocator_t host_allocator;
+ iree_host_size_t capacity;
+ iree_host_size_t count;
+ iree_hal_device_t* devices[];
+} iree_hal_device_list_t;
+
+// Allocates an empty device list with the given capacity.
+IREE_API_EXPORT iree_status_t iree_hal_device_list_allocate(
+ iree_host_size_t capacity, iree_allocator_t host_allocator,
+ iree_hal_device_list_t** out_list);
+
+// Frees a device |list|.
+IREE_API_EXPORT void iree_hal_device_list_free(iree_hal_device_list_t* list);
+
+// Pushes a |device| onto the |list| and retains it.
+IREE_API_EXPORT iree_status_t iree_hal_device_list_push_back(
+ iree_hal_device_list_t* list, iree_hal_device_t* device);
+
+// Returns the device at index |i| in the |list| or NULL if out of range.
+// Callers must retain the device if it's possible for the returned pointer to
+// live beyond the list.
+IREE_API_EXPORT iree_hal_device_t* iree_hal_device_list_at(
+ const iree_hal_device_list_t* list, iree_host_size_t i);
+
+//===----------------------------------------------------------------------===//
// iree_hal_device_t implementation details
//===----------------------------------------------------------------------===//
diff --git a/runtime/src/iree/tooling/context_util.c b/runtime/src/iree/tooling/context_util.c
index ccb0d9c..5ec1028 100644
--- a/runtime/src/iree/tooling/context_util.c
+++ b/runtime/src/iree/tooling/context_util.c
@@ -196,18 +196,20 @@
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_module_register_all_types(instance));
- // TODO(multi-device): create multiple devices (maybe with an
- // iree_hal_device_list_t helper for retaining/managing the dynamic list).
- // Create the device to use.
- // In the future this will change to a set of available devices instead.
+ // Create the device(s) to use.
if (iree_string_view_is_empty(default_device_uri)) {
default_device_uri = iree_hal_default_device_uri();
}
- iree_hal_device_t* device = NULL;
+ iree_hal_device_list_t* device_list = NULL;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
- z0, iree_hal_create_device_from_flags(
+ z0, iree_hal_create_devices_from_flags(
iree_hal_available_driver_registry(), default_device_uri,
- host_allocator, &device));
+ host_allocator, &device_list));
+
+ // Pick a lead device we'll use for bookkeeping.
+ iree_hal_device_t* device = iree_hal_device_list_at(device_list, 0);
+ IREE_ASSERT(device, "require at least one device");
+ iree_hal_device_retain(device);
// Fetch the allocator from the device to pass back to the caller.
iree_hal_allocator_t* device_allocator = iree_hal_device_allocator(device);
@@ -216,8 +218,11 @@
// Create HAL module wrapping the device created above.
iree_hal_module_flags_t flags = IREE_HAL_MODULE_FLAG_NONE;
iree_vm_module_t* module = NULL;
- iree_status_t status = iree_hal_module_create(
- instance, /*device_count=*/1, &device, flags, host_allocator, &module);
+ iree_status_t status =
+ iree_hal_module_create(instance, device_list->count, device_list->devices,
+ flags, host_allocator, &module);
+
+ iree_hal_device_list_free(device_list);
if (iree_status_is_ok(status)) {
*out_module = module;
diff --git a/runtime/src/iree/tooling/context_util.h b/runtime/src/iree/tooling/context_util.h
index 75a2459..8116e8b 100644
--- a/runtime/src/iree/tooling/context_util.h
+++ b/runtime/src/iree/tooling/context_util.h
@@ -51,9 +51,12 @@
//
// |default_device_uri| can be specified to provide a default if a device flag
// is not provided by the user.
-// |out_device| will contain the created device if using the full HAL.
+// |out_device| will contain the first created device if using the full HAL.
// |out_device_allocator| can be used to allocate buffers for use with the
// context and is available in all execution models.
+//
+// If multiple devices are created the one returned (and it's corresponding
+// allocator) are considered the 'lead' device for bookkeeping purposes.
iree_status_t iree_tooling_resolve_modules(
iree_vm_instance_t* instance, iree_host_size_t user_module_count,
iree_vm_module_t** user_modules, iree_string_view_t default_device_uri,
@@ -86,9 +89,12 @@
//
// |default_device_uri| can be specified to provide a default if a device flag
// is not provided by the user.
-// |out_device| will contain the created device if using the full HAL.
+// |out_device| will contain the first created device if using the full HAL.
// |out_device_allocator| can be used to allocate buffers for use with the
// context and is available in all execution models.
+//
+// If multiple devices are created the one returned (and it's corresponding
+// allocator) are considered the 'lead' device for bookkeeping purposes.
iree_status_t iree_tooling_create_context_from_flags(
iree_vm_instance_t* instance, iree_host_size_t user_module_count,
iree_vm_module_t** user_modules, iree_string_view_t default_device_uri,
diff --git a/runtime/src/iree/tooling/device_util.c b/runtime/src/iree/tooling/device_util.c
index 0a0798a..bd864ad 100644
--- a/runtime/src/iree/tooling/device_util.c
+++ b/runtime/src/iree/tooling/device_util.c
@@ -307,6 +307,12 @@
// Collectives configuration
//===----------------------------------------------------------------------===//
+// TODO(multi-device): support more provider types/have a provider registry.
+// MPI is insufficient for heterogeneous/multi-device configurations. Currently
+// we set the same provider for every device and that'll really confuse things
+// as MPI rank/count configuration is global in the environment and not per
+// device. Hosting frameworks/runtimes can set their own providers that have
+// more meaningful representation of multi-device/multi-node.
iree_status_t iree_hal_device_set_default_channel_provider(
iree_hal_device_t* device) {
if (!iree_hal_mpi_is_configured()) return iree_ok_status();
@@ -330,63 +336,91 @@
"Use --list_devices/--dump_devices to see available devices and their\n"
"canonical URI used with this flag.");
-// TODO(#5724): remove this and replace with an iree_hal_device_set_t.
-void iree_hal_get_devices_flag_list(iree_host_size_t* out_count,
- const iree_string_view_t** out_list) {
- *out_count = FLAG_device_list().count;
- *out_list = FLAG_device_list().values;
+iree_string_view_list_t iree_hal_device_flag_list(void) {
+ return FLAG_device_list();
}
iree_status_t iree_hal_create_device_from_flags(
iree_hal_driver_registry_t* driver_registry,
iree_string_view_t default_device, iree_allocator_t host_allocator,
iree_hal_device_t** out_device) {
- iree_string_view_t device_uri = default_device;
- const iree_flag_string_list_t list = FLAG_device_list();
- if (list.count == 0) {
+ iree_hal_device_list_t* device_list = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_create_devices_from_flags(
+ driver_registry, default_device, host_allocator, &device_list));
+ iree_hal_device_t* device = iree_hal_device_list_at(device_list, 0);
+ iree_hal_device_retain(device);
+ iree_hal_device_list_free(device_list);
+ *out_device = device;
+ return iree_ok_status();
+}
+
+iree_status_t iree_hal_create_devices_from_flags(
+ iree_hal_driver_registry_t* driver_registry,
+ iree_string_view_t default_device, iree_allocator_t host_allocator,
+ iree_hal_device_list_t** out_device_list) {
+ iree_flag_string_list_t flag_list = FLAG_device_list();
+ if (flag_list.count == 0) {
// No devices specified. Use default if provided.
if (iree_string_view_is_empty(default_device)) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
- "no device specified; use --list_devices to see the "
- "available devices and specify one with --device=");
+ "no devices specified; use --list_devices to see the "
+ "available devices and specify one or more with --device=");
}
- } else if (list.count > 1) {
- // Too many devices for the single device creation function.
- return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
- "too many devices specified; only one --device= "
- "flag may be provided with this API");
- } else {
- // Exactly one device specified.
- device_uri = list.values[0];
+ flag_list.count = 1;
+ flag_list.values = &default_device;
}
IREE_TRACE_ZONE_BEGIN(z0);
+ IREE_TRACE({
+ for (iree_host_size_t i = 0; i < flag_list.count; ++i) {
+ IREE_TRACE_ZONE_APPEND_TEXT(z0, flag_list.values[i].data,
+ flag_list.values[i].size);
+ }
+ });
- // Create the device, which may be slow and dynamically load big dependencies
- // (CUDA, Vulkan, etc).
- iree_hal_device_t* device = NULL;
+ iree_hal_device_list_t* device_list = NULL;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
- z0, iree_hal_create_device(iree_hal_available_driver_registry(),
- device_uri, host_allocator, &device));
+ z0, iree_hal_device_list_allocate(flag_list.count, host_allocator,
+ &device_list));
- // Optionally wrap the base device allocator with caching/pooling.
- // Doing this here satisfies the requirement that no buffers have been
- // allocated yet - if we returned the device without doing this the caller can
- // more easily break the rules.
- iree_status_t status = iree_hal_configure_allocator_from_flags(device);
+ iree_status_t status = iree_ok_status();
+ for (iree_host_size_t i = 0; i < flag_list.count; ++i) {
+ // Create the device, which may be slow and dynamically load big
+ // dependencies (CUDA, Vulkan, etc).
+ iree_hal_device_t* device = NULL;
+ status = iree_hal_create_device(driver_registry, flag_list.values[i],
+ host_allocator, &device);
- // Optionally set a collective channel provider used by devices to initialize
- // their default channels. Hosting libraries or applications can do the same
- // to interface with their own implementations.
- if (iree_status_is_ok(status)) {
- status = iree_hal_device_set_default_channel_provider(device);
+ // Optionally wrap the base device allocator with caching/pooling.
+ // Doing this here satisfies the requirement that no buffers have been
+ // allocated yet - if we returned the device without doing this the caller
+ // can more easily break the rules.
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_configure_allocator_from_flags(device);
+ }
+
+ // Optionally set a collective channel provider used by devices to
+ // initialize their default channels. Hosting libraries or applications can
+ // do the same to interface with their own implementations. Note that this
+ // currently sets the same provider for all devices.
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_device_set_default_channel_provider(device);
+ }
+
+ // Add the device to the list to retain it for the caller.
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_device_list_push_back(device_list, device);
+ }
+ iree_hal_device_release(device);
+
+ if (!iree_status_is_ok(status)) break;
}
if (iree_status_is_ok(status)) {
- *out_device = device;
+ *out_device_list = device_list;
} else {
- iree_hal_device_release(device);
+ iree_hal_device_list_free(device_list);
}
IREE_TRACE_ZONE_END(z0);
return status;
diff --git a/runtime/src/iree/tooling/device_util.h b/runtime/src/iree/tooling/device_util.h
index 68aede9..b9a4330 100644
--- a/runtime/src/iree/tooling/device_util.h
+++ b/runtime/src/iree/tooling/device_util.h
@@ -22,9 +22,9 @@
// flags and tools should encourage that.
iree_string_view_t iree_hal_default_device_uri(void);
-// TODO(#5724): remove this and replace with an iree_hal_device_set_t.
-void iree_hal_get_devices_flag_list(iree_host_size_t* out_count,
- const iree_string_view_t** out_list);
+// Returns a reference to the storage of the --device= flag.
+// Changes to flags invalidate the storage.
+iree_string_view_list_t iree_hal_device_flag_list(void);
// Creates a single device from the --device= flag.
// Uses the |default_device| if no flags were specified.
@@ -34,6 +34,13 @@
iree_string_view_t default_device, iree_allocator_t host_allocator,
iree_hal_device_t** out_device);
+// Creates one or more devices from the repeatable --device= flag.
+// Uses the |default_device| if no flags were specified.
+iree_status_t iree_hal_create_devices_from_flags(
+ iree_hal_driver_registry_t* driver_registry,
+ iree_string_view_t default_device, iree_allocator_t host_allocator,
+ iree_hal_device_list_t** out_device_list);
+
// Configures the |device| channel provider based on the current environment.
// Today this simply checks to see if the process is running under MPI and
// initializes that unconditionally.
diff --git a/runtime/src/iree/tooling/trace_replay.c b/runtime/src/iree/tooling/trace_replay.c
index 7bfdcf7..fea9aa6 100644
--- a/runtime/src/iree/tooling/trace_replay.c
+++ b/runtime/src/iree/tooling/trace_replay.c
@@ -88,9 +88,7 @@
}
void iree_trace_replay_set_hal_devices_override(
- iree_trace_replay_t* replay, iree_host_size_t device_uri_count,
- const iree_string_view_t* device_uris) {
- replay->device_uri_count = device_uri_count;
+ iree_trace_replay_t* replay, iree_string_view_list_t device_uris) {
replay->device_uris = device_uris;
}
@@ -133,7 +131,9 @@
// type: module_load
//===----------------------------------------------------------------------===//
-// TODO(benvanik): rework this to allow for multiple devices from a device set.
+// TODO(multi-device): rework this to allow for multiple devices. Today it
+// assumes there's a single device but the |device_node| can specify any
+// arbitrary device we want to create.
static iree_status_t iree_trace_replay_create_device(
iree_trace_replay_t* replay, yaml_node_t* device_node,
iree_allocator_t host_allocator, iree_hal_device_t** out_device) {
@@ -147,15 +147,17 @@
// Use the provided driver name or override with the --device= flag.
iree_string_view_t device_uri = iree_yaml_node_as_string(device_node);
if (iree_string_view_is_empty(device_uri)) {
- if (replay->device_uri_count != 1) {
+ if (replay->device_uris.count != 1) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"exactly one device must be specified when none "
"is present in the trace file");
}
- device_uri = replay->device_uris[0];
+ device_uri = replay->device_uris.values[0];
}
// Try to create the device.
+ // NOTE: this is assuming a single device only - we may retire trace replay
+ // before we ever try to make it work with multiple devices.
return iree_hal_create_device_from_flags(replay->driver_registry, device_uri,
host_allocator, out_device);
}
@@ -345,7 +347,7 @@
IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(UINT_16, uint16_t)
IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(UINT_32, uint32_t)
IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(UINT_64, uint64_t)
- // clang-format off
+ // clang-format off
case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
*(uint16_t*)dst = iree_math_f32_to_f16((float)value);
break;
diff --git a/runtime/src/iree/tooling/trace_replay.h b/runtime/src/iree/tooling/trace_replay.h
index 27d670b..ee707ed 100644
--- a/runtime/src/iree/tooling/trace_replay.h
+++ b/runtime/src/iree/tooling/trace_replay.h
@@ -71,9 +71,11 @@
// be read on each event. Must remain valid for the lifetime of the replay.
iree_const_byte_span_t stdin_contents;
+ // HAL driver registry used to enumerate and create devices.
iree_hal_driver_registry_t* driver_registry;
- iree_host_size_t device_uri_count;
- const iree_string_view_t* device_uris;
+ // Unowned reference to a list of device URIs overriding the original devices
+ // specified in the trace.
+ iree_string_view_list_t device_uris;
// All loaded modules, reset each context load unless
// IREE_TRACE_REPLAY_FLAG_REUSE_MODULES is set to preserve them.
@@ -114,11 +116,10 @@
void iree_trace_replay_deinitialize(iree_trace_replay_t* replay);
// TODO(#5724): remove this and instead provide a device set on initialize.
-// Overrides the HAL driver used in the trace with the given |driver|.
+// Overrides the HAL driver used in the trace with the given devices.
// |device_uris| must remain valid for the lifetime of the replay instance.
void iree_trace_replay_set_hal_devices_override(
- iree_trace_replay_t* replay, iree_host_size_t device_uri_count,
- const iree_string_view_t* device_uris);
+ iree_trace_replay_t* replay, iree_string_view_list_t device_uris);
// Resets replay input/output/blackboard state.
void iree_trace_replay_reset(iree_trace_replay_t* replay);
diff --git a/tools/iree-benchmark-trace-main.c b/tools/iree-benchmark-trace-main.c
index a992261..5920714 100644
--- a/tools/iree-benchmark-trace-main.c
+++ b/tools/iree-benchmark-trace-main.c
@@ -188,12 +188,8 @@
// Query device overrides, if any. When omitted the devices from the trace
// file will be used.
- // TODO(#5724): remove this and instead provide a device set on initialize.
- iree_host_size_t device_uri_count = 0;
- const iree_string_view_t* device_uris = NULL;
- iree_hal_get_devices_flag_list(&device_uri_count, &device_uris);
- iree_trace_replay_set_hal_devices_override(&replay, device_uri_count,
- device_uris);
+ iree_trace_replay_set_hal_devices_override(&replay,
+ iree_hal_device_flag_list());
// Open trace YAML file from the given file_path.
FILE* file = fopen(registration->file_path.data, "rb");
diff --git a/tools/iree-e2e-matmul-test.c b/tools/iree-e2e-matmul-test.c
index 5d76c48..b6bdff0 100644
--- a/tools/iree-e2e-matmul-test.c
+++ b/tools/iree-e2e-matmul-test.c
@@ -1173,12 +1173,8 @@
// Query device overrides, if any. When omitted the devices from the trace
// file will be used.
- // TODO(#5724): remove this and instead provide a device set on initialize.
- iree_host_size_t device_uri_count = 0;
- const iree_string_view_t* device_uris = NULL;
- iree_hal_get_devices_flag_list(&device_uri_count, &device_uris);
- iree_trace_replay_set_hal_devices_override(&replay, device_uri_count,
- device_uris);
+ iree_trace_replay_set_hal_devices_override(&replay,
+ iree_hal_device_flag_list());
yaml_parser_t parser;
if (!yaml_parser_initialize(&parser)) {
diff --git a/tools/iree-run-mlir-main.cc b/tools/iree-run-mlir-main.cc
index 00c62fb..d2eb3f8 100644
--- a/tools/iree-run-mlir-main.cc
+++ b/tools/iree-run-mlir-main.cc
@@ -129,14 +129,13 @@
// guesswork. If we can't produce a target backend flag value we bail.
// Returns a comma-delimited list of target backends.
StatusOr<std::string> InferTargetBackendsFromDevices(
- iree_host_size_t device_flag_count,
- const iree_string_view_t* device_flag_values) {
+ iree_string_view_list_t device_uris) {
// No-op when no devices specified (probably no HAL).
- if (device_flag_count == 0) return "";
+ if (device_uris.count == 0) return "";
// If multiple devices were provided we need to target all of them.
std::set<std::string> target_backends;
- for (iree_host_size_t i = 0; i < device_flag_count; ++i) {
- auto target_backend = InferTargetBackendFromDevice(device_flag_values[i]);
+ for (iree_host_size_t i = 0; i < device_uris.count; ++i) {
+ auto target_backend = InferTargetBackendFromDevice(device_uris.values[i]);
if (!target_backend.empty()) {
target_backends.insert(std::move(target_backend));
}
@@ -176,20 +175,18 @@
// Query the tooling utils for the --device= flag values. Note that zero or
// more devices may be specified.
- iree_host_size_t device_flag_count = 0;
- const iree_string_view_t* device_flag_values = NULL;
- iree_hal_get_devices_flag_list(&device_flag_count, &device_flag_values);
+ iree_string_view_list_t device_uris = iree_hal_device_flag_list();
// No-op if no target backends or devices are specified - this can be an
// intentional decision as the user may be running a program that doesn't use
// the HAL.
- if (target_backends_flag.empty() && device_flag_count == 0) {
+ if (target_backends_flag.empty() && device_uris.count == 0) {
return OkStatus();
}
// No-op if both target backends and devices are set as the user has
// explicitly specified a configuration.
- if (!target_backends_flag.empty() && device_flag_count > 0) {
+ if (!target_backends_flag.empty() && device_uris.count > 0) {
return OkStatus();
}
@@ -197,7 +194,7 @@
// the compiler configuration. This only works if there's a single backend
// specified; if the user wants multiple target backends then they must
// specify the device(s) to use.
- if (device_flag_count == 0) {
+ if (device_uris.count == 0) {
if (target_backends_flag.find(',') != std::string::npos) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
@@ -216,9 +213,8 @@
// guesses. In the future we'll have more ways of configuring the compiler
// from available runtime devices (not just the target backend but also
// target-specific settings).
- IREE_ASSIGN_OR_RETURN(
- auto target_backends,
- InferTargetBackendsFromDevices(device_flag_count, device_flag_values));
+ IREE_ASSIGN_OR_RETURN(auto target_backends,
+ InferTargetBackendsFromDevices(device_uris));
if (!target_backends.empty()) {
auto target_backends_flag =
std::string("--iree-hal-target-backends=") + target_backends;
diff --git a/tools/iree-run-trace-main.c b/tools/iree-run-trace-main.c
index b1b39dc..4bca457 100644
--- a/tools/iree-run-trace-main.c
+++ b/tools/iree-run-trace-main.c
@@ -146,12 +146,8 @@
// Query device overrides, if any. When omitted the devices from the trace
// file will be used.
- // TODO(#5724): remove this and instead provide a device set on initialize.
- iree_host_size_t device_uri_count = 0;
- const iree_string_view_t* device_uris = NULL;
- iree_hal_get_devices_flag_list(&device_uri_count, &device_uris);
- iree_trace_replay_set_hal_devices_override(&replay, device_uri_count,
- device_uris);
+ iree_trace_replay_set_hal_devices_override(&replay,
+ iree_hal_device_flag_list());
yaml_parser_t parser;
if (!yaml_parser_initialize(&parser)) {