| # Lint as: python3 |
| # Copyright 2021 The IREE Authors |
| # |
| # Licensed under the Apache License v2.0 with LLVM Exceptions. |
| # See https://llvm.org/LICENSE.txt for license information. |
| # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| from typing import Optional |
| |
| import json |
| import logging |
| |
| import numpy as np |
| |
| from .binding import HalDevice, HalElementType, VmContext, VmFunction, VmVariantList |
| from . import tracing |
| |
| __all__ = [ |
| "FunctionInvoker", |
| ] |
| |
| |
| class Invocation: |
| __slots__ = [ |
| "current_arg", |
| "current_desc", |
| "current_return_list", |
| "current_return_index", |
| "device", |
| ] |
| |
| def __init__(self, device: HalDevice): |
| self.device = device |
| # Captured during arg/ret processing to emit better error messages. |
| self.current_arg = None |
| self.current_desc = None |
| self.current_return_list = None |
| self.current_return_index = 0 |
| |
| def summarize_arg_error(self) -> str: |
| if self.current_arg is None: |
| return "" |
| if isinstance(self.current_arg, np.ndarray): |
| current_arg_repr = ( |
| f"ndarray({self.current_arg.shape}, {self.current_arg.dtype})") |
| else: |
| current_arg_repr = repr(self.current_arg) |
| return f"{repr(current_arg_repr)} with description {self.current_desc}" |
| |
| def summarize_return_error(self) -> str: |
| if self.current_return_list is None: |
| return "" |
| try: |
| vm_repr = f"{self.current_return_index}@{self.current_return_list}" |
| except: |
| vm_repr = "<error printing list item>" |
| return f"{vm_repr} with description {self.current_desc}" |
| |
| |
| class FunctionInvoker: |
| """Wraps a VmFunction, enabling invocations against it.""" |
| __slots__ = [ |
| "_vm_context", |
| "_device", |
| "_vm_function", |
| "_abi_dict", |
| "_arg_descs", |
| "_ret_descs", |
| "_has_kwargs", |
| "_tracer", |
| ] |
| |
| def __init__(self, vm_context: VmContext, device: HalDevice, |
| vm_function: VmFunction, |
| tracer: Optional[tracing.ContextTracer]): |
| self._vm_context = vm_context |
| # TODO: Needing to know the precise device to allocate on here is bad |
| # layering and will need to be fixed in some fashion if/when doing |
| # heterogenous dispatch. |
| self._device = device |
| self._vm_function = vm_function |
| self._tracer = tracer |
| self._abi_dict = None |
| self._arg_descs = None |
| self._ret_descs = None |
| self._has_kwargs = False |
| self._parse_abi_dict(vm_function) |
| |
| @property |
| def vm_function(self) -> VmFunction: |
| return self._vm_function |
| |
| def __call__(self, *args, **kwargs): |
| call_trace = None # type: Optional[tracing.CallTrace] |
| if self._tracer: |
| call_trace = self._tracer.start_call(self._vm_function) |
| |
| # Initialize the capacity to our total number of args, since we should |
| # be below that when doing a flat invocation. May want to be more |
| # conservative here when considering nesting. |
| inv = Invocation(self._device) |
| ret_descs = self._ret_descs |
| |
| # If kwargs are present, we treat those more as kwarg-only parameters (i.e. |
| # you cannot just arbitrarily use them to override positional arguments |
| # by name in the current implementation). If the backing ABI metadata |
| # declares support for kwargs, this will be done by having a final |
| # 'kwargs_sdict' arg descriptor, and we rewrite into this form. |
| # So we just append the kwargs dict to the args list and let decoding |
| # happen normally. |
| if self._has_kwargs: |
| args = list(args) |
| args.append(kwargs if kwargs else dict()) |
| |
| arg_list = VmVariantList(len(args)) |
| ret_list = VmVariantList(len(ret_descs) if ret_descs is not None else 1) |
| _merge_python_sequence_to_vm(inv, arg_list, args, self._arg_descs) |
| if call_trace: |
| call_trace.add_vm_list(arg_list, "args") |
| self._vm_context.invoke(self._vm_function, arg_list, ret_list) |
| if call_trace: |
| call_trace.add_vm_list(ret_list, "results") |
| returns = _extract_vm_sequence_to_python(inv, ret_list, ret_descs) |
| if call_trace: |
| call_trace.end_call() |
| return_arity = len(returns) |
| if return_arity == 1: |
| return returns[0] |
| elif return_arity == 0: |
| return None |
| else: |
| return tuple(returns) |
| |
| def _parse_abi_dict(self, vm_function: VmFunction): |
| reflection = vm_function.reflection |
| abi_json = reflection.get("iree.abi") |
| if abi_json is None: |
| # It is valid to have no reflection data, and rely on pure dynamic |
| # dispatch. |
| logging.debug( |
| "Function lacks reflection data. Interop will be limited: %r", |
| vm_function) |
| return |
| try: |
| self._abi_dict = json.loads(abi_json) |
| except json.JSONDecodeError as e: |
| raise RuntimeError( |
| f"Reflection metadata is not valid JSON: {abi_json}") from e |
| try: |
| self._arg_descs = self._abi_dict["a"] |
| self._ret_descs = self._abi_dict["r"] |
| except KeyError as e: |
| raise RuntimeError( |
| f"Malformed function reflection metadata: {reflection}") from e |
| if not isinstance(self._arg_descs, list) or not isinstance( |
| self._ret_descs, list): |
| raise RuntimeError( |
| f"Malformed function reflection metadata structure: {reflection}") |
| |
| # See if kwargs are expected. |
| if self._arg_descs: |
| maybe_kwargs_desc = self._arg_descs[-1] |
| if maybe_kwargs_desc and maybe_kwargs_desc[0] == "sdict_kwargs": |
| self._has_kwargs = True |
| |
| def __repr__(self): |
| return repr(self._vm_function) |
| |
| |
| # Python type to VM Type converters. All of these take: |
| # inv: Invocation |
| # target_list: VmVariantList to append to |
| # python_value: The python value of the given type |
| # desc: The ABI descriptor list (or None if in dynamic mode). |
| |
| |
| def _bool_to_vm(inv: Invocation, t: VmVariantList, x, desc): |
| _int_to_vm(inv, t, int(x), desc) |
| |
| |
| def _int_to_vm(inv: Invocation, t: VmVariantList, x, desc): |
| # Implicit conversion to a 0d tensor. |
| if desc and _is_0d_ndarray_descriptor(desc): |
| casted = _cast_scalar_to_ndarray(inv, x, desc) |
| _ndarray_to_vm(inv, t, casted, desc) |
| return |
| t.push_int(x) |
| |
| |
| def _float_to_vm(inv: Invocation, t: VmVariantList, x, desc): |
| # Implicit conversion to a 0d tensor. |
| if desc and _is_0d_ndarray_descriptor(desc): |
| casted = _cast_scalar_to_ndarray(inv, x, desc) |
| _ndarray_to_vm(inv, t, casted, desc) |
| return |
| t.push_float(x) |
| |
| |
| def _list_or_tuple_to_vm(inv: Invocation, t: VmVariantList, x, desc): |
| desc_type = desc[0] |
| if desc_type != "slist" and desc_type != "stuple": |
| _raise_argument_error(inv, |
| f"passed a list or tuple but expected {desc_type}") |
| # When decoding a list or tuple, the desc object is like: |
| # ['slist', [...value_type_0...], ...] |
| # Where the type is either 'slist' or 'stuple'. |
| sub_descriptors = desc[1:] |
| arity = len(sub_descriptors) |
| if len(x) != arity: |
| _raise_argument_error(inv, |
| f"mismatched list/tuple arity: {len(x)} vs {arity}") |
| sub_list = VmVariantList(arity) |
| _merge_python_sequence_to_vm(inv, sub_list, x, sub_descriptors) |
| t.push_list(sub_list) |
| |
| |
| def _dict_to_vm(inv: Invocation, t: VmVariantList, x, desc): |
| desc_type = desc[0] |
| if desc_type != "sdict" and desc_type != "sdict_kwargs": |
| _raise_argument_error(inv, f"passed a dict but expected {desc_type}") |
| # When decoding a dict, the desc object is like: |
| # ['sdict', ['key0', [...value_type_0...]], ['key1', [...value_type_1...]]]] |
| sub_descriptors = desc[1:] |
| py_values = [] |
| value_descs = [] |
| for key, value_desc in sub_descriptors: |
| try: |
| py_values.append(x[key]) |
| except KeyError: |
| _raise_argument_error(inv, f"expected dict item with key '{key}'") |
| value_descs.append(value_desc) |
| |
| sub_list = VmVariantList(len(py_values)) |
| _merge_python_sequence_to_vm(inv, sub_list, py_values, value_descs) |
| t.push_list(sub_list) |
| |
| |
| def _str_to_vm(inv: Invocation, t: VmVariantList, x, desc): |
| _raise_argument_error(inv, "Python str arguments not yet supported") |
| |
| |
| def _ndarray_to_vm(inv: Invocation, t: VmVariantList, x, desc): |
| # Validate and implicit conversion against type descriptor. |
| if desc is not None: |
| desc_type = desc[0] |
| if desc_type != "ndarray": |
| _raise_argument_error(inv, f"passed an ndarray but expected {desc_type}") |
| dtype_str = desc[1] |
| try: |
| dtype = ABI_TYPE_TO_DTYPE[dtype_str] |
| except KeyError: |
| _raise_argument_error(inv, f"unrecognized dtype '{dtype_str}'") |
| if dtype != x.dtype: |
| x = x.astype(dtype) |
| rank = desc[2] |
| shape = desc[3:] |
| ndarray_shape = x.shape |
| if len(shape) != len(ndarray_shape) or rank != len(ndarray_shape): |
| _raise_argument_error( |
| inv, f"rank mismatch {len(ndarray_shape)} vs {len(shape)}") |
| for exp_dim, act_dim in zip(shape, ndarray_shape): |
| if exp_dim is not None and exp_dim != act_dim: |
| _raise_argument_error( |
| inv, f"shape mismatch {ndarray_shape} vs {tuple(shape)}") |
| actual_dtype = x.dtype |
| for match_dtype, element_type in DTYPE_TO_HAL_ELEMENT_TYPE: |
| if match_dtype == actual_dtype: |
| break |
| else: |
| _raise_argument_error(inv, f"unsupported numpy dtype {x.dtype}") |
| t.push_buffer_view(inv.device, x, element_type) |
| |
| |
| def _ndarray_like_to_vm(inv: Invocation, t: VmVariantList, x, desc): |
| return _ndarray_to_vm(inv, t, np.asarray(x), desc) |
| |
| |
| PYTHON_TO_VM_CONVERTERS = { |
| bool: _bool_to_vm, |
| int: _int_to_vm, |
| float: _float_to_vm, |
| list: _list_or_tuple_to_vm, |
| tuple: _list_or_tuple_to_vm, |
| dict: _dict_to_vm, |
| str: _str_to_vm, |
| np.ndarray: _ndarray_to_vm, |
| } |
| |
| # VM to Python converters. All take: |
| # inv: Invocation |
| # vm_list: VmVariantList to read from |
| # vm_index: Index in the vm_list to extract |
| # desc: The ABI descriptor list (or None if in dynamic mode) |
| # Return the corresponding Python object. |
| |
| |
| def _vm_to_ndarray(inv: Invocation, vm_list: VmVariantList, vm_index: int, |
| desc): |
| return vm_list.get_as_ndarray(vm_index) |
| |
| |
| def _vm_to_sdict(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc): |
| # The descriptor for an sdict is like: |
| # ['sdict', ['key1', value1], ...] |
| sub_vm_list = vm_list.get_as_list(vm_index) |
| item_keys = [] |
| item_descs = [] |
| for k, d in desc[1:]: |
| item_keys.append(k) |
| item_descs.append(d) |
| py_items = _extract_vm_sequence_to_python(inv, sub_vm_list, item_descs) |
| return dict(zip(item_keys, py_items)) |
| |
| |
| def _vm_to_slist(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc): |
| # The descriptor for an slist is like: |
| # ['slist, item1, ...] |
| sub_vm_list = vm_list.get_as_list(vm_index) |
| item_descs = desc[1:] |
| py_items = _extract_vm_sequence_to_python(inv, sub_vm_list, item_descs) |
| return py_items |
| |
| |
| def _vm_to_stuple(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc): |
| return tuple(_vm_to_slist(inv, vm_list, vm_index, desc)) |
| |
| |
| VM_TO_PYTHON_CONVERTERS = { |
| "ndarray": _vm_to_ndarray, |
| "sdict": _vm_to_sdict, |
| "slist": _vm_to_slist, |
| "stuple": _vm_to_stuple, |
| } |
| |
| ABI_TYPE_TO_DTYPE = { |
| # TODO: Others. |
| "f32": np.float32, |
| "i32": np.int32, |
| "i64": np.int64, |
| "f64": np.float64, |
| "i16": np.int16, |
| "i1": np.bool_, |
| } |
| |
| # NOTE: Numpy dtypes are not hashable and exist in a hierarchy that should |
| # be queried via isinstance checks. This should be done as a fallback but |
| # this is a linear list for quick access to the most common. There may also |
| # be a better way to do this. |
| DTYPE_TO_HAL_ELEMENT_TYPE = ( |
| (np.float32, HalElementType.FLOAT_32), |
| (np.float64, HalElementType.FLOAT_64), |
| (np.float16, HalElementType.FLOAT_16), |
| (np.int32, HalElementType.SINT_32), |
| (np.int64, HalElementType.SINT_64), |
| (np.int16, HalElementType.SINT_16), |
| (np.int8, HalElementType.SINT_8), |
| (np.uint32, HalElementType.UINT_32), |
| (np.uint64, HalElementType.UINT_64), |
| (np.uint16, HalElementType.UINT_16), |
| (np.uint8, HalElementType.UINT_8), |
| (np.bool_, HalElementType.BOOL_8), |
| ) |
| |
| |
| def _is_ndarray_descriptor(desc): |
| return desc and desc[0] == "ndarray" |
| |
| |
| def _is_0d_ndarray_descriptor(desc): |
| # Example: ["ndarray", "f32", 0] |
| return desc and desc[0] == "ndarray" and desc[2] == 0 |
| |
| |
| def _cast_scalar_to_ndarray(inv: Invocation, x, desc): |
| # Example descriptor: ["ndarray", "f32", 0] |
| dtype_str = desc[1] |
| try: |
| dtype = ABI_TYPE_TO_DTYPE[dtype_str] |
| except KeyError: |
| _raise_argument_error(inv, f"unrecognized dtype '{dtype_str}'") |
| return dtype(x) |
| |
| |
| def _raise_argument_error(inv: Invocation, |
| summary: str, |
| e: Optional[Exception] = None): |
| new_e = ValueError(f"Error passing argument: {summary} " |
| f"(while encoding argument {inv.summarize_arg_error()})") |
| if e: |
| raise new_e from e |
| else: |
| raise new_e |
| |
| |
| def _raise_return_error(inv: Invocation, |
| summary: str, |
| e: Optional[Exception] = None): |
| new_e = ValueError(f"Error processing function return: {summary} " |
| f"(while decoding return {inv.summarize_return_error()})") |
| if e: |
| raise new_e from e |
| else: |
| raise new_e |
| |
| |
| def _merge_python_sequence_to_vm(inv: Invocation, vm_list, py_list, descs): |
| # For dynamic mode, just assume we have the right arity. |
| if descs is None: |
| descs = [None] * len(py_list) |
| elif len(py_list) != len(descs): |
| _raise_argument_error( |
| inv, f"mismatched function call arity: " |
| f"expected={descs}, got={py_list}") |
| for py_value, desc in zip(py_list, descs): |
| inv.current_arg = py_value |
| inv.current_desc = desc |
| py_type = py_value.__class__ |
| |
| # For ndarray, we want to be able to handle array-like, so check for that |
| # explicitly (duck typed vs static typed). |
| if _is_ndarray_descriptor(desc): |
| converter = _ndarray_like_to_vm |
| else: |
| try: |
| converter = PYTHON_TO_VM_CONVERTERS[py_type] |
| except KeyError: |
| _raise_argument_error( |
| inv, f"cannot map Python type to VM: {py_type}" |
| f" (for desc {desc})") |
| try: |
| converter(inv, vm_list, py_value, desc) |
| except Exception as e: |
| _raise_argument_error(inv, f"exception converting from Python type to VM", |
| e) |
| |
| |
| def _extract_vm_sequence_to_python(inv: Invocation, vm_list, descs): |
| vm_list_arity = len(vm_list) |
| if descs is None: |
| descs = [None] * vm_list_arity |
| elif vm_list_arity != len(descs): |
| _raise_return_error( |
| inv, f"mismatched return arity: {vm_list_arity} vs {len(descs)}") |
| results = [] |
| for vm_index, desc in zip(range(vm_list_arity), descs): |
| inv.current_return_list = vm_list |
| inv.current_return_index = vm_index |
| inv.current_desc = desc |
| if desc is None: |
| # Dynamic (non reflection mode). |
| converted = vm_list.get_variant(vm_index) |
| else: |
| # Known type descriptor. |
| vm_type = desc[0] |
| try: |
| converter = VM_TO_PYTHON_CONVERTERS[vm_type] |
| except KeyError: |
| _raise_return_error(inv, f"cannot map VM type to Python: {vm_type}") |
| try: |
| converted = converter(inv, vm_list, vm_index, desc) |
| except Exception as e: |
| _raise_return_error(inv, f"exception converting from VM type to Python", |
| e) |
| results.append(converted) |
| return results |