Implement transposition / broadcast on host transfers (#15300)

It is possible to use a specialized program for device transfers. This
can minimize the amount of data transferred and leverage the accelerator
for peforming the transposition while decreasing the required host
resources.
diff --git a/integrations/pjrt/src/iree_pjrt/common/api_impl.cc b/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
index 6ac6f47..714b2a4 100644
--- a/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
+++ b/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
@@ -7,6 +7,8 @@
 #include "iree_pjrt/common/api_impl.h"
 
 #include <optional>
+#include <sstream>
+#include <utility>
 
 #include "iree/hal/api.h"
 #include "iree_pjrt/common/iree_helpers.h"
@@ -20,6 +22,9 @@
 
 const std::string_view kMlirFormat = "mlir";
 
+// We hardcode the maximum number of dimensions to avoid mallocs.
+constexpr int64_t kMaxDims = 9;
+
 // Some general conversion functions for managing around some API layering
 // that is in flight. It is expected that most of this goes away over time.
 namespace PJRTApiConverter {
@@ -88,6 +93,70 @@
   }
 }
 
+iree_status_t MapElementTypeToMlirType(iree_hal_element_type_t element_type,
+                                       char const** ty) {
+  switch (element_type) {
+    case PJRT_Buffer_Type_INVALID:
+      return iree_make_status(IREE_STATUS_INVALID_ARGUMENT);
+    case IREE_HAL_ELEMENT_TYPE_BOOL_8:
+      *ty = "i1";
+      return iree_ok_status();
+    case IREE_HAL_ELEMENT_TYPE_SINT_4:
+      *ty = "si4";
+      return iree_ok_status();
+    case IREE_HAL_ELEMENT_TYPE_SINT_8:
+      *ty = "si8";
+      return iree_ok_status();
+    case IREE_HAL_ELEMENT_TYPE_SINT_16:
+      *ty = "si16";
+      return iree_ok_status();
+    case IREE_HAL_ELEMENT_TYPE_SINT_32:
+      *ty = "si32";
+      return iree_ok_status();
+    case IREE_HAL_ELEMENT_TYPE_SINT_64:
+      *ty = "si64";
+      return iree_ok_status();
+    case IREE_HAL_ELEMENT_TYPE_UINT_4:
+      *ty = "ui4";
+      return iree_ok_status();
+    case IREE_HAL_ELEMENT_TYPE_UINT_8:
+      *ty = "ui8";
+      return iree_ok_status();
+    case IREE_HAL_ELEMENT_TYPE_UINT_16:
+      *ty = "ui16";
+      return iree_ok_status();
+    case IREE_HAL_ELEMENT_TYPE_UINT_32:
+      *ty = "ui32";
+      return iree_ok_status();
+    case IREE_HAL_ELEMENT_TYPE_UINT_64:
+      *ty = "ui64";
+      return iree_ok_status();
+    case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
+      *ty = "f16";
+      return iree_ok_status();
+    case IREE_HAL_ELEMENT_TYPE_FLOAT_32:
+      *ty = "f32";
+      return iree_ok_status();
+    case IREE_HAL_ELEMENT_TYPE_FLOAT_64:
+      *ty = "f64";
+      return iree_ok_status();
+    case IREE_HAL_ELEMENT_TYPE_BFLOAT_16:
+      *ty = "bf16";
+      return iree_ok_status();
+    case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64:
+      *ty = "complex<f32>";
+      return iree_ok_status();
+    case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128:
+      *ty = "complex<f64>";
+      return iree_ok_status();
+    default:
+      return iree_make_status(
+          IREE_STATUS_UNIMPLEMENTED,
+          "conversion from unknown iree hal element type %d",
+          (int)element_type);
+  }
+}
+
 }  // namespace
 }  // namespace PJRTApiConverter
 
@@ -717,7 +786,7 @@
       iree_hal_element_dense_byte_count(element_type);
 
   // Handle strided layouts and shape.
