Implement transposition / broadcast on host transfers (#15300)
It is possible to use a specialized program for device transfers. This
can minimize the amount of data transferred and leverage the accelerator
for peforming the transposition while decreasing the required host
resources.
diff --git a/integrations/pjrt/src/iree_pjrt/common/api_impl.cc b/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
index 6ac6f47..714b2a4 100644
--- a/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
+++ b/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
@@ -7,6 +7,8 @@
#include "iree_pjrt/common/api_impl.h"
#include <optional>
+#include <sstream>
+#include <utility>
#include "iree/hal/api.h"
#include "iree_pjrt/common/iree_helpers.h"
@@ -20,6 +22,9 @@
const std::string_view kMlirFormat = "mlir";
+// We hardcode the maximum number of dimensions to avoid mallocs.
+constexpr int64_t kMaxDims = 9;
+
// Some general conversion functions for managing around some API layering
// that is in flight. It is expected that most of this goes away over time.
namespace PJRTApiConverter {
@@ -88,6 +93,70 @@
}
}
+iree_status_t MapElementTypeToMlirType(iree_hal_element_type_t element_type,
+ char const** ty) {
+ switch (element_type) {
+ case PJRT_Buffer_Type_INVALID:
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT);
+ case IREE_HAL_ELEMENT_TYPE_BOOL_8:
+ *ty = "i1";
+ return iree_ok_status();
+ case IREE_HAL_ELEMENT_TYPE_SINT_4:
+ *ty = "si4";
+ return iree_ok_status();
+ case IREE_HAL_ELEMENT_TYPE_SINT_8:
+ *ty = "si8";
+ return iree_ok_status();
+ case IREE_HAL_ELEMENT_TYPE_SINT_16:
+ *ty = "si16";
+ return iree_ok_status();
+ case IREE_HAL_ELEMENT_TYPE_SINT_32:
+ *ty = "si32";
+ return iree_ok_status();
+ case IREE_HAL_ELEMENT_TYPE_SINT_64:
+ *ty = "si64";
+ return iree_ok_status();
+ case IREE_HAL_ELEMENT_TYPE_UINT_4:
+ *ty = "ui4";
+ return iree_ok_status();
+ case IREE_HAL_ELEMENT_TYPE_UINT_8:
+ *ty = "ui8";
+ return iree_ok_status();
+ case IREE_HAL_ELEMENT_TYPE_UINT_16:
+ *ty = "ui16";
+ return iree_ok_status();
+ case IREE_HAL_ELEMENT_TYPE_UINT_32:
+ *ty = "ui32";
+ return iree_ok_status();
+ case IREE_HAL_ELEMENT_TYPE_UINT_64:
+ *ty = "ui64";
+ return iree_ok_status();
+ case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
+ *ty = "f16";
+ return iree_ok_status();
+ case IREE_HAL_ELEMENT_TYPE_FLOAT_32:
+ *ty = "f32";
+ return iree_ok_status();
+ case IREE_HAL_ELEMENT_TYPE_FLOAT_64:
+ *ty = "f64";
+ return iree_ok_status();
+ case IREE_HAL_ELEMENT_TYPE_BFLOAT_16:
+ *ty = "bf16";
+ return iree_ok_status();
+ case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64:
+ *ty = "complex<f32>";
+ return iree_ok_status();
+ case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128:
+ *ty = "complex<f64>";
+ return iree_ok_status();
+ default:
+ return iree_make_status(
+ IREE_STATUS_UNIMPLEMENTED,
+ "conversion from unknown iree hal element type %d",
+ (int)element_type);
+ }
+}
+
} // namespace
} // namespace PJRTApiConverter
@@ -717,7 +786,7 @@
iree_hal_element_dense_byte_count(element_type);
// Handle strided layouts and shape.
- std::array<iree_hal_dim_t, 9> shape;
+ std::array<iree_hal_dim_t, kMaxDims> shape;
if (num_dims > shape.size()) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"only supports up to %d dims but got %d",
@@ -804,7 +873,7 @@
IREE_RETURN_IF_ERROR(
PJRTApiConverter::MapBufferTypeToElementType(type, &element_type));
- std::array<iree_hal_dim_t, 9> shape;
+ std::array<iree_hal_dim_t, kMaxDims> shape;
if (num_dims > shape.size()) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"only supports up to %d dims but got %d",
@@ -852,6 +921,131 @@
return iree_ok_status();
}
+iree_status_t DeviceInstance::TransposeBroadcastDeviceBuffer(
+ BufferInstance* buffer, iree_hal_element_type_t element_type,
+ const iree_hal_dim_t* input_dims, const iree_hal_dim_t* output_dims,
+ const int64_t* perms, size_t num_dims,
+ PJRT_HostBufferSemantics host_buffer_semantics,
+ EventInstance** out_done_with_host_buffer_event,
+ BufferInstance** out_buffer) {
+ if (num_dims > kMaxDims) {
+ auto ret = iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "number of dimensions exceeded max supported");
+ }
+
+ std::array<iree_hal_dim_t, kMaxDims> transpose_dims;
+ for (int i = 0; i < num_dims; ++i) {
+ transpose_dims[i] = input_dims[perms[i]];
+ }
+
+ auto typeBuilder = [](const iree_hal_dim_t* dims, int64_t num_dims,
+ const char* ty) {
+ std::stringstream ss;
+ ss << "tensor<";
+ for (int i = 0; i < num_dims; ++i) {
+ ss << dims[i] << "x";
+ }
+
+ ss << ty << ">";
+ return ss.str();
+ };
+
+ auto arrayBuilder = [](const int64_t* vals, int64_t sz) {
+ std::stringstream ss;
+ ss << " {permutation = dense<[" << vals[0];
+ for (int i = 1; i < sz; ++i) ss << ", " << vals[i];
+ ss << "]> : tensor<" << sz << "xi64>}";
+ return ss.str();
+ };
+
+ auto broadcastBuilder = [](int64_t sz) {
+ std::stringstream ss;
+ ss << "{broadcast_dimensions = dense<[0";
+ for (int i = 1; i < sz; ++i) ss << ", " << i;
+ ss << "]> : tensor<" << sz << "xi64>}";
+ return ss.str();
+ };
+
+ const char* mlir_ty;
+ IREE_RETURN_IF_ERROR(
+ PJRTApiConverter::MapElementTypeToMlirType(element_type, &mlir_ty));
+
+ auto input_ty = typeBuilder(input_dims, num_dims, mlir_ty);
+ auto transpose_ty = typeBuilder(transpose_dims.data(), num_dims, mlir_ty);
+ auto output_ty = typeBuilder(output_dims, num_dims, mlir_ty);
+ auto perms_str = arrayBuilder(perms, num_dims);
+ auto broadcast_str = broadcastBuilder(num_dims);
+
+ const char* program_literal = R"(func.func @main(%%arg0 : %1$s) -> (%3$s) {
+ %%0 = "stablehlo.transpose"(%%arg0) %4$s : (%1$s) -> %2$s
+ %%1 = "stablehlo.broadcast_in_dim"(%%0) %5$s : (%2$s) -> %3$s
+ return %%1 : %3$s
+ })";
+ char transpose_program[512];
+ size_t program_len = std::snprintf(
+ transpose_program, sizeof(transpose_program), program_literal,
+ input_ty.c_str(), transpose_ty.c_str(), output_ty.c_str(),
+ perms_str.c_str(), broadcast_str.c_str());
+ if (program_len > sizeof(transpose_program)) {
+ auto ret = iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "program size exceeded limit");
+ }
+
+ // Create an on stack program:
+ PJRT_Program program;
+ program.code = transpose_program;
+ program.code_size = program_len;
+ program.format = kMlirFormat.data();
+ program.format_size = kMlirFormat.size();
+
+ // Compile program and check for errors:
+ LoadedExecutableInstance* executable;
+ auto* error = this->client().Compile(&program, &executable);
+ if (error) {
+ auto errinst = ErrorInstance::FromError(error);
+ auto ret = iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "transposition program failed to build");
+ delete errinst;
+ return ret;
+ }
+
+ PJRT_Buffer* input = *buffer;
+ PJRT_Buffer** input_list = &input;
+
+ PJRT_Buffer* output;
+ PJRT_Buffer** output_list = &output;
+ PJRT_Event* event;
+
+ // Build the execution arguments for transposing the loaded memory:
+ PJRT_LoadedExecutable_Execute_Args execute_args;
+ memset(&execute_args, 0, sizeof(execute_args));
+
+ PJRT_ExecuteOptions execute_options;
+ memset(&execute_options, 0, sizeof(execute_options));
+ execute_args.executable = *executable;
+ execute_args.options = &execute_options;
+ execute_args.argument_lists = &input_list;
+ execute_args.output_lists = &output_list;
+ execute_args.num_devices = 1;
+ execute_args.num_args = 1;
+ execute_args.device_complete_events = &event;
+
+ // We do no support specifying the device yet.
+ execute_args.execute_device = nullptr;
+
+ auto err = executable->BatchExecute(&execute_args);
+ delete executable;
+
+ if (err) {
+ return err;
+ }
+
+ *out_buffer = BufferInstance::Unwrap(output);
+ *out_done_with_host_buffer_event = EventInstance::Unwrap(event);
+
+ return iree_ok_status();
+}
+
iree_status_t DeviceInstance::HostBufferToDevice(
const void* data, PJRT_Buffer_Type type, const int64_t* dims,
size_t num_dims, const int64_t* byte_strides, size_t num_byte_strides,
@@ -902,38 +1096,30 @@
// Handle strided layouts and shape:
std::vector<int64_t> perms(num_dims);
- std::array<iree_hal_dim_t, 9> input_shape;
- std::array<iree_hal_dim_t, 9> output_shape;
+ std::array<iree_hal_dim_t, kMaxDims> input_shape;
+ std::array<iree_hal_dim_t, kMaxDims> transpose_shape;
+ std::array<iree_hal_dim_t, kMaxDims> output_shape;
if (num_dims > input_shape.size()) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"only supports up to %d dims but got %d",
(int)input_shape.size(), (int)num_dims);
}
- for (int i = 0, s = num_dims; i < s; ++i) {
- output_shape[i] = dims[i];
- }
-
// Compute the input shape and permutations for the broadcast.
iree::pjrt::computeBroadcastArgs(
num_dims, element_type_byte_size, byte_strides, dims,
reinterpret_cast<int64_t*>(input_shape.data()), perms.data());
+ for (int i = 0, s = num_dims; i < s; ++i) {
+ transpose_shape[i] = input_shape[perms[i]];
+ output_shape[i] = dims[i];
+ }
+
bool is_dense_row_major = true;
for (int i = 0, s = num_dims; i < s; ++i) {
is_dense_row_major &= (input_shape[i] == dims[i]) && (perms[i] == i);
}
- std::vector<int8_t> transposed_data;
- if (!is_dense_row_major) {
- client().logger().debug(
- "Performing transpose on host. This uses 2x memory and is slower than "
- "doing the transpose on device. See: "
- "https://github.com/openxla/openxla-pjrt-plugin/issues/201");
- return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
- "transposition is not currently supported");
- }
-
iree::vm::ref<iree_hal_buffer_t> buffer;
// There are multiple ways to implement zero-copy/staged transfers and each
// implementation will have different performance cliffs associated with
@@ -1007,20 +1193,30 @@
// Wrap in a buffer view and return.
iree::vm::ref<iree_hal_buffer_view_t> result_buffer_view;
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create(
- buffer.get(), num_dims, &output_shape[0], element_type,
+ buffer.get(), num_dims, &input_shape[0], element_type,
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, client_.host_allocator(),
&result_buffer_view));
auto instance = new BufferInstance(*this, std::move(result_buffer_view));
instance->AdvanceReadyFence(transfer_timeline_.get(), signal_copy_complete);
instance->AdvanceDoneFence(transfer_timeline_.get(), signal_copy_complete);
- *out_buffer = instance;
- // We snapshotted the caller data when acquiring the host staging buffer,
- // so we won't be touching it again.
- *out_done_with_host_buffer_event = new EventInstance(/*fence=*/nullptr);
+ if (is_dense_row_major) {
+ *out_buffer = instance;
- return iree_ok_status();
+ // We snapshotted the caller data when acquiring the host staging buffer,
+ // so we won't be touching it again.
+ *out_done_with_host_buffer_event = new EventInstance(/*fence=*/nullptr);
+
+ return iree_ok_status();
+ }
+
+ auto err = TransposeBroadcastDeviceBuffer(
+ instance, element_type, input_shape.data(), output_shape.data(),
+ perms.data(), num_dims, host_buffer_semantics,
+ out_done_with_host_buffer_event, out_buffer);
+ delete instance;
+ return err;
}
iree_status_t DeviceInstance::AcquireHostStagingBuffer(
@@ -2064,7 +2260,7 @@
api->struct_size = PJRT_Api_STRUCT_SIZE;
api->extension_start = nullptr;
api->pjrt_api_version.major_version = PJRT_API_MAJOR;
- api->pjrt_api_version.minor_version = PJRT_API_MINOR;
+ api->pjrt_api_version.minor_version = PJRT_API_MINOR - 1;
// This is a bare implementation throwing UNDEFINED errors. This way new
// functions will not segmentation fault on invocation.
diff --git a/integrations/pjrt/src/iree_pjrt/common/api_impl.h b/integrations/pjrt/src/iree_pjrt/common/api_impl.h
index 4a84e89..ecbea06 100644
--- a/integrations/pjrt/src/iree_pjrt/common/api_impl.h
+++ b/integrations/pjrt/src/iree_pjrt/common/api_impl.h
@@ -224,6 +224,14 @@
size_t num_dims, EventInstance** out_done_with_host_buffer_event,
BufferInstance** out_buffer);
+ iree_status_t TransposeBroadcastDeviceBuffer(
+ BufferInstance* buffer, iree_hal_element_type_t type,
+ const iree_hal_dim_t* input_dims, const iree_hal_dim_t* output_dims,
+ const int64_t* perms, size_t num_dims,
+ PJRT_HostBufferSemantics host_buffer_semantics,
+ EventInstance** out_done_with_host_buffer_event,
+ BufferInstance** out_buffer);
+
// Copies a host buffer to the device.
// See PJRT_Client_BufferFromHostBuffer
iree_status_t HostBufferToDevice(