Refactor HostToDevice Transfer to seperate 0-dim and splat cases (#15285)

Transfer is an intermixing of cases. Best to refactor so that we
separate splatting/0-dim/transfer cases
to cleanup the branching behavior.
diff --git a/integrations/pjrt/src/iree_pjrt/common/api_impl.cc b/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
index bd9b3ce..8bb8cc9 100644
--- a/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
+++ b/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
@@ -698,6 +698,160 @@
   return iree_ok_status();
 }
 
+iree_status_t DeviceInstance::HostBufferToDeviceSplat(
+    const void* data, PJRT_Buffer_Type type, const int64_t* dims,
+    size_t num_dims, EventInstance** out_done_with_host_buffer_event,
+    BufferInstance** out_buffer) {
+  // Map element type:
+  iree_hal_element_type_t element_type;
+  IREE_RETURN_IF_ERROR(
+      PJRTApiConverter::MapBufferTypeToElementType(type, &element_type));
+  // TODO: Do something sensible with sub-byte aligned types.
+  if (IREE_UNLIKELY(iree_hal_element_bit_count(element_type) == 0) ||
+      IREE_UNLIKELY(!iree_hal_element_is_byte_aligned(element_type))) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "opaque and sub-byte aligned element types cannot be indexed");
+  }
+  iree_device_size_t element_type_byte_size =
+      iree_hal_element_dense_byte_count(element_type);
+
+  // Handle strided layouts and shape.
+  std::array<iree_hal_dim_t, 9> shape;
+  if (num_dims > shape.size()) {
+    return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                            "only supports up to %d dims but got %d",
+                            (int)shape.size(), (int)num_dims);
+  }
+
+  iree_device_size_t byte_length = element_type_byte_size;
+  for (int i = 0, s = num_dims; i < s; ++i) {
+    byte_length *= dims[i];
+    shape[i] = dims[i];
+  }
+
+  iree::vm::ref<iree_hal_buffer_t> buffer;
+
+  // Allocate on stream. We serialize across 3 timepoints:
+  //   0. Last transfer complete
+  //   1. Allocation
+  //   2. Fill is complete
+  // There are various ways to be smarter about this but without more
+  // information from the caller, this is ok. If we wanted to favor smaller
+  // allocation scopes, it may be desirable to join with the main execution
+  // timeline, but that would obviously serialize more.
+  uint64_t wait_transfer_start = last_transfer_timepoint_;
+  uint64_t signal_alloca_complete = ++last_transfer_timepoint_;
+  uint64_t signal_copy_complete = ++last_transfer_timepoint_;
+  iree_hal_buffer_params_t params;
+  memset(&params, 0, sizeof(params));
+  params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL;
+  params.usage =
+      IREE_HAL_BUFFER_USAGE_DEFAULT | IREE_HAL_BUFFER_USAGE_TRANSFER_TARGET;
+  IREE_RETURN_IF_ERROR(IreeApi::hal_device_queue_alloca(
+      device(), IREE_HAL_QUEUE_AFFINITY_ANY,
+      /*wait_semaphore_list=*/
+      {1, &transfer_timeline_, &wait_transfer_start},
+      /*signal_semaphore_list=*/
+      {1, &transfer_timeline_, &signal_alloca_complete},
+      IREE_HAL_ALLOCATOR_POOL_DEFAULT, params, byte_length, &buffer));
+
+  // Queue up the buffer fill for splatting:
+  iree::vm::ref<iree_hal_command_buffer_t> transfer_cb;
+  IREE_RETURN_IF_ERROR(iree_hal_command_buffer_create(
+      device(), IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT,
+      IREE_HAL_COMMAND_CATEGORY_ANY, IREE_HAL_QUEUE_AFFINITY_ANY,
+      /*binding_capacity=*/0, &transfer_cb));
+  IREE_CHECK_OK(iree_hal_command_buffer_begin(transfer_cb.get()));
+  IREE_RETURN_IF_ERROR(iree_hal_command_buffer_fill_buffer(
+      transfer_cb.get(), buffer.get(), /*target_offset=*/0,
+      /*target_size=*/byte_length, data, element_type_byte_size));
+  IREE_CHECK_OK(iree_hal_command_buffer_end(transfer_cb.get()));
+
+  // Execute the enqueued splat:
+  IREE_RETURN_IF_ERROR(IreeApi::hal_device_queue_execute(
+      device(), IREE_HAL_QUEUE_AFFINITY_ANY,
+      /*wait_semaphore_list=*/
+      {1, &transfer_timeline_, &signal_alloca_complete},
+      /*signal_semaphore_list=*/
+      {1, &transfer_timeline_, &signal_copy_complete},
+      /*command_buffer_count=*/1, &transfer_cb));
+
+  // Wrap in a buffer view and return:
+  iree::vm::ref<iree_hal_buffer_view_t> result_buffer_view;
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create(
+      buffer.get(), num_dims, &shape[0], element_type,
+      IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, client_.host_allocator(),
+      &result_buffer_view));
+
+  auto instance = new BufferInstance(*this, std::move(result_buffer_view));
+  instance->AdvanceReadyFence(transfer_timeline_.get(), signal_copy_complete);
+  instance->AdvanceDoneFence(transfer_timeline_.get(), signal_copy_complete);
+  *out_buffer = instance;
+
+  // Splat so the data is no longer required:
+  *out_done_with_host_buffer_event = new EventInstance(/*fence=*/nullptr);
+
+  return iree_ok_status();
+}
+
+iree_status_t DeviceInstance::HostBufferToDeviceZeroDim(
+    PJRT_Buffer_Type type, const int64_t* dims, size_t num_dims,
+    EventInstance** out_done_with_host_buffer_event,
+    BufferInstance** out_buffer) {
+  // Map element type:
+  iree_hal_element_type_t element_type;
+  IREE_RETURN_IF_ERROR(
+      PJRTApiConverter::MapBufferTypeToElementType(type, &element_type));
+
+  std::array<iree_hal_dim_t, 9> shape;
+  if (num_dims > shape.size()) {
+    return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                            "only supports up to %d dims but got %d",
+                            (int)shape.size(), (int)num_dims);
+  }
+
+  for (int i = 0, s = num_dims; i < s; ++i) {
+    shape[i] = dims[i];
+  }
+
+  // We only need to wait for previous transfer and allocate data:
+  uint64_t wait_transfer_start = last_transfer_timepoint_;
+  uint64_t signal_alloca_complete = ++last_transfer_timepoint_;
+
+  iree_hal_buffer_params_t params;
+  iree::vm::ref<iree_hal_buffer_t> buffer;
+  memset(&params, 0, sizeof(params));
+  params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL;
+  params.usage =
+      IREE_HAL_BUFFER_USAGE_DEFAULT | IREE_HAL_BUFFER_USAGE_TRANSFER_TARGET;
+  IREE_RETURN_IF_ERROR(IreeApi::hal_device_queue_alloca(
+      device(), IREE_HAL_QUEUE_AFFINITY_ANY,
+      /*wait_semaphore_list=*/
+      {1, &transfer_timeline_, &wait_transfer_start},
+      /*signal_semaphore_list=*/
+      {1, &transfer_timeline_, &signal_alloca_complete},
+      IREE_HAL_ALLOCATOR_POOL_DEFAULT, params,
+      iree_hal_element_dense_byte_count(element_type), &buffer));
+
+  // Wrap in a buffer view and return.
+  iree::vm::ref<iree_hal_buffer_view_t> result_buffer_view;
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create(
+      buffer.get(), num_dims, &shape[0], element_type,
+      IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, client_.host_allocator(),
+      &result_buffer_view));
+
+  auto instance = new BufferInstance(*this, std::move(result_buffer_view));
+  instance->AdvanceReadyFence(transfer_timeline_.get(), signal_alloca_complete);
+  instance->AdvanceDoneFence(transfer_timeline_.get(), signal_alloca_complete);
+  *out_buffer = instance;
+
+  // Degenerate case ignores the data so we can just return:
+  *out_done_with_host_buffer_event = new EventInstance(/*fence=*/nullptr);
+
+  return iree_ok_status();
+}
+
 iree_status_t DeviceInstance::HostBufferToDevice(
     const void* data, PJRT_Buffer_Type type, const int64_t* dims,
     size_t num_dims, const int64_t* byte_strides, size_t num_byte_strides,
@@ -720,7 +874,33 @@
   iree_device_size_t element_type_byte_size =
       iree_hal_element_dense_byte_count(element_type);
 
-  // Handle strided layouts and shape.
+  // We need to check for special cases (splatting, zerodim):
+  bool is_splat = element_type_byte_size == 1 || element_type_byte_size == 2 ||
+                  element_type_byte_size == 4;
+  bool has_zero_dim = false;
+  iree_device_size_t byte_length = element_type_byte_size;
+
+  for (int i = 0; i < num_byte_strides; ++i) {
+    is_splat &= (dims[i] == 1 || byte_strides[i] == 0);
+    has_zero_dim |= (dims[i] == 0);
+    byte_length *= dims[i];
+  }
+
+  byte_length = std::max(element_type_byte_size, byte_length);
+
+  // If we encounter the zero dim case no transfer is required:
+  if (has_zero_dim) {
+    return HostBufferToDeviceZeroDim(
+        type, dims, num_dims, out_done_with_host_buffer_event, out_buffer);
+  }
+
+  // If we encounter the splat case we can perform a fill instead:
+  if (is_splat) {
+    return HostBufferToDeviceSplat(data, type, dims, num_dims,
+                                   out_done_with_host_buffer_event, out_buffer);
+  }
+
+  // Handle strided layouts and shape:
   std::vector<int64_t> perms(num_dims);
   std::array<iree_hal_dim_t, 9> input_shape;
   std::array<iree_hal_dim_t, 9> output_shape;
@@ -730,43 +910,20 @@
                             (int)input_shape.size(), (int)num_dims);
   }
 
