[runtime][cuda] Propagate errors through semaphores (#18095)
Mirror the change 3fd336f2850e3965a26fab27a70e160d87588552 from the HAL
HIP driver.
Adds proper handling of errors that occur when executing operations on
the device or when a semaphore fails in the wait list. These errors will
propagate to downstream semaphores that are in the signal list of the
operation.
This change includes some refactoring of the pending queue actions:
* Make the context hold a sticky status instead of just an status code.
* Remove the worker threads' "error" state. This can be handled by the
context's status.
* Remove "exit committed" thread state in favor of standard thread
joining.
* Make the "exit requested" thread state a separate boolean variable and
guard against submitting more work after an exit is requested. Also wait
on all work to complete before exiting worker threads, not just on the
currently ran actions.
* Make pending work items increment immediately when an action is
enqueued instead of when scheduled on the CUDA stream. This is required
to properly count outstanding work.
* Remove and merge some of the redundant state for the worker and
completion threads.
* Remove reference counting from the pending queue actions context. It
has a clear owner, which is the device.
* Rework when the threads exit, which is pretty much only when exit is
requested and there is no more queued or executing actions. Errors don't
cause the threads to exit.
Here is not included moving the destruction and cleanup of actions from
the worker thread to the completion thread. This is an optimization and
code simplification that is now possible since we are not using HIP
stream callbacks, so we could do that right after an action completes.
Technically, actions get destroyed on the completion thread as well when
not on the happy path and actions fail.
diff --git a/runtime/src/iree/hal/drivers/cuda/cts/CMakeLists.txt b/runtime/src/iree/hal/drivers/cuda/cts/CMakeLists.txt
index 99f2bb6..ae6843e 100644
--- a/runtime/src/iree/hal/drivers/cuda/cts/CMakeLists.txt
+++ b/runtime/src/iree/hal/drivers/cuda/cts/CMakeLists.txt
@@ -4,16 +4,6 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-unset(FILTER_TESTS)
-string(APPEND FILTER_TESTS "SemaphoreTest.WaitThenFail:")
-string(APPEND FILTER_TESTS "SemaphoreTest.FailThenWait:")
-string(APPEND FILTER_TESTS "SemaphoreTest.MultiWaitThenFail:")
-string(APPEND FILTER_TESTS "SemaphoreTest.DeviceMultiWaitThenFail:")
-string(APPEND FILTER_TESTS "SemaphoreSubmissionTest.PropagateFailSignal:")
-set(FILTER_TESTS_ARGS
- "--gtest_filter=-${FILTER_TESTS}"
-)
-
iree_hal_cts_test_suite(
DRIVER_NAME
cuda
@@ -29,7 +19,6 @@
"\"PTXE\""
ARGS
"--cuda_use_streams=false"
- ${FILTER_TESTS_ARGS}
DEPS
iree::hal::drivers::cuda::registration
EXCLUDED_TESTS
@@ -55,7 +44,6 @@
"\"PTXE\""
ARGS
"--cuda_use_streams=true"
- ${FILTER_TESTS_ARGS}
DEPS
iree::hal::drivers::cuda::registration
EXCLUDED_TESTS
diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_driver.c b/runtime/src/iree/hal/drivers/cuda/cuda_driver.c
index ebacb16..e301986 100644
--- a/runtime/src/iree/hal/drivers/cuda/cuda_driver.c
+++ b/runtime/src/iree/hal/drivers/cuda/cuda_driver.c
@@ -19,7 +19,7 @@
// Maximum device name length supported by the CUDA HAL driver.
#define IREE_HAL_CUDA_MAX_DEVICE_NAME_LENGTH 128
-// Utility macros to convert between CUDevice and iree_hal_device_id_t.
+// Utility macros to convert between CUdevice and iree_hal_device_id_t.
#define IREE_CUDEVICE_TO_DEVICE_ID(device) (iree_hal_device_id_t)((device) + 1)
#define IREE_DEVICE_ID_TO_CUDEVICE(device_id) (CUdevice)((device_id) - 1)
diff --git a/runtime/src/iree/hal/drivers/cuda/event_semaphore.c b/runtime/src/iree/hal/drivers/cuda/event_semaphore.c
index 10c0b9f..a2bc05f 100644
--- a/runtime/src/iree/hal/drivers/cuda/event_semaphore.c
+++ b/runtime/src/iree/hal/drivers/cuda/event_semaphore.c
@@ -187,6 +187,13 @@
// Notify timepoints - note that this must happen outside the lock.
iree_hal_semaphore_notify(&semaphore->base, IREE_HAL_SEMAPHORE_FAILURE_VALUE,
status_code);
+
+ // Advance the pending queue actions if possible. This also must happen
+ // outside the lock to avoid nesting.
+ status = iree_hal_cuda_pending_queue_actions_issue(
+ semaphore->pending_queue_actions);
+ iree_status_ignore(status);
+
IREE_TRACE_ZONE_END(z0);
}
@@ -317,6 +324,14 @@
return iree_ok_status();
}
+ iree_slim_mutex_lock(&semaphore->mutex);
+ if (semaphore->current_value == IREE_HAL_SEMAPHORE_FAILURE_VALUE) {
+ iree_slim_mutex_unlock(&semaphore->mutex);
+ IREE_TRACE_ZONE_END(z0);
+ return iree_make_status(IREE_STATUS_ABORTED);
+ }
+ iree_slim_mutex_unlock(&semaphore->mutex);
+
// Wait until the timepoint resolves.
// If satisfied the timepoint is automatically cleaned up and we are done. If
// the deadline is reached before satisfied then we have to clean it up.
@@ -329,6 +344,17 @@
}
iree_hal_cuda_timepoint_pool_release(semaphore->timepoint_pool, 1,
&timepoint);
+ if (!iree_status_is_ok(status)) {
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+ }
+
+ iree_slim_mutex_lock(&semaphore->mutex);
+ if (semaphore->current_value == IREE_HAL_SEMAPHORE_FAILURE_VALUE) {
+ status = iree_make_status(IREE_STATUS_ABORTED);
+ }
+ iree_slim_mutex_unlock(&semaphore->mutex);
+
IREE_TRACE_ZONE_END(z0);
return status;
}
@@ -409,6 +435,23 @@
iree_wait_set_free(wait_set);
iree_arena_deinitialize(&arena);
+ if (!iree_status_is_ok(status)) {
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+ }
+
+ for (iree_host_size_t i = 0; i < semaphore_list.count; ++i) {
+ iree_hal_cuda_semaphore_t* semaphore =
+ iree_hal_cuda_semaphore_cast(semaphore_list.semaphores[i]);
+ iree_slim_mutex_lock(&semaphore->mutex);
+ if (semaphore->current_value == IREE_HAL_SEMAPHORE_FAILURE_VALUE) {
+ iree_slim_mutex_unlock(&semaphore->mutex);
+ status = iree_make_status(IREE_STATUS_ABORTED);
+ break;
+ }
+ iree_slim_mutex_unlock(&semaphore->mutex);
+ }
+
IREE_TRACE_ZONE_END(z0);
return status;
}
diff --git a/runtime/src/iree/hal/drivers/cuda/pending_queue_actions.c b/runtime/src/iree/hal/drivers/cuda/pending_queue_actions.c
index 2e964e8..5b64dfa 100644
--- a/runtime/src/iree/hal/drivers/cuda/pending_queue_actions.c
+++ b/runtime/src/iree/hal/drivers/cuda/pending_queue_actions.c
@@ -45,6 +45,13 @@
IREE_HAL_CUDA_QUEUE_ACTION_STATE_ZOMBIE,
} iree_hal_cuda_queue_action_state_t;
+// How many work items must complete in order for an action to complete.
+// We keep track of the remaining work for an action so we don't exit worker
+// threads prematurely.
+// +1 for issuing an execution of an action.
+// +1 for cleaning up a zombie action.
+static const iree_host_size_t total_work_items_to_complete_an_action = 2;
+
// A pending queue action.
//
// Note that this struct does not have internal synchronization; it's expected
@@ -101,6 +108,9 @@
bool is_pending;
} iree_hal_cuda_queue_action_t;
+static void iree_hal_cuda_queue_action_fail_locked(
+ iree_hal_cuda_queue_action_t* action, iree_status_t status);
+
static void iree_hal_cuda_queue_action_clear_events(
iree_hal_cuda_queue_action_t* action) {
for (iree_host_size_t i = 0; i < action->event_count; ++i) {
@@ -248,7 +258,7 @@
typedef struct iree_hal_cuda_completion_list_node_t {
// The callback and user data for that callback. To be called
// when the associated event has completed.
- iree_status_t (*callback)(void* user_data);
+ iree_status_t (*callback)(iree_status_t, void* user_data);
void* user_data;
// The event to wait for on the completion thread.
CUevent event;
@@ -316,7 +326,7 @@
iree_slim_mutex_deinitialize(&list->guard_mutex);
}
-static iree_hal_cuda_queue_action_t* iree_hal_cuda_atomic_slist_entry_pop_front(
+static iree_hal_cuda_queue_action_t* iree_hal_cuda_entry_list_node_pop_front(
iree_hal_cuda_entry_list_node_t* list) {
IREE_ASSERT(list->ready_list_head);
@@ -331,17 +341,27 @@
return action;
}
+static void iree_hal_cuda_entry_list_node_push_front(
+ iree_hal_cuda_entry_list_node_t* entry,
+ iree_hal_cuda_queue_action_t* action) {
+ IREE_ASSERT(!action->next && !action->prev);
+
+ iree_hal_cuda_queue_action_t* head = entry->ready_list_head;
+ entry->ready_list_head = action;
+ if (head) {
+ action->next = head;
+ head->prev = action;
+ }
+}
+
// The ready-list processing worker's working/exiting state.
//
// States in the list has increasing priorities--meaning normally ones appearing
// earlier can overwrite ones appearing later without checking; but not the
// reverse order.
typedef enum iree_hal_cuda_worker_state_e {
- IREE_HAL_CUDA_WORKER_STATE_IDLE_WAITING = 0, // Worker to main thread
- IREE_HAL_CUDA_WORKER_STATE_WORKLOAD_PENDING = 1, // Main to worker thread
- IREE_HAL_CUDA_WORKER_STATE_EXIT_REQUESTED = -1, // Main to worker thread
- IREE_HAL_CUDA_WORKER_STATE_EXIT_COMMITTED = -2, // Worker to main thread
- IREE_HAL_CUDA_WORKER_STATE_EXIT_ERROR = -3, // Worker to main thread
+ IREE_HAL_CUDA_WORKER_STATE_IDLE_WAITING = 0, // Worker to any thread
+ IREE_HAL_CUDA_WORKER_STATE_WORKLOAD_PENDING = 1, // Any to worker thread
} iree_hal_cuda_worker_state_t;
// The data structure needed by a ready-list processing worker thread to issue
@@ -359,29 +379,8 @@
typedef struct iree_hal_cuda_working_area_t {
// Notification from the parent thread to request worker state changes.
iree_notification_t state_notification;
- // Notification to the parent thread to indicate the worker committed exiting.
- // TODO: maybe remove this. We can just wait on the worker thread to exit.
- iree_notification_t exit_notification;
- iree_hal_cuda_entry_list_t ready_worklist;
- iree_atomic_int32_t worker_state; // atomic
- // TODO: use status to provide more context for the error.
- iree_atomic_intptr_t error_code; // atomic
-
- // The number of asynchronous work items that are scheduled and not
- // complete.
- // These are
- // * the number of callbacks that are scheduled on the host stream.
- // * the number of pending action cleanup.
- // We need to wait for them to finish before destroying the context.
- iree_slim_mutex_t pending_work_items_count_mutex;
- iree_notification_t pending_work_items_count_notification;
- int32_t pending_work_items_count
- IREE_GUARDED_BY(pending_work_items_count_mutex);
-
- iree_allocator_t host_allocator; // const
-
- const iree_hal_cuda_dynamic_symbols_t* symbols;
- CUcontext context;
+ iree_hal_cuda_entry_list_t ready_worklist; // atomic
+ iree_atomic_int32_t worker_state; // atomic
} iree_hal_cuda_working_area_t;
// This data structure is shared by the parent thread. It is responsible
@@ -390,103 +389,59 @@
// This replaces the use of cuLaunchHostFunc, which causes the stream to block
// and wait for the CPU work to complete. It also picks up completed
// events with significantly less latency than cuLaunchHostFunc.
-
typedef struct iree_hal_cuda_completion_area_t {
// Notification from the parent thread to request completion state changes.
iree_notification_t state_notification;
- // Notification to the parent thread to indicate the worker committed exiting.
- iree_notification_t exit_notification;
- iree_hal_cuda_completion_list_t completion_list;
- iree_atomic_int32_t worker_state; // atomic
-
- iree_atomic_intptr_t error_code; // atomic
-
- // The number of asynchronous completions items that are scheduled and not
- // yet waited on.
- // We need to wait for them to finish before destroying the context.
- iree_slim_mutex_t pending_completion_count_mutex;
- iree_notification_t pending_completion_count_notification;
- int32_t pending_completion_count
- IREE_GUARDED_BY(pending_completion_count_mutex);
-
- iree_allocator_t host_allocator;
-
- const iree_hal_cuda_dynamic_symbols_t* symbols;
- CUcontext context;
+ iree_hal_cuda_completion_list_t completion_list; // atomic
+ iree_atomic_int32_t worker_state; // atomic
} iree_hal_cuda_completion_area_t;
static void iree_hal_cuda_working_area_initialize(
- iree_allocator_t host_allocator, CUcontext context,
+ iree_allocator_t host_allocator, CUdevice device,
const iree_hal_cuda_dynamic_symbols_t* symbols,
iree_hal_cuda_working_area_t* working_area) {
iree_notification_initialize(&working_area->state_notification);
- iree_notification_initialize(&working_area->exit_notification);
- iree_hal_cuda_ready_action_list_initialize(&working_area->ready_worklist);
+ iree_hal_cuda_ready_action_list_deinitialize(&working_area->ready_worklist,
+ host_allocator);
iree_atomic_store_int32(&working_area->worker_state,
IREE_HAL_CUDA_WORKER_STATE_IDLE_WAITING,
iree_memory_order_release);
- iree_atomic_store_int32(&working_area->error_code, IREE_STATUS_OK,
- iree_memory_order_release);
- iree_slim_mutex_initialize(&working_area->pending_work_items_count_mutex);
- iree_notification_initialize(
- &working_area->pending_work_items_count_notification);
- working_area->pending_work_items_count = 0;
- working_area->host_allocator = host_allocator;
- working_area->symbols = symbols;
- working_area->context = context;
}
static void iree_hal_cuda_working_area_deinitialize(
- iree_hal_cuda_working_area_t* working_area) {
+ iree_hal_cuda_working_area_t* working_area,
+ iree_allocator_t host_allocator) {
iree_hal_cuda_ready_action_list_deinitialize(&working_area->ready_worklist,
- working_area->host_allocator);
- iree_notification_deinitialize(&working_area->exit_notification);
+ host_allocator);
iree_notification_deinitialize(&working_area->state_notification);
- iree_slim_mutex_deinitialize(&working_area->pending_work_items_count_mutex);
- iree_notification_deinitialize(
- &working_area->pending_work_items_count_notification);
}
static void iree_hal_cuda_completion_area_initialize(
- iree_allocator_t host_allocator, CUcontext context,
+ iree_allocator_t host_allocator, CUdevice device,
const iree_hal_cuda_dynamic_symbols_t* symbols,
iree_hal_cuda_completion_area_t* completion_area) {
iree_notification_initialize(&completion_area->state_notification);
- iree_notification_initialize(&completion_area->exit_notification);
iree_hal_cuda_completion_list_initialize(&completion_area->completion_list);
iree_atomic_store_int32(&completion_area->worker_state,
IREE_HAL_CUDA_WORKER_STATE_IDLE_WAITING,
iree_memory_order_release);
- iree_atomic_store_int32(&completion_area->error_code, IREE_STATUS_OK,
- iree_memory_order_release);
- iree_slim_mutex_initialize(&completion_area->pending_completion_count_mutex);
- iree_notification_initialize(
- &completion_area->pending_completion_count_notification);
- completion_area->pending_completion_count = 0;
- completion_area->host_allocator = host_allocator;
- completion_area->symbols = symbols;
- completion_area->context = context;
}
static void iree_hal_cuda_completion_area_deinitialize(
- iree_hal_cuda_completion_area_t* completion_area) {
+ iree_hal_cuda_completion_area_t* completion_area,
+ const iree_hal_cuda_dynamic_symbols_t* symbols,
+ iree_allocator_t host_allocator) {
iree_hal_cuda_completion_list_deinitialize(&completion_area->completion_list,
- completion_area->symbols,
- completion_area->host_allocator);
- iree_notification_deinitialize(&completion_area->exit_notification);
+ symbols, host_allocator);
iree_notification_deinitialize(&completion_area->state_notification);
- iree_slim_mutex_deinitialize(
- &completion_area->pending_completion_count_mutex);
- iree_notification_deinitialize(
- &completion_area->pending_completion_count_notification);
}
// The main function for the ready-list processing worker thread.
static int iree_hal_cuda_worker_execute(
- iree_hal_cuda_working_area_t* working_area);
+ iree_hal_cuda_pending_queue_actions_t* actions);
static int iree_hal_cuda_completion_execute(
- iree_hal_cuda_completion_area_t* working_area);
+ iree_hal_cuda_pending_queue_actions_t* actions);
//===----------------------------------------------------------------------===//
// Pending queue actions
@@ -505,7 +460,7 @@
// The symbols used to create and destroy CUevent objects.
const iree_hal_cuda_dynamic_symbols_t* symbols;
- // Non-recursive mutex guarding access to the action list.
+ // Non-recursive mutex guarding access.
iree_slim_mutex_t action_mutex;
// The double-linked list of pending actions.
@@ -525,12 +480,29 @@
// Completion thread's working area.
iree_hal_cuda_completion_area_t completion_area;
+ // Atomic of type iree_status_t. It is a sticky error.
+ // Once set with an error, all subsequent actions that have not completed
+ // will fail with this error.
+ iree_status_t status IREE_GUARDED_BY(action_mutex);
+
// The associated cuda device.
CUdevice device;
-};
+ CUcontext cuda_context;
-static const iree_hal_resource_vtable_t
- iree_hal_cuda_pending_queue_actions_vtable;
+ // The number of asynchronous work items that are scheduled and not
+ // complete.
+ // These are
+ // * the number of actions issued.
+ // * the number of pending action cleanups.
+ // The work and completion threads can exit only when there are no more
+ // pending work items.
+ iree_host_size_t pending_work_items_count IREE_GUARDED_BY(action_mutex);
+
+ // The owner can request an exit of the worker threads.
+ // Once all pending enqueued work is complete the threads will exit.
+ // No actions can be enqueued after requesting an exit.
+ bool exit_requested IREE_GUARDED_BY(action_mutex);
+};
iree_status_t iree_hal_cuda_pending_queue_actions_create(
const iree_hal_cuda_dynamic_symbols_t* symbols, CUdevice device,
@@ -546,22 +518,21 @@
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_allocator_malloc(host_allocator, sizeof(*actions),
(void**)&actions));
- iree_hal_resource_initialize(&iree_hal_cuda_pending_queue_actions_vtable,
- &actions->resource);
actions->host_allocator = host_allocator;
actions->block_pool = block_pool;
actions->symbols = symbols;
actions->device = device;
+ actions->cuda_context = context;
iree_slim_mutex_initialize(&actions->action_mutex);
memset(&actions->action_list, 0, sizeof(actions->action_list));
// Initialize the working area for the ready-list processing worker.
iree_hal_cuda_working_area_t* working_area = &actions->working_area;
- iree_hal_cuda_working_area_initialize(host_allocator, context, symbols,
+ iree_hal_cuda_working_area_initialize(host_allocator, device, symbols,
working_area);
iree_hal_cuda_completion_area_t* completion_area = &actions->completion_area;
- iree_hal_cuda_completion_area_initialize(host_allocator, context, symbols,
+ iree_hal_cuda_completion_area_initialize(host_allocator, device, symbols,
completion_area);
// Create the ready-list processing worker itself.
@@ -570,15 +541,15 @@
params.name = IREE_SV("deferque_worker");
params.create_suspended = false;
iree_status_t status = iree_thread_create(
- (iree_thread_entry_t)iree_hal_cuda_worker_execute, working_area, params,
+ (iree_thread_entry_t)iree_hal_cuda_worker_execute, actions, params,
actions->host_allocator, &actions->worker_thread);
params.name = IREE_SV("done_worker");
params.create_suspended = false;
if (iree_status_is_ok(status)) {
status = iree_thread_create(
- (iree_thread_entry_t)iree_hal_cuda_completion_execute, completion_area,
- params, actions->host_allocator, &actions->completion_thread);
+ (iree_thread_entry_t)iree_hal_cuda_completion_execute, actions, params,
+ actions->host_allocator, &actions->completion_thread);
}
if (iree_status_is_ok(status)) {
@@ -596,59 +567,62 @@
return (iree_hal_cuda_pending_queue_actions_t*)base_value;
}
-static bool iree_hal_cuda_worker_committed_exiting(
- iree_hal_cuda_working_area_t* working_area);
+static void iree_hal_cuda_pending_queue_actions_notify_worker_thread(
+ iree_hal_cuda_working_area_t* working_area) {
+ iree_atomic_store_int32(&working_area->worker_state,
+ IREE_HAL_CUDA_WORKER_STATE_WORKLOAD_PENDING,
+ iree_memory_order_release);
+ iree_notification_post(&working_area->state_notification, IREE_ALL_WAITERS);
+}
-static bool iree_hal_cuda_completion_committed_exiting(
- iree_hal_cuda_completion_area_t* working_area);
+static void iree_hal_cuda_pending_queue_actions_notify_completion_thread(
+ iree_hal_cuda_completion_area_t* completion_area) {
+ iree_atomic_store_int32(&completion_area->worker_state,
+ IREE_HAL_CUDA_WORKER_STATE_WORKLOAD_PENDING,
+ iree_memory_order_release);
+ iree_notification_post(&completion_area->state_notification,
+ IREE_ALL_WAITERS);
+}
+
+// Notifies worker and completion threads that there is work available to
+// process.
+static void iree_hal_cuda_pending_queue_actions_notify_threads(
+ iree_hal_cuda_pending_queue_actions_t* actions) {
+ iree_hal_cuda_pending_queue_actions_notify_worker_thread(
+ &actions->working_area);
+ iree_hal_cuda_pending_queue_actions_notify_completion_thread(
+ &actions->completion_area);
+}
+
+static void iree_hal_cuda_pending_queue_actions_request_exit(
+ iree_hal_cuda_pending_queue_actions_t* actions) {
+ iree_slim_mutex_lock(&actions->action_mutex);
+ actions->exit_requested = true;
+ iree_slim_mutex_unlock(&actions->action_mutex);
+
+ iree_hal_cuda_pending_queue_actions_notify_threads(actions);
+}
void iree_hal_cuda_pending_queue_actions_destroy(
iree_hal_resource_t* base_actions) {
+ IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_cuda_pending_queue_actions_t* actions =
iree_hal_cuda_pending_queue_actions_cast(base_actions);
iree_allocator_t host_allocator = actions->host_allocator;
- iree_hal_cuda_working_area_t* working_area = &actions->working_area;
- iree_hal_cuda_completion_area_t* completion_area = &actions->completion_area;
- IREE_TRACE_ZONE_BEGIN(z0);
- // Request the worker to exit.
- iree_hal_cuda_worker_state_t prev_state =
- (iree_hal_cuda_worker_state_t)iree_atomic_exchange_int32(
- &working_area->worker_state,
- IREE_HAL_CUDA_WORKER_STATE_EXIT_REQUESTED, iree_memory_order_acq_rel);
- iree_notification_post(&working_area->state_notification, IREE_ALL_WAITERS);
+ // Request the workers to exit.
+ iree_hal_cuda_pending_queue_actions_request_exit(actions);
- // Check potential exit states from the worker.
- if (prev_state != IREE_HAL_CUDA_WORKER_STATE_EXIT_ERROR) {
- // Wait until the worker acknowledged exiting.
- iree_notification_await(
- &working_area->exit_notification,
- (iree_condition_fn_t)iree_hal_cuda_worker_committed_exiting,
- working_area, iree_infinite_timeout());
- }
-
- // Now we can delete worker related resources.
+ iree_thread_join(actions->worker_thread);
iree_thread_release(actions->worker_thread);
- iree_hal_cuda_working_area_deinitialize(working_area);
- // Request the completion thread to exit.
- prev_state = (iree_hal_cuda_worker_state_t)iree_atomic_exchange_int32(
- &completion_area->worker_state, IREE_HAL_CUDA_WORKER_STATE_EXIT_REQUESTED,
- iree_memory_order_acq_rel);
- iree_notification_post(&completion_area->state_notification,
- IREE_ALL_WAITERS);
-
- // Check potential exit states from the completion thread.
- if (prev_state != IREE_HAL_CUDA_WORKER_STATE_EXIT_ERROR) {
- // Wait until the completion thread acknowledged exiting.
- iree_notification_await(
- &completion_area->exit_notification,
- (iree_condition_fn_t)iree_hal_cuda_completion_committed_exiting,
- completion_area, iree_infinite_timeout());
- }
-
+ iree_thread_join(actions->completion_thread);
iree_thread_release(actions->completion_thread);
- iree_hal_cuda_completion_area_deinitialize(completion_area);
+
+ iree_hal_cuda_working_area_deinitialize(&actions->working_area,
+ actions->host_allocator);
+ iree_hal_cuda_completion_area_deinitialize(
+ &actions->completion_area, actions->symbols, actions->host_allocator);
iree_slim_mutex_deinitialize(&actions->action_mutex);
iree_hal_cuda_queue_action_list_destroy(actions->action_list.head);
@@ -657,11 +631,6 @@
IREE_TRACE_ZONE_END(z0);
}
-static const iree_hal_resource_vtable_t
- iree_hal_cuda_pending_queue_actions_vtable = {
- .destroy = iree_hal_cuda_pending_queue_actions_destroy,
-};
-
static void iree_hal_cuda_queue_action_destroy(
iree_hal_cuda_queue_action_t* action) {
IREE_TRACE_ZONE_BEGIN(z0);
@@ -686,30 +655,10 @@
}
static void iree_hal_cuda_queue_decrement_work_items_count(
- iree_hal_cuda_working_area_t* working_area) {
- iree_slim_mutex_lock(&working_area->pending_work_items_count_mutex);
- --working_area->pending_work_items_count;
- if (working_area->pending_work_items_count == 0) {
- // Notify inside the lock to make sure that we are done touching anything
- // since the context may get destroyed in the meantime.
- iree_notification_post(&working_area->pending_work_items_count_notification,
- IREE_ALL_WAITERS);
- }
- iree_slim_mutex_unlock(&working_area->pending_work_items_count_mutex);
-}
-
-static void iree_hal_cuda_queue_decrement_completion_count(
- iree_hal_cuda_completion_area_t* completion_area) {
- iree_slim_mutex_lock(&completion_area->pending_completion_count_mutex);
- --completion_area->pending_completion_count;
- if (completion_area->pending_completion_count == 0) {
- // Notify inside the lock to make sure that we are done touching anything
- // since the context may get destroyed in the meantime.
- iree_notification_post(
- &completion_area->pending_completion_count_notification,
- IREE_ALL_WAITERS);
- }
- iree_slim_mutex_unlock(&completion_area->pending_completion_count_mutex);
+ iree_hal_cuda_pending_queue_actions_t* actions) {
+ iree_slim_mutex_lock(&actions->action_mutex);
+ --actions->pending_work_items_count;
+ iree_slim_mutex_unlock(&actions->action_mutex);
}
iree_status_t iree_hal_cuda_pending_queue_actions_enqueue_execution(
@@ -855,12 +804,21 @@
}
if (iree_status_is_ok(status)) {
- // Retain the owning queue to make sure the action outlives it.
- iree_hal_resource_retain(actions);
-
// Now everything is okay and we can enqueue the action.
iree_slim_mutex_lock(&actions->action_mutex);
- iree_hal_cuda_queue_action_list_push_back(&actions->action_list, action);
+ if (actions->exit_requested) {
+ status = iree_make_status(
+ IREE_STATUS_ABORTED,
+ "can not issue more executions, exit already requested");
+ iree_hal_cuda_queue_action_fail_locked(action, status);
+ } else {
+ iree_hal_cuda_queue_action_list_push_back(&actions->action_list, action);
+ // One work item is the callback that makes it across from the
+ // completion thread.
+ // The other is the cleanup of the action.
+ actions->pending_work_items_count +=
+ total_work_items_to_complete_an_action;
+ }
iree_slim_mutex_unlock(&actions->action_mutex);
} else {
iree_hal_resource_set_free(action->resource_set);
@@ -871,45 +829,116 @@
return status;
}
-static void iree_hal_cuda_post_error_to_worker_state(
- iree_hal_cuda_working_area_t* working_area, iree_status_code_t code) {
- // Write error code, but don't overwrite existing error codes.
- intptr_t prev_error_code = IREE_STATUS_OK;
- iree_atomic_compare_exchange_strong_int32(
- &working_area->error_code, /*expected=*/&prev_error_code,
- /*desired=*/code,
- /*order_succ=*/iree_memory_order_acq_rel,
- /*order_fail=*/iree_memory_order_acquire);
-
- // This state has the highest priority so just overwrite.
- iree_atomic_store_int32(&working_area->worker_state,
- IREE_HAL_CUDA_WORKER_STATE_EXIT_ERROR,
- iree_memory_order_release);
- iree_notification_post(&working_area->state_notification, IREE_ALL_WAITERS);
+// Does not consume |status|.
+static void iree_hal_cuda_pending_queue_actions_fail_status_locked(
+ iree_hal_cuda_pending_queue_actions_t* actions, iree_status_t status) {
+ if (iree_status_is_ok(actions->status) && status != actions->status) {
+ actions->status = iree_status_clone(status);
+ }
}
-static void iree_hal_cuda_post_error_to_completion_state(
- iree_hal_cuda_completion_area_t* completion_area, iree_status_code_t code) {
- // Write error code, but don't overwrite existing error codes.
- intptr_t prev_error_code = IREE_STATUS_OK;
- iree_atomic_compare_exchange_strong_int32(
- &completion_area->error_code, /*expected=*/&prev_error_code,
- /*desired=*/code,
- /*order_succ=*/iree_memory_order_acq_rel,
- /*order_fail=*/iree_memory_order_acquire);
+// Fails and destroys the action.
+// Does not consume |status|.
+// Decrements pending work items count accordingly based on the unfulfilled
+// number of work items.
+static void iree_hal_cuda_queue_action_fail_locked(
+ iree_hal_cuda_queue_action_t* action, iree_status_t status) {
+ IREE_ASSERT(!iree_status_is_ok(status));
+ iree_hal_cuda_pending_queue_actions_t* actions = action->owning_actions;
- // This state has the highest priority so just overwrite.
- iree_atomic_store_int32(&completion_area->worker_state,
- IREE_HAL_CUDA_WORKER_STATE_EXIT_ERROR,
- iree_memory_order_release);
- iree_notification_post(&completion_area->state_notification,
- IREE_ALL_WAITERS);
+ // Unlock since failing the semaphore will use |actions|.
+ iree_slim_mutex_unlock(&actions->action_mutex);
+ iree_hal_semaphore_list_fail(action->signal_semaphore_list,
+ iree_status_clone(status));
+
+ iree_host_size_t work_items_remaining = 0;
+ switch (action->state) {
+ case IREE_HAL_CUDA_QUEUE_ACTION_STATE_ALIVE:
+ work_items_remaining = total_work_items_to_complete_an_action;
+ break;
+ case IREE_HAL_CUDA_QUEUE_ACTION_STATE_ZOMBIE:
+ work_items_remaining = 1;
+ break;
+ default:
+ // Someone forgot to handle all enum values?
+ iree_abort();
+ }
+ iree_slim_mutex_lock(&actions->action_mutex);
+ action->owning_actions->pending_work_items_count -= work_items_remaining;
+ iree_hal_cuda_pending_queue_actions_fail_status_locked(actions, status);
+ iree_hal_cuda_queue_action_destroy(action);
+}
+
+// Fails and destroys all actions.
+// Does not consume |status|.
+static void iree_hal_cuda_queue_action_fail(
+ iree_hal_cuda_queue_action_t* action, iree_status_t status) {
+ iree_hal_cuda_pending_queue_actions_t* actions = action->owning_actions;
+ iree_slim_mutex_lock(&actions->action_mutex);
+ iree_hal_cuda_queue_action_fail_locked(action, status);
+ iree_slim_mutex_unlock(&actions->action_mutex);
+}
+
+// Fails and destroys all actions.
+// Does not consume |status|.
+static void iree_hal_cuda_queue_action_raw_list_fail_locked(
+ iree_hal_cuda_queue_action_t* head_action, iree_status_t status) {
+ while (head_action) {
+ iree_hal_cuda_queue_action_t* next_action = head_action->next;
+ iree_hal_cuda_queue_action_fail_locked(head_action, status);
+ head_action = next_action;
+ }
+}
+
+// Fails and destroys all actions.
+// Does not consume |status|.
+static void iree_hal_cuda_ready_action_list_fail_locked(
+ iree_hal_cuda_entry_list_t* list, iree_status_t status) {
+ iree_hal_cuda_entry_list_node_t* entry = iree_hal_cuda_entry_list_pop(list);
+ while (entry) {
+ iree_hal_cuda_queue_action_raw_list_fail_locked(entry->ready_list_head,
+ status);
+ entry = iree_hal_cuda_entry_list_pop(list);
+ }
+}
+
+// Fails and destroys all actions.
+// Does not consume |status|.
+static void iree_hal_cuda_queue_action_list_fail_locked(
+ iree_hal_cuda_queue_action_list_t* list, iree_status_t status) {
+ iree_hal_cuda_queue_action_t* action;
+ if (iree_hal_cuda_queue_action_list_is_empty(list)) {
+ return;
+ }
+ do {
+ action = iree_hal_cuda_queue_action_list_pop_front(list);
+ iree_hal_cuda_queue_action_fail_locked(action, status);
+ } while (action);
+}
+
+// Fails and destroys all actions and sets status of |actions|.
+// Does not consume |status|.
+// Assumes the caller is holding the action_mutex.
+static void iree_hal_cuda_pending_queue_actions_fail_locked(
+ iree_hal_cuda_pending_queue_actions_t* actions, iree_status_t status) {
+ iree_hal_cuda_pending_queue_actions_fail_status_locked(actions, status);
+ iree_hal_cuda_queue_action_list_fail_locked(&actions->action_list, status);
+ iree_hal_cuda_ready_action_list_fail_locked(
+ &actions->working_area.ready_worklist, status);
+}
+
+// Does not consume |status|.
+static void iree_hal_cuda_pending_queue_actions_fail(
+ iree_hal_cuda_pending_queue_actions_t* actions, iree_status_t status) {
+ iree_slim_mutex_lock(&actions->action_mutex);
+ iree_hal_cuda_pending_queue_actions_fail_locked(actions, status);
+ iree_slim_mutex_unlock(&actions->action_mutex);
}
// Releases resources after action completion on the GPU and advances timeline
// and pending actions queue.
static iree_status_t iree_hal_cuda_execution_device_signal_host_callback(
- void* user_data) {
+ iree_status_t status, void* user_data) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_cuda_queue_action_t* action =
(iree_hal_cuda_queue_action_t*)user_data;
@@ -917,7 +946,11 @@
IREE_ASSERT_EQ(action->state, IREE_HAL_CUDA_QUEUE_ACTION_STATE_ALIVE);
iree_hal_cuda_pending_queue_actions_t* actions = action->owning_actions;
- iree_status_t status;
+ if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
+ iree_hal_cuda_queue_action_fail(action, status);
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+ }
// Need to signal the list before zombifying the action, because in the mean
// time someone else may issue the pending queue actions.
@@ -925,9 +958,9 @@
// may run while we are still using the semaphore list, causing a crash.
status = iree_hal_semaphore_list_signal(action->signal_semaphore_list);
if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
- IREE_ASSERT(false && "cannot signal semaphores in host callback");
- iree_hal_cuda_post_error_to_worker_state(&actions->working_area,
- iree_status_code(status));
+ iree_hal_cuda_queue_action_fail(action, status);
+ IREE_TRACE_ZONE_END(z0);
+ return status;
}
// Flip the action state to zombie and enqueue it again so that we can let
@@ -937,18 +970,12 @@
action->state = IREE_HAL_CUDA_QUEUE_ACTION_STATE_ZOMBIE;
iree_slim_mutex_lock(&actions->action_mutex);
iree_hal_cuda_queue_action_list_push_back(&actions->action_list, action);
+ // The callback (work item) is complete.
+ --actions->pending_work_items_count;
iree_slim_mutex_unlock(&actions->action_mutex);
// We need to trigger execution of this action again, so it gets cleaned up.
status = iree_hal_cuda_pending_queue_actions_issue(actions);
- if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
- IREE_ASSERT(false && "cannot issue action for cleanup in host callback");
- iree_hal_cuda_post_error_to_worker_state(&actions->working_area,
- iree_status_code(status));
- }
-
- // The callback (work item) is complete.
- iree_hal_cuda_queue_decrement_work_items_count(&actions->working_area);
IREE_TRACE_ZONE_END(z0);
return status;
@@ -959,8 +986,8 @@
iree_hal_cuda_queue_action_t* action) {
IREE_ASSERT_EQ(action->kind, IREE_HAL_CUDA_QUEUE_ACTION_TYPE_EXECUTION);
IREE_ASSERT_EQ(action->is_pending, false);
- const iree_hal_cuda_dynamic_symbols_t* symbols =
- action->owning_actions->symbols;
+ iree_hal_cuda_pending_queue_actions_t* actions = action->owning_actions;
+ const iree_hal_cuda_dynamic_symbols_t* symbols = actions->symbols;
IREE_TRACE_ZONE_BEGIN(z0);
// No need to lock given that this action is already detched from the pending
@@ -1060,21 +1087,10 @@
z0, symbols, cuEventRecord(completion_event, action->dispatch_cu_stream),
"cuEventRecord");
- iree_slim_mutex_lock(
- &action->owning_actions->working_area.pending_work_items_count_mutex);
- // One work item is the callback that makes it across from the
- // completion thread.
- // The other is the cleanup of the action.
- action->owning_actions->working_area.pending_work_items_count += 2;
- iree_slim_mutex_unlock(
- &action->owning_actions->working_area.pending_work_items_count_mutex);
- // Now launch a host function on the callback stream to advance the semaphore
- // timeline.
-
iree_hal_cuda_completion_list_node_t* entry = NULL;
// TODO: avoid host allocator malloc; use some pool for the allocation.
- iree_status_t status = iree_allocator_malloc(
- action->owning_actions->host_allocator, sizeof(*entry), (void**)&entry);
+ iree_status_t status = iree_allocator_malloc(actions->host_allocator,
+ sizeof(*entry), (void**)&entry);
if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
IREE_TRACE_ZONE_END(z0);
@@ -1087,39 +1103,11 @@
entry->created_event = created_event;
entry->callback = iree_hal_cuda_execution_device_signal_host_callback;
entry->user_data = action;
- iree_hal_cuda_completion_list_push(
- &action->owning_actions->completion_area.completion_list, entry);
+ iree_hal_cuda_completion_list_push(&actions->completion_area.completion_list,
+ entry);
- iree_slim_mutex_lock(
- &action->owning_actions->completion_area.pending_completion_count_mutex);
-
- action->owning_actions->completion_area.pending_completion_count += 1;
-
- iree_slim_mutex_unlock(
- &action->owning_actions->completion_area.pending_completion_count_mutex);
-
- // We can only overwrite the worker state if the previous state is idle
- // waiting; we cannot overwrite exit related states. so we need to perform
- // atomic compare and exchange here.
- iree_hal_cuda_worker_state_t prev_state =
- IREE_HAL_CUDA_WORKER_STATE_IDLE_WAITING;
- iree_atomic_compare_exchange_strong_int32(
- &action->owning_actions->completion_area.worker_state,
- /*expected=*/&prev_state,
- /*desired=*/IREE_HAL_CUDA_WORKER_STATE_WORKLOAD_PENDING,
- /*order_succ=*/iree_memory_order_acq_rel,
- /*order_fail=*/iree_memory_order_acquire);
- iree_notification_post(
- &action->owning_actions->completion_area.state_notification,
- IREE_ALL_WAITERS);
-
- // Handle potential error cases from the worker thread.
- if (prev_state == IREE_HAL_CUDA_WORKER_STATE_EXIT_ERROR) {
- iree_status_code_t code = iree_atomic_load_int32(
- &action->owning_actions->completion_area.error_code,
- iree_memory_order_acquire);
- IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_status_from_code(code));
- }
+ iree_hal_cuda_pending_queue_actions_notify_completion_thread(
+ &actions->completion_area);
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
@@ -1135,7 +1123,7 @@
// Now we fully executed and cleaned up this action. Decrease the work items
// counter.
- iree_hal_cuda_queue_decrement_work_items_count(&actions->working_area);
+ iree_hal_cuda_queue_decrement_work_items_count(actions);
IREE_TRACE_ZONE_END(z0);
}
@@ -1155,8 +1143,16 @@
return iree_ok_status();
}
- // Scan through the list and categorize actions into pending and ready lists.
+ if (IREE_UNLIKELY(!iree_status_is_ok(actions->status))) {
+ iree_hal_cuda_queue_action_list_fail_locked(&actions->action_list,
+ actions->status);
+ iree_slim_mutex_unlock(&actions->action_mutex);
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+ }
+
iree_status_t status = iree_ok_status();
+ // Scan through the list and categorize actions into pending and ready lists.
while (!iree_hal_cuda_queue_action_list_is_empty(&actions->action_list)) {
iree_hal_cuda_queue_action_t* action =
iree_hal_cuda_queue_action_list_pop_front(&actions->action_list);
@@ -1165,6 +1161,7 @@
uint64_t* values = action->wait_semaphore_list.payload_values;
action->is_pending = false;
+ bool action_failed = false;
// Cleanup actions are immediately ready to release. Otherwise, look at all
// wait semaphores to make sure that they are either already ready or we can
@@ -1174,8 +1171,14 @@
// If this semaphore has already signaled past the desired value, we can
// just ignore it.
uint64_t value = 0;
- status = iree_hal_semaphore_query(semaphores[i], &value);
- if (IREE_UNLIKELY(!iree_status_is_ok(status))) break;
+ iree_status_t semaphore_status =
+ iree_hal_semaphore_query(semaphores[i], &value);
+ if (IREE_UNLIKELY(!iree_status_is_ok(semaphore_status))) {
+ iree_hal_cuda_queue_action_fail_locked(action, semaphore_status);
+ iree_status_ignore(semaphore_status);
+ action_failed = true;
+ break;
+ }
if (value >= values[i]) {
// No need to wait on this timepoint as it has already occurred and
// we can remove it from the wait list.
@@ -1204,6 +1207,10 @@
IREE_STATUS_RESOURCE_EXHAUSTED,
"exceeded maximum queue action wait event limit");
iree_hal_cuda_event_release(wait_event);
+ if (iree_status_is_ok(actions->status)) {
+ actions->status = status;
+ }
+ iree_hal_cuda_queue_action_fail_locked(action, status);
break;
}
action->events[action->event_count++] = wait_event;
@@ -1215,11 +1222,16 @@
}
}
- if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
- // Some error happened during processing the current action.
- // Put it back to the pending list so we don't leak.
- action->is_pending = true;
- iree_hal_cuda_queue_action_list_push_back(&pending_list, action);
+ if (IREE_UNLIKELY(!iree_status_is_ok(actions->status))) {
+ if (!action_failed) {
+ iree_hal_cuda_queue_action_fail_locked(action, actions->status);
+ }
+ iree_hal_cuda_queue_action_list_fail_locked(&actions->action_list,
+ actions->status);
+ break;
+ }
+
+ if (action_failed) {
break;
}
@@ -1249,9 +1261,10 @@
}
if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
- // Release all actions in the ready list to avoid leaking.
- iree_hal_cuda_queue_action_list_destroy(ready_list.head);
- iree_allocator_free(actions->host_allocator, entry);
+ iree_slim_mutex_lock(&actions->action_mutex);
+ iree_hal_cuda_pending_queue_actions_fail_status_locked(actions, status);
+ iree_hal_cuda_queue_action_list_fail_locked(&ready_list, status);
+ iree_slim_mutex_unlock(&actions->action_mutex);
IREE_TRACE_ZONE_END(z0);
return status;
}
@@ -1261,25 +1274,8 @@
entry->ready_list_head = ready_list.head;
iree_hal_cuda_entry_list_push(&actions->working_area.ready_worklist, entry);
- // We can only overwrite the worker state if the previous state is idle
- // waiting; we cannot overwrite exit related states. so we need to perform
- // atomic compare and exchange here.
- iree_hal_cuda_worker_state_t prev_state =
- IREE_HAL_CUDA_WORKER_STATE_IDLE_WAITING;
- iree_atomic_compare_exchange_strong_int32(
- &actions->working_area.worker_state, /*expected=*/&prev_state,
- /*desired=*/IREE_HAL_CUDA_WORKER_STATE_WORKLOAD_PENDING,
- /*order_succ=*/iree_memory_order_acq_rel,
- /*order_fail=*/iree_memory_order_acquire);
- iree_notification_post(&actions->working_area.state_notification,
- IREE_ALL_WAITERS);
-
- // Handle potential error cases from the worker thread.
- if (prev_state == IREE_HAL_CUDA_WORKER_STATE_EXIT_ERROR) {
- iree_status_code_t code = iree_atomic_load_int32(
- &actions->working_area.error_code, iree_memory_order_acquire);
- status = iree_status_from_code(code);
- }
+ iree_hal_cuda_pending_queue_actions_notify_worker_thread(
+ &actions->working_area);
IREE_TRACE_ZONE_END(z0);
return status;
@@ -1289,247 +1285,153 @@
// Worker routines
//===----------------------------------------------------------------------===//
-static bool iree_hal_cuda_worker_has_incoming_request_or_error(
+static bool iree_hal_cuda_worker_has_incoming_request(
iree_hal_cuda_working_area_t* working_area) {
iree_hal_cuda_worker_state_t value = iree_atomic_load_int32(
&working_area->worker_state, iree_memory_order_acquire);
- return value == IREE_HAL_CUDA_WORKER_STATE_WORKLOAD_PENDING ||
- value == IREE_HAL_CUDA_WORKER_STATE_EXIT_REQUESTED ||
- value == IREE_HAL_CUDA_WORKER_STATE_EXIT_ERROR;
+ return value == IREE_HAL_CUDA_WORKER_STATE_WORKLOAD_PENDING;
}
-static bool iree_hal_cuda_completion_has_incoming_request_or_error(
+static bool iree_hal_cuda_completion_has_incoming_request(
iree_hal_cuda_completion_area_t* completion_area) {
iree_hal_cuda_worker_state_t value = iree_atomic_load_int32(
&completion_area->worker_state, iree_memory_order_acquire);
- return value == IREE_HAL_CUDA_WORKER_STATE_WORKLOAD_PENDING ||
- value == IREE_HAL_CUDA_WORKER_STATE_EXIT_REQUESTED ||
- value == IREE_HAL_CUDA_WORKER_STATE_EXIT_ERROR;
-}
-
-static bool iree_hal_cuda_worker_committed_exiting(
- iree_hal_cuda_working_area_t* working_area) {
- return iree_atomic_load_int32(&working_area->worker_state,
- iree_memory_order_acquire) ==
- IREE_HAL_CUDA_WORKER_STATE_EXIT_COMMITTED;
-}
-
-static bool iree_hal_cuda_completion_committed_exiting(
- iree_hal_cuda_completion_area_t* working_area) {
- return iree_atomic_load_int32(&working_area->worker_state,
- iree_memory_order_acquire) ==
- IREE_HAL_CUDA_WORKER_STATE_EXIT_COMMITTED;
+ return value == IREE_HAL_CUDA_WORKER_STATE_WORKLOAD_PENDING;
}
// Processes all ready actions in the given |worklist|.
-static iree_status_t iree_hal_cuda_worker_process_ready_list(
- iree_allocator_t host_allocator, iree_hal_cuda_entry_list_t* worklist) {
+static void iree_hal_cuda_worker_process_ready_list(
+ iree_hal_cuda_pending_queue_actions_t* actions) {
IREE_TRACE_ZONE_BEGIN(z0);
- iree_status_t status = iree_ok_status();
+ iree_slim_mutex_lock(&actions->action_mutex);
+ iree_status_t status = actions->status;
+ if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
+ iree_hal_cuda_ready_action_list_fail_locked(
+ &actions->working_area.ready_worklist, status);
+ iree_slim_mutex_unlock(&actions->action_mutex);
+ iree_status_ignore(status);
+ return;
+ }
+ iree_slim_mutex_unlock(&actions->action_mutex);
+
while (true) {
iree_hal_cuda_entry_list_node_t* entry =
- iree_hal_cuda_entry_list_pop(worklist);
+ iree_hal_cuda_entry_list_pop(&actions->working_area.ready_worklist);
if (!entry) break;
// Process the current batch of ready actions.
while (entry->ready_list_head) {
iree_hal_cuda_queue_action_t* action =
- iree_hal_cuda_atomic_slist_entry_pop_front(entry);
-
+ iree_hal_cuda_entry_list_node_pop_front(entry);
switch (action->state) {
case IREE_HAL_CUDA_QUEUE_ACTION_STATE_ALIVE:
status = iree_hal_cuda_pending_queue_actions_issue_execution(action);
- if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
- iree_hal_cuda_queue_action_destroy(action);
- }
break;
case IREE_HAL_CUDA_QUEUE_ACTION_STATE_ZOMBIE:
iree_hal_cuda_pending_queue_actions_issue_cleanup(action);
break;
}
- if (IREE_UNLIKELY(!iree_status_is_ok(status))) break;
+
+ if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
+ iree_hal_cuda_entry_list_node_push_front(entry, action);
+ iree_hal_cuda_entry_list_push(&actions->working_area.ready_worklist,
+ entry);
+ break;
+ }
}
if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
- // Let common destruction path take care of destroying the worklist.
- // When we know all host stream callbacks are done and not touching
- // anything.
- iree_hal_cuda_entry_list_push(worklist, entry);
break;
}
- iree_allocator_free(host_allocator, entry);
+ iree_allocator_free(actions->host_allocator, entry);
+ }
+
+ if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
+ iree_hal_cuda_pending_queue_actions_fail(actions, status);
+ iree_status_ignore(status);
}
IREE_TRACE_ZONE_END(z0);
- return status;
}
-static bool iree_hal_cuda_worker_has_no_pending_work_items(
- iree_hal_cuda_working_area_t* working_area) {
- iree_slim_mutex_lock(&working_area->pending_work_items_count_mutex);
- bool result = (working_area->pending_work_items_count == 0);
- iree_slim_mutex_unlock(&working_area->pending_work_items_count_mutex);
- return result;
-}
+static void iree_hal_cuda_worker_process_completion(
+ iree_hal_cuda_pending_queue_actions_t* actions) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+ iree_hal_cuda_completion_list_t* worklist =
+ &actions->completion_area.completion_list;
+ iree_slim_mutex_lock(&actions->action_mutex);
+ iree_status_t status = iree_status_clone(actions->status);
+ iree_slim_mutex_unlock(&actions->action_mutex);
-static bool iree_hal_cuda_completion_has_no_pending_completion_items(
- iree_hal_cuda_completion_area_t* completion_area) {
- iree_slim_mutex_lock(&completion_area->pending_completion_count_mutex);
- bool result = (completion_area->pending_completion_count == 0);
- iree_slim_mutex_unlock(&completion_area->pending_completion_count_mutex);
- return result;
-}
-
-// Wait for all work items to finish.
-static void iree_hal_cuda_worker_wait_pending_work_items(
- iree_hal_cuda_working_area_t* working_area) {
- iree_notification_await(
- &working_area->pending_work_items_count_notification,
- (iree_condition_fn_t)iree_hal_cuda_worker_has_no_pending_work_items,
- working_area, iree_infinite_timeout());
- // Lock then unlock to make sure that all callbacks are really done.
- // Not even touching the notification.
- iree_slim_mutex_lock(&working_area->pending_work_items_count_mutex);
- iree_slim_mutex_unlock(&working_area->pending_work_items_count_mutex);
-}
-
-// Wait for all work items to finish.
-static void iree_hal_cuda_completion_wait_pending_completion_items(
- iree_hal_cuda_completion_area_t* completion_area) {
- iree_notification_await(
- &completion_area->pending_completion_count_notification,
- (iree_condition_fn_t)
- iree_hal_cuda_completion_has_no_pending_completion_items,
- completion_area, iree_infinite_timeout());
- // Lock then unlock to make sure that all callbacks are really done.
- // Not even touching the notification.
- iree_slim_mutex_lock(&completion_area->pending_completion_count_mutex);
- iree_slim_mutex_unlock(&completion_area->pending_completion_count_mutex);
-}
-
-static iree_status_t iree_hal_cuda_worker_process_completion(
- iree_hal_cuda_completion_list_t* worklist,
- iree_hal_cuda_completion_area_t* completion_area) {
- iree_status_t status = iree_ok_status();
while (true) {
iree_hal_cuda_completion_list_node_t* entry =
iree_hal_cuda_completion_list_pop(worklist);
if (!entry) break;
- IREE_TRACE_ZONE_BEGIN_NAMED(z1, "cuEventSynchronize");
- CUresult result =
- completion_area->symbols->cuEventSynchronize(entry->event);
- IREE_TRACE_ZONE_END(z1);
- if (IREE_UNLIKELY(result != CUDA_SUCCESS)) {
- // Let common destruction path take care of destroying the worklist.
- // When we know all host stream callbacks are done and not touching
- // anything.
- iree_hal_cuda_completion_list_push(worklist, entry);
- status =
- iree_make_status(IREE_STATUS_ABORTED, "could not wait on cuda event");
- break;
+ if (IREE_LIKELY(iree_status_is_ok(status))) {
+ IREE_TRACE_ZONE_BEGIN_NAMED(z1, "cuEventSynchronize");
+ status = IREE_CURESULT_TO_STATUS(actions->symbols,
+ cuEventSynchronize(entry->event));
+ IREE_TRACE_ZONE_END(z1);
}
- status = entry->callback(entry->user_data);
- if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
- break;
- }
+
+ status =
+ iree_status_join(status, entry->callback(status, entry->user_data));
if (IREE_UNLIKELY(entry->created_event)) {
- IREE_CUDA_IGNORE_ERROR(completion_area->symbols,
- cuEventDestroy(entry->event));
+ status = iree_status_join(
+ status, IREE_CURESULT_TO_STATUS(actions->symbols,
+ cuEventDestroy(entry->event)));
}
- iree_allocator_free(completion_area->host_allocator, entry);
-
- // Now we fully executed and cleaned up this entry. Decrease the work
- // items counter.
- iree_hal_cuda_queue_decrement_completion_count(completion_area);
+ iree_allocator_free(actions->host_allocator, entry);
}
- return status;
+
+ if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
+ iree_hal_cuda_pending_queue_actions_fail(actions, status);
+ iree_status_ignore(status);
+ }
+
+ IREE_TRACE_ZONE_END(z0);
}
// The main function for the completion worker thread.
static int iree_hal_cuda_completion_execute(
- iree_hal_cuda_completion_area_t* completion_area) {
- iree_hal_cuda_completion_list_t* worklist = &completion_area->completion_list;
+ iree_hal_cuda_pending_queue_actions_t* actions) {
+ iree_hal_cuda_completion_area_t* completion_area = &actions->completion_area;
iree_status_t status = IREE_CURESULT_TO_STATUS(
- completion_area->symbols, cuCtxSetCurrent(completion_area->context),
+ actions->symbols, cuCtxSetCurrent(actions->cuda_context),
"cuCtxSetCurrent");
if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
- iree_hal_cuda_completion_wait_pending_completion_items(completion_area);
- iree_hal_cuda_post_error_to_completion_state(completion_area,
- iree_status_code(status));
- return -1;
+ iree_hal_cuda_pending_queue_actions_fail(actions, status);
+ iree_status_ignore(status);
}
while (true) {
iree_notification_await(
&completion_area->state_notification,
- (iree_condition_fn_t)
- iree_hal_cuda_completion_has_incoming_request_or_error,
+ (iree_condition_fn_t)iree_hal_cuda_completion_has_incoming_request,
completion_area, iree_infinite_timeout());
- IREE_TRACE_ZONE_BEGIN(z0);
// Immediately flip the state to idle waiting if and only if the previous
// state is workload pending. We do it before processing ready list to make
// sure that we don't accidentally ignore new workload pushed after done
// ready list processing but before overwriting the state from this worker
- // thread. Also we don't want to overwrite other exit states. So we need to
- // perform atomic compare and exchange here.
- iree_hal_cuda_worker_state_t prev_state =
- IREE_HAL_CUDA_WORKER_STATE_WORKLOAD_PENDING;
- iree_atomic_compare_exchange_strong_int32(
- &completion_area->worker_state, /*expected=*/&prev_state,
- /*desired=*/IREE_HAL_CUDA_WORKER_STATE_IDLE_WAITING,
- /*order_succ=*/iree_memory_order_acq_rel,
- /*order_fail=*/iree_memory_order_acquire);
+ // thread.
+ iree_atomic_store_int32(&completion_area->worker_state,
+ IREE_HAL_CUDA_WORKER_STATE_IDLE_WAITING,
+ iree_memory_order_release);
+ iree_hal_cuda_worker_process_completion(actions);
- int32_t worker_state = iree_atomic_load_int32(
- &completion_area->worker_state, iree_memory_order_acquire);
- // Exit if CUDA callbacks have posted any errors.
- if (IREE_UNLIKELY(worker_state == IREE_HAL_CUDA_WORKER_STATE_EXIT_ERROR)) {
- iree_hal_cuda_completion_wait_pending_completion_items(completion_area);
- IREE_TRACE_ZONE_END(z0);
- return -1;
- }
- // Check if we received request to stop processing and exit this thread.
- bool should_exit =
- (worker_state == IREE_HAL_CUDA_WORKER_STATE_EXIT_REQUESTED);
-
- iree_status_t status =
- iree_hal_cuda_worker_process_completion(worklist, completion_area);
-
- if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
- iree_hal_cuda_completion_wait_pending_completion_items(completion_area);
- iree_hal_cuda_post_error_to_completion_state(completion_area,
- iree_status_code(status));
- IREE_TRACE_ZONE_END(z0);
- return -1;
- }
-
- if (IREE_UNLIKELY(should_exit &&
- iree_hal_cuda_completion_has_no_pending_completion_items(
- completion_area))) {
- iree_hal_cuda_completion_wait_pending_completion_items(completion_area);
- // Signal that this thread is committed to exit.
- // This state has a priority that is only lower than error exit.
- // A CUDA callback may have posted an error, make sure we don't
- // overwrite this error state.
- iree_hal_cuda_worker_state_t prev_state =
- IREE_HAL_CUDA_WORKER_STATE_EXIT_REQUESTED;
- iree_atomic_compare_exchange_strong_int32(
- &completion_area->worker_state, /*expected=*/&prev_state,
- /*desired=*/IREE_HAL_CUDA_WORKER_STATE_EXIT_COMMITTED,
- /*order_succ=*/iree_memory_order_acq_rel,
- /*order_fail=*/iree_memory_order_acquire);
- iree_notification_post(&completion_area->exit_notification,
- IREE_ALL_WAITERS);
- IREE_TRACE_ZONE_END(z0);
+ iree_slim_mutex_lock(&actions->action_mutex);
+ if (IREE_UNLIKELY(actions->exit_requested &&
+ !actions->pending_work_items_count)) {
+ iree_slim_mutex_unlock(&actions->action_mutex);
return 0;
}
- IREE_TRACE_ZONE_END(z0);
+ iree_slim_mutex_unlock(&actions->action_mutex);
}
return 0;
@@ -1537,20 +1439,20 @@
// The main function for the ready-list processing worker thread.
static int iree_hal_cuda_worker_execute(
- iree_hal_cuda_working_area_t* working_area) {
- iree_hal_cuda_entry_list_t* worklist = &working_area->ready_worklist;
+ iree_hal_cuda_pending_queue_actions_t* actions) {
+ iree_hal_cuda_working_area_t* working_area = &actions->working_area;
// Cuda stores thread-local data based on the device. Some cuda commands pull
- // the device from there, this will cause failures when using it with other
- // devices (or streams from other devices). Force the correct device onto
- // this thread.
+ // the device from there, and it defaults to device 0 (e.g. cuEventCreate),
+ // this will cause failures when using it with other devices (or streams from
+ // other devices). Force the correct device onto this thread.
iree_status_t status = IREE_CURESULT_TO_STATUS(
- working_area->symbols, cuCtxSetCurrent(working_area->context),
+ actions->symbols, cuCtxSetCurrent(actions->cuda_context),
"cuCtxSetCurrent");
if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
- iree_hal_cuda_worker_wait_pending_work_items(working_area);
- iree_hal_cuda_post_error_to_worker_state(working_area,
- iree_status_code(status));
+ iree_hal_cuda_pending_queue_actions_fail(actions, status);
+ iree_status_ignore(status);
+ // We can safely exit here because there are no actions in flight yet.
return -1;
}
@@ -1564,63 +1466,29 @@
// host stream callbacks.
iree_notification_await(
&working_area->state_notification,
- (iree_condition_fn_t)iree_hal_cuda_worker_has_incoming_request_or_error,
+ (iree_condition_fn_t)iree_hal_cuda_worker_has_incoming_request,
working_area, iree_infinite_timeout());
// Immediately flip the state to idle waiting if and only if the previous
// state is workload pending. We do it before processing ready list to make
// sure that we don't accidentally ignore new workload pushed after done
// ready list processing but before overwriting the state from this worker
- // thread. Also we don't want to overwrite other exit states. So we need to
- // perform atomic compare and exchange here.
- iree_hal_cuda_worker_state_t prev_state =
- IREE_HAL_CUDA_WORKER_STATE_WORKLOAD_PENDING;
- iree_atomic_compare_exchange_strong_int32(
- &working_area->worker_state, /*expected=*/&prev_state,
- /*desired=*/IREE_HAL_CUDA_WORKER_STATE_IDLE_WAITING,
- /*order_succ=*/iree_memory_order_acq_rel,
- /*order_fail=*/iree_memory_order_acquire);
+ // thread.
+ iree_atomic_store_int32(&working_area->worker_state,
+ IREE_HAL_CUDA_WORKER_STATE_IDLE_WAITING,
+ iree_memory_order_release);
- int32_t worker_state = iree_atomic_load_int32(&working_area->worker_state,
- iree_memory_order_acquire);
- // Exit if CUDA callbacks have posted any errors.
- if (IREE_UNLIKELY(worker_state == IREE_HAL_CUDA_WORKER_STATE_EXIT_ERROR)) {
- iree_hal_cuda_worker_wait_pending_work_items(working_area);
- return -1;
- }
- // Check if we received request to stop processing and exit this thread.
- bool should_exit =
- (worker_state == IREE_HAL_CUDA_WORKER_STATE_EXIT_REQUESTED);
+ iree_hal_cuda_worker_process_ready_list(actions);
- // Process the ready list. We also want this even requested to exit.
- iree_status_t status = iree_hal_cuda_worker_process_ready_list(
- working_area->host_allocator, worklist);
- if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
- iree_hal_cuda_worker_wait_pending_work_items(working_area);
- iree_hal_cuda_post_error_to_worker_state(working_area,
- iree_status_code(status));
- return -1;
- }
-
- if (IREE_UNLIKELY(
- should_exit &&
- iree_hal_cuda_worker_has_no_pending_work_items(working_area))) {
- iree_hal_cuda_worker_wait_pending_work_items(working_area);
- // Signal that this thread is committed to exit.
- // This state has a priority that is only lower than error exit.
- // A CUDA callback may have posted an error, make sure we don't
- // overwrite this error state.
- iree_hal_cuda_worker_state_t prev_state =
- IREE_HAL_CUDA_WORKER_STATE_EXIT_REQUESTED;
- iree_atomic_compare_exchange_strong_int32(
- &working_area->worker_state, /*expected=*/&prev_state,
- /*desired=*/IREE_HAL_CUDA_WORKER_STATE_EXIT_COMMITTED,
- /*order_succ=*/iree_memory_order_acq_rel,
- /*order_fail=*/iree_memory_order_acquire);
- iree_notification_post(&working_area->exit_notification,
- IREE_ALL_WAITERS);
+ iree_slim_mutex_lock(&actions->action_mutex);
+ if (IREE_UNLIKELY(actions->exit_requested &&
+ !actions->pending_work_items_count)) {
+ iree_slim_mutex_unlock(&actions->action_mutex);
+ iree_hal_cuda_pending_queue_actions_notify_completion_thread(
+ &actions->completion_area);
return 0;
}
+ iree_slim_mutex_unlock(&actions->action_mutex);
}
return 0;
}
diff --git a/runtime/src/iree/hal/drivers/cuda/tracing.c b/runtime/src/iree/hal/drivers/cuda/tracing.c
index 68df052..ea508be 100644
--- a/runtime/src/iree/hal/drivers/cuda/tracing.c
+++ b/runtime/src/iree/hal/drivers/cuda/tracing.c
@@ -16,7 +16,7 @@
// To prevent spilling pages we leave some room for the context structure.
#define IREE_HAL_CUDA_TRACING_DEFAULT_QUERY_CAPACITY (16 * 1024 - 256)
-// iree_hal_hip_tracing_context_event_t contains a cuEvent that is used to
+// iree_hal_cuda_tracing_context_event_t contains a cuEvent that is used to
// record timestamps for tracing GPU execution. In this struct, there are also
// two linked lists that the current event may be added to during its lifetime.
//