Making VM modules take an instance arg.
This will make it possible to use instance-specific information during
module construction like registered types.
Progress on #8698.
diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc
index bdd4929..3166fa2 100644
--- a/runtime/bindings/python/hal.cc
+++ b/runtime/bindings/python/hal.cc
@@ -387,12 +387,12 @@
 // HAL module
 //------------------------------------------------------------------------------
 
-VmModule CreateHalModule(HalDevice* device) {
-  iree_vm_module_t* module;
-  CheckApiStatus(
-      iree_hal_module_create(device->raw_ptr(), IREE_HAL_MODULE_FLAG_NONE,
-                             iree_allocator_system(), &module),
-      "Error creating hal module");
+VmModule CreateHalModule(VmInstance* instance, HalDevice* device) {
+  iree_vm_module_t* module = NULL;
+  CheckApiStatus(iree_hal_module_create(instance->raw_ptr(), device->raw_ptr(),
+                                        IREE_HAL_MODULE_FLAG_NONE,
+                                        iree_allocator_system(), &module),
+                 "Error creating hal module");
   return VmModule::StealFromRawPtr(module);
 }
 
@@ -403,9 +403,6 @@
 void SetupHalBindings(pybind11::module m) {
   py::dict driver_cache;
 
-  // TODO(#8698): need to register these on an instance.
-  IREE_CHECK_OK(iree_hal_module_register_all_types(NULL));
-
   // Built-in module creation.
   m.def("create_hal_module", &CreateHalModule);
 
diff --git a/runtime/bindings/python/iree/runtime/system_api.py b/runtime/bindings/python/iree/runtime/system_api.py
index 5479224..025b4b5 100644
--- a/runtime/bindings/python/iree/runtime/system_api.py
+++ b/runtime/bindings/python/iree/runtime/system_api.py
@@ -72,7 +72,7 @@
           driver_name.split(",") if driver_name is not None else None)
 
     self.vm_instance = _binding.VmInstance()
-    hal_module = _binding.create_hal_module(self.device)
+    hal_module = _binding.create_hal_module(self.vm_instance, self.device)
     self.default_vm_modules = (hal_module,)
     self.tracer = tracer or tracing.get_default_tracer()
     if self.tracer and self.tracer.enabled:
@@ -283,8 +283,10 @@
                      "the driver from.")
   if backend is not None:
     driver = TARGET_BACKEND_TO_DRIVER[backend]
-  vm_module = _binding.VmModule.from_flatbuffer(vm_flatbuffer)
-  bound_module = load_vm_module(vm_module, Config(driver))
+  config = Config(driver)
+  vm_module = _binding.VmModule.from_flatbuffer(config.vm_instance,
+                                                vm_flatbuffer)
+  bound_module = load_vm_module(vm_module, config)
   return bound_module
 
 
diff --git a/runtime/bindings/python/tests/system_api_test.py b/runtime/bindings/python/tests/system_api_test.py
index ccb0dbf..227df92 100644
--- a/runtime/bindings/python/tests/system_api_test.py
+++ b/runtime/bindings/python/tests/system_api_test.py
@@ -17,7 +17,7 @@
 import numpy as np
 
 
-def create_simple_mul_module():
+def create_simple_mul_module(instance):
   binary = iree.compiler.compile_str(
       """
       module @arithmetic {
@@ -29,7 +29,7 @@
       """,
       target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
   )
-  m = iree.runtime.VmModule.from_flatbuffer(binary)
+  m = iree.runtime.VmModule.from_flatbuffer(instance, binary)
   return m
 
 
@@ -62,7 +62,7 @@
   def test_custom_dynamic(self):
     ctx = iree.runtime.SystemContext()
     self.assertTrue(ctx.is_dynamic)
-    ctx.add_vm_module(create_simple_mul_module())
+    ctx.add_vm_module(create_simple_mul_module(ctx.instance))
     self.assertEqual(ctx.modules.arithmetic.name, "arithmetic")
     f = ctx.modules.arithmetic["simple_mul"]
     f_repr = repr(f)