-  bool has_zero_length = false;
   for (int i = 0, s = num_dims; i < s; ++i) {
-    has_zero_length |= dims[i] == 0;
     output_shape[i] = dims[i];
   }
 
-  if (has_zero_length) {
-    // This is a degenerate case.
-    for (int i = 0; i < num_dims; ++i) {
-      perms[i] = i;
-      input_shape[i] = dims[i];
-    }
-  } else {
-    // Compute the input shape and permutations for the broadcast.
-    iree::pjrt::computeBroadcastArgs(
-        num_dims, element_type_byte_size, byte_strides, dims,
-        reinterpret_cast<int64_t*>(input_shape.data()), perms.data());
-  }
+  // Compute the input shape and permutations for the broadcast.
+  iree::pjrt::computeBroadcastArgs(
+      num_dims, element_type_byte_size, byte_strides, dims,
+      reinterpret_cast<int64_t*>(input_shape.data()), perms.data());
 
-  // Splatting requires length 1, 2, or 4
-  bool is_splat = element_type_byte_size == 1 || element_type_byte_size == 2 ||
-                  element_type_byte_size == 4;
   bool is_dense_row_major = true;
   for (int i = 0, s = num_dims; i < s; ++i) {
     is_dense_row_major &= (input_shape[i] == dims[i]) && (perms[i] == i);
-    is_splat &= (byte_strides[i] == 0) || (dims[i] == 1);
   }
 
