[cuda] Port over changes made in hip backend to cuda (#18079)

This brings over 2 changes from hip over into cuda.

Specifically replacing `cudaLaunchHostFunc` with a dedicated waiting
thread, and replacing the use of atomic_slist (which has
non-deterministic ordering) with a well-ordered linked list.

The original commits on main branch were #17925 and #18048

---------

Signed-off-by: Andrew Woloszyn <andrew.woloszyn@gmail.com>
diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_device.c b/runtime/src/iree/hal/drivers/cuda/cuda_device.c
index a78dc6f..011b5d9 100644
--- a/runtime/src/iree/hal/drivers/cuda/cuda_device.c
+++ b/runtime/src/iree/hal/drivers/cuda/cuda_device.c
@@ -61,8 +61,6 @@
   // TODO: Support multiple device streams.
   // The CUstream used to issue device kernels and allocations.
   CUstream dispatch_cu_stream;
-  // The CUstream used to issue host callback functions.
-  CUstream callback_cu_stream;
 
   iree_hal_cuda_tracing_context_t* tracing_context;
 
@@ -129,7 +127,7 @@
 static iree_status_t iree_hal_cuda_device_create_internal(
     iree_hal_driver_t* driver, iree_string_view_t identifier,
     const iree_hal_cuda_device_params_t* params, CUdevice cu_device,
-    CUstream dispatch_stream, CUstream callback_stream, CUcontext context,
+    CUstream dispatch_stream, CUcontext context,
     const iree_hal_cuda_dynamic_symbols_t* cuda_symbols,
     const iree_hal_cuda_nccl_dynamic_symbols_t* nccl_symbols,
     iree_allocator_t host_allocator, iree_hal_device_t** out_device) {
@@ -152,11 +150,10 @@
   device->cu_context = context;
   device->cu_device = cu_device;
   device->dispatch_cu_stream = dispatch_stream;
-  device->callback_cu_stream = callback_stream;
   device->host_allocator = host_allocator;
 
   iree_status_t status = iree_hal_cuda_pending_queue_actions_create(
-      cuda_symbols, &device->block_pool, host_allocator,
+      cuda_symbols, cu_device, context, &device->block_pool, host_allocator,
       &device->pending_queue_actions);
 
   // Enable tracing for the (currently only) stream - no-op if disabled.
@@ -230,20 +227,13 @@
     status = IREE_CURESULT_TO_STATUS(
         cuda_symbols, cuStreamCreate(&dispatch_stream, CU_STREAM_NON_BLOCKING));
   }
-  // Create the default callback stream for the device.
-  CUstream callback_stream = NULL;
-  if (iree_status_is_ok(status)) {
-    status = IREE_CURESULT_TO_STATUS(
-        cuda_symbols, cuStreamCreate(&callback_stream, CU_STREAM_NON_BLOCKING));
-  }
 
   if (iree_status_is_ok(status)) {
     status = iree_hal_cuda_device_create_internal(
-        driver, identifier, params, device, dispatch_stream, callback_stream,
-        context, cuda_symbols, nccl_symbols, host_allocator, out_device);
+        driver, identifier, params, device, dispatch_stream, context,
+        cuda_symbols, nccl_symbols, host_allocator, out_device);
   } else {
     // Release resources we have accquired thus far.
-    if (callback_stream) cuda_symbols->cuStreamDestroy(callback_stream);
     if (dispatch_stream) cuda_symbols->cuStreamDestroy(dispatch_stream);
     if (context) cuda_symbols->cuDevicePrimaryCtxRelease(device);
   }
@@ -331,7 +321,6 @@
   if (device->host_event_pool) iree_event_pool_free(device->host_event_pool);
 
   IREE_CUDA_IGNORE_ERROR(symbols, cuStreamDestroy(device->dispatch_cu_stream));
-  IREE_CUDA_IGNORE_ERROR(symbols, cuStreamDestroy(device->callback_cu_stream));
 
   IREE_CUDA_IGNORE_ERROR(symbols, cuDevicePrimaryCtxRelease(device->cu_device));
 
@@ -776,8 +765,7 @@
   IREE_TRACE_ZONE_BEGIN(z0);
 
   iree_status_t status = iree_hal_cuda_pending_queue_actions_enqueue_execution(
-      base_device, device->dispatch_cu_stream, device->callback_cu_stream,
-      device->pending_queue_actions,
+      base_device, device->dispatch_cu_stream, device->pending_queue_actions,
       iree_hal_cuda_device_collect_tracing_context, device->tracing_context,
       wait_semaphore_list, signal_semaphore_list, command_buffer_count,
       command_buffers, binding_tables);
diff --git a/runtime/src/iree/hal/drivers/cuda/pending_queue_actions.c b/runtime/src/iree/hal/drivers/cuda/pending_queue_actions.c
index 42f1f86..2e964e8 100644
--- a/runtime/src/iree/hal/drivers/cuda/pending_queue_actions.c
+++ b/runtime/src/iree/hal/drivers/cuda/pending_queue_actions.c
@@ -85,8 +85,6 @@
 
   // The stream to launch main GPU workload.
   CUstream dispatch_cu_stream;
-  // The stream to launch CUDA host function callbacks.
-  CUstream callback_cu_stream;
 
   // Resource set to retain all associated resources by the payload.
   iree_hal_resource_set_t* resource_set;
@@ -186,31 +184,140 @@
 //===----------------------------------------------------------------------===//
 
 // Ready action atomic slist entry struct.
