Centralizing HAL allocator wrapper creation and exposing to python. (#12356)

diff --git a/runtime/bindings/python/CMakeLists.txt b/runtime/bindings/python/CMakeLists.txt
index b1e4b83..72ddffa 100644
--- a/runtime/bindings/python/CMakeLists.txt
+++ b/runtime/bindings/python/CMakeLists.txt
@@ -52,6 +52,7 @@
     iree::base::tracing
     iree::hal
     iree::hal::drivers
+    iree::hal::utils::allocators
     iree::modules::hal
     iree::vm
     iree::vm::bytecode_module
diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc
index c11dfc2..110f121 100644
--- a/runtime/bindings/python/hal.cc
+++ b/runtime/bindings/python/hal.cc
@@ -10,6 +10,7 @@
 #include "iree/base/internal/path.h"
 #include "iree/base/tracing.h"
 #include "iree/hal/api.h"
+#include "iree/hal/utils/allocators.h"
 #include "iree/modules/hal/module.h"
 #include "pybind11/numpy.h"
 
@@ -312,15 +313,48 @@
   return results;
 }
 
-HalDevice HalDriver::CreateDefaultDevice() {
+// Configures |device| based on flags before returning it to the user.
+static iree_status_t ConfigureDevice(iree_hal_device_t* device,
+                                     const py::kwargs& kwargs) {
+  // Optionally wrap the base device allocator with caching/pooling.
+  // Doing this here satisfies the requirement that no buffers have been
+  // allocated yet - if we returned the device without doing this the caller
+  // can more easily break the rules.
+  if (kwargs.contains("allocators")) {
+    // NOTE: we need to pass string views that point to the std::string storage.
+    // We do that in two passes because as we grow spec_storage it may
+    // reallocate itself and invalidate the pointers - only after we're done
+    // can we capture them in views.
+    auto spec_list = py::cast<py::list>(kwargs["allocators"]);
+    std::vector<std::string> spec_storage;
+    spec_storage.reserve(spec_list.size());
+    for (auto item : spec_list) {
+      auto spec = py::cast<std::string>(item);
+      spec_storage.push_back(std::move(spec));
+    }
+    std::vector<iree_string_view_t> spec_views;
+    spec_views.reserve(spec_list.size());
+    for (const auto& spec : spec_storage) {
+      spec_views.push_back(iree_make_string_view(spec.data(), spec.size()));
+    }
+    IREE_RETURN_IF_ERROR(iree_hal_configure_allocator_from_specs(
+        spec_views.size(), spec_views.data(), device));
+  }
+  return iree_ok_status();
+}
+
+HalDevice HalDriver::CreateDefaultDevice(const py::kwargs& kwargs) {
   iree_hal_device_t* device;
   CheckApiStatus(iree_hal_driver_create_default_device(
                      raw_ptr(), iree_allocator_system(), &device),
                  "Error creating default device");
+  CheckApiStatus(ConfigureDevice(device, kwargs),
+                 "Error configuring the device");
   return HalDevice::StealFromRawPtr(device);
 }
 
-HalDevice HalDriver::CreateDevice(iree_hal_device_id_t device_id) {
+HalDevice HalDriver::CreateDevice(iree_hal_device_id_t device_id,
+                                  const py::kwargs& kwargs) {
   // Since the device ids are supposed to be opaque, we need to verify
   // them by querying available devices.
   py::list available_devices = QueryAvailableDevices();
@@ -352,16 +386,21 @@
                      raw_ptr(), device_id, params.size(), &params.front(),
                      iree_allocator_system(), &device),
                  "Error creating default device");
+  CheckApiStatus(ConfigureDevice(device, kwargs),
+                 "Error configuring the device");
   return HalDevice::StealFromRawPtr(device);
 }
 
-HalDevice HalDriver::CreateDeviceByURI(std::string& device_uri) {
+HalDevice HalDriver::CreateDeviceByURI(std::string& device_uri,
+                                       const py::kwargs& kwargs) {
   iree_hal_device_t* device;
   iree_string_view_t device_uri_sv{device_uri.data(), device_uri.size()};
   CheckApiStatus(
       iree_hal_driver_create_device_by_uri(raw_ptr(), device_uri_sv,
                                            iree_allocator_system(), &device),
       "Error creating device");
+  CheckApiStatus(ConfigureDevice(device, kwargs),
+                 "Error configuring the device");
   return HalDevice::StealFromRawPtr(device);
 }
 
@@ -591,12 +630,13 @@
            py::keep_alive<0, 1>())
       .def(
           "create_device",
-          [](HalDriver& self, py::dict device_info) -> HalDevice {
+          [](HalDriver& self, py::dict device_info,
+             const py::kwargs& kwargs) -> HalDevice {
             // Alias of create_device that takes a dict as returned from
             // query_available_devices for convenience.
             auto device_id =
                 py::cast<iree_hal_device_id_t>(device_info["device_id"]);
-            return self.CreateDevice(device_id);
+            return self.CreateDevice(device_id, kwargs);
           },
           py::keep_alive<0, 1>())
       .def("query_available_devices", &HalDriver::QueryAvailableDevices);
diff --git a/runtime/bindings/python/hal.h b/runtime/bindings/python/hal.h
index 64f0819..ff15dfa 100644
--- a/runtime/bindings/python/hal.h
+++ b/runtime/bindings/python/hal.h
@@ -97,9 +97,11 @@
                            py::dict& driver_cache);
 
   py::list QueryAvailableDevices();
