blob: b66505ae101dc2ffacbe30c241691bb8c9aef7bc [file] [log] [blame]
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "./numpy_interop.h"
#include "./binding.h"
namespace iree::python::numpy {
static const char* ConvertHalElementTypeToDtypeName(iree_hal_element_type_t t) {
switch (t) {
case IREE_HAL_ELEMENT_TYPE_BOOL_8:
return "bool";
case IREE_HAL_ELEMENT_TYPE_INT_8:
case IREE_HAL_ELEMENT_TYPE_SINT_8:
return "int8";
case IREE_HAL_ELEMENT_TYPE_UINT_8:
return "uint8";
case IREE_HAL_ELEMENT_TYPE_INT_16:
case IREE_HAL_ELEMENT_TYPE_SINT_16:
return "int16";
case IREE_HAL_ELEMENT_TYPE_UINT_16:
return "uint16";
case IREE_HAL_ELEMENT_TYPE_INT_32:
case IREE_HAL_ELEMENT_TYPE_SINT_32:
return "int32";
case IREE_HAL_ELEMENT_TYPE_UINT_32:
return "uint32";
case IREE_HAL_ELEMENT_TYPE_INT_64:
case IREE_HAL_ELEMENT_TYPE_SINT_64:
return "int64";
case IREE_HAL_ELEMENT_TYPE_UINT_64:
return "uint64";
case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
return "float16";
case IREE_HAL_ELEMENT_TYPE_FLOAT_32:
return "float32";
case IREE_HAL_ELEMENT_TYPE_FLOAT_64:
return "float64";
case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64:
return "complex64";
case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128:
return "complex128";
default:
throw py::value_error("Unsupported VM Buffer -> numpy dtype mapping");
}
}
py::object DescrNewFromType(iree_hal_element_type_t t) {
const char* name = ConvertHalElementTypeToDtypeName(t);
// import_() is a sys.modules dict lookup, not a full import. Could cache
// the module reference if this becomes a hot path.
return py::module_::import_("numpy").attr("dtype")(name);
}
py::object SimpleNewFromData(int nd, intptr_t const* dims,
py::handle dtype_descr, void* data,
py::handle base_object) {
int itemsize = py::cast<int>(dtype_descr.attr("itemsize"));
Py_ssize_t total_elems = 1;
for (int i = 0; i < nd; ++i) {
total_elems *= dims[i];
}
Py_ssize_t byte_len = total_elems * itemsize;
// Create a writable memoryview that keeps base_object alive.
// PyBuffer_FillInfo sets buf.obj = base_object (with Py_INCREF), and
// PyMemoryView_FromBuffer copies the buffer info. When the memoryview is
// released, PyBuffer_Release DECREFs base_object. This maintains the
// lifetime chain: array.base -> memoryview -> base_object, matching the
// original PyArray_SetBaseObject semantics. The writable flag matches
// the original PyArray_SimpleNewFromData behavior.
py::object buf;
if (base_object.ptr()) {
Py_buffer pybuf;
if (PyBuffer_FillInfo(&pybuf, base_object.ptr(), static_cast<char*>(data),
byte_len,
/*readonly=*/0, PyBUF_WRITABLE) == 0) {
buf = py::steal(PyMemoryView_FromBuffer(&pybuf));
}
if (!buf.ptr()) PyErr_Clear();
}
if (!buf.ptr()) {
// Fallback: base_object is null or PyBuffer_FillInfo failed.
buf = py::steal(PyMemoryView_FromMemory(static_cast<char*>(data), byte_len,
PyBUF_WRITE));
if (!buf.ptr()) throw py::python_error();
}
// import_() is a sys.modules dict lookup, not a full import. Could cache
// the module reference if this becomes a hot path.
py::object array = py::module_::import_("numpy").attr("frombuffer")(
buf, py::arg("dtype") = dtype_descr);
// Reshape if needed (frombuffer always returns 1-D).
if (nd != 1) {
PyObject* shape_raw = PyTuple_New(nd);
if (!shape_raw) throw py::python_error();
for (int i = 0; i < nd; ++i) {
PyObject* item = PyLong_FromSsize_t(dims[i]);
if (!item) {
Py_DECREF(shape_raw);
throw py::python_error();
}
PyTuple_SetItem(shape_raw, i, item);
}
py::object shape_tuple = py::steal(shape_raw);
array = array.attr("reshape")(shape_tuple);
}
return array;
}
} // namespace iree::python::numpy