-typedef struct iree_hal_cuda_atomic_slist_entry_t {
+typedef struct iree_hal_cuda_entry_list_node_t {
   iree_hal_cuda_queue_action_t* ready_list_head;
-  iree_atomic_slist_intrusive_ptr_t slist_next;
-} iree_hal_cuda_atomic_slist_entry_t;
+  struct iree_hal_cuda_entry_list_node_t* next;
+} iree_hal_cuda_entry_list_node_t;
 
-// Ready action atomic slist.
-IREE_TYPED_ATOMIC_SLIST_WRAPPER(iree_hal_cuda_ready_action,
-                                iree_hal_cuda_atomic_slist_entry_t,
-                                offsetof(iree_hal_cuda_atomic_slist_entry_t,
-                                         slist_next));
+typedef struct iree_hal_cuda_entry_list_t {
+  iree_slim_mutex_t guard_mutex;
 
-static void iree_hal_cuda_ready_action_slist_destroy(
-    iree_hal_cuda_ready_action_slist_t* list, iree_allocator_t host_allocator) {
-  while (true) {
-    iree_hal_cuda_atomic_slist_entry_t* entry =
-        iree_hal_cuda_ready_action_slist_pop(list);
-    if (!entry) break;
-    iree_hal_cuda_queue_action_list_destroy(entry->ready_list_head);
-    iree_allocator_free(host_allocator, entry);
+  iree_hal_cuda_entry_list_node_t* head IREE_GUARDED_BY(guard_mutex);
+  iree_hal_cuda_entry_list_node_t* tail IREE_GUARDED_BY(guard_mutex);
+} iree_hal_cuda_entry_list_t;
+
+static iree_hal_cuda_entry_list_node_t* iree_hal_cuda_entry_list_pop(
+    iree_hal_cuda_entry_list_t* list) {
+  iree_hal_cuda_entry_list_node_t* out = NULL;
+  iree_slim_mutex_lock(&list->guard_mutex);
+  if (list->head) {
+    out = list->head;
+    list->head = list->head->next;
+    if (out == list->tail) {
+      list->tail = NULL;
+    }
   }
-  iree_hal_cuda_ready_action_slist_deinitialize(list);
+  iree_slim_mutex_unlock(&list->guard_mutex);
+  return out;
+}
+
+void iree_hal_cuda_entry_list_push(iree_hal_cuda_entry_list_t* list,
+                                   iree_hal_cuda_entry_list_node_t* next) {
+  iree_slim_mutex_lock(&list->guard_mutex);
+  next->next = NULL;
+  if (list->tail) {
+    list->tail->next = next;
+    list->tail = next;
+  } else {
+    list->head = next;
+    list->tail = next;
+  }
+  iree_slim_mutex_unlock(&list->guard_mutex);
+}
+
+static void iree_hal_cuda_ready_action_list_deinitialize(
+    iree_hal_cuda_entry_list_t* list, iree_allocator_t host_allocator) {
+  iree_hal_cuda_entry_list_node_t* head = list->head;
+  while (head) {
+    if (!head) break;
+    iree_hal_cuda_queue_action_list_destroy(head->ready_list_head);
+    list->head = head->next;
+    iree_allocator_free(host_allocator, head);
+  }
+  iree_slim_mutex_deinitialize(&list->guard_mutex);
+}
+
+static void iree_hal_cuda_ready_action_list_initialize(
+    iree_hal_cuda_entry_list_t* list) {
+  list->head = NULL;
+  list->tail = NULL;
+  iree_slim_mutex_initialize(&list->guard_mutex);
+}
+
+// Ready action atomic slist entry struct.
+typedef struct iree_hal_cuda_completion_list_node_t {
+  // The callback and user data for that callback. To be called
+  // when the associated event has completed.
+  iree_status_t (*callback)(void* user_data);
+  void* user_data;
+  // The event to wait for on the completion thread.
+  CUevent event;
+  // If this event was created just for the completion thread, and therefore
+  // needs to be cleaned up.
+  bool created_event;
+  struct iree_hal_cuda_completion_list_node_t* next;
+} iree_hal_cuda_completion_list_node_t;
+
+typedef struct iree_hal_cuda_completion_list_t {
+  iree_slim_mutex_t guard_mutex;
+  iree_hal_cuda_completion_list_node_t* head IREE_GUARDED_BY(guard_mutex);
+  iree_hal_cuda_completion_list_node_t* tail IREE_GUARDED_BY(guard_mutex);
+} iree_hal_cuda_completion_list_t;
+
+static iree_hal_cuda_completion_list_node_t* iree_hal_cuda_completion_list_pop(
+    iree_hal_cuda_completion_list_t* list) {
+  iree_hal_cuda_completion_list_node_t* out = NULL;
+  iree_slim_mutex_lock(&list->guard_mutex);
+  if (list->head) {
+    out = list->head;
+    list->head = list->head->next;
+    if (out == list->tail) {
+      list->tail = NULL;
+    }
+  }
+  iree_slim_mutex_unlock(&list->guard_mutex);
+  return out;
+}
+
+void iree_hal_cuda_completion_list_push(
+    iree_hal_cuda_completion_list_t* list,
+    iree_hal_cuda_completion_list_node_t* next) {
+  iree_slim_mutex_lock(&list->guard_mutex);
+  next->next = NULL;
+  if (list->tail) {
+    list->tail->next = next;
+    list->tail = next;
+  } else {
+    list->head = next;
+    list->tail = next;
+  }
+  iree_slim_mutex_unlock(&list->guard_mutex);
+}
+
+static void iree_hal_cuda_completion_list_initialize(
+    iree_hal_cuda_completion_list_t* list) {
+  list->head = NULL;
+  list->tail = NULL;
+  iree_slim_mutex_initialize(&list->guard_mutex);
+}
+
+static void iree_hal_cuda_completion_list_deinitialize(
+    iree_hal_cuda_completion_list_t* list,
+    const iree_hal_cuda_dynamic_symbols_t* symbols,
+    iree_allocator_t host_allocator) {
+  iree_hal_cuda_completion_list_node_t* head = list->head;
+  while (head) {
+    if (head->created_event) {
+      IREE_CUDA_IGNORE_ERROR(symbols, cuEventDestroy(head->event));
+    }
+    list->head = list->head->next;
+    iree_allocator_free(host_allocator, head);
+  }
+  iree_slim_mutex_deinitialize(&list->guard_mutex);
 }
 
 static iree_hal_cuda_queue_action_t* iree_hal_cuda_atomic_slist_entry_pop_front(
-    iree_hal_cuda_atomic_slist_entry_t* list) {
+    iree_hal_cuda_entry_list_node_t* list) {
   IREE_ASSERT(list->ready_list_head);
 
   iree_hal_cuda_queue_action_t* action = list->ready_list_head;
@@ -255,8 +362,8 @@
   // Notification to the parent thread to indicate the worker committed exiting.
   // TODO: maybe remove this. We can just wait on the worker thread to exit.
   iree_notification_t exit_notification;
-  iree_hal_cuda_ready_action_slist_t ready_worklist;  // atomic
-  iree_atomic_int32_t worker_state;                   // atomic
+  iree_hal_cuda_entry_list_t ready_worklist;
+  iree_atomic_int32_t worker_state;  // atomic
   // TODO: use status to provide more context for the error.
   iree_atomic_intptr_t error_code;  // atomic
 
@@ -272,14 +379,49 @@
       IREE_GUARDED_BY(pending_work_items_count_mutex);
 
   iree_allocator_t host_allocator;  // const
+
+  const iree_hal_cuda_dynamic_symbols_t* symbols;
+  CUcontext context;
 } iree_hal_cuda_working_area_t;
 
+// This data structure is shared by the parent thread. It is responsible
+// for dispatching callbacks when work items complete.
+
+// This replaces the use of cuLaunchHostFunc, which causes the stream to block
+// and wait for the CPU work to complete. It also picks up completed
+// events with significantly less latency than cuLaunchHostFunc.
+
+typedef struct iree_hal_cuda_completion_area_t {
+  // Notification from the parent thread to request completion state changes.
+  iree_notification_t state_notification;
+  // Notification to the parent thread to indicate the worker committed exiting.
+  iree_notification_t exit_notification;
+  iree_hal_cuda_completion_list_t completion_list;
+  iree_atomic_int32_t worker_state;  // atomic
+
+  iree_atomic_intptr_t error_code;  // atomic
+
+  // The number of asynchronous completions items that are scheduled and not
+  // yet waited on.
+  // We need to wait for them to finish before destroying the context.
+  iree_slim_mutex_t pending_completion_count_mutex;
+  iree_notification_t pending_completion_count_notification;
+  int32_t pending_completion_count
+      IREE_GUARDED_BY(pending_completion_count_mutex);
+
+  iree_allocator_t host_allocator;
+
+  const iree_hal_cuda_dynamic_symbols_t* symbols;
+  CUcontext context;
+} iree_hal_cuda_completion_area_t;
+
 static void iree_hal_cuda_working_area_initialize(
-    iree_allocator_t host_allocator,
+    iree_allocator_t host_allocator, CUcontext context,
+    const iree_hal_cuda_dynamic_symbols_t* symbols,
     iree_hal_cuda_working_area_t* working_area) {
   iree_notification_initialize(&working_area->state_notification);
   iree_notification_initialize(&working_area->exit_notification);
-  iree_hal_cuda_ready_action_slist_initialize(&working_area->ready_worklist);
+  iree_hal_cuda_ready_action_list_initialize(&working_area->ready_worklist);
   iree_atomic_store_int32(&working_area->worker_state,
                           IREE_HAL_CUDA_WORKER_STATE_IDLE_WAITING,
                           iree_memory_order_release);
@@ -290,12 +432,14 @@
       &working_area->pending_work_items_count_notification);
   working_area->pending_work_items_count = 0;
   working_area->host_allocator = host_allocator;
+  working_area->symbols = symbols;
+  working_area->context = context;
 }
 
 static void iree_hal_cuda_working_area_deinitialize(
     iree_hal_cuda_working_area_t* working_area) {
-  iree_hal_cuda_ready_action_slist_destroy(&working_area->ready_worklist,
-                                           working_area->host_allocator);
+  iree_hal_cuda_ready_action_list_deinitialize(&working_area->ready_worklist,
+                                               working_area->host_allocator);
   iree_notification_deinitialize(&working_area->exit_notification);
   iree_notification_deinitialize(&working_area->state_notification);
   iree_slim_mutex_deinitialize(&working_area->pending_work_items_count_mutex);
@@ -303,10 +447,47 @@
       &working_area->pending_work_items_count_notification);
 }
 
