Add a number of runtime python bindings and refine the HalFence.wait() behavior. (#16371)

* Adds static methods to `HalElementType`: `is_byte_aligned`,
`dense_byte_count`
* Binds the ref protocol for `HalBuffer` so it can be be put in lists.
* Adds `HalBufferView.byte_length`
* Adds `HalFence.signal` and `HalFence.fail`
* Changes behavior of `HalFence.wait` to return False if not satisfied
due to timeout and True if satisfied. Raises a Python exception on
either synchronous or asynchronous failure. The prior behavior raised
exceptions for timeout and reported success on asynchronous failure. The
latter is a bug, and the former is just a bad practice (i.e. don't raise
exceptions for normal behavior).
diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc
index 2460b36..84c2d52 100644
--- a/runtime/bindings/python/hal.cc
+++ b/runtime/bindings/python/hal.cc
@@ -76,6 +76,9 @@
   * timeout: Relative nanoseconds to wait.
   * deadine: Absolute nanoseconds to wait.
   * Neither: Waits for infinite time.
+
+Returns whether the wait succeeded (True) or timed out (False). If the fence was
+asynchronously failed, an exception is raised.
 )";
 
 // RAII wrapper for a Py_buffer which calls PyBuffer_Release when it goes
@@ -912,9 +915,18 @@
       .value("COMPLEX_64", IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64)
       .value("COMPLEX_128", IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128)
       .export_values()
-      .def_static("map_to_dtype", [](iree_hal_element_type_t element_type) {
-        int typenum = numpy::ConvertHalElementTypeToNumPyTypeNum(element_type);
-        return numpy::DescrNewFromType(typenum);
+      .def_static("map_to_dtype",
+                  [](iree_hal_element_type_t element_type) {
+                    int typenum = numpy::ConvertHalElementTypeToNumPyTypeNum(
+                        element_type);
+                    return numpy::DescrNewFromType(typenum);
+                  })
+      .def_static("is_byte_aligned",
+                  [](iree_hal_element_type_t element_type) {
+                    return iree_hal_element_is_byte_aligned(element_type);
+                  })
+      .def_static("dense_byte_count", [](iree_hal_element_type_t element_type) {
+        return iree_hal_element_dense_byte_count(element_type);
       });
 
   py::class_<HalDevice>(m, "HalDevice")
@@ -1041,7 +1053,11 @@
            "last resort method for making them compatible for transfer to "
            "arbitrary devices.");
 
-  py::class_<HalBuffer>(m, "HalBuffer")
+  auto hal_buffer = py::class_<HalBuffer>(m, "HalBuffer");
+  VmRef::BindRefProtocol(hal_buffer, iree_hal_buffer_type,
+                         iree_hal_buffer_retain_ref, iree_hal_buffer_deref,
+                         iree_hal_buffer_isa);
+  hal_buffer
       .def("fill_zero", &HalBuffer::FillZero, py::arg("byte_offset"),
            py::arg("byte_length"))
       .def("byte_length", &HalBuffer::byte_length)
@@ -1096,6 +1112,10 @@
                    [](HalBufferView& self) {
                      return iree_hal_buffer_view_element_type(self.raw_ptr());
                    })
+      .def_prop_ro("byte_length",
+                   [](HalBufferView& self) {
+                     return iree_hal_buffer_view_byte_length(self.raw_ptr());
+                   })
       .def("__repr__", &HalBufferView::Repr);
 
   py::class_<HalSemaphore>(m, "HalSemaphore")
