blob: 0526c37d64e6c0a4c42deaaa74e4d9b53579add3 [file] [log] [blame]
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree_pjrt/common/api_impl.h"
#include <optional>
#include <sstream>
#include <utility>
#include "iree/hal/api.h"
#include "iree_pjrt/common/iree_helpers.h"
#include "iree_pjrt/common/tensor_utils.h"
// TODO: Excise. Uses deep XLA internals.
// #include "xla/pjrt/transpose.h"
using iree::vm::retain_ref;
namespace iree::pjrt {
const std::string_view kMlirFormat = "mlir";
// We hardcode the maximum number of dimensions to avoid mallocs.
constexpr int64_t kMaxDims = 9;
// Some general conversion functions for managing around some API layering
// that is in flight. It is expected that most of this goes away over time.
namespace PJRTApiConverter {
namespace {
iree_status_t MapBufferTypeToElementType(
PJRT_Buffer_Type buffer_type, iree_hal_element_type_t* element_type) {
switch (buffer_type) {
case PJRT_Buffer_Type_INVALID:
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT);
case PJRT_Buffer_Type_PRED:
*element_type = IREE_HAL_ELEMENT_TYPE_BOOL_8;
return iree_ok_status();
case PJRT_Buffer_Type_S4:
*element_type = IREE_HAL_ELEMENT_TYPE_SINT_4;
return iree_ok_status();
case PJRT_Buffer_Type_S8:
*element_type = IREE_HAL_ELEMENT_TYPE_SINT_8;
return iree_ok_status();
case PJRT_Buffer_Type_S16:
*element_type = IREE_HAL_ELEMENT_TYPE_SINT_16;
return iree_ok_status();
case PJRT_Buffer_Type_S32:
*element_type = IREE_HAL_ELEMENT_TYPE_SINT_32;
return iree_ok_status();
case PJRT_Buffer_Type_S64:
*element_type = IREE_HAL_ELEMENT_TYPE_SINT_64;
return iree_ok_status();
case PJRT_Buffer_Type_U4:
*element_type = IREE_HAL_ELEMENT_TYPE_UINT_4;
return iree_ok_status();
case PJRT_Buffer_Type_U8:
*element_type = IREE_HAL_ELEMENT_TYPE_UINT_8;
return iree_ok_status();
case PJRT_Buffer_Type_U16:
*element_type = IREE_HAL_ELEMENT_TYPE_UINT_16;
return iree_ok_status();
case PJRT_Buffer_Type_U32:
*element_type = IREE_HAL_ELEMENT_TYPE_UINT_32;
return iree_ok_status();
case PJRT_Buffer_Type_U64:
*element_type = IREE_HAL_ELEMENT_TYPE_UINT_64;
return iree_ok_status();
case PJRT_Buffer_Type_F16:
*element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_16;
return iree_ok_status();
case PJRT_Buffer_Type_F32:
*element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_32;
return iree_ok_status();
case PJRT_Buffer_Type_F64:
*element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_64;
return iree_ok_status();
case PJRT_Buffer_Type_BF16:
*element_type = IREE_HAL_ELEMENT_TYPE_BFLOAT_16;
return iree_ok_status();
case PJRT_Buffer_Type_C64:
*element_type = IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64;
return iree_ok_status();
case PJRT_Buffer_Type_C128:
*element_type = IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128;
return iree_ok_status();
default:
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"conversion from unknown buffer type %d",
(int)buffer_type);
}
}
iree_status_t MapElementTypeToMlirType(iree_hal_element_type_t element_type,
char const** ty) {
switch (element_type) {
case PJRT_Buffer_Type_INVALID:
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT);
case IREE_HAL_ELEMENT_TYPE_BOOL_8:
*ty = "i1";
return iree_ok_status();
case IREE_HAL_ELEMENT_TYPE_SINT_4:
*ty = "si4";
return iree_ok_status();
case IREE_HAL_ELEMENT_TYPE_SINT_8:
*ty = "si8";
return iree_ok_status();
case IREE_HAL_ELEMENT_TYPE_SINT_16:
*ty = "si16";
return iree_ok_status();
case IREE_HAL_ELEMENT_TYPE_SINT_32:
*ty = "si32";
return iree_ok_status();
case IREE_HAL_ELEMENT_TYPE_SINT_64:
*ty = "si64";
return iree_ok_status();
case IREE_HAL_ELEMENT_TYPE_UINT_4:
*ty = "ui4";
return iree_ok_status();
case IREE_HAL_ELEMENT_TYPE_UINT_8:
*ty = "ui8";
return iree_ok_status();
case IREE_HAL_ELEMENT_TYPE_UINT_16:
*ty = "ui16";
return iree_ok_status();
case IREE_HAL_ELEMENT_TYPE_UINT_32:
*ty = "ui32";
return iree_ok_status();
case IREE_HAL_ELEMENT_TYPE_UINT_64:
*ty = "ui64";
return iree_ok_status();
case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
*ty = "f16";
return iree_ok_status();
case IREE_HAL_ELEMENT_TYPE_FLOAT_32:
*ty = "f32";
return iree_ok_status();
case IREE_HAL_ELEMENT_TYPE_FLOAT_64:
*ty = "f64";
return iree_ok_status();
case IREE_HAL_ELEMENT_TYPE_BFLOAT_16:
*ty = "bf16";
return iree_ok_status();
case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64:
*ty = "complex<f32>";
return iree_ok_status();
case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128:
*ty = "complex<f64>";
return iree_ok_status();
default:
return iree_make_status(
IREE_STATUS_UNIMPLEMENTED,
"conversion from unknown iree hal element type %d",
(int)element_type);
}
}
} // namespace
} // namespace PJRTApiConverter
//===----------------------------------------------------------------------===//
// Error
//===----------------------------------------------------------------------===//
void ErrorInstance::BindApi(PJRT_Api* api) {
api->PJRT_Error_Destroy = +[](PJRT_Error_Destroy_Args* args) {
if (!args->error) return;
delete ErrorInstance::FromError(args->error);
};
api->PJRT_Error_Message = +[](PJRT_Error_Message_Args* args) {
auto* error = ErrorInstance::FromError(args->error);
if (!error) {
args->message = "OK";
args->message_size = 2;
return;
}
const std::string& message = error->message();
args->message = message.data();
args->message_size = message.size();
};
api->PJRT_Error_GetCode = +[](PJRT_Error_GetCode_Args* args) -> PJRT_Error* {
auto* error = ErrorInstance::FromError(args->error);
iree_status_code_t status_code = iree_status_code(error->status());
switch (status_code) {
case IREE_STATUS_CANCELLED:
args->code = PJRT_Error_Code_CANCELLED;
break;
case IREE_STATUS_UNKNOWN:
args->code = PJRT_Error_Code_UNKNOWN;
break;
case IREE_STATUS_INVALID_ARGUMENT:
args->code = PJRT_Error_Code_INVALID_ARGUMENT;
break;
case IREE_STATUS_DEADLINE_EXCEEDED:
args->code = PJRT_Error_Code_DEADLINE_EXCEEDED;
break;
case IREE_STATUS_NOT_FOUND:
args->code = PJRT_Error_Code_NOT_FOUND;
break;
case IREE_STATUS_ALREADY_EXISTS:
args->code = PJRT_Error_Code_ALREADY_EXISTS;
break;
case IREE_STATUS_PERMISSION_DENIED:
args->code = PJRT_Error_Code_PERMISSION_DENIED;
break;
case IREE_STATUS_RESOURCE_EXHAUSTED:
args->code = PJRT_Error_Code_RESOURCE_EXHAUSTED;
break;
case IREE_STATUS_FAILED_PRECONDITION:
args->code = PJRT_Error_Code_FAILED_PRECONDITION;
break;
case IREE_STATUS_ABORTED:
args->code = PJRT_Error_Code_ABORTED;
break;
case IREE_STATUS_OUT_OF_RANGE:
args->code = PJRT_Error_Code_OUT_OF_RANGE;
break;
case IREE_STATUS_UNIMPLEMENTED:
args->code = PJRT_Error_Code_UNIMPLEMENTED;
break;
case IREE_STATUS_INTERNAL:
args->code = PJRT_Error_Code_INTERNAL;
break;
case IREE_STATUS_UNAVAILABLE:
args->code = PJRT_Error_Code_UNAVAILABLE;
break;
case IREE_STATUS_DATA_LOSS:
args->code = PJRT_Error_Code_DATA_LOSS;
break;
case IREE_STATUS_UNAUTHENTICATED:
args->code = PJRT_Error_Code_UNAUTHENTICATED;
break;
case IREE_STATUS_DEFERRED:
args->code = PJRT_Error_Code_UNKNOWN; // No mapping
break;
case IREE_STATUS_INCOMPATIBLE:
args->code = PJRT_Error_Code_NOT_FOUND;
break;
default:
// Should not happen.
args->code = PJRT_Error_Code_UNKNOWN;
}
return nullptr;
};
}
const std::string& ErrorInstance::message() const {
if (cached_message_.empty()) {
std::string buffer;
iree_host_size_t actual_len;
buffer.resize(1024); // TODO: Actually reallocate to full size on trunc.
if (!iree_status_format(status_, buffer.size(), buffer.data(),
&actual_len)) {
buffer.resize(actual_len);
if (!iree_status_format(status_, buffer.size(), buffer.data(),
&actual_len)) {
actual_len = 0;
}
}
buffer.resize(actual_len);
cached_message_ = std::move(buffer);
}
return cached_message_;
}
//===----------------------------------------------------------------------===//
// BufferInstance
//===----------------------------------------------------------------------===//
BufferInstance::~BufferInstance() = default;
BufferInstance::BufferInstance(
DeviceInstance& device, iree::vm::ref<iree_hal_buffer_view_t> buffer_view)
: device_(device), buffer_view_(std::move(buffer_view)) {
IREE_CHECK_OK(device.CreateFence(&ready_fence_));
IREE_CHECK_OK(device.CreateFence(&done_fence_));
// Cache the dims.
size_t rank = iree_hal_buffer_view_shape_rank(buffer_view_.get());
const iree_hal_dim_t* dims =
iree_hal_buffer_view_shape_dims(buffer_view_.get());
dims_.resize(rank);
for (size_t i = 0; i < rank; ++i) {
dims_[i] = dims[i];
}
}
void BufferInstance::ComputeLayout() {
iree_hal_encoding_type_t encoding =
iree_hal_buffer_view_encoding_type(buffer_view_.get());
iree_hal_element_type_t element_type =
iree_hal_buffer_view_element_type(buffer_view_.get());
layout_.Reset();
if (encoding == IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR &&
iree_hal_element_is_byte_aligned(element_type)) {
// It is not documented, but PJRT only supports device buffers with a tiled
// layout.
layout_.InitializeDenseRowMajorTiled(dims_.size());
}
}
void BufferInstance::BindApi(PJRT_Api* api) {
api->PJRT_Buffer_Destroy =
+[](PJRT_Buffer_Destroy_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE_NAMED("PJRT_Buffer_Destroy");
BufferInstance* buffer = BufferInstance::Unwrap(args->buffer);
delete buffer;
return nullptr;
};
api->PJRT_Buffer_ElementType =
+[](PJRT_Buffer_ElementType_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE_NAMED("PJRT_Buffer_ElementType");
BufferInstance* buffer = BufferInstance::Unwrap(args->buffer);
auto element_type = buffer->element_type();
if (!element_type) {
return MakeError(iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"Unsupported PJRT buffer type"));
}
args->type = *element_type;
return nullptr;
};
api->PJRT_Buffer_Dimensions =
+[](PJRT_Buffer_Dimensions_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE_NAMED("PJRT_Buffer_Dimensions");
BufferInstance* buffer = BufferInstance::Unwrap(args->buffer);
args->dims = buffer->dims();
args->num_dims = buffer->num_dims();
return nullptr;
};
api->PJRT_Buffer_UnpaddedDimensions =
+[](PJRT_Buffer_UnpaddedDimensions_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE_NAMED("PJRT_Buffer_UnpaddedDimensions");
BufferInstance* buffer = BufferInstance::Unwrap(args->buffer);
args->unpadded_dims = buffer->dims();
args->num_dims = buffer->num_dims();
return nullptr;
};
api->PJRT_Buffer_DynamicDimensionIndices =
+[](PJRT_Buffer_DynamicDimensionIndices_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE_NAMED("PJRT_Buffer_DynamicDimensionIndices");
return MakeError(iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"PJRT_Buffer_DynamicDimensionIndices"));
};
api->PJRT_Buffer_GetMemoryLayout =
+[](PJRT_Buffer_GetMemoryLayout_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE_NAMED("PJRT_Buffer_GetMemoryLayout");
BufferInstance* buffer = BufferInstance::Unwrap(args->buffer);
const PJRT_Buffer_MemoryLayout* layout = buffer->layout();
if (!layout) {
return MakeError(
iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"Unsupported PJRT layout for buffer view"));
}
args->layout = *layout;
return nullptr;
};
api->PJRT_Buffer_ToHostBuffer =
+[](PJRT_Buffer_ToHostBuffer_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE_NAMED("PJRT_Buffer_ToHostBuffer");
BufferInstance* buffer = BufferInstance::Unwrap(args->src);
if (!args->dst) {
// Size query.
return MakeError(buffer->GetHostSizeInBytes(&args->dst_size));
} else {
// Initiate transfer.
return MakeError(
buffer->CopyToHost(args->dst, args->dst_size,
reinterpret_cast<EventInstance**>(&args->event)));
}
};
api->PJRT_Buffer_OnDeviceSizeInBytes =
+[](PJRT_Buffer_OnDeviceSizeInBytes_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE_NAMED("PJRT_Buffer_OnDeviceSizeInBytes");
BufferInstance* buffer = BufferInstance::Unwrap(args->buffer);
iree_device_size_t size =
iree_hal_buffer_view_byte_length(buffer->buffer_view());
args->on_device_size_in_bytes = size;
return nullptr;
};
api->PJRT_Buffer_Delete = +[](PJRT_Buffer_Delete_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE_NAMED("PJRT_Buffer_Delete");
BufferInstance* buffer = BufferInstance::Unwrap(args->buffer);
buffer->Delete();
return nullptr;
};
api->PJRT_Buffer_IsDeleted =
+[](PJRT_Buffer_IsDeleted_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE_NAMED("PJRT_Buffer_IsDeleted");
BufferInstance* buffer = BufferInstance::Unwrap(args->buffer);
args->is_deleted = buffer->is_deleted();
return nullptr;
};
api->PJRT_Buffer_CopyToDevice =
+[](PJRT_Buffer_CopyToDevice_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE_NAMED("PJRT_Buffer_CopyToDevice");
return MakeError(iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"PJRT_Buffer_CopyToDevice"));
};
api->PJRT_Buffer_IsOnCpu =
+[](PJRT_Buffer_IsOnCpu_Args* args) -> PJRT_Error* {
args->is_on_cpu = BufferInstance::Unwrap(args->buffer)->is_on_cpu();
return nullptr;
};
api->PJRT_Buffer_Device = +[](PJRT_Buffer_Device_Args* args) -> PJRT_Error* {
args->device = BufferInstance::Unwrap(args->buffer)->device();
return nullptr;
};
api->PJRT_Buffer_Memory = +[](PJRT_Buffer_Memory_Args* args) -> PJRT_Error* {
return MakeError(
iree_make_status(IREE_STATUS_UNIMPLEMENTED, "PJRT_Buffer_Memory"));
};
api->PJRT_Buffer_ReadyEvent =
+[](PJRT_Buffer_ReadyEvent_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE_NAMED("PJRT_Buffer_ReadyEvent");
BufferInstance* buffer = BufferInstance::Unwrap(args->buffer);
args->event = reinterpret_cast<PJRT_Event*>(
new EventInstance(retain_ref(buffer->ready_fence())));
return nullptr;
};
// TODO: Rework the API to be Aliases(b1, b2) to let the plugin explicitly
// check for aliases.
api->PJRT_Buffer_UnsafePointer =
+[](PJRT_Buffer_UnsafePointer_Args* args) -> PJRT_Error* {
BufferInstance* buffer = BufferInstance::Unwrap(args->buffer);
iree_hal_buffer_t* hal_buffer =
iree_hal_buffer_view_buffer(buffer->buffer_view());
args->buffer_pointer = (uintptr_t)hal_buffer;
return nullptr;
};
}
iree_status_t BufferInstance::GetHostSizeInBytes(iree_host_size_t* host_size) {
*host_size = iree_hal_buffer_view_byte_length(buffer_view());
return iree_ok_status();
}
iree_status_t BufferInstance::AsyncDeallocate() {
IREE_TRACE_SCOPE();
if (is_deleted_) {
return iree_ok_status();
}
is_deleted_ = true;
return IreeApi::hal_device_queue_dealloca(
device().device(), IREE_HAL_QUEUE_AFFINITY_ANY,
/*wait_semaphore_list=*/iree_hal_fence_semaphore_list(done_fence()),
/*signal_semaphore_list=*/iree_hal_semaphore_list_empty(),
iree_hal_buffer_view_buffer(buffer_view_.get()));
}
iree_status_t BufferInstance::Delete() {
IREE_TRACE_SCOPE();
is_deleted_ = true;
buffer_view_.release();
return iree_ok_status();
}
iree_status_t BufferInstance::CopyToHost(void* dst, iree_host_size_t dst_size,
EventInstance** out_done_event) {
// Use a data structure to handle intermediary buffer when necessary. This
// needs to include the destination and aligned buffer, along with the size
// so the destination can be mem-copied if necessary.
struct CopyToHostData {
void* alloc;
void* aligned;
void* dst;
size_t size;
// Fence will be signaled when copy to host is complete.
iree::vm::ref<iree_hal_fence_t> copy_done_fence;
};
// Configure a default structure that writes directly to dst.
const size_t alignment = 64;
struct CopyToHostData* copy_to_host_data = new CopyToHostData;
copy_to_host_data->alloc = nullptr;
copy_to_host_data->aligned = dst;
copy_to_host_data->dst = dst;
copy_to_host_data->size = dst_size;
// If the destination is unaligned we need to write to an intermediary buffer.
if (((uintptr_t)dst) & (alignment - 1)) {
const size_t alignment_size = alignment + dst_size + sizeof(uintptr_t);
char* alloc = new char[alignment_size];
copy_to_host_data->alloc = alloc;
copy_to_host_data->aligned =
(void*)((((uintptr_t)alloc + alignment) & ~(uintptr_t)(alignment - 1)));
}
// Import the destination (host) buffer as an iree_hal_buffer_t so that we
// can issue copy commands.
iree::vm::ref<iree_hal_buffer_t> dst_buffer;
iree_hal_buffer_params_t dst_buffer_params = {
/*usage=*/IREE_HAL_BUFFER_USAGE_TRANSFER_TARGET,
// TODO: We should be able to use WRITE access here since the buffer
// is never actually mapped to read back out (just accessed through the
// void* later). However, that seems to cause the memory to never be
// committed and the interaction aborted.
/*access=*/IREE_HAL_MEMORY_ACCESS_ALL,
/*type=*/IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
};
iree_hal_external_buffer_t dst_external_buffer;
memset(&dst_external_buffer, 0, sizeof(dst_external_buffer));
dst_external_buffer.type = IREE_HAL_EXTERNAL_BUFFER_TYPE_HOST_ALLOCATION;
dst_external_buffer.flags = IREE_HAL_EXTERNAL_BUFFER_FLAG_NONE;
dst_external_buffer.size = dst_size;
dst_external_buffer.handle.host_allocation.ptr = copy_to_host_data->aligned;
IREE_RETURN_IF_ERROR(IreeApi::hal_allocator_import_buffer(
device_.device_allocator(), dst_buffer_params, &dst_external_buffer,
/*release_callback=*/iree_hal_buffer_release_callback_null(),
&dst_buffer));
// Create the transfer command buffer.
iree::vm::ref<iree_hal_command_buffer_t> transfer_cb;
iree_hal_transfer_command_t transfer_command;
memset(&transfer_command, 0, sizeof(transfer_command));
transfer_command.type = IREE_HAL_TRANSFER_COMMAND_TYPE_COPY;
transfer_command.copy.source_buffer =
iree_hal_buffer_view_buffer(buffer_view());
transfer_command.copy.source_offset = 0;
transfer_command.copy.target_buffer = dst_buffer.get();
transfer_command.copy.target_offset = 0;
transfer_command.copy.length = dst_size;
IREE_RETURN_IF_ERROR(iree_hal_create_transfer_command_buffer(
device_.device(), IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT,
IREE_HAL_QUEUE_AFFINITY_ANY,
/*transfer_count=*/1, &transfer_command, &transfer_cb));
dst_buffer.reset();
iree::vm::ref<iree_hal_semaphore_t> semaphore;
IREE_RETURN_IF_ERROR(iree_hal_semaphore_create(
device_.device(), 0ull, IREE_HAL_SEMAPHORE_FLAG_NONE, &semaphore));
// Signaled when `dst_buffer` is ready to be consumed.
iree::vm::ref<iree_hal_fence_t> dst_buffer_ready_fence;
IREE_RETURN_IF_ERROR(IreeApi::hal_fence_create_at(
semaphore.get(), 1ull, device_.client().host_allocator(),
&dst_buffer_ready_fence));
// Signaled when copy to host is complete.
IREE_RETURN_IF_ERROR(IreeApi::hal_fence_create_at(
semaphore.get(), 2ull, device_.client().host_allocator(),
&(copy_to_host_data->copy_done_fence)));
auto dst_buffer_callback = [](PJRT_Error* error, void* user_data) {
const ErrorInstance* error_instance = ErrorInstance::FromError(error);
auto* copy_data = static_cast<CopyToHostData*>(user_data);
if (!error) {
// If there is an allocated buffer we need to copy to the destinaton.
if (copy_data->alloc) {
std::memcpy(copy_data->dst, copy_data->aligned, copy_data->size);
}
iree_hal_fence_signal(copy_data->copy_done_fence.get());
} else {
iree_hal_fence_fail(copy_data->copy_done_fence.get(),
error_instance->status());
}
if (copy_data->alloc) {
delete[] static_cast<char*>(copy_data->alloc);
}
delete copy_data;
delete error_instance;
};
// This callback simply deletes the `dst_buffer_ready_event`. We could perform
// this deletion in the `dst_buffer_callback`, but this would result in the
// callback thread of `dst_buffer_ready_event` detaching from the main thread,
// potentially resulting in the callback thread outliving the main thread.
auto copy_done_callback = [](PJRT_Error* error, void* user_data) {
EventInstance* dst_buffer_ready_event =
static_cast<EventInstance*>(user_data);
delete dst_buffer_ready_event;
delete ErrorInstance::FromError(error);
};
auto dst_buffer_ready_event =
new EventInstance(retain_ref(dst_buffer_ready_fence));
dst_buffer_ready_event->OnReady(dst_buffer_callback, copy_to_host_data);
auto copy_done_event =
new EventInstance(retain_ref(copy_to_host_data->copy_done_fence));
copy_done_event->OnReady(copy_done_callback, dst_buffer_ready_event);
IREE_RETURN_IF_ERROR(IreeApi::hal_device_queue_execute(
device_.device(), IREE_HAL_QUEUE_AFFINITY_ANY,
/*wait_semaphore_list=*/iree_hal_fence_semaphore_list(ready_fence_.get()),
/*signal_semaphore_list=*/
iree_hal_fence_semaphore_list(dst_buffer_ready_fence.get()),
transfer_cb.get()));
*out_done_event = copy_done_event;
return iree_ok_status();
}
iree_status_t BufferInstance::AdvanceReadyFence(iree_hal_semaphore_t* semaphore,
uint64_t timepoint) {
return IreeApi::hal_fence_insert(ready_fence_.get(), semaphore, timepoint);
}
iree_status_t BufferInstance::AdvanceDoneFence(iree_hal_semaphore_t* semaphore,
uint64_t timepoint) {
return IreeApi::hal_fence_insert(done_fence_.get(), semaphore, timepoint);
}
std::optional<PJRT_Buffer_Type> BufferInstance::element_type() {
iree_hal_element_type_t hal_element_type =
iree_hal_buffer_view_element_type(buffer_view());
// TODO: Cascade on bit-field sub-types to avoid large linear scan.
switch (hal_element_type) {
// TODO: How do I interpret signless?
case IREE_HAL_ELEMENT_TYPE_BOOL_8:
return PJRT_Buffer_Type_PRED;
case IREE_HAL_ELEMENT_TYPE_INT_4:
return PJRT_Buffer_Type_S4;
case IREE_HAL_ELEMENT_TYPE_INT_8:
return PJRT_Buffer_Type_S8;
case IREE_HAL_ELEMENT_TYPE_INT_16:
return PJRT_Buffer_Type_S16;
case IREE_HAL_ELEMENT_TYPE_INT_32:
return PJRT_Buffer_Type_S32;
case IREE_HAL_ELEMENT_TYPE_INT_64:
return PJRT_Buffer_Type_S64;
case IREE_HAL_ELEMENT_TYPE_SINT_4:
return PJRT_Buffer_Type_S4;
case IREE_HAL_ELEMENT_TYPE_SINT_8:
return PJRT_Buffer_Type_S8;
case IREE_HAL_ELEMENT_TYPE_SINT_16:
return PJRT_Buffer_Type_S16;
case IREE_HAL_ELEMENT_TYPE_SINT_32:
return PJRT_Buffer_Type_S32;
case IREE_HAL_ELEMENT_TYPE_SINT_64:
return PJRT_Buffer_Type_S64;
case IREE_HAL_ELEMENT_TYPE_UINT_4:
return PJRT_Buffer_Type_U4;
case IREE_HAL_ELEMENT_TYPE_UINT_8:
return PJRT_Buffer_Type_U8;
case IREE_HAL_ELEMENT_TYPE_UINT_16:
return PJRT_Buffer_Type_U16;
case IREE_HAL_ELEMENT_TYPE_UINT_32:
return PJRT_Buffer_Type_U32;
case IREE_HAL_ELEMENT_TYPE_UINT_64:
return PJRT_Buffer_Type_U64;
case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
return PJRT_Buffer_Type_F16;
case IREE_HAL_ELEMENT_TYPE_FLOAT_32:
return PJRT_Buffer_Type_F32;
case IREE_HAL_ELEMENT_TYPE_FLOAT_64:
return PJRT_Buffer_Type_F64;
case IREE_HAL_ELEMENT_TYPE_BFLOAT_16:
return PJRT_Buffer_Type_BF16;
case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64:
return PJRT_Buffer_Type_C64;
case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128:
return PJRT_Buffer_Type_C128;
default:
return {};
}
}
//===----------------------------------------------------------------------===//
// DeviceDescription
//===----------------------------------------------------------------------===//
DeviceDescription::~DeviceDescription() = default;
void DeviceDescription::BindApi(PJRT_Api* api) {
api->PJRT_DeviceDescription_Id =
+[](PJRT_DeviceDescription_Id_Args* args) -> PJRT_Error* {
args->id = DeviceDescription::Unwrap(args->device_description)->client_id();
return nullptr;
};
api->PJRT_DeviceDescription_ProcessIndex =
+[](PJRT_DeviceDescription_ProcessIndex_Args* args) -> PJRT_Error* {
args->process_index =
DeviceDescription::Unwrap(args->device_description)->process_index();
return nullptr;
};
api->PJRT_DeviceDescription_Attributes =
+[](PJRT_DeviceDescription_Attributes_Args* args) -> PJRT_Error* {
// TODO: Implement something.
args->num_attributes = 0;
args->attributes = nullptr;
return nullptr;
};
api->PJRT_DeviceDescription_Kind =
+[](PJRT_DeviceDescription_Kind_Args* args) -> PJRT_Error* {
auto sv =
DeviceDescription::Unwrap(args->device_description)->kind_string();
args->device_kind = sv.data();
args->device_kind_size = sv.size();
return nullptr;
};
api->PJRT_DeviceDescription_DebugString =
+[](PJRT_DeviceDescription_DebugString_Args* args) -> PJRT_Error* {
auto sv =
DeviceDescription::Unwrap(args->device_description)->debug_string();
args->debug_string = sv.data();
args->debug_string_size = sv.size();
return nullptr;
};
api->PJRT_DeviceDescription_ToString =
+[](PJRT_DeviceDescription_ToString_Args* args) -> PJRT_Error* {
auto sv =
DeviceDescription::Unwrap(args->device_description)->user_string();
args->to_string = sv.data();
args->to_string_size = sv.size();
return nullptr;
};
}
//===----------------------------------------------------------------------===//
// DeviceInstance
//===----------------------------------------------------------------------===//
DeviceInstance::~DeviceInstance() = default;
void DeviceInstance::BindApi(PJRT_Api* api) {
api->PJRT_Device_IsAddressable =
+[](PJRT_Device_IsAddressable_Args* args) -> PJRT_Error* {
args->is_addressable =
DeviceInstance::Unwrap(args->device)->is_addressable();
return nullptr;
};
api->PJRT_Device_LocalHardwareId =
+[](PJRT_Device_LocalHardwareId_Args* args) -> PJRT_Error* {
args->local_hardware_id =
DeviceInstance::Unwrap(args->device)->local_hardware_id();
return nullptr;
};
api->PJRT_Device_AddressableMemories =
+[](PJRT_Device_AddressableMemories_Args* args) -> PJRT_Error* {
return MakeError(iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"PJRT_Device_AddressableMemories"));
};
api->PJRT_Device_DefaultMemory =
+[](PJRT_Device_DefaultMemory_Args* args) -> PJRT_Error* {
return MakeError(iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"PJRT_Device_DefaultMemory"));
};
api->PJRT_Device_GetDescription =
+[](PJRT_Device_GetDescription_Args* args) -> PJRT_Error* {
args->device_description = reinterpret_cast<PJRT_DeviceDescription*>(
DeviceInstance::Unwrap(args->device)->device_description());
return nullptr;
};
}
iree_status_t DeviceInstance::CreateFence(iree_hal_fence_t** out_fence) {
return IreeApi::hal_fence_create(/*capacity=*/2, client_.host_allocator(),
out_fence);
}
iree_status_t DeviceInstance::OpenDevice() {
if (device_) return iree_ok_status();
IREE_RETURN_IF_ERROR(iree_hal_driver_create_device_by_id(
driver_, /*device_id=*/info_.device_id(),
/*param_count=*/0, /*params=*/nullptr, client_.host_allocator(),
&device_));
IREE_RETURN_IF_ERROR(iree_hal_semaphore_create(
device(), 0ull, IREE_HAL_SEMAPHORE_FLAG_NONE, &main_timeline_));
IREE_RETURN_IF_ERROR(iree_hal_semaphore_create(
device(), 0ull, IREE_HAL_SEMAPHORE_FLAG_NONE, &transfer_timeline_));
return iree_ok_status();
}
iree_status_t DeviceInstance::HostBufferToDeviceSplat(
const void* data, PJRT_Buffer_Type type, const int64_t* dims,
size_t num_dims, EventInstance** out_done_with_host_buffer_event,
BufferInstance** out_buffer) {
// Map element type:
iree_hal_element_type_t element_type;
IREE_RETURN_IF_ERROR(
PJRTApiConverter::MapBufferTypeToElementType(type, &element_type));
// TODO: Do something sensible with sub-byte aligned types.
if (IREE_UNLIKELY(iree_hal_element_bit_count(element_type) == 0) ||
IREE_UNLIKELY(!iree_hal_element_is_byte_aligned(element_type))) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"opaque and sub-byte aligned element types cannot be indexed");
}
iree_device_size_t element_type_byte_size =
iree_hal_element_dense_byte_count(element_type);
// Handle strided layouts and shape.
std::array<iree_hal_dim_t, kMaxDims> shape;
if (num_dims > shape.size()) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"only supports up to %d dims but got %d",
(int)shape.size(), (int)num_dims);
}
iree_device_size_t byte_length = element_type_byte_size;
for (int i = 0, s = num_dims; i < s; ++i) {
byte_length *= dims[i];
shape[i] = dims[i];
}
iree::vm::ref<iree_hal_buffer_t> buffer;
// Allocate on stream. We serialize across 3 timepoints:
// 0. Last transfer complete
// 1. Allocation
// 2. Fill is complete
// There are various ways to be smarter about this but without more
// information from the caller, this is ok. If we wanted to favor smaller
// allocation scopes, it may be desirable to join with the main execution
// timeline, but that would obviously serialize more.
uint64_t wait_transfer_start = last_transfer_timepoint_;
uint64_t signal_alloca_complete = ++last_transfer_timepoint_;
uint64_t signal_copy_complete = ++last_transfer_timepoint_;
iree_hal_buffer_params_t params;
memset(&params, 0, sizeof(params));
params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL;
params.usage =
IREE_HAL_BUFFER_USAGE_DEFAULT | IREE_HAL_BUFFER_USAGE_TRANSFER_TARGET;
IREE_RETURN_IF_ERROR(IreeApi::hal_device_queue_alloca(
device(), IREE_HAL_QUEUE_AFFINITY_ANY,
/*wait_semaphore_list=*/
{1, &transfer_timeline_, &wait_transfer_start},
/*signal_semaphore_list=*/
{1, &transfer_timeline_, &signal_alloca_complete},
IREE_HAL_ALLOCATOR_POOL_DEFAULT, params, byte_length, &buffer));
// Queue up the buffer fill for splatting:
iree::vm::ref<iree_hal_command_buffer_t> transfer_cb;
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_create(
device(), IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT,
IREE_HAL_COMMAND_CATEGORY_ANY, IREE_HAL_QUEUE_AFFINITY_ANY,
/*binding_capacity=*/0, &transfer_cb));
IREE_CHECK_OK(iree_hal_command_buffer_begin(transfer_cb.get()));
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_fill_buffer(
transfer_cb.get(), iree_hal_make_buffer_ref(buffer.get(), 0, byte_length),
data, element_type_byte_size, IREE_HAL_FILL_FLAG_NONE));
IREE_CHECK_OK(iree_hal_command_buffer_end(transfer_cb.get()));
// Execute the enqueued splat:
IREE_RETURN_IF_ERROR(IreeApi::hal_device_queue_execute(
device(), IREE_HAL_QUEUE_AFFINITY_ANY,
/*wait_semaphore_list=*/
{1, &transfer_timeline_, &signal_alloca_complete},
/*signal_semaphore_list=*/
{1, &transfer_timeline_, &signal_copy_complete}, transfer_cb.get()));
// Wrap in a buffer view and return:
iree::vm::ref<iree_hal_buffer_view_t> result_buffer_view;
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create(
buffer.get(), num_dims, &shape[0], element_type,
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, client_.host_allocator(),
&result_buffer_view));
auto instance = new BufferInstance(*this, std::move(result_buffer_view));
instance->AdvanceReadyFence(transfer_timeline_.get(), signal_copy_complete);
instance->AdvanceDoneFence(transfer_timeline_.get(), signal_copy_complete);
*out_buffer = instance;
// Splat so the data is no longer required:
*out_done_with_host_buffer_event = new EventInstance(/*fence=*/nullptr);
return iree_ok_status();
}
iree_status_t DeviceInstance::HostBufferToDeviceZeroDim(
PJRT_Buffer_Type type, const int64_t* dims, size_t num_dims,
EventInstance** out_done_with_host_buffer_event,
BufferInstance** out_buffer) {
// Map element type:
iree_hal_element_type_t element_type;
IREE_RETURN_IF_ERROR(
PJRTApiConverter::MapBufferTypeToElementType(type, &element_type));
std::array<iree_hal_dim_t, kMaxDims> shape;
if (num_dims > shape.size()) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"only supports up to %d dims but got %d",
(int)shape.size(), (int)num_dims);
}
for (int i = 0, s = num_dims; i < s; ++i) {
shape[i] = dims[i];
}
// We only need to wait for previous transfer and allocate data:
uint64_t wait_transfer_start = last_transfer_timepoint_;
uint64_t signal_alloca_complete = ++last_transfer_timepoint_;
iree_hal_buffer_params_t params;
iree::vm::ref<iree_hal_buffer_t> buffer;
memset(&params, 0, sizeof(params));
params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL;
params.usage =
IREE_HAL_BUFFER_USAGE_DEFAULT | IREE_HAL_BUFFER_USAGE_TRANSFER_TARGET;
IREE_RETURN_IF_ERROR(IreeApi::hal_device_queue_alloca(
device(), IREE_HAL_QUEUE_AFFINITY_ANY,
/*wait_semaphore_list=*/
{1, &transfer_timeline_, &wait_transfer_start},
/*signal_semaphore_list=*/
{1, &transfer_timeline_, &signal_alloca_complete},
IREE_HAL_ALLOCATOR_POOL_DEFAULT, params,
iree_hal_element_dense_byte_count(element_type), &buffer));
// Wrap in a buffer view and return.
iree::vm::ref<iree_hal_buffer_view_t> result_buffer_view;
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create(
buffer.get(), num_dims, &shape[0], element_type,
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, client_.host_allocator(),
&result_buffer_view));
auto instance = new BufferInstance(*this, std::move(result_buffer_view));
instance->AdvanceReadyFence(transfer_timeline_.get(), signal_alloca_complete);
instance->AdvanceDoneFence(transfer_timeline_.get(), signal_alloca_complete);
*out_buffer = instance;
// Degenerate case ignores the data so we can just return:
*out_done_with_host_buffer_event = new EventInstance(/*fence=*/nullptr);
return iree_ok_status();
}
iree_status_t DeviceInstance::TransposeBroadcastDeviceBuffer(
BufferInstance* buffer, iree_hal_element_type_t element_type,
const iree_hal_dim_t* input_dims, const iree_hal_dim_t* output_dims,
const int64_t* perms, size_t num_dims,
PJRT_HostBufferSemantics host_buffer_semantics,
EventInstance** out_done_with_host_buffer_event,
BufferInstance** out_buffer) {
if (num_dims > kMaxDims) {
auto ret = iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"number of dimensions exceeded max supported");
}
std::array<iree_hal_dim_t, kMaxDims> transpose_dims;
for (int i = 0; i < num_dims; ++i) {
transpose_dims[i] = input_dims[perms[i]];
}
auto typeBuilder = [](const iree_hal_dim_t* dims, int64_t num_dims,
const char* ty) {
std::stringstream ss;
ss << "tensor<";
for (int i = 0; i < num_dims; ++i) {
ss << dims[i] << "x";
}
ss << ty << ">";
return ss.str();
};
auto arrayBuilder = [](const int64_t* vals, int64_t sz) {
std::stringstream ss;
ss << " {permutation = dense<[" << vals[0];
for (int i = 1; i < sz; ++i) ss << ", " << vals[i];
ss << "]> : tensor<" << sz << "xi64>}";
return ss.str();
};
auto broadcastBuilder = [](int64_t sz) {
std::stringstream ss;
ss << "{broadcast_dimensions = dense<[0";
for (int i = 1; i < sz; ++i) ss << ", " << i;
ss << "]> : tensor<" << sz << "xi64>}";
return ss.str();
};
const char* mlir_ty;
IREE_RETURN_IF_ERROR(
PJRTApiConverter::MapElementTypeToMlirType(element_type, &mlir_ty));
auto input_ty = typeBuilder(input_dims, num_dims, mlir_ty);
auto transpose_ty = typeBuilder(transpose_dims.data(), num_dims, mlir_ty);
auto output_ty = typeBuilder(output_dims, num_dims, mlir_ty);
auto perms_str = arrayBuilder(perms, num_dims);
auto broadcast_str = broadcastBuilder(num_dims);
const char* program_literal = R"(func.func @main(%%arg0 : %1$s) -> (%3$s) {
%%0 = "stablehlo.transpose"(%%arg0) %4$s : (%1$s) -> %2$s
%%1 = "stablehlo.broadcast_in_dim"(%%0) %5$s : (%2$s) -> %3$s
return %%1 : %3$s
})";
char transpose_program[512];
size_t program_len = std::snprintf(
transpose_program, sizeof(transpose_program), program_literal,
input_ty.c_str(), transpose_ty.c_str(), output_ty.c_str(),
perms_str.c_str(), broadcast_str.c_str());
if (program_len > sizeof(transpose_program)) {
auto ret = iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"program size exceeded limit");
}
// Create an on stack program:
PJRT_Program program;
program.code = transpose_program;
program.code_size = program_len;
program.format = kMlirFormat.data();
program.format_size = kMlirFormat.size();
// Compile program and check for errors:
LoadedExecutableInstance* executable;
auto* error = this->client().Compile(&program, &executable);
if (error) {
auto errinst = ErrorInstance::FromError(error);
auto ret = iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"transposition program failed to build");
delete errinst;
return ret;
}
PJRT_Buffer* input = *buffer;
PJRT_Buffer** input_list = &input;
PJRT_Buffer* output;
PJRT_Buffer** output_list = &output;
PJRT_Event* event;
// Build the execution arguments for transposing the loaded memory:
PJRT_LoadedExecutable_Execute_Args execute_args;
memset(&execute_args, 0, sizeof(execute_args));
PJRT_ExecuteOptions execute_options;
memset(&execute_options, 0, sizeof(execute_options));
execute_args.executable = *executable;
execute_args.options = &execute_options;
execute_args.argument_lists = &input_list;
execute_args.output_lists = &output_list;
execute_args.num_devices = 1;
execute_args.num_args = 1;
execute_args.device_complete_events = &event;
// We do no support specifying the device yet.
execute_args.execute_device = nullptr;
auto err = executable->BatchExecute(&execute_args);
delete executable;
if (err) {
return err;
}
*out_buffer = BufferInstance::Unwrap(output);
*out_done_with_host_buffer_event = EventInstance::Unwrap(event);
return iree_ok_status();
}
iree_status_t DeviceInstance::HostBufferToDevice(
const void* data, PJRT_Buffer_Type type, const int64_t* dims,
size_t num_dims, const int64_t* byte_strides, size_t num_byte_strides,
PJRT_HostBufferSemantics host_buffer_semantics,
EventInstance** out_done_with_host_buffer_event,
BufferInstance** out_buffer) {
IREE_RETURN_IF_ERROR(OpenDevice());
// Map element type.
iree_hal_element_type_t element_type;
IREE_RETURN_IF_ERROR(
PJRTApiConverter::MapBufferTypeToElementType(type, &element_type));
// TODO: Do something sensible with sub-byte aligned types.
if (IREE_UNLIKELY(iree_hal_element_bit_count(element_type) == 0) ||
IREE_UNLIKELY(!iree_hal_element_is_byte_aligned(element_type))) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"opaque and sub-byte aligned element types cannot be indexed");
}
iree_device_size_t element_type_byte_size =
iree_hal_element_dense_byte_count(element_type);
// We need to check for special cases (splatting, zerodim):
bool is_splat = element_type_byte_size == 1 || element_type_byte_size == 2 ||
element_type_byte_size == 4;
bool has_zero_dim = false;
iree_device_size_t byte_length = element_type_byte_size;
for (int i = 0; i < num_byte_strides; ++i) {
is_splat &= (dims[i] == 1 || byte_strides[i] == 0);
has_zero_dim |= (dims[i] == 0);
byte_length *= dims[i];
}
byte_length = std::max(element_type_byte_size, byte_length);
// If we encounter the zero dim case no transfer is required:
if (has_zero_dim) {
return HostBufferToDeviceZeroDim(
type, dims, num_dims, out_done_with_host_buffer_event, out_buffer);
}
// If we encounter the splat case we can perform a fill instead:
if (is_splat) {
return HostBufferToDeviceSplat(data, type, dims, num_dims,
out_done_with_host_buffer_event, out_buffer);
}
// Handle strided layouts and shape:
std::vector<int64_t> perms(num_dims);
std::array<iree_hal_dim_t, kMaxDims> input_shape;
std::array<iree_hal_dim_t, kMaxDims> transpose_shape;
std::array<iree_hal_dim_t, kMaxDims> output_shape;
if (num_dims > input_shape.size()) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"only supports up to %d dims but got %d",
(int)input_shape.size(), (int)num_dims);
}
// Compute the input shape and permutations for the broadcast.
iree::pjrt::computeBroadcastArgs(
num_dims, element_type_byte_size, byte_strides, dims,
reinterpret_cast<int64_t*>(input_shape.data()), perms.data());
for (int i = 0, s = num_dims; i < s; ++i) {
transpose_shape[i] = input_shape[perms[i]];
output_shape[i] = dims[i];
}
bool is_dense_row_major = true;
for (int i = 0, s = num_dims; i < s; ++i) {
is_dense_row_major &= (input_shape[i] == dims[i]) && (perms[i] == i);
}
iree::vm::ref<iree_hal_buffer_t> buffer;
// There are multiple ways to implement zero-copy/staged transfers and each
// implementation will have different performance cliffs associated with
// directly operating on imported host buffers. In many actual
// host/device situations, such unified memory is a productivity (not a
// performance) feature and best avoided. As such, we always need to be
// able to decide to do a staged transfer and implement that here. Using
// an imported buffer on the device is left as an optimization for
// implementations on which we believe it will be beneficial.
bool require_snapshot_now = host_buffer_semantics ==
PJRT_HostBufferSemantics_kImmutableOnlyDuringCall;
bool caller_data_done = false;
iree::vm::ref<iree_hal_buffer_t> host_staging_buffer;
IREE_RETURN_IF_ERROR(AcquireHostStagingBuffer(
iree_make_const_byte_span(data, byte_length), require_snapshot_now,
&caller_data_done, &host_staging_buffer));
if (!caller_data_done) {
return iree_make_status(
IREE_STATUS_UNIMPLEMENTED,
"deferred snapshot of host data not yet implemented");
}
// Allocate on stream. We serialize across 3 timepoints:
// 0. Last transfer complete
// 1. Allocation
// 2. This transfer complete
// There are various ways to be smarter about this but without more
// information from the caller, this is ok. If we wanted to favor smaller
// allocation scopes, it may be desirable to join with the main execution
// timeline, but that would obviously serialize more.
uint64_t wait_transfer_start = last_transfer_timepoint_;
uint64_t signal_alloca_complete = ++last_transfer_timepoint_;
uint64_t signal_copy_complete = ++last_transfer_timepoint_;
iree_hal_buffer_params_t params;
memset(&params, 0, sizeof(params));
params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL;
params.usage =
IREE_HAL_BUFFER_USAGE_DEFAULT | IREE_HAL_BUFFER_USAGE_TRANSFER_TARGET;
IREE_RETURN_IF_ERROR(IreeApi::hal_device_queue_alloca(
device(), IREE_HAL_QUEUE_AFFINITY_ANY,
/*wait_semaphore_list=*/
{1, &transfer_timeline_, &wait_transfer_start},
/*signal_semaphore_list=*/
{1, &transfer_timeline_, &signal_alloca_complete},
IREE_HAL_ALLOCATOR_POOL_DEFAULT, params, byte_length, &buffer));
// Queue up the transfer command.
iree::vm::ref<iree_hal_command_buffer_t> transfer_cb;
iree_hal_transfer_command_t transfer_command;
memset(&transfer_command, 0, sizeof(transfer_command));
transfer_command.type = IREE_HAL_TRANSFER_COMMAND_TYPE_COPY;
transfer_command.copy.source_buffer = host_staging_buffer.get(),
transfer_command.copy.source_offset = 0;
transfer_command.copy.target_buffer = buffer.get();
transfer_command.copy.target_offset = 0;
transfer_command.copy.length = byte_length;
IREE_RETURN_IF_ERROR(iree_hal_create_transfer_command_buffer(
device(), IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT,
IREE_HAL_QUEUE_AFFINITY_ANY,
/*transfer_count=*/1, &transfer_command, &transfer_cb));
IREE_RETURN_IF_ERROR(IreeApi::hal_device_queue_execute(
device(), IREE_HAL_QUEUE_AFFINITY_ANY,
/*wait_semaphore_list=*/
{1, &transfer_timeline_, &signal_alloca_complete},
/*signal_semaphore_list=*/
{1, &transfer_timeline_, &signal_copy_complete}, transfer_cb.get()));
// Wrap in a buffer view and return.
iree::vm::ref<iree_hal_buffer_view_t> result_buffer_view;
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create(
buffer.get(), num_dims, &input_shape[0], element_type,
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, client_.host_allocator(),
&result_buffer_view));
auto instance = new BufferInstance(*this, std::move(result_buffer_view));
instance->AdvanceReadyFence(transfer_timeline_.get(), signal_copy_complete);
instance->AdvanceDoneFence(transfer_timeline_.get(), signal_copy_complete);
if (is_dense_row_major) {
*out_buffer = instance;
// We snapshotted the caller data when acquiring the host staging buffer,
// so we won't be touching it again.
*out_done_with_host_buffer_event = new EventInstance(/*fence=*/nullptr);
return iree_ok_status();
}
auto err = TransposeBroadcastDeviceBuffer(
instance, element_type, input_shape.data(), output_shape.data(),
perms.data(), num_dims, host_buffer_semantics,
out_done_with_host_buffer_event, out_buffer);
delete instance;
return err;
}
iree_status_t DeviceInstance::AcquireHostStagingBuffer(
iree_const_byte_span_t initial_contents, bool snapshot_initial_contents_now,
bool* initial_contents_snapshotted, iree_hal_buffer_t** out_buffer) {
IREE_TRACE_SCOPE();
// There are multiple ways to do this that have different cost/benefits.
// Here we do the simplest thing and snapshot into a new host allocation.
// This could be replaced with either some form of staging ring buffer
// or importing from a raw pointer (on implementations where the cost of
// unified addressing is zero).
iree_hal_buffer_params_t params;
memset(&params, 0, sizeof(params));
params.type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE;
params.usage = IREE_HAL_BUFFER_USAGE_TRANSFER;
IREE_RETURN_IF_ERROR(IreeApi::hal_allocator_allocate_buffer(
device_allocator(), params, initial_contents.data_length, out_buffer));
IREE_RETURN_IF_ERROR(iree_hal_device_transfer_h2d(
device(), initial_contents.data, *out_buffer, 0,
initial_contents.data_length, IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
iree_infinite_timeout()));
// We did a synchronous snapshot (memcpy).
*initial_contents_snapshotted = true;
return iree_ok_status();
}
iree_status_t DeviceInstance::GetHalDevice(iree_hal_device_t** out_device) {
IREE_RETURN_IF_ERROR(OpenDevice());
*out_device = device_.get();
return iree_ok_status();
}
//===----------------------------------------------------------------------===//
// ClientInstance
//===----------------------------------------------------------------------===//
ClientInstance::ClientInstance(std::unique_ptr<Platform> platform)
: platform_(std::move(platform)) {
host_allocator_ = iree_allocator_system();
IREE_CHECK_OK(
iree_hal_driver_registry_allocate(host_allocator_, &driver_registry_));
cached_platform_version_ = "git"; // TODO: Plumb through version info.
}
ClientInstance::~ClientInstance() {
for (auto* device : devices_) {
delete device;
}
if (device_infos_) {
iree_allocator_free(host_allocator_, device_infos_);
}
// Explicitly releasing vs using a ref so as to better control shut-down
// ordering (bad shutdown ordering of the driver is a frequent cause of
// bugs).
iree_hal_driver_release(driver_);
iree_hal_driver_registry_free(driver_registry_);
}
void ClientInstance::BindApi(PJRT_Api* api) {
// PJRT_Client_Create is polymorphic
api->PJRT_Client_Destroy =
+[](PJRT_Client_Destroy_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE_NAMED("PJRT_Client_Destroy");
delete ClientInstance::Unwrap(args->client);
return nullptr;
};
api->PJRT_Client_PlatformName =
+[](PJRT_Client_PlatformName_Args* args) -> PJRT_Error* {
auto* client = ClientInstance::Unwrap(args->client);
args->platform_name = client->cached_platform_name().data();
args->platform_name_size = client->cached_platform_name().size();
return nullptr;
};
api->PJRT_Client_ProcessIndex =
+[](PJRT_Client_ProcessIndex_Args* args) -> PJRT_Error* {
args->process_index = 0;
return nullptr;
};
api->PJRT_Client_PlatformVersion =
+[](PJRT_Client_PlatformVersion_Args* args) -> PJRT_Error* {
auto* client = ClientInstance::Unwrap(args->client);
args->platform_version = client->cached_platform_version().data();
args->platform_version_size = client->cached_platform_version().size();
return nullptr;
};
api->PJRT_Client_Devices =
+[](PJRT_Client_Devices_Args* args) -> PJRT_Error* {
auto& devices = ClientInstance::Unwrap(args->client)->devices();
args->devices = const_cast<PJRT_Device**>(
reinterpret_cast<PJRT_Device* const*>(devices.data()));
args->num_devices = devices.size();
return nullptr;
};
api->PJRT_Client_AddressableDevices =
+[](PJRT_Client_AddressableDevices_Args* args) -> PJRT_Error* {
auto& devices = ClientInstance::Unwrap(args->client)->addressable_devices();
args->addressable_devices = const_cast<PJRT_Device**>(
reinterpret_cast<PJRT_Device* const*>(devices.data()));
args->num_addressable_devices = devices.size();
return nullptr;
};
api->PJRT_Client_LookupDevice =
+[](PJRT_Client_LookupDevice_Args* args) -> PJRT_Error* {
auto& devices = ClientInstance::Unwrap(args->client)->devices();
size_t id_as_size = args->id;
if (id_as_size >= devices.size()) {
return MakeError(
iree_make_status(IREE_STATUS_OUT_OF_RANGE,
"because device id %d is invalid (%d devices known)",
(int)id_as_size, (int)devices.size()));
}
args->device = *devices[id_as_size];
return nullptr;
};
api->PJRT_Client_AddressableMemories =
+[](PJRT_Client_AddressableMemories_Args* args) -> PJRT_Error* {
return MakeError(iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"PJRT_Client_AddressableMemories"));
};
api->PJRT_Client_Compile =
+[](PJRT_Client_Compile_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE_NAMED("PJRT_Client_Compile");
// TODO: It is not great that we only get a client here vs a list of
// devices to consider (or something). The issue is that systems often
// have unrelated devices that will not actually be scheduled and those
// will very naturally have different tuning flags. We therefore have to
// guess... which is an accident waiting to happen.
// Looks like what I need is buried in the compile options... need to
// work on that.
auto* client = ClientInstance::Unwrap(args->client);
LoadedExecutableInstance* executable;
// Read compilation options.
// TODO: Port CompileOptionsProto into the project or leave ommitted.
// xla::CompileOptionsProto options_proto;
// if (!options_proto.ParseFromArray(args->compile_options,
// args->compile_options_size)) {
// return MakeError(iree_make_status(IREE_STATUS_INTERNAL,
// "could not parse compilation
// options"));
// }
// auto options = xla::CompileOptions::FromProto(options_proto);
// if (!options.ok()) {
// return MakeError(
// iree_make_status(IREE_STATUS_INTERNAL,
// std::string(options.status().message()).c_str()));
// }
auto* error = client->Compile(args->program, /**options,*/ &executable);
if (error) return error;
args->executable = *executable;
return nullptr;
};
api->PJRT_Client_DefaultDeviceAssignment =
+[](PJRT_Client_DefaultDeviceAssignment_Args* args) -> PJRT_Error* {
// TODO: Something sensible.
for (size_t i = 0; i < args->default_assignment_size; ++i) {
args->default_assignment[i] = 0;
}
return nullptr;
};
api->PJRT_Client_BufferFromHostBuffer =
+[](PJRT_Client_BufferFromHostBuffer_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE_NAMED("PJRT_Client_BufferFromHostBuffer");
auto status =
DeviceInstance::Unwrap(args->device)
->HostBufferToDevice(
args->data, args->type, args->dims, args->num_dims,
args->byte_strides, args->num_byte_strides,
args->host_buffer_semantics,
reinterpret_cast<EventInstance**>(&args->done_with_host_buffer),
reinterpret_cast<BufferInstance**>(&args->buffer));
return MakeError(status);
};
api->PJRT_LoadedExecutable_Fingerprint =
+[](PJRT_LoadedExecutable_Fingerprint_Args* args) -> PJRT_Error* {
return MakeError(iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"PJRT_LoadedExecutable_Fingerprint"));
};
}
PJRT_Error* ClientInstance::Initialize() {
// TODO: Remove calls to iree_status_fprint once JAX properly reports
// initialization errors: https://github.com/google/jax/issues/13763
auto status = CreateDriver(&driver_);
if (!iree_status_is_ok(status)) {
iree_status_fprint(stderr, status);
return MakeError(status);
}
status = InitializeVM();
if (!iree_status_is_ok(status)) {
iree_status_fprint(stderr, status);
return MakeError(status);
}
status = PopulateDevices();
if (!iree_status_is_ok(status)) {
iree_status_fprint(stderr, status);
return MakeError(status);
}
// More initialization.
return nullptr;
}
iree_status_t ClientInstance::InitializeVM() {
IREE_RETURN_IF_ERROR(iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT,
host_allocator_, &vm_instance_));
IREE_RETURN_IF_ERROR(iree_hal_module_register_all_types(vm_instance_.get()));
return iree_ok_status();
}
iree_status_t ClientInstance::PopulateDevices() {
IREE_RETURN_IF_ERROR(iree_hal_driver_query_available_devices(
driver_, host_allocator_, &device_info_count_, &device_infos_));
devices_.resize(device_info_count_);
for (iree_host_size_t i = 0; i < device_info_count_; ++i) {
// Note that we assume one driver per client here.
// But device is modeled with a driver in case if it ever becomes
// more heterogenous.
devices_[i] = new DeviceInstance(i, *this, driver_, &device_infos_[i]);
}
// For now, just make all devices addressable.
addressable_devices_.reserve(devices_.size());
for (auto* device : devices_) {
addressable_devices_.push_back(device);
}
return iree_ok_status();
}
PJRT_Error* ClientInstance::Compile(const PJRT_Program* program,
/*xla::CompileOptions options,*/
LoadedExecutableInstance** out_executable) {
std::unique_ptr<ArtifactDumper::Transaction> artifact_tx;
if (platform().artifact_dumper().enabled()) {
artifact_tx = platform().artifact_dumper().CreateTransaction();
}
iree_status_t status;
std::string_view format(program->format, program->format_size);
std::string_view code(program->code, program->code_size);
if (artifact_tx) {
artifact_tx->WriteArtifact(/*label=*/"program", /*extension=*/"mlirbc",
/*index=*/-1, code);
}
if (format != "mlir") {
// See: https://github.com/google/jax/issues/13722
return MakeError(iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"because IREE only supports MLIR input but got something else"));
}
auto MakeCompilerError = [&](CompilerJob& job) {
std::string message = job.GetErrorMessage();
return MakeError(iree_make_status(IREE_STATUS_INVALID_ARGUMENT, ": %s",
message.c_str()));
};
std::vector<std::unique_ptr<CompilerOutput>> retained_outputs;
// Partition.
if (platform().partitioner()) {
std::unique_ptr<CompilerJob> job = platform().partitioner()->StartJob();
if (!job) {
std::string message = platform().partitioner()->GetErrorMessage();
return MakeError(
iree_make_status(IREE_STATUS_CANCELLED, ": %s", message.c_str()));
}
if (artifact_tx) {
job->EnableCrashDumps(artifact_tx.get());
}
// Set flags.
// TODO: Plumb CompileOptions through.
// if (!job->SetFlags(options)) return MakeCompilerError(*job);
if (artifact_tx) {
artifact_tx->WriteArtifact(
/*label=*/"partitioner_flags", /*extension=*/"txt", /*index=*/-1,
job->GetFlags());
}
// Parse the source.
if (!job->ParseSourceBuffer(code.data(), code.size())) {
return MakeCompilerError(*job);
}
// Partition.
std::unique_ptr<CompilerOutput> output = job->CompileStandardPipeline();
if (!output) {
return MakeCompilerError(*job);
}
if (artifact_tx) {
artifact_tx->WriteArtifact(
/*label=*/"partitioned", /*extension=*/"mlir", /*index=*/-1,
std::string_view(static_cast<const char*>(output->GetData()),
output->GetDataSize()));
}
// Update the code alias and retain the backing output for the next
// compilation step.
code = std::string_view(static_cast<const char*>(output->GetData()),
output->GetDataSize());
retained_outputs.push_back(std::move(output));
}
// Main compilation.
{
std::unique_ptr<CompilerJob> job = platform().compiler().StartJob();
if (!job) {
std::string message = platform().compiler().GetErrorMessage();
return MakeError(
iree_make_status(IREE_STATUS_CANCELLED, ": %s", message.c_str()));
}
if (artifact_tx) {
job->EnableCrashDumps(artifact_tx.get());
}
// Set flags.
// TODO: This should be done as part of session setup from a named pool.
// TODO: The HAL backends and other flags should come from the assigned
// devices.
if (!SetDefaultCompilerFlags(job.get())) {
return MakeCompilerError(*job);
}
// TODO: Plumb CompileOptions through.
// if (!job->SetFlags(options)) return MakeCompilerError(*job);
if (artifact_tx) {
artifact_tx->WriteArtifact(
/*label=*/"flags", /*extension=*/"txt", /*index=*/-1,
job->GetFlags());
}
// Parse the source.
if (!job->ParseSourceBuffer(code.data(), code.size())) {
return MakeCompilerError(*job);
}
// Perform main compilation.
std::unique_ptr<CompilerOutput> output = job->CompileStandardPipeline();
if (!output) {
return MakeCompilerError(*job);
}
if (artifact_tx) {
artifact_tx->WriteArtifact(
/*label=*/"program", /*extension=*/"vmfb", /*index=*/-1,
std::string_view(static_cast<const char*>(output->GetData()),
output->GetDataSize()));
}
auto executable = std::make_unique<LoadedExecutableInstance>(
*this,
new ExecutableImage(std::move(output),
std::string(program->code, program->code_size)),
addressable_devices_);
status = executable->LoadAll();
if (!iree_status_is_ok(status)) {
return MakeError(status);
}
*out_executable = executable.release();
}
// Success? Cancel the artifact so we don't persist successful runs
// (unless if so configured).
if (artifact_tx) {
artifact_tx->Cancel();
}
return nullptr;
}
iree_status_t ClientInstance::PopulateVMModules(
std::vector<iree::vm::ref<iree_vm_module_t>>& modules,
iree_hal_device_t* hal_device,
iree::vm::ref<iree_vm_module_t>& main_module) {
// HAL module.
modules.push_back({});
IREE_RETURN_IF_ERROR(iree_hal_module_create(
vm_instance(), /*device_count=*/1, &hal_device, IREE_HAL_MODULE_FLAG_NONE,
iree_hal_module_debug_sink_stdio(stderr), host_allocator(),
&modules.back()));
// Main module.
modules.push_back(main_module);
return iree_ok_status();
}
std::tuple<uint64_t, uint64_t> ClientInstance::AdvanceTimeline() {
uint64_t current = execution_timeline_;
uint64_t next = current + 1;
execution_timeline_ = next;
return std::make_tuple(current, next);
}
//===----------------------------------------------------------------------===//
// EventInstance
//===----------------------------------------------------------------------===//
EventInstance::EventInstance(iree::vm::ref<iree_hal_fence_t> fence)
: is_ready_(false) {
if (!fence) {
is_ready_ = true;
return;
}
{
std::lock_guard<std::mutex> guard(lock_);
// Create a thread that waits on the fence and executes the callbacks when
// the fence is ready.
signal_thread_ = std::make_unique<std::thread>(
[](EventInstance* event_instance,
iree::vm::ref<iree_hal_fence_t> fence) {
iree_status_t wait_status =
iree_hal_fence_wait(fence.get(), iree_infinite_timeout());
event_instance->SignalReady(wait_status);
},
this, std::move(fence));
}
}
EventInstance::~EventInstance() {
std::lock_guard<std::mutex> guard(lock_);
if (signal_thread_) {
if (std::this_thread::get_id() != signal_thread_->get_id()) {
signal_thread_->join();
} else {
// An `EventInstance` is allowed to delete itself in one of its callbacks,
// resulting in `signal_thread_` being the thread calling the destructor.
// In such cases, we must let the thread continue running independent of
// the destructor to avoid a deadlock.
signal_thread_->detach();
signal_thread_.release();
}
}
iree_status_ignore(status_);
}
void EventInstance::BindApi(PJRT_Api* api) {
api->PJRT_Event_Destroy = +[](PJRT_Event_Destroy_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE_NAMED("PJRT_Event_Destroy");
auto instance = EventInstance::Unwrap(args->event);
auto delete_event = [](PJRT_Error* error, void* user_data) {
EventInstance* event = static_cast<EventInstance*>(user_data);
delete event;
if (error) {
delete ErrorInstance::FromError(error);
}
};
instance->OnReady(delete_event, args->event);
return nullptr;
};
api->PJRT_Event_IsReady = +[](PJRT_Event_IsReady_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE_NAMED("PJRT_Event_IsReady");
args->is_ready = EventInstance::Unwrap(args->event)->is_ready();
return nullptr;
};
api->PJRT_Event_Error = +[](PJRT_Event_Error_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE_NAMED("PJRT_Event_Error");
return (PJRT_Error*)EventInstance::Unwrap(args->event)->error();
};
api->PJRT_Event_Await = +[](PJRT_Event_Await_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE_NAMED("PJRT_Event_Await");
return MakeError(
iree_make_status(IREE_STATUS_UNIMPLEMENTED, "PJRT_Event_Await"));
};
api->PJRT_Event_OnReady = +[](PJRT_Event_OnReady_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE_NAMED("PJRT_Event_OnReady");
return MakeError(EventInstance::Unwrap(args->event)
->OnReady(args->callback, args->user_arg));
};
}
ErrorInstance* EventInstance::error() {
std::lock_guard<std::mutex> guard(lock_);
if (!iree_status_is_ok(status_)) return new ErrorInstance(status_);
return nullptr;
}
bool EventInstance::is_ready() {
std::lock_guard<std::mutex> guard(lock_);
return is_ready_;
}
iree_status_t EventInstance::OnReady(PJRT_Event_OnReadyCallback callback,
void* user_arg) {
iree_status_t local_status;
{
std::lock_guard<std::mutex> guard(lock_);
if (!is_ready_) {
pending_callbacks_.push_back({callback, user_arg});
return iree_ok_status();
}
local_status = status_;
}
// Already signalled. Callback out of lock scope.
// Note that the callback may destroy the event - so must only operate on
// locals.
callback(
iree_status_is_ok(local_status)
? nullptr
: (PJRT_Error*)new ErrorInstance(iree_status_clone(local_status)),
user_arg);
return iree_ok_status();
}
void EventInstance::SignalReady(iree_status_t status) {
IREE_TRACE_SCOPE();
iree_status_t local_status;
std::vector<std::pair<PJRT_Event_OnReadyCallback, void*>> local_callbacks;
{
std::lock_guard<std::mutex> guard(lock_);
if (is_ready_) {
return;
}
local_callbacks.swap(pending_callbacks_);
is_ready_ = true;
status_ = status;
local_status = status_;
}
// Trigger callbacks outside of the lock.
// Note that the callback may destroy the event - so must only operate on
// locals.
for (auto& cb : local_callbacks) {
IREE_TRACE_SCOPE_NAMED("PJRT_User_Callback_Invoke");
cb.first(
iree_status_is_ok(local_status)
? nullptr
: (PJRT_Error*)new ErrorInstance(iree_status_clone(local_status)),
cb.second);
}
}
//===----------------------------------------------------------------------===//
// LoadedExecutableInstance
//===----------------------------------------------------------------------===//
void ExecutableImage::BindApi(PJRT_Api* api) {
api->PJRT_Executable_Destroy =
+[](PJRT_Executable_Destroy_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE_NAMED("PJRT_Executable_Destroy");
ExecutableImage::Unwrap(args->executable)->DecRef();
return nullptr;
};
api->PJRT_Executable_Name =
+[](PJRT_Executable_Name_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE_NAMED(PJRT_Executable_Name);
const char* dummy_name = "iree_vmfb";
args->executable_name = dummy_name;
args->executable_name_size = strlen(dummy_name);
return nullptr;
};
api->PJRT_Executable_SizeOfGeneratedCodeInBytes =
+[](PJRT_Executable_SizeOfGeneratedCodeInBytes_Args* args)
-> PJRT_Error* {
IREE_TRACE_SCOPE_NAMED("PJRT_Executable_SizeOfGeneratedCodeInBytes");
args->size_in_bytes =
ExecutableImage::Unwrap(args->executable)->binary->GetDataSize();
return nullptr;
};
api->PJRT_Executable_NumOutputs =
+[](PJRT_Executable_NumOutputs_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE_NAMED("PJRT_Executable_NumOutputs");
auto* exec = ExecutableImage::Unwrap(args->executable);
assert(exec->metadata_initialized);
args->num_outputs = exec->result_count;
return nullptr;
};
api->PJRT_Executable_NumPartitions =
+[](PJRT_Executable_NumPartitions_Args* args) -> PJRT_Error* {
// This should be updated once iree supports partitioning.
args->num_partitions = 1;
return nullptr;
};
api->PJRT_Executable_NumReplicas =
+[](PJRT_Executable_NumReplicas_Args* args) -> PJRT_Error* {
// This should be updated once iree supports replicas.
args->num_replicas = 1;
return nullptr;
};
api->PJRT_Executable_Serialize =
+[](PJRT_Executable_Serialize_Args* args) -> PJRT_Error* {
return MakeError(iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"PJRT_Executable_Serialize"));
};
api->PJRT_Executable_DeserializeAndLoad =
+[](PJRT_Executable_DeserializeAndLoad_Args* args) -> PJRT_Error* {
return MakeError(iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"PJRT_Executable_DeserializeAndLoad"));
};
api->PJRT_Executable_Serialize =
+[](PJRT_Executable_Serialize_Args* args) -> PJRT_Error* {
return MakeError(iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"PJRT_Executable_Serialize"));
};
api->PJRT_Executable_OptimizedProgram =
+[](PJRT_Executable_OptimizedProgram_Args* args) -> PJRT_Error* {
ExecutableImage* executable = ExecutableImage::Unwrap(args->executable);
PJRT_Program* program = args->program;
program->format = kMlirFormat.data();
program->format_size = kMlirFormat.size();
size_t code_size = executable->code.size();
if (program->code == nullptr) {
program->code_size = code_size;
} else {
if (program->code_size < code_size) {
return MakeError(
iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"expected code_size >= %lu, got code_size = %lu",
code_size, program->code_size));
}
std::memcpy(program->code, executable->code.c_str(),
executable->code.size());
}
return nullptr;
};
api->PJRT_Executable_GetCostAnalysis =
+[](PJRT_Executable_GetCostAnalysis_Args* args) -> PJRT_Error* {
return MakeError(iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"PJRT_Executable_GetCostAnalysis"));
};
api->PJRT_Executable_OutputElementTypes =
+[](PJRT_Executable_OutputElementTypes_Args* args) -> PJRT_Error* {
return MakeError(iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"PJRT_Executable_OutputElementTypes"));
};
api->PJRT_Executable_OutputDimensions =
+[](PJRT_Executable_OutputDimensions_Args* args) -> PJRT_Error* {
return MakeError(iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"PJRT_Executable_OutputDimensions"));
};
api->PJRT_Executable_OutputMemoryKinds =
+[](PJRT_Executable_OutputMemoryKinds_Args* args) -> PJRT_Error* {
return MakeError(iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"PJRT_Executable_OutputMemoryKinds"));
};
}
void LoadedExecutableInstance::BindApi(PJRT_Api* api) {
api->PJRT_LoadedExecutable_Destroy =
+[](PJRT_LoadedExecutable_Destroy_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE_NAMED("PJRT_LoadedExecutable_Destroy");
delete LoadedExecutableInstance::Unwrap(args->executable);
return nullptr;
};
api->PJRT_LoadedExecutable_AddressableDevices =
+[](PJRT_LoadedExecutable_AddressableDevices_Args* args) -> PJRT_Error* {
auto& devices = LoadedExecutableInstance::Unwrap(args->executable)
->addressable_devices();
args->addressable_devices = const_cast<PJRT_Device**>(
reinterpret_cast<PJRT_Device* const*>(devices.data()));
args->num_addressable_devices = devices.size();
return nullptr;
};
api->PJRT_LoadedExecutable_Delete =
+[](PJRT_LoadedExecutable_Delete_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE_NAMED("PJRT_LoadedExecutable_Delete");
return MakeError(iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"PJRT_LoadedExecutable_Delete"));
};
api->PJRT_LoadedExecutable_IsDeleted =
+[](PJRT_LoadedExecutable_IsDeleted_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE_NAMED("PJRT_LoadedExecutable_IsDeleted");
return MakeError(iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"PJRT_LoadedExecutable_IsDeleted"));
};
api->PJRT_LoadedExecutable_Execute =
+[](PJRT_LoadedExecutable_Execute_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE_NAMED("PJRT_LoadedExecutable_Execute");
return MakeError(
LoadedExecutableInstance::Unwrap(args->executable)->BatchExecute(args));
};
api->PJRT_LoadedExecutable_GetExecutable =
+[](PJRT_LoadedExecutable_GetExecutable_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE_NAMED("PJRT_LoadedExecutable_GetExecutable");
auto* loaded_exe =
LoadedExecutableInstance::Unwrap(args->loaded_executable);
ExecutableImage* image = loaded_exe->image_;
if (!image->metadata_initialized) {
auto status = loaded_exe->GetArgResultCount(&image->arg_count,
&image->result_count);
if (!iree_status_is_ok(status)) {
return MakeError(status);
}
image->metadata_initialized = true;
}
image->AddRef();
args->executable = *image;
return nullptr;
};
}
iree_status_t LoadedExecutableInstance::LoadAll() {
IREE_TRACE_SCOPE();
if (!resident_executables_.empty()) return iree_ok_status();
std::vector<ResidentExecutable> new_list;
for (DeviceInstance* device_instance : addressable_devices_) {
iree_hal_device_t* hal_device;
IREE_RETURN_IF_ERROR(device_instance->GetHalDevice(&hal_device));
new_list.push_back({});
ResidentExecutable& loaded = new_list.back();
loaded.device_instance = device_instance;
// Only de-reference through the image_ shared_ptr once to get the
// binary CompilerOutput (mmap).
auto* binary = image_->binary.get();
IREE_RETURN_IF_ERROR(iree_vm_bytecode_module_create(
client_.vm_instance(),
iree_make_const_byte_span(binary->GetData(), binary->GetDataSize()),
/*archive_allocator=*/iree_allocator_null(), client_.host_allocator(),
&loaded.main_module));
// Lookup main function.
const char kNameMain[] = "main";
IREE_RETURN_IF_ERROR(iree_vm_module_lookup_function_by_name(
loaded.main_module.get(), IREE_VM_FUNCTION_LINKAGE_EXPORT,
iree_string_view_t{kNameMain, sizeof(kNameMain) - 1},
&loaded.main_function));
// Record number of args/results.
iree_vm_function_signature_t sig =
iree_vm_function_signature(&loaded.main_function);
IREE_RETURN_IF_ERROR(iree_vm_function_call_count_arguments_and_results(
&sig, &loaded.arg_count, &loaded.result_count));
// Defer to the client to populate the stack of modules.
std::vector<iree::vm::ref<iree_vm_module_t>> modules;
IREE_RETURN_IF_ERROR(
client_.PopulateVMModules(modules, hal_device, loaded.main_module));
std::vector<iree_vm_module_t*> module_ptrs;
module_ptrs.resize(modules.size());
for (size_t i = 0; i < modules.size(); ++i) {
module_ptrs[i] = modules[i].get();
}
IREE_CHECK_OK(iree_vm_context_create_with_modules(
client_.vm_instance(), IREE_VM_CONTEXT_FLAG_NONE, module_ptrs.size(),
module_ptrs.data(), iree_allocator_system(), &loaded.vm_context));
}
new_list.swap(resident_executables_);
return iree_ok_status();
}
iree_status_t LoadedExecutableInstance::GetDefaultResidentExecutable(
ResidentExecutable** out_loaded) {
IREE_RETURN_IF_ERROR(LoadAll());
if (resident_executables_.empty()) {
return iree_make_status(IREE_STATUS_NOT_FOUND,
"no executables could be loaded");
}
*out_loaded = &resident_executables_.front();
return iree_ok_status();
}
iree_status_t LoadedExecutableInstance::GetArgResultCount(
iree_host_size_t* out_arg_count, iree_host_size_t* out_result_count) {
ResidentExecutable* loaded;
IREE_RETURN_IF_ERROR(GetDefaultResidentExecutable(&loaded));
*out_arg_count = loaded->arg_count;
*out_result_count = loaded->result_count;
return iree_ok_status();
}
iree_status_t LoadedExecutableInstance::BatchExecute(
PJRT_LoadedExecutable_Execute_Args* args) {
// Early exit for unsupported features and illegal input.
if (args->execute_device) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"executing with a specific device not supported");
}
if (args->num_devices != addressable_devices_.size()) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"incorrect number of devices to execute on (%d vs %d)",
(int)args->num_devices, (int)addressable_devices_.size());
}
// Make sure loaded.
IREE_RETURN_IF_ERROR(LoadAll());
// Timeline setup. There are two timelines that we synchronize to:
// the main execution timeline, which preserves as-called ordering to
// execution, and the transfer timeline of each device.
auto [wait_timepoint, signal_timepoint] = client_.AdvanceTimeline();
// Initialize invocations.
auto allocator = client_.host_allocator();
auto& resident_executables_ecs = resident_executables_;
struct Invocation {
ResidentExecutable* res_exe;
iree::vm::ref<iree_vm_list_t> inputs;
iree::vm::ref<iree_vm_list_t> outputs;
iree::vm::ref<iree_hal_fence_t> wait_fence;
iree::vm::ref<iree_hal_fence_t> signal_fence;
};
std::vector<Invocation> invs;
invs.resize(args->num_devices);
for (size_t dev_index = 0; dev_index < args->num_devices; ++dev_index) {
auto& inv = invs[dev_index];
inv.res_exe = &resident_executables_ecs[dev_index];
// Wait fence initial value.
// We allocate it to be able to hold two semaphores (main timeline and
// transfer timeline) and initialize it with the global invocation order
// of the main timeline. As we process inputs, we will also insert their
// transfer ready semaphore value so that execution can only begin once
// all dependencies are ready. This at most represents two unique
// semaphores.
IREE_RETURN_IF_ERROR(
inv.res_exe->device_instance->CreateFence(&inv.wait_fence));
IREE_RETURN_IF_ERROR(IreeApi::hal_fence_insert(
inv.wait_fence.get(), inv.res_exe->device_instance->main_timeline(),
wait_timepoint));
// Signal fence. This signals the next tick on the main execution
// timeline.
IREE_RETURN_IF_ERROR(IreeApi::hal_fence_create_at(
inv.res_exe->device_instance->main_timeline(), signal_timepoint,
client_.host_allocator(), &inv.signal_fence));
IREE_RETURN_IF_ERROR(iree_vm_list_create(iree_vm_make_undefined_type_def(),
args->num_args, allocator,
&inv.inputs));
IREE_RETURN_IF_ERROR(iree_vm_list_create(iree_vm_make_undefined_type_def(),
inv.res_exe->result_count,
allocator, &inv.outputs));
// Populate inputs.
for (size_t i = 0; i < args->num_args; ++i) {
auto* buffer = BufferInstance::Unwrap(args->argument_lists[dev_index][i]);
iree_vm_ref_t bv_ref =
iree_hal_buffer_view_retain_ref(buffer->buffer_view());
IREE_RETURN_IF_ERROR(
iree_vm_list_push_ref_move(inv.inputs.get(), &bv_ref));
// Extend the execute wait to include the input's ready signal.
IREE_RETURN_IF_ERROR(IreeApi::hal_fence_extend(inv.wait_fence.get(),
buffer->ready_fence()));
// And extend the buffer's done fence to close over this execution.
buffer->AdvanceDoneFence(inv.res_exe->device_instance->main_timeline(),
signal_timepoint);
}
// Add (wait, signal) fences as required by the async-external execution
// model.
iree_vm_list_push_ref_retain(inv.inputs.get(), inv.wait_fence);
iree_vm_list_push_ref_retain(inv.inputs.get(), inv.signal_fence);
}
// Issue invocations.
// TODO: Switch to using the async API. I've tried to structure this
// so that we can move to that. Obviously important before we have more
// than one device.
iree_status_t status = iree_ok_status();
for (size_t dev_index = 0; dev_index < args->num_devices; ++dev_index) {
auto& inv = invs[dev_index];
if (IreeApi::LOGGING_ENABLED) {
IreeApi::LogInvoke(
"vm_invoke[async]",
"context=%p, f=%d, wait_fence=%p {%s}, signal_fence=%p {%s}",
inv.res_exe->vm_context.get(),
(int)inv.res_exe->main_function.ordinal, inv.wait_fence.get(),
IreeApi::FenceToString(inv.wait_fence.get()).c_str(),
inv.signal_fence.get(),
IreeApi::FenceToString(inv.signal_fence.get()).c_str());
}
auto new_status = IreeApi::HandleStatus(
"vm_invoke[async]",
iree_vm_invoke(inv.res_exe->vm_context.get(),
inv.res_exe->main_function, IREE_VM_INVOCATION_FLAG_NONE,
/*policy=*/nullptr, inv.inputs.get(), inv.outputs.get(),
allocator));
// Any invocation that fails needs a barrier so that signal fence is
// incremented otherwise future waits will fail. We do this instead of
// incrementing as only a subset of devices may fail.
if (!iree_status_is_ok(new_status)) {
status = new_status;
// We can ignore the error as we are already erroring out earlier.
IREE_IGNORE_ERROR(IreeApi::hal_device_queue_barrier(
inv.res_exe->device_instance->device(), IREE_HAL_QUEUE_AFFINITY_ANY,
iree_hal_fence_semaphore_list(inv.wait_fence.get()),
iree_hal_fence_semaphore_list(inv.signal_fence.get())));
}
}
// Process results.
// Early exit before committing things to the client if anything failed.
if (!iree_status_is_ok(status)) return status;
for (size_t dev_index = 0; dev_index < args->num_devices; ++dev_index) {
auto& inv = invs[dev_index];
for (size_t i = 0; i < inv.res_exe->result_count; ++i) {
iree::vm::ref<iree_hal_buffer_view_t> ret_buffer_view =
retain_ref((iree_hal_buffer_view_t*)iree_vm_list_get_ref_deref(
inv.outputs.get(), i, iree_hal_buffer_view_type()));
// This should not be possible so just hard-assert.
IREE_ASSERT_ARGUMENT(ret_buffer_view);
auto result_buffer = std::make_unique<BufferInstance>(
*inv.res_exe->device_instance, std::move(ret_buffer_view));
IREE_RETURN_IF_ERROR(result_buffer->AdvanceReadyFence(
inv.res_exe->device_instance->main_timeline(), signal_timepoint));
IREE_RETURN_IF_ERROR(result_buffer->AdvanceDoneFence(
inv.res_exe->device_instance->main_timeline(), signal_timepoint));
args->output_lists[dev_index][i] = *(result_buffer.release());
}
if (args->device_complete_events) {
args->device_complete_events[dev_index] =
*(new EventInstance(retain_ref(inv.wait_fence)));
}
}
return status;
}
static void BindUndefineds(PJRT_Api* api) {
#define _STUB(API) \
api->API = +[](API##_Args* args) -> decltype(api->API(args)) { \
return (decltype(api->API(args)))MakeError( \
iree_make_status(IREE_STATUS_UNIMPLEMENTED, #API)); \
}
#include "stubs.inc"
}
//===----------------------------------------------------------------------===//
// Top-level API binding.
//===----------------------------------------------------------------------===//
void BindMonomorphicApi(PJRT_Api* api) {
api->struct_size = PJRT_Api_STRUCT_SIZE;
api->extension_start = nullptr;
api->pjrt_api_version.major_version = PJRT_API_MAJOR;
api->pjrt_api_version.minor_version = PJRT_API_MINOR;
// This is a bare implementation throwing UNDEFINED errors. This way new
// functions will not segmentation fault on invocation.
BindUndefineds(api);
ErrorInstance::BindApi(api);
api->PJRT_Plugin_Initialize =
+[](PJRT_Plugin_Initialize_Args* args) -> PJRT_Error* { return nullptr; };
// Bind by object types.
BufferInstance::BindApi(api);
ClientInstance::BindApi(api);
DeviceDescription::BindApi(api);
DeviceInstance::BindApi(api);
EventInstance::BindApi(api);
ExecutableImage::BindApi(api);
LoadedExecutableInstance::BindApi(api);
}
} // namespace iree::pjrt