-  const bool needs_staging_buffer = !is_splat && !has_zero_length;
-
-  iree_device_size_t byte_length = element_type_byte_size;
-  for (size_t i = 0; i < num_dims; ++i) {
-    byte_length *= dims[i];
-  }
-
-  byte_length = std::max(element_type_byte_size, byte_length);
-
   std::vector<int8_t> transposed_data;
   if (!is_dense_row_major) {
     client().logger().debug(
@@ -774,24 +931,7 @@
         "doing the transpose on device. See: "
         "https://github.com/openxla/openxla-pjrt-plugin/issues/201");
     return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-                            "unable to create transpose plan");
-    // std::vector<int64_t> input_strides(num_dims);
-    // std::vector<int64_t> input_dims(num_dims);
-    // for (size_t i = 0; i < num_dims; i++) {
-    //   input_strides[perms[i]] = byte_strides[i];
-    //   input_dims[i] = input_shape[i];
-    // }
-    // transposed_data.resize(byte_length);
-    // // TODO: use caching to improve performance of plan creation
-    // auto transpose =
-    //     xla::TransposePlan::Create(element_type_byte_size, input_dims, perms,
-    //                                xla::TransposePlan::Striding{input_strides});
-    // if (!transpose.ok()) {
-    //   return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-    //                           "unable to create transpose plan");
-    // }
-    // transpose.value()->Execute(data, transposed_data.data());
-    // data = transposed_data.data();
+                            "transposition is not currently supported");
   }
 
   iree::vm::ref<iree_hal_buffer_t> buffer;
