Add a number of runtime python bindings and refine the HalFence.wait() behavior. (#16371)
* Adds static methods to `HalElementType`: `is_byte_aligned`,
`dense_byte_count`
* Binds the ref protocol for `HalBuffer` so it can be be put in lists.
* Adds `HalBufferView.byte_length`
* Adds `HalFence.signal` and `HalFence.fail`
* Changes behavior of `HalFence.wait` to return False if not satisfied
due to timeout and True if satisfied. Raises a Python exception on
either synchronous or asynchronous failure. The prior behavior raised
exceptions for timeout and reported success on asynchronous failure. The
latter is a bug, and the former is just a bad practice (i.e. don't raise
exceptions for normal behavior).
diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc
index 2460b36..84c2d52 100644
--- a/runtime/bindings/python/hal.cc
+++ b/runtime/bindings/python/hal.cc
@@ -76,6 +76,9 @@
* timeout: Relative nanoseconds to wait.
* deadine: Absolute nanoseconds to wait.
* Neither: Waits for infinite time.
+
+Returns whether the wait succeeded (True) or timed out (False). If the fence was
+asynchronously failed, an exception is raised.
)";
// RAII wrapper for a Py_buffer which calls PyBuffer_Release when it goes
@@ -912,9 +915,18 @@
.value("COMPLEX_64", IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64)
.value("COMPLEX_128", IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128)
.export_values()
- .def_static("map_to_dtype", [](iree_hal_element_type_t element_type) {
- int typenum = numpy::ConvertHalElementTypeToNumPyTypeNum(element_type);
- return numpy::DescrNewFromType(typenum);
+ .def_static("map_to_dtype",
+ [](iree_hal_element_type_t element_type) {
+ int typenum = numpy::ConvertHalElementTypeToNumPyTypeNum(
+ element_type);
+ return numpy::DescrNewFromType(typenum);
+ })
+ .def_static("is_byte_aligned",
+ [](iree_hal_element_type_t element_type) {
+ return iree_hal_element_is_byte_aligned(element_type);
+ })
+ .def_static("dense_byte_count", [](iree_hal_element_type_t element_type) {
+ return iree_hal_element_dense_byte_count(element_type);
});
py::class_<HalDevice>(m, "HalDevice")
@@ -1041,7 +1053,11 @@
"last resort method for making them compatible for transfer to "
"arbitrary devices.");
- py::class_<HalBuffer>(m, "HalBuffer")
+ auto hal_buffer = py::class_<HalBuffer>(m, "HalBuffer");
+ VmRef::BindRefProtocol(hal_buffer, iree_hal_buffer_type,
+ iree_hal_buffer_retain_ref, iree_hal_buffer_deref,
+ iree_hal_buffer_isa);
+ hal_buffer
.def("fill_zero", &HalBuffer::FillZero, py::arg("byte_offset"),
py::arg("byte_length"))
.def("byte_length", &HalBuffer::byte_length)
@@ -1096,6 +1112,10 @@
[](HalBufferView& self) {
return iree_hal_buffer_view_element_type(self.raw_ptr());
})
+ .def_prop_ro("byte_length",
+ [](HalBufferView& self) {
+ return iree_hal_buffer_view_byte_length(self.raw_ptr());
+ })
.def("__repr__", &HalBufferView::Repr);
py::class_<HalSemaphore>(m, "HalSemaphore")
@@ -1177,9 +1197,23 @@
},
py::arg("from_fence"))
.def(
+ "fail",
+ [](HalFence& self, std::string& message) {
+ // TODO: Take some category enum and use that is available.
+ iree_status_t status =
+ iree_make_status(IREE_STATUS_UNKNOWN, "%s", message.c_str());
+ iree_hal_fence_fail(self.raw_ptr(), status);
+ },
+ py::arg("message"))
+ .def("signal",
+ [](HalFence& self) {
+ CheckApiStatus(iree_hal_fence_signal(self.raw_ptr()),
+ "signalling fence");
+ })
+ .def(
"wait",
[](HalFence& self, std::optional<iree_duration_t> timeout,
- std::optional<iree_time_t> deadline) {
+ std::optional<iree_time_t> deadline) -> bool {
iree_timeout_t t;
if (!timeout && !deadline) {
t = iree_infinite_timeout();
@@ -1196,7 +1230,31 @@
py::gil_scoped_release release;
status = iree_hal_fence_wait(self.raw_ptr(), t);
}
- CheckApiStatus(status, "waiting for fence");
+ if (iree_status_is_deadline_exceeded(status)) {
+ // Time out.
+ return false;
+ } else if (iree_status_is_aborted(status)) {
+ // Synchronous failure.
+ iree_status_ignore(status);
+ status = iree_hal_fence_query(self.raw_ptr());
+ if (iree_status_is_ok(status)) {
+ status = iree_make_status(
+ IREE_STATUS_FAILED_PRECONDITION,
+ "expected synchronous status failure missing");
+ }
+ CheckApiStatus(status, "synchronous fence failure");
+ } else {
+ // General failure check.
+ CheckApiStatus(status, "waiting for fence");
+ }
+
+ // Asynchronous failure.
+ status = iree_hal_fence_query(self.raw_ptr());
+ if (iree_status_is_deferred(status)) {
+ return false;
+ }
+ CheckApiStatus(status, "asynchronous fence failure");
+ return true;
},
py::arg("timeout") = py::none(), py::arg("deadline") = py::none(),
kHalFenceWait);
diff --git a/runtime/bindings/python/iree/runtime/__init__.py b/runtime/bindings/python/iree/runtime/__init__.py
index ea8794a..b95b773 100644
--- a/runtime/bindings/python/iree/runtime/__init__.py
+++ b/runtime/bindings/python/iree/runtime/__init__.py
@@ -51,6 +51,7 @@
VmInstance,
VmContext,
VmModule,
+ VmRef,
)
from .array_interop import *
diff --git a/runtime/bindings/python/iree/runtime/_binding.pyi b/runtime/bindings/python/iree/runtime/_binding.pyi
index 2ed5282..db21fb7 100644
--- a/runtime/bindings/python/iree/runtime/_binding.pyi
+++ b/runtime/bindings/python/iree/runtime/_binding.pyi
@@ -113,6 +113,8 @@
@property
def shape(self) -> list[int]: ...
@property
+ def byte_length(self) -> int: ...
+ @property
def __iree_vm_ref__(self) -> VmRef: ...
class HalCommandBuffer:
@@ -225,6 +227,10 @@
UINT_8: ClassVar[HalElementType] = ...
@staticmethod
def map_to_dtype(element_type: HalElementType) -> Any: ...
+ @staticmethod
+ def is_byte_aligned(element_type: HalElementType) -> bool: ...
+ @staticmethod
+ def dense_byte_count(element_type: HalElementType) -> int: ...
__name__: Any
class HalFence:
@@ -234,7 +240,9 @@
def join(fences: Sequence[HalFence]) -> HalFence: ...
def __init__(self, capacity: int) -> None: ...
def extend(self, from_fence: HalFence) -> None: ...
+ def fail(self, message: str) -> None: ...
def insert(self, sem: HalSemaphore, value: int) -> None: ...
+ def signal(self) -> None: ...
def wait(
self, timeout: Optional[int] = None, deadline: Optional[int] = None
) -> None: ...
diff --git a/runtime/bindings/python/tests/hal_test.py b/runtime/bindings/python/tests/hal_test.py
index b38da47..4d59673 100644
--- a/runtime/bindings/python/tests/hal_test.py
+++ b/runtime/bindings/python/tests/hal_test.py
@@ -8,11 +8,13 @@
import gc
import numpy as np
+import threading
+import time
import unittest
class NonDeviceHalTest(unittest.TestCase):
- def testEnums(self):
+ def testMemoryEnums(self):
print("MemoryType:", iree.runtime.MemoryType)
print("HOST_VISIBLE:", int(iree.runtime.MemoryType.HOST_VISIBLE))
@@ -61,6 +63,13 @@
int(iree.runtime.MemoryType.OPTIMAL),
)
+ def testElementTypeEnums(self):
+ i8 = iree.runtime.HalElementType.INT_8
+ i4 = iree.runtime.HalElementType.INT_4
+ self.assertTrue(iree.runtime.HalElementType.is_byte_aligned(i8))
+ self.assertFalse(iree.runtime.HalElementType.is_byte_aligned(i4))
+ self.assertEqual(1, iree.runtime.HalElementType.dense_byte_count(i8))
+
class DeviceHalTest(unittest.TestCase):
def setUp(self):
@@ -143,6 +152,7 @@
repr(bv),
"<HalBufferView (1, 2), element_type=0x10000010, 13 bytes (at offset 0 into 13), memory_type=DEVICE_LOCAL|HOST_VISIBLE, allowed_access=ALL, allowed_usage=TRANSFER|DISPATCH_STORAGE|MAPPING|MAPPING_PERSISTENT>",
)
+ self.assertEqual(4, bv.byte_length)
def testBufferMap(self):
buffer = self.allocator.allocate_buffer(
@@ -226,16 +236,64 @@
self.device.queue_dealloca(
buf, wait_semaphores=fence1, signal_semaphores=fence2
)
- fence2.wait()
+ self.assertTrue(fence2.wait())
self.assertEqual(sem.query(), 2)
def testFenceCreateAt(self):
sem = self.device.create_semaphore(0)
fence = iree.runtime.HalFence.create_at(sem, 1)
- with self.assertRaisesRegex(RuntimeError, "DEADLINE_EXCEEDED"):
- fence.wait(deadline=0)
+ self.assertFalse(fence.wait(deadline=0))
sem.signal(1)
- fence.wait(deadline=0)
+ self.assertTrue(fence.wait(deadline=0))
+
+ def testFenceSignal(self):
+ sem = self.device.create_semaphore(0)
+ fence = iree.runtime.HalFence.create_at(sem, 1)
+ self.assertFalse(fence.wait(deadline=0))
+ fence.signal()
+ self.assertTrue(fence.wait(deadline=0))
+
+ def testSynchronousFenceFailed(self):
+ sem = self.device.create_semaphore(0)
+ fence = iree.runtime.HalFence.create_at(sem, 1)
+ fence.fail("TEST FAILURE")
+ with self.assertRaisesRegex(
+ RuntimeError, "^synchronous fence failure.*TEST FAILURE"
+ ):
+ fence.wait(deadline=0)
+
+ def testAsynchronousFenceFailed(self):
+ sem = self.device.create_semaphore(0)
+ fence = iree.runtime.HalFence.create_at(sem, 1)
+ exceptions = []
+
+ def run():
+ print("SIGNALLING ASYNC FAILURE")
+ time.sleep(0.2)
+ fence.fail("TEST FAILURE")
+ print("SIGNALLED")
+
+ def wait():
+ print("WAITING")
+ try:
+ fence.wait()
+ except RuntimeError as e:
+ exceptions.append(e)
+
+ runner = threading.Thread(target=run)
+ waiter = threading.Thread(target=wait)
+ waiter.start()
+ runner.start()
+ waiter.join()
+ runner.join()
+ self.assertTrue(exceptions)
+ print(exceptions)
+ # Note: It is impossible to 100% guarantee that this sequences such as to
+ # report an asynchronous vs synchronous failure, although we tip the odds in
+ # this favor with the sleep in the signalling thread. Therefore, we do not
+ # check the "asynchronous" vs "synchronous" message prefix to avoid flaky
+ # test races.
+ self.assertIn("TEST FAILURE", str(exceptions[0]))
def testFenceJoin(self):
sem1 = self.device.create_semaphore(0)
diff --git a/runtime/bindings/python/tests/vm_types_test.py b/runtime/bindings/python/tests/vm_types_test.py
index 2bce419..f3bd0f8 100644
--- a/runtime/bindings/python/tests/vm_types_test.py
+++ b/runtime/bindings/python/tests/vm_types_test.py
@@ -45,7 +45,7 @@
l.push_int(10 * 1000 * 1000 * 1000)
self.assertEqual(str(l), "<VmVariantList(1): [10000000000]>")
- def test_variant_list_buffers(self):
+ def test_variant_list_buffer_view(self):
device = rt.get_device("local-sync")
ET = rt.HalElementType
for dt, et in (
@@ -82,6 +82,17 @@
with self.assertRaises(IndexError):
lst.get_as_object(1, rt.HalBufferView)
+ def test_variant_list_buffer(self):
+ device = rt.get_device("local-sync")
+ lst = rt.VmVariantList(5)
+ buffer = device.allocator.allocate_buffer(
+ memory_type=rt.MemoryType.DEVICE_LOCAL,
+ allowed_usage=rt.BufferUsage.DEFAULT,
+ allocation_size=1024,
+ )
+ lst.push_ref(buffer)
+ _ = (lst.get_as_object(0, rt.HalBuffer),)
+
def test_variant_list_zero_rank_tensor_to_str(self):
device = rt.get_device("local-sync")
lst = rt.VmVariantList(1)