-  std::array<iree_hal_dim_t, 9> shape;
+  std::array<iree_hal_dim_t, kMaxDims> shape;
   if (num_dims > shape.size()) {
     return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
                             "only supports up to %d dims but got %d",
@@ -804,7 +873,7 @@
   IREE_RETURN_IF_ERROR(
       PJRTApiConverter::MapBufferTypeToElementType(type, &element_type));
 
-  std::array<iree_hal_dim_t, 9> shape;
+  std::array<iree_hal_dim_t, kMaxDims> shape;
   if (num_dims > shape.size()) {
     return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
                             "only supports up to %d dims but got %d",
@@ -852,6 +921,131 @@
   return iree_ok_status();
 }
 
+iree_status_t DeviceInstance::TransposeBroadcastDeviceBuffer(
+    BufferInstance* buffer, iree_hal_element_type_t element_type,
+    const iree_hal_dim_t* input_dims, const iree_hal_dim_t* output_dims,
+    const int64_t* perms, size_t num_dims,
+    PJRT_HostBufferSemantics host_buffer_semantics,
+    EventInstance** out_done_with_host_buffer_event,
+    BufferInstance** out_buffer) {
+  if (num_dims > kMaxDims) {
+    auto ret = iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                                "number of dimensions exceeded max supported");
+  }
+
+  std::array<iree_hal_dim_t, kMaxDims> transpose_dims;
+  for (int i = 0; i < num_dims; ++i) {
+    transpose_dims[i] = input_dims[perms[i]];
+  }
+
+  auto typeBuilder = [](const iree_hal_dim_t* dims, int64_t num_dims,
+                        const char* ty) {
+    std::stringstream ss;
+    ss << "tensor<";
+    for (int i = 0; i < num_dims; ++i) {
+      ss << dims[i] << "x";
+    }
+
+    ss << ty << ">";
+    return ss.str();
+  };
+
+  auto arrayBuilder = [](const int64_t* vals, int64_t sz) {
+    std::stringstream ss;
+    ss << " {permutation = dense<[" << vals[0];
+    for (int i = 1; i < sz; ++i) ss << ", " << vals[i];
+    ss << "]> : tensor<" << sz << "xi64>}";
+    return ss.str();
+  };
+
+  auto broadcastBuilder = [](int64_t sz) {
+    std::stringstream ss;
+    ss << "{broadcast_dimensions = dense<[0";
+    for (int i = 1; i < sz; ++i) ss << ", " << i;
+    ss << "]> : tensor<" << sz << "xi64>}";
+    return ss.str();
+  };
+
+  const char* mlir_ty;
+  IREE_RETURN_IF_ERROR(
+      PJRTApiConverter::MapElementTypeToMlirType(element_type, &mlir_ty));
+
+  auto input_ty = typeBuilder(input_dims, num_dims, mlir_ty);
+  auto transpose_ty = typeBuilder(transpose_dims.data(), num_dims, mlir_ty);
+  auto output_ty = typeBuilder(output_dims, num_dims, mlir_ty);
+  auto perms_str = arrayBuilder(perms, num_dims);
+  auto broadcast_str = broadcastBuilder(num_dims);
+
+  const char* program_literal = R"(func.func @main(%%arg0 : %1$s) -> (%3$s) {
+   %%0 = "stablehlo.transpose"(%%arg0) %4$s : (%1$s) -> %2$s
+   %%1 = "stablehlo.broadcast_in_dim"(%%0) %5$s : (%2$s) -> %3$s
+   return %%1 : %3$s
+  })";
+  char transpose_program[512];
+  size_t program_len = std::snprintf(
+      transpose_program, sizeof(transpose_program), program_literal,
+      input_ty.c_str(), transpose_ty.c_str(), output_ty.c_str(),
+      perms_str.c_str(), broadcast_str.c_str());
+  if (program_len > sizeof(transpose_program)) {
+    auto ret = iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                                "program size exceeded limit");
+  }
+
+  // Create an on stack program:
+  PJRT_Program program;
+  program.code = transpose_program;
+  program.code_size = program_len;
+  program.format = kMlirFormat.data();
+  program.format_size = kMlirFormat.size();
+
+  // Compile program and check for errors:
+  LoadedExecutableInstance* executable;
+  auto* error = this->client().Compile(&program, &executable);
+  if (error) {
+    auto errinst = ErrorInstance::FromError(error);
+    auto ret = iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                                "transposition program failed to build");
+    delete errinst;
+    return ret;
+  }
+
+  PJRT_Buffer* input = *buffer;
+  PJRT_Buffer** input_list = &input;
+
+  PJRT_Buffer* output;
+  PJRT_Buffer** output_list = &output;
+  PJRT_Event* event;
+
+  // Build the execution arguments for transposing the loaded memory:
+  PJRT_LoadedExecutable_Execute_Args execute_args;
+  memset(&execute_args, 0, sizeof(execute_args));
+
+  PJRT_ExecuteOptions execute_options;
+  memset(&execute_options, 0, sizeof(execute_options));
+  execute_args.executable = *executable;
+  execute_args.options = &execute_options;
+  execute_args.argument_lists = &input_list;
+  execute_args.output_lists = &output_list;
+  execute_args.num_devices = 1;
+  execute_args.num_args = 1;
+  execute_args.device_complete_events = &event;
+
+  // We do no support specifying the device yet.
+  execute_args.execute_device = nullptr;
+
+  auto err = executable->BatchExecute(&execute_args);
+  delete executable;
+
+  if (err) {
+    return err;
+  }
+
+  *out_buffer = BufferInstance::Unwrap(output);
+  *out_done_with_host_buffer_event = EventInstance::Unwrap(event);
+
+  return iree_ok_status();
+}
+
 iree_status_t DeviceInstance::HostBufferToDevice(
     const void* data, PJRT_Buffer_Type type, const int64_t* dims,
     size_t num_dims, const int64_t* byte_strides, size_t num_byte_strides,
@@ -902,38 +1096,30 @@
 
   // Handle strided layouts and shape:
   std::vector<int64_t> perms(num_dims);
-  std::array<iree_hal_dim_t, 9> input_shape;
-  std::array<iree_hal_dim_t, 9> output_shape;
+  std::array<iree_hal_dim_t, kMaxDims> input_shape;
+  std::array<iree_hal_dim_t, kMaxDims> transpose_shape;
+  std::array<iree_hal_dim_t, kMaxDims> output_shape;
   if (num_dims > input_shape.size()) {
     return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
                             "only supports up to %d dims but got %d",
                             (int)input_shape.size(), (int)num_dims);
   }
 