@@ -807,17 +947,14 @@
                               PJRT_HostBufferSemantics_kImmutableOnlyDuringCall;
   bool caller_data_done = false;
 
-  // We only need a staging buffer if we cannot splat the data.
   iree::vm::ref<iree_hal_buffer_t> host_staging_buffer;
-  if (needs_staging_buffer) {
-    IREE_RETURN_IF_ERROR(AcquireHostStagingBuffer(
-        iree_make_const_byte_span(data, byte_length), require_snapshot_now,
-        &caller_data_done, &host_staging_buffer));
-    if (!caller_data_done) {
-      return iree_make_status(
-          IREE_STATUS_UNIMPLEMENTED,
-          "deferred snapshot of host data not yet implemented");
-    }
+  IREE_RETURN_IF_ERROR(AcquireHostStagingBuffer(
+      iree_make_const_byte_span(data, byte_length), require_snapshot_now,
+      &caller_data_done, &host_staging_buffer));
+  if (!caller_data_done) {
+    return iree_make_status(
+        IREE_STATUS_UNIMPLEMENTED,
+        "deferred snapshot of host data not yet implemented");
   }
 
   // Allocate on stream. We serialize across 3 timepoints:
@@ -846,47 +983,26 @@
 
   // Queue up the transfer command.
   iree::vm::ref<iree_hal_command_buffer_t> transfer_cb;
-  if (is_splat) {
-    IREE_RETURN_IF_ERROR(iree_hal_command_buffer_create(
-        device(), IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT,
-        IREE_HAL_COMMAND_CATEGORY_ANY, IREE_HAL_QUEUE_AFFINITY_ANY,
-        /*binding_capacity=*/0, &transfer_cb));
-    IREE_CHECK_OK(iree_hal_command_buffer_begin(transfer_cb.get()));
-    IREE_RETURN_IF_ERROR(iree_hal_command_buffer_fill_buffer(
-        transfer_cb.get(), buffer.get(), /*target_offset=*/0,
-        /*target_size=*/byte_length, data, element_type_byte_size));
-    IREE_CHECK_OK(iree_hal_command_buffer_end(transfer_cb.get()));
-  } else if (!has_zero_length) {
-    iree_hal_transfer_command_t transfer_command;
-    memset(&transfer_command, 0, sizeof(transfer_command));
-    transfer_command.type = IREE_HAL_TRANSFER_COMMAND_TYPE_COPY;
-    transfer_command.copy.source_buffer = host_staging_buffer.get(),
-    transfer_command.copy.source_offset = 0;
-    transfer_command.copy.target_buffer = buffer.get();
-    transfer_command.copy.target_offset = 0;
-    transfer_command.copy.length = byte_length;
-    IREE_RETURN_IF_ERROR(iree_hal_create_transfer_command_buffer(
-        device(), IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT,
-        IREE_HAL_QUEUE_AFFINITY_ANY,
-        /*transfer_count=*/1, &transfer_command, &transfer_cb));
-  }
+  iree_hal_transfer_command_t transfer_command;
+  memset(&transfer_command, 0, sizeof(transfer_command));
+  transfer_command.type = IREE_HAL_TRANSFER_COMMAND_TYPE_COPY;
+  transfer_command.copy.source_buffer = host_staging_buffer.get(),
+  transfer_command.copy.source_offset = 0;
+  transfer_command.copy.target_buffer = buffer.get();
+  transfer_command.copy.target_offset = 0;
+  transfer_command.copy.length = byte_length;
+  IREE_RETURN_IF_ERROR(iree_hal_create_transfer_command_buffer(
+      device(), IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT,
+      IREE_HAL_QUEUE_AFFINITY_ANY,
+      /*transfer_count=*/1, &transfer_command, &transfer_cb));
 
-  if (has_zero_length) {
-    IREE_RETURN_IF_ERROR(IreeApi::hal_device_queue_barrier(
-        device(), IREE_HAL_QUEUE_AFFINITY_ANY,
-        /*wait_semaphore_list=*/
-        {1, &transfer_timeline_, &signal_alloca_complete},
-        /*signal_semaphore_list=*/
-        {1, &transfer_timeline_, &signal_copy_complete}));
-  } else {
-    IREE_RETURN_IF_ERROR(IreeApi::hal_device_queue_execute(
-        device(), IREE_HAL_QUEUE_AFFINITY_ANY,
-        /*wait_semaphore_list=*/
-        {1, &transfer_timeline_, &signal_alloca_complete},
-        /*signal_semaphore_list=*/
-        {1, &transfer_timeline_, &signal_copy_complete},
-        /*command_buffer_count=*/1, &transfer_cb));
-  }
+  IREE_RETURN_IF_ERROR(IreeApi::hal_device_queue_execute(
+      device(), IREE_HAL_QUEUE_AFFINITY_ANY,
+      /*wait_semaphore_list=*/
+      {1, &transfer_timeline_, &signal_alloca_complete},
+      /*signal_semaphore_list=*/
+      {1, &transfer_timeline_, &signal_copy_complete},
+      /*command_buffer_count=*/1, &transfer_cb));
 
   // Wrap in a buffer view and return.
   iree::vm::ref<iree_hal_buffer_view_t> result_buffer_view;
@@ -895,11 +1011,10 @@
       IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, client_.host_allocator(),
       &result_buffer_view));
 