@@ -72,14 +72,14 @@
   def test_duplicate_module(self):
     ctx = iree.runtime.SystemContext()
     self.assertTrue(ctx.is_dynamic)
-    ctx.add_vm_module(create_simple_mul_module())
+    ctx.add_vm_module(create_simple_mul_module(ctx.instance))
     with self.assertRaisesRegex(ValueError, "arithmetic"):
-      ctx.add_vm_module(create_simple_mul_module())
+      ctx.add_vm_module(create_simple_mul_module(ctx.instance))
 
   def test_static_invoke(self):
     ctx = iree.runtime.SystemContext()
     self.assertTrue(ctx.is_dynamic)
-    ctx.add_vm_module(create_simple_mul_module())
+    ctx.add_vm_module(create_simple_mul_module(ctx.instance))
     self.assertEqual(ctx.modules.arithmetic.name, "arithmetic")
     f = ctx.modules.arithmetic["simple_mul"]
     arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
@@ -92,7 +92,7 @@
     # and input to functions.
     ctx = iree.runtime.SystemContext()
     self.assertTrue(ctx.is_dynamic)
-    ctx.add_vm_module(create_simple_mul_module())
+    ctx.add_vm_module(create_simple_mul_module(ctx.instance))
     self.assertEqual(ctx.modules.arithmetic.name, "arithmetic")
     f = ctx.modules.arithmetic["simple_mul"]
     arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
@@ -121,7 +121,7 @@
   def verify_tracing(self, config, temp_dir):
     logging.info("Tracing test to: %s", temp_dir)
     ctx = iree.runtime.SystemContext(config=config)
-    ctx.add_vm_module(create_simple_mul_module())
+    ctx.add_vm_module(create_simple_mul_module(ctx.instance))
     f = ctx.modules.arithmetic["simple_mul"]
     arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
     arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
@@ -131,7 +131,9 @@
     # TODO: Once replay is possible, verify that.
 
   def test_load_vm_module(self):
-    arithmetic = iree.runtime.load_vm_module(create_simple_mul_module())
+    ctx = iree.runtime.SystemContext()
+    arithmetic = iree.runtime.load_vm_module(
+        create_simple_mul_module(ctx.instance))
     arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
     arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
     results = arithmetic.simple_mul(arg0, arg1)
@@ -142,7 +144,8 @@
     # Doing default device configuration multiple times should be valid
     # (if this were instantiating drivers multiple times, it can trigger
     # a crash, depending on whether the driver supports multi-instantiation).
-    m = create_simple_mul_module()
+    ctx = iree.runtime.SystemContext()
+    m = create_simple_mul_module(ctx.instance)
     m1 = iree.runtime.load_vm_module(m)
     m2 = iree.runtime.load_vm_module(m)
 
diff --git a/runtime/bindings/python/tests/vm_test.py b/runtime/bindings/python/tests/vm_test.py
index 98573b0..1e7d307 100644
--- a/runtime/bindings/python/tests/vm_test.py
+++ b/runtime/bindings/python/tests/vm_test.py
@@ -14,7 +14,7 @@
 import iree.runtime
 
 
-def create_add_scalar_module():
+def create_add_scalar_module(instance):
   binary = iree.compiler.compile_str(
       """
       func.func @add_scalar(%arg0: i32, %arg1: i32) -> i32 {
@@ -24,11 +24,11 @@
       """,
       target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
   )
-  m = iree.runtime.VmModule.from_flatbuffer(binary)
+  m = iree.runtime.VmModule.from_flatbuffer(instance, binary)
   return m
 
 
-def create_simple_static_mul_module():
+def create_simple_static_mul_module(instance):
   binary = iree.compiler.compile_str(
       """
       func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
@@ -38,11 +38,11 @@
       """,
       target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
   )
-  m = iree.runtime.VmModule.from_flatbuffer(binary)
+  m = iree.runtime.VmModule.from_flatbuffer(instance, binary)
   return m
 
 
-def create_simple_dynamic_abs_module():
+def create_simple_dynamic_abs_module(instance):
   binary = iree.compiler.compile_str(
       """
       func.func @dynamic_abs(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
@@ -52,7 +52,7 @@
       """,
       target_backends=iree.compiler.DEFAULT_TESTING_BACKENDS,
   )
