blob: 535406f46a5c94007934c7043bc27cee5a2808b9 [file] [log] [blame]
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "experimental/cuda2/timepoint_pool.h"
#include <stdbool.h>
#include <stddef.h>
#include <string.h>
#include "experimental/cuda2/cuda_dynamic_symbols.h"
#include "experimental/cuda2/cuda_status_util.h"
#include "experimental/cuda2/event_pool.h"
#include "iree/base/api.h"
#include "iree/base/internal/atomics.h"
#include "iree/base/internal/event_pool.h"
#include "iree/base/internal/synchronization.h"
#include "iree/hal/api.h"
#include "iree/hal/utils/semaphore_base.h"
//===----------------------------------------------------------------------===//
// iree_hal_cuda2_timepoint_t
//===----------------------------------------------------------------------===//
static iree_status_t iree_hal_cuda2_timepoint_allocate(
iree_hal_cuda2_timepoint_pool_t* pool, iree_allocator_t host_allocator,
iree_hal_cuda2_timepoint_t** out_timepoint) {
IREE_ASSERT_ARGUMENT(pool);
IREE_ASSERT_ARGUMENT(out_timepoint);
*out_timepoint = NULL;
IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_cuda2_timepoint_t* timepoint = NULL;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_allocator_malloc(host_allocator, sizeof(*timepoint),
(void**)&timepoint));
// iree_allocator_malloc zeros out the whole struct.
timepoint->host_allocator = host_allocator;
timepoint->pool = pool;
*out_timepoint = timepoint;
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
// Clears all data fields in the given |timepoint| except the original host
// allocator and owning pool.
static void iree_hal_cuda2_timepoint_clear(
iree_hal_cuda2_timepoint_t* timepoint) {
iree_allocator_t host_allocator = timepoint->host_allocator;
iree_hal_cuda2_timepoint_pool_t* pool = timepoint->pool;
memset(timepoint, 0, sizeof(*timepoint));
timepoint->host_allocator = host_allocator;
timepoint->pool = pool;
}
static void iree_hal_cuda2_timepoint_free(
iree_hal_cuda2_timepoint_t* timepoint) {
iree_allocator_t host_allocator = timepoint->host_allocator;
IREE_TRACE_ZONE_BEGIN(z0);
IREE_ASSERT(timepoint->kind == IREE_HAL_CUDA_TIMEPOINT_KIND_NONE);
iree_allocator_free(host_allocator, timepoint);
IREE_TRACE_ZONE_END(z0);
}
//===----------------------------------------------------------------------===//
// iree_hal_cuda2_timepoint_pool_t
//===----------------------------------------------------------------------===//
struct iree_hal_cuda2_timepoint_pool_t {
// The allocator used to create the timepoint pool.
iree_allocator_t host_allocator;
// The pool to acquire host events.
iree_event_pool_t* host_event_pool;
// The pool to acquire device events. Internally synchronized.
iree_hal_cuda2_event_pool_t* device_event_pool;
// Note that the above pools are internally synchronized; so we don't and
// shouldn't use the following mutex to guard access to them.
// Guards timepoint related fields this pool. We don't expect a performant
// program to frequently allocate timepoints for synchronization purposes; the
// traffic to this pool should be low. So it should be fine to use mutex to
// guard here.
iree_slim_mutex_t timepoint_mutex;
// Maximum number of timepoint objects that will be maintained in the pool.
// More timepoints may be allocated at any time, but they will be disposed
// directly when they are no longer needed.
iree_host_size_t available_capacity IREE_GUARDED_BY(timepoint_mutex);
// Total number of currently available timepoint objects.
iree_host_size_t available_count IREE_GUARDED_BY(timepoint_mutex);
// The list of available_count timepoint objects.
iree_hal_cuda2_timepoint_t* available_list[] IREE_GUARDED_BY(timepoint_mutex);
};
// + Additional inline allocation for holding timepoints up to the capacity.
iree_status_t iree_hal_cuda2_timepoint_pool_allocate(
iree_event_pool_t* host_event_pool,
iree_hal_cuda2_event_pool_t* device_event_pool,
iree_host_size_t available_capacity, iree_allocator_t host_allocator,
iree_hal_cuda2_timepoint_pool_t** out_timepoint_pool) {
IREE_ASSERT_ARGUMENT(host_event_pool);
IREE_ASSERT_ARGUMENT(device_event_pool);
IREE_ASSERT_ARGUMENT(out_timepoint_pool);
*out_timepoint_pool = NULL;
IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_cuda2_timepoint_pool_t* timepoint_pool = NULL;
iree_host_size_t total_size =
sizeof(*timepoint_pool) +
available_capacity * sizeof(*timepoint_pool->available_list);
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_allocator_malloc(host_allocator, total_size,
(void**)&timepoint_pool));
timepoint_pool->host_allocator = host_allocator;
timepoint_pool->host_event_pool = host_event_pool;
timepoint_pool->device_event_pool = device_event_pool;
iree_slim_mutex_initialize(&timepoint_pool->timepoint_mutex);
timepoint_pool->available_capacity = available_capacity;
timepoint_pool->available_count = 0;
iree_status_t status = iree_ok_status();
for (iree_host_size_t i = 0; i < available_capacity; ++i) {
status = iree_hal_cuda2_timepoint_allocate(
timepoint_pool, host_allocator,
&timepoint_pool->available_list[timepoint_pool->available_count++]);
if (!iree_status_is_ok(status)) break;
}
if (iree_status_is_ok(status)) {
*out_timepoint_pool = timepoint_pool;
} else {
iree_hal_cuda2_timepoint_pool_free(timepoint_pool);
}
IREE_TRACE_ZONE_END(z0);
return status;
}
void iree_hal_cuda2_timepoint_pool_free(
iree_hal_cuda2_timepoint_pool_t* timepoint_pool) {
iree_allocator_t host_allocator = timepoint_pool->host_allocator;
IREE_TRACE_ZONE_BEGIN(z0);
for (iree_host_size_t i = 0; i < timepoint_pool->available_count; ++i) {
iree_hal_cuda2_timepoint_free(timepoint_pool->available_list[i]);
}
iree_slim_mutex_deinitialize(&timepoint_pool->timepoint_mutex);
iree_allocator_free(host_allocator, timepoint_pool);
IREE_TRACE_ZONE_END(z0);
}
// Acquires |timepoint_count| timepoints from the given |timepoint_pool|.
// The |out_timepoints| needs to be further initialized with proper kind and
// payload values.
static iree_status_t iree_hal_cuda2_timepoint_pool_acquire_internal(
iree_hal_cuda2_timepoint_pool_t* timepoint_pool,
iree_host_size_t timepoint_count,
iree_hal_cuda2_timepoint_t** out_timepoints) {
IREE_ASSERT_ARGUMENT(timepoint_pool);
if (!timepoint_count) return iree_ok_status();
IREE_ASSERT_ARGUMENT(out_timepoints);
IREE_TRACE_ZONE_BEGIN(z0);
// We'll try to get what we can from the pool and fall back to initializing
// new iree_hal_cuda2_timepoint_t objects.
iree_host_size_t remaining_count = timepoint_count;
// Try first to grab from the pool.
iree_slim_mutex_lock(&timepoint_pool->timepoint_mutex);
iree_host_size_t from_pool_count =
iree_min(timepoint_pool->available_count, timepoint_count);
if (from_pool_count > 0) {
iree_host_size_t pool_base_index =
timepoint_pool->available_count - from_pool_count;
memcpy(out_timepoints, &timepoint_pool->available_list[pool_base_index],
from_pool_count * sizeof(*timepoint_pool->available_list));
timepoint_pool->available_count -= from_pool_count;
remaining_count -= from_pool_count;
}
iree_slim_mutex_unlock(&timepoint_pool->timepoint_mutex);
// Allocate the rest of the timepoints.
if (remaining_count > 0) {
IREE_TRACE_ZONE_BEGIN_NAMED(z1, "timepoint-pool-unpooled-acquire");
iree_status_t status = iree_ok_status();
for (iree_host_size_t i = 0; i < remaining_count; ++i) {
status = iree_hal_cuda2_timepoint_allocate(
timepoint_pool, timepoint_pool->host_allocator,
&out_timepoints[from_pool_count + i]);
if (!iree_status_is_ok(status)) {
// Must release all timepoints we've acquired so far.
iree_hal_cuda2_timepoint_pool_release(
timepoint_pool, from_pool_count + i, out_timepoints);
IREE_TRACE_ZONE_END(z1);
IREE_TRACE_ZONE_END(z0);
return status;
}
}
IREE_TRACE_ZONE_END(z1);
}
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
iree_status_t iree_hal_cuda2_timepoint_pool_acquire_host_wait(
iree_hal_cuda2_timepoint_pool_t* timepoint_pool,
iree_host_size_t timepoint_count,
iree_hal_cuda2_timepoint_t** out_timepoints) {
IREE_TRACE_ZONE_BEGIN(z0);
// Acquire host events to wrap up. This should happen before acquiring the
// timepoints to avoid nested locks.
iree_event_t* host_events = iree_alloca(
timepoint_count * sizeof((*out_timepoints)->timepoint.host_wait));
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_event_pool_acquire(timepoint_pool->host_event_pool,
timepoint_count, host_events));
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_cuda2_timepoint_pool_acquire_internal(
timepoint_pool, timepoint_count, out_timepoints));
for (iree_host_size_t i = 0; i < timepoint_count; ++i) {
out_timepoints[i]->kind = IREE_HAL_CUDA_TIMEPOINT_KIND_HOST_WAIT;
out_timepoints[i]->timepoint.host_wait = host_events[i];
}
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
iree_status_t iree_hal_cuda2_timepoint_pool_acquire_device_signal(
iree_hal_cuda2_timepoint_pool_t* timepoint_pool,
iree_host_size_t timepoint_count,
iree_hal_cuda2_timepoint_t** out_timepoints) {
IREE_TRACE_ZONE_BEGIN(z0);
// Acquire device events to wrap up. This should happen before acquiring the
// timepoints to avoid nested locks.
iree_hal_cuda2_event_t** device_events = iree_alloca(
timepoint_count * sizeof((*out_timepoints)->timepoint.device_signal));
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_cuda2_event_pool_acquire(timepoint_pool->device_event_pool,
timepoint_count, device_events));
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_cuda2_timepoint_pool_acquire_internal(
timepoint_pool, timepoint_count, out_timepoints));
for (iree_host_size_t i = 0; i < timepoint_count; ++i) {
out_timepoints[i]->kind = IREE_HAL_CUDA_TIMEPOINT_KIND_DEVICE_SIGNAL;
out_timepoints[i]->timepoint.device_signal = device_events[i];
}
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
iree_status_t iree_hal_cuda2_timepoint_pool_acquire_device_wait(
iree_hal_cuda2_timepoint_pool_t* timepoint_pool,
iree_host_size_t timepoint_count,
iree_hal_cuda2_timepoint_t** out_timepoints) {
IREE_TRACE_ZONE_BEGIN(z0);
// Acquire device events to wrap up. This should happen before acquiring the
// timepoints to avoid nested locks.
iree_hal_cuda2_event_t** device_events = iree_alloca(
timepoint_count * sizeof((*out_timepoints)->timepoint.device_wait));
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_cuda2_event_pool_acquire(timepoint_pool->device_event_pool,
timepoint_count, device_events));
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_cuda2_timepoint_pool_acquire_internal(
timepoint_pool, timepoint_count, out_timepoints));
for (iree_host_size_t i = 0; i < timepoint_count; ++i) {
out_timepoints[i]->kind = IREE_HAL_CUDA_TIMEPOINT_KIND_DEVICE_WAIT;
out_timepoints[i]->timepoint.device_wait = device_events[i];
}
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
void iree_hal_cuda2_timepoint_pool_release(
iree_hal_cuda2_timepoint_pool_t* timepoint_pool,
iree_host_size_t timepoint_count, iree_hal_cuda2_timepoint_t** timepoints) {
IREE_ASSERT_ARGUMENT(timepoint_pool);
if (!timepoint_count) return;
IREE_ASSERT_ARGUMENT(timepoints);
IREE_TRACE_ZONE_BEGIN(z0);
// Release the wrapped host/device events. This should happen before acquiring
// the timepoint pool's lock given that the host/device event pool its
// internal lock too.
// TODO: Release in batch to avoid lock overhead from separate calls.
for (iree_host_size_t i = 0; i < timepoint_count; ++i) {
switch (timepoints[i]->kind) {
case IREE_HAL_CUDA_TIMEPOINT_KIND_HOST_WAIT:
iree_event_pool_release(timepoint_pool->host_event_pool, 1,
&timepoints[i]->timepoint.host_wait);
break;
case IREE_HAL_CUDA_TIMEPOINT_KIND_DEVICE_SIGNAL:
iree_hal_cuda2_event_release(timepoints[i]->timepoint.device_signal);
break;
case IREE_HAL_CUDA_TIMEPOINT_KIND_DEVICE_WAIT:
iree_hal_cuda2_event_release(timepoints[i]->timepoint.device_wait);
break;
default:
break;
}
}
// We'll try to release all we can back to the pool and then deinitialize
// the ones that won't fit.
iree_host_size_t remaining_count = timepoint_count;
// Try first to release to the pool.
iree_slim_mutex_lock(&timepoint_pool->timepoint_mutex);
iree_host_size_t to_pool_count = iree_min(
timepoint_pool->available_capacity - timepoint_pool->available_count,
timepoint_count);
if (to_pool_count > 0) {
for (iree_host_size_t i = 0; i < to_pool_count; ++i) {
iree_hal_cuda2_timepoint_clear(timepoints[i]);
}
iree_host_size_t pool_base_index = timepoint_pool->available_count;
memcpy(&timepoint_pool->available_list[pool_base_index], timepoints,
to_pool_count * sizeof(*timepoint_pool->available_list));
timepoint_pool->available_count += to_pool_count;
remaining_count -= to_pool_count;
}
iree_slim_mutex_unlock(&timepoint_pool->timepoint_mutex);
// Deallocate the rest of the timepoints. We don't bother resetting them as we
// are getting rid of them.
if (remaining_count > 0) {
IREE_TRACE_ZONE_BEGIN_NAMED(z1, "timepoint-pool-unpooled-release");
for (iree_host_size_t i = 0; i < remaining_count; ++i) {
iree_hal_cuda2_timepoint_clear(timepoints[to_pool_count + i]);
iree_hal_cuda2_timepoint_free(timepoints[to_pool_count + i]);
}
IREE_TRACE_ZONE_END(z1);
}
IREE_TRACE_ZONE_END(z0);
}