blob: 7577619c4012fc069d99c317ffec198554f1c670 [file] [log] [blame]
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/hal/api.h"
namespace iree::pjrt {
// Anonymous namespace containing helpers and wrappers for IREE API
// functions which can perform verbose logging when enabled. These all
// match an IREE api but will have the |iree_| prefix elided, so they are
// used as IreeApi::hal_allocator_allocate_buffer(...), which should be a
// drop-in for iree_hal_allocator_allocate_buffer(...).
namespace IreeApi {
namespace {
// Controls whether logging is printed to stderr. We may want to make this
// more configurable in the future.
const bool LOGGING_ENABLED = false;
IREE_PRINTF_ATTRIBUTE(2, 3)
void LogInvoke(const char* func, const char* fmt, ...) {
if (LOGGING_ENABLED) {
fprintf(stderr, ":: IREE INVOKE (%s): ", func);
va_list args;
va_start(args, fmt);
vfprintf(stderr, fmt, args);
va_end(args);
fflush(stderr);
}
}
iree_status_t HandleStatus(const char* func, iree_status_t status) {
if (LOGGING_ENABLED) {
if (!iree_status_is_ok(status)) {
fprintf(stderr, " (");
iree_status_fprint(stderr, status);
fprintf(stderr, ")\n");
} else {
fprintf(stderr, " (OK)\n");
}
}
return status;
}
std::string SemaphoreListToString(const iree_hal_semaphore_list_t sl) {
std::string result;
char fmtBuffer[64];
for (iree_host_size_t i = 0; i < sl.count; ++i) {
snprintf(fmtBuffer, sizeof(fmtBuffer), "%p:%" PRIu64, sl.semaphores[i],
sl.payload_values[i]);
if (i > 0) {
result.append(", ");
}
result.append(fmtBuffer);
}
return result;
}
std::string FenceToString(iree_hal_fence_t* fence) {
return SemaphoreListToString(iree_hal_fence_semaphore_list(fence));
}
iree_status_t hal_allocator_allocate_buffer(
iree_hal_allocator_t* IREE_RESTRICT allocator,
iree_hal_buffer_params_t params, iree_device_size_t allocation_size,
iree_hal_buffer_t** out_buffer) {
auto status = iree_hal_allocator_allocate_buffer(allocator, params,
allocation_size, out_buffer);
if (LOGGING_ENABLED) {
LogInvoke(__func__, "allocator=%p, size=%zu, buffer=%p", allocator,
(size_t)allocation_size, *out_buffer);
}
return HandleStatus(__func__, status);
}
iree_status_t hal_allocator_import_buffer(
iree_hal_allocator_t* IREE_RESTRICT allocator,
iree_hal_buffer_params_t params,
iree_hal_external_buffer_t* IREE_RESTRICT external_buffer,
iree_hal_buffer_release_callback_t release_callback,
iree_hal_buffer_t** out_buffer) {
if (LOGGING_ENABLED) {
LogInvoke(__func__, "external_buffer=%p", external_buffer);
}
return HandleStatus(__func__, iree_hal_allocator_import_buffer(
allocator, params, external_buffer,
release_callback, out_buffer));
}
iree_status_t hal_device_queue_alloca(
iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity,
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
iree_hal_allocator_pool_t pool, iree_hal_buffer_params_t params,
iree_device_size_t allocation_size,
iree_hal_buffer_t** IREE_RESTRICT out_buffer) {
if (LOGGING_ENABLED) {
LogInvoke(__func__, "device=%p, size=%zd, wait={%s}, signal={%s}", device,
(size_t)allocation_size,
SemaphoreListToString(wait_semaphore_list).c_str(),
SemaphoreListToString(signal_semaphore_list).c_str());
}
return HandleStatus(__func__, iree_hal_device_queue_alloca(
device, queue_affinity, wait_semaphore_list,
signal_semaphore_list, pool, params,
allocation_size, out_buffer));
}
iree_status_t hal_device_queue_dealloca(
iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity,
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
iree_hal_buffer_t* buffer) {
if (LOGGING_ENABLED) {
LogInvoke(__func__, "device=%p, buffer=%p, wait={%s}, signal={%s}", device,
buffer, SemaphoreListToString(wait_semaphore_list).c_str(),
SemaphoreListToString(signal_semaphore_list).c_str());
}
return HandleStatus(__func__, iree_hal_device_queue_dealloca(
device, queue_affinity, wait_semaphore_list,
signal_semaphore_list, buffer));
}
iree_status_t hal_device_queue_barrier(
iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity,
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list) {
if (LOGGING_ENABLED) {
LogInvoke(__func__, "device=%p, wait={%s}, signal={%s}", device,
SemaphoreListToString(wait_semaphore_list).c_str(),
SemaphoreListToString(signal_semaphore_list).c_str());
}
return HandleStatus(__func__, iree_hal_device_queue_barrier(
device, queue_affinity, wait_semaphore_list,
signal_semaphore_list));
}
iree_status_t hal_device_queue_execute(
iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity,
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
iree_hal_command_buffer_t* command_buffer) {
if (LOGGING_ENABLED) {
LogInvoke(__func__, "device=%p, wait={%s}, signal={%s}", device,
SemaphoreListToString(wait_semaphore_list).c_str(),
SemaphoreListToString(signal_semaphore_list).c_str());
}
return HandleStatus(__func__, iree_hal_device_queue_execute(
device, queue_affinity, wait_semaphore_list,
signal_semaphore_list, command_buffer,
iree_hal_buffer_binding_table_empty()));
}
iree_status_t hal_fence_create(iree_host_size_t capacity,
iree_allocator_t host_allocator,
iree_hal_fence_t** out_fence) {
auto status = iree_hal_fence_create(capacity, host_allocator, out_fence);
if (LOGGING_ENABLED) {
LogInvoke(__func__, "capacity=%zu, fence=%p", (size_t)capacity, *out_fence);
}
return HandleStatus(__func__, status);
}
iree_status_t hal_fence_create_at(iree_hal_semaphore_t* semaphore,
uint64_t value,
iree_allocator_t host_allocator,
iree_hal_fence_t** out_fence) {
auto status =
iree_hal_fence_create_at(semaphore, value, host_allocator, out_fence);
if (LOGGING_ENABLED) {
LogInvoke(__func__, "semaphore=%p, value=%" PRIu64 ", fence=%p", semaphore,
value, *out_fence);
}
return HandleStatus(__func__, status);
}
iree_status_t hal_fence_extend(iree_hal_fence_t* into_fence,
iree_hal_fence_t* from_fence) {
if (LOGGING_ENABLED) {
LogInvoke(__func__, "into_fence=%p, from_fence=%p", into_fence, from_fence);
}
return HandleStatus(__func__, iree_hal_fence_extend(into_fence, from_fence));
}
iree_status_t hal_fence_insert(iree_hal_fence_t* fence,
iree_hal_semaphore_t* semaphore,
uint64_t value) {
if (LOGGING_ENABLED) {
LogInvoke(__func__, "fence=%p, semaphore=%p, value=%" PRIu64, fence,
semaphore, value);
}
return HandleStatus(__func__, iree_hal_fence_insert(fence, semaphore, value));
}
} // namespace
} // namespace IreeApi
} // namespace iree::pjrt