Port iree.runtime to nanobind. (#14214)
I believe that this should be a no-op for users. There is one minor API
change (the MappedMemory class no longer implements the buffer protocol,
but I've seen no evidence that this was actually used since it was a
less functional way to get a host ndarray).
More adventurous use of nanobind is possible in the future (i.e. using
`ndarray` and `dlpack` interop for sharing across frameworks), but using
that will most likely necessitate API changes, which I was working to
avoid.
Aside from relatively mechanical differences from pybind11, the main
issues were that the buffer protocol and array support was dropped in
nanobind. This required some direct coding against the C API to achieve
the same characteristics. I think this is actually an improvement as the
pybind11 implementations of these features was neither efficient nor
obvious what it was doing.
A build time dependency on `nanobind` is added. When building Python
wheels, this gets satisfied automatically. Otherwise, the docker images
have been updated to pre-install the necessary Python package. In
addition, there is now a build time dependency on NumPy headers, which
should already be installed (pybind11 vendored stripped down copies of
these headers in an effort to avoid this, but I opted to just do the
normal thing).
Nanobind's performance is [quite
compelling](https://nanobind.readthedocs.io/en/latest/benchmark.html)
and owes to a combination of favoring more efficient binding styles that
would basically be a rewrite in pybind11 and exclusive use of the new
Python 3.8+ vectorcall ABI. Since the runtime is performance critical
and the cost of Python calls is already quite visibly adding overhead on
traces, it makes sense to baseline on the most efficient implementation.
In addition, the compile-time savings seem to be real and the build is
noticeably faster (this was not a primary consideration, just a nice
bonus).
diff --git a/runtime/bindings/python/vm.cc b/runtime/bindings/python/vm.cc
index 7a8e0dd..18c8e8d 100644
--- a/runtime/bindings/python/vm.cc
+++ b/runtime/bindings/python/vm.cc
@@ -6,6 +6,11 @@
#include "./vm.h"
+#include <ios>
+#include <sstream>
+#include <unordered_set>
+
+#include "./buffer_interop.h"
#include "./status_utils.h"
#include "iree/base/api.h"
@@ -15,9 +20,8 @@
#include "iree/modules/hal/module.h"
#include "iree/tooling/modules/resolver.h"
#include "iree/vm/api.h"
-#include "pybind11/numpy.h"
-using namespace pybind11::literals;
+using namespace nanobind::literals;
namespace iree {
namespace python {
@@ -152,7 +156,7 @@
//------------------------------------------------------------------------------
VmContext VmContext::Create(VmInstance* instance,
- std::optional<std::vector<VmModule*>> modules) {
+ std::optional<std::vector<VmModule*>>& modules) {
IREE_TRACE_SCOPE_NAMED("VmContext::Create");
iree_vm_context_t* context;
if (!modules) {
@@ -228,8 +232,8 @@
VmModule VmModule::MMap(VmInstance* instance, std::string filepath,
py::object destroy_callback) {
IREE_TRACE_SCOPE_NAMED("VmModule::MMap");
- auto mmap_module = py::module::import("mmap");
- auto open_func = py::module::import("io").attr("open");
+ auto mmap_module = py::module_::import_("mmap");
+ auto open_func = py::module_::import_("io").attr("open");
auto file_obj = open_func(filepath, "r+b");
// The signature of mmap is different on Windows vs others. On others,
// we use explicit flags and protection attributes for better control,
@@ -267,13 +271,12 @@
VmModule VmModule::WrapBuffer(VmInstance* instance, py::object buffer_obj,
py::object destroy_callback, bool close_buffer) {
IREE_TRACE_SCOPE_NAMED("VmModule::FromAlignedMemory");
- auto py_buffer = py::cast<py::buffer>(buffer_obj);
- auto buffer_info = py_buffer.request();
- if (!iree_host_size_has_alignment((uintptr_t)buffer_info.ptr,
+ PyBufferRequest buffer_info(buffer_obj, PyBUF_SIMPLE);
+ if (!iree_host_size_has_alignment((uintptr_t)buffer_info.view().buf,
IREE_HAL_HEAP_BUFFER_ALIGNMENT)) {
std::stringstream err;
err << "VmModule.from_aligned_memory received an unaligned buffer. ";
- err << "Got 0x" << (void*)buffer_info.ptr << ", expected alignment ";
+ err << "Got 0x" << (void*)buffer_info.view().buf << ", expected alignment ";
err << IREE_HAL_HEAP_BUFFER_ALIGNMENT;
throw std::invalid_argument(err.str());
}
@@ -329,8 +332,8 @@
auto status = iree_vm_bytecode_module_create(
instance->raw_ptr(),
- {static_cast<const uint8_t*>(buffer_info.ptr),
- static_cast<iree_host_size_t>(buffer_info.size)},
+ {static_cast<const uint8_t*>(buffer_info.view().buf),
+ static_cast<iree_host_size_t>(buffer_info.view().len)},
deallocator, iree_allocator_system(), &module);
if (!iree_status_is_ok(status)) {
delete state;
@@ -347,31 +350,33 @@
VmModule VmModule::CopyBuffer(VmInstance* instance, py::object buffer_obj) {
IREE_TRACE_SCOPE_NAMED("VmModule::CopyBuffer");
auto alignment =
- py::cast<uintptr_t>(py::module::import("mmap").attr("PAGESIZE"));
- auto bytearray_ctor = py::module::import("builtins").attr("bytearray");
- auto src_buffer = py::cast<py::buffer>(buffer_obj);
- auto src_buffer_info = src_buffer.request();
- py::ssize_t src_buffer_size = src_buffer_info.itemsize * src_buffer_info.size;
+ py::cast<uintptr_t>(py::module_::import_("mmap").attr("PAGESIZE"));
+ auto bytearray_ctor = py::module_::import_("builtins").attr("bytearray");
+ PyBufferRequest src_buffer_info(buffer_obj, PyBUF_SIMPLE);
+ auto src_buffer_size = src_buffer_info.view().len;
// Need to allocate an extra page because there is no control at the Python
// level for the alignment it may have.
- auto dst_buffer =
- py::cast<py::buffer>(bytearray_ctor(src_buffer_size + alignment));
- auto dst_buffer_info = dst_buffer.request();
+ auto dst_buffer = bytearray_ctor(src_buffer_size + alignment);
+ PyBufferRequest dst_buffer_info(dst_buffer, PyBUF_SIMPLE);
void* dst_aligned =
- (void*)iree_host_align((uintptr_t)dst_buffer_info.ptr, alignment);
+ (void*)iree_host_align((uintptr_t)dst_buffer_info.view().buf, alignment);
uintptr_t dst_offset =
- (uintptr_t)dst_aligned - (uintptr_t)dst_buffer_info.ptr;
+ (uintptr_t)dst_aligned - (uintptr_t)dst_buffer_info.view().buf;
// Now create a memoryview over the unaligned bytearray and slice into that
// to get the aligned Python buffer.
- auto dst_slice = py::slice(dst_offset, dst_offset + src_buffer_size, 1);
- py::object dst_view = py::memoryview(dst_buffer);
+ auto dst_slice =
+ py::slice(py::cast(dst_offset), py::cast(dst_offset + src_buffer_size),
+ py::cast(1));
+
+ py::object dst_view = py::steal<py::object>(
+ PyMemoryView_GetContiguous(dst_buffer.ptr(), PyBUF_READ, 'C'));
py::object dst_view_aligned = dst_view[dst_slice];
// If any of the indexing math was wrong, Python exceptions will be raised
// above, so this is implicitly guarding the memcpy if it is done last.
- std::memcpy(dst_aligned, src_buffer_info.ptr, src_buffer_size);
+ std::memcpy(dst_aligned, src_buffer_info.view().buf, src_buffer_size);
return WrapBuffer(instance, std::move(dst_view_aligned),
/*destroy_callback=*/py::none(),
/*close_buffer=*/false);
@@ -380,15 +385,15 @@
VmModule VmModule::FromBuffer(VmInstance* instance, py::object buffer_obj,
bool warn_if_copy) {
IREE_TRACE_SCOPE_NAMED("VmModule::FromBuffer");
- auto py_buffer = py::cast<py::buffer>(buffer_obj);
- auto buffer_info = py_buffer.request();
- if (iree_host_size_has_alignment((uintptr_t)buffer_info.ptr,
+ PyBufferRequest buffer_info(buffer_obj, PyBUF_SIMPLE);
+
+ if (iree_host_size_has_alignment((uintptr_t)buffer_info.view().buf,
IREE_HAL_HEAP_BUFFER_ALIGNMENT)) {
return WrapBuffer(instance, std::move(buffer_obj),
/*destroy_callback=*/py::none(), /*close_buffer=*/false);
} else {
if (warn_if_copy) {
- py::module::import("warnings")
+ py::module_::import_("warnings")
.attr("warn")(
"Making copy of unaligned VmModule buffer. It is recommended to "
"make this deterministic by calling `copy_buffer` to always make "
@@ -423,7 +428,7 @@
const char* const VmRef::kTypeAttr = "__iree_vm_type__";
py::object VmRef::Deref(py::object ref_object_class, bool optional) {
- py::object casted = ref_object_class.attr(kCastAttr)(*this);
+ py::object casted = ref_object_class.attr(kCastAttr)(this);
if (!optional && casted.is_none()) {
throw py::type_error("Cannot dereference to specific type");
}
@@ -517,7 +522,7 @@
} else if (iree_vm_variant_is_ref(v)) {
VmRef ref;
iree_vm_ref_retain(&v.ref, &ref.ref());
- return py::cast(ref, py::return_value_policy::move);
+ return py::cast(ref, py::rv_policy::move);
}
throw RaiseValueError("Unsupported VM to Python Type Conversion");
@@ -642,7 +647,7 @@
}
VmRef ref;
iree_vm_ref_retain(&v.ref, &ref.ref());
- return py::cast(ref, py::return_value_policy::move);
+ return py::cast(ref, py::rv_policy::move);
}
py::object VmVariantList::GetAsObject(int index, py::object clazz) {
@@ -763,7 +768,7 @@
return s;
}
-void SetupVmBindings(pybind11::module m) {
+void SetupVmBindings(nanobind::module_ m) {
py::enum_<enum iree_vm_function_linkage_e>(m, "Linkage")
.value("INTERNAL", IREE_VM_FUNCTION_LINKAGE_INTERNAL)
.value("IMPORT", IREE_VM_FUNCTION_LINKAGE_IMPORT)
@@ -771,37 +776,67 @@
.value("EXPORT", IREE_VM_FUNCTION_LINKAGE_EXPORT)
.export_values();
- auto vm_buffer = py::class_<VmBuffer>(m, "VmBuffer", py::buffer_protocol());
+ auto vm_buffer = py::class_<VmBuffer>(m, "VmBuffer");
VmRef::BindRefProtocol(vm_buffer, iree_vm_buffer_type,
iree_vm_buffer_retain_ref, iree_vm_buffer_deref,
iree_vm_buffer_isa);
+ // Implement the buffer protocol with low-level API.
+ {
+ static PyBufferProcs buffer_procs = {
+ // It is not legal to raise exceptions from these callbacks.
+ +[](PyObject* raw_self, Py_buffer* view, int flags) -> int {
+ // Cast must succeed due to invariants.
+ auto self = py::cast<VmBuffer*>(py::handle(raw_self));
+ if (view == NULL) {
+ PyErr_SetString(PyExc_ValueError, "NULL view in getbuffer");
+ return -1;
+ }
+
+ Py_INCREF(raw_self);
+ view->obj = raw_self;
+ view->buf = self->raw_ptr()->data.data;
+ view->len = self->raw_ptr()->data.data_length;
+ view->readonly =
+ !(self->raw_ptr()->access & IREE_VM_BUFFER_ACCESS_MUTABLE);
+ view->itemsize = 1;
+ view->format = (char*)"B"; // Byte
+ view->ndim = 1;
+ view->shape = nullptr;
+ view->strides = nullptr;
+ view->suboffsets = nullptr;
+ view->internal = nullptr;
+ return 0;
+ },
+ +[](PyObject* self_obj, Py_buffer* view) -> void {
+
+ },
+ };
+ auto heap_type = reinterpret_cast<PyHeapTypeObject*>(vm_buffer.ptr());
+ assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE &&
+ "must be heap type");
+ heap_type->as_buffer = buffer_procs;
+ }
+
vm_buffer
- .def(py::init([](iree_host_size_t length, iree_host_size_t alignment,
- bool is_mutable) {
- iree_vm_buffer_access_t access = 0;
- if (is_mutable) {
- access |= IREE_VM_BUFFER_ACCESS_MUTABLE;
- }
- iree_vm_buffer_t* raw_buffer;
- CheckApiStatus(
- iree_vm_buffer_create(access, length, alignment,
- iree_allocator_system(), &raw_buffer),
- "Error creating buffer");
- return VmBuffer::StealFromRawPtr(raw_buffer);
- }),
- py::arg("length"), py::arg("alignment") = 0,
- py::arg("mutable") = true)
- .def_buffer([](VmBuffer& self) -> py::buffer_info {
- return py::buffer_info(
- /*ptr=*/self.raw_ptr()->data.data,
- /*itemsize=*/sizeof(uint8_t),
- /*format=*/py::format_descriptor<uint8_t>::format(),
- /*ndim=*/1,
- /*shape=*/{self.raw_ptr()->data.data_length},
- /*strides=*/{1},
- /*readonly=*/
- !(self.raw_ptr()->access & IREE_VM_BUFFER_ACCESS_MUTABLE));
- })
+ .def(
+ "__init__",
+ [](VmBuffer* self, iree_host_size_t length,
+ iree_host_size_t alignment, bool is_mutable) {
+ iree_vm_buffer_access_t access = 0;
+ if (is_mutable) {
+ access |= IREE_VM_BUFFER_ACCESS_MUTABLE;
+ }
+ iree_vm_buffer_t* raw_buffer;
+ CheckApiStatus(
+ iree_vm_buffer_create(access, length, alignment,
+ iree_allocator_system(), &raw_buffer),
+ "Error creating buffer");
+
+ new (self) VmBuffer();
+ *self = VmBuffer::StealFromRawPtr(raw_buffer);
+ },
+ py::arg("length"), py::arg("alignment") = 0,
+ py::arg("mutable") = true)
.def("__repr__", [](VmBuffer& self) {
std::stringstream ss;
ss << "<VmBuffer size " << self.raw_ptr()->data.data_length << " at 0x"
@@ -816,8 +851,14 @@
iree_vm_list_deref, iree_vm_list_isa);
vm_list
// User Methods.
- .def(py::init(&VmVariantList::Create))
- .def_property_readonly("size", &VmVariantList::size)
+ .def(
+ "__init__",
+ [](VmVariantList* self, iree_host_size_t capacity) {
+ new (self) VmVariantList();
+ *self = VmVariantList::Create(capacity);
+ },
+ py::arg("capacity"))
+ .def_prop_ro("size", &VmVariantList::size)
.def("__len__", &VmVariantList::size)
.def("get_as_ref", &VmVariantList::GetAsRef)
.def("get_as_object", &VmVariantList::GetAsObject)
@@ -832,24 +873,22 @@
.def("__repr__", &VmVariantList::DebugString);
py::class_<iree_vm_function_t>(m, "VmFunction")
- .def_readonly("linkage", &iree_vm_function_t::linkage)
- .def_readonly("ordinal", &iree_vm_function_t::ordinal)
- .def_property_readonly("name",
- [](iree_vm_function_t& self) {
- iree_string_view_t name =
- iree_vm_function_name(&self);
- return py::str(name.data, name.size);
- })
- .def_property_readonly("module_name",
- [](iree_vm_function_t& self) {
- iree_string_view_t name =
- iree_vm_module_name(self.module);
- return py::str(name.data, name.size);
- })
- .def_property_readonly("reflection",
- [](iree_vm_function_t& self) {
- return GetFunctionReflectionDict(self);
- })
+ .def_ro("linkage", &iree_vm_function_t::linkage)
+ .def_ro("ordinal", &iree_vm_function_t::ordinal)
+ .def_prop_ro("name",
+ [](iree_vm_function_t& self) {
+ iree_string_view_t name = iree_vm_function_name(&self);
+ return py::str(name.data, name.size);
+ })
+ .def_prop_ro("module_name",
+ [](iree_vm_function_t& self) {
+ iree_string_view_t name = iree_vm_module_name(self.module);
+ return py::str(name.data, name.size);
+ })
+ .def_prop_ro("reflection",
+ [](iree_vm_function_t& self) {
+ return GetFunctionReflectionDict(self);
+ })
.def("__repr__", [](iree_vm_function_t& self) {
iree_string_view_t name = iree_vm_function_name(&self);
std::string repr("<VmFunction ");
@@ -865,13 +904,22 @@
return repr;
});
- py::class_<VmInstance>(m, "VmInstance").def(py::init(&VmInstance::Create));
-
+ py::class_<VmInstance>(m, "VmInstance").def("__init__", [](VmInstance* self) {
+ new (self) VmInstance();
+ *self = VmInstance::Create();
+ });
py::class_<VmContext>(m, "VmContext")
- .def(py::init(&VmContext::Create), py::arg("instance"),
- py::arg("modules") = std::optional<std::vector<VmModule*>>())
+ .def(
+ "__init__",
+ [](VmContext* self, VmInstance* instance,
+ std::optional<std::vector<VmModule*>> modules) {
+ new (self) VmContext();
+ *self = VmContext::Create(instance, modules);
+ },
+ py::arg("instance"),
+ py::arg("modules") = std::optional<std::vector<VmModule*>>())
.def("register_modules", &VmContext::RegisterModules)
- .def_property_readonly("context_id", &VmContext::context_id)
+ .def_prop_ro("context_id", &VmContext::context_id)
.def("invoke", &VmContext::Invoke);
py::class_<VmModule>(m, "VmModule")
@@ -891,40 +939,40 @@
.def_static("mmap", &VmModule::MMap, py::arg("instance"),
py::arg("filepath"), py::arg("destroy_callback") = py::none(),
kMMapDocstring)
- .def_property_readonly("name", &VmModule::name)
- .def_property_readonly("version",
- [](VmModule& self) {
- iree_vm_module_signature_t sig =
- iree_vm_module_signature(self.raw_ptr());
- return sig.version;
- })
+ .def_prop_ro("name", &VmModule::name)
+ .def_prop_ro("version",
+ [](VmModule& self) {
+ iree_vm_module_signature_t sig =
+ iree_vm_module_signature(self.raw_ptr());
+ return sig.version;
+ })
.def("lookup_function", &VmModule::LookupFunction, py::arg("name"),
py::arg("linkage") = IREE_VM_FUNCTION_LINKAGE_EXPORT)
- .def_property_readonly(
+ .def_prop_ro(
"stashed_flatbuffer_blob",
[](VmModule& self) { return self.get_stashed_flatbuffer_blob(); })
- .def_property_readonly(
- "function_names",
- [](VmModule& self) {
- py::list names;
- iree_vm_module_signature_t sig =
- iree_vm_module_signature(self.raw_ptr());
- for (size_t ordinal = 0; ordinal < sig.export_function_count;
- ++ordinal) {
- iree_vm_function_t f;
- auto status = iree_vm_module_lookup_function_by_ordinal(
- self.raw_ptr(), IREE_VM_FUNCTION_LINKAGE_EXPORT, ordinal, &f);
- if (iree_status_is_not_found(status)) {
- iree_status_ignore(status);
- break;
- }
- CheckApiStatus(status, "Error enumerating module");
- iree_string_view_t fname = iree_vm_function_name(&f);
- py::str name(fname.data, fname.size);
- names.append(name);
- }
- return names;
- })
+ .def_prop_ro("function_names",
+ [](VmModule& self) {
+ py::list names;
+ iree_vm_module_signature_t sig =
+ iree_vm_module_signature(self.raw_ptr());
+ for (size_t ordinal = 0;
+ ordinal < sig.export_function_count; ++ordinal) {
+ iree_vm_function_t f;
+ auto status = iree_vm_module_lookup_function_by_ordinal(
+ self.raw_ptr(), IREE_VM_FUNCTION_LINKAGE_EXPORT,
+ ordinal, &f);
+ if (iree_status_is_not_found(status)) {
+ iree_status_ignore(status);
+ break;
+ }
+ CheckApiStatus(status, "Error enumerating module");
+ iree_string_view_t fname = iree_vm_function_name(&f);
+ py::str name(fname.data, fname.size);
+ names.append(name);
+ }
+ return names;
+ })
.def("__repr__", [](VmModule& self) {
std::string repr("<VmModule ");
iree_string_view_t name = iree_vm_module_name(self.raw_ptr());
@@ -959,8 +1007,7 @@
.def("deref", &VmRef::Deref, py::arg("value"),
py::arg("optional") = false)
.def("__repr__", &VmRef::ToString)
- .def_property_readonly(VmRef::kRefAttr,
- [](py::object self) { return self; })
+ .def_prop_ro(VmRef::kRefAttr, [](py::object self) { return self; })
.def("__eq__",
[](VmRef& self, VmRef& other) {
return self.ref().ptr == other.ref().ptr;