+static void iree_hal_cuda_completion_area_initialize(
+    iree_allocator_t host_allocator, CUcontext context,
+    const iree_hal_cuda_dynamic_symbols_t* symbols,
+    iree_hal_cuda_completion_area_t* completion_area) {
+  iree_notification_initialize(&completion_area->state_notification);
+  iree_notification_initialize(&completion_area->exit_notification);
+  iree_hal_cuda_completion_list_initialize(&completion_area->completion_list);
+  iree_atomic_store_int32(&completion_area->worker_state,
+                          IREE_HAL_CUDA_WORKER_STATE_IDLE_WAITING,
+                          iree_memory_order_release);
+  iree_atomic_store_int32(&completion_area->error_code, IREE_STATUS_OK,
+                          iree_memory_order_release);
+  iree_slim_mutex_initialize(&completion_area->pending_completion_count_mutex);
+  iree_notification_initialize(
+      &completion_area->pending_completion_count_notification);
+  completion_area->pending_completion_count = 0;
+  completion_area->host_allocator = host_allocator;
+  completion_area->symbols = symbols;
+  completion_area->context = context;
+}
+
+static void iree_hal_cuda_completion_area_deinitialize(
+    iree_hal_cuda_completion_area_t* completion_area) {
+  iree_hal_cuda_completion_list_deinitialize(&completion_area->completion_list,
+                                             completion_area->symbols,
+                                             completion_area->host_allocator);
+  iree_notification_deinitialize(&completion_area->exit_notification);
+  iree_notification_deinitialize(&completion_area->state_notification);
+  iree_slim_mutex_deinitialize(
+      &completion_area->pending_completion_count_mutex);
+  iree_notification_deinitialize(
+      &completion_area->pending_completion_count_notification);
+}
+
 // The main function for the ready-list processing worker thread.
 static int iree_hal_cuda_worker_execute(
     iree_hal_cuda_working_area_t* working_area);
 
