Implement function.py return type coercion (#6832)
Update `function.py` to support `ndarray` type coercion on return (currently supported on input values). Tested for i8 to boolean conversion since the IREE runtime doesn't support boolean values internally.
diff --git a/bindings/python/iree/runtime/function.py b/bindings/python/iree/runtime/function.py
index 7fef1c6..028d6df 100644
--- a/bindings/python/iree/runtime/function.py
+++ b/bindings/python/iree/runtime/function.py
@@ -329,7 +329,18 @@
def _vm_to_ndarray(inv: Invocation, vm_list: VmVariantList, vm_index: int,
desc):
- return vm_list.get_as_ndarray(vm_index)
+ # The descriptor for an ndarray is like:
+ # ["ndarray", "<dtype>", <rank>, <dim>...]
+ # ex: ['ndarray', 'i32', 1, 25948]
+ x = vm_list.get_as_ndarray(vm_index)
+ dtype_str = desc[1]
+ try:
+ dtype = ABI_TYPE_TO_DTYPE[dtype_str]
+ except KeyError:
+ _raise_return_error(inv, f"unrecognized dtype '{dtype_str}'")
+ if dtype != x.dtype:
+ x = x.astype(dtype)
+ return x
def _vm_to_sdict(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc):
@@ -394,6 +405,7 @@
"i64": np.int64,
"f64": np.float64,
"i16": np.int16,
+ "i8": np.int8,
"i1": np.bool_,
}
diff --git a/bindings/python/iree/runtime/function_test.py b/bindings/python/iree/runtime/function_test.py
index 3db2671..ed8e887 100644
--- a/bindings/python/iree/runtime/function_test.py
+++ b/bindings/python/iree/runtime/function_test.py
@@ -6,6 +6,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import json
+import numpy as np
from absl.testing import absltest
@@ -197,6 +198,26 @@
with self.assertRaisesRegexp(ValueError, "specified kwarg 'c' is unknown"):
result = invoker(-1, a=1, b=2, c=3)
+ # TODO: Fill out all return types.
+ def testReturnTypeNdArrayBool(self):
+ result_array = np.asarray([1, 0], dtype=np.int8)
+
+ def invoke(arg_list, ret_list):
+ ret_list.push_buffer_view(self.device, result_array,
+ rt.HalElementType.UINT_8)
+
+ vm_context = MockVmContext(invoke)
+ vm_function = MockVmFunction(reflection={
+ "iree.abi": json.dumps({
+ "a": [],
+ "r": [["ndarray", "i1", 1, 2]],
+ })
+ })
+ invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+ result = invoker()
+ # assertEqual on bool arrays is fraught for... reasons.
+ self.assertEqual("array([ True, False])", repr(result))
+
if __name__ == "__main__":
absltest.main()