@@ -1177,9 +1197,23 @@
           },
           py::arg("from_fence"))
       .def(
+          "fail",
+          [](HalFence& self, std::string& message) {
+            // TODO: Take some category enum and use that is available.
+            iree_status_t status =
+                iree_make_status(IREE_STATUS_UNKNOWN, "%s", message.c_str());
+            iree_hal_fence_fail(self.raw_ptr(), status);
+          },
+          py::arg("message"))
+      .def("signal",
+           [](HalFence& self) {
+             CheckApiStatus(iree_hal_fence_signal(self.raw_ptr()),
+                            "signalling fence");
+           })
+      .def(
           "wait",
           [](HalFence& self, std::optional<iree_duration_t> timeout,
-             std::optional<iree_time_t> deadline) {
+             std::optional<iree_time_t> deadline) -> bool {
             iree_timeout_t t;
             if (!timeout && !deadline) {
               t = iree_infinite_timeout();
@@ -1196,7 +1230,31 @@
               py::gil_scoped_release release;
               status = iree_hal_fence_wait(self.raw_ptr(), t);
             }
-            CheckApiStatus(status, "waiting for fence");
+            if (iree_status_is_deadline_exceeded(status)) {
+              // Time out.
+              return false;
+            } else if (iree_status_is_aborted(status)) {
+              // Synchronous failure.
+              iree_status_ignore(status);
+              status = iree_hal_fence_query(self.raw_ptr());
+              if (iree_status_is_ok(status)) {
+                status = iree_make_status(
+                    IREE_STATUS_FAILED_PRECONDITION,
+                    "expected synchronous status failure missing");
+              }
+              CheckApiStatus(status, "synchronous fence failure");
+            } else {
+              // General failure check.
+              CheckApiStatus(status, "waiting for fence");
+            }
+
+            // Asynchronous failure.
+            status = iree_hal_fence_query(self.raw_ptr());
+            if (iree_status_is_deferred(status)) {
+              return false;
+            }
+            CheckApiStatus(status, "asynchronous fence failure");
+            return true;
           },
           py::arg("timeout") = py::none(), py::arg("deadline") = py::none(),
           kHalFenceWait);
diff --git a/runtime/bindings/python/iree/runtime/__init__.py b/runtime/bindings/python/iree/runtime/__init__.py
index ea8794a..b95b773 100644
--- a/runtime/bindings/python/iree/runtime/__init__.py
+++ b/runtime/bindings/python/iree/runtime/__init__.py
@@ -51,6 +51,7 @@
     VmInstance,
     VmContext,
     VmModule,
+    VmRef,
 )
 
 from .array_interop import *
diff --git a/runtime/bindings/python/iree/runtime/_binding.pyi b/runtime/bindings/python/iree/runtime/_binding.pyi
index 2ed5282..db21fb7 100644
--- a/runtime/bindings/python/iree/runtime/_binding.pyi
+++ b/runtime/bindings/python/iree/runtime/_binding.pyi
@@ -113,6 +113,8 @@
     @property
     def shape(self) -> list[int]: ...
     @property
+    def byte_length(self) -> int: ...
+    @property
     def __iree_vm_ref__(self) -> VmRef: ...
 
 class HalCommandBuffer:
@@ -225,6 +227,10 @@
     UINT_8: ClassVar[HalElementType] = ...
     @staticmethod
     def map_to_dtype(element_type: HalElementType) -> Any: ...
+    @staticmethod
+    def is_byte_aligned(element_type: HalElementType) -> bool: ...
+    @staticmethod
+    def dense_byte_count(element_type: HalElementType) -> int: ...
     __name__: Any
 
 class HalFence:
@@ -234,7 +240,9 @@
     def join(fences: Sequence[HalFence]) -> HalFence: ...
     def __init__(self, capacity: int) -> None: ...
     def extend(self, from_fence: HalFence) -> None: ...
+    def fail(self, message: str) -> None: ...
     def insert(self, sem: HalSemaphore, value: int) -> None: ...
+    def signal(self) -> None: ...
     def wait(
         self, timeout: Optional[int] = None, deadline: Optional[int] = None
     ) -> None: ...
diff --git a/runtime/bindings/python/tests/hal_test.py b/runtime/bindings/python/tests/hal_test.py
index b38da47..4d59673 100644
--- a/runtime/bindings/python/tests/hal_test.py
+++ b/runtime/bindings/python/tests/hal_test.py
@@ -8,11 +8,13 @@
 
 import gc
 import numpy as np
+import threading
+import time
 import unittest
 
 
 class NonDeviceHalTest(unittest.TestCase):
-    def testEnums(self):
+    def testMemoryEnums(self):
         print("MemoryType:", iree.runtime.MemoryType)
         print("HOST_VISIBLE:", int(iree.runtime.MemoryType.HOST_VISIBLE))
 
@@ -61,6 +63,13 @@
             int(iree.runtime.MemoryType.OPTIMAL),
         )
 
+    def testElementTypeEnums(self):
+        i8 = iree.runtime.HalElementType.INT_8
+        i4 = iree.runtime.HalElementType.INT_4
+        self.assertTrue(iree.runtime.HalElementType.is_byte_aligned(i8))
+        self.assertFalse(iree.runtime.HalElementType.is_byte_aligned(i4))
+        self.assertEqual(1, iree.runtime.HalElementType.dense_byte_count(i8))
+
 
 class DeviceHalTest(unittest.TestCase):
     def setUp(self):
