| # Lint as: python3 |
| # Copyright 2021 The IREE Authors |
| # |
| # Licensed under the Apache License v2.0 with LLVM Exceptions. |
| # See https://llvm.org/LICENSE.txt for license information. |
| # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| from typing import Dict, Optional |
| |
| import json |
| import logging |
| |
| import numpy as np |
| |
| from .binding import ( |
| _invoke_statics, |
| ArgumentPacker, |
| BufferUsage, |
| HalBufferView, |
| HalDevice, |
| InvokeContext, |
| MemoryType, |
| VmContext, |
| VmFunction, |
| VmVariantList, |
| ) |
| |
| from . import tracing |
| from .array_interop import ( |
| map_dtype_to_element_type, |
| DeviceArray, |
| ) |
| from .flags import ( |
| FUNCTION_INPUT_VALIDATION,) |
| |
| __all__ = [ |
| "FunctionInvoker", |
| ] |
| |
| |
| class Invocation: |
| __slots__ = [ |
| "current_arg", |
| "current_desc", |
| "current_return_list", |
| "current_return_index", |
| "device", |
| ] |
| |
| def __init__(self, device: HalDevice): |
| self.device = device |
| # Captured during arg/ret processing to emit better error messages. |
| self.current_arg = None |
| self.current_desc = None |
| self.current_return_list = None |
| self.current_return_index = 0 |
| |
| def summarize_arg_error(self) -> str: |
| if self.current_arg is None: |
| return "" |
| if isinstance(self.current_arg, np.ndarray): |
| current_arg_repr = ( |
| f"ndarray({self.current_arg.shape}, {self.current_arg.dtype})") |
| else: |
| current_arg_repr = repr(self.current_arg) |
| return f"{repr(current_arg_repr)} with description {self.current_desc}" |
| |
| def summarize_return_error(self) -> str: |
| if self.current_return_list is None: |
| return "" |
| try: |
| vm_repr = f"{self.current_return_index}@{self.current_return_list}" |
| except: |
| vm_repr = "<error printing list item>" |
| return f"{vm_repr} with description {self.current_desc}" |
| |
| |
| class FunctionInvoker: |
| """Wraps a VmFunction, enabling invocations against it.""" |
| __slots__ = [ |
| "_vm_context", |
| "_device", |
| "_vm_function", |
| "_abi_dict", |
| "_arg_descs", |
| "_arg_packer", |
| "_ret_descs", |
| "_has_inlined_results", |
| "_tracer", |
| ] |
| |
| def __init__(self, vm_context: VmContext, device: HalDevice, |
| vm_function: VmFunction, |
| tracer: Optional[tracing.ContextTracer]): |
| self._vm_context = vm_context |
| # TODO: Needing to know the precise device to allocate on here is bad |
| # layering and will need to be fixed in some fashion if/when doing |
| # heterogenous dispatch. |
| self._device = device |
| self._vm_function = vm_function |
| self._tracer = tracer |
| self._abi_dict = None |
| self._arg_descs = None |
| self._ret_descs = None |
| self._has_inlined_results = False |
| self._parse_abi_dict(vm_function) |
| self._arg_packer = ArgumentPacker(_invoke_statics, self._arg_descs) |
| |
| @property |
| def vm_function(self) -> VmFunction: |
| return self._vm_function |
| |
| def __call__(self, *args, **kwargs): |
| invoke_context = InvokeContext(self._device) |
| arg_list = self._arg_packer.pack(invoke_context, args, kwargs) |
| |
| call_trace = None # type: Optional[tracing.CallTrace] |
| if self._tracer: |
| call_trace = self._tracer.start_call(self._vm_function) |
| try: |
| # Initialize the capacity to our total number of args, since we should |
| # be below that when doing a flat invocation. May want to be more |
| # conservative here when considering nesting. |
| inv = Invocation(self._device) |
| ret_descs = self._ret_descs |
| |
| ret_list = VmVariantList(len(ret_descs) if ret_descs is not None else 1) |
| if call_trace: |
| call_trace.add_vm_list(arg_list, "args") |
| self._invoke(arg_list, ret_list) |
| if call_trace: |
| call_trace.add_vm_list(ret_list, "results") |
| |
| # Un-inline the results to align with reflection, as needed. |
| reflection_aligned_ret_list = ret_list |
| if self._has_inlined_results: |
| reflection_aligned_ret_list = VmVariantList(1) |
| reflection_aligned_ret_list.push_list(ret_list) |
| returns = _extract_vm_sequence_to_python(inv, reflection_aligned_ret_list, |
| ret_descs) |
| return_arity = len(returns) |
| if return_arity == 1: |
| return returns[0] |
| elif return_arity == 0: |
| return None |
| else: |
| return tuple(returns) |
| finally: |
| if call_trace: |
| call_trace.end_call() |
| |
| # Break out invoke so it shows up in profiles. |
| def _invoke(self, arg_list, ret_list): |
| self._vm_context.invoke(self._vm_function, arg_list, ret_list) |
| |
| def _parse_abi_dict(self, vm_function: VmFunction): |
| reflection = vm_function.reflection |
| abi_json = reflection.get("iree.abi") |
| if abi_json is None: |
| # It is valid to have no reflection data, and rely on pure dynamic |
| # dispatch. |
| logging.debug( |
| "Function lacks reflection data. Interop will be limited: %r", |
| vm_function) |
| return |
| try: |
| self._abi_dict = json.loads(abi_json) |
| except json.JSONDecodeError as e: |
| raise RuntimeError( |
| f"Reflection metadata is not valid JSON: {abi_json}") from e |
| try: |
| self._arg_descs = self._abi_dict["a"] |
| self._ret_descs = self._abi_dict["r"] |
| except KeyError as e: |
| raise RuntimeError( |
| f"Malformed function reflection metadata: {reflection}") from e |
| if not isinstance(self._arg_descs, list) or not isinstance( |
| self._ret_descs, list): |
| raise RuntimeError( |
| f"Malformed function reflection metadata structure: {reflection}") |
| |
| # Detect whether the results are a slist/stuple/sdict, which indicates |
| # that they are inlined with the function's results. |
| if len(self._ret_descs) == 1: |
| maybe_inlined = self._ret_descs[0] |
| if maybe_inlined and maybe_inlined[0] in ["slist", "stuple", "sdict"]: |
| self._has_inlined_results = True |
| |
| def __repr__(self): |
| return repr(self._vm_function) |
| |
| |
| # VM to Python converters. All take: |
| # inv: Invocation |
| # vm_list: VmVariantList to read from |
| # vm_index: Index in the vm_list to extract |
| # desc: The ABI descriptor list (or None if in dynamic mode) |
| # Return the corresponding Python object. |
| |
| |
| def _vm_to_ndarray(inv: Invocation, vm_list: VmVariantList, vm_index: int, |
| desc): |
| # The descriptor for an ndarray is like: |
| # ["ndarray", "<dtype>", <rank>, <dim>...] |
| # ex: ['ndarray', 'i32', 1, 25948] |
| buffer_view = vm_list.get_as_buffer_view(vm_index) |
| dtype_str = desc[1] |
| try: |
| dtype = ABI_TYPE_TO_DTYPE[dtype_str] |
| except KeyError: |
| _raise_return_error(inv, f"unrecognized dtype '{dtype_str}'") |
| x = DeviceArray(inv.device, |
| buffer_view, |
| implicit_host_transfer=True, |
| override_dtype=dtype) |
| return x |
| |
| |
| def _vm_to_sdict(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc): |
| # The descriptor for an sdict is like: |
| # ['sdict', ['key1', value1], ...] |
| sub_vm_list = vm_list.get_as_list(vm_index) |
| item_keys = [] |
| item_descs = [] |
| for k, d in desc[1:]: |
| item_keys.append(k) |
| item_descs.append(d) |
| py_items = _extract_vm_sequence_to_python(inv, sub_vm_list, item_descs) |
| return dict(zip(item_keys, py_items)) |
| |
| |
| def _vm_to_slist(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc): |
| # The descriptor for an slist is like: |
| # ['slist, item1, ...] |
| sub_vm_list = vm_list.get_as_list(vm_index) |
| item_descs = desc[1:] |
| py_items = _extract_vm_sequence_to_python(inv, sub_vm_list, item_descs) |
| return py_items |
| |
| |
| def _vm_to_stuple(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc): |
| return tuple(_vm_to_slist(inv, vm_list, vm_index, desc)) |
| |
| |
| def _vm_to_scalar(type_bound: type): |
| |
| def convert(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc): |
| value = vm_list.get_variant(vm_index) |
| if not isinstance(value, type_bound): |
| raise ReturnError( |
| f"expected an {type_bound} value but got {value.__class__}") |
| return value |
| |
| return convert |
| |
| |
| def _vm_to_pylist(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc): |
| # The descriptor for a pylist is like: |
| # ['pylist', element_type] |
| sub_vm_list = vm_list.get_as_list(vm_index) |
| element_type_desc = desc[1:] |
| py_items = _extract_vm_sequence_to_python( |
| inv, sub_vm_list, element_type_desc * len(sub_vm_list)) |
| return py_items |
| |
| |
| VM_TO_PYTHON_CONVERTERS = { |
| "ndarray": _vm_to_ndarray, |
| "sdict": _vm_to_sdict, |
| "slist": _vm_to_slist, |
| "stuple": _vm_to_stuple, |
| "py_homogeneous_list": _vm_to_pylist, |
| |
| # Scalars. |
| "i8": _vm_to_scalar(int), |
| "i16": _vm_to_scalar(int), |
| "i32": _vm_to_scalar(int), |
| "i64": _vm_to_scalar(int), |
| "f16": _vm_to_scalar(float), |
| "f32": _vm_to_scalar(float), |
| "f64": _vm_to_scalar(float), |
| "bf16": _vm_to_scalar(float), |
| } |
| |
| ABI_TYPE_TO_DTYPE = { |
| # TODO: Others. |
| "f32": np.float32, |
| "i32": np.int32, |
| "i64": np.int64, |
| "f64": np.float64, |
| "i16": np.int16, |
| "i8": np.int8, |
| "i1": np.bool_, |
| } |
| |
| # When we get an ndarray as an argument and are implicitly mapping it to a |
| # buffer view, flags for doing so. |
| IMPLICIT_BUFFER_ARG_MEMORY_TYPE = (MemoryType.DEVICE_LOCAL | |
| MemoryType.DEVICE_VISIBLE) |
| IMPLICIT_BUFFER_ARG_USAGE = BufferUsage.ALL |
| |
| |
| def _is_ndarray_descriptor(desc): |
| return desc and desc[0] == "ndarray" |
| |
| |
| def _is_0d_ndarray_descriptor(desc): |
| # Example: ["ndarray", "f32", 0] |
| return desc and desc[0] == "ndarray" and desc[2] == 0 |
| |
| |
| def _cast_scalar_to_ndarray(inv: Invocation, x, desc): |
| # Example descriptor: ["ndarray", "f32", 0] |
| dtype_str = desc[1] |
| try: |
| dtype = ABI_TYPE_TO_DTYPE[dtype_str] |
| except KeyError: |
| _raise_argument_error(inv, f"unrecognized dtype '{dtype_str}'") |
| return dtype(x) |
| |
| |
| class ArgumentError(ValueError): |
| pass |
| |
| |
| class ReturnError(ValueError): |
| pass |
| |
| |
| def _raise_argument_error(inv: Invocation, |
| summary: str, |
| e: Optional[Exception] = None): |
| new_e = ArgumentError( |
| f"Error passing argument: {summary} " |
| f"(while encoding argument {inv.summarize_arg_error()})") |
| if e: |
| raise new_e from e |
| else: |
| raise new_e |
| |
| |
| def _raise_return_error(inv: Invocation, |
| summary: str, |
| e: Optional[Exception] = None): |
| new_e = ReturnError(f"Error processing function return: {summary} " |
| f"(while decoding return {inv.summarize_return_error()})") |
| if e: |
| raise new_e from e |
| else: |
| raise new_e |
| |
| |
| def _extract_vm_sequence_to_python(inv: Invocation, vm_list, descs): |
| vm_list_arity = len(vm_list) |
| if descs is None: |
| descs = [None] * vm_list_arity |
| elif vm_list_arity != len(descs): |
| _raise_return_error( |
| inv, f"mismatched return arity: {vm_list_arity} vs {len(descs)}") |
| results = [] |
| for vm_index, desc in zip(range(vm_list_arity), descs): |
| inv.current_return_list = vm_list |
| inv.current_return_index = vm_index |
| inv.current_desc = desc |
| if desc is None: |
| # Dynamic (non reflection mode). |
| converted = vm_list.get_variant(vm_index) |
| # Special case: Upgrade HalBufferView to a DeviceArray. We do that here |
| # since this is higher level and it preserves layering. Note that |
| # the reflection case also does this conversion. |
| if isinstance(converted, HalBufferView): |
| converted = DeviceArray(inv.device, |
| converted, |
| implicit_host_transfer=True) |
| else: |
| # Known type descriptor. |
| vm_type = desc if isinstance(desc, str) else desc[0] |
| try: |
| converter = VM_TO_PYTHON_CONVERTERS[vm_type] |
| except KeyError: |
| _raise_return_error(inv, f"cannot map VM type to Python: {vm_type}") |
| try: |
| converted = converter(inv, vm_list, vm_index, desc) |
| except ReturnError: |
| raise |
| except Exception as e: |
| _raise_return_error(inv, f"exception converting from VM type to Python", |
| e) |
| results.append(converted) |
| return results |