-  HalDevice CreateDefaultDevice();
-  HalDevice CreateDevice(iree_hal_device_id_t device_id);
-  HalDevice CreateDeviceByURI(std::string& device_uri);
+  HalDevice CreateDefaultDevice(const py::kwargs& kwargs);
+  HalDevice CreateDevice(iree_hal_device_id_t device_id,
+                         const py::kwargs& kwargs);
+  HalDevice CreateDeviceByURI(std::string& device_uri,
+                              const py::kwargs& kwargs);
 };
 
 class HalAllocator : public ApiRefCounted<HalAllocator, iree_hal_allocator_t> {
diff --git a/runtime/bindings/python/tests/system_setup_test.py b/runtime/bindings/python/tests/system_setup_test.py
index d4707c3..3248b7c 100644
--- a/runtime/bindings/python/tests/system_setup_test.py
+++ b/runtime/bindings/python/tests/system_setup_test.py
@@ -61,6 +61,13 @@
     with self.assertRaises(ValueError, msg="Device not found: local-sync://1"):
       _ = ss.get_device("local-sync://1")
 
+  def testCreateDeviceWithAllocators(self):
+    driver = ss.get_driver("local-sync")
+    infos = driver.query_available_devices()
+    device1 = driver.create_device(infos[0]["device_id"], allocators=[])
+    device2 = driver.create_device(infos[0]["device_id"],
+                                   allocators=["caching", "debug"])
+
 
 if __name__ == "__main__":
   logging.basicConfig(level=logging.INFO)
diff --git a/runtime/src/iree/hal/utils/BUILD b/runtime/src/iree/hal/utils/BUILD
index 83bb230..78a4755 100644
--- a/runtime/src/iree/hal/utils/BUILD
+++ b/runtime/src/iree/hal/utils/BUILD
@@ -14,6 +14,20 @@
 )
 
 iree_runtime_cc_library(
+    name = "allocators",
+    srcs = ["allocators.c"],
+    hdrs = ["allocators.h"],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":caching_allocator",
+        ":debug_allocator",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/base:tracing",
+        "//runtime/src/iree/hal",
+    ],
+)
+
+iree_runtime_cc_library(
     name = "buffer_transfer",
     srcs = ["buffer_transfer.c"],
     hdrs = ["buffer_transfer.h"],
diff --git a/runtime/src/iree/hal/utils/CMakeLists.txt b/runtime/src/iree/hal/utils/CMakeLists.txt
index 8199d25..3e84e67 100644
--- a/runtime/src/iree/hal/utils/CMakeLists.txt
+++ b/runtime/src/iree/hal/utils/CMakeLists.txt
@@ -12,6 +12,22 @@
 
 iree_cc_library(
   NAME
+    allocators
+  HDRS
+    "allocators.h"
+  SRCS
+    "allocators.c"
+  DEPS
+    ::caching_allocator
+    ::debug_allocator
+    iree::base
+    iree::base::tracing
+    iree::hal
+  PUBLIC
+)
+
+iree_cc_library(
+  NAME
     buffer_transfer
   HDRS
     "buffer_transfer.h"
diff --git a/runtime/src/iree/hal/utils/allocators.c b/runtime/src/iree/hal/utils/allocators.c
new file mode 100644
index 0000000..98c406e
--- /dev/null
+++ b/runtime/src/iree/hal/utils/allocators.c
@@ -0,0 +1,69 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/utils/allocators.h"
+
+#include "iree/base/tracing.h"
+#include "iree/hal/utils/caching_allocator.h"
+#include "iree/hal/utils/debug_allocator.h"
+
+iree_status_t iree_hal_configure_allocator_from_spec(
+    iree_string_view_t spec, iree_hal_device_t* device,
+    iree_hal_allocator_t* base_allocator,
+    iree_hal_allocator_t** out_wrapped_allocator) {
+  iree_string_view_t allocator_name = iree_string_view_empty();
+  iree_string_view_t config_pairs = iree_string_view_empty();
+  iree_string_view_split(spec, ':', &allocator_name, &config_pairs);
+  iree_status_t status = iree_ok_status();
+  iree_allocator_t host_allocator =
+      iree_hal_allocator_host_allocator(base_allocator);
+  if (iree_string_view_equal(allocator_name, IREE_SV("caching"))) {
+    status = iree_hal_caching_allocator_create_from_spec(
+        config_pairs, base_allocator, host_allocator, out_wrapped_allocator);
+  } else if (iree_string_view_equal(allocator_name, IREE_SV("debug"))) {
+    status = iree_hal_debug_allocator_create(
+        device, base_allocator, host_allocator, out_wrapped_allocator);
+  } else {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "unrecognized allocator '%.*s'",
+                            (int)allocator_name.size, allocator_name.data);
+  }
+  if (iree_status_is_ok(status)) {
+    // New wrapping allocator has taken ownership of the base allocator.
+    iree_hal_allocator_release(base_allocator);
+  }
+  return status;
+}
+
+iree_status_t iree_hal_configure_allocator_from_specs(
+    iree_host_size_t spec_count, const iree_string_view_t* specs,
+    iree_hal_device_t* device) {
+  IREE_ASSERT_ARGUMENT(device);
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  // The current device allocator should be the base one registered or created
+  // with the device. If no allocator specs were provided this may be no-op and
+  // we'll just pass it right back in.
+  iree_hal_allocator_t* device_allocator = iree_hal_device_allocator(device);
+  iree_hal_allocator_retain(device_allocator);
+
+  // Walk the specs provided and wrap in order from base to last specified.
+  iree_status_t status = iree_ok_status();
+  for (iree_host_size_t i = 0; i < spec_count; ++i) {
+    status = iree_hal_configure_allocator_from_spec(
+        specs[i], device, device_allocator, &device_allocator);
+    if (!iree_status_is_ok(status)) break;
+  }
+
+  // Swap the allocator on the device - this is only safe because we know no
+  // allocations have been made yet.
+  if (iree_status_is_ok(status)) {
+    iree_hal_device_replace_allocator(device, device_allocator);
+  }
+  iree_hal_allocator_release(device_allocator);
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
diff --git a/runtime/src/iree/hal/utils/allocators.h b/runtime/src/iree/hal/utils/allocators.h
new file mode 100644
index 0000000..faa5efd
--- /dev/null
+++ b/runtime/src/iree/hal/utils/allocators.h
@@ -0,0 +1,48 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_UTILS_ALLOCATORS_H_
+#define IREE_HAL_UTILS_ALLOCATORS_H_
+
+#include "iree/base/api.h"
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// WARNING: including this file will pull in all allocator implementations.
+// Only use this if you need the dynamic allocator configuration and otherwise
+// prefer to directly instantiate the allocators you want with their structured
+// options instead of strings.
+
+// Parses a single allocator specification value and wraps |base_allocator|.
+// The available allocators is based on the build configuration.
+//
+// Examples:
+//   some_allocator
+//   some_allocator:key=value
+//   some_allocator:key=value,key=value
+iree_status_t iree_hal_configure_allocator_from_spec(
+    iree_string_view_t spec, iree_hal_device_t* device,
+    iree_hal_allocator_t* base_allocator,
+    iree_hal_allocator_t** out_wrapped_allocator);
+
+// Configures a |device| allocator based on the allocator |specs|.
+// This will wrap the underlying device allocator in zero or more configurable
+// allocator shims.
+//
+// WARNING: not thread-safe and must only be called immediately after device
+// creation.
+iree_status_t iree_hal_configure_allocator_from_specs(
+    iree_host_size_t spec_count, const iree_string_view_t* specs,
+    iree_hal_device_t* device);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_UTILS_ALLOCATORS_H_
diff --git a/runtime/src/iree/tooling/BUILD b/runtime/src/iree/tooling/BUILD
index e6ca5f8..423aa24 100644
--- a/runtime/src/iree/tooling/BUILD
+++ b/runtime/src/iree/tooling/BUILD
@@ -103,8 +103,7 @@
         "//runtime/src/iree/base/internal:synchronization",
         "//runtime/src/iree/hal",
         "//runtime/src/iree/hal/drivers",
-        "//runtime/src/iree/hal/utils:caching_allocator",
-        "//runtime/src/iree/hal/utils:debug_allocator",
+        "//runtime/src/iree/hal/utils:allocators",
     ],
 )
 
