blob: e92adf158308104cee70505a6fd43a8a56e156a2 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "bindings/python/pyiree/rt/function_abi.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/types/span.h"
#include "bindings/python/pyiree/common/status_utils.h"
#include "bindings/python/pyiree/rt/hal.h"
#include "bindings/python/pyiree/rt/vm.h"
#include "iree/base/api.h"
#include "iree/base/signature_mangle.h"
#include "iree/hal/api.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/vm/list.h"
#include "iree/vm/ref.h"
namespace iree {
namespace python {
namespace {
class SipLinearizeInputsVisitor {
public:
SipLinearizeInputsVisitor(SipSignatureParser& parser, py::tuple& py_args,
py::dict& py_kwargs,
absl::InlinedVector<py::handle, 4>& linear_py_args)
: parser_(parser),
py_args_(py_args),
py_kwargs_(py_kwargs),
linear_py_args_(linear_py_args) {}
void IntegerKey(SipSignatureParser& p, int k) {
auto current = tos();
try {
auto current_seq = current.cast<py::sequence>();
stack_.push_back(current_seq[k]);
} catch (std::exception& e) {
auto message =
absl::StrCat("Expected sequence index ", k, " not found in ",
py::repr(current).cast<std::string>());
SetError(std::move(message));
}
}
void StringKey(SipSignatureParser& p, absl::string_view k) {
auto current = tos();
py::str py_k(k.data(), k.size());
try {
auto current_dict = tos().cast<py::dict>();
stack_.push_back(current_dict[py_k]);
} catch (std::exception& e) {
auto message = absl::StrCat("Expected key '", k, "' not found in ",
py::repr(current).cast<std::string>());
SetError(std::move(message));
}
}
void OpenStruct(SipSignatureParser& p,
SipSignatureParser::StructType struct_type) {
// Only structs directly off of the root are opened without a key.
if (!stack_.empty()) return;
py::handle tos;
switch (struct_type) {
case SipSignatureParser::StructType::kDict:
tos = py_kwargs_;
break;
case SipSignatureParser::StructType::kSequence:
tos = py_args_;
break;
}
stack_.push_back(tos);
}
void CloseStruct(SipSignatureParser& p) {
if (!stack_.empty()) {
stack_.pop_back();
}
}
void MapToRawSignatureIndex(SipSignatureParser& p, int index) {
if (static_cast<int>(linear_py_args_.size()) <= index) {
linear_py_args_.resize(index + 1);
}
linear_py_args_[index] = tos();
if (!stack_.empty()) {
stack_.pop_back();
}
}
private:
py::handle tos() {
if (stack_.empty()) {
SetError("Mismatched structures during unpacking arguments");
return py::handle();
}
return stack_.back();
}
void SetError(std::string message) { parser_.SetError(message); }
SipSignatureParser& parser_;
py::tuple& py_args_;
py::dict& py_kwargs_;
absl::InlinedVector<py::handle, 4>& linear_py_args_;
// The struct stack. Top is the last.
// When the stack is empty, opening a struct will push the first entry:
// py_args_ if a sequence and py_kwargs_ if a dict. Otherwise, new stack
// levels are opened upon key resolution.
// Either CloseStruct or MapToRawSignatureIndex terminate each level of
// the stack.
absl::InlinedVector<py::handle, 4> stack_;
};
class SipStructureResultsVisitor {
public:
SipStructureResultsVisitor(
SipSignatureParser& parser,
absl::InlinedVector<py::object, 4>& linear_py_results)
: parser_(parser), linear_py_results_(linear_py_results) {}
void IntegerKey(SipSignatureParser& p, int k) {
pending_assign_key_ = py::int_(k);
}
void StringKey(SipSignatureParser& p, absl::string_view k) {
pending_assign_key_ = py::str(k.data(), k.size());
}
void OpenStruct(SipSignatureParser& p,
SipSignatureParser::StructType struct_type) {
py::object struct_obj;
bool is_dict;
switch (struct_type) {
case SipSignatureParser::StructType::kDict:
struct_obj = py::dict();
is_dict = true;
break;
case SipSignatureParser::StructType::kSequence:
struct_obj = py::list();
is_dict = false;
break;
default:
SetError("Illegal structure type");
return;
}
// Must assign before pushing so as to assign to the prior level.
AssignCurrent(struct_obj);
stack_.push_back(std::make_pair(std::move(struct_obj), is_dict));
}
void CloseStruct(SipSignatureParser& p) {
if (!stack_.empty()) stack_.pop_back();
pending_assign_key_ = py::none(); // Just in case (for error path).
}
void MapToRawSignatureIndex(SipSignatureParser& p, int index) {
if (index < 0 || index >= static_cast<int>(linear_py_results_.size())) {
SetError("Raw result index out of range in reflection metadata");
return;
}
py::object current_obj = linear_py_results_[index];
AssignCurrent(std::move(current_obj));
}
py::object ConsumeResult() {
if (result)
return std::move(result);
else
return py::none();
}
private:
void AssignCurrent(py::object value) {
if (stack_.empty()) {
if (result) {
SetError("Attempt to unpack multiple roots");
return;
}
result = std::move(value);
} else {
if (!pending_assign_key_ || pending_assign_key_.is_none()) {
SetError("Attempt to assign out of order");
return;
}
try {
auto stack_entry = stack_.back();
bool is_dict = stack_entry.second;
if (is_dict) {
stack_entry.first.cast<py::dict>()[pending_assign_key_] = value;
} else {
int index = pending_assign_key_.cast<int>();
py::list l = stack_entry.first.cast<py::list>();
// Technically, signature keys can come out of order, which is sad.
// none-fill the list as needed to fill the gap.
// TODO: Further guarantees can be enforced at conversion time,
// simplifying this.
bool extended = false;
int list_size = l.size();
if (list_size <= index) {
while (l.size() <= index) {
l.append(py::none());
extended = true;
}
l.append(std::move(value));
} else {
l[index] = std::move(value);
}
pending_assign_key_ = py::none();
}
} catch (std::exception& e) {
SetError("Corrupt sip signature: Signature/data type mismatch");
pending_assign_key_ = py::none();
}
}
}
void SetError(std::string message) { parser_.SetError(message); }
SipSignatureParser& parser_;
absl::InlinedVector<py::object, 4>& linear_py_results_;
py::object result;
// Parse state.
// A new level of the stack is opened for each container. Each entry is a
// pair of (container, is_dict). If not is_dict, it is assumed to be a list.
absl::InlinedVector<std::pair<py::object, bool>, 4> stack_;
// If a pending key has been set for a following assignment, it is noted
// here. The nested assignments, the call sequence is:
// 1. OpenStruct
// For-each key:
// a. IntegerKey or StringKey
// b. MapToRawSignatureIndex
// 2. CloseStruct
// For single-result situations, it is legal to just have a single, top-level
// call to MapToRawSignatureIndex, which causes the entire result to be
// equal to the current object.
py::object pending_assign_key_;
};
// Python friendly entry-point for creating an instance from a list
// of attributes. This is not particularly efficient and is primarily
// for testing. Typically, this will be created directly from a function
// and the attribute introspection will happen internal to C++.
std::unique_ptr<FunctionAbi> PyCreateAbi(
HalDevice& device, std::shared_ptr<HostTypeFactory> host_type_factory,
std::vector<std::pair<std::string, std::string>> attrs) {
auto lookup =
[&attrs](absl::string_view key) -> absl::optional<absl::string_view> {
for (const auto& kv : attrs) {
if (kv.first == key) return kv.second;
}
return absl::nullopt;
};
return FunctionAbi::Create(device, std::move(host_type_factory), lookup);
}
VmVariantList PyAllocateResults(FunctionAbi* self, VmVariantList& f_args,
bool static_alloc) {
auto f_results = VmVariantList::Create(self->raw_result_arity());
if (static_alloc) {
// For static dispatch, attempt to fully allocate and perform shape
// inference.
self->AllocateResults(absl::MakeConstSpan(self->raw_config().results),
f_args, f_results);
}
return f_results;
}
// RAII wrapper for a Py_buffer which calls PyBuffer_Release when it goes
// out of scope.
class PyBufferReleaser {
public:
PyBufferReleaser(Py_buffer& b) : b_(b) {}
~PyBufferReleaser() { PyBuffer_Release(&b_); }
private:
Py_buffer& b_;
};
pybind11::error_already_set RaiseBufferMismatchError(
std::string message, py::handle obj,
const RawSignatureParser::Description& desc) {
message.append("For argument = ");
auto arg_py_str = py::str(obj);
auto arg_str = static_cast<std::string>(arg_py_str);
message.append(arg_str);
message.append(" (expected ");
desc.ToString(message);
message.append(")");
return RaiseValueError(message.c_str());
}
// Verifies and maps the py buffer shape and layout to the bound argument.
// Returns false if not compatible.
void MapBufferAttrs(Py_buffer& py_view,
const RawSignatureParser::Description& desc,
absl::InlinedVector<int, 2>& dynamic_dims) {
// Verify that rank matches.
if (py_view.ndim != desc.dims.size()) {
throw RaiseBufferMismatchError(
absl::StrCat("Mismatched buffer rank (received: ", py_view.ndim,
", expected: ", desc.dims.size(), "): "),
py::handle(py_view.obj), desc);
}
// Verify that the item size matches.
size_t f_item_size =
AbiConstants::kScalarTypeSize[static_cast<int>(desc.buffer.scalar_type)];
if (f_item_size != py_view.itemsize) {
throw RaiseBufferMismatchError(
absl::StrCat("Mismatched buffer item size (received: ",
py_view.itemsize, ", expected: ", f_item_size, "): "),
py::handle(py_view.obj), desc);
}
// Note: The python buffer format does not map precisely to IREE's type
// system, so the below is only advisory for where they do match. Otherwise,
// it is basically a bitcast.
const char* f_expected_format =
kScalarTypePyFormat[static_cast<int>(desc.buffer.scalar_type)];
// If the format is booleans, we should treat it as bytes.
const char* f_found_format = py_view.format;
if (strcmp(f_found_format, "?") == 0) {
f_found_format = "b";
}
if (f_expected_format != nullptr &&
strcmp(f_expected_format, f_found_format) != 0) {
throw RaiseBufferMismatchError(
absl::StrCat("Mismatched buffer format (received: ", py_view.format,
", expected: ", f_expected_format, "): "),
py::handle(py_view.obj), desc);
}
// Verify shape, populating dynamic_dims while looping.
for (size_t i = 0; i < py_view.ndim; ++i) {
auto py_dim = py_view.shape[i];
auto f_dim = desc.dims[i];
if (f_dim < 0) {
// Dynamic.
dynamic_dims.push_back(py_dim);
} else if (py_dim != f_dim) {
// Mismatch.
throw RaiseBufferMismatchError(
absl::StrCat("Mismatched buffer dim (received: ", py_dim,
", expected: ", f_dim, "): "),
py::handle(py_view.obj), desc);
}
}
}
void PackScalar(const RawSignatureParser::Description& desc, py::handle py_arg,
VmVariantList& f_args) {
iree_vm_value value;
value.type = IREE_VM_VALUE_TYPE_I32;
switch (desc.scalar.type) {
case AbiConstants::ScalarType::kUint8:
case AbiConstants::ScalarType::kUint16:
case AbiConstants::ScalarType::kUint32: {
value.i32 = py_arg.cast<int32_t>();
break;
}
case AbiConstants::ScalarType::kSint8:
case AbiConstants::ScalarType::kSint16:
case AbiConstants::ScalarType::kSint32: {
value.i32 = py_arg.cast<int32_t>();
break;
}
default:
throw RaisePyError(PyExc_NotImplementedError, "Unsupported scalar type");
}
CheckApiStatus(iree_vm_list_push_value(f_args.raw_ptr(), &value),
"Could not pack scalar argument");
}
py::object UnpackScalar(const RawSignatureParser::Description& desc,
const iree_vm_variant_t& f_result) {
switch (desc.scalar.type) {
case AbiConstants::ScalarType::kUint8:
case AbiConstants::ScalarType::kUint16:
case AbiConstants::ScalarType::kUint32: {
return py::int_(static_cast<uint32_t>(f_result.i32));
}
case AbiConstants::ScalarType::kSint8:
case AbiConstants::ScalarType::kSint16:
case AbiConstants::ScalarType::kSint32: {
return py::int_(f_result.i32);
}
default:
throw RaisePyError(PyExc_NotImplementedError, "Unsupported scalar type");
}
}
} // namespace
//------------------------------------------------------------------------------
// FunctionAbi
//------------------------------------------------------------------------------
std::string FunctionAbi::DebugString() const {
RawSignatureParser p;
auto s = p.FunctionSignatureToString(raw_config_.signature);
if (!s) {
return "<FunctionAbi NO_DEBUG_INFO>";
}
auto result = absl::StrCat("<FunctionAbi ", *s);
if (sip_signature_) {
absl::StrAppend(&result, " SIP:'", *sip_signature_, "'");
}
absl::StrAppend(&result, ">");
return result;
}
std::unique_ptr<FunctionAbi> FunctionAbi::Create(
HalDevice& device, std::shared_ptr<HostTypeFactory> host_type_factory,
AttributeLookup lookup) {
auto abi =
absl::make_unique<FunctionAbi>(device, std::move(host_type_factory));
// Fetch key attributes for the raw ABI.
auto raw_version = lookup("fv");
auto raw_fsig_str = lookup("f");
// Validation.
if (!raw_fsig_str) {
throw RaiseValueError("No raw abi reflection metadata for function");
}
if (!raw_version || *raw_version != "1") {
throw RaiseValueError("Unsupported raw function ABI version");
}
// Parse signature.
abi->raw_config().signature = std::string(*raw_fsig_str);
RawSignatureParser raw_parser;
raw_parser.VisitInputs(*raw_fsig_str,
[&abi](const RawSignatureParser::Description& d) {
abi->raw_config().inputs.push_back(d);
});
raw_parser.VisitResults(*raw_fsig_str,
[&abi](const RawSignatureParser::Description& d) {
abi->raw_config().results.push_back(d);
});
if (raw_parser.GetError()) {
auto message = absl::StrCat(
"Error parsing raw ABI signature: ", *raw_parser.GetError(), " ('",
*raw_fsig_str, "')");
throw RaiseValueError(message.c_str());
}
auto reported_abi = lookup("abi");
auto sip_signature = lookup("sip");
if (reported_abi && *reported_abi == "sip" && sip_signature) {
abi->sip_signature_ = std::string(*sip_signature);
}
return abi;
}
void FunctionAbi::Pack(py::tuple& py_args, py::dict& py_kwargs,
absl::Span<const Description> descs, VmVariantList& args,
bool writable) {
absl::InlinedVector<py::handle, 4> linear_py_args;
if (!sip_signature_) {
// There is no python -> linear translation.
size_t e = py_args.size();
linear_py_args.resize(e);
for (size_t i = 0; i < e; ++i) {
linear_py_args[i] = py_args[i];
}
} else {
// Linearize based on sip signature.
// Note that we use explicit errors here and do not let exceptions escape
// since parsing may be happening in a library not compiled for exceptions.
SipSignatureParser parser;
SipLinearizeInputsVisitor visitor(parser, py_args, py_kwargs,
linear_py_args);
parser.VisitInputs(visitor, *sip_signature_);
auto error = parser.GetError();
if (error) {
auto message =
absl::StrCat("Could not unpack python arguments: ", *error);
throw RaiseValueError(message.c_str());
}
}
RawPack(descs, absl::MakeSpan(linear_py_args), args, writable);
}
py::object FunctionAbi::Unpack(absl::Span<const Description> descs,
VmVariantList& f_results) {
absl::InlinedVector<py::object, 4> linear_py_results;
linear_py_results.resize(f_results.size());
RawUnpack(descs, f_results, absl::MakeSpan(linear_py_results));
if (!sip_signature_) {
// Just emulate unpacking to a tuple, which is the standard way of
// returning multiple results from a python function.
auto linear_size = linear_py_results.size();
if (linear_size == 0) {
return py::none();
} else if (linear_size == 1) {
return std::move(linear_py_results.front());
}
// Fall back to tuple multi-result form.
py::tuple py_result_tuple(linear_size);
for (size_t i = 0; i < linear_size; ++i) {
py_result_tuple[i] = std::move(linear_py_results[i]);
}
return std::move(py_result_tuple); // Without move, warns of copy.
}
// Structured unpack with the sip signature.
// Note that we use explicit errors here and do not let exceptions escape
// since parsing may be happening in a library not compiled for exceptions.
SipSignatureParser parser;
SipStructureResultsVisitor visitor(parser, linear_py_results);
parser.VisitResults(visitor, *sip_signature_);
auto error = parser.GetError();
if (error) {
auto message =
absl::StrCat("Could not create python structured results: ", *error);
throw RaiseValueError(message.c_str());
}
assert(!PyErr_Occurred());
return visitor.ConsumeResult();
}
void FunctionAbi::RawPack(absl::Span<const Description> descs,
absl::Span<py::handle> py_args, VmVariantList& f_args,
bool writable) {
if (descs.size() != py_args.size()) {
throw RaiseValueError("Mismatched RawPack() input arity");
}
for (size_t i = 0, e = descs.size(); i < e; ++i) {
const Description& desc = descs[i];
switch (desc.type) {
case RawSignatureParser::Type::kBuffer:
PackBuffer(desc, py_args[i], f_args, writable);
break;
case RawSignatureParser::Type::kRefObject:
throw RaisePyError(PyExc_NotImplementedError,
"Ref objects not yet supported");
break;
case RawSignatureParser::Type::kScalar:
PackScalar(desc, py_args[i], f_args);
break;
default:
throw RaisePyError(PyExc_NotImplementedError,
"Unsupported argument type");
}
}
}
void FunctionAbi::RawUnpack(absl::Span<const Description> descs,
VmVariantList& f_results,
absl::Span<py::object> py_results) {
if (descs.size() != f_results.size() || descs.size() != py_results.size()) {
throw RaiseValueError("Mismatched RawUnpack() result arity");
}
for (size_t i = 0, e = descs.size(); i < e; ++i) {
const Description& desc = descs[i];
iree_vm_variant_t f_result = iree_vm_variant_empty();
iree_status_t status =
iree_vm_list_get_variant(f_results.raw_ptr(), i, &f_result);
if (!iree_status_is_ok(status)) {
iree_status_ignore(status);
throw RaiseValueError("Could not get result from list");
}
switch (desc.type) {
case RawSignatureParser::Type::kBuffer: {
iree_hal_buffer_view_t* buffer_view =
iree_hal_buffer_view_deref(&f_result.ref);
if (!buffer_view) {
throw RaiseValueError(
"Could not deref result buffer view (wrong type?)");
}
iree_hal_buffer* raw_buffer = iree_hal_buffer_view_buffer(buffer_view);
if (!raw_buffer) {
throw RaiseValueError("Could not deref result buffer (wrong type?)");
}
HalBuffer buffer = HalBuffer::RetainAndCreate(raw_buffer);
// Extract dims from the buffer view.
size_t rank = 0;
absl::InlinedVector<int32_t, 6> dims(6);
iree_status_t status = iree_hal_buffer_view_shape(
buffer_view, dims.capacity(), dims.data(), &rank);
if (iree_status_is_out_of_range(status)) {
dims.resize(rank);
status = iree_hal_buffer_view_shape(buffer_view, dims.capacity(),
dims.data(), &rank);
}
CheckApiStatus(status, "Error extracting shape");
dims.resize(rank);
// Deal with int32_t != int (but require 32bits). Happens on some
// embedded platforms.
static_assert(sizeof(dims[0]) == sizeof(int),
"expected int to be 32 bits");
py_results[i] = host_type_factory_->CreateImmediateNdarray(
desc.buffer.scalar_type,
absl::MakeConstSpan(reinterpret_cast<int*>(dims.data()),
dims.size()),
std::move(buffer));
break;
}
case RawSignatureParser::Type::kRefObject:
throw RaisePyError(PyExc_NotImplementedError,
"Ref objects not yet supported");
break;
case RawSignatureParser::Type::kScalar:
py_results[i] = UnpackScalar(desc, f_result);
break;
default:
throw RaisePyError(PyExc_NotImplementedError,
"Unsupported result type");
}
}
}
void FunctionAbi::AllocateResults(absl::Span<const Description> descs,
VmVariantList& f_args,
VmVariantList& f_results) {
if (f_args.size() != raw_config().inputs.size()) {
throw RaiseValueError("Mismatched AllocateResults() input arity");
}
for (size_t i = 0, e = descs.size(); i < e; ++i) {
const Description& desc = descs[i];
iree_device_size_t alloc_size =
AbiConstants::kScalarTypeSize[static_cast<int>(
desc.buffer.scalar_type)];
switch (desc.type) {
case RawSignatureParser::Type::kBuffer: {
absl::InlinedVector<int32_t, 5> dims;
for (auto dim : desc.dims) {
if (dim < 0) {
// If there is a dynamic dim, fallback to completely func allocated
// result. This is the worst case because it will force a
// pipeline stall.
// TODO(laurenzo): Invoke shape resolution function if available
// to allocate full result.
f_results.AppendNullRef();
}
alloc_size *= dim;
dims.push_back(dim);
}
// Static cases are easy.
iree_hal_buffer_t* raw_buffer;
CheckApiStatus(iree_hal_allocator_allocate_buffer(
device_.allocator(),
static_cast<iree_hal_memory_type_t>(
IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL |
IREE_HAL_MEMORY_TYPE_HOST_VISIBLE),
IREE_HAL_BUFFER_USAGE_ALL, alloc_size, &raw_buffer),
"Error allocating host visible buffer");
auto element_type = static_cast<iree_hal_element_type_t>(
kScalarTypeToHalElementType[static_cast<unsigned>(
desc.scalar.type)]);
iree_hal_buffer_view_t* buffer_view;
CheckApiStatus(iree_hal_buffer_view_create(
raw_buffer, dims.data(), dims.size(), element_type,
iree_allocator_system(), &buffer_view),
"Error allocating buffer_view");
iree_hal_buffer_release(raw_buffer);
iree_vm_ref_t buffer_view_ref =
iree_hal_buffer_view_move_ref(buffer_view);
CheckApiStatus(
iree_vm_list_push_ref_move(f_results.raw_ptr(), &buffer_view_ref),
"Error moving buffer");
break;
}
case RawSignatureParser::Type::kRefObject:
throw RaisePyError(PyExc_NotImplementedError,
"Ref objects not yet supported");
break;
case RawSignatureParser::Type::kScalar:
break;
default:
throw RaisePyError(PyExc_NotImplementedError,
"Unsupported allocation argument type");
}
}
}
void FunctionAbi::PackBuffer(const RawSignatureParser::Description& desc,
py::handle py_arg, VmVariantList& f_args,
bool writable) {
// Request a view of the buffer (use the raw python C API to avoid some
// allocation and copying at the pybind level).
Py_buffer py_view;
// Note that only C-Contiguous ND-arrays are presently supported, so
// only request that via PyBUF_ND. Long term, we should consult an
// "oracle" in the runtime to determine the precise required format and
// set flags accordingly (and fallback/copy on failure).
int flags = PyBUF_FORMAT | PyBUF_ND;
if (writable) {
flags |= PyBUF_WRITABLE;
}
// Acquire the backing buffer and setup RAII release.
if (PyObject_GetBuffer(py_arg.ptr(), &py_view, flags) != 0) {
// The GetBuffer call is required to set an appropriate error.
throw py::error_already_set();
}
PyBufferReleaser py_view_releaser(py_view);
// Whether the py object needs to be retained with the argument.
// Should be set to true if directly mapping, false if copied.
bool depends_on_pyobject = false;
// Verify compatibility.
absl::InlinedVector<int, 2> dynamic_dims;
MapBufferAttrs(py_view, desc, dynamic_dims);
// Allocate a HalBuffer.
// This is hard-coded to C-contiguous right now.
// TODO(laurenzo): Expand to other layouts as needed.
// TODO(laurenzo): Wrap and retain original buffer (depends_on_pyobject=true).
iree_hal_buffer_t* raw_buffer;
CheckApiStatus(iree_hal_allocator_allocate_buffer(
device_.allocator(),
static_cast<iree_hal_memory_type_t>(
IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE),
IREE_HAL_BUFFER_USAGE_ALL, py_view.len, &raw_buffer),
"Failed to allocate device visible buffer");
CheckApiStatus(
iree_hal_buffer_write_data(raw_buffer, 0, py_view.buf, py_view.len),
"Error writing to input buffer");
// Only capture the reference to the exporting object (incrementing it)
// once guaranteed successful.
if (depends_on_pyobject) {
// Note for future implementation: there needs to be a place to stash
// references to be kept alive which back a buffer. This is likely an
// additional bag of refs returned from this function, which can then
// be attached to an invocation.
throw RaisePyError(PyExc_NotImplementedError,
"Dependent buffer arguments not implemented");
}
// Create the buffer_view. (note that numpy shape is ssize_t)
auto element_type = static_cast<iree_hal_element_type_t>(
kScalarTypeToHalElementType[static_cast<unsigned>(desc.scalar.type)]);
absl::InlinedVector<int, 5> dims(py_view.ndim);
std::copy(py_view.shape, py_view.shape + py_view.ndim, dims.begin());
iree_hal_buffer_view_t* buffer_view;
CheckApiStatus(iree_hal_buffer_view_create(
raw_buffer, dims.data(), dims.size(), element_type,
iree_allocator_system(), &buffer_view),
"Error allocating buffer_view");
iree_hal_buffer_release(raw_buffer);
iree_vm_ref_t buffer_view_ref = iree_hal_buffer_view_move_ref(buffer_view);
CheckApiStatus(iree_vm_list_push_ref_move(f_args.raw_ptr(), &buffer_view_ref),
"Error moving buffer view");
}
std::vector<std::string> SerializeVmVariantList(VmVariantList& vm_list) {
size_t size = vm_list.size();
std::vector<std::string> results;
results.reserve(size);
for (iree_host_size_t i = 0; i < size; ++i) {
iree_vm_variant_t variant = iree_vm_variant_empty();
iree_status_t status =
iree_vm_list_get_variant(vm_list.raw_ptr(), i, &variant);
CheckApiStatus(status, "Failed to get vm variant from list");
if (iree_vm_variant_is_value(variant)) {
results.push_back("i32=" + std::to_string(variant.i32));
} else if (iree_vm_variant_is_ref(variant) &&
iree_hal_buffer_view_isa(&variant.ref)) {
auto buffer_view = iree_hal_buffer_view_deref(&variant.ref);
std::string result_str(4096, '\0');
iree_status_t status;
do {
iree_host_size_t actual_length = 0;
iree_host_size_t max_element_count =
std::numeric_limits<iree_host_size_t>::max();
status = iree_hal_buffer_view_format(buffer_view, max_element_count,
result_str.size() + 1,
&result_str[0], &actual_length);
result_str.resize(actual_length);
} while (iree_status_is_out_of_range(status));
CheckApiStatus(status,
"Failed to create a string representation of the inputs");
results.push_back(result_str);
} else {
RaiseValueError(
"Expected vm_list's elements to be scalars or buffer views.");
}
}
return results;
}
void SetupFunctionAbiBindings(pybind11::module m) {
py::class_<FunctionAbi, std::unique_ptr<FunctionAbi>>(m, "FunctionAbi")
.def(py::init(&PyCreateAbi))
.def("__repr__", &FunctionAbi::DebugString)
.def_property_readonly("raw_input_arity", &FunctionAbi::raw_input_arity)
.def_property_readonly("raw_result_arity", &FunctionAbi::raw_result_arity)
.def("pack_inputs",
[](FunctionAbi* self, py::args py_args, py::kwargs py_kwargs) {
VmVariantList f_args = VmVariantList::Create(py_args.size());
self->Pack(py_args, py_kwargs,
absl::MakeConstSpan(self->raw_config().inputs), f_args,
false /* writable */);
return f_args;
})
.def("serialize_vm_list",
[](FunctionAbi* self, VmVariantList& vm_list) {
return SerializeVmVariantList(vm_list);
})
.def("allocate_results", &PyAllocateResults, py::arg("f_results"),
py::arg("static_alloc") = true)
.def("unpack_results", [](FunctionAbi* self, VmVariantList& f_results) {
return self->Unpack(absl::MakeConstSpan(self->raw_config().results),
f_results);
});
}
} // namespace python
} // namespace iree