[Bindings] Implement alloc + copy to local host when map is unavailable. (#14997)

diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc
index fba5a73..8ca1078 100644
--- a/runtime/bindings/python/hal.cc
+++ b/runtime/bindings/python/hal.cc
@@ -57,6 +57,18 @@
   signal_semaphores: Semaphores/Fence to signal.
 )";
 
+static const char kHalDeviceQueueCopy[] =
+    R"(Copy data from a source buffer to destination buffer.
+
+Args:
+  source_buffer: `HalBuffer` that holds src data.
+  target_buffer: `HalBuffer` that will receive data.
+  wait_semaphores: `List[Tuple[HalSemaphore, int]]` of semaphore values or
+    a HalFence. The allocation will be made once these semaphores are
+    satisfied.
+  signal_semaphores: Semaphores/Fence to signal.
+)";
+
 static const char kHalFenceWait[] =
     R"(Waits until the fence is signalled or errored.
 
@@ -524,6 +536,69 @@
       "executing command buffers");
 }
 
+void HalDevice::QueueCopy(HalBuffer& source_buffer, HalBuffer& target_buffer,
+                          py::handle wait_semaphores,
+                          py::handle signal_semaphores) {
+  iree_hal_semaphore_list_t wait_list;
+  iree_hal_semaphore_list_t signal_list;
+
+  // Wait list.
+  if (py::isinstance<HalFence>(wait_semaphores)) {
+    wait_list = iree_hal_fence_semaphore_list(
+        py::cast<HalFence*>(wait_semaphores)->raw_ptr());
+  } else {
+    size_t wait_count = py::len(wait_semaphores);
+    wait_list = {
+        wait_count,
+        /*semaphores=*/
+        static_cast<iree_hal_semaphore_t**>(
+            alloca(sizeof(iree_hal_semaphore_t*) * wait_count)),
+        /*payload_values=*/
+        static_cast<uint64_t*>(alloca(sizeof(uint64_t) * wait_count)),
+    };
+    for (size_t i = 0; i < wait_count; ++i) {
+      py::tuple pair = wait_semaphores[i];
+      wait_list.semaphores[i] = py::cast<HalSemaphore*>(pair[0])->raw_ptr();
+      wait_list.payload_values[i] = py::cast<uint64_t>(pair[1]);
+    }
+  }
+
+  // Signal list.
+  if (py::isinstance<HalFence>(signal_semaphores)) {
+    signal_list = iree_hal_fence_semaphore_list(
+        py::cast<HalFence*>(signal_semaphores)->raw_ptr());
+  } else {
+    size_t signal_count = py::len(signal_semaphores);
+    signal_list = {
+        signal_count,
+        /*semaphores=*/
+        static_cast<iree_hal_semaphore_t**>(
+            alloca(sizeof(iree_hal_semaphore_t*) * signal_count)),
+        /*payload_values=*/
+        static_cast<uint64_t*>(alloca(sizeof(uint64_t) * signal_count)),
+    };
+    for (size_t i = 0; i < signal_count; ++i) {
+      py::tuple pair = signal_semaphores[i];
+      signal_list.semaphores[i] = py::cast<HalSemaphore*>(pair[0])->raw_ptr();
+      signal_list.payload_values[i] = py::cast<uint64_t>(pair[1]);
+    }
+  }
+
+  // TODO: Accept params for src_offset and target_offset.
+  iree_device_size_t source_length =
+      iree_hal_buffer_byte_length(source_buffer.raw_ptr());
+  if (source_length != iree_hal_buffer_byte_length(target_buffer.raw_ptr())) {
+    throw std::invalid_argument(
+        "Source and target buffer length must match and it does not. Please "
+        "check allocations");
+  }
+  CheckApiStatus(iree_hal_device_queue_copy(
+                     raw_ptr(), IREE_HAL_QUEUE_AFFINITY_ANY, wait_list,
+                     signal_list, source_buffer.raw_ptr(), 0,
+                     target_buffer.raw_ptr(), 0, source_length),
+                 "Copying buffer on queue");
+}
+
 //------------------------------------------------------------------------------
 // HalDriver
 //------------------------------------------------------------------------------
@@ -861,6 +936,9 @@
       .def("queue_execute", &HalDevice::QueueExecute,
            py::arg("command_buffers"), py::arg("wait_semaphores"),
            py::arg("signal_semaphores"), kHalDeviceQueueExecute)