diff --git a/runtime/src/iree/tooling/CMakeLists.txt b/runtime/src/iree/tooling/CMakeLists.txt
index a2735e4..2e4a9e3 100644
--- a/runtime/src/iree/tooling/CMakeLists.txt
+++ b/runtime/src/iree/tooling/CMakeLists.txt
@@ -115,8 +115,7 @@
     iree::base::tracing
     iree::hal
     iree::hal::drivers
-    iree::hal::utils::caching_allocator
-    iree::hal::utils::debug_allocator
+    iree::hal::utils::allocators
   PUBLIC
 )
 
diff --git a/runtime/src/iree/tooling/device_util.c b/runtime/src/iree/tooling/device_util.c
index 9c1492f..4d0017f 100644
--- a/runtime/src/iree/tooling/device_util.c
+++ b/runtime/src/iree/tooling/device_util.c
@@ -10,8 +10,7 @@
 #include "iree/base/internal/flags.h"
 #include "iree/base/tracing.h"
 #include "iree/hal/drivers/init.h"
-#include "iree/hal/utils/caching_allocator.h"
-#include "iree/hal/utils/debug_allocator.h"
+#include "iree/hal/utils/allocators.h"
 
 //===----------------------------------------------------------------------===//
 // Shared driver registry
@@ -291,40 +290,6 @@
     "Specifies one or more HAL device allocator specs to augment the base\n"
     "device allocator. See each allocator type for supported configurations.");
 
