function.py: Add support for mashaling list types on returns. (#6999)
diff --git a/bindings/python/iree/runtime/function.py b/bindings/python/iree/runtime/function.py
index 028d6df..f42a4fa 100644
--- a/bindings/python/iree/runtime/function.py
+++ b/bindings/python/iree/runtime/function.py
@@ -381,11 +381,22 @@
return convert
+def _vm_to_pylist(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc):
+ # The descriptor for a pylist is like:
+ # ['pylist', element_type]
+ sub_vm_list = vm_list.get_as_list(vm_index)
+ element_type_desc = desc[1:]
+ py_items = _extract_vm_sequence_to_python(
+ inv, sub_vm_list, element_type_desc * len(sub_vm_list))
+ return py_items
+
+
VM_TO_PYTHON_CONVERTERS = {
"ndarray": _vm_to_ndarray,
"sdict": _vm_to_sdict,
"slist": _vm_to_slist,
"stuple": _vm_to_stuple,
+ "py_homogeneous_list": _vm_to_pylist,
# Scalars.
"i8": _vm_to_scalar(int),
diff --git a/bindings/python/iree/runtime/function_test.py b/bindings/python/iree/runtime/function_test.py
index ed8e887..90ccda2 100644
--- a/bindings/python/iree/runtime/function_test.py
+++ b/bindings/python/iree/runtime/function_test.py
@@ -218,6 +218,26 @@
# assertEqual on bool arrays is fraught for... reasons.
self.assertEqual("array([ True, False])", repr(result))
+ def testReturnTypeList(self):
+ vm_list = VmVariantList(2)
+ vm_list.push_int(1)
+ vm_list.push_int(2)
+
+ def invoke(arg_list, ret_list):
+ ret_list.push_list(vm_list)
+
+ vm_context = MockVmContext(invoke)
+ vm_function = MockVmFunction(reflection={
+ "iree.abi":
+ json.dumps({
+ "a": [],
+ "r": [["py_homogeneous_list", "i64"]],
+ })
+ })
+ invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+ result = invoker()
+ self.assertEqual("[1, 2]", repr(result))
+
if __name__ == "__main__":
absltest.main()
diff --git a/docs/developers/design_docs/function_abi.md b/docs/developers/design_docs/function_abi.md
index 49cdc54..d936bda 100644
--- a/docs/developers/design_docs/function_abi.md
+++ b/docs/developers/design_docs/function_abi.md
@@ -183,3 +183,5 @@
- `["sdict", ["key", {slot_type}]...]`: An anonymous structure with named
slots. Note that when passing these types, the keys are not passed to the
function (only the slot values).
+- `["py_homogeneous_list", {element_type}]`: A Python list of unknown size
+ with elements sharing a common type bound given by `element_type`.