Some python binding plumbing to enable the new VM2 module loading.

* Just adds enough of the new hal types to get a default device.
* Restricted to vulkan for the moment.
* Verifies that with the new hal module, the VM2 compiled simple_mul module loads.
* Disabled vm_test for TAP due to missing vulkan deps on test hosts.

PiperOrigin-RevId: 286590561
diff --git a/bindings/python/pyiree/BUILD b/bindings/python/pyiree/BUILD
index e3d710a..d95dd48 100644
--- a/bindings/python/pyiree/BUILD
+++ b/bindings/python/pyiree/BUILD
@@ -127,6 +127,7 @@
         "@com_google_absl//absl/types:span",
         "@com_google_absl//absl/types:variant",
         "//iree/compiler/Utils",
+        "//iree/modules/hal",
         "//iree/vm2",
         "//iree/vm2:bytecode_module",
         "@local_config_mlir//:IR",
@@ -210,6 +211,8 @@
     name = "vm_test",
     srcs = ["vm_test.py"],
     python_version = "PY3",
+    # TODO(laurenzo): Enable once test does not depend on a real vulkan device.
+    tags = ["notap"],
     deps = NUMPY_DEPS + [
         "//bindings/python:pathsetup",  # build_cleaner: keep
         "@absl_py//absl/testing:absltest",
diff --git a/bindings/python/pyiree/hal.cc b/bindings/python/pyiree/hal.cc
index 413224a..a78f46b 100644
--- a/bindings/python/pyiree/hal.cc
+++ b/bindings/python/pyiree/hal.cc
@@ -86,6 +86,43 @@
 
 }  // namespace
 