+static int iree_hal_cuda_completion_execute(
+    iree_hal_cuda_completion_area_t* working_area);
+
 //===----------------------------------------------------------------------===//
 // Pending queue actions
 //===----------------------------------------------------------------------===//
@@ -333,16 +514,28 @@
   // The worker thread that monitors incoming requests and issues ready actions
   // to the GPU.
   iree_thread_t* worker_thread;
+
+  // Worker thread to wait on completion events instead of running
+  // synchronous completion callbacks
+  iree_thread_t* completion_thread;
+
   // The worker's working area; data exchange place with the parent thread.
   iree_hal_cuda_working_area_t working_area;
+
+  // Completion thread's working area.
+  iree_hal_cuda_completion_area_t completion_area;
+
+  // The associated cuda device.
+  CUdevice device;
 };
 
 static const iree_hal_resource_vtable_t
     iree_hal_cuda_pending_queue_actions_vtable;
 
 iree_status_t iree_hal_cuda_pending_queue_actions_create(
-    const iree_hal_cuda_dynamic_symbols_t* symbols,
-    iree_arena_block_pool_t* block_pool, iree_allocator_t host_allocator,
+    const iree_hal_cuda_dynamic_symbols_t* symbols, CUdevice device,
+    CUcontext context, iree_arena_block_pool_t* block_pool,
+    iree_allocator_t host_allocator,
     iree_hal_cuda_pending_queue_actions_t** out_actions) {
   IREE_ASSERT_ARGUMENT(symbols);
   IREE_ASSERT_ARGUMENT(block_pool);
@@ -358,12 +551,18 @@
   actions->host_allocator = host_allocator;
   actions->block_pool = block_pool;
   actions->symbols = symbols;
+  actions->device = device;
   iree_slim_mutex_initialize(&actions->action_mutex);
   memset(&actions->action_list, 0, sizeof(actions->action_list));
 
   // Initialize the working area for the ready-list processing worker.
   iree_hal_cuda_working_area_t* working_area = &actions->working_area;
-  iree_hal_cuda_working_area_initialize(host_allocator, working_area);
+  iree_hal_cuda_working_area_initialize(host_allocator, context, symbols,
+                                        working_area);
+
+  iree_hal_cuda_completion_area_t* completion_area = &actions->completion_area;
+  iree_hal_cuda_completion_area_initialize(host_allocator, context, symbols,
+                                           completion_area);
 
   // Create the ready-list processing worker itself.
   iree_thread_create_params_t params;
@@ -374,6 +573,14 @@
       (iree_thread_entry_t)iree_hal_cuda_worker_execute, working_area, params,
       actions->host_allocator, &actions->worker_thread);
 
