[HAL] Allow iree_hal_buffer_view_shape to accept NULL out_shape (#14298)
When the buffer view has a 0 rank allow `out_shape` to be NULL.
Add Python bindings test to check `VmVariantList` -> `str` conversion
for 0-rank tensor.
---------
Co-authored-by: Boian Petkantchin <boian@nod-labs.com>
diff --git a/runtime/bindings/python/tests/vm_types_test.py b/runtime/bindings/python/tests/vm_types_test.py
index db8c7e8..245a445 100644
--- a/runtime/bindings/python/tests/vm_types_test.py
+++ b/runtime/bindings/python/tests/vm_types_test.py
@@ -81,6 +81,19 @@
with self.assertRaises(IndexError):
lst.get_as_object(1, rt.HalBufferView)
+ def test_variant_list_zero_rank_tensor_to_str(self):
+ device = rt.get_device("local-sync")
+ lst = rt.VmVariantList(1)
+ array = np.array(1234, dtype=np.int32)
+ buffer_view = device.allocator.allocate_buffer_copy(
+ memory_type=rt.MemoryType.DEVICE_LOCAL,
+ allowed_usage=(rt.BufferUsage.DEFAULT | rt.BufferUsage.MAPPING),
+ buffer=array,
+ element_type=rt.HalElementType.SINT_32,
+ )
+ lst.push_ref(buffer_view)
+ self.assertEqual(str(lst), "<VmVariantList(1): [HalBufferView(:0x20000011)]>")
+
def test_variant_list_list(self):
lst1 = rt.VmVariantList(5)
lst2 = rt.VmVariantList(5)
diff --git a/runtime/src/iree/hal/buffer_view.c b/runtime/src/iree/hal/buffer_view.c
index 3dd6632..4c90f57 100644
--- a/runtime/src/iree/hal/buffer_view.c
+++ b/runtime/src/iree/hal/buffer_view.c
@@ -142,10 +142,6 @@
const iree_hal_buffer_view_t* buffer_view, iree_host_size_t rank_capacity,
iree_hal_dim_t* out_shape, iree_host_size_t* out_shape_rank) {
IREE_ASSERT_ARGUMENT(buffer_view);
- IREE_ASSERT_ARGUMENT(out_shape);
- if (out_shape_rank) {
- *out_shape_rank = 0;
- }
if (out_shape_rank) {
*out_shape_rank = buffer_view->shape_rank;
@@ -155,6 +151,7 @@
return iree_status_from_code(IREE_STATUS_OUT_OF_RANGE);
}
+ IREE_ASSERT(buffer_view->shape_rank == 0 || out_shape);
for (iree_host_size_t i = 0; i < buffer_view->shape_rank; ++i) {
out_shape[i] = buffer_view->shape[i];
}
diff --git a/runtime/src/iree/hal/buffer_view.h b/runtime/src/iree/hal/buffer_view.h
index 5a2071e..1268637 100644
--- a/runtime/src/iree/hal/buffer_view.h
+++ b/runtime/src/iree/hal/buffer_view.h
@@ -233,7 +233,9 @@
// Returns the dimensions of the shape in |out_shape| and its rank in
// |out_shape_rank|. |rank_capacity| indicates the number of dimensions
// available in the |out_shape| buffer. If there is not enough capacity to store
-// all of the dimensions IREE_STATUS_OUT_OF_RANGE is returned.
+// all of the dimensions IREE_STATUS_OUT_OF_RANGE is returned
+// without populating |out_shape|.
+// If the shape rank of |buffer_view| is 0, |out_shape| can be NULL.
// |out_shape_rank| can be omitted if the rank is already known.
IREE_API_EXPORT iree_status_t iree_hal_buffer_view_shape(
const iree_hal_buffer_view_t* buffer_view, iree_host_size_t rank_capacity,