-// Parses a single flag and wraps |base_allocator|.
-// Flag values are specifications and may include configuration values.
-// Examples:
-//   some_allocator
-//   some_allocator:key=value
-//   some_allocator:key=value,key=value
-static iree_status_t iree_hal_configure_allocator_from_spec(
-    iree_string_view_t spec, iree_hal_device_t* device,
-    iree_hal_allocator_t* base_allocator,
-    iree_hal_allocator_t** out_wrapped_allocator) {
-  iree_string_view_t allocator_name = iree_string_view_empty();
-  iree_string_view_t config_pairs = iree_string_view_empty();
-  iree_string_view_split(spec, ':', &allocator_name, &config_pairs);
-  iree_status_t status = iree_ok_status();
-  iree_allocator_t host_allocator =
-      iree_hal_allocator_host_allocator(base_allocator);
-  if (iree_string_view_equal(allocator_name, IREE_SV("caching"))) {
-    status = iree_hal_caching_allocator_create_from_spec(
-        config_pairs, base_allocator, host_allocator, out_wrapped_allocator);
-  } else if (iree_string_view_equal(allocator_name, IREE_SV("debug"))) {
-    status = iree_hal_debug_allocator_create(
-        device, base_allocator, host_allocator, out_wrapped_allocator);
-  } else {
-    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-                            "unrecognized allocator '%.*s'",
-                            (int)allocator_name.size, allocator_name.data);
-  }
-  if (iree_status_is_ok(status)) {
-    // New wrapping allocator has taken ownership of the base allocator.
-    iree_hal_allocator_release(base_allocator);
-  }
-  return status;
-}
-
 // Configures the |device| allocator based on the --device_allocator= flag.
 // This will wrap the underlying device allocator in zero or more configurable
 // allocator shims.
@@ -333,33 +298,9 @@
 // creation.
 static iree_status_t iree_hal_configure_allocator_from_flags(
     iree_hal_device_t* device) {
-  IREE_ASSERT_ARGUMENT(device);
-  IREE_TRACE_ZONE_BEGIN(z0);
-
   const iree_flag_string_list_t list = FLAG_device_allocator_list();
-
-  // The current device allocator should be the base one registered or created
-  // with the device. If no allocator flags were provided this may be no-op and
-  // we'll just pass it right back in.
-  iree_hal_allocator_t* device_allocator = iree_hal_device_allocator(device);
-  iree_hal_allocator_retain(device_allocator);
-
-  // Walk the specs provided and wrap in order from base to last specified.
-  iree_status_t status = iree_ok_status();
-  for (iree_host_size_t i = 0; i < list.count; ++i) {
-    status = iree_hal_configure_allocator_from_spec(
-        list.values[i], device, device_allocator, &device_allocator);
-    if (!iree_status_is_ok(status)) break;
-  }
-
-  // Swap the allocator on the device - this is only safe because we know no
-  // allocations have been made yet.
-  if (iree_status_is_ok(status)) {
-    iree_hal_device_replace_allocator(device, device_allocator);
-  }
-  iree_hal_allocator_release(device_allocator);
-  IREE_TRACE_ZONE_END(z0);
-  return status;
+  return iree_hal_configure_allocator_from_specs(list.count, list.values,
+                                                 device);
 }
 
 //===----------------------------------------------------------------------===//