-  m = iree.runtime.VmModule.from_flatbuffer(binary)
+  m = iree.runtime.VmModule.from_flatbuffer(instance, binary)
   return m
 
 
@@ -60,49 +60,46 @@
 
   @classmethod
   def setUp(self):
+    self.instance = iree.runtime.VmInstance()
     self.device = iree.runtime.get_device(
         iree.compiler.core.DEFAULT_TESTING_DRIVER)
-    self.hal_module = iree.runtime.create_hal_module(self.device)
+    self.hal_module = iree.runtime.create_hal_module(self.instance, self.device)
 
   def test_context_id(self):
-    instance = iree.runtime.VmInstance()
-    context1 = iree.runtime.VmContext(instance)
-    context2 = iree.runtime.VmContext(instance)
+    context1 = iree.runtime.VmContext(self.instance)
+    context2 = iree.runtime.VmContext(self.instance)
     self.assertNotEqual(context2.context_id, context1.context_id)
 
   def test_module_basics(self):
-    m = create_simple_static_mul_module()
+    m = create_simple_static_mul_module(self.instance)
     f = m.lookup_function("simple_mul")
     self.assertGreaterEqual(f.ordinal, 0)
     notfound = m.lookup_function("notfound")
     self.assertIs(notfound, None)
 
   def test_dynamic_module_context(self):
-    instance = iree.runtime.VmInstance()
-    context = iree.runtime.VmContext(instance)
-    m = create_simple_static_mul_module()
+    context = iree.runtime.VmContext(self.instance)
+    m = create_simple_static_mul_module(self.instance)
     context.register_modules([self.hal_module, m])
 
   def test_static_module_context(self):
-    m = create_simple_static_mul_module()
+    m = create_simple_static_mul_module(self.instance)
     logging.info("module: %s", m)
-    instance = iree.runtime.VmInstance()
-    logging.info("instance: %s", instance)
-    context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
+    context = iree.runtime.VmContext(self.instance,
+                                     modules=[self.hal_module, m])
     logging.info("context: %s", context)
 
   def test_dynamic_shape_compile(self):
-    m = create_simple_dynamic_abs_module()
+    m = create_simple_dynamic_abs_module(self.instance)
     logging.info("module: %s", m)
-    instance = iree.runtime.VmInstance()
-    logging.info("instance: %s", instance)
-    context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
+    context = iree.runtime.VmContext(self.instance,
+                                     modules=[self.hal_module, m])
     logging.info("context: %s", context)
 
   def test_add_scalar_new_abi(self):
-    m = create_add_scalar_module()
-    instance = iree.runtime.VmInstance()
-    context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
+    m = create_add_scalar_module(self.instance)
+    context = iree.runtime.VmContext(self.instance,
+                                     modules=[self.hal_module, m])
     f = m.lookup_function("add_scalar")
     finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None)
     result = finv(5, 6)
@@ -110,9 +107,9 @@
     self.assertEqual(result, 11)
 
   def test_synchronous_dynamic_shape_invoke_function_new_abi(self):
-    m = create_simple_dynamic_abs_module()
-    instance = iree.runtime.VmInstance()
-    context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
+    m = create_simple_dynamic_abs_module(self.instance)
+    context = iree.runtime.VmContext(self.instance,
+                                     modules=[self.hal_module, m])
     f = m.lookup_function("dynamic_abs")
     finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None)
     arg0 = np.array([[-1., 2.], [3., -4.]], dtype=np.float32)
@@ -121,9 +118,9 @@
     np.testing.assert_allclose(result, [[1., 2.], [3., 4.]])
 
   def test_synchronous_invoke_function_new_abi(self):
-    m = create_simple_static_mul_module()
-    instance = iree.runtime.VmInstance()
-    context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
+    m = create_simple_static_mul_module(self.instance)
+    context = iree.runtime.VmContext(self.instance,
+                                     modules=[self.hal_module, m])
     f = m.lookup_function("simple_mul")
     finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None)
     arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
