Port iree.runtime to nanobind. (#14214)
I believe that this should be a no-op for users. There is one minor API
change (the MappedMemory class no longer implements the buffer protocol,
but I've seen no evidence that this was actually used since it was a
less functional way to get a host ndarray).
More adventurous use of nanobind is possible in the future (i.e. using
`ndarray` and `dlpack` interop for sharing across frameworks), but using
that will most likely necessitate API changes, which I was working to
avoid.
Aside from relatively mechanical differences from pybind11, the main
issues were that the buffer protocol and array support was dropped in
nanobind. This required some direct coding against the C API to achieve
the same characteristics. I think this is actually an improvement as the
pybind11 implementations of these features was neither efficient nor
obvious what it was doing.
A build time dependency on `nanobind` is added. When building Python
wheels, this gets satisfied automatically. Otherwise, the docker images
have been updated to pre-install the necessary Python package. In
addition, there is now a build time dependency on NumPy headers, which
should already be installed (pybind11 vendored stripped down copies of
these headers in an effort to avoid this, but I opted to just do the
normal thing).
Nanobind's performance is [quite
compelling](https://nanobind.readthedocs.io/en/latest/benchmark.html)
and owes to a combination of favoring more efficient binding styles that
would basically be a rewrite in pybind11 and exclusive use of the new
Python 3.8+ vectorcall ABI. Since the runtime is performance critical
and the cost of Python calls is already quite visibly adding overhead on
traces, it makes sense to baseline on the most efficient implementation.
In addition, the compile-time savings seem to be real and the build is
noticeably faster (this was not a primary consideration, just a nice
bonus).
diff --git a/runtime/bindings/python/CMakeLists.txt b/runtime/bindings/python/CMakeLists.txt
index f0c3fd3..8f4bac8 100644
--- a/runtime/bindings/python/CMakeLists.txt
+++ b/runtime/bindings/python/CMakeLists.txt
@@ -4,9 +4,23 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-set(NUMPY_DEPS "")
-set(PYBIND_COPTS "-fexceptions")
-set(PYBIND_EXTENSION_COPTS "-fvisibility=hidden")
+if(NOT nanobind_FOUND)
+ # nanobind requires Python >= 3.8.
+ find_package(Python 3.8 COMPONENTS Interpreter Development.Module NumPy REQUIRED)
+ find_package(nanobind CONFIG QUIET)
+ if(NOT nanobind_FOUND)
+ execute_process(
+ COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
+ OUTPUT_STRIP_TRAILING_WHITESPACE
+ OUTPUT_VARIABLE NB_DIR
+ RESULT_VARIABLE RC)
+ if(RC AND NOT RC EQUAL 0)
+ message(WARNING "Probing for nanobind failed. Please install the project's Python dependencies or '${Python_EXECUTABLE} -m pip install nanobind'")
+ endif()
+ list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
+ endif()
+ find_package(nanobind CONFIG REQUIRED)
+endif()
set(_PYTHON_EXTRA_SRCS)
set(_EXTRA_INSTALL_TOOL_TARGETS)
@@ -23,39 +37,78 @@
# Package
################################################################################
-iree_pyext_module(
- NAME
- PyExtRt
- MODULE_NAME iree/_runtime
- SRCS
- "binding.h"
- "initialize_module.cc"
- "invoke.h"
- "invoke.cc"
- "hal.h"
- "hal.cc"
- "py_module.h"
- "py_module.cc"
- "status_utils.cc"
- "status_utils.h"
- "vm.h"
- "vm.cc"
- UNIX_LINKER_SCRIPT
- "unix_version.lds"
- DEFINES
- # Pybind code seems to be incompatible with C++ allocation tracing
- # hooks so disable it.
- IREE_TRACING_HOOK_CPP_NEW_DELETE=0
- DEPS
- iree::base
- iree::base::internal::flags
- iree::hal
- iree::hal::drivers
- iree::hal::utils::allocators
- iree::modules::hal
- iree::tooling::modules
- iree::vm
- iree::vm::bytecode::module
+# nanobind requires both RTTI and Exceptions, and it does not know that
+# we have disabled them globally, so turn them back on. Since this is
+# *the only* place in the codebase where we do this, just inline here.
+# Note that this is playing with fire and the extension code is structured
+# so as not to cause problems with RTTI cross-module issues.
+iree_select_compiler_opts(_RTTI_AND_EXCEPTION_COPTS
+ CLANG_OR_GCC
+ "-frtti"
+ "-fexceptions"
+ MSVC_OR_CLANG_CL
+ # Configure exception handling for standard C++ behavior.
+ # - /EHs enables C++ catch-style exceptions
+ # - /EHc breaks unwinding across extern C boundaries, dramatically reducing
+ # unwind table size and associated exception handling overhead as the
+ # compiler can assume no exception will ever be thrown within any function
+ # annotated with extern "C".
+ # https://docs.microsoft.com/en-us/cpp/build/reference/eh-exception-handling-model
+ "/EHsc"
+ # Configure RTTI generation.
+ # - /GR - Enable generation of RTTI (default)
+ # - /GR- - Disables generation of RTTI
+ # https://docs.microsoft.com/en-us/cpp/build/reference/gr-enable-run-time-type-information?view=msvc-160
+ "/GR"
+)
+
+nanobind_add_module(iree_runtime_bindings_python_PyExtRt
+ NB_STATIC LTO
+ "binding.h"
+ "initialize_module.cc"
+ "invoke.h"
+ "invoke.cc"
+ "hal.h"
+ "hal.cc"
+ "numpy_interop.h"
+ "numpy_interop.cc"
+ "py_module.h"
+ "py_module.cc"
+ "status_utils.cc"
+ "status_utils.h"
+ "vm.h"
+ "vm.cc"
+)
+
+target_link_libraries(iree_runtime_bindings_python_PyExtRt
+ PRIVATE
+ iree::base
+ iree::base::internal::flags
+ iree::hal
+ iree::hal::drivers
+ iree::hal::utils::allocators
+ iree::modules::hal
+ iree::tooling::modules
+ iree::vm
+ iree::vm::bytecode::module
+
+ Python::NumPy
+)
+
+target_compile_options(iree_runtime_bindings_python_PyExtRt
+ PRIVATE
+ ${IREE_DEFAULT_COPTS}
+ # Default COPTS disable exceptions/rtti. Re-enable them.
+ ${_RTTI_AND_EXCEPTION_COPTS}
+)
+target_compile_definitions(iree_runtime_bindings_python_PyExtRt
+ PRIVATE
+ IREE_TRACING_HOOK_CPP_NEW_DELETE=0
+)
+
+set_target_properties(
+ iree_runtime_bindings_python_PyExtRt
+ PROPERTIES OUTPUT_NAME "iree/_runtime"
)
iree_py_library(
@@ -78,7 +131,7 @@
"iree/runtime/scripts/iree_run_module/__main__.py"
${_PYTHON_EXTRA_SRCS}
PYEXT_DEPS
- ::PyExtRt
+ iree_runtime_bindings_python_PyExtRt
)
iree_symlink_tool(
diff --git a/runtime/bindings/python/binding.h b/runtime/bindings/python/binding.h
index feb1722..71168be 100644
--- a/runtime/bindings/python/binding.h
+++ b/runtime/bindings/python/binding.h
@@ -7,17 +7,23 @@
#ifndef IREE_BINDINGS_PYTHON_IREE_BINDING_H_
#define IREE_BINDINGS_PYTHON_IREE_BINDING_H_
+#include <nanobind/nanobind.h>
+#include <nanobind/ndarray.h>
+#include <nanobind/stl/optional.h>
+#include <nanobind/stl/string.h>
+#include <nanobind/stl/string_view.h>
+#include <nanobind/stl/vector.h>
+
#include <optional>
#include <vector>
#include "iree/base/api.h"
-#include "pybind11/pybind11.h"
-#include "pybind11/stl.h"
namespace iree {
namespace python {
-namespace py = pybind11;
+namespace py = nanobind;
+using namespace nanobind::literals;
template <typename T>
struct ApiPtrAdapter {};
@@ -94,6 +100,22 @@
T* instance_;
};
+// Pybind11 had an isintance for Python objects helper. Nanobind doesn't.
+inline bool is_instance_of_type_object(py::handle inst,
+ py::handle type_object) {
+ int rc = PyObject_IsInstance(inst.ptr(), type_object.ptr());
+ if (rc == -1) {
+ throw py::python_error();
+ }
+ return static_cast<bool>(rc);
+}
+
+// Nanobind's tuple class has a default constructor that creates a nullptr
+// tuple. Which is not really what one wants.
+inline py::object create_empty_tuple() {
+ return py::steal(py::handle(PyTuple_New(0)));
+}
+
} // namespace python
} // namespace iree
diff --git a/runtime/bindings/python/buffer_interop.h b/runtime/bindings/python/buffer_interop.h
new file mode 100644
index 0000000..20ee352
--- /dev/null
+++ b/runtime/bindings/python/buffer_interop.h
@@ -0,0 +1,42 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+// Nanobind removed buffer protocol interop in favor of a new and improved
+// ndarray thingy. This thingy is mostly better for a subset of cases but is
+// not great for just generically accessing chunks of memory. For cases where
+// we do the latter (mapping files, etc), we just have some helpers over the
+// low level Python buffer protocol to ease the transition.
+
+#ifndef IREE_BINDINGS_PYTHON_IREE_BUFFER_INTEROP_H_
+#define IREE_BINDINGS_PYTHON_IREE_BUFFER_INTEROP_H_
+
+#include "./binding.h"
+
+namespace iree::python {
+
+// Represents a Py_buffer obtained via PyObject_GetBuffer() and terminated via
+// PyBuffer_Release().
+class PyBufferRequest {
+ public:
+ PyBufferRequest(py::object &exporter, int flags) {
+ int rc = PyObject_GetBuffer(exporter.ptr(), &view_, flags);
+ if (rc != 0) {
+ throw py::python_error();
+ }
+ }
+ ~PyBufferRequest() { PyBuffer_Release(&view_); }
+ PyBufferRequest(const PyBufferRequest &) = delete;
+ void operator=(const PyBufferRequest &) = delete;
+
+ Py_buffer &view() { return view_; }
+
+ private:
+ Py_buffer view_;
+};
+
+} // namespace iree::python
+
+#endif // IREE_BINDINGS_PYTHON_IREE_BUFFER_INTEROP_H_
diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc
index eb65c5b..1da899f 100644
--- a/runtime/bindings/python/hal.cc
+++ b/runtime/bindings/python/hal.cc
@@ -6,12 +6,12 @@
#include "./hal.h"
+#include "./numpy_interop.h"
#include "./vm.h"
#include "iree/base/internal/path.h"
#include "iree/hal/api.h"
#include "iree/hal/utils/allocators.h"
#include "iree/modules/hal/module.h"
-#include "pybind11/numpy.h"
namespace iree {
namespace python {
@@ -96,7 +96,7 @@
// Acquire the backing buffer and setup RAII release.
if (PyObject_GetBuffer(buffer.ptr(), &py_view, flags) != 0) {
// The GetBuffer call is required to set an appropriate error.
- throw py::error_already_set();
+ throw py::python_error();
}
PyBufferReleaser py_view_releaser(py_view);
@@ -117,7 +117,7 @@
if (!element_type) {
return py::cast(HalBuffer::StealFromRawPtr(hal_buffer),
- py::return_value_policy::move);
+ py::rv_policy::move);
}
// Create the buffer_view. (note that numpy shape is ssize_t, so we need to
@@ -135,7 +135,7 @@
iree_hal_buffer_release(hal_buffer);
return py::cast(HalBufferView::StealFromRawPtr(hal_buffer_view),
- py::return_value_policy::move);
+ py::rv_policy::move);
}
//------------------------------------------------------------------------------
@@ -177,7 +177,7 @@
std::string repr("<HalBuffer ");
AppendHalBufferRepr(raw_ptr(), repr);
repr.append(">");
- return py::str(repr);
+ return py::str(py::cast(repr));
}
//------------------------------------------------------------------------------
@@ -205,36 +205,32 @@
repr.append(", ");
AppendHalBufferRepr(iree_hal_buffer_view_buffer(raw_ptr()), repr);
repr.append(">");
- return py::str(repr);
+ return py::str(py::cast(repr));
}
//------------------------------------------------------------------------------
// HalDevice
//------------------------------------------------------------------------------
-void HalDevice::BeginProfiling(const py::kwargs& kwargs) {
+void HalDevice::BeginProfiling(std::optional<std::string> mode,
+ std::optional<std::string> file_path) {
iree_hal_device_profiling_options_t options;
memset(&options, 0, sizeof(options));
options.mode = IREE_HAL_DEVICE_PROFILING_MODE_QUEUE_OPERATIONS;
- if (kwargs.contains("mode")) {
- auto mode_str = kwargs["mode"].cast<std::string>();
- if (mode_str == "queue") {
+ if (mode) {
+ if (*mode == "queue") {
options.mode = IREE_HAL_DEVICE_PROFILING_MODE_QUEUE_OPERATIONS;
- } else if (mode_str == "dispatch") {
+ } else if (*mode == "dispatch") {
options.mode = IREE_HAL_DEVICE_PROFILING_MODE_DISPATCH_COUNTERS;
- } else if (mode_str == "executable") {
+ } else if (*mode == "executable") {
options.mode = IREE_HAL_DEVICE_PROFILING_MODE_EXECUTABLE_COUNTERS;
} else {
throw RaiseValueError("unrecognized profiling mode");
}
}
- std::string file_path = kwargs.contains("file_path")
- ? kwargs["file_path"].cast<std::string>()
- : "";
- options.file_path = !file_path.empty() ? file_path.c_str() : NULL;
-
+ options.file_path = file_path ? file_path->c_str() : nullptr;
CheckApiStatus(iree_hal_device_profiling_begin(raw_ptr(), &options),
"starting device profiling");
}
@@ -315,17 +311,17 @@
// Configures |device| based on flags before returning it to the user.
static iree_status_t ConfigureDevice(iree_hal_device_t* device,
- const py::kwargs& kwargs) {
+ std::optional<py::list> allocators) {
// Optionally wrap the base device allocator with caching/pooling.
// Doing this here satisfies the requirement that no buffers have been
// allocated yet - if we returned the device without doing this the caller
// can more easily break the rules.
- if (kwargs.contains("allocators")) {
+ if (allocators) {
// NOTE: we need to pass string views that point to the std::string storage.
// We do that in two passes because as we grow spec_storage it may
// reallocate itself and invalidate the pointers - only after we're done
// can we capture them in views.
- auto spec_list = py::cast<py::list>(kwargs["allocators"]);
+ auto& spec_list = *allocators;
std::vector<std::string> spec_storage;
spec_storage.reserve(spec_list.size());
for (auto item : spec_list) {
@@ -343,18 +339,18 @@
return iree_ok_status();
}
-HalDevice HalDriver::CreateDefaultDevice(const py::kwargs& kwargs) {
+HalDevice HalDriver::CreateDefaultDevice(std::optional<py::list> allocators) {
iree_hal_device_t* device;
CheckApiStatus(iree_hal_driver_create_default_device(
raw_ptr(), iree_allocator_system(), &device),
"Error creating default device");
- CheckApiStatus(ConfigureDevice(device, kwargs),
+ CheckApiStatus(ConfigureDevice(device, allocators),
"Error configuring the device");
return HalDevice::StealFromRawPtr(device);
}
HalDevice HalDriver::CreateDevice(iree_hal_device_id_t device_id,
- const py::kwargs& kwargs) {
+ std::optional<py::list> allocators) {
// Since the device ids are supposed to be opaque, we need to verify
// them by querying available devices.
py::list available_devices = QueryAvailableDevices();
@@ -376,7 +372,7 @@
msg.append("Device id ");
msg.append(std::to_string(device_id));
msg.append(" not found. Available devices: ");
- msg.append(py::repr(available_devices));
+ msg.append(py::cast<std::string>(py::repr(available_devices)));
throw std::invalid_argument(std::move(msg));
}
@@ -386,13 +382,13 @@
raw_ptr(), device_id, params.size(), ¶ms.front(),
iree_allocator_system(), &device),
"Error creating default device");
- CheckApiStatus(ConfigureDevice(device, kwargs),
+ CheckApiStatus(ConfigureDevice(device, allocators),
"Error configuring the device");
return HalDevice::StealFromRawPtr(device);
}
HalDevice HalDriver::CreateDeviceByURI(std::string& device_uri,
- const py::kwargs& kwargs) {
+ std::optional<py::list> allocators) {
iree_hal_device_t* device;
iree_string_view_t device_uri_sv{
device_uri.data(), static_cast<iree_host_size_t>(device_uri.size())};
@@ -400,82 +396,12 @@
iree_hal_driver_create_device_by_uri(raw_ptr(), device_uri_sv,
iree_allocator_system(), &device),
"Error creating device");
- CheckApiStatus(ConfigureDevice(device, kwargs),
+ CheckApiStatus(ConfigureDevice(device, allocators),
"Error configuring the device");
return HalDevice::StealFromRawPtr(device);
}
//------------------------------------------------------------------------------
-// Enum helpers
-//------------------------------------------------------------------------------
-
-namespace {
-
-py::object MapElementTypeToDType(iree_hal_element_type_t element_type) {
- // See:
- // * https://numpy.org/doc/stable/reference/arrays.dtypes.html
- // * https://docs.python.org/3/c-api/arg.html#numbers
- //
- // Single letter codes can be ambiguous across platforms, so prefer explicit
- // bit depth values, ("Type strings: Any string in numpy.sctypeDict.keys()").
- // See https://github.com/pybind/pybind11/issues/1908
- const char* dtype_string;
- switch (element_type) {
- case IREE_HAL_ELEMENT_TYPE_BOOL_8:
- dtype_string = "?";
- break;
- case IREE_HAL_ELEMENT_TYPE_INT_8:
- case IREE_HAL_ELEMENT_TYPE_SINT_8:
- dtype_string = "int8";
- break;
- case IREE_HAL_ELEMENT_TYPE_UINT_8:
- dtype_string = "uint8";
- break;
- case IREE_HAL_ELEMENT_TYPE_INT_16:
- case IREE_HAL_ELEMENT_TYPE_SINT_16:
- dtype_string = "int16";
- break;
- case IREE_HAL_ELEMENT_TYPE_UINT_16:
- dtype_string = "uint16";
- break;
- case IREE_HAL_ELEMENT_TYPE_INT_32:
- case IREE_HAL_ELEMENT_TYPE_SINT_32:
- dtype_string = "int32";
- break;
- case IREE_HAL_ELEMENT_TYPE_UINT_32:
- dtype_string = "uint32";
- break;
- case IREE_HAL_ELEMENT_TYPE_INT_64:
- case IREE_HAL_ELEMENT_TYPE_SINT_64:
- dtype_string = "int64";
- break;
- case IREE_HAL_ELEMENT_TYPE_UINT_64:
- dtype_string = "uint64";
- break;
- case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
- dtype_string = "float16";
- break;
- case IREE_HAL_ELEMENT_TYPE_FLOAT_32:
- dtype_string = "float32";
- break;
- case IREE_HAL_ELEMENT_TYPE_FLOAT_64:
- dtype_string = "float64";
- break;
- case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64:
- dtype_string = "complex64";
- break;
- case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128:
- dtype_string = "complex128";
- break;
- default:
- throw RaiseValueError("Unsupported VM Buffer -> numpy dtype mapping");
- }
- return py::dtype(dtype_string);
-}
-
-} // namespace
-
-//------------------------------------------------------------------------------
// HAL module
//------------------------------------------------------------------------------
@@ -492,7 +418,7 @@
// Bindings
//------------------------------------------------------------------------------
-void SetupHalBindings(pybind11::module m) {
+void SetupHalBindings(nanobind::module_ m) {
py::dict driver_cache;
// Built-in module creation.
@@ -621,36 +547,46 @@
.value("COMPLEX_64", IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64)
.value("COMPLEX_128", IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128)
.export_values()
- .def_static("map_to_dtype", &MapElementTypeToDType);
+ .def_static("map_to_dtype", [](iree_hal_element_type_t element_type) {
+ int typenum = numpy::ConvertHalElementTypeToNumPyTypeNum(element_type);
+ return numpy::DescrNewFromType(typenum);
+ });
py::class_<HalDevice>(m, "HalDevice")
- .def_property_readonly(
+ .def_prop_ro(
"allocator",
[](HalDevice& self) {
return HalAllocator::BorrowFromRawPtr(self.allocator());
},
py::keep_alive<0, 1>())
- .def("begin_profiling", &HalDevice::BeginProfiling)
+ .def("begin_profiling", &HalDevice::BeginProfiling,
+ py::arg("mode") = py::none(), py::arg("file_path") = py::none())
.def("end_profiling", &HalDevice::EndProfiling);
py::class_<HalDriver>(m, "HalDriver")
.def_static("query", &HalDriver::Query)
+
+ // All 'create_device' functions take optional kwargs that should be kept
+ // in sync.
.def("create_default_device", &HalDriver::CreateDefaultDevice,
- py::keep_alive<0, 1>())
- .def("create_device", &HalDriver::CreateDevice, py::keep_alive<0, 1>())
+ py::keep_alive<0, 1>(), py::arg("allocators") = py::none())
+ .def("create_device", &HalDriver::CreateDevice, py::keep_alive<0, 1>(),
+ py::arg("device_id"), py::arg("allocators") = py::none())
.def("create_device_by_uri", &HalDriver::CreateDeviceByURI,
- py::keep_alive<0, 1>())
+ py::keep_alive<0, 1>(), py::arg("device_uri"),
+ py::arg("allocators") = py::none())
.def(
"create_device",
[](HalDriver& self, py::dict device_info,
- const py::kwargs& kwargs) -> HalDevice {
+ std::optional<py::list> allocators) -> HalDevice {
// Alias of create_device that takes a dict as returned from
// query_available_devices for convenience.
auto device_id =
py::cast<iree_hal_device_id_t>(device_info["device_id"]);
- return self.CreateDevice(device_id, kwargs);
+ return self.CreateDevice(device_id, allocators);
},
- py::keep_alive<0, 1>())
+ py::keep_alive<0, 1>(), py::arg("device_info"),
+ py::arg("allocators") = py::none())
.def("query_available_devices", &HalDriver::QueryAvailableDevices);
m.def(
@@ -667,12 +603,11 @@
CheckApiStatus(iree_hal_allocator_trim(self.raw_ptr()),
"Error trim()'ing HAL allocator");
})
- .def_property_readonly(
+ .def_prop_ro(
"has_statistics",
[](HalAllocator& self) -> bool { return IREE_STATISTICS_ENABLE; })
- .def_property_readonly("statistics", &HalAllocator::QueryStatistics)
- .def_property_readonly("formatted_statistics",
- &HalAllocator::FormattedStatistics)
+ .def_prop_ro("statistics", &HalAllocator::QueryStatistics)
+ .def_prop_ro("formatted_statistics", &HalAllocator::FormattedStatistics)
.def(
"query_buffer_compatibility",
[](HalAllocator& self, int memory_type, int allowed_usage,
@@ -726,37 +661,43 @@
iree_hal_buffer_view_retain_ref,
iree_hal_buffer_view_deref, iree_hal_buffer_view_isa);
hal_buffer_view.def("map", HalMappedMemory::Create, py::keep_alive<0, 1>())
- .def_property_readonly(
- "shape",
- [](HalBufferView& self) {
- iree_host_size_t rank =
- iree_hal_buffer_view_shape_rank(self.raw_ptr());
- auto* dims = iree_hal_buffer_view_shape_dims(self.raw_ptr());
- py::list result;
- for (iree_host_size_t i = 0; i < rank; ++i) {
- result.append(dims[i]);
- }
- return result;
- })
- .def_property_readonly(
- "element_type",
- [](HalBufferView& self) {
- return iree_hal_buffer_view_element_type(self.raw_ptr());
- })
+ .def_prop_ro("shape",
+ [](HalBufferView& self) {
+ iree_host_size_t rank =
+ iree_hal_buffer_view_shape_rank(self.raw_ptr());
+ auto* dims =
+ iree_hal_buffer_view_shape_dims(self.raw_ptr());
+ py::list result;
+ for (iree_host_size_t i = 0; i < rank; ++i) {
+ result.append(dims[i]);
+ }
+ return result;
+ })
+ .def_prop_ro("element_type",
+ [](HalBufferView& self) {
+ return iree_hal_buffer_view_element_type(self.raw_ptr());
+ })
.def("__repr__", &HalBufferView::Repr);
- py::class_<HalMappedMemory>(m, "MappedMemory", py::buffer_protocol())
- .def_buffer(&HalMappedMemory::ToBufferInfo)
- .def("asarray",
- [](HalMappedMemory& self, std::vector<iree_host_size_t> shape,
- py::object dtype) {
- py::object py_mapped_memory = py::cast(self);
- return py::array(std::move(dtype), shape,
- self.mapped_memory().contents.data,
- std::move(py_mapped_memory) /* base */);
- });
+ py::class_<HalMappedMemory>(m, "MappedMemory")
+ .def(
+ "asarray",
+ [](HalMappedMemory* self, std::vector<iree_host_size_t> shape,
+ py::object dtype) {
+ py::object py_mapped_memory = py::cast(self);
+ static_assert(sizeof(shape[0]) == sizeof(intptr_t),
+ "size_t not of same size as intptr_t");
+ int typenum = numpy::TypenumFromDescr(dtype);
+ return numpy::SimpleNewFromData(
+ shape.size(), reinterpret_cast<intptr_t const*>(shape.data()),
+ typenum, self->mapped_memory().contents.data, py_mapped_memory);
+ },
+ py::arg("shape"), py::arg("element_type"));
- py::class_<HalShape>(m, "Shape").def(py::init(&HalShape::FromIntVector));
+ py::class_<HalShape>(m, "Shape")
+ .def("__init__", [](HalShape* self, std::vector<iree_hal_dim_t> indices) {
+ new (self) HalShape(indices);
+ });
}
} // namespace python
diff --git a/runtime/bindings/python/hal.h b/runtime/bindings/python/hal.h
index ff15dfa..bae4d96 100644
--- a/runtime/bindings/python/hal.h
+++ b/runtime/bindings/python/hal.h
@@ -86,7 +86,8 @@
return iree_hal_device_allocator(raw_ptr());
}
- void BeginProfiling(const py::kwargs& kwargs);
+ void BeginProfiling(std::optional<std::string> mode,
+ std::optional<std::string> file_path);
void EndProfiling();
};
@@ -97,11 +98,11 @@
py::dict& driver_cache);
py::list QueryAvailableDevices();
- HalDevice CreateDefaultDevice(const py::kwargs& kwargs);
+ HalDevice CreateDefaultDevice(std::optional<py::list> allocators);
HalDevice CreateDevice(iree_hal_device_id_t device_id,
- const py::kwargs& kwargs);
+ std::optional<py::list> allocators);
HalDevice CreateDeviceByURI(std::string& device_uri,
- const py::kwargs& kwargs);
+ std::optional<py::list> allocators);
};
class HalAllocator : public ApiRefCounted<HalAllocator, iree_hal_allocator_t> {
@@ -116,10 +117,8 @@
struct HalShape {
public:
- static HalShape FromIntVector(std::vector<iree_hal_dim_t> indices) {
- HalShape s;
- s.s = {indices.begin(), indices.end()};
- return s;
+ HalShape(std::vector<iree_hal_dim_t>& indices) {
+ s = {indices.begin(), indices.end()};
}
std::vector<iree_hal_dim_t> s;
@@ -193,31 +192,6 @@
return HalMappedMemory(mapped_memory, bv.raw_ptr());
}
- py::buffer_info ToBufferInfo() {
- std::vector<iree_hal_dim_t> shape(iree_hal_buffer_view_shape_rank(bv_));
- CheckApiStatus(
- iree_hal_buffer_view_shape(bv_, shape.size(), shape.data(), nullptr),
- "Error getting buffer view shape");
- iree_hal_element_type_t element_type =
- iree_hal_buffer_view_element_type(bv_);
- int32_t element_size = iree_hal_element_dense_byte_count(element_type);
- std::vector<py::ssize_t> dims(shape.size());
- for (int i = 0; i < shape.size(); ++i) {
- dims[i] = shape[i];
- }
- std::vector<py::ssize_t> strides(shape.size());
- if (!strides.empty()) {
- strides[shape.size() - 1] = element_size;
- for (int i = shape.size() - 2; i >= 0; --i) {
- strides[i] = strides[i + 1] * shape[i + 1];
- }
- }
-
- return py::buffer_info(mapped_memory_.contents.data, element_size,
- py::format_descriptor<float>::format(), shape.size(),
- dims, strides);
- }
-
iree_hal_buffer_mapping_t& mapped_memory() { return mapped_memory_; }
private:
@@ -225,7 +199,7 @@
iree_hal_buffer_view_t* bv_ = nullptr;
};
-void SetupHalBindings(pybind11::module m);
+void SetupHalBindings(nanobind::module_ m);
} // namespace python
} // namespace iree
diff --git a/runtime/bindings/python/initialize_module.cc b/runtime/bindings/python/initialize_module.cc
index 3e76943..c3eac96 100644
--- a/runtime/bindings/python/initialize_module.cc
+++ b/runtime/bindings/python/initialize_module.cc
@@ -7,6 +7,7 @@
#include "./binding.h"
#include "./hal.h"
#include "./invoke.h"
+#include "./numpy_interop.h"
#include "./py_module.h"
#include "./status_utils.h"
#include "./vm.h"
@@ -16,7 +17,9 @@
namespace iree {
namespace python {
-PYBIND11_MODULE(_runtime, m) {
+NB_MODULE(_runtime, m) {
+ numpy::InitializeNumPyInterop();
+
IREE_CHECK_OK(iree_hal_register_all_available_drivers(
iree_hal_driver_registry_default()));
@@ -29,7 +32,7 @@
m.def("parse_flags", [](py::args py_flags) {
std::vector<std::string> alloced_flags;
alloced_flags.push_back("python");
- for (auto &py_flag : py_flags) {
+ for (py::handle py_flag : py_flags) {
alloced_flags.push_back(py::cast<std::string>(py_flag));
}
diff --git a/runtime/bindings/python/invoke.cc b/runtime/bindings/python/invoke.cc
index c3b5df1..cafa2ad 100644
--- a/runtime/bindings/python/invoke.cc
+++ b/runtime/bindings/python/invoke.cc
@@ -6,6 +6,8 @@
#include "./invoke.h"
+#include <functional>
+
#include "./hal.h"
#include "./vm.h"
#include "iree/base/api.h"
@@ -74,18 +76,19 @@
py::str kAttrBufferView = py::str("_buffer_view");
// Module 'numpy'.
- py::module &numpy_module() { return numpy_module_; }
+ py::module_ &numpy_module() { return numpy_module_; }
py::object &runtime_module() {
if (!runtime_module_) {
- runtime_module_ = py::module::import("iree.runtime");
+ runtime_module_ = py::module_::import_("iree.runtime");
}
return *runtime_module_;
}
- py::module &array_interop_module() {
+ py::module_ &array_interop_module() {
if (!array_interop_module_) {
- array_interop_module_ = py::module::import("iree.runtime.array_interop");
+ array_interop_module_ =
+ py::module_::import_("iree.runtime.array_interop");
}
return *array_interop_module_;
}
@@ -97,7 +100,7 @@
return *device_array_type_;
}
- py::type &hal_buffer_view_type() { return hal_buffer_view_type_; }
+ py::type_object &hal_buffer_view_type() { return hal_buffer_view_type_; }
py::object MapElementAbiTypeToDtype(py::object &element_abi_type) {
try {
@@ -165,12 +168,13 @@
IREE_TRACE_SCOPE_NAMED("ArgumentPacker::ReflectionNdarray");
HalBufferView *bv = nullptr;
py::object retained_bv;
- if (py::isinstance(py_value, device_array_type())) {
+ if (is_instance_of_type_object(py_value, device_array_type())) {
// Short-circuit: If a DeviceArray is provided, assume it is
// correct.
IREE_TRACE_SCOPE_NAMED("PackDeviceArray");
bv = py::cast<HalBufferView *>(py_value.attr(kAttrBufferView));
- } else if (py::isinstance(py_value, hal_buffer_view_type())) {
+ } else if (is_instance_of_type_object(py_value,
+ hal_buffer_view_type())) {
// Short-circuit: If a HalBufferView is provided directly.
IREE_TRACE_SCOPE_NAMED("PackBufferView");
bv = py::cast<HalBufferView *>(py_value);
@@ -363,7 +367,7 @@
PackCallback GetGenericPackCallbackFor(py::handle arg) {
PopulatePyTypeToPackCallbacks();
- py::type clazz = py::type::of(arg);
+ py::handle clazz = arg.type();
auto found_it = py_type_to_pack_callbacks_.find(clazz.ptr());
if (found_it == py_type_to_pack_callbacks_.end()) {
// Probe to see if we have a host array.
@@ -419,7 +423,7 @@
// has no further refinement of these, just treat them as vm 64 bit int and
// floats and let the VM take care of it. There isn't much else we can do.
AddPackCallback(
- py::type::of(py::cast(1)),
+ py::cast(1).type(),
[](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
iree_vm_value_t vm_value =
iree_vm_value_make_i64(py::cast<int64_t>(py_value));
@@ -428,7 +432,7 @@
});
AddPackCallback(
- py::type::of(py::cast(1.0)),
+ py::cast(1.0).type(),
[](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
iree_vm_value_t vm_value =
iree_vm_value_make_f64(py::cast<double>(py_value));
@@ -441,7 +445,7 @@
py::handle py_value) {
auto py_seq = py::cast<py::sequence>(py_value);
VmVariantList item_list = VmVariantList::Create(py::len(py_seq));
- for (py::object py_item : py_seq) {
+ for (py::handle py_item : py_seq) {
PackCallback sub_packer = GetGenericPackCallbackFor(py_item);
if (!sub_packer) {
std::string message("could not convert python value to VM: ");
@@ -454,8 +458,8 @@
iree_vm_ref_t retained = iree_vm_list_move_ref(item_list.steal_raw_ptr());
iree_vm_list_push_ref_move(list, &retained);
};
- AddPackCallback(py::type::of(py::list{}), sequence_callback);
- AddPackCallback(py::type::of(py::tuple{}), sequence_callback);
+ AddPackCallback((py::list{}).type(), sequence_callback);
+ AddPackCallback((create_empty_tuple()).type(), sequence_callback);
// Dict.
auto dict_callback = [this](InvokeContext &c, iree_vm_list_t *list,
@@ -483,11 +487,11 @@
iree_vm_ref_t retained = iree_vm_list_move_ref(item_list.steal_raw_ptr());
iree_vm_list_push_ref_move(list, &retained);
};
- AddPackCallback(py::type::of(py::dict{}), dict_callback);
+ AddPackCallback((py::dict{}).type(), dict_callback);
// HalBufferView.
AddPackCallback(
- py::type::of<HalBufferView>(),
+ py::type<HalBufferView>(),
[](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
HalBufferView *bv = py::cast<HalBufferView *>(py_value);
iree_vm_ref_t buffer_view_ref =
@@ -529,11 +533,12 @@
// Cached modules and types. Those that involve recursive lookup within
// our top level module, we defer. Those outside, we cache at creation.
- py::module numpy_module_ = py::module::import("numpy");
+ py::module_ numpy_module_ = py::module_::import_("numpy");
std::optional<py::object> runtime_module_;
- std::optional<py::module> array_interop_module_;
+ std::optional<py::module_> array_interop_module_;
std::optional<py::object> device_array_type_;
- py::type hal_buffer_view_type_ = py::type::of<HalBufferView>();
+ py::type_object hal_buffer_view_type_ =
+ py::cast<py::type_object>(py::type<HalBufferView>());
// Maps Python type to a PackCallback that can generically code it.
// This will have inc_ref() called on them when added.
@@ -594,12 +599,12 @@
// Dynamic dispatch.
if (dynamic_dispatch_) {
IREE_TRACE_SCOPE_NAMED("ArgumentPacker::PackDynamic");
- if (!kw_args.empty()) {
+ if (kw_args.size() != 0) {
throw std::invalid_argument(
"kwargs not supported for dynamic dispatch functions");
}
- VmVariantList arg_list = VmVariantList::Create(pos_args.size());
+ VmVariantList arg_list = VmVariantList::Create(py::len(pos_args));
for (py::handle py_arg : pos_args) {
PackCallback packer = statics_.GetGenericPackCallbackFor(py_arg);
if (!packer) {
@@ -618,11 +623,12 @@
// Reflection based dispatch.
std::vector<py::handle> py_args(flat_arg_packers_.size());
- if (pos_args.size() > pos_only_arg_count_) {
+ auto pos_args_size = py::len(pos_args);
+ if (pos_args_size > pos_only_arg_count_) {
std::string message("mismatched call arity: expected ");
message.append(std::to_string(pos_only_arg_count_));
message.append(" got ");
- message.append(std::to_string(pos_args.size()));
+ message.append(std::to_string(pos_args_size));
throw std::invalid_argument(std::move(message));
}
@@ -639,14 +645,14 @@
found_index = py::cast<int>(kwarg_to_index_[it.first]);
} catch (std::exception &) {
std::string message("specified kwarg '");
- message.append(py::cast<py::str>(it.first));
+ message.append(py::cast<std::string>(it.first));
message.append("' is unknown");
throw std::invalid_argument(std::move(message));
}
if (py_args[found_index]) {
std::string message(
"mismatched call arity: duplicate keyword argument '");
- message.append(py::cast<py::str>(it.first));
+ message.append(py::cast<std::string>(it.first));
message.append("'");
throw std::invalid_argument(std::move(message));
}
@@ -693,11 +699,12 @@
} // namespace
-void SetupInvokeBindings(pybind11::module &m) {
+void SetupInvokeBindings(nanobind::module_ &m) {
py::class_<InvokeStatics>(m, "_InvokeStatics");
py::class_<InvokeContext>(m, "InvokeContext").def(py::init<HalDevice &>());
py::class_<ArgumentPacker>(m, "ArgumentPacker")
- .def(py::init<InvokeStatics &, std::optional<py::list>>())
+ .def(py::init<InvokeStatics &, std::optional<py::list>>(),
+ py::arg("statics"), py::arg("arg_descs") = py::none())
.def("pack", &ArgumentPacker::Pack);
m.attr("_invoke_statics") = py::cast(InvokeStatics());
diff --git a/runtime/bindings/python/invoke.h b/runtime/bindings/python/invoke.h
index 206524f..4bc2de0 100644
--- a/runtime/bindings/python/invoke.h
+++ b/runtime/bindings/python/invoke.h
@@ -12,7 +12,7 @@
namespace iree {
namespace python {
-void SetupInvokeBindings(pybind11::module &m);
+void SetupInvokeBindings(py::module_ &m);
} // namespace python
} // namespace iree
diff --git a/runtime/bindings/python/iree/runtime/__init__.py b/runtime/bindings/python/iree/runtime/__init__.py
index bffc5fa..cb0bdd6 100644
--- a/runtime/bindings/python/iree/runtime/__init__.py
+++ b/runtime/bindings/python/iree/runtime/__init__.py
@@ -27,11 +27,11 @@
MemoryType,
PyModuleInterface,
Shape,
+ create_hal_module,
)
# Vm imports
from ._binding import (
- create_hal_module,
Linkage,
VmBuffer,
VmVariantList,
diff --git a/runtime/bindings/python/numpy_interop.cc b/runtime/bindings/python/numpy_interop.cc
new file mode 100644
index 0000000..85e5182
--- /dev/null
+++ b/runtime/bindings/python/numpy_interop.cc
@@ -0,0 +1,101 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "./numpy_interop.h"
+
+#include "./binding.h"
+
+#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
+#include "numpy/arrayobject.h"
+
+namespace iree::python::numpy {
+
+namespace {
+
+int internal_import_array() {
+ import_array1(-1);
+ return 0;
+}
+
+} // namespace
+
+void InitializeNumPyInterop() {
+ if (internal_import_array() < 0) {
+ throw py::import_error("numpy.core.multiarray failed to import");
+ }
+}
+
+int ConvertHalElementTypeToNumPyTypeNum(iree_hal_element_type_t t) {
+ switch (t) {
+ case IREE_HAL_ELEMENT_TYPE_BOOL_8:
+ return NPY_BOOL;
+ case IREE_HAL_ELEMENT_TYPE_INT_8:
+ case IREE_HAL_ELEMENT_TYPE_SINT_8:
+ return NPY_INT8;
+ case IREE_HAL_ELEMENT_TYPE_UINT_8:
+ return NPY_UINT8;
+ case IREE_HAL_ELEMENT_TYPE_INT_16:
+ case IREE_HAL_ELEMENT_TYPE_SINT_16:
+ return NPY_INT16;
+ case IREE_HAL_ELEMENT_TYPE_UINT_16:
+ return NPY_UINT16;
+ case IREE_HAL_ELEMENT_TYPE_INT_32:
+ case IREE_HAL_ELEMENT_TYPE_SINT_32:
+ return NPY_INT32;
+ case IREE_HAL_ELEMENT_TYPE_UINT_32:
+ return NPY_UINT32;
+ case IREE_HAL_ELEMENT_TYPE_INT_64:
+ case IREE_HAL_ELEMENT_TYPE_SINT_64:
+ return NPY_INT64;
+ case IREE_HAL_ELEMENT_TYPE_UINT_64:
+ return NPY_UINT64;
+ case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
+ return NPY_FLOAT16;
+ case IREE_HAL_ELEMENT_TYPE_FLOAT_32:
+ return NPY_FLOAT32;
+ case IREE_HAL_ELEMENT_TYPE_FLOAT_64:
+ return NPY_FLOAT64;
+ case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64:
+ return NPY_COMPLEX64;
+ case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128:
+ return NPY_COMPLEX128;
+ default:
+ throw py::value_error("Unsupported VM Buffer -> numpy dtype mapping");
+ }
+}
+
+py::object DescrNewFromType(int typenum) {
+ PyArray_Descr *dtype = PyArray_DescrNewFromType(typenum);
+ if (!dtype) {
+ throw py::python_error();
+ }
+ return py::steal((PyObject *)dtype);
+}
+
+int TypenumFromDescr(py::handle dtype) {
+ if (!PyArray_DescrCheck(dtype.ptr())) {
+ throw py::cast_error();
+ }
+ PyArray_Descr *descr = (PyArray_Descr *)dtype.ptr();
+ return descr->type_num;
+}
+
+py::object SimpleNewFromData(int nd, intptr_t const *dims, int typenum,
+ void *data, py::handle base_object) {
+ PyObject *array_c = PyArray_SimpleNewFromData(nd, dims, typenum, data);
+ if (!array_c) throw py::python_error();
+ py::object array = py::steal(array_c);
+ if (base_object) {
+ if (PyArray_SetBaseObject(reinterpret_cast<PyArrayObject *>(array.ptr()),
+ base_object.ptr())) {
+ throw py::python_error();
+ }
+ base_object.inc_ref();
+ }
+ return array;
+}
+
+} // namespace iree::python::numpy
diff --git a/runtime/bindings/python/numpy_interop.h b/runtime/bindings/python/numpy_interop.h
new file mode 100644
index 0000000..3084e14
--- /dev/null
+++ b/runtime/bindings/python/numpy_interop.h
@@ -0,0 +1,33 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_BINDINGS_PYTHON_NUMPY_INTEROP_H_
+#define IREE_BINDINGS_PYTHON_NUMPY_INTEROP_H_
+
+#include "./binding.h"
+#include "iree/hal/api.h"
+
+namespace iree::python::numpy {
+
+// Must be called in init of extension module.
+void InitializeNumPyInterop();
+
+// Converts an IREE element type to a NumPy NPY_TYPES value.
+int ConvertHalElementTypeToNumPyTypeNum(iree_hal_element_type_t t);
+
+// Wraps a call to PyArray_DescrNewFromType(int).
+py::object DescrNewFromType(int typenum);
+
+// Extracts a typenum from a dtype (descriptor) object.
+int TypenumFromDescr(py::handle dtype);
+
+// Delegates to PyArray_SimpleNewFromData and sets the base_object.
+py::object SimpleNewFromData(int nd, intptr_t const *dims, int typenum,
+ void *data, py::handle base_object);
+
+} // namespace iree::python::numpy
+
+#endif // IREE_BINDINGS_PYTHON_NUMPY_INTEROP_H_
diff --git a/runtime/bindings/python/py_module.cc b/runtime/bindings/python/py_module.cc
index f3b8b34..86eb1e1 100644
--- a/runtime/bindings/python/py_module.cc
+++ b/runtime/bindings/python/py_module.cc
@@ -7,6 +7,7 @@
#include "./py_module.h"
#include <string_view>
+#include <unordered_map>
#include "./vm.h"
@@ -370,7 +371,7 @@
// count.
VmRef py_ref;
iree_vm_ref_retain(&ref, &py_ref.ref());
- arguments.append(py::cast(py_ref, py::return_value_policy::move));
+ arguments.append(py::cast(py_ref, py::rv_policy::move));
packed_arguments += sizeof(iree_vm_ref_t);
break;
}
@@ -471,13 +472,13 @@
py::object retained_self_ref_;
};
-void SetupPyModuleBindings(py::module& m) {
+void SetupPyModuleBindings(py::module_& m) {
py::class_<PyModuleInterface>(m, "PyModuleInterface")
.def(py::init<std::string, py::object>(), py::arg("module_name"),
py::arg("ctor"))
.def("__str__", &PyModuleInterface::ToString)
- .def_property_readonly("initialized", &PyModuleInterface::initialized)
- .def_property_readonly("destroyed", &PyModuleInterface::destroyed)
+ .def_prop_ro("initialized", &PyModuleInterface::initialized)
+ .def_prop_ro("destroyed", &PyModuleInterface::destroyed)
.def("create", &PyModuleInterface::Create)
.def("export", &PyModuleInterface::ExportFunction, py::arg("name"),
py::arg("cconv"), py::arg("callable"));
diff --git a/runtime/bindings/python/py_module.h b/runtime/bindings/python/py_module.h
index ee04d7e..30b836c 100644
--- a/runtime/bindings/python/py_module.h
+++ b/runtime/bindings/python/py_module.h
@@ -13,7 +13,7 @@
namespace iree::python {
-void SetupPyModuleBindings(py::module &m);
+void SetupPyModuleBindings(py::module_ &m);
} // namespace iree::python
diff --git a/runtime/bindings/python/status_utils.cc b/runtime/bindings/python/status_utils.cc
index a97d0c6..36941c2 100644
--- a/runtime/bindings/python/status_utils.cc
+++ b/runtime/bindings/python/status_utils.cc
@@ -43,8 +43,8 @@
} // namespace
-pybind11::error_already_set ApiStatusToPyExc(iree_status_t status,
- const char* message) {
+nanobind::python_error ApiStatusToPyExc(iree_status_t status,
+ const char* message) {
assert(!iree_status_is_ok(status));
std::string full_message;
@@ -58,13 +58,12 @@
PyErr_SetString(ApiStatusToPyExcClass(status), full_message.c_str());
iree_status_ignore(status);
- return pybind11::error_already_set();
+ return nanobind::python_error();
}
-pybind11::error_already_set RaisePyError(PyObject* exc_class,
- const char* message) {
+nanobind::python_error RaisePyError(PyObject* exc_class, const char* message) {
PyErr_SetString(exc_class, message);
- return pybind11::error_already_set();
+ return nanobind::python_error();
}
} // namespace python
diff --git a/runtime/bindings/python/status_utils.h b/runtime/bindings/python/status_utils.h
index d87d308..612f335 100644
--- a/runtime/bindings/python/status_utils.h
+++ b/runtime/bindings/python/status_utils.h
@@ -7,8 +7,9 @@
#ifndef IREE_BINDINGS_PYTHON_IREE_COMMON_STATUS_UTILS_H_
#define IREE_BINDINGS_PYTHON_IREE_COMMON_STATUS_UTILS_H_
+#include <nanobind/nanobind.h>
+
#include "iree/base/api.h"
-#include "pybind11/pybind11.h"
namespace iree {
namespace python {
@@ -16,18 +17,17 @@
// Raises a value error with the given message.
// Correct usage:
// throw RaiseValueError(PyExc_ValueError, "Foobar'd");
-pybind11::error_already_set RaisePyError(PyObject* exc_class,
- const char* message);
+nanobind::python_error RaisePyError(PyObject* exc_class, const char* message);
// Raises a value error with the given message.
// Correct usage:
// throw RaiseValueError("Foobar'd");
-inline pybind11::error_already_set RaiseValueError(const char* message) {
+inline nanobind::python_error RaiseValueError(const char* message) {
return RaisePyError(PyExc_ValueError, message);
}
-pybind11::error_already_set ApiStatusToPyExc(iree_status_t status,
- const char* message);
+nanobind::python_error ApiStatusToPyExc(iree_status_t status,
+ const char* message);
inline void CheckApiStatus(iree_status_t status, const char* message) {
if (iree_status_is_ok(status)) {
diff --git a/runtime/bindings/python/vm.cc b/runtime/bindings/python/vm.cc
index 7a8e0dd..18c8e8d 100644
--- a/runtime/bindings/python/vm.cc
+++ b/runtime/bindings/python/vm.cc
@@ -6,6 +6,11 @@
#include "./vm.h"
+#include <ios>
+#include <sstream>
+#include <unordered_set>
+
+#include "./buffer_interop.h"
#include "./status_utils.h"
#include "iree/base/api.h"
@@ -15,9 +20,8 @@
#include "iree/modules/hal/module.h"
#include "iree/tooling/modules/resolver.h"
#include "iree/vm/api.h"
-#include "pybind11/numpy.h"
-using namespace pybind11::literals;
+using namespace nanobind::literals;
namespace iree {
namespace python {
@@ -152,7 +156,7 @@
//------------------------------------------------------------------------------
VmContext VmContext::Create(VmInstance* instance,
- std::optional<std::vector<VmModule*>> modules) {
+ std::optional<std::vector<VmModule*>>& modules) {
IREE_TRACE_SCOPE_NAMED("VmContext::Create");
iree_vm_context_t* context;
if (!modules) {
@@ -228,8 +232,8 @@
VmModule VmModule::MMap(VmInstance* instance, std::string filepath,
py::object destroy_callback) {
IREE_TRACE_SCOPE_NAMED("VmModule::MMap");
- auto mmap_module = py::module::import("mmap");
- auto open_func = py::module::import("io").attr("open");
+ auto mmap_module = py::module_::import_("mmap");
+ auto open_func = py::module_::import_("io").attr("open");
auto file_obj = open_func(filepath, "r+b");
// The signature of mmap is different on Windows vs others. On others,
// we use explicit flags and protection attributes for better control,
@@ -267,13 +271,12 @@
VmModule VmModule::WrapBuffer(VmInstance* instance, py::object buffer_obj,
py::object destroy_callback, bool close_buffer) {
IREE_TRACE_SCOPE_NAMED("VmModule::FromAlignedMemory");
- auto py_buffer = py::cast<py::buffer>(buffer_obj);
- auto buffer_info = py_buffer.request();
- if (!iree_host_size_has_alignment((uintptr_t)buffer_info.ptr,
+ PyBufferRequest buffer_info(buffer_obj, PyBUF_SIMPLE);
+ if (!iree_host_size_has_alignment((uintptr_t)buffer_info.view().buf,
IREE_HAL_HEAP_BUFFER_ALIGNMENT)) {
std::stringstream err;
err << "VmModule.from_aligned_memory received an unaligned buffer. ";
- err << "Got 0x" << (void*)buffer_info.ptr << ", expected alignment ";
+ err << "Got 0x" << (void*)buffer_info.view().buf << ", expected alignment ";
err << IREE_HAL_HEAP_BUFFER_ALIGNMENT;
throw std::invalid_argument(err.str());
}
@@ -329,8 +332,8 @@
auto status = iree_vm_bytecode_module_create(
instance->raw_ptr(),
- {static_cast<const uint8_t*>(buffer_info.ptr),
- static_cast<iree_host_size_t>(buffer_info.size)},
+ {static_cast<const uint8_t*>(buffer_info.view().buf),
+ static_cast<iree_host_size_t>(buffer_info.view().len)},
deallocator, iree_allocator_system(), &module);
if (!iree_status_is_ok(status)) {
delete state;
@@ -347,31 +350,33 @@
VmModule VmModule::CopyBuffer(VmInstance* instance, py::object buffer_obj) {
IREE_TRACE_SCOPE_NAMED("VmModule::CopyBuffer");
auto alignment =
- py::cast<uintptr_t>(py::module::import("mmap").attr("PAGESIZE"));
- auto bytearray_ctor = py::module::import("builtins").attr("bytearray");
- auto src_buffer = py::cast<py::buffer>(buffer_obj);
- auto src_buffer_info = src_buffer.request();
- py::ssize_t src_buffer_size = src_buffer_info.itemsize * src_buffer_info.size;
+ py::cast<uintptr_t>(py::module_::import_("mmap").attr("PAGESIZE"));
+ auto bytearray_ctor = py::module_::import_("builtins").attr("bytearray");
+ PyBufferRequest src_buffer_info(buffer_obj, PyBUF_SIMPLE);
+ auto src_buffer_size = src_buffer_info.view().len;
// Need to allocate an extra page because there is no control at the Python
// level for the alignment it may have.
- auto dst_buffer =
- py::cast<py::buffer>(bytearray_ctor(src_buffer_size + alignment));
- auto dst_buffer_info = dst_buffer.request();
+ auto dst_buffer = bytearray_ctor(src_buffer_size + alignment);
+ PyBufferRequest dst_buffer_info(dst_buffer, PyBUF_SIMPLE);
void* dst_aligned =
- (void*)iree_host_align((uintptr_t)dst_buffer_info.ptr, alignment);
+ (void*)iree_host_align((uintptr_t)dst_buffer_info.view().buf, alignment);
uintptr_t dst_offset =
- (uintptr_t)dst_aligned - (uintptr_t)dst_buffer_info.ptr;
+ (uintptr_t)dst_aligned - (uintptr_t)dst_buffer_info.view().buf;
// Now create a memoryview over the unaligned bytearray and slice into that
// to get the aligned Python buffer.
- auto dst_slice = py::slice(dst_offset, dst_offset + src_buffer_size, 1);
- py::object dst_view = py::memoryview(dst_buffer);
+ auto dst_slice =
+ py::slice(py::cast(dst_offset), py::cast(dst_offset + src_buffer_size),
+ py::cast(1));
+
+ py::object dst_view = py::steal<py::object>(
+ PyMemoryView_GetContiguous(dst_buffer.ptr(), PyBUF_READ, 'C'));
py::object dst_view_aligned = dst_view[dst_slice];
// If any of the indexing math was wrong, Python exceptions will be raised
// above, so this is implicitly guarding the memcpy if it is done last.
- std::memcpy(dst_aligned, src_buffer_info.ptr, src_buffer_size);
+ std::memcpy(dst_aligned, src_buffer_info.view().buf, src_buffer_size);
return WrapBuffer(instance, std::move(dst_view_aligned),
/*destroy_callback=*/py::none(),
/*close_buffer=*/false);
@@ -380,15 +385,15 @@
VmModule VmModule::FromBuffer(VmInstance* instance, py::object buffer_obj,
bool warn_if_copy) {
IREE_TRACE_SCOPE_NAMED("VmModule::FromBuffer");
- auto py_buffer = py::cast<py::buffer>(buffer_obj);
- auto buffer_info = py_buffer.request();
- if (iree_host_size_has_alignment((uintptr_t)buffer_info.ptr,
+ PyBufferRequest buffer_info(buffer_obj, PyBUF_SIMPLE);
+
+ if (iree_host_size_has_alignment((uintptr_t)buffer_info.view().buf,
IREE_HAL_HEAP_BUFFER_ALIGNMENT)) {
return WrapBuffer(instance, std::move(buffer_obj),
/*destroy_callback=*/py::none(), /*close_buffer=*/false);
} else {
if (warn_if_copy) {
- py::module::import("warnings")
+ py::module_::import_("warnings")
.attr("warn")(
"Making copy of unaligned VmModule buffer. It is recommended to "
"make this deterministic by calling `copy_buffer` to always make "
@@ -423,7 +428,7 @@
const char* const VmRef::kTypeAttr = "__iree_vm_type__";
py::object VmRef::Deref(py::object ref_object_class, bool optional) {
- py::object casted = ref_object_class.attr(kCastAttr)(*this);
+ py::object casted = ref_object_class.attr(kCastAttr)(this);
if (!optional && casted.is_none()) {
throw py::type_error("Cannot dereference to specific type");
}
@@ -517,7 +522,7 @@
} else if (iree_vm_variant_is_ref(v)) {
VmRef ref;
iree_vm_ref_retain(&v.ref, &ref.ref());
- return py::cast(ref, py::return_value_policy::move);
+ return py::cast(ref, py::rv_policy::move);
}
throw RaiseValueError("Unsupported VM to Python Type Conversion");
@@ -642,7 +647,7 @@
}
VmRef ref;
iree_vm_ref_retain(&v.ref, &ref.ref());
- return py::cast(ref, py::return_value_policy::move);
+ return py::cast(ref, py::rv_policy::move);
}
py::object VmVariantList::GetAsObject(int index, py::object clazz) {
@@ -763,7 +768,7 @@
return s;
}
-void SetupVmBindings(pybind11::module m) {
+void SetupVmBindings(nanobind::module_ m) {
py::enum_<enum iree_vm_function_linkage_e>(m, "Linkage")
.value("INTERNAL", IREE_VM_FUNCTION_LINKAGE_INTERNAL)
.value("IMPORT", IREE_VM_FUNCTION_LINKAGE_IMPORT)
@@ -771,37 +776,67 @@
.value("EXPORT", IREE_VM_FUNCTION_LINKAGE_EXPORT)
.export_values();
- auto vm_buffer = py::class_<VmBuffer>(m, "VmBuffer", py::buffer_protocol());
+ auto vm_buffer = py::class_<VmBuffer>(m, "VmBuffer");
VmRef::BindRefProtocol(vm_buffer, iree_vm_buffer_type,
iree_vm_buffer_retain_ref, iree_vm_buffer_deref,
iree_vm_buffer_isa);
+ // Implement the buffer protocol with low-level API.
+ {
+ static PyBufferProcs buffer_procs = {
+ // It is not legal to raise exceptions from these callbacks.
+ +[](PyObject* raw_self, Py_buffer* view, int flags) -> int {
+ // Cast must succeed due to invariants.
+ auto self = py::cast<VmBuffer*>(py::handle(raw_self));
+ if (view == NULL) {
+ PyErr_SetString(PyExc_ValueError, "NULL view in getbuffer");
+ return -1;
+ }
+
+ Py_INCREF(raw_self);
+ view->obj = raw_self;
+ view->buf = self->raw_ptr()->data.data;
+ view->len = self->raw_ptr()->data.data_length;
+ view->readonly =
+ !(self->raw_ptr()->access & IREE_VM_BUFFER_ACCESS_MUTABLE);
+ view->itemsize = 1;
+ view->format = (char*)"B"; // Byte
+ view->ndim = 1;
+ view->shape = nullptr;
+ view->strides = nullptr;
+ view->suboffsets = nullptr;
+ view->internal = nullptr;
+ return 0;
+ },
+ +[](PyObject* self_obj, Py_buffer* view) -> void {
+
+ },
+ };
+ auto heap_type = reinterpret_cast<PyHeapTypeObject*>(vm_buffer.ptr());
+ assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE &&
+ "must be heap type");
+ heap_type->as_buffer = buffer_procs;
+ }
+
vm_buffer
- .def(py::init([](iree_host_size_t length, iree_host_size_t alignment,
- bool is_mutable) {
- iree_vm_buffer_access_t access = 0;
- if (is_mutable) {
- access |= IREE_VM_BUFFER_ACCESS_MUTABLE;
- }
- iree_vm_buffer_t* raw_buffer;
- CheckApiStatus(
- iree_vm_buffer_create(access, length, alignment,
- iree_allocator_system(), &raw_buffer),
- "Error creating buffer");
- return VmBuffer::StealFromRawPtr(raw_buffer);
- }),
- py::arg("length"), py::arg("alignment") = 0,
- py::arg("mutable") = true)
- .def_buffer([](VmBuffer& self) -> py::buffer_info {
- return py::buffer_info(
- /*ptr=*/self.raw_ptr()->data.data,
- /*itemsize=*/sizeof(uint8_t),
- /*format=*/py::format_descriptor<uint8_t>::format(),
- /*ndim=*/1,
- /*shape=*/{self.raw_ptr()->data.data_length},
- /*strides=*/{1},
- /*readonly=*/
- !(self.raw_ptr()->access & IREE_VM_BUFFER_ACCESS_MUTABLE));
- })
+ .def(
+ "__init__",
+ [](VmBuffer* self, iree_host_size_t length,
+ iree_host_size_t alignment, bool is_mutable) {
+ iree_vm_buffer_access_t access = 0;
+ if (is_mutable) {
+ access |= IREE_VM_BUFFER_ACCESS_MUTABLE;
+ }
+ iree_vm_buffer_t* raw_buffer;
+ CheckApiStatus(
+ iree_vm_buffer_create(access, length, alignment,
+ iree_allocator_system(), &raw_buffer),
+ "Error creating buffer");
+
+ new (self) VmBuffer();
+ *self = VmBuffer::StealFromRawPtr(raw_buffer);
+ },
+ py::arg("length"), py::arg("alignment") = 0,
+ py::arg("mutable") = true)
.def("__repr__", [](VmBuffer& self) {
std::stringstream ss;
ss << "<VmBuffer size " << self.raw_ptr()->data.data_length << " at 0x"
@@ -816,8 +851,14 @@
iree_vm_list_deref, iree_vm_list_isa);
vm_list
// User Methods.
- .def(py::init(&VmVariantList::Create))
- .def_property_readonly("size", &VmVariantList::size)
+ .def(
+ "__init__",
+ [](VmVariantList* self, iree_host_size_t capacity) {
+ new (self) VmVariantList();
+ *self = VmVariantList::Create(capacity);
+ },
+ py::arg("capacity"))
+ .def_prop_ro("size", &VmVariantList::size)
.def("__len__", &VmVariantList::size)
.def("get_as_ref", &VmVariantList::GetAsRef)
.def("get_as_object", &VmVariantList::GetAsObject)
@@ -832,24 +873,22 @@
.def("__repr__", &VmVariantList::DebugString);
py::class_<iree_vm_function_t>(m, "VmFunction")
- .def_readonly("linkage", &iree_vm_function_t::linkage)
- .def_readonly("ordinal", &iree_vm_function_t::ordinal)
- .def_property_readonly("name",
- [](iree_vm_function_t& self) {
- iree_string_view_t name =
- iree_vm_function_name(&self);
- return py::str(name.data, name.size);
- })
- .def_property_readonly("module_name",
- [](iree_vm_function_t& self) {
- iree_string_view_t name =
- iree_vm_module_name(self.module);
- return py::str(name.data, name.size);
- })
- .def_property_readonly("reflection",
- [](iree_vm_function_t& self) {
- return GetFunctionReflectionDict(self);
- })
+ .def_ro("linkage", &iree_vm_function_t::linkage)
+ .def_ro("ordinal", &iree_vm_function_t::ordinal)
+ .def_prop_ro("name",
+ [](iree_vm_function_t& self) {
+ iree_string_view_t name = iree_vm_function_name(&self);
+ return py::str(name.data, name.size);
+ })
+ .def_prop_ro("module_name",
+ [](iree_vm_function_t& self) {
+ iree_string_view_t name = iree_vm_module_name(self.module);
+ return py::str(name.data, name.size);
+ })
+ .def_prop_ro("reflection",
+ [](iree_vm_function_t& self) {
+ return GetFunctionReflectionDict(self);
+ })
.def("__repr__", [](iree_vm_function_t& self) {
iree_string_view_t name = iree_vm_function_name(&self);
std::string repr("<VmFunction ");
@@ -865,13 +904,22 @@
return repr;
});
- py::class_<VmInstance>(m, "VmInstance").def(py::init(&VmInstance::Create));
-
+ py::class_<VmInstance>(m, "VmInstance").def("__init__", [](VmInstance* self) {
+ new (self) VmInstance();
+ *self = VmInstance::Create();
+ });
py::class_<VmContext>(m, "VmContext")
- .def(py::init(&VmContext::Create), py::arg("instance"),
- py::arg("modules") = std::optional<std::vector<VmModule*>>())
+ .def(
+ "__init__",
+ [](VmContext* self, VmInstance* instance,
+ std::optional<std::vector<VmModule*>> modules) {
+ new (self) VmContext();
+ *self = VmContext::Create(instance, modules);
+ },
+ py::arg("instance"),
+ py::arg("modules") = std::optional<std::vector<VmModule*>>())
.def("register_modules", &VmContext::RegisterModules)
- .def_property_readonly("context_id", &VmContext::context_id)
+ .def_prop_ro("context_id", &VmContext::context_id)
.def("invoke", &VmContext::Invoke);
py::class_<VmModule>(m, "VmModule")
@@ -891,40 +939,40 @@
.def_static("mmap", &VmModule::MMap, py::arg("instance"),
py::arg("filepath"), py::arg("destroy_callback") = py::none(),
kMMapDocstring)
- .def_property_readonly("name", &VmModule::name)
- .def_property_readonly("version",
- [](VmModule& self) {
- iree_vm_module_signature_t sig =
- iree_vm_module_signature(self.raw_ptr());
- return sig.version;
- })
+ .def_prop_ro("name", &VmModule::name)
+ .def_prop_ro("version",
+ [](VmModule& self) {
+ iree_vm_module_signature_t sig =
+ iree_vm_module_signature(self.raw_ptr());
+ return sig.version;
+ })
.def("lookup_function", &VmModule::LookupFunction, py::arg("name"),
py::arg("linkage") = IREE_VM_FUNCTION_LINKAGE_EXPORT)
- .def_property_readonly(
+ .def_prop_ro(
"stashed_flatbuffer_blob",
[](VmModule& self) { return self.get_stashed_flatbuffer_blob(); })
- .def_property_readonly(
- "function_names",
- [](VmModule& self) {
- py::list names;
- iree_vm_module_signature_t sig =
- iree_vm_module_signature(self.raw_ptr());
- for (size_t ordinal = 0; ordinal < sig.export_function_count;
- ++ordinal) {
- iree_vm_function_t f;
- auto status = iree_vm_module_lookup_function_by_ordinal(
- self.raw_ptr(), IREE_VM_FUNCTION_LINKAGE_EXPORT, ordinal, &f);
- if (iree_status_is_not_found(status)) {
- iree_status_ignore(status);
- break;
- }
- CheckApiStatus(status, "Error enumerating module");
- iree_string_view_t fname = iree_vm_function_name(&f);
- py::str name(fname.data, fname.size);
- names.append(name);
- }
- return names;
- })
+ .def_prop_ro("function_names",
+ [](VmModule& self) {
+ py::list names;
+ iree_vm_module_signature_t sig =
+ iree_vm_module_signature(self.raw_ptr());
+ for (size_t ordinal = 0;
+ ordinal < sig.export_function_count; ++ordinal) {
+ iree_vm_function_t f;
+ auto status = iree_vm_module_lookup_function_by_ordinal(
+ self.raw_ptr(), IREE_VM_FUNCTION_LINKAGE_EXPORT,
+ ordinal, &f);
+ if (iree_status_is_not_found(status)) {
+ iree_status_ignore(status);
+ break;
+ }
+ CheckApiStatus(status, "Error enumerating module");
+ iree_string_view_t fname = iree_vm_function_name(&f);
+ py::str name(fname.data, fname.size);
+ names.append(name);
+ }
+ return names;
+ })
.def("__repr__", [](VmModule& self) {
std::string repr("<VmModule ");
iree_string_view_t name = iree_vm_module_name(self.raw_ptr());
@@ -959,8 +1007,7 @@
.def("deref", &VmRef::Deref, py::arg("value"),
py::arg("optional") = false)
.def("__repr__", &VmRef::ToString)
- .def_property_readonly(VmRef::kRefAttr,
- [](py::object self) { return self; })
+ .def_prop_ro(VmRef::kRefAttr, [](py::object self) { return self; })
.def("__eq__",
[](VmRef& self, VmRef& other) {
return self.ref().ptr == other.ref().ptr;
diff --git a/runtime/bindings/python/vm.h b/runtime/bindings/python/vm.h
index f52ef2a..ef5d8ea 100644
--- a/runtime/bindings/python/vm.h
+++ b/runtime/bindings/python/vm.h
@@ -171,7 +171,7 @@
// static, disallowing further module registration (and may be more
// efficient).
static VmContext Create(VmInstance* instance,
- std::optional<std::vector<VmModule*>> modules);
+ std::optional<std::vector<VmModule*>>& modules);
// Registers additional modules. Only valid for non static contexts (i.e.
// those created without modules.
@@ -224,20 +224,20 @@
static void BindRefProtocol(PyClass& cls, TypeFunctor type,
RetainRefFunctor retain_ref, DerefFunctor deref,
IsaFunctor isa) {
- using WrapperType = typename PyClass::type;
+ using WrapperType = typename PyClass::Type;
using RawPtrType = typename WrapperType::RawPtrType;
auto ref_lambda = [=](WrapperType& self) {
return VmRef::Steal(retain_ref(self.raw_ptr()));
};
cls.def_static(VmRef::kTypeAttr, [=]() { return type(); });
- cls.def_property_readonly(VmRef::kRefAttr, ref_lambda);
- cls.def_property_readonly("ref", ref_lambda);
+ cls.def_prop_ro(VmRef::kRefAttr, ref_lambda);
+ cls.def_prop_ro("ref", ref_lambda);
cls.def_static(VmRef::kCastAttr, [=](VmRef& ref) -> py::object {
if (!isa(ref.ref())) {
return py::none();
}
return py::cast(WrapperType::BorrowFromRawPtr(deref(ref.ref())),
- py::return_value_policy::move);
+ py::rv_policy::move);
});
cls.def("__eq__", [](WrapperType& self, WrapperType& other) {
return self.raw_ptr() == other.raw_ptr();
@@ -275,7 +275,7 @@
iree_vm_ref_t ref_;
};
-void SetupVmBindings(pybind11::module m);
+void SetupVmBindings(nanobind::module_ m);
} // namespace python
} // namespace iree