blob: b3d3546ce9023d44da357be3f6bb1b9483b5ef4b [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "bindings/python/pyiree/function_abi.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/types/span.h"
#include "bindings/python/pyiree/rt.h"
#include "bindings/python/pyiree/status_utils.h"
#include "iree/base/api.h"
#include "iree/base/signature_mangle.h"
#include "iree/hal/api.h"
namespace iree {
namespace python {
namespace {
// Python friendly entry-point for creating an instance from a list
// of attributes. This is not particularly efficient and is primarily
// for testing. Typically, this will be created directly from a function
// and the attribute introspection will happen internal to C++.
std::unique_ptr<FunctionAbi> PyCreateAbi(
std::shared_ptr<HostTypeFactory> host_type_factory,
std::vector<std::pair<std::string, std::string>> attrs) {
auto lookup =
[&attrs](absl::string_view key) -> absl::optional<absl::string_view> {
for (const auto& kv : attrs) {
if (kv.first == key) return kv.second;
}
return absl::nullopt;
};
return FunctionAbi::Create(std::move(host_type_factory), lookup);
}
std::unique_ptr<FunctionArgVariantList> PyRawPack(
FunctionAbi* self, RtContext& context,
absl::Span<const FunctionAbi::Description> descs, py::sequence py_args,
bool writable) {
if (py_args.size() != descs.size()) {
throw RaiseValueError("Mismatched pack arity");
}
auto f_list = absl::make_unique<FunctionArgVariantList>();
f_list->contents().resize(py_args.size());
absl::InlinedVector<py::handle, 8> local_py_args(py_args.begin(),
py_args.end());
self->RawPack(context, descs, absl::MakeSpan(local_py_args),
absl::MakeSpan(f_list->contents()), writable);
return f_list;
}
std::unique_ptr<FunctionArgVariantList> PyAllocateResults(
FunctionAbi* self, RtContext& context, FunctionArgVariantList* f_args) {
auto f_results = absl::make_unique<FunctionArgVariantList>();
f_results->contents().resize(self->raw_result_arity());
self->AllocateResults(context,
absl::MakeConstSpan(self->raw_config().results),
absl::MakeConstSpan(f_args->contents()),
absl::MakeSpan(f_results->contents()));
return f_results;
}
py::object PyRawUnpackResults(FunctionAbi* self, RtContext& context,
FunctionArgVariantList* f_args) {
absl::InlinedVector<py::object, 4> py_results;
py_results.resize(f_args->contents().size());
self->RawUnpack(context, absl::MakeConstSpan(self->raw_config().results),
absl::MakeSpan(f_args->contents()),
absl::MakeSpan(py_results));
py::tuple py_result_tuple(py_results.size());
for (size_t i = 0, e = py_results.size(); i < e; ++i) {
py_result_tuple[i] = std::move(py_results[i]);
}
return py_result_tuple;
}
// RAII wrapper for a Py_buffer which calls PyBuffer_Release when it goes
// out of scope.
class PyBufferReleaser {
public:
PyBufferReleaser(Py_buffer& b) : b_(b) {}
~PyBufferReleaser() { PyBuffer_Release(&b_); }
private:
Py_buffer& b_;
};
pybind11::error_already_set RaiseBufferMismatchError(
std::string message, py::handle obj,
const RawSignatureParser::Description& desc) {
message.append("For argument = ");
auto arg_py_str = py::str(obj);
auto arg_str = static_cast<std::string>(arg_py_str);
message.append(arg_str);
message.append(" (expected ");
desc.ToString(message);
message.append(")");
return RaiseValueError(message.c_str());
}
// Verifies and maps the py buffer shape and layout to the bound argument.
// Returns false if not compatible.
void MapBufferAttrs(Py_buffer& py_view,
const RawSignatureParser::Description& desc,
BoundHalBufferFunctionArg& bound_arg) {
// Verify that rank matches.
if (py_view.ndim != desc.dims.size()) {
throw RaiseBufferMismatchError(
absl::StrCat("Mismatched buffer rank (received: ", py_view.ndim,
", expected: ", desc.dims.size(), "): "),
py::handle(py_view.obj), desc);
}
// Verify that the item size matches.
size_t f_item_size =
AbiConstants::kScalarTypeSize[static_cast<int>(desc.buffer.scalar_type)];
if (f_item_size != py_view.itemsize) {
throw RaiseBufferMismatchError(
absl::StrCat("Mismatched buffer item size (received: ",
py_view.itemsize, ", expected: ", f_item_size, "): "),
py::handle(py_view.obj), desc);
}
// Note: The python buffer format does not map precisely to IREE's type
// system, so the below is only advisory for where they do match. Otherwise,
// it is basically a bitcast.
const char* f_expected_format =
kScalarTypePyFormat[static_cast<int>(desc.buffer.scalar_type)];
if (f_expected_format != nullptr &&
strcmp(f_expected_format, py_view.format) != 0) {
throw RaiseBufferMismatchError(
absl::StrCat("Mismatched buffer format (received: ", py_view.format,
", expected: ", f_expected_format, "): "),
py::handle(py_view.obj), desc);
}
// Verify shape, populating dynamic_dims while looping.
for (size_t i = 0; i < py_view.ndim; ++i) {
auto py_dim = py_view.shape[i];
auto f_dim = desc.dims[i];
if (f_dim < 0) {
// Dynamic.
bound_arg.dynamic_dims.push_back(py_dim);
} else if (py_dim != f_dim) {
// Mismatch.
throw RaiseBufferMismatchError(
absl::StrCat("Mismatched buffer dim (received: ", py_dim,
", expected: ", f_dim, "): "),
py::handle(py_view.obj), desc);
}
}
}
void PackBuffer(RtContext& context, const RawSignatureParser::Description& desc,
py::handle py_arg, FunctionArgVariant& f_arg, bool writable) {
// Request a view of the buffer (use the raw python C API to avoid some
// allocation and copying at the pybind level).
Py_buffer py_view;
// Note that only C-Contiguous ND-arrays are presently supported, so
// only request that via PyBUF_ND. Long term, we should consult an
// "oracle" in the runtime to determine the precise required format and
// set flags accordingly (and fallback/copy on failure).
int flags = PyBUF_FORMAT | PyBUF_ND;
if (writable) {
flags |= PyBUF_WRITABLE;
}
// Acquire the backing buffer and setup RAII release.
if (PyObject_GetBuffer(py_arg.ptr(), &py_view, flags) != 0) {
// The GetBuffer call is required to set an appropriate error.
throw py::error_already_set();
}
PyBufferReleaser py_view_releaser(py_view);
auto& bound_arg = f_arg.emplace<BoundHalBufferFunctionArg>();
// Whether the py object needs to be retained with the argument.
// Should be set to true if directly mapping, false if copied.
bool depends_on_pyobject = false;
// Verify compatibility.
MapBufferAttrs(py_view, desc, bound_arg);
// Allocate a HalBuffer.
// This is hard-coded to C-contiguous right now.
// TODO(laurenzo): Expand to other layouts as needed.
// TODO(laurenzo): Wrap and retain original buffer (depends_on_pyobject=true).
bound_arg.buffer =
context.AllocateDeviceVisible(py_view.len, IREE_HAL_BUFFER_USAGE_ALL);
CheckApiStatus(iree_hal_buffer_write_data(bound_arg.buffer.raw_ptr(), 0,
py_view.buf, py_view.len),
"Error writing to input buffer");
// Only capture the reference to the exporting object (incrementing it)
// once guaranteed successful.
if (depends_on_pyobject) {
bound_arg.dependent_pyobject = py::object(py::handle(py_view.obj), false);
}
}
std::string FunctionArgVariantListRepr(FunctionArgVariantList* self) {
std::string s;
struct Visitor {
Visitor(std::string& s) : s(s) {}
void operator()(std::nullptr_t) { s.append("None"); }
void operator()(const BoundHalBufferFunctionArg& arg) {
absl::StrAppend(&s, "HalBuffer(", arg.buffer.byte_length());
if (!arg.dynamic_dims.empty()) {
absl::StrAppend(&s, ", dynamic_dims=[");
for (size_t i = 0, e = arg.dynamic_dims.size(); i < e; ++i) {
if (i > 0) s.append(",");
absl::StrAppend(&s, arg.dynamic_dims[i]);
}
absl::StrAppend(&s, "]");
}
if (arg.dependent_pyobject) {
absl::StrAppend(&s, ", wraps=");
std::string wraps_s = py::str(arg.dependent_pyobject);
s.append(wraps_s);
}
absl::StrAppend(&s, ")");
}
std::string& s;
};
Visitor v(s);
absl::StrAppend(&s, "<FunctionArgVariantList(", self->contents().size(),
"): [");
for (size_t i = 0, e = self->contents().size(); i < e; ++i) {
if (i > 0) s.append(", ");
const auto& arg = self->contents()[i];
absl::visit(v, arg);
}
absl::StrAppend(&s, "]>");
return s;
}
} // namespace
//------------------------------------------------------------------------------
// FunctionAbi
//------------------------------------------------------------------------------
std::string FunctionAbi::DebugString() const {
RawSignatureParser p;
auto s = p.FunctionSignatureToString(raw_config_.signature);
if (!s) {
return "<FunctionAbi NO_DEBUG_INFO>";
}
return absl::StrCat("<FunctionAbi ", *s, ">");
}
std::unique_ptr<FunctionAbi> FunctionAbi::Create(
std::shared_ptr<HostTypeFactory> host_type_factory,
AttributeLookup lookup) {
auto abi = absl::make_unique<FunctionAbi>(std::move(host_type_factory));
// Fetch key attributes for the raw ABI.
auto raw_version = lookup("fv");
auto raw_fsig_str = lookup("f");
// Validation.
if (!raw_fsig_str) {
throw RaiseValueError("No raw abi reflection metadata for function");
}
if (!raw_version || *raw_version != "1") {
throw RaiseValueError("Unsupported raw function ABI version");
}
// Parse signature.
abi->raw_config().signature = std::string(*raw_fsig_str);
RawSignatureParser raw_parser;
raw_parser.VisitInputs(*raw_fsig_str,
[&abi](const RawSignatureParser::Description& d) {
abi->raw_config().inputs.push_back(d);
});
raw_parser.VisitResults(*raw_fsig_str,
[&abi](const RawSignatureParser::Description& d) {
abi->raw_config().results.push_back(d);
});
if (raw_parser.GetError()) {
auto message = absl::StrCat(
"Error parsing raw ABI signature: ", *raw_parser.GetError(), " (",
*raw_fsig_str, ")");
throw RaiseValueError(message.c_str());
}
// TODO(laurenzo): Detect sip ABI and add a translation layer.
return abi;
}
void FunctionAbi::RawPack(RtContext& context,
absl::Span<const Description> descs,
absl::Span<py::handle> py_args,
absl::Span<FunctionArgVariant> f_args,
bool writable) {
if (descs.size() != py_args.size() || descs.size() != f_args.size()) {
throw RaiseValueError("Mismatched RawPack() input arity");
}
for (size_t i = 0, e = descs.size(); i < e; ++i) {
const Description& desc = descs[i];
switch (desc.type) {
case RawSignatureParser::Type::kBuffer:
PackBuffer(context, desc, py_args[i], f_args[i], writable);
break;
case RawSignatureParser::Type::kRefObject:
throw RaisePyError(PyExc_NotImplementedError,
"Ref objects not yet supported");
break;
default:
throw RaisePyError(PyExc_NotImplementedError,
"Unsupported argument type");
}
}
}
void FunctionAbi::RawUnpack(RtContext& context,
absl::Span<const Description> descs,
absl::Span<FunctionArgVariant> f_results,
absl::Span<py::object> py_results) {
if (descs.size() != f_results.size() || descs.size() != py_results.size()) {
throw RaiseValueError("Mismatched RawUnpack() result arity");
}
for (size_t i = 0, e = descs.size(); i < e; ++i) {
const Description& desc = descs[i];
switch (desc.type) {
case RawSignatureParser::Type::kBuffer: {
auto& bound_arg = absl::get<BoundHalBufferFunctionArg>(f_results[i]);
absl::InlinedVector<int, 7> backing_full_dims;
absl::Span<const int> dims;
if (bound_arg.dynamic_dims.empty()) {
// No dynamic dims. Just use the desc dims.
dims = absl::MakeSpan(desc.dims);
} else {
// Splice the dims back together.
size_t j = 0;
for (auto dim : desc.dims) {
if (dim >= 0) {
backing_full_dims.push_back(dim);
} else if (j < bound_arg.dynamic_dims.size()) {
backing_full_dims.push_back(bound_arg.dynamic_dims[j++]);
} else {
throw RaiseValueError("Mismatched static and dynamic dims");
}
}
dims = absl::MakeConstSpan(backing_full_dims);
}
py_results[i] = host_type_factory_->CreateImmediateNdarray(
desc.buffer.scalar_type, absl::MakeSpan(desc.dims),
std::move(bound_arg.buffer));
// Already moved the buffer out. Clear the rest.
f_results[i] = nullptr;
break;
}
case RawSignatureParser::Type::kRefObject:
throw RaisePyError(PyExc_NotImplementedError,
"Ref objects not yet supported");
break;
default:
throw RaisePyError(PyExc_NotImplementedError,
"Unsupported argument type");
}
}
}
void FunctionAbi::AllocateResults(RtContext& context,
absl::Span<const Description> descs,
absl::Span<const FunctionArgVariant> f_args,
absl::Span<FunctionArgVariant> f_results) {
if (descs.size() != f_results.size()) {
throw RaiseValueError("Mismatched AllocateResults() result arity");
}
if (f_args.size() != raw_config().inputs.size()) {
throw RaiseValueError("Mismatched AllocatResults() input arity");
}
for (size_t i = 0, e = descs.size(); i < e; ++i) {
const Description& desc = descs[i];
iree_device_size_t alloc_size =
AbiConstants::kScalarTypeSize[static_cast<int>(
desc.buffer.scalar_type)];
switch (desc.type) {
case RawSignatureParser::Type::kBuffer: {
BoundHalBufferFunctionArg bound_result;
for (auto dim : desc.dims) {
if (dim < 0) {
bound_result.dynamic_dims.push_back(dim);
// TODO(laurenzo): Should accumulate dynamic dim slices into an
// invocation of a result allocator function and then use it to
// invoke.
// TODO(laurenzo): Should fallback to completely unknown result
// thunk.
throw RaisePyError(
PyExc_NotImplementedError,
"AllocateResult() does not yet support dynamic dims");
}
alloc_size *= dim;
}
// Static cases are easy.
// TODO(laurenzo): This should probably be AllocateDeviceLocal
// with a memory type of HOST_VISIBLE, which is not yet plumbed
// through.
bound_result.buffer = context.AllocateDeviceVisible(
alloc_size, IREE_HAL_BUFFER_USAGE_ALL);
f_results[i] = std::move(bound_result);
break;
}
case RawSignatureParser::Type::kRefObject:
throw RaisePyError(PyExc_NotImplementedError,
"Ref objects not yet supported");
break;
default:
throw RaisePyError(PyExc_NotImplementedError,
"Unsupported argument type");
}
}
}
void SetupFunctionAbiBindings(pybind11::module m) {
m.def("create", &PyCreateAbi);
py::class_<FunctionAbi, std::unique_ptr<FunctionAbi>>(m, "FunctionAbi")
.def("__repr__", &FunctionAbi::DebugString)
.def_property_readonly("raw_input_arity", &FunctionAbi::raw_input_arity)
.def_property_readonly("raw_result_arity", &FunctionAbi::raw_result_arity)
.def("raw_pack_inputs",
[](FunctionAbi* self, RtContext& context, py::sequence py_args) {
return PyRawPack(self, context,
absl::MakeConstSpan(self->raw_config().inputs),
py_args, false /* writable */);
})
.def("allocate_results", &PyAllocateResults)
.def("raw_unpack_results", &PyRawUnpackResults);
// Note that FunctionArgVariantList is meant to be opaque to normal users,
// but we provide some printing and other introspection to aid testing.
py::class_<FunctionArgVariantList, std::unique_ptr<FunctionArgVariantList>>(
m, "FunctionArgVariantList")
.def_property_readonly(
"size",
[](FunctionArgVariantList* self) { return self->contents().size(); })
.def("__repr__", &FunctionArgVariantListRepr);
}
} // namespace python
} // namespace iree