Making VM modules take an instance arg.
This will make it possible to use instance-specific information during
module construction like registered types.
Progress on #8698.
diff --git a/compiler/src/iree/compiler/ConstEval/Runtime.cpp b/compiler/src/iree/compiler/ConstEval/Runtime.cpp
index fe2bd1f..7557c16 100644
--- a/compiler/src/iree/compiler/ConstEval/Runtime.cpp
+++ b/compiler/src/iree/compiler/ConstEval/Runtime.cpp
@@ -272,13 +272,14 @@
iree_hal_driver_release(driver);
// Create hal module.
- IREE_CHECK_OK(iree_hal_module_create(device, IREE_HAL_MODULE_FLAG_NONE,
+ IREE_CHECK_OK(iree_hal_module_create(runtime.instance, device,
+ IREE_HAL_MODULE_FLAG_NONE,
iree_allocator_system(), &hal_module));
// Bytecode module.
IREE_CHECK_OK(iree_vm_bytecode_module_create(
- iree_make_const_byte_span(data, length), iree_allocator_null(),
- iree_allocator_system(), &main_module));
+ runtime.instance, iree_make_const_byte_span(data, length),
+ iree_allocator_null(), iree_allocator_system(), &main_module));
// Context.
std::array<iree_vm_module_t*, 2> modules = {hal_module, main_module};
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
index 0953cb7..59b1a63 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
@@ -1007,10 +1007,16 @@
auto funcType = mlir::FunctionType::get(
ctx,
- {emitc::OpaqueType::get(ctx, "iree_allocator_t"),
- emitc::PointerType::get(emitc::PointerType::get(
- emitc::OpaqueType::get(ctx, "iree_vm_module_t")))},
- {emitc::OpaqueType::get(ctx, "iree_status_t")});
+ {
+ emitc::PointerType::get(
+ emitc::OpaqueType::get(ctx, "iree_vm_instance_t")),
+ emitc::OpaqueType::get(ctx, "iree_allocator_t"),
+ emitc::PointerType::get(emitc::PointerType::get(
+ emitc::OpaqueType::get(ctx, "iree_vm_module_t"))),
+ },
+ {
+ emitc::OpaqueType::get(ctx, "iree_status_t"),
+ });
auto funcOp = builder.create<mlir::func::FuncOp>(
loc, moduleName + "_create", funcType);
@@ -1054,7 +1060,7 @@
returnIfError(builder, loc, StringAttr::get(ctx, "iree_allocator_malloc"),
{}, {},
- {funcOp.getArgument(0), moduleSize, voidPtr.getResult()},
+ {funcOp.getArgument(1), moduleSize, voidPtr.getResult()},
/*typeConverter=*/typeConverter);
builder.create<emitc::CallOp>(
@@ -1072,7 +1078,7 @@
emitc_builders::structPtrMemberAssign(builder, loc,
/*memberName=*/"allocator",
/*operand=*/module.getResult(),
- /*value=*/funcOp.getArgument(0));
+ /*value=*/funcOp.getArgument(1));
auto vmModule = builder.create<emitc::VariableOp>(
/*location=*/loc,
@@ -1122,7 +1128,7 @@
/*args=*/ArrayAttr{},
/*templateArgs=*/ArrayAttr{},
/*operands=*/
- ArrayRef<Value>{funcOp.getArgument(0), module.getResult()});
+ ArrayRef<Value>{funcOp.getArgument(1), module.getResult()});
builder.create<mlir::func::ReturnOp>(loc,
vmInitializeStatus.getResult(0));
@@ -1144,7 +1150,7 @@
/*value=*/moduleName + "_" + funcName);
}
- std::string descriptoPtr = "&" + moduleName + "_descriptor_";
+ std::string descriptorPtr = "&" + moduleName + "_descriptor_";
auto status = builder.create<emitc::CallOp>(
/*location=*/loc,
@@ -1152,12 +1158,13 @@
/*callee=*/StringAttr::get(ctx, "iree_vm_native_module_create"),
/*args=*/
ArrayAttr::get(ctx, {builder.getIndexAttr(0),
- emitc::OpaqueAttr::get(ctx, descriptoPtr),
- builder.getIndexAttr(1), builder.getIndexAttr(2)}),
+ emitc::OpaqueAttr::get(ctx, descriptorPtr),
+ builder.getIndexAttr(1), builder.getIndexAttr(2),
+ builder.getIndexAttr(3)}),
/*templateArgs=*/ArrayAttr{},
/*operands=*/
ArrayRef<Value>{vmModulePtr, funcOp.getArgument(0),
- funcOp.getArgument(1)});
+ funcOp.getArgument(1), funcOp.getArgument(2)});
builder.create<mlir::func::ReturnOp>(loc, status.getResult(0));
}
diff --git a/docs/website/docs/bindings/c-api.md b/docs/website/docs/bindings/c-api.md
index 3f1e3ac..66033d0 100644
--- a/docs/website/docs/bindings/c-api.md
+++ b/docs/website/docs/bindings/c-api.md
@@ -128,7 +128,7 @@
// We'll load this module into a VM context later.
iree_vm_module_t* hal_module = NULL;
IREE_CHECK_OK(
- iree_hal_module_create(device, IREE_HAL_MODULE_FLAG_NONE,
+ iree_hal_module_create(instance, device, IREE_HAL_MODULE_FLAG_NONE,
iree_allocator_system(), &hal_module));
// The reference to the driver can be released now.
iree_hal_driver_release(driver);
@@ -145,6 +145,7 @@
iree_vm_module_t* bytecode_module = NULL;
IREE_CHECK_OK(iree_vm_bytecode_module_create(
+ instance,
iree_const_byte_span_t{module_data, module_size},
/*flatbuffer_allocator=*/iree_allocator_null(),
/*allocator=*/iree_allocator_system(), &bytecode_module));
diff --git a/docs/website/docs/getting-started/tflite.md b/docs/website/docs/getting-started/tflite.md
index 2d9db3c..4832dde 100644
--- a/docs/website/docs/getting-started/tflite.md
+++ b/docs/website/docs/getting-started/tflite.md
@@ -110,7 +110,7 @@
config = iree_rt.Config("local-task")
context = iree_rt.SystemContext(config=config)
with open(bytecodeModule, 'rb') as f:
- vm_module = iree_rt.VmModule.from_flatbuffer(f.read())
+ vm_module = iree_rt.VmModule.from_flatbuffer(config.vm_instance, f.read())
context.add_vm_module(vm_module)
```
diff --git a/experimental/web/sample_dynamic/main.c b/experimental/web/sample_dynamic/main.c
index 6cb8a48..33e4e7c 100644
--- a/experimental/web/sample_dynamic/main.c
+++ b/experimental/web/sample_dynamic/main.c
@@ -131,7 +131,7 @@
// Take ownership of the FlatBuffer data so JavaScript doesn't need to
// explicitly call `Module._free()`.
status = iree_vm_bytecode_module_create(
- iree_make_const_byte_span(vmfb_data, length),
+ sample_state->instance, iree_make_const_byte_span(vmfb_data, length),
/*flatbuffer_allocator=*/iree_allocator_system(),
iree_allocator_system(), &program_state->module);
} else {
diff --git a/experimental/web/sample_static/main.c b/experimental/web/sample_static/main.c
index 3f50ba9..b73ed9e 100644
--- a/experimental/web/sample_static/main.c
+++ b/experimental/web/sample_static/main.c
@@ -39,11 +39,13 @@
iree_runtime_call_t call;
} iree_sample_state_t;
-iree_status_t create_bytecode_module(iree_vm_module_t** out_module) {
+iree_status_t create_bytecode_module(iree_vm_instance_t* instance,
+ iree_vm_module_t** out_module) {
const struct iree_file_toc_t* module_file_toc = iree_static_mnist_create();
iree_const_byte_span_t module_data =
iree_make_const_byte_span(module_file_toc->data, module_file_toc->size);
- return iree_vm_bytecode_module_create(module_data, iree_allocator_null(),
+ return iree_vm_bytecode_module_create(instance, module_data,
+ iree_allocator_null(),
iree_allocator_system(), out_module);
}
@@ -76,7 +78,7 @@
}
if (iree_status_is_ok(status)) {
- status = create_bytecode_module(&state->module);
+ status = create_bytecode_module(state->instance, &state->module);
}
if (iree_status_is_ok(status)) {
status = iree_runtime_session_append_module(state->session, state->module);
diff --git a/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/module_utils.py b/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/module_utils.py
index ed889a3..8569b52 100644
--- a/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/module_utils.py
+++ b/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/module_utils.py
@@ -369,8 +369,9 @@
backend_info=backend_info,
exported_names=exported_names,
artifacts_dir=artifacts_dir)
- vm_module = iree.runtime.VmModule.from_flatbuffer(module_blob)
config = iree.runtime.Config(driver_name=backend_info.driver)
+ vm_module = iree.runtime.VmModule.from_flatbuffer(config.vm_instance,
+ module_blob)
compiled_paths = None
if compiled_path is not None:
@@ -412,8 +413,9 @@
module_blob, compiled_path = _incrementally_compile_tf_signature_def_saved_model(
saved_model_dir, saved_model_tags, backend_info, exported_name,
artifacts_dir)
- vm_module = iree.runtime.VmModule.from_flatbuffer(module_blob)
config = iree.runtime.Config(driver_name=backend_info.driver)
+ vm_module = iree.runtime.VmModule.from_flatbuffer(config.vm_instance,
+ module_blob)
compiled_paths = None
if compiled_path is not None:
diff --git a/integrations/tensorflow/test/python/iree_tfl_tests/test_util.py b/integrations/tensorflow/test/python/iree_tfl_tests/test_util.py
index 9e4210c..285b7df 100644
--- a/integrations/tensorflow/test/python/iree_tfl_tests/test_util.py
+++ b/integrations/tensorflow/test/python/iree_tfl_tests/test_util.py
@@ -105,7 +105,7 @@
with open(self.binary, 'rb') as f:
config = iree_rt.Config(configs[absl.flags.FLAGS.target_backend])
self.iree_context = iree_rt.SystemContext(config=config)
- vm_module = iree_rt.VmModule.from_flatbuffer(f.read())
+ vm_module = iree_rt.VmModule.from_flatbuffer(config.vm_instance, f.read())
self.iree_context.add_vm_module(vm_module)
def invoke_tflite(self, args):
diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc
index bdd4929..3166fa2 100644
--- a/runtime/bindings/python/hal.cc
+++ b/runtime/bindings/python/hal.cc
@@ -387,12 +387,12 @@
// HAL module
//------------------------------------------------------------------------------
-VmModule CreateHalModule(HalDevice* device) {
- iree_vm_module_t* module;
- CheckApiStatus(
- iree_hal_module_create(device->raw_ptr(), IREE_HAL_MODULE_FLAG_NONE,
- iree_allocator_system(), &module),
- "Error creating hal module");
+VmModule CreateHalModule(VmInstance* instance, HalDevice* device) {
+ iree_vm_module_t* module = NULL;
+ CheckApiStatus(iree_hal_module_create(instance->raw_ptr(), device->raw_ptr(),
+ IREE_HAL_MODULE_FLAG_NONE,
+ iree_allocator_system(), &module),
+ "Error creating hal module");
return VmModule::StealFromRawPtr(module);
}
@@ -403,9 +403,6 @@
void SetupHalBindings(pybind11::module m) {
py::dict driver_cache;
- // TODO(#8698): need to register these on an instance.
- IREE_CHECK_OK(iree_hal_module_register_all_types(NULL));
-
// Built-in module creation.
m.def("create_hal_module", &CreateHalModule);
diff --git a/runtime/bindings/python/iree/runtime/system_api.py b/runtime/bindings/python/iree/runtime/system_api.py
index 5479224..025b4b5 100644
--- a/runtime/bindings/python/iree/runtime/system_api.py
+++ b/runtime/bindings/python/iree/runtime/system_api.py
@@ -72,7 +72,7 @@
driver_name.split(",") if driver_name is not None else None)
self.vm_instance = _binding.VmInstance()
- hal_module = _binding.create_hal_module(self.device)
+ hal_module = _binding.create_hal_module(self.vm_instance, self.device)
self.default_vm_modules = (hal_module,)
self.tracer = tracer or tracing.get_default_tracer()
if self.tracer and self.tracer.enabled:
@@ -283,8 +283,10 @@
"the driver from.")
if backend is not None:
driver = TARGET_BACKEND_TO_DRIVER[backend]
- vm_module = _binding.VmModule.from_flatbuffer(vm_flatbuffer)
- bound_module = load_vm_module(vm_module, Config(driver))
+ config = Config(driver)
+ vm_module = _binding.VmModule.from_flatbuffer(config.vm_instance,
+ vm_flatbuffer)
+ bound_module = load_vm_module(vm_module, config)
return bound_module
diff --git a/runtime/bindings/python/tests/system_api_test.py b/runtime/bindings/python/tests/system_api_test.py
index ccb0dbf..227df92 100644
--- a/runtime/bindings/python/tests/system_api_test.py
+++ b/runtime/bindings/python/tests/system_api_test.py
@@ -17,7 +17,7 @@
import numpy as np
-def create_simple_mul_module():
+def create_simple_mul_module(instance):
binary = iree.compiler.compile_str(
"""
module @arithmetic {
@@ -29,7 +29,7 @@
""",
target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
)
- m = iree.runtime.VmModule.from_flatbuffer(binary)
+ m = iree.runtime.VmModule.from_flatbuffer(instance, binary)
return m
@@ -62,7 +62,7 @@
def test_custom_dynamic(self):
ctx = iree.runtime.SystemContext()
self.assertTrue(ctx.is_dynamic)
- ctx.add_vm_module(create_simple_mul_module())
+ ctx.add_vm_module(create_simple_mul_module(ctx.instance))
self.assertEqual(ctx.modules.arithmetic.name, "arithmetic")
f = ctx.modules.arithmetic["simple_mul"]
f_repr = repr(f)
@@ -72,14 +72,14 @@
def test_duplicate_module(self):
ctx = iree.runtime.SystemContext()
self.assertTrue(ctx.is_dynamic)
- ctx.add_vm_module(create_simple_mul_module())
+ ctx.add_vm_module(create_simple_mul_module(ctx.instance))
with self.assertRaisesRegex(ValueError, "arithmetic"):
- ctx.add_vm_module(create_simple_mul_module())
+ ctx.add_vm_module(create_simple_mul_module(ctx.instance))
def test_static_invoke(self):
ctx = iree.runtime.SystemContext()
self.assertTrue(ctx.is_dynamic)
- ctx.add_vm_module(create_simple_mul_module())
+ ctx.add_vm_module(create_simple_mul_module(ctx.instance))
self.assertEqual(ctx.modules.arithmetic.name, "arithmetic")
f = ctx.modules.arithmetic["simple_mul"]
arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
@@ -92,7 +92,7 @@
# and input to functions.
ctx = iree.runtime.SystemContext()
self.assertTrue(ctx.is_dynamic)
- ctx.add_vm_module(create_simple_mul_module())
+ ctx.add_vm_module(create_simple_mul_module(ctx.instance))
self.assertEqual(ctx.modules.arithmetic.name, "arithmetic")
f = ctx.modules.arithmetic["simple_mul"]
arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
@@ -121,7 +121,7 @@
def verify_tracing(self, config, temp_dir):
logging.info("Tracing test to: %s", temp_dir)
ctx = iree.runtime.SystemContext(config=config)
- ctx.add_vm_module(create_simple_mul_module())
+ ctx.add_vm_module(create_simple_mul_module(ctx.instance))
f = ctx.modules.arithmetic["simple_mul"]
arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
@@ -131,7 +131,9 @@
# TODO: Once replay is possible, verify that.
def test_load_vm_module(self):
- arithmetic = iree.runtime.load_vm_module(create_simple_mul_module())
+ ctx = iree.runtime.SystemContext()
+ arithmetic = iree.runtime.load_vm_module(
+ create_simple_mul_module(ctx.instance))
arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
results = arithmetic.simple_mul(arg0, arg1)
@@ -142,7 +144,8 @@
# Doing default device configuration multiple times should be valid
# (if this were instantiating drivers multiple times, it can trigger
# a crash, depending on whether the driver supports multi-instantiation).
- m = create_simple_mul_module()
+ ctx = iree.runtime.SystemContext()
+ m = create_simple_mul_module(ctx.instance)
m1 = iree.runtime.load_vm_module(m)
m2 = iree.runtime.load_vm_module(m)
diff --git a/runtime/bindings/python/tests/vm_test.py b/runtime/bindings/python/tests/vm_test.py
index 98573b0..1e7d307 100644
--- a/runtime/bindings/python/tests/vm_test.py
+++ b/runtime/bindings/python/tests/vm_test.py
@@ -14,7 +14,7 @@
import iree.runtime
-def create_add_scalar_module():
+def create_add_scalar_module(instance):
binary = iree.compiler.compile_str(
"""
func.func @add_scalar(%arg0: i32, %arg1: i32) -> i32 {
@@ -24,11 +24,11 @@
""",
target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
)
- m = iree.runtime.VmModule.from_flatbuffer(binary)
+ m = iree.runtime.VmModule.from_flatbuffer(instance, binary)
return m
-def create_simple_static_mul_module():
+def create_simple_static_mul_module(instance):
binary = iree.compiler.compile_str(
"""
func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
@@ -38,11 +38,11 @@
""",
target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
)
- m = iree.runtime.VmModule.from_flatbuffer(binary)
+ m = iree.runtime.VmModule.from_flatbuffer(instance, binary)
return m
-def create_simple_dynamic_abs_module():
+def create_simple_dynamic_abs_module(instance):
binary = iree.compiler.compile_str(
"""
func.func @dynamic_abs(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
@@ -52,7 +52,7 @@
""",
target_backends=iree.compiler.DEFAULT_TESTING_BACKENDS,
)
- m = iree.runtime.VmModule.from_flatbuffer(binary)
+ m = iree.runtime.VmModule.from_flatbuffer(instance, binary)
return m
@@ -60,49 +60,46 @@
@classmethod
def setUp(self):
+ self.instance = iree.runtime.VmInstance()
self.device = iree.runtime.get_device(
iree.compiler.core.DEFAULT_TESTING_DRIVER)
- self.hal_module = iree.runtime.create_hal_module(self.device)
+ self.hal_module = iree.runtime.create_hal_module(self.instance, self.device)
def test_context_id(self):
- instance = iree.runtime.VmInstance()
- context1 = iree.runtime.VmContext(instance)
- context2 = iree.runtime.VmContext(instance)
+ context1 = iree.runtime.VmContext(self.instance)
+ context2 = iree.runtime.VmContext(self.instance)
self.assertNotEqual(context2.context_id, context1.context_id)
def test_module_basics(self):
- m = create_simple_static_mul_module()
+ m = create_simple_static_mul_module(self.instance)
f = m.lookup_function("simple_mul")
self.assertGreaterEqual(f.ordinal, 0)
notfound = m.lookup_function("notfound")
self.assertIs(notfound, None)
def test_dynamic_module_context(self):
- instance = iree.runtime.VmInstance()
- context = iree.runtime.VmContext(instance)
- m = create_simple_static_mul_module()
+ context = iree.runtime.VmContext(self.instance)
+ m = create_simple_static_mul_module(self.instance)
context.register_modules([self.hal_module, m])
def test_static_module_context(self):
- m = create_simple_static_mul_module()
+ m = create_simple_static_mul_module(self.instance)
logging.info("module: %s", m)
- instance = iree.runtime.VmInstance()
- logging.info("instance: %s", instance)
- context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
+ context = iree.runtime.VmContext(self.instance,
+ modules=[self.hal_module, m])
logging.info("context: %s", context)
def test_dynamic_shape_compile(self):
- m = create_simple_dynamic_abs_module()
+ m = create_simple_dynamic_abs_module(self.instance)
logging.info("module: %s", m)
- instance = iree.runtime.VmInstance()
- logging.info("instance: %s", instance)
- context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
+ context = iree.runtime.VmContext(self.instance,
+ modules=[self.hal_module, m])
logging.info("context: %s", context)
def test_add_scalar_new_abi(self):
- m = create_add_scalar_module()
- instance = iree.runtime.VmInstance()
- context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
+ m = create_add_scalar_module(self.instance)
+ context = iree.runtime.VmContext(self.instance,
+ modules=[self.hal_module, m])
f = m.lookup_function("add_scalar")
finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None)
result = finv(5, 6)
@@ -110,9 +107,9 @@
self.assertEqual(result, 11)
def test_synchronous_dynamic_shape_invoke_function_new_abi(self):
- m = create_simple_dynamic_abs_module()
- instance = iree.runtime.VmInstance()
- context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
+ m = create_simple_dynamic_abs_module(self.instance)
+ context = iree.runtime.VmContext(self.instance,
+ modules=[self.hal_module, m])
f = m.lookup_function("dynamic_abs")
finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None)
arg0 = np.array([[-1., 2.], [3., -4.]], dtype=np.float32)
@@ -121,9 +118,9 @@
np.testing.assert_allclose(result, [[1., 2.], [3., 4.]])
def test_synchronous_invoke_function_new_abi(self):
- m = create_simple_static_mul_module()
- instance = iree.runtime.VmInstance()
- context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
+ m = create_simple_static_mul_module(self.instance)
+ context = iree.runtime.VmContext(self.instance,
+ modules=[self.hal_module, m])
f = m.lookup_function("simple_mul")
finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None)
arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
diff --git a/runtime/bindings/python/tests/vm_types_test.py b/runtime/bindings/python/tests/vm_types_test.py
index 88fd9ff..891f647 100644
--- a/runtime/bindings/python/tests/vm_types_test.py
+++ b/runtime/bindings/python/tests/vm_types_test.py
@@ -13,6 +13,11 @@
class VmTypesTest(unittest.TestCase):
+ @classmethod
+ def setUp(self):
+ # Ensures types are registered.
+ self.instance = rt.VmInstance()
+
def testRefProtocol(self):
lst1 = rt.VmVariantList(0)
ref = lst1.__iree_vm_ref__
diff --git a/runtime/bindings/python/vm.cc b/runtime/bindings/python/vm.cc
index cf01457..027a8c3 100644
--- a/runtime/bindings/python/vm.cc
+++ b/runtime/bindings/python/vm.cc
@@ -58,9 +58,17 @@
VmInstance VmInstance::Create() {
IREE_TRACE_SCOPE0("VmInstance::Create");
- iree_vm_instance_t* instance;
+
+ iree_vm_instance_t* instance = NULL;
auto status = iree_vm_instance_create(iree_allocator_system(), &instance);
CheckApiStatus(status, "Error creating instance");
+
+ // The python bindings assume the HAL is always available for use.
+ // We register the types here so modules can be loaded using the HAL types
+ // in any order.
+ CheckApiStatus(iree_hal_module_register_all_types(instance),
+ "registering HAL types");
+
return VmInstance::StealFromRawPtr(instance);
}
@@ -122,11 +130,12 @@
// VmModule
//------------------------------------------------------------------------------
-VmModule VmModule::FromFlatbufferBlob(py::object flatbuffer_blob_object) {
+VmModule VmModule::FromFlatbufferBlob(VmInstance* instance,
+ py::object flatbuffer_blob_object) {
IREE_TRACE_SCOPE0("VmModule::FromFlatbufferBlob");
auto flatbuffer_blob = py::cast<py::buffer>(flatbuffer_blob_object);
auto buffer_info = flatbuffer_blob.request();
- iree_vm_module_t* module;
+ iree_vm_module_t* module = nullptr;
// Bridge to the C-based deallocator API.
auto* raw_ptr = flatbuffer_blob.ptr();
@@ -141,6 +150,7 @@
iree_allocator_t deallocator{/*self=*/NULL, /*ctl=*/ctl_fn};
auto status = iree_vm_bytecode_module_create(
+ instance->raw_ptr(),
{static_cast<const uint8_t*>(buffer_info.ptr),
static_cast<iree_host_size_t>(buffer_info.size)},
deallocator, iree_allocator_system(), &module);
@@ -511,11 +521,6 @@
}
void SetupVmBindings(pybind11::module m) {
- // TODO(#8698): need to register these on an instance.
- // The instance constructor does this for us and if we created it first we
- // wouldn't need to call this.
- IREE_CHECK_OK(iree_vm_register_builtin_types(NULL));
-
py::enum_<enum iree_vm_function_linkage_e>(m, "Linkage")
.value("INTERNAL", IREE_VM_FUNCTION_LINKAGE_INTERNAL)
.value("IMPORT", IREE_VM_FUNCTION_LINKAGE_IMPORT)
diff --git a/runtime/bindings/python/vm.h b/runtime/bindings/python/vm.h
index 37443cc..202e221 100644
--- a/runtime/bindings/python/vm.h
+++ b/runtime/bindings/python/vm.h
@@ -129,7 +129,8 @@
class VmModule : public ApiRefCounted<VmModule, iree_vm_module_t> {
public:
- static VmModule FromFlatbufferBlob(py::object flatbuffer_blob_object);
+ static VmModule FromFlatbufferBlob(VmInstance* instance,
+ py::object flatbuffer_blob_object);
std::optional<iree_vm_function_t> LookupFunction(
const std::string& name, iree_vm_function_linkage_t linkage);
diff --git a/runtime/bindings/tflite/interpreter.c b/runtime/bindings/tflite/interpreter.c
index f12177b..4183b34 100644
--- a/runtime/bindings/tflite/interpreter.c
+++ b/runtime/bindings/tflite/interpreter.c
@@ -61,9 +61,9 @@
"failed creating the default device for driver '%.*s'",
(int)driver_name.size, driver_name.data);
- IREE_RETURN_IF_ERROR(
- iree_hal_module_create(interpreter->device, IREE_HAL_MODULE_FLAG_NONE,
- interpreter->allocator, &interpreter->hal_module));
+ IREE_RETURN_IF_ERROR(iree_hal_module_create(
+ interpreter->instance, interpreter->device, IREE_HAL_MODULE_FLAG_NONE,
+ interpreter->allocator, &interpreter->hal_module));
return iree_ok_status();
}
@@ -357,13 +357,13 @@
sizeof(interpreter->options));
}
+ interpreter->instance = model->instance;
+ iree_vm_instance_retain(interpreter->instance);
interpreter->user_module = model->module;
iree_vm_module_retain(interpreter->user_module);
// External contexts could possibly used to emulate sharing this, but really
// if a user is running with multiple models the tflite API is insufficient.
- IREE_RETURN_IF_ERROR(
- iree_vm_instance_create(interpreter->allocator, &interpreter->instance));
IREE_RETURN_IF_ERROR(_TfLiteInterpreterPrepareHAL(interpreter));
// Context will contain both the user-provided bytecode and the HAL module.
diff --git a/runtime/bindings/tflite/model.c b/runtime/bindings/tflite/model.c
index a5760d9..fa2f538 100644
--- a/runtime/bindings/tflite/model.c
+++ b/runtime/bindings/tflite/model.c
@@ -13,20 +13,6 @@
#include "iree/modules/hal/module.h"
#include "iree/vm/bytecode_module.h"
-static iree_status_t _TfLiteModelPrepareRuntime() {
- IREE_TRACE_ZONE_BEGIN(z0);
-
- // TODO(#8698): need to register these on an instance.
- // The instance constructor does the builtin types for us and if we created it
- // first we wouldn't need to call it.
- IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_vm_register_builtin_types(NULL));
- IREE_RETURN_AND_END_ZONE_IF_ERROR(z0,
- iree_hal_module_register_all_types(NULL));
-
- IREE_TRACE_ZONE_END(z0);
- return iree_ok_status();
-}
-
static iree_status_t _TfLiteModelCalculateFunctionIOCounts(
const iree_vm_function_signature_t* signature, int32_t* out_input_count,
int32_t* out_output_count) {
@@ -47,15 +33,19 @@
TfLiteModel* model) {
IREE_TRACE_ZONE_BEGIN(z0);
- IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, _TfLiteModelPrepareRuntime());
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_vm_instance_create(allocator, &model->instance));
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_module_register_all_types(model->instance));
iree_const_byte_span_t flatbuffer_span =
iree_make_const_byte_span(flatbuffer_data, flatbuffer_size);
iree_allocator_t flatbuffer_allocator = iree_allocator_null();
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0,
- iree_vm_bytecode_module_create(flatbuffer_span, flatbuffer_allocator,
- allocator, &model->module),
+ iree_vm_bytecode_module_create(model->instance, flatbuffer_span,
+ flatbuffer_allocator, allocator,
+ &model->module),
"error creating bytecode module");
IREE_RETURN_AND_END_ZONE_IF_ERROR(
@@ -130,7 +120,7 @@
if (!iree_status_is_ok(status)) {
iree_status_fprint(stderr, status);
iree_status_free(status);
- iree_allocator_free(allocator, model);
+ TfLiteModelDelete(model);
IREE_TRACE_ZONE_END(z0);
return NULL;
}
@@ -170,6 +160,7 @@
int ret = fread(model->owned_model_data, 1, file_size, file);
fclose(file);
if (ret != file_size) {
+ TfLiteModelDelete(model);
IREE_TRACE_MESSAGE(ERROR, "failed model+data read");
IREE_TRACE_ZONE_END(z0);
return NULL;
@@ -178,6 +169,7 @@
status = _TfLiteModelInitializeModule(model->owned_model_data, file_size,
allocator, model);
if (!iree_status_is_ok(iree_status_consume_code(status))) {
+ TfLiteModelDelete(model);
IREE_TRACE_ZONE_END(z0);
return NULL;
}
@@ -196,6 +188,7 @@
if (model && iree_atomic_ref_count_dec(&model->ref_count) == 1) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_vm_module_release(model->module);
+ iree_vm_instance_release(model->instance);
iree_allocator_free(model->allocator, model);
IREE_TRACE_ZONE_END(z0);
}
diff --git a/runtime/bindings/tflite/model.h b/runtime/bindings/tflite/model.h
index 9db6b43..d5f1ea2 100644
--- a/runtime/bindings/tflite/model.h
+++ b/runtime/bindings/tflite/model.h
@@ -29,6 +29,10 @@
iree_allocator_t allocator;
void* owned_model_data;
+ // HACK: no public API that allows us to share this without spooky action
+ // at a distance. Today it's ok for these to be unique as we don't check that
+ // they are consistent and all instances have the same types registered.
+ iree_vm_instance_t* instance;
iree_vm_module_t* module;
_TfLiteModelExports exports;
int32_t input_count;
diff --git a/runtime/src/iree/hal/local/loaders/vmvx_module_loader.c b/runtime/src/iree/hal/local/loaders/vmvx_module_loader.c
index 21f8454..32cff19 100644
--- a/runtime/src/iree/hal/local/loaders/vmvx_module_loader.c
+++ b/runtime/src/iree/hal/local/loaders/vmvx_module_loader.c
@@ -471,7 +471,8 @@
// A single VMVX module is shared across all loaded executables.
iree_vm_module_t* vmvx_module = NULL;
- IREE_RETURN_IF_ERROR(iree_vmvx_module_create(host_allocator, &vmvx_module));
+ IREE_RETURN_IF_ERROR(
+ iree_vmvx_module_create(instance, host_allocator, &vmvx_module));
iree_host_size_t common_module_count = 1 + user_module_count;
iree_hal_vmvx_module_loader_t* executable_loader = NULL;
@@ -584,8 +585,9 @@
// we have it) to the module to manage.
iree_vm_module_t* bytecode_module = NULL;
iree_status_t status = iree_vm_bytecode_module_create(
- executable_params->executable_data, bytecode_module_allocator,
- executable_loader->host_allocator, &bytecode_module);
+ executable_loader->instance, executable_params->executable_data,
+ bytecode_module_allocator, executable_loader->host_allocator,
+ &bytecode_module);
// Create the context tying together the shared VMVX module and the
// user-provided module that references it. We always link the compiled module
diff --git a/runtime/src/iree/modules/check/check_test.cc b/runtime/src/iree/modules/check/check_test.cc
index 966ea4a..6f5e72e 100644
--- a/runtime/src/iree/modules/check/check_test.cc
+++ b/runtime/src/iree/modules/check/check_test.cc
@@ -46,13 +46,13 @@
}
IREE_ASSERT_OK(iree_hal_driver_create_default_device(
hal_driver, iree_allocator_system(), &device_));
- IREE_ASSERT_OK(iree_hal_module_create(device_, IREE_HAL_MODULE_FLAG_NONE,
- iree_allocator_system(),
- &hal_module_));
+ IREE_ASSERT_OK(
+ iree_hal_module_create(instance_, device_, IREE_HAL_MODULE_FLAG_NONE,
+ iree_allocator_system(), &hal_module_));
iree_hal_driver_release(hal_driver);
- IREE_ASSERT_OK(
- iree_check_module_create(iree_allocator_system(), &check_module_))
+ IREE_ASSERT_OK(iree_check_module_create(instance_, iree_allocator_system(),
+ &check_module_))
<< "Native module failed to init";
}
diff --git a/runtime/src/iree/modules/check/module.cc b/runtime/src/iree/modules/check/module.cc
index c74793c..14e2487 100644
--- a/runtime/src/iree/modules/check/module.cc
+++ b/runtime/src/iree/modules/check/module.cc
@@ -397,11 +397,12 @@
// Note that while we are using C++ bindings internally we still expose the
// module as a C instance. This hides the details of our implementation.
extern "C" iree_status_t iree_check_module_create(
- iree_allocator_t allocator, iree_vm_module_t** out_module) {
+ iree_vm_instance_t* instance, iree_allocator_t allocator,
+ iree_vm_module_t** out_module) {
IREE_ASSERT_ARGUMENT(out_module);
*out_module = NULL;
auto module = std::make_unique<CheckModule>(
- "check", /*version=*/0, allocator,
+ "check", /*version=*/0, instance, allocator,
iree::span<const vm::NativeFunction<CheckModuleState>>(
kCheckModuleFunctions));
*out_module = module.release()->interface();
diff --git a/runtime/src/iree/modules/check/module.h b/runtime/src/iree/modules/check/module.h
index 24d29ba..1cb5c9e 100644
--- a/runtime/src/iree/modules/check/module.h
+++ b/runtime/src/iree/modules/check/module.h
@@ -17,7 +17,8 @@
#endif // __cplusplus
// Creates a native custom module.
-iree_status_t iree_check_module_create(iree_allocator_t allocator,
+iree_status_t iree_check_module_create(iree_vm_instance_t* instance,
+ iree_allocator_t allocator,
iree_vm_module_t** out_module);
#ifdef __cplusplus
diff --git a/runtime/src/iree/modules/hal/module.c b/runtime/src/iree/modules/hal/module.c
index d998e77..e7d2326 100644
--- a/runtime/src/iree/modules/hal/module.c
+++ b/runtime/src/iree/modules/hal/module.c
@@ -1510,8 +1510,10 @@
};
IREE_API_EXPORT iree_status_t iree_hal_module_create(
- iree_hal_device_t* device, iree_hal_module_flags_t flags,
- iree_allocator_t host_allocator, iree_vm_module_t** out_module) {
+ iree_vm_instance_t* instance, iree_hal_device_t* device,
+ iree_hal_module_flags_t flags, iree_allocator_t host_allocator,
+ iree_vm_module_t** out_module) {
+ IREE_ASSERT_ARGUMENT(instance);
IREE_ASSERT_ARGUMENT(device);
IREE_ASSERT_ARGUMENT(out_module);
*out_module = NULL;
@@ -1532,8 +1534,9 @@
IREE_RETURN_IF_ERROR(
iree_allocator_malloc(host_allocator, total_size, (void**)&base_module));
memset(base_module, 0, total_size);
- iree_status_t status = iree_vm_native_module_initialize(
- &interface, &iree_hal_module_descriptor_, host_allocator, base_module);
+ iree_status_t status =
+ iree_vm_native_module_initialize(&interface, &iree_hal_module_descriptor_,
+ instance, host_allocator, base_module);
if (!iree_status_is_ok(status)) {
iree_allocator_free(host_allocator, base_module);
return status;
diff --git a/runtime/src/iree/modules/hal/module.h b/runtime/src/iree/modules/hal/module.h
index 09dd8ef..5ba71af 100644
--- a/runtime/src/iree/modules/hal/module.h
+++ b/runtime/src/iree/modules/hal/module.h
@@ -30,8 +30,9 @@
// Each context using this module will share the device and have compatible
// allocations.
IREE_API_EXPORT iree_status_t iree_hal_module_create(
- iree_hal_device_t* device, iree_hal_module_flags_t flags,
- iree_allocator_t host_allocator, iree_vm_module_t** out_module);
+ iree_vm_instance_t* instance, iree_hal_device_t* device,
+ iree_hal_module_flags_t flags, iree_allocator_t host_allocator,
+ iree_vm_module_t** out_module);
// Returns the device currently in use by the HAL module.
// Returns NULL if no device has been initialized yet.
diff --git a/runtime/src/iree/modules/vmvx/module.c b/runtime/src/iree/modules/vmvx/module.c
index 631c0d9..293e7be 100644
--- a/runtime/src/iree/modules/vmvx/module.c
+++ b/runtime/src/iree/modules/vmvx/module.c
@@ -582,7 +582,9 @@
};
IREE_API_EXPORT iree_status_t iree_vmvx_module_create(
- iree_allocator_t host_allocator, iree_vm_module_t** out_module) {
+ iree_vm_instance_t* instance, iree_allocator_t host_allocator,
+ iree_vm_module_t** out_module) {
+ IREE_ASSERT_ARGUMENT(instance);
IREE_ASSERT_ARGUMENT(out_module);
*out_module = NULL;
@@ -602,7 +604,8 @@
iree_allocator_malloc(host_allocator, total_size, (void**)&base_module));
memset(base_module, 0, total_size);
iree_status_t status = iree_vm_native_module_initialize(
- &interface, &iree_vmvx_module_descriptor_, host_allocator, base_module);
+ &interface, &iree_vmvx_module_descriptor_, instance, host_allocator,
+ base_module);
if (!iree_status_is_ok(status)) {
iree_allocator_free(host_allocator, base_module);
return status;
diff --git a/runtime/src/iree/modules/vmvx/module.h b/runtime/src/iree/modules/vmvx/module.h
index 1755ad9..5861a86 100644
--- a/runtime/src/iree/modules/vmvx/module.h
+++ b/runtime/src/iree/modules/vmvx/module.h
@@ -18,7 +18,8 @@
// Creates the VMVX module with a default configuration.
IREE_API_EXPORT iree_status_t iree_vmvx_module_create(
- iree_allocator_t host_allocator, iree_vm_module_t** out_module);
+ iree_vm_instance_t* instance, iree_allocator_t host_allocator,
+ iree_vm_module_t** out_module);
#ifdef __cplusplus
} // extern "C"
diff --git a/runtime/src/iree/runtime/instance.c b/runtime/src/iree/runtime/instance.c
index 9502211..c710f6c 100644
--- a/runtime/src/iree/runtime/instance.c
+++ b/runtime/src/iree/runtime/instance.c
@@ -53,6 +53,9 @@
// can find the same devices. This may mean a new HAL type like
// iree_hal_device_pool_t to prevent too much coupling and make weak
// references easier.
+
+ // VM instance shared across all sessions.
+ iree_vm_instance_t* vm_instance;
};
IREE_API_EXPORT iree_status_t iree_runtime_instance_create(
@@ -70,16 +73,6 @@
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_api_version_check(options->api_version, &actual_version));
- // Register builtin types.
- // TODO(#8698): change to per-instance type registries to avoid these
- // global (UNSAFE!) calls. For now hosting applications should really only
- // be using a single instance anyway.
- // The instance constructor does the builtin types for us and if we created it
- // first we wouldn't need to call it.
- IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_vm_register_builtin_types(NULL));
- IREE_RETURN_AND_END_ZONE_IF_ERROR(z0,
- iree_hal_module_register_all_types(NULL));
-
// Allocate the instance state.
iree_runtime_instance_t* instance = NULL;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
@@ -91,15 +84,26 @@
instance->driver_registry = options->driver_registry;
// TODO(benvanik): driver registry ref counting.
- *out_instance = instance;
+ iree_status_t status =
+ iree_vm_instance_create(host_allocator, &instance->vm_instance);
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_module_register_all_types(instance->vm_instance);
+ }
+
+ if (iree_status_is_ok(status)) {
+ *out_instance = instance;
+ } else {
+ iree_runtime_instance_release(instance);
+ }
IREE_TRACE_ZONE_END(z0);
- return iree_ok_status();
+ return status;
}
static void iree_runtime_instance_destroy(iree_runtime_instance_t* instance) {
IREE_ASSERT_ARGUMENT(instance);
IREE_TRACE_ZONE_BEGIN(z0);
+ iree_vm_instance_release(instance->vm_instance);
iree_allocator_free(instance->host_allocator, instance);
IREE_TRACE_ZONE_END(z0);
@@ -125,6 +129,12 @@
return instance->host_allocator;
}
+IREE_API_EXPORT iree_vm_instance_t* iree_runtime_instance_vm_instance(
+ const iree_runtime_instance_t* instance) {
+ IREE_ASSERT_ARGUMENT(instance);
+ return instance->vm_instance;
+}
+
IREE_API_EXPORT iree_hal_driver_registry_t*
iree_runtime_instance_driver_registry(const iree_runtime_instance_t* instance) {
IREE_ASSERT_ARGUMENT(instance);
diff --git a/runtime/src/iree/runtime/instance.h b/runtime/src/iree/runtime/instance.h
index 6bf5423..d3d79bf 100644
--- a/runtime/src/iree/runtime/instance.h
+++ b/runtime/src/iree/runtime/instance.h
@@ -9,6 +9,7 @@
#include "iree/base/api.h"
#include "iree/hal/api.h"
+#include "iree/vm/api.h"
#ifdef __cplusplus
extern "C" {
@@ -104,6 +105,10 @@
IREE_API_EXPORT iree_allocator_t
iree_runtime_instance_host_allocator(const iree_runtime_instance_t* instance);
+// Returns the VM instance shared by all sessions using the runtime instance.
+IREE_API_EXPORT iree_vm_instance_t* iree_runtime_instance_vm_instance(
+ const iree_runtime_instance_t* instance);
+
// Returns the optional driver registry used to enumerate drivers and devices.
// If not provided then iree_runtime_session_create_with_device must be used
// to specify the device that a session should use.
diff --git a/runtime/src/iree/runtime/session.c b/runtime/src/iree/runtime/session.c
index e64a8ca..53cdbc3 100644
--- a/runtime/src/iree/runtime/session.c
+++ b/runtime/src/iree/runtime/session.c
@@ -87,14 +87,15 @@
// Create the context empty so that we can add our modules to it.
iree_status_t status = iree_vm_context_create(
- /*instance=*/NULL, options->context_flags, host_allocator,
- &session->context);
+ iree_runtime_instance_vm_instance(instance), options->context_flags,
+ host_allocator, &session->context);
// Add the HAL module; it is always required when using the runtime API.
// Lower-level usage of the VM can avoid the HAL if it's not required.
iree_vm_module_t* hal_module = NULL;
if (iree_status_is_ok(status)) {
- status = iree_hal_module_create(device, IREE_HAL_MODULE_FLAG_NONE,
+ status = iree_hal_module_create(iree_runtime_instance_vm_instance(instance),
+ device, IREE_HAL_MODULE_FLAG_NONE,
host_allocator, &hal_module);
}
if (iree_status_is_ok(status)) {
@@ -208,8 +209,9 @@
iree_vm_module_t* module = NULL;
iree_status_t status = iree_vm_bytecode_module_create(
- flatbuffer_data, flatbuffer_allocator,
- iree_runtime_session_host_allocator(session), &module);
+ iree_runtime_instance_vm_instance(session->instance), flatbuffer_data,
+ flatbuffer_allocator, iree_runtime_session_host_allocator(session),
+ &module);
if (iree_status_is_ok(status)) {
status = iree_runtime_session_append_module(session, module);
}
diff --git a/runtime/src/iree/tooling/context_util.c b/runtime/src/iree/tooling/context_util.c
index 7163eaf..8883bde 100644
--- a/runtime/src/iree/tooling/context_util.c
+++ b/runtime/src/iree/tooling/context_util.c
@@ -58,7 +58,7 @@
// The module takes ownership of the file contents (when successful).
iree_vm_module_t* module = NULL;
iree_status_t status = iree_vm_bytecode_module_create(
- file_contents->const_buffer,
+ instance, file_contents->const_buffer,
iree_file_contents_deallocator(file_contents), host_allocator, &module);
if (iree_status_is_ok(status)) {
@@ -115,7 +115,7 @@
iree_hal_module_flags_t flags = IREE_HAL_MODULE_FLAG_NONE;
iree_vm_module_t* module = NULL;
iree_status_t status =
- iree_hal_module_create(device, flags, host_allocator, &module);
+ iree_hal_module_create(instance, device, flags, host_allocator, &module);
if (iree_status_is_ok(status)) {
*out_module = module;
diff --git a/runtime/src/iree/tooling/trace_replay.c b/runtime/src/iree/tooling/trace_replay.c
index e21e8ea..eaadddc 100644
--- a/runtime/src/iree/tooling/trace_replay.c
+++ b/runtime/src/iree/tooling/trace_replay.c
@@ -109,9 +109,9 @@
document, module_node, iree_make_cstring_view("driver"), &driver_node));
IREE_RETURN_IF_ERROR(iree_trace_replay_create_device(
replay, driver_node, replay->host_allocator, &replay->device));
- IREE_RETURN_IF_ERROR(
- iree_hal_module_create(replay->device, IREE_HAL_MODULE_FLAG_NONE,
- replay->host_allocator, &module));
+ IREE_RETURN_IF_ERROR(iree_hal_module_create(
+ replay->instance, replay->device, IREE_HAL_MODULE_FLAG_NONE,
+ replay->host_allocator, &module));
}
if (!module) {
return iree_make_status(
@@ -154,9 +154,9 @@
if (iree_status_is_ok(status)) {
iree_allocator_t flatbuffer_deallocator =
iree_file_contents_deallocator(flatbuffer_contents);
- status = iree_vm_bytecode_module_create(flatbuffer_contents->const_buffer,
- flatbuffer_deallocator,
- replay->host_allocator, &module);
+ status = iree_vm_bytecode_module_create(
+ replay->instance, flatbuffer_contents->const_buffer,
+ flatbuffer_deallocator, replay->host_allocator, &module);
if (!iree_status_is_ok(status)) {
iree_file_contents_free(flatbuffer_contents);
}
diff --git a/runtime/src/iree/vm/BUILD b/runtime/src/iree/vm/BUILD
index f8ef9b7..4483b76 100644
--- a/runtime/src/iree/vm/BUILD
+++ b/runtime/src/iree/vm/BUILD
@@ -55,7 +55,6 @@
name = "impl",
srcs = [
"buffer.c",
- "builtin_types.c",
"context.c",
"instance.c",
"invocation.c",
@@ -68,7 +67,6 @@
],
hdrs = [
"buffer.h",
- "builtin_types.h",
"context.h",
"instance.h",
"invocation.h",
diff --git a/runtime/src/iree/vm/CMakeLists.txt b/runtime/src/iree/vm/CMakeLists.txt
index e5872fc..9351f48 100644
--- a/runtime/src/iree/vm/CMakeLists.txt
+++ b/runtime/src/iree/vm/CMakeLists.txt
@@ -42,7 +42,6 @@
impl
HDRS
"buffer.h"
- "builtin_types.h"
"context.h"
"instance.h"
"invocation.h"
@@ -56,7 +55,6 @@
"value.h"
SRCS
"buffer.c"
- "builtin_types.c"
"context.c"
"instance.c"
"invocation.c"
diff --git a/runtime/src/iree/vm/api.h b/runtime/src/iree/vm/api.h
index 3f559f7..87e5750 100644
--- a/runtime/src/iree/vm/api.h
+++ b/runtime/src/iree/vm/api.h
@@ -9,7 +9,6 @@
#include "iree/base/api.h"
#include "iree/vm/buffer.h" // IWYU pragma: export
-#include "iree/vm/builtin_types.h" // IWYU pragma: export
#include "iree/vm/context.h" // IWYU pragma: export
#include "iree/vm/instance.h" // IWYU pragma: export
#include "iree/vm/invocation.h" // IWYU pragma: export
diff --git a/runtime/src/iree/vm/buffer_test.cc b/runtime/src/iree/vm/buffer_test.cc
index f443be9..ef1c660 100644
--- a/runtime/src/iree/vm/buffer_test.cc
+++ b/runtime/src/iree/vm/buffer_test.cc
@@ -10,18 +10,16 @@
#include "iree/base/api.h"
#include "iree/testing/gtest.h"
-#include "iree/vm/builtin_types.h"
+#include "iree/vm/instance.h"
namespace {
-class VMBufferTest : public ::testing::Test {
- protected:
+static iree_vm_instance_t* instance = NULL;
+struct VMBufferTest : public ::testing::Test {
static void SetUpTestSuite() {
- // TODO(#8698): need to register these on an instance.
- // The instance constructor does this for us and if we created it first we
- // wouldn't need to call this.
- IREE_CHECK_OK(iree_vm_register_builtin_types(NULL));
+ IREE_CHECK_OK(iree_vm_instance_create(iree_allocator_system(), &instance));
}
+ static void TearDownTestSuite() { iree_vm_instance_release(instance); }
};
// Tests that the data allocator is correctly called when using stack
diff --git a/runtime/src/iree/vm/builtin_types.c b/runtime/src/iree/vm/builtin_types.c
deleted file mode 100644
index f652629..0000000
--- a/runtime/src/iree/vm/builtin_types.c
+++ /dev/null
@@ -1,17 +0,0 @@
-// Copyright 2019 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/vm/builtin_types.h"
-
-iree_status_t iree_vm_buffer_register_types(iree_vm_instance_t* instance);
-iree_status_t iree_vm_list_register_types(iree_vm_instance_t* instance);
-
-IREE_API_EXPORT iree_status_t
-iree_vm_register_builtin_types(iree_vm_instance_t* instance) {
- IREE_RETURN_IF_ERROR(iree_vm_buffer_register_types(instance));
- IREE_RETURN_IF_ERROR(iree_vm_list_register_types(instance));
- return iree_ok_status();
-}
diff --git a/runtime/src/iree/vm/builtin_types.h b/runtime/src/iree/vm/builtin_types.h
deleted file mode 100644
index f384e16..0000000
--- a/runtime/src/iree/vm/builtin_types.h
+++ /dev/null
@@ -1,27 +0,0 @@
-// Copyright 2019 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_VM_BUILTIN_TYPES_H_
-#define IREE_VM_BUILTIN_TYPES_H_
-
-#include "iree/base/api.h"
-
-#ifdef __cplusplus
-extern "C" {
-#endif // __cplusplus
-
-typedef struct iree_vm_instance_t iree_vm_instance_t;
-
-// Registers the builtin VM types. This must be called on startup. Safe to call
-// multiple times.
-IREE_API_EXPORT iree_status_t
-iree_vm_register_builtin_types(iree_vm_instance_t* instance);
-
-#ifdef __cplusplus
-} // extern "C"
-#endif // __cplusplus
-
-#endif // IREE_VM_BUILTIN_TYPES_H_
diff --git a/runtime/src/iree/vm/bytecode_dispatch_async_test.cc b/runtime/src/iree/vm/bytecode_dispatch_async_test.cc
index 8ef3fbc..e398c32 100644
--- a/runtime/src/iree/vm/bytecode_dispatch_async_test.cc
+++ b/runtime/src/iree/vm/bytecode_dispatch_async_test.cc
@@ -26,11 +26,6 @@
class VMBytecodeDispatchAsyncTest : public ::testing::Test {
protected:
- static void SetUpTestSuite() {
- // TODO(#8698): need to register these on an instance.
- IREE_CHECK_OK(iree_vm_register_builtin_types(NULL));
- }
-
void SetUp() override {
IREE_TRACE_SCOPE();
const iree_file_toc_t* file = async_bytecode_modules_c_create();
@@ -38,6 +33,7 @@
IREE_CHECK_OK(iree_vm_instance_create(iree_allocator_system(), &instance_));
IREE_CHECK_OK(iree_vm_bytecode_module_create(
+ instance_,
iree_const_byte_span_t{reinterpret_cast<const uint8_t*>(file->data),
file->size},
iree_allocator_null(), iree_allocator_system(), &bytecode_module_));
diff --git a/runtime/src/iree/vm/bytecode_dispatch_test.cc b/runtime/src/iree/vm/bytecode_dispatch_test.cc
index 65f92bb..58ea382 100644
--- a/runtime/src/iree/vm/bytecode_dispatch_test.cc
+++ b/runtime/src/iree/vm/bytecode_dispatch_test.cc
@@ -36,10 +36,8 @@
std::vector<TestParams> GetModuleTestParams() {
std::vector<TestParams> test_params;
- // TODO(#8698): need to register these on an instance.
- // The instance constructor does this for us and if we created it first we
- // wouldn't need to call this.
- IREE_CHECK_OK(iree_vm_register_builtin_types(NULL));
+ iree_vm_instance_t* instance = NULL;
+ IREE_CHECK_OK(iree_vm_instance_create(iree_allocator_system(), &instance));
const struct iree_file_toc_t* module_file_toc =
all_bytecode_modules_c_create();
@@ -47,6 +45,7 @@
const auto& module_file = module_file_toc[i];
iree_vm_module_t* module = nullptr;
IREE_CHECK_OK(iree_vm_bytecode_module_create(
+ instance,
iree_const_byte_span_t{
reinterpret_cast<const uint8_t*>(module_file.data),
module_file.size},
@@ -64,6 +63,8 @@
iree_vm_module_release(module);
}
+ iree_vm_instance_release(instance);
+
return test_params;
}
@@ -77,6 +78,7 @@
IREE_CHECK_OK(iree_vm_instance_create(iree_allocator_system(), &instance_));
IREE_CHECK_OK(iree_vm_bytecode_module_create(
+ instance_,
iree_const_byte_span_t{
reinterpret_cast<const uint8_t*>(test_params.module_file.data),
test_params.module_file.size},
diff --git a/runtime/src/iree/vm/bytecode_module.c b/runtime/src/iree/vm/bytecode_module.c
index f07f8a9..ec5d091 100644
--- a/runtime/src/iree/vm/bytecode_module.c
+++ b/runtime/src/iree/vm/bytecode_module.c
@@ -1119,8 +1119,9 @@
}
IREE_API_EXPORT iree_status_t iree_vm_bytecode_module_create(
- iree_const_byte_span_t archive_contents, iree_allocator_t archive_allocator,
- iree_allocator_t allocator, iree_vm_module_t** out_module) {
+ iree_vm_instance_t* instance, iree_const_byte_span_t archive_contents,
+ iree_allocator_t archive_allocator, iree_allocator_t allocator,
+ iree_vm_module_t** out_module) {
IREE_TRACE_ZONE_BEGIN(z0);
IREE_ASSERT_ARGUMENT(out_module);
*out_module = NULL;
diff --git a/runtime/src/iree/vm/bytecode_module.h b/runtime/src/iree/vm/bytecode_module.h
index e158d40..93ebd82 100644
--- a/runtime/src/iree/vm/bytecode_module.h
+++ b/runtime/src/iree/vm/bytecode_module.h
@@ -21,8 +21,9 @@
// |archive_contents| when the module is destroyed and otherwise the ownership
// of the memory remains with the caller.
IREE_API_EXPORT iree_status_t iree_vm_bytecode_module_create(
- iree_const_byte_span_t archive_contents, iree_allocator_t archive_allocator,
- iree_allocator_t allocator, iree_vm_module_t** out_module);
+ iree_vm_instance_t* instance, iree_const_byte_span_t archive_contents,
+ iree_allocator_t archive_allocator, iree_allocator_t allocator,
+ iree_vm_module_t** out_module);
// Parses the module archive header in |archive_contents|.
// The subrange containing the FlatBuffer data is returned as well as the
diff --git a/runtime/src/iree/vm/bytecode_module_benchmark.cc b/runtime/src/iree/vm/bytecode_module_benchmark.cc
index e851bff..04e3716 100644
--- a/runtime/src/iree/vm/bytecode_module_benchmark.cc
+++ b/runtime/src/iree/vm/bytecode_module_benchmark.cc
@@ -61,11 +61,13 @@
};
static iree_status_t native_import_module_create(
- iree_allocator_t allocator, iree_vm_module_t** out_module) {
+ iree_vm_instance_t* instance, iree_allocator_t allocator,
+ iree_vm_module_t** out_module) {
iree_vm_module_t interface;
IREE_RETURN_IF_ERROR(iree_vm_module_initialize(&interface, NULL));
- return iree_vm_native_module_create(
- &interface, &native_import_module_descriptor_, allocator, out_module);
+ return iree_vm_native_module_create(&interface,
+ &native_import_module_descriptor_,
+ instance, allocator, out_module);
}
// Benchmarks the given exported function, optionally passing in arguments.
@@ -77,13 +79,14 @@
IREE_CHECK_OK(iree_vm_instance_create(iree_allocator_system(), &instance));
iree_vm_module_t* import_module = NULL;
- IREE_CHECK_OK(
- native_import_module_create(iree_allocator_system(), &import_module));
+ IREE_CHECK_OK(native_import_module_create(instance, iree_allocator_system(),
+ &import_module));
const auto* module_file_toc =
iree_vm_bytecode_module_benchmark_module_create();
iree_vm_module_t* bytecode_module = nullptr;
IREE_CHECK_OK(iree_vm_bytecode_module_create(
+ instance,
iree_const_byte_span_t{
reinterpret_cast<const uint8_t*>(module_file_toc->data),
module_file_toc->size},
@@ -130,13 +133,15 @@
}
static void BM_ModuleCreate(benchmark::State& state) {
- IREE_CHECK_OK(iree_vm_register_builtin_types());
+ iree_vm_instance_t* instance = NULL;
+ IREE_CHECK_OK(iree_vm_instance_create(iree_allocator_system(), &instance));
while (state.KeepRunning()) {
const auto* module_file_toc =
iree_vm_bytecode_module_benchmark_module_create();
iree_vm_module_t* module = nullptr;
IREE_CHECK_OK(iree_vm_bytecode_module_create(
+ instance,
iree_const_byte_span_t{
reinterpret_cast<const uint8_t*>(module_file_toc->data),
module_file_toc->size},
@@ -147,16 +152,20 @@
iree_vm_module_release(module);
}
+
+ iree_vm_instance_release(instance);
}
BENCHMARK(BM_ModuleCreate);
static void BM_ModuleCreateState(benchmark::State& state) {
- IREE_CHECK_OK(iree_vm_register_builtin_types());
+ iree_vm_instance_t* instance = NULL;
+ IREE_CHECK_OK(iree_vm_instance_create(iree_allocator_system(), &instance));
const auto* module_file_toc =
iree_vm_bytecode_module_benchmark_module_create();
iree_vm_module_t* module = nullptr;
IREE_CHECK_OK(iree_vm_bytecode_module_create(
+ instance,
iree_const_byte_span_t{
reinterpret_cast<const uint8_t*>(module_file_toc->data),
module_file_toc->size},
@@ -174,17 +183,20 @@
}
iree_vm_module_release(module);
+ iree_vm_instance_release(instance);
}
BENCHMARK(BM_ModuleCreateState);
static void BM_FullModuleInit(benchmark::State& state) {
- IREE_CHECK_OK(iree_vm_register_builtin_types());
+ iree_vm_instance_t* instance = NULL;
+ IREE_CHECK_OK(iree_vm_instance_create(iree_allocator_system(), &instance));
while (state.KeepRunning()) {
const auto* module_file_toc =
iree_vm_bytecode_module_benchmark_module_create();
iree_vm_module_t* module = nullptr;
IREE_CHECK_OK(iree_vm_bytecode_module_create(
+ instance,
iree_const_byte_span_t{
reinterpret_cast<const uint8_t*>(module_file_toc->data),
module_file_toc->size},
@@ -198,6 +210,8 @@
module->free_state(module->self, module_state);
iree_vm_module_release(module);
}
+
+ iree_vm_instance_release(instance);
}
BENCHMARK(BM_FullModuleInit);
diff --git a/runtime/src/iree/vm/bytecode_module_size_benchmark.cc b/runtime/src/iree/vm/bytecode_module_size_benchmark.cc
index 6b949e3..af5c16e 100644
--- a/runtime/src/iree/vm/bytecode_module_size_benchmark.cc
+++ b/runtime/src/iree/vm/bytecode_module_size_benchmark.cc
@@ -17,6 +17,7 @@
iree_vm_bytecode_module_size_benchmark_module_create();
iree_vm_module_t* module = nullptr;
iree_vm_bytecode_module_create(
+ instance,
iree_const_byte_span_t{
reinterpret_cast<const uint8_t*>(module_file_toc->data),
module_file_toc->size},
diff --git a/runtime/src/iree/vm/instance.c b/runtime/src/iree/vm/instance.c
index 63744e3..2defba9 100644
--- a/runtime/src/iree/vm/instance.c
+++ b/runtime/src/iree/vm/instance.c
@@ -10,7 +10,19 @@
#include "iree/base/internal/atomics.h"
#include "iree/base/tracing.h"
-#include "iree/vm/builtin_types.h"
+
+// Defined in their respective files:
+iree_status_t iree_vm_buffer_register_types(iree_vm_instance_t* instance);
+iree_status_t iree_vm_list_register_types(iree_vm_instance_t* instance);
+
+// Registers the builtin VM types. This must be called on startup. Safe to call
+// multiple times.
+static iree_status_t iree_vm_register_builtin_types(
+ iree_vm_instance_t* instance) {
+ IREE_RETURN_IF_ERROR(iree_vm_buffer_register_types(instance));
+ IREE_RETURN_IF_ERROR(iree_vm_list_register_types(instance));
+ return iree_ok_status();
+}
struct iree_vm_instance_t {
iree_atomic_ref_count_t ref_count;
diff --git a/runtime/src/iree/vm/list_test.cc b/runtime/src/iree/vm/list_test.cc
index ace1787..979ac02 100644
--- a/runtime/src/iree/vm/list_test.cc
+++ b/runtime/src/iree/vm/list_test.cc
@@ -12,7 +12,7 @@
#include "iree/base/api.h"
#include "iree/testing/gtest.h"
#include "iree/testing/status_matchers.h"
-#include "iree/vm/builtin_types.h"
+#include "iree/vm/instance.h"
#include "iree/vm/ref_cc.h"
class A : public iree::vm::RefObject<A> {
@@ -70,15 +70,13 @@
return ref;
}
-class VMListTest : public ::testing::Test {
- protected:
+static iree_vm_instance_t* instance = NULL;
+struct VMListTest : public ::testing::Test {
static void SetUpTestSuite() {
- // TODO(#8698): need to register these on an instance.
- // The instance constructor does this for us and if we created it first we
- // wouldn't need to call this.
- IREE_CHECK_OK(iree_vm_register_builtin_types(NULL));
- RegisterRefTypes(NULL);
+ IREE_CHECK_OK(iree_vm_instance_create(iree_allocator_system(), &instance));
+ RegisterRefTypes(instance);
}
+ static void TearDownTestSuite() { iree_vm_instance_release(instance); }
};
// Tests simple primitive value list usage, mainly just for demonstration.
diff --git a/runtime/src/iree/vm/native_module.c b/runtime/src/iree/vm/native_module.c
index 92bc1af..95084b7 100644
--- a/runtime/src/iree/vm/native_module.c
+++ b/runtime/src/iree/vm/native_module.c
@@ -402,7 +402,11 @@
IREE_API_EXPORT iree_status_t iree_vm_native_module_create(
const iree_vm_module_t* interface,
const iree_vm_native_module_descriptor_t* module_descriptor,
- iree_allocator_t allocator, iree_vm_module_t** out_module) {
+ iree_vm_instance_t* instance, iree_allocator_t allocator,
+ iree_vm_module_t** out_module) {
+ IREE_ASSERT_ARGUMENT(interface);
+ IREE_ASSERT_ARGUMENT(module_descriptor);
+ IREE_ASSERT_ARGUMENT(instance);
IREE_ASSERT_ARGUMENT(out_module);
*out_module = NULL;
@@ -434,8 +438,9 @@
IREE_RETURN_IF_ERROR(
iree_allocator_malloc(allocator, sizeof(*module), (void**)&module));
- iree_status_t status = iree_vm_native_module_initialize(
- interface, module_descriptor, allocator, (iree_vm_module_t*)module);
+ iree_status_t status =
+ iree_vm_native_module_initialize(interface, module_descriptor, instance,
+ allocator, (iree_vm_module_t*)module);
if (!iree_status_is_ok(status)) {
iree_allocator_free(allocator, module);
return status;
@@ -448,9 +453,11 @@
IREE_API_EXPORT iree_status_t iree_vm_native_module_initialize(
const iree_vm_module_t* interface,
const iree_vm_native_module_descriptor_t* module_descriptor,
- iree_allocator_t allocator, iree_vm_module_t* base_module) {
+ iree_vm_instance_t* instance, iree_allocator_t allocator,
+ iree_vm_module_t* base_module) {
IREE_ASSERT_ARGUMENT(interface);
IREE_ASSERT_ARGUMENT(module_descriptor);
+ IREE_ASSERT_ARGUMENT(instance);
IREE_ASSERT_ARGUMENT(base_module);
iree_vm_native_module_t* module = (iree_vm_native_module_t*)base_module;
diff --git a/runtime/src/iree/vm/native_module.h b/runtime/src/iree/vm/native_module.h
index f5e08fc..3693e6b 100644
--- a/runtime/src/iree/vm/native_module.h
+++ b/runtime/src/iree/vm/native_module.h
@@ -12,6 +12,7 @@
#include <stdint.h>
#include "iree/base/api.h"
+#include "iree/vm/instance.h"
#include "iree/vm/module.h"
#include "iree/vm/stack.h"
@@ -136,12 +137,14 @@
IREE_API_EXPORT iree_status_t iree_vm_native_module_create(
const iree_vm_module_t* interface,
const iree_vm_native_module_descriptor_t* module_descriptor,
- iree_allocator_t allocator, iree_vm_module_t** out_module);
+ iree_vm_instance_t* instance, iree_allocator_t allocator,
+ iree_vm_module_t** out_module);
IREE_API_EXPORT iree_status_t iree_vm_native_module_initialize(
const iree_vm_module_t* interface,
const iree_vm_native_module_descriptor_t* module_descriptor,
- iree_allocator_t allocator, iree_vm_module_t* module);
+ iree_vm_instance_t* instance, iree_allocator_t allocator,
+ iree_vm_module_t* module);
#ifdef __cplusplus
} // extern "C"
diff --git a/runtime/src/iree/vm/native_module_cc.h b/runtime/src/iree/vm/native_module_cc.h
index b13be56..dce141d 100644
--- a/runtime/src/iree/vm/native_module_cc.h
+++ b/runtime/src/iree/vm/native_module_cc.h
@@ -14,6 +14,7 @@
#include "iree/base/api.h"
#include "iree/base/internal/span.h"
#include "iree/base/status_cc.h"
+#include "iree/vm/instance.h"
#include "iree/vm/module.h"
#include "iree/vm/native_module_packing.h" // IWYU pragma: export
#include "iree/vm/stack.h"
@@ -67,12 +68,15 @@
template <typename State>
class NativeModule {
public:
- NativeModule(const char* name, uint32_t version, iree_allocator_t allocator,
+ NativeModule(const char* name, uint32_t version, iree_vm_instance_t* instance,
+ iree_allocator_t allocator,
iree::span<const NativeFunction<State>> dispatch_table)
: name_(name),
version_(version),
+ instance_(instance),
allocator_(allocator),
dispatch_table_(dispatch_table) {
+ iree_vm_instance_retain(instance);
IREE_CHECK_OK(iree_vm_module_initialize(&interface_, this));
interface_.destroy = NativeModule::ModuleDestroy;
interface_.name = NativeModule::ModuleName;
@@ -91,7 +95,7 @@
// TODO(benvanik): resume_call
}
- virtual ~NativeModule() = default;
+ virtual ~NativeModule() { iree_vm_instance_release(instance_); }
// C API module interface bound to this NativeModule instance.
iree_vm_module_t* interface() { return &interface_; }
@@ -278,6 +282,7 @@
const char* name_;
uint32_t version_;
+ iree_vm_instance_t* instance_;
const iree_allocator_t allocator_;
iree_vm_module_t interface_;
diff --git a/runtime/src/iree/vm/native_module_packing.h b/runtime/src/iree/vm/native_module_packing.h
index 7a87e0b..fc86c2d 100644
--- a/runtime/src/iree/vm/native_module_packing.h
+++ b/runtime/src/iree/vm/native_module_packing.h
@@ -15,7 +15,6 @@
#include "iree/base/api.h"
#include "iree/base/internal/span.h"
#include "iree/base/status_cc.h"
-#include "iree/vm/builtin_types.h"
#include "iree/vm/module.h"
#include "iree/vm/ref.h"
#include "iree/vm/ref_cc.h"
diff --git a/runtime/src/iree/vm/native_module_test.cc b/runtime/src/iree/vm/native_module_test.cc
index ee06bcc..5a4c489 100644
--- a/runtime/src/iree/vm/native_module_test.cc
+++ b/runtime/src/iree/vm/native_module_test.cc
@@ -32,9 +32,11 @@
// Create both modules shared instances. These are generally immutable and
// can be shared by multiple contexts.
iree_vm_module_t* module_a = nullptr;
- IREE_CHECK_OK(module_a_create(iree_allocator_system(), &module_a));
+ IREE_CHECK_OK(
+ module_a_create(instance_, iree_allocator_system(), &module_a));
iree_vm_module_t* module_b = nullptr;
- IREE_CHECK_OK(module_b_create(iree_allocator_system(), &module_b));
+ IREE_CHECK_OK(
+ module_b_create(instance_, iree_allocator_system(), &module_b));
// Create the context with both modules and perform runtime linkage.
// Imports from module_a -> module_b will be resolved and per-context state
diff --git a/runtime/src/iree/vm/native_module_test.h b/runtime/src/iree/vm/native_module_test.h
index 8ca958d..b6c1a16 100644
--- a/runtime/src/iree/vm/native_module_test.h
+++ b/runtime/src/iree/vm/native_module_test.h
@@ -123,13 +123,14 @@
/*functions=*/module_a_funcs_,
};
-static iree_status_t module_a_create(iree_allocator_t allocator,
+static iree_status_t module_a_create(iree_vm_instance_t* instance,
+ iree_allocator_t allocator,
iree_vm_module_t** out_module) {
// NOTE: this module has neither shared or per-context module state.
iree_vm_module_t interface;
IREE_RETURN_IF_ERROR(iree_vm_module_initialize(&interface, NULL));
return iree_vm_native_module_create(&interface, &module_a_descriptor_,
- allocator, out_module);
+ instance, allocator, out_module);
}
//===----------------------------------------------------------------------===//
@@ -274,7 +275,8 @@
/*functions=*/module_b_funcs_,
};
-static iree_status_t module_b_create(iree_allocator_t allocator,
+static iree_status_t module_b_create(iree_vm_instance_t* instance,
+ iree_allocator_t allocator,
iree_vm_module_t** out_module) {
// Allocate shared module state.
module_b_t* module = NULL;
@@ -307,5 +309,5 @@
interface.free_state = module_b_free_state;
interface.resolve_import = module_b_resolve_import;
return iree_vm_native_module_create(&interface, &module_b_descriptor_,
- allocator, out_module);
+ instance, allocator, out_module);
}
diff --git a/runtime/src/iree/vm/test/emitc/module_test.cc b/runtime/src/iree/vm/test/emitc/module_test.cc
index 8544003..d42fa57 100644
--- a/runtime/src/iree/vm/test/emitc/module_test.cc
+++ b/runtime/src/iree/vm/test/emitc/module_test.cc
@@ -41,7 +41,8 @@
namespace {
-typedef iree_status_t (*create_function_t)(iree_allocator_t,
+typedef iree_status_t (*create_function_t)(iree_vm_instance_t*,
+ iree_allocator_t,
iree_vm_module_t**);
struct TestParams {
@@ -121,8 +122,8 @@
IREE_CHECK_OK(iree_vm_instance_create(iree_allocator_system(), &instance_));
iree_vm_module_t* module_ = nullptr;
- IREE_CHECK_OK(
- test_params.create_function(iree_allocator_system(), &module_));
+ IREE_CHECK_OK(test_params.create_function(
+ instance_, iree_allocator_system(), &module_));
std::vector<iree_vm_module_t*> modules = {module_};
IREE_CHECK_OK(iree_vm_context_create_with_modules(
diff --git a/samples/colab/edge_detection.ipynb b/samples/colab/edge_detection.ipynb
index 944a2bc..10fa776 100644
--- a/samples/colab/edge_detection.ipynb
+++ b/samples/colab/edge_detection.ipynb
@@ -404,11 +404,11 @@
"#@title Compile and prepare to test the edge detection module\n",
"\n",
"flatbuffer_blob = compile_str(compiler_module, target_backends=[\"vmvx\"], input_type=\"mhlo\")\n",
- "vm_module = ireert.VmModule.from_flatbuffer(flatbuffer_blob)\n",
"\n",
"# Register the module with a runtime context.\n",
"config = ireert.Config(backend.driver)\n",
"ctx = ireert.SystemContext(config=config)\n",
+ "vm_module = ireert.VmModule.from_flatbuffer(ctx.instance, flatbuffer_blob)\n",
"ctx.add_vm_module(vm_module)"
],
"execution_count": 11,
diff --git a/samples/colab/low_level_invoke_function.ipynb b/samples/colab/low_level_invoke_function.ipynb
index 95ba289..7b9ede0 100644
--- a/samples/colab/low_level_invoke_function.ipynb
+++ b/samples/colab/low_level_invoke_function.ipynb
@@ -121,8 +121,7 @@
"\"\"\"\n",
"\n",
"# Compile using the vmvx (reference) target:\n",
- "compiled_flatbuffer = compile_str(SIMPLE_MUL_ASM, target_backends=[\"vmvx\"])\n",
- "vm_module = ireert.VmModule.from_flatbuffer(compiled_flatbuffer)"
+ "compiled_flatbuffer = compile_str(SIMPLE_MUL_ASM, target_backends=[\"vmvx\"])"
],
"execution_count": 4,
"outputs": []
@@ -141,6 +140,7 @@
"# Use the \"local-task\" CPU driver, which can load the vmvx executable:\n",
"config = ireert.Config(\"local-task\")\n",
"ctx = ireert.SystemContext(config=config)\n",
+ "vm_module = ireert.VmModule.from_flatbuffer(ctx.instance, compiled_flatbuffer)\n",
"ctx.add_vm_module(vm_module)\n",
"\n",
"# Invoke the function and print the result.\n",
diff --git a/samples/colab/tflite_text_classification.ipynb b/samples/colab/tflite_text_classification.ipynb
index 12178fe..437f492 100644
--- a/samples/colab/tflite_text_classification.ipynb
+++ b/samples/colab/tflite_text_classification.ipynb
@@ -391,11 +391,11 @@
"source": [
"# Compile the TOSA MLIR into a VM module.\n",
"compiled_flatbuffer = compile_str(tosa_mlir, input_type=\"tosa\", target_backends=[\"vmvx\"])\n",
- "vm_module = iree_rt.VmModule.from_flatbuffer(compiled_flatbuffer)\n",
"\n",
"# Register the module with a runtime context.\n",
"config = iree_rt.Config(\"local-task\")\n",
"ctx = iree_rt.SystemContext(config=config)\n",
+ "vm_module = iree_rt.VmModule.from_flatbuffer(ctx.instance, compiled_flatbuffer)\n",
"ctx.add_vm_module(vm_module)\n",
"invoke_text_classification = ctx.modules.module[\"main\"]\n",
"\n",
diff --git a/samples/custom_module/README.md b/samples/custom_module/README.md
index eb09262..bece021 100644
--- a/samples/custom_module/README.md
+++ b/samples/custom_module/README.md
@@ -168,12 +168,12 @@
```c
// Ensure custom types are registered before loading modules that use them.
-// This only needs to be done once.
-IREE_CHECK_OK(iree_custom_module_register_types());
+// This only needs to be done once per instance.
+IREE_CHECK_OK(iree_custom_module_register_types(instance));
// Create the custom module that can be reused across contexts.
iree_vm_module_t* custom_module = NULL;
-IREE_CHECK_OK(iree_custom_module_create(allocator, &custom_module));
+IREE_CHECK_OK(iree_custom_module_create(instance, allocator, &custom_module));
// Create the context for this invocation reusing the loaded modules.
// Contexts hold isolated state and can be reused for multiple calls.
diff --git a/samples/custom_module/main.c b/samples/custom_module/main.c
index d740ff6..bf72aa0 100644
--- a/samples/custom_module/main.c
+++ b/samples/custom_module/main.c
@@ -48,12 +48,11 @@
// Ensure custom types are registered before loading modules that use them.
// This only needs to be done once.
- // TODO(benvanik): move to instance-based registration.
- IREE_CHECK_OK(iree_custom_module_register_types());
+ IREE_CHECK_OK(iree_custom_module_register_types(instance));
// Create the custom module that can be reused across contexts.
iree_vm_module_t* custom_module = NULL;
- IREE_CHECK_OK(iree_custom_module_create(allocator, &custom_module));
+ IREE_CHECK_OK(iree_custom_module_create(instance, allocator, &custom_module));
// Load the module from stdin or a file on disk.
// Applications can ship and load modules however they want (such as mapping
@@ -75,7 +74,7 @@
// Note that we let the module retain the file contents for as long as needed.
iree_vm_module_t* bytecode_module = NULL;
IREE_CHECK_OK(iree_vm_bytecode_module_create(
- module_contents->const_buffer,
+ instance, module_contents->const_buffer,
iree_file_contents_deallocator(module_contents), allocator,
&bytecode_module));
diff --git a/samples/custom_module/module.cc b/samples/custom_module/module.cc
index 31ef0b9..8b5dd12 100644
--- a/samples/custom_module/module.cc
+++ b/samples/custom_module/module.cc
@@ -68,7 +68,8 @@
iree_allocator_free(string->allocator, ptr);
}
-extern "C" iree_status_t iree_custom_module_register_types(void) {
+extern "C" iree_status_t iree_custom_module_register_types(
+ iree_vm_instance_t* instance) {
if (iree_custom_string_descriptor.type) {
return iree_ok_status(); // Already registered.
}
@@ -185,11 +186,12 @@
// Note that while we are using C++ bindings internally we still expose the
// module as a C instance. This hides the details of our implementation.
extern "C" iree_status_t iree_custom_module_create(
- iree_allocator_t allocator, iree_vm_module_t** out_module) {
+ iree_vm_instance_t* instance, iree_allocator_t allocator,
+ iree_vm_module_t** out_module) {
IREE_ASSERT_ARGUMENT(out_module);
*out_module = NULL;
auto module = std::make_unique<CustomModule>(
- "custom", /*version=*/0, allocator,
+ "custom", /*version=*/0, instance, allocator,
iree::span<const vm::NativeFunction<CustomModuleState>>(
kCustomModuleFunctions));
*out_module = module.release()->interface();
diff --git a/samples/custom_module/module.h b/samples/custom_module/module.h
index 0e0df41..8fdc5c0 100644
--- a/samples/custom_module/module.h
+++ b/samples/custom_module/module.h
@@ -31,16 +31,15 @@
iree_custom_string_t** out_string);
// Registers types provided by the custom module.
-// TODO(benvanik): move to instance-based type registration (this would take
-// a iree_vm_instance_t).
-iree_status_t iree_custom_module_register_types(void);
+iree_status_t iree_custom_module_register_types(iree_vm_instance_t* instance);
// Creates a native custom module that can be reused in multiple contexts.
// The module itself may hold state that can be shared by all instantiated
// copies but it will require the module to provide synchronization; usually
// it's safer to just treat the module as immutable and keep state within the
// instantiated module states instead.
-iree_status_t iree_custom_module_create(iree_allocator_t allocator,
+iree_status_t iree_custom_module_create(iree_vm_instance_t* instance,
+ iree_allocator_t allocator,
iree_vm_module_t** out_module);
#ifdef __cplusplus
diff --git a/samples/dynamic_shapes/dynamic_shapes.ipynb b/samples/dynamic_shapes/dynamic_shapes.ipynb
index 1a5532a..392c816 100644
--- a/samples/dynamic_shapes/dynamic_shapes.ipynb
+++ b/samples/dynamic_shapes/dynamic_shapes.ipynb
@@ -319,9 +319,9 @@
"\n",
"from iree import runtime as ireert\n",
"\n",
- "vm_module = ireert.VmModule.from_flatbuffer(flatbuffer_blob)\n",
"config = ireert.Config(\"local-task\")\n",
"ctx = ireert.SystemContext(config=config)\n",
+ "vm_module = ireert.VmModule.from_flatbuffer(ctx.instance, flatbuffer_blob)\n",
"ctx.add_vm_module(vm_module)"
],
"execution_count": 9,
diff --git a/samples/emitc_modules/add_module_test.cc b/samples/emitc_modules/add_module_test.cc
index 652702d..e769bfa 100644
--- a/samples/emitc_modules/add_module_test.cc
+++ b/samples/emitc_modules/add_module_test.cc
@@ -21,7 +21,8 @@
IREE_CHECK_OK(iree_vm_instance_create(iree_allocator_system(), &instance_));
iree_vm_module_t* add_module = nullptr;
- IREE_CHECK_OK(add_module_create(iree_allocator_system(), &add_module));
+ IREE_CHECK_OK(
+ add_module_create(instance_, iree_allocator_system(), &add_module));
std::vector<iree_vm_module_t*> modules = {add_module};
IREE_CHECK_OK(iree_vm_context_create_with_modules(
diff --git a/samples/emitc_modules/import_module_test.cc b/samples/emitc_modules/import_module_test.cc
index 53f1464..4c93c50 100644
--- a/samples/emitc_modules/import_module_test.cc
+++ b/samples/emitc_modules/import_module_test.cc
@@ -21,10 +21,12 @@
IREE_CHECK_OK(iree_vm_instance_create(iree_allocator_system(), &instance_));
iree_vm_module_t* module_a = nullptr;
- IREE_CHECK_OK(module_a_create(iree_allocator_system(), &module_a));
+ IREE_CHECK_OK(
+ module_a_create(instance_, iree_allocator_system(), &module_a));
iree_vm_module_t* module_b = nullptr;
- IREE_CHECK_OK(module_b_create(iree_allocator_system(), &module_b));
+ IREE_CHECK_OK(
+ module_b_create(instance_, iree_allocator_system(), &module_b));
// Note: order matters as module_a imports from module_b
std::vector<iree_vm_module_t*> modules = {module_b, module_a};
diff --git a/samples/py_custom_module/decode_secret_message.py b/samples/py_custom_module/decode_secret_message.py
index 4284f6e..42682b4 100644
--- a/samples/py_custom_module/decode_secret_message.py
+++ b/samples/py_custom_module/decode_secret_message.py
@@ -96,17 +96,17 @@
def compile():
- vmfb_contents = compiler.tools.compile_file(os.path.join(
- os.path.dirname(__file__), "main.mlir"),
- target_backends=["vmvx"])
- return rt.VmModule.from_flatbuffer(vmfb_contents)
+ return compiler.tools.compile_file(os.path.join(os.path.dirname(__file__),
+ "main.mlir"),
+ target_backends=["vmvx"])
def main():
print("Compiling...")
- main_module = compile()
+ vmfb_contents = compile()
print("Decoding secret message...")
config = rt.Config("local-sync")
+ main_module = rt.VmModule.from_flatbuffer(config.vm_instance, vmfb_contents)
modules = config.default_vm_modules + (
create_tokenizer_module(),
main_module,
diff --git a/samples/simple_embedding/simple_embedding.c b/samples/simple_embedding/simple_embedding.c
index 9d5544b..d7585e0 100644
--- a/samples/simple_embedding/simple_embedding.c
+++ b/samples/simple_embedding/simple_embedding.c
@@ -41,7 +41,7 @@
"create device");
iree_vm_module_t* hal_module = NULL;
IREE_RETURN_IF_ERROR(
- iree_hal_module_create(device, IREE_HAL_MODULE_FLAG_SYNCHRONOUS,
+ iree_hal_module_create(instance, device, IREE_HAL_MODULE_FLAG_SYNCHRONOUS,
iree_allocator_system(), &hal_module));
// Load bytecode module from the embedded data.
@@ -49,7 +49,7 @@
iree_vm_module_t* bytecode_module = NULL;
IREE_RETURN_IF_ERROR(iree_vm_bytecode_module_create(
- module_data, iree_allocator_null(), iree_allocator_system(),
+ instance, module_data, iree_allocator_null(), iree_allocator_system(),
&bytecode_module));
// Allocate a context that will hold the module state across invocations.
diff --git a/samples/static_library/create_bytecode_module.c b/samples/static_library/create_bytecode_module.c
index 4e5c5e7..b7b6b21 100644
--- a/samples/static_library/create_bytecode_module.c
+++ b/samples/static_library/create_bytecode_module.c
@@ -9,14 +9,15 @@
#include "samples/static_library/simple_mul_c.h"
// A function to create the bytecode module.
-iree_status_t create_module(iree_vm_module_t** module) {
+iree_status_t create_module(iree_vm_instance_t* instance,
+ iree_vm_module_t** out_module) {
const struct iree_file_toc_t* module_file_toc =
iree_samples_static_library_simple_mul_create();
iree_const_byte_span_t module_data =
iree_make_const_byte_span(module_file_toc->data, module_file_toc->size);
-
- return iree_vm_bytecode_module_create(module_data, iree_allocator_null(),
- iree_allocator_system(), module);
+ return iree_vm_bytecode_module_create(
+ instance, module_data, iree_allocator_null(),
+ iree_vm_instance_allocator(instance), out_module);
}
void print_success() { printf("static_library_run_bytecode passed\n"); }
diff --git a/samples/static_library/create_c_module.c b/samples/static_library/create_c_module.c
index 683aead..2dfc0fd 100644
--- a/samples/static_library/create_c_module.c
+++ b/samples/static_library/create_c_module.c
@@ -8,8 +8,10 @@
#include "samples/static_library/simple_mul_emitc.h"
// A function to create the C module.
-iree_status_t create_module(iree_vm_module_t** module) {
- return module_create(iree_allocator_system(), module);
+iree_status_t create_module(iree_vm_instance_t* instance,
+ iree_vm_module_t** out_module) {
+ return module_create(instance, iree_vm_instance_allocator(instance),
+ out_module);
}
void print_success() { printf("static_library_run_c passed\n"); }
diff --git a/samples/static_library/static_library_demo.c b/samples/static_library/static_library_demo.c
index 719802c..ceb9770 100644
--- a/samples/static_library/static_library_demo.c
+++ b/samples/static_library/static_library_demo.c
@@ -17,7 +17,8 @@
iree_hal_executable_library_version_t max_version,
const iree_hal_executable_environment_v0_t* environment);
// A function to create the bytecode or C module.
-extern iree_status_t create_module(iree_vm_module_t** module);
+extern iree_status_t create_module(iree_vm_instance_t* instance,
+ iree_vm_module_t** out_module);
extern void print_success();
@@ -95,7 +96,8 @@
iree_vm_module_t* module = NULL;
if (iree_status_is_ok(status)) {
- status = create_module(&module);
+ status =
+ create_module(iree_runtime_instance_vm_instance(instance), &module);
}
if (iree_status_is_ok(status)) {
diff --git a/samples/variables_and_state/variables_and_state.ipynb b/samples/variables_and_state/variables_and_state.ipynb
index 77c851d..5c0ca86 100644
--- a/samples/variables_and_state/variables_and_state.ipynb
+++ b/samples/variables_and_state/variables_and_state.ipynb
@@ -331,9 +331,9 @@
"\n",
"from iree import runtime as ireert\n",
"\n",
- "vm_module = ireert.VmModule.from_flatbuffer(flatbuffer_blob)\n",
"config = ireert.Config(\"local-task\")\n",
"ctx = ireert.SystemContext(config=config)\n",
+ "vm_module = ireert.VmModule.from_flatbuffer(ctx.instance, flatbuffer_blob)\n",
"ctx.add_vm_module(vm_module)"
],
"execution_count": 9,
diff --git a/tools/android/run_module_app/src/main.cc b/tools/android/run_module_app/src/main.cc
index 90626a1..f4047da 100644
--- a/tools/android/run_module_app/src/main.cc
+++ b/tools/android/run_module_app/src/main.cc
@@ -98,6 +98,7 @@
iree_vm_module_t* input_module = nullptr;
IREE_RETURN_IF_ERROR(iree_vm_bytecode_module_create(
+ instance,
iree_make_const_byte_span((void*)invocation.module.data(),
invocation.module.size()),
iree_allocator_null(), iree_allocator_system(), &input_module));
@@ -108,8 +109,9 @@
iree_make_string_view(invocation.device.data(), invocation.device.size()),
iree_allocator_system(), &device));
iree_vm_module_t* hal_module = nullptr;
- IREE_RETURN_IF_ERROR(iree_hal_module_create(
- device, IREE_HAL_MODULE_FLAG_NONE, iree_allocator_system(), &hal_module));
+ IREE_RETURN_IF_ERROR(
+ iree_hal_module_create(instance, device, IREE_HAL_MODULE_FLAG_NONE,
+ iree_allocator_system(), &hal_module));
iree_vm_context_t* context = nullptr;
// Order matters. The input module will likely be dependent on the hal module.
diff --git a/tools/iree-check-module-main.cc b/tools/iree-check-module-main.cc
index 21a85f3..8c4dd88 100644
--- a/tools/iree-check-module-main.cc
+++ b/tools/iree-check-module-main.cc
@@ -83,7 +83,8 @@
"creating instance");
iree_vm_module_t* check_module = nullptr;
- IREE_RETURN_IF_ERROR(iree_check_module_create(host_allocator, &check_module));
+ IREE_RETURN_IF_ERROR(
+ iree_check_module_create(instance, host_allocator, &check_module));
// TODO(benvanik): use --module_file= flag in order to reuse
// iree_tooling_load_module_from_flags.
@@ -98,7 +99,7 @@
}
iree_vm_module_t* main_module = nullptr;
IREE_RETURN_IF_ERROR(iree_vm_bytecode_module_create(
- flatbuffer_contents->const_buffer,
+ instance, flatbuffer_contents->const_buffer,
iree_file_contents_deallocator(flatbuffer_contents), host_allocator,
&main_module));
diff --git a/tools/iree-run-mlir-main.cc b/tools/iree-run-mlir-main.cc
index 608faa4..62e1885 100644
--- a/tools/iree-run-mlir-main.cc
+++ b/tools/iree-run-mlir-main.cc
@@ -359,6 +359,7 @@
// with devices.
iree_vm_module_t* main_module = nullptr;
IREE_RETURN_IF_ERROR(iree_vm_bytecode_module_create(
+ instance,
iree_make_const_byte_span((void*)flatbuffer_data.data(),
flatbuffer_data.size()),
iree_allocator_null(), iree_allocator_system(), &main_module));