[python] Add a HalDeviceLoop class for routing runtime events to futures. (#16385)
This is the most basic level of interop possible with the asyncio system
in Python.
As noted offline when we were looking at this, the test case chosen
seems to be exposing a synchronization bug in local-task, causing a wait
to not return even though there are signaled handles. This does not (can
not) occur in local-sync, and appears to not be an issue in vulkan. As
such, landing with local-sync so that we can iterate further.
diff --git a/runtime/bindings/python/CMakeLists.txt b/runtime/bindings/python/CMakeLists.txt
index 28fb964..83fb55b 100644
--- a/runtime/bindings/python/CMakeLists.txt
+++ b/runtime/bindings/python/CMakeLists.txt
@@ -68,6 +68,8 @@
"io.cc"
"hal.h"
"hal.cc"
+ "loop.h"
+ "loop.cc"
"numpy_interop.h"
"numpy_interop.cc"
"py_module.h"
diff --git a/runtime/bindings/python/binding.h b/runtime/bindings/python/binding.h
index 4465879..c9a3a30 100644
--- a/runtime/bindings/python/binding.h
+++ b/runtime/bindings/python/binding.h
@@ -18,6 +18,19 @@
#include "nanobind/stl/string_view.h"
#include "nanobind/stl/vector.h"
+// Uncomment the following to enable various noisy output to stderr for
+// verifying sequencing and reference counting.
+// #define IREE_PY_TRACE_ENABLED 1
+
+#if IREE_PY_TRACE_ENABLED
+#define IREE_PY_TRACEF(fmt, ...) \
+ fprintf(stderr, "[[IREE_PY_TRACE]]: " fmt "\n", __VA_ARGS__)
+#define IREE_PY_TRACE(msg) fprintf(stderr, "[[IREE_PY_TRACE]]: %s", msg)
+#else
+#define IREE_PY_TRACEF(...)
+#define IREE_PY_TRACE(...)
+#endif
+
namespace iree {
namespace python {
diff --git a/runtime/bindings/python/initialize_module.cc b/runtime/bindings/python/initialize_module.cc
index bac89bd..7eb9254 100644
--- a/runtime/bindings/python/initialize_module.cc
+++ b/runtime/bindings/python/initialize_module.cc
@@ -8,6 +8,7 @@
#include "./hal.h"
#include "./invoke.h"
#include "./io.h"
+#include "./loop.h"
#include "./numpy_interop.h"
#include "./py_module.h"
#include "./status_utils.h"
@@ -20,6 +21,7 @@
NB_MODULE(_runtime, m) {
numpy::InitializeNumPyInterop();
+ IREE_TRACE_APP_ENTER();
IREE_CHECK_OK(iree_hal_register_all_available_drivers(
iree_hal_driver_registry_default()));
@@ -28,6 +30,7 @@
SetupHalBindings(m);
SetupInvokeBindings(m);
SetupIoBindings(m);
+ SetupLoopBindings(m);
SetupPyModuleBindings(m);
SetupVmBindings(m);
diff --git a/runtime/bindings/python/iree/runtime/__init__.py b/runtime/bindings/python/iree/runtime/__init__.py
index b95b773..056faa7 100644
--- a/runtime/bindings/python/iree/runtime/__init__.py
+++ b/runtime/bindings/python/iree/runtime/__init__.py
@@ -30,6 +30,7 @@
HalBufferView,
HalCommandBuffer,
HalDevice,
+ HalDeviceLoopBridge,
HalDriver,
HalElementType,
HalFence,
diff --git a/runtime/bindings/python/loop.cc b/runtime/bindings/python/loop.cc
new file mode 100644
index 0000000..d191bb3
--- /dev/null
+++ b/runtime/bindings/python/loop.cc
@@ -0,0 +1,309 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "./loop.h"
+
+#include <cstdio>
+#include <vector>
+
+#include "./hal.h"
+#include "iree/base/internal/synchronization.h"
+
+namespace iree::python {
+
+namespace {
+
+static const char kHalDeviceLoopBridgeDocstring[] =
+ R"(Bridges device semaphore signalling to asyncio futures.
+
+This is intended to be run alongside an asyncio loop, allowing arbitrary
+semaphore timepoints to be bridged to the loop, satisfying futures.
+
+Internally, it starts a thread which spins to poll the requested semaphores
+(which all must be from the same device). It can be used in single-device
+cases as a simpler implementation than a full integration with an asyncio
+event loop, theoretically resulting in fewer heavy-weight, kernel/device
+synchronization interactions.
+)";
+
+class HalDeviceLoopBridge {
+ public:
+ HalDeviceLoopBridge(HalDevice device, py::object loop)
+ : device_(std::move(device)), loop_(std::move(loop)) {
+ IREE_PY_TRACEF("new HalDeviceLoopBridge (%p)", this);
+ iree_slim_mutex_initialize(&mu_);
+ CheckApiStatus(
+ iree_hal_semaphore_create(device_.raw_ptr(), 0, &control_sem_),
+ "create semaphore");
+
+ loop_call_soon_ = loop_.attr("call_soon_threadsafe");
+
+ // Start the thread.
+ auto threading_m = py::module_::import_("threading");
+ thread_ = threading_m.attr("Thread")(
+ /*group=*/py::none(),
+ /*target=*/py::cpp_function([this]() { Run(); }),
+ /*name=*/"HalDeviceLoopBridge");
+ thread_.attr("start")();
+ }
+ ~HalDeviceLoopBridge() {
+ IREE_PY_TRACEF("~HalDeviceLoopBridge(%p)", this);
+
+ // Stopping the thread during destruction is not great. But it is better
+ // than invalidating live memory.
+ if (!thread_.is_none()) {
+ auto warnings_m = py::module_::import_("warnings");
+ warnings_m.attr("warn")(
+ "HalDeviceLoopBridge deleted while running. Recommend explicitly "
+ "calling stop() to avoid hanging the gc");
+ Stop();
+ }
+
+ // Cancel all futures.
+ iree_slim_mutex_lock(&mu_);
+ for (auto &entry : next_pending_futures_) {
+ iree_hal_semaphore_release(std::get<0>(entry));
+ py::handle future = std::get<2>(entry);
+ py::handle value = std::get<3>(entry);
+ CancelFuture(future, value);
+ }
+ next_pending_futures_.clear();
+ iree_slim_mutex_unlock(&mu_);
+
+ iree_slim_mutex_deinitialize(&mu_);
+ iree_hal_semaphore_release(control_sem_);
+ }
+
+ void Stop() {
+ if (thread_.is_none()) {
+ IREE_PY_TRACEF("HalDeviceLoopBridge::Stop(%p): Already stopped", this);
+ return;
+ }
+ IREE_PY_TRACEF("HalDeviceLoopBridge::Stop(%p)", this);
+ iree_slim_mutex_lock(&mu_);
+ shutdown_signaled_ = true;
+ auto status = iree_hal_semaphore_signal(control_sem_, control_next_++);
+ iree_slim_mutex_unlock(&mu_);
+ CheckApiStatus(status, "iree_hal_semaphore_signal");
+ thread_.attr("join")();
+ thread_ = py::none();
+ IREE_PY_TRACEF("HalDeviceLoopBridge::Stop(%p): Joined", this);
+ }
+
+ void Run() {
+ IREE_PY_TRACEF("HalDeviceLoopBridge::Run(%p)", this);
+ py::gil_scoped_release gil_release;
+ // Wait list.
+ std::vector<iree_hal_semaphore_t *> wait_semaphores;
+ std::vector<uint64_t> wait_payloads;
+ wait_semaphores.reserve(5);
+ wait_payloads.reserve(5);
+ // Pending futures that are actively being waited on. Owned by Run().
+ std::vector<
+ std::tuple<iree_hal_semaphore_t *, uint64_t, py::handle, py::handle>>
+ pending_futures;
+ // Scratch pad of pending futures that we must keep waiting on. Owned by
+ // Run().
+ std::vector<
+ std::tuple<iree_hal_semaphore_t *, uint64_t, py::handle, py::handle>>
+ scratch_pending_futures;
+ pending_futures.reserve(next_pending_futures_.capacity());
+
+ bool keep_running = true;
+ uint64_t next_control_wakeup = 1;
+ while (true) {
+ IREE_PY_TRACEF("HalDeviceLoopBridge::Run(%p): Loop begin", this);
+ // Transfer any pending futures into the current list.
+ iree_slim_mutex_lock(&mu_);
+ while (!next_pending_futures_.empty()) {
+ pending_futures.push_back(std::move(next_pending_futures_.back()));
+ next_pending_futures_.pop_back();
+ }
+ keep_running = !shutdown_signaled_;
+ iree_slim_mutex_unlock(&mu_);
+ if (!keep_running) {
+ IREE_PY_TRACEF("HalDeviceLoopBridge::Run(%p): Loop break", this);
+ break;
+ }
+ wait_semaphores.clear();
+ wait_payloads.clear();
+
+ // Poll all futures and dispatch. Any that are still pending are routed
+ // to the scratch_pending_futures. Important: we don't hold the gil so
+ // can not do anything that toggles reference counts or calls Python yet.
+ iree_status_t status;
+ for (size_t i = 0; i < pending_futures.size(); ++i) {
+ auto &entry = pending_futures[i];
+ uint64_t current_payload;
+ iree_hal_semaphore_t *semaphore = std::get<0>(entry);
+ status = iree_hal_semaphore_query(semaphore, ¤t_payload);
+ if (iree_status_is_ok(status)) {
+ if (current_payload >= std::get<1>(entry)) {
+ // All done.
+ iree_hal_semaphore_release(semaphore);
+ SignalFuture(std::get<2>(entry), std::get<3>(entry));
+ } else {
+ // Keep it pending.
+ IREE_PY_TRACEF(" Add to wait list: semaphore=%p, payload=%" PRIu64,
+ semaphore, std::get<1>(entry));
+ wait_semaphores.push_back(semaphore);
+ wait_payloads.push_back(std::get<1>(entry));
+ scratch_pending_futures.push_back(std::move(entry));
+ }
+ } else {
+ iree_hal_semaphore_release(semaphore);
+ SignalFutureFailure(std::get<2>(entry), std::get<3>(entry), status);
+ }
+ }
+ pending_futures.clear();
+ pending_futures.swap(scratch_pending_futures);
+
+ // Add the control semaphore.
+ wait_semaphores.push_back(control_sem_);
+ wait_payloads.push_back(next_control_wakeup);
+
+ // Wait.
+ IREE_PY_TRACEF("HalDeviceLoopBridge::Run(%p): wait_semaphores(%zu)", this,
+ wait_semaphores.size());
+ status = iree_hal_device_wait_semaphores(
+ device_.raw_ptr(), IREE_HAL_WAIT_MODE_ANY,
+ {wait_semaphores.size(), wait_semaphores.data(),
+ wait_payloads.data()},
+ iree_infinite_timeout());
+ if (!iree_status_is_ok(status)) {
+ py::gil_scoped_acquire acquire_gil;
+ CheckApiStatus(
+ status, "iree_hal_device_wait_semaphores from HalDeviceLoopBridge");
+ }
+
+ status = iree_hal_semaphore_query(control_sem_, &next_control_wakeup);
+ if (!iree_status_is_ok(status)) {
+ py::gil_scoped_acquire acquire_gil;
+ CheckApiStatus(
+ status, "iree_hal_device_wait_semaphores from HalDeviceLoopBridge");
+ }
+ next_control_wakeup += 1;
+ IREE_PY_TRACEF("HalDeviceLoopBridge::Run(%p): Loop end", this);
+ }
+
+ // Cancel all pending futures.
+ {
+ for (auto &entry : pending_futures) {
+ iree_hal_semaphore_release(std::get<0>(entry));
+ py::handle future = std::get<2>(entry);
+ py::handle value = std::get<3>(entry);
+ CancelFuture(future, value);
+ }
+ }
+
+ IREE_PY_TRACEF("HalDeviceLoopBridge::Run(%p): Thread complete", this);
+ }
+
+ void CancelFuture(py::handle future, py::handle value) {
+ IREE_PY_TRACEF("HalDeviceLoopBridge::CancelFuture(%p)", future.ptr());
+ py::gil_scoped_acquire acquire_gil;
+ try {
+ future.attr("cancel")();
+ } catch (py::python_error &e) {
+ ReportUncaughtException(e);
+ }
+ future.dec_ref();
+ value.dec_ref();
+ }
+
+ void SignalFuture(py::handle future, py::handle value) {
+ IREE_PY_TRACEF("HalDeviceLoopBridge::SignalFuture(%p)", future.ptr());
+ py::gil_scoped_acquire acquire_gil;
+ py::object future_owned = py::steal(future);
+ py::object value_owned = py::steal(value);
+ loop_call_soon_(py::cpp_function([future_owned = std::move(future_owned),
+ value_owned = std::move(value_owned)]() {
+ future_owned.attr("set_result")(value_owned);
+ }));
+ }
+
+ void SignalFutureFailure(py::handle future, py::handle value,
+ iree_status_t status) {
+ py::gil_scoped_acquire acquire_gil;
+ py::object future_owned = py::steal(future);
+ py::object value_owned = py::steal(value);
+ std::string message = ApiStatusToString(status);
+ IREE_PY_TRACEF("HalDeviceLoopBridge::SignalFutureFailure(future=%p) : %s",
+ future.ptr(), message.c_str());
+ iree_status_ignore(status);
+ loop_call_soon_(py::cpp_function([future_owned = std::move(future_owned),
+ value_owned = std::move(value_owned),
+ message = std::move(message)]() {
+ PyErr_SetString(PyExc_RuntimeError, message.c_str());
+ PyObject *exc_type;
+ PyObject *exc_value;
+ PyObject *exc_tb;
+ PyErr_Fetch(&exc_type, &exc_value, &exc_tb);
+ future_owned.attr("set_exception")(exc_value);
+ Py_XDECREF(exc_type);
+ Py_XDECREF(exc_tb);
+ Py_XDECREF(exc_value);
+ }));
+ }
+
+ py::object OnSemaphore(HalSemaphore semaphore, uint64_t payload,
+ py::object value) {
+ IREE_PY_TRACEF(
+ "HalDeviceLoopBridge::OnSemaphore(semaphore=%p, payload=%" PRIu64 ")",
+ semaphore.raw_ptr(), payload);
+ py::object future = loop_.attr("create_future")();
+ iree_slim_mutex_lock(&mu_);
+ next_pending_futures_.push_back(std::make_tuple(
+ semaphore.steal_raw_ptr(), payload, future, value.release()));
+ future.inc_ref();
+ auto status = iree_hal_semaphore_signal(control_sem_, control_next_++);
+ iree_slim_mutex_unlock(&mu_);
+ CheckApiStatus(status, "iree_hal_semaphore_signal");
+ return future;
+ }
+
+ private:
+ // Certain calls into Futures may raise exceptions because of illegal states.
+ // There is really not much we can do about this, so we attempt to report.
+ // TODO: Have some kind of fatal exception hook.
+ void ReportUncaughtException(py::python_error &e) {
+ e.discard_as_unraisable(py::str(__func__));
+ }
+
+ iree_slim_mutex_t mu_;
+ HalDevice device_;
+ py::object loop_;
+ py::object thread_;
+ py::object loop_call_soon_;
+ iree_hal_semaphore_t *control_sem_ = nullptr;
+ uint64_t control_next_ = 1;
+ bool shutdown_signaled_ = false;
+
+ // Incoming futures to add to the pending list on next cycle. Must be locked
+ // with mu_.
+ // Note that because these structures are processed without the GIL being
+ // held, we cannot unexpectedly do any reference count manipulation.
+ // Therefore, when added here, it is added with a reference. And the reference
+ // must be returned when retired.
+ // Fields: Semaphore, wait_payload_value, future, future_value
+ std::vector<
+ std::tuple<iree_hal_semaphore_t *, uint64_t, py::handle, py::handle>>
+ next_pending_futures_;
+};
+
+} // namespace
+
+void SetupLoopBindings(py::module_ &m) {
+ py::class_<HalDeviceLoopBridge>(m, "HalDeviceLoopBridge")
+ .def(py::init<HalDevice, py::object>(), py::arg("device"),
+ py::arg("loop"))
+ .def("stop", &HalDeviceLoopBridge::Stop)
+ .def("on_semaphore", &HalDeviceLoopBridge::OnSemaphore,
+ py::arg("semaphore"), py::arg("payload"), py::arg("value"))
+ .doc() = kHalDeviceLoopBridgeDocstring;
+}
+
+} // namespace iree::python
diff --git a/runtime/bindings/python/loop.h b/runtime/bindings/python/loop.h
new file mode 100644
index 0000000..b8c91ad
--- /dev/null
+++ b/runtime/bindings/python/loop.h
@@ -0,0 +1,18 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_BINDINGS_PYTHON_IREE_RT_LOOP_H_
+#define IREE_BINDINGS_PYTHON_IREE_RT_LOOP_H_
+
+#include "./binding.h"
+
+namespace iree::python {
+
+void SetupLoopBindings(py::module_ &m);
+
+} // namespace iree::python
+
+#endif // IREE_BINDINGS_PYTHON_IREE_RT_LOOP_H_
diff --git a/runtime/bindings/python/status_utils.cc b/runtime/bindings/python/status_utils.cc
index 36941c2..46d293c 100644
--- a/runtime/bindings/python/status_utils.cc
+++ b/runtime/bindings/python/status_utils.cc
@@ -26,7 +26,9 @@
}
}
-static std::string ApiStatusToString(iree_status_t status) {
+} // namespace
+
+std::string ApiStatusToString(iree_status_t status) {
iree_host_size_t buffer_length = 0;
if (IREE_UNLIKELY(!iree_status_format(status, /*buffer_capacity=*/0,
/*buffer=*/NULL, &buffer_length))) {
@@ -41,8 +43,6 @@
: "";
}
-} // namespace
-
nanobind::python_error ApiStatusToPyExc(iree_status_t status,
const char* message) {
assert(!iree_status_is_ok(status));
diff --git a/runtime/bindings/python/status_utils.h b/runtime/bindings/python/status_utils.h
index 03181ae..936ea72 100644
--- a/runtime/bindings/python/status_utils.h
+++ b/runtime/bindings/python/status_utils.h
@@ -25,6 +25,8 @@
return RaisePyError(PyExc_ValueError, message);
}
+std::string ApiStatusToString(iree_status_t status);
+
nanobind::python_error ApiStatusToPyExc(iree_status_t status,
const char* message);
diff --git a/runtime/bindings/python/tests/hal_device_loop_test.py b/runtime/bindings/python/tests/hal_device_loop_test.py
new file mode 100644
index 0000000..d3c4717
--- /dev/null
+++ b/runtime/bindings/python/tests/hal_device_loop_test.py
@@ -0,0 +1,81 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import asyncio
+import timeit
+import unittest
+
+from iree.runtime import (
+ get_device,
+ HalDeviceLoopBridge,
+)
+
+
+class HalDeviceLoopBridgeTest(unittest.TestCase):
+ def testBridge(self):
+ loop = asyncio.new_event_loop()
+ bridge = HalDeviceLoopBridge(self.device, loop)
+ sem1 = None
+ sem2 = None
+ report = None
+
+ async def main():
+ def done_1(x):
+ report("PYTHON: sem2.signal(1)")
+ sem2.signal(1)
+
+ def done_2(x):
+ report("PYTHON: sem2.signal(2)")
+ sem2.signal(2)
+
+ f1 = bridge.on_semaphore(sem1, 1, "Semaphore 1 Signaled")
+ f1.add_done_callback(done_1)
+ f2 = bridge.on_semaphore(sem2, 1, "Semaphore 2 Signaled")
+ f2.add_done_callback(done_2)
+ f2_again = bridge.on_semaphore(sem2, 2, "Semaphore 2 Signaled Again")
+
+ sem1.signal(1)
+ f1_result = await f1
+ report("PYTHON: await f1 =", f1_result)
+ f2_result = await f2
+ report("PYTHON: await f2 =", f2_result)
+ f2_again_result = await f2_again
+ report("PYTHON: await f2_again =", f2_again_result)
+
+ self.assertEqual(f1_result, "Semaphore 1 Signaled")
+ self.assertEqual(f2_result, "Semaphore 2 Signaled")
+ self.assertEqual(f2_again_result, "Semaphore 2 Signaled Again")
+ report("PYTHON: ASYNC MAIN() COMPLETE")
+
+ def run_iter(with_report):
+ nonlocal sem1
+ nonlocal sem2
+ nonlocal report
+ sem1 = self.device.create_semaphore(0)
+ sem2 = self.device.create_semaphore(0)
+ if with_report:
+ report = lambda *args: print(*args)
+ else:
+ report = lambda *args: None
+ loop.run_until_complete(main())
+
+ try:
+ run_iter(True)
+ iter_time = timeit.timeit("run_iter(False)", globals=locals(), number=10)
+ print(f"Time/iter = {iter_time}s")
+ finally:
+ bridge.stop()
+
+ def setUp(self):
+ super().setUp()
+ # TODO: Switch to local-task (experiencing some wait deadlocking
+ # that needs triage).
+ self.device = get_device("local-sync")
+ self.allocator = self.device.allocator
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/runtime/src/iree/base/internal/wait_handle_poll.c b/runtime/src/iree/base/internal/wait_handle_poll.c
index b2adc9b..ef0db0d 100644
--- a/runtime/src/iree/base/internal/wait_handle_poll.c
+++ b/runtime/src/iree/base/internal/wait_handle_poll.c
@@ -337,7 +337,9 @@
// Make the syscall only when we have at least one valid fd.
// Don't use this as a sleep.
if (set->handle_count <= 0) {
- memset(out_wake_handle, 0, sizeof(*out_wake_handle));
+ if (out_wake_handle) {
+ memset(out_wake_handle, 0, sizeof(*out_wake_handle));
+ }
return iree_ok_status();
}
@@ -354,18 +356,20 @@
&signaled_count));
// Find at least one signaled handle.
- memset(out_wake_handle, 0, sizeof(*out_wake_handle));
- if (signaled_count > 0) {
- for (iree_host_size_t i = 0; i < set->handle_count; ++i) {
- bool signaled = false;
- IREE_RETURN_AND_END_ZONE_IF_ERROR(
- z0, iree_wait_set_resolve_poll_events(set->poll_fds[i].revents,
- &signaled));
- if (signaled) {
- memcpy(out_wake_handle, &set->user_handles[i],
- sizeof(*out_wake_handle));
- out_wake_handle->set_internal.index = i;
- break;
+ if (out_wake_handle) {
+ memset(out_wake_handle, 0, sizeof(*out_wake_handle));
+ if (signaled_count > 0) {
+ for (iree_host_size_t i = 0; i < set->handle_count; ++i) {
+ bool signaled = false;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_wait_set_resolve_poll_events(set->poll_fds[i].revents,
+ &signaled));
+ if (signaled) {
+ memcpy(out_wake_handle, &set->user_handles[i],
+ sizeof(*out_wake_handle));
+ out_wake_handle->set_internal.index = i;
+ break;
+ }
}
}
}