Rewrite Python function argument packing in C++. (#8311)
* Rewrite Python function argument packing in C++.
* This is a first step towards implementing all hot path code in C++.
* With the Python code, I didn't have great performance baselines, but I estimate it was taking ~80-100us to do argument packing for a 3 item call. This patch:
* Average of 8us total for argument packing for this case.
* Average of 1us per DeviceArray.
* Both numbers are for invocations with static reflection data (which is also now layered for easier access to it can be added at runtime if desired to make things faster). The dynamic case does a bit more work.
* Once more of this is ported to C++, we can eliminate a lot more overhead by intercepting the tp_call protocol directly (vs doing list/dict manipulation).
* This also fixes a critical bug where the Python "DEVICE_VISIBLE" enumeration value was actually set to IREE_HAL_MEMORY_TYPE_HOST_VISIBLE. I believe the result was that all arguments were being set up in a non optimal way (but we were still getting a latency boost from the results). Attempting to make this fix had the effect of opening a bug farm of other issues setting up a buffer allocation, so I ended up adding a TODO and re-enabling HOST_VISIBLE. Will triage more once this lands, now that is is possible to actually disable HOST_VISIBLE.
* Need to check performance in a real setting. My maching is pretty noisy. I am seeing lower e2e numbers than before.
* Release gil around invoke and buffer copy
diff --git a/bindings/python/iree/runtime/CMakeLists.txt b/bindings/python/iree/runtime/CMakeLists.txt
index f2d62ed..95cab2e 100644
--- a/bindings/python/iree/runtime/CMakeLists.txt
+++ b/bindings/python/iree/runtime/CMakeLists.txt
@@ -24,13 +24,15 @@
PyExtRt
MODULE_NAME binding
SRCS
- "initialize_module.cc"
"binding.h"
+ "initialize_module.cc"
+ "invoke.h"
+ "invoke.cc"
"hal.h"
- "vm.h"
"hal.cc"
"status_utils.cc"
"status_utils.h"
+ "vm.h"
"vm.cc"
UNIX_LINKER_SCRIPT
"unix_version.lds"
@@ -42,6 +44,7 @@
iree::base
iree::base::cc
iree::base::internal::flags
+ iree::base::tracing
iree::hal
iree::hal::drivers
iree::modules::hal
diff --git a/bindings/python/iree/runtime/binding.h b/bindings/python/iree/runtime/binding.h
index cded74f..64240d4 100644
--- a/bindings/python/iree/runtime/binding.h
+++ b/bindings/python/iree/runtime/binding.h
@@ -26,6 +26,7 @@
class ApiRefCounted {
public:
ApiRefCounted() : instance_(nullptr) {}
+ ApiRefCounted(ApiRefCounted& other) : instance_(other.instance_) { Retain(); }
ApiRefCounted(ApiRefCounted&& other) : instance_(other.instance_) {
other.instance_ = nullptr;
}
diff --git a/bindings/python/iree/runtime/function.py b/bindings/python/iree/runtime/function.py
index d0ebd79..9c39845 100644
--- a/bindings/python/iree/runtime/function.py
+++ b/bindings/python/iree/runtime/function.py
@@ -13,10 +13,12 @@
import numpy as np
from .binding import (
+ _invoke_statics,
+ ArgumentPacker,
BufferUsage,
HalBufferView,
HalDevice,
- HalElementType,
+ InvokeContext,
MemoryType,
VmContext,
VmFunction,
@@ -81,9 +83,8 @@
"_vm_function",
"_abi_dict",
"_arg_descs",
+ "_arg_packer",
"_ret_descs",
- "_named_arg_indices",
- "_max_named_arg_index",
"_has_inlined_results",
"_tracer",
]
@@ -102,15 +103,17 @@
self._arg_descs = None
self._ret_descs = None
self._has_inlined_results = False
- self._named_arg_indices: Dict[str, int] = {}
- self._max_named_arg_index: int = -1
self._parse_abi_dict(vm_function)
+ self._arg_packer = ArgumentPacker(_invoke_statics, self._arg_descs)
@property
def vm_function(self) -> VmFunction:
return self._vm_function
def __call__(self, *args, **kwargs):
+ invoke_context = InvokeContext(self._device)
+ arg_list = self._arg_packer.pack(invoke_context, args, kwargs)
+
call_trace = None # type: Optional[tracing.CallTrace]
if self._tracer:
call_trace = self._tracer.start_call(self._vm_function)
@@ -121,25 +124,7 @@
inv = Invocation(self._device)
ret_descs = self._ret_descs
- # Merge keyword args in by name->position mapping.
- if kwargs:
- args = list(args)
- len_delta = self._max_named_arg_index - len(args) + 1
- if len_delta > 0:
- # Fill in MissingArgument placeholders before arranging kwarg input.
- # Any remaining placeholders will fail arity checks later on.
- args.extend([MissingArgument] * len_delta)
-
- for kwarg_key, kwarg_value in kwargs.items():
- try:
- kwarg_index = self._named_arg_indices[kwarg_key]
- except KeyError:
- raise ArgumentError(f"specified kwarg '{kwarg_key}' is unknown")
- args[kwarg_index] = kwarg_value
-
- arg_list = VmVariantList(len(args))
ret_list = VmVariantList(len(ret_descs) if ret_descs is not None else 1)
- _merge_python_sequence_to_vm(inv, arg_list, args, self._arg_descs)
if call_trace:
call_trace.add_vm_list(arg_list, "args")
self._invoke(arg_list, ret_list)
@@ -194,17 +179,6 @@
raise RuntimeError(
f"Malformed function reflection metadata structure: {reflection}")
- # Post-process the arg descs to transform "named" records to just their
- # type, stashing the index.
- for i in range(len(self._arg_descs)):
- maybe_named_desc = self._arg_descs[i]
- if maybe_named_desc and maybe_named_desc[0] == "named":
- arg_name, arg_type_desc = maybe_named_desc[1:]
- self._arg_descs[i] = arg_type_desc
- self._named_arg_indices[arg_name] = i
- if i > self._max_named_arg_index:
- self._max_named_arg_index = i
-
# Detect whether the results are a slist/stuple/sdict, which indicates
# that they are inlined with the function's results.
if len(self._ret_descs) == 1:
@@ -216,164 +190,6 @@
return repr(self._vm_function)
-# Python type to VM Type converters. All of these take:
-# inv: Invocation
-# target_list: VmVariantList to append to
-# python_value: The python value of the given type
-# desc: The ABI descriptor list (or None if in dynamic mode).
-
-
-def _bool_to_vm(inv: Invocation, t: VmVariantList, x, desc):
- _int_to_vm(inv, t, int(x), desc)
-
-
-def _int_to_vm(inv: Invocation, t: VmVariantList, x, desc):
- # Implicit conversion to a 0d tensor.
- if desc and _is_0d_ndarray_descriptor(desc):
- casted = _cast_scalar_to_ndarray(inv, x, desc)
- _ndarray_to_vm(inv, t, casted, desc)
- return
- t.push_int(x)
-
-
-def _float_to_vm(inv: Invocation, t: VmVariantList, x, desc):
- # Implicit conversion to a 0d tensor.
- if desc and _is_0d_ndarray_descriptor(desc):
- casted = _cast_scalar_to_ndarray(inv, x, desc)
- _ndarray_to_vm(inv, t, casted, desc)
- return
- t.push_float(x)
-
-
-def _list_or_tuple_to_vm(inv: Invocation, t: VmVariantList, x, desc):
- desc_type = desc[0]
- if desc_type != "slist" and desc_type != "stuple":
- _raise_argument_error(inv,
- f"passed a list or tuple but expected {desc_type}")
- # When decoding a list or tuple, the desc object is like:
- # ['slist', [...value_type_0...], ...]
- # Where the type is either 'slist' or 'stuple'.
- sub_descriptors = desc[1:]
- arity = len(sub_descriptors)
- if len(x) != arity:
- _raise_argument_error(inv,
- f"mismatched list/tuple arity: {len(x)} vs {arity}")
- sub_list = VmVariantList(arity)
- _merge_python_sequence_to_vm(inv, sub_list, x, sub_descriptors)
- t.push_list(sub_list)
-
-
-def _dict_to_vm(inv: Invocation, t: VmVariantList, x, desc):
- desc_type = desc[0]
- if desc_type != "sdict":
- _raise_argument_error(inv, f"passed a dict but expected {desc_type}")
- # When decoding a dict, the desc object is like:
- # ['sdict', ['key0', [...value_type_0...]], ['key1', [...value_type_1...]]]]
- sub_descriptors = desc[1:]
- py_values = []
- value_descs = []
- for key, value_desc in sub_descriptors:
- try:
- py_values.append(x[key])
- except KeyError:
- _raise_argument_error(inv, f"expected dict item with key '{key}'")
- value_descs.append(value_desc)
-
- sub_list = VmVariantList(len(py_values))
- _merge_python_sequence_to_vm(inv, sub_list, py_values, value_descs)
- t.push_list(sub_list)
-
-
-def _str_to_vm(inv: Invocation, t: VmVariantList, x, desc):
- _raise_argument_error(inv, "Python str arguments not yet supported")
-
-
-def _ndarray_to_vm(inv: Invocation, t: VmVariantList, x, desc):
- # Validate and implicit conversion against type descriptor.
- if FUNCTION_INPUT_VALIDATION and desc is not None:
- desc_type = desc[0]
- if desc_type != "ndarray":
- _raise_argument_error(inv, f"passed an ndarray but expected {desc_type}")
- dtype_str = desc[1]
- try:
- dtype = ABI_TYPE_TO_DTYPE[dtype_str]
- except KeyError:
- _raise_argument_error(inv, f"unrecognized dtype '{dtype_str}'")
- if dtype != x.dtype:
- # TODO: If we got a DeviceArray in which triggers this implicit
- # conversion, it will fault back to the host, be converted and
- # then sent back. This is... not great.
- # At least warning about it so we know it might be a problem.
- logging.warn(
- "Implicit dtype conversion of DeviceArray forces transfer to host")
- x = np.asarray(x).astype(dtype)
- rank = desc[2]
- shape = desc[3:]
- ndarray_shape = x.shape
- if len(shape) != len(ndarray_shape) or rank != len(ndarray_shape):
- _raise_argument_error(
- inv, f"rank mismatch {len(ndarray_shape)} vs {len(shape)}")
- for exp_dim, act_dim in zip(shape, ndarray_shape):
- if exp_dim is not None and exp_dim != act_dim:
- _raise_argument_error(
- inv, f"shape mismatch {ndarray_shape} vs {tuple(shape)}")
- actual_dtype = x.dtype
- element_type = map_dtype_to_element_type(actual_dtype)
- if element_type is None:
- _raise_argument_error(inv, f"unsupported numpy dtype {x.dtype}")
-
- if isinstance(x, DeviceArray):
- # Already one of ours and did not get implicitly converted.
- buffer_view = x._buffer_view
- else:
- # Not one of ours. Put it on the device.
- buffer_view = inv.device.allocator.allocate_buffer_copy(
- memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE,
- allowed_usage=IMPLICIT_BUFFER_ARG_USAGE,
- buffer=np.asarray(x),
- element_type=element_type)
-
- t.push_buffer_view(buffer_view)
-
-
-def _buffer_view_to_vm(inv: Invocation, t: VmVariantList, x, desc):
- # BufferView is a low-level object and we do no validation here for it.
- # The assumption is that it is coming from either an advanced use case
- # or a systematic integration that knows what it is doing. The runtime
- # will do necessary validation.
- t.push_buffer_view(x)
-
-
-# Called in reflection mode when we know we want to coerce from something
-# 'ndarray' like (as defined by the reflection metadata).
-def _ndarray_like_to_vm(inv: Invocation, t: VmVariantList, x, desc):
- if isinstance(x, HalBufferView):
- return _buffer_view_to_vm(inv, t, x, desc)
- return _ndarray_to_vm(inv, t, x, desc)
-
-
-class _MissingArgument:
- """Placeholder for missing kwargs in the function input."""
-
- def __repr__(self):
- return "<mising argument>"
-
-
-MissingArgument = _MissingArgument()
-
-PYTHON_TO_VM_CONVERTERS = {
- bool: _bool_to_vm,
- int: _int_to_vm,
- float: _float_to_vm,
- list: _list_or_tuple_to_vm,
- tuple: _list_or_tuple_to_vm,
- dict: _dict_to_vm,
- str: _str_to_vm,
- np.ndarray: _ndarray_to_vm,
- HalBufferView: _buffer_view_to_vm,
- DeviceArray: _ndarray_to_vm,
-}
-
# VM to Python converters. All take:
# inv: Invocation
# vm_list: VmVariantList to read from
@@ -534,54 +350,6 @@
raise new_e
-def _merge_python_sequence_to_vm(inv: Invocation, vm_list, py_list, descs):
- # For dynamic mode, just assume we have the right arity.
- if descs is None:
- descs = [None] * len(py_list)
- elif FUNCTION_INPUT_VALIDATION:
- len_py_list = sum([1 for x in py_list if x is not MissingArgument])
- if len(py_list) != len_py_list:
- _raise_argument_error(
- inv,
- f"mismatched call arity: expected {len(descs)} arguments but got "
- f"{len_py_list}. Expected signature=\n{descs}\nfor input=\n{py_list}")
-
- for py_value, desc in zip(py_list, descs):
- inv.current_arg = py_value
- inv.current_desc = desc
-
- # For ndarray, we want to be able to handle array-like, so check for that
- # explicitly (duck typed vs static typed).
- if _is_ndarray_descriptor(desc):
- converter = _ndarray_like_to_vm
- else:
- converter = _get_python_to_vm_converter(inv, py_value, desc)
-
- try:
- converter(inv, vm_list, py_value, desc)
- except ArgumentError:
- raise
- except Exception as e:
- _raise_argument_error(inv, f"exception converting from Python type to VM",
- e)
-
-
-def _get_python_to_vm_converter(inv: Invocation, py_value, desc):
- py_type = py_value.__class__
- converter = PYTHON_TO_VM_CONVERTERS.get(py_type)
- if converter is not None:
- return converter
- # See if it supports the __array__ protocol and if so, pull it to
- # the host and use the ndarray converter. This will create round-trips
- # between frameworks but at least enables interop.
- if hasattr(py_value, "__array__"):
- return _ndarray_to_vm
-
- _raise_argument_error(
- inv, f"cannot map Python type to VM: {py_type}"
- f" (for desc {desc})")
-
-
def _extract_vm_sequence_to_python(inv: Invocation, vm_list, descs):
vm_list_arity = len(vm_list)
if descs is None:
diff --git a/bindings/python/iree/runtime/function_test.py b/bindings/python/iree/runtime/function_test.py
index 0a013cc..57b8481 100644
--- a/bindings/python/iree/runtime/function_test.py
+++ b/bindings/python/iree/runtime/function_test.py
@@ -43,10 +43,11 @@
class FunctionTest(absltest.TestCase):
- def setUp(self):
+ @classmethod
+ def setUpClass(cls):
# Doesn't matter what device. We just need one.
config = rt.Config("vmvx")
- self.device = config.device
+ cls.device = config.device
def testNoReflectionScalars(self):
@@ -85,6 +86,140 @@
vm_context.mock_arg_reprs)
self.assertEqual(3, result)
+ def testListArg(self):
+
+ def invoke(arg_list, ret_list):
+ ret_list.push_int(3)
+
+ vm_context = MockVmContext(invoke)
+ vm_function = MockVmFunction(reflection={
+ "iree.abi":
+ json.dumps({
+ "a": [["slist", "i32", "i32"],],
+ "r": ["i32",],
+ })
+ })
+ invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+ _ = invoker([2, 3])
+ self.assertEqual("[<VmVariantList(1): [List[2, 3]]>]",
+ vm_context.mock_arg_reprs)
+
+ def testListArgNoReflection(self):
+
+ def invoke(arg_list, ret_list):
+ ret_list.push_int(3)
+
+ vm_context = MockVmContext(invoke)
+ vm_function = MockVmFunction(reflection={})
+ invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+ _ = invoker([2, 3])
+ self.assertEqual("[<VmVariantList(1): [List[2, 3]]>]",
+ vm_context.mock_arg_reprs)
+
+ def testListArgArityMismatch(self):
+
+ def invoke(arg_list, ret_list):
+ ret_list.push_int(3)
+
+ vm_context = MockVmContext(invoke)
+ vm_function = MockVmFunction(reflection={
+ "iree.abi":
+ json.dumps({
+ "a": [["slist", "i32", "i32"],],
+ "r": ["i32",],
+ })
+ })
+ invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+ with self.assertRaisesRegex(ValueError,
+ "expected a sequence with 2 values. got:"):
+ _ = invoker([2, 3, 4])
+
+ def testTupleArg(self):
+
+ def invoke(arg_list, ret_list):
+ ret_list.push_int(3)
+
+ vm_context = MockVmContext(invoke)
+ vm_function = MockVmFunction(reflection={
+ "iree.abi":
+ json.dumps({
+ "a": [["stuple", "i32", "i32"],],
+ "r": ["i32",],
+ })
+ })
+ invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+ _ = invoker((2, 3))
+ self.assertEqual("[<VmVariantList(1): [List[2, 3]]>]",
+ vm_context.mock_arg_reprs)
+
+ def testDictArg(self):
+
+ def invoke(arg_list, ret_list):
+ ret_list.push_int(3)
+
+ vm_context = MockVmContext(invoke)
+ vm_function = MockVmFunction(
+ reflection={
+ "iree.abi":
+ json.dumps({
+ "a": [["sdict", ["a", "i32"], ["b", "i32"]],],
+ "r": ["i32",],
+ })
+ })
+ invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+ _ = invoker({"b": 3, "a": 2})
+ self.assertEqual("[<VmVariantList(1): [List[2, 3]]>]",
+ vm_context.mock_arg_reprs)
+
+ def testDictArgArityMismatch(self):
+
+ def invoke(arg_list, ret_list):
+ ret_list.push_int(3)
+
+ vm_context = MockVmContext(invoke)
+ vm_function = MockVmFunction(
+ reflection={
+ "iree.abi":
+ json.dumps({
+ "a": [["sdict", ["a", "i32"], ["b", "i32"]],],
+ "r": ["i32",],
+ })
+ })
+ invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+ with self.assertRaisesRegex(ValueError,
+ "expected a dict with 2 values. got:"):
+ _ = invoker({"a": 2, "b": 3, "c": 4})
+
+ def testDictArgKeyError(self):
+
+ def invoke(arg_list, ret_list):
+ ret_list.push_int(3)
+
+ vm_context = MockVmContext(invoke)
+ vm_function = MockVmFunction(
+ reflection={
+ "iree.abi":
+ json.dumps({
+ "a": [["sdict", ["a", "i32"], ["b", "i32"]],],
+ "r": ["i32",],
+ })
+ })
+ invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+ with self.assertRaisesRegex(ValueError, "could not get item 'b' from: "):
+ _ = invoker({"a": 2, "c": 3})
+
+ def testDictArgNoReflection(self):
+
+ def invoke(arg_list, ret_list):
+ ret_list.push_int(3)
+
+ vm_context = MockVmContext(invoke)
+ vm_function = MockVmFunction(reflection={})
+ invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+ _ = invoker({"b": 3, "a": 2})
+ self.assertEqual("[<VmVariantList(1): [List[2, 3]]>]",
+ vm_context.mock_arg_reprs)
+
def testInlinedResults(self):
def invoke(arg_list, ret_list):
@@ -316,6 +451,43 @@
self.assertEqual("<VmVariantList(1): [HalBufferView(2:0x20000011)]>",
repr(invoked_arg_list))
+ def testNdarrayArgNoReflection(self):
+ arg_array = np.asarray([1, 0], dtype=np.int32)
+
+ invoked_arg_list = None
+
+ def invoke(arg_list, ret_list):
+ nonlocal invoked_arg_list
+ invoked_arg_list = arg_list
+
+ vm_context = MockVmContext(invoke)
+ vm_function = MockVmFunction(reflection={})
+ invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+ result = invoker(arg_array)
+ self.assertEqual("<VmVariantList(1): [HalBufferView(2:0x20000011)]>",
+ repr(invoked_arg_list))
+
+ def testDeviceArrayArgNoReflection(self):
+ # Note that since the device array is set up to disallow implicit host
+ # transfers, this also verifies that no accidental/automatic transfers
+ # are done as part of marshalling the array to the function.
+ arg_array = rt.asdevicearray(self.device,
+ np.asarray([1, 0], dtype=np.int32),
+ implicit_host_transfer=False)
+
+ invoked_arg_list = None
+
+ def invoke(arg_list, ret_list):
+ nonlocal invoked_arg_list
+ invoked_arg_list = arg_list
+
+ vm_context = MockVmContext(invoke)
+ vm_function = MockVmFunction(reflection={})
+ invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+ result = invoker(arg_array)
+ self.assertEqual("<VmVariantList(1): [HalBufferView(2:0x20000011)]>",
+ repr(invoked_arg_list))
+
def testBufferViewArgNoReflection(self):
arg_buffer_view = self.device.allocator.allocate_buffer_copy(
memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE,
diff --git a/bindings/python/iree/runtime/hal.cc b/bindings/python/iree/runtime/hal.cc
index afb1ed7..d6be403 100644
--- a/bindings/python/iree/runtime/hal.cc
+++ b/bindings/python/iree/runtime/hal.cc
@@ -6,6 +6,7 @@
#include "bindings/python/iree/runtime/hal.h"
+#include "iree/base/tracing.h"
#include "iree/hal/api.h"
#include "pybind11/numpy.h"
@@ -79,6 +80,7 @@
py::object HalAllocator::AllocateBufferCopy(
int memory_type, int allowed_usage, py::object buffer,
std::optional<iree_hal_element_types_t> element_type) {
+ IREE_TRACE_SCOPE0("HalAllocator::AllocateBufferCopy");
// Request a view of the buffer (use the raw python C API to avoid
// some allocation and copying at the pybind level).
Py_buffer py_view;
@@ -96,11 +98,16 @@
PyBufferReleaser py_view_releaser(py_view);
iree_hal_buffer_t* hal_buffer;
- CheckApiStatus(
- iree_hal_allocator_allocate_buffer(
- raw_ptr(), memory_type, allowed_usage, py_view.len,
- iree_make_const_byte_span(py_view.buf, py_view.len), &hal_buffer),
- "Failed to allocate device visible buffer");
+ // TODO: Should not require host visible :(
+ {
+ py::gil_scoped_release release;
+ CheckApiStatus(
+ iree_hal_allocator_allocate_buffer(
+ raw_ptr(), memory_type | IREE_HAL_MEMORY_TYPE_HOST_VISIBLE,
+ allowed_usage, py_view.len,
+ iree_make_const_byte_span(py_view.buf, py_view.len), &hal_buffer),
+ "Failed to allocate device visible buffer");
+ }
if (!element_type) {
return py::cast(HalBuffer::StealFromRawPtr(hal_buffer),
@@ -306,7 +313,7 @@
.value("HOST_COHERENT", IREE_HAL_MEMORY_TYPE_HOST_COHERENT)
.value("HOST_CACHED", IREE_HAL_MEMORY_TYPE_HOST_CACHED)
.value("HOST_LOCAL", IREE_HAL_MEMORY_TYPE_HOST_LOCAL)
- .value("DEVICE_VISIBLE", IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)
+ .value("DEVICE_VISIBLE", IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)
.value("DEVICE_LOCAL", IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL)
.export_values()
.def("__or__",
diff --git a/bindings/python/iree/runtime/initialize_module.cc b/bindings/python/iree/runtime/initialize_module.cc
index b48586c..6c4487d 100644
--- a/bindings/python/iree/runtime/initialize_module.cc
+++ b/bindings/python/iree/runtime/initialize_module.cc
@@ -6,6 +6,7 @@
#include "bindings/python/iree/runtime/binding.h"
#include "bindings/python/iree/runtime/hal.h"
+#include "bindings/python/iree/runtime/invoke.h"
#include "bindings/python/iree/runtime/status_utils.h"
#include "bindings/python/iree/runtime/vm.h"
#include "iree/base/internal/flags.h"
@@ -21,6 +22,7 @@
m.doc() = "IREE Binding Backend Helpers";
SetupHalBindings(m);
+ SetupInvokeBindings(m);
SetupVmBindings(m);
m.def("parse_flags", [](py::args py_flags) {
diff --git a/bindings/python/iree/runtime/invoke.cc b/bindings/python/iree/runtime/invoke.cc
new file mode 100644
index 0000000..33686bd
--- /dev/null
+++ b/bindings/python/iree/runtime/invoke.cc
@@ -0,0 +1,712 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "bindings/python/iree/runtime/invoke.h"
+
+#include "bindings/python/iree/runtime/hal.h"
+#include "bindings/python/iree/runtime/vm.h"
+#include "iree/base/api.h"
+#include "iree/base/tracing.h"
+#include "iree/hal/api.h"
+#include "iree/modules/hal/module.h"
+#include "iree/vm/api.h"
+
+namespace iree {
+namespace python {
+
+namespace {
+
+class InvokeContext {
+ public:
+ InvokeContext(HalDevice &device) : device_(device) {}
+
+ HalDevice &device() { return device_; }
+ HalAllocator allocator() {
+ // TODO: Unfortunate that we inc ref here but that is how our object model
+ // is set up.
+ return HalAllocator::BorrowFromRawPtr(device().allocator());
+ }
+
+ private:
+ HalDevice device_;
+};
+
+using PackCallback =
+ std::function<void(InvokeContext &, iree_vm_list_t *, py::handle)>;
+
+class InvokeStatics {
+ public:
+ ~InvokeStatics() {
+ for (auto it : py_type_to_pack_callbacks_) {
+ py::handle(it.first).dec_ref();
+ }
+ }
+
+ py::str kNamedTag = py::str("named");
+ py::str kSlistTag = py::str("slist");
+ py::str kStupleTag = py::str("stuple");
+ py::str kSdictTag = py::str("sdict");
+
+ py::int_ kZero = py::int_(0);
+ py::int_ kOne = py::int_(1);
+ py::int_ kTwo = py::int_(2);
+ py::str kAsArray = py::str("asarray");
+ py::str kMapDtypeToElementTypeAttr = py::str("map_dtype_to_element_type");
+ py::str kContiguousArg = py::str("C");
+ py::str kArrayProtocolAttr = py::str("__array__");
+ py::str kDtypeAttr = py::str("dtype");
+
+ // Primitive type names.
+ py::str kF32 = py::str("f32");
+ py::str kF64 = py::str("f64");
+ py::str kI1 = py::str("i1");
+ py::str kI8 = py::str("i8");
+ py::str kI16 = py::str("i16");
+ py::str kI32 = py::str("i32");
+ py::str kI64 = py::str("i64");
+
+ // Compound types names.
+ py::str kNdarray = py::str("ndarray");
+
+ // Attribute names.
+ py::str kAttrBufferView = py::str("_buffer_view");
+
+ // Module 'numpy'.
+ py::module &numpy_module() { return numpy_module_; }
+
+ py::object &runtime_module() {
+ if (!runtime_module_) {
+ runtime_module_ = py::module::import("iree.runtime");
+ }
+ return *runtime_module_;
+ }
+
+ py::module &array_interop_module() {
+ if (!array_interop_module_) {
+ array_interop_module_ = py::module::import("iree.runtime.array_interop");
+ }
+ return *array_interop_module_;
+ }
+
+ py::object &device_array_type() {
+ if (!device_array_type_) {
+ device_array_type_ = runtime_module().attr("DeviceArray");
+ }
+ return *device_array_type_;
+ }
+
+ py::type &hal_buffer_view_type() { return hal_buffer_view_type_; }
+
+ py::object MapElementAbiTypeToDtype(py::object &element_abi_type) {
+ try {
+ return abi_type_to_dtype_[element_abi_type];
+ } catch (std::exception &) {
+ std::string msg("could not map abi type ");
+ msg.append(py::cast<std::string>(py::repr(element_abi_type)));
+ msg.append(" to numpy dtype");
+ throw std::invalid_argument(std::move(msg));
+ }
+ }
+
+ enum iree_hal_element_types_t MapDtypeToElementType(py::object dtype) {
+ // TODO: Consider porting this from a py func to C++ as it can be on
+ // the critical path.
+ try {
+ py::object element_type =
+ array_interop_module().attr(kMapDtypeToElementTypeAttr)(dtype);
+ if (element_type.is_none()) {
+ throw std::invalid_argument("mapping not found");
+ }
+ return py::cast<enum iree_hal_element_types_t>(element_type);
+ } catch (std::exception &e) {
+ std::string msg("could not map dtype ");
+ msg.append(py::cast<std::string>(py::repr(dtype)));
+ msg.append(" to element type: ");
+ msg.append(e.what());
+ throw std::invalid_argument(std::move(msg));
+ }
+ }
+
+ PackCallback AbiTypeToPackCallback(py::handle desc) {
+ return AbiTypeToPackCallback(
+ std::move(desc), /*desc_is_list=*/py::isinstance<py::list>(desc));
+ }
+
+ // Given an ABI desc, return a callback that can pack a corresponding py
+ // value into a list. For efficiency, the caller must specify whether the
+ // desc is a list (this check already needs to be done typically so
+ // passed in).
+ PackCallback AbiTypeToPackCallback(py::handle desc, bool desc_is_list) {
+ // Switch based on descriptor type.
+ if (desc_is_list) {
+ // Compound type.
+ py::object compound_type = desc[kZero];
+ if (compound_type.equal(kNdarray)) {
+ // Has format:
+ // ["ndarray", "f32", dim0, dim1, ...]
+ // Extract static information about the target.
+ std::vector<int64_t> abi_shape(py::len(desc) - 2);
+ for (size_t i = 0, e = abi_shape.size(); i < e; ++i) {
+ py::handle dim = desc[py::int_(i + 2)];
+ abi_shape[i] = dim.is_none() ? -1 : py::cast<int64_t>(dim);
+ }
+
+ // Map abi element type to dtype.
+ py::object abi_type = desc[kOne];
+ py::object target_dtype = MapElementAbiTypeToDtype(abi_type);
+ auto hal_element_type = MapDtypeToElementType(target_dtype);
+
+ return [this, target_dtype = std::move(target_dtype), hal_element_type,
+ abi_shape = std::move(abi_shape)](InvokeContext &c,
+ iree_vm_list_t *list,
+ py::handle py_value) {
+ IREE_TRACE_SCOPE0("ArgumentPacker::ReflectionNdarray");
+ HalBufferView *bv = nullptr;
+ py::object retained_bv;
+ if (py::isinstance(py_value, device_array_type())) {
+ // Short-circuit: If a DeviceArray is provided, assume it is
+ // correct.
+ IREE_TRACE_SCOPE0("PackDeviceArray");
+ bv = py::cast<HalBufferView *>(py_value.attr(kAttrBufferView));
+ } else if (py::isinstance(py_value, hal_buffer_view_type())) {
+ // Short-circuit: If a HalBufferView is provided directly.
+ IREE_TRACE_SCOPE0("PackBufferView");
+ bv = py::cast<HalBufferView *>(py_value);
+ } else {
+ // Fall back to the array protocol to generate a host side
+ // array and then convert that.
+ IREE_TRACE_SCOPE0("PackHostArray");
+ py::object host_array;
+ try {
+ host_array = numpy_module().attr(kAsArray)(py_value, target_dtype,
+ kContiguousArg);
+ } catch (std::exception &e) {
+ std::string msg("could not convert value to numpy array: dtype=");
+ msg.append(py::cast<std::string>(py::repr(target_dtype)));
+ msg.append(", error='");
+ msg.append(e.what());
+ msg.append("', value=");
+ msg.append(py::cast<std::string>(py::repr(py_value)));
+ throw std::invalid_argument(std::move(msg));
+ }
+
+ retained_bv = c.allocator().AllocateBufferCopy(
+ IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL |
+ IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
+ IREE_HAL_BUFFER_USAGE_ALL, host_array, hal_element_type);
+ bv = py::cast<HalBufferView *>(retained_bv);
+ }
+
+ // TODO: Add some shape verification. Not strictly necessary as the VM
+ // will check, but may make error reporting nicer.
+ // TODO: It is theoretically possible to enqueue further conversions
+ // on the device, but for now we require things to line up closely.
+ // TODO: If adding further manipulation here, please make this common
+ // with the generic access case.
+ iree_vm_ref_t buffer_view_ref =
+ iree_hal_buffer_view_retain_ref(bv->raw_ptr());
+ CheckApiStatus(iree_vm_list_push_ref_move(list, &buffer_view_ref),
+ "could not push buffer view to list");
+ };
+ } else if (compound_type.equal(kSlistTag) ||
+ compound_type.equal(kStupleTag)) {
+ // Tuple/list extraction.
+ // When decoding a list or tuple, the desc object is like:
+ // ['slist', [...value_type_0...], ...]
+ // Where the type is either 'slist' or 'stuple'.
+ std::vector<PackCallback> sub_packers(py::len(desc) - 1);
+ for (size_t i = 0; i < sub_packers.size(); i++) {
+ sub_packers[i] = AbiTypeToPackCallback(desc[py::int_(i + 1)]);
+ }
+ return [sub_packers = std::move(sub_packers)](InvokeContext &c,
+ iree_vm_list_t *list,
+ py::handle py_value) {
+ if (py::len(py_value) != sub_packers.size()) {
+ std::string msg("expected a sequence with ");
+ msg.append(std::to_string(sub_packers.size()));
+ msg.append(" values. got: ");
+ msg.append(py::cast<std::string>(py::repr(py_value)));
+ throw std::invalid_argument(std::move(msg));
+ }
+ VmVariantList item_list = VmVariantList::Create(sub_packers.size());
+ for (size_t i = 0; i < sub_packers.size(); ++i) {
+ py::object item_py_value;
+ try {
+ item_py_value = py_value[py::int_(i)];
+ } catch (std::exception &e) {
+ std::string msg("could not get item ");
+ msg.append(std::to_string(i));
+ msg.append(" from: ");
+ msg.append(py::cast<std::string>(py::repr(py_value)));
+ msg.append(": ");
+ msg.append(e.what());
+ throw std::invalid_argument(std::move(msg));
+ }
+ sub_packers[i](c, item_list.raw_ptr(), item_py_value);
+ }
+
+ // Push the sub list.
+ iree_vm_ref_t retained =
+ iree_vm_list_retain_ref(item_list.steal_raw_ptr());
+ iree_vm_list_push_ref_move(list, &retained);
+ };
+ } else if (compound_type.equal(kSdictTag)) {
+ // Dict extraction.
+ // The descriptor for an sdict is like:
+ // ['sdict', ['key1', value1], ...]
+ std::vector<std::pair<py::object, PackCallback>> sub_packers(
+ py::len(desc) - 1);
+ for (size_t i = 0; i < sub_packers.size(); i++) {
+ py::object sub_desc = desc[py::int_(i + 1)];
+ py::object key = sub_desc[kZero];
+ py::object value_desc = sub_desc[kOne];
+ sub_packers[i] =
+ std::make_pair(std::move(key), AbiTypeToPackCallback(value_desc));
+ }
+ return [sub_packers = std::move(sub_packers)](InvokeContext &c,
+ iree_vm_list_t *list,
+ py::handle py_value) {
+ if (py::len(py_value) != sub_packers.size()) {
+ std::string msg("expected a dict with ");
+ msg.append(std::to_string(sub_packers.size()));
+ msg.append(" values. got: ");
+ msg.append(py::cast<std::string>(py::repr(py_value)));
+ throw std::invalid_argument(std::move(msg));
+ }
+ VmVariantList item_list = VmVariantList::Create(sub_packers.size());
+ for (size_t i = 0; i < sub_packers.size(); ++i) {
+ py::object item_py_value;
+ try {
+ item_py_value = py_value[sub_packers[i].first];
+ } catch (std::exception &e) {
+ std::string msg("could not get item ");
+ msg.append(py::cast<std::string>(py::repr(sub_packers[i].first)));
+ msg.append(" from: ");
+ msg.append(py::cast<std::string>(py::repr(py_value)));
+ msg.append(": ");
+ msg.append(e.what());
+ throw std::invalid_argument(std::move(msg));
+ }
+ sub_packers[i].second(c, item_list.raw_ptr(), item_py_value);
+ }
+
+ // Push the sub list.
+ iree_vm_ref_t retained =
+ iree_vm_list_retain_ref(item_list.steal_raw_ptr());
+ iree_vm_list_push_ref_move(list, &retained);
+ };
+ } else {
+ std::string message("Unrecognized reflection compound type: ");
+ message.append(py::cast<std::string>(compound_type));
+ throw std::invalid_argument(message);
+ }
+ } else {
+ // Primtive type.
+ py::str prim_type = py::cast<py::str>(desc);
+ if (prim_type.equal(kF32)) {
+ // f32
+ return [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
+ iree_vm_value_t vm_value =
+ iree_vm_value_make_f32(py::cast<float>(py_value));
+ CheckApiStatus(iree_vm_list_push_value(list, &vm_value),
+ "could not append value");
+ };
+ } else if (prim_type.equal(kF64)) {
+ // f64
+ return [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
+ iree_vm_value_t vm_value =
+ iree_vm_value_make_f64(py::cast<double>(py_value));
+ CheckApiStatus(iree_vm_list_push_value(list, &vm_value),
+ "could not append value");
+ };
+ } else if (prim_type.equal(kI32)) {
+ // i32.
+ return [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
+ iree_vm_value_t vm_value =
+ iree_vm_value_make_i32(py::cast<int32_t>(py_value));
+ CheckApiStatus(iree_vm_list_push_value(list, &vm_value),
+ "could not append value");
+ };
+ } else if (prim_type.equal(kI64)) {
+ // i64.
+ return [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
+ iree_vm_value_t vm_value =
+ iree_vm_value_make_i64(py::cast<int64_t>(py_value));
+ CheckApiStatus(iree_vm_list_push_value(list, &vm_value),
+ "could not append value");
+ };
+ } else if (prim_type.equal(kI8)) {
+ // i8.
+ return [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
+ iree_vm_value_t vm_value =
+ iree_vm_value_make_i8(py::cast<int8_t>(py_value));
+ CheckApiStatus(iree_vm_list_push_value(list, &vm_value),
+ "could not append value");
+ };
+ } else if (prim_type.equal(kI16)) {
+ // i16.
+ return [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
+ iree_vm_value_t vm_value =
+ iree_vm_value_make_i16(py::cast<int16_t>(py_value));
+ CheckApiStatus(iree_vm_list_push_value(list, &vm_value),
+ "could not append value");
+ };
+ } else {
+ std::string message("Unrecognized reflection primitive type: ");
+ message.append(py::cast<std::string>(prim_type));
+ throw std::invalid_argument(message);
+ }
+ }
+ }
+
+ PackCallback GetGenericPackCallbackFor(py::handle arg) {
+ PopulatePyTypeToPackCallbacks();
+ py::type clazz = py::type::of(arg);
+ auto found_it = py_type_to_pack_callbacks_.find(clazz.ptr());
+ if (found_it == py_type_to_pack_callbacks_.end()) {
+ // Probe to see if we have a host array.
+ if (py::hasattr(arg, kArrayProtocolAttr)) {
+ return GetGenericPackCallbackForNdarray();
+ }
+ return {};
+ }
+
+ return found_it->second;
+ }
+
+ private:
+ PackCallback GetGenericPackCallbackForNdarray() {
+ return [this](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
+ IREE_TRACE_SCOPE0("ArgumentPacker::GenericNdarray");
+ py::object host_array;
+ try {
+ host_array = numpy_module().attr(kAsArray)(
+ py_value, /*dtype=*/py::none(), kContiguousArg);
+ } catch (std::exception &e) {
+ std::string msg("could not convert value to numpy array: ");
+ msg.append("error='");
+ msg.append(e.what());
+ msg.append("', value=");
+ msg.append(py::cast<std::string>(py::repr(py_value)));
+ throw std::invalid_argument(std::move(msg));
+ }
+
+ auto hal_element_type =
+ MapDtypeToElementType(host_array.attr(kDtypeAttr));
+
+ // Put it on the device.
+ py::object retained_bv = c.allocator().AllocateBufferCopy(
+ IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL |
+ IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
+ IREE_HAL_BUFFER_USAGE_ALL, host_array, hal_element_type);
+ HalBufferView *bv = py::cast<HalBufferView *>(retained_bv);
+
+ // TODO: If adding further manipulation here, please make this common
+ // with the reflection access case.
+ iree_vm_ref_t buffer_view_ref =
+ iree_hal_buffer_view_retain_ref(bv->raw_ptr());
+ CheckApiStatus(iree_vm_list_push_ref_move(list, &buffer_view_ref),
+ "could not append value");
+ };
+ }
+
+ void PopulatePyTypeToPackCallbacks() {
+ if (!py_type_to_pack_callbacks_.empty()) return;
+
+ // We only care about int and double in the numeric hierarchy. Since Python
+ // has no further refinement of these, just treat them as vm 64 bit int and
+ // floats and let the VM take care of it. There isn't much else we can do.
+ AddPackCallback(
+ py::type::of(py::cast(1)),
+ [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
+ iree_vm_value_t vm_value =
+ iree_vm_value_make_i64(py::cast<int64_t>(py_value));
+ CheckApiStatus(iree_vm_list_push_value(list, &vm_value),
+ "could not append value");
+ });
+
+ AddPackCallback(
+ py::type::of(py::cast(1.0)),
+ [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
+ iree_vm_value_t vm_value =
+ iree_vm_value_make_f64(py::cast<double>(py_value));
+ CheckApiStatus(iree_vm_list_push_value(list, &vm_value),
+ "could not append value");
+ });
+
+ // List/tuple.
+ auto sequence_callback = [this](InvokeContext &c, iree_vm_list_t *list,
+ py::handle py_value) {
+ auto py_seq = py::cast<py::sequence>(py_value);
+ VmVariantList item_list = VmVariantList::Create(py::len(py_seq));
+ for (py::object py_item : py_seq) {
+ PackCallback sub_packer = GetGenericPackCallbackFor(py_item);
+ if (!sub_packer) {
+ std::string message("could not convert python value to VM: ");
+ message.append(py::cast<std::string>(py::repr(py_item)));
+ throw std::invalid_argument(std::move(message));
+ }
+ sub_packer(c, item_list.raw_ptr(), py_item);
+ }
+ // Push the sub list.
+ iree_vm_ref_t retained =
+ iree_vm_list_retain_ref(item_list.steal_raw_ptr());
+ iree_vm_list_push_ref_move(list, &retained);
+ };
+ AddPackCallback(py::type::of(py::list{}), sequence_callback);
+ AddPackCallback(py::type::of(py::tuple{}), sequence_callback);
+
+ // Dict.
+ auto dict_callback = [this](InvokeContext &c, iree_vm_list_t *list,
+ py::handle py_value) {
+ // Gets all dict items and sorts (by key).
+ auto py_dict = py::cast<py::dict>(py_value);
+ py::list py_keys;
+ for (std::pair<py::handle, py::handle> it : py_dict) {
+ py_keys.append(it.first);
+ }
+ py_keys.attr("sort")();
+
+ VmVariantList item_list = VmVariantList::Create(py_keys.size());
+ for (auto py_key : py_keys) {
+ py::object py_item = py_dict[py_key];
+ PackCallback sub_packer = GetGenericPackCallbackFor(py_item);
+ if (!sub_packer) {
+ std::string message("could not convert python value to VM: ");
+ message.append(py::cast<std::string>(py::repr(py_item)));
+ throw std::invalid_argument(std::move(message));
+ }
+ sub_packer(c, item_list.raw_ptr(), py_item);
+ }
+ // Push the sub list.
+ iree_vm_ref_t retained =
+ iree_vm_list_retain_ref(item_list.steal_raw_ptr());
+ iree_vm_list_push_ref_move(list, &retained);
+ };
+ AddPackCallback(py::type::of(py::dict{}), dict_callback);
+
+ // HalBufferView.
+ AddPackCallback(
+ py::type::of<HalBufferView>(),
+ [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
+ HalBufferView *bv = py::cast<HalBufferView *>(py_value);
+ iree_vm_ref_t buffer_view_ref =
+ iree_hal_buffer_view_retain_ref(bv->raw_ptr());
+ CheckApiStatus(iree_vm_list_push_ref_move(list, &buffer_view_ref),
+ "could not append value");
+ });
+
+ // DeviceArray.
+ AddPackCallback(
+ device_array_type(),
+ [this](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
+ HalBufferView *bv =
+ py::cast<HalBufferView *>(py_value.attr(kAttrBufferView));
+ iree_vm_ref_t buffer_view_ref =
+ iree_hal_buffer_view_retain_ref(bv->raw_ptr());
+ CheckApiStatus(iree_vm_list_push_ref_move(list, &buffer_view_ref),
+ "could not append value");
+ });
+ }
+
+ void AddPackCallback(py::handle t, PackCallback pcb) {
+ assert(py_type_to_pack_callbacks_.count(t.ptr()) == 0 && "duplicate types");
+ t.inc_ref();
+ py_type_to_pack_callbacks_.insert(std::make_pair(t.ptr(), std::move(pcb)));
+ }
+
+ py::dict BuildAbiTypeToDtype() {
+ auto d = py::dict();
+ d[kF32] = numpy_module().attr("float32");
+ d[kF64] = numpy_module().attr("float64");
+ d[kI1] = numpy_module().attr("bool_");
+ d[kI8] = numpy_module().attr("int8");
+ d[kI16] = numpy_module().attr("int16");
+ d[kI64] = numpy_module().attr("int64");
+ d[kI32] = numpy_module().attr("int32");
+ return d;
+ }
+
+ // Cached modules and types. Those that involve recursive lookup within
+ // our top level module, we defer. Those outside, we cache at creation.
+ py::module numpy_module_ = py::module::import("numpy");
+ std::optional<py::object> runtime_module_;
+ std::optional<py::module> array_interop_module_;
+ std::optional<py::object> device_array_type_;
+ py::type hal_buffer_view_type_ = py::type::of<HalBufferView>();
+
+ // Maps Python type to a PackCallback that can generically code it.
+ // This will have inc_ref() called on them when added.
+ std::unordered_map<PyObject *, PackCallback> py_type_to_pack_callbacks_;
+
+ // Dict of str (ABI dtype like 'f32') to numpy dtype.
+ py::dict abi_type_to_dtype_ = BuildAbiTypeToDtype();
+};
+
+/// Object that can pack Python arguments into a VM List for a specific
+/// function.
+class ArgumentPacker {
+ public:
+ ArgumentPacker(InvokeStatics &statics, std::optional<py::list> arg_descs)
+ : statics_(statics) {
+ IREE_TRACE_SCOPE0("ArgumentPacker::Init");
+ if (!arg_descs) {
+ dynamic_dispatch_ = true;
+ } else {
+ // Reflection dispatch.
+ bool found_named = false;
+ for (py::handle desc : *arg_descs) {
+ int arg_index = flat_arg_packers_.size();
+ std::optional<std::string> kwarg_name;
+ py::object retained_sub_desc;
+
+ bool desc_is_list = py::isinstance<py::list>(desc);
+
+ // Check if named.
+ // ["named", "kwarg_name", sub_desc]
+ // If found, then we set kwarg_name and reset desc to the sub_desc.
+ if (desc_is_list) {
+ py::object maybe_named_field = desc[statics.kZero];
+ if (maybe_named_field.equal(statics.kNamedTag)) {
+ found_named = true;
+ py::object name_field = desc[statics.kOne];
+ retained_sub_desc = desc[statics.kTwo];
+ kwarg_name = py::cast<std::string>(name_field);
+ desc = retained_sub_desc;
+ desc_is_list = py::isinstance<py::list>(desc);
+
+ kwarg_to_index_[name_field] = arg_index;
+ }
+ }
+
+ if (!kwarg_name) {
+ pos_only_arg_count_ += 1;
+ }
+
+ flat_arg_packers_.push_back(
+ statics.AbiTypeToPackCallback(desc, desc_is_list));
+ }
+ }
+ }
+
+ /// Packs positional/kw arguments into a suitable VmVariantList and returns
+ /// it.
+ VmVariantList Pack(InvokeContext &invoke_context, py::sequence pos_args,
+ py::dict kw_args) {
+ // Dynamic dispatch.
+ if (dynamic_dispatch_) {
+ IREE_TRACE_SCOPE0("ArgumentPacker::PackDynamic");
+ if (!kw_args.empty()) {
+ throw std::invalid_argument(
+ "kwargs not supported for dynamic dispatch functions");
+ }
+
+ VmVariantList arg_list = VmVariantList::Create(pos_args.size());
+ for (py::handle py_arg : pos_args) {
+ PackCallback packer = statics_.GetGenericPackCallbackFor(py_arg);
+ if (!packer) {
+ std::string message("could not convert python value to VM: ");
+ message.append(py::cast<std::string>(py::repr(py_arg)));
+ throw std::invalid_argument(std::move(message));
+ }
+ // TODO: Better error handling by catching the exception and
+ // reporting which arg has a problem.
+ packer(invoke_context, arg_list.raw_ptr(), py_arg);
+ }
+ return arg_list;
+ } else {
+ IREE_TRACE_SCOPE0("ArgumentPacker::PackReflection");
+
+ // Reflection based dispatch.
+ std::vector<py::handle> py_args(flat_arg_packers_.size());
+
+ if (pos_args.size() > pos_only_arg_count_) {
+ std::string message("mismatched call arity: expected ");
+ message.append(std::to_string(pos_only_arg_count_));
+ message.append(" got ");
+ message.append(std::to_string(pos_args.size()));
+ throw std::invalid_argument(std::move(message));
+ }
+
+ // Positional args.
+ size_t pos_index = 0;
+ for (py::handle py_arg : pos_args) {
+ py_args[pos_index++] = py_arg;
+ }
+
+ // Keyword args.
+ for (auto it : kw_args) {
+ int found_index;
+ try {
+ found_index = py::cast<int>(kwarg_to_index_[it.first]);
+ } catch (std::exception &) {
+ std::string message("specified kwarg '");
+ message.append(py::cast<py::str>(it.first));
+ message.append("' is unknown");
+ throw std::invalid_argument(std::move(message));
+ }
+ if (py_args[found_index]) {
+ std::string message(
+ "mismatched call arity: duplicate keyword argument '");
+ message.append(py::cast<py::str>(it.first));
+ message.append("'");
+ throw std::invalid_argument(std::move(message));
+ }
+ py_args[found_index] = it.second;
+ }
+
+ // Now check to see that all args are set.
+ for (size_t i = 0; i < py_args.size(); ++i) {
+ if (!py_args[i]) {
+ std::string message(
+ "mismatched call arity: expected a value for argument ");
+ message.append(std::to_string(i));
+ throw std::invalid_argument(std::move(message));
+ }
+ }
+
+ // Start packing into the list.
+ VmVariantList arg_list = VmVariantList::Create(flat_arg_packers_.size());
+ for (size_t i = 0; i < py_args.size(); ++i) {
+ // TODO: Better error handling by catching the exception and
+ // reporting which arg has a problem.
+ flat_arg_packers_[i](invoke_context, arg_list.raw_ptr(), py_args[i]);
+ }
+ return arg_list;
+ }
+ }
+
+ private:
+ InvokeStatics &statics_;
+
+ int pos_only_arg_count_ = 0;
+
+ // Dictionary of py::str -> py::int_ mapping kwarg names to position in
+ // the argument list. We store this as a py::dict because it is optimized
+ // for py::str lookup.
+ py::dict kwarg_to_index_;
+
+ std::vector<PackCallback> flat_arg_packers_;
+
+ // If true, then there is no dispatch metadata and we process fully
+ // dynamically.
+ bool dynamic_dispatch_ = false;
+};
+
+} // namespace
+
+void SetupInvokeBindings(pybind11::module &m) {
+ py::class_<InvokeStatics>(m, "_InvokeStatics");
+ py::class_<InvokeContext>(m, "InvokeContext").def(py::init<HalDevice &>());
+ py::class_<ArgumentPacker>(m, "ArgumentPacker")
+ .def(py::init<InvokeStatics &, std::optional<py::list>>())
+ .def("pack", &ArgumentPacker::Pack);
+
+ m.attr("_invoke_statics") = py::cast(InvokeStatics());
+}
+
+} // namespace python
+} // namespace iree
diff --git a/bindings/python/iree/runtime/invoke.h b/bindings/python/iree/runtime/invoke.h
new file mode 100644
index 0000000..65e86b7
--- /dev/null
+++ b/bindings/python/iree/runtime/invoke.h
@@ -0,0 +1,20 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_BINDINGS_PYTHON_IREE_RT_INVOKE_H_
+#define IREE_BINDINGS_PYTHON_IREE_RT_INVOKE_H_
+
+#include "bindings/python/iree/runtime/binding.h"
+
+namespace iree {
+namespace python {
+
+void SetupInvokeBindings(pybind11::module &m);
+
+} // namespace python
+} // namespace iree
+
+#endif // IREE_BINDINGS_PYTHON_IREE_RT_INVOKE_H_
diff --git a/bindings/python/iree/runtime/vm.cc b/bindings/python/iree/runtime/vm.cc
index 0e35083..68b7522 100644
--- a/bindings/python/iree/runtime/vm.cc
+++ b/bindings/python/iree/runtime/vm.cc
@@ -9,6 +9,7 @@
#include "bindings/python/iree/runtime/status_utils.h"
#include "iree/base/api.h"
#include "iree/base/status_cc.h"
+#include "iree/base/tracing.h"
#include "iree/hal/api.h"
#include "iree/modules/hal/module.h"
#include "iree/vm/api.h"
@@ -63,6 +64,7 @@
//------------------------------------------------------------------------------
VmInstance VmInstance::Create() {
+ IREE_TRACE_SCOPE0("VmInstance::Create");
iree_vm_instance_t* instance;
auto status = iree_vm_instance_create(iree_allocator_system(), &instance);
CheckApiStatus(status, "Error creating instance");
@@ -75,6 +77,7 @@
VmContext VmContext::Create(VmInstance* instance,
std::optional<std::vector<VmModule*>> modules) {
+ IREE_TRACE_SCOPE0("VmContext::Create");
iree_vm_context_t* context;
if (!modules) {
// Simple create with open allowed modules.
@@ -112,6 +115,7 @@
void VmContext::Invoke(iree_vm_function_t f, VmVariantList& inputs,
VmVariantList& outputs) {
+ py::gil_scoped_release release;
CheckApiStatus(iree_vm_invoke(raw_ptr(), f, IREE_VM_INVOCATION_FLAG_NONE,
nullptr, inputs.raw_ptr(), outputs.raw_ptr(),
iree_allocator_system()),
@@ -123,6 +127,7 @@
//------------------------------------------------------------------------------
VmModule VmModule::FromFlatbufferBlob(py::object flatbuffer_blob_object) {
+ IREE_TRACE_SCOPE0("VmModule::FromFlatbufferBlob");
auto flatbuffer_blob = py::cast<py::buffer>(flatbuffer_blob_object);
auto buffer_info = flatbuffer_blob.request();
iree_vm_module_t* module;
diff --git a/bindings/python/iree/runtime/vm.h b/bindings/python/iree/runtime/vm.h
index e9c3db7..aa6c62a 100644
--- a/bindings/python/iree/runtime/vm.h
+++ b/bindings/python/iree/runtime/vm.h
@@ -83,7 +83,11 @@
iree_vm_list_t* raw_ptr() { return list_; }
const iree_vm_list_t* raw_ptr() const { return list_; }
-
+ iree_vm_list_t* steal_raw_ptr() {
+ iree_vm_list_t* stolen = list_;
+ list_ = nullptr;
+ return stolen;
+ }
void AppendNullRef() {
iree_vm_ref_t null_ref = {0};
CheckApiStatus(iree_vm_list_push_ref_move(raw_ptr(), &null_ref),