Finish wiring vm2 up to python bindings.
* The example I have isn't producing the right result (it is all zeros) so it is commented out. It will be easier to diagnose once landed.
* Fixes a bug in merging reflection attrs (redundant nested I and R spans).
* Stores reflection attributes on the corresponding internal function, since that is what we get when we look up the export.
* Fixes reflection lookup to consult internal functions.
* Once I have this working, I'll introduce a high level python API to initialize the system, load modules and invoke without all of the verbosity.
* I'm not really happy with how this layers together yet, but considering the amount of things in flight, it is ok. Once we delete all of the old rt stuff, it should be more obvious how to refactor the bindings for simplicity.
PiperOrigin-RevId: 286710166
diff --git a/bindings/python/pyiree/vm.cc b/bindings/python/pyiree/vm.cc
index c3fc37a..14578fe 100644
--- a/bindings/python/pyiree/vm.cc
+++ b/bindings/python/pyiree/vm.cc
@@ -15,9 +15,12 @@
#include "bindings/python/pyiree/vm.h"
#include "absl/types/optional.h"
+#include "bindings/python/pyiree/function_abi.h"
#include "bindings/python/pyiree/status_utils.h"
#include "iree/base/api.h"
#include "iree/modules/hal/hal_module.h"
+#include "iree/vm2/invocation.h"
+#include "iree/vm2/module.h"
namespace iree {
namespace python {
@@ -95,6 +98,42 @@
CheckApiStatus(status, "Error registering modules");
}
+std::unique_ptr<FunctionAbi> VmContext::CreateFunctionAbi(
+ HalDevice& device, std::shared_ptr<HostTypeFactory> host_type_factory,
+ iree_vm_function_t f) {
+ // Resolve attrs.
+ absl::InlinedVector<std::pair<iree_string_view_t, iree_string_view_t>, 4>
+ attrs;
+ for (int i = 0;; ++i) {
+ attrs.push_back({});
+ auto status = iree_vm_get_function_reflection_attr(
+ f, i, &attrs.back().first, &attrs.back().second);
+ if (status == IREE_STATUS_NOT_FOUND) {
+ attrs.pop_back();
+ break;
+ }
+ CheckApiStatus(status, "Error getting reflection attr");
+ }
+ auto attr_lookup =
+ [&attrs](absl::string_view key) -> absl::optional<absl::string_view> {
+ for (const auto& attr : attrs) {
+ absl::string_view found_key(attr.first.data, attr.first.size);
+ absl::string_view found_value(attr.second.data, attr.second.size);
+ if (found_key == key) return found_value;
+ }
+ return absl::nullopt;
+ };
+
+ return FunctionAbi::Create(device, std::move(host_type_factory), attr_lookup);
+}
+
+void VmContext::Invoke(iree_vm_function_t f, VmVariantList& inputs,
+ VmVariantList& outputs) {
+ CheckApiStatus(iree_vm_invoke(raw_ptr(), f, nullptr, inputs.raw_ptr(),
+ outputs.raw_ptr(), IREE_ALLOCATOR_SYSTEM),
+ "Error invoking function");
+}
+
//------------------------------------------------------------------------------
// VmModule
//------------------------------------------------------------------------------
@@ -149,8 +188,8 @@
.def_property_readonly("size", &VmVariantList::size);
py::class_<iree_vm_function_t>(m, "VmFunction")
- .def_readonly("ordinal", &iree_vm_function_t::ordinal)
- .def_readonly("linkage", &iree_vm_function_t::linkage);
+ .def_readonly("linkage", &iree_vm_function_t::linkage)
+ .def_readonly("ordinal", &iree_vm_function_t::ordinal);
py::class_<VmInstance>(m, "VmInstance").def(py::init(&VmInstance::Create));
@@ -158,7 +197,10 @@
.def(py::init(&VmContext::Create), py::arg("instance"),
py::arg("modules") = absl::nullopt)
.def("register_modules", &VmContext::RegisterModules)
- .def_property_readonly("context_id", &VmContext::context_id);
+ .def_property_readonly("context_id", &VmContext::context_id)
+ .def("create_function_abi", &VmContext::CreateFunctionAbi,
+ py::arg("device"), py::arg("host_type_factory"), py::arg("f"))
+ .def("invoke", &VmContext::Invoke);
py::class_<VmModule>(m, "VmModule")
.def_static("from_flatbuffer", &VmModule::FromFlatbufferBlob)