Making VM modules take an instance arg.
This will make it possible to use instance-specific information during
module construction like registered types.
Progress on #8698.
diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc
index bdd4929..3166fa2 100644
--- a/runtime/bindings/python/hal.cc
+++ b/runtime/bindings/python/hal.cc
@@ -387,12 +387,12 @@
// HAL module
//------------------------------------------------------------------------------
-VmModule CreateHalModule(HalDevice* device) {
- iree_vm_module_t* module;
- CheckApiStatus(
- iree_hal_module_create(device->raw_ptr(), IREE_HAL_MODULE_FLAG_NONE,
- iree_allocator_system(), &module),
- "Error creating hal module");
+VmModule CreateHalModule(VmInstance* instance, HalDevice* device) {
+ iree_vm_module_t* module = NULL;
+ CheckApiStatus(iree_hal_module_create(instance->raw_ptr(), device->raw_ptr(),
+ IREE_HAL_MODULE_FLAG_NONE,
+ iree_allocator_system(), &module),
+ "Error creating hal module");
return VmModule::StealFromRawPtr(module);
}
@@ -403,9 +403,6 @@
void SetupHalBindings(pybind11::module m) {
py::dict driver_cache;
- // TODO(#8698): need to register these on an instance.
- IREE_CHECK_OK(iree_hal_module_register_all_types(NULL));
-
// Built-in module creation.
m.def("create_hal_module", &CreateHalModule);
diff --git a/runtime/bindings/python/iree/runtime/system_api.py b/runtime/bindings/python/iree/runtime/system_api.py
index 5479224..025b4b5 100644
--- a/runtime/bindings/python/iree/runtime/system_api.py
+++ b/runtime/bindings/python/iree/runtime/system_api.py
@@ -72,7 +72,7 @@
driver_name.split(",") if driver_name is not None else None)
self.vm_instance = _binding.VmInstance()
- hal_module = _binding.create_hal_module(self.device)
+ hal_module = _binding.create_hal_module(self.vm_instance, self.device)
self.default_vm_modules = (hal_module,)
self.tracer = tracer or tracing.get_default_tracer()
if self.tracer and self.tracer.enabled:
@@ -283,8 +283,10 @@
"the driver from.")
if backend is not None:
driver = TARGET_BACKEND_TO_DRIVER[backend]
- vm_module = _binding.VmModule.from_flatbuffer(vm_flatbuffer)
- bound_module = load_vm_module(vm_module, Config(driver))
+ config = Config(driver)
+ vm_module = _binding.VmModule.from_flatbuffer(config.vm_instance,
+ vm_flatbuffer)
+ bound_module = load_vm_module(vm_module, config)
return bound_module
diff --git a/runtime/bindings/python/tests/system_api_test.py b/runtime/bindings/python/tests/system_api_test.py
index ccb0dbf..227df92 100644
--- a/runtime/bindings/python/tests/system_api_test.py
+++ b/runtime/bindings/python/tests/system_api_test.py
@@ -17,7 +17,7 @@
import numpy as np
-def create_simple_mul_module():
+def create_simple_mul_module(instance):
binary = iree.compiler.compile_str(
"""
module @arithmetic {
@@ -29,7 +29,7 @@
""",
target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
)
- m = iree.runtime.VmModule.from_flatbuffer(binary)
+ m = iree.runtime.VmModule.from_flatbuffer(instance, binary)
return m
@@ -62,7 +62,7 @@
def test_custom_dynamic(self):
ctx = iree.runtime.SystemContext()
self.assertTrue(ctx.is_dynamic)
- ctx.add_vm_module(create_simple_mul_module())
+ ctx.add_vm_module(create_simple_mul_module(ctx.instance))
self.assertEqual(ctx.modules.arithmetic.name, "arithmetic")
f = ctx.modules.arithmetic["simple_mul"]
f_repr = repr(f)
@@ -72,14 +72,14 @@
def test_duplicate_module(self):
ctx = iree.runtime.SystemContext()
self.assertTrue(ctx.is_dynamic)
- ctx.add_vm_module(create_simple_mul_module())
+ ctx.add_vm_module(create_simple_mul_module(ctx.instance))
with self.assertRaisesRegex(ValueError, "arithmetic"):
- ctx.add_vm_module(create_simple_mul_module())
+ ctx.add_vm_module(create_simple_mul_module(ctx.instance))
def test_static_invoke(self):
ctx = iree.runtime.SystemContext()
self.assertTrue(ctx.is_dynamic)
- ctx.add_vm_module(create_simple_mul_module())
+ ctx.add_vm_module(create_simple_mul_module(ctx.instance))
self.assertEqual(ctx.modules.arithmetic.name, "arithmetic")
f = ctx.modules.arithmetic["simple_mul"]
arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
@@ -92,7 +92,7 @@
# and input to functions.
ctx = iree.runtime.SystemContext()
self.assertTrue(ctx.is_dynamic)
- ctx.add_vm_module(create_simple_mul_module())
+ ctx.add_vm_module(create_simple_mul_module(ctx.instance))
self.assertEqual(ctx.modules.arithmetic.name, "arithmetic")
f = ctx.modules.arithmetic["simple_mul"]
arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
@@ -121,7 +121,7 @@
def verify_tracing(self, config, temp_dir):
logging.info("Tracing test to: %s", temp_dir)
ctx = iree.runtime.SystemContext(config=config)
- ctx.add_vm_module(create_simple_mul_module())
+ ctx.add_vm_module(create_simple_mul_module(ctx.instance))
f = ctx.modules.arithmetic["simple_mul"]
arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
@@ -131,7 +131,9 @@
# TODO: Once replay is possible, verify that.
def test_load_vm_module(self):
- arithmetic = iree.runtime.load_vm_module(create_simple_mul_module())
+ ctx = iree.runtime.SystemContext()
+ arithmetic = iree.runtime.load_vm_module(
+ create_simple_mul_module(ctx.instance))
arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
results = arithmetic.simple_mul(arg0, arg1)
@@ -142,7 +144,8 @@
# Doing default device configuration multiple times should be valid
# (if this were instantiating drivers multiple times, it can trigger
# a crash, depending on whether the driver supports multi-instantiation).
- m = create_simple_mul_module()
+ ctx = iree.runtime.SystemContext()
+ m = create_simple_mul_module(ctx.instance)
m1 = iree.runtime.load_vm_module(m)
m2 = iree.runtime.load_vm_module(m)
diff --git a/runtime/bindings/python/tests/vm_test.py b/runtime/bindings/python/tests/vm_test.py
index 98573b0..1e7d307 100644
--- a/runtime/bindings/python/tests/vm_test.py
+++ b/runtime/bindings/python/tests/vm_test.py
@@ -14,7 +14,7 @@
import iree.runtime
-def create_add_scalar_module():
+def create_add_scalar_module(instance):
binary = iree.compiler.compile_str(
"""
func.func @add_scalar(%arg0: i32, %arg1: i32) -> i32 {
@@ -24,11 +24,11 @@
""",
target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
)
- m = iree.runtime.VmModule.from_flatbuffer(binary)
+ m = iree.runtime.VmModule.from_flatbuffer(instance, binary)
return m
-def create_simple_static_mul_module():
+def create_simple_static_mul_module(instance):
binary = iree.compiler.compile_str(
"""
func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
@@ -38,11 +38,11 @@
""",
target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
)
- m = iree.runtime.VmModule.from_flatbuffer(binary)
+ m = iree.runtime.VmModule.from_flatbuffer(instance, binary)
return m
-def create_simple_dynamic_abs_module():
+def create_simple_dynamic_abs_module(instance):
binary = iree.compiler.compile_str(
"""
func.func @dynamic_abs(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
@@ -52,7 +52,7 @@
""",
target_backends=iree.compiler.DEFAULT_TESTING_BACKENDS,
)
- m = iree.runtime.VmModule.from_flatbuffer(binary)
+ m = iree.runtime.VmModule.from_flatbuffer(instance, binary)
return m
@@ -60,49 +60,46 @@
@classmethod
def setUp(self):
+ self.instance = iree.runtime.VmInstance()
self.device = iree.runtime.get_device(
iree.compiler.core.DEFAULT_TESTING_DRIVER)
- self.hal_module = iree.runtime.create_hal_module(self.device)
+ self.hal_module = iree.runtime.create_hal_module(self.instance, self.device)
def test_context_id(self):
- instance = iree.runtime.VmInstance()
- context1 = iree.runtime.VmContext(instance)
- context2 = iree.runtime.VmContext(instance)
+ context1 = iree.runtime.VmContext(self.instance)
+ context2 = iree.runtime.VmContext(self.instance)
self.assertNotEqual(context2.context_id, context1.context_id)
def test_module_basics(self):
- m = create_simple_static_mul_module()
+ m = create_simple_static_mul_module(self.instance)
f = m.lookup_function("simple_mul")
self.assertGreaterEqual(f.ordinal, 0)
notfound = m.lookup_function("notfound")
self.assertIs(notfound, None)
def test_dynamic_module_context(self):
- instance = iree.runtime.VmInstance()
- context = iree.runtime.VmContext(instance)
- m = create_simple_static_mul_module()
+ context = iree.runtime.VmContext(self.instance)
+ m = create_simple_static_mul_module(self.instance)
context.register_modules([self.hal_module, m])
def test_static_module_context(self):
- m = create_simple_static_mul_module()
+ m = create_simple_static_mul_module(self.instance)
logging.info("module: %s", m)
- instance = iree.runtime.VmInstance()
- logging.info("instance: %s", instance)
- context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
+ context = iree.runtime.VmContext(self.instance,
+ modules=[self.hal_module, m])
logging.info("context: %s", context)
def test_dynamic_shape_compile(self):
- m = create_simple_dynamic_abs_module()
+ m = create_simple_dynamic_abs_module(self.instance)
logging.info("module: %s", m)
- instance = iree.runtime.VmInstance()
- logging.info("instance: %s", instance)
- context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
+ context = iree.runtime.VmContext(self.instance,
+ modules=[self.hal_module, m])
logging.info("context: %s", context)
def test_add_scalar_new_abi(self):
- m = create_add_scalar_module()
- instance = iree.runtime.VmInstance()
- context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
+ m = create_add_scalar_module(self.instance)
+ context = iree.runtime.VmContext(self.instance,
+ modules=[self.hal_module, m])
f = m.lookup_function("add_scalar")
finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None)
result = finv(5, 6)
@@ -110,9 +107,9 @@
self.assertEqual(result, 11)
def test_synchronous_dynamic_shape_invoke_function_new_abi(self):
- m = create_simple_dynamic_abs_module()
- instance = iree.runtime.VmInstance()
- context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
+ m = create_simple_dynamic_abs_module(self.instance)
+ context = iree.runtime.VmContext(self.instance,
+ modules=[self.hal_module, m])
f = m.lookup_function("dynamic_abs")
finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None)
arg0 = np.array([[-1., 2.], [3., -4.]], dtype=np.float32)
@@ -121,9 +118,9 @@
np.testing.assert_allclose(result, [[1., 2.], [3., 4.]])
def test_synchronous_invoke_function_new_abi(self):
- m = create_simple_static_mul_module()
- instance = iree.runtime.VmInstance()
- context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
+ m = create_simple_static_mul_module(self.instance)
+ context = iree.runtime.VmContext(self.instance,
+ modules=[self.hal_module, m])
f = m.lookup_function("simple_mul")
finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None)
arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
diff --git a/runtime/bindings/python/tests/vm_types_test.py b/runtime/bindings/python/tests/vm_types_test.py
index 88fd9ff..891f647 100644
--- a/runtime/bindings/python/tests/vm_types_test.py
+++ b/runtime/bindings/python/tests/vm_types_test.py
@@ -13,6 +13,11 @@
class VmTypesTest(unittest.TestCase):
+ @classmethod
+ def setUp(self):
+ # Ensures types are registered.
+ self.instance = rt.VmInstance()
+
def testRefProtocol(self):
lst1 = rt.VmVariantList(0)
ref = lst1.__iree_vm_ref__
diff --git a/runtime/bindings/python/vm.cc b/runtime/bindings/python/vm.cc
index cf01457..027a8c3 100644
--- a/runtime/bindings/python/vm.cc
+++ b/runtime/bindings/python/vm.cc
@@ -58,9 +58,17 @@
VmInstance VmInstance::Create() {
IREE_TRACE_SCOPE0("VmInstance::Create");
- iree_vm_instance_t* instance;
+
+ iree_vm_instance_t* instance = NULL;
auto status = iree_vm_instance_create(iree_allocator_system(), &instance);
CheckApiStatus(status, "Error creating instance");
+
+ // The python bindings assume the HAL is always available for use.
+ // We register the types here so modules can be loaded using the HAL types
+ // in any order.
+ CheckApiStatus(iree_hal_module_register_all_types(instance),
+ "registering HAL types");
+
return VmInstance::StealFromRawPtr(instance);
}
@@ -122,11 +130,12 @@
// VmModule
//------------------------------------------------------------------------------
-VmModule VmModule::FromFlatbufferBlob(py::object flatbuffer_blob_object) {
+VmModule VmModule::FromFlatbufferBlob(VmInstance* instance,
+ py::object flatbuffer_blob_object) {
IREE_TRACE_SCOPE0("VmModule::FromFlatbufferBlob");
auto flatbuffer_blob = py::cast<py::buffer>(flatbuffer_blob_object);
auto buffer_info = flatbuffer_blob.request();
- iree_vm_module_t* module;
+ iree_vm_module_t* module = nullptr;
// Bridge to the C-based deallocator API.
auto* raw_ptr = flatbuffer_blob.ptr();
@@ -141,6 +150,7 @@
iree_allocator_t deallocator{/*self=*/NULL, /*ctl=*/ctl_fn};
auto status = iree_vm_bytecode_module_create(
+ instance->raw_ptr(),
{static_cast<const uint8_t*>(buffer_info.ptr),
static_cast<iree_host_size_t>(buffer_info.size)},
deallocator, iree_allocator_system(), &module);
@@ -511,11 +521,6 @@
}
void SetupVmBindings(pybind11::module m) {
- // TODO(#8698): need to register these on an instance.
- // The instance constructor does this for us and if we created it first we
- // wouldn't need to call this.
- IREE_CHECK_OK(iree_vm_register_builtin_types(NULL));
-
py::enum_<enum iree_vm_function_linkage_e>(m, "Linkage")
.value("INTERNAL", IREE_VM_FUNCTION_LINKAGE_INTERNAL)
.value("IMPORT", IREE_VM_FUNCTION_LINKAGE_IMPORT)
diff --git a/runtime/bindings/python/vm.h b/runtime/bindings/python/vm.h
index 37443cc..202e221 100644
--- a/runtime/bindings/python/vm.h
+++ b/runtime/bindings/python/vm.h
@@ -129,7 +129,8 @@
class VmModule : public ApiRefCounted<VmModule, iree_vm_module_t> {
public:
- static VmModule FromFlatbufferBlob(py::object flatbuffer_blob_object);
+ static VmModule FromFlatbufferBlob(VmInstance* instance,
+ py::object flatbuffer_blob_object);
std::optional<iree_vm_function_t> LookupFunction(
const std::string& name, iree_vm_function_linkage_t linkage);