[runtime][python] Add multi-device HAL module construction (#17943)
The underlying C HAL function supports creating the HAL with multiple
devices. The Python API should support that as well.
---------
Signed-off-by: Boian Petkantchin <boian.petkantchin@amd.com>
diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc
index d102c00..61bd8a3 100644
--- a/runtime/bindings/python/hal.cc
+++ b/runtime/bindings/python/hal.cc
@@ -6,6 +6,11 @@
#include "./hal.h"
+#include <nanobind/nanobind.h>
+#include <nanobind/stl/vector.h>
+
+#include <optional>
+
#include "./local_dlpack.h"
#include "./numpy_interop.h"
#include "./vm.h"
@@ -1066,12 +1071,34 @@
// HAL module
//------------------------------------------------------------------------------
-// TODO(multi-device): allow for multiple devices to be passed in.
-VmModule CreateHalModule(VmInstance* instance, HalDevice* device) {
- iree_hal_device_t* device_ptr = device->raw_ptr();
+VmModule CreateHalModule(VmInstance* instance, std::optional<HalDevice*> device,
+ std::optional<py::list> devices) {
+ if (device && devices) {
+ PyErr_SetString(
+ PyExc_ValueError,
+ "\"device\" and \"devices\" are mutually exclusive arguments.");
+ }
+ std::vector<iree_hal_device_t*> devices_vector;
+ iree_hal_device_t* device_ptr;
+ iree_hal_device_t** devices_ptr;
+ iree_host_size_t device_count;
iree_vm_module_t* module = NULL;
- CheckApiStatus(iree_hal_module_create(instance->raw_ptr(), /*device_count=*/1,
- &device_ptr, IREE_HAL_MODULE_FLAG_NONE,
+ if (device) {
+ device_ptr = device.value()->raw_ptr();
+ devices_ptr = &device_ptr;
+ device_count = 1;
+ } else {
+ // Set device related arguments in the case of multiple devices.
+ devices_vector.reserve(devices->size());
+ for (auto devicesIt = devices->begin(); devicesIt != devices->end();
+ ++devicesIt) {
+ devices_vector.push_back(py::cast<HalDevice*>(*devicesIt)->raw_ptr());
+ }
+ devices_ptr = devices_vector.data();
+ device_count = devices_vector.size();
+ }
+ CheckApiStatus(iree_hal_module_create(instance->raw_ptr(), device_count,
+ devices_ptr, IREE_HAL_MODULE_FLAG_NONE,
iree_allocator_system(), &module),
"Error creating hal module");
return VmModule::StealFromRawPtr(module);
@@ -1085,7 +1112,8 @@
py::dict driver_cache;
// Built-in module creation.
- m.def("create_hal_module", &CreateHalModule);
+ m.def("create_hal_module", &CreateHalModule, py::arg("instance"),
+ py::arg("device") = py::none(), py::arg("devices") = py::none());
// Enums.
py::enum_<enum iree_hal_memory_type_bits_t>(m, "MemoryType")
diff --git a/runtime/bindings/python/iree/runtime/_binding.pyi b/runtime/bindings/python/iree/runtime/_binding.pyi
index 803e10f..040b92f 100644
--- a/runtime/bindings/python/iree/runtime/_binding.pyi
+++ b/runtime/bindings/python/iree/runtime/_binding.pyi
@@ -4,7 +4,11 @@
import asyncio
-def create_hal_module(instance: VmInstance, device: HalDevice) -> VmModule: ...
+def create_hal_module(
+ instance: VmInstance,
+ device: Optional[HalDevice] = None,
+ devices: Optional[List[HalDevice]] = None,
+) -> VmModule: ...
def create_io_parameters_module(
instance: VmInstance, *providers: ParameterProvider
) -> VmModule: ...
diff --git a/runtime/bindings/python/tests/vm_test.py b/runtime/bindings/python/tests/vm_test.py
index 8aae0b7..bc65926 100644
--- a/runtime/bindings/python/tests/vm_test.py
+++ b/runtime/bindings/python/tests/vm_test.py
@@ -219,6 +219,15 @@
logging.info("result: %s", result)
np.testing.assert_allclose(result, [4.0, 10.0, 18.0, 28.0])
+ def test_create_vm_module_with_multiple_devices(self):
+ """Sanity test that we can create a VM module with 2 devices."""
+ devices = [
+ iree.runtime.get_device("local-task"),
+ iree.runtime.get_device("local-sync"),
+ ]
+ module = iree.runtime.create_hal_module(self.instance, devices=devices)
+ assert isinstance(module, iree.runtime.VmModule)
+
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)