blob: fd4a1e670686b04021c69919317664adad035bbf [file] [log] [blame]
# Copyright 2022 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""BufferView and Python Array Protocol interop."""
from typing import Optional, Tuple
import logging
import numpy as np
import numpy.lib.mixins
from .binding import (
BufferUsage,
HalBufferView,
HalDevice,
HalElementType,
MappedMemory,
MemoryType,
)
__all__ = [
"asdevicearray",
"DeviceArray",
]
_DEVICE_HANDLED_FUNCTIONS = {}
def _device_implements(np_function):
"""Decorator that registers a base class implementation."""
def decorator(func):
_DEVICE_HANDLED_FUNCTIONS[np_function] = func
return func
return decorator
class DeviceArray(numpy.lib.mixins.NDArrayOperatorsMixin):
"""An IREE device array.
Device arrays can be in one of two states:
1. Host accessible: The array will be backed by host accessible memory
and can have the usual things done with it that one expects to be
able to do with an ndarray.
2. Device resident: The array is just a handle to a device resident
Buffer (and BufferView wrapper). Metadata about the array are accessible
(shape and dtype) but anything that touches the data cannot be accessed
in this state.
How a device array comes into existence controls how it can transition
between these states:
* A user can create a DeviceArray explicitly with a device allocator.
Such an array will not be implicitly convertible to host accessible,
although accessors exist to do so.
* When created by the platform with a synchronization policy, then
implicit transfer back to the host will trigger appropriate waits and
be performed automatically (this is the common case for function return
values if not otherwise configured, as an example).
"""
def __init__(self,
device: HalDevice,
buffer_view: HalBufferView,
implicit_host_transfer: bool = False,
override_dtype=None):
self._device = device
self._buffer_view = buffer_view
self._implicit_host_transfer = implicit_host_transfer
self._override_dtype = override_dtype
# If the array is host accessible, these will be non-None.
self._mapped_memory: Optional[MappedMemory] = None
self._host_array: Optional[np.ndarray] = None
def __array__(self, dtype=None):
self._transfer_to_host(True)
if dtype is None:
return self._host_array
else:
return self._host_array.__array__(dtype) # pytype: disable=attribute-error
def __array_function__(self, func, types, args, kwargs):
if func in _DEVICE_HANDLED_FUNCTIONS:
return _DEVICE_HANDLED_FUNCTIONS[func](*args, **kwargs)
# Anything else forces a transfer to host and then delegates to the
# host array.
host_array = self.to_host()
return host_array.__array_function__(func, types, args, kwargs) # pytype: disable=attribute-error
def __repr__(self):
return f"<IREE DeviceArray: shape={np.shape(self)}, dtype={self.dtype}>"
@property
def is_host_accessible(self):
"""Whether this array is currently host accessible."""
return self._host_array is not None
def to_host(self) -> np.ndarray:
self._transfer_to_host(False)
return self._host_array
def _transfer_to_host(self, implicit):
if self._host_array is not None:
return
if implicit and not self._implicit_host_transfer:
raise ValueError(
"DeviceArray cannot be implicitly transferred to the host: "
"if necessary, do an explicit transfer via .to_host()")
self._mapped_memory, self._host_array = self._map_to_host()
def _map_to_host(self) -> Tuple[MappedMemory, np.ndarray]:
# TODO: When synchronization is enabled, need to block here.
raw_dtype = self._get_raw_dtype()
mapped_memory = self._buffer_view.map()
host_array = mapped_memory.asarray(self._buffer_view.shape, raw_dtype)
# Detect if we need to force an explicit conversion. This happens when
# we were requested to pretend that the array is in a specific dtype,
# even if that is not representable on the device. You guessed it:
# this is to support bools.
if self._override_dtype is not None and self._override_dtype != raw_dtype:
host_array = host_array.astype(self._override_dtype)
return mapped_memory, host_array
def _get_raw_dtype(self):
return HalElementType.map_to_dtype(self._buffer_view.element_type)
@property
def dtype(self):
if self._override_dtype:
return self._override_dtype
return self._get_raw_dtype()
@property
def shape(self):
return np.shape(self)
def astype(self, dtype, casting="unsafe", copy=True):
if self.dtype == dtype and not copy:
return self
host_ary = self.to_host()
return host_ary.astype(dtype, casting=casting, copy=copy)
def reshape(self, *args):
# TODO(scotttodd): add a native impl with a new buffer_view of the same data
# TODO(scotttodd): return DeviceArray instead of host ndarray?
host_ary = self.to_host()
return host_ary.reshape(*args)
def __iter__(self):
host_ary = self.to_host()
return host_ary.__iter__()
def __getitem__(self, index):
host_ary = self.to_host()
return host_ary.__getitem__(index)
def __reduce__(self):
# Since this is used for making deep copies and pickling, we map
# separately from any interactive state. We just reduce to the actual
# host ndarray, which supports the necessary serialization protocols.
_, host_array = self._map_to_host()
return _restore_reduced_array, (host_array,)
def _restore_reduced_array(ary):
return ary
# Function implementations with custom behavior.
@_device_implements(np.shape)
def _(arr: DeviceArray):
return arr._buffer_view.shape
@_device_implements(np.reshape)
def _(arr: DeviceArray, *args):
return arr.reshape(*args)
def asdevicearray(device: HalDevice,
a,
dtype=None,
*,
implicit_host_transfer: bool = False,
memory_type=MemoryType.DEVICE_LOCAL |
MemoryType.DEVICE_VISIBLE,
allowed_usage=BufferUsage.ALL,
element_type: Optional[HalElementType] = None) -> DeviceArray:
"""Helper to create a DeviceArray from an arbitrary array like.
This is similar in purpose and usage to np.asarray, except that it takes
a device as the first argument. This may not be the best mechanism for
getting a DeviceArray, depending on your use case, but it is reliable
and simple. This function may make a defensive copy or cause implicit
transfers to satisfy the request. If this is important to you, then a lower
level API is likely more appropriate.
Note that additional flags `memory_type`, `allowed_usage` and `element_type`
are only hints if creating a new DeviceArray. If `a` is already a DeviceArray,
they are ignored.
"""
if isinstance(a, DeviceArray):
if dtype is None:
return a
# Need to do a conversion, which we currently do not support on the
# device, so transfer back to the host.
logging.warn(
"Implicit dtype conversion of a DeviceArray forces a host transfer")
# First get an ndarray.
a = np.asarray(a, dtype=dtype)
element_type = map_dtype_to_element_type(a.dtype)
if element_type is None:
raise ValueError(f"Could not map dtype {a.dtype} to IREE element type")
buffer_view = device.allocator.allocate_buffer_copy(
memory_type=memory_type,
allowed_usage=allowed_usage,
buffer=a,
element_type=element_type)
return DeviceArray(device,
buffer_view,
implicit_host_transfer=implicit_host_transfer,
override_dtype=a.dtype)
# NOTE: Numpy dtypes are not hashable and exist in a hierarchy that should
# be queried via isinstance checks. This should be done as a fallback but
# this is a linear list for quick access to the most common. There may also
# be a better way to do this.
_DTYPE_TO_HAL_ELEMENT_TYPE = (
(np.float32, HalElementType.FLOAT_32),
(np.float64, HalElementType.FLOAT_64),
(np.float16, HalElementType.FLOAT_16),
(np.int32, HalElementType.SINT_32),
(np.int64, HalElementType.SINT_64),
(np.int16, HalElementType.SINT_16),
(np.int8, HalElementType.SINT_8),
(np.uint32, HalElementType.UINT_32),
(np.uint64, HalElementType.UINT_64),
(np.uint16, HalElementType.UINT_16),
(np.uint8, HalElementType.UINT_8),
(np.bool_, HalElementType.BOOL_8),
)
def map_dtype_to_element_type(dtype) -> Optional[HalElementType]:
for match_dtype, element_type in _DTYPE_TO_HAL_ELEMENT_TYPE:
if match_dtype == dtype:
return element_type
else:
return None