@@ -143,6 +152,7 @@
             repr(bv),
             "<HalBufferView (1, 2), element_type=0x10000010, 13 bytes (at offset 0 into 13), memory_type=DEVICE_LOCAL|HOST_VISIBLE, allowed_access=ALL, allowed_usage=TRANSFER|DISPATCH_STORAGE|MAPPING|MAPPING_PERSISTENT>",
         )
+        self.assertEqual(4, bv.byte_length)
 
     def testBufferMap(self):
         buffer = self.allocator.allocate_buffer(
@@ -226,16 +236,64 @@
         self.device.queue_dealloca(
             buf, wait_semaphores=fence1, signal_semaphores=fence2
         )
-        fence2.wait()
+        self.assertTrue(fence2.wait())
         self.assertEqual(sem.query(), 2)
 
     def testFenceCreateAt(self):
         sem = self.device.create_semaphore(0)
         fence = iree.runtime.HalFence.create_at(sem, 1)
-        with self.assertRaisesRegex(RuntimeError, "DEADLINE_EXCEEDED"):
-            fence.wait(deadline=0)
+        self.assertFalse(fence.wait(deadline=0))
         sem.signal(1)
-        fence.wait(deadline=0)
+        self.assertTrue(fence.wait(deadline=0))
+
+    def testFenceSignal(self):
+        sem = self.device.create_semaphore(0)
+        fence = iree.runtime.HalFence.create_at(sem, 1)
+        self.assertFalse(fence.wait(deadline=0))
+        fence.signal()
+        self.assertTrue(fence.wait(deadline=0))
+
+    def testSynchronousFenceFailed(self):
+        sem = self.device.create_semaphore(0)
+        fence = iree.runtime.HalFence.create_at(sem, 1)
+        fence.fail("TEST FAILURE")
+        with self.assertRaisesRegex(
+            RuntimeError, "^synchronous fence failure.*TEST FAILURE"
+        ):
+            fence.wait(deadline=0)
+
+    def testAsynchronousFenceFailed(self):
+        sem = self.device.create_semaphore(0)
+        fence = iree.runtime.HalFence.create_at(sem, 1)
+        exceptions = []
+
+        def run():
+            print("SIGNALLING ASYNC FAILURE")
+            time.sleep(0.2)
+            fence.fail("TEST FAILURE")
+            print("SIGNALLED")
+
+        def wait():
+            print("WAITING")
+            try:
+                fence.wait()
+            except RuntimeError as e:
+                exceptions.append(e)
+
+        runner = threading.Thread(target=run)
+        waiter = threading.Thread(target=wait)
+        waiter.start()
+        runner.start()
+        waiter.join()
+        runner.join()
+        self.assertTrue(exceptions)
+        print(exceptions)
+        # Note: It is impossible to 100% guarantee that this sequences such as to
+        # report an asynchronous vs synchronous failure, although we tip the odds in
+        # this favor with the sleep in the signalling thread. Therefore, we do not
+        # check the "asynchronous" vs "synchronous" message prefix to avoid flaky
+        # test races.
+        self.assertIn("TEST FAILURE", str(exceptions[0]))
 
     def testFenceJoin(self):
         sem1 = self.device.create_semaphore(0)
diff --git a/runtime/bindings/python/tests/vm_types_test.py b/runtime/bindings/python/tests/vm_types_test.py
index 2bce419..f3bd0f8 100644
--- a/runtime/bindings/python/tests/vm_types_test.py
+++ b/runtime/bindings/python/tests/vm_types_test.py
@@ -45,7 +45,7 @@
         l.push_int(10 * 1000 * 1000 * 1000)
         self.assertEqual(str(l), "<VmVariantList(1): [10000000000]>")
 
-    def test_variant_list_buffers(self):
+    def test_variant_list_buffer_view(self):
         device = rt.get_device("local-sync")
         ET = rt.HalElementType
         for dt, et in (
@@ -82,6 +82,17 @@
             with self.assertRaises(IndexError):
                 lst.get_as_object(1, rt.HalBufferView)
 
+    def test_variant_list_buffer(self):
+        device = rt.get_device("local-sync")
+        lst = rt.VmVariantList(5)
+        buffer = device.allocator.allocate_buffer(
+            memory_type=rt.MemoryType.DEVICE_LOCAL,
+            allowed_usage=rt.BufferUsage.DEFAULT,
+            allocation_size=1024,
+        )
+        lst.push_ref(buffer)
+        _ = (lst.get_as_object(0, rt.HalBuffer),)
+
     def test_variant_list_zero_rank_tensor_to_str(self):
         device = rt.get_device("local-sync")
         lst = rt.VmVariantList(1)