+  params.name = IREE_SV("done_worker");
+  params.create_suspended = false;
+  if (iree_status_is_ok(status)) {
+    status = iree_thread_create(
+        (iree_thread_entry_t)iree_hal_cuda_completion_execute, completion_area,
+        params, actions->host_allocator, &actions->completion_thread);
+  }
+
   if (iree_status_is_ok(status)) {
     *out_actions = actions;
   } else {
@@ -392,12 +599,16 @@
 static bool iree_hal_cuda_worker_committed_exiting(
     iree_hal_cuda_working_area_t* working_area);
 
+static bool iree_hal_cuda_completion_committed_exiting(
+    iree_hal_cuda_completion_area_t* working_area);
+
 void iree_hal_cuda_pending_queue_actions_destroy(
     iree_hal_resource_t* base_actions) {
   iree_hal_cuda_pending_queue_actions_t* actions =
       iree_hal_cuda_pending_queue_actions_cast(base_actions);
   iree_allocator_t host_allocator = actions->host_allocator;
   iree_hal_cuda_working_area_t* working_area = &actions->working_area;
+  iree_hal_cuda_completion_area_t* completion_area = &actions->completion_area;
   IREE_TRACE_ZONE_BEGIN(z0);
 
   // Request the worker to exit.
@@ -420,6 +631,25 @@
   iree_thread_release(actions->worker_thread);
   iree_hal_cuda_working_area_deinitialize(working_area);
 
+  // Request the completion thread to exit.
+  prev_state = (iree_hal_cuda_worker_state_t)iree_atomic_exchange_int32(
+      &completion_area->worker_state, IREE_HAL_CUDA_WORKER_STATE_EXIT_REQUESTED,
+      iree_memory_order_acq_rel);
+  iree_notification_post(&completion_area->state_notification,
+                         IREE_ALL_WAITERS);
+
+  // Check potential exit states from the completion thread.
+  if (prev_state != IREE_HAL_CUDA_WORKER_STATE_EXIT_ERROR) {
+    // Wait until the completion thread acknowledged exiting.
+    iree_notification_await(
+        &completion_area->exit_notification,
+        (iree_condition_fn_t)iree_hal_cuda_completion_committed_exiting,
+        completion_area, iree_infinite_timeout());
+  }
+
+  iree_thread_release(actions->completion_thread);
+  iree_hal_cuda_completion_area_deinitialize(completion_area);
+
   iree_slim_mutex_deinitialize(&actions->action_mutex);
   iree_hal_cuda_queue_action_list_destroy(actions->action_list.head);
   iree_allocator_free(host_allocator, actions);
@@ -468,9 +698,23 @@
   iree_slim_mutex_unlock(&working_area->pending_work_items_count_mutex);
 }
 
+static void iree_hal_cuda_queue_decrement_completion_count(
+    iree_hal_cuda_completion_area_t* completion_area) {
+  iree_slim_mutex_lock(&completion_area->pending_completion_count_mutex);
+  --completion_area->pending_completion_count;
+  if (completion_area->pending_completion_count == 0) {
+    // Notify inside the lock to make sure that we are done touching anything
+    // since the context may get destroyed in the meantime.
+    iree_notification_post(
+        &completion_area->pending_completion_count_notification,
+        IREE_ALL_WAITERS);
+  }
+  iree_slim_mutex_unlock(&completion_area->pending_completion_count_mutex);
+}
+
 iree_status_t iree_hal_cuda_pending_queue_actions_enqueue_execution(
     iree_hal_device_t* device, CUstream dispatch_stream,
-    CUstream callback_stream, iree_hal_cuda_pending_queue_actions_t* actions,
+    iree_hal_cuda_pending_queue_actions_t* actions,
     iree_hal_cuda_pending_action_cleanup_callback_t cleanup_callback,
     void* callback_user_data,
     const iree_hal_semaphore_list_t wait_semaphore_list,
@@ -519,7 +763,6 @@
   action->kind = IREE_HAL_CUDA_QUEUE_ACTION_TYPE_EXECUTION;
   action->device = device;
   action->dispatch_cu_stream = dispatch_stream;
-  action->callback_cu_stream = callback_stream;
 
   // Initialize scratch fields.
   action->event_count = 0;
@@ -645,13 +888,27 @@
   iree_notification_post(&working_area->state_notification, IREE_ALL_WAITERS);
 }
 
+static void iree_hal_cuda_post_error_to_completion_state(
+    iree_hal_cuda_completion_area_t* completion_area, iree_status_code_t code) {
+  // Write error code, but don't overwrite existing error codes.
+  intptr_t prev_error_code = IREE_STATUS_OK;
+  iree_atomic_compare_exchange_strong_int32(
+      &completion_area->error_code, /*expected=*/&prev_error_code,
+      /*desired=*/code,
+      /*order_succ=*/iree_memory_order_acq_rel,
+      /*order_fail=*/iree_memory_order_acquire);
+
+  // This state has the highest priority so just overwrite.
+  iree_atomic_store_int32(&completion_area->worker_state,
+                          IREE_HAL_CUDA_WORKER_STATE_EXIT_ERROR,
+                          iree_memory_order_release);
+  iree_notification_post(&completion_area->state_notification,
+                         IREE_ALL_WAITERS);
+}
+
 // Releases resources after action completion on the GPU and advances timeline
 // and pending actions queue.
-//
-// This is the CUDA host function callback to cudaLaunchHostFunc(), invoked by a
-// CUDA driver thread. Note that code in this function MUST NOT invoke any GPU
-// API under the hood to avoid potential deadlock.
-static void iree_hal_cuda_execution_device_signal_host_callback(
+static iree_status_t iree_hal_cuda_execution_device_signal_host_callback(
     void* user_data) {
   IREE_TRACE_ZONE_BEGIN(z0);
   iree_hal_cuda_queue_action_t* action =
@@ -694,6 +951,7 @@
   iree_hal_cuda_queue_decrement_work_items_count(&actions->working_area);
 
   IREE_TRACE_ZONE_END(z0);
+  return status;
 }
 
 // Issues the given kernel dispatch |action| to the GPU.
@@ -770,6 +1028,7 @@
   }
   IREE_TRACE_ZONE_END(z_dispatch_command_buffers);
 
+  CUevent completion_event = NULL;
   // Last record CUevent signals in the dispatch stream.
   for (iree_host_size_t i = 0; i < action->signal_semaphore_list.count; ++i) {
     // Grab a CUevent for this semaphore value signaling.
@@ -783,17 +1042,28 @@
     IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR(
         z0, symbols, cuEventRecord(event, action->dispatch_cu_stream),
         "cuEventRecord");
-    // Let the callback stream to wait on the CUevent.
-    IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR(
-        z0, symbols,
-        cuStreamWaitEvent(action->callback_cu_stream, event,
-                          CU_EVENT_WAIT_DEFAULT),
-        "cuStreamWaitEvent");
+    completion_event = event;
   }
 
+  bool created_event = false;
+  // In the case where we issue an execution and there are signal semaphores
+  // we can re-use those as a wait event. However if there are no signals
+  // then we create one. In my testing this is not a common case.
+  if (IREE_UNLIKELY(!completion_event)) {
+    IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, symbols, cuEventCreate(&completion_event, CU_EVENT_DISABLE_TIMING),
+        "cuEventCreate");
+    created_event = true;
+  }
+
+  IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, symbols, cuEventRecord(completion_event, action->dispatch_cu_stream),
+      "cuEventRecord");
+
   iree_slim_mutex_lock(
       &action->owning_actions->working_area.pending_work_items_count_mutex);
