blob: b0b9a5951f6927432d5f8b3b8509c576bec2a835 [file] [log] [blame]
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "./invoke.h"
#include "./hal.h"
#include "./vm.h"
#include "iree/base/api.h"
#include "iree/base/tracing.h"
#include "iree/hal/api.h"
#include "iree/modules/hal/module.h"
#include "iree/vm/api.h"
namespace iree {
namespace python {
namespace {
class InvokeContext {
public:
InvokeContext(HalDevice &device) : device_(device) {}
HalDevice &device() { return device_; }
HalAllocator allocator() {
// TODO: Unfortunate that we inc ref here but that is how our object model
// is set up.
return HalAllocator::BorrowFromRawPtr(device().allocator());
}
private:
HalDevice device_;
};
using PackCallback =
std::function<void(InvokeContext &, iree_vm_list_t *, py::handle)>;
class InvokeStatics {
public:
~InvokeStatics() {
for (auto it : py_type_to_pack_callbacks_) {
py::handle(it.first).dec_ref();
}
}
py::str kNamedTag = py::str("named");
py::str kSlistTag = py::str("slist");
py::str kStupleTag = py::str("stuple");
py::str kSdictTag = py::str("sdict");
py::int_ kZero = py::int_(0);
py::int_ kOne = py::int_(1);
py::int_ kTwo = py::int_(2);
py::str kAsArray = py::str("asarray");
py::str kMapDtypeToElementTypeAttr = py::str("map_dtype_to_element_type");
py::str kContiguousArg = py::str("C");
py::str kArrayProtocolAttr = py::str("__array__");
py::str kDtypeAttr = py::str("dtype");
// Primitive type names.
py::str kF32 = py::str("f32");
py::str kF64 = py::str("f64");
py::str kI1 = py::str("i1");
py::str kI8 = py::str("i8");
py::str kI16 = py::str("i16");
py::str kI32 = py::str("i32");
py::str kI64 = py::str("i64");
// Compound types names.
py::str kNdarray = py::str("ndarray");
// Attribute names.
py::str kAttrBufferView = py::str("_buffer_view");
// Module 'numpy'.
py::module &numpy_module() { return numpy_module_; }
py::object &runtime_module() {
if (!runtime_module_) {
runtime_module_ = py::module::import("iree.runtime");
}
return *runtime_module_;
}
py::module &array_interop_module() {
if (!array_interop_module_) {
array_interop_module_ = py::module::import("iree.runtime.array_interop");
}
return *array_interop_module_;
}
py::object &device_array_type() {
if (!device_array_type_) {
device_array_type_ = runtime_module().attr("DeviceArray");
}
return *device_array_type_;
}
py::type &hal_buffer_view_type() { return hal_buffer_view_type_; }
py::object MapElementAbiTypeToDtype(py::object &element_abi_type) {
try {
return abi_type_to_dtype_[element_abi_type];
} catch (std::exception &) {
std::string msg("could not map abi type ");
msg.append(py::cast<std::string>(py::repr(element_abi_type)));
msg.append(" to numpy dtype");
throw std::invalid_argument(std::move(msg));
}
}
enum iree_hal_element_types_t MapDtypeToElementType(py::object dtype) {
// TODO: Consider porting this from a py func to C++ as it can be on
// the critical path.
try {
py::object element_type =
array_interop_module().attr(kMapDtypeToElementTypeAttr)(dtype);
if (element_type.is_none()) {
throw std::invalid_argument("mapping not found");
}
return py::cast<enum iree_hal_element_types_t>(element_type);
} catch (std::exception &e) {
std::string msg("could not map dtype ");
msg.append(py::cast<std::string>(py::repr(dtype)));
msg.append(" to element type: ");
msg.append(e.what());
throw std::invalid_argument(std::move(msg));
}
}
PackCallback AbiTypeToPackCallback(py::handle desc) {
return AbiTypeToPackCallback(
std::move(desc), /*desc_is_list=*/py::isinstance<py::list>(desc));
}
// Given an ABI desc, return a callback that can pack a corresponding py
// value into a list. For efficiency, the caller must specify whether the
// desc is a list (this check already needs to be done typically so
// passed in).
PackCallback AbiTypeToPackCallback(py::handle desc, bool desc_is_list) {
// Switch based on descriptor type.
if (desc_is_list) {
// Compound type.
py::object compound_type = desc[kZero];
if (compound_type.equal(kNdarray)) {
// Has format:
// ["ndarray", "f32", dim0, dim1, ...]
// Extract static information about the target.
std::vector<int64_t> abi_shape(py::len(desc) - 2);
for (size_t i = 0, e = abi_shape.size(); i < e; ++i) {
py::handle dim = desc[py::int_(i + 2)];
abi_shape[i] = dim.is_none() ? -1 : py::cast<int64_t>(dim);
}
// Map abi element type to dtype.
py::object abi_type = desc[kOne];
py::object target_dtype = MapElementAbiTypeToDtype(abi_type);
auto hal_element_type = MapDtypeToElementType(target_dtype);
return [this, target_dtype = std::move(target_dtype), hal_element_type,
abi_shape = std::move(abi_shape)](InvokeContext &c,
iree_vm_list_t *list,
py::handle py_value) {
IREE_TRACE_SCOPE0("ArgumentPacker::ReflectionNdarray");
HalBufferView *bv = nullptr;
py::object retained_bv;
if (py::isinstance(py_value, device_array_type())) {
// Short-circuit: If a DeviceArray is provided, assume it is
// correct.
IREE_TRACE_SCOPE0("PackDeviceArray");
bv = py::cast<HalBufferView *>(py_value.attr(kAttrBufferView));
} else if (py::isinstance(py_value, hal_buffer_view_type())) {
// Short-circuit: If a HalBufferView is provided directly.
IREE_TRACE_SCOPE0("PackBufferView");
bv = py::cast<HalBufferView *>(py_value);
} else {
// Fall back to the array protocol to generate a host side
// array and then convert that.
IREE_TRACE_SCOPE0("PackHostArray");
py::object host_array;
try {
host_array = numpy_module().attr(kAsArray)(py_value, target_dtype,
kContiguousArg);
} catch (std::exception &e) {
std::string msg("could not convert value to numpy array: dtype=");
msg.append(py::cast<std::string>(py::repr(target_dtype)));
msg.append(", error='");
msg.append(e.what());
msg.append("', value=");
msg.append(py::cast<std::string>(py::repr(py_value)));
throw std::invalid_argument(std::move(msg));
}
retained_bv = c.allocator().AllocateBufferCopy(
IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL,
IREE_HAL_BUFFER_USAGE_DISPATCH |
IREE_HAL_BUFFER_USAGE_TRANSFER |
IREE_HAL_BUFFER_USAGE_MAPPING,
host_array, hal_element_type);
bv = py::cast<HalBufferView *>(retained_bv);
}
// TODO: Add some shape verification. Not strictly necessary as the VM
// will check, but may make error reporting nicer.
// TODO: It is theoretically possible to enqueue further conversions
// on the device, but for now we require things to line up closely.
// TODO: If adding further manipulation here, please make this common
// with the generic access case.
iree_vm_ref_t buffer_view_ref =
iree_hal_buffer_view_retain_ref(bv->raw_ptr());
CheckApiStatus(iree_vm_list_push_ref_move(list, &buffer_view_ref),
"could not push buffer view to list");
};
} else if (compound_type.equal(kSlistTag) ||
compound_type.equal(kStupleTag)) {
// Tuple/list extraction.
// When decoding a list or tuple, the desc object is like:
// ['slist', [...value_type_0...], ...]
// Where the type is either 'slist' or 'stuple'.
std::vector<PackCallback> sub_packers(py::len(desc) - 1);
for (size_t i = 0; i < sub_packers.size(); i++) {
sub_packers[i] = AbiTypeToPackCallback(desc[py::int_(i + 1)]);
}
return [sub_packers = std::move(sub_packers)](InvokeContext &c,
iree_vm_list_t *list,
py::handle py_value) {
if (py::len(py_value) != sub_packers.size()) {
std::string msg("expected a sequence with ");
msg.append(std::to_string(sub_packers.size()));
msg.append(" values. got: ");
msg.append(py::cast<std::string>(py::repr(py_value)));
throw std::invalid_argument(std::move(msg));
}
VmVariantList item_list = VmVariantList::Create(sub_packers.size());
for (size_t i = 0; i < sub_packers.size(); ++i) {
py::object item_py_value;
try {
item_py_value = py_value[py::int_(i)];
} catch (std::exception &e) {
std::string msg("could not get item ");
msg.append(std::to_string(i));
msg.append(" from: ");
msg.append(py::cast<std::string>(py::repr(py_value)));
msg.append(": ");
msg.append(e.what());
throw std::invalid_argument(std::move(msg));
}
sub_packers[i](c, item_list.raw_ptr(), item_py_value);
}
// Push the sub list.
iree_vm_ref_t retained =
iree_vm_list_retain_ref(item_list.steal_raw_ptr());
iree_vm_list_push_ref_move(list, &retained);
};
} else if (compound_type.equal(kSdictTag)) {
// Dict extraction.
// The descriptor for an sdict is like:
// ['sdict', ['key1', value1], ...]
std::vector<std::pair<py::object, PackCallback>> sub_packers(
py::len(desc) - 1);
for (size_t i = 0; i < sub_packers.size(); i++) {
py::object sub_desc = desc[py::int_(i + 1)];
py::object key = sub_desc[kZero];
py::object value_desc = sub_desc[kOne];
sub_packers[i] =
std::make_pair(std::move(key), AbiTypeToPackCallback(value_desc));
}
return [sub_packers = std::move(sub_packers)](InvokeContext &c,
iree_vm_list_t *list,
py::handle py_value) {
if (py::len(py_value) != sub_packers.size()) {
std::string msg("expected a dict with ");
msg.append(std::to_string(sub_packers.size()));
msg.append(" values. got: ");
msg.append(py::cast<std::string>(py::repr(py_value)));
throw std::invalid_argument(std::move(msg));
}
VmVariantList item_list = VmVariantList::Create(sub_packers.size());
for (size_t i = 0; i < sub_packers.size(); ++i) {
py::object item_py_value;
try {
item_py_value = py_value[sub_packers[i].first];
} catch (std::exception &e) {
std::string msg("could not get item ");
msg.append(py::cast<std::string>(py::repr(sub_packers[i].first)));
msg.append(" from: ");
msg.append(py::cast<std::string>(py::repr(py_value)));
msg.append(": ");
msg.append(e.what());
throw std::invalid_argument(std::move(msg));
}
sub_packers[i].second(c, item_list.raw_ptr(), item_py_value);
}
// Push the sub list.
iree_vm_ref_t retained =
iree_vm_list_retain_ref(item_list.steal_raw_ptr());
iree_vm_list_push_ref_move(list, &retained);
};
} else {
std::string message("Unrecognized reflection compound type: ");
message.append(py::cast<std::string>(compound_type));
throw std::invalid_argument(message);
}
} else {
// Primtive type.
py::str prim_type = py::cast<py::str>(desc);
if (prim_type.equal(kF32)) {
// f32
return [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
iree_vm_value_t vm_value =
iree_vm_value_make_f32(py::cast<float>(py_value));
CheckApiStatus(iree_vm_list_push_value(list, &vm_value),
"could not append value");
};
} else if (prim_type.equal(kF64)) {
// f64
return [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
iree_vm_value_t vm_value =
iree_vm_value_make_f64(py::cast<double>(py_value));
CheckApiStatus(iree_vm_list_push_value(list, &vm_value),
"could not append value");
};
} else if (prim_type.equal(kI32)) {
// i32.
return [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
iree_vm_value_t vm_value =
iree_vm_value_make_i32(py::cast<int32_t>(py_value));
CheckApiStatus(iree_vm_list_push_value(list, &vm_value),
"could not append value");
};
} else if (prim_type.equal(kI64)) {
// i64.
return [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
iree_vm_value_t vm_value =
iree_vm_value_make_i64(py::cast<int64_t>(py_value));
CheckApiStatus(iree_vm_list_push_value(list, &vm_value),
"could not append value");
};
} else if (prim_type.equal(kI8)) {
// i8.
return [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
iree_vm_value_t vm_value =
iree_vm_value_make_i8(py::cast<int8_t>(py_value));
CheckApiStatus(iree_vm_list_push_value(list, &vm_value),
"could not append value");
};
} else if (prim_type.equal(kI16)) {
// i16.
return [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
iree_vm_value_t vm_value =
iree_vm_value_make_i16(py::cast<int16_t>(py_value));
CheckApiStatus(iree_vm_list_push_value(list, &vm_value),
"could not append value");
};
} else {
std::string message("Unrecognized reflection primitive type: ");
message.append(py::cast<std::string>(prim_type));
throw std::invalid_argument(message);
}
}
}
PackCallback GetGenericPackCallbackFor(py::handle arg) {
PopulatePyTypeToPackCallbacks();
py::type clazz = py::type::of(arg);
auto found_it = py_type_to_pack_callbacks_.find(clazz.ptr());
if (found_it == py_type_to_pack_callbacks_.end()) {
// Probe to see if we have a host array.
if (py::hasattr(arg, kArrayProtocolAttr)) {
return GetGenericPackCallbackForNdarray();
}
return {};
}
return found_it->second;
}
private:
PackCallback GetGenericPackCallbackForNdarray() {
return [this](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
IREE_TRACE_SCOPE0("ArgumentPacker::GenericNdarray");
py::object host_array;
try {
host_array = numpy_module().attr(kAsArray)(
py_value, /*dtype=*/py::none(), kContiguousArg);
} catch (std::exception &e) {
std::string msg("could not convert value to numpy array: ");
msg.append("error='");
msg.append(e.what());
msg.append("', value=");
msg.append(py::cast<std::string>(py::repr(py_value)));
throw std::invalid_argument(std::move(msg));
}
auto hal_element_type =
MapDtypeToElementType(host_array.attr(kDtypeAttr));
// Put it on the device.
py::object retained_bv = c.allocator().AllocateBufferCopy(
IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL,
IREE_HAL_BUFFER_USAGE_DISPATCH | IREE_HAL_BUFFER_USAGE_TRANSFER |
IREE_HAL_BUFFER_USAGE_MAPPING,
host_array, hal_element_type);
HalBufferView *bv = py::cast<HalBufferView *>(retained_bv);
// TODO: If adding further manipulation here, please make this common
// with the reflection access case.
iree_vm_ref_t buffer_view_ref =
iree_hal_buffer_view_retain_ref(bv->raw_ptr());
CheckApiStatus(iree_vm_list_push_ref_move(list, &buffer_view_ref),
"could not append value");
};
}
void PopulatePyTypeToPackCallbacks() {
if (!py_type_to_pack_callbacks_.empty()) return;
// We only care about int and double in the numeric hierarchy. Since Python
// has no further refinement of these, just treat them as vm 64 bit int and
// floats and let the VM take care of it. There isn't much else we can do.
AddPackCallback(
py::type::of(py::cast(1)),
[](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
iree_vm_value_t vm_value =
iree_vm_value_make_i64(py::cast<int64_t>(py_value));
CheckApiStatus(iree_vm_list_push_value(list, &vm_value),
"could not append value");
});
AddPackCallback(
py::type::of(py::cast(1.0)),
[](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
iree_vm_value_t vm_value =
iree_vm_value_make_f64(py::cast<double>(py_value));
CheckApiStatus(iree_vm_list_push_value(list, &vm_value),
"could not append value");
});
// List/tuple.
auto sequence_callback = [this](InvokeContext &c, iree_vm_list_t *list,
py::handle py_value) {
auto py_seq = py::cast<py::sequence>(py_value);
VmVariantList item_list = VmVariantList::Create(py::len(py_seq));
for (py::object py_item : py_seq) {
PackCallback sub_packer = GetGenericPackCallbackFor(py_item);
if (!sub_packer) {
std::string message("could not convert python value to VM: ");
message.append(py::cast<std::string>(py::repr(py_item)));
throw std::invalid_argument(std::move(message));
}
sub_packer(c, item_list.raw_ptr(), py_item);
}
// Push the sub list.
iree_vm_ref_t retained =
iree_vm_list_retain_ref(item_list.steal_raw_ptr());
iree_vm_list_push_ref_move(list, &retained);
};
AddPackCallback(py::type::of(py::list{}), sequence_callback);
AddPackCallback(py::type::of(py::tuple{}), sequence_callback);
// Dict.
auto dict_callback = [this](InvokeContext &c, iree_vm_list_t *list,
py::handle py_value) {
// Gets all dict items and sorts (by key).
auto py_dict = py::cast<py::dict>(py_value);
py::list py_keys;
for (std::pair<py::handle, py::handle> it : py_dict) {
py_keys.append(it.first);
}
py_keys.attr("sort")();
VmVariantList item_list = VmVariantList::Create(py_keys.size());
for (auto py_key : py_keys) {
py::object py_item = py_dict[py_key];
PackCallback sub_packer = GetGenericPackCallbackFor(py_item);
if (!sub_packer) {
std::string message("could not convert python value to VM: ");
message.append(py::cast<std::string>(py::repr(py_item)));
throw std::invalid_argument(std::move(message));
}
sub_packer(c, item_list.raw_ptr(), py_item);
}
// Push the sub list.
iree_vm_ref_t retained =
iree_vm_list_retain_ref(item_list.steal_raw_ptr());
iree_vm_list_push_ref_move(list, &retained);
};
AddPackCallback(py::type::of(py::dict{}), dict_callback);
// HalBufferView.
AddPackCallback(
py::type::of<HalBufferView>(),
[](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
HalBufferView *bv = py::cast<HalBufferView *>(py_value);
iree_vm_ref_t buffer_view_ref =
iree_hal_buffer_view_retain_ref(bv->raw_ptr());
CheckApiStatus(iree_vm_list_push_ref_move(list, &buffer_view_ref),
"could not append value");
});
// DeviceArray.
AddPackCallback(
device_array_type(),
[this](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
HalBufferView *bv =
py::cast<HalBufferView *>(py_value.attr(kAttrBufferView));
iree_vm_ref_t buffer_view_ref =
iree_hal_buffer_view_retain_ref(bv->raw_ptr());
CheckApiStatus(iree_vm_list_push_ref_move(list, &buffer_view_ref),
"could not append value");
});
}
void AddPackCallback(py::handle t, PackCallback pcb) {
assert(py_type_to_pack_callbacks_.count(t.ptr()) == 0 && "duplicate types");
t.inc_ref();
py_type_to_pack_callbacks_.insert(std::make_pair(t.ptr(), std::move(pcb)));
}
py::dict BuildAbiTypeToDtype() {
auto d = py::dict();
d[kF32] = numpy_module().attr("float32");
d[kF64] = numpy_module().attr("float64");
d[kI1] = numpy_module().attr("bool_");
d[kI8] = numpy_module().attr("int8");
d[kI16] = numpy_module().attr("int16");
d[kI64] = numpy_module().attr("int64");
d[kI32] = numpy_module().attr("int32");
return d;
}
// Cached modules and types. Those that involve recursive lookup within
// our top level module, we defer. Those outside, we cache at creation.
py::module numpy_module_ = py::module::import("numpy");
std::optional<py::object> runtime_module_;
std::optional<py::module> array_interop_module_;
std::optional<py::object> device_array_type_;
py::type hal_buffer_view_type_ = py::type::of<HalBufferView>();
// Maps Python type to a PackCallback that can generically code it.
// This will have inc_ref() called on them when added.
std::unordered_map<PyObject *, PackCallback> py_type_to_pack_callbacks_;
// Dict of str (ABI dtype like 'f32') to numpy dtype.
py::dict abi_type_to_dtype_ = BuildAbiTypeToDtype();
};
/// Object that can pack Python arguments into a VM List for a specific
/// function.
class ArgumentPacker {
public:
ArgumentPacker(InvokeStatics &statics, std::optional<py::list> arg_descs)
: statics_(statics) {
IREE_TRACE_SCOPE0("ArgumentPacker::Init");
if (!arg_descs) {
dynamic_dispatch_ = true;
} else {
// Reflection dispatch.
for (py::handle desc : *arg_descs) {
int arg_index = flat_arg_packers_.size();
std::optional<std::string> kwarg_name;
py::object retained_sub_desc;
bool desc_is_list = py::isinstance<py::list>(desc);
// Check if named.
// ["named", "kwarg_name", sub_desc]
// If found, then we set kwarg_name and reset desc to the sub_desc.
if (desc_is_list) {
py::object maybe_named_field = desc[statics.kZero];
if (maybe_named_field.equal(statics.kNamedTag)) {
py::object name_field = desc[statics.kOne];
retained_sub_desc = desc[statics.kTwo];
kwarg_name = py::cast<std::string>(name_field);
desc = retained_sub_desc;
desc_is_list = py::isinstance<py::list>(desc);
kwarg_to_index_[name_field] = arg_index;
}
}
if (!kwarg_name) {
pos_only_arg_count_ += 1;
}
flat_arg_packers_.push_back(
statics.AbiTypeToPackCallback(desc, desc_is_list));
}
}
}
/// Packs positional/kw arguments into a suitable VmVariantList and returns
/// it.
VmVariantList Pack(InvokeContext &invoke_context, py::sequence pos_args,
py::dict kw_args) {
// Dynamic dispatch.
if (dynamic_dispatch_) {
IREE_TRACE_SCOPE0("ArgumentPacker::PackDynamic");
if (!kw_args.empty()) {
throw std::invalid_argument(
"kwargs not supported for dynamic dispatch functions");
}
VmVariantList arg_list = VmVariantList::Create(pos_args.size());
for (py::handle py_arg : pos_args) {
PackCallback packer = statics_.GetGenericPackCallbackFor(py_arg);
if (!packer) {
std::string message("could not convert python value to VM: ");
message.append(py::cast<std::string>(py::repr(py_arg)));
throw std::invalid_argument(std::move(message));
}
// TODO: Better error handling by catching the exception and
// reporting which arg has a problem.
packer(invoke_context, arg_list.raw_ptr(), py_arg);
}
return arg_list;
} else {
IREE_TRACE_SCOPE0("ArgumentPacker::PackReflection");
// Reflection based dispatch.
std::vector<py::handle> py_args(flat_arg_packers_.size());
if (pos_args.size() > pos_only_arg_count_) {
std::string message("mismatched call arity: expected ");
message.append(std::to_string(pos_only_arg_count_));
message.append(" got ");
message.append(std::to_string(pos_args.size()));
throw std::invalid_argument(std::move(message));
}
// Positional args.
size_t pos_index = 0;
for (py::handle py_arg : pos_args) {
py_args[pos_index++] = py_arg;
}
// Keyword args.
for (auto it : kw_args) {
int found_index;
try {
found_index = py::cast<int>(kwarg_to_index_[it.first]);
} catch (std::exception &) {
std::string message("specified kwarg '");
message.append(py::cast<py::str>(it.first));
message.append("' is unknown");
throw std::invalid_argument(std::move(message));
}
if (py_args[found_index]) {
std::string message(
"mismatched call arity: duplicate keyword argument '");
message.append(py::cast<py::str>(it.first));
message.append("'");
throw std::invalid_argument(std::move(message));
}
py_args[found_index] = it.second;
}
// Now check to see that all args are set.
for (size_t i = 0; i < py_args.size(); ++i) {
if (!py_args[i]) {
std::string message(
"mismatched call arity: expected a value for argument ");
message.append(std::to_string(i));
throw std::invalid_argument(std::move(message));
}
}
// Start packing into the list.
VmVariantList arg_list = VmVariantList::Create(flat_arg_packers_.size());
for (size_t i = 0; i < py_args.size(); ++i) {
// TODO: Better error handling by catching the exception and
// reporting which arg has a problem.
flat_arg_packers_[i](invoke_context, arg_list.raw_ptr(), py_args[i]);
}
return arg_list;
}
}
private:
InvokeStatics &statics_;
int pos_only_arg_count_ = 0;
// Dictionary of py::str -> py::int_ mapping kwarg names to position in
// the argument list. We store this as a py::dict because it is optimized
// for py::str lookup.
py::dict kwarg_to_index_;
std::vector<PackCallback> flat_arg_packers_;
// If true, then there is no dispatch metadata and we process fully
// dynamically.
bool dynamic_dispatch_ = false;
};
} // namespace
void SetupInvokeBindings(pybind11::module &m) {
py::class_<InvokeStatics>(m, "_InvokeStatics");
py::class_<InvokeContext>(m, "InvokeContext").def(py::init<HalDevice &>());
py::class_<ArgumentPacker>(m, "ArgumentPacker")
.def(py::init<InvokeStatics &, std::optional<py::list>>())
.def("pack", &ArgumentPacker::Pack);
m.attr("_invoke_statics") = py::cast(InvokeStatics());
}
} // namespace python
} // namespace iree