Add module dependencies via python bindings (#13472)
This change allows loading custom modules using python binding.
Example:
```python
ctx.add_module_dependency("cudnn")
ctx.add_vm_module(create_simple_mul_module(ctx.instance)) # depends on cuDNN module
```diff --git a/runtime/bindings/python/CMakeLists.txt b/runtime/bindings/python/CMakeLists.txt
index defabc6..714aa53 100644
--- a/runtime/bindings/python/CMakeLists.txt
+++ b/runtime/bindings/python/CMakeLists.txt
@@ -54,6 +54,7 @@
iree::hal::drivers
iree::hal::utils::allocators
iree::modules::hal
+ iree::tooling::modules
iree::vm
iree::vm::bytecode::module
)
diff --git a/runtime/bindings/python/iree/runtime/system_api.py b/runtime/bindings/python/iree/runtime/system_api.py
index bdbe13e..2423d2f 100644
--- a/runtime/bindings/python/iree/runtime/system_api.py
+++ b/runtime/bindings/python/iree/runtime/system_api.py
@@ -240,6 +240,11 @@
def modules(self) -> BoundModules:
return self._bound_modules
+ def add_module_dependency(self, name, minimum_version=0):
+ resolved_module = _binding.VmModule.resolve_module_dependency(
+ self._config.vm_instance, name, minimum_version)
+ self._vm_context.register_modules([resolved_module])
+
def add_vm_modules(self, vm_modules):
assert self._is_dynamic, "Cannot 'add_module' on a static context"
for m in vm_modules:
diff --git a/runtime/bindings/python/vm.cc b/runtime/bindings/python/vm.cc
index b9c0a1b..5f9acfa 100644
--- a/runtime/bindings/python/vm.cc
+++ b/runtime/bindings/python/vm.cc
@@ -13,6 +13,7 @@
// summaries of HAL objects in lists. We should have a better way of doing this
// dynamically vs hard depending on a type switch here.
#include "iree/modules/hal/module.h"
+#include "iree/tooling/modules/resolver.h"
#include "iree/vm/api.h"
#include "pybind11/numpy.h"
@@ -130,6 +131,26 @@
// VmModule
//------------------------------------------------------------------------------
+VmModule VmModule::ResolveModuleDependency(VmInstance* instance,
+ const std::string& name,
+ uint32_t minimum_version) {
+ IREE_TRACE_SCOPE0("VmModule::ResolveModuleDependency");
+ iree_vm_module_t* module = nullptr;
+
+ iree_vm_module_dependency_t dependency = {
+ iree_make_cstring_view(name.c_str()), minimum_version,
+ IREE_VM_MODULE_DEPENDENCY_FLAG_REQUIRED};
+
+ auto status = iree_tooling_resolve_module_dependency(
+ instance->raw_ptr(), &dependency, iree_allocator_system(), &module);
+
+ assert(module != nullptr);
+
+ CheckApiStatus(status, "Error resolving module dependency");
+ auto py_module = VmModule::StealFromRawPtr(module);
+ return py_module;
+}
+
VmModule VmModule::FromFlatbufferBlob(VmInstance* instance,
py::object flatbuffer_blob_object) {
IREE_TRACE_SCOPE0("VmModule::FromFlatbufferBlob");
@@ -638,6 +659,8 @@
.def("invoke", &VmContext::Invoke);
py::class_<VmModule>(m, "VmModule")
+ .def_static("resolve_module_dependency",
+ &VmModule::ResolveModuleDependency)
.def_static("from_flatbuffer", &VmModule::FromFlatbufferBlob)
.def_property_readonly("name", &VmModule::name)
.def_property_readonly("version",
diff --git a/runtime/bindings/python/vm.h b/runtime/bindings/python/vm.h
index ed3661e..4b097e7 100644
--- a/runtime/bindings/python/vm.h
+++ b/runtime/bindings/python/vm.h
@@ -129,6 +129,10 @@
class VmModule : public ApiRefCounted<VmModule, iree_vm_module_t> {
public:
+ static VmModule ResolveModuleDependency(VmInstance* instance,
+ const std::string& name,
+ uint32_t minimum_version);
+
static VmModule FromFlatbufferBlob(VmInstance* instance,
py::object flatbuffer_blob_object);