-  // One work item is the host stream callback.
+  // One work item is the callback that makes it across from the
+  // completion thread.
   // The other is the cleanup of the action.
   action->owning_actions->working_area.pending_work_items_count += 2;
   iree_slim_mutex_unlock(
@@ -801,12 +1071,55 @@
   // Now launch a host function on the callback stream to advance the semaphore
   // timeline.
 
-  IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR(
-      z0, symbols,
-      cuLaunchHostFunc(action->callback_cu_stream,
-                       iree_hal_cuda_execution_device_signal_host_callback,
-                       action),
-      "cuLaunchHostFunc");
+  iree_hal_cuda_completion_list_node_t* entry = NULL;
+  // TODO: avoid host allocator malloc; use some pool for the allocation.
+  iree_status_t status = iree_allocator_malloc(
+      action->owning_actions->host_allocator, sizeof(*entry), (void**)&entry);
+
+  if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
+    IREE_TRACE_ZONE_END(z0);
+    return status;
+  }
+
+  // Now push the ready list to the worker and have it to issue the actions to
+  // the GPU.
+  entry->event = completion_event;
+  entry->created_event = created_event;
+  entry->callback = iree_hal_cuda_execution_device_signal_host_callback;
+  entry->user_data = action;
+  iree_hal_cuda_completion_list_push(
+      &action->owning_actions->completion_area.completion_list, entry);
+
+  iree_slim_mutex_lock(
+      &action->owning_actions->completion_area.pending_completion_count_mutex);
+
+  action->owning_actions->completion_area.pending_completion_count += 1;
+
+  iree_slim_mutex_unlock(
+      &action->owning_actions->completion_area.pending_completion_count_mutex);
+
+  // We can only overwrite the worker state if the previous state is idle
+  // waiting; we cannot overwrite exit related states. so we need to perform
+  // atomic compare and exchange here.
+  iree_hal_cuda_worker_state_t prev_state =
+      IREE_HAL_CUDA_WORKER_STATE_IDLE_WAITING;
+  iree_atomic_compare_exchange_strong_int32(
+      &action->owning_actions->completion_area.worker_state,
+      /*expected=*/&prev_state,
+      /*desired=*/IREE_HAL_CUDA_WORKER_STATE_WORKLOAD_PENDING,
+      /*order_succ=*/iree_memory_order_acq_rel,
+      /*order_fail=*/iree_memory_order_acquire);
+  iree_notification_post(
+      &action->owning_actions->completion_area.state_notification,
+      IREE_ALL_WAITERS);
+
+  // Handle potential error cases from the worker thread.
+  if (prev_state == IREE_HAL_CUDA_WORKER_STATE_EXIT_ERROR) {
+    iree_status_code_t code = iree_atomic_load_int32(
+        &action->owning_actions->completion_area.error_code,
+        iree_memory_order_acquire);
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_status_from_code(code));
+  }
 
   IREE_TRACE_ZONE_END(z0);
   return iree_ok_status();
@@ -928,7 +1241,7 @@
     return status;
   }
 
