In Python API create devices with default collectives channel provider (#14384)
In a similar manner as is done in iree-run-module, populate created
devices with a default channel provider.
Co-authored-by: Boian Petkantchin <boian@nod-labs.com>
diff --git a/runtime/bindings/python/CMakeLists.txt b/runtime/bindings/python/CMakeLists.txt
index 4b5196f..4e447ad 100644
--- a/runtime/bindings/python/CMakeLists.txt
+++ b/runtime/bindings/python/CMakeLists.txt
@@ -86,6 +86,7 @@
iree::hal::drivers
iree::hal::utils::allocators
iree::modules::hal
+ iree::tooling::device_util
iree::tooling::modules
iree::vm
iree::vm::bytecode::module
diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc
index 747f009..e56d974 100644
--- a/runtime/bindings/python/hal.cc
+++ b/runtime/bindings/python/hal.cc
@@ -12,6 +12,7 @@
#include "iree/hal/api.h"
#include "iree/hal/utils/allocators.h"
#include "iree/modules/hal/module.h"
+#include "iree/tooling/device_util.h"
namespace iree {
namespace python {
@@ -605,6 +606,9 @@
IREE_RETURN_IF_ERROR(iree_hal_configure_allocator_from_specs(
spec_views.size(), spec_views.data(), device));
}
+
+ IREE_RETURN_IF_ERROR(iree_hal_device_set_default_channel_provider(device));
+
return iree_ok_status();
}
diff --git a/runtime/src/iree/tooling/device_util.c b/runtime/src/iree/tooling/device_util.c
index b714333..0a0798a 100644
--- a/runtime/src/iree/tooling/device_util.c
+++ b/runtime/src/iree/tooling/device_util.c
@@ -307,13 +307,7 @@
// Collectives configuration
//===----------------------------------------------------------------------===//
-// Configures the |device| channel provider based on the current environment.
-// Today this simply checks to see if the process is running under MPI and
-// initializes that unconditionally.
-//
-// WARNING: not thread-safe and must only be called immediately after device
-// creation.
-static iree_status_t iree_hal_configure_collectives_from_flags(
+iree_status_t iree_hal_device_set_default_channel_provider(
iree_hal_device_t* device) {
if (!iree_hal_mpi_is_configured()) return iree_ok_status();
iree_hal_channel_provider_t* channel_provider = NULL;
@@ -386,7 +380,7 @@
// their default channels. Hosting libraries or applications can do the same
// to interface with their own implementations.
if (iree_status_is_ok(status)) {
- status = iree_hal_configure_collectives_from_flags(device);
+ status = iree_hal_device_set_default_channel_provider(device);
}
if (iree_status_is_ok(status)) {
diff --git a/runtime/src/iree/tooling/device_util.h b/runtime/src/iree/tooling/device_util.h
index d290fad..68aede9 100644
--- a/runtime/src/iree/tooling/device_util.h
+++ b/runtime/src/iree/tooling/device_util.h
@@ -34,6 +34,15 @@
iree_string_view_t default_device, iree_allocator_t host_allocator,
iree_hal_device_t** out_device);
+// Configures the |device| channel provider based on the current environment.
+// Today this simply checks to see if the process is running under MPI and
+// initializes that unconditionally.
+//
+// WARNING: not thread-safe and must only be called immediately after device
+// creation.
+iree_status_t iree_hal_device_set_default_channel_provider(
+ iree_hal_device_t* device);
+
// Equivalent to iree_hal_device_profiling_begin with options sourced from
// command line flags. No-op if profiling is not enabled.
// Must be matched with a call to iree_hal_end_profiling_from_flags.