+//------------------------------------------------------------------------------
+// HalDriver
+//------------------------------------------------------------------------------
+
+std::vector<std::string> HalDriver::Query() {
+  iree_string_view_t* driver_names;
+  iree_host_size_t driver_count;
+  CheckApiStatus(iree_hal_driver_registry_query_available_drivers(
+                     IREE_ALLOCATOR_SYSTEM, &driver_names, &driver_count),
+                 "Error querying drivers");
+
+  std::vector<std::string> drivers;
+  drivers.resize(driver_count);
+  for (iree_host_size_t i = 0; i < driver_count; ++i) {
+    drivers[i] = std::string(driver_names[i].data, driver_names[i].size);
+  }
+  free(driver_names);
+  return drivers;
+}
+
+HalDriver HalDriver::Create(const std::string& driver_name) {
+  iree_hal_driver_t* driver;
+  CheckApiStatus(iree_hal_driver_registry_create_driver(
+                     {driver_name.data(), driver_name.size()},
+                     IREE_ALLOCATOR_SYSTEM, &driver),
+                 "Error creating driver");
+  return HalDriver::CreateRetained(driver);
+}
+
+HalDevice HalDriver::CreateDefaultDevice() {
+  iree_hal_device_t* device;
+  CheckApiStatus(iree_hal_driver_create_default_device(
+                     raw_ptr(), IREE_ALLOCATOR_SYSTEM, &device),
+                 "Error creating default device");
+  return HalDevice::CreateRetained(device);
+}
+
 void SetupHalBindings(pybind11::module m) {
   // Enums.
   py::enum_<iree_hal_memory_type_t>(m, "MemoryType")
@@ -115,6 +152,12 @@
       .value("ALL", IREE_HAL_MEMORY_ACCESS_ALL)
       .export_values();
 
+  py::class_<HalDevice>(m, "HalDevice");
+  py::class_<HalDriver>(m, "HalDriver")
+      .def_static("query", &HalDriver::Query)
+      .def_static("create", &HalDriver::Create, py::arg("driver_name"))
+      .def("create_default_device", &HalDriver::CreateDefaultDevice);
+
   py::class_<HalShape>(m, "Shape").def(py::init(&HalShape::FromIntVector));
   py::class_<HalBufferView>(m, "BufferView")
       .def("map", HalMappedMemory::Create);
diff --git a/bindings/python/pyiree/hal.h b/bindings/python/pyiree/hal.h
index 38e2c3b..2f7cfdb 100644
--- a/bindings/python/pyiree/hal.h
+++ b/bindings/python/pyiree/hal.h
@@ -22,6 +22,22 @@
 namespace iree {
 namespace python {
 
+//------------------------------------------------------------------------------
+// Retain/release bindings
+//------------------------------------------------------------------------------
+
+template <>
+struct ApiPtrAdapter<iree_hal_driver_t> {
+  static void Retain(iree_hal_driver_t* d) { iree_hal_driver_retain(d); }
+  static void Release(iree_hal_driver_t* d) { iree_hal_driver_release(d); }
+};
+
+template <>
+struct ApiPtrAdapter<iree_hal_device_t> {
+  static void Retain(iree_hal_device_t* d) { iree_hal_device_retain(d); }
+  static void Release(iree_hal_device_t* d) { iree_hal_device_release(d); }
+};
+
 template <>
 struct ApiPtrAdapter<iree_hal_buffer_t> {
   static void Retain(iree_hal_buffer_t* b) { iree_hal_buffer_retain(b); }
@@ -38,6 +54,22 @@
   }
 };
 
+//------------------------------------------------------------------------------
+// ApiRefCounted types
+//------------------------------------------------------------------------------
+
+class HalDevice : public ApiRefCounted<HalDevice, iree_hal_device_t> {
+ public:
+};
+
+class HalDriver : public ApiRefCounted<HalDriver, iree_hal_driver_t> {
+ public:
+  static std::vector<std::string> Query();
+  static HalDriver Create(const std::string& driver_name);
+
+  HalDevice CreateDefaultDevice();
+};
+
 struct HalShape {
  public:
   static HalShape FromIntVector(std::vector<int32_t> indices) {
diff --git a/bindings/python/pyiree/vm.cc b/bindings/python/pyiree/vm.cc
index 72af468..c2e59fc 100644
--- a/bindings/python/pyiree/vm.cc
+++ b/bindings/python/pyiree/vm.cc
@@ -17,10 +17,13 @@
 #include "absl/types/optional.h"
 #include "bindings/python/pyiree/status_utils.h"
 #include "iree/base/api.h"
+#include "iree/modules/hal/hal_module.h"
 
 namespace iree {
 namespace python {
 
+namespace {
+
 RtModule CreateModuleFromBlob(std::shared_ptr<OpaqueBlob> blob) {
   iree_rt_module_t* module;
   auto free_fn = OpaqueBlob::CreateFreeFn(blob);
@@ -31,6 +34,16 @@
   return RtModule::CreateRetained(module);
 }
 
+VmModule CreateHalModule(HalDevice* device) {
+  iree_vm_module_t* module;
+  CheckApiStatus(
+      iree_hal_module_create(device->raw_ptr(), IREE_ALLOCATOR_SYSTEM, &module),
+      "Error creating hal module");
+  return VmModule::CreateRetained(module);
+}
+
+}  // namespace
+
 //------------------------------------------------------------------------------
 // VmInstance
 //------------------------------------------------------------------------------
@@ -115,9 +128,15 @@
 }
 
 void SetupVmBindings(pybind11::module m) {
+  CHECK_EQ(IREE_STATUS_OK, iree_vm_register_builtin_types());
+  CHECK_EQ(IREE_STATUS_OK, iree_hal_module_register_types());
+
   // Deprecated: VM1 module.
   m.def("create_module_from_blob", CreateModuleFromBlob);
 
+  // Built-in module creation.
+  m.def("create_hal_module", &CreateHalModule);
+
   py::enum_<iree_vm_function_linkage_t>(m, "Linkage")
       .value("INTERNAL", IREE_VM_FUNCTION_LINKAGE_INTERNAL)
       .value("IMPORT", IREE_VM_FUNCTION_LINKAGE_IMPORT)
diff --git a/bindings/python/pyiree/vm_test.py b/bindings/python/pyiree/vm_test.py
index 338d64d..3302907 100644
--- a/bindings/python/pyiree/vm_test.py
+++ b/bindings/python/pyiree/vm_test.py
@@ -14,7 +14,6 @@
 # limitations under the License.
 
 # pylint: disable=unused-variable
-# pylint: disable=g-unreachable-test-method
 
 from absl.testing import absltest
 import pyiree
@@ -36,6 +35,15 @@
 
 class VmTest(absltest.TestCase):
 
+  @classmethod
+  def setUpClass(cls):
+    super().setUpClass()
+    driver_names = pyiree.binding.hal.HalDriver.query()
+    print("DRIVER_NAMES =", driver_names)
+    cls.driver = pyiree.binding.hal.HalDriver.create("vulkan")
+    cls.device = cls.driver.create_default_device()
+    cls.hal_module = pyiree.binding.vm.create_hal_module(cls.device)
+
   def test_variant_list(self):
     l = pyiree.binding.vm.VmVariantList(5)
     print(l)
@@ -47,25 +55,26 @@
     context2 = pyiree.binding.vm.VmContext(instance)
     self.assertGreater(context2.context_id, context1.context_id)
 
-  def disabled_test_module_basics(self):
+  def test_module_basics(self):
     m = create_simple_mul_module()
     f = m.lookup_function("simple_mul")
     self.assertGreater(f.ordinal, 0)
     notfound = m.lookup_function("notfound")
     self.assertIs(notfound, None)
 
-  def disabled_test_dynamic_module_context(self):
+  def test_dynamic_module_context(self):
     instance = pyiree.binding.vm.VmInstance()
     context = pyiree.binding.vm.VmContext(instance)
     m = create_simple_mul_module()
-    context.register_modules([m])
+    context.register_modules([self.hal_module, m])
 
-  def disabled_test_static_module_context(self):
+  def test_static_module_context(self):
     m = create_simple_mul_module()
     print(m)
     instance = pyiree.binding.vm.VmInstance()
     print(instance)
-    context = pyiree.binding.vm.VmContext(instance, modules=[m])
+    context = pyiree.binding.vm.VmContext(
+        instance, modules=[self.hal_module, m])
     print(context)