Make HalBufferView a VM Ref object in Python. (#9965)
Previously, the VM was depending on HAL, which made this impossible.
Severs the dependency and make VmVariantList operate on generic
reference types. Refactors call sites to cast explicitly to their
expected types.
Eliminates APIs:
* VmVariantList.get_as_buffer_view() (use get_as_ref() or
get_as_object())
* VmVariantList.push_buffer_view() (use push_ref()).
Changed APIs:
* VmVariantList.get_variant() now returns a VmRef instead of a casted
object.
All of these are fairly internal APIs that users should not have likely
used yet.
* Expose VM buffers to Python.
diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc
index 71a0270..6e8c18a 100644
--- a/runtime/bindings/python/hal.cc
+++ b/runtime/bindings/python/hal.cc
@@ -6,9 +6,11 @@
#include "./hal.h"
+#include "./vm.h"
#include "iree/base/internal/path.h"
#include "iree/base/tracing.h"
#include "iree/hal/api.h"
+#include "iree/modules/hal/module.h"
#include "pybind11/numpy.h"
namespace iree {
@@ -382,12 +384,30 @@
} // namespace
//------------------------------------------------------------------------------
+// HAL module
+//------------------------------------------------------------------------------
+
+VmModule CreateHalModule(HalDevice* device) {
+ iree_vm_module_t* module;
+ CheckApiStatus(
+ iree_hal_module_create(device->raw_ptr(), IREE_HAL_MODULE_FLAG_NONE,
+ iree_allocator_system(), &module),
+ "Error creating hal module");
+ return VmModule::StealFromRawPtr(module);
+}
+
+//------------------------------------------------------------------------------
// Bindings
//------------------------------------------------------------------------------
void SetupHalBindings(pybind11::module m) {
py::dict driver_cache;
+ IREE_CHECK_OK(iree_hal_module_register_all_types());
+
+ // Built-in module creation.
+ m.def("create_hal_module", &CreateHalModule);
+
// Enums.
py::enum_<enum iree_hal_memory_type_bits_t>(m, "MemoryType")
.value("NONE", IREE_HAL_MEMORY_TYPE_NONE)
@@ -606,8 +626,11 @@
py::arg("element_size"), py::keep_alive<0, 1>())
.def("__repr__", &HalBuffer::Repr);
- py::class_<HalBufferView>(m, "HalBufferView")
- .def("map", HalMappedMemory::Create, py::keep_alive<0, 1>())
+ auto hal_buffer_view = py::class_<HalBufferView>(m, "HalBufferView");
+ VmRef::BindRefProtocol(hal_buffer_view, iree_hal_buffer_view_type_id,
+ iree_hal_buffer_view_retain_ref,
+ iree_hal_buffer_view_deref, iree_hal_buffer_view_isa);
+ hal_buffer_view.def("map", HalMappedMemory::Create, py::keep_alive<0, 1>())
.def_property_readonly(
"shape",
[](HalBufferView& self) {
diff --git a/runtime/bindings/python/hal.h b/runtime/bindings/python/hal.h
index 507629b..d4a2c85 100644
--- a/runtime/bindings/python/hal.h
+++ b/runtime/bindings/python/hal.h
@@ -11,6 +11,7 @@
#include "./binding.h"
#include "./status_utils.h"
+#include "./vm.h"
#include "iree/hal/api.h"
namespace iree {
diff --git a/runtime/bindings/python/iree/runtime/__init__.py b/runtime/bindings/python/iree/runtime/__init__.py
index c7d22c3..bffc5fa 100644
--- a/runtime/bindings/python/iree/runtime/__init__.py
+++ b/runtime/bindings/python/iree/runtime/__init__.py
@@ -33,6 +33,7 @@
from ._binding import (
create_hal_module,
Linkage,
+ VmBuffer,
VmVariantList,
VmFunction,
VmInstance,
diff --git a/runtime/bindings/python/iree/runtime/function.py b/runtime/bindings/python/iree/runtime/function.py
index 30b63c2..5ad1624 100644
--- a/runtime/bindings/python/iree/runtime/function.py
+++ b/runtime/bindings/python/iree/runtime/function.py
@@ -21,6 +21,7 @@
MemoryType,
VmContext,
VmFunction,
+ VmRef,
VmVariantList,
)
@@ -202,7 +203,7 @@
# The descriptor for an ndarray is like:
# ["ndarray", "<dtype>", <rank>, <dim>...]
# ex: ['ndarray', 'i32', 1, 25948]
- buffer_view = vm_list.get_as_buffer_view(vm_index)
+ buffer_view = vm_list.get_as_object(vm_index, HalBufferView)
dtype_str = desc[1]
try:
dtype = ABI_TYPE_TO_DTYPE[dtype_str]
@@ -366,10 +367,12 @@
# Special case: Upgrade HalBufferView to a DeviceArray. We do that here
# since this is higher level and it preserves layering. Note that
# the reflection case also does this conversion.
- if isinstance(converted, HalBufferView):
- converted = DeviceArray(inv.device,
- converted,
- implicit_host_transfer=True)
+ if isinstance(converted, VmRef):
+ converted_buffer_view = converted.deref(HalBufferView, True)
+ if converted_buffer_view:
+ converted = DeviceArray(inv.device,
+ converted_buffer_view,
+ implicit_host_transfer=True)
else:
# Known type descriptor.
vm_type = desc if isinstance(desc, str) else desc[0]
diff --git a/runtime/bindings/python/tests/function_test.py b/runtime/bindings/python/tests/function_test.py
index 4fb6427..4c23cc6 100644
--- a/runtime/bindings/python/tests/function_test.py
+++ b/runtime/bindings/python/tests/function_test.py
@@ -515,7 +515,7 @@
allowed_usage=IMPLICIT_BUFFER_ARG_USAGE,
buffer=result_array,
element_type=rt.HalElementType.SINT_32)
- ret_list.push_buffer_view(buffer_view)
+ ret_list.push_ref(buffer_view)
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(reflection={
@@ -537,7 +537,7 @@
allowed_usage=IMPLICIT_BUFFER_ARG_USAGE,
buffer=result_array,
element_type=rt.HalElementType.SINT_32)
- ret_list.push_buffer_view(buffer_view)
+ ret_list.push_ref(buffer_view)
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(reflection={})
@@ -555,7 +555,7 @@
allowed_usage=IMPLICIT_BUFFER_ARG_USAGE,
buffer=result_array,
element_type=rt.HalElementType.UINT_8)
- ret_list.push_buffer_view(buffer_view)
+ ret_list.push_ref(buffer_view)
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(reflection={
diff --git a/runtime/bindings/python/tests/vm_test.py b/runtime/bindings/python/tests/vm_test.py
index 02cf883..98573b0 100644
--- a/runtime/bindings/python/tests/vm_test.py
+++ b/runtime/bindings/python/tests/vm_test.py
@@ -64,52 +64,6 @@
iree.compiler.core.DEFAULT_TESTING_DRIVER)
self.hal_module = iree.runtime.create_hal_module(self.device)
- def test_variant_list(self):
- l = iree.runtime.VmVariantList(5)
- logging.info("variant_list: %s", l)
- self.assertEqual(l.size, 0)
-
- def test_variant_list_i64(self):
- l = iree.runtime.VmVariantList(5)
- # Push a value that exceeds 32-bit range.
- l.push_int(10 * 1000 * 1000 * 1000)
- self.assertEqual(str(l), "<VmVariantList(1): [10000000000]>")
-
- def test_variant_list_buffers(self):
- ET = iree.runtime.HalElementType
- for dt, et in ((np.int8, ET.SINT_8), (np.int16, ET.SINT_16),
- (np.int32, ET.SINT_32), (np.int64, ET.SINT_64),
- (np.uint8, ET.UINT_8), (np.uint16, ET.UINT_16),
- (np.uint32, ET.UINT_32), (np.uint64, ET.UINT_64),
- (np.float32, ET.FLOAT_32), (np.float64, ET.FLOAT_64)):
- # TODO: Unimplemented: (np.float16, ET.FLOAT_16)
- lst = iree.runtime.VmVariantList(5)
- ary1 = np.asarray([1, 2, 3, 4], dtype=dt)
- bv1 = self.device.allocator.allocate_buffer_copy(
- memory_type=iree.runtime.MemoryType.DEVICE_LOCAL,
- allowed_usage=(iree.runtime.BufferUsage.DEFAULT |
- iree.runtime.BufferUsage.MAPPING),
- buffer=ary1,
- element_type=et)
- lst.push_buffer_view(bv1)
- ary2 = iree.runtime.DeviceArray(self.device,
- lst.get_as_buffer_view(0),
- override_dtype=dt,
- implicit_host_transfer=True)
- np.testing.assert_array_equal(ary1, ary2)
- with self.assertRaises(IndexError):
- lst.get_as_buffer_view(1)
-
- def test_variant_list_list(self):
- lst1 = iree.runtime.VmVariantList(5)
- lst2 = iree.runtime.VmVariantList(5)
- lst1.push_list(lst2)
- self.assertEqual("<VmVariantList(1): [List[]]>", str(lst1))
- lstout = lst1.get_as_list(0)
- self.assertEqual("<VmVariantList(0): []>", str(lstout))
- with self.assertRaises(IndexError):
- lst1.get_as_list(1)
-
def test_context_id(self):
instance = iree.runtime.VmInstance()
context1 = iree.runtime.VmContext(instance)
diff --git a/runtime/bindings/python/tests/vm_types_test.py b/runtime/bindings/python/tests/vm_types_test.py
index ada9b5d..88fd9ff 100644
--- a/runtime/bindings/python/tests/vm_types_test.py
+++ b/runtime/bindings/python/tests/vm_types_test.py
@@ -4,6 +4,8 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+import logging
+import numpy as np
import unittest
import iree.runtime as rt
@@ -28,6 +30,66 @@
self.assertNotEqual(lst1, False)
self.assertTrue(ref.isinstance(rt.VmVariantList))
+ def test_variant_list(self):
+ l = rt.VmVariantList(5)
+ logging.info("variant_list: %s", l)
+ self.assertEqual(l.size, 0)
+
+ def test_variant_list_i64(self):
+ l = rt.VmVariantList(5)
+ # Push a value that exceeds 32-bit range.
+ l.push_int(10 * 1000 * 1000 * 1000)
+ self.assertEqual(str(l), "<VmVariantList(1): [10000000000]>")
+
+ def test_variant_list_buffers(self):
+ device = rt.get_device("local-sync")
+ ET = rt.HalElementType
+ for dt, et in ((np.int8, ET.SINT_8), (np.int16, ET.SINT_16),
+ (np.int32, ET.SINT_32), (np.int64, ET.SINT_64),
+ (np.uint8, ET.UINT_8), (np.uint16, ET.UINT_16),
+ (np.uint32, ET.UINT_32), (np.uint64, ET.UINT_64),
+ (np.float32, ET.FLOAT_32), (np.float64, ET.FLOAT_64)):
+ # TODO: Unimplemented: (np.float16, ET.FLOAT_16)
+ lst = rt.VmVariantList(5)
+ ary1 = np.asarray([1, 2, 3, 4], dtype=dt)
+ bv1 = device.allocator.allocate_buffer_copy(
+ memory_type=rt.MemoryType.DEVICE_LOCAL,
+ allowed_usage=(rt.BufferUsage.DEFAULT | rt.BufferUsage.MAPPING),
+ buffer=ary1,
+ element_type=et)
+ lst.push_ref(bv1)
+ ary2 = rt.DeviceArray(device,
+ lst.get_as_object(0, rt.HalBufferView),
+ override_dtype=dt,
+ implicit_host_transfer=True)
+ np.testing.assert_array_equal(ary1, ary2)
+ with self.assertRaises(IndexError):
+ lst.get_as_object(1, rt.HalBufferView)
+
+ def test_variant_list_list(self):
+ lst1 = rt.VmVariantList(5)
+ lst2 = rt.VmVariantList(5)
+ lst1.push_list(lst2)
+ self.assertEqual("<VmVariantList(1): [List[]]>", str(lst1))
+ lstout = lst1.get_as_list(0)
+ self.assertEqual("<VmVariantList(0): []>", str(lstout))
+ with self.assertRaises(IndexError):
+ lst1.get_as_list(1)
+
+ def test_vm_buffer(self):
+ b1 = rt.VmBuffer(10, mutable=True)
+ print(b1)
+ contents = memoryview(b1)
+ contents[0:] = b'0123456789'
+ self.assertEqual(bytes(b1), b'0123456789')
+
+ def test_vm_buffer_ro(self):
+ b1 = rt.VmBuffer(10, mutable=False)
+ contents = memoryview(b1)
+ with self.assertRaises(TypeError):
+ contents[0:] = b'0123456789'
+
if __name__ == "__main__":
+ logging.basicConfig(level=logging.DEBUG)
unittest.main()
diff --git a/runtime/bindings/python/vm.cc b/runtime/bindings/python/vm.cc
index 5ffeede..349fa9e 100644
--- a/runtime/bindings/python/vm.cc
+++ b/runtime/bindings/python/vm.cc
@@ -10,7 +10,9 @@
#include "iree/base/api.h"
#include "iree/base/status_cc.h"
#include "iree/base/tracing.h"
-#include "iree/hal/api.h"
+// TODO: We shouldn't need the HAL API but it is used for direct printing
+// summaries of HAL objects in lists. We should have a better way of doing this
+// dynamically vs hard depending on a type switch here.
#include "iree/modules/hal/module.h"
#include "iree/vm/api.h"
#include "pybind11/numpy.h"
@@ -20,15 +22,6 @@
namespace {
-VmModule CreateHalModule(HalDevice* device) {
- iree_vm_module_t* module;
- CheckApiStatus(
- iree_hal_module_create(device->raw_ptr(), IREE_HAL_MODULE_FLAG_NONE,
- iree_allocator_system(), &module),
- "Error creating hal module");
- return VmModule::StealFromRawPtr(module);
-}
-
// RAII wrapper for a Py_buffer which calls PyBuffer_Release when it goes
// out of scope.
class PyBufferReleaser {
@@ -182,8 +175,12 @@
const char* const VmRef::kCastAttr = "__iree_vm_cast__";
const char* const VmRef::kTypeIdAttr = "__iree_vm_type_id__";
-py::object VmRef::Deref(py::object ref_object_class) {
- return ref_object_class.attr(kCastAttr)(*this);
+py::object VmRef::Deref(py::object ref_object_class, bool optional) {
+ py::object casted = ref_object_class.attr(kCastAttr)(*this);
+ if (!optional && casted.is_none()) {
+ throw py::type_error("Cannot dereference to specific type");
+ }
+ return casted;
}
bool VmRef::IsInstance(py::object ref_object_class) {
@@ -229,11 +226,11 @@
iree_vm_list_push_ref_move(raw_ptr(), &retained);
}
-void VmVariantList::PushBufferView(HalBufferView& buffer_view) {
- iree_vm_ref_t buffer_view_ref =
- iree_hal_buffer_view_retain_ref(buffer_view.raw_ptr());
- CheckApiStatus(iree_vm_list_push_ref_move(raw_ptr(), &buffer_view_ref),
- "Error moving buffer view");
+void VmVariantList::PushRef(py::handle ref_or_object) {
+ py::object py_ref = ref_or_object.attr(VmRef::kRefAttr);
+ VmRef& ref = py::cast<VmRef&>(py_ref);
+ CheckApiStatus(iree_vm_list_push_ref_retain(raw_ptr(), &ref.ref()),
+ "Failed to push ref");
}
py::object VmVariantList::GetAsList(int index) {
@@ -271,13 +268,10 @@
}
} else if (v.type.ref_type == IREE_VM_REF_TYPE_NULL) {
return py::none();
- } else if (iree_vm_type_def_is_ref(&v.type)) {
- // Convert reference type.
- if (iree_vm_list_isa(v.ref)) {
- return GetAsList(index);
- } else if (iree_hal_buffer_view_isa(v.ref)) {
- return GetAsBufferView(index);
- }
+ } else if (iree_vm_variant_is_ref(v)) {
+ VmRef ref;
+ iree_vm_ref_retain(&v.ref, &ref.ref());
+ return py::cast(ref, py::return_value_policy::move);
}
throw RaiseValueError("Unsupported VM to Python Type Conversion");
@@ -387,16 +381,20 @@
throw RaiseValueError("Unsupported VM to Python Type Conversion");
}
-py::object VmVariantList::GetAsBufferView(int index) {
+py::object VmVariantList::GetAsRef(int index) {
iree_vm_variant_t v = iree_vm_variant_empty();
CheckApiStatus(iree_vm_list_get_variant(raw_ptr(), index, &v),
"Could not access list element");
- iree_hal_buffer_view_t* buffer_view = iree_hal_buffer_view_deref(v.ref);
- if (!buffer_view) {
- throw RaiseValueError("Could not deref result buffer view (wrong type?)");
+ if (!iree_vm_variant_is_ref(v)) {
+ throw std::invalid_argument("list element is not a ref");
}
- return py::cast(HalBufferView::BorrowFromRawPtr(buffer_view),
- py::return_value_policy::move);
+ VmRef ref;
+ iree_vm_ref_retain(&v.ref, &ref.ref());
+ return py::cast(ref, py::return_value_policy::move);
+}
+
+py::object VmVariantList::GetAsObject(int index, py::object clazz) {
+ return clazz.attr(VmRef::kCastAttr)(GetAsRef(index));
}
namespace {
@@ -514,10 +512,6 @@
void SetupVmBindings(pybind11::module m) {
IREE_CHECK_OK(iree_vm_register_builtin_types());
- IREE_CHECK_OK(iree_hal_module_register_all_types());
-
- // Built-in module creation.
- m.def("create_hal_module", &CreateHalModule);
py::enum_<enum iree_vm_function_linkage_e>(m, "Linkage")
.value("INTERNAL", IREE_VM_FUNCTION_LINKAGE_INTERNAL)
@@ -526,16 +520,54 @@
.value("EXPORT", IREE_VM_FUNCTION_LINKAGE_EXPORT)
.export_values();
+ auto vm_buffer = py::class_<VmBuffer>(m, "VmBuffer", py::buffer_protocol());
+ VmRef::BindRefProtocol(vm_buffer, iree_vm_buffer_type_id,
+ iree_vm_buffer_retain_ref, iree_vm_buffer_deref,
+ iree_vm_buffer_isa);
+ vm_buffer
+ .def(py::init([](iree_host_size_t length, bool is_mutable) {
+ iree_vm_buffer_access_t access = 0;
+ if (is_mutable) {
+ access |= IREE_VM_BUFFER_ACCESS_MUTABLE;
+ }
+ iree_vm_buffer_t* raw_buffer;
+ CheckApiStatus(
+ iree_vm_buffer_create(access, length, iree_allocator_system(),
+ &raw_buffer),
+ "Error creating buffer");
+ return VmBuffer::StealFromRawPtr(raw_buffer);
+ }),
+ py::arg("length"), py::arg("mutable") = true)
+ .def_buffer([](VmBuffer& self) -> py::buffer_info {
+ return py::buffer_info(
+ /*ptr=*/self.raw_ptr()->data.data,
+ /*itemsize=*/sizeof(uint8_t),
+ /*format=*/py::format_descriptor<uint8_t>::format(),
+ /*ndim=*/1,
+ /*shape=*/{self.raw_ptr()->data.data_length},
+ /*strides=*/{1},
+ /*readonly=*/
+ !(self.raw_ptr()->access & IREE_VM_BUFFER_ACCESS_MUTABLE));
+ })
+ .def("__repr__", [](VmBuffer& self) {
+ std::stringstream ss;
+ ss << "<VmBuffer size " << self.raw_ptr()->data.data_length << " at 0x"
+ << std::hex << reinterpret_cast<uintptr_t>(self.raw_ptr()->data.data)
+ << ">";
+ return ss.str();
+ });
+
// Mutation and inspection of the variant list is mostly opaque to python.
auto vm_list = py::class_<VmVariantList>(m, "VmVariantList");
VmRef::BindRefProtocol(vm_list, iree_vm_list_type_id, iree_vm_list_retain_ref,
- iree_vm_list_check_deref);
+ iree_vm_list_deref, iree_vm_list_isa);
vm_list
// User Methods.
.def(py::init(&VmVariantList::Create))
.def_property_readonly("size", &VmVariantList::size)
.def("__len__", &VmVariantList::size)
- .def("get_as_buffer_view", &VmVariantList::GetAsBufferView)
+ .def("get_as_ref", &VmVariantList::GetAsRef)
+ .def("get_as_object", &VmVariantList::GetAsObject)
.def("get_as_list", &VmVariantList::GetAsList)
.def("get_variant", &VmVariantList::GetVariant)
.def("get_serialized_trace_value",
@@ -543,7 +575,7 @@
.def("push_float", &VmVariantList::PushFloat)
.def("push_int", &VmVariantList::PushInt)
.def("push_list", &VmVariantList::PushList)
- .def("push_buffer_view", &VmVariantList::PushBufferView)
+ .def("push_ref", &VmVariantList::PushRef)
.def("__repr__", &VmVariantList::DebugString);
py::class_<iree_vm_function_t>(m, "VmFunction")
@@ -650,8 +682,11 @@
py::class_<VmRef>(m, "VmRef")
.def("isinstance", &VmRef::IsInstance)
- .def("deref", &VmRef::Deref)
+ .def("deref", &VmRef::Deref, py::arg("value"),
+ py::arg("optional") = false)
.def("__repr__", &VmRef::ToString)
+ .def_property_readonly(VmRef::kRefAttr,
+ [](py::object self) { return self; })
.def("__eq__",
[](VmRef& self, VmRef& other) {
return self.ref().ptr == other.ref().ptr;
diff --git a/runtime/bindings/python/vm.h b/runtime/bindings/python/vm.h
index 135e98b..37443cc 100644
--- a/runtime/bindings/python/vm.h
+++ b/runtime/bindings/python/vm.h
@@ -10,7 +10,7 @@
#include <optional>
#include "./binding.h"
-#include "./hal.h"
+#include "./status_utils.h"
#include "iree/base/api.h"
#include "iree/vm/api.h"
#include "iree/vm/bytecode_module.h"
@@ -26,6 +26,12 @@
//------------------------------------------------------------------------------
template <>
+struct ApiPtrAdapter<iree_vm_buffer_t> {
+ static void Retain(iree_vm_buffer_t* b) { iree_vm_buffer_retain(b); }
+ static void Release(iree_vm_buffer_t* b) { iree_vm_buffer_release(b); }
+};
+
+template <>
struct ApiPtrAdapter<iree_vm_instance_t> {
static void Retain(iree_vm_instance_t* b) { iree_vm_instance_retain(b); }
static void Release(iree_vm_instance_t* b) { iree_vm_instance_release(b); }
@@ -68,6 +74,12 @@
};
//------------------------------------------------------------------------------
+// VmBuffer
+//------------------------------------------------------------------------------
+
+class VmBuffer : public ApiRefCounted<VmBuffer, iree_vm_buffer_t> {};
+
+//------------------------------------------------------------------------------
// VmVariantList
// TODO: Rename to VmList
//------------------------------------------------------------------------------
@@ -94,9 +106,10 @@
void PushFloat(double fvalue);
void PushInt(int64_t ivalue);
void PushList(VmVariantList& other);
- void PushBufferView(HalBufferView& buffer_view);
+ void PushRef(py::handle ref_or_object);
py::object GetAsList(int index);
- py::object GetAsBufferView(int index);
+ py::object GetAsRef(int index);
+ py::object GetAsObject(int index, py::object clazz);
py::object GetVariant(int index);
py::object GetAsSerializedTraceValue(int index);
};
@@ -174,7 +187,8 @@
// [readonly property] __iree_vm_ref__ :
// Gets a VmRef from the object.
// __iree_vm_cast__(ref) :
- // Dereferences the VmRef to the concrete type.
+ // Dereferences the VmRef to the concrete type. Returns None on cast
+ // failure.
//
// In addition, a user attribute of "ref" will be added that is an alias of
// __iree_vm_ref__.
@@ -191,10 +205,10 @@
static const char* const kCastAttr;
template <typename PyClass, typename TypeIdFunctor, typename RetainRefFunctor,
- typename CheckDerefFunctor>
+ typename DerefFunctor, typename IsaFunctor>
static void BindRefProtocol(PyClass& cls, TypeIdFunctor type_id,
- RetainRefFunctor retain_ref,
- CheckDerefFunctor check_deref) {
+ RetainRefFunctor retain_ref, DerefFunctor deref,
+ IsaFunctor isa) {
using WrapperType = typename PyClass::type;
using RawPtrType = typename WrapperType::RawPtrType;
auto ref_lambda = [=](WrapperType& self) {
@@ -203,10 +217,12 @@
cls.def_static(VmRef::kTypeIdAttr, [=]() { return type_id(); });
cls.def_property_readonly(VmRef::kRefAttr, ref_lambda);
cls.def_property_readonly("ref", ref_lambda);
- cls.def_static(VmRef::kCastAttr, [=](VmRef& ref) {
- RawPtrType casted;
- CheckApiStatus(check_deref(ref.ref(), &casted), "Incompatible type");
- return WrapperType::StealFromRawPtr(casted);
+ cls.def_static(VmRef::kCastAttr, [=](VmRef& ref) -> py::object {
+ if (!isa(ref.ref())) {
+ return py::none();
+ }
+ return py::cast(WrapperType::BorrowFromRawPtr(deref(ref.ref())),
+ py::return_value_policy::move);
});
cls.def("__eq__", [](WrapperType& self, WrapperType& other) {
return self.raw_ptr() == other.raw_ptr();
@@ -233,7 +249,7 @@
iree_vm_ref_t& ref() { return ref_; }
- py::object Deref(py::object ref_object_class);
+ py::object Deref(py::object ref_object_class, bool optional);
bool IsInstance(py::object ref_object_class);
std::string ToString();