function.py: Add support for mashaling list types on returns. (#6999)
diff --git a/bindings/python/iree/runtime/function.py b/bindings/python/iree/runtime/function.py index 028d6df..f42a4fa 100644 --- a/bindings/python/iree/runtime/function.py +++ b/bindings/python/iree/runtime/function.py
@@ -381,11 +381,22 @@ return convert +def _vm_to_pylist(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc): + # The descriptor for a pylist is like: + # ['pylist', element_type] + sub_vm_list = vm_list.get_as_list(vm_index) + element_type_desc = desc[1:] + py_items = _extract_vm_sequence_to_python( + inv, sub_vm_list, element_type_desc * len(sub_vm_list)) + return py_items + + VM_TO_PYTHON_CONVERTERS = { "ndarray": _vm_to_ndarray, "sdict": _vm_to_sdict, "slist": _vm_to_slist, "stuple": _vm_to_stuple, + "py_homogeneous_list": _vm_to_pylist, # Scalars. "i8": _vm_to_scalar(int),
diff --git a/bindings/python/iree/runtime/function_test.py b/bindings/python/iree/runtime/function_test.py index ed8e887..90ccda2 100644 --- a/bindings/python/iree/runtime/function_test.py +++ b/bindings/python/iree/runtime/function_test.py
@@ -218,6 +218,26 @@ # assertEqual on bool arrays is fraught for... reasons. self.assertEqual("array([ True, False])", repr(result)) + def testReturnTypeList(self): + vm_list = VmVariantList(2) + vm_list.push_int(1) + vm_list.push_int(2) + + def invoke(arg_list, ret_list): + ret_list.push_list(vm_list) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction(reflection={ + "iree.abi": + json.dumps({ + "a": [], + "r": [["py_homogeneous_list", "i64"]], + }) + }) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + result = invoker() + self.assertEqual("[1, 2]", repr(result)) + if __name__ == "__main__": absltest.main()
diff --git a/docs/developers/design_docs/function_abi.md b/docs/developers/design_docs/function_abi.md index 49cdc54..d936bda 100644 --- a/docs/developers/design_docs/function_abi.md +++ b/docs/developers/design_docs/function_abi.md
@@ -183,3 +183,5 @@ - `["sdict", ["key", {slot_type}]...]`: An anonymous structure with named slots. Note that when passing these types, the keys are not passed to the function (only the slot values). +- `["py_homogeneous_list", {element_type}]`: A Python list of unknown size + with elements sharing a common type bound given by `element_type`.