Start python bindings for vm2.
Most things don't yet work as we need the hal module that is still in flight.
PiperOrigin-RevId: 286416263
diff --git a/bindings/python/pyiree/BUILD b/bindings/python/pyiree/BUILD
index d0d531d..e3d710a 100644
--- a/bindings/python/pyiree/BUILD
+++ b/bindings/python/pyiree/BUILD
@@ -127,6 +127,8 @@
"@com_google_absl//absl/types:span",
"@com_google_absl//absl/types:variant",
"//iree/compiler/Utils",
+ "//iree/vm2",
+ "//iree/vm2:bytecode_module",
"@local_config_mlir//:IR",
"//iree/base:api",
"//iree/base:status",
@@ -203,3 +205,14 @@
"//bindings/python/pyiree",
],
)
+
+py_test(
+ name = "vm_test",
+ srcs = ["vm_test.py"],
+ python_version = "PY3",
+ deps = NUMPY_DEPS + [
+ "//bindings/python:pathsetup", # build_cleaner: keep
+ "@absl_py//absl/testing:absltest",
+ "//bindings/python/pyiree",
+ ],
+)
diff --git a/bindings/python/pyiree/binding.h b/bindings/python/pyiree/binding.h
index c39177c..24e4907 100644
--- a/bindings/python/pyiree/binding.h
+++ b/bindings/python/pyiree/binding.h
@@ -69,6 +69,22 @@
return {free_fn, holder};
}
+ static iree_allocator_t CreateDeallocator(std::shared_ptr<OpaqueBlob> blob) {
+ // Note that there are more efficient ways to write this which
+ // don't bounce through an extra heap alloc, but this is not
+ // intended to be a high impact code path.
+ struct Holder {
+ std::shared_ptr<OpaqueBlob> blob;
+ };
+ Holder* holder = new Holder{std::move(blob)};
+ auto free_fn = +([](void* self, void*) -> iree_status_t {
+ Holder* self_holder = static_cast<Holder*>(self);
+ delete self_holder;
+ return IREE_STATUS_OK;
+ });
+ return {holder /* self */, nullptr /* alloc */, free_fn /* free */};
+ }
+
protected:
void* data_;
size_t size_;
diff --git a/bindings/python/pyiree/vm.cc b/bindings/python/pyiree/vm.cc
index ec9ab9b..b69082d 100644
--- a/bindings/python/pyiree/vm.cc
+++ b/bindings/python/pyiree/vm.cc
@@ -14,7 +14,10 @@
#include "bindings/python/pyiree/vm.h"
+#include "absl/types/optional.h"
#include "bindings/python/pyiree/status_utils.h"
+#include "iree/base/api.h"
+#include "iree/vm2/instance.h"
namespace iree {
namespace python {
@@ -29,8 +32,120 @@
return RtModule::CreateRetained(module);
}
+//------------------------------------------------------------------------------
+// VmInstance
+//------------------------------------------------------------------------------
+
+VmInstance VmInstance::Create() {
+ iree_vm_instance_t* instance;
+ auto status = iree_vm_instance_create(IREE_ALLOCATOR_SYSTEM, &instance);
+ CheckApiStatus(status, "Error creating instance");
+ return VmInstance::CreateRetained(instance);
+}
+
+//------------------------------------------------------------------------------
+// VmContext
+//------------------------------------------------------------------------------
+
+VmContext VmContext::Create(VmInstance* instance,
+ absl::optional<std::vector<VmModule*>> modules) {
+ iree_vm_context_t* context;
+ if (!modules) {
+ // Simple create with open allowed modules.
+ auto status = iree_vm_context_create(instance->raw_ptr(),
+ IREE_ALLOCATOR_SYSTEM, &context);
+ CheckApiStatus(status, "Error creating vm context");
+ } else {
+ // Closed set of modules.
+ absl::InlinedVector<iree_vm_module_t*, 8> module_handles;
+ module_handles.resize(modules->size());
+ for (size_t i = 0, e = module_handles.size(); i < e; ++i) {
+ module_handles[i] = (*modules)[i]->raw_ptr();
+ }
+ auto status = iree_vm_context_create_with_modules(
+ instance->raw_ptr(), module_handles.data(), module_handles.size(),
+ IREE_ALLOCATOR_SYSTEM, &context);
+ CheckApiStatus(status, "Error creating vm context with modules");
+ }
+
+ CHECK(context);
+ return VmContext::CreateRetained(context);
+}
+
+void VmContext::RegisterModules(std::vector<VmModule*> modules) {
+ absl::InlinedVector<iree_vm_module_t*, 8> module_handles;
+ module_handles.resize(modules.size());
+ for (size_t i = 0, e = module_handles.size(); i < e; ++i) {
+ module_handles[i] = modules[i]->raw_ptr();
+ }
+ auto status = iree_vm_context_register_modules(raw_ptr(), &module_handles[0],
+ module_handles.size());
+ CheckApiStatus(status, "Error registering modules");
+}
+
+//------------------------------------------------------------------------------
+// VmModule
+//------------------------------------------------------------------------------
+
+VmModule VmModule::FromFlatbufferBlob(
+ std::shared_ptr<OpaqueBlob> flatbuffer_blob) {
+ iree_vm_module_t* module;
+ auto deallocator = OpaqueBlob::CreateDeallocator(flatbuffer_blob);
+ auto status = iree_vm_bytecode_module_create(
+ {static_cast<const uint8_t*>(flatbuffer_blob->data()),
+ flatbuffer_blob->size()},
+ deallocator, IREE_ALLOCATOR_SYSTEM, &module);
+ if (status != IREE_STATUS_OK) {
+ deallocator.free(deallocator.self, nullptr);
+ }
+
+ CheckApiStatus(status, "Error creating vm module from flatbuffer");
+ return VmModule::CreateRetained(module);
+}
+
+absl::optional<iree_vm_function_t> VmModule::LookupFunction(
+ const std::string& name, iree_vm_function_linkage_t linkage) {
+ iree_vm_function_t f;
+ auto status = iree_vm_module_lookup_function(raw_ptr(), linkage,
+ {name.data(), name.size()}, &f);
+ if (status == IREE_STATUS_NOT_FOUND) {
+ return absl::nullopt;
+ }
+ CheckApiStatus(status, "Error looking up function");
+ return f;
+}
+
void SetupVmBindings(pybind11::module m) {
+ // Deprecated: VM1 module.
m.def("create_module_from_blob", CreateModuleFromBlob);
+
+ py::enum_<iree_vm_function_linkage_t>(m, "Linkage")
+ .value("INTERNAL", IREE_VM_FUNCTION_LINKAGE_INTERNAL)
+ .value("IMPORT", IREE_VM_FUNCTION_LINKAGE_IMPORT)
+ .value("EXPORT", IREE_VM_FUNCTION_LINKAGE_EXPORT)
+ .export_values();
+
+ // Mutation and inspection of the variant list is mostly opaque to python.
+ py::class_<VmVariantList>(m, "VmVariantList")
+ .def(py::init(&VmVariantList::Create))
+ .def_property_readonly("size", &VmVariantList::size);
+
+ py::class_<iree_vm_function_t>(m, "VmFunction")
+ .def_readonly("ordinal", &iree_vm_function_t::ordinal)
+ .def_readonly("linkage", &iree_vm_function_t::linkage);
+
+ py::class_<VmInstance>(m, "VmInstance").def(py::init(&VmInstance::Create));
+
+ py::class_<VmContext>(m, "VmContext")
+ .def(py::init(&VmContext::Create), py::arg("instance"),
+ py::arg("modules") = absl::nullopt)
+ .def("register_modules", &VmContext::RegisterModules)
+ .def_property_readonly("context_id", &VmContext::context_id);
+
+ py::class_<VmModule>(m, "VmModule")
+ .def_static("from_flatbuffer", &VmModule::FromFlatbufferBlob)
+ .def("lookup_function", &VmModule::LookupFunction, py::arg("name"),
+ py::arg("linkage") = IREE_VM_FUNCTION_LINKAGE_EXPORT);
}
} // namespace python
diff --git a/bindings/python/pyiree/vm.h b/bindings/python/pyiree/vm.h
index ffb79b1..ee094b3 100644
--- a/bindings/python/pyiree/vm.h
+++ b/bindings/python/pyiree/vm.h
@@ -15,15 +15,125 @@
#ifndef IREE_BINDINGS_PYTHON_PYIREE_VM_H_
#define IREE_BINDINGS_PYTHON_PYIREE_VM_H_
+#include "absl/types/optional.h"
#include "bindings/python/pyiree/binding.h"
#include "bindings/python/pyiree/rt.h"
+#include "iree/base/api.h"
#include "iree/vm/api.h"
+#include "iree/vm2/bytecode_module.h"
+#include "iree/vm2/context.h"
+#include "iree/vm2/instance.h"
+#include "iree/vm2/invocation.h"
+#include "iree/vm2/module.h"
namespace iree {
namespace python {
+//------------------------------------------------------------------------------
+// Retain/release bindings
+//------------------------------------------------------------------------------
+
+template <>
+struct ApiPtrAdapter<iree_vm_instance_t> {
+ static void Retain(iree_vm_instance_t* b) { iree_vm_instance_retain(b); }
+ static void Release(iree_vm_instance_t* b) { iree_vm_instance_release(b); }
+};
+
+template <>
+struct ApiPtrAdapter<iree_vm_context_t> {
+ static void Retain(iree_vm_context_t* b) { iree_vm_context_retain(b); }
+ static void Release(iree_vm_context_t* b) { iree_vm_context_release(b); }
+};
+
+template <>
+struct ApiPtrAdapter<iree_vm_module_t> {
+ static void Retain(iree_vm_module_t* b) { iree_vm_module_retain(b); }
+ static void Release(iree_vm_module_t* b) { iree_vm_module_release(b); }
+};
+
+template <>
+struct ApiPtrAdapter<iree_vm_invocation_t> {
+ static void Retain(iree_vm_invocation_t* b) { iree_vm_invocation_retain(b); }
+ static void Release(iree_vm_invocation_t* b) {
+ iree_vm_invocation_release(b);
+ }
+};
+
+//------------------------------------------------------------------------------
+// ApiRefCounted types
+//------------------------------------------------------------------------------
+
+class VmInstance : public ApiRefCounted<VmInstance, iree_vm_instance_t> {
+ public:
+ static VmInstance Create();
+};
+
+class VmModule : public ApiRefCounted<VmModule, iree_vm_module_t> {
+ public:
+ static VmModule FromFlatbufferBlob(
+ std::shared_ptr<OpaqueBlob> flatbuffer_blob);
+
+ absl::optional<iree_vm_function_t> LookupFunction(
+ const std::string& name, iree_vm_function_linkage_t linkage);
+};
+
+class VmContext : public ApiRefCounted<VmContext, iree_vm_context_t> {
+ public:
+ // Creates a context, optionally with modules, which will make the context
+ // static, disallowing further module registration (and may be more
+ // efficient).
+ static VmContext Create(VmInstance* instance,
+ absl::optional<std::vector<VmModule*>> modules);
+
+ // Registers additional modules. Only valid for non static contexts (i.e.
+ // those created without modules.
+ void RegisterModules(std::vector<VmModule*> modules);
+
+ // Unique id for this context.
+ int context_id() const { return iree_vm_context_id(raw_ptr()); }
+};
+
+class VmInvocation : public ApiRefCounted<VmInvocation, iree_vm_invocation_t> {
+};
+
void SetupVmBindings(pybind11::module m);
+//------------------------------------------------------------------------------
+// VmVariantList
+//------------------------------------------------------------------------------
+
+class VmVariantList {
+ public:
+ VmVariantList() : list_(nullptr) {}
+ ~VmVariantList() {
+ if (list_) {
+ CheckApiStatus(iree_vm_variant_list_free(list_), "Error freeing list");
+ }
+ }
+
+ VmVariantList(VmVariantList&& other) {
+ list_ = other.list_;
+ other.list_ = nullptr;
+ }
+
+ VmVariantList& operator=(const VmVariantList&) = delete;
+ VmVariantList(const VmVariantList&) = delete;
+
+ static VmVariantList Create(iree_host_size_t capacity) {
+ iree_vm_variant_list_t* list;
+ CheckApiStatus(
+ iree_vm_variant_list_alloc(capacity, IREE_ALLOCATOR_SYSTEM, &list),
+ "Error allocating variant list");
+ return VmVariantList(list);
+ }
+
+ iree_host_size_t size() const { return iree_vm_variant_list_size(list_); }
+
+ private:
+ VmVariantList(iree_vm_variant_list_t* list) : list_(list) {}
+ iree_vm_variant_list_t* list_;
+};
+
} // namespace python
} // namespace iree
diff --git a/bindings/python/pyiree/vm_test.py b/bindings/python/pyiree/vm_test.py
new file mode 100644
index 0000000..338d64d
--- /dev/null
+++ b/bindings/python/pyiree/vm_test.py
@@ -0,0 +1,73 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: disable=unused-variable
+# pylint: disable=g-unreachable-test-method
+
+from absl.testing import absltest
+import pyiree
+
+
+def create_simple_mul_module():
+ ctx = pyiree.CompilerContext()
+ input_module = ctx.parse_asm("""
+ func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
+ attributes { iree.module.export } {
+ %0 = "xla_hlo.mul"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ return %0 : tensor<4xf32>
+ }
+ """)
+ binary = input_module.compile()
+ m = pyiree.binding.vm.VmModule.from_flatbuffer(binary)
+ return m
+
+
+class VmTest(absltest.TestCase):
+
+ def test_variant_list(self):
+ l = pyiree.binding.vm.VmVariantList(5)
+ print(l)
+ self.assertEqual(l.size, 0)
+
+ def test_context_id(self):
+ instance = pyiree.binding.vm.VmInstance()
+ context1 = pyiree.binding.vm.VmContext(instance)
+ context2 = pyiree.binding.vm.VmContext(instance)
+ self.assertGreater(context2.context_id, context1.context_id)
+
+ def disabled_test_module_basics(self):
+ m = create_simple_mul_module()
+ f = m.lookup_function("simple_mul")
+ self.assertGreater(f.ordinal, 0)
+ notfound = m.lookup_function("notfound")
+ self.assertIs(notfound, None)
+
+ def disabled_test_dynamic_module_context(self):
+ instance = pyiree.binding.vm.VmInstance()
+ context = pyiree.binding.vm.VmContext(instance)
+ m = create_simple_mul_module()
+ context.register_modules([m])
+
+ def disabled_test_static_module_context(self):
+ m = create_simple_mul_module()
+ print(m)
+ instance = pyiree.binding.vm.VmInstance()
+ print(instance)
+ context = pyiree.binding.vm.VmContext(instance, modules=[m])
+ print(context)
+
+
+if __name__ == "__main__":
+ absltest.main()