Implement Python bindings for async HAL objects. (#14476)
This is not a complete mapping but is being added demand-driven while
implementing the native PyTorch support.
diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc
index 1da899f..747f009 100644
--- a/runtime/bindings/python/hal.cc
+++ b/runtime/bindings/python/hal.cc
@@ -18,6 +18,53 @@
namespace {
+static const char kHalDeviceQueueAlloca[] =
+ R"(Reserves and returns a device-local queue-ordered transient buffer.
+
+Args:
+ allocation_size: The size in bytes of the allocation.
+ wait_semaphores: `List[Tuple[HalSemaphore, int]]` of semaphore values or
+ a HalFence. The allocation will be made once these semaphores are
+ satisfied.
+ signal_semaphores: Semaphores/Fence to signal.
+
+Returns:
+ HalBuffer.
+)";
+
+static const char kHalDeviceQueueDealloca[] =
+ R"(Deallocates a queue-ordered transient buffer.
+
+Args:
+ wait_semaphores: `List[Tuple[HalSemaphore, int]]` of semaphore values or
+ a HalFence. The allocation will be made once these semaphores are
+ satisfied.
+ signal_semaphores: Semaphores/Fence to signal.
+
+Returns:
+ HalBuffer.
+)";
+
+static const char kHalDeviceQueueExecute[] =
+ R"(Executes a sequence of command buffers.
+
+Args:
+ command_buffers: Sequence of command buffers to enqueue.
+ wait_semaphores: `List[Tuple[HalSemaphore, int]]` of semaphore values or
+ a HalFence. The allocation will be made once these semaphores are
+ satisfied.
+ signal_semaphores: Semaphores/Fence to signal.
+)";
+
+static const char kHalFenceWait[] =
+ R"(Waits until the fence is signalled or errored.
+
+Three wait cases are supported:
+ * timeout: Relative nanoseconds to wait.
+ * deadine: Absolute nanoseconds to wait.
+ * Neither: Waits for infinite time.
+)";
+
// RAII wrapper for a Py_buffer which calls PyBuffer_Release when it goes
// out of scope.
class PyBufferReleaser {
@@ -138,6 +185,42 @@
py::rv_policy::move);
}
+HalBuffer HalAllocator::AllocateHostStagingBufferCopy(py::handle buffer) {
+ IREE_TRACE_SCOPE_NAMED("HalAllocator::AllocateHostStagingBufferCopy");
+ // Request a view of the buffer (use the raw python C API to avoid
+ // some allocation and copying at the pybind level).
+ Py_buffer py_view;
+ // Note that only C-Contiguous ND-arrays are presently supported, so
+ // only request that via PyBUF_ND. Long term, we should consult an
+ // "oracle" in the runtime to determine the precise required format
+ // and set flags accordingly (and fallback/copy on failure).
+ int flags = PyBUF_FORMAT | PyBUF_ND;
+
+ // Acquire the backing buffer and setup RAII release.
+ if (PyObject_GetBuffer(buffer.ptr(), &py_view, flags) != 0) {
+ // The GetBuffer call is required to set an appropriate error.
+ throw py::python_error();
+ }
+ PyBufferReleaser py_view_releaser(py_view);
+
+ iree_hal_buffer_params_t params = {0};
+ std::memset(¶ms, 0, sizeof(params));
+ params.type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE;
+ params.usage = IREE_HAL_BUFFER_USAGE_TRANSFER;
+
+ iree_hal_buffer_t* hal_buffer = nullptr;
+ iree_status_t status = iree_ok_status();
+ {
+ py::gil_scoped_release release;
+ status = iree_hal_allocator_allocate_buffer(
+ raw_ptr(), params, py_view.len,
+ iree_make_const_byte_span(py_view.buf, py_view.len), &hal_buffer);
+ }
+ CheckApiStatus(status, "Failed to allocate device visible buffer");
+
+ return HalBuffer::StealFromRawPtr(hal_buffer);
+}
+
//------------------------------------------------------------------------------
// HalBuffer
//------------------------------------------------------------------------------
@@ -240,6 +323,192 @@
"ending device profiling");
}
+HalSemaphore HalDevice::CreateSemaphore(uint64_t initial_value) {
+ iree_hal_semaphore_t* out_sem;
+ CheckApiStatus(iree_hal_semaphore_create(raw_ptr(), initial_value, &out_sem),
+ "creating semaphore");
+ return HalSemaphore::StealFromRawPtr(out_sem);
+}
+
+HalBuffer HalDevice::QueueAlloca(uint64_t allocation_size,
+ py::handle wait_semaphores,
+ py::handle signal_semaphores) {
+ iree_hal_buffer_params_t params;
+ memset(¶ms, 0, sizeof(params));
+ // TODO: Accept explicit params in API.
+ params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL;
+
+ iree_hal_semaphore_list_t wait_list;
+ iree_hal_semaphore_list_t signal_list;
+
+ // Wait list.
+ if (py::isinstance<HalFence>(wait_semaphores)) {
+ wait_list = iree_hal_fence_semaphore_list(
+ py::cast<HalFence*>(wait_semaphores)->raw_ptr());
+ } else {
+ size_t wait_count = py::len(wait_semaphores);
+ wait_list = {
+ wait_count,
+ /*semaphores=*/
+ static_cast<iree_hal_semaphore_t**>(
+ alloca(sizeof(iree_hal_semaphore_t*) * wait_count)),
+ /*payload_values=*/
+ static_cast<uint64_t*>(alloca(sizeof(uint64_t) * wait_count)),
+ };
+ for (size_t i = 0; i < wait_count; ++i) {
+ py::tuple pair = wait_semaphores[i];
+ wait_list.semaphores[i] = py::cast<HalSemaphore*>(pair[0])->raw_ptr();
+ wait_list.payload_values[i] = py::cast<uint64_t>(pair[1]);
+ }
+ }
+
+ // Signal list.
+ if (py::isinstance<HalFence>(signal_semaphores)) {
+ signal_list = iree_hal_fence_semaphore_list(
+ py::cast<HalFence*>(signal_semaphores)->raw_ptr());
+ } else {
+ size_t signal_count = py::len(signal_semaphores);
+ signal_list = {
+ signal_count,
+ /*semaphores=*/
+ static_cast<iree_hal_semaphore_t**>(
+ alloca(sizeof(iree_hal_semaphore_t*) * signal_count)),
+ /*payload_values=*/
+ static_cast<uint64_t*>(alloca(sizeof(uint64_t) * signal_count)),
+ };
+ for (size_t i = 0; i < signal_count; ++i) {
+ py::tuple pair = signal_semaphores[i];
+ signal_list.semaphores[i] = py::cast<HalSemaphore*>(pair[0])->raw_ptr();
+ signal_list.payload_values[i] = py::cast<uint64_t>(pair[1]);
+ }
+ }
+
+ iree_hal_buffer_t* out_buffer;
+ // TODO: Accept params for queue affinity and pool.
+ CheckApiStatus(iree_hal_device_queue_alloca(
+ raw_ptr(), IREE_HAL_QUEUE_AFFINITY_ANY, wait_list,
+ signal_list, IREE_HAL_ALLOCATOR_POOL_DEFAULT, params,
+ allocation_size, &out_buffer),
+ "allocating memory on queue");
+ return HalBuffer::StealFromRawPtr(out_buffer);
+}
+
+void HalDevice::QueueDealloca(HalBuffer& buffer, py::handle wait_semaphores,
+ py::handle signal_semaphores) {
+ iree_hal_semaphore_list_t wait_list;
+ iree_hal_semaphore_list_t signal_list;
+
+ // Wait list.
+ if (py::isinstance<HalFence>(wait_semaphores)) {
+ wait_list = iree_hal_fence_semaphore_list(
+ py::cast<HalFence*>(wait_semaphores)->raw_ptr());
+ } else {
+ size_t wait_count = py::len(wait_semaphores);
+ wait_list = {
+ wait_count,
+ /*semaphores=*/
+ static_cast<iree_hal_semaphore_t**>(
+ alloca(sizeof(iree_hal_semaphore_t*) * wait_count)),
+ /*payload_values=*/
+ static_cast<uint64_t*>(alloca(sizeof(uint64_t) * wait_count)),
+ };
+ for (size_t i = 0; i < wait_count; ++i) {
+ py::tuple pair = wait_semaphores[i];
+ wait_list.semaphores[i] = py::cast<HalSemaphore*>(pair[0])->raw_ptr();
+ wait_list.payload_values[i] = py::cast<uint64_t>(pair[1]);
+ }
+ }
+
+ // Signal list.
+ if (py::isinstance<HalFence>(signal_semaphores)) {
+ signal_list = iree_hal_fence_semaphore_list(
+ py::cast<HalFence*>(signal_semaphores)->raw_ptr());
+ } else {
+ size_t signal_count = py::len(signal_semaphores);
+ signal_list = {
+ signal_count,
+ /*semaphores=*/
+ static_cast<iree_hal_semaphore_t**>(
+ alloca(sizeof(iree_hal_semaphore_t*) * signal_count)),
+ /*payload_values=*/
+ static_cast<uint64_t*>(alloca(sizeof(uint64_t) * signal_count)),
+ };
+ for (size_t i = 0; i < signal_count; ++i) {
+ py::tuple pair = signal_semaphores[i];
+ signal_list.semaphores[i] = py::cast<HalSemaphore*>(pair[0])->raw_ptr();
+ signal_list.payload_values[i] = py::cast<uint64_t>(pair[1]);
+ }
+ }
+
+ CheckApiStatus(
+ iree_hal_device_queue_dealloca(raw_ptr(), IREE_HAL_QUEUE_AFFINITY_ANY,
+ wait_list, signal_list, buffer.raw_ptr()),
+ "deallocating memory on queue");
+}
+
+void HalDevice::QueueExecute(py::handle command_buffers,
+ py::handle wait_semaphores,
+ py::handle signal_semaphores) {
+ iree_hal_semaphore_list_t wait_list;
+ iree_hal_semaphore_list_t signal_list;
+
+ // Wait list.
+ if (py::isinstance<HalFence>(wait_semaphores)) {
+ wait_list = iree_hal_fence_semaphore_list(
+ py::cast<HalFence*>(wait_semaphores)->raw_ptr());
+ } else {
+ size_t wait_count = py::len(wait_semaphores);
+ wait_list = {
+ wait_count,
+ /*semaphores=*/
+ static_cast<iree_hal_semaphore_t**>(
+ alloca(sizeof(iree_hal_semaphore_t*) * wait_count)),
+ /*payload_values=*/
+ static_cast<uint64_t*>(alloca(sizeof(uint64_t) * wait_count)),
+ };
+ for (size_t i = 0; i < wait_count; ++i) {
+ py::tuple pair = wait_semaphores[i];
+ wait_list.semaphores[i] = py::cast<HalSemaphore*>(pair[0])->raw_ptr();
+ wait_list.payload_values[i] = py::cast<uint64_t>(pair[1]);
+ }
+ }
+
+ // Signal list.
+ if (py::isinstance<HalFence>(signal_semaphores)) {
+ signal_list = iree_hal_fence_semaphore_list(
+ py::cast<HalFence*>(signal_semaphores)->raw_ptr());
+ } else {
+ size_t signal_count = py::len(signal_semaphores);
+ signal_list = {
+ signal_count,
+ /*semaphores=*/
+ static_cast<iree_hal_semaphore_t**>(
+ alloca(sizeof(iree_hal_semaphore_t*) * signal_count)),
+ /*payload_values=*/
+ static_cast<uint64_t*>(alloca(sizeof(uint64_t) * signal_count)),
+ };
+ for (size_t i = 0; i < signal_count; ++i) {
+ py::tuple pair = signal_semaphores[i];
+ signal_list.semaphores[i] = py::cast<HalSemaphore*>(pair[0])->raw_ptr();
+ signal_list.payload_values[i] = py::cast<uint64_t>(pair[1]);
+ }
+ }
+
+ // Unpack command buffers.
+ size_t cb_count = py::len(command_buffers);
+ iree_hal_command_buffer_t** cb_list =
+ static_cast<iree_hal_command_buffer_t**>(
+ alloca(sizeof(iree_hal_command_buffer_t*) * cb_count));
+ for (size_t i = 0; i < cb_count; ++i) {
+ cb_list[i] = py::cast<HalCommandBuffer*>(command_buffers[i])->raw_ptr();
+ }
+
+ CheckApiStatus(
+ iree_hal_device_queue_execute(raw_ptr(), IREE_HAL_QUEUE_AFFINITY_ANY,
+ wait_list, signal_list, cb_count, cb_list),
+ "executing command buffers");
+}
+
//------------------------------------------------------------------------------
// HalDriver
//------------------------------------------------------------------------------
@@ -561,7 +830,22 @@
py::keep_alive<0, 1>())
.def("begin_profiling", &HalDevice::BeginProfiling,
py::arg("mode") = py::none(), py::arg("file_path") = py::none())
- .def("end_profiling", &HalDevice::EndProfiling);
+ .def("end_profiling", &HalDevice::EndProfiling)
+ .def("create_semaphore", &HalDevice::CreateSemaphore,
+ py::arg("initial_value"))
+ .def("queue_alloca", &HalDevice::QueueAlloca, py::arg("allocation_size"),
+ py::arg("wait_semaphores"), py::arg("signal_semaphores"),
+ kHalDeviceQueueAlloca)
+ .def("queue_dealloca", &HalDevice::QueueDealloca, py::arg("buffer"),
+ py::arg("wait_semaphores"), py::arg("signal_semaphores"),
+ kHalDeviceQueueDealloca)
+ .def("queue_execute", &HalDevice::QueueExecute,
+ py::arg("command_buffers"), py::arg("wait_semaphores"),
+ py::arg("signal_semaphores"), kHalDeviceQueueExecute)
+ .def("__repr__", [](HalDevice& self) {
+ auto id_sv = iree_hal_device_id(self.raw_ptr());
+ return std::string(id_sv.data, id_sv.size);
+ });
py::class_<HalDriver>(m, "HalDriver")
.def_static("query", &HalDriver::Query)
@@ -647,20 +931,50 @@
"object. If an element type is specified, wraps in a BufferView "
"matching the characteristics of the Python buffer. The format is "
"requested as ND/C-Contiguous, which may incur copies if not "
- "already in that format.");
+ "already in that format.")
+ .def("allocate_host_staging_buffer_copy",
+ &HalAllocator::AllocateHostStagingBufferCopy,
+ py::arg("initial_contents"), py::keep_alive<0, 1>(),
+ "Allocates a new buffer and initializes it from a Python buffer "
+ "object. The buffer is configured as optimal for use on the device "
+ "as a transfer buffer. For buffers of unknown providence, this is a "
+ "last resort method for making them compatible for transfer to "
+ "arbitrary devices.");
py::class_<HalBuffer>(m, "HalBuffer")
.def("fill_zero", &HalBuffer::FillZero, py::arg("byte_offset"),
py::arg("byte_length"))
.def("create_view", &HalBuffer::CreateView, py::arg("shape"),
py::arg("element_size"), py::keep_alive<0, 1>())
+ .def("map", HalMappedMemory::CreateFromBuffer, py::keep_alive<0, 1>())
.def("__repr__", &HalBuffer::Repr);
auto hal_buffer_view = py::class_<HalBufferView>(m, "HalBufferView");
VmRef::BindRefProtocol(hal_buffer_view, iree_hal_buffer_view_type,
iree_hal_buffer_view_retain_ref,
iree_hal_buffer_view_deref, iree_hal_buffer_view_isa);
- hal_buffer_view.def("map", HalMappedMemory::Create, py::keep_alive<0, 1>())
+ hal_buffer_view.def(
+ "__init__",
+ [](HalBufferView* new_self, HalBuffer& buffer, py::handle shape,
+ iree_hal_element_type_t element_type) {
+ size_t rank = py::len(shape);
+ iree_hal_dim_t* dims =
+ static_cast<iree_hal_dim_t*>(alloca(sizeof(iree_hal_dim_t) * rank));
+ for (size_t i = 0; i < rank; ++i) {
+ dims[i] = py::cast<iree_hal_dim_t>(shape[i]);
+ }
+ iree_hal_buffer_view_t* out_bv;
+ CheckApiStatus(iree_hal_buffer_view_create(
+ buffer.raw_ptr(), rank, dims, element_type,
+ IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
+ iree_allocator_system(), &out_bv),
+ "creating buffer view");
+ new (new_self) HalBufferView();
+ *new_self = HalBufferView::StealFromRawPtr(out_bv);
+ },
+ py::arg("buffer"), py::arg("shape"), py::arg("element_type"));
+ hal_buffer_view
+ .def("map", HalMappedMemory::CreateFromBufferView, py::keep_alive<0, 1>())
.def_prop_ro("shape",
[](HalBufferView& self) {
iree_host_size_t rank =
@@ -679,25 +993,229 @@
})
.def("__repr__", &HalBufferView::Repr);
+ py::class_<HalSemaphore>(m, "HalSemaphore")
+ .def("query",
+ [](HalSemaphore& self) {
+ uint64_t out_value;
+ CheckApiStatus(
+ iree_hal_semaphore_query(self.raw_ptr(), &out_value),
+ "querying semaphore");
+ return out_value;
+ })
+ .def("signal", [](HalSemaphore& self, uint64_t new_value) {
+ CheckApiStatus(iree_hal_semaphore_signal(self.raw_ptr(), new_value),
+ "signaling semaphore");
+ });
+
+ py::class_<HalFence>(m, "HalFence")
+ .def(
+ "__init__",
+ [](HalFence* new_fence, iree_host_size_t capacity) {
+ iree_hal_fence_t* out_fence;
+ CheckApiStatus(iree_hal_fence_create(
+ capacity, iree_allocator_system(), &out_fence),
+ "creating fence");
+ new (new_fence) HalFence();
+ (*new_fence) = HalFence::StealFromRawPtr(out_fence);
+ },
+ py::arg("capacity"))
+ .def_static(
+ "create_at",
+ [](HalSemaphore& sem, uint64_t value) {
+ iree_hal_fence_t* out_fence;
+ CheckApiStatus(
+ iree_hal_fence_create_at(sem.raw_ptr(), value,
+ iree_allocator_system(), &out_fence),
+ "creating fence");
+ return HalFence::StealFromRawPtr(out_fence);
+ },
+ py::arg("sem"), py::arg("value"))
+ .def_static(
+ "join",
+ [](py::sequence fences) {
+ size_t count = py::len(fences);
+ iree_hal_fence_t** fence_ptrs = static_cast<iree_hal_fence_t**>(
+ alloca(sizeof(iree_hal_fence_t*) * count));
+ for (size_t i = 0; i < count; ++i) {
+ fence_ptrs[i] = py::cast<HalFence*>(fences[i])->raw_ptr();
+ }
+ iree_hal_fence_t* out_fence;
+ CheckApiStatus(
+ iree_hal_fence_join(count, fence_ptrs, iree_allocator_system(),
+ &out_fence),
+ "joining fences");
+ return HalFence::StealFromRawPtr(out_fence);
+ },
+ py::arg("fences"))
+ .def_prop_ro("timepoint_count",
+ [](HalFence& self) {
+ return iree_hal_fence_timepoint_count(self.raw_ptr());
+ })
+ .def(
+ "insert",
+ [](HalFence& self, HalSemaphore& sem, uint64_t value) {
+ CheckApiStatus(
+ iree_hal_fence_insert(self.raw_ptr(), sem.raw_ptr(), value),
+ "insertint into fence");
+ },
+ py::arg("sem"), py::arg("value"))
+ .def(
+ "extend",
+ [](HalFence& self, HalFence& from_fence) {
+ CheckApiStatus(
+ iree_hal_fence_extend(self.raw_ptr(), from_fence.raw_ptr()),
+ "extending fence");
+ },
+ py::arg("from_fence"))
+ .def(
+ "wait",
+ [](HalFence& self, std::optional<iree_duration_t> timeout,
+ std::optional<iree_time_t> deadline) {
+ iree_timeout_t t;
+ if (!timeout && !deadline) {
+ t = iree_infinite_timeout();
+ } else if (timeout && deadline) {
+ throw std::invalid_argument(
+ "timeout and deadline cannot both be set");
+ } else if (timeout) {
+ t = iree_make_timeout_ns(*timeout);
+ } else {
+ t = iree_timeout_t{IREE_TIMEOUT_ABSOLUTE, *deadline};
+ }
+ iree_status_t status;
+ {
+ py::gil_scoped_release release;
+ status = iree_hal_fence_wait(self.raw_ptr(), t);
+ }
+ CheckApiStatus(status, "waiting for fence");
+ },
+ py::arg("timeout") = py::none(), py::arg("deadline") = py::none(),
+ kHalFenceWait);
+
py::class_<HalMappedMemory>(m, "MappedMemory")
.def(
"asarray",
- [](HalMappedMemory* self, std::vector<iree_host_size_t> shape,
- py::object dtype) {
+ [](HalMappedMemory* self, py::handle shape, py::object dtype_descr) {
py::object py_mapped_memory = py::cast(self);
- static_assert(sizeof(shape[0]) == sizeof(intptr_t),
- "size_t not of same size as intptr_t");
- int typenum = numpy::TypenumFromDescr(dtype);
- return numpy::SimpleNewFromData(
- shape.size(), reinterpret_cast<intptr_t const*>(shape.data()),
- typenum, self->mapped_memory().contents.data, py_mapped_memory);
+ size_t rank = py::len(shape);
+ intptr_t* dims =
+ static_cast<intptr_t*>(alloca(sizeof(intptr_t) * rank));
+ for (size_t i = 0; i < rank; ++i) {
+ dims[i] = py::cast<intptr_t>(shape[i]);
+ }
+ int typenum = numpy::TypenumFromDescr(dtype_descr);
+ return numpy::SimpleNewFromData(rank, dims, typenum,
+ self->mapped_memory().contents.data,
+ py_mapped_memory);
},
- py::arg("shape"), py::arg("element_type"));
+ py::arg("shape"), py::arg("numpy_dtype_descr"));
py::class_<HalShape>(m, "Shape")
.def("__init__", [](HalShape* self, std::vector<iree_hal_dim_t> indices) {
new (self) HalShape(indices);
});
+
+ py::class_<HalCommandBuffer>(m, "HalCommandBuffer")
+ .def(
+ "__init__",
+ [](HalCommandBuffer* new_self, HalDevice& device,
+ iree_host_size_t binding_capacity, bool begin) {
+ iree_hal_command_buffer_t* out_cb;
+ CheckApiStatus(iree_hal_command_buffer_create(
+ device.raw_ptr(),
+ /*mode=*/IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT,
+ /*categories=*/IREE_HAL_COMMAND_CATEGORY_ANY,
+ /*queue_affinity=*/IREE_HAL_QUEUE_AFFINITY_ANY,
+ binding_capacity, &out_cb),
+ "creating command buffer");
+ HalCommandBuffer cb = HalCommandBuffer::StealFromRawPtr(out_cb);
+ if (begin) {
+ CheckApiStatus(iree_hal_command_buffer_begin(cb.raw_ptr()),
+ "command buffer begin");
+ }
+ new (new_self) HalCommandBuffer();
+ *new_self = std::move(cb);
+ },
+ py::arg("device"), py::arg("binding_capacity") = 0,
+ py::arg("begin") = true)
+ .def("begin",
+ [](HalCommandBuffer& self) {
+ CheckApiStatus(iree_hal_command_buffer_begin(self.raw_ptr()),
+ "command buffer begin");
+ })
+ .def("end",
+ [](HalCommandBuffer& self) {
+ CheckApiStatus(iree_hal_command_buffer_end(self.raw_ptr()),
+ "command buffer end");
+ })
+ .def(
+ "copy",
+ [](HalCommandBuffer& self, HalBuffer& source_buffer,
+ HalBuffer& target_buffer, iree_device_size_t source_offset,
+ iree_device_size_t target_offset,
+ std::optional<iree_device_size_t> length, bool end) {
+ iree_device_size_t resolved_length;
+ if (length) {
+ resolved_length = *length;
+ } else {
+ resolved_length =
+ iree_hal_buffer_byte_length(source_buffer.raw_ptr());
+ if (resolved_length !=
+ iree_hal_buffer_byte_length(target_buffer.raw_ptr())) {
+ throw std::invalid_argument(
+ "If length is not provided, source and target bufer length "
+ "must match and it does not. Provide explicit length=");
+ }
+ }
+ CheckApiStatus(
+ iree_hal_command_buffer_copy_buffer(
+ self.raw_ptr(), source_buffer.raw_ptr(), source_offset,
+ target_buffer.raw_ptr(), target_offset, resolved_length),
+ "copy command");
+ if (end) {
+ CheckApiStatus(iree_hal_command_buffer_end(self.raw_ptr()),
+ "command buffer end");
+ }
+ },
+ py::arg("source_buffer"), py::arg("target_buffer"),
+ py::arg("source_offset") = 0, py::arg("target_offset") = 0,
+ py::arg("length") = py::none(), py::arg("end") = false,
+ "Copies a range from a source to target buffer. If the length is "
+ "not specified, then it is taken from the source/target buffer, "
+ "which must match.")
+ .def(
+ "fill",
+ [](HalCommandBuffer& self, HalBuffer& target_buffer,
+ py::handle pattern, iree_device_size_t target_offset,
+ std::optional<iree_device_size_t> length, bool end) {
+ Py_buffer pattern_view;
+ int flags = PyBUF_FORMAT | PyBUF_ND;
+ if (PyObject_GetBuffer(pattern.ptr(), &pattern_view, flags) != 0) {
+ // The GetBuffer call is required to set an appropriate error.
+ throw py::python_error();
+ }
+ PyBufferReleaser py_pattern_releaser(pattern_view);
+
+ iree_device_size_t resolved_length;
+ if (length) {
+ resolved_length = *length;
+ } else {
+ resolved_length =
+ iree_hal_buffer_byte_length(target_buffer.raw_ptr());
+ }
+ CheckApiStatus(
+ iree_hal_command_buffer_fill_buffer(
+ self.raw_ptr(), target_buffer.raw_ptr(), target_offset,
+ resolved_length, pattern_view.buf, pattern_view.len),
+ "command buffer fill");
+ if (end) {
+ CheckApiStatus(iree_hal_command_buffer_end(self.raw_ptr()),
+ "command buffer end");
+ }
+ },
+ py::arg("target_buffer"), py::arg("pattern"),
+ py::arg("target_offset") = 0, py::arg("length") = py::none(),
+ py::arg("end") = false);
}
} // namespace python
diff --git a/runtime/bindings/python/hal.h b/runtime/bindings/python/hal.h
index bae4d96..e4b681d 100644
--- a/runtime/bindings/python/hal.h
+++ b/runtime/bindings/python/hal.h
@@ -76,10 +76,41 @@
}
};
+template <>
+struct ApiPtrAdapter<iree_hal_semaphore_t> {
+ static void Retain(iree_hal_semaphore_t* sem) {
+ iree_hal_semaphore_retain(sem);
+ }
+ static void Release(iree_hal_semaphore_t* sem) {
+ iree_hal_semaphore_release(sem);
+ }
+};
+
+template <>
+struct ApiPtrAdapter<iree_hal_fence_t> {
+ static void Retain(iree_hal_fence_t* fence) { iree_hal_fence_retain(fence); }
+ static void Release(iree_hal_fence_t* fence) {
+ iree_hal_fence_release(fence);
+ }
+};
+
+template <>
+struct ApiPtrAdapter<iree_hal_command_buffer_t> {
+ static void Retain(iree_hal_command_buffer_t* cb) {
+ iree_hal_command_buffer_retain(cb);
+ }
+ static void Release(iree_hal_command_buffer_t* cb) {
+ iree_hal_command_buffer_release(cb);
+ }
+};
+
//------------------------------------------------------------------------------
// ApiRefCounted types
//------------------------------------------------------------------------------
+class HalBuffer;
+class HalSemaphore;
+
class HalDevice : public ApiRefCounted<HalDevice, iree_hal_device_t> {
public:
iree_hal_allocator_t* allocator() {
@@ -89,6 +120,13 @@
void BeginProfiling(std::optional<std::string> mode,
std::optional<std::string> file_path);
void EndProfiling();
+ HalSemaphore CreateSemaphore(uint64_t initial_value);
+ HalBuffer QueueAlloca(uint64_t allocation_size, py::handle wait_semaphores,
+ py::handle signal_semaphores);
+ void QueueDealloca(HalBuffer& buffer, py::handle wait_semaphores,
+ py::handle signal_semaphores);
+ void QueueExecute(py::handle command_buffers, py::handle wait_semaphores,
+ py::handle signal_semaphores);
};
class HalDriver : public ApiRefCounted<HalDriver, iree_hal_driver_t> {
@@ -113,6 +151,7 @@
py::object AllocateBufferCopy(
int memory_type, int allowed_usage, py::object buffer,
std::optional<iree_hal_element_types_t> element_type);
+ HalBuffer AllocateHostStagingBufferCopy(py::handle buffer);
};
struct HalShape {
@@ -160,28 +199,35 @@
py::str Repr();
};
-// Wrapper around an iree_hal_buffer_mapping_t and iree_hal_buffer_view_t
+class HalSemaphore : public ApiRefCounted<HalSemaphore, iree_hal_semaphore_t> {
+ public:
+};
+
+class HalFence : public ApiRefCounted<HalFence, iree_hal_fence_t> {
+ public:
+};
+
+// Wrapper around an iree_hal_buffer_mapping_t and iree_hal_buffer_t
// which retains the latter and unmaps/releases on deallocation.
class HalMappedMemory {
public:
HalMappedMemory(iree_hal_buffer_mapping_t mapped_memory,
- iree_hal_buffer_view_t* bv)
- : mapped_memory_(mapped_memory), bv_(bv) {
- iree_hal_buffer_view_retain(bv_);
+ iree_hal_buffer_t* buffer)
+ : mapped_memory_(mapped_memory), buffer_(buffer) {
+ iree_hal_buffer_retain(buffer_);
}
~HalMappedMemory() {
- if (bv_) {
+ if (buffer_) {
iree_hal_buffer_unmap_range(&mapped_memory_);
- iree_hal_buffer_view_release(bv_);
+ iree_hal_buffer_release(buffer_);
}
}
HalMappedMemory(HalMappedMemory&& other)
- : mapped_memory_(other.mapped_memory_), bv_(other.bv_) {
- other.bv_ = nullptr;
+ : mapped_memory_(other.mapped_memory_), buffer_(other.buffer_) {
+ other.buffer_ = nullptr;
}
- static HalMappedMemory Create(HalBufferView& bv) {
- iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(bv.raw_ptr());
+ static HalMappedMemory Create(iree_hal_buffer_t* buffer) {
iree_device_size_t byte_length = iree_hal_buffer_byte_length(buffer);
iree_hal_buffer_mapping_t mapped_memory = {{0}};
CheckApiStatus(
@@ -189,16 +235,25 @@
IREE_HAL_MEMORY_ACCESS_READ, 0, byte_length,
&mapped_memory),
"Could not map memory");
- return HalMappedMemory(mapped_memory, bv.raw_ptr());
+ return HalMappedMemory(mapped_memory, buffer);
+ }
+ static HalMappedMemory CreateFromBuffer(HalBuffer& b) {
+ return Create(b.raw_ptr());
+ }
+ static HalMappedMemory CreateFromBufferView(HalBufferView& bv) {
+ return Create(iree_hal_buffer_view_buffer(bv.raw_ptr()));
}
iree_hal_buffer_mapping_t& mapped_memory() { return mapped_memory_; }
private:
iree_hal_buffer_mapping_t mapped_memory_ = {{0}};
- iree_hal_buffer_view_t* bv_ = nullptr;
+ iree_hal_buffer_t* buffer_ = nullptr;
};
+class HalCommandBuffer
+ : public ApiRefCounted<HalCommandBuffer, iree_hal_command_buffer_t> {};
+
void SetupHalBindings(nanobind::module_ m);
} // namespace python
diff --git a/runtime/bindings/python/iree/runtime/__init__.py b/runtime/bindings/python/iree/runtime/__init__.py
index cb0bdd6..7363925 100644
--- a/runtime/bindings/python/iree/runtime/__init__.py
+++ b/runtime/bindings/python/iree/runtime/__init__.py
@@ -20,9 +20,13 @@
HalAllocator,
HalBuffer,
HalBufferView,
+ HalCommandBuffer,
HalDevice,
HalDriver,
HalElementType,
+ HalFence,
+ HalSemaphore,
+ MappedMemory,
MemoryAccess,
MemoryType,
PyModuleInterface,
diff --git a/runtime/bindings/python/tests/hal_test.py b/runtime/bindings/python/tests/hal_test.py
index 38eaf3d..8067ae0 100644
--- a/runtime/bindings/python/tests/hal_test.py
+++ b/runtime/bindings/python/tests/hal_test.py
@@ -128,6 +128,29 @@
)
print("BUFFER:", buffer)
+ def testBufferViewConstructor(self):
+ buffer = self.allocator.allocate_buffer(
+ memory_type=iree.runtime.MemoryType.DEVICE_LOCAL,
+ allowed_usage=iree.runtime.BufferUsage.DEFAULT,
+ allocation_size=13,
+ )
+ bv = iree.runtime.HalBufferView(
+ buffer, (1, 2), iree.runtime.HalElementType.INT_16
+ )
+ self.assertEqual(
+ repr(bv),
+ "<HalBufferView (1, 2), element_type=0x10000010, 13 bytes (at offset 0 into 13), memory_type=DEVICE_LOCAL|HOST_VISIBLE, allowed_access=ALL, allowed_usage=TRANSFER|DISPATCH_STORAGE|MAPPING>",
+ )
+
+ def testBufferMap(self):
+ buffer = self.allocator.allocate_buffer(
+ memory_type=iree.runtime.MemoryType.DEVICE_LOCAL,
+ allowed_usage=iree.runtime.BufferUsage.DEFAULT,
+ allocation_size=13,
+ )
+ m = buffer.map()
+ self.assertIsInstance(m, iree.runtime.MappedMemory)
+
def testAllocateBufferCopy(self):
ary = np.zeros([3, 4], dtype=np.int32) + 2
buffer = self.allocator.allocate_buffer_copy(
@@ -153,6 +176,156 @@
"<HalBufferView (3, 4), element_type=0x20000011, 48 bytes (at offset 0 into 48), memory_type=DEVICE_LOCAL|HOST_VISIBLE, allowed_access=ALL, allowed_usage=TRANSFER|DISPATCH_STORAGE|MAPPING>",
)
+ def testAllocateHostStagingBufferCopy(self):
+ buffer = self.allocator.allocate_host_staging_buffer_copy(np.int32(0))
+ self.assertEqual(
+ repr(buffer),
+ "<HalBuffer 4 bytes (at offset 0 into 4), memory_type=DEVICE_LOCAL|HOST_VISIBLE, allowed_access=ALL, allowed_usage=TRANSFER|MAPPING>",
+ )
+
+ def testSemaphore(self):
+ sem0 = self.device.create_semaphore(0)
+ self.assertEqual(sem0.query(), 0)
+ sem1 = self.device.create_semaphore(1)
+ self.assertEqual(sem1.query(), 1)
+ sem1.signal(2)
+ self.assertEqual(sem1.query(), 2)
+
+ def testTrivialQueueAlloc(self):
+ sem = self.device.create_semaphore(0)
+ buf = self.device.queue_alloca(
+ 1024, wait_semaphores=[(sem, 0)], signal_semaphores=[(sem, 1)]
+ )
+ self.assertIsInstance(buf, iree.runtime.HalBuffer)
+ self.device.queue_dealloca(
+ buf, wait_semaphores=[(sem, 1)], signal_semaphores=[]
+ )
+
+ def testAllocAcceptsFences(self):
+ # Also tests HalFence, HalFence.insert, HalFence.wait (infinite)
+ sem = self.device.create_semaphore(0)
+ fence0 = iree.runtime.HalFence(1)
+ fence0.insert(sem, 0)
+ fence1 = iree.runtime.HalFence(1)
+ fence1.insert(sem, 1)
+ fence2 = iree.runtime.HalFence(2)
+ fence2.insert(sem, 2)
+ buf = self.device.queue_alloca(
+ 1024, wait_semaphores=fence0, signal_semaphores=fence1
+ )
+ self.assertIsInstance(buf, iree.runtime.HalBuffer)
+ self.device.queue_dealloca(
+ buf, wait_semaphores=fence1, signal_semaphores=fence2
+ )
+ fence2.wait()
+ self.assertEqual(sem.query(), 2)
+
+ def testFenceCreateAt(self):
+ sem = self.device.create_semaphore(0)
+ fence = iree.runtime.HalFence.create_at(sem, 1)
+ with self.assertRaisesRegex(RuntimeError, "DEADLINE_EXCEEDED"):
+ fence.wait(deadline=0)
+ sem.signal(1)
+ fence.wait(deadline=0)
+
+ def testFenceJoin(self):
+ sem1 = self.device.create_semaphore(0)
+ sem2 = self.device.create_semaphore(0)
+ fence1 = iree.runtime.HalFence.create_at(sem1, 1)
+ fence2 = iree.runtime.HalFence.create_at(sem2, 1)
+ fence = iree.runtime.HalFence.join([fence1, fence2])
+ self.assertEqual(fence.timepoint_count, 2)
+
+ def testFenceInsert(self):
+ sem1 = self.device.create_semaphore(0)
+ sem2 = self.device.create_semaphore(0)
+ fence = iree.runtime.HalFence(2)
+ fence.insert(sem1, 1)
+ self.assertEqual(fence.timepoint_count, 1)
+ fence.insert(sem1, 2)
+ self.assertEqual(fence.timepoint_count, 1)
+ fence.insert(sem2, 2)
+ self.assertEqual(fence.timepoint_count, 2)
+
+ def testFenceExtend(self):
+ sem1 = self.device.create_semaphore(0)
+ sem2 = self.device.create_semaphore(0)
+ fence = iree.runtime.HalFence(2)
+ fence.insert(sem1, 1)
+ self.assertEqual(fence.timepoint_count, 1)
+ fence.extend(iree.runtime.HalFence.create_at(sem2, 2))
+ self.assertEqual(fence.timepoint_count, 2)
+
+ def testCommandBufferStartsByDefault(self):
+ cb = iree.runtime.HalCommandBuffer(self.device)
+ with self.assertRaisesRegex(RuntimeError, "FAILED_PRECONDITION"):
+ cb.begin()
+ cb = iree.runtime.HalCommandBuffer(self.device, begin=False)
+ cb.begin()
+
+ def testCommandBufferCopy(self):
+ # Doesn't test much but that calls succeed.
+ cb = iree.runtime.HalCommandBuffer(self.device)
+ buffer1 = self.allocator.allocate_buffer(
+ memory_type=iree.runtime.MemoryType.DEVICE_LOCAL,
+ allowed_usage=iree.runtime.BufferUsage.DEFAULT,
+ allocation_size=13,
+ )
+ buffer2 = self.allocator.allocate_buffer(
+ memory_type=iree.runtime.MemoryType.DEVICE_LOCAL,
+ allowed_usage=iree.runtime.BufferUsage.DEFAULT,
+ allocation_size=13,
+ )
+ cb.copy(buffer1, buffer2, end=True)
+ with self.assertRaisesRegex(RuntimeError, "FAILED_PRECONDITION"):
+ cb.end()
+
+ def testCommandBufferFill(self):
+ # Doesn't test much but that calls succeed.
+ cb = iree.runtime.HalCommandBuffer(self.device)
+ buffer1 = self.allocator.allocate_buffer(
+ memory_type=iree.runtime.MemoryType.DEVICE_LOCAL,
+ allowed_usage=iree.runtime.BufferUsage.DEFAULT,
+ allocation_size=12,
+ )
+ cb.fill(buffer1, np.int32(1), 0, 12, end=True)
+ with self.assertRaisesRegex(RuntimeError, "FAILED_PRECONDITION"):
+ cb.end()
+
+ def testCommandBufferExecute(self):
+ # Doesn't test much but that calls succeed.
+ cb = iree.runtime.HalCommandBuffer(self.device)
+ buffer1 = self.allocator.allocate_buffer(
+ memory_type=iree.runtime.MemoryType.DEVICE_LOCAL,
+ allowed_usage=iree.runtime.BufferUsage.DEFAULT,
+ allocation_size=12,
+ )
+ cb.fill(buffer1, np.int32(1), 0, 12, end=True)
+
+ sem = self.device.create_semaphore(0)
+ self.device.queue_execute(
+ [cb], wait_semaphores=[(sem, 0)], signal_semaphores=[(sem, 1)]
+ )
+ iree.runtime.HalFence.create_at(sem, 1).wait()
+
+ def testCommandBufferExecuteAcceptsFence(self):
+ # Doesn't test much but that calls succeed.
+ cb = iree.runtime.HalCommandBuffer(self.device)
+ buffer1 = self.allocator.allocate_buffer(
+ memory_type=iree.runtime.MemoryType.DEVICE_LOCAL,
+ allowed_usage=iree.runtime.BufferUsage.DEFAULT,
+ allocation_size=12,
+ )
+ cb.fill(buffer1, np.int32(1), 0, 12, end=True)
+
+ sem = self.device.create_semaphore(0)
+ self.device.queue_execute(
+ [cb],
+ wait_semaphores=iree.runtime.HalFence.create_at(sem, 0),
+ signal_semaphores=iree.runtime.HalFence.create_at(sem, 1),
+ )
+ iree.runtime.HalFence.create_at(sem, 1).wait()
+
if __name__ == "__main__":
unittest.main()