[runtime][python] Add multi-device HAL module construction (#17943)

The underlying C HAL function supports creating the HAL with multiple
devices. The Python API should support that as well.

---------

Signed-off-by: Boian Petkantchin <boian.petkantchin@amd.com>
diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc
index d102c00..61bd8a3 100644
--- a/runtime/bindings/python/hal.cc
+++ b/runtime/bindings/python/hal.cc
@@ -6,6 +6,11 @@
 
 #include "./hal.h"
 
+#include <nanobind/nanobind.h>
+#include <nanobind/stl/vector.h>
+
+#include <optional>
+
 #include "./local_dlpack.h"
 #include "./numpy_interop.h"
 #include "./vm.h"
@@ -1066,12 +1071,34 @@
 // HAL module
 //------------------------------------------------------------------------------
 
-// TODO(multi-device): allow for multiple devices to be passed in.
-VmModule CreateHalModule(VmInstance* instance, HalDevice* device) {
-  iree_hal_device_t* device_ptr = device->raw_ptr();
+VmModule CreateHalModule(VmInstance* instance, std::optional<HalDevice*> device,
+                         std::optional<py::list> devices) {
+  if (device && devices) {
+    PyErr_SetString(
+        PyExc_ValueError,
+        "\"device\" and \"devices\" are mutually exclusive arguments.");
+  }
+  std::vector<iree_hal_device_t*> devices_vector;
+  iree_hal_device_t* device_ptr;
+  iree_hal_device_t** devices_ptr;
+  iree_host_size_t device_count;
   iree_vm_module_t* module = NULL;
-  CheckApiStatus(iree_hal_module_create(instance->raw_ptr(), /*device_count=*/1,
-                                        &device_ptr, IREE_HAL_MODULE_FLAG_NONE,
+  if (device) {
+    device_ptr = device.value()->raw_ptr();
+    devices_ptr = &device_ptr;
+    device_count = 1;
+  } else {
+    // Set device related arguments in the case of multiple devices.
+    devices_vector.reserve(devices->size());
+    for (auto devicesIt = devices->begin(); devicesIt != devices->end();
+         ++devicesIt) {
+      devices_vector.push_back(py::cast<HalDevice*>(*devicesIt)->raw_ptr());
+    }
+    devices_ptr = devices_vector.data();
+    device_count = devices_vector.size();
+  }
+  CheckApiStatus(iree_hal_module_create(instance->raw_ptr(), device_count,
+                                        devices_ptr, IREE_HAL_MODULE_FLAG_NONE,
                                         iree_allocator_system(), &module),
                  "Error creating hal module");
   return VmModule::StealFromRawPtr(module);
@@ -1085,7 +1112,8 @@
   py::dict driver_cache;
 
   // Built-in module creation.
-  m.def("create_hal_module", &CreateHalModule);
+  m.def("create_hal_module", &CreateHalModule, py::arg("instance"),
+        py::arg("device") = py::none(), py::arg("devices") = py::none());
 
   // Enums.
   py::enum_<enum iree_hal_memory_type_bits_t>(m, "MemoryType")
diff --git a/runtime/bindings/python/iree/runtime/_binding.pyi b/runtime/bindings/python/iree/runtime/_binding.pyi
index 803e10f..040b92f 100644
--- a/runtime/bindings/python/iree/runtime/_binding.pyi
+++ b/runtime/bindings/python/iree/runtime/_binding.pyi
@@ -4,7 +4,11 @@
 
 import asyncio
 
-def create_hal_module(instance: VmInstance, device: HalDevice) -> VmModule: ...
+def create_hal_module(
+    instance: VmInstance,
+    device: Optional[HalDevice] = None,
+    devices: Optional[List[HalDevice]] = None,
+) -> VmModule: ...
 def create_io_parameters_module(
     instance: VmInstance, *providers: ParameterProvider
 ) -> VmModule: ...
diff --git a/runtime/bindings/python/tests/vm_test.py b/runtime/bindings/python/tests/vm_test.py
index 8aae0b7..bc65926 100644
--- a/runtime/bindings/python/tests/vm_test.py
+++ b/runtime/bindings/python/tests/vm_test.py
@@ -219,6 +219,15 @@
         logging.info("result: %s", result)
         np.testing.assert_allclose(result, [4.0, 10.0, 18.0, 28.0])
 
+    def test_create_vm_module_with_multiple_devices(self):
+        """Sanity test that we can create a VM module with 2 devices."""
+        devices = [
+            iree.runtime.get_device("local-task"),
+            iree.runtime.get_device("local-sync"),
+        ]
+        module = iree.runtime.create_hal_module(self.instance, devices=devices)
+        assert isinstance(module, iree.runtime.VmModule)
+
 
 if __name__ == "__main__":
     logging.basicConfig(level=logging.DEBUG)