[Python] Add Fence VMRef Binding to enable async-exec on py. (#15263)
To use `iree-execution-model=async-*`, we'd need to feed in `HalFence`
as inputs to `Vm.Invoke`. Which means we'd need to be able to push_ref
`HalFence` into `VmVariantList`, this PR enables that support.
diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc
index 8ca1078..e08d5a3 100644
--- a/runtime/bindings/python/hal.cc
+++ b/runtime/bindings/python/hal.cc
@@ -1109,7 +1109,11 @@
"signaling semaphore");
});
- py::class_<HalFence>(m, "HalFence")
+ auto hal_fence = py::class_<HalFence>(m, "HalFence");
+ VmRef::BindRefProtocol(hal_fence, iree_hal_fence_type,
+ iree_hal_fence_retain_ref, iree_hal_fence_deref,
+ iree_hal_fence_isa);
+ hal_fence
.def(
"__init__",
[](HalFence* new_fence, iree_host_size_t capacity) {
diff --git a/runtime/bindings/python/tests/vm_types_test.py b/runtime/bindings/python/tests/vm_types_test.py
index 71ecbbf..2bce419 100644
--- a/runtime/bindings/python/tests/vm_types_test.py
+++ b/runtime/bindings/python/tests/vm_types_test.py
@@ -96,6 +96,12 @@
lst.push_ref(buffer_view)
self.assertEqual(str(lst), "<VmVariantList(1): [HalBufferView(:0x20000011)]>")
+ def test_variant_list_fence_to_str(self):
+ lst = rt.VmVariantList(1)
+ fence = rt.HalFence(2)
+ lst.push_ref(fence)
+ self.assertEqual(str(lst), "<VmVariantList(1): [fence(0)]>")
+
def test_variant_list_list(self):
lst1 = rt.VmVariantList(5)
lst2 = rt.VmVariantList(5)
diff --git a/runtime/bindings/python/vm.cc b/runtime/bindings/python/vm.cc
index 18c8e8d..a5c6859 100644
--- a/runtime/bindings/python/vm.cc
+++ b/runtime/bindings/python/vm.cc
@@ -742,6 +742,12 @@
out.append("...circular...");
}
out.append("]");
+ } else if (iree_hal_fence_isa(variant.ref)) {
+ out.append("fence(");
+ auto* hal_fence = iree_hal_fence_deref(variant.ref);
+ iree_host_size_t timepoint_count =
+ iree_hal_fence_timepoint_count(hal_fence);
+ out.append(std::to_string(timepoint_count) + ")");
} else {
out += "Unknown(" +
std::to_string(iree_vm_type_def_as_ref(variant.type)) + ")";