[Bindings] Implement alloc + copy to local host when map is unavailable. (#14997)
diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc
index fba5a73..8ca1078 100644
--- a/runtime/bindings/python/hal.cc
+++ b/runtime/bindings/python/hal.cc
@@ -57,6 +57,18 @@
signal_semaphores: Semaphores/Fence to signal.
)";
+static const char kHalDeviceQueueCopy[] =
+ R"(Copy data from a source buffer to destination buffer.
+
+Args:
+ source_buffer: `HalBuffer` that holds src data.
+ target_buffer: `HalBuffer` that will receive data.
+ wait_semaphores: `List[Tuple[HalSemaphore, int]]` of semaphore values or
+ a HalFence. The allocation will be made once these semaphores are
+ satisfied.
+ signal_semaphores: Semaphores/Fence to signal.
+)";
+
static const char kHalFenceWait[] =
R"(Waits until the fence is signalled or errored.
@@ -524,6 +536,69 @@
"executing command buffers");
}
+void HalDevice::QueueCopy(HalBuffer& source_buffer, HalBuffer& target_buffer,
+ py::handle wait_semaphores,
+ py::handle signal_semaphores) {
+ iree_hal_semaphore_list_t wait_list;
+ iree_hal_semaphore_list_t signal_list;
+
+ // Wait list.
+ if (py::isinstance<HalFence>(wait_semaphores)) {
+ wait_list = iree_hal_fence_semaphore_list(
+ py::cast<HalFence*>(wait_semaphores)->raw_ptr());
+ } else {
+ size_t wait_count = py::len(wait_semaphores);
+ wait_list = {
+ wait_count,
+ /*semaphores=*/
+ static_cast<iree_hal_semaphore_t**>(
+ alloca(sizeof(iree_hal_semaphore_t*) * wait_count)),
+ /*payload_values=*/
+ static_cast<uint64_t*>(alloca(sizeof(uint64_t) * wait_count)),
+ };
+ for (size_t i = 0; i < wait_count; ++i) {
+ py::tuple pair = wait_semaphores[i];
+ wait_list.semaphores[i] = py::cast<HalSemaphore*>(pair[0])->raw_ptr();
+ wait_list.payload_values[i] = py::cast<uint64_t>(pair[1]);
+ }
+ }
+
+ // Signal list.
+ if (py::isinstance<HalFence>(signal_semaphores)) {
+ signal_list = iree_hal_fence_semaphore_list(
+ py::cast<HalFence*>(signal_semaphores)->raw_ptr());
+ } else {
+ size_t signal_count = py::len(signal_semaphores);
+ signal_list = {
+ signal_count,
+ /*semaphores=*/
+ static_cast<iree_hal_semaphore_t**>(
+ alloca(sizeof(iree_hal_semaphore_t*) * signal_count)),
+ /*payload_values=*/
+ static_cast<uint64_t*>(alloca(sizeof(uint64_t) * signal_count)),
+ };
+ for (size_t i = 0; i < signal_count; ++i) {
+ py::tuple pair = signal_semaphores[i];
+ signal_list.semaphores[i] = py::cast<HalSemaphore*>(pair[0])->raw_ptr();
+ signal_list.payload_values[i] = py::cast<uint64_t>(pair[1]);
+ }
+ }
+
+ // TODO: Accept params for src_offset and target_offset.
+ iree_device_size_t source_length =
+ iree_hal_buffer_byte_length(source_buffer.raw_ptr());
+ if (source_length != iree_hal_buffer_byte_length(target_buffer.raw_ptr())) {
+ throw std::invalid_argument(
+ "Source and target buffer length must match and it does not. Please "
+ "check allocations");
+ }
+ CheckApiStatus(iree_hal_device_queue_copy(
+ raw_ptr(), IREE_HAL_QUEUE_AFFINITY_ANY, wait_list,
+ signal_list, source_buffer.raw_ptr(), 0,
+ target_buffer.raw_ptr(), 0, source_length),
+ "Copying buffer on queue");
+}
+
//------------------------------------------------------------------------------
// HalDriver
//------------------------------------------------------------------------------
@@ -861,6 +936,9 @@
.def("queue_execute", &HalDevice::QueueExecute,
py::arg("command_buffers"), py::arg("wait_semaphores"),
py::arg("signal_semaphores"), kHalDeviceQueueExecute)
+ .def("queue_copy", &HalDevice::QueueCopy, py::arg("source_buffer"),
+ py::arg("target_buffer"), py::arg("wait_semaphores"),
+ py::arg("signal_semaphores"), kHalDeviceQueueCopy)
.def("__repr__", [](HalDevice& self) {
auto id_sv = iree_hal_device_id(self.raw_ptr());
return std::string(id_sv.data, id_sv.size);
@@ -963,6 +1041,9 @@
py::class_<HalBuffer>(m, "HalBuffer")
.def("fill_zero", &HalBuffer::FillZero, py::arg("byte_offset"),
py::arg("byte_length"))
+ .def("byte_length", &HalBuffer::byte_length)
+ .def("memory_type", &HalBuffer::memory_type)
+ .def("allowed_usage", &HalBuffer::allowed_usage)
.def("create_view", &HalBuffer::CreateView, py::arg("shape"),
py::arg("element_size"), py::keep_alive<0, 1>())
.def("map", HalMappedMemory::CreateFromBuffer, py::keep_alive<0, 1>())
@@ -994,6 +1075,8 @@
py::arg("buffer"), py::arg("shape"), py::arg("element_type"));
hal_buffer_view
.def("map", HalMappedMemory::CreateFromBufferView, py::keep_alive<0, 1>())
+ .def("get_buffer", HalBuffer::CreateFromBufferView,
+ py::keep_alive<0, 1>())
.def_prop_ro("shape",
[](HalBufferView& self) {
iree_host_size_t rank =
diff --git a/runtime/bindings/python/hal.h b/runtime/bindings/python/hal.h
index 0c3bc63..5e18cfa 100644
--- a/runtime/bindings/python/hal.h
+++ b/runtime/bindings/python/hal.h
@@ -128,6 +128,8 @@
py::handle signal_semaphores);
void QueueExecute(py::handle command_buffers, py::handle wait_semaphores,
py::handle signal_semaphores);
+ void QueueCopy(HalBuffer& src_buffer, HalBuffer& dst_buffer,
+ py::handle wait_semaphores, py::handle signal_semaphores);
};
class HalDriver : public ApiRefCounted<HalDriver, iree_hal_driver_t> {
@@ -176,6 +178,10 @@
return iree_hal_buffer_byte_length(raw_ptr());
}
+ int memory_type() const { return iree_hal_buffer_memory_type(raw_ptr()); }
+
+ int allowed_usage() const { return iree_hal_buffer_allowed_usage(raw_ptr()); }
+
void FillZero(iree_device_size_t byte_offset,
iree_device_size_t byte_length) {
CheckApiStatus(
@@ -197,6 +203,11 @@
return HalBufferView::StealFromRawPtr(bv);
}
+ static HalBuffer CreateFromBufferView(HalBufferView& bv) {
+ return HalBuffer::BorrowFromRawPtr(
+ iree_hal_buffer_view_buffer(bv.raw_ptr()));
+ }
+
py::str Repr();
};
diff --git a/runtime/bindings/python/iree/runtime/array_interop.py b/runtime/bindings/python/iree/runtime/array_interop.py
index 096fc9b..fb67b21 100644
--- a/runtime/bindings/python/iree/runtime/array_interop.py
+++ b/runtime/bindings/python/iree/runtime/array_interop.py
@@ -17,6 +17,7 @@
HalElementType,
MappedMemory,
MemoryType,
+ HalFence,
)
__all__ = [
@@ -106,6 +107,20 @@
self._transfer_to_host(False)
return self._host_array
+ def _is_mappable(self) -> bool:
+ buffer = self._buffer_view.get_buffer()
+ if (
+ buffer.memory_type() & int(MemoryType.HOST_VISIBLE)
+ != MemoryType.HOST_VISIBLE
+ ):
+ return False
+ if (
+ buffer.allowed_usage() & int(BufferUsage.MAPPING_SCOPED)
+ != BufferUsage.MAPPING_SCOPED
+ ):
+ return False
+ return True
+
def _transfer_to_host(self, implicit):
if self._host_array is not None:
return
@@ -114,7 +129,10 @@
"DeviceArray cannot be implicitly transferred to the host: "
"if necessary, do an explicit transfer via .to_host()"
)
- self._mapped_memory, self._host_array = self._map_to_host()
+ if self._is_mappable():
+ self._mapped_memory, self._host_array = self._map_to_host()
+ else:
+ self._host_array = self._copy_to_host()
def _map_to_host(self) -> Tuple[MappedMemory, np.ndarray]:
# TODO: When synchronization is enabled, need to block here.
@@ -129,6 +147,35 @@
host_array = host_array.astype(self._override_dtype)
return mapped_memory, host_array
+ def _copy_to_host(self) -> np.ndarray:
+ # TODO: When synchronization is enabled, need to block here.
+ source_buffer = self._buffer_view.get_buffer()
+ host_buffer = self._device.allocator.allocate_buffer(
+ memory_type=(MemoryType.HOST_LOCAL | MemoryType.DEVICE_VISIBLE),
+ allowed_usage=(BufferUsage.TRANSFER_TARGET | BufferUsage.MAPPING_SCOPED),
+ allocation_size=source_buffer.byte_length(),
+ )
+ # Copy and wait for buffer to be copied from source buffer.
+ sem = self._device.create_semaphore(0)
+ self._device.queue_copy(
+ source_buffer,
+ host_buffer,
+ wait_semaphores=HalFence.create_at(sem, 0),
+ signal_semaphores=HalFence.create_at(sem, 1),
+ )
+ HalFence.create_at(sem, 1).wait()
+ # Map and reformat buffer as np.array.
+ raw_dtype = self._get_raw_dtype()
+ mapped_memory = host_buffer.map()
+ host_array = mapped_memory.asarray(self._buffer_view.shape, raw_dtype)
+ # Detect if we need to force an explicit conversion. This happens when
+ # we were requested to pretend that the array is in a specific dtype,
+ # even if that is not representable on the device. You guessed it:
+ # this is to support bools.
+ if self._override_dtype is not None and self._override_dtype != raw_dtype:
+ host_array = host_array.astype(self._override_dtype)
+ return host_array
+
def _get_raw_dtype(self):
return HalElementType.map_to_dtype(self._buffer_view.element_type)
diff --git a/runtime/bindings/python/tests/hal_test.py b/runtime/bindings/python/tests/hal_test.py
index 00076cf..c7407bc 100644
--- a/runtime/bindings/python/tests/hal_test.py
+++ b/runtime/bindings/python/tests/hal_test.py
@@ -265,6 +265,52 @@
fence.extend(iree.runtime.HalFence.create_at(sem2, 2))
self.assertEqual(fence.timepoint_count, 2)
+ def testRoundTripQueueCopy(self):
+ original_ary = np.zeros([3, 4], dtype=np.int32) + 2
+ source_bv = self.allocator.allocate_buffer_copy(
+ memory_type=iree.runtime.MemoryType.DEVICE_LOCAL,
+ allowed_usage=iree.runtime.BufferUsage.DEFAULT,
+ device=self.device,
+ buffer=original_ary,
+ element_type=iree.runtime.HalElementType.SINT_32,
+ )
+ source_buffer = source_bv.get_buffer()
+ target_buffer = self.allocator.allocate_buffer(
+ memory_type=iree.runtime.MemoryType.DEVICE_LOCAL,
+ allowed_usage=iree.runtime.BufferUsage.DEFAULT,
+ allocation_size=source_buffer.byte_length(),
+ )
+ sem = self.device.create_semaphore(0)
+ self.device.queue_copy(
+ source_buffer,
+ target_buffer,
+ wait_semaphores=iree.runtime.HalFence.create_at(sem, 0),
+ signal_semaphores=iree.runtime.HalFence.create_at(sem, 1),
+ )
+ iree.runtime.HalFence.create_at(sem, 1).wait()
+ copy_ary = target_buffer.map().asarray(original_ary.shape, original_ary.dtype)
+ np.testing.assert_array_equal(original_ary, copy_ary)
+
+ def testDifferentSizeQueueCopy(self):
+ source_buffer = self.allocator.allocate_buffer(
+ memory_type=iree.runtime.MemoryType.DEVICE_LOCAL,
+ allowed_usage=iree.runtime.BufferUsage.DEFAULT,
+ allocation_size=12,
+ )
+ target_buffer = self.allocator.allocate_buffer(
+ memory_type=iree.runtime.MemoryType.DEVICE_LOCAL,
+ allowed_usage=iree.runtime.BufferUsage.DEFAULT,
+ allocation_size=13,
+ )
+ sem = self.device.create_semaphore(0)
+ with self.assertRaisesRegex(ValueError, "length must match"):
+ self.device.queue_copy(
+ source_buffer,
+ target_buffer,
+ wait_semaphores=iree.runtime.HalFence.create_at(sem, 0),
+ signal_semaphores=iree.runtime.HalFence.create_at(sem, 1),
+ )
+
def testCommandBufferStartsByDefault(self):
cb = iree.runtime.HalCommandBuffer(self.device)
with self.assertRaisesRegex(RuntimeError, "FAILED_PRECONDITION"):