[python] Add a couple more async APIs. (#16419)

diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc
index 84c2d52..89d18db 100644
--- a/runtime/bindings/python/hal.cc
+++ b/runtime/bindings/python/hal.cc
@@ -69,8 +69,8 @@
   signal_semaphores: Semaphores/Fence to signal.
 )";
 
-static const char kHalFenceWait[] =
-    R"(Waits until the fence is signalled or errored.
+static const char kHalWait[] =
+    R"(Waits until the semaphore or fence is signalled or errored.
 
 Three wait cases are supported:
   * timeout: Relative nanoseconds to wait.
@@ -106,6 +106,19 @@
   return ToHexString((const uint8_t*)&value, sizeof(value));
 }
 
+iree_timeout_t NormalizeTimeout(std::optional<iree_duration_t> timeout,
+                                std::optional<iree_time_t> deadline) {
+  if (!timeout && !deadline) {
+    return iree_infinite_timeout();
+  } else if (timeout && deadline) {
+    throw std::invalid_argument("timeout and deadline cannot both be set");
+  } else if (timeout) {
+    return iree_make_timeout_ns(*timeout);
+  } else {
+    return iree_timeout_t{IREE_TIMEOUT_ABSOLUTE, *deadline};
+  }
+}
+
 }  // namespace
 
 //------------------------------------------------------------------------------
@@ -1119,6 +1132,15 @@
       .def("__repr__", &HalBufferView::Repr);
 
   py::class_<HalSemaphore>(m, "HalSemaphore")
+      .def(
+          "fail",
+          [](HalSemaphore& self, std::string& message) {
+            // TODO: Take some category enum and use that is available.
+            iree_status_t status =
+                iree_make_status(IREE_STATUS_UNKNOWN, "%s", message.c_str());
+            iree_hal_semaphore_fail(self.raw_ptr(), status);
+          },
+          py::arg("message"))
       .def("query",
            [](HalSemaphore& self) {
              uint64_t out_value;
@@ -1127,10 +1149,52 @@
                  "querying semaphore");
              return out_value;
            })
