|  | // Copyright 2023 The IREE Authors | 
|  | // | 
|  | // Licensed under the Apache License v2.0 with LLVM Exceptions. | 
|  | // See https://llvm.org/LICENSE.txt for license information. | 
|  | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
|  |  | 
|  | #include "./numpy_interop.h" | 
|  |  | 
|  | #include "./binding.h" | 
|  |  | 
|  | #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION | 
|  | #include "numpy/arrayobject.h" | 
|  |  | 
|  | namespace iree::python::numpy { | 
|  |  | 
|  | namespace { | 
|  |  | 
|  | int internal_import_array() { | 
|  | import_array1(-1); | 
|  | return 0; | 
|  | } | 
|  |  | 
|  | }  // namespace | 
|  |  | 
|  | void InitializeNumPyInterop() { | 
|  | if (internal_import_array() < 0) { | 
|  | throw py::import_error("numpy.core.multiarray failed to import"); | 
|  | } | 
|  | } | 
|  |  | 
|  | int ConvertHalElementTypeToNumPyTypeNum(iree_hal_element_type_t t) { | 
|  | switch (t) { | 
|  | case IREE_HAL_ELEMENT_TYPE_BOOL_8: | 
|  | return NPY_BOOL; | 
|  | case IREE_HAL_ELEMENT_TYPE_INT_8: | 
|  | case IREE_HAL_ELEMENT_TYPE_SINT_8: | 
|  | return NPY_INT8; | 
|  | case IREE_HAL_ELEMENT_TYPE_UINT_8: | 
|  | return NPY_UINT8; | 
|  | case IREE_HAL_ELEMENT_TYPE_INT_16: | 
|  | case IREE_HAL_ELEMENT_TYPE_SINT_16: | 
|  | return NPY_INT16; | 
|  | case IREE_HAL_ELEMENT_TYPE_UINT_16: | 
|  | return NPY_UINT16; | 
|  | case IREE_HAL_ELEMENT_TYPE_INT_32: | 
|  | case IREE_HAL_ELEMENT_TYPE_SINT_32: | 
|  | return NPY_INT32; | 
|  | case IREE_HAL_ELEMENT_TYPE_UINT_32: | 
|  | return NPY_UINT32; | 
|  | case IREE_HAL_ELEMENT_TYPE_INT_64: | 
|  | case IREE_HAL_ELEMENT_TYPE_SINT_64: | 
|  | return NPY_INT64; | 
|  | case IREE_HAL_ELEMENT_TYPE_UINT_64: | 
|  | return NPY_UINT64; | 
|  | case IREE_HAL_ELEMENT_TYPE_FLOAT_16: | 
|  | return NPY_FLOAT16; | 
|  | case IREE_HAL_ELEMENT_TYPE_FLOAT_32: | 
|  | return NPY_FLOAT32; | 
|  | case IREE_HAL_ELEMENT_TYPE_FLOAT_64: | 
|  | return NPY_FLOAT64; | 
|  | case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64: | 
|  | return NPY_COMPLEX64; | 
|  | case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128: | 
|  | return NPY_COMPLEX128; | 
|  | default: | 
|  | throw py::value_error("Unsupported VM Buffer -> numpy dtype mapping"); | 
|  | } | 
|  | } | 
|  |  | 
|  | py::object DescrNewFromType(int typenum) { | 
|  | PyArray_Descr *dtype = PyArray_DescrNewFromType(typenum); | 
|  | if (!dtype) { | 
|  | throw py::python_error(); | 
|  | } | 
|  | return py::steal((PyObject *)dtype); | 
|  | } | 
|  |  | 
|  | int TypenumFromDescr(py::handle dtype) { | 
|  | if (!PyArray_DescrCheck(dtype.ptr())) { | 
|  | throw py::cast_error(); | 
|  | } | 
|  | PyArray_Descr *descr = (PyArray_Descr *)dtype.ptr(); | 
|  | return descr->type_num; | 
|  | } | 
|  |  | 
|  | py::object SimpleNewFromData(int nd, intptr_t const *dims, int typenum, | 
|  | void *data, py::handle base_object) { | 
|  | PyObject *array_c = PyArray_SimpleNewFromData(nd, dims, typenum, data); | 
|  | if (!array_c) throw py::python_error(); | 
|  | py::object array = py::steal(array_c); | 
|  | if (base_object) { | 
|  | if (PyArray_SetBaseObject(reinterpret_cast<PyArrayObject *>(array.ptr()), | 
|  | base_object.ptr())) { | 
|  | throw py::python_error(); | 
|  | } | 
|  | base_object.inc_ref(); | 
|  | } | 
|  | return array; | 
|  | } | 
|  |  | 
|  | }  // namespace iree::python::numpy |