|  | // Copyright 2023 The IREE Authors | 
|  | // | 
|  | // Licensed under the Apache License v2.0 with LLVM Exceptions. | 
|  | // See https://llvm.org/LICENSE.txt for license information. | 
|  | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
|  |  | 
|  | #include "experimental/cuda2/timepoint_pool.h" | 
|  |  | 
|  | #include <stdbool.h> | 
|  | #include <stddef.h> | 
|  | #include <string.h> | 
|  |  | 
|  | #include "experimental/cuda2/cuda_dynamic_symbols.h" | 
|  | #include "experimental/cuda2/cuda_status_util.h" | 
|  | #include "experimental/cuda2/event_pool.h" | 
|  | #include "iree/base/api.h" | 
|  | #include "iree/base/internal/atomics.h" | 
|  | #include "iree/base/internal/event_pool.h" | 
|  | #include "iree/base/internal/synchronization.h" | 
|  | #include "iree/hal/api.h" | 
|  | #include "iree/hal/utils/semaphore_base.h" | 
|  |  | 
|  | //===----------------------------------------------------------------------===// | 
|  | // iree_hal_cuda2_timepoint_t | 
|  | //===----------------------------------------------------------------------===// | 
|  |  | 
|  | static iree_status_t iree_hal_cuda2_timepoint_allocate( | 
|  | iree_hal_cuda2_timepoint_pool_t* pool, iree_allocator_t host_allocator, | 
|  | iree_hal_cuda2_timepoint_t** out_timepoint) { | 
|  | IREE_ASSERT_ARGUMENT(pool); | 
|  | IREE_ASSERT_ARGUMENT(out_timepoint); | 
|  | *out_timepoint = NULL; | 
|  | IREE_TRACE_ZONE_BEGIN(z0); | 
|  |  | 
|  | iree_hal_cuda2_timepoint_t* timepoint = NULL; | 
|  | IREE_RETURN_AND_END_ZONE_IF_ERROR( | 
|  | z0, iree_allocator_malloc(host_allocator, sizeof(*timepoint), | 
|  | (void**)&timepoint)); | 
|  | // iree_allocator_malloc zeros out the whole struct. | 
|  | timepoint->host_allocator = host_allocator; | 
|  | timepoint->pool = pool; | 
|  |  | 
|  | *out_timepoint = timepoint; | 
|  |  | 
|  | IREE_TRACE_ZONE_END(z0); | 
|  | return iree_ok_status(); | 
|  | } | 
|  |  | 
|  | // Clears all data fields in the given |timepoint| except the original host | 
|  | // allocator and owning pool. | 
|  | static void iree_hal_cuda2_timepoint_clear( | 
|  | iree_hal_cuda2_timepoint_t* timepoint) { | 
|  | iree_allocator_t host_allocator = timepoint->host_allocator; | 
|  | iree_hal_cuda2_timepoint_pool_t* pool = timepoint->pool; | 
|  | memset(timepoint, 0, sizeof(*timepoint)); | 
|  | timepoint->host_allocator = host_allocator; | 
|  | timepoint->pool = pool; | 
|  | } | 
|  |  | 
|  | static void iree_hal_cuda2_timepoint_free( | 
|  | iree_hal_cuda2_timepoint_t* timepoint) { | 
|  | iree_allocator_t host_allocator = timepoint->host_allocator; | 
|  | IREE_TRACE_ZONE_BEGIN(z0); | 
|  |  | 
|  | IREE_ASSERT(timepoint->kind == IREE_HAL_CUDA_TIMEPOINT_KIND_NONE); | 
|  | iree_allocator_free(host_allocator, timepoint); | 
|  |  | 
|  | IREE_TRACE_ZONE_END(z0); | 
|  | } | 
|  |  | 
|  | //===----------------------------------------------------------------------===// | 
|  | // iree_hal_cuda2_timepoint_pool_t | 
|  | //===----------------------------------------------------------------------===// | 
|  |  | 
|  | struct iree_hal_cuda2_timepoint_pool_t { | 
|  | // The allocator used to create the timepoint pool. | 
|  | iree_allocator_t host_allocator; | 
|  |  | 
|  | // The pool to acquire host events. | 
|  | iree_event_pool_t* host_event_pool; | 
|  | // The pool to acquire device events. Internally synchronized. | 
|  | iree_hal_cuda2_event_pool_t* device_event_pool; | 
|  |  | 
|  | // Note that the above pools are internally synchronized; so we don't and | 
|  | // shouldn't use the following mutex to guard access to them. | 
|  |  | 
|  | // Guards timepoint related fields this pool. We don't expect a performant | 
|  | // program to frequently allocate timepoints for synchronization purposes; the | 
|  | // traffic to this pool should be low. So it should be fine to use mutex to | 
|  | // guard here. | 
|  | iree_slim_mutex_t timepoint_mutex; | 
|  |  | 
|  | // Maximum number of timepoint objects that will be maintained in the pool. | 
|  | // More timepoints may be allocated at any time, but they will be disposed | 
|  | // directly when they are no longer needed. | 
|  | iree_host_size_t available_capacity IREE_GUARDED_BY(timepoint_mutex); | 
|  | // Total number of currently available timepoint objects. | 
|  | iree_host_size_t available_count IREE_GUARDED_BY(timepoint_mutex); | 
|  | // The list of available_count timepoint objects. | 
|  | iree_hal_cuda2_timepoint_t* available_list[] IREE_GUARDED_BY(timepoint_mutex); | 
|  | }; | 
|  | // + Additional inline allocation for holding timepoints up to the capacity. | 
|  |  | 
|  | iree_status_t iree_hal_cuda2_timepoint_pool_allocate( | 
|  | iree_event_pool_t* host_event_pool, | 
|  | iree_hal_cuda2_event_pool_t* device_event_pool, | 
|  | iree_host_size_t available_capacity, iree_allocator_t host_allocator, | 
|  | iree_hal_cuda2_timepoint_pool_t** out_timepoint_pool) { | 
|  | IREE_ASSERT_ARGUMENT(host_event_pool); | 
|  | IREE_ASSERT_ARGUMENT(device_event_pool); | 
|  | IREE_ASSERT_ARGUMENT(out_timepoint_pool); | 
|  | *out_timepoint_pool = NULL; | 
|  | IREE_TRACE_ZONE_BEGIN(z0); | 
|  |  | 
|  | iree_hal_cuda2_timepoint_pool_t* timepoint_pool = NULL; | 
|  | iree_host_size_t total_size = | 
|  | sizeof(*timepoint_pool) + | 
|  | available_capacity * sizeof(*timepoint_pool->available_list); | 
|  | IREE_RETURN_AND_END_ZONE_IF_ERROR( | 
|  | z0, iree_allocator_malloc(host_allocator, total_size, | 
|  | (void**)&timepoint_pool)); | 
|  | timepoint_pool->host_allocator = host_allocator; | 
|  | timepoint_pool->host_event_pool = host_event_pool; | 
|  | timepoint_pool->device_event_pool = device_event_pool; | 
|  |  | 
|  | iree_slim_mutex_initialize(&timepoint_pool->timepoint_mutex); | 
|  | timepoint_pool->available_capacity = available_capacity; | 
|  | timepoint_pool->available_count = 0; | 
|  |  | 
|  | iree_status_t status = iree_ok_status(); | 
|  | for (iree_host_size_t i = 0; i < available_capacity; ++i) { | 
|  | status = iree_hal_cuda2_timepoint_allocate( | 
|  | timepoint_pool, host_allocator, | 
|  | &timepoint_pool->available_list[timepoint_pool->available_count++]); | 
|  | if (!iree_status_is_ok(status)) break; | 
|  | } | 
|  |  | 
|  | if (iree_status_is_ok(status)) { | 
|  | *out_timepoint_pool = timepoint_pool; | 
|  | } else { | 
|  | iree_hal_cuda2_timepoint_pool_free(timepoint_pool); | 
|  | } | 
|  | IREE_TRACE_ZONE_END(z0); | 
|  | return status; | 
|  | } | 
|  |  | 
|  | void iree_hal_cuda2_timepoint_pool_free( | 
|  | iree_hal_cuda2_timepoint_pool_t* timepoint_pool) { | 
|  | iree_allocator_t host_allocator = timepoint_pool->host_allocator; | 
|  | IREE_TRACE_ZONE_BEGIN(z0); | 
|  |  | 
|  | for (iree_host_size_t i = 0; i < timepoint_pool->available_count; ++i) { | 
|  | iree_hal_cuda2_timepoint_free(timepoint_pool->available_list[i]); | 
|  | } | 
|  | iree_slim_mutex_deinitialize(&timepoint_pool->timepoint_mutex); | 
|  | iree_allocator_free(host_allocator, timepoint_pool); | 
|  |  | 
|  | IREE_TRACE_ZONE_END(z0); | 
|  | } | 
|  |  | 
|  | // Acquires |timepoint_count| timepoints from the given |timepoint_pool|. | 
|  | // The |out_timepoints| needs to be further initialized with proper kind and | 
|  | // payload values. | 
|  | static iree_status_t iree_hal_cuda2_timepoint_pool_acquire_internal( | 
|  | iree_hal_cuda2_timepoint_pool_t* timepoint_pool, | 
|  | iree_host_size_t timepoint_count, | 
|  | iree_hal_cuda2_timepoint_t** out_timepoints) { | 
|  | IREE_ASSERT_ARGUMENT(timepoint_pool); | 
|  | if (!timepoint_count) return iree_ok_status(); | 
|  | IREE_ASSERT_ARGUMENT(out_timepoints); | 
|  | IREE_TRACE_ZONE_BEGIN(z0); | 
|  |  | 
|  | // We'll try to get what we can from the pool and fall back to initializing | 
|  | // new iree_hal_cuda2_timepoint_t objects. | 
|  | iree_host_size_t remaining_count = timepoint_count; | 
|  |  | 
|  | // Try first to grab from the pool. | 
|  | iree_slim_mutex_lock(&timepoint_pool->timepoint_mutex); | 
|  | iree_host_size_t from_pool_count = | 
|  | iree_min(timepoint_pool->available_count, timepoint_count); | 
|  | if (from_pool_count > 0) { | 
|  | iree_host_size_t pool_base_index = | 
|  | timepoint_pool->available_count - from_pool_count; | 
|  | memcpy(out_timepoints, &timepoint_pool->available_list[pool_base_index], | 
|  | from_pool_count * sizeof(*timepoint_pool->available_list)); | 
|  | timepoint_pool->available_count -= from_pool_count; | 
|  | remaining_count -= from_pool_count; | 
|  | } | 
|  | iree_slim_mutex_unlock(&timepoint_pool->timepoint_mutex); | 
|  |  | 
|  | // Allocate the rest of the timepoints. | 
|  | if (remaining_count > 0) { | 
|  | IREE_TRACE_ZONE_BEGIN_NAMED(z1, "timepoint-pool-unpooled-acquire"); | 
|  | iree_status_t status = iree_ok_status(); | 
|  | for (iree_host_size_t i = 0; i < remaining_count; ++i) { | 
|  | status = iree_hal_cuda2_timepoint_allocate( | 
|  | timepoint_pool, timepoint_pool->host_allocator, | 
|  | &out_timepoints[from_pool_count + i]); | 
|  | if (!iree_status_is_ok(status)) { | 
|  | // Must release all timepoints we've acquired so far. | 
|  | iree_hal_cuda2_timepoint_pool_release( | 
|  | timepoint_pool, from_pool_count + i, out_timepoints); | 
|  | IREE_TRACE_ZONE_END(z1); | 
|  | IREE_TRACE_ZONE_END(z0); | 
|  | return status; | 
|  | } | 
|  | } | 
|  | IREE_TRACE_ZONE_END(z1); | 
|  | } | 
|  |  | 
|  | IREE_TRACE_ZONE_END(z0); | 
|  | return iree_ok_status(); | 
|  | } | 
|  |  | 
|  | iree_status_t iree_hal_cuda2_timepoint_pool_acquire_host_wait( | 
|  | iree_hal_cuda2_timepoint_pool_t* timepoint_pool, | 
|  | iree_host_size_t timepoint_count, | 
|  | iree_hal_cuda2_timepoint_t** out_timepoints) { | 
|  | IREE_TRACE_ZONE_BEGIN(z0); | 
|  |  | 
|  | // Acquire host events to wrap up. This should happen before acquiring the | 
|  | // timepoints to avoid nested locks. | 
|  | iree_event_t* host_events = iree_alloca( | 
|  | timepoint_count * sizeof((*out_timepoints)->timepoint.host_wait)); | 
|  | IREE_RETURN_AND_END_ZONE_IF_ERROR( | 
|  | z0, iree_event_pool_acquire(timepoint_pool->host_event_pool, | 
|  | timepoint_count, host_events)); | 
|  |  | 
|  | IREE_RETURN_AND_END_ZONE_IF_ERROR( | 
|  | z0, iree_hal_cuda2_timepoint_pool_acquire_internal( | 
|  | timepoint_pool, timepoint_count, out_timepoints)); | 
|  | for (iree_host_size_t i = 0; i < timepoint_count; ++i) { | 
|  | out_timepoints[i]->kind = IREE_HAL_CUDA_TIMEPOINT_KIND_HOST_WAIT; | 
|  | out_timepoints[i]->timepoint.host_wait = host_events[i]; | 
|  | } | 
|  |  | 
|  | IREE_TRACE_ZONE_END(z0); | 
|  | return iree_ok_status(); | 
|  | } | 
|  |  | 
|  | iree_status_t iree_hal_cuda2_timepoint_pool_acquire_device_signal( | 
|  | iree_hal_cuda2_timepoint_pool_t* timepoint_pool, | 
|  | iree_host_size_t timepoint_count, | 
|  | iree_hal_cuda2_timepoint_t** out_timepoints) { | 
|  | IREE_TRACE_ZONE_BEGIN(z0); | 
|  |  | 
|  | // Acquire device events to wrap up. This should happen before acquiring the | 
|  | // timepoints to avoid nested locks. | 
|  | iree_hal_cuda2_event_t** device_events = iree_alloca( | 
|  | timepoint_count * sizeof((*out_timepoints)->timepoint.device_signal)); | 
|  | IREE_RETURN_AND_END_ZONE_IF_ERROR( | 
|  | z0, iree_hal_cuda2_event_pool_acquire(timepoint_pool->device_event_pool, | 
|  | timepoint_count, device_events)); | 
|  |  | 
|  | IREE_RETURN_AND_END_ZONE_IF_ERROR( | 
|  | z0, iree_hal_cuda2_timepoint_pool_acquire_internal( | 
|  | timepoint_pool, timepoint_count, out_timepoints)); | 
|  | for (iree_host_size_t i = 0; i < timepoint_count; ++i) { | 
|  | out_timepoints[i]->kind = IREE_HAL_CUDA_TIMEPOINT_KIND_DEVICE_SIGNAL; | 
|  | out_timepoints[i]->timepoint.device_signal = device_events[i]; | 
|  | } | 
|  |  | 
|  | IREE_TRACE_ZONE_END(z0); | 
|  | return iree_ok_status(); | 
|  | } | 
|  |  | 
|  | iree_status_t iree_hal_cuda2_timepoint_pool_acquire_device_wait( | 
|  | iree_hal_cuda2_timepoint_pool_t* timepoint_pool, | 
|  | iree_host_size_t timepoint_count, | 
|  | iree_hal_cuda2_timepoint_t** out_timepoints) { | 
|  | IREE_TRACE_ZONE_BEGIN(z0); | 
|  |  | 
|  | // Acquire device events to wrap up. This should happen before acquiring the | 
|  | // timepoints to avoid nested locks. | 
|  | iree_hal_cuda2_event_t** device_events = iree_alloca( | 
|  | timepoint_count * sizeof((*out_timepoints)->timepoint.device_wait)); | 
|  | IREE_RETURN_AND_END_ZONE_IF_ERROR( | 
|  | z0, iree_hal_cuda2_event_pool_acquire(timepoint_pool->device_event_pool, | 
|  | timepoint_count, device_events)); | 
|  |  | 
|  | IREE_RETURN_AND_END_ZONE_IF_ERROR( | 
|  | z0, iree_hal_cuda2_timepoint_pool_acquire_internal( | 
|  | timepoint_pool, timepoint_count, out_timepoints)); | 
|  | for (iree_host_size_t i = 0; i < timepoint_count; ++i) { | 
|  | out_timepoints[i]->kind = IREE_HAL_CUDA_TIMEPOINT_KIND_DEVICE_WAIT; | 
|  | out_timepoints[i]->timepoint.device_wait = device_events[i]; | 
|  | } | 
|  |  | 
|  | IREE_TRACE_ZONE_END(z0); | 
|  | return iree_ok_status(); | 
|  | } | 
|  |  | 
|  | void iree_hal_cuda2_timepoint_pool_release( | 
|  | iree_hal_cuda2_timepoint_pool_t* timepoint_pool, | 
|  | iree_host_size_t timepoint_count, iree_hal_cuda2_timepoint_t** timepoints) { | 
|  | IREE_ASSERT_ARGUMENT(timepoint_pool); | 
|  | if (!timepoint_count) return; | 
|  | IREE_ASSERT_ARGUMENT(timepoints); | 
|  | IREE_TRACE_ZONE_BEGIN(z0); | 
|  |  | 
|  | // Release the wrapped host/device events. This should happen before acquiring | 
|  | // the timepoint pool's lock given that the host/device event pool its | 
|  | // internal lock too. | 
|  | // TODO: Release in batch to avoid lock overhead from separate calls. | 
|  | for (iree_host_size_t i = 0; i < timepoint_count; ++i) { | 
|  | switch (timepoints[i]->kind) { | 
|  | case IREE_HAL_CUDA_TIMEPOINT_KIND_HOST_WAIT: | 
|  | iree_event_pool_release(timepoint_pool->host_event_pool, 1, | 
|  | &timepoints[i]->timepoint.host_wait); | 
|  | break; | 
|  | case IREE_HAL_CUDA_TIMEPOINT_KIND_DEVICE_SIGNAL: | 
|  | iree_hal_cuda2_event_release(timepoints[i]->timepoint.device_signal); | 
|  | break; | 
|  | case IREE_HAL_CUDA_TIMEPOINT_KIND_DEVICE_WAIT: | 
|  | iree_hal_cuda2_event_release(timepoints[i]->timepoint.device_wait); | 
|  | break; | 
|  | default: | 
|  | break; | 
|  | } | 
|  | } | 
|  |  | 
|  | // We'll try to release all we can back to the pool and then deinitialize | 
|  | // the ones that won't fit. | 
|  | iree_host_size_t remaining_count = timepoint_count; | 
|  |  | 
|  | // Try first to release to the pool. | 
|  | iree_slim_mutex_lock(&timepoint_pool->timepoint_mutex); | 
|  | iree_host_size_t to_pool_count = iree_min( | 
|  | timepoint_pool->available_capacity - timepoint_pool->available_count, | 
|  | timepoint_count); | 
|  | if (to_pool_count > 0) { | 
|  | for (iree_host_size_t i = 0; i < to_pool_count; ++i) { | 
|  | iree_hal_cuda2_timepoint_clear(timepoints[i]); | 
|  | } | 
|  | iree_host_size_t pool_base_index = timepoint_pool->available_count; | 
|  | memcpy(&timepoint_pool->available_list[pool_base_index], timepoints, | 
|  | to_pool_count * sizeof(*timepoint_pool->available_list)); | 
|  | timepoint_pool->available_count += to_pool_count; | 
|  | remaining_count -= to_pool_count; | 
|  | } | 
|  | iree_slim_mutex_unlock(&timepoint_pool->timepoint_mutex); | 
|  |  | 
|  | // Deallocate the rest of the timepoints. We don't bother resetting them as we | 
|  | // are getting rid of them. | 
|  | if (remaining_count > 0) { | 
|  | IREE_TRACE_ZONE_BEGIN_NAMED(z1, "timepoint-pool-unpooled-release"); | 
|  | for (iree_host_size_t i = 0; i < remaining_count; ++i) { | 
|  | iree_hal_cuda2_timepoint_clear(timepoints[to_pool_count + i]); | 
|  | iree_hal_cuda2_timepoint_free(timepoints[to_pool_count + i]); | 
|  | } | 
|  | IREE_TRACE_ZONE_END(z1); | 
|  | } | 
|  | IREE_TRACE_ZONE_END(z0); | 
|  | } |