-  for (int i = 0, s = num_dims; i < s; ++i) {
-    output_shape[i] = dims[i];
-  }
-
   // Compute the input shape and permutations for the broadcast.
   iree::pjrt::computeBroadcastArgs(
       num_dims, element_type_byte_size, byte_strides, dims,
       reinterpret_cast<int64_t*>(input_shape.data()), perms.data());
 
+  for (int i = 0, s = num_dims; i < s; ++i) {
+    transpose_shape[i] = input_shape[perms[i]];
+    output_shape[i] = dims[i];
+  }
+
   bool is_dense_row_major = true;
   for (int i = 0, s = num_dims; i < s; ++i) {
     is_dense_row_major &= (input_shape[i] == dims[i]) && (perms[i] == i);
   }
 
-  std::vector<int8_t> transposed_data;
-  if (!is_dense_row_major) {
-    client().logger().debug(
-        "Performing transpose on host. This uses 2x memory and is slower than "
-        "doing the transpose on device. See: "
-        "https://github.com/openxla/openxla-pjrt-plugin/issues/201");
-    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-                            "transposition is not currently supported");
-  }
-
   iree::vm::ref<iree_hal_buffer_t> buffer;
   // There are multiple ways to implement zero-copy/staged transfers and each
   // implementation will have different performance cliffs associated with
@@ -1007,20 +1193,30 @@
   // Wrap in a buffer view and return.
   iree::vm::ref<iree_hal_buffer_view_t> result_buffer_view;
   IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create(
-      buffer.get(), num_dims, &output_shape[0], element_type,
+      buffer.get(), num_dims, &input_shape[0], element_type,
       IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, client_.host_allocator(),
       &result_buffer_view));
 
   auto instance = new BufferInstance(*this, std::move(result_buffer_view));
   instance->AdvanceReadyFence(transfer_timeline_.get(), signal_copy_complete);
   instance->AdvanceDoneFence(transfer_timeline_.get(), signal_copy_complete);
-  *out_buffer = instance;
 
-  // We snapshotted the caller data when acquiring the host staging buffer,
-  // so we won't be touching it again.
-  *out_done_with_host_buffer_event = new EventInstance(/*fence=*/nullptr);
+  if (is_dense_row_major) {
+    *out_buffer = instance;
 
-  return iree_ok_status();
+    // We snapshotted the caller data when acquiring the host staging buffer,
+    // so we won't be touching it again.
+    *out_done_with_host_buffer_event = new EventInstance(/*fence=*/nullptr);
+
+    return iree_ok_status();
+  }
+
+  auto err = TransposeBroadcastDeviceBuffer(
+      instance, element_type, input_shape.data(), output_shape.data(),
+      perms.data(), num_dims, host_buffer_semantics,
+      out_done_with_host_buffer_event, out_buffer);
+  delete instance;
+  return err;
 }
 
 iree_status_t DeviceInstance::AcquireHostStagingBuffer(
@@ -2064,7 +2260,7 @@
   api->struct_size = PJRT_Api_STRUCT_SIZE;
   api->extension_start = nullptr;
   api->pjrt_api_version.major_version = PJRT_API_MAJOR;
-  api->pjrt_api_version.minor_version = PJRT_API_MINOR;
+  api->pjrt_api_version.minor_version = PJRT_API_MINOR - 1;
 
   // This is a bare implementation throwing UNDEFINED errors. This way new
   // functions will not segmentation fault on invocation.
diff --git a/integrations/pjrt/src/iree_pjrt/common/api_impl.h b/integrations/pjrt/src/iree_pjrt/common/api_impl.h
index 4a84e89..ecbea06 100644
--- a/integrations/pjrt/src/iree_pjrt/common/api_impl.h
+++ b/integrations/pjrt/src/iree_pjrt/common/api_impl.h
@@ -224,6 +224,14 @@
       size_t num_dims, EventInstance** out_done_with_host_buffer_event,
       BufferInstance** out_buffer);
 
+  iree_status_t TransposeBroadcastDeviceBuffer(
+      BufferInstance* buffer, iree_hal_element_type_t type,
+      const iree_hal_dim_t* input_dims, const iree_hal_dim_t* output_dims,
+      const int64_t* perms, size_t num_dims,
+      PJRT_HostBufferSemantics host_buffer_semantics,
+      EventInstance** out_done_with_host_buffer_event,
+      BufferInstance** out_buffer);
+
   // Copies a host buffer to the device.
   // See PJRT_Client_BufferFromHostBuffer
   iree_status_t HostBufferToDevice(