Centralizing HAL allocator wrapper creation and exposing to python. (#12356)
diff --git a/runtime/bindings/python/CMakeLists.txt b/runtime/bindings/python/CMakeLists.txt
index b1e4b83..72ddffa 100644
--- a/runtime/bindings/python/CMakeLists.txt
+++ b/runtime/bindings/python/CMakeLists.txt
@@ -52,6 +52,7 @@
iree::base::tracing
iree::hal
iree::hal::drivers
+ iree::hal::utils::allocators
iree::modules::hal
iree::vm
iree::vm::bytecode_module
diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc
index c11dfc2..110f121 100644
--- a/runtime/bindings/python/hal.cc
+++ b/runtime/bindings/python/hal.cc
@@ -10,6 +10,7 @@
#include "iree/base/internal/path.h"
#include "iree/base/tracing.h"
#include "iree/hal/api.h"
+#include "iree/hal/utils/allocators.h"
#include "iree/modules/hal/module.h"
#include "pybind11/numpy.h"
@@ -312,15 +313,48 @@
return results;
}
-HalDevice HalDriver::CreateDefaultDevice() {
+// Configures |device| based on flags before returning it to the user.
+static iree_status_t ConfigureDevice(iree_hal_device_t* device,
+ const py::kwargs& kwargs) {
+ // Optionally wrap the base device allocator with caching/pooling.
+ // Doing this here satisfies the requirement that no buffers have been
+ // allocated yet - if we returned the device without doing this the caller
+ // can more easily break the rules.
+ if (kwargs.contains("allocators")) {
+ // NOTE: we need to pass string views that point to the std::string storage.
+ // We do that in two passes because as we grow spec_storage it may
+ // reallocate itself and invalidate the pointers - only after we're done
+ // can we capture them in views.
+ auto spec_list = py::cast<py::list>(kwargs["allocators"]);
+ std::vector<std::string> spec_storage;
+ spec_storage.reserve(spec_list.size());
+ for (auto item : spec_list) {
+ auto spec = py::cast<std::string>(item);
+ spec_storage.push_back(std::move(spec));
+ }
+ std::vector<iree_string_view_t> spec_views;
+ spec_views.reserve(spec_list.size());
+ for (const auto& spec : spec_storage) {
+ spec_views.push_back(iree_make_string_view(spec.data(), spec.size()));
+ }
+ IREE_RETURN_IF_ERROR(iree_hal_configure_allocator_from_specs(
+ spec_views.size(), spec_views.data(), device));
+ }
+ return iree_ok_status();
+}
+
+HalDevice HalDriver::CreateDefaultDevice(const py::kwargs& kwargs) {
iree_hal_device_t* device;
CheckApiStatus(iree_hal_driver_create_default_device(
raw_ptr(), iree_allocator_system(), &device),
"Error creating default device");
+ CheckApiStatus(ConfigureDevice(device, kwargs),
+ "Error configuring the device");
return HalDevice::StealFromRawPtr(device);
}
-HalDevice HalDriver::CreateDevice(iree_hal_device_id_t device_id) {
+HalDevice HalDriver::CreateDevice(iree_hal_device_id_t device_id,
+ const py::kwargs& kwargs) {
// Since the device ids are supposed to be opaque, we need to verify
// them by querying available devices.
py::list available_devices = QueryAvailableDevices();
@@ -352,16 +386,21 @@
raw_ptr(), device_id, params.size(), ¶ms.front(),
iree_allocator_system(), &device),
"Error creating default device");
+ CheckApiStatus(ConfigureDevice(device, kwargs),
+ "Error configuring the device");
return HalDevice::StealFromRawPtr(device);
}
-HalDevice HalDriver::CreateDeviceByURI(std::string& device_uri) {
+HalDevice HalDriver::CreateDeviceByURI(std::string& device_uri,
+ const py::kwargs& kwargs) {
iree_hal_device_t* device;
iree_string_view_t device_uri_sv{device_uri.data(), device_uri.size()};
CheckApiStatus(
iree_hal_driver_create_device_by_uri(raw_ptr(), device_uri_sv,
iree_allocator_system(), &device),
"Error creating device");
+ CheckApiStatus(ConfigureDevice(device, kwargs),
+ "Error configuring the device");
return HalDevice::StealFromRawPtr(device);
}
@@ -591,12 +630,13 @@
py::keep_alive<0, 1>())
.def(
"create_device",
- [](HalDriver& self, py::dict device_info) -> HalDevice {
+ [](HalDriver& self, py::dict device_info,
+ const py::kwargs& kwargs) -> HalDevice {
// Alias of create_device that takes a dict as returned from
// query_available_devices for convenience.
auto device_id =
py::cast<iree_hal_device_id_t>(device_info["device_id"]);
- return self.CreateDevice(device_id);
+ return self.CreateDevice(device_id, kwargs);
},
py::keep_alive<0, 1>())
.def("query_available_devices", &HalDriver::QueryAvailableDevices);
diff --git a/runtime/bindings/python/hal.h b/runtime/bindings/python/hal.h
index 64f0819..ff15dfa 100644
--- a/runtime/bindings/python/hal.h
+++ b/runtime/bindings/python/hal.h
@@ -97,9 +97,11 @@
py::dict& driver_cache);
py::list QueryAvailableDevices();
- HalDevice CreateDefaultDevice();
- HalDevice CreateDevice(iree_hal_device_id_t device_id);
- HalDevice CreateDeviceByURI(std::string& device_uri);
+ HalDevice CreateDefaultDevice(const py::kwargs& kwargs);
+ HalDevice CreateDevice(iree_hal_device_id_t device_id,
+ const py::kwargs& kwargs);
+ HalDevice CreateDeviceByURI(std::string& device_uri,
+ const py::kwargs& kwargs);
};
class HalAllocator : public ApiRefCounted<HalAllocator, iree_hal_allocator_t> {
diff --git a/runtime/bindings/python/tests/system_setup_test.py b/runtime/bindings/python/tests/system_setup_test.py
index d4707c3..3248b7c 100644
--- a/runtime/bindings/python/tests/system_setup_test.py
+++ b/runtime/bindings/python/tests/system_setup_test.py
@@ -61,6 +61,13 @@
with self.assertRaises(ValueError, msg="Device not found: local-sync://1"):
_ = ss.get_device("local-sync://1")
+ def testCreateDeviceWithAllocators(self):
+ driver = ss.get_driver("local-sync")
+ infos = driver.query_available_devices()
+ device1 = driver.create_device(infos[0]["device_id"], allocators=[])
+ device2 = driver.create_device(infos[0]["device_id"],
+ allocators=["caching", "debug"])
+
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
diff --git a/runtime/src/iree/hal/utils/BUILD b/runtime/src/iree/hal/utils/BUILD
index 83bb230..78a4755 100644
--- a/runtime/src/iree/hal/utils/BUILD
+++ b/runtime/src/iree/hal/utils/BUILD
@@ -14,6 +14,20 @@
)
iree_runtime_cc_library(
+ name = "allocators",
+ srcs = ["allocators.c"],
+ hdrs = ["allocators.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":caching_allocator",
+ ":debug_allocator",
+ "//runtime/src/iree/base",
+ "//runtime/src/iree/base:tracing",
+ "//runtime/src/iree/hal",
+ ],
+)
+
+iree_runtime_cc_library(
name = "buffer_transfer",
srcs = ["buffer_transfer.c"],
hdrs = ["buffer_transfer.h"],
diff --git a/runtime/src/iree/hal/utils/CMakeLists.txt b/runtime/src/iree/hal/utils/CMakeLists.txt
index 8199d25..3e84e67 100644
--- a/runtime/src/iree/hal/utils/CMakeLists.txt
+++ b/runtime/src/iree/hal/utils/CMakeLists.txt
@@ -12,6 +12,22 @@
iree_cc_library(
NAME
+ allocators
+ HDRS
+ "allocators.h"
+ SRCS
+ "allocators.c"
+ DEPS
+ ::caching_allocator
+ ::debug_allocator
+ iree::base
+ iree::base::tracing
+ iree::hal
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
buffer_transfer
HDRS
"buffer_transfer.h"
diff --git a/runtime/src/iree/hal/utils/allocators.c b/runtime/src/iree/hal/utils/allocators.c
new file mode 100644
index 0000000..98c406e
--- /dev/null
+++ b/runtime/src/iree/hal/utils/allocators.c
@@ -0,0 +1,69 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/utils/allocators.h"
+
+#include "iree/base/tracing.h"
+#include "iree/hal/utils/caching_allocator.h"
+#include "iree/hal/utils/debug_allocator.h"
+
+iree_status_t iree_hal_configure_allocator_from_spec(
+ iree_string_view_t spec, iree_hal_device_t* device,
+ iree_hal_allocator_t* base_allocator,
+ iree_hal_allocator_t** out_wrapped_allocator) {
+ iree_string_view_t allocator_name = iree_string_view_empty();
+ iree_string_view_t config_pairs = iree_string_view_empty();
+ iree_string_view_split(spec, ':', &allocator_name, &config_pairs);
+ iree_status_t status = iree_ok_status();
+ iree_allocator_t host_allocator =
+ iree_hal_allocator_host_allocator(base_allocator);
+ if (iree_string_view_equal(allocator_name, IREE_SV("caching"))) {
+ status = iree_hal_caching_allocator_create_from_spec(
+ config_pairs, base_allocator, host_allocator, out_wrapped_allocator);
+ } else if (iree_string_view_equal(allocator_name, IREE_SV("debug"))) {
+ status = iree_hal_debug_allocator_create(
+ device, base_allocator, host_allocator, out_wrapped_allocator);
+ } else {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "unrecognized allocator '%.*s'",
+ (int)allocator_name.size, allocator_name.data);
+ }
+ if (iree_status_is_ok(status)) {
+ // New wrapping allocator has taken ownership of the base allocator.
+ iree_hal_allocator_release(base_allocator);
+ }
+ return status;
+}
+
+iree_status_t iree_hal_configure_allocator_from_specs(
+ iree_host_size_t spec_count, const iree_string_view_t* specs,
+ iree_hal_device_t* device) {
+ IREE_ASSERT_ARGUMENT(device);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // The current device allocator should be the base one registered or created
+ // with the device. If no allocator specs were provided this may be no-op and
+ // we'll just pass it right back in.
+ iree_hal_allocator_t* device_allocator = iree_hal_device_allocator(device);
+ iree_hal_allocator_retain(device_allocator);
+
+ // Walk the specs provided and wrap in order from base to last specified.
+ iree_status_t status = iree_ok_status();
+ for (iree_host_size_t i = 0; i < spec_count; ++i) {
+ status = iree_hal_configure_allocator_from_spec(
+ specs[i], device, device_allocator, &device_allocator);
+ if (!iree_status_is_ok(status)) break;
+ }
+
+ // Swap the allocator on the device - this is only safe because we know no
+ // allocations have been made yet.
+ if (iree_status_is_ok(status)) {
+ iree_hal_device_replace_allocator(device, device_allocator);
+ }
+ iree_hal_allocator_release(device_allocator);
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
diff --git a/runtime/src/iree/hal/utils/allocators.h b/runtime/src/iree/hal/utils/allocators.h
new file mode 100644
index 0000000..faa5efd
--- /dev/null
+++ b/runtime/src/iree/hal/utils/allocators.h
@@ -0,0 +1,48 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_UTILS_ALLOCATORS_H_
+#define IREE_HAL_UTILS_ALLOCATORS_H_
+
+#include "iree/base/api.h"
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// WARNING: including this file will pull in all allocator implementations.
+// Only use this if you need the dynamic allocator configuration and otherwise
+// prefer to directly instantiate the allocators you want with their structured
+// options instead of strings.
+
+// Parses a single allocator specification value and wraps |base_allocator|.
+// The available allocators is based on the build configuration.
+//
+// Examples:
+// some_allocator
+// some_allocator:key=value
+// some_allocator:key=value,key=value
+iree_status_t iree_hal_configure_allocator_from_spec(
+ iree_string_view_t spec, iree_hal_device_t* device,
+ iree_hal_allocator_t* base_allocator,
+ iree_hal_allocator_t** out_wrapped_allocator);
+
+// Configures a |device| allocator based on the allocator |specs|.
+// This will wrap the underlying device allocator in zero or more configurable
+// allocator shims.
+//
+// WARNING: not thread-safe and must only be called immediately after device
+// creation.
+iree_status_t iree_hal_configure_allocator_from_specs(
+ iree_host_size_t spec_count, const iree_string_view_t* specs,
+ iree_hal_device_t* device);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_UTILS_ALLOCATORS_H_
diff --git a/runtime/src/iree/tooling/BUILD b/runtime/src/iree/tooling/BUILD
index e6ca5f8..423aa24 100644
--- a/runtime/src/iree/tooling/BUILD
+++ b/runtime/src/iree/tooling/BUILD
@@ -103,8 +103,7 @@
"//runtime/src/iree/base/internal:synchronization",
"//runtime/src/iree/hal",
"//runtime/src/iree/hal/drivers",
- "//runtime/src/iree/hal/utils:caching_allocator",
- "//runtime/src/iree/hal/utils:debug_allocator",
+ "//runtime/src/iree/hal/utils:allocators",
],
)
diff --git a/runtime/src/iree/tooling/CMakeLists.txt b/runtime/src/iree/tooling/CMakeLists.txt
index a2735e4..2e4a9e3 100644
--- a/runtime/src/iree/tooling/CMakeLists.txt
+++ b/runtime/src/iree/tooling/CMakeLists.txt
@@ -115,8 +115,7 @@
iree::base::tracing
iree::hal
iree::hal::drivers
- iree::hal::utils::caching_allocator
- iree::hal::utils::debug_allocator
+ iree::hal::utils::allocators
PUBLIC
)
diff --git a/runtime/src/iree/tooling/device_util.c b/runtime/src/iree/tooling/device_util.c
index 9c1492f..4d0017f 100644
--- a/runtime/src/iree/tooling/device_util.c
+++ b/runtime/src/iree/tooling/device_util.c
@@ -10,8 +10,7 @@
#include "iree/base/internal/flags.h"
#include "iree/base/tracing.h"
#include "iree/hal/drivers/init.h"
-#include "iree/hal/utils/caching_allocator.h"
-#include "iree/hal/utils/debug_allocator.h"
+#include "iree/hal/utils/allocators.h"
//===----------------------------------------------------------------------===//
// Shared driver registry
@@ -291,40 +290,6 @@
"Specifies one or more HAL device allocator specs to augment the base\n"
"device allocator. See each allocator type for supported configurations.");
-// Parses a single flag and wraps |base_allocator|.
-// Flag values are specifications and may include configuration values.
-// Examples:
-// some_allocator
-// some_allocator:key=value
-// some_allocator:key=value,key=value
-static iree_status_t iree_hal_configure_allocator_from_spec(
- iree_string_view_t spec, iree_hal_device_t* device,
- iree_hal_allocator_t* base_allocator,
- iree_hal_allocator_t** out_wrapped_allocator) {
- iree_string_view_t allocator_name = iree_string_view_empty();
- iree_string_view_t config_pairs = iree_string_view_empty();
- iree_string_view_split(spec, ':', &allocator_name, &config_pairs);
- iree_status_t status = iree_ok_status();
- iree_allocator_t host_allocator =
- iree_hal_allocator_host_allocator(base_allocator);
- if (iree_string_view_equal(allocator_name, IREE_SV("caching"))) {
- status = iree_hal_caching_allocator_create_from_spec(
- config_pairs, base_allocator, host_allocator, out_wrapped_allocator);
- } else if (iree_string_view_equal(allocator_name, IREE_SV("debug"))) {
- status = iree_hal_debug_allocator_create(
- device, base_allocator, host_allocator, out_wrapped_allocator);
- } else {
- return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
- "unrecognized allocator '%.*s'",
- (int)allocator_name.size, allocator_name.data);
- }
- if (iree_status_is_ok(status)) {
- // New wrapping allocator has taken ownership of the base allocator.
- iree_hal_allocator_release(base_allocator);
- }
- return status;
-}
-
// Configures the |device| allocator based on the --device_allocator= flag.
// This will wrap the underlying device allocator in zero or more configurable
// allocator shims.
@@ -333,33 +298,9 @@
// creation.
static iree_status_t iree_hal_configure_allocator_from_flags(
iree_hal_device_t* device) {
- IREE_ASSERT_ARGUMENT(device);
- IREE_TRACE_ZONE_BEGIN(z0);
-
const iree_flag_string_list_t list = FLAG_device_allocator_list();
-
- // The current device allocator should be the base one registered or created
- // with the device. If no allocator flags were provided this may be no-op and
- // we'll just pass it right back in.
- iree_hal_allocator_t* device_allocator = iree_hal_device_allocator(device);
- iree_hal_allocator_retain(device_allocator);
-
- // Walk the specs provided and wrap in order from base to last specified.
- iree_status_t status = iree_ok_status();
- for (iree_host_size_t i = 0; i < list.count; ++i) {
- status = iree_hal_configure_allocator_from_spec(
- list.values[i], device, device_allocator, &device_allocator);
- if (!iree_status_is_ok(status)) break;
- }
-
- // Swap the allocator on the device - this is only safe because we know no
- // allocations have been made yet.
- if (iree_status_is_ok(status)) {
- iree_hal_device_replace_allocator(device, device_allocator);
- }
- iree_hal_allocator_release(device_allocator);
- IREE_TRACE_ZONE_END(z0);
- return status;
+ return iree_hal_configure_allocator_from_specs(list.count, list.values,
+ device);
}
//===----------------------------------------------------------------------===//