[cuda] Port over changes made in hip backend to cuda (#18079)
This brings over 2 changes from hip over into cuda.
Specifically replacing `cudaLaunchHostFunc` with a dedicated waiting
thread, and replacing the use of atomic_slist (which has
non-deterministic ordering) with a well-ordered linked list.
The original commits on main branch were #17925 and #18048
---------
Signed-off-by: Andrew Woloszyn <andrew.woloszyn@gmail.com>
diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_device.c b/runtime/src/iree/hal/drivers/cuda/cuda_device.c
index a78dc6f..011b5d9 100644
--- a/runtime/src/iree/hal/drivers/cuda/cuda_device.c
+++ b/runtime/src/iree/hal/drivers/cuda/cuda_device.c
@@ -61,8 +61,6 @@
// TODO: Support multiple device streams.
// The CUstream used to issue device kernels and allocations.
CUstream dispatch_cu_stream;
- // The CUstream used to issue host callback functions.
- CUstream callback_cu_stream;
iree_hal_cuda_tracing_context_t* tracing_context;
@@ -129,7 +127,7 @@
static iree_status_t iree_hal_cuda_device_create_internal(
iree_hal_driver_t* driver, iree_string_view_t identifier,
const iree_hal_cuda_device_params_t* params, CUdevice cu_device,
- CUstream dispatch_stream, CUstream callback_stream, CUcontext context,
+ CUstream dispatch_stream, CUcontext context,
const iree_hal_cuda_dynamic_symbols_t* cuda_symbols,
const iree_hal_cuda_nccl_dynamic_symbols_t* nccl_symbols,
iree_allocator_t host_allocator, iree_hal_device_t** out_device) {
@@ -152,11 +150,10 @@
device->cu_context = context;
device->cu_device = cu_device;
device->dispatch_cu_stream = dispatch_stream;
- device->callback_cu_stream = callback_stream;
device->host_allocator = host_allocator;
iree_status_t status = iree_hal_cuda_pending_queue_actions_create(
- cuda_symbols, &device->block_pool, host_allocator,
+ cuda_symbols, cu_device, context, &device->block_pool, host_allocator,
&device->pending_queue_actions);
// Enable tracing for the (currently only) stream - no-op if disabled.
@@ -230,20 +227,13 @@
status = IREE_CURESULT_TO_STATUS(
cuda_symbols, cuStreamCreate(&dispatch_stream, CU_STREAM_NON_BLOCKING));
}
- // Create the default callback stream for the device.
- CUstream callback_stream = NULL;
- if (iree_status_is_ok(status)) {
- status = IREE_CURESULT_TO_STATUS(
- cuda_symbols, cuStreamCreate(&callback_stream, CU_STREAM_NON_BLOCKING));
- }
if (iree_status_is_ok(status)) {
status = iree_hal_cuda_device_create_internal(
- driver, identifier, params, device, dispatch_stream, callback_stream,
- context, cuda_symbols, nccl_symbols, host_allocator, out_device);
+ driver, identifier, params, device, dispatch_stream, context,
+ cuda_symbols, nccl_symbols, host_allocator, out_device);
} else {
// Release resources we have accquired thus far.
- if (callback_stream) cuda_symbols->cuStreamDestroy(callback_stream);
if (dispatch_stream) cuda_symbols->cuStreamDestroy(dispatch_stream);
if (context) cuda_symbols->cuDevicePrimaryCtxRelease(device);
}
@@ -331,7 +321,6 @@
if (device->host_event_pool) iree_event_pool_free(device->host_event_pool);
IREE_CUDA_IGNORE_ERROR(symbols, cuStreamDestroy(device->dispatch_cu_stream));
- IREE_CUDA_IGNORE_ERROR(symbols, cuStreamDestroy(device->callback_cu_stream));
IREE_CUDA_IGNORE_ERROR(symbols, cuDevicePrimaryCtxRelease(device->cu_device));
@@ -776,8 +765,7 @@
IREE_TRACE_ZONE_BEGIN(z0);
iree_status_t status = iree_hal_cuda_pending_queue_actions_enqueue_execution(
- base_device, device->dispatch_cu_stream, device->callback_cu_stream,
- device->pending_queue_actions,
+ base_device, device->dispatch_cu_stream, device->pending_queue_actions,
iree_hal_cuda_device_collect_tracing_context, device->tracing_context,
wait_semaphore_list, signal_semaphore_list, command_buffer_count,
command_buffers, binding_tables);
diff --git a/runtime/src/iree/hal/drivers/cuda/pending_queue_actions.c b/runtime/src/iree/hal/drivers/cuda/pending_queue_actions.c
index 42f1f86..2e964e8 100644
--- a/runtime/src/iree/hal/drivers/cuda/pending_queue_actions.c
+++ b/runtime/src/iree/hal/drivers/cuda/pending_queue_actions.c
@@ -85,8 +85,6 @@
// The stream to launch main GPU workload.
CUstream dispatch_cu_stream;
- // The stream to launch CUDA host function callbacks.
- CUstream callback_cu_stream;
// Resource set to retain all associated resources by the payload.
iree_hal_resource_set_t* resource_set;
@@ -186,31 +184,140 @@
//===----------------------------------------------------------------------===//
// Ready action atomic slist entry struct.
-typedef struct iree_hal_cuda_atomic_slist_entry_t {
+typedef struct iree_hal_cuda_entry_list_node_t {
iree_hal_cuda_queue_action_t* ready_list_head;
- iree_atomic_slist_intrusive_ptr_t slist_next;
-} iree_hal_cuda_atomic_slist_entry_t;
+ struct iree_hal_cuda_entry_list_node_t* next;
+} iree_hal_cuda_entry_list_node_t;
-// Ready action atomic slist.
-IREE_TYPED_ATOMIC_SLIST_WRAPPER(iree_hal_cuda_ready_action,
- iree_hal_cuda_atomic_slist_entry_t,
- offsetof(iree_hal_cuda_atomic_slist_entry_t,
- slist_next));
+typedef struct iree_hal_cuda_entry_list_t {
+ iree_slim_mutex_t guard_mutex;
-static void iree_hal_cuda_ready_action_slist_destroy(
- iree_hal_cuda_ready_action_slist_t* list, iree_allocator_t host_allocator) {
- while (true) {
- iree_hal_cuda_atomic_slist_entry_t* entry =
- iree_hal_cuda_ready_action_slist_pop(list);
- if (!entry) break;
- iree_hal_cuda_queue_action_list_destroy(entry->ready_list_head);
- iree_allocator_free(host_allocator, entry);
+ iree_hal_cuda_entry_list_node_t* head IREE_GUARDED_BY(guard_mutex);
+ iree_hal_cuda_entry_list_node_t* tail IREE_GUARDED_BY(guard_mutex);
+} iree_hal_cuda_entry_list_t;
+
+static iree_hal_cuda_entry_list_node_t* iree_hal_cuda_entry_list_pop(
+ iree_hal_cuda_entry_list_t* list) {
+ iree_hal_cuda_entry_list_node_t* out = NULL;
+ iree_slim_mutex_lock(&list->guard_mutex);
+ if (list->head) {
+ out = list->head;
+ list->head = list->head->next;
+ if (out == list->tail) {
+ list->tail = NULL;
+ }
}
- iree_hal_cuda_ready_action_slist_deinitialize(list);
+ iree_slim_mutex_unlock(&list->guard_mutex);
+ return out;
+}
+
+void iree_hal_cuda_entry_list_push(iree_hal_cuda_entry_list_t* list,
+ iree_hal_cuda_entry_list_node_t* next) {
+ iree_slim_mutex_lock(&list->guard_mutex);
+ next->next = NULL;
+ if (list->tail) {
+ list->tail->next = next;
+ list->tail = next;
+ } else {
+ list->head = next;
+ list->tail = next;
+ }
+ iree_slim_mutex_unlock(&list->guard_mutex);
+}
+
+static void iree_hal_cuda_ready_action_list_deinitialize(
+ iree_hal_cuda_entry_list_t* list, iree_allocator_t host_allocator) {
+ iree_hal_cuda_entry_list_node_t* head = list->head;
+ while (head) {
+ if (!head) break;
+ iree_hal_cuda_queue_action_list_destroy(head->ready_list_head);
+ list->head = head->next;
+ iree_allocator_free(host_allocator, head);
+ }
+ iree_slim_mutex_deinitialize(&list->guard_mutex);
+}
+
+static void iree_hal_cuda_ready_action_list_initialize(
+ iree_hal_cuda_entry_list_t* list) {
+ list->head = NULL;
+ list->tail = NULL;
+ iree_slim_mutex_initialize(&list->guard_mutex);
+}
+
+// Ready action atomic slist entry struct.
+typedef struct iree_hal_cuda_completion_list_node_t {
+ // The callback and user data for that callback. To be called
+ // when the associated event has completed.
+ iree_status_t (*callback)(void* user_data);
+ void* user_data;
+ // The event to wait for on the completion thread.
+ CUevent event;
+ // If this event was created just for the completion thread, and therefore
+ // needs to be cleaned up.
+ bool created_event;
+ struct iree_hal_cuda_completion_list_node_t* next;
+} iree_hal_cuda_completion_list_node_t;
+
+typedef struct iree_hal_cuda_completion_list_t {
+ iree_slim_mutex_t guard_mutex;
+ iree_hal_cuda_completion_list_node_t* head IREE_GUARDED_BY(guard_mutex);
+ iree_hal_cuda_completion_list_node_t* tail IREE_GUARDED_BY(guard_mutex);
+} iree_hal_cuda_completion_list_t;
+
+static iree_hal_cuda_completion_list_node_t* iree_hal_cuda_completion_list_pop(
+ iree_hal_cuda_completion_list_t* list) {
+ iree_hal_cuda_completion_list_node_t* out = NULL;
+ iree_slim_mutex_lock(&list->guard_mutex);
+ if (list->head) {
+ out = list->head;
+ list->head = list->head->next;
+ if (out == list->tail) {
+ list->tail = NULL;
+ }
+ }
+ iree_slim_mutex_unlock(&list->guard_mutex);
+ return out;
+}
+
+void iree_hal_cuda_completion_list_push(
+ iree_hal_cuda_completion_list_t* list,
+ iree_hal_cuda_completion_list_node_t* next) {
+ iree_slim_mutex_lock(&list->guard_mutex);
+ next->next = NULL;
+ if (list->tail) {
+ list->tail->next = next;
+ list->tail = next;
+ } else {
+ list->head = next;
+ list->tail = next;
+ }
+ iree_slim_mutex_unlock(&list->guard_mutex);
+}
+
+static void iree_hal_cuda_completion_list_initialize(
+ iree_hal_cuda_completion_list_t* list) {
+ list->head = NULL;
+ list->tail = NULL;
+ iree_slim_mutex_initialize(&list->guard_mutex);
+}
+
+static void iree_hal_cuda_completion_list_deinitialize(
+ iree_hal_cuda_completion_list_t* list,
+ const iree_hal_cuda_dynamic_symbols_t* symbols,
+ iree_allocator_t host_allocator) {
+ iree_hal_cuda_completion_list_node_t* head = list->head;
+ while (head) {
+ if (head->created_event) {
+ IREE_CUDA_IGNORE_ERROR(symbols, cuEventDestroy(head->event));
+ }
+ list->head = list->head->next;
+ iree_allocator_free(host_allocator, head);
+ }
+ iree_slim_mutex_deinitialize(&list->guard_mutex);
}
static iree_hal_cuda_queue_action_t* iree_hal_cuda_atomic_slist_entry_pop_front(
- iree_hal_cuda_atomic_slist_entry_t* list) {
+ iree_hal_cuda_entry_list_node_t* list) {
IREE_ASSERT(list->ready_list_head);
iree_hal_cuda_queue_action_t* action = list->ready_list_head;
@@ -255,8 +362,8 @@
// Notification to the parent thread to indicate the worker committed exiting.
// TODO: maybe remove this. We can just wait on the worker thread to exit.
iree_notification_t exit_notification;
- iree_hal_cuda_ready_action_slist_t ready_worklist; // atomic
- iree_atomic_int32_t worker_state; // atomic
+ iree_hal_cuda_entry_list_t ready_worklist;
+ iree_atomic_int32_t worker_state; // atomic
// TODO: use status to provide more context for the error.
iree_atomic_intptr_t error_code; // atomic
@@ -272,14 +379,49 @@
IREE_GUARDED_BY(pending_work_items_count_mutex);
iree_allocator_t host_allocator; // const
+
+ const iree_hal_cuda_dynamic_symbols_t* symbols;
+ CUcontext context;
} iree_hal_cuda_working_area_t;
+// This data structure is shared by the parent thread. It is responsible
+// for dispatching callbacks when work items complete.
+
+// This replaces the use of cuLaunchHostFunc, which causes the stream to block
+// and wait for the CPU work to complete. It also picks up completed
+// events with significantly less latency than cuLaunchHostFunc.
+
+typedef struct iree_hal_cuda_completion_area_t {
+ // Notification from the parent thread to request completion state changes.
+ iree_notification_t state_notification;
+ // Notification to the parent thread to indicate the worker committed exiting.
+ iree_notification_t exit_notification;
+ iree_hal_cuda_completion_list_t completion_list;
+ iree_atomic_int32_t worker_state; // atomic
+
+ iree_atomic_intptr_t error_code; // atomic
+
+ // The number of asynchronous completions items that are scheduled and not
+ // yet waited on.
+ // We need to wait for them to finish before destroying the context.
+ iree_slim_mutex_t pending_completion_count_mutex;
+ iree_notification_t pending_completion_count_notification;
+ int32_t pending_completion_count
+ IREE_GUARDED_BY(pending_completion_count_mutex);
+
+ iree_allocator_t host_allocator;
+
+ const iree_hal_cuda_dynamic_symbols_t* symbols;
+ CUcontext context;
+} iree_hal_cuda_completion_area_t;
+
static void iree_hal_cuda_working_area_initialize(
- iree_allocator_t host_allocator,
+ iree_allocator_t host_allocator, CUcontext context,
+ const iree_hal_cuda_dynamic_symbols_t* symbols,
iree_hal_cuda_working_area_t* working_area) {
iree_notification_initialize(&working_area->state_notification);
iree_notification_initialize(&working_area->exit_notification);
- iree_hal_cuda_ready_action_slist_initialize(&working_area->ready_worklist);
+ iree_hal_cuda_ready_action_list_initialize(&working_area->ready_worklist);
iree_atomic_store_int32(&working_area->worker_state,
IREE_HAL_CUDA_WORKER_STATE_IDLE_WAITING,
iree_memory_order_release);
@@ -290,12 +432,14 @@
&working_area->pending_work_items_count_notification);
working_area->pending_work_items_count = 0;
working_area->host_allocator = host_allocator;
+ working_area->symbols = symbols;
+ working_area->context = context;
}
static void iree_hal_cuda_working_area_deinitialize(
iree_hal_cuda_working_area_t* working_area) {
- iree_hal_cuda_ready_action_slist_destroy(&working_area->ready_worklist,
- working_area->host_allocator);
+ iree_hal_cuda_ready_action_list_deinitialize(&working_area->ready_worklist,
+ working_area->host_allocator);
iree_notification_deinitialize(&working_area->exit_notification);
iree_notification_deinitialize(&working_area->state_notification);
iree_slim_mutex_deinitialize(&working_area->pending_work_items_count_mutex);
@@ -303,10 +447,47 @@
&working_area->pending_work_items_count_notification);
}
+static void iree_hal_cuda_completion_area_initialize(
+ iree_allocator_t host_allocator, CUcontext context,
+ const iree_hal_cuda_dynamic_symbols_t* symbols,
+ iree_hal_cuda_completion_area_t* completion_area) {
+ iree_notification_initialize(&completion_area->state_notification);
+ iree_notification_initialize(&completion_area->exit_notification);
+ iree_hal_cuda_completion_list_initialize(&completion_area->completion_list);
+ iree_atomic_store_int32(&completion_area->worker_state,
+ IREE_HAL_CUDA_WORKER_STATE_IDLE_WAITING,
+ iree_memory_order_release);
+ iree_atomic_store_int32(&completion_area->error_code, IREE_STATUS_OK,
+ iree_memory_order_release);
+ iree_slim_mutex_initialize(&completion_area->pending_completion_count_mutex);
+ iree_notification_initialize(
+ &completion_area->pending_completion_count_notification);
+ completion_area->pending_completion_count = 0;
+ completion_area->host_allocator = host_allocator;
+ completion_area->symbols = symbols;
+ completion_area->context = context;
+}
+
+static void iree_hal_cuda_completion_area_deinitialize(
+ iree_hal_cuda_completion_area_t* completion_area) {
+ iree_hal_cuda_completion_list_deinitialize(&completion_area->completion_list,
+ completion_area->symbols,
+ completion_area->host_allocator);
+ iree_notification_deinitialize(&completion_area->exit_notification);
+ iree_notification_deinitialize(&completion_area->state_notification);
+ iree_slim_mutex_deinitialize(
+ &completion_area->pending_completion_count_mutex);
+ iree_notification_deinitialize(
+ &completion_area->pending_completion_count_notification);
+}
+
// The main function for the ready-list processing worker thread.
static int iree_hal_cuda_worker_execute(
iree_hal_cuda_working_area_t* working_area);
+static int iree_hal_cuda_completion_execute(
+ iree_hal_cuda_completion_area_t* working_area);
+
//===----------------------------------------------------------------------===//
// Pending queue actions
//===----------------------------------------------------------------------===//
@@ -333,16 +514,28 @@
// The worker thread that monitors incoming requests and issues ready actions
// to the GPU.
iree_thread_t* worker_thread;
+
+ // Worker thread to wait on completion events instead of running
+ // synchronous completion callbacks
+ iree_thread_t* completion_thread;
+
// The worker's working area; data exchange place with the parent thread.
iree_hal_cuda_working_area_t working_area;
+
+ // Completion thread's working area.
+ iree_hal_cuda_completion_area_t completion_area;
+
+ // The associated cuda device.
+ CUdevice device;
};
static const iree_hal_resource_vtable_t
iree_hal_cuda_pending_queue_actions_vtable;
iree_status_t iree_hal_cuda_pending_queue_actions_create(
- const iree_hal_cuda_dynamic_symbols_t* symbols,
- iree_arena_block_pool_t* block_pool, iree_allocator_t host_allocator,
+ const iree_hal_cuda_dynamic_symbols_t* symbols, CUdevice device,
+ CUcontext context, iree_arena_block_pool_t* block_pool,
+ iree_allocator_t host_allocator,
iree_hal_cuda_pending_queue_actions_t** out_actions) {
IREE_ASSERT_ARGUMENT(symbols);
IREE_ASSERT_ARGUMENT(block_pool);
@@ -358,12 +551,18 @@
actions->host_allocator = host_allocator;
actions->block_pool = block_pool;
actions->symbols = symbols;
+ actions->device = device;
iree_slim_mutex_initialize(&actions->action_mutex);
memset(&actions->action_list, 0, sizeof(actions->action_list));
// Initialize the working area for the ready-list processing worker.
iree_hal_cuda_working_area_t* working_area = &actions->working_area;
- iree_hal_cuda_working_area_initialize(host_allocator, working_area);
+ iree_hal_cuda_working_area_initialize(host_allocator, context, symbols,
+ working_area);
+
+ iree_hal_cuda_completion_area_t* completion_area = &actions->completion_area;
+ iree_hal_cuda_completion_area_initialize(host_allocator, context, symbols,
+ completion_area);
// Create the ready-list processing worker itself.
iree_thread_create_params_t params;
@@ -374,6 +573,14 @@
(iree_thread_entry_t)iree_hal_cuda_worker_execute, working_area, params,
actions->host_allocator, &actions->worker_thread);
+ params.name = IREE_SV("done_worker");
+ params.create_suspended = false;
+ if (iree_status_is_ok(status)) {
+ status = iree_thread_create(
+ (iree_thread_entry_t)iree_hal_cuda_completion_execute, completion_area,
+ params, actions->host_allocator, &actions->completion_thread);
+ }
+
if (iree_status_is_ok(status)) {
*out_actions = actions;
} else {
@@ -392,12 +599,16 @@
static bool iree_hal_cuda_worker_committed_exiting(
iree_hal_cuda_working_area_t* working_area);
+static bool iree_hal_cuda_completion_committed_exiting(
+ iree_hal_cuda_completion_area_t* working_area);
+
void iree_hal_cuda_pending_queue_actions_destroy(
iree_hal_resource_t* base_actions) {
iree_hal_cuda_pending_queue_actions_t* actions =
iree_hal_cuda_pending_queue_actions_cast(base_actions);
iree_allocator_t host_allocator = actions->host_allocator;
iree_hal_cuda_working_area_t* working_area = &actions->working_area;
+ iree_hal_cuda_completion_area_t* completion_area = &actions->completion_area;
IREE_TRACE_ZONE_BEGIN(z0);
// Request the worker to exit.
@@ -420,6 +631,25 @@
iree_thread_release(actions->worker_thread);
iree_hal_cuda_working_area_deinitialize(working_area);
+ // Request the completion thread to exit.
+ prev_state = (iree_hal_cuda_worker_state_t)iree_atomic_exchange_int32(
+ &completion_area->worker_state, IREE_HAL_CUDA_WORKER_STATE_EXIT_REQUESTED,
+ iree_memory_order_acq_rel);
+ iree_notification_post(&completion_area->state_notification,
+ IREE_ALL_WAITERS);
+
+ // Check potential exit states from the completion thread.
+ if (prev_state != IREE_HAL_CUDA_WORKER_STATE_EXIT_ERROR) {
+ // Wait until the completion thread acknowledged exiting.
+ iree_notification_await(
+ &completion_area->exit_notification,
+ (iree_condition_fn_t)iree_hal_cuda_completion_committed_exiting,
+ completion_area, iree_infinite_timeout());
+ }
+
+ iree_thread_release(actions->completion_thread);
+ iree_hal_cuda_completion_area_deinitialize(completion_area);
+
iree_slim_mutex_deinitialize(&actions->action_mutex);
iree_hal_cuda_queue_action_list_destroy(actions->action_list.head);
iree_allocator_free(host_allocator, actions);
@@ -468,9 +698,23 @@
iree_slim_mutex_unlock(&working_area->pending_work_items_count_mutex);
}
+static void iree_hal_cuda_queue_decrement_completion_count(
+ iree_hal_cuda_completion_area_t* completion_area) {
+ iree_slim_mutex_lock(&completion_area->pending_completion_count_mutex);
+ --completion_area->pending_completion_count;
+ if (completion_area->pending_completion_count == 0) {
+ // Notify inside the lock to make sure that we are done touching anything
+ // since the context may get destroyed in the meantime.
+ iree_notification_post(
+ &completion_area->pending_completion_count_notification,
+ IREE_ALL_WAITERS);
+ }
+ iree_slim_mutex_unlock(&completion_area->pending_completion_count_mutex);
+}
+
iree_status_t iree_hal_cuda_pending_queue_actions_enqueue_execution(
iree_hal_device_t* device, CUstream dispatch_stream,
- CUstream callback_stream, iree_hal_cuda_pending_queue_actions_t* actions,
+ iree_hal_cuda_pending_queue_actions_t* actions,
iree_hal_cuda_pending_action_cleanup_callback_t cleanup_callback,
void* callback_user_data,
const iree_hal_semaphore_list_t wait_semaphore_list,
@@ -519,7 +763,6 @@
action->kind = IREE_HAL_CUDA_QUEUE_ACTION_TYPE_EXECUTION;
action->device = device;
action->dispatch_cu_stream = dispatch_stream;
- action->callback_cu_stream = callback_stream;
// Initialize scratch fields.
action->event_count = 0;
@@ -645,13 +888,27 @@
iree_notification_post(&working_area->state_notification, IREE_ALL_WAITERS);
}
+static void iree_hal_cuda_post_error_to_completion_state(
+ iree_hal_cuda_completion_area_t* completion_area, iree_status_code_t code) {
+ // Write error code, but don't overwrite existing error codes.
+ intptr_t prev_error_code = IREE_STATUS_OK;
+ iree_atomic_compare_exchange_strong_int32(
+ &completion_area->error_code, /*expected=*/&prev_error_code,
+ /*desired=*/code,
+ /*order_succ=*/iree_memory_order_acq_rel,
+ /*order_fail=*/iree_memory_order_acquire);
+
+ // This state has the highest priority so just overwrite.
+ iree_atomic_store_int32(&completion_area->worker_state,
+ IREE_HAL_CUDA_WORKER_STATE_EXIT_ERROR,
+ iree_memory_order_release);
+ iree_notification_post(&completion_area->state_notification,
+ IREE_ALL_WAITERS);
+}
+
// Releases resources after action completion on the GPU and advances timeline
// and pending actions queue.
-//
-// This is the CUDA host function callback to cudaLaunchHostFunc(), invoked by a
-// CUDA driver thread. Note that code in this function MUST NOT invoke any GPU
-// API under the hood to avoid potential deadlock.
-static void iree_hal_cuda_execution_device_signal_host_callback(
+static iree_status_t iree_hal_cuda_execution_device_signal_host_callback(
void* user_data) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_cuda_queue_action_t* action =
@@ -694,6 +951,7 @@
iree_hal_cuda_queue_decrement_work_items_count(&actions->working_area);
IREE_TRACE_ZONE_END(z0);
+ return status;
}
// Issues the given kernel dispatch |action| to the GPU.
@@ -770,6 +1028,7 @@
}
IREE_TRACE_ZONE_END(z_dispatch_command_buffers);
+ CUevent completion_event = NULL;
// Last record CUevent signals in the dispatch stream.
for (iree_host_size_t i = 0; i < action->signal_semaphore_list.count; ++i) {
// Grab a CUevent for this semaphore value signaling.
@@ -783,17 +1042,28 @@
IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR(
z0, symbols, cuEventRecord(event, action->dispatch_cu_stream),
"cuEventRecord");
- // Let the callback stream to wait on the CUevent.
- IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR(
- z0, symbols,
- cuStreamWaitEvent(action->callback_cu_stream, event,
- CU_EVENT_WAIT_DEFAULT),
- "cuStreamWaitEvent");
+ completion_event = event;
}
+ bool created_event = false;
+ // In the case where we issue an execution and there are signal semaphores
+ // we can re-use those as a wait event. However if there are no signals
+ // then we create one. In my testing this is not a common case.
+ if (IREE_UNLIKELY(!completion_event)) {
+ IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, symbols, cuEventCreate(&completion_event, CU_EVENT_DISABLE_TIMING),
+ "cuEventCreate");
+ created_event = true;
+ }
+
+ IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, symbols, cuEventRecord(completion_event, action->dispatch_cu_stream),
+ "cuEventRecord");
+
iree_slim_mutex_lock(
&action->owning_actions->working_area.pending_work_items_count_mutex);
- // One work item is the host stream callback.
+ // One work item is the callback that makes it across from the
+ // completion thread.
// The other is the cleanup of the action.
action->owning_actions->working_area.pending_work_items_count += 2;
iree_slim_mutex_unlock(
@@ -801,12 +1071,55 @@
// Now launch a host function on the callback stream to advance the semaphore
// timeline.
- IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR(
- z0, symbols,
- cuLaunchHostFunc(action->callback_cu_stream,
- iree_hal_cuda_execution_device_signal_host_callback,
- action),
- "cuLaunchHostFunc");
+ iree_hal_cuda_completion_list_node_t* entry = NULL;
+ // TODO: avoid host allocator malloc; use some pool for the allocation.
+ iree_status_t status = iree_allocator_malloc(
+ action->owning_actions->host_allocator, sizeof(*entry), (void**)&entry);
+
+ if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+ }
+
+ // Now push the ready list to the worker and have it to issue the actions to
+ // the GPU.
+ entry->event = completion_event;
+ entry->created_event = created_event;
+ entry->callback = iree_hal_cuda_execution_device_signal_host_callback;
+ entry->user_data = action;
+ iree_hal_cuda_completion_list_push(
+ &action->owning_actions->completion_area.completion_list, entry);
+
+ iree_slim_mutex_lock(
+ &action->owning_actions->completion_area.pending_completion_count_mutex);
+
+ action->owning_actions->completion_area.pending_completion_count += 1;
+
+ iree_slim_mutex_unlock(
+ &action->owning_actions->completion_area.pending_completion_count_mutex);
+
+ // We can only overwrite the worker state if the previous state is idle
+ // waiting; we cannot overwrite exit related states. so we need to perform
+ // atomic compare and exchange here.
+ iree_hal_cuda_worker_state_t prev_state =
+ IREE_HAL_CUDA_WORKER_STATE_IDLE_WAITING;
+ iree_atomic_compare_exchange_strong_int32(
+ &action->owning_actions->completion_area.worker_state,
+ /*expected=*/&prev_state,
+ /*desired=*/IREE_HAL_CUDA_WORKER_STATE_WORKLOAD_PENDING,
+ /*order_succ=*/iree_memory_order_acq_rel,
+ /*order_fail=*/iree_memory_order_acquire);
+ iree_notification_post(
+ &action->owning_actions->completion_area.state_notification,
+ IREE_ALL_WAITERS);
+
+ // Handle potential error cases from the worker thread.
+ if (prev_state == IREE_HAL_CUDA_WORKER_STATE_EXIT_ERROR) {
+ iree_status_code_t code = iree_atomic_load_int32(
+ &action->owning_actions->completion_area.error_code,
+ iree_memory_order_acquire);
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_status_from_code(code));
+ }
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
@@ -928,7 +1241,7 @@
return status;
}
- iree_hal_cuda_atomic_slist_entry_t* entry = NULL;
+ iree_hal_cuda_entry_list_node_t* entry = NULL;
// TODO: avoid host allocator malloc; use some pool for the allocation.
if (iree_status_is_ok(status)) {
status = iree_allocator_malloc(actions->host_allocator, sizeof(*entry),
@@ -946,8 +1259,7 @@
// Now push the ready list to the worker and have it to issue the actions to
// the GPU.
entry->ready_list_head = ready_list.head;
- iree_hal_cuda_ready_action_slist_push(&actions->working_area.ready_worklist,
- entry);
+ iree_hal_cuda_entry_list_push(&actions->working_area.ready_worklist, entry);
// We can only overwrite the worker state if the previous state is idle
// waiting; we cannot overwrite exit related states. so we need to perform
@@ -986,6 +1298,15 @@
value == IREE_HAL_CUDA_WORKER_STATE_EXIT_ERROR;
}
+static bool iree_hal_cuda_completion_has_incoming_request_or_error(
+ iree_hal_cuda_completion_area_t* completion_area) {
+ iree_hal_cuda_worker_state_t value = iree_atomic_load_int32(
+ &completion_area->worker_state, iree_memory_order_acquire);
+ return value == IREE_HAL_CUDA_WORKER_STATE_WORKLOAD_PENDING ||
+ value == IREE_HAL_CUDA_WORKER_STATE_EXIT_REQUESTED ||
+ value == IREE_HAL_CUDA_WORKER_STATE_EXIT_ERROR;
+}
+
static bool iree_hal_cuda_worker_committed_exiting(
iree_hal_cuda_working_area_t* working_area) {
return iree_atomic_load_int32(&working_area->worker_state,
@@ -993,16 +1314,22 @@
IREE_HAL_CUDA_WORKER_STATE_EXIT_COMMITTED;
}
+static bool iree_hal_cuda_completion_committed_exiting(
+ iree_hal_cuda_completion_area_t* working_area) {
+ return iree_atomic_load_int32(&working_area->worker_state,
+ iree_memory_order_acquire) ==
+ IREE_HAL_CUDA_WORKER_STATE_EXIT_COMMITTED;
+}
+
// Processes all ready actions in the given |worklist|.
static iree_status_t iree_hal_cuda_worker_process_ready_list(
- iree_allocator_t host_allocator,
- iree_hal_cuda_ready_action_slist_t* worklist) {
+ iree_allocator_t host_allocator, iree_hal_cuda_entry_list_t* worklist) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_status_t status = iree_ok_status();
while (true) {
- iree_hal_cuda_atomic_slist_entry_t* entry =
- iree_hal_cuda_ready_action_slist_pop(worklist);
+ iree_hal_cuda_entry_list_node_t* entry =
+ iree_hal_cuda_entry_list_pop(worklist);
if (!entry) break;
// Process the current batch of ready actions.
@@ -1028,7 +1355,7 @@
// Let common destruction path take care of destroying the worklist.
// When we know all host stream callbacks are done and not touching
// anything.
- iree_hal_cuda_ready_action_slist_push(worklist, entry);
+ iree_hal_cuda_entry_list_push(worklist, entry);
break;
}
@@ -1047,6 +1374,14 @@
return result;
}
+static bool iree_hal_cuda_completion_has_no_pending_completion_items(
+ iree_hal_cuda_completion_area_t* completion_area) {
+ iree_slim_mutex_lock(&completion_area->pending_completion_count_mutex);
+ bool result = (completion_area->pending_completion_count == 0);
+ iree_slim_mutex_unlock(&completion_area->pending_completion_count_mutex);
+ return result;
+}
+
// Wait for all work items to finish.
static void iree_hal_cuda_worker_wait_pending_work_items(
iree_hal_cuda_working_area_t* working_area) {
@@ -1060,10 +1395,164 @@
iree_slim_mutex_unlock(&working_area->pending_work_items_count_mutex);
}
+// Wait for all work items to finish.
+static void iree_hal_cuda_completion_wait_pending_completion_items(
+ iree_hal_cuda_completion_area_t* completion_area) {
+ iree_notification_await(
+ &completion_area->pending_completion_count_notification,
+ (iree_condition_fn_t)
+ iree_hal_cuda_completion_has_no_pending_completion_items,
+ completion_area, iree_infinite_timeout());
+ // Lock then unlock to make sure that all callbacks are really done.
+ // Not even touching the notification.
+ iree_slim_mutex_lock(&completion_area->pending_completion_count_mutex);
+ iree_slim_mutex_unlock(&completion_area->pending_completion_count_mutex);
+}
+
+static iree_status_t iree_hal_cuda_worker_process_completion(
+ iree_hal_cuda_completion_list_t* worklist,
+ iree_hal_cuda_completion_area_t* completion_area) {
+ iree_status_t status = iree_ok_status();
+ while (true) {
+ iree_hal_cuda_completion_list_node_t* entry =
+ iree_hal_cuda_completion_list_pop(worklist);
+ if (!entry) break;
+
+ IREE_TRACE_ZONE_BEGIN_NAMED(z1, "cuEventSynchronize");
+ CUresult result =
+ completion_area->symbols->cuEventSynchronize(entry->event);
+ IREE_TRACE_ZONE_END(z1);
+ if (IREE_UNLIKELY(result != CUDA_SUCCESS)) {
+ // Let common destruction path take care of destroying the worklist.
+ // When we know all host stream callbacks are done and not touching
+ // anything.
+ iree_hal_cuda_completion_list_push(worklist, entry);
+ status =
+ iree_make_status(IREE_STATUS_ABORTED, "could not wait on cuda event");
+ break;
+ }
+ status = entry->callback(entry->user_data);
+ if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
+ break;
+ }
+
+ if (IREE_UNLIKELY(entry->created_event)) {
+ IREE_CUDA_IGNORE_ERROR(completion_area->symbols,
+ cuEventDestroy(entry->event));
+ }
+ iree_allocator_free(completion_area->host_allocator, entry);
+
+ // Now we fully executed and cleaned up this entry. Decrease the work
+ // items counter.
+ iree_hal_cuda_queue_decrement_completion_count(completion_area);
+ }
+ return status;
+}
+
+// The main function for the completion worker thread.
+static int iree_hal_cuda_completion_execute(
+ iree_hal_cuda_completion_area_t* completion_area) {
+ iree_hal_cuda_completion_list_t* worklist = &completion_area->completion_list;
+
+ iree_status_t status = IREE_CURESULT_TO_STATUS(
+ completion_area->symbols, cuCtxSetCurrent(completion_area->context),
+ "cuCtxSetCurrent");
+ if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
+ iree_hal_cuda_completion_wait_pending_completion_items(completion_area);
+ iree_hal_cuda_post_error_to_completion_state(completion_area,
+ iree_status_code(status));
+ return -1;
+ }
+
+ while (true) {
+ iree_notification_await(
+ &completion_area->state_notification,
+ (iree_condition_fn_t)
+ iree_hal_cuda_completion_has_incoming_request_or_error,
+ completion_area, iree_infinite_timeout());
+
+ IREE_TRACE_ZONE_BEGIN(z0);
+ // Immediately flip the state to idle waiting if and only if the previous
+ // state is workload pending. We do it before processing ready list to make
+ // sure that we don't accidentally ignore new workload pushed after done
+ // ready list processing but before overwriting the state from this worker
+ // thread. Also we don't want to overwrite other exit states. So we need to
+ // perform atomic compare and exchange here.
+ iree_hal_cuda_worker_state_t prev_state =
+ IREE_HAL_CUDA_WORKER_STATE_WORKLOAD_PENDING;
+ iree_atomic_compare_exchange_strong_int32(
+ &completion_area->worker_state, /*expected=*/&prev_state,
+ /*desired=*/IREE_HAL_CUDA_WORKER_STATE_IDLE_WAITING,
+ /*order_succ=*/iree_memory_order_acq_rel,
+ /*order_fail=*/iree_memory_order_acquire);
+
+ int32_t worker_state = iree_atomic_load_int32(
+ &completion_area->worker_state, iree_memory_order_acquire);
+ // Exit if CUDA callbacks have posted any errors.
+ if (IREE_UNLIKELY(worker_state == IREE_HAL_CUDA_WORKER_STATE_EXIT_ERROR)) {
+ iree_hal_cuda_completion_wait_pending_completion_items(completion_area);
+ IREE_TRACE_ZONE_END(z0);
+ return -1;
+ }
+ // Check if we received request to stop processing and exit this thread.
+ bool should_exit =
+ (worker_state == IREE_HAL_CUDA_WORKER_STATE_EXIT_REQUESTED);
+
+ iree_status_t status =
+ iree_hal_cuda_worker_process_completion(worklist, completion_area);
+
+ if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
+ iree_hal_cuda_completion_wait_pending_completion_items(completion_area);
+ iree_hal_cuda_post_error_to_completion_state(completion_area,
+ iree_status_code(status));
+ IREE_TRACE_ZONE_END(z0);
+ return -1;
+ }
+
+ if (IREE_UNLIKELY(should_exit &&
+ iree_hal_cuda_completion_has_no_pending_completion_items(
+ completion_area))) {
+ iree_hal_cuda_completion_wait_pending_completion_items(completion_area);
+ // Signal that this thread is committed to exit.
+ // This state has a priority that is only lower than error exit.
+ // A CUDA callback may have posted an error, make sure we don't
+ // overwrite this error state.
+ iree_hal_cuda_worker_state_t prev_state =
+ IREE_HAL_CUDA_WORKER_STATE_EXIT_REQUESTED;
+ iree_atomic_compare_exchange_strong_int32(
+ &completion_area->worker_state, /*expected=*/&prev_state,
+ /*desired=*/IREE_HAL_CUDA_WORKER_STATE_EXIT_COMMITTED,
+ /*order_succ=*/iree_memory_order_acq_rel,
+ /*order_fail=*/iree_memory_order_acquire);
+ iree_notification_post(&completion_area->exit_notification,
+ IREE_ALL_WAITERS);
+ IREE_TRACE_ZONE_END(z0);
+ return 0;
+ }
+ IREE_TRACE_ZONE_END(z0);
+ }
+
+ return 0;
+}
+
// The main function for the ready-list processing worker thread.
static int iree_hal_cuda_worker_execute(
iree_hal_cuda_working_area_t* working_area) {
- iree_hal_cuda_ready_action_slist_t* worklist = &working_area->ready_worklist;
+ iree_hal_cuda_entry_list_t* worklist = &working_area->ready_worklist;
+
+ // Cuda stores thread-local data based on the device. Some cuda commands pull
+ // the device from there, this will cause failures when using it with other
+ // devices (or streams from other devices). Force the correct device onto
+ // this thread.
+ iree_status_t status = IREE_CURESULT_TO_STATUS(
+ working_area->symbols, cuCtxSetCurrent(working_area->context),
+ "cuCtxSetCurrent");
+ if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
+ iree_hal_cuda_worker_wait_pending_work_items(working_area);
+ iree_hal_cuda_post_error_to_worker_state(working_area,
+ iree_status_code(status));
+ return -1;
+ }
while (true) {
// Block waiting for incoming requests.
diff --git a/runtime/src/iree/hal/drivers/cuda/pending_queue_actions.h b/runtime/src/iree/hal/drivers/cuda/pending_queue_actions.h
index b889428..fa16e1f 100644
--- a/runtime/src/iree/hal/drivers/cuda/pending_queue_actions.h
+++ b/runtime/src/iree/hal/drivers/cuda/pending_queue_actions.h
@@ -38,8 +38,9 @@
// Creates a pending actions queue.
iree_status_t iree_hal_cuda_pending_queue_actions_create(
- const iree_hal_cuda_dynamic_symbols_t* symbols,
- iree_arena_block_pool_t* block_pool, iree_allocator_t host_allocator,
+ const iree_hal_cuda_dynamic_symbols_t* symbols, CUdevice device,
+ CUcontext context, iree_arena_block_pool_t* block_pool,
+ iree_allocator_t host_allocator,
iree_hal_cuda_pending_queue_actions_t** out_actions);
// Destroys the pending |actions| queue.
@@ -59,7 +60,7 @@
// before releasing all retained resources.
iree_status_t iree_hal_cuda_pending_queue_actions_enqueue_execution(
iree_hal_device_t* device, CUstream dispatch_stream,
- CUstream callback_stream, iree_hal_cuda_pending_queue_actions_t* actions,
+ iree_hal_cuda_pending_queue_actions_t* actions,
iree_hal_cuda_pending_action_cleanup_callback_t cleanup_callback,
void* callback_user_data,
const iree_hal_semaphore_list_t wait_semaphore_list,