blob: 0bddc8a8374723929a38e927cc986a8f48fad55d [file] [log] [blame]
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "experimental/cuda2/event_semaphore.h"
#include "experimental/cuda2/cuda_dynamic_symbols.h"
#include "experimental/cuda2/cuda_headers.h"
#include "experimental/cuda2/cuda_status_util.h"
#include "experimental/cuda2/timepoint_pool.h"
#include "iree/base/internal/synchronization.h"
#include "iree/hal/utils/semaphore_base.h"
typedef struct iree_hal_cuda2_semaphore_t {
// Abstract resource used for injecting reference counting and vtable;
// must be at offset 0.
iree_hal_semaphore_t base;
// The allocator used to create this semaphore.
iree_allocator_t host_allocator;
// The symbols used to issue CUDA API calls.
const iree_hal_cuda2_dynamic_symbols_t* symbols;
// The timepoint pool to acquire timepoint objects.
iree_hal_cuda2_timepoint_pool_t* timepoint_pool;
// The list of pending queue actions that this semaphore need to advance on
// new signaled values.
iree_hal_cuda2_pending_queue_actions_t* pending_queue_actions;
// Guards value and status. We expect low contention on semaphores and since
// iree_slim_mutex_t is (effectively) just a CAS this keeps things simpler
// than trying to make the entire structure lock-free.
iree_slim_mutex_t mutex;
// Current signaled value. May be IREE_HAL_SEMAPHORE_FAILURE_VALUE to
// indicate that the semaphore has been signaled for failure and
// |failure_status| contains the error.
uint64_t current_value IREE_GUARDED_BY(mutex);
// OK or the status passed to iree_hal_semaphore_fail. Owned by the semaphore.
iree_status_t failure_status IREE_GUARDED_BY(mutex);
} iree_hal_cuda2_semaphore_t;
static const iree_hal_semaphore_vtable_t iree_hal_cuda2_semaphore_vtable;
static iree_hal_cuda2_semaphore_t* iree_hal_cuda2_semaphore_cast(
iree_hal_semaphore_t* base_value) {
IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_cuda2_semaphore_vtable);
return (iree_hal_cuda2_semaphore_t*)base_value;
}
iree_status_t iree_hal_cuda2_event_semaphore_create(
uint64_t initial_value, const iree_hal_cuda2_dynamic_symbols_t* symbols,
iree_hal_cuda2_timepoint_pool_t* timepoint_pool,
iree_hal_cuda2_pending_queue_actions_t* pending_queue_actions,
iree_allocator_t host_allocator, iree_hal_semaphore_t** out_semaphore) {
IREE_ASSERT_ARGUMENT(symbols);
IREE_ASSERT_ARGUMENT(timepoint_pool);
IREE_ASSERT_ARGUMENT(pending_queue_actions);
IREE_ASSERT_ARGUMENT(out_semaphore);
IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_cuda2_semaphore_t* semaphore = NULL;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_allocator_malloc(host_allocator, sizeof(*semaphore),
(void**)&semaphore));
iree_hal_semaphore_initialize(&iree_hal_cuda2_semaphore_vtable,
&semaphore->base);
semaphore->host_allocator = host_allocator;
semaphore->symbols = symbols;
semaphore->timepoint_pool = timepoint_pool;
semaphore->pending_queue_actions = pending_queue_actions;
iree_slim_mutex_initialize(&semaphore->mutex);
semaphore->current_value = initial_value;
semaphore->failure_status = iree_ok_status();
*out_semaphore = &semaphore->base;
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
static void iree_hal_cuda2_semaphore_destroy(
iree_hal_semaphore_t* base_semaphore) {
iree_hal_cuda2_semaphore_t* semaphore =
iree_hal_cuda2_semaphore_cast(base_semaphore);
iree_allocator_t host_allocator = semaphore->host_allocator;
IREE_TRACE_ZONE_BEGIN(z0);
iree_status_ignore(semaphore->failure_status);
iree_slim_mutex_deinitialize(&semaphore->mutex);
iree_hal_semaphore_deinitialize(&semaphore->base);
iree_allocator_free(host_allocator, semaphore);
IREE_TRACE_ZONE_END(z0);
}
static iree_status_t iree_hal_cuda2_semaphore_query(
iree_hal_semaphore_t* base_semaphore, uint64_t* out_value) {
iree_hal_cuda2_semaphore_t* semaphore =
iree_hal_cuda2_semaphore_cast(base_semaphore);
IREE_TRACE_ZONE_BEGIN(z0);
iree_slim_mutex_lock(&semaphore->mutex);
*out_value = semaphore->current_value;
iree_status_t status = iree_ok_status();
if (*out_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) {
status = iree_status_clone(semaphore->failure_status);
}
iree_slim_mutex_unlock(&semaphore->mutex);
IREE_TRACE_ZONE_END(z0);
return status;
}
static iree_status_t iree_hal_cuda2_semaphore_signal(
iree_hal_semaphore_t* base_semaphore, uint64_t new_value) {
iree_hal_cuda2_semaphore_t* semaphore =
iree_hal_cuda2_semaphore_cast(base_semaphore);
IREE_TRACE_ZONE_BEGIN(z0);
iree_slim_mutex_lock(&semaphore->mutex);
if (new_value <= semaphore->current_value) {
uint64_t current_value IREE_ATTRIBUTE_UNUSED = semaphore->current_value;
iree_slim_mutex_unlock(&semaphore->mutex);
IREE_TRACE_ZONE_END(z0);
return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
"semaphore values must be monotonically "
"increasing; current_value=%" PRIu64
", new_value=%" PRIu64,
current_value, new_value);
}
semaphore->current_value = new_value;
iree_slim_mutex_unlock(&semaphore->mutex);
// Notify timepoints - note that this must happen outside the lock.
iree_hal_semaphore_notify(&semaphore->base, new_value, IREE_STATUS_OK);
// Advance the pending queue actions if possible. This also must happen
// outside the lock to avoid nesting.
iree_status_t status = iree_hal_cuda2_pending_queue_actions_issue(
semaphore->pending_queue_actions);
IREE_TRACE_ZONE_END(z0);
return status;
}
static void iree_hal_cuda2_semaphore_fail(iree_hal_semaphore_t* base_semaphore,
iree_status_t status) {
iree_hal_cuda2_semaphore_t* semaphore =
iree_hal_cuda2_semaphore_cast(base_semaphore);
IREE_TRACE_ZONE_BEGIN(z0);
const iree_status_code_t status_code = iree_status_code(status);
iree_slim_mutex_lock(&semaphore->mutex);
// Try to set our local status - we only preserve the first failure so only
// do this if we are going from a valid semaphore to a failed one.
if (!iree_status_is_ok(semaphore->failure_status)) {
// Previous status was not OK; drop our new status.
IREE_IGNORE_ERROR(status);
iree_slim_mutex_unlock(&semaphore->mutex);
IREE_TRACE_ZONE_END(z0);
return;
}
// Signal to our failure sentinel value.
semaphore->current_value = IREE_HAL_SEMAPHORE_FAILURE_VALUE;
semaphore->failure_status = status;
iree_slim_mutex_unlock(&semaphore->mutex);
// Notify timepoints - note that this must happen outside the lock.
iree_hal_semaphore_notify(&semaphore->base, IREE_HAL_SEMAPHORE_FAILURE_VALUE,
status_code);
IREE_TRACE_ZONE_END(z0);
}
// Handles host wait timepoints on the host when the |semaphore| timeline
// advances past the given |value|.
//
// Note that this callback is invoked by the a host thread.
static iree_status_t iree_hal_cuda2_semaphore_timepoint_host_wait_callback(
void* user_data, iree_hal_semaphore_t* semaphore, uint64_t value,
iree_status_code_t status_code) {
iree_hal_cuda2_timepoint_t* timepoint =
(iree_hal_cuda2_timepoint_t*)user_data;
iree_event_set(&timepoint->timepoint.host_wait);
return iree_ok_status();
}
// Acquires a timepoint to wait the timeline to reach at least the given
// |min_value| from the host.
static iree_status_t iree_hal_cuda2_semaphore_acquire_timepoint_host_wait(
iree_hal_cuda2_semaphore_t* semaphore, uint64_t min_value,
iree_timeout_t timeout, iree_hal_cuda2_timepoint_t** out_timepoint) {
IREE_TRACE_ZONE_BEGIN(z0);
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_cuda2_timepoint_pool_acquire_host_wait(
semaphore->timepoint_pool, 1, out_timepoint));
// Initialize the timepoint with the value and callback, and connect it to
// this semaphore.
iree_hal_semaphore_acquire_timepoint(
&semaphore->base, min_value, timeout,
(iree_hal_semaphore_callback_t){
.fn = iree_hal_cuda2_semaphore_timepoint_host_wait_callback,
.user_data = *out_timepoint,
},
&(*out_timepoint)->base);
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
static iree_status_t iree_hal_cuda2_semaphore_wait(
iree_hal_semaphore_t* base_semaphore, uint64_t value,
iree_timeout_t timeout) {
iree_hal_cuda2_semaphore_t* semaphore =
iree_hal_cuda2_semaphore_cast(base_semaphore);
IREE_TRACE_ZONE_BEGIN(z0);
iree_slim_mutex_lock(&semaphore->mutex);
if (!iree_status_is_ok(semaphore->failure_status)) {
// Fastest path: failed; return an error to tell callers to query for it.
iree_slim_mutex_unlock(&semaphore->mutex);
IREE_TRACE_ZONE_END(z0);
return iree_status_from_code(IREE_STATUS_ABORTED);
}
if (semaphore->current_value >= value) {
// Fast path: already satisfied.
iree_slim_mutex_unlock(&semaphore->mutex);
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
if (iree_timeout_is_immediate(timeout)) {
// Not satisfied but a poll, so can avoid the expensive wait handle work.
iree_slim_mutex_unlock(&semaphore->mutex);
IREE_TRACE_ZONE_END(z0);
return iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED);
}
iree_slim_mutex_unlock(&semaphore->mutex);
iree_time_t deadline_ns = iree_timeout_as_deadline_ns(timeout);
// Slow path: acquire a timepoint. This should happen outside of the lock to
// given that acquiring has its own internal locks.
iree_hal_cuda2_timepoint_t* timepoint = NULL;
iree_status_t status = iree_hal_cuda2_semaphore_acquire_timepoint_host_wait(
semaphore, value, timeout, &timepoint);
if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
IREE_TRACE_ZONE_END(z0);
return status;
}
// Wait until the timepoint resolves.
// If satisfied the timepoint is automatically cleaned up and we are done. If
// the deadline is reached before satisfied then we have to clean it up.
status = iree_wait_one(&timepoint->timepoint.host_wait, deadline_ns);
if (!iree_status_is_ok(status)) {
iree_hal_semaphore_cancel_timepoint(&semaphore->base, &timepoint->base);
}
iree_hal_cuda2_timepoint_pool_release(semaphore->timepoint_pool, 1,
&timepoint);
IREE_TRACE_ZONE_END(z0);
return status;
}
// Handles device signal timepoints on the host when the |semaphore| timeline
// advances past the given |value|.
//
// Note that this callback is invoked by the a host thread after the CUDA host
// function callback function is triggered in the CUDA driver.
static iree_status_t iree_hal_cuda2_semaphore_timepoint_device_signal_callback(
void* user_data, iree_hal_semaphore_t* semaphore, uint64_t value,
iree_status_code_t status_code) {
iree_hal_cuda2_timepoint_t* timepoint =
(iree_hal_cuda2_timepoint_t*)user_data;
// Just release the timepoint back to the pool. This will decrease the
// reference count of the underlying CUDA event internally.
iree_hal_cuda2_timepoint_pool_release(timepoint->pool, 1, &timepoint);
return iree_ok_status();
}
// Acquires a timepoint to signal the timeline to the given |to_value| from the
// device.
iree_status_t iree_hal_cuda2_event_semaphore_acquire_timepoint_device_signal(
iree_hal_semaphore_t* base_semaphore, uint64_t to_value,
CUevent* out_event) {
iree_hal_cuda2_semaphore_t* semaphore =
iree_hal_cuda2_semaphore_cast(base_semaphore);
iree_hal_cuda2_timepoint_t* signal_timepoint = NULL;
IREE_TRACE_ZONE_BEGIN(z0);
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_cuda2_timepoint_pool_acquire_device_signal(
semaphore->timepoint_pool, 1, &signal_timepoint));
// Initialize the timepoint with the value and callback, and connect it to
// this semaphore.
iree_hal_semaphore_acquire_timepoint(
&semaphore->base, to_value, iree_infinite_timeout(),
(iree_hal_semaphore_callback_t){
.fn = iree_hal_cuda2_semaphore_timepoint_device_signal_callback,
.user_data = signal_timepoint,
},
&signal_timepoint->base);
iree_hal_cuda2_event_t* event = signal_timepoint->timepoint.device_signal;
// Scan through the timepoint list and update device wait timepoints to wait
// for this device signal when possible. We need to lock with the timepoint
// list mutex here.
iree_slim_mutex_lock(&semaphore->base.timepoint_mutex);
for (iree_hal_semaphore_timepoint_t* tp = semaphore->base.timepoint_list.head;
tp != NULL; tp = tp->next) {
iree_hal_cuda2_timepoint_t* wait_timepoint =
(iree_hal_cuda2_timepoint_t*)tp;
if (wait_timepoint->kind == IREE_HAL_CUDA_TIMEPOINT_KIND_DEVICE_WAIT &&
wait_timepoint->timepoint.device_wait == NULL &&
wait_timepoint->base.minimum_value <= to_value) {
iree_hal_cuda2_event_retain(event);
wait_timepoint->timepoint.device_wait = event;
}
}
iree_slim_mutex_unlock(&semaphore->base.timepoint_mutex);
*out_event = iree_hal_cuda2_event_handle(event);
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
// Handles device wait timepoints on the host when the |semaphore| timeline
// advances past the given |value|.
//
// Note that this callback is invoked by the a host thread.
static iree_status_t iree_hal_cuda2_semaphore_timepoint_device_wait_callback(
void* user_data, iree_hal_semaphore_t* semaphore, uint64_t value,
iree_status_code_t status_code) {
iree_hal_cuda2_timepoint_t* timepoint =
(iree_hal_cuda2_timepoint_t*)user_data;
// Just release the timepoint back to the pool. This will decrease the
// reference count of the underlying CUDA event internally.
iree_hal_cuda2_timepoint_pool_release(timepoint->pool, 1, &timepoint);
return iree_ok_status();
}
// Acquires a timepoint to wait the timeline to reach at least the given
// |min_value| on the device.
iree_status_t iree_hal_cuda2_event_semaphore_acquire_timepoint_device_wait(
iree_hal_semaphore_t* base_semaphore, uint64_t min_value,
CUevent* out_event) {
iree_hal_cuda2_semaphore_t* semaphore =
iree_hal_cuda2_semaphore_cast(base_semaphore);
iree_hal_cuda2_timepoint_t* wait_timepoint = NULL;
IREE_TRACE_ZONE_BEGIN(z0);
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_cuda2_timepoint_pool_acquire_device_wait(
semaphore->timepoint_pool, 1, &wait_timepoint));
// Initialize the timepoint with the value and callback, and connect it to
// this semaphore.
iree_hal_semaphore_acquire_timepoint(
&semaphore->base, min_value, iree_infinite_timeout(),
(iree_hal_semaphore_callback_t){
.fn = iree_hal_cuda2_semaphore_timepoint_device_wait_callback,
.user_data = wait_timepoint,
},
&wait_timepoint->base);
// Scan through the timepoint list and try to find a device event signal to
// wait on. We need to lock with the timepoint list mutex here.
iree_slim_mutex_lock(&semaphore->base.timepoint_mutex);
for (iree_hal_semaphore_timepoint_t* tp = semaphore->base.timepoint_list.head;
tp != NULL; tp = tp->next) {
iree_hal_cuda2_timepoint_t* signal_timepoint =
(iree_hal_cuda2_timepoint_t*)tp;
if (signal_timepoint->kind == IREE_HAL_CUDA_TIMEPOINT_KIND_DEVICE_SIGNAL &&
signal_timepoint->base.minimum_value >= min_value) {
// We've found an existing signal timepoint to wait on; we don't need a
// standalone wait timepoint anymore. Decrease its refcount before
// overwriting it to return it back to the pool and retain the new one.
iree_hal_cuda2_event_release(wait_timepoint->timepoint.device_wait);
iree_hal_cuda2_event_t* event = signal_timepoint->timepoint.device_signal;
iree_hal_cuda2_event_retain(event);
wait_timepoint->timepoint.device_wait = event;
}
}
iree_slim_mutex_unlock(&semaphore->base.timepoint_mutex);
*out_event =
iree_hal_cuda2_event_handle(wait_timepoint->timepoint.device_wait);
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
static const iree_hal_semaphore_vtable_t iree_hal_cuda2_semaphore_vtable = {
.destroy = iree_hal_cuda2_semaphore_destroy,
.query = iree_hal_cuda2_semaphore_query,
.signal = iree_hal_cuda2_semaphore_signal,
.fail = iree_hal_cuda2_semaphore_fail,
.wait = iree_hal_cuda2_semaphore_wait,
};