-  *out_buffer = new BufferInstance(*this, std::move(result_buffer_view));
-  (*out_buffer)
-      ->AdvanceReadyFence(transfer_timeline_.get(), signal_copy_complete);
-  (*out_buffer)
-      ->AdvanceDoneFence(transfer_timeline_.get(), signal_copy_complete);
+  auto instance = new BufferInstance(*this, std::move(result_buffer_view));
+  instance->AdvanceReadyFence(transfer_timeline_.get(), signal_copy_complete);
+  instance->AdvanceDoneFence(transfer_timeline_.get(), signal_copy_complete);
+  *out_buffer = instance;
 
   // We snapshotted the caller data when acquiring the host staging buffer,
   // so we won't be touching it again.
diff --git a/integrations/pjrt/src/iree_pjrt/common/api_impl.h b/integrations/pjrt/src/iree_pjrt/common/api_impl.h
index dc405f1..4a84e89 100644
--- a/integrations/pjrt/src/iree_pjrt/common/api_impl.h
+++ b/integrations/pjrt/src/iree_pjrt/common/api_impl.h
@@ -214,6 +214,16 @@
   bool is_addressable() { return true; }
   int local_hardware_id() { return -1; }
 
+  iree_status_t HostBufferToDeviceZeroDim(
+      PJRT_Buffer_Type type, const int64_t* dims, size_t num_dims,
+      EventInstance** out_done_with_host_buffer_event,
+      BufferInstance** out_buffer);
+
+  iree_status_t HostBufferToDeviceSplat(
+      const void* data, PJRT_Buffer_Type type, const int64_t* dims,
+      size_t num_dims, EventInstance** out_done_with_host_buffer_event,
+      BufferInstance** out_buffer);
+
   // Copies a host buffer to the device.
   // See PJRT_Client_BufferFromHostBuffer
   iree_status_t HostBufferToDevice(