| // Copyright 2023 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "./numpy_interop.h" |
| |
| #include "./binding.h" |
| |
| #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION |
| #include "numpy/arrayobject.h" |
| |
| namespace iree::python::numpy { |
| |
| namespace { |
| |
| int internal_import_array() { |
| import_array1(-1); |
| return 0; |
| } |
| |
| } // namespace |
| |
| void InitializeNumPyInterop() { |
| if (internal_import_array() < 0) { |
| throw py::import_error("numpy.core.multiarray failed to import"); |
| } |
| } |
| |
| int ConvertHalElementTypeToNumPyTypeNum(iree_hal_element_type_t t) { |
| switch (t) { |
| case IREE_HAL_ELEMENT_TYPE_BOOL_8: |
| return NPY_BOOL; |
| case IREE_HAL_ELEMENT_TYPE_INT_8: |
| case IREE_HAL_ELEMENT_TYPE_SINT_8: |
| return NPY_INT8; |
| case IREE_HAL_ELEMENT_TYPE_UINT_8: |
| return NPY_UINT8; |
| case IREE_HAL_ELEMENT_TYPE_INT_16: |
| case IREE_HAL_ELEMENT_TYPE_SINT_16: |
| return NPY_INT16; |
| case IREE_HAL_ELEMENT_TYPE_UINT_16: |
| return NPY_UINT16; |
| case IREE_HAL_ELEMENT_TYPE_INT_32: |
| case IREE_HAL_ELEMENT_TYPE_SINT_32: |
| return NPY_INT32; |
| case IREE_HAL_ELEMENT_TYPE_UINT_32: |
| return NPY_UINT32; |
| case IREE_HAL_ELEMENT_TYPE_INT_64: |
| case IREE_HAL_ELEMENT_TYPE_SINT_64: |
| return NPY_INT64; |
| case IREE_HAL_ELEMENT_TYPE_UINT_64: |
| return NPY_UINT64; |
| case IREE_HAL_ELEMENT_TYPE_FLOAT_16: |
| return NPY_FLOAT16; |
| case IREE_HAL_ELEMENT_TYPE_FLOAT_32: |
| return NPY_FLOAT32; |
| case IREE_HAL_ELEMENT_TYPE_FLOAT_64: |
| return NPY_FLOAT64; |
| case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64: |
| return NPY_COMPLEX64; |
| case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128: |
| return NPY_COMPLEX128; |
| default: |
| throw py::value_error("Unsupported VM Buffer -> numpy dtype mapping"); |
| } |
| } |
| |
| py::object DescrNewFromType(int typenum) { |
| PyArray_Descr *dtype = PyArray_DescrNewFromType(typenum); |
| if (!dtype) { |
| throw py::python_error(); |
| } |
| return py::steal((PyObject *)dtype); |
| } |
| |
| int TypenumFromDescr(py::handle dtype) { |
| if (!PyArray_DescrCheck(dtype.ptr())) { |
| throw py::cast_error(); |
| } |
| PyArray_Descr *descr = (PyArray_Descr *)dtype.ptr(); |
| return descr->type_num; |
| } |
| |
| py::object SimpleNewFromData(int nd, intptr_t const *dims, int typenum, |
| void *data, py::handle base_object) { |
| PyObject *array_c = PyArray_SimpleNewFromData(nd, dims, typenum, data); |
| if (!array_c) throw py::python_error(); |
| py::object array = py::steal(array_c); |
| if (base_object) { |
| if (PyArray_SetBaseObject(reinterpret_cast<PyArrayObject *>(array.ptr()), |
| base_object.ptr())) { |
| throw py::python_error(); |
| } |
| base_object.inc_ref(); |
| } |
| return array; |
| } |
| |
| } // namespace iree::python::numpy |