-  iree_hal_cuda_atomic_slist_entry_t* entry = NULL;
+  iree_hal_cuda_entry_list_node_t* entry = NULL;
   // TODO: avoid host allocator malloc; use some pool for the allocation.
   if (iree_status_is_ok(status)) {
     status = iree_allocator_malloc(actions->host_allocator, sizeof(*entry),
@@ -946,8 +1259,7 @@
   // Now push the ready list to the worker and have it to issue the actions to
   // the GPU.
   entry->ready_list_head = ready_list.head;
-  iree_hal_cuda_ready_action_slist_push(&actions->working_area.ready_worklist,
-                                        entry);
+  iree_hal_cuda_entry_list_push(&actions->working_area.ready_worklist, entry);
 
   // We can only overwrite the worker state if the previous state is idle
   // waiting; we cannot overwrite exit related states. so we need to perform
@@ -986,6 +1298,15 @@
          value == IREE_HAL_CUDA_WORKER_STATE_EXIT_ERROR;
 }
 
+static bool iree_hal_cuda_completion_has_incoming_request_or_error(
+    iree_hal_cuda_completion_area_t* completion_area) {
+  iree_hal_cuda_worker_state_t value = iree_atomic_load_int32(
+      &completion_area->worker_state, iree_memory_order_acquire);
+  return value == IREE_HAL_CUDA_WORKER_STATE_WORKLOAD_PENDING ||
+         value == IREE_HAL_CUDA_WORKER_STATE_EXIT_REQUESTED ||
+         value == IREE_HAL_CUDA_WORKER_STATE_EXIT_ERROR;
+}
+
 static bool iree_hal_cuda_worker_committed_exiting(
     iree_hal_cuda_working_area_t* working_area) {
   return iree_atomic_load_int32(&working_area->worker_state,
@@ -993,16 +1314,22 @@
          IREE_HAL_CUDA_WORKER_STATE_EXIT_COMMITTED;
 }
 
+static bool iree_hal_cuda_completion_committed_exiting(
+    iree_hal_cuda_completion_area_t* working_area) {
+  return iree_atomic_load_int32(&working_area->worker_state,
+                                iree_memory_order_acquire) ==
+         IREE_HAL_CUDA_WORKER_STATE_EXIT_COMMITTED;
+}
+
 // Processes all ready actions in the given |worklist|.
 static iree_status_t iree_hal_cuda_worker_process_ready_list(
-    iree_allocator_t host_allocator,
-    iree_hal_cuda_ready_action_slist_t* worklist) {
+    iree_allocator_t host_allocator, iree_hal_cuda_entry_list_t* worklist) {
   IREE_TRACE_ZONE_BEGIN(z0);
 
   iree_status_t status = iree_ok_status();
   while (true) {
-    iree_hal_cuda_atomic_slist_entry_t* entry =
-        iree_hal_cuda_ready_action_slist_pop(worklist);
+    iree_hal_cuda_entry_list_node_t* entry =
+        iree_hal_cuda_entry_list_pop(worklist);
     if (!entry) break;
 
     // Process the current batch of ready actions.
@@ -1028,7 +1355,7 @@
       // Let common destruction path take care of destroying the worklist.
       // When we know all host stream callbacks are done and not touching
       // anything.
-      iree_hal_cuda_ready_action_slist_push(worklist, entry);
+      iree_hal_cuda_entry_list_push(worklist, entry);
       break;
     }
 
@@ -1047,6 +1374,14 @@
   return result;
 }
 
+static bool iree_hal_cuda_completion_has_no_pending_completion_items(
+    iree_hal_cuda_completion_area_t* completion_area) {
+  iree_slim_mutex_lock(&completion_area->pending_completion_count_mutex);
+  bool result = (completion_area->pending_completion_count == 0);
+  iree_slim_mutex_unlock(&completion_area->pending_completion_count_mutex);
+  return result;
+}
+
 // Wait for all work items to finish.
 static void iree_hal_cuda_worker_wait_pending_work_items(
     iree_hal_cuda_working_area_t* working_area) {
@@ -1060,10 +1395,164 @@
   iree_slim_mutex_unlock(&working_area->pending_work_items_count_mutex);
 }
 
+// Wait for all work items to finish.
+static void iree_hal_cuda_completion_wait_pending_completion_items(
+    iree_hal_cuda_completion_area_t* completion_area) {
+  iree_notification_await(
+      &completion_area->pending_completion_count_notification,
+      (iree_condition_fn_t)
+          iree_hal_cuda_completion_has_no_pending_completion_items,
+      completion_area, iree_infinite_timeout());
+  // Lock then unlock to make sure that all callbacks are really done.
+  // Not even touching the notification.
+  iree_slim_mutex_lock(&completion_area->pending_completion_count_mutex);
+  iree_slim_mutex_unlock(&completion_area->pending_completion_count_mutex);
+}
+
+static iree_status_t iree_hal_cuda_worker_process_completion(
+    iree_hal_cuda_completion_list_t* worklist,
+    iree_hal_cuda_completion_area_t* completion_area) {
+  iree_status_t status = iree_ok_status();
+  while (true) {
+    iree_hal_cuda_completion_list_node_t* entry =
+        iree_hal_cuda_completion_list_pop(worklist);
+    if (!entry) break;
+
+    IREE_TRACE_ZONE_BEGIN_NAMED(z1, "cuEventSynchronize");
+    CUresult result =
+        completion_area->symbols->cuEventSynchronize(entry->event);
+    IREE_TRACE_ZONE_END(z1);
+    if (IREE_UNLIKELY(result != CUDA_SUCCESS)) {
+      // Let common destruction path take care of destroying the worklist.
+      // When we know all host stream callbacks are done and not touching
+      // anything.
+      iree_hal_cuda_completion_list_push(worklist, entry);
+      status =
+          iree_make_status(IREE_STATUS_ABORTED, "could not wait on cuda event");
+      break;
+    }
+    status = entry->callback(entry->user_data);
+    if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
+      break;
+    }
+
+    if (IREE_UNLIKELY(entry->created_event)) {
+      IREE_CUDA_IGNORE_ERROR(completion_area->symbols,
+                             cuEventDestroy(entry->event));
+    }
+    iree_allocator_free(completion_area->host_allocator, entry);
+
+    // Now we fully executed and cleaned up this entry. Decrease the work
+    // items counter.
+    iree_hal_cuda_queue_decrement_completion_count(completion_area);
+  }
+  return status;
+}
+
+// The main function for the completion worker thread.
+static int iree_hal_cuda_completion_execute(
+    iree_hal_cuda_completion_area_t* completion_area) {
+  iree_hal_cuda_completion_list_t* worklist = &completion_area->completion_list;
+
+  iree_status_t status = IREE_CURESULT_TO_STATUS(
+      completion_area->symbols, cuCtxSetCurrent(completion_area->context),
+      "cuCtxSetCurrent");
+  if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
+    iree_hal_cuda_completion_wait_pending_completion_items(completion_area);
+    iree_hal_cuda_post_error_to_completion_state(completion_area,
+                                                 iree_status_code(status));
+    return -1;
+  }
+
+  while (true) {
+    iree_notification_await(
+        &completion_area->state_notification,
+        (iree_condition_fn_t)
+            iree_hal_cuda_completion_has_incoming_request_or_error,
+        completion_area, iree_infinite_timeout());
+
+    IREE_TRACE_ZONE_BEGIN(z0);
+    // Immediately flip the state to idle waiting if and only if the previous
+    // state is workload pending. We do it before processing ready list to make
+    // sure that we don't accidentally ignore new workload pushed after done
+    // ready list processing but before overwriting the state from this worker
+    // thread. Also we don't want to overwrite other exit states. So we need to
+    // perform atomic compare and exchange here.
+    iree_hal_cuda_worker_state_t prev_state =
+        IREE_HAL_CUDA_WORKER_STATE_WORKLOAD_PENDING;
+    iree_atomic_compare_exchange_strong_int32(
+        &completion_area->worker_state, /*expected=*/&prev_state,
+        /*desired=*/IREE_HAL_CUDA_WORKER_STATE_IDLE_WAITING,
+        /*order_succ=*/iree_memory_order_acq_rel,
+        /*order_fail=*/iree_memory_order_acquire);
+
+    int32_t worker_state = iree_atomic_load_int32(
+        &completion_area->worker_state, iree_memory_order_acquire);
+    // Exit if CUDA callbacks have posted any errors.
+    if (IREE_UNLIKELY(worker_state == IREE_HAL_CUDA_WORKER_STATE_EXIT_ERROR)) {
+      iree_hal_cuda_completion_wait_pending_completion_items(completion_area);
+      IREE_TRACE_ZONE_END(z0);
+      return -1;
+    }
+    // Check if we received request to stop processing and exit this thread.
+    bool should_exit =
+        (worker_state == IREE_HAL_CUDA_WORKER_STATE_EXIT_REQUESTED);
+
+    iree_status_t status =
+        iree_hal_cuda_worker_process_completion(worklist, completion_area);
+
+    if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
+      iree_hal_cuda_completion_wait_pending_completion_items(completion_area);
+      iree_hal_cuda_post_error_to_completion_state(completion_area,
+                                                   iree_status_code(status));
+      IREE_TRACE_ZONE_END(z0);
+      return -1;
+    }
+
+    if (IREE_UNLIKELY(should_exit &&
+                      iree_hal_cuda_completion_has_no_pending_completion_items(
+                          completion_area))) {
+      iree_hal_cuda_completion_wait_pending_completion_items(completion_area);
+      // Signal that this thread is committed to exit.
+      // This state has a priority that is only lower than error exit.
+      // A CUDA callback may have posted an error, make sure we don't
+      // overwrite this error state.
+      iree_hal_cuda_worker_state_t prev_state =
+          IREE_HAL_CUDA_WORKER_STATE_EXIT_REQUESTED;
+      iree_atomic_compare_exchange_strong_int32(
+          &completion_area->worker_state, /*expected=*/&prev_state,
+          /*desired=*/IREE_HAL_CUDA_WORKER_STATE_EXIT_COMMITTED,
+          /*order_succ=*/iree_memory_order_acq_rel,
+          /*order_fail=*/iree_memory_order_acquire);
+      iree_notification_post(&completion_area->exit_notification,
+                             IREE_ALL_WAITERS);
+      IREE_TRACE_ZONE_END(z0);
+      return 0;
+    }
+    IREE_TRACE_ZONE_END(z0);
+  }
+
+  return 0;
+}
+
 // The main function for the ready-list processing worker thread.
 static int iree_hal_cuda_worker_execute(
     iree_hal_cuda_working_area_t* working_area) {
-  iree_hal_cuda_ready_action_slist_t* worklist = &working_area->ready_worklist;
+  iree_hal_cuda_entry_list_t* worklist = &working_area->ready_worklist;
+
+  // Cuda stores thread-local data based on the device. Some cuda commands pull
+  // the device from there, this will cause failures when using it with other
+  // devices (or streams from other devices). Force the correct device onto
+  // this thread.
+  iree_status_t status = IREE_CURESULT_TO_STATUS(
+      working_area->symbols, cuCtxSetCurrent(working_area->context),
+      "cuCtxSetCurrent");
+  if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
+    iree_hal_cuda_worker_wait_pending_work_items(working_area);
+    iree_hal_cuda_post_error_to_worker_state(working_area,
+                                             iree_status_code(status));
+    return -1;
+  }
 
   while (true) {
     // Block waiting for incoming requests.
diff --git a/runtime/src/iree/hal/drivers/cuda/pending_queue_actions.h b/runtime/src/iree/hal/drivers/cuda/pending_queue_actions.h
index b889428..fa16e1f 100644
--- a/runtime/src/iree/hal/drivers/cuda/pending_queue_actions.h
+++ b/runtime/src/iree/hal/drivers/cuda/pending_queue_actions.h
@@ -38,8 +38,9 @@
 
 // Creates a pending actions queue.
 iree_status_t iree_hal_cuda_pending_queue_actions_create(
-    const iree_hal_cuda_dynamic_symbols_t* symbols,
-    iree_arena_block_pool_t* block_pool, iree_allocator_t host_allocator,
+    const iree_hal_cuda_dynamic_symbols_t* symbols, CUdevice device,
+    CUcontext context, iree_arena_block_pool_t* block_pool,
+    iree_allocator_t host_allocator,
     iree_hal_cuda_pending_queue_actions_t** out_actions);
 
 // Destroys the pending |actions| queue.
@@ -59,7 +60,7 @@
 // before releasing all retained resources.
 iree_status_t iree_hal_cuda_pending_queue_actions_enqueue_execution(
     iree_hal_device_t* device, CUstream dispatch_stream,
-    CUstream callback_stream, iree_hal_cuda_pending_queue_actions_t* actions,
+    iree_hal_cuda_pending_queue_actions_t* actions,
     iree_hal_cuda_pending_action_cleanup_callback_t cleanup_callback,
     void* callback_user_data,
     const iree_hal_semaphore_list_t wait_semaphore_list,