+      .def("queue_copy", &HalDevice::QueueCopy, py::arg("source_buffer"),
+           py::arg("target_buffer"), py::arg("wait_semaphores"),
+           py::arg("signal_semaphores"), kHalDeviceQueueCopy)
       .def("__repr__", [](HalDevice& self) {
         auto id_sv = iree_hal_device_id(self.raw_ptr());
         return std::string(id_sv.data, id_sv.size);
@@ -963,6 +1041,9 @@
   py::class_<HalBuffer>(m, "HalBuffer")
       .def("fill_zero", &HalBuffer::FillZero, py::arg("byte_offset"),
            py::arg("byte_length"))
+      .def("byte_length", &HalBuffer::byte_length)
+      .def("memory_type", &HalBuffer::memory_type)
+      .def("allowed_usage", &HalBuffer::allowed_usage)
       .def("create_view", &HalBuffer::CreateView, py::arg("shape"),
            py::arg("element_size"), py::keep_alive<0, 1>())
       .def("map", HalMappedMemory::CreateFromBuffer, py::keep_alive<0, 1>())
@@ -994,6 +1075,8 @@
       py::arg("buffer"), py::arg("shape"), py::arg("element_type"));
   hal_buffer_view
       .def("map", HalMappedMemory::CreateFromBufferView, py::keep_alive<0, 1>())
+      .def("get_buffer", HalBuffer::CreateFromBufferView,
+           py::keep_alive<0, 1>())
       .def_prop_ro("shape",
                    [](HalBufferView& self) {
                      iree_host_size_t rank =
diff --git a/runtime/bindings/python/hal.h b/runtime/bindings/python/hal.h
index 0c3bc63..5e18cfa 100644
--- a/runtime/bindings/python/hal.h
+++ b/runtime/bindings/python/hal.h
@@ -128,6 +128,8 @@
                      py::handle signal_semaphores);
   void QueueExecute(py::handle command_buffers, py::handle wait_semaphores,
                     py::handle signal_semaphores);
+  void QueueCopy(HalBuffer& src_buffer, HalBuffer& dst_buffer,
+                 py::handle wait_semaphores, py::handle signal_semaphores);
 };
 
 class HalDriver : public ApiRefCounted<HalDriver, iree_hal_driver_t> {
@@ -176,6 +178,10 @@
     return iree_hal_buffer_byte_length(raw_ptr());
   }
 
+  int memory_type() const { return iree_hal_buffer_memory_type(raw_ptr()); }
+
+  int allowed_usage() const { return iree_hal_buffer_allowed_usage(raw_ptr()); }
+
   void FillZero(iree_device_size_t byte_offset,
                 iree_device_size_t byte_length) {
     CheckApiStatus(
@@ -197,6 +203,11 @@
     return HalBufferView::StealFromRawPtr(bv);
   }
 
+  static HalBuffer CreateFromBufferView(HalBufferView& bv) {
+    return HalBuffer::BorrowFromRawPtr(
+        iree_hal_buffer_view_buffer(bv.raw_ptr()));
+  }
+
   py::str Repr();
 };
 
diff --git a/runtime/bindings/python/iree/runtime/array_interop.py b/runtime/bindings/python/iree/runtime/array_interop.py
index 096fc9b..fb67b21 100644
--- a/runtime/bindings/python/iree/runtime/array_interop.py
+++ b/runtime/bindings/python/iree/runtime/array_interop.py
@@ -17,6 +17,7 @@
     HalElementType,
     MappedMemory,
     MemoryType,
+    HalFence,
 )
 
 __all__ = [
@@ -106,6 +107,20 @@
         self._transfer_to_host(False)
         return self._host_array
 
+    def _is_mappable(self) -> bool:
+        buffer = self._buffer_view.get_buffer()
+        if (
+            buffer.memory_type() & int(MemoryType.HOST_VISIBLE)
+            != MemoryType.HOST_VISIBLE
+        ):
+            return False
+        if (
+            buffer.allowed_usage() & int(BufferUsage.MAPPING_SCOPED)
+            != BufferUsage.MAPPING_SCOPED
+        ):
+            return False
+        return True
+
     def _transfer_to_host(self, implicit):
         if self._host_array is not None:
             return
@@ -114,7 +129,10 @@
                 "DeviceArray cannot be implicitly transferred to the host: "
                 "if necessary, do an explicit transfer via .to_host()"
             )
-        self._mapped_memory, self._host_array = self._map_to_host()
+        if self._is_mappable():
+            self._mapped_memory, self._host_array = self._map_to_host()
+        else:
+            self._host_array = self._copy_to_host()
 
     def _map_to_host(self) -> Tuple[MappedMemory, np.ndarray]:
         # TODO: When synchronization is enabled, need to block here.
@@ -129,6 +147,35 @@
             host_array = host_array.astype(self._override_dtype)
         return mapped_memory, host_array
 
