Enabling --device_allocator= support in iree-run-trace.
diff --git a/runtime/src/iree/tooling/BUILD b/runtime/src/iree/tooling/BUILD index 84e1a12..9fc50c8 100644 --- a/runtime/src/iree/tooling/BUILD +++ b/runtime/src/iree/tooling/BUILD
@@ -193,6 +193,7 @@ srcs = ["trace_replay.c"], hdrs = ["trace_replay.h"], deps = [ + ":device_util", ":numpy_io", ":yaml_util", "//runtime/src/iree/base",
diff --git a/runtime/src/iree/tooling/CMakeLists.txt b/runtime/src/iree/tooling/CMakeLists.txt index b3f7284..7e4c87e 100644 --- a/runtime/src/iree/tooling/CMakeLists.txt +++ b/runtime/src/iree/tooling/CMakeLists.txt
@@ -214,6 +214,7 @@ SRCS "trace_replay.c" DEPS + ::device_util ::numpy_io ::yaml_util iree::base
diff --git a/runtime/src/iree/tooling/context_util.c b/runtime/src/iree/tooling/context_util.c index 301819e..9bc9398 100644 --- a/runtime/src/iree/tooling/context_util.c +++ b/runtime/src/iree/tooling/context_util.c
@@ -111,8 +111,9 @@ } iree_hal_device_t* device = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_create_device_from_flags(default_device_uri, host_allocator, - &device)); + z0, iree_hal_create_device_from_flags( + iree_hal_available_driver_registry(), default_device_uri, + host_allocator, &device)); // Fetch the allocator from the device to pass back to the caller. iree_hal_allocator_t* device_allocator = iree_hal_device_allocator(device);
diff --git a/runtime/src/iree/tooling/device_util.c b/runtime/src/iree/tooling/device_util.c index 4d0017f..f738dcd 100644 --- a/runtime/src/iree/tooling/device_util.c +++ b/runtime/src/iree/tooling/device_util.c
@@ -321,6 +321,7 @@ } iree_status_t iree_hal_create_device_from_flags( + iree_hal_driver_registry_t* driver_registry, iree_string_view_t default_device, iree_allocator_t host_allocator, iree_hal_device_t** out_device) { iree_string_view_t device_uri = default_device;
diff --git a/runtime/src/iree/tooling/device_util.h b/runtime/src/iree/tooling/device_util.h index 915bb4b..d290fad 100644 --- a/runtime/src/iree/tooling/device_util.h +++ b/runtime/src/iree/tooling/device_util.h
@@ -30,6 +30,7 @@ // Uses the |default_device| if no flags were specified. // Fails if more than one device was specified. iree_status_t iree_hal_create_device_from_flags( + iree_hal_driver_registry_t* driver_registry, iree_string_view_t default_device, iree_allocator_t host_allocator, iree_hal_device_t** out_device);
diff --git a/runtime/src/iree/tooling/trace_replay.c b/runtime/src/iree/tooling/trace_replay.c index ea5a921..930ffbe 100644 --- a/runtime/src/iree/tooling/trace_replay.c +++ b/runtime/src/iree/tooling/trace_replay.c
@@ -17,6 +17,7 @@ #include "iree/base/internal/path.h" #include "iree/base/tracing.h" #include "iree/modules/hal/module.h" +#include "iree/tooling/device_util.h" #include "iree/tooling/numpy_io.h" #include "iree/vm/bytecode/module.h" @@ -125,8 +126,8 @@ } // Try to create the device. - return iree_hal_create_device(replay->driver_registry, device_uri, - host_allocator, out_device); + return iree_hal_create_device_from_flags(replay->driver_registry, device_uri, + host_allocator, out_device); } static iree_status_t iree_trace_replay_load_builtin_module(