Modernize and relocate iree/runtime Python package to runtime/bindings/python. (#8912) * Modernize and relocate iree/runtime Python package to iree/runtime/python. Non-functional changes: * Moves `bindings/python/iree/runtime` -> `iree/runtime/python/iree/runtime`. * Fixes dash vs undescore inconsistency in compile install path. Now both use underscores (python_packages/iree_compiler and python_packages/iree_runtime). * Moves build directory for iree/compiler/python to iree/compiler/python (was outputting to bindings/python). Updates locations that were hard-coded. Functional changes: * Removes the old build-dir only setup.py in favor of an iree/runtime/setup.py that works from either the source or build dir. * Reworks the releases to use the new setup.py as-is vs scripting the build manually. * iree.runtime.version is now generated in the same way as iree.compiler.version. * Users can now run iree/runtime/setup.py with pip themselves to generate a wheel (i.e. `pip wheel iree/runtime`). It is now possible to integrate python package testing into the presubmit and have build jobs that generated Python installable binaries for subsequent steps. The only file left in bindings/python is build_requirements.txt. It is referred to by some docs and CI jobs so leaving as-is for the moment (will find it a new home in a followup).
diff --git a/runtime/CMakeLists.txt b/runtime/CMakeLists.txt new file mode 100644 index 0000000..8a3ce7a --- /dev/null +++ b/runtime/CMakeLists.txt
@@ -0,0 +1,14 @@ +# Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +if(IREE_BUILD_PYTHON_BINDINGS) + # Copy Python packaging files to the build dir so that we can install from + # there. + configure_file(pyproject.toml pyproject.toml COPYONLY) + configure_file(setup.py setup.py @ONLY) + + add_subdirectory(bindings/python/iree/runtime) +endif()
diff --git a/runtime/README.md b/runtime/README.md new file mode 100644 index 0000000..e395753 --- /dev/null +++ b/runtime/README.md
@@ -0,0 +1,43 @@ +# IREE runtime + +Note that this directory is in a transitional state. The C code still lives +in directories under `iree/` and will be relocated here in the future. + +## Language Bindings + +### Python + +The included `setup.py` file can be used to build Python binaries or directly +install the IREE runtime API. Do note that the runtime is quite heavy and +unless you are developing it and on a significant machine, you will want to +use released binaries. + +There are two ways to build/install Python packages: + +* Directly from the source tree (this is how official releases are done). +* From the build directory while developing. + +It is recommended to use your favorite method for managing +[virtual environments](https://docs.python.org/3/library/venv.html) instead +of modifying the system installation. + +Only relatively recent versions of `pip` are supported. Always use the latest +via `pip install --upgrade pip`. + +You can build either from the source or build tree (assumes that CMake has +been configured and the project built). The latter is typically used by +project developers who are already setup for development and want to +incrementally generate Python packages without rebuilding. + +To build a wheel that can be installed on the same Python version and OS: + +``` +python -m pip wheel runtime/ +``` + +To directly install: + +``` +python -m pip install runtime/ +``` +
diff --git a/runtime/bindings/python/iree/runtime/CMakeLists.txt b/runtime/bindings/python/iree/runtime/CMakeLists.txt new file mode 100644 index 0000000..e93f41f --- /dev/null +++ b/runtime/bindings/python/iree/runtime/CMakeLists.txt
@@ -0,0 +1,185 @@ +# Copyright 2020 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +set(NUMPY_DEPS "") +set(PYBIND_COPTS "-fexceptions") +set(PYBIND_EXTENSION_COPTS "-fvisibility=hidden") + +set(_python_extra_srcs) +set(_extra_install_tool_targets) +set(_tracy_enabled OFF) + +if(TARGET IREETracyCaptureServer) + message(STATUS "Bundline Tracy CLI tools with Python API") + set(_tracy_enabled ON) + list(APPEND _python_extra_srcs "scripts/iree-tracy-capture") + list(APPEND _extra_install_tool_targets "IREETracyCaptureServer") +endif() + +################################################################################ +# Package +################################################################################ + +iree_pyext_module( + NAME + PyExtRt + MODULE_NAME binding + SRCS + "binding.h" + "initialize_module.cc" + "invoke.h" + "invoke.cc" + "hal.h" + "hal.cc" + "status_utils.cc" + "status_utils.h" + "vm.h" + "vm.cc" + UNIX_LINKER_SCRIPT + "unix_version.lds" + DEFINES + # Pybind code seems to be incompatible with C++ allocation tracing + # hooks so disable it. + IREE_TRACING_HOOK_CPP_NEW_DELETE=0 + DEPS + iree::base + iree::base::cc + iree::base::internal::flags + iree::base::tracing + iree::hal + iree::hal::drivers + iree::modules::hal + iree::vm + iree::vm::bytecode_module +) + +iree_py_library( + NAME + runtime + SRCS + "__init__.py" + "array_interop.py" + "flags.py" + "function.py" + "system_api.py" + "tracing.py" + "scripts/iree_benchmark_trace/__main__.py" + "scripts/iree_run_trace/__main__.py" + "scripts/iree_run_module/__main__.py" + ${_python_extra_srcs} + PYEXT_DEPS + ::PyExtRt +) + +iree_symlink_tool( + TARGET runtime + FROM_TOOL_TARGET iree_tools_iree-benchmark-trace + TO_EXE_NAME iree-benchmark-trace +) + +iree_symlink_tool( + TARGET runtime + FROM_TOOL_TARGET iree_tools_iree-run-trace + TO_EXE_NAME iree-run-trace +) + +iree_symlink_tool( + TARGET runtime + FROM_TOOL_TARGET iree_tools_iree-run-module + TO_EXE_NAME iree-run-module +) + +if(_tracy_enabled) + iree_symlink_tool( + TARGET runtime + FROM_TOOL_TARGET IREETracyCaptureServer + TO_EXE_NAME iree-tracy-capture + ) +endif() + +################################################################################ +# Tests +################################################################################ + +iree_py_test( + NAME + array_interop_test + SRCS + "array_interop_test.py" +) + +iree_py_test( + NAME + flags_test + SRCS + "flags_test.py" +) + +iree_py_test( + NAME + function_test + SRCS + "function_test.py" +) + +iree_py_test( + NAME + hal_test + SRCS + "hal_test.py" +) + +iree_py_test( + NAME + system_api_test + SRCS + "system_api_test.py" +) + +iree_py_test( + NAME + vm_test + SRCS + "vm_test.py" +) + +################################################################################ +# Install +################################################################################ + +iree_py_install_package( + COMPONENT IreePythonPackage-runtime + PACKAGE_NAME iree_runtime + MODULE_PATH iree/runtime + DEPS + # TODO: Update CMake target/path mangling rules to make this syntactically + # rooted on "iree" in some way. + runtime_bindings_python_iree_runtime_PyExtRt + iree_tools_iree-benchmark-trace + iree_tools_iree-run-module + iree_tools_iree-run-trace + ${_extra_install_tool_targets} + ADDL_PACKAGE_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/README.md +) + +install( + # TODO: Update CMake target/path mangling rules to make this syntactically + # rooted on "iree" in some way. + TARGETS runtime_bindings_python_iree_runtime_PyExtRt + COMPONENT ${PY_INSTALL_COMPONENT} + DESTINATION "${PY_INSTALL_MODULE_DIR}" +) + +install( + TARGETS + iree_tools_iree-benchmark-trace + iree_tools_iree-run-module + iree_tools_iree-run-trace + ${_extra_install_tool_targets} + DESTINATION "${PY_INSTALL_MODULE_DIR}" + COMPONENT "${PY_INSTALL_COMPONENT}" +)
diff --git a/runtime/bindings/python/iree/runtime/README.md b/runtime/bindings/python/iree/runtime/README.md new file mode 100644 index 0000000..cd21ffe --- /dev/null +++ b/runtime/bindings/python/iree/runtime/README.md
@@ -0,0 +1,23 @@ +# IREE Python Runtime Components + +This package provides an API for running compiled IREE binaries and interfacing with the hardware-abstraction-layer. + +## Tracing + +Execution of calls against binaries can be traced for later replay (i.e. via +tools like `iree-run-module`). This can be set up either explicitly or +via environment variables. + +To trace via environment variable, set `IREE_SAVE_CALLS` to a directory to dump +traces into. Each created `SystemContext` will result in one `calls.yaml` +file (with an index appended to the stem for multiples). Any referenced +module binaries will be dumped into the same directory and referenced by the +YAML file. + +### Explicit API + +```python +tracer = iree.runtime.Tracer(some_dir) +config = iree.runtime.Config(driver, tracer) +... +```
diff --git a/runtime/bindings/python/iree/runtime/__init__.py b/runtime/bindings/python/iree/runtime/__init__.py new file mode 100644 index 0000000..e0d8643 --- /dev/null +++ b/runtime/bindings/python/iree/runtime/__init__.py
@@ -0,0 +1,47 @@ +"""Module init for the python bindings.""" + +# Copyright 2019 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# pylint: disable=g-multiple-import +# pylint: disable=g-bad-import-order +# pylint: disable=wildcard-import + +from . import binding + +# Pull some of the native symbols into the public API. +# Hal imports +from .binding import ( + BufferCompatibility, + BufferUsage, + HalAllocator, + HalBuffer, + HalBufferView, + HalDevice, + HalDriver, + HalElementType, + MemoryAccess, + MemoryType, + Shape, +) + +# Vm imports +from .binding import ( + create_hal_module, + Linkage, + VmVariantList, + VmFunction, + VmInstance, + VmContext, + VmModule, +) + +from .array_interop import * +from .system_api import * +from .function import * +from .tracing import * + +from . import flags
diff --git a/runtime/bindings/python/iree/runtime/array_interop.py b/runtime/bindings/python/iree/runtime/array_interop.py new file mode 100644 index 0000000..d761274 --- /dev/null +++ b/runtime/bindings/python/iree/runtime/array_interop.py
@@ -0,0 +1,253 @@ +# Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""BufferView and Python Array Protocol interop.""" + +from typing import Optional, Tuple +import logging +import numpy as np +import numpy.lib.mixins + +from .binding import ( + BufferUsage, + HalBufferView, + HalDevice, + HalElementType, + MappedMemory, + MemoryType, +) + +__all__ = [ + "asdevicearray", + "DeviceArray", +] + +_DEVICE_HANDLED_FUNCTIONS = {} + + +def _device_implements(np_function): + """Decorator that registers a base class implementation.""" + + def decorator(func): + _DEVICE_HANDLED_FUNCTIONS[np_function] = func + return func + + return decorator + + +class DeviceArray(numpy.lib.mixins.NDArrayOperatorsMixin): + """An IREE device array. + + Device arrays can be in one of two states: + 1. Host accessible: The array will be backed by host accessible memory + and can have the usual things done with it that one expects to be + able to do with an ndarray. + 2. Device resident: The array is just a handle to a device resident + Buffer (and BufferView wrapper). Metadata about the array are accessible + (shape and dtype) but anything that touches the data cannot be accessed + in this state. + + How a device array comes into existence controls how it can transition + between these states: + * A user can create a DeviceArray explicitly with a device allocator. + Such an array will not be implicitly convertible to host accessible, + although accessors exist to do so. + * When created by the platform with a synchronization policy, then + implicit transfer back to the host will trigger appropriate waits and + be performed automatically (this is the common case for function return + values if not otherwise configured, as an example). + """ + + def __init__(self, + device: HalDevice, + buffer_view: HalBufferView, + implicit_host_transfer: bool = False, + override_dtype=None): + self._device = device + self._buffer_view = buffer_view + self._implicit_host_transfer = implicit_host_transfer + self._override_dtype = override_dtype + + # If the array is host accessible, these will be non-None. + self._mapped_memory: Optional[MappedMemory] = None + self._host_array: Optional[np.ndarray] = None + + def __array__(self, dtype=None): + self._transfer_to_host(True) + if dtype is None: + return self._host_array + else: + return self._host_array.__array__(dtype) # pytype: disable=attribute-error + + def __array_function__(self, func, types, args, kwargs): + if func in _DEVICE_HANDLED_FUNCTIONS: + return _DEVICE_HANDLED_FUNCTIONS[func](*args, **kwargs) + + # Anything else forces a transfer to host and then delegates to the + # host array. + host_array = self.to_host() + return host_array.__array_function__(func, types, args, kwargs) # pytype: disable=attribute-error + + def __repr__(self): + return f"<IREE DeviceArray: shape={np.shape(self)}, dtype={self.dtype}>" + + @property + def is_host_accessible(self): + """Whether this array is currently host accessible.""" + return self._host_array is not None + + def to_host(self) -> np.ndarray: + self._transfer_to_host(False) + return self._host_array + + def _transfer_to_host(self, implicit): + if self._host_array is not None: + return + if implicit and not self._implicit_host_transfer: + raise ValueError( + "DeviceArray cannot be implicitly transferred to the host: " + "if necessary, do an explicit transfer via .to_host()") + self._mapped_memory, self._host_array = self._map_to_host() + + def _map_to_host(self) -> Tuple[MappedMemory, np.ndarray]: + # TODO: When synchronization is enabled, need to block here. + raw_dtype = self._get_raw_dtype() + mapped_memory = self._buffer_view.map() + host_array = mapped_memory.asarray(self._buffer_view.shape, raw_dtype) + # Detect if we need to force an explicit conversion. This happens when + # we were requested to pretend that the array is in a specific dtype, + # even if that is not representable on the device. You guessed it: + # this is to support bools. + if self._override_dtype is not None and self._override_dtype != raw_dtype: + host_array = host_array.astype(self._override_dtype) + return mapped_memory, host_array + + def _get_raw_dtype(self): + return HalElementType.map_to_dtype(self._buffer_view.element_type) + + @property + def dtype(self): + if self._override_dtype: + return self._override_dtype + return self._get_raw_dtype() + + @property + def shape(self): + return np.shape(self) + + def astype(self, dtype, casting="unsafe", copy=True): + if self.dtype == dtype and not copy: + return self + host_ary = self.to_host() + return host_ary.astype(dtype, casting=casting, copy=copy) + + def reshape(self, *args): + # TODO(scotttodd): add a native impl with a new buffer_view of the same data + # TODO(scotttodd): return DeviceArray instead of host ndarray? + host_ary = self.to_host() + return host_ary.reshape(*args) + + def __iter__(self): + host_ary = self.to_host() + return host_ary.__iter__() + + def __getitem__(self, index): + host_ary = self.to_host() + return host_ary.__getitem__(index) + + def __reduce__(self): + # Since this is used for making deep copies and pickling, we map + # separately from any interactive state. We just reduce to the actual + # host ndarray, which supports the necessary serialization protocols. + _, host_array = self._map_to_host() + return _restore_reduced_array, (host_array,) + + +def _restore_reduced_array(ary): + return ary + + +# Function implementations with custom behavior. +@_device_implements(np.shape) +def _(arr: DeviceArray): + return arr._buffer_view.shape + + +@_device_implements(np.reshape) +def _(arr: DeviceArray, *args): + return arr.reshape(*args) + + +def asdevicearray(device: HalDevice, + a, + dtype=None, + *, + implicit_host_transfer: bool = False, + memory_type=MemoryType.DEVICE_LOCAL, + allowed_usage=(BufferUsage.DISPATCH | BufferUsage.TRANSFER | + BufferUsage.MAPPING), + element_type: Optional[HalElementType] = None) -> DeviceArray: + """Helper to create a DeviceArray from an arbitrary array like. + + This is similar in purpose and usage to np.asarray, except that it takes + a device as the first argument. This may not be the best mechanism for + getting a DeviceArray, depending on your use case, but it is reliable + and simple. This function may make a defensive copy or cause implicit + transfers to satisfy the request. If this is important to you, then a lower + level API is likely more appropriate. + + Note that additional flags `memory_type`, `allowed_usage` and `element_type` + are only hints if creating a new DeviceArray. If `a` is already a DeviceArray, + they are ignored. + """ + if isinstance(a, DeviceArray): + if dtype is None: + return a + # Need to do a conversion, which we currently do not support on the + # device, so transfer back to the host. + logging.warn( + "Implicit dtype conversion of a DeviceArray forces a host transfer") + # First get an ndarray. + a = np.asarray(a, dtype=dtype) + element_type = map_dtype_to_element_type(a.dtype) + if element_type is None: + raise ValueError(f"Could not map dtype {a.dtype} to IREE element type") + buffer_view = device.allocator.allocate_buffer_copy( + memory_type=memory_type, + allowed_usage=allowed_usage, + buffer=a, + element_type=element_type) + return DeviceArray(device, + buffer_view, + implicit_host_transfer=implicit_host_transfer, + override_dtype=a.dtype) + + +# NOTE: Numpy dtypes are not hashable and exist in a hierarchy that should +# be queried via isinstance checks. This should be done as a fallback but +# this is a linear list for quick access to the most common. There may also +# be a better way to do this. +_DTYPE_TO_HAL_ELEMENT_TYPE = ( + (np.float32, HalElementType.FLOAT_32), + (np.float64, HalElementType.FLOAT_64), + (np.float16, HalElementType.FLOAT_16), + (np.int32, HalElementType.SINT_32), + (np.int64, HalElementType.SINT_64), + (np.int16, HalElementType.SINT_16), + (np.int8, HalElementType.SINT_8), + (np.uint32, HalElementType.UINT_32), + (np.uint64, HalElementType.UINT_64), + (np.uint16, HalElementType.UINT_16), + (np.uint8, HalElementType.UINT_8), + (np.bool_, HalElementType.BOOL_8), +) + + +def map_dtype_to_element_type(dtype) -> Optional[HalElementType]: + for match_dtype, element_type in _DTYPE_TO_HAL_ELEMENT_TYPE: + if match_dtype == dtype: + return element_type + else: + return None
diff --git a/runtime/bindings/python/iree/runtime/array_interop_test.py b/runtime/bindings/python/iree/runtime/array_interop_test.py new file mode 100644 index 0000000..032f296 --- /dev/null +++ b/runtime/bindings/python/iree/runtime/array_interop_test.py
@@ -0,0 +1,149 @@ +# Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import copy +import numpy as np +import unittest + +import iree.runtime + + +class DeviceHalTest(unittest.TestCase): + + def setUp(self): + super().setUp() + self.driver = iree.runtime.HalDriver.create("vmvx") + self.device = self.driver.create_default_device() + self.allocator = self.device.allocator + + def testMetadataAttributes(self): + init_ary = np.zeros([3, 4], dtype=np.int32) + 2 + ary = iree.runtime.asdevicearray(self.device, init_ary) + self.assertEqual([3, 4], ary.shape) + self.assertEqual(np.int32, ary.dtype) + + def testExplicitHostTransfer(self): + init_ary = np.zeros([3, 4], dtype=np.int32) + 2 + ary = iree.runtime.asdevicearray(self.device, init_ary) + self.assertEqual(repr(ary), "<IREE DeviceArray: shape=[3, 4], dtype=int32>") + self.assertFalse(ary.is_host_accessible) + + # Explicit transfer. + cp = ary.to_host() + np.testing.assert_array_equal(cp, init_ary) + self.assertTrue(ary.is_host_accessible) + + def testOverrideDtype(self): + init_ary = np.zeros([3, 4], dtype=np.int32) + 2 + buffer_view = self.allocator.allocate_buffer_copy( + memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, + allowed_usage=iree.runtime.BufferUsage.CONSTANT, + buffer=init_ary, + element_type=iree.runtime.HalElementType.SINT_32) + + ary = iree.runtime.DeviceArray(self.device, + buffer_view, + override_dtype=np.float32) + + # Explicit transfer. + cp = ary.to_host() + self.assertEqual(cp.dtype, np.float32) + np.testing.assert_array_equal(cp, init_ary.astype(np.float32)) + self.assertTrue(ary.is_host_accessible) + + def testIllegalImplicitHostTransfer(self): + init_ary = np.zeros([3, 4], dtype=np.int32) + 2 + ary = iree.runtime.asdevicearray(self.device, init_ary) + # Implicit transfer. + with self.assertRaises(ValueError): + _ = np.asarray(ary) + + def testImplicitHostArithmetic(self): + init_ary = np.zeros([3, 4], dtype=np.int32) + 2 + ary = iree.runtime.asdevicearray(self.device, + init_ary, + implicit_host_transfer=True) + sum = ary + init_ary + np.testing.assert_array_equal(sum, init_ary + 2) + self.assertTrue(ary.is_host_accessible) + + def testArrayFunctions(self): + init_ary = np.zeros([3, 4], dtype=np.float32) + 2 + ary = iree.runtime.asdevicearray(self.device, + init_ary, + implicit_host_transfer=True) + f = np.isfinite(ary) + self.assertTrue(f.all()) + + def testIteration(self): + init_ary = np.array([0, 1, 2, 3, 4, 5]) + ary = iree.runtime.asdevicearray(self.device, + init_ary, + implicit_host_transfer=True) + + for index, value in enumerate(ary): + self.assertEqual(index, value) + + def testSubscriptable(self): + init_ary = np.array([0, 1, 2, 3, 4, 5]) + ary = iree.runtime.asdevicearray(self.device, + init_ary, + implicit_host_transfer=True) + + for index in range(0, 6): + value = ary[index] + self.assertEqual(index, value) + + def testReshape(self): + init_ary = np.zeros([3, 4], dtype=np.float32) + 2 + ary = iree.runtime.asdevicearray(self.device, + init_ary, + implicit_host_transfer=True) + reshaped = ary.reshape((4, 3)) + self.assertEqual((4, 3), reshaped.shape) + + np_reshaped = np.reshape(ary, (2, 2, 3)) + self.assertEqual((2, 2, 3), np_reshaped.shape) + + def testDeepcopy(self): + init_ary = np.zeros([3, 4], dtype=np.float32) + 2 + orig_ary = iree.runtime.asdevicearray(self.device, + init_ary, + implicit_host_transfer=True) + copy_ary = copy.deepcopy(orig_ary) + self.assertIsNot(orig_ary, copy_ary) + np.testing.assert_array_equal(orig_ary, copy_ary) + + def testAsType(self): + init_ary = np.zeros([3, 4], dtype=np.int32) + 2 + orig_ary = iree.runtime.asdevicearray(self.device, + init_ary, + implicit_host_transfer=True) + # Same dtype, no copy. + i32_nocopy = orig_ary.astype(np.int32, copy=False) + self.assertIs(orig_ary, i32_nocopy) + + # Same dtype, copy. + i32_nocopy = orig_ary.astype(np.int32) + self.assertIsNot(orig_ary, i32_nocopy) + np.testing.assert_array_equal(orig_ary, i32_nocopy) + + # Different dtype, copy. + f32_copy = orig_ary.astype(np.float32) + self.assertIsNot(orig_ary, f32_copy) + self.assertEqual(f32_copy.dtype, np.float32) + np.testing.assert_array_equal(orig_ary.astype(np.float32), f32_copy) + + def testBool(self): + init_ary = np.zeros([3, 4], dtype=np.bool_) + init_ary[1] = True # Set some non-zero value. + ary = iree.runtime.asdevicearray(self.device, init_ary) + self.assertEqual(repr(ary), "<IREE DeviceArray: shape=[3, 4], dtype=bool>") + np.testing.assert_array_equal(ary.to_host(), init_ary) + + +if __name__ == "__main__": + unittest.main()
diff --git a/runtime/bindings/python/iree/runtime/binding.h b/runtime/bindings/python/iree/runtime/binding.h new file mode 100644 index 0000000..64240d4 --- /dev/null +++ b/runtime/bindings/python/iree/runtime/binding.h
@@ -0,0 +1,99 @@ +// Copyright 2019 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_BINDINGS_PYTHON_IREE_BINDING_H_ +#define IREE_BINDINGS_PYTHON_IREE_BINDING_H_ + +#include <optional> +#include <vector> + +#include "iree/base/api.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace iree { +namespace python { + +namespace py = pybind11; + +template <typename T> +struct ApiPtrAdapter {}; + +template <typename Self, typename T> +class ApiRefCounted { + public: + ApiRefCounted() : instance_(nullptr) {} + ApiRefCounted(ApiRefCounted& other) : instance_(other.instance_) { Retain(); } + ApiRefCounted(ApiRefCounted&& other) : instance_(other.instance_) { + other.instance_ = nullptr; + } + ApiRefCounted& operator=(ApiRefCounted&& other) { + instance_ = other.instance_; + other.instance_ = nullptr; + return *this; + } + void operator=(const ApiRefCounted&) = delete; + + ~ApiRefCounted() { Release(); } + + // Steals the reference to the object referenced by the given raw pointer and + // returns a wrapper (transfers ownership). + static Self StealFromRawPtr(T* retained_inst) { + auto self = Self(); + self.instance_ = retained_inst; + return self; + } + + // Retains the object referenced by the given raw pointer and returns + // a wrapper. + static Self BorrowFromRawPtr(T* non_retained_inst) { + auto self = Self(); + self.instance_ = non_retained_inst; + if (non_retained_inst) { + ApiPtrAdapter<T>::Retain(non_retained_inst); + } + return self; + } + + // Whether it is nullptr. + operator bool() const { return instance_; } + + T* steal_raw_ptr() { + T* ret = instance_; + instance_ = nullptr; + return ret; + } + + T* raw_ptr() { + if (!instance_) { + throw std::invalid_argument("API object is null"); + } + return instance_; + } + + const T* raw_ptr() const { + return const_cast<ApiRefCounted*>(this)->raw_ptr(); + } + + void Retain() { + if (instance_) { + ApiPtrAdapter<T>::Retain(instance_); + } + } + void Release() { + if (instance_) { + ApiPtrAdapter<T>::Release(instance_); + } + } + + private: + T* instance_; +}; + +} // namespace python +} // namespace iree + +#endif // IREE_BINDINGS_PYTHON_IREE_BINDING_H_
diff --git a/runtime/bindings/python/iree/runtime/export.def b/runtime/bindings/python/iree/runtime/export.def new file mode 100644 index 0000000..24e1f7b --- /dev/null +++ b/runtime/bindings/python/iree/runtime/export.def
@@ -0,0 +1,9 @@ +;; Copyright 2019 The IREE Authors +;; +;; Licensed under the Apache License v2.0 with LLVM Exceptions. +;; See https://llvm.org/LICENSE.txt for license information. +;; SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +LIBRARY BINDING +EXPORTS + PyInit_binding @1
diff --git a/runtime/bindings/python/iree/runtime/flags.py b/runtime/bindings/python/iree/runtime/flags.py new file mode 100644 index 0000000..a7b1020 --- /dev/null +++ b/runtime/bindings/python/iree/runtime/flags.py
@@ -0,0 +1,12 @@ +# Copyright 2021 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .binding import parse_flags + +# When enabled, performs additional function input validation checks. In the +# event of errors, this will yield nicer error messages but comes with a +# runtime cost. +FUNCTION_INPUT_VALIDATION = True
diff --git a/runtime/bindings/python/iree/runtime/flags_test.py b/runtime/bindings/python/iree/runtime/flags_test.py new file mode 100644 index 0000000..886176a --- /dev/null +++ b/runtime/bindings/python/iree/runtime/flags_test.py
@@ -0,0 +1,24 @@ +# Copyright 2019 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from iree import runtime as rt +import numpy as np +import unittest + + +class FlagsTest(unittest.TestCase): + + def testParse(self): + # We always have the logging verbose level available so use it. + rt.flags.parse_flags("--iree_v=1") + + def testParseError(self): + with self.assertRaisesRegex(ValueError, "flag 'barbar' not recognized"): + rt.flags.parse_flags("--barbar") + + +if __name__ == "__main__": + unittest.main()
diff --git a/runtime/bindings/python/iree/runtime/function.py b/runtime/bindings/python/iree/runtime/function.py new file mode 100644 index 0000000..40eb89c --- /dev/null +++ b/runtime/bindings/python/iree/runtime/function.py
@@ -0,0 +1,390 @@ +# Lint as: python3 +# Copyright 2021 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Dict, Optional + +import json +import logging + +import numpy as np + +from .binding import ( + _invoke_statics, + ArgumentPacker, + BufferUsage, + HalBufferView, + HalDevice, + InvokeContext, + MemoryType, + VmContext, + VmFunction, + VmVariantList, +) + +from . import tracing +from .array_interop import ( + map_dtype_to_element_type, + DeviceArray, +) +from .flags import ( + FUNCTION_INPUT_VALIDATION,) + +__all__ = [ + "FunctionInvoker", +] + + +class Invocation: + __slots__ = [ + "current_arg", + "current_desc", + "current_return_list", + "current_return_index", + "device", + ] + + def __init__(self, device: HalDevice): + self.device = device + # Captured during arg/ret processing to emit better error messages. + self.current_arg = None + self.current_desc = None + self.current_return_list = None + self.current_return_index = 0 + + def summarize_arg_error(self) -> str: + if self.current_arg is None: + return "" + if isinstance(self.current_arg, np.ndarray): + current_arg_repr = ( + f"ndarray({self.current_arg.shape}, {self.current_arg.dtype})") + else: + current_arg_repr = repr(self.current_arg) + return f"{repr(current_arg_repr)} with description {self.current_desc}" + + def summarize_return_error(self) -> str: + if self.current_return_list is None: + return "" + try: + vm_repr = f"{self.current_return_index}@{self.current_return_list}" + except: + vm_repr = "<error printing list item>" + return f"{vm_repr} with description {self.current_desc}" + + +class FunctionInvoker: + """Wraps a VmFunction, enabling invocations against it.""" + __slots__ = [ + "_vm_context", + "_device", + "_vm_function", + "_abi_dict", + "_arg_descs", + "_arg_packer", + "_ret_descs", + "_has_inlined_results", + "_tracer", + ] + + def __init__(self, vm_context: VmContext, device: HalDevice, + vm_function: VmFunction, + tracer: Optional[tracing.ContextTracer]): + self._vm_context = vm_context + # TODO: Needing to know the precise device to allocate on here is bad + # layering and will need to be fixed in some fashion if/when doing + # heterogenous dispatch. + self._device = device + self._vm_function = vm_function + self._tracer = tracer + self._abi_dict = None + self._arg_descs = None + self._ret_descs = None + self._has_inlined_results = False + self._parse_abi_dict(vm_function) + self._arg_packer = ArgumentPacker(_invoke_statics, self._arg_descs) + + @property + def vm_function(self) -> VmFunction: + return self._vm_function + + def __call__(self, *args, **kwargs): + invoke_context = InvokeContext(self._device) + arg_list = self._arg_packer.pack(invoke_context, args, kwargs) + + call_trace = None # type: Optional[tracing.CallTrace] + if self._tracer: + call_trace = self._tracer.start_call(self._vm_function) + try: + # Initialize the capacity to our total number of args, since we should + # be below that when doing a flat invocation. May want to be more + # conservative here when considering nesting. + inv = Invocation(self._device) + ret_descs = self._ret_descs + + ret_list = VmVariantList(len(ret_descs) if ret_descs is not None else 1) + if call_trace: + call_trace.add_vm_list(arg_list, "args") + self._invoke(arg_list, ret_list) + if call_trace: + call_trace.add_vm_list(ret_list, "results") + + # Un-inline the results to align with reflection, as needed. + reflection_aligned_ret_list = ret_list + if self._has_inlined_results: + reflection_aligned_ret_list = VmVariantList(1) + reflection_aligned_ret_list.push_list(ret_list) + returns = _extract_vm_sequence_to_python(inv, reflection_aligned_ret_list, + ret_descs) + return_arity = len(returns) + if return_arity == 1: + return returns[0] + elif return_arity == 0: + return None + else: + return tuple(returns) + finally: + if call_trace: + call_trace.end_call() + + # Break out invoke so it shows up in profiles. + def _invoke(self, arg_list, ret_list): + self._vm_context.invoke(self._vm_function, arg_list, ret_list) + + def _parse_abi_dict(self, vm_function: VmFunction): + reflection = vm_function.reflection + abi_json = reflection.get("iree.abi") + if abi_json is None: + # It is valid to have no reflection data, and rely on pure dynamic + # dispatch. + logging.debug( + "Function lacks reflection data. Interop will be limited: %r", + vm_function) + return + try: + self._abi_dict = json.loads(abi_json) + except json.JSONDecodeError as e: + raise RuntimeError( + f"Reflection metadata is not valid JSON: {abi_json}") from e + try: + self._arg_descs = self._abi_dict["a"] + self._ret_descs = self._abi_dict["r"] + except KeyError as e: + raise RuntimeError( + f"Malformed function reflection metadata: {reflection}") from e + if not isinstance(self._arg_descs, list) or not isinstance( + self._ret_descs, list): + raise RuntimeError( + f"Malformed function reflection metadata structure: {reflection}") + + # Detect whether the results are a slist/stuple/sdict, which indicates + # that they are inlined with the function's results. + if len(self._ret_descs) == 1: + maybe_inlined = self._ret_descs[0] + if maybe_inlined and maybe_inlined[0] in ["slist", "stuple", "sdict"]: + self._has_inlined_results = True + + def __repr__(self): + return repr(self._vm_function) + + +# VM to Python converters. All take: +# inv: Invocation +# vm_list: VmVariantList to read from +# vm_index: Index in the vm_list to extract +# desc: The ABI descriptor list (or None if in dynamic mode) +# Return the corresponding Python object. + + +def _vm_to_ndarray(inv: Invocation, vm_list: VmVariantList, vm_index: int, + desc): + # The descriptor for an ndarray is like: + # ["ndarray", "<dtype>", <rank>, <dim>...] + # ex: ['ndarray', 'i32', 1, 25948] + buffer_view = vm_list.get_as_buffer_view(vm_index) + dtype_str = desc[1] + try: + dtype = ABI_TYPE_TO_DTYPE[dtype_str] + except KeyError: + _raise_return_error(inv, f"unrecognized dtype '{dtype_str}'") + x = DeviceArray(inv.device, + buffer_view, + implicit_host_transfer=True, + override_dtype=dtype) + return x + + +def _vm_to_sdict(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc): + # The descriptor for an sdict is like: + # ['sdict', ['key1', value1], ...] + sub_vm_list = vm_list.get_as_list(vm_index) + item_keys = [] + item_descs = [] + for k, d in desc[1:]: + item_keys.append(k) + item_descs.append(d) + py_items = _extract_vm_sequence_to_python(inv, sub_vm_list, item_descs) + return dict(zip(item_keys, py_items)) + + +def _vm_to_slist(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc): + # The descriptor for an slist is like: + # ['slist, item1, ...] + sub_vm_list = vm_list.get_as_list(vm_index) + item_descs = desc[1:] + py_items = _extract_vm_sequence_to_python(inv, sub_vm_list, item_descs) + return py_items + + +def _vm_to_stuple(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc): + return tuple(_vm_to_slist(inv, vm_list, vm_index, desc)) + + +def _vm_to_scalar(type_bound: type): + + def convert(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc): + value = vm_list.get_variant(vm_index) + if not isinstance(value, type_bound): + raise ReturnError( + f"expected an {type_bound} value but got {value.__class__}") + return value + + return convert + + +def _vm_to_pylist(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc): + # The descriptor for a pylist is like: + # ['pylist', element_type] + sub_vm_list = vm_list.get_as_list(vm_index) + element_type_desc = desc[1:] + py_items = _extract_vm_sequence_to_python( + inv, sub_vm_list, element_type_desc * len(sub_vm_list)) + return py_items + + +VM_TO_PYTHON_CONVERTERS = { + "ndarray": _vm_to_ndarray, + "sdict": _vm_to_sdict, + "slist": _vm_to_slist, + "stuple": _vm_to_stuple, + "py_homogeneous_list": _vm_to_pylist, + + # Scalars. + "i8": _vm_to_scalar(int), + "i16": _vm_to_scalar(int), + "i32": _vm_to_scalar(int), + "i64": _vm_to_scalar(int), + "f16": _vm_to_scalar(float), + "f32": _vm_to_scalar(float), + "f64": _vm_to_scalar(float), + "bf16": _vm_to_scalar(float), +} + +ABI_TYPE_TO_DTYPE = { + # TODO: Others. + "f32": np.float32, + "i32": np.int32, + "i64": np.int64, + "f64": np.float64, + "i16": np.int16, + "i8": np.int8, + "i1": np.bool_, +} + +# When we get an ndarray as an argument and are implicitly mapping it to a +# buffer view, flags for doing so. +IMPLICIT_BUFFER_ARG_MEMORY_TYPE = MemoryType.DEVICE_LOCAL +IMPLICIT_BUFFER_ARG_USAGE = (BufferUsage.DISPATCH | BufferUsage.TRANSFER | + BufferUsage.MAPPING) + + +def _is_ndarray_descriptor(desc): + return desc and desc[0] == "ndarray" + + +def _is_0d_ndarray_descriptor(desc): + # Example: ["ndarray", "f32", 0] + return desc and desc[0] == "ndarray" and desc[2] == 0 + + +def _cast_scalar_to_ndarray(inv: Invocation, x, desc): + # Example descriptor: ["ndarray", "f32", 0] + dtype_str = desc[1] + try: + dtype = ABI_TYPE_TO_DTYPE[dtype_str] + except KeyError: + _raise_argument_error(inv, f"unrecognized dtype '{dtype_str}'") + return dtype(x) + + +class ArgumentError(ValueError): + pass + + +class ReturnError(ValueError): + pass + + +def _raise_argument_error(inv: Invocation, + summary: str, + e: Optional[Exception] = None): + new_e = ArgumentError( + f"Error passing argument: {summary} " + f"(while encoding argument {inv.summarize_arg_error()})") + if e: + raise new_e from e + else: + raise new_e + + +def _raise_return_error(inv: Invocation, + summary: str, + e: Optional[Exception] = None): + new_e = ReturnError(f"Error processing function return: {summary} " + f"(while decoding return {inv.summarize_return_error()})") + if e: + raise new_e from e + else: + raise new_e + + +def _extract_vm_sequence_to_python(inv: Invocation, vm_list, descs): + vm_list_arity = len(vm_list) + if descs is None: + descs = [None] * vm_list_arity + elif vm_list_arity != len(descs): + _raise_return_error( + inv, f"mismatched return arity: {vm_list_arity} vs {len(descs)}") + results = [] + for vm_index, desc in zip(range(vm_list_arity), descs): + inv.current_return_list = vm_list + inv.current_return_index = vm_index + inv.current_desc = desc + if desc is None: + # Dynamic (non reflection mode). + converted = vm_list.get_variant(vm_index) + # Special case: Upgrade HalBufferView to a DeviceArray. We do that here + # since this is higher level and it preserves layering. Note that + # the reflection case also does this conversion. + if isinstance(converted, HalBufferView): + converted = DeviceArray(inv.device, + converted, + implicit_host_transfer=True) + else: + # Known type descriptor. + vm_type = desc if isinstance(desc, str) else desc[0] + try: + converter = VM_TO_PYTHON_CONVERTERS[vm_type] + except KeyError: + _raise_return_error(inv, f"cannot map VM type to Python: {vm_type}") + try: + converted = converter(inv, vm_list, vm_index, desc) + except ReturnError: + raise + except Exception as e: + _raise_return_error(inv, f"exception converting from VM type to Python", + e) + results.append(converted) + return results
diff --git a/runtime/bindings/python/iree/runtime/function_test.py b/runtime/bindings/python/iree/runtime/function_test.py new file mode 100644 index 0000000..57b8481 --- /dev/null +++ b/runtime/bindings/python/iree/runtime/function_test.py
@@ -0,0 +1,596 @@ +# Lint as: python3 +# Copyright 2019 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import json +import numpy as np + +from absl.testing import absltest + +from iree import runtime as rt +from iree.runtime.function import ( + FunctionInvoker, + IMPLICIT_BUFFER_ARG_MEMORY_TYPE, + IMPLICIT_BUFFER_ARG_USAGE, +) +from iree.runtime.binding import VmVariantList + + +class MockVmContext: + + def __init__(self, invoke_callback): + self._invoke_callback = invoke_callback + self.invocations = [] + + def invoke(self, vm_function, arg_list, ret_list): + self._invoke_callback(arg_list, ret_list) + self.invocations.append((vm_function, arg_list, ret_list)) + print(f"INVOKE: {arg_list} -> {ret_list}") + + @property + def mock_arg_reprs(self): + return repr([arg_list for _, arg_list, _ in self.invocations]) + + +class MockVmFunction: + + def __init__(self, reflection): + self.reflection = reflection + + +class FunctionTest(absltest.TestCase): + + @classmethod + def setUpClass(cls): + # Doesn't matter what device. We just need one. + config = rt.Config("vmvx") + cls.device = config.device + + def testNoReflectionScalars(self): + + def invoke(arg_list, ret_list): + ret_list.push_int(3) + ret_list.push_int(4) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction(reflection={}) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + result = invoker(1, 2) + self.assertEqual("[<VmVariantList(2): [1, 2]>]", vm_context.mock_arg_reprs) + self.assertEqual((3, 4), result) + + def testKeywordArgs(self): + + def invoke(arg_list, ret_list): + ret_list.push_int(3) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction( + reflection={ + "iree.abi": + json.dumps({ + "a": [ + "i32", + ["named", "a", "i32"], + ["named", "b", "i32"], + ], + "r": ["i32",], + }) + }) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + result = invoker(-1, a=1, b=2) + self.assertEqual("[<VmVariantList(3): [-1, 1, 2]>]", + vm_context.mock_arg_reprs) + self.assertEqual(3, result) + + def testListArg(self): + + def invoke(arg_list, ret_list): + ret_list.push_int(3) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction(reflection={ + "iree.abi": + json.dumps({ + "a": [["slist", "i32", "i32"],], + "r": ["i32",], + }) + }) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + _ = invoker([2, 3]) + self.assertEqual("[<VmVariantList(1): [List[2, 3]]>]", + vm_context.mock_arg_reprs) + + def testListArgNoReflection(self): + + def invoke(arg_list, ret_list): + ret_list.push_int(3) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction(reflection={}) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + _ = invoker([2, 3]) + self.assertEqual("[<VmVariantList(1): [List[2, 3]]>]", + vm_context.mock_arg_reprs) + + def testListArgArityMismatch(self): + + def invoke(arg_list, ret_list): + ret_list.push_int(3) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction(reflection={ + "iree.abi": + json.dumps({ + "a": [["slist", "i32", "i32"],], + "r": ["i32",], + }) + }) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + with self.assertRaisesRegex(ValueError, + "expected a sequence with 2 values. got:"): + _ = invoker([2, 3, 4]) + + def testTupleArg(self): + + def invoke(arg_list, ret_list): + ret_list.push_int(3) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction(reflection={ + "iree.abi": + json.dumps({ + "a": [["stuple", "i32", "i32"],], + "r": ["i32",], + }) + }) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + _ = invoker((2, 3)) + self.assertEqual("[<VmVariantList(1): [List[2, 3]]>]", + vm_context.mock_arg_reprs) + + def testDictArg(self): + + def invoke(arg_list, ret_list): + ret_list.push_int(3) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction( + reflection={ + "iree.abi": + json.dumps({ + "a": [["sdict", ["a", "i32"], ["b", "i32"]],], + "r": ["i32",], + }) + }) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + _ = invoker({"b": 3, "a": 2}) + self.assertEqual("[<VmVariantList(1): [List[2, 3]]>]", + vm_context.mock_arg_reprs) + + def testDictArgArityMismatch(self): + + def invoke(arg_list, ret_list): + ret_list.push_int(3) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction( + reflection={ + "iree.abi": + json.dumps({ + "a": [["sdict", ["a", "i32"], ["b", "i32"]],], + "r": ["i32",], + }) + }) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + with self.assertRaisesRegex(ValueError, + "expected a dict with 2 values. got:"): + _ = invoker({"a": 2, "b": 3, "c": 4}) + + def testDictArgKeyError(self): + + def invoke(arg_list, ret_list): + ret_list.push_int(3) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction( + reflection={ + "iree.abi": + json.dumps({ + "a": [["sdict", ["a", "i32"], ["b", "i32"]],], + "r": ["i32",], + }) + }) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + with self.assertRaisesRegex(ValueError, "could not get item 'b' from: "): + _ = invoker({"a": 2, "c": 3}) + + def testDictArgNoReflection(self): + + def invoke(arg_list, ret_list): + ret_list.push_int(3) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction(reflection={}) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + _ = invoker({"b": 3, "a": 2}) + self.assertEqual("[<VmVariantList(1): [List[2, 3]]>]", + vm_context.mock_arg_reprs) + + def testInlinedResults(self): + + def invoke(arg_list, ret_list): + ret_list.push_int(3) + ret_list.push_int(4) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction(reflection={ + "iree.abi": json.dumps({ + "a": [], + "r": [["slist", "i32", "i32"]], + }) + }) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + result = invoker() + self.assertEqual([3, 4], result) + + def testNestedResults(self): + + def invoke(arg_list, ret_list): + ret_list.push_int(3) + sub_list = VmVariantList(2) + sub_dict = VmVariantList(2) + sub_dict.push_int(100) + sub_dict.push_int(200) + sub_list.push_list(sub_dict) + sub_list.push_int(6) + ret_list.push_list(sub_list) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction( + reflection={ + "iree.abi": + json.dumps({ + "a": [], + "r": [ + "i32", + [ + "slist", + ["sdict", ["bar", "i32"], ["foo", "i32"]], + "i64", + ] + ], + }) + }) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + result = invoker() + self.assertEqual((3, [{"bar": 100, "foo": 200}, 6]), result) + + def testMissingPositional(self): + + def invoke(arg_list, ret_list): + ret_list.push_int(3) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction( + reflection={ + "iree.abi": + json.dumps({ + "a": [ + "i32", + ["named", "a", "i32"], + ["named", "b", "i32"], + ], + "r": ["i32",], + }) + }) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + with self.assertRaisesRegex(ValueError, "mismatched call arity:"): + result = invoker(a=1, b=1) + + def testMissingPositionalNdarray(self): + + def invoke(arg_list, ret_list): + ret_list.push_int(3) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction( + reflection={ + "iree.abi": + json.dumps({ + "a": [ + ["ndarray", "i32", 1, 1], + ["named", "a", ["ndarray", "i32", 1, 1]], + ["named", "b", ["ndarray", "i32", 1, 1]], + ], + "r": ["i32",], + }) + }) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + with self.assertRaisesRegex(ValueError, "mismatched call arity:"): + result = invoker(a=1, b=1) + + def testMissingKeyword(self): + + def invoke(arg_list, ret_list): + ret_list.push_int(3) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction( + reflection={ + "iree.abi": + json.dumps({ + "a": [ + "i32", + ["named", "a", "i32"], + ["named", "b", "i32"], + ], + "r": ["i32",], + }) + }) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + with self.assertRaisesRegex(ValueError, "mismatched call arity:"): + result = invoker(-1, a=1) + + def testMissingKeywordNdArray(self): + + def invoke(arg_list, ret_list): + ret_list.push_int(3) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction( + reflection={ + "iree.abi": + json.dumps({ + "a": [ + ["ndarray", "i32", 1, 1], + ["named", "a", ["ndarray", "i32", 1, 1]], + ["named", "b", ["ndarray", "i32", 1, 1]], + ], + "r": ["i32",], + }) + }) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + with self.assertRaisesRegex(ValueError, "mismatched call arity:"): + result = invoker(-1, a=1) + + def testExtraKeyword(self): + + def invoke(arg_list, ret_list): + ret_list.push_int(3) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction( + reflection={ + "iree.abi": + json.dumps({ + "a": [ + "i32", + ["named", "a", "i32"], + ["named", "b", "i32"], + ], + "r": ["i32",], + }) + }) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + with self.assertRaisesRegex(ValueError, "specified kwarg 'c' is unknown"): + result = invoker(-1, a=1, b=2, c=3) + + def testNdarrayArg(self): + arg_array = np.asarray([1, 0], dtype=np.int32) + + invoked_arg_list = None + + def invoke(arg_list, ret_list): + nonlocal invoked_arg_list + invoked_arg_list = arg_list + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction(reflection={ + "iree.abi": json.dumps({ + "a": [["ndarray", "i32", 1, 2]], + "r": [], + }) + }) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + result = invoker(arg_array) + self.assertEqual("<VmVariantList(1): [HalBufferView(2:0x20000011)]>", + repr(invoked_arg_list)) + + def testDeviceArrayArg(self): + # Note that since the device array is set up to disallow implicit host + # transfers, this also verifies that no accidental/automatic transfers + # are done as part of marshalling the array to the function. + arg_array = rt.asdevicearray(self.device, + np.asarray([1, 0], dtype=np.int32), + implicit_host_transfer=False) + + invoked_arg_list = None + + def invoke(arg_list, ret_list): + nonlocal invoked_arg_list + invoked_arg_list = arg_list + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction(reflection={ + "iree.abi": json.dumps({ + "a": [["ndarray", "i32", 1, 2]], + "r": [], + }) + }) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + result = invoker(arg_array) + self.assertEqual("<VmVariantList(1): [HalBufferView(2:0x20000011)]>", + repr(invoked_arg_list)) + + def testBufferViewArg(self): + arg_buffer_view = self.device.allocator.allocate_buffer_copy( + memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE, + allowed_usage=IMPLICIT_BUFFER_ARG_USAGE, + buffer=np.asarray([1, 0], dtype=np.int32), + element_type=rt.HalElementType.SINT_32) + + invoked_arg_list = None + + def invoke(arg_list, ret_list): + nonlocal invoked_arg_list + invoked_arg_list = arg_list + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction(reflection={ + "iree.abi": json.dumps({ + "a": [["ndarray", "i32", 1, 2]], + "r": [], + }) + }) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + _ = invoker(arg_buffer_view) + self.assertEqual("<VmVariantList(1): [HalBufferView(2:0x20000011)]>", + repr(invoked_arg_list)) + + def testNdarrayArgNoReflection(self): + arg_array = np.asarray([1, 0], dtype=np.int32) + + invoked_arg_list = None + + def invoke(arg_list, ret_list): + nonlocal invoked_arg_list + invoked_arg_list = arg_list + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction(reflection={}) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + result = invoker(arg_array) + self.assertEqual("<VmVariantList(1): [HalBufferView(2:0x20000011)]>", + repr(invoked_arg_list)) + + def testDeviceArrayArgNoReflection(self): + # Note that since the device array is set up to disallow implicit host + # transfers, this also verifies that no accidental/automatic transfers + # are done as part of marshalling the array to the function. + arg_array = rt.asdevicearray(self.device, + np.asarray([1, 0], dtype=np.int32), + implicit_host_transfer=False) + + invoked_arg_list = None + + def invoke(arg_list, ret_list): + nonlocal invoked_arg_list + invoked_arg_list = arg_list + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction(reflection={}) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + result = invoker(arg_array) + self.assertEqual("<VmVariantList(1): [HalBufferView(2:0x20000011)]>", + repr(invoked_arg_list)) + + def testBufferViewArgNoReflection(self): + arg_buffer_view = self.device.allocator.allocate_buffer_copy( + memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE, + allowed_usage=IMPLICIT_BUFFER_ARG_USAGE, + buffer=np.asarray([1, 0], dtype=np.int32), + element_type=rt.HalElementType.SINT_32) + + invoked_arg_list = None + + def invoke(arg_list, ret_list): + nonlocal invoked_arg_list + invoked_arg_list = arg_list + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction(reflection={}) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + _ = invoker(arg_buffer_view) + self.assertEqual("<VmVariantList(1): [HalBufferView(2:0x20000011)]>", + repr(invoked_arg_list)) + + def testReturnBufferView(self): + result_array = np.asarray([1, 0], dtype=np.int32) + + def invoke(arg_list, ret_list): + buffer_view = self.device.allocator.allocate_buffer_copy( + memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE, + allowed_usage=IMPLICIT_BUFFER_ARG_USAGE, + buffer=result_array, + element_type=rt.HalElementType.SINT_32) + ret_list.push_buffer_view(buffer_view) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction(reflection={ + "iree.abi": json.dumps({ + "a": [], + "r": [["ndarray", "i32", 1, 2]], + }) + }) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + result = invoker() + np.testing.assert_array_equal([1, 0], result) + + def testReturnBufferViewNoReflection(self): + result_array = np.asarray([1, 0], dtype=np.int32) + + def invoke(arg_list, ret_list): + buffer_view = self.device.allocator.allocate_buffer_copy( + memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE, + allowed_usage=IMPLICIT_BUFFER_ARG_USAGE, + buffer=result_array, + element_type=rt.HalElementType.SINT_32) + ret_list.push_buffer_view(buffer_view) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction(reflection={}) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + result = invoker() + np.testing.assert_array_equal([1, 0], result) + + # TODO: Fill out all return types. + def testReturnTypeNdArrayBool(self): + result_array = np.asarray([1, 0], dtype=np.int8) + + def invoke(arg_list, ret_list): + buffer_view = self.device.allocator.allocate_buffer_copy( + memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE, + allowed_usage=IMPLICIT_BUFFER_ARG_USAGE, + buffer=result_array, + element_type=rt.HalElementType.UINT_8) + ret_list.push_buffer_view(buffer_view) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction(reflection={ + "iree.abi": json.dumps({ + "a": [], + "r": [["ndarray", "i1", 1, 2]], + }) + }) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + result = invoker() + # assertEqual on bool arrays is fraught for... reasons. + np.testing.assert_array_equal([True, False], result) + + def testReturnTypeList(self): + vm_list = VmVariantList(2) + vm_list.push_int(1) + vm_list.push_int(2) + + def invoke(arg_list, ret_list): + ret_list.push_list(vm_list) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction(reflection={ + "iree.abi": + json.dumps({ + "a": [], + "r": [["py_homogeneous_list", "i64"]], + }) + }) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + result = invoker() + self.assertEqual("[1, 2]", repr(result)) + + +if __name__ == "__main__": + absltest.main()
diff --git a/runtime/bindings/python/iree/runtime/hal.cc b/runtime/bindings/python/iree/runtime/hal.cc new file mode 100644 index 0000000..1d577a6 --- /dev/null +++ b/runtime/bindings/python/iree/runtime/hal.cc
@@ -0,0 +1,514 @@ +// Copyright 2019 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "./hal.h" + +#include "iree/base/tracing.h" +#include "iree/hal/api.h" +#include "pybind11/numpy.h" + +namespace iree { +namespace python { + +namespace { + +// RAII wrapper for a Py_buffer which calls PyBuffer_Release when it goes +// out of scope. +class PyBufferReleaser { + public: + PyBufferReleaser(Py_buffer& b) : b_(b) {} + ~PyBufferReleaser() { PyBuffer_Release(&b_); } + + private: + Py_buffer& b_; +}; + +static std::string ToHexString(const uint8_t* data, size_t length) { + static constexpr char kHexChars[] = {'0', '1', '2', '3', '4', '5', '6', '7', + '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'}; + std::string s(length * 2, ' '); + for (size_t i = 0; i < length; ++i) { + s[2 * i + 0] = kHexChars[(data[i] & 0xF0) >> 4]; + s[2 * i + 1] = kHexChars[(data[i] & 0x0F) >> 0]; + } + return s; +} +static std::string ToHexString(uint32_t value) { + return ToHexString((const uint8_t*)&value, sizeof(value)); +} + +} // namespace + +//------------------------------------------------------------------------------ +// HalAllocator +//------------------------------------------------------------------------------ + +py::dict HalAllocator::QueryStatistics() { + py::dict items; + iree_hal_allocator_statistics_t stats; + iree_hal_allocator_query_statistics(raw_ptr(), &stats); +#if IREE_STATISTICS_ENABLE + items["host_bytes_peak"] = stats.host_bytes_peak; + items["host_bytes_allocated"] = stats.host_bytes_allocated; + items["host_bytes_freed"] = stats.host_bytes_freed; + items["device_bytes_peak"] = stats.device_bytes_peak; + items["device_bytes_allocated"] = stats.device_bytes_allocated; + items["device_bytes_freed"] = stats.device_bytes_freed; +#endif + return items; +} + +py::str HalAllocator::FormattedStatistics() { + // Perform all allocating string manipulation without early exit. + iree_string_builder_t builder; + iree_string_builder_initialize(iree_allocator_system(), &builder); + iree_hal_allocator_statistics_t stats; + iree_hal_allocator_query_statistics(raw_ptr(), &stats); + auto status = iree_hal_allocator_statistics_format(&stats, &builder); + iree_string_view_t view = iree_string_builder_view(&builder); + py::str result = py::str(view.data, view.size); + iree_string_builder_deinitialize(&builder); + + // Check/raise after all memory alloc/dealloc. + CheckApiStatus(status, "unable to format statistics"); + return result; +} + +py::object HalAllocator::AllocateBufferCopy( + int memory_type, int allowed_usage, py::object buffer, + std::optional<iree_hal_element_types_t> element_type) { + IREE_TRACE_SCOPE0("HalAllocator::AllocateBufferCopy"); + // Request a view of the buffer (use the raw python C API to avoid + // some allocation and copying at the pybind level). + Py_buffer py_view; + // Note that only C-Contiguous ND-arrays are presently supported, so + // only request that via PyBUF_ND. Long term, we should consult an + // "oracle" in the runtime to determine the precise required format + // and set flags accordingly (and fallback/copy on failure). + int flags = PyBUF_FORMAT | PyBUF_ND; + + // Acquire the backing buffer and setup RAII release. + if (PyObject_GetBuffer(buffer.ptr(), &py_view, flags) != 0) { + // The GetBuffer call is required to set an appropriate error. + throw py::error_already_set(); + } + PyBufferReleaser py_view_releaser(py_view); + + iree_hal_buffer_params_t params = {0}; + // TODO: Should not require host visible :( + params.type = memory_type | IREE_HAL_MEMORY_TYPE_HOST_VISIBLE; + params.usage = allowed_usage; + + iree_hal_buffer_t* hal_buffer = nullptr; + iree_status_t status = iree_ok_status(); + { + py::gil_scoped_release release; + status = iree_hal_allocator_allocate_buffer( + raw_ptr(), params, py_view.len, + iree_make_const_byte_span(py_view.buf, py_view.len), &hal_buffer); + } + CheckApiStatus(status, "Failed to allocate device visible buffer"); + + if (!element_type) { + return py::cast(HalBuffer::StealFromRawPtr(hal_buffer), + py::return_value_policy::move); + } + + // Create the buffer_view. (note that numpy shape is ssize_t, so we need to + // copy). + iree_hal_encoding_type_t encoding_type = + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR; + std::vector<iree_hal_dim_t> dims(py_view.ndim); + std::copy(py_view.shape, py_view.shape + py_view.ndim, dims.begin()); + iree_hal_buffer_view_t* hal_buffer_view; + CheckApiStatus( + iree_hal_buffer_view_create( + hal_buffer, dims.data(), dims.size(), *element_type, encoding_type, + iree_hal_allocator_host_allocator(raw_ptr()), &hal_buffer_view), + "Error allocating buffer_view"); + iree_hal_buffer_release(hal_buffer); + + return py::cast(HalBufferView::StealFromRawPtr(hal_buffer_view), + py::return_value_policy::move); +} + +//------------------------------------------------------------------------------ +// HalBuffer +//------------------------------------------------------------------------------ + +namespace { + +void AppendHalBufferRepr(iree_hal_buffer_t* buffer, std::string& repr) { + repr.append(std::to_string(iree_hal_buffer_byte_length(buffer))); + repr.append(" bytes (at offset "); + repr.append(std::to_string(iree_hal_buffer_byte_offset(buffer))); + repr.append(" into "); + repr.append(std::to_string(iree_hal_buffer_allocation_size(buffer))); + repr.append("), memory_type="); + + // Memory type. + iree_bitfield_string_temp_t tmp; + iree_string_view_t sv; + sv = iree_hal_memory_type_format(iree_hal_buffer_memory_type(buffer), &tmp); + repr.append(sv.data, sv.size); + + // Allowed access. + repr.append(", allowed_access="); + sv = iree_hal_memory_access_format(iree_hal_buffer_allowed_access(buffer), + &tmp); + repr.append(sv.data, sv.size); + + // Allowed usage. + repr.append(", allowed_usage="); + sv = + iree_hal_buffer_usage_format(iree_hal_buffer_allowed_usage(buffer), &tmp); + repr.append(sv.data, sv.size); +} + +} // namespace + +py::str HalBuffer::Repr() { + std::string repr("<HalBuffer "); + AppendHalBufferRepr(raw_ptr(), repr); + repr.append(">"); + return py::str(repr); +} + +//------------------------------------------------------------------------------ +// HalBufferView +//------------------------------------------------------------------------------ + +py::str HalBufferView::Repr() { + std::string repr("<HalBufferView ("); + + // Shape. + iree_host_size_t rank = iree_hal_buffer_view_shape_rank(raw_ptr()); + for (iree_host_size_t i = 0; i < rank; ++i) { + if (i > 0) { + repr.append(", "); + } + repr.append(std::to_string(iree_hal_buffer_view_shape_dim(raw_ptr(), i))); + } + repr.append(")"); + + // Element type. + repr.append(", element_type=0x"); + auto element_type = iree_hal_buffer_view_element_type(raw_ptr()); + repr.append(ToHexString(static_cast<uint32_t>(element_type))); + + repr.append(", "); + AppendHalBufferRepr(iree_hal_buffer_view_buffer(raw_ptr()), repr); + repr.append(">"); + return py::str(repr); +} + +//------------------------------------------------------------------------------ +// HalDriver +//------------------------------------------------------------------------------ + +std::vector<std::string> HalDriver::Query() { + iree_hal_driver_info_t* driver_infos = NULL; + iree_host_size_t driver_info_count = 0; + CheckApiStatus( + iree_hal_driver_registry_enumerate(iree_hal_driver_registry_default(), + iree_allocator_system(), &driver_infos, + &driver_info_count), + "Error enumerating HAL drivers"); + std::vector<std::string> driver_names(driver_info_count); + for (iree_host_size_t i = 0; i < driver_info_count; ++i) { + driver_names[i] = std::string(driver_infos[i].driver_name.data, + driver_infos[i].driver_name.size); + } + iree_allocator_free(iree_allocator_system(), driver_infos); + return driver_names; +} + +HalDriver HalDriver::Create(const std::string& driver_name) { + iree_hal_driver_t* driver; + CheckApiStatus(iree_hal_driver_registry_try_create_by_name( + iree_hal_driver_registry_default(), + {driver_name.data(), driver_name.size()}, + iree_allocator_system(), &driver), + "Error creating driver"); + return HalDriver::StealFromRawPtr(driver); +} + +HalDevice HalDriver::CreateDefaultDevice() { + iree_hal_device_t* device; + CheckApiStatus(iree_hal_driver_create_default_device( + raw_ptr(), iree_allocator_system(), &device), + "Error creating default device"); + return HalDevice::StealFromRawPtr(device); +} + +//------------------------------------------------------------------------------ +// Enum helpers +//------------------------------------------------------------------------------ + +namespace { + +py::object MapElementTypeToDType(iree_hal_element_type_t element_type) { + // See: https://docs.python.org/3/c-api/arg.html#numbers + // TODO: Handle dtypes that do not map to a code (i.e. fp16). + const char* dtype_code; + switch (element_type) { + case IREE_HAL_ELEMENT_TYPE_INT_8: + case IREE_HAL_ELEMENT_TYPE_SINT_8: + dtype_code = "b"; + break; + case IREE_HAL_ELEMENT_TYPE_UINT_8: + dtype_code = "B"; + break; + case IREE_HAL_ELEMENT_TYPE_INT_16: + case IREE_HAL_ELEMENT_TYPE_SINT_16: + dtype_code = "h"; + break; + case IREE_HAL_ELEMENT_TYPE_UINT_16: + dtype_code = "H"; + break; + case IREE_HAL_ELEMENT_TYPE_INT_32: + case IREE_HAL_ELEMENT_TYPE_SINT_32: + dtype_code = "i"; + break; + case IREE_HAL_ELEMENT_TYPE_UINT_32: + dtype_code = "I"; + break; + case IREE_HAL_ELEMENT_TYPE_INT_64: + case IREE_HAL_ELEMENT_TYPE_SINT_64: + dtype_code = "l"; + break; + case IREE_HAL_ELEMENT_TYPE_UINT_64: + dtype_code = "L"; + break; + case IREE_HAL_ELEMENT_TYPE_FLOAT_32: + dtype_code = "f"; + break; + case IREE_HAL_ELEMENT_TYPE_FLOAT_64: + dtype_code = "d"; + break; + case IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_INTEGER, 1): + dtype_code = "?"; + break; + default: + throw RaiseValueError("Unsupported VM Buffer -> numpy dtype mapping"); + } + return py::dtype(dtype_code); +} + +} // namespace + +//------------------------------------------------------------------------------ +// Bindings +//------------------------------------------------------------------------------ + +void SetupHalBindings(pybind11::module m) { + // Enums. + py::enum_<enum iree_hal_memory_type_bits_t>(m, "MemoryType") + .value("NONE", IREE_HAL_MEMORY_TYPE_NONE) + .value("TRANSIENT", IREE_HAL_MEMORY_TYPE_TRANSIENT) + .value("HOST_VISIBLE", IREE_HAL_MEMORY_TYPE_HOST_VISIBLE) + .value("HOST_COHERENT", IREE_HAL_MEMORY_TYPE_HOST_COHERENT) + .value("HOST_CACHED", IREE_HAL_MEMORY_TYPE_HOST_CACHED) + .value("HOST_LOCAL", IREE_HAL_MEMORY_TYPE_HOST_LOCAL) + .value("DEVICE_VISIBLE", IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE) + .value("DEVICE_LOCAL", IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL) + .export_values() + .def("__or__", + [](enum iree_hal_memory_type_bits_t self, + enum iree_hal_memory_type_bits_t other) { return self | other; }) + .def("__and__", + [](enum iree_hal_memory_type_bits_t self, + enum iree_hal_memory_type_bits_t other) { return self & other; }); + + py::enum_<enum iree_hal_buffer_compatibility_bits_t>(m, "BufferCompatibility") + .value("NONE", IREE_HAL_BUFFER_COMPATIBILITY_NONE) + .value("ALLOCATABLE", IREE_HAL_BUFFER_COMPATIBILITY_ALLOCATABLE) + .value("IMPORTABLE", IREE_HAL_BUFFER_COMPATIBILITY_IMPORTABLE) + .value("EXPORTABLE", IREE_HAL_BUFFER_COMPATIBILITY_EXPORTABLE) + .value("QUEUE_TRANSFER", IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER) + .value("QUEUE_DISPATCH", IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH) + .export_values() + .def("__or__", + [](enum iree_hal_buffer_compatibility_bits_t self, + enum iree_hal_buffer_compatibility_bits_t other) { + return self | other; + }) + .def("__and__", [](enum iree_hal_buffer_compatibility_bits_t self, + enum iree_hal_buffer_compatibility_bits_t other) { + return self & other; + }); + + py::enum_<enum iree_hal_buffer_usage_bits_t>(m, "BufferUsage") + .value("NONE", IREE_HAL_BUFFER_USAGE_NONE) + .value("CONSTANT", IREE_HAL_BUFFER_USAGE_CONSTANT) + .value("TRANSFER", IREE_HAL_BUFFER_USAGE_TRANSFER) + .value("MAPPING", IREE_HAL_BUFFER_USAGE_MAPPING) + .value("DISPATCH", IREE_HAL_BUFFER_USAGE_DISPATCH) + .export_values() + .def("__or__", + [](enum iree_hal_buffer_usage_bits_t self, + enum iree_hal_buffer_usage_bits_t other) { + return (enum iree_hal_buffer_usage_bits_t)(self | other); + }) + .def("__and__", [](enum iree_hal_buffer_usage_bits_t self, + enum iree_hal_buffer_usage_bits_t other) { + return (enum iree_hal_buffer_usage_bits_t)(self & other); + }); + + py::enum_<enum iree_hal_memory_access_bits_t>(m, "MemoryAccess") + .value("NONE", IREE_HAL_MEMORY_ACCESS_NONE) + .value("READ", IREE_HAL_MEMORY_ACCESS_READ) + .value("WRITE", IREE_HAL_MEMORY_ACCESS_WRITE) + .value("DISCARD", IREE_HAL_MEMORY_ACCESS_DISCARD) + .value("DISCARD_WRITE", IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE) + .value("ALL", IREE_HAL_MEMORY_ACCESS_ALL) + .export_values() + .def( + "__or__", + [](enum iree_hal_memory_access_bits_t self, + enum iree_hal_memory_access_bits_t other) { return self | other; }) + .def("__and__", [](enum iree_hal_memory_access_bits_t self, + enum iree_hal_memory_access_bits_t other) { + return self & other; + }); + + py::enum_<enum iree_hal_element_types_t>(m, "HalElementType") + .value("NONE", IREE_HAL_ELEMENT_TYPE_NONE) + .value("OPAQUE_8", IREE_HAL_ELEMENT_TYPE_OPAQUE_8) + .value("OPAQUE_16", IREE_HAL_ELEMENT_TYPE_OPAQUE_16) + .value("OPAQUE_32", IREE_HAL_ELEMENT_TYPE_OPAQUE_32) + .value("OPAQUE_64", IREE_HAL_ELEMENT_TYPE_OPAQUE_64) + .value("INT_4", IREE_HAL_ELEMENT_TYPE_INT_4) + .value("INT_8", IREE_HAL_ELEMENT_TYPE_INT_8) + .value("INT_16", IREE_HAL_ELEMENT_TYPE_INT_16) + .value("INT_32", IREE_HAL_ELEMENT_TYPE_INT_32) + .value("INT_64", IREE_HAL_ELEMENT_TYPE_INT_64) + .value("SINT_4", IREE_HAL_ELEMENT_TYPE_SINT_4) + .value("SINT_8", IREE_HAL_ELEMENT_TYPE_SINT_8) + .value("SINT_16", IREE_HAL_ELEMENT_TYPE_SINT_16) + .value("SINT_32", IREE_HAL_ELEMENT_TYPE_SINT_32) + .value("SINT_64", IREE_HAL_ELEMENT_TYPE_SINT_64) + .value("UINT_4", IREE_HAL_ELEMENT_TYPE_UINT_4) + .value("UINT_8", IREE_HAL_ELEMENT_TYPE_UINT_8) + .value("UINT_16", IREE_HAL_ELEMENT_TYPE_UINT_16) + .value("UINT_32", IREE_HAL_ELEMENT_TYPE_UINT_32) + .value("UINT_64", IREE_HAL_ELEMENT_TYPE_UINT_64) + .value("FLOAT_16", IREE_HAL_ELEMENT_TYPE_FLOAT_16) + .value("FLOAT_32", IREE_HAL_ELEMENT_TYPE_FLOAT_32) + .value("FLOAT_64", IREE_HAL_ELEMENT_TYPE_FLOAT_64) + .value("BFLOAT_16", IREE_HAL_ELEMENT_TYPE_BFLOAT_16) + .value("BOOL_8", + static_cast<iree_hal_element_types_t>(IREE_HAL_ELEMENT_TYPE_VALUE( + IREE_HAL_NUMERICAL_TYPE_INTEGER, 1))) + .export_values() + .def_static("map_to_dtype", &MapElementTypeToDType); + + py::class_<HalDevice>(m, "HalDevice") + .def_property_readonly("allocator", [](HalDevice& self) { + return HalAllocator::BorrowFromRawPtr(self.allocator()); + }); + + py::class_<HalDriver>(m, "HalDriver") + .def_static("query", &HalDriver::Query) + .def_static("create", &HalDriver::Create, py::arg("driver_name")) + .def("create_default_device", &HalDriver::CreateDefaultDevice); + + py::class_<HalAllocator>(m, "HalAllocator") + .def("trim", + [](HalAllocator& self) { + CheckApiStatus(iree_hal_allocator_trim(self.raw_ptr()), + "Error trim()'ing HAL allocator"); + }) + .def_property_readonly( + "has_statistics", + [](HalAllocator& self) -> bool { return IREE_STATISTICS_ENABLE; }) + .def_property_readonly("statistics", &HalAllocator::QueryStatistics) + .def_property_readonly("formatted_statistics", + &HalAllocator::FormattedStatistics) + .def( + "query_compatibility", + [](HalAllocator& self, int memory_type, int allowed_usage, + int intended_usage, iree_device_size_t allocation_size) -> int { + iree_hal_buffer_params_t params = {0}; + params.type = memory_type; + params.usage = allowed_usage & intended_usage; + return iree_hal_allocator_query_compatibility( + self.raw_ptr(), params, allocation_size); + }, + py::arg("memory_type"), py::arg("allowed_usage"), + py::arg("intended_usage"), py::arg("allocation_size")) + .def( + "allocate_buffer", + [](HalAllocator& self, int memory_type, int allowed_usage, + iree_device_size_t allocation_size) { + iree_hal_buffer_params_t params = {0}; + params.type = memory_type; + params.usage = allowed_usage; + iree_hal_buffer_t* buffer = nullptr; + iree_const_byte_span_t empty_initial_data{nullptr, 0}; + CheckApiStatus(iree_hal_allocator_allocate_buffer( + self.raw_ptr(), params, allocation_size, + empty_initial_data, &buffer), + "could not allocate buffer"); + return HalBuffer::StealFromRawPtr(buffer); + }, + py::arg("memory_type"), py::arg("allowed_usage"), + py::arg("allocation_size"), + "Allocates a new buffer with requested characteristics (does not " + "initialize with specific data).") + .def("allocate_buffer_copy", &HalAllocator::AllocateBufferCopy, + py::arg("memory_type"), py::arg("allowed_usage"), py::arg("buffer"), + py::arg("element_type") = py::none(), + "Allocates a new buffer and initializes it from a Python buffer " + "object. If an element type is specified, wraps in a BufferView " + "matching the characteristics of the Python buffer. The format is " + "requested as ND/C-Contiguous, which may incur copies if not " + "already in that format."); + + py::class_<HalBuffer>(m, "HalBuffer") + .def("fill_zero", &HalBuffer::FillZero, py::arg("byte_offset"), + py::arg("byte_length")) + .def("create_view", &HalBuffer::CreateView, py::arg("shape"), + py::arg("element_size")) + .def("__repr__", &HalBuffer::Repr); + + py::class_<HalBufferView>(m, "HalBufferView") + .def("map", HalMappedMemory::Create) + .def_property_readonly( + "shape", + [](HalBufferView& self) { + iree_host_size_t rank = + iree_hal_buffer_view_shape_rank(self.raw_ptr()); + auto* dims = iree_hal_buffer_view_shape_dims(self.raw_ptr()); + py::list result; + for (iree_host_size_t i = 0; i < rank; ++i) { + result.append(dims[i]); + } + return result; + }) + .def_property_readonly( + "element_type", + [](HalBufferView& self) { + return iree_hal_buffer_view_element_type(self.raw_ptr()); + }) + .def("__repr__", &HalBufferView::Repr); + + py::class_<HalMappedMemory>(m, "MappedMemory", py::buffer_protocol()) + .def_buffer(&HalMappedMemory::ToBufferInfo) + .def("asarray", + [](HalMappedMemory& self, std::vector<iree_host_size_t> shape, + py::object dtype) { + py::object py_mapped_memory = py::cast(self); + return py::array(std::move(dtype), shape, + self.mapped_memory().contents.data, + std::move(py_mapped_memory) /* base */); + }); + + py::class_<HalShape>(m, "Shape").def(py::init(&HalShape::FromIntVector)); +} + +} // namespace python +} // namespace iree
diff --git a/runtime/bindings/python/iree/runtime/hal.h b/runtime/bindings/python/iree/runtime/hal.h new file mode 100644 index 0000000..02201a9 --- /dev/null +++ b/runtime/bindings/python/iree/runtime/hal.h
@@ -0,0 +1,204 @@ +// Copyright 2019 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_BINDINGS_PYTHON_IREE_RT_HAL_H_ +#define IREE_BINDINGS_PYTHON_IREE_RT_HAL_H_ + +#include <vector> + +#include "./binding.h" +#include "./status_utils.h" +#include "iree/hal/api.h" + +namespace iree { +namespace python { + +//------------------------------------------------------------------------------ +// Retain/release bindings +//------------------------------------------------------------------------------ + +template <> +struct ApiPtrAdapter<iree_hal_driver_t> { + static void Retain(iree_hal_driver_t* d) { iree_hal_driver_retain(d); } + static void Release(iree_hal_driver_t* d) { iree_hal_driver_release(d); } +}; + +template <> +struct ApiPtrAdapter<iree_hal_device_t> { + static void Retain(iree_hal_device_t* d) { iree_hal_device_retain(d); } + static void Release(iree_hal_device_t* d) { iree_hal_device_release(d); } +}; + +template <> +struct ApiPtrAdapter<iree_hal_allocator_t> { + static void Retain(iree_hal_allocator_t* d) { iree_hal_allocator_retain(d); } + static void Release(iree_hal_allocator_t* d) { + iree_hal_allocator_release(d); + } +}; + +template <> +struct ApiPtrAdapter<iree_hal_buffer_t> { + static void Retain(iree_hal_buffer_t* b) { iree_hal_buffer_retain(b); } + static void Release(iree_hal_buffer_t* b) { iree_hal_buffer_release(b); } +}; + +template <> +struct ApiPtrAdapter<iree_hal_buffer_view_t> { + static void Retain(iree_hal_buffer_view_t* bv) { + iree_hal_buffer_view_retain(bv); + } + static void Release(iree_hal_buffer_view_t* bv) { + iree_hal_buffer_view_release(bv); + } +}; + +//------------------------------------------------------------------------------ +// ApiRefCounted types +//------------------------------------------------------------------------------ + +class HalDevice : public ApiRefCounted<HalDevice, iree_hal_device_t> { + public: + iree_hal_allocator_t* allocator() { + return iree_hal_device_allocator(raw_ptr()); + } +}; + +class HalDriver : public ApiRefCounted<HalDriver, iree_hal_driver_t> { + public: + static std::vector<std::string> Query(); + static HalDriver Create(const std::string& driver_name); + + HalDevice CreateDefaultDevice(); +}; + +class HalAllocator : public ApiRefCounted<HalAllocator, iree_hal_allocator_t> { + public: + py::dict QueryStatistics(); + py::str FormattedStatistics(); + + py::object AllocateBufferCopy( + int memory_type, int allowed_usage, py::object buffer, + std::optional<iree_hal_element_types_t> element_type); +}; + +struct HalShape { + public: + static HalShape FromIntVector(std::vector<int32_t> indices) { + HalShape s; + s.s = {indices.begin(), indices.end()}; + return s; + } + + std::vector<int32_t> s; +}; + +class HalBufferView + : public ApiRefCounted<HalBufferView, iree_hal_buffer_view_t> { + public: + py::str Repr(); +}; + +class HalBuffer : public ApiRefCounted<HalBuffer, iree_hal_buffer_t> { + public: + iree_device_size_t byte_length() const { + return iree_hal_buffer_byte_length(raw_ptr()); + } + + void FillZero(iree_device_size_t byte_offset, + iree_device_size_t byte_length) { + CheckApiStatus( + iree_hal_buffer_map_zero(raw_ptr(), byte_offset, byte_length), + "Error zero filling buffer"); + } + + // TODO(laurenzo): make this take element_type instead. + HalBufferView CreateView(HalShape& shape, size_t element_size) { + iree_hal_buffer_view_t* bv; + iree_hal_element_type_t element_type = iree_hal_make_element_type( + IREE_HAL_ELEMENT_TYPE_NONE, element_size * 8); + iree_hal_encoding_type_t encoding_type = + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR; + CheckApiStatus(iree_hal_buffer_view_create( + raw_ptr(), shape.s.data(), shape.s.size(), element_type, + encoding_type, iree_allocator_system(), &bv), + "Error creating buffer view"); + return HalBufferView::StealFromRawPtr(bv); + } + + py::str Repr(); +}; + +// Wrapper around an iree_hal_buffer_mapping_t and iree_hal_buffer_view_t +// which retains the latter and unmaps/releases on deallocation. +class HalMappedMemory { + public: + HalMappedMemory(iree_hal_buffer_mapping_t mapped_memory, + iree_hal_buffer_view_t* bv) + : mapped_memory_(mapped_memory), bv_(bv) { + iree_hal_buffer_view_retain(bv_); + } + ~HalMappedMemory() { + if (bv_) { + iree_hal_buffer_unmap_range(&mapped_memory_); + iree_hal_buffer_view_release(bv_); + } + } + HalMappedMemory(HalMappedMemory&& other) + : mapped_memory_(other.mapped_memory_), bv_(other.bv_) { + other.bv_ = nullptr; + } + + static HalMappedMemory Create(HalBufferView& bv) { + iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(bv.raw_ptr()); + iree_device_size_t byte_length = iree_hal_buffer_byte_length(buffer); + iree_hal_buffer_mapping_t mapped_memory = {{0}}; + CheckApiStatus( + iree_hal_buffer_map_range(buffer, IREE_HAL_MAPPING_MODE_SCOPED, + IREE_HAL_MEMORY_ACCESS_READ, 0, byte_length, + &mapped_memory), + "Could not map memory"); + return HalMappedMemory(mapped_memory, bv.raw_ptr()); + } + + py::buffer_info ToBufferInfo() { + std::vector<int32_t> shape(iree_hal_buffer_view_shape_rank(bv_)); + CheckApiStatus( + iree_hal_buffer_view_shape(bv_, shape.size(), shape.data(), nullptr), + "Error getting buffer view shape"); + iree_hal_element_type_t element_type = + iree_hal_buffer_view_element_type(bv_); + int32_t element_size = iree_hal_element_dense_byte_count(element_type); + std::vector<py::ssize_t> dims(shape.size()); + for (int i = 0; i < shape.size(); ++i) { + dims[i] = shape[i]; + } + std::vector<py::ssize_t> strides(shape.size()); + if (!strides.empty()) { + strides[shape.size() - 1] = element_size; + for (int i = shape.size() - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * shape[i + 1]; + } + } + + return py::buffer_info(mapped_memory_.contents.data, element_size, + py::format_descriptor<float>::format(), shape.size(), + dims, strides); + } + + iree_hal_buffer_mapping_t& mapped_memory() { return mapped_memory_; } + + private: + iree_hal_buffer_mapping_t mapped_memory_ = {{0}}; + iree_hal_buffer_view_t* bv_ = nullptr; +}; + +void SetupHalBindings(pybind11::module m); + +} // namespace python +} // namespace iree + +#endif // IREE_BINDINGS_PYTHON_IREE_RT_HAL_H_
diff --git a/runtime/bindings/python/iree/runtime/hal_test.py b/runtime/bindings/python/iree/runtime/hal_test.py new file mode 100644 index 0000000..f9f27e9 --- /dev/null +++ b/runtime/bindings/python/iree/runtime/hal_test.py
@@ -0,0 +1,132 @@ +# Copyright 2019 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import iree.runtime +import numpy as np +import unittest + + +class NonDeviceHalTest(unittest.TestCase): + + def testEnums(self): + print("MemoryType:", iree.runtime.MemoryType) + print("HOST_VISIBLE:", int(iree.runtime.MemoryType.HOST_VISIBLE)) + + # Enum and/or operations on BufferCompatibility. + self.assertEqual( + iree.runtime.BufferCompatibility.IMPORTABLE | + iree.runtime.BufferCompatibility.EXPORTABLE, + int(iree.runtime.BufferCompatibility.IMPORTABLE) | + int(iree.runtime.BufferCompatibility.EXPORTABLE)) + self.assertEqual( + iree.runtime.BufferCompatibility.EXPORTABLE & + iree.runtime.BufferCompatibility.EXPORTABLE, + int(iree.runtime.BufferCompatibility.EXPORTABLE)) + + # Enum and/or operations on BufferUsage. + self.assertEqual( + iree.runtime.BufferUsage.CONSTANT | iree.runtime.BufferUsage.TRANSFER, + int(iree.runtime.BufferUsage.CONSTANT) | + int(iree.runtime.BufferUsage.TRANSFER)) + self.assertEqual( + iree.runtime.BufferUsage.CONSTANT & iree.runtime.BufferUsage.CONSTANT, + int(iree.runtime.BufferUsage.CONSTANT)) + + # Enum and/or operations on MemoryAccess. + self.assertEqual( + iree.runtime.MemoryAccess.READ | iree.runtime.MemoryAccess.WRITE, + int(iree.runtime.MemoryAccess.READ) | + int(iree.runtime.MemoryAccess.WRITE)) + self.assertEqual( + iree.runtime.MemoryAccess.ALL & iree.runtime.MemoryAccess.READ, + int(iree.runtime.MemoryAccess.READ)) + + # Enum and/or operations on MemoryType. + self.assertEqual( + iree.runtime.MemoryType.TRANSIENT | + iree.runtime.MemoryType.HOST_VISIBLE, + int(iree.runtime.MemoryType.TRANSIENT) | + int(iree.runtime.MemoryType.HOST_VISIBLE)) + self.assertEqual( + iree.runtime.MemoryType.TRANSIENT & iree.runtime.MemoryType.TRANSIENT, + int(iree.runtime.MemoryType.TRANSIENT)) + + +class DeviceHalTest(unittest.TestCase): + + def setUp(self): + super().setUp() + self.driver = iree.runtime.HalDriver.create("vmvx") + self.device = self.driver.create_default_device() + self.allocator = self.device.allocator + + def testTrim(self): + self.allocator.trim() + # Just running is sufficient. + + def testStatistics(self): + stats_dict = self.allocator.statistics + stats_str = self.allocator.formatted_statistics + if self.allocator.has_statistics: + self.assertIn("host_bytes_peak", stats_dict) + self.assertIn("host_bytes_allocated", stats_dict) + self.assertIn("host_bytes_freed", stats_dict) + self.assertIn("device_bytes_peak", stats_dict) + self.assertIn("device_bytes_allocated", stats_dict) + self.assertIn("device_bytes_freed", stats_dict) + self.assertIn("HOST_LOCAL", stats_str) + + def testQueryCompatibility(self): + compat = self.allocator.query_compatibility( + memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, + allowed_usage=iree.runtime.BufferUsage.CONSTANT, + intended_usage=iree.runtime.BufferUsage.CONSTANT | + iree.runtime.BufferUsage.TRANSFER, + allocation_size=1024) + print("COMPAT:", compat) + self.assertTrue( + bool(compat & int(iree.runtime.BufferCompatibility.ALLOCATABLE)), + "should be allocatable") + self.assertTrue( + bool(compat & int(iree.runtime.BufferCompatibility.IMPORTABLE)), + "should be importable") + self.assertTrue( + bool(compat & int(iree.runtime.BufferCompatibility.EXPORTABLE)), + "should be exportable") + + def testAllocateBuffer(self): + buffer = self.allocator.allocate_buffer( + memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, + allowed_usage=iree.runtime.BufferUsage.CONSTANT, + allocation_size=13) + print("BUFFER:", buffer) + + def testAllocateBufferCopy(self): + ary = np.zeros([3, 4], dtype=np.int32) + 2 + buffer = self.allocator.allocate_buffer_copy( + memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, + allowed_usage=iree.runtime.BufferUsage.CONSTANT, + buffer=ary) + self.assertEqual( + repr(buffer), + "<HalBuffer 48 bytes (at offset 0 into 48), memory_type=DEVICE_LOCAL|HOST_VISIBLE, allowed_access=ALL, allowed_usage=CONSTANT|TRANSFER|MAPPING>" + ) + + def testAllocateBufferViewCopy(self): + ary = np.zeros([3, 4], dtype=np.int32) + 2 + buffer = self.allocator.allocate_buffer_copy( + memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, + allowed_usage=iree.runtime.BufferUsage.CONSTANT, + buffer=ary, + element_type=iree.runtime.HalElementType.SINT_32) + self.assertEqual( + repr(buffer), + "<HalBufferView (3, 4), element_type=0x20000011, 48 bytes (at offset 0 into 48), memory_type=DEVICE_LOCAL|HOST_VISIBLE, allowed_access=ALL, allowed_usage=CONSTANT|TRANSFER|MAPPING>" + ) + + +if __name__ == "__main__": + unittest.main()
diff --git a/runtime/bindings/python/iree/runtime/initialize_module.cc b/runtime/bindings/python/iree/runtime/initialize_module.cc new file mode 100644 index 0000000..211e7dc --- /dev/null +++ b/runtime/bindings/python/iree/runtime/initialize_module.cc
@@ -0,0 +1,50 @@ +// Copyright 2019 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "./binding.h" +#include "./hal.h" +#include "./invoke.h" +#include "./status_utils.h" +#include "./vm.h" +#include "iree/base/internal/flags.h" +#include "iree/base/status_cc.h" +#include "iree/hal/drivers/init.h" + +namespace iree { +namespace python { + +PYBIND11_MODULE(binding, m) { + IREE_CHECK_OK(iree_hal_register_all_available_drivers( + iree_hal_driver_registry_default())); + + m.doc() = "IREE Binding Backend Helpers"; + SetupHalBindings(m); + SetupInvokeBindings(m); + SetupVmBindings(m); + + m.def("parse_flags", [](py::args py_flags) { + std::vector<std::string> alloced_flags; + alloced_flags.push_back("python"); + for (auto &py_flag : py_flags) { + alloced_flags.push_back(py::cast<std::string>(py_flag)); + } + + // Must build pointer vector after filling so pointers are stable. + std::vector<char *> flag_ptrs; + for (auto &alloced_flag : alloced_flags) { + flag_ptrs.push_back(const_cast<char *>(alloced_flag.c_str())); + } + + char **argv = &flag_ptrs[0]; + int argc = flag_ptrs.size(); + CheckApiStatus( + iree_flags_parse(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv), + "Error parsing flags"); + }); +} + +} // namespace python +} // namespace iree
diff --git a/runtime/bindings/python/iree/runtime/invoke.cc b/runtime/bindings/python/iree/runtime/invoke.cc new file mode 100644 index 0000000..b0b9a59 --- /dev/null +++ b/runtime/bindings/python/iree/runtime/invoke.cc
@@ -0,0 +1,713 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "./invoke.h" + +#include "./hal.h" +#include "./vm.h" +#include "iree/base/api.h" +#include "iree/base/tracing.h" +#include "iree/hal/api.h" +#include "iree/modules/hal/module.h" +#include "iree/vm/api.h" + +namespace iree { +namespace python { + +namespace { + +class InvokeContext { + public: + InvokeContext(HalDevice &device) : device_(device) {} + + HalDevice &device() { return device_; } + HalAllocator allocator() { + // TODO: Unfortunate that we inc ref here but that is how our object model + // is set up. + return HalAllocator::BorrowFromRawPtr(device().allocator()); + } + + private: + HalDevice device_; +}; + +using PackCallback = + std::function<void(InvokeContext &, iree_vm_list_t *, py::handle)>; + +class InvokeStatics { + public: + ~InvokeStatics() { + for (auto it : py_type_to_pack_callbacks_) { + py::handle(it.first).dec_ref(); + } + } + + py::str kNamedTag = py::str("named"); + py::str kSlistTag = py::str("slist"); + py::str kStupleTag = py::str("stuple"); + py::str kSdictTag = py::str("sdict"); + + py::int_ kZero = py::int_(0); + py::int_ kOne = py::int_(1); + py::int_ kTwo = py::int_(2); + py::str kAsArray = py::str("asarray"); + py::str kMapDtypeToElementTypeAttr = py::str("map_dtype_to_element_type"); + py::str kContiguousArg = py::str("C"); + py::str kArrayProtocolAttr = py::str("__array__"); + py::str kDtypeAttr = py::str("dtype"); + + // Primitive type names. + py::str kF32 = py::str("f32"); + py::str kF64 = py::str("f64"); + py::str kI1 = py::str("i1"); + py::str kI8 = py::str("i8"); + py::str kI16 = py::str("i16"); + py::str kI32 = py::str("i32"); + py::str kI64 = py::str("i64"); + + // Compound types names. + py::str kNdarray = py::str("ndarray"); + + // Attribute names. + py::str kAttrBufferView = py::str("_buffer_view"); + + // Module 'numpy'. + py::module &numpy_module() { return numpy_module_; } + + py::object &runtime_module() { + if (!runtime_module_) { + runtime_module_ = py::module::import("iree.runtime"); + } + return *runtime_module_; + } + + py::module &array_interop_module() { + if (!array_interop_module_) { + array_interop_module_ = py::module::import("iree.runtime.array_interop"); + } + return *array_interop_module_; + } + + py::object &device_array_type() { + if (!device_array_type_) { + device_array_type_ = runtime_module().attr("DeviceArray"); + } + return *device_array_type_; + } + + py::type &hal_buffer_view_type() { return hal_buffer_view_type_; } + + py::object MapElementAbiTypeToDtype(py::object &element_abi_type) { + try { + return abi_type_to_dtype_[element_abi_type]; + } catch (std::exception &) { + std::string msg("could not map abi type "); + msg.append(py::cast<std::string>(py::repr(element_abi_type))); + msg.append(" to numpy dtype"); + throw std::invalid_argument(std::move(msg)); + } + } + + enum iree_hal_element_types_t MapDtypeToElementType(py::object dtype) { + // TODO: Consider porting this from a py func to C++ as it can be on + // the critical path. + try { + py::object element_type = + array_interop_module().attr(kMapDtypeToElementTypeAttr)(dtype); + if (element_type.is_none()) { + throw std::invalid_argument("mapping not found"); + } + return py::cast<enum iree_hal_element_types_t>(element_type); + } catch (std::exception &e) { + std::string msg("could not map dtype "); + msg.append(py::cast<std::string>(py::repr(dtype))); + msg.append(" to element type: "); + msg.append(e.what()); + throw std::invalid_argument(std::move(msg)); + } + } + + PackCallback AbiTypeToPackCallback(py::handle desc) { + return AbiTypeToPackCallback( + std::move(desc), /*desc_is_list=*/py::isinstance<py::list>(desc)); + } + + // Given an ABI desc, return a callback that can pack a corresponding py + // value into a list. For efficiency, the caller must specify whether the + // desc is a list (this check already needs to be done typically so + // passed in). + PackCallback AbiTypeToPackCallback(py::handle desc, bool desc_is_list) { + // Switch based on descriptor type. + if (desc_is_list) { + // Compound type. + py::object compound_type = desc[kZero]; + if (compound_type.equal(kNdarray)) { + // Has format: + // ["ndarray", "f32", dim0, dim1, ...] + // Extract static information about the target. + std::vector<int64_t> abi_shape(py::len(desc) - 2); + for (size_t i = 0, e = abi_shape.size(); i < e; ++i) { + py::handle dim = desc[py::int_(i + 2)]; + abi_shape[i] = dim.is_none() ? -1 : py::cast<int64_t>(dim); + } + + // Map abi element type to dtype. + py::object abi_type = desc[kOne]; + py::object target_dtype = MapElementAbiTypeToDtype(abi_type); + auto hal_element_type = MapDtypeToElementType(target_dtype); + + return [this, target_dtype = std::move(target_dtype), hal_element_type, + abi_shape = std::move(abi_shape)](InvokeContext &c, + iree_vm_list_t *list, + py::handle py_value) { + IREE_TRACE_SCOPE0("ArgumentPacker::ReflectionNdarray"); + HalBufferView *bv = nullptr; + py::object retained_bv; + if (py::isinstance(py_value, device_array_type())) { + // Short-circuit: If a DeviceArray is provided, assume it is + // correct. + IREE_TRACE_SCOPE0("PackDeviceArray"); + bv = py::cast<HalBufferView *>(py_value.attr(kAttrBufferView)); + } else if (py::isinstance(py_value, hal_buffer_view_type())) { + // Short-circuit: If a HalBufferView is provided directly. + IREE_TRACE_SCOPE0("PackBufferView"); + bv = py::cast<HalBufferView *>(py_value); + } else { + // Fall back to the array protocol to generate a host side + // array and then convert that. + IREE_TRACE_SCOPE0("PackHostArray"); + py::object host_array; + try { + host_array = numpy_module().attr(kAsArray)(py_value, target_dtype, + kContiguousArg); + } catch (std::exception &e) { + std::string msg("could not convert value to numpy array: dtype="); + msg.append(py::cast<std::string>(py::repr(target_dtype))); + msg.append(", error='"); + msg.append(e.what()); + msg.append("', value="); + msg.append(py::cast<std::string>(py::repr(py_value))); + throw std::invalid_argument(std::move(msg)); + } + + retained_bv = c.allocator().AllocateBufferCopy( + IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, + IREE_HAL_BUFFER_USAGE_DISPATCH | + IREE_HAL_BUFFER_USAGE_TRANSFER | + IREE_HAL_BUFFER_USAGE_MAPPING, + host_array, hal_element_type); + bv = py::cast<HalBufferView *>(retained_bv); + } + + // TODO: Add some shape verification. Not strictly necessary as the VM + // will check, but may make error reporting nicer. + // TODO: It is theoretically possible to enqueue further conversions + // on the device, but for now we require things to line up closely. + // TODO: If adding further manipulation here, please make this common + // with the generic access case. + iree_vm_ref_t buffer_view_ref = + iree_hal_buffer_view_retain_ref(bv->raw_ptr()); + CheckApiStatus(iree_vm_list_push_ref_move(list, &buffer_view_ref), + "could not push buffer view to list"); + }; + } else if (compound_type.equal(kSlistTag) || + compound_type.equal(kStupleTag)) { + // Tuple/list extraction. + // When decoding a list or tuple, the desc object is like: + // ['slist', [...value_type_0...], ...] + // Where the type is either 'slist' or 'stuple'. + std::vector<PackCallback> sub_packers(py::len(desc) - 1); + for (size_t i = 0; i < sub_packers.size(); i++) { + sub_packers[i] = AbiTypeToPackCallback(desc[py::int_(i + 1)]); + } + return [sub_packers = std::move(sub_packers)](InvokeContext &c, + iree_vm_list_t *list, + py::handle py_value) { + if (py::len(py_value) != sub_packers.size()) { + std::string msg("expected a sequence with "); + msg.append(std::to_string(sub_packers.size())); + msg.append(" values. got: "); + msg.append(py::cast<std::string>(py::repr(py_value))); + throw std::invalid_argument(std::move(msg)); + } + VmVariantList item_list = VmVariantList::Create(sub_packers.size()); + for (size_t i = 0; i < sub_packers.size(); ++i) { + py::object item_py_value; + try { + item_py_value = py_value[py::int_(i)]; + } catch (std::exception &e) { + std::string msg("could not get item "); + msg.append(std::to_string(i)); + msg.append(" from: "); + msg.append(py::cast<std::string>(py::repr(py_value))); + msg.append(": "); + msg.append(e.what()); + throw std::invalid_argument(std::move(msg)); + } + sub_packers[i](c, item_list.raw_ptr(), item_py_value); + } + + // Push the sub list. + iree_vm_ref_t retained = + iree_vm_list_retain_ref(item_list.steal_raw_ptr()); + iree_vm_list_push_ref_move(list, &retained); + }; + } else if (compound_type.equal(kSdictTag)) { + // Dict extraction. + // The descriptor for an sdict is like: + // ['sdict', ['key1', value1], ...] + std::vector<std::pair<py::object, PackCallback>> sub_packers( + py::len(desc) - 1); + for (size_t i = 0; i < sub_packers.size(); i++) { + py::object sub_desc = desc[py::int_(i + 1)]; + py::object key = sub_desc[kZero]; + py::object value_desc = sub_desc[kOne]; + sub_packers[i] = + std::make_pair(std::move(key), AbiTypeToPackCallback(value_desc)); + } + return [sub_packers = std::move(sub_packers)](InvokeContext &c, + iree_vm_list_t *list, + py::handle py_value) { + if (py::len(py_value) != sub_packers.size()) { + std::string msg("expected a dict with "); + msg.append(std::to_string(sub_packers.size())); + msg.append(" values. got: "); + msg.append(py::cast<std::string>(py::repr(py_value))); + throw std::invalid_argument(std::move(msg)); + } + VmVariantList item_list = VmVariantList::Create(sub_packers.size()); + for (size_t i = 0; i < sub_packers.size(); ++i) { + py::object item_py_value; + try { + item_py_value = py_value[sub_packers[i].first]; + } catch (std::exception &e) { + std::string msg("could not get item "); + msg.append(py::cast<std::string>(py::repr(sub_packers[i].first))); + msg.append(" from: "); + msg.append(py::cast<std::string>(py::repr(py_value))); + msg.append(": "); + msg.append(e.what()); + throw std::invalid_argument(std::move(msg)); + } + sub_packers[i].second(c, item_list.raw_ptr(), item_py_value); + } + + // Push the sub list. + iree_vm_ref_t retained = + iree_vm_list_retain_ref(item_list.steal_raw_ptr()); + iree_vm_list_push_ref_move(list, &retained); + }; + } else { + std::string message("Unrecognized reflection compound type: "); + message.append(py::cast<std::string>(compound_type)); + throw std::invalid_argument(message); + } + } else { + // Primtive type. + py::str prim_type = py::cast<py::str>(desc); + if (prim_type.equal(kF32)) { + // f32 + return [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) { + iree_vm_value_t vm_value = + iree_vm_value_make_f32(py::cast<float>(py_value)); + CheckApiStatus(iree_vm_list_push_value(list, &vm_value), + "could not append value"); + }; + } else if (prim_type.equal(kF64)) { + // f64 + return [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) { + iree_vm_value_t vm_value = + iree_vm_value_make_f64(py::cast<double>(py_value)); + CheckApiStatus(iree_vm_list_push_value(list, &vm_value), + "could not append value"); + }; + } else if (prim_type.equal(kI32)) { + // i32. + return [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) { + iree_vm_value_t vm_value = + iree_vm_value_make_i32(py::cast<int32_t>(py_value)); + CheckApiStatus(iree_vm_list_push_value(list, &vm_value), + "could not append value"); + }; + } else if (prim_type.equal(kI64)) { + // i64. + return [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) { + iree_vm_value_t vm_value = + iree_vm_value_make_i64(py::cast<int64_t>(py_value)); + CheckApiStatus(iree_vm_list_push_value(list, &vm_value), + "could not append value"); + }; + } else if (prim_type.equal(kI8)) { + // i8. + return [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) { + iree_vm_value_t vm_value = + iree_vm_value_make_i8(py::cast<int8_t>(py_value)); + CheckApiStatus(iree_vm_list_push_value(list, &vm_value), + "could not append value"); + }; + } else if (prim_type.equal(kI16)) { + // i16. + return [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) { + iree_vm_value_t vm_value = + iree_vm_value_make_i16(py::cast<int16_t>(py_value)); + CheckApiStatus(iree_vm_list_push_value(list, &vm_value), + "could not append value"); + }; + } else { + std::string message("Unrecognized reflection primitive type: "); + message.append(py::cast<std::string>(prim_type)); + throw std::invalid_argument(message); + } + } + } + + PackCallback GetGenericPackCallbackFor(py::handle arg) { + PopulatePyTypeToPackCallbacks(); + py::type clazz = py::type::of(arg); + auto found_it = py_type_to_pack_callbacks_.find(clazz.ptr()); + if (found_it == py_type_to_pack_callbacks_.end()) { + // Probe to see if we have a host array. + if (py::hasattr(arg, kArrayProtocolAttr)) { + return GetGenericPackCallbackForNdarray(); + } + return {}; + } + + return found_it->second; + } + + private: + PackCallback GetGenericPackCallbackForNdarray() { + return [this](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) { + IREE_TRACE_SCOPE0("ArgumentPacker::GenericNdarray"); + py::object host_array; + try { + host_array = numpy_module().attr(kAsArray)( + py_value, /*dtype=*/py::none(), kContiguousArg); + } catch (std::exception &e) { + std::string msg("could not convert value to numpy array: "); + msg.append("error='"); + msg.append(e.what()); + msg.append("', value="); + msg.append(py::cast<std::string>(py::repr(py_value))); + throw std::invalid_argument(std::move(msg)); + } + + auto hal_element_type = + MapDtypeToElementType(host_array.attr(kDtypeAttr)); + + // Put it on the device. + py::object retained_bv = c.allocator().AllocateBufferCopy( + IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, + IREE_HAL_BUFFER_USAGE_DISPATCH | IREE_HAL_BUFFER_USAGE_TRANSFER | + IREE_HAL_BUFFER_USAGE_MAPPING, + host_array, hal_element_type); + HalBufferView *bv = py::cast<HalBufferView *>(retained_bv); + + // TODO: If adding further manipulation here, please make this common + // with the reflection access case. + iree_vm_ref_t buffer_view_ref = + iree_hal_buffer_view_retain_ref(bv->raw_ptr()); + CheckApiStatus(iree_vm_list_push_ref_move(list, &buffer_view_ref), + "could not append value"); + }; + } + + void PopulatePyTypeToPackCallbacks() { + if (!py_type_to_pack_callbacks_.empty()) return; + + // We only care about int and double in the numeric hierarchy. Since Python + // has no further refinement of these, just treat them as vm 64 bit int and + // floats and let the VM take care of it. There isn't much else we can do. + AddPackCallback( + py::type::of(py::cast(1)), + [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) { + iree_vm_value_t vm_value = + iree_vm_value_make_i64(py::cast<int64_t>(py_value)); + CheckApiStatus(iree_vm_list_push_value(list, &vm_value), + "could not append value"); + }); + + AddPackCallback( + py::type::of(py::cast(1.0)), + [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) { + iree_vm_value_t vm_value = + iree_vm_value_make_f64(py::cast<double>(py_value)); + CheckApiStatus(iree_vm_list_push_value(list, &vm_value), + "could not append value"); + }); + + // List/tuple. + auto sequence_callback = [this](InvokeContext &c, iree_vm_list_t *list, + py::handle py_value) { + auto py_seq = py::cast<py::sequence>(py_value); + VmVariantList item_list = VmVariantList::Create(py::len(py_seq)); + for (py::object py_item : py_seq) { + PackCallback sub_packer = GetGenericPackCallbackFor(py_item); + if (!sub_packer) { + std::string message("could not convert python value to VM: "); + message.append(py::cast<std::string>(py::repr(py_item))); + throw std::invalid_argument(std::move(message)); + } + sub_packer(c, item_list.raw_ptr(), py_item); + } + // Push the sub list. + iree_vm_ref_t retained = + iree_vm_list_retain_ref(item_list.steal_raw_ptr()); + iree_vm_list_push_ref_move(list, &retained); + }; + AddPackCallback(py::type::of(py::list{}), sequence_callback); + AddPackCallback(py::type::of(py::tuple{}), sequence_callback); + + // Dict. + auto dict_callback = [this](InvokeContext &c, iree_vm_list_t *list, + py::handle py_value) { + // Gets all dict items and sorts (by key). + auto py_dict = py::cast<py::dict>(py_value); + py::list py_keys; + for (std::pair<py::handle, py::handle> it : py_dict) { + py_keys.append(it.first); + } + py_keys.attr("sort")(); + + VmVariantList item_list = VmVariantList::Create(py_keys.size()); + for (auto py_key : py_keys) { + py::object py_item = py_dict[py_key]; + PackCallback sub_packer = GetGenericPackCallbackFor(py_item); + if (!sub_packer) { + std::string message("could not convert python value to VM: "); + message.append(py::cast<std::string>(py::repr(py_item))); + throw std::invalid_argument(std::move(message)); + } + sub_packer(c, item_list.raw_ptr(), py_item); + } + // Push the sub list. + iree_vm_ref_t retained = + iree_vm_list_retain_ref(item_list.steal_raw_ptr()); + iree_vm_list_push_ref_move(list, &retained); + }; + AddPackCallback(py::type::of(py::dict{}), dict_callback); + + // HalBufferView. + AddPackCallback( + py::type::of<HalBufferView>(), + [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) { + HalBufferView *bv = py::cast<HalBufferView *>(py_value); + iree_vm_ref_t buffer_view_ref = + iree_hal_buffer_view_retain_ref(bv->raw_ptr()); + CheckApiStatus(iree_vm_list_push_ref_move(list, &buffer_view_ref), + "could not append value"); + }); + + // DeviceArray. + AddPackCallback( + device_array_type(), + [this](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) { + HalBufferView *bv = + py::cast<HalBufferView *>(py_value.attr(kAttrBufferView)); + iree_vm_ref_t buffer_view_ref = + iree_hal_buffer_view_retain_ref(bv->raw_ptr()); + CheckApiStatus(iree_vm_list_push_ref_move(list, &buffer_view_ref), + "could not append value"); + }); + } + + void AddPackCallback(py::handle t, PackCallback pcb) { + assert(py_type_to_pack_callbacks_.count(t.ptr()) == 0 && "duplicate types"); + t.inc_ref(); + py_type_to_pack_callbacks_.insert(std::make_pair(t.ptr(), std::move(pcb))); + } + + py::dict BuildAbiTypeToDtype() { + auto d = py::dict(); + d[kF32] = numpy_module().attr("float32"); + d[kF64] = numpy_module().attr("float64"); + d[kI1] = numpy_module().attr("bool_"); + d[kI8] = numpy_module().attr("int8"); + d[kI16] = numpy_module().attr("int16"); + d[kI64] = numpy_module().attr("int64"); + d[kI32] = numpy_module().attr("int32"); + return d; + } + + // Cached modules and types. Those that involve recursive lookup within + // our top level module, we defer. Those outside, we cache at creation. + py::module numpy_module_ = py::module::import("numpy"); + std::optional<py::object> runtime_module_; + std::optional<py::module> array_interop_module_; + std::optional<py::object> device_array_type_; + py::type hal_buffer_view_type_ = py::type::of<HalBufferView>(); + + // Maps Python type to a PackCallback that can generically code it. + // This will have inc_ref() called on them when added. + std::unordered_map<PyObject *, PackCallback> py_type_to_pack_callbacks_; + + // Dict of str (ABI dtype like 'f32') to numpy dtype. + py::dict abi_type_to_dtype_ = BuildAbiTypeToDtype(); +}; + +/// Object that can pack Python arguments into a VM List for a specific +/// function. +class ArgumentPacker { + public: + ArgumentPacker(InvokeStatics &statics, std::optional<py::list> arg_descs) + : statics_(statics) { + IREE_TRACE_SCOPE0("ArgumentPacker::Init"); + if (!arg_descs) { + dynamic_dispatch_ = true; + } else { + // Reflection dispatch. + for (py::handle desc : *arg_descs) { + int arg_index = flat_arg_packers_.size(); + std::optional<std::string> kwarg_name; + py::object retained_sub_desc; + + bool desc_is_list = py::isinstance<py::list>(desc); + + // Check if named. + // ["named", "kwarg_name", sub_desc] + // If found, then we set kwarg_name and reset desc to the sub_desc. + if (desc_is_list) { + py::object maybe_named_field = desc[statics.kZero]; + if (maybe_named_field.equal(statics.kNamedTag)) { + py::object name_field = desc[statics.kOne]; + retained_sub_desc = desc[statics.kTwo]; + kwarg_name = py::cast<std::string>(name_field); + desc = retained_sub_desc; + desc_is_list = py::isinstance<py::list>(desc); + + kwarg_to_index_[name_field] = arg_index; + } + } + + if (!kwarg_name) { + pos_only_arg_count_ += 1; + } + + flat_arg_packers_.push_back( + statics.AbiTypeToPackCallback(desc, desc_is_list)); + } + } + } + + /// Packs positional/kw arguments into a suitable VmVariantList and returns + /// it. + VmVariantList Pack(InvokeContext &invoke_context, py::sequence pos_args, + py::dict kw_args) { + // Dynamic dispatch. + if (dynamic_dispatch_) { + IREE_TRACE_SCOPE0("ArgumentPacker::PackDynamic"); + if (!kw_args.empty()) { + throw std::invalid_argument( + "kwargs not supported for dynamic dispatch functions"); + } + + VmVariantList arg_list = VmVariantList::Create(pos_args.size()); + for (py::handle py_arg : pos_args) { + PackCallback packer = statics_.GetGenericPackCallbackFor(py_arg); + if (!packer) { + std::string message("could not convert python value to VM: "); + message.append(py::cast<std::string>(py::repr(py_arg))); + throw std::invalid_argument(std::move(message)); + } + // TODO: Better error handling by catching the exception and + // reporting which arg has a problem. + packer(invoke_context, arg_list.raw_ptr(), py_arg); + } + return arg_list; + } else { + IREE_TRACE_SCOPE0("ArgumentPacker::PackReflection"); + + // Reflection based dispatch. + std::vector<py::handle> py_args(flat_arg_packers_.size()); + + if (pos_args.size() > pos_only_arg_count_) { + std::string message("mismatched call arity: expected "); + message.append(std::to_string(pos_only_arg_count_)); + message.append(" got "); + message.append(std::to_string(pos_args.size())); + throw std::invalid_argument(std::move(message)); + } + + // Positional args. + size_t pos_index = 0; + for (py::handle py_arg : pos_args) { + py_args[pos_index++] = py_arg; + } + + // Keyword args. + for (auto it : kw_args) { + int found_index; + try { + found_index = py::cast<int>(kwarg_to_index_[it.first]); + } catch (std::exception &) { + std::string message("specified kwarg '"); + message.append(py::cast<py::str>(it.first)); + message.append("' is unknown"); + throw std::invalid_argument(std::move(message)); + } + if (py_args[found_index]) { + std::string message( + "mismatched call arity: duplicate keyword argument '"); + message.append(py::cast<py::str>(it.first)); + message.append("'"); + throw std::invalid_argument(std::move(message)); + } + py_args[found_index] = it.second; + } + + // Now check to see that all args are set. + for (size_t i = 0; i < py_args.size(); ++i) { + if (!py_args[i]) { + std::string message( + "mismatched call arity: expected a value for argument "); + message.append(std::to_string(i)); + throw std::invalid_argument(std::move(message)); + } + } + + // Start packing into the list. + VmVariantList arg_list = VmVariantList::Create(flat_arg_packers_.size()); + for (size_t i = 0; i < py_args.size(); ++i) { + // TODO: Better error handling by catching the exception and + // reporting which arg has a problem. + flat_arg_packers_[i](invoke_context, arg_list.raw_ptr(), py_args[i]); + } + return arg_list; + } + } + + private: + InvokeStatics &statics_; + + int pos_only_arg_count_ = 0; + + // Dictionary of py::str -> py::int_ mapping kwarg names to position in + // the argument list. We store this as a py::dict because it is optimized + // for py::str lookup. + py::dict kwarg_to_index_; + + std::vector<PackCallback> flat_arg_packers_; + + // If true, then there is no dispatch metadata and we process fully + // dynamically. + bool dynamic_dispatch_ = false; +}; + +} // namespace + +void SetupInvokeBindings(pybind11::module &m) { + py::class_<InvokeStatics>(m, "_InvokeStatics"); + py::class_<InvokeContext>(m, "InvokeContext").def(py::init<HalDevice &>()); + py::class_<ArgumentPacker>(m, "ArgumentPacker") + .def(py::init<InvokeStatics &, std::optional<py::list>>()) + .def("pack", &ArgumentPacker::Pack); + + m.attr("_invoke_statics") = py::cast(InvokeStatics()); +} + +} // namespace python +} // namespace iree
diff --git a/runtime/bindings/python/iree/runtime/invoke.h b/runtime/bindings/python/iree/runtime/invoke.h new file mode 100644 index 0000000..206524f --- /dev/null +++ b/runtime/bindings/python/iree/runtime/invoke.h
@@ -0,0 +1,20 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_BINDINGS_PYTHON_IREE_RT_INVOKE_H_ +#define IREE_BINDINGS_PYTHON_IREE_RT_INVOKE_H_ + +#include "./binding.h" + +namespace iree { +namespace python { + +void SetupInvokeBindings(pybind11::module &m); + +} // namespace python +} // namespace iree + +#endif // IREE_BINDINGS_PYTHON_IREE_RT_INVOKE_H_
diff --git a/runtime/bindings/python/iree/runtime/scripts/iree_benchmark_trace/__main__.py b/runtime/bindings/python/iree/runtime/scripts/iree_benchmark_trace/__main__.py new file mode 100644 index 0000000..007e96e --- /dev/null +++ b/runtime/bindings/python/iree/runtime/scripts/iree_benchmark_trace/__main__.py
@@ -0,0 +1,21 @@ +# Copyright 2021 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +import subprocess +import sys + + +def main(args=None): + if args is None: + args = sys.argv[1:] + exe = os.path.join(os.path.dirname(__file__), "..", "..", + "iree-benchmark-trace") + return subprocess.call(args=[exe] + args) + + +if __name__ == "__main__": + sys.exit(main())
diff --git a/runtime/bindings/python/iree/runtime/scripts/iree_run_module/__main__.py b/runtime/bindings/python/iree/runtime/scripts/iree_run_module/__main__.py new file mode 100644 index 0000000..a5509a3 --- /dev/null +++ b/runtime/bindings/python/iree/runtime/scripts/iree_run_module/__main__.py
@@ -0,0 +1,20 @@ +# Copyright 2021 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +import subprocess +import sys + + +def main(args=None): + if args is None: + args = sys.argv[1:] + exe = os.path.join(os.path.dirname(__file__), "..", "..", "iree-run-module") + return subprocess.call(args=[exe] + args) + + +if __name__ == "__main__": + sys.exit(main())
diff --git a/runtime/bindings/python/iree/runtime/scripts/iree_run_trace/__main__.py b/runtime/bindings/python/iree/runtime/scripts/iree_run_trace/__main__.py new file mode 100644 index 0000000..08dced3 --- /dev/null +++ b/runtime/bindings/python/iree/runtime/scripts/iree_run_trace/__main__.py
@@ -0,0 +1,20 @@ +# Copyright 2021 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +import subprocess +import sys + + +def main(args=None): + if args is None: + args = sys.argv[1:] + exe = os.path.join(os.path.dirname(__file__), "..", "..", "iree-run-trace") + return subprocess.call(args=[exe] + args) + + +if __name__ == "__main__": + sys.exit(main())
diff --git a/runtime/bindings/python/iree/runtime/scripts/iree_tracy_capture/__main__.py b/runtime/bindings/python/iree/runtime/scripts/iree_tracy_capture/__main__.py new file mode 100644 index 0000000..58f2118 --- /dev/null +++ b/runtime/bindings/python/iree/runtime/scripts/iree_tracy_capture/__main__.py
@@ -0,0 +1,21 @@ +# Copyright 2021 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +import subprocess +import sys + + +def main(args=None): + if args is None: + args = sys.argv[1:] + exe = os.path.join(os.path.dirname(__file__), "..", "..", + "iree-tracy-capture") + return subprocess.call(args=[exe] + args) + + +if __name__ == "__main__": + sys.exit(main())
diff --git a/runtime/bindings/python/iree/runtime/status_utils.cc b/runtime/bindings/python/iree/runtime/status_utils.cc new file mode 100644 index 0000000..a05bfd5 --- /dev/null +++ b/runtime/bindings/python/iree/runtime/status_utils.cc
@@ -0,0 +1,69 @@ +// Copyright 2019 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "./status_utils.h" + +namespace iree { +namespace python { + +namespace { + +PyObject* ApiStatusToPyExcClass(iree_status_t status) { + switch (iree_status_code(status)) { + case IREE_STATUS_INVALID_ARGUMENT: + return PyExc_ValueError; + case IREE_STATUS_OUT_OF_RANGE: + return PyExc_IndexError; + case IREE_STATUS_UNIMPLEMENTED: + return PyExc_NotImplementedError; + default: + return PyExc_RuntimeError; + } +} + +static std::string ApiStatusToString(iree_status_t status) { + iree_host_size_t buffer_length = 0; + if (IREE_UNLIKELY(!iree_status_format(status, /*buffer_capacity=*/0, + /*buffer=*/NULL, &buffer_length))) { + return ""; + } + std::string result; + result.resize(buffer_length); + // NOTE: buffer capacity needs to be +1 for the NUL terminator in snprintf. + return iree_status_format(status, result.size() + 1, + const_cast<char*>(result.data()), &buffer_length) + ? result + : ""; +} + +} // namespace + +pybind11::error_already_set ApiStatusToPyExc(iree_status_t status, + const char* message) { + assert(!iree_status_is_ok(status)); + std::string full_message; + + auto status_str = ApiStatusToString(status); + if (status_str.empty()) { + full_message = std::string(message) + ": " + + iree_status_code_string(iree_status_code(status)); + } else { + full_message = std::string(message) + ": " + status_str; + } + + PyErr_SetString(ApiStatusToPyExcClass(status), full_message.c_str()); + iree_status_ignore(status); + return pybind11::error_already_set(); +} + +pybind11::error_already_set RaisePyError(PyObject* exc_class, + const char* message) { + PyErr_SetString(exc_class, message); + return pybind11::error_already_set(); +} + +} // namespace python +} // namespace iree
diff --git a/runtime/bindings/python/iree/runtime/status_utils.h b/runtime/bindings/python/iree/runtime/status_utils.h new file mode 100644 index 0000000..d87d308 --- /dev/null +++ b/runtime/bindings/python/iree/runtime/status_utils.h
@@ -0,0 +1,48 @@ +// Copyright 2019 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_BINDINGS_PYTHON_IREE_COMMON_STATUS_UTILS_H_ +#define IREE_BINDINGS_PYTHON_IREE_COMMON_STATUS_UTILS_H_ + +#include "iree/base/api.h" +#include "pybind11/pybind11.h" + +namespace iree { +namespace python { + +// Raises a value error with the given message. +// Correct usage: +// throw RaiseValueError(PyExc_ValueError, "Foobar'd"); +pybind11::error_already_set RaisePyError(PyObject* exc_class, + const char* message); + +// Raises a value error with the given message. +// Correct usage: +// throw RaiseValueError("Foobar'd"); +inline pybind11::error_already_set RaiseValueError(const char* message) { + return RaisePyError(PyExc_ValueError, message); +} + +pybind11::error_already_set ApiStatusToPyExc(iree_status_t status, + const char* message); + +inline void CheckApiStatus(iree_status_t status, const char* message) { + if (iree_status_is_ok(status)) { + return; + } + throw ApiStatusToPyExc(status, message); +} + +inline void CheckApiNotNull(const void* p, const char* message) { + if (!p) { + throw RaiseValueError(message); + } +} + +} // namespace python +} // namespace iree + +#endif // IREE_BINDINGS_PYTHON_IREE_COMMON_STATUS_UTILS_H_
diff --git a/runtime/bindings/python/iree/runtime/system_api.py b/runtime/bindings/python/iree/runtime/system_api.py new file mode 100644 index 0000000..fc82a3b --- /dev/null +++ b/runtime/bindings/python/iree/runtime/system_api.py
@@ -0,0 +1,358 @@ +# Lint as: python3 +# Copyright 2019 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""Top-level python system API. + +This facility layers on top of the underlying binding native facilities and +exposes them in a way that allows general operation against contexts, modules +and functions. +""" + +# pylint: disable=protected-access +# pylint: disable=unused-argument +# pylint: disable=g-explicit-length-test + +# TODO(#4131) python>=3.7: Use postponed type annotations. + +__all__ = [ + "load_vm_flatbuffer", + "load_vm_flatbuffer_file", + "load_vm_module", + "load_vm_modules", + "normalize_value", + "Config", + "SystemContext", + "TARGET_BACKEND_TO_DRIVER", +] + +import logging +import os +import sys + +from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union + +from . import binding as _binding +from .function import FunctionInvoker +from . import tracing + +import numpy as np + +# Environment key for a comma-delimitted list of drivers to try to load. +PREFERRED_DRIVER_ENV_KEY = "IREE_DEFAULT_DRIVER" + +# Default value for IREE_DRIVER +DEFAULT_IREE_DRIVER_VALUE = "dylib,vulkan,vmvx" + +# Mapping from IREE target backends to their corresponding drivers. +TARGET_BACKEND_TO_DRIVER = { + "dylib-llvm-aot": "dylib", + "vmvx": "vmvx", + "vulkan-spirv": "vulkan", +} + + +def _create_default_iree_driver( + driver_names: Optional[Sequence[str]] = None) -> _binding.HalDriver: + """Returns a default driver based on environment settings.""" + # TODO(laurenzo): Ideally this should take a VmModule and join any explicitly + # provided driver list with environmental constraints and what the module + # was compiled for. + if driver_names is None: + # Read from environment. + driver_names = os.environ.get(PREFERRED_DRIVER_ENV_KEY) + if driver_names is None: + driver_names = DEFAULT_IREE_DRIVER_VALUE + driver_names = driver_names.split(",") + available_driver_names = _binding.HalDriver.query() + driver_exceptions = {} + for driver_name in driver_names: + if driver_name not in available_driver_names: + logging.error("Could not create driver %s (not registered)", driver_name) + continue + try: + driver = _binding.HalDriver.create(driver_name) + except Exception as ex: # pylint: disable=broad-except + logging.exception("Could not create default driver %s", driver_name) + driver_exceptions[driver_name] = ex + continue + + # Sanity check creation of the default device and skip the driver if + # this fails (this works around issues where the driver is present + # but there are no devices). This default initialization scheme needs + # to be improved. + try: + device = driver.create_default_device() + except Exception as ex: + logging.exception("Could not create default driver device %s", + driver_name) + driver_exceptions[driver_name] = ex + continue + + logging.debug("Created IREE driver %s: %r", driver_name, driver) + return driver + + # All failed. + raise RuntimeError( + f"Could not create any requested driver {repr(driver_names)} (available=" + f"{repr(available_driver_names)}) : {repr(driver_exceptions)}") + + +class Config: + """System configuration.""" + + driver: _binding.HalDriver + device: _binding.HalDevice + vm_instance: _binding.VmInstance + default_vm_modules: Tuple[_binding.VmModule, ...] + tracer: Optional[tracing.Tracer] + + def __init__(self, + driver_name: Optional[str] = None, + tracer: Optional[tracing.Tracer] = None): + self.vm_instance = _binding.VmInstance() + self.driver = _create_default_iree_driver( + driver_name.split(",") if driver_name is not None else None) + self.device = self.driver.create_default_device() + hal_module = _binding.create_hal_module(self.device) + self.default_vm_modules = (hal_module,) + self.tracer = tracer or tracing.get_default_tracer() + if self.tracer and self.tracer.enabled: + logging.info("IREE runtime tracing calls to path: %s", + self.tracer.trace_path) + else: + self.tracer = None + + +_global_config = None + + +def _get_global_config(): + global _global_config + if _global_config is None: + _global_config = Config() + return _global_config + + +def _bool_to_int8( + array: Any) -> Optional[Union[np.ndarray, List[Any], Tuple[Any]]]: + if not isinstance(array, np.ndarray): + return array + + # IREE models booleans as i8s. + # TODO(#5359): This cast should be moved into the function abi. + if array.dtype == np.bool: + array = array.astype(np.int8) + return array + + +def normalize_value( + value: Any) -> Optional[Union[np.ndarray, List[Any], Tuple[Any]]]: + """Normalizes the given value for input to (or comparison with) IREE.""" + if value is None: + # Exclude None from falling through to blanket np.asarray conversion. + return value + + if isinstance(value, (list, tuple, dict)): + return value + + array = np.asarray(value) + # TODO(#5359): Move into the function abi. + if isinstance(value, (bool, int, float)): + # Manually convert ints and floats to 32 bits. + if array.dtype == np.float64: + array = array.astype(np.float32) + elif array.dtype == np.int64: + array = array.astype(np.int32) + + return array + + +def _convert_lists_to_tuples(pytree): + if isinstance(pytree, Sequence): + return tuple(_convert_lists_to_tuples(leaf) for leaf in pytree) + elif isinstance(pytree, Mapping): + for key in pytree: + pytree[key] = _convert_lists_to_tuples(pytree[key]) + return pytree + else: + return pytree + + +class BoundModule: + """Wraps a VmModule with its context and provides nice python accessors. + + Resolves item access (["foo"]) as function resolution. + """ + + def __init__(self, context: "SystemContext", vm_module: _binding.VmModule): + self._context = context + self._tracer = self._context._config.tracer + self._vm_module = vm_module + self._lazy_functions = dict() + + # Let the tracing infra create a traced module. + self.traced_module = None + if self._tracer: + self.traced_module = self._tracer.persist_vm_module(vm_module) + + @property + def name(self): + return self._vm_module.name + + @property + def vm_module(self): + return self._vm_module + + def __getattr__(self, name): + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __getitem__(self, name): + vm_function = self._lazy_functions.get(name) + if vm_function is not None: + return vm_function + + vm_function = self._vm_module.lookup_function(name) + if vm_function is None: + raise KeyError(f"Function '{name}' not found in module '{self}'") + + # TODO: Needing to know the precise device to allocate on here is bad + # layering and will need to be fixed in some fashion if/when doing + # heterogenous dispatch. + return FunctionInvoker(self._context.vm_context, + self._context.config.device, vm_function, + self._context._tracer) + + def __repr__(self): + return f"<BoundModule {repr(self._vm_module)}>" + + +class BoundModules(dict): + """Provides nice python accessors for a dict of BoundModules.""" + + def __getattr__(self, name): + try: + return self[name] + except KeyError: + raise AttributeError(name) + + +class SystemContext: + """Global system.""" + + def __init__(self, vm_modules=None, config: Optional[Config] = None): + self._config = config if config is not None else _get_global_config() + logging.debug("SystemContext driver=%r", self._config.driver) + self._is_dynamic = vm_modules is None + if self._is_dynamic: + init_vm_modules = None + else: + init_vm_modules = self._config.default_vm_modules + tuple(vm_modules) + + self._vm_context = _binding.VmContext(instance=self._config.vm_instance, + modules=init_vm_modules) + + if self._is_dynamic: + self._vm_context.register_modules(self._config.default_vm_modules) + self._bound_modules = BoundModules([ + (m.name, BoundModule(self, m)) + for m in self._config.default_vm_modules + ]) + else: + self._bound_modules = BoundModules([ + (m.name, BoundModule(self, m)) for m in init_vm_modules + ]) + + self._tracer = None # type: Optional[tracing.ContextTracer] + if self._config.tracer: + self._tracer = tracing.ContextTracer( + self._config.tracer, + is_dynamic=self._is_dynamic, + modules=[bm.traced_module for bm in self._bound_modules.values()]) + + @property + def vm_context(self) -> _binding.VmContext: + return self._vm_context + + @property + def is_dynamic(self) -> bool: + return self._is_dynamic + + @property + def config(self) -> Config: + return self._config + + @property + def instance(self) -> _binding.VmInstance: + return self._instance + + @property + def modules(self) -> BoundModules: + return self._bound_modules + + def add_vm_modules(self, vm_modules): + assert self._is_dynamic, "Cannot 'add_module' on a static context" + for m in vm_modules: + if m.name in self._bound_modules: + raise ValueError(f"Attempt to register duplicate VmModule: '{m.name}'") + bound_module = BoundModule(self, m) + self._bound_modules[m.name] = bound_module + if self._tracer: + self._tracer.add_module(bound_module.traced_module) + self._vm_context.register_modules(vm_modules) + + def add_vm_module(self, vm_module): + self.add_vm_modules((vm_module,)) + + +def load_vm_modules(*vm_modules, config: Optional[Config] = None): + """Loads VmModules into a new SystemContext and returns them.""" + context = SystemContext(vm_modules=vm_modules, config=config) + bound_modules = [context.modules[m.name] for m in vm_modules] + return bound_modules + + +def load_vm_module(vm_module, config: Optional[Config] = None): + """Loads a VmModule into a new SystemContext and returns it.""" + return load_vm_modules(vm_module, config=config)[0] + + +def load_vm_flatbuffer(vm_flatbuffer: bytes, + *, + driver: Optional[str] = None, + backend: Optional[str] = None) -> BoundModule: + """Loads a VM Flatbuffer into a callable module. + + Either 'driver' or 'backend' must be specified. + """ + if driver is None and backend is None: + raise ValueError("Either 'driver' or 'backend' must be specified, but got " + "'None' for both.") + if backend is not None and driver is not None: + raise ValueError("Cannot specify both 'driver' and a 'backend' to infer " + "the driver from.") + if backend is not None: + driver = TARGET_BACKEND_TO_DRIVER[backend] + vm_module = _binding.VmModule.from_flatbuffer(vm_flatbuffer) + bound_module = load_vm_module(vm_module, Config(driver)) + return bound_module + + +# TODO: There should be an API for mmap'ing the file which should be used +# instead of reading into memory. +def load_vm_flatbuffer_file(path: str, + *, + driver: Optional[str] = None, + backend: Optional[str] = None) -> BoundModule: + """Loads a file containing a VM Flatbuffer into a callable module. + + Either 'driver' or 'backend' must be specified. + """ + with open(path, "rb") as f: + vm_flatbuffer = f.read() + return load_vm_flatbuffer(vm_flatbuffer, driver=driver, backend=backend)
diff --git a/runtime/bindings/python/iree/runtime/system_api_test.py b/runtime/bindings/python/iree/runtime/system_api_test.py new file mode 100644 index 0000000..ed9a585 --- /dev/null +++ b/runtime/bindings/python/iree/runtime/system_api_test.py
@@ -0,0 +1,141 @@ +# Lint as: python3 +# Copyright 2019 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# pylint: disable=unused-variable + +import os +import re +import tempfile + +from absl import logging +from absl.testing import absltest +import iree.compiler +import iree.runtime +import numpy as np + + +def create_simple_mul_module(): + binary = iree.compiler.compile_str( + """ + module @arithmetic { + func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + %0 = "mhlo.multiply"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> + } + } + """, + input_type="mhlo", + target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS, + ) + m = iree.runtime.VmModule.from_flatbuffer(binary) + return m + + +class SystemApiTest(absltest.TestCase): + + def test_non_existing_driver(self): + with self.assertRaisesRegex(RuntimeError, + "Could not create any requested driver"): + config = iree.runtime.Config("nothere1,nothere2") + + def test_subsequent_driver(self): + config = iree.runtime.Config("nothere1,dylib") + + def test_empty_dynamic(self): + ctx = iree.runtime.SystemContext() + self.assertTrue(ctx.is_dynamic) + self.assertIn("hal", ctx.modules) + self.assertEqual(ctx.modules.hal.name, "hal") + + def test_empty_static(self): + ctx = iree.runtime.SystemContext(vm_modules=()) + self.assertFalse(ctx.is_dynamic) + self.assertIn("hal", ctx.modules) + self.assertEqual(ctx.modules.hal.name, "hal") + + def test_custom_dynamic(self): + ctx = iree.runtime.SystemContext() + self.assertTrue(ctx.is_dynamic) + ctx.add_vm_module(create_simple_mul_module()) + self.assertEqual(ctx.modules.arithmetic.name, "arithmetic") + f = ctx.modules.arithmetic["simple_mul"] + f_repr = repr(f) + logging.info("f_repr: %s", f_repr) + self.assertEqual(f_repr, "<VmFunction simple_mul(0rr_r), reflection = {}>") + + def test_duplicate_module(self): + ctx = iree.runtime.SystemContext() + self.assertTrue(ctx.is_dynamic) + ctx.add_vm_module(create_simple_mul_module()) + with self.assertRaisesRegex(ValueError, "arithmetic"): + ctx.add_vm_module(create_simple_mul_module()) + + def test_static_invoke(self): + ctx = iree.runtime.SystemContext() + self.assertTrue(ctx.is_dynamic) + ctx.add_vm_module(create_simple_mul_module()) + self.assertEqual(ctx.modules.arithmetic.name, "arithmetic") + f = ctx.modules.arithmetic["simple_mul"] + arg0 = np.array([1., 2., 3., 4.], dtype=np.float32) + arg1 = np.array([4., 5., 6., 7.], dtype=np.float32) + results = f(arg0, arg1) + np.testing.assert_allclose(results, [4., 10., 18., 28.]) + + def test_chained_invoke(self): + # This ensures that everything works if DeviceArrays are returned + # and input to functions. + ctx = iree.runtime.SystemContext() + self.assertTrue(ctx.is_dynamic) + ctx.add_vm_module(create_simple_mul_module()) + self.assertEqual(ctx.modules.arithmetic.name, "arithmetic") + f = ctx.modules.arithmetic["simple_mul"] + arg0 = np.array([1., 2., 3., 4.], dtype=np.float32) + arg1 = np.array([4., 5., 6., 7.], dtype=np.float32) + results = f(arg0, arg1) + results2 = f(results, results) + np.testing.assert_allclose(results2, [16., 100., 324., 784.]) + + def test_tracing_explicit(self): + with tempfile.TemporaryDirectory() as temp_dir: + tracer = iree.runtime.Tracer(temp_dir) + config = iree.runtime.Config("dylib", tracer=tracer) + self.verify_tracing(config, temp_dir) + + def test_tracing_from_environment(self): + original = os.environ.get(iree.runtime.TRACE_PATH_ENV_KEY) + try: + with tempfile.TemporaryDirectory() as temp_dir: + os.environ[iree.runtime.TRACE_PATH_ENV_KEY] = temp_dir + config = iree.runtime.Config("dylib") + self.verify_tracing(config, temp_dir) + finally: + if original: + os.environ[iree.runtime.TRACE_PATH_ENV_KEY] = original + + def verify_tracing(self, config, temp_dir): + logging.info("Tracing test to: %s", temp_dir) + ctx = iree.runtime.SystemContext(config=config) + ctx.add_vm_module(create_simple_mul_module()) + f = ctx.modules.arithmetic["simple_mul"] + arg0 = np.array([1., 2., 3., 4.], dtype=np.float32) + arg1 = np.array([4., 5., 6., 7.], dtype=np.float32) + results = f(arg0, arg1) + self.assertTrue(os.path.exists(os.path.join(temp_dir, "arithmetic.vmfb"))) + self.assertTrue(os.path.exists(os.path.join(temp_dir, "calls.yaml"))) + # TODO: Once replay is possible, verify that. + + def test_load_vm_module(self): + arithmetic = iree.runtime.load_vm_module(create_simple_mul_module()) + arg0 = np.array([1., 2., 3., 4.], dtype=np.float32) + arg1 = np.array([4., 5., 6., 7.], dtype=np.float32) + results = arithmetic.simple_mul(arg0, arg1) + print("SIMPLE_MUL RESULTS:", results) + np.testing.assert_allclose(results, [4., 10., 18., 28.]) + + +if __name__ == "__main__": + absltest.main()
diff --git a/runtime/bindings/python/iree/runtime/tracing.py b/runtime/bindings/python/iree/runtime/tracing.py new file mode 100644 index 0000000..ea804ac --- /dev/null +++ b/runtime/bindings/python/iree/runtime/tracing.py
@@ -0,0 +1,170 @@ +"""Tracing support.""" + +# Copyright 2021 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from genericpath import exists +from typing import Dict, List, Optional, Sequence + +import logging +import os +import sys + +from . import binding as _binding + +try: + import yaml +except ModuleNotFoundError: + _has_yaml = False +else: + _has_yaml = True + +__all__ = [ + "get_default_tracer", + "Tracer", + "TRACE_PATH_ENV_KEY", +] + +TRACE_PATH_ENV_KEY = "IREE_SAVE_CALLS" + + +class Tracer: + """Object for tracing calls made into the runtime.""" + + def __init__(self, trace_path: str): + if not _has_yaml: + self.enabled = False + logging.warning("PyYAML not installed: tracing will be disabled") + return + self.enabled = True + self.trace_path = trace_path + os.makedirs(trace_path, exist_ok=True) + self._name_count = dict() # type: Dict[str, int] + + def persist_vm_module(self, vm_module: _binding.VmModule) -> "TracedModule": + # Depending on how the module was created, there are different bits + # of information available to reconstruct. + name = vm_module.name + flatbuffer_blob = vm_module.stashed_flatbuffer_blob + if flatbuffer_blob: + save_path = os.path.join(self.trace_path, + self.get_unique_name(f"{name}.vmfb")) + logging.info("Saving traced vmfb to %s", save_path) + with open(save_path, "wb") as f: + f.write(flatbuffer_blob) + return TracedModule(self, vm_module, save_path) + + # No persistent form, but likely they are built-in modules. + return TracedModule(self, vm_module) + + def get_unique_name(self, local_name: str) -> str: + if local_name not in self._name_count: + self._name_count[local_name] = 1 + return local_name + stem, ext = os.path.splitext(local_name) + index = self._name_count[local_name] + self._name_count[local_name] += 1 + unique_name = f"{stem}__{index}{ext}" + return unique_name + + +class TracedModule: + """Wraps a VmModule with additional information for tracing.""" + + def __init__(self, + parent: Tracer, + vm_module: _binding.VmModule, + vmfb_path: Optional[str] = None): + self._parent = parent + self._vm_module = vm_module + self._vmfb_path = vmfb_path + + def serialize(self): + module_record = {"name": self._vm_module.name} + if self._vmfb_path: + module_record["type"] = "bytecode" + module_record["path"] = os.path.relpath(self._vmfb_path, + self._parent.trace_path) + else: + module_record["type"] = "builtin" + + return module_record + + +class ContextTracer: + """Traces invocations against a context.""" + + def __init__(self, parent: Tracer, is_dynamic: bool, + modules: Sequence[TracedModule]): + self._parent = parent + self._modules = list(modules) # type: List[TracedModule] + self._frame_count = 0 + self._file_path = os.path.join(parent.trace_path, + parent.get_unique_name("calls.yaml")) + if os.path.exists(self._file_path): + # Truncate the file. + with open(self._file_path, "wt"): + pass + else: + os.makedirs(os.path.dirname(parent.trace_path), exist_ok=True) + logging.info("Tracing context events to: %s", self._file_path) + self.emit_frame({ + "type": "context_load", + }) + for module in self._modules: + self.emit_frame({ + "type": "module_load", + "module": module.serialize(), + }) + + def add_module(self, module: TracedModule): + self._modules.append(module) + self.emit_frame({ + "type": "module_load", + "module": module.serialize(), + }) + + def start_call(self, function: _binding.VmFunction): + logging.info("Tracing call to %s.%s", function.module_name, function.name) + + # Start assembling the call record. + record = { + "type": "call", + "function": "%s.%s" % (function.module_name, function.name), + } + return CallTrace(self, record) + + def emit_frame(self, frame: dict): + self._frame_count += 1 + with open(self._file_path, "at") as f: + if self._frame_count != 1: + f.write("---\n") + contents = yaml.dump(frame, sort_keys=False) + f.write(contents) + + +class CallTrace: + + def __init__(self, parent: ContextTracer, record: dict): + self._parent = parent + self._record = record + + def add_vm_list(self, vm_list: _binding.VmVariantList, key: str): + mapped = [] + for i in range(len(vm_list)): + mapped.append(vm_list.get_serialized_trace_value(i)) + self._record[key] = mapped + + def end_call(self): + self._parent.emit_frame(self._record) + + +def get_default_tracer() -> Optional[Tracer]: + """Gets a default call tracer based on environment variables.""" + default_path = os.getenv(TRACE_PATH_ENV_KEY) + if not default_path: + return None + return Tracer(default_path)
diff --git a/runtime/bindings/python/iree/runtime/unix_version.lds b/runtime/bindings/python/iree/runtime/unix_version.lds new file mode 100644 index 0000000..fd766d1 --- /dev/null +++ b/runtime/bindings/python/iree/runtime/unix_version.lds
@@ -0,0 +1,11 @@ +/* Copyright 2020 The IREE Authors + * +/* Licensed under the Apache License v2.0 with LLVM Exceptions. +/* See https://llvm.org/LICENSE.txt for license information. +/* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + */ + +{ + global: PyInit_binding; + local: *; +};
diff --git a/runtime/bindings/python/iree/runtime/vm.cc b/runtime/bindings/python/iree/runtime/vm.cc new file mode 100644 index 0000000..53dbbd1 --- /dev/null +++ b/runtime/bindings/python/iree/runtime/vm.cc
@@ -0,0 +1,616 @@ +// Copyright 2019 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "./vm.h" + +#include "./status_utils.h" +#include "iree/base/api.h" +#include "iree/base/status_cc.h" +#include "iree/base/tracing.h" +#include "iree/hal/api.h" +#include "iree/modules/hal/module.h" +#include "iree/vm/api.h" +#include "pybind11/numpy.h" + +namespace iree { +namespace python { + +namespace { + +VmModule CreateHalModule(HalDevice* device) { + iree_vm_module_t* module; + CheckApiStatus(iree_hal_module_create(device->raw_ptr(), + iree_allocator_system(), &module), + "Error creating hal module"); + return VmModule::StealFromRawPtr(module); +} + +// RAII wrapper for a Py_buffer which calls PyBuffer_Release when it goes +// out of scope. +class PyBufferReleaser { + public: + PyBufferReleaser(Py_buffer& b) : b_(b) {} + ~PyBufferReleaser() { PyBuffer_Release(&b_); } + + private: + Py_buffer& b_; +}; + +py::dict GetFunctionReflectionDict(iree_vm_function_t& f) { + py::dict attrs; + for (int i = 0;; ++i) { + iree_string_view_t key; + iree_string_view_t value; + auto status = iree_vm_get_function_reflection_attr(f, i, &key, &value); + if (iree_status_is_not_found(status)) { + iree_status_ignore(status); + break; + } + CheckApiStatus(status, "Error getting reflection attr"); + py::str key_str(key.data, key.size); + py::str value_str(value.data, value.size); + attrs[std::move(key_str)] = std::move(value_str); + } + return attrs; +} + +} // namespace + +//------------------------------------------------------------------------------ +// VmInstance +//------------------------------------------------------------------------------ + +VmInstance VmInstance::Create() { + IREE_TRACE_SCOPE0("VmInstance::Create"); + iree_vm_instance_t* instance; + auto status = iree_vm_instance_create(iree_allocator_system(), &instance); + CheckApiStatus(status, "Error creating instance"); + return VmInstance::StealFromRawPtr(instance); +} + +//------------------------------------------------------------------------------ +// VmContext +//------------------------------------------------------------------------------ + +VmContext VmContext::Create(VmInstance* instance, + std::optional<std::vector<VmModule*>> modules) { + IREE_TRACE_SCOPE0("VmContext::Create"); + iree_vm_context_t* context; + if (!modules) { + // Simple create with open allowed modules. + auto status = + iree_vm_context_create(instance->raw_ptr(), IREE_VM_CONTEXT_FLAG_NONE, + iree_allocator_system(), &context); + CheckApiStatus(status, "Error creating vm context"); + } else { + // Closed set of modules. + std::vector<iree_vm_module_t*> module_handles; + module_handles.resize(modules->size()); + for (size_t i = 0, e = module_handles.size(); i < e; ++i) { + module_handles[i] = (*modules)[i]->raw_ptr(); + } + auto status = iree_vm_context_create_with_modules( + instance->raw_ptr(), IREE_VM_CONTEXT_FLAG_NONE, module_handles.data(), + module_handles.size(), iree_allocator_system(), &context); + CheckApiStatus(status, "Error creating vm context with modules"); + } + + IREE_CHECK(context); + return VmContext::StealFromRawPtr(context); +} + +void VmContext::RegisterModules(std::vector<VmModule*> modules) { + std::vector<iree_vm_module_t*> module_handles; + module_handles.resize(modules.size()); + for (size_t i = 0, e = module_handles.size(); i < e; ++i) { + module_handles[i] = modules[i]->raw_ptr(); + } + auto status = iree_vm_context_register_modules(raw_ptr(), &module_handles[0], + module_handles.size()); + CheckApiStatus(status, "Error registering modules"); +} + +void VmContext::Invoke(iree_vm_function_t f, VmVariantList& inputs, + VmVariantList& outputs) { + iree_status_t status; + { + py::gil_scoped_release release; + status = iree_vm_invoke(raw_ptr(), f, IREE_VM_INVOCATION_FLAG_NONE, nullptr, + inputs.raw_ptr(), outputs.raw_ptr(), + iree_allocator_system()); + } + CheckApiStatus(status, "Error invoking function"); +} + +//------------------------------------------------------------------------------ +// VmModule +//------------------------------------------------------------------------------ + +VmModule VmModule::FromFlatbufferBlob(py::object flatbuffer_blob_object) { + IREE_TRACE_SCOPE0("VmModule::FromFlatbufferBlob"); + auto flatbuffer_blob = py::cast<py::buffer>(flatbuffer_blob_object); + auto buffer_info = flatbuffer_blob.request(); + iree_vm_module_t* module; + + // Bridge to the C-based deallocator API. + auto* raw_ptr = flatbuffer_blob.ptr(); + auto ctl_fn = +([](void* self, iree_allocator_command_t command, + const void* params, void** inout_ptr) { + assert(command == IREE_ALLOCATOR_COMMAND_FREE); + PyObject* object_ptr = static_cast<PyObject*>(*inout_ptr); + Py_XDECREF(object_ptr); + return iree_ok_status(); + }); + flatbuffer_blob.inc_ref(); + iree_allocator_t deallocator{/*self=*/NULL, /*ctl=*/ctl_fn}; + + auto status = iree_vm_bytecode_module_create( + {static_cast<const uint8_t*>(buffer_info.ptr), + static_cast<iree_host_size_t>(buffer_info.size)}, + deallocator, iree_allocator_system(), &module); + if (!iree_status_is_ok(status)) { + iree_allocator_free(deallocator, raw_ptr); + } + + CheckApiStatus(status, "Error creating vm module from flatbuffer"); + auto py_module = VmModule::StealFromRawPtr(module); + py_module.stashed_flatbuffer_blob = flatbuffer_blob_object; + return py_module; +} + +std::optional<iree_vm_function_t> VmModule::LookupFunction( + const std::string& name, iree_vm_function_linkage_t linkage) { + iree_vm_function_t f; + auto status = iree_vm_module_lookup_function_by_name( + raw_ptr(), linkage, {name.data(), name.size()}, &f); + if (iree_status_is_not_found(status)) { + iree_status_ignore(status); + return std::nullopt; + } + CheckApiStatus(status, "Error looking up function"); + return f; +} + +//------------------------------------------------------------------------------ +// VmVariantList +//------------------------------------------------------------------------------ + +void VmVariantList::PushFloat(double fvalue) { + // Note that Python floats are f64. + iree_vm_value_t value = iree_vm_value_make_f64(fvalue); + CheckApiStatus(iree_vm_list_push_value(raw_ptr(), &value), + "Could not push float"); +} + +void VmVariantList::PushInt(int64_t ivalue) { + // Note that Python ints are unbounded, so just use the largest type we + // have. + iree_vm_value_t value = iree_vm_value_make_i64(ivalue); + CheckApiStatus(iree_vm_list_push_value(raw_ptr(), &value), + "Could not push int"); +} + +void VmVariantList::PushList(VmVariantList& other) { + iree_vm_ref_t retained = iree_vm_list_retain_ref(other.raw_ptr()); + iree_vm_list_push_ref_move(raw_ptr(), &retained); +} + +void VmVariantList::PushBufferView(HalBufferView& buffer_view) { + iree_vm_ref_t buffer_view_ref = + iree_hal_buffer_view_retain_ref(buffer_view.raw_ptr()); + CheckApiStatus(iree_vm_list_push_ref_move(raw_ptr(), &buffer_view_ref), + "Error moving buffer view"); +} + +py::object VmVariantList::GetAsList(int index) { + iree_vm_ref_t ref = {0}; + CheckApiStatus(iree_vm_list_get_ref_assign(raw_ptr(), index, &ref), + "Could not access list element"); + iree_vm_list_t* sub_list = NULL; + CheckApiStatus(iree_vm_list_check_deref(ref, &sub_list), + "Could not deref list (wrong type?)"); + iree_vm_list_retain(sub_list); + return py::cast(VmVariantList(sub_list)); +} + +py::object VmVariantList::GetVariant(int index) { + iree_vm_variant_t v = iree_vm_variant_empty(); + CheckApiStatus(iree_vm_list_get_variant(raw_ptr(), index, &v), + "Could not access list element"); + if (iree_vm_type_def_is_value(&v.type)) { + // Convert a value type. + switch (v.type.value_type) { + case IREE_VM_VALUE_TYPE_I8: + return py::cast(v.i8); + case IREE_VM_VALUE_TYPE_I16: + return py::cast(v.i16); + case IREE_VM_VALUE_TYPE_I32: + return py::cast(v.i32); + case IREE_VM_VALUE_TYPE_I64: + return py::cast(v.i64); + case IREE_VM_VALUE_TYPE_F32: + return py::cast(v.f32); + case IREE_VM_VALUE_TYPE_F64: + return py::cast(v.f64); + default: + throw RaiseValueError("Unsupported VM value type conversion"); + } + } else if (v.type.ref_type == IREE_VM_REF_TYPE_NULL) { + return py::none(); + } else if (iree_vm_type_def_is_ref(&v.type)) { + // Convert reference type. + if (iree_vm_list_isa(v.ref)) { + return GetAsList(index); + } else if (iree_hal_buffer_view_isa(v.ref)) { + return GetAsBufferView(index); + } + } + + throw RaiseValueError("Unsupported VM to Python Type Conversion"); +} + +py::object VmVariantList::GetAsSerializedTraceValue(int index) { + iree_vm_variant_t v = iree_vm_variant_empty(); + CheckApiStatus(iree_vm_list_get_variant(raw_ptr(), index, &v), + "Could not access list element"); + if (iree_vm_type_def_is_value(&v.type)) { + // Convert a value type. + py::dict record; + switch (v.type.value_type) { + case IREE_VM_VALUE_TYPE_I8: + record["i8"] = py::cast(v.i8); + break; + case IREE_VM_VALUE_TYPE_I16: + record["i16"] = py::cast(v.i16); + break; + case IREE_VM_VALUE_TYPE_I32: + record["i32"] = py::cast(v.i32); + break; + case IREE_VM_VALUE_TYPE_I64: + record["i64"] = py::cast(v.i64); + break; + case IREE_VM_VALUE_TYPE_F32: + record["f32"] = py::cast(v.f32); + break; + case IREE_VM_VALUE_TYPE_F64: + record["f64"] = py::cast(v.f64); + break; + default: + throw RaiseValueError("Unsupported VM value type conversion"); + } + record["type"] = py::cast("value"); + return std::move(record); + } else if (v.type.ref_type == IREE_VM_REF_TYPE_NULL) { + py::dict record; + record["type"] = "null"; + return std::move(record); + } else if (iree_vm_type_def_is_ref(&v.type)) { + // Convert reference type. + if (iree_vm_list_isa(v.ref)) { + py::dict record; + record["type"] = "vm.list"; + py::list items; + iree_vm_list_t* sub_list = NULL; + CheckApiStatus(iree_vm_list_check_deref(v.ref, &sub_list), + "Could not deref list (wrong type?)"); + iree_vm_list_retain(sub_list); + VmVariantList sub_list_object(sub_list); + for (int i = 0, e = sub_list_object.size(); i < e; ++i) { + items.append(sub_list_object.GetAsSerializedTraceValue(i)); + } + record["items"] = std::move(items); + return std::move(record); + } else if (iree_hal_buffer_view_isa(v.ref)) { + py::dict record; + record["type"] = "hal.buffer_view"; + iree_hal_buffer_view_t* buffer_view = iree_hal_buffer_view_deref(v.ref); + if (!buffer_view) { + throw RaiseValueError( + "Could not deref result buffer view (wrong type?)"); + } + iree_hal_buffer_t* raw_buffer = iree_hal_buffer_view_buffer(buffer_view); + if (!raw_buffer) { + throw RaiseValueError("Could not deref result buffer (wrong type?)"); + } + + // Extract dims from the buffer view. + size_t rank = 0; + std::vector<int32_t> dims(6); + iree_status_t status = iree_hal_buffer_view_shape( + buffer_view, dims.capacity(), dims.data(), &rank); + if (iree_status_is_out_of_range(status)) { + dims.resize(rank); + status = iree_hal_buffer_view_shape(buffer_view, dims.capacity(), + dims.data(), &rank); + } + CheckApiStatus(status, "Error extracting shape"); + dims.resize(rank); + record["shape"] = py::cast(std::move(dims)); + + // Element type. + iree_hal_element_type_t element_type = + iree_hal_buffer_view_element_type(buffer_view); + // TODO: Would be nice to output as hex. + record["element_type"] = element_type; + + // Map memory. + iree_device_size_t byte_length = iree_hal_buffer_byte_length(raw_buffer); + iree_hal_buffer_mapping_t mapped_memory = {{0}}; + CheckApiStatus(iree_hal_buffer_map_range( + raw_buffer, IREE_HAL_MAPPING_MODE_SCOPED, + IREE_HAL_MEMORY_ACCESS_READ, 0 /* element_offset */, + byte_length, &mapped_memory), + "Could not map memory"); + record["contents"] = + py::bytes(reinterpret_cast<const char*>(mapped_memory.contents.data), + mapped_memory.contents.data_length); + iree_hal_buffer_unmap_range(&mapped_memory); + + return std::move(record); + } + } + + throw RaiseValueError("Unsupported VM to Python Type Conversion"); +} + +py::object VmVariantList::GetAsBufferView(int index) { + iree_vm_variant_t v = iree_vm_variant_empty(); + CheckApiStatus(iree_vm_list_get_variant(raw_ptr(), index, &v), + "Could not access list element"); + iree_hal_buffer_view_t* buffer_view = iree_hal_buffer_view_deref(v.ref); + if (!buffer_view) { + throw RaiseValueError("Could not deref result buffer view (wrong type?)"); + } + return py::cast(HalBufferView::BorrowFromRawPtr(buffer_view), + py::return_value_policy::move); +} + +namespace { + +static std::string ToHexString(const uint8_t* data, size_t length) { + static constexpr char kHexChars[] = {'0', '1', '2', '3', '4', '5', '6', '7', + '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'}; + std::string s(length * 2, ' '); + for (size_t i = 0; i < length; ++i) { + s[2 * i + 0] = kHexChars[(data[i] & 0xF0) >> 4]; + s[2 * i + 1] = kHexChars[(data[i] & 0x0F) >> 0]; + } + return s; +} +static std::string ToHexString(uint32_t value) { + return ToHexString((const uint8_t*)&value, sizeof(value)); +} + +void AppendListContents(std::string& out, iree_vm_list_t* list, + std::unordered_set<iree_vm_list_t*>& visited) { + for (iree_host_size_t i = 0, e = iree_vm_list_size(list); i < e; ++i) { + iree_vm_variant_t variant = iree_vm_variant_empty(); + iree_status_t status = iree_vm_list_get_variant(list, i, &variant); + if (!iree_status_is_ok(status)) { + iree_status_ignore(status); + out.append("Error"); + continue; + } + if (i > 0) out.append(", "); + + if (iree_vm_variant_is_value(variant)) { + // Convert a value type to a string. + switch (variant.type.value_type) { + case IREE_VM_VALUE_TYPE_I8: { + out += std::to_string(variant.i8); + break; + } + case IREE_VM_VALUE_TYPE_I16: { + out += std::to_string(variant.i16); + break; + } + case IREE_VM_VALUE_TYPE_I32: { + out += std::to_string(variant.i32); + break; + } + case IREE_VM_VALUE_TYPE_I64: { + out += std::to_string(variant.i64); + break; + } + case IREE_VM_VALUE_TYPE_F32: { + out += std::to_string(variant.f32); + break; + } + case IREE_VM_VALUE_TYPE_F64: { + out += std::to_string(variant.f64); + break; + } + default: + throw RaiseValueError("Unsupported VM value type to string"); + } + } else if (iree_vm_variant_is_ref(variant)) { + // Pretty print a subset of ABI impacting known types. + if (iree_hal_buffer_isa(variant.ref)) { + auto* hal_buffer = iree_hal_buffer_deref(variant.ref); + assert(hal_buffer); + out += std::string("HalBuffer(") + + std::to_string(iree_hal_buffer_byte_length(hal_buffer)) + ")"; + } else if (iree_hal_buffer_view_isa(variant.ref)) { + auto hal_bv = iree_hal_buffer_view_deref(variant.ref); + out += "HalBufferView("; + std::vector<int32_t> shape(iree_hal_buffer_view_shape_rank(hal_bv)); + iree_hal_buffer_view_shape(hal_bv, shape.size(), shape.data(), nullptr); + for (size_t i = 0; i < shape.size(); ++i) { + if (i > 0) out += 'x'; + out += std::to_string(shape[i]); + } + out += ":0x" + + ToHexString(static_cast<uint32_t>( + iree_hal_buffer_view_element_type(hal_bv))) + + ")"; + } else if (iree_vm_list_isa(variant.ref)) { + out.append("List["); + iree_vm_list_t* sub_list = iree_vm_list_deref(variant.ref); + if (visited.insert(sub_list).second) { + AppendListContents(out, sub_list, visited); + } else { + out.append("...circular..."); + } + out.append("]"); + } else { + out += "Unknown(" + std::to_string(variant.type.ref_type) + ")"; + } + } else { + out.append("None"); + } + } +} + +} // namespace + +std::string VmVariantList::DebugString() const { + // The variant list API requires mutability, so we const cast to it internally + // so we can maintain a const DebugString() for callers. + auto mutable_this = const_cast<VmVariantList*>(this); + std::string s = + std::string("<VmVariantList(") + std::to_string(size()) + "): ["; + iree_vm_list_t* list = mutable_this->raw_ptr(); + std::unordered_set<iree_vm_list_t*> visited; + visited.insert(list); + AppendListContents(s, list, visited); + s.append("]>"); + return s; +} + +void SetupVmBindings(pybind11::module m) { + IREE_CHECK_OK(iree_vm_register_builtin_types()); + IREE_CHECK_OK(iree_hal_module_register_types()); + + // Built-in module creation. + m.def("create_hal_module", &CreateHalModule); + + py::enum_<enum iree_vm_function_linkage_e>(m, "Linkage") + .value("INTERNAL", IREE_VM_FUNCTION_LINKAGE_INTERNAL) + .value("IMPORT", IREE_VM_FUNCTION_LINKAGE_IMPORT) + .value("EXPORT", IREE_VM_FUNCTION_LINKAGE_EXPORT) + .export_values(); + + // Mutation and inspection of the variant list is mostly opaque to python. + py::class_<VmVariantList>(m, "VmVariantList") + .def(py::init(&VmVariantList::Create)) + .def_property_readonly("size", &VmVariantList::size) + .def("__len__", &VmVariantList::size) + .def("get_as_buffer_view", &VmVariantList::GetAsBufferView) + .def("get_as_list", &VmVariantList::GetAsList) + .def("get_variant", &VmVariantList::GetVariant) + .def("get_serialized_trace_value", + &VmVariantList::GetAsSerializedTraceValue) + .def("push_float", &VmVariantList::PushFloat) + .def("push_int", &VmVariantList::PushInt) + .def("push_list", &VmVariantList::PushList) + .def("push_buffer_view", &VmVariantList::PushBufferView) + .def("__repr__", &VmVariantList::DebugString); + + py::class_<iree_vm_function_t>(m, "VmFunction") + .def_readonly("linkage", &iree_vm_function_t::linkage) + .def_readonly("ordinal", &iree_vm_function_t::ordinal) + .def_property_readonly("name", + [](iree_vm_function_t& self) { + iree_string_view_t name = + iree_vm_function_name(&self); + return py::str(name.data, name.size); + }) + .def_property_readonly("module_name", + [](iree_vm_function_t& self) { + iree_string_view_t name = + iree_vm_module_name(self.module); + return py::str(name.data, name.size); + }) + .def_property_readonly("reflection", + [](iree_vm_function_t& self) { + return GetFunctionReflectionDict(self); + }) + .def("__repr__", [](iree_vm_function_t& self) { + iree_string_view_t name = iree_vm_function_name(&self); + std::string repr("<VmFunction "); + repr.append(name.data, name.size); + + iree_vm_function_signature_t sig = iree_vm_function_signature(&self); + repr.append("("); + repr.append(sig.calling_convention.data, sig.calling_convention.size); + repr.append("), reflection = "); + py::dict reflection = GetFunctionReflectionDict(self); + repr.append(py::cast<std::string>(py::repr(reflection))); + repr.append(">"); + return repr; + }); + + py::class_<VmInstance>(m, "VmInstance").def(py::init(&VmInstance::Create)); + + py::class_<VmContext>(m, "VmContext") + .def(py::init(&VmContext::Create), py::arg("instance"), + py::arg("modules") = std::optional<std::vector<VmModule*>>()) + .def("register_modules", &VmContext::RegisterModules) + .def_property_readonly("context_id", &VmContext::context_id) + .def("invoke", &VmContext::Invoke); + + py::class_<VmModule>(m, "VmModule") + .def_static("from_flatbuffer", &VmModule::FromFlatbufferBlob) + .def_property_readonly("name", &VmModule::name) + .def("lookup_function", &VmModule::LookupFunction, py::arg("name"), + py::arg("linkage") = IREE_VM_FUNCTION_LINKAGE_EXPORT) + .def_property_readonly( + "stashed_flatbuffer_blob", + [](VmModule& self) { return self.get_stashed_flatbuffer_blob(); }) + .def_property_readonly( + "function_names", + [](VmModule& self) { + py::list names; + iree_vm_module_signature_t sig = + iree_vm_module_signature(self.raw_ptr()); + for (size_t ordinal = 0; ordinal < sig.export_function_count; + ++ordinal) { + iree_vm_function_t f; + auto status = iree_vm_module_lookup_function_by_ordinal( + self.raw_ptr(), IREE_VM_FUNCTION_LINKAGE_EXPORT, ordinal, &f); + if (iree_status_is_not_found(status)) { + iree_status_ignore(status); + break; + } + CheckApiStatus(status, "Error enumerating module"); + iree_string_view_t fname = iree_vm_function_name(&f); + py::str name(fname.data, fname.size); + names.append(name); + } + return names; + }) + .def("__repr__", [](VmModule& self) { + std::string repr("<VmModule "); + iree_string_view_t name = iree_vm_module_name(self.raw_ptr()); + repr.append(name.data, name.size); + + iree_vm_module_signature_t sig = + iree_vm_module_signature(self.raw_ptr()); + repr.append(" : ["); + for (size_t ordinal = 0; ordinal < sig.export_function_count; + ++ordinal) { + iree_vm_function_t f; + auto status = iree_vm_module_lookup_function_by_ordinal( + self.raw_ptr(), IREE_VM_FUNCTION_LINKAGE_EXPORT, ordinal, &f); + if (iree_status_is_not_found(status)) { + iree_status_ignore(status); + break; + } + CheckApiStatus(status, "Error enumerating module"); + iree_string_view_t fname = iree_vm_function_name(&f); + if (ordinal > 0) { + repr.append(", "); + } + repr.append(fname.data, fname.size); + } + repr.append("]"); + repr.append(">"); + return repr; + }); +} + +} // namespace python +} // namespace iree
diff --git a/runtime/bindings/python/iree/runtime/vm.h b/runtime/bindings/python/iree/runtime/vm.h new file mode 100644 index 0000000..48fdbab --- /dev/null +++ b/runtime/bindings/python/iree/runtime/vm.h
@@ -0,0 +1,168 @@ +// Copyright 2019 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_BINDINGS_PYTHON_IREE_RT_VM_H_ +#define IREE_BINDINGS_PYTHON_IREE_RT_VM_H_ + +#include <optional> + +#include "./binding.h" +#include "./hal.h" +#include "iree/base/api.h" +#include "iree/vm/api.h" +#include "iree/vm/bytecode_module.h" + +namespace iree { +namespace python { + +class FunctionAbi; + +//------------------------------------------------------------------------------ +// Retain/release bindings +//------------------------------------------------------------------------------ + +template <> +struct ApiPtrAdapter<iree_vm_instance_t> { + static void Retain(iree_vm_instance_t* b) { iree_vm_instance_retain(b); } + static void Release(iree_vm_instance_t* b) { iree_vm_instance_release(b); } +}; + +template <> +struct ApiPtrAdapter<iree_vm_context_t> { + static void Retain(iree_vm_context_t* b) { iree_vm_context_retain(b); } + static void Release(iree_vm_context_t* b) { iree_vm_context_release(b); } +}; + +template <> +struct ApiPtrAdapter<iree_vm_module_t> { + static void Retain(iree_vm_module_t* b) { iree_vm_module_retain(b); } + static void Release(iree_vm_module_t* b) { iree_vm_module_release(b); } +}; + +template <> +struct ApiPtrAdapter<iree_vm_invocation_t> { + static void Retain(iree_vm_invocation_t* b) { iree_vm_invocation_retain(b); } + static void Release(iree_vm_invocation_t* b) { + iree_vm_invocation_release(b); + } +}; + +//------------------------------------------------------------------------------ +// VmVariantList +//------------------------------------------------------------------------------ + +class VmVariantList { + public: + VmVariantList() : list_(nullptr) {} + ~VmVariantList() { + if (list_) { + iree_vm_list_release(list_); + } + } + + VmVariantList(VmVariantList&& other) { + list_ = other.list_; + other.list_ = nullptr; + } + + VmVariantList& operator=(const VmVariantList&) = delete; + VmVariantList(const VmVariantList&) = delete; + + static VmVariantList Create(iree_host_size_t capacity) { + iree_vm_list_t* list; + CheckApiStatus(iree_vm_list_create(/*element_type=*/nullptr, capacity, + iree_allocator_system(), &list), + "Error allocating variant list"); + return VmVariantList(list); + } + + iree_host_size_t size() const { return iree_vm_list_size(list_); } + + iree_vm_list_t* raw_ptr() { return list_; } + const iree_vm_list_t* raw_ptr() const { return list_; } + iree_vm_list_t* steal_raw_ptr() { + iree_vm_list_t* stolen = list_; + list_ = nullptr; + return stolen; + } + void AppendNullRef() { + iree_vm_ref_t null_ref = {0}; + CheckApiStatus(iree_vm_list_push_ref_move(raw_ptr(), &null_ref), + "Error appending to list"); + } + + std::string DebugString() const; + void PushFloat(double fvalue); + void PushInt(int64_t ivalue); + void PushList(VmVariantList& other); + void PushBufferView(HalBufferView& buffer_view); + py::object GetAsList(int index); + py::object GetAsBufferView(int index); + py::object GetVariant(int index); + py::object GetAsSerializedTraceValue(int index); + + private: + VmVariantList(iree_vm_list_t* list) : list_(list) {} + iree_vm_list_t* list_; +}; + +//------------------------------------------------------------------------------ +// ApiRefCounted types +//------------------------------------------------------------------------------ + +class VmInstance : public ApiRefCounted<VmInstance, iree_vm_instance_t> { + public: + static VmInstance Create(); +}; + +class VmModule : public ApiRefCounted<VmModule, iree_vm_module_t> { + public: + static VmModule FromFlatbufferBlob(py::object flatbuffer_blob_object); + + std::optional<iree_vm_function_t> LookupFunction( + const std::string& name, iree_vm_function_linkage_t linkage); + + std::string name() const { + auto name_sv = iree_vm_module_name(raw_ptr()); + return std::string(name_sv.data, name_sv.size); + } + + py::object get_stashed_flatbuffer_blob() { return stashed_flatbuffer_blob; } + + private: + // If the module was created from a flatbuffer blob, we stash it here. + py::object stashed_flatbuffer_blob = py::none(); +}; + +class VmContext : public ApiRefCounted<VmContext, iree_vm_context_t> { + public: + // Creates a context, optionally with modules, which will make the context + // static, disallowing further module registration (and may be more + // efficient). + static VmContext Create(VmInstance* instance, + std::optional<std::vector<VmModule*>> modules); + + // Registers additional modules. Only valid for non static contexts (i.e. + // those created without modules. + void RegisterModules(std::vector<VmModule*> modules); + + // Unique id for this context. + int context_id() const { return iree_vm_context_id(raw_ptr()); } + + // Synchronously invokes the given function. + void Invoke(iree_vm_function_t f, VmVariantList& inputs, + VmVariantList& outputs); +}; + +class VmInvocation : public ApiRefCounted<VmInvocation, iree_vm_invocation_t> { +}; + +void SetupVmBindings(pybind11::module m); + +} // namespace python +} // namespace iree + +#endif // IREE_BINDINGS_PYTHON_IREE_RT_VM_H_
diff --git a/runtime/bindings/python/iree/runtime/vm_test.py b/runtime/bindings/python/iree/runtime/vm_test.py new file mode 100644 index 0000000..d730442 --- /dev/null +++ b/runtime/bindings/python/iree/runtime/vm_test.py
@@ -0,0 +1,193 @@ +# Lint as: python3 +# Copyright 2019 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# pylint: disable=unused-variable + +from absl import logging +from absl.testing import absltest +import iree.compiler +import iree.runtime +import numpy as np + + +def create_add_scalar_module(): + binary = iree.compiler.compile_str( + """ + func.func @add_scalar(%arg0: i32, %arg1: i32) -> i32 { + %0 = arith.addi %arg0, %arg1 : i32 + return %0 : i32 + } + """, + input_type="mhlo", + target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS, + ) + m = iree.runtime.VmModule.from_flatbuffer(binary) + return m + + +def create_simple_static_mul_module(): + binary = iree.compiler.compile_str( + """ + func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + %0 = "mhlo.multiply"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> + } + """, + input_type="mhlo", + target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS, + ) + m = iree.runtime.VmModule.from_flatbuffer(binary) + return m + + +def create_simple_dynamic_abs_module(): + # TODO(laurenzo): Compile for more backends as dynamic shapes come online. + target_backends = iree.compiler.DEFAULT_TESTING_BACKENDS + binary = iree.compiler.compile_str( + """ + func.func @simple_mul(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> { + %0 = "mhlo.abs"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32> + return %0 : tensor<?x?xf32> + } + """, + input_type="mhlo", + target_backends=target_backends, + ) + m = iree.runtime.VmModule.from_flatbuffer(binary) + return m + + +class VmTest(absltest.TestCase): + + @classmethod + def setUpClass(cls): + super().setUpClass() + driver_names = iree.runtime.HalDriver.query() + logging.info("driver_names: %s", driver_names) + cls.driver = iree.runtime.HalDriver.create( + iree.compiler.core.DEFAULT_TESTING_DRIVER) + cls.device = cls.driver.create_default_device() + cls.hal_module = iree.runtime.create_hal_module(cls.device) + + def test_variant_list(self): + l = iree.runtime.VmVariantList(5) + logging.info("variant_list: %s", l) + self.assertEqual(l.size, 0) + + def test_variant_list_i64(self): + l = iree.runtime.VmVariantList(5) + # Push a value that exceeds 32-bit range. + l.push_int(10 * 1000 * 1000 * 1000) + self.assertEqual(str(l), "<VmVariantList(1): [10000000000]>") + + def test_variant_list_buffers(self): + ET = iree.runtime.HalElementType + for dt, et in ((np.int8, ET.SINT_8), (np.int16, ET.SINT_16), + (np.int32, ET.SINT_32), (np.int64, ET.SINT_64), + (np.uint8, ET.UINT_8), (np.uint16, ET.UINT_16), + (np.uint32, ET.UINT_32), (np.uint64, ET.UINT_64), + (np.float32, ET.FLOAT_32), (np.float64, ET.FLOAT_64)): + # TODO: Unimplemented: (np.float16, ET.FLOAT_16) + lst = iree.runtime.VmVariantList(5) + ary1 = np.asarray([1, 2, 3, 4], dtype=dt) + bv1 = self.device.allocator.allocate_buffer_copy( + memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, + allowed_usage=(iree.runtime.BufferUsage.DISPATCH | + iree.runtime.BufferUsage.TRANSFER | + iree.runtime.BufferUsage.MAPPING), + buffer=ary1, + element_type=et) + lst.push_buffer_view(bv1) + ary2 = iree.runtime.DeviceArray(self.device, + lst.get_as_buffer_view(0), + override_dtype=dt, + implicit_host_transfer=True) + np.testing.assert_array_equal(ary1, ary2) + with self.assertRaises(IndexError): + lst.get_as_buffer_view(1) + + def test_variant_list_list(self): + lst1 = iree.runtime.VmVariantList(5) + lst2 = iree.runtime.VmVariantList(5) + lst1.push_list(lst2) + self.assertEqual("<VmVariantList(1): [List[]]>", str(lst1)) + lstout = lst1.get_as_list(0) + self.assertEqual("<VmVariantList(0): []>", str(lstout)) + with self.assertRaises(IndexError): + lst1.get_as_list(1) + + def test_context_id(self): + instance = iree.runtime.VmInstance() + context1 = iree.runtime.VmContext(instance) + context2 = iree.runtime.VmContext(instance) + self.assertGreater(context2.context_id, context1.context_id) + + def test_module_basics(self): + m = create_simple_static_mul_module() + f = m.lookup_function("simple_mul") + self.assertGreaterEqual(f.ordinal, 0) + notfound = m.lookup_function("notfound") + self.assertIs(notfound, None) + + def test_dynamic_module_context(self): + instance = iree.runtime.VmInstance() + context = iree.runtime.VmContext(instance) + m = create_simple_static_mul_module() + context.register_modules([self.hal_module, m]) + + def test_static_module_context(self): + m = create_simple_static_mul_module() + logging.info("module: %s", m) + instance = iree.runtime.VmInstance() + logging.info("instance: %s", instance) + context = iree.runtime.VmContext(instance, modules=[self.hal_module, m]) + logging.info("context: %s", context) + + def test_dynamic_shape_compile(self): + m = create_simple_dynamic_abs_module() + logging.info("module: %s", m) + instance = iree.runtime.VmInstance() + logging.info("instance: %s", instance) + context = iree.runtime.VmContext(instance, modules=[self.hal_module, m]) + logging.info("context: %s", context) + + def test_add_scalar_new_abi(self): + m = create_add_scalar_module() + instance = iree.runtime.VmInstance() + context = iree.runtime.VmContext(instance, modules=[self.hal_module, m]) + f = m.lookup_function("add_scalar") + finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None) + result = finv(5, 6) + logging.info("result: %s", result) + self.assertEqual(result, 11) + + def test_synchronous_dynamic_shape_invoke_function_new_abi(self): + m = create_simple_dynamic_abs_module() + instance = iree.runtime.VmInstance() + context = iree.runtime.VmContext(instance, modules=[self.hal_module, m]) + f = m.lookup_function("simple_mul") + finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None) + arg0 = np.array([[-1., 2.], [3., -4.]], dtype=np.float32) + result = finv(arg0) + logging.info("result: %s", result) + np.testing.assert_allclose(result, [[1., 2.], [3., 4.]]) + + def test_synchronous_invoke_function_new_abi(self): + m = create_simple_static_mul_module() + instance = iree.runtime.VmInstance() + context = iree.runtime.VmContext(instance, modules=[self.hal_module, m]) + f = m.lookup_function("simple_mul") + finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None) + arg0 = np.array([1., 2., 3., 4.], dtype=np.float32) + arg1 = np.array([4., 5., 6., 7.], dtype=np.float32) + result = finv(arg0, arg1) + logging.info("result: %s", result) + np.testing.assert_allclose(result, [4., 10., 18., 28.]) + + +if __name__ == "__main__": + absltest.main()
diff --git a/runtime/pyproject.toml b/runtime/pyproject.toml new file mode 100644 index 0000000..22b4019 --- /dev/null +++ b/runtime/pyproject.toml
@@ -0,0 +1,14 @@ +[build-system] +requires = [ + "setuptools>=42", + "wheel", + # There is no fundamental reason to pin this CMake version, beyond + # build stability. + "cmake==3.22.2", + "ninja==1.10.2", + "packaging", + # Version 2.7.0 excluded: https://github.com/pybind/pybind11/issues/3136 + "pybind11>=2.6.0,!=2.7.0", + "PyYAML", +] +build-backend = "setuptools.build_meta"
diff --git a/runtime/setup.py b/runtime/setup.py new file mode 100644 index 0000000..fa684bc --- /dev/null +++ b/runtime/setup.py
@@ -0,0 +1,391 @@ +# Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# This builds just the runtime API and is relatively quick to build. +# To install: +# pip install . +# To build a wheel: +# pip wheel . +# +# It is recommended to build with Ninja and ccache. To do so, set environment +# variables by prefixing to above invocations: +# CMAKE_C_COMPILER_LAUNCHER=ccache CMAKE_CXX_COMPILER_LAUNCHER=ccache +# +# On CIs, it is often advantageous to re-use/control the CMake build directory. +# This can be set with the IREE_RUNTIME_API_CMAKE_BUILD_DIR env var. +# +# A custom package suffix can be specified with the environment variable: +# IREE_RUNTIME_CUSTOM_PACKAGE_SUFFIX +# +# Select CMake options are available from environment variables: +# IREE_HAL_DRIVER_CUDA +# IREE_HAL_DRIVER_VULKAN +# IREE_ENABLE_RUNTIME_TRACING +# IREE_BUILD_TRACY + +from gettext import install +import json +from multiprocessing.spawn import prepare +import os +import platform +import re +import shutil +import subprocess +import sys +import sysconfig + +from distutils.command.build import build as _build +from setuptools import find_namespace_packages, setup, Extension +from setuptools.command.build_ext import build_ext as _build_ext +from setuptools.command.build_py import build_py as _build_py + + +def check_pip_version(): + from packaging import version + # Pip versions < 22.0.3 default to out of tree builds, which is quite + # incompatible with what we do (and has other issues). Pip >= 22.0.4 + # removed this option entirely and are only in-tree builds. Since the + # old behavior can silently produce unworking installations, we aggressively + # suppress it. + try: + import pip + except ModuleNotFoundError: + # If pip not installed, we are obviously not trying to package via pip. + pass + else: + if (version.parse(pip.__version__) < version.parse("21.3")): + print("ERROR: pip version >= 21.3 required") + print("Upgrade: pip install pip --upgrade") + sys.exit(2) + + +check_pip_version() + +# This file can be run directly from the source tree or it can be CMake +# configured so it can run from the build tree with an already existing +# build tree. We detect the difference based on whether the following +# are expanded by CMake. +CONFIGURED_SOURCE_DIR = "@IREE_SOURCE_DIR@" +CONFIGURED_BINARY_DIR = "@IREE_BINARY_DIR@" + +IREE_SOURCE_DIR = None +IREE_BINARY_DIR = None + +# We must do the intermediate installation to a fixed location that agrees +# between what we pass to setup() and cmake. So hard-code it here. +# Note that setup() needs a relative path (to the setup.py file). +SETUPPY_DIR = os.path.realpath(os.path.dirname(__file__)) +CMAKE_INSTALL_DIR_REL = os.path.join("build", "cmake_install") +CMAKE_INSTALL_DIR_ABS = os.path.join(SETUPPY_DIR, CMAKE_INSTALL_DIR_REL) + +IS_CONFIGURED = CONFIGURED_SOURCE_DIR[0] != "@" +if IS_CONFIGURED: + IREE_SOURCE_DIR = CONFIGURED_SOURCE_DIR + IREE_BINARY_DIR = CONFIGURED_BINARY_DIR + print( + f"Running setup.py from build tree: " + f"SOURCE_DIR = {IREE_SOURCE_DIR} " + f"BINARY_DIR = {IREE_BINARY_DIR}", + file=sys.stderr) +else: + IREE_SOURCE_DIR = os.path.join(SETUPPY_DIR, "..") + IREE_BINARY_DIR = os.getenv("IREE_RUNTIME_API_CMAKE_BUILD_DIR") + if not IREE_BINARY_DIR: + # Note that setuptools always builds into a "build" directory that + # is a sibling of setup.py, so we just colonize a sub-directory of that + # by default. + IREE_BINARY_DIR = os.path.join(SETUPPY_DIR, "build", "cmake_build") + print( + f"Running setup.py from source tree: " + f"SOURCE_DIR = {IREE_SOURCE_DIR} " + f"BINARY_DIR = {IREE_BINARY_DIR}", + file=sys.stderr) + +# Setup and get version information. +VERSION_INFO_FILE = os.path.join(IREE_SOURCE_DIR, "version_info.json") + + +def load_version_info(): + with open(VERSION_INFO_FILE, "rt") as f: + return json.load(f) + + +try: + version_info = load_version_info() +except FileNotFoundError: + print("version_info.json not found. Using defaults", file=sys.stderr) + version_info = {} + +PACKAGE_SUFFIX = version_info.get("package-suffix") or "" +PACKAGE_VERSION = version_info.get("package-version") or "0.1dev1" + + +def maybe_nuke_cmake_cache(): + # From run to run under pip, we can end up with different paths to ninja, + # which isn't great and will confuse cmake. Detect if the location of + # ninja changes and force a cache flush. + ninja_path = "" + try: + import ninja + except ModuleNotFoundError: + pass + else: + ninja_path = ninja.__file__ + expected_stamp_contents = f"{sys.executable}\n{ninja_path}" + + # In order to speed things up on CI and not rebuild everything, we nuke + # the CMakeCache.txt file if the path to the Python interpreter changed. + # Ideally, CMake would let us reconfigure this dynamically... but it does + # not (and gets very confused). + PYTHON_STAMP_FILE = os.path.join(IREE_BINARY_DIR, "python_stamp.txt") + if os.path.exists(PYTHON_STAMP_FILE): + with open(PYTHON_STAMP_FILE, "rt") as f: + actual_stamp_contents = f.read() + if actual_stamp_contents == expected_stamp_contents: + # All good. + return + + # Mismatch or not found. Clean it. + cmake_cache_file = os.path.join(IREE_BINARY_DIR, "CMakeCache.txt") + if os.path.exists(cmake_cache_file): + print("Removing CMakeCache.txt because Python version changed", + file=sys.stderr) + os.remove(cmake_cache_file) + + # And write. + with open(PYTHON_STAMP_FILE, "wt") as f: + f.write(expected_stamp_contents) + + +def get_env_cmake_option(name: str, default_value: bool = False) -> bool: + svalue = os.getenv(name) + if not svalue: + svalue = "ON" if default_value else "OFF" + return f"-D{name}={svalue}" + + +def prepare_installation(): + subprocess.check_call(["cmake", "--version"]) + version_py_content = generate_version_py() + print(f"Generating version.py:\n{version_py_content}", file=sys.stderr) + + if not IS_CONFIGURED: + # Build from source tree. + os.makedirs(IREE_BINARY_DIR, exist_ok=True) + maybe_nuke_cmake_cache() + print(f"CMake build dir: {IREE_BINARY_DIR}", file=sys.stderr) + print(f"CMake install dir: {CMAKE_INSTALL_DIR_ABS}", file=sys.stderr) + cfg = "Release" + cmake_args = [ + "-GNinja", + "--log-level=VERBOSE", + "-DIREE_BUILD_PYTHON_BINDINGS=ON", + "-DIREE_BUILD_COMPILER=OFF", + "-DIREE_BUILD_SAMPLES=OFF", + "-DIREE_BUILD_TESTS=OFF", + "-DPython3_EXECUTABLE={}".format(sys.executable), + "-DCMAKE_BUILD_TYPE={}".format(cfg), + get_env_cmake_option("IREE_HAL_DRIVER_CUDA"), + get_env_cmake_option("IREE_HAL_DRIVER_VULKAN", + "OFF" if platform.system() == "Darwin" else "ON"), + get_env_cmake_option("IREE_ENABLE_RUNTIME_TRACING"), + get_env_cmake_option("IREE_BUILD_TRACY"), + ] + + # Only do a from-scratch configure if not already configured. + cmake_cache_file = os.path.join(IREE_BINARY_DIR, "CMakeCache.txt") + if not os.path.exists(cmake_cache_file): + print(f"Configuring with: {cmake_args}", file=sys.stderr) + subprocess.check_call(["cmake", IREE_SOURCE_DIR] + cmake_args, + cwd=IREE_BINARY_DIR) + else: + print(f"Not re-configuring (already configured)", file=sys.stderr) + + # Build. + subprocess.check_call([ + "cmake", "--build", ".", "--target", + "runtime/bindings/python/iree/runtime/all" + ], + cwd=IREE_BINARY_DIR) + print("Build complete.", file=sys.stderr) + + # Install the directory we care about. + install_subdirectory = os.path.join(IREE_BINARY_DIR, "runtime", "bindings", + "python", "iree", "runtime") + install_args = [ + "-DCMAKE_INSTALL_DO_STRIP=ON", + f"-DCMAKE_INSTALL_PREFIX={CMAKE_INSTALL_DIR_ABS}/", + "-P", + os.path.join(install_subdirectory, "cmake_install.cmake"), + ] + print(f"Installing with: {install_args}", file=sys.stderr) + subprocess.check_call(["cmake"] + install_args, cwd=install_subdirectory) + + # Write version.py directly into install dir. + version_py_file = os.path.join(CMAKE_INSTALL_DIR_ABS, "python_packages", + "iree_runtime", "iree", "runtime", + "version.py") + os.makedirs(os.path.dirname(version_py_file), exist_ok=True) + with open(version_py_file, "wt") as f: + f.write(version_py_content) + + print(f"Installation prepared: {CMAKE_INSTALL_DIR_ABS}", file=sys.stderr) + + +class CMakeBuildPy(_build_py): + + def run(self): + # It is critical that the target directory contain all built extensions, + # or else setuptools will helpfully compile an empty binary for us + # (this is the **worst** possible thing it could do). We just copy + # everything. What's another hundred megs between friends? + target_dir = os.path.abspath(self.build_lib) + print(f"Building in target dir: {target_dir}", file=sys.stderr) + os.makedirs(target_dir, exist_ok=True) + print("Copying install to target.", file=sys.stderr) + if os.path.exists(target_dir): + shutil.rmtree(target_dir) + shutil.copytree(os.path.join(CMAKE_INSTALL_DIR_ABS, "python_packages", + "iree_runtime"), + target_dir, + symlinks=False) + print("Target populated.", file=sys.stderr) + + +class CustomBuild(_build): + + def run(self): + self.run_command("build_py") + self.run_command("build_ext") + self.run_command("build_scripts") + + +class CMakeExtension(Extension): + + def __init__(self, name, sourcedir=""): + Extension.__init__(self, name, sources=[]) + self.sourcedir = os.path.abspath(sourcedir) + + +class NoopBuildExtension(_build_ext): + + def __init__(self, *args, **kwargs): + assert False + + def build_extension(self, ext): + pass + + +def generate_version_py(): + return f"""# Auto-generated version info. +PACKAGE_SUFFIX = "{PACKAGE_SUFFIX}" +VERSION = "{PACKAGE_VERSION}" +REVISIONS = {json.dumps(find_git_versions())} +""" + + +def find_git_versions(): + revisions = {} + try: + revisions["IREE"] = subprocess.check_output( + ["git", "rev-parse", "HEAD"], + cwd=IREE_SOURCE_DIR).decode("utf-8").strip() + except subprocess.SubprocessError as e: + print(f"ERROR: Could not get IREE revision: {e}", file=sys.stderr) + return revisions + + +def find_git_submodule_revision(submodule_path): + try: + data = subprocess.check_output(["git", "ls-tree", "HEAD", submodule_path], + cwd=IREE_SOURCE_DIR).decode("utf-8").strip() + columns = re.split("\\s+", data) + return columns[2] + except Exception as e: + print( + f"ERROR: Could not get submodule revision for {submodule_path}" + f" ({e})", + file=sys.stderr) + return "" + + +prepare_installation() + +packages = find_namespace_packages(where=os.path.join(CMAKE_INSTALL_DIR_ABS, + "python_packages", + "iree_runtime"), + include=[ + "iree.runtime", + "iree.runtime.*", + ]) +print(f"Found runtime packages: {packages}") + +with open( + os.path.join(IREE_SOURCE_DIR, "runtime", "bindings", "python", "iree", + "runtime", "README.md"), "rt") as f: + README = f.read() + +custom_package_suffix = os.getenv("IREE_RUNTIME_CUSTOM_PACKAGE_SUFFIX") +if not custom_package_suffix: + custom_package_suffix = "" + +setup( + name=f"iree-runtime{PACKAGE_SUFFIX}{custom_package_suffix}", + version=f"{PACKAGE_VERSION}", + author="IREE Authors", + author_email="iree-discuss@googlegroups.com", + description="IREE Python Runtime Components", + long_description=README, + long_description_content_type="text/markdown", + license="Apache-2.0", + classifiers=[ + "Development Status :: 3 - Alpha", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + ], + url="https://github.com/google/iree", + python_requires=">=3.7", + ext_modules=[ + CMakeExtension("iree.runtime.binding"), + ], + cmdclass={ + "build": CustomBuild, + "built_ext": NoopBuildExtension, + "build_py": CMakeBuildPy, + }, + zip_safe=False, + package_dir={ + # Note: Must be relative path, so we line this up with the absolute + # path built above. Note that this must exist prior to the call. + "": f"{CMAKE_INSTALL_DIR_REL}/python_packages/iree_runtime", + }, + packages=packages, + # Matching the native extension as a data file keeps setuptools from + # "building" it (i.e. turning it into a static binary). + package_data={ + "": [ + f"*{sysconfig.get_config_var('EXT_SUFFIX')}", + "iree-run-module*", + "iree-run-trace*", + "iree-benchmark-trace*", + "iree-tracy-capture*", + ], + }, + entry_points={ + "console_scripts": [ + "iree-run-module = iree.runtime.scripts.iree_run_module.__main__:main", + "iree-run-trace = iree.runtime.scripts.iree_run_trace.__main__:main", + "iree-benchmark-trace = iree.runtime.scripts.iree_benchmark_trace.__main__:main", + "iree-tracy-capture = iree.runtime.scripts.iree_tracy_capture.__main__:main", + ], + }, + install_requires=[ + "numpy", + "PyYAML", + ], +)