[python] Add a couple more async APIs. (#16419)
diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc
index 84c2d52..89d18db 100644
--- a/runtime/bindings/python/hal.cc
+++ b/runtime/bindings/python/hal.cc
@@ -69,8 +69,8 @@
signal_semaphores: Semaphores/Fence to signal.
)";
-static const char kHalFenceWait[] =
- R"(Waits until the fence is signalled or errored.
+static const char kHalWait[] =
+ R"(Waits until the semaphore or fence is signalled or errored.
Three wait cases are supported:
* timeout: Relative nanoseconds to wait.
@@ -106,6 +106,19 @@
return ToHexString((const uint8_t*)&value, sizeof(value));
}
+iree_timeout_t NormalizeTimeout(std::optional<iree_duration_t> timeout,
+ std::optional<iree_time_t> deadline) {
+ if (!timeout && !deadline) {
+ return iree_infinite_timeout();
+ } else if (timeout && deadline) {
+ throw std::invalid_argument("timeout and deadline cannot both be set");
+ } else if (timeout) {
+ return iree_make_timeout_ns(*timeout);
+ } else {
+ return iree_timeout_t{IREE_TIMEOUT_ABSOLUTE, *deadline};
+ }
+}
+
} // namespace
//------------------------------------------------------------------------------
@@ -1119,6 +1132,15 @@
.def("__repr__", &HalBufferView::Repr);
py::class_<HalSemaphore>(m, "HalSemaphore")
+ .def(
+ "fail",
+ [](HalSemaphore& self, std::string& message) {
+ // TODO: Take some category enum and use that is available.
+ iree_status_t status =
+ iree_make_status(IREE_STATUS_UNKNOWN, "%s", message.c_str());
+ iree_hal_semaphore_fail(self.raw_ptr(), status);
+ },
+ py::arg("message"))
.def("query",
[](HalSemaphore& self) {
uint64_t out_value;
@@ -1127,10 +1149,52 @@
"querying semaphore");
return out_value;
})
- .def("signal", [](HalSemaphore& self, uint64_t new_value) {
- CheckApiStatus(iree_hal_semaphore_signal(self.raw_ptr(), new_value),
- "signaling semaphore");
- });
+ .def("signal",
+ [](HalSemaphore& self, uint64_t new_value) {
+ CheckApiStatus(
+ iree_hal_semaphore_signal(self.raw_ptr(), new_value),
+ "signaling semaphore");
+ })
+ .def(
+ "wait",
+ [](HalSemaphore& self, uint64_t payload,
+ std::optional<iree_duration_t> timeout,
+ std::optional<iree_time_t> deadline) -> bool {
+ iree_timeout_t t = NormalizeTimeout(timeout, deadline);
+ iree_status_t status;
+ uint64_t unused_value;
+ {
+ py::gil_scoped_release release;
+ status = iree_hal_semaphore_wait(self.raw_ptr(), payload, t);
+ }
+ if (iree_status_is_deadline_exceeded(status)) {
+ // Time out.
+ return false;
+ } else if (iree_status_is_aborted(status)) {
+ // Synchronous failure.
+ iree_status_ignore(status);
+ status = iree_hal_semaphore_query(self.raw_ptr(), &unused_value);
+ if (iree_status_is_ok(status)) {
+ status = iree_make_status(
+ IREE_STATUS_FAILED_PRECONDITION,
+ "expected synchronous status failure missing");
+ }
+ CheckApiStatus(status, "synchronous semaphore failure");
+ } else {
+ // General failure check.
+ CheckApiStatus(status, "waiting for semaphore");
+ }
+
+ // Asynchronous failure.
+ status = iree_hal_semaphore_query(self.raw_ptr(), &unused_value);
+ if (iree_status_is_deferred(status)) {
+ return false;
+ }
+ CheckApiStatus(status, "asynchronous semaphore failure");
+ return true;
+ },
+ py::arg("payload"), py::arg("timeout") = py::none(),
+ py::arg("deadline") = py::none(), kHalWait);
auto hal_fence = py::class_<HalFence>(m, "HalFence");
VmRef::BindRefProtocol(hal_fence, iree_hal_fence_type,
@@ -1214,17 +1278,7 @@
"wait",
[](HalFence& self, std::optional<iree_duration_t> timeout,
std::optional<iree_time_t> deadline) -> bool {
- iree_timeout_t t;
- if (!timeout && !deadline) {
- t = iree_infinite_timeout();
- } else if (timeout && deadline) {
- throw std::invalid_argument(
- "timeout and deadline cannot both be set");
- } else if (timeout) {
- t = iree_make_timeout_ns(*timeout);
- } else {
- t = iree_timeout_t{IREE_TIMEOUT_ABSOLUTE, *deadline};
- }
+ iree_timeout_t t = NormalizeTimeout(timeout, deadline);
iree_status_t status;
{
py::gil_scoped_release release;
@@ -1257,7 +1311,7 @@
return true;
},
py::arg("timeout") = py::none(), py::arg("deadline") = py::none(),
- kHalFenceWait);
+ kHalWait);
py::class_<HalMappedMemory>(m, "MappedMemory")
.def(
diff --git a/runtime/bindings/python/iree/runtime/_binding.pyi b/runtime/bindings/python/iree/runtime/_binding.pyi
index db21fb7..3c44e88 100644
--- a/runtime/bindings/python/iree/runtime/_binding.pyi
+++ b/runtime/bindings/python/iree/runtime/_binding.pyi
@@ -2,6 +2,8 @@
from typing import overload
+import asyncio
+
def create_hal_module(instance: VmInstance, device: HalDevice) -> VmModule: ...
def create_io_parameters_module(
instance: VmInstance, *providers: ParameterProvider
@@ -178,6 +180,13 @@
@property
def allocator(self) -> HalAllocator: ...
+class HalDeviceLoopBridge:
+ def __init__(self, device: HalDevice, loop: asyncio.BaseEventLoop): ...
+ def stop(self): ...
+ def on_semaphore(
+ self, semaphore: HalSemaphore, payload: int, value: Any
+ ) -> asyncio.Future: ...
+
class HalDriver:
@staticmethod
def query() -> List[str]: ...
@@ -254,8 +263,15 @@
def __iree_vm_ref__(self) -> VmRef: ...
class HalSemaphore:
+ def fail(self, message: str): ...
def query(self) -> int: ...
def signal(self, new_value: int) -> None: ...
+ def wait(
+ self,
+ payload: int,
+ timeout: Optional[int] = None,
+ deadline: Optional[int] = None,
+ ) -> None: ...
class Linkage(int):
EXPORT: ClassVar[Linkage] = ...
diff --git a/runtime/bindings/python/tests/hal_test.py b/runtime/bindings/python/tests/hal_test.py
index 4d59673..b19caf9 100644
--- a/runtime/bindings/python/tests/hal_test.py
+++ b/runtime/bindings/python/tests/hal_test.py
@@ -210,6 +210,52 @@
sem1.signal(2)
self.assertEqual(sem1.query(), 2)
+ def testSemaphoreSignal(self):
+ sem = self.device.create_semaphore(0)
+ self.assertFalse(sem.wait(1, deadline=0))
+ sem.signal(1)
+ self.assertTrue(sem.wait(1, deadline=0))
+
+ def testSynchronousSemaphoreFailed(self):
+ sem = self.device.create_semaphore(0)
+ sem.fail("TEST FAILURE")
+ with self.assertRaisesRegex(
+ RuntimeError, "^synchronous semaphore failure.*TEST FAILURE"
+ ):
+ sem.wait(1, deadline=0)
+
+ def testAsynchronousSemaphoreFailed(self):
+ sem = self.device.create_semaphore(0)
+ exceptions = []
+
+ def run():
+ print("SIGNALLING ASYNC FAILURE")
+ time.sleep(0.2)
+ sem.fail("TEST FAILURE")
+ print("SIGNALLED")
+
+ def wait():
+ print("WAITING")
+ try:
+ sem.wait(1)
+ except RuntimeError as e:
+ exceptions.append(e)
+
+ runner = threading.Thread(target=run)
+ waiter = threading.Thread(target=wait)
+ waiter.start()
+ runner.start()
+ waiter.join()
+ runner.join()
+ self.assertTrue(exceptions)
+ print(exceptions)
+ # Note: It is impossible to 100% guarantee that this sequences such as to
+ # report an asynchronous vs synchronous failure, although we tip the odds in
+ # this favor with the sleep in the signalling thread. Therefore, we do not
+ # check the "asynchronous" vs "synchronous" message prefix to avoid flaky
+ # test races.
+ self.assertIn("TEST FAILURE", str(exceptions[0]))
+
def testTrivialQueueAlloc(self):
sem = self.device.create_semaphore(0)
buf = self.device.queue_alloca(