Refactor HostToDevice Transfer to seperate 0-dim and splat cases (#15285)
Transfer is an intermixing of cases. Best to refactor so that we
separate splatting/0-dim/transfer cases
to cleanup the branching behavior.
diff --git a/integrations/pjrt/src/iree_pjrt/common/api_impl.cc b/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
index bd9b3ce..8bb8cc9 100644
--- a/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
+++ b/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
@@ -698,6 +698,160 @@
return iree_ok_status();
}
+iree_status_t DeviceInstance::HostBufferToDeviceSplat(
+ const void* data, PJRT_Buffer_Type type, const int64_t* dims,
+ size_t num_dims, EventInstance** out_done_with_host_buffer_event,
+ BufferInstance** out_buffer) {
+ // Map element type:
+ iree_hal_element_type_t element_type;
+ IREE_RETURN_IF_ERROR(
+ PJRTApiConverter::MapBufferTypeToElementType(type, &element_type));
+ // TODO: Do something sensible with sub-byte aligned types.
+ if (IREE_UNLIKELY(iree_hal_element_bit_count(element_type) == 0) ||
+ IREE_UNLIKELY(!iree_hal_element_is_byte_aligned(element_type))) {
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "opaque and sub-byte aligned element types cannot be indexed");
+ }
+ iree_device_size_t element_type_byte_size =
+ iree_hal_element_dense_byte_count(element_type);
+
+ // Handle strided layouts and shape.
+ std::array<iree_hal_dim_t, 9> shape;
+ if (num_dims > shape.size()) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "only supports up to %d dims but got %d",
+ (int)shape.size(), (int)num_dims);
+ }
+
+ iree_device_size_t byte_length = element_type_byte_size;
+ for (int i = 0, s = num_dims; i < s; ++i) {
+ byte_length *= dims[i];
+ shape[i] = dims[i];
+ }
+
+ iree::vm::ref<iree_hal_buffer_t> buffer;
+
+ // Allocate on stream. We serialize across 3 timepoints:
+ // 0. Last transfer complete
+ // 1. Allocation
+ // 2. Fill is complete
+ // There are various ways to be smarter about this but without more
+ // information from the caller, this is ok. If we wanted to favor smaller
+ // allocation scopes, it may be desirable to join with the main execution
+ // timeline, but that would obviously serialize more.
+ uint64_t wait_transfer_start = last_transfer_timepoint_;
+ uint64_t signal_alloca_complete = ++last_transfer_timepoint_;
+ uint64_t signal_copy_complete = ++last_transfer_timepoint_;
+ iree_hal_buffer_params_t params;
+ memset(¶ms, 0, sizeof(params));
+ params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL;
+ params.usage =
+ IREE_HAL_BUFFER_USAGE_DEFAULT | IREE_HAL_BUFFER_USAGE_TRANSFER_TARGET;
+ IREE_RETURN_IF_ERROR(IreeApi::hal_device_queue_alloca(
+ device(), IREE_HAL_QUEUE_AFFINITY_ANY,
+ /*wait_semaphore_list=*/
+ {1, &transfer_timeline_, &wait_transfer_start},
+ /*signal_semaphore_list=*/
+ {1, &transfer_timeline_, &signal_alloca_complete},
+ IREE_HAL_ALLOCATOR_POOL_DEFAULT, params, byte_length, &buffer));
+
+ // Queue up the buffer fill for splatting:
+ iree::vm::ref<iree_hal_command_buffer_t> transfer_cb;
+ IREE_RETURN_IF_ERROR(iree_hal_command_buffer_create(
+ device(), IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT,
+ IREE_HAL_COMMAND_CATEGORY_ANY, IREE_HAL_QUEUE_AFFINITY_ANY,
+ /*binding_capacity=*/0, &transfer_cb));
+ IREE_CHECK_OK(iree_hal_command_buffer_begin(transfer_cb.get()));
+ IREE_RETURN_IF_ERROR(iree_hal_command_buffer_fill_buffer(
+ transfer_cb.get(), buffer.get(), /*target_offset=*/0,
+ /*target_size=*/byte_length, data, element_type_byte_size));
+ IREE_CHECK_OK(iree_hal_command_buffer_end(transfer_cb.get()));
+
+ // Execute the enqueued splat:
+ IREE_RETURN_IF_ERROR(IreeApi::hal_device_queue_execute(
+ device(), IREE_HAL_QUEUE_AFFINITY_ANY,
+ /*wait_semaphore_list=*/
+ {1, &transfer_timeline_, &signal_alloca_complete},
+ /*signal_semaphore_list=*/
+ {1, &transfer_timeline_, &signal_copy_complete},
+ /*command_buffer_count=*/1, &transfer_cb));
+
+ // Wrap in a buffer view and return:
+ iree::vm::ref<iree_hal_buffer_view_t> result_buffer_view;
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create(
+ buffer.get(), num_dims, &shape[0], element_type,
+ IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, client_.host_allocator(),
+ &result_buffer_view));
+
+ auto instance = new BufferInstance(*this, std::move(result_buffer_view));
+ instance->AdvanceReadyFence(transfer_timeline_.get(), signal_copy_complete);
+ instance->AdvanceDoneFence(transfer_timeline_.get(), signal_copy_complete);
+ *out_buffer = instance;
+
+ // Splat so the data is no longer required:
+ *out_done_with_host_buffer_event = new EventInstance(/*fence=*/nullptr);
+
+ return iree_ok_status();
+}
+
+iree_status_t DeviceInstance::HostBufferToDeviceZeroDim(
+ PJRT_Buffer_Type type, const int64_t* dims, size_t num_dims,
+ EventInstance** out_done_with_host_buffer_event,
+ BufferInstance** out_buffer) {
+ // Map element type:
+ iree_hal_element_type_t element_type;
+ IREE_RETURN_IF_ERROR(
+ PJRTApiConverter::MapBufferTypeToElementType(type, &element_type));
+
+ std::array<iree_hal_dim_t, 9> shape;
+ if (num_dims > shape.size()) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "only supports up to %d dims but got %d",
+ (int)shape.size(), (int)num_dims);
+ }
+
+ for (int i = 0, s = num_dims; i < s; ++i) {
+ shape[i] = dims[i];
+ }
+
+ // We only need to wait for previous transfer and allocate data:
+ uint64_t wait_transfer_start = last_transfer_timepoint_;
+ uint64_t signal_alloca_complete = ++last_transfer_timepoint_;
+
+ iree_hal_buffer_params_t params;
+ iree::vm::ref<iree_hal_buffer_t> buffer;
+ memset(¶ms, 0, sizeof(params));
+ params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL;
+ params.usage =
+ IREE_HAL_BUFFER_USAGE_DEFAULT | IREE_HAL_BUFFER_USAGE_TRANSFER_TARGET;
+ IREE_RETURN_IF_ERROR(IreeApi::hal_device_queue_alloca(
+ device(), IREE_HAL_QUEUE_AFFINITY_ANY,
+ /*wait_semaphore_list=*/
+ {1, &transfer_timeline_, &wait_transfer_start},
+ /*signal_semaphore_list=*/
+ {1, &transfer_timeline_, &signal_alloca_complete},
+ IREE_HAL_ALLOCATOR_POOL_DEFAULT, params,
+ iree_hal_element_dense_byte_count(element_type), &buffer));
+
+ // Wrap in a buffer view and return.
+ iree::vm::ref<iree_hal_buffer_view_t> result_buffer_view;
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create(
+ buffer.get(), num_dims, &shape[0], element_type,
+ IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, client_.host_allocator(),
+ &result_buffer_view));
+
+ auto instance = new BufferInstance(*this, std::move(result_buffer_view));
+ instance->AdvanceReadyFence(transfer_timeline_.get(), signal_alloca_complete);
+ instance->AdvanceDoneFence(transfer_timeline_.get(), signal_alloca_complete);
+ *out_buffer = instance;
+
+ // Degenerate case ignores the data so we can just return:
+ *out_done_with_host_buffer_event = new EventInstance(/*fence=*/nullptr);
+
+ return iree_ok_status();
+}
+
iree_status_t DeviceInstance::HostBufferToDevice(
const void* data, PJRT_Buffer_Type type, const int64_t* dims,
size_t num_dims, const int64_t* byte_strides, size_t num_byte_strides,
@@ -720,7 +874,33 @@
iree_device_size_t element_type_byte_size =
iree_hal_element_dense_byte_count(element_type);
- // Handle strided layouts and shape.
+ // We need to check for special cases (splatting, zerodim):
+ bool is_splat = element_type_byte_size == 1 || element_type_byte_size == 2 ||
+ element_type_byte_size == 4;
+ bool has_zero_dim = false;
+ iree_device_size_t byte_length = element_type_byte_size;
+
+ for (int i = 0; i < num_byte_strides; ++i) {
+ is_splat &= (dims[i] == 1 || byte_strides[i] == 0);
+ has_zero_dim |= (dims[i] == 0);
+ byte_length *= dims[i];
+ }
+
+ byte_length = std::max(element_type_byte_size, byte_length);
+
+ // If we encounter the zero dim case no transfer is required:
+ if (has_zero_dim) {
+ return HostBufferToDeviceZeroDim(
+ type, dims, num_dims, out_done_with_host_buffer_event, out_buffer);
+ }
+
+ // If we encounter the splat case we can perform a fill instead:
+ if (is_splat) {
+ return HostBufferToDeviceSplat(data, type, dims, num_dims,
+ out_done_with_host_buffer_event, out_buffer);
+ }
+
+ // Handle strided layouts and shape:
std::vector<int64_t> perms(num_dims);
std::array<iree_hal_dim_t, 9> input_shape;
std::array<iree_hal_dim_t, 9> output_shape;
@@ -730,43 +910,20 @@
(int)input_shape.size(), (int)num_dims);
}
- bool has_zero_length = false;
for (int i = 0, s = num_dims; i < s; ++i) {
- has_zero_length |= dims[i] == 0;
output_shape[i] = dims[i];
}
- if (has_zero_length) {
- // This is a degenerate case.
- for (int i = 0; i < num_dims; ++i) {
- perms[i] = i;
- input_shape[i] = dims[i];
- }
- } else {
- // Compute the input shape and permutations for the broadcast.
- iree::pjrt::computeBroadcastArgs(
- num_dims, element_type_byte_size, byte_strides, dims,
- reinterpret_cast<int64_t*>(input_shape.data()), perms.data());
- }
+ // Compute the input shape and permutations for the broadcast.
+ iree::pjrt::computeBroadcastArgs(
+ num_dims, element_type_byte_size, byte_strides, dims,
+ reinterpret_cast<int64_t*>(input_shape.data()), perms.data());
- // Splatting requires length 1, 2, or 4
- bool is_splat = element_type_byte_size == 1 || element_type_byte_size == 2 ||
- element_type_byte_size == 4;
bool is_dense_row_major = true;
for (int i = 0, s = num_dims; i < s; ++i) {
is_dense_row_major &= (input_shape[i] == dims[i]) && (perms[i] == i);
- is_splat &= (byte_strides[i] == 0) || (dims[i] == 1);
}
- const bool needs_staging_buffer = !is_splat && !has_zero_length;
-
- iree_device_size_t byte_length = element_type_byte_size;
- for (size_t i = 0; i < num_dims; ++i) {
- byte_length *= dims[i];
- }
-
- byte_length = std::max(element_type_byte_size, byte_length);
-
std::vector<int8_t> transposed_data;
if (!is_dense_row_major) {
client().logger().debug(
@@ -774,24 +931,7 @@
"doing the transpose on device. See: "
"https://github.com/openxla/openxla-pjrt-plugin/issues/201");
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
- "unable to create transpose plan");
- // std::vector<int64_t> input_strides(num_dims);
- // std::vector<int64_t> input_dims(num_dims);
- // for (size_t i = 0; i < num_dims; i++) {
- // input_strides[perms[i]] = byte_strides[i];
- // input_dims[i] = input_shape[i];
- // }
- // transposed_data.resize(byte_length);
- // // TODO: use caching to improve performance of plan creation
- // auto transpose =
- // xla::TransposePlan::Create(element_type_byte_size, input_dims, perms,
- // xla::TransposePlan::Striding{input_strides});
- // if (!transpose.ok()) {
- // return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
- // "unable to create transpose plan");
- // }
- // transpose.value()->Execute(data, transposed_data.data());
- // data = transposed_data.data();
+ "transposition is not currently supported");
}
iree::vm::ref<iree_hal_buffer_t> buffer;
@@ -807,17 +947,14 @@
PJRT_HostBufferSemantics_kImmutableOnlyDuringCall;
bool caller_data_done = false;
- // We only need a staging buffer if we cannot splat the data.
iree::vm::ref<iree_hal_buffer_t> host_staging_buffer;
- if (needs_staging_buffer) {
- IREE_RETURN_IF_ERROR(AcquireHostStagingBuffer(
- iree_make_const_byte_span(data, byte_length), require_snapshot_now,
- &caller_data_done, &host_staging_buffer));
- if (!caller_data_done) {
- return iree_make_status(
- IREE_STATUS_UNIMPLEMENTED,
- "deferred snapshot of host data not yet implemented");
- }
+ IREE_RETURN_IF_ERROR(AcquireHostStagingBuffer(
+ iree_make_const_byte_span(data, byte_length), require_snapshot_now,
+ &caller_data_done, &host_staging_buffer));
+ if (!caller_data_done) {
+ return iree_make_status(
+ IREE_STATUS_UNIMPLEMENTED,
+ "deferred snapshot of host data not yet implemented");
}
// Allocate on stream. We serialize across 3 timepoints:
@@ -846,47 +983,26 @@
// Queue up the transfer command.
iree::vm::ref<iree_hal_command_buffer_t> transfer_cb;
- if (is_splat) {
- IREE_RETURN_IF_ERROR(iree_hal_command_buffer_create(
- device(), IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT,
- IREE_HAL_COMMAND_CATEGORY_ANY, IREE_HAL_QUEUE_AFFINITY_ANY,
- /*binding_capacity=*/0, &transfer_cb));
- IREE_CHECK_OK(iree_hal_command_buffer_begin(transfer_cb.get()));
- IREE_RETURN_IF_ERROR(iree_hal_command_buffer_fill_buffer(
- transfer_cb.get(), buffer.get(), /*target_offset=*/0,
- /*target_size=*/byte_length, data, element_type_byte_size));
- IREE_CHECK_OK(iree_hal_command_buffer_end(transfer_cb.get()));
- } else if (!has_zero_length) {
- iree_hal_transfer_command_t transfer_command;
- memset(&transfer_command, 0, sizeof(transfer_command));
- transfer_command.type = IREE_HAL_TRANSFER_COMMAND_TYPE_COPY;
- transfer_command.copy.source_buffer = host_staging_buffer.get(),
- transfer_command.copy.source_offset = 0;
- transfer_command.copy.target_buffer = buffer.get();
- transfer_command.copy.target_offset = 0;
- transfer_command.copy.length = byte_length;
- IREE_RETURN_IF_ERROR(iree_hal_create_transfer_command_buffer(
- device(), IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT,
- IREE_HAL_QUEUE_AFFINITY_ANY,
- /*transfer_count=*/1, &transfer_command, &transfer_cb));
- }
+ iree_hal_transfer_command_t transfer_command;
+ memset(&transfer_command, 0, sizeof(transfer_command));
+ transfer_command.type = IREE_HAL_TRANSFER_COMMAND_TYPE_COPY;
+ transfer_command.copy.source_buffer = host_staging_buffer.get(),
+ transfer_command.copy.source_offset = 0;
+ transfer_command.copy.target_buffer = buffer.get();
+ transfer_command.copy.target_offset = 0;
+ transfer_command.copy.length = byte_length;
+ IREE_RETURN_IF_ERROR(iree_hal_create_transfer_command_buffer(
+ device(), IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT,
+ IREE_HAL_QUEUE_AFFINITY_ANY,
+ /*transfer_count=*/1, &transfer_command, &transfer_cb));
- if (has_zero_length) {
- IREE_RETURN_IF_ERROR(IreeApi::hal_device_queue_barrier(
- device(), IREE_HAL_QUEUE_AFFINITY_ANY,
- /*wait_semaphore_list=*/
- {1, &transfer_timeline_, &signal_alloca_complete},
- /*signal_semaphore_list=*/
- {1, &transfer_timeline_, &signal_copy_complete}));
- } else {
- IREE_RETURN_IF_ERROR(IreeApi::hal_device_queue_execute(
- device(), IREE_HAL_QUEUE_AFFINITY_ANY,
- /*wait_semaphore_list=*/
- {1, &transfer_timeline_, &signal_alloca_complete},
- /*signal_semaphore_list=*/
- {1, &transfer_timeline_, &signal_copy_complete},
- /*command_buffer_count=*/1, &transfer_cb));
- }
+ IREE_RETURN_IF_ERROR(IreeApi::hal_device_queue_execute(
+ device(), IREE_HAL_QUEUE_AFFINITY_ANY,
+ /*wait_semaphore_list=*/
+ {1, &transfer_timeline_, &signal_alloca_complete},
+ /*signal_semaphore_list=*/
+ {1, &transfer_timeline_, &signal_copy_complete},
+ /*command_buffer_count=*/1, &transfer_cb));
// Wrap in a buffer view and return.
iree::vm::ref<iree_hal_buffer_view_t> result_buffer_view;
@@ -895,11 +1011,10 @@
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, client_.host_allocator(),
&result_buffer_view));
- *out_buffer = new BufferInstance(*this, std::move(result_buffer_view));
- (*out_buffer)
- ->AdvanceReadyFence(transfer_timeline_.get(), signal_copy_complete);
- (*out_buffer)
- ->AdvanceDoneFence(transfer_timeline_.get(), signal_copy_complete);
+ auto instance = new BufferInstance(*this, std::move(result_buffer_view));
+ instance->AdvanceReadyFence(transfer_timeline_.get(), signal_copy_complete);
+ instance->AdvanceDoneFence(transfer_timeline_.get(), signal_copy_complete);
+ *out_buffer = instance;
// We snapshotted the caller data when acquiring the host staging buffer,
// so we won't be touching it again.
diff --git a/integrations/pjrt/src/iree_pjrt/common/api_impl.h b/integrations/pjrt/src/iree_pjrt/common/api_impl.h
index dc405f1..4a84e89 100644
--- a/integrations/pjrt/src/iree_pjrt/common/api_impl.h
+++ b/integrations/pjrt/src/iree_pjrt/common/api_impl.h
@@ -214,6 +214,16 @@
bool is_addressable() { return true; }
int local_hardware_id() { return -1; }
+ iree_status_t HostBufferToDeviceZeroDim(
+ PJRT_Buffer_Type type, const int64_t* dims, size_t num_dims,
+ EventInstance** out_done_with_host_buffer_event,
+ BufferInstance** out_buffer);
+
+ iree_status_t HostBufferToDeviceSplat(
+ const void* data, PJRT_Buffer_Type type, const int64_t* dims,
+ size_t num_dims, EventInstance** out_done_with_host_buffer_event,
+ BufferInstance** out_buffer);
+
// Copies a host buffer to the device.
// See PJRT_Client_BufferFromHostBuffer
iree_status_t HostBufferToDevice(