Replacing most iree_hal_buffer_read_data usage with device transfer.
diff --git a/experimental/web/sample_static/main.c b/experimental/web/sample_static/main.c
index 7f0ca31..720cf3a 100644
--- a/experimental/web/sample_static/main.c
+++ b/experimental/web/sample_static/main.c
@@ -154,9 +154,10 @@
// confidence values for each digit in [0, 9].
float predictions[1 * 10] = {0.0f};
if (iree_status_is_ok(status)) {
- status =
- iree_hal_buffer_read_data(iree_hal_buffer_view_buffer(ret_buffer_view),
- 0, predictions, sizeof(predictions));
+ status = iree_hal_device_transfer_d2h(
+ state->device, iree_hal_buffer_view_buffer(ret_buffer_view), 0,
+ predictions, sizeof(predictions), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
+ iree_infinite_timeout());
}
iree_hal_buffer_view_release(ret_buffer_view);
diff --git a/iree/compiler/ConstEval/Runtime.cpp b/iree/compiler/ConstEval/Runtime.cpp
index cf6d155..4a45a41 100644
--- a/iree/compiler/ConstEval/Runtime.cpp
+++ b/iree/compiler/ConstEval/Runtime.cpp
@@ -230,7 +230,7 @@
iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(bufferView);
// Map the memory and construct.
- // TODO(benvanik): fallback to alloc + iree_hal_buffer_read_data if
+ // TODO(benvanik): fallback to alloc + iree_hal_device_transfer_range if
// mapping is not available. Today with the CPU backends it's always
// possible but would not work with accelerators.
iree_hal_buffer_mapping_t mapping;
diff --git a/iree/hal/cts/command_buffer_dispatch_test.h b/iree/hal/cts/command_buffer_dispatch_test.h
index 9ad32c7..cb36055 100644
--- a/iree/hal/cts/command_buffer_dispatch_test.h
+++ b/iree/hal/cts/command_buffer_dispatch_test.h
@@ -132,10 +132,12 @@
IREE_ASSERT_OK(SubmitCommandBufferAndWait(IREE_HAL_COMMAND_CATEGORY_DISPATCH,
command_buffer));
- float out_value = 0.0f;
- IREE_ASSERT_OK(iree_hal_buffer_read_data(output_buffer, /*source_offset=*/0,
- &out_value, sizeof(out_value)));
- EXPECT_EQ(2.5f, out_value);
+ float output_value = 0.0f;
+ IREE_ASSERT_OK(iree_hal_device_transfer_d2h(
+ device_, output_buffer,
+ /*source_offset=*/0, &output_value, sizeof(output_value),
+ IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout()));
+ EXPECT_EQ(2.5f, output_value);
iree_hal_command_buffer_release(command_buffer);
iree_hal_buffer_release(output_buffer);
diff --git a/iree/hal/cts/command_buffer_test.h b/iree/hal/cts/command_buffer_test.h
index 10e0639..6dcb621 100644
--- a/iree/hal/cts/command_buffer_test.h
+++ b/iree/hal/cts/command_buffer_test.h
@@ -67,10 +67,11 @@
// Read data for returning.
std::vector<uint8_t> actual_data(buffer_size);
- IREE_CHECK_OK(
- iree_hal_buffer_read_data(device_buffer, /*source_offset=*/0,
- /*target_buffer=*/actual_data.data(),
- /*data_length=*/buffer_size));
+ IREE_CHECK_OK(iree_hal_device_transfer_d2h(
+ device_, device_buffer, /*source_offset=*/0,
+ /*target_buffer=*/actual_data.data(),
+ /*data_length=*/buffer_size, IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
+ iree_infinite_timeout()));
// Cleanup and return.
iree_hal_command_buffer_release(command_buffer);
@@ -168,10 +169,11 @@
// Read the device buffer and compare.
std::vector<uint8_t> actual_data(kDefaultAllocationSize);
- IREE_ASSERT_OK(
- iree_hal_buffer_read_data(device_buffer, /*source_offset=*/0,
- /*target_buffer=*/actual_data.data(),
- /*data_length=*/kDefaultAllocationSize));
+ IREE_ASSERT_OK(iree_hal_device_transfer_d2h(
+ device_, device_buffer, /*source_offset=*/0,
+ /*target_buffer=*/actual_data.data(),
+ /*data_length=*/kDefaultAllocationSize,
+ IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout()));
EXPECT_THAT(actual_data, ContainerEq(reference_buffer));
// Must release the command buffer before resources used by it.
@@ -237,10 +239,11 @@
// Read the device buffer and compare.
std::vector<uint8_t> actual_data(kDefaultAllocationSize);
- IREE_ASSERT_OK(
- iree_hal_buffer_read_data(device_buffer, /*source_offset=*/0,
- /*target_buffer=*/actual_data.data(),
- /*data_length=*/kDefaultAllocationSize));
+ IREE_ASSERT_OK(iree_hal_device_transfer_d2h(
+ device_, device_buffer, /*source_offset=*/0,
+ /*target_buffer=*/actual_data.data(),
+ /*data_length=*/kDefaultAllocationSize,
+ IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout()));
EXPECT_THAT(actual_data, ContainerEq(reference_buffer));
// Must release the command buffer before resources used by it.
@@ -445,9 +448,10 @@
// Check that the contents match what we expect.
std::vector<uint8_t> actual_data(target_buffer_size);
- IREE_CHECK_OK(iree_hal_buffer_read_data(device_buffer, /*source_offset=*/0,
- actual_data.data(),
- actual_data.size()));
+ IREE_CHECK_OK(iree_hal_device_transfer_d2h(
+ device_, device_buffer, /*source_offset=*/0, actual_data.data(),
+ actual_data.size(), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
+ iree_infinite_timeout()));
EXPECT_THAT(actual_data, ContainerEq(source_buffer));
iree_hal_command_buffer_release(command_buffer);
@@ -481,9 +485,10 @@
// Check that the contents match what we expect.
std::vector<uint8_t> actual_data(target_buffer_size);
- IREE_CHECK_OK(iree_hal_buffer_read_data(device_buffer, /*source_offset=*/0,
- actual_data.data(),
- actual_data.size()));
+ IREE_CHECK_OK(iree_hal_device_transfer_d2h(
+ device_, device_buffer, /*source_offset=*/0, actual_data.data(),
+ actual_data.size(), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
+ iree_infinite_timeout()));
std::vector<uint8_t> reference_buffer{0x00, 0x00, 0x00, 0x00, //
0x05, 0x06, 0x07, 0x08, //
0xA1, 0xA2, 0xA3, 0xA4, //
@@ -527,9 +532,10 @@
// Check that the contents match what we expect.
std::vector<uint8_t> actual_data(target_buffer_size);
- IREE_ASSERT_OK(iree_hal_buffer_read_data(device_buffer, /*source_offset=*/0,
- actual_data.data(),
- actual_data.size()));
+ IREE_ASSERT_OK(iree_hal_device_transfer_d2h(
+ device_, device_buffer, /*source_offset=*/0, actual_data.data(),
+ actual_data.size(), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
+ iree_infinite_timeout()));
std::vector<uint8_t> reference_buffer{0x00, 0x00, 0x00, 0x00, //
0x00, 0x00, 0x00, 0x00, //
0x05, 0x06, 0x07, 0x08, //
@@ -537,9 +543,10 @@
EXPECT_THAT(actual_data, ContainerEq(reference_buffer));
// Also check the subspan.
std::vector<uint8_t> actual_data_subspan(subspan_length);
- IREE_ASSERT_OK(iree_hal_buffer_read_data(buffer_subspan, /*source_offset=*/0,
- actual_data_subspan.data(),
- actual_data_subspan.size()));
+ IREE_ASSERT_OK(iree_hal_device_transfer_d2h(
+ device_, buffer_subspan, /*source_offset=*/0, actual_data_subspan.data(),
+ actual_data_subspan.size(), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
+ iree_infinite_timeout()));
std::vector<uint8_t> reference_buffer_subspan{0x00, 0x00, 0x00, 0x00, //
0x05, 0x06, 0x07, 0x08};
EXPECT_THAT(actual_data_subspan, ContainerEq(reference_buffer_subspan));
diff --git a/iree/hal/device.c b/iree/hal/device.c
index 0fc08eb..d389906 100644
--- a/iree/hal/device.c
+++ b/iree/hal/device.c
@@ -110,6 +110,39 @@
return status;
}
+IREE_API_EXPORT iree_status_t iree_hal_device_transfer_h2d(
+ iree_hal_device_t* device, const void* source, iree_hal_buffer_t* target,
+ iree_device_size_t target_offset, iree_device_size_t data_length,
+ iree_hal_transfer_buffer_flags_t flags, iree_timeout_t timeout) {
+ return iree_hal_device_transfer_range(
+ device,
+ iree_hal_make_host_transfer_buffer_span((void*)source, data_length), 0,
+ iree_hal_make_device_transfer_buffer(target), target_offset, data_length,
+ flags, timeout);
+}
+
+IREE_API_EXPORT iree_status_t iree_hal_device_transfer_d2h(
+ iree_hal_device_t* device, iree_hal_buffer_t* source,
+ iree_device_size_t source_offset, void* target,
+ iree_device_size_t data_length, iree_hal_transfer_buffer_flags_t flags,
+ iree_timeout_t timeout) {
+ return iree_hal_device_transfer_range(
+ device, iree_hal_make_device_transfer_buffer(source), source_offset,
+ iree_hal_make_host_transfer_buffer_span(target, data_length), 0,
+ data_length, flags, timeout);
+}
+
+IREE_API_EXPORT iree_status_t iree_hal_device_transfer_d2d(
+ iree_hal_device_t* device, iree_hal_buffer_t* source,
+ iree_device_size_t source_offset, iree_hal_buffer_t* target,
+ iree_device_size_t target_offset, iree_device_size_t data_length,
+ iree_hal_transfer_buffer_flags_t flags, iree_timeout_t timeout) {
+ return iree_hal_device_transfer_range(
+ device, iree_hal_make_device_transfer_buffer(source), source_offset,
+ iree_hal_make_device_transfer_buffer(target), target_offset, data_length,
+ flags, timeout);
+}
+
IREE_API_EXPORT iree_status_t iree_hal_device_transfer_and_wait(
iree_hal_device_t* device, iree_hal_semaphore_t* wait_semaphore,
uint64_t wait_value, iree_host_size_t transfer_count,
diff --git a/iree/hal/device.h b/iree/hal/device.h
index f228555..d1096ad 100644
--- a/iree/hal/device.h
+++ b/iree/hal/device.h
@@ -230,6 +230,29 @@
iree_device_size_t target_offset, iree_device_size_t data_length,
iree_hal_transfer_buffer_flags_t flags, iree_timeout_t timeout);
+// Synchronously copies data from host |source| into device |target|.
+// Convience wrapper around iree_hal_device_transfer_range.
+IREE_API_EXPORT iree_status_t iree_hal_device_transfer_h2d(
+ iree_hal_device_t* device, const void* source, iree_hal_buffer_t* target,
+ iree_device_size_t target_offset, iree_device_size_t data_length,
+ iree_hal_transfer_buffer_flags_t flags, iree_timeout_t timeout);
+
+// Synchronously copies data from device |source| into host |target|.
+// Convience wrapper around iree_hal_device_transfer_range.
+IREE_API_EXPORT iree_status_t iree_hal_device_transfer_d2h(
+ iree_hal_device_t* device, iree_hal_buffer_t* source,
+ iree_device_size_t source_offset, void* target,
+ iree_device_size_t data_length, iree_hal_transfer_buffer_flags_t flags,
+ iree_timeout_t timeout);
+
+// Synchronously copies data from device |source| into device |target|.
+// Convience wrapper around iree_hal_device_transfer_range.
+IREE_API_EXPORT iree_status_t iree_hal_device_transfer_d2d(
+ iree_hal_device_t* device, iree_hal_buffer_t* source,
+ iree_device_size_t source_offset, iree_hal_buffer_t* target,
+ iree_device_size_t target_offset, iree_device_size_t data_length,
+ iree_hal_transfer_buffer_flags_t flags, iree_timeout_t timeout);
+
// Synchronously executes one or more transfer operations against a queue.
// All buffers must be compatible with |device| and ranges must not overlap
// (same as with memcpy).
diff --git a/iree/samples/dynamic_shapes/main.c b/iree/samples/dynamic_shapes/main.c
index e11affa..c75c6c3 100644
--- a/iree/samples/dynamic_shapes/main.c
+++ b/iree/samples/dynamic_shapes/main.c
@@ -45,8 +45,11 @@
iree_runtime_call_outputs_pop_front_buffer_view(&call, &buffer_view);
}
if (iree_status_is_ok(status)) {
- status = iree_hal_buffer_read_data(iree_hal_buffer_view_buffer(buffer_view),
- 0, out_result, sizeof(*out_result));
+ status = iree_hal_device_transfer_d2h(
+ iree_runtime_session_device(session),
+ iree_hal_buffer_view_buffer(buffer_view), 0, out_result,
+ sizeof(*out_result), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
+ iree_infinite_timeout());
}
iree_hal_buffer_view_release(buffer_view);
diff --git a/iree/samples/simple_embedding/simple_embedding.c b/iree/samples/simple_embedding/simple_embedding.c
index 57d4f09..cce2143 100644
--- a/iree/samples/simple_embedding/simple_embedding.c
+++ b/iree/samples/simple_embedding/simple_embedding.c
@@ -142,9 +142,10 @@
// Read back the results and ensure we got the right values.
float results[] = {0.0f, 0.0f, 0.0f, 0.0f};
- IREE_RETURN_IF_ERROR(
- iree_hal_buffer_read_data(iree_hal_buffer_view_buffer(ret_buffer_view), 0,
- results, sizeof(results)));
+ IREE_RETURN_IF_ERROR(iree_hal_device_transfer_d2h(
+ device, iree_hal_buffer_view_buffer(ret_buffer_view), 0, results,
+ sizeof(results), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
+ iree_infinite_timeout()));
for (iree_host_size_t i = 0; i < IREE_ARRAYSIZE(results); ++i) {
if (results[i] != 8.0f) {
return iree_make_status(IREE_STATUS_UNKNOWN, "result mismatches");
diff --git a/iree/samples/static_library/static_library_demo.c b/iree/samples/static_library/static_library_demo.c
index d787ab3..2da897e 100644
--- a/iree/samples/static_library/static_library_demo.c
+++ b/iree/samples/static_library/static_library_demo.c
@@ -174,9 +174,10 @@
// Read back the results and ensure we got the right values.
float results[] = {0.0f, 0.0f, 0.0f, 0.0f};
if (iree_status_is_ok(status)) {
- status =
- iree_hal_buffer_read_data(iree_hal_buffer_view_buffer(ret_buffer_view),
- 0, results, sizeof(results));
+ status = iree_hal_device_transfer_d2h(
+ device, iree_hal_buffer_view_buffer(ret_buffer_view), 0, results,
+ sizeof(results), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
+ iree_infinite_timeout());
}
if (iree_status_is_ok(status)) {
for (iree_host_size_t i = 0; i < IREE_ARRAYSIZE(results); ++i) {
diff --git a/iree/samples/variables_and_state/main.c b/iree/samples/variables_and_state/main.c
index 0014ac6..45c8f9c 100644
--- a/iree/samples/variables_and_state/main.c
+++ b/iree/samples/variables_and_state/main.c
@@ -24,8 +24,11 @@
iree_runtime_call_outputs_pop_front_buffer_view(&call, &buffer_view);
}
if (iree_status_is_ok(status)) {
- status = iree_hal_buffer_read_data(iree_hal_buffer_view_buffer(buffer_view),
- 0, out_value, sizeof(*out_value));
+ status = iree_hal_device_transfer_d2h(
+ iree_runtime_session_device(session),
+ iree_hal_buffer_view_buffer(buffer_view), 0, out_value,
+ sizeof(*out_value), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
+ iree_infinite_timeout());
}
iree_hal_buffer_view_release(buffer_view);
diff --git a/iree/samples/vision/iree-run-mnist-module.c b/iree/samples/vision/iree-run-mnist-module.c
index d525bdd..9168e8f 100644
--- a/iree/samples/vision/iree-run-mnist-module.c
+++ b/iree/samples/vision/iree-run-mnist-module.c
@@ -74,9 +74,11 @@
// Read back the results. The output of the mnist model is a 1x10 prediction
// confidence values for each digit in [0, 9].
float predictions[1 * 10] = {0.0f};
- IREE_RETURN_IF_ERROR(
- iree_hal_buffer_read_data(iree_hal_buffer_view_buffer(ret_buffer_view), 0,
- predictions, sizeof(predictions)));
+ IREE_RETURN_IF_ERROR(iree_hal_device_transfer_d2h(
+ iree_runtime_session_device(session),
+ iree_hal_buffer_view_buffer(ret_buffer_view), 0, predictions,
+ sizeof(predictions), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
+ iree_infinite_timeout()));
iree_hal_buffer_view_release(ret_buffer_view);
// Get the highest index from the output.
diff --git a/iree/samples/vulkan/vulkan_inference_gui.cc b/iree/samples/vulkan/vulkan_inference_gui.cc
index 8ec9c5a..8888e16 100644
--- a/iree/samples/vulkan/vulkan_inference_gui.cc
+++ b/iree/samples/vulkan/vulkan_inference_gui.cc
@@ -413,9 +413,10 @@
auto* output_buffer_view = reinterpret_cast<iree_hal_buffer_view_t*>(
iree_vm_list_get_ref_deref(outputs.get(), 0,
iree_hal_buffer_view_get_descriptor()));
- IREE_CHECK_OK(iree_hal_buffer_read_data(
- iree_hal_buffer_view_buffer(output_buffer_view), 0, latest_output,
- sizeof(latest_output)));
+ IREE_CHECK_OK(iree_hal_device_transfer_d2h(
+ iree_vk_device, iree_hal_buffer_view_buffer(output_buffer_view), 0,
+ latest_output, sizeof(latest_output),
+ IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout()));
dirty = false;
}