Finally moving VM type registration to iree_vm_instance_t. (#12650)
This allows for thread-safe type registration scoped to instances and
unregistration of types as required by types in dynamically loaded
modules that may wink out of existence at some point. The main trick
here was changing the type ID from an ordinal in the type table to just
the pointer of the type descriptor. This requires an extra 4 bytes per
ref on 64-bit systems but who cares - now there's no round-tripping
through the type table for common operations.
As part of simplifying the way types are referenced VM type descriptors
are now hidden behind iree_vm_ref_type_t. This makes refs much easier to
work with as there's only one way to reference types and it always
bottoms out on the registered descriptor
handle. It also allows us to remove some type descriptor indirection
we'd previously required in order to get reference counter offsets as we
can share the same packed type identifier in type defs, refs, or lists.
Thanks to @simon-camp for the required EmitC changes!
---------
Co-authored-by: Simon Camphausen <simon.camphausen@iml.fraunhofer.de>
diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc
index bf43d93..6f51450 100644
--- a/runtime/bindings/python/hal.cc
+++ b/runtime/bindings/python/hal.cc
@@ -721,7 +721,7 @@
.def("__repr__", &HalBuffer::Repr);
auto hal_buffer_view = py::class_<HalBufferView>(m, "HalBufferView");
- VmRef::BindRefProtocol(hal_buffer_view, iree_hal_buffer_view_type_id,
+ VmRef::BindRefProtocol(hal_buffer_view, iree_hal_buffer_view_type,
iree_hal_buffer_view_retain_ref,
iree_hal_buffer_view_deref, iree_hal_buffer_view_isa);
hal_buffer_view.def("map", HalMappedMemory::Create, py::keep_alive<0, 1>())
diff --git a/runtime/bindings/python/vm.cc b/runtime/bindings/python/vm.cc
index 70e5f53..b9c0a1b 100644
--- a/runtime/bindings/python/vm.cc
+++ b/runtime/bindings/python/vm.cc
@@ -59,7 +59,8 @@
IREE_TRACE_SCOPE0("VmInstance::Create");
iree_vm_instance_t* instance = NULL;
- auto status = iree_vm_instance_create(iree_allocator_system(), &instance);
+ auto status = iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT,
+ iree_allocator_system(), &instance);
CheckApiStatus(status, "Error creating instance");
// The python bindings assume the HAL is always available for use.
@@ -182,7 +183,7 @@
const char* const VmRef::kRefAttr = "__iree_vm_ref__";
const char* const VmRef::kCastAttr = "__iree_vm_cast__";
-const char* const VmRef::kTypeIdAttr = "__iree_vm_type_id__";
+const char* const VmRef::kTypeAttr = "__iree_vm_type__";
py::object VmRef::Deref(py::object ref_object_class, bool optional) {
py::object casted = ref_object_class.attr(kCastAttr)(*this);
@@ -193,9 +194,8 @@
}
bool VmRef::IsInstance(py::object ref_object_class) {
- auto type_id =
- py::cast<iree_vm_ref_type_t>(ref_object_class.attr(kTypeIdAttr)());
- return type_id == ref_.type;
+ auto type = py::cast<iree_vm_ref_type_t>(ref_object_class.attr(kTypeAttr)());
+ return type == ref_.type;
}
std::string VmRef::ToString() {
@@ -257,9 +257,11 @@
iree_vm_variant_t v = iree_vm_variant_empty();
CheckApiStatus(iree_vm_list_get_variant_assign(raw_ptr(), index, &v),
"Could not access list element");
- if (iree_vm_type_def_is_value(&v.type)) {
+ if (iree_vm_variant_is_empty(v)) {
+ return py::none();
+ } else if (iree_vm_variant_is_value(v)) {
// Convert a value type.
- switch (v.type.value_type) {
+ switch (iree_vm_type_def_as_value(v.type)) {
case IREE_VM_VALUE_TYPE_I8:
return py::cast(v.i8);
case IREE_VM_VALUE_TYPE_I16:
@@ -275,8 +277,6 @@
default:
throw RaiseValueError("Unsupported VM value type conversion");
}
- } else if (v.type.ref_type == IREE_VM_REF_TYPE_NULL) {
- return py::none();
} else if (iree_vm_variant_is_ref(v)) {
VmRef ref;
iree_vm_ref_retain(&v.ref, &ref.ref());
@@ -290,10 +290,14 @@
iree_vm_variant_t v = iree_vm_variant_empty();
CheckApiStatus(iree_vm_list_get_variant_assign(raw_ptr(), index, &v),
"Could not access list element");
- if (iree_vm_type_def_is_value(&v.type)) {
+ if (iree_vm_variant_is_empty(v)) {
+ py::dict record;
+ record["type"] = "null";
+ return std::move(record);
+ } else if (iree_vm_variant_is_value(v)) {
// Convert a value type.
py::dict record;
- switch (v.type.value_type) {
+ switch (iree_vm_type_def_as_value(v.type)) {
case IREE_VM_VALUE_TYPE_I8:
record["i8"] = py::cast(v.i8);
break;
@@ -317,11 +321,7 @@
}
record["type"] = py::cast("value");
return std::move(record);
- } else if (v.type.ref_type == IREE_VM_REF_TYPE_NULL) {
- py::dict record;
- record["type"] = "null";
- return std::move(record);
- } else if (iree_vm_type_def_is_ref(&v.type)) {
+ } else if (iree_vm_variant_is_ref(v)) {
// Convert reference type.
if (iree_vm_list_isa(v.ref)) {
py::dict record;
@@ -442,7 +442,7 @@
if (iree_vm_variant_is_value(variant)) {
// Convert a value type to a string.
- switch (variant.type.value_type) {
+ switch (iree_vm_type_def_as_value(variant.type)) {
case IREE_VM_VALUE_TYPE_I8: {
out += std::to_string(variant.i8);
break;
@@ -501,7 +501,8 @@
}
out.append("]");
} else {
- out += "Unknown(" + std::to_string(variant.type.ref_type) + ")";
+ out += "Unknown(" +
+ std::to_string(iree_vm_type_def_as_ref(variant.type)) + ")";
}
} else {
out.append("None");
@@ -534,7 +535,7 @@
.export_values();
auto vm_buffer = py::class_<VmBuffer>(m, "VmBuffer", py::buffer_protocol());
- VmRef::BindRefProtocol(vm_buffer, iree_vm_buffer_type_id,
+ VmRef::BindRefProtocol(vm_buffer, iree_vm_buffer_type,
iree_vm_buffer_retain_ref, iree_vm_buffer_deref,
iree_vm_buffer_isa);
vm_buffer
@@ -574,7 +575,7 @@
// Mutation and inspection of the variant list is mostly opaque to python.
auto vm_list = py::class_<VmVariantList>(m, "VmVariantList");
- VmRef::BindRefProtocol(vm_list, iree_vm_list_type_id, iree_vm_list_retain_ref,
+ VmRef::BindRefProtocol(vm_list, iree_vm_list_type, iree_vm_list_retain_ref,
iree_vm_list_deref, iree_vm_list_isa);
vm_list
// User Methods.
diff --git a/runtime/bindings/python/vm.h b/runtime/bindings/python/vm.h
index 464db60..ed3661e 100644
--- a/runtime/bindings/python/vm.h
+++ b/runtime/bindings/python/vm.h
@@ -87,9 +87,10 @@
public:
static VmVariantList Create(iree_host_size_t capacity) {
iree_vm_list_t* list;
- CheckApiStatus(iree_vm_list_create(/*element_type=*/nullptr, capacity,
- iree_allocator_system(), &list),
- "Error allocating variant list");
+ CheckApiStatus(
+ iree_vm_list_create(iree_vm_make_undefined_type_def(), capacity,
+ iree_allocator_system(), &list),
+ "Error allocating variant list");
return VmVariantList::StealFromRawPtr(list);
}
@@ -182,8 +183,8 @@
//----------------------------------------------------------------------------
// Binds the reference protocol to a VmRefObject bound class.
// This defines three attributes:
- // __iree_vm_type_id__()
- // Gets the type id from the object.
+ // __iree_vm_type__()
+ // Gets the type from the object.
// [readonly property] __iree_vm_ref__ :
// Gets a VmRef from the object.
// __iree_vm_cast__(ref) :
@@ -200,13 +201,13 @@
// reference object. It takes some of the C helper functions that are defined
// for each type and is generic.
//----------------------------------------------------------------------------
- static const char* const kTypeIdAttr;
+ static const char* const kTypeAttr;
static const char* const kRefAttr;
static const char* const kCastAttr;
- template <typename PyClass, typename TypeIdFunctor, typename RetainRefFunctor,
+ template <typename PyClass, typename TypeFunctor, typename RetainRefFunctor,
typename DerefFunctor, typename IsaFunctor>
- static void BindRefProtocol(PyClass& cls, TypeIdFunctor type_id,
+ static void BindRefProtocol(PyClass& cls, TypeFunctor type,
RetainRefFunctor retain_ref, DerefFunctor deref,
IsaFunctor isa) {
using WrapperType = typename PyClass::type;
@@ -214,7 +215,7 @@
auto ref_lambda = [=](WrapperType& self) {
return VmRef::Steal(retain_ref(self.raw_ptr()));
};
- cls.def_static(VmRef::kTypeIdAttr, [=]() { return type_id(); });
+ cls.def_static(VmRef::kTypeAttr, [=]() { return type(); });
cls.def_property_readonly(VmRef::kRefAttr, ref_lambda);
cls.def_property_readonly("ref", ref_lambda);
cls.def_static(VmRef::kCastAttr, [=](VmRef& ref) -> py::object {
diff --git a/runtime/bindings/tflite/interpreter.c b/runtime/bindings/tflite/interpreter.c
index 0aaeb92..b7c7f53 100644
--- a/runtime/bindings/tflite/interpreter.c
+++ b/runtime/bindings/tflite/interpreter.c
@@ -91,7 +91,7 @@
_TfLiteInterpreterShapeFrame* frame) {
// [int32...] storage for the shape dimension inputs/outputs.
iree_vm_type_def_t dim_type =
- iree_vm_type_def_make_value_type(IREE_VM_VALUE_TYPE_I32);
+ iree_vm_make_value_type_def(IREE_VM_VALUE_TYPE_I32);
IREE_RETURN_IF_ERROR(iree_vm_list_initialize(
iree_make_byte_span(frame->shape_list_storage,
IREE_ARRAYSIZE(frame->shape_list_storage)),
@@ -107,7 +107,7 @@
// Arg 1 is always the shape list for all I/O, so do that once here.
iree_vm_ref_t shape_list_ref = {0};
IREE_RETURN_IF_ERROR(iree_vm_ref_wrap_assign(
- frame->shape_list, iree_vm_list_type_id(), &shape_list_ref));
+ frame->shape_list, iree_vm_list_type(), &shape_list_ref));
IREE_RETURN_IF_ERROR(
iree_vm_list_set_ref_retain(frame->arg_list, 1, &shape_list_ref));
@@ -233,7 +233,7 @@
iree_host_align(sizeof(TfLiteInterpreter), iree_max_align_t);
iree_vm_type_def_t buffer_view_type_def =
- iree_vm_type_def_make_ref_type(iree_hal_buffer_type_id());
+ iree_vm_make_ref_type_def(iree_hal_buffer_type());
total_size +=
iree_vm_list_storage_size(&buffer_view_type_def, model->input_count);
total_size +=
@@ -264,7 +264,7 @@
iree_host_align(sizeof(*interpreter), iree_max_align_t);
iree_vm_type_def_t buffer_view_type_def =
- iree_vm_type_def_make_ref_type(iree_hal_buffer_type_id());
+ iree_vm_make_ref_type_def(iree_hal_buffer_type());
iree_byte_span_t input_list_storage = iree_make_byte_span(
p, iree_vm_list_storage_size(&buffer_view_type_def, model->input_count));
@@ -588,8 +588,8 @@
// NOTE: we could defer the mapping unless requested and ensure state buffers
// remain where they currently are for the next invocation.
for (iree_host_size_t i = 0; i < interpreter->model->output_count; ++i) {
- iree_hal_buffer_t* buffer = (iree_hal_buffer_t*)iree_vm_list_get_ref_deref(
- interpreter->output_list, i, &iree_hal_buffer_descriptor);
+ iree_hal_buffer_t* buffer =
+ iree_vm_list_get_buffer_assign(interpreter->output_list, i);
TfLiteTensor* tensor = &interpreter->output_tensors[i];
IREE_RETURN_IF_ERROR(_TfLiteTensorBind(tensor, buffer));
}
diff --git a/runtime/bindings/tflite/model.c b/runtime/bindings/tflite/model.c
index 55773a0..a3becd7 100644
--- a/runtime/bindings/tflite/model.c
+++ b/runtime/bindings/tflite/model.c
@@ -34,7 +34,8 @@
IREE_TRACE_ZONE_BEGIN(z0);
IREE_RETURN_AND_END_ZONE_IF_ERROR(
- z0, iree_vm_instance_create(allocator, &model->instance));
+ z0, iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT, allocator,
+ &model->instance));
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_module_register_all_types(model->instance));