Make HalBufferView a VM Ref object in Python. (#9965)
Previously, the VM was depending on HAL, which made this impossible.
Severs the dependency and make VmVariantList operate on generic
reference types. Refactors call sites to cast explicitly to their
expected types.
Eliminates APIs:
* VmVariantList.get_as_buffer_view() (use get_as_ref() or
get_as_object())
* VmVariantList.push_buffer_view() (use push_ref()).
Changed APIs:
* VmVariantList.get_variant() now returns a VmRef instead of a casted
object.
All of these are fairly internal APIs that users should not have likely
used yet.
* Expose VM buffers to Python.
diff --git a/runtime/bindings/python/vm.cc b/runtime/bindings/python/vm.cc
index 5ffeede..349fa9e 100644
--- a/runtime/bindings/python/vm.cc
+++ b/runtime/bindings/python/vm.cc
@@ -10,7 +10,9 @@
#include "iree/base/api.h"
#include "iree/base/status_cc.h"
#include "iree/base/tracing.h"
-#include "iree/hal/api.h"
+// TODO: We shouldn't need the HAL API but it is used for direct printing
+// summaries of HAL objects in lists. We should have a better way of doing this
+// dynamically vs hard depending on a type switch here.
#include "iree/modules/hal/module.h"
#include "iree/vm/api.h"
#include "pybind11/numpy.h"
@@ -20,15 +22,6 @@
namespace {
-VmModule CreateHalModule(HalDevice* device) {
- iree_vm_module_t* module;
- CheckApiStatus(
- iree_hal_module_create(device->raw_ptr(), IREE_HAL_MODULE_FLAG_NONE,
- iree_allocator_system(), &module),
- "Error creating hal module");
- return VmModule::StealFromRawPtr(module);
-}
-
// RAII wrapper for a Py_buffer which calls PyBuffer_Release when it goes
// out of scope.
class PyBufferReleaser {
@@ -182,8 +175,12 @@
const char* const VmRef::kCastAttr = "__iree_vm_cast__";
const char* const VmRef::kTypeIdAttr = "__iree_vm_type_id__";
-py::object VmRef::Deref(py::object ref_object_class) {
- return ref_object_class.attr(kCastAttr)(*this);
+py::object VmRef::Deref(py::object ref_object_class, bool optional) {
+ py::object casted = ref_object_class.attr(kCastAttr)(*this);
+ if (!optional && casted.is_none()) {
+ throw py::type_error("Cannot dereference to specific type");
+ }
+ return casted;
}
bool VmRef::IsInstance(py::object ref_object_class) {
@@ -229,11 +226,11 @@
iree_vm_list_push_ref_move(raw_ptr(), &retained);
}
-void VmVariantList::PushBufferView(HalBufferView& buffer_view) {
- iree_vm_ref_t buffer_view_ref =
- iree_hal_buffer_view_retain_ref(buffer_view.raw_ptr());
- CheckApiStatus(iree_vm_list_push_ref_move(raw_ptr(), &buffer_view_ref),
- "Error moving buffer view");
+void VmVariantList::PushRef(py::handle ref_or_object) {
+ py::object py_ref = ref_or_object.attr(VmRef::kRefAttr);
+ VmRef& ref = py::cast<VmRef&>(py_ref);
+ CheckApiStatus(iree_vm_list_push_ref_retain(raw_ptr(), &ref.ref()),
+ "Failed to push ref");
}
py::object VmVariantList::GetAsList(int index) {
@@ -271,13 +268,10 @@
}
} else if (v.type.ref_type == IREE_VM_REF_TYPE_NULL) {
return py::none();
- } else if (iree_vm_type_def_is_ref(&v.type)) {
- // Convert reference type.
- if (iree_vm_list_isa(v.ref)) {
- return GetAsList(index);
- } else if (iree_hal_buffer_view_isa(v.ref)) {
- return GetAsBufferView(index);
- }
+ } else if (iree_vm_variant_is_ref(v)) {
+ VmRef ref;
+ iree_vm_ref_retain(&v.ref, &ref.ref());
+ return py::cast(ref, py::return_value_policy::move);
}
throw RaiseValueError("Unsupported VM to Python Type Conversion");
@@ -387,16 +381,20 @@
throw RaiseValueError("Unsupported VM to Python Type Conversion");
}
-py::object VmVariantList::GetAsBufferView(int index) {
+py::object VmVariantList::GetAsRef(int index) {
iree_vm_variant_t v = iree_vm_variant_empty();
CheckApiStatus(iree_vm_list_get_variant(raw_ptr(), index, &v),
"Could not access list element");
- iree_hal_buffer_view_t* buffer_view = iree_hal_buffer_view_deref(v.ref);
- if (!buffer_view) {
- throw RaiseValueError("Could not deref result buffer view (wrong type?)");
+ if (!iree_vm_variant_is_ref(v)) {
+ throw std::invalid_argument("list element is not a ref");
}
- return py::cast(HalBufferView::BorrowFromRawPtr(buffer_view),
- py::return_value_policy::move);
+ VmRef ref;
+ iree_vm_ref_retain(&v.ref, &ref.ref());
+ return py::cast(ref, py::return_value_policy::move);
+}
+
+py::object VmVariantList::GetAsObject(int index, py::object clazz) {
+ return clazz.attr(VmRef::kCastAttr)(GetAsRef(index));
}
namespace {
@@ -514,10 +512,6 @@
void SetupVmBindings(pybind11::module m) {
IREE_CHECK_OK(iree_vm_register_builtin_types());
- IREE_CHECK_OK(iree_hal_module_register_all_types());
-
- // Built-in module creation.
- m.def("create_hal_module", &CreateHalModule);
py::enum_<enum iree_vm_function_linkage_e>(m, "Linkage")
.value("INTERNAL", IREE_VM_FUNCTION_LINKAGE_INTERNAL)
@@ -526,16 +520,54 @@
.value("EXPORT", IREE_VM_FUNCTION_LINKAGE_EXPORT)
.export_values();
+ auto vm_buffer = py::class_<VmBuffer>(m, "VmBuffer", py::buffer_protocol());
+ VmRef::BindRefProtocol(vm_buffer, iree_vm_buffer_type_id,
+ iree_vm_buffer_retain_ref, iree_vm_buffer_deref,
+ iree_vm_buffer_isa);
+ vm_buffer
+ .def(py::init([](iree_host_size_t length, bool is_mutable) {
+ iree_vm_buffer_access_t access = 0;
+ if (is_mutable) {
+ access |= IREE_VM_BUFFER_ACCESS_MUTABLE;
+ }
+ iree_vm_buffer_t* raw_buffer;
+ CheckApiStatus(
+ iree_vm_buffer_create(access, length, iree_allocator_system(),
+ &raw_buffer),
+ "Error creating buffer");
+ return VmBuffer::StealFromRawPtr(raw_buffer);
+ }),
+ py::arg("length"), py::arg("mutable") = true)
+ .def_buffer([](VmBuffer& self) -> py::buffer_info {
+ return py::buffer_info(
+ /*ptr=*/self.raw_ptr()->data.data,
+ /*itemsize=*/sizeof(uint8_t),
+ /*format=*/py::format_descriptor<uint8_t>::format(),
+ /*ndim=*/1,
+ /*shape=*/{self.raw_ptr()->data.data_length},
+ /*strides=*/{1},
+ /*readonly=*/
+ !(self.raw_ptr()->access & IREE_VM_BUFFER_ACCESS_MUTABLE));
+ })
+ .def("__repr__", [](VmBuffer& self) {
+ std::stringstream ss;
+ ss << "<VmBuffer size " << self.raw_ptr()->data.data_length << " at 0x"
+ << std::hex << reinterpret_cast<uintptr_t>(self.raw_ptr()->data.data)
+ << ">";
+ return ss.str();
+ });
+
// Mutation and inspection of the variant list is mostly opaque to python.
auto vm_list = py::class_<VmVariantList>(m, "VmVariantList");
VmRef::BindRefProtocol(vm_list, iree_vm_list_type_id, iree_vm_list_retain_ref,
- iree_vm_list_check_deref);
+ iree_vm_list_deref, iree_vm_list_isa);
vm_list
// User Methods.
.def(py::init(&VmVariantList::Create))
.def_property_readonly("size", &VmVariantList::size)
.def("__len__", &VmVariantList::size)
- .def("get_as_buffer_view", &VmVariantList::GetAsBufferView)
+ .def("get_as_ref", &VmVariantList::GetAsRef)
+ .def("get_as_object", &VmVariantList::GetAsObject)
.def("get_as_list", &VmVariantList::GetAsList)
.def("get_variant", &VmVariantList::GetVariant)
.def("get_serialized_trace_value",
@@ -543,7 +575,7 @@
.def("push_float", &VmVariantList::PushFloat)
.def("push_int", &VmVariantList::PushInt)
.def("push_list", &VmVariantList::PushList)
- .def("push_buffer_view", &VmVariantList::PushBufferView)
+ .def("push_ref", &VmVariantList::PushRef)
.def("__repr__", &VmVariantList::DebugString);
py::class_<iree_vm_function_t>(m, "VmFunction")
@@ -650,8 +682,11 @@
py::class_<VmRef>(m, "VmRef")
.def("isinstance", &VmRef::IsInstance)
- .def("deref", &VmRef::Deref)
+ .def("deref", &VmRef::Deref, py::arg("value"),
+ py::arg("optional") = false)
.def("__repr__", &VmRef::ToString)
+ .def_property_readonly(VmRef::kRefAttr,
+ [](py::object self) { return self; })
.def("__eq__",
[](VmRef& self, VmRef& other) {
return self.ref().ptr == other.ref().ptr;