-      .def("signal", [](HalSemaphore& self, uint64_t new_value) {
-        CheckApiStatus(iree_hal_semaphore_signal(self.raw_ptr(), new_value),
-                       "signaling semaphore");
-      });
+      .def("signal",
+           [](HalSemaphore& self, uint64_t new_value) {
+             CheckApiStatus(
+                 iree_hal_semaphore_signal(self.raw_ptr(), new_value),
+                 "signaling semaphore");
+           })
+      .def(
+          "wait",
+          [](HalSemaphore& self, uint64_t payload,
+             std::optional<iree_duration_t> timeout,
+             std::optional<iree_time_t> deadline) -> bool {
+            iree_timeout_t t = NormalizeTimeout(timeout, deadline);
+            iree_status_t status;
+            uint64_t unused_value;
+            {
+              py::gil_scoped_release release;
+              status = iree_hal_semaphore_wait(self.raw_ptr(), payload, t);
+            }
+            if (iree_status_is_deadline_exceeded(status)) {
+              // Time out.
+              return false;
+            } else if (iree_status_is_aborted(status)) {
+              // Synchronous failure.
+              iree_status_ignore(status);
+              status = iree_hal_semaphore_query(self.raw_ptr(), &unused_value);
+              if (iree_status_is_ok(status)) {
+                status = iree_make_status(
+                    IREE_STATUS_FAILED_PRECONDITION,
+                    "expected synchronous status failure missing");
+              }
+              CheckApiStatus(status, "synchronous semaphore failure");
+            } else {
+              // General failure check.
+              CheckApiStatus(status, "waiting for semaphore");
+            }
+
+            // Asynchronous failure.
+            status = iree_hal_semaphore_query(self.raw_ptr(), &unused_value);
+            if (iree_status_is_deferred(status)) {
+              return false;
+            }
+            CheckApiStatus(status, "asynchronous semaphore failure");
+            return true;
+          },
+          py::arg("payload"), py::arg("timeout") = py::none(),
+          py::arg("deadline") = py::none(), kHalWait);
 
   auto hal_fence = py::class_<HalFence>(m, "HalFence");
   VmRef::BindRefProtocol(hal_fence, iree_hal_fence_type,
@@ -1214,17 +1278,7 @@
           "wait",
           [](HalFence& self, std::optional<iree_duration_t> timeout,
              std::optional<iree_time_t> deadline) -> bool {
-            iree_timeout_t t;
-            if (!timeout && !deadline) {
-              t = iree_infinite_timeout();
-            } else if (timeout && deadline) {
-              throw std::invalid_argument(
-                  "timeout and deadline cannot both be set");
-            } else if (timeout) {
-              t = iree_make_timeout_ns(*timeout);
-            } else {
-              t = iree_timeout_t{IREE_TIMEOUT_ABSOLUTE, *deadline};
-            }
+            iree_timeout_t t = NormalizeTimeout(timeout, deadline);
             iree_status_t status;
             {
               py::gil_scoped_release release;
@@ -1257,7 +1311,7 @@
             return true;
           },
           py::arg("timeout") = py::none(), py::arg("deadline") = py::none(),
-          kHalFenceWait);
+          kHalWait);
 
   py::class_<HalMappedMemory>(m, "MappedMemory")
       .def(
diff --git a/runtime/bindings/python/iree/runtime/_binding.pyi b/runtime/bindings/python/iree/runtime/_binding.pyi
index db21fb7..3c44e88 100644
--- a/runtime/bindings/python/iree/runtime/_binding.pyi
+++ b/runtime/bindings/python/iree/runtime/_binding.pyi
@@ -2,6 +2,8 @@
 
 from typing import overload
 
+import asyncio
+
 def create_hal_module(instance: VmInstance, device: HalDevice) -> VmModule: ...
 def create_io_parameters_module(
     instance: VmInstance, *providers: ParameterProvider
@@ -178,6 +180,13 @@
     @property
     def allocator(self) -> HalAllocator: ...
 
+class HalDeviceLoopBridge:
+    def __init__(self, device: HalDevice, loop: asyncio.BaseEventLoop): ...
+    def stop(self): ...
+    def on_semaphore(
+        self, semaphore: HalSemaphore, payload: int, value: Any
+    ) -> asyncio.Future: ...
+
 class HalDriver:
     @staticmethod
     def query() -> List[str]: ...
@@ -254,8 +263,15 @@
     def __iree_vm_ref__(self) -> VmRef: ...
 
 class HalSemaphore:
+    def fail(self, message: str): ...
     def query(self) -> int: ...
     def signal(self, new_value: int) -> None: ...
+    def wait(
+        self,
+        payload: int,
+        timeout: Optional[int] = None,
+        deadline: Optional[int] = None,
+    ) -> None: ...
 
 class Linkage(int):
     EXPORT: ClassVar[Linkage] = ...
diff --git a/runtime/bindings/python/tests/hal_test.py b/runtime/bindings/python/tests/hal_test.py
index 4d59673..b19caf9 100644
--- a/runtime/bindings/python/tests/hal_test.py
+++ b/runtime/bindings/python/tests/hal_test.py
@@ -210,6 +210,52 @@
         sem1.signal(2)
         self.assertEqual(sem1.query(), 2)
 
+    def testSemaphoreSignal(self):
+        sem = self.device.create_semaphore(0)
+        self.assertFalse(sem.wait(1, deadline=0))
+        sem.signal(1)
+        self.assertTrue(sem.wait(1, deadline=0))
+
+    def testSynchronousSemaphoreFailed(self):
+        sem = self.device.create_semaphore(0)
+        sem.fail("TEST FAILURE")
+        with self.assertRaisesRegex(
+            RuntimeError, "^synchronous semaphore failure.*TEST FAILURE"
+        ):
+            sem.wait(1, deadline=0)
+
+    def testAsynchronousSemaphoreFailed(self):
+        sem = self.device.create_semaphore(0)
+        exceptions = []
+
+        def run():
+            print("SIGNALLING ASYNC FAILURE")
+            time.sleep(0.2)
+            sem.fail("TEST FAILURE")
+            print("SIGNALLED")
+
+        def wait():
+            print("WAITING")
+            try:
+                sem.wait(1)
+            except RuntimeError as e:
+                exceptions.append(e)
+
+        runner = threading.Thread(target=run)
+        waiter = threading.Thread(target=wait)
+        waiter.start()
+        runner.start()
+        waiter.join()
+        runner.join()
+        self.assertTrue(exceptions)
+        print(exceptions)
+        # Note: It is impossible to 100% guarantee that this sequences such as to
+        # report an asynchronous vs synchronous failure, although we tip the odds in
+        # this favor with the sleep in the signalling thread. Therefore, we do not
+        # check the "asynchronous" vs "synchronous" message prefix to avoid flaky
+        # test races.
+        self.assertIn("TEST FAILURE", str(exceptions[0]))
+
     def testTrivialQueueAlloc(self):
         sem = self.device.create_semaphore(0)
         buf = self.device.queue_alloca(