diff --git a/runtime/bindings/python/tests/vm_types_test.py b/runtime/bindings/python/tests/vm_types_test.py
index 88fd9ff..891f647 100644
--- a/runtime/bindings/python/tests/vm_types_test.py
+++ b/runtime/bindings/python/tests/vm_types_test.py
@@ -13,6 +13,11 @@
 
 class VmTypesTest(unittest.TestCase):
 
+  @classmethod
+  def setUp(self):
+    # Ensures types are registered.
+    self.instance = rt.VmInstance()
+
   def testRefProtocol(self):
     lst1 = rt.VmVariantList(0)
     ref = lst1.__iree_vm_ref__
diff --git a/runtime/bindings/python/vm.cc b/runtime/bindings/python/vm.cc
index cf01457..027a8c3 100644
--- a/runtime/bindings/python/vm.cc
+++ b/runtime/bindings/python/vm.cc
@@ -58,9 +58,17 @@
 
 VmInstance VmInstance::Create() {
   IREE_TRACE_SCOPE0("VmInstance::Create");
-  iree_vm_instance_t* instance;
+
+  iree_vm_instance_t* instance = NULL;
   auto status = iree_vm_instance_create(iree_allocator_system(), &instance);
   CheckApiStatus(status, "Error creating instance");
+
+  // The python bindings assume the HAL is always available for use.
+  // We register the types here so modules can be loaded using the HAL types
+  // in any order.
+  CheckApiStatus(iree_hal_module_register_all_types(instance),
+                 "registering HAL types");
+
   return VmInstance::StealFromRawPtr(instance);
 }
 
@@ -122,11 +130,12 @@
 // VmModule
 //------------------------------------------------------------------------------
 
-VmModule VmModule::FromFlatbufferBlob(py::object flatbuffer_blob_object) {
+VmModule VmModule::FromFlatbufferBlob(VmInstance* instance,
+                                      py::object flatbuffer_blob_object) {
   IREE_TRACE_SCOPE0("VmModule::FromFlatbufferBlob");
   auto flatbuffer_blob = py::cast<py::buffer>(flatbuffer_blob_object);
   auto buffer_info = flatbuffer_blob.request();
-  iree_vm_module_t* module;
+  iree_vm_module_t* module = nullptr;
 
   // Bridge to the C-based deallocator API.
   auto* raw_ptr = flatbuffer_blob.ptr();
@@ -141,6 +150,7 @@
   iree_allocator_t deallocator{/*self=*/NULL, /*ctl=*/ctl_fn};
 
   auto status = iree_vm_bytecode_module_create(
+      instance->raw_ptr(),
       {static_cast<const uint8_t*>(buffer_info.ptr),
        static_cast<iree_host_size_t>(buffer_info.size)},
       deallocator, iree_allocator_system(), &module);
@@ -511,11 +521,6 @@
 }
 
 void SetupVmBindings(pybind11::module m) {
-  // TODO(#8698): need to register these on an instance.
-  // The instance constructor does this for us and if we created it first we
-  // wouldn't need to call this.
-  IREE_CHECK_OK(iree_vm_register_builtin_types(NULL));
-
   py::enum_<enum iree_vm_function_linkage_e>(m, "Linkage")
       .value("INTERNAL", IREE_VM_FUNCTION_LINKAGE_INTERNAL)
       .value("IMPORT", IREE_VM_FUNCTION_LINKAGE_IMPORT)
diff --git a/runtime/bindings/python/vm.h b/runtime/bindings/python/vm.h
index 37443cc..202e221 100644
--- a/runtime/bindings/python/vm.h
+++ b/runtime/bindings/python/vm.h
@@ -129,7 +129,8 @@
 
 class VmModule : public ApiRefCounted<VmModule, iree_vm_module_t> {
  public:
-  static VmModule FromFlatbufferBlob(py::object flatbuffer_blob_object);
+  static VmModule FromFlatbufferBlob(VmInstance* instance,
+                                     py::object flatbuffer_blob_object);
 
   std::optional<iree_vm_function_t> LookupFunction(
       const std::string& name, iree_vm_function_linkage_t linkage);