+    def _copy_to_host(self) -> np.ndarray:
+        # TODO: When synchronization is enabled, need to block here.
+        source_buffer = self._buffer_view.get_buffer()
+        host_buffer = self._device.allocator.allocate_buffer(
+            memory_type=(MemoryType.HOST_LOCAL | MemoryType.DEVICE_VISIBLE),
+            allowed_usage=(BufferUsage.TRANSFER_TARGET | BufferUsage.MAPPING_SCOPED),
+            allocation_size=source_buffer.byte_length(),
+        )
+        # Copy and wait for buffer to be copied from source buffer.
+        sem = self._device.create_semaphore(0)
+        self._device.queue_copy(
+            source_buffer,
+            host_buffer,
+            wait_semaphores=HalFence.create_at(sem, 0),
+            signal_semaphores=HalFence.create_at(sem, 1),
+        )
+        HalFence.create_at(sem, 1).wait()
+        # Map and reformat buffer as np.array.
+        raw_dtype = self._get_raw_dtype()
+        mapped_memory = host_buffer.map()
+        host_array = mapped_memory.asarray(self._buffer_view.shape, raw_dtype)
+        # Detect if we need to force an explicit conversion. This happens when
+        # we were requested to pretend that the array is in a specific dtype,
+        # even if that is not representable on the device. You guessed it:
+        # this is to support bools.
+        if self._override_dtype is not None and self._override_dtype != raw_dtype:
+            host_array = host_array.astype(self._override_dtype)
+        return host_array
+
     def _get_raw_dtype(self):
         return HalElementType.map_to_dtype(self._buffer_view.element_type)
 
diff --git a/runtime/bindings/python/tests/hal_test.py b/runtime/bindings/python/tests/hal_test.py
index 00076cf..c7407bc 100644
--- a/runtime/bindings/python/tests/hal_test.py
+++ b/runtime/bindings/python/tests/hal_test.py
@@ -265,6 +265,52 @@
         fence.extend(iree.runtime.HalFence.create_at(sem2, 2))
         self.assertEqual(fence.timepoint_count, 2)
 
+    def testRoundTripQueueCopy(self):
+        original_ary = np.zeros([3, 4], dtype=np.int32) + 2
+        source_bv = self.allocator.allocate_buffer_copy(
+            memory_type=iree.runtime.MemoryType.DEVICE_LOCAL,
+            allowed_usage=iree.runtime.BufferUsage.DEFAULT,
+            device=self.device,
+            buffer=original_ary,
+            element_type=iree.runtime.HalElementType.SINT_32,
+        )
+        source_buffer = source_bv.get_buffer()
+        target_buffer = self.allocator.allocate_buffer(
+            memory_type=iree.runtime.MemoryType.DEVICE_LOCAL,
+            allowed_usage=iree.runtime.BufferUsage.DEFAULT,
+            allocation_size=source_buffer.byte_length(),
+        )
+        sem = self.device.create_semaphore(0)
+        self.device.queue_copy(
+            source_buffer,
+            target_buffer,
+            wait_semaphores=iree.runtime.HalFence.create_at(sem, 0),
+            signal_semaphores=iree.runtime.HalFence.create_at(sem, 1),
+        )
+        iree.runtime.HalFence.create_at(sem, 1).wait()
+        copy_ary = target_buffer.map().asarray(original_ary.shape, original_ary.dtype)
+        np.testing.assert_array_equal(original_ary, copy_ary)
+
+    def testDifferentSizeQueueCopy(self):
+        source_buffer = self.allocator.allocate_buffer(
+            memory_type=iree.runtime.MemoryType.DEVICE_LOCAL,
+            allowed_usage=iree.runtime.BufferUsage.DEFAULT,
+            allocation_size=12,
+        )
+        target_buffer = self.allocator.allocate_buffer(
+            memory_type=iree.runtime.MemoryType.DEVICE_LOCAL,
+            allowed_usage=iree.runtime.BufferUsage.DEFAULT,
+            allocation_size=13,
+        )
+        sem = self.device.create_semaphore(0)
+        with self.assertRaisesRegex(ValueError, "length must match"):
+            self.device.queue_copy(
+                source_buffer,
+                target_buffer,
+                wait_semaphores=iree.runtime.HalFence.create_at(sem, 0),
+                signal_semaphores=iree.runtime.HalFence.create_at(sem, 1),
+            )
+
     def testCommandBufferStartsByDefault(self):
         cb = iree.runtime.HalCommandBuffer(self.device)
         with self.assertRaisesRegex(RuntimeError, "FAILED_PRECONDITION"):