[cuda] Implement HAL semaphore using CUevent objects (#14426)

This commit adds a HAL semaphore implementation for the CUDA driver
backed by iree_event_t and CUevent objects for different synchronization
directions.

Fixes https://github.com/openxla/iree/issues/4727
Progress towards https://github.com/openxla/iree/issues/13245
diff --git a/experimental/cuda2/cuda_device.c b/experimental/cuda2/cuda_device.c
index 0d72c36..c6bc989 100644
--- a/experimental/cuda2/cuda_device.c
+++ b/experimental/cuda2/cuda_device.c
@@ -14,15 +14,19 @@
 #include "experimental/cuda2/cuda_buffer.h"
 #include "experimental/cuda2/cuda_dynamic_symbols.h"
 #include "experimental/cuda2/cuda_status_util.h"
+#include "experimental/cuda2/event_pool.h"
+#include "experimental/cuda2/event_semaphore.h"
 #include "experimental/cuda2/graph_command_buffer.h"
 #include "experimental/cuda2/memory_pools.h"
 #include "experimental/cuda2/nccl_channel.h"
 #include "experimental/cuda2/nccl_dynamic_symbols.h"
 #include "experimental/cuda2/nop_executable_cache.h"
-#include "experimental/cuda2/nop_semaphore.h"
+#include "experimental/cuda2/pending_queue_actions.h"
 #include "experimental/cuda2/pipeline_layout.h"
+#include "experimental/cuda2/timepoint_pool.h"
 #include "experimental/cuda2/tracing.h"
 #include "iree/base/internal/arena.h"
+#include "iree/base/internal/event_pool.h"
 #include "iree/base/internal/math.h"
 #include "iree/hal/utils/buffer_transfer.h"
 #include "iree/hal/utils/deferred_command_buffer.h"
@@ -53,13 +57,28 @@
 
   CUcontext cu_context;
   CUdevice cu_device;
-  // TODO: support multiple streams.
-  CUstream cu_stream;
+  // TODO: Support multiple device streams.
+  // The CUstream used to issue device kernels and allocations.
+  CUstream dispatch_cu_stream;
+  // The CUstream used to issue host callback functions.
+  CUstream callback_cu_stream;
 
   iree_hal_cuda2_tracing_context_t* tracing_context;
 
   iree_allocator_t host_allocator;
 
+  // Host/device event pools, used for backing semaphore timepoints.
+  iree_event_pool_t* host_event_pool;
+  iree_hal_cuda2_event_pool_t* device_event_pool;
+  // Timepoint pools, shared by various semaphores.
+  iree_hal_cuda2_timepoint_pool_t* timepoint_pool;
+
+  // A queue to order device workloads and relase to the GPU when constraints
+  // are met. It buffers submissions and allocations internally before they
+  // are ready. This queue couples with HAL semaphores backed by iree_event_t
+  // and CUevent objects.
+  iree_hal_cuda2_pending_queue_actions_t* pending_queue_actions;
+
   // Device memory pools and allocators.
   bool supports_memory_pools;
   iree_hal_cuda2_memory_pools_t memory_pools;
@@ -86,6 +105,7 @@
     iree_hal_cuda2_device_params_t* out_params) {
   memset(out_params, 0, sizeof(*out_params));
   out_params->arena_block_size = 32 * 1024;
+  out_params->event_pool_capacity = 32;
   out_params->queue_count = 1;
   out_params->stream_tracing = false;
   out_params->async_allocations = true;
@@ -107,9 +127,12 @@
 static iree_status_t iree_hal_cuda2_device_create_internal(
     iree_hal_driver_t* driver, iree_string_view_t identifier,
     const iree_hal_cuda2_device_params_t* params, CUdevice cu_device,
-    CUstream stream, CUcontext context,
+    CUstream dispatch_stream, CUstream callback_stream, CUcontext context,
     const iree_hal_cuda2_dynamic_symbols_t* cuda_symbols,
     const iree_hal_cuda2_nccl_dynamic_symbols_t* nccl_symbols,
+    iree_event_pool_t* host_event_pool,
+    iree_hal_cuda2_event_pool_t* device_event_pool,
+    iree_hal_cuda2_timepoint_pool_t* timepoint_pool,
     iree_allocator_t host_allocator, iree_hal_device_t** out_device) {
   iree_hal_cuda2_device_t* device = NULL;
   iree_host_size_t total_size = iree_sizeof_struct(*device) + identifier.size;
@@ -131,15 +154,22 @@
   device->params = *params;
   device->cu_context = context;
   device->cu_device = cu_device;
-  device->cu_stream = stream;
+  device->dispatch_cu_stream = dispatch_stream;
+  device->callback_cu_stream = callback_stream;
   device->host_allocator = host_allocator;
+  device->host_event_pool = host_event_pool;
+  device->device_event_pool = device_event_pool;
+  device->timepoint_pool = timepoint_pool;
+
+  iree_status_t status = iree_hal_cuda2_pending_queue_actions_create(
+      cuda_symbols, &device->block_pool, host_allocator,
+      &device->pending_queue_actions);
 
   // Enable tracing for the (currently only) stream - no-op if disabled.
-  iree_status_t status = iree_ok_status();
-  if (device->params.stream_tracing) {
+  if (iree_status_is_ok(status) && device->params.stream_tracing) {
     status = iree_hal_cuda2_tracing_context_allocate(
-        device->cuda_symbols, device->identifier, stream, &device->block_pool,
-        host_allocator, &device->tracing_context);
+        device->cuda_symbols, device->identifier, dispatch_stream,
+        &device->block_pool, host_allocator, &device->tracing_context);
   }
 
   // Memory pool support is conditional.
@@ -163,7 +193,7 @@
 
   if (iree_status_is_ok(status)) {
     status = iree_hal_cuda2_allocator_create(
-        (iree_hal_device_t*)device, cuda_symbols, cu_device, stream,
+        (iree_hal_device_t*)device, cuda_symbols, cu_device, dispatch_stream,
         device->supports_memory_pools ? &device->memory_pools : NULL,
         host_allocator, &device->device_allocator);
   }
@@ -200,20 +230,52 @@
     status = IREE_CURESULT_TO_STATUS(cuda_symbols, cuCtxSetCurrent(context));
   }
 
-  // Create the default stream for the device.
-  CUstream stream = NULL;
+  // Create the default dispatch stream for the device.
+  CUstream dispatch_stream = NULL;
   if (iree_status_is_ok(status)) {
     status = IREE_CURESULT_TO_STATUS(
-        cuda_symbols, cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING));
+        cuda_symbols, cuStreamCreate(&dispatch_stream, CU_STREAM_NON_BLOCKING));
+  }
+  // Create the default callback stream for the device.
+  CUstream callback_stream = NULL;
+  if (iree_status_is_ok(status)) {
+    status = IREE_CURESULT_TO_STATUS(
+        cuda_symbols, cuStreamCreate(&callback_stream, CU_STREAM_NON_BLOCKING));
+  }
+
+  iree_event_pool_t* host_event_pool = NULL;
+  if (iree_status_is_ok(status)) {
+    status = iree_event_pool_allocate(params->event_pool_capacity,
+                                      host_allocator, &host_event_pool);
+  }
+
+  iree_hal_cuda2_event_pool_t* device_event_pool = NULL;
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_cuda2_event_pool_allocate(
+        cuda_symbols, params->event_pool_capacity, host_allocator,
+        &device_event_pool);
+  }
+
+  iree_hal_cuda2_timepoint_pool_t* timepoint_pool = NULL;
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_cuda2_timepoint_pool_allocate(
+        host_event_pool, device_event_pool, params->event_pool_capacity,
+        host_allocator, &timepoint_pool);
   }
 
   if (iree_status_is_ok(status)) {
     status = iree_hal_cuda2_device_create_internal(
-        driver, identifier, params, device, stream, context, cuda_symbols,
-        nccl_symbols, host_allocator, out_device);
+        driver, identifier, params, device, dispatch_stream, callback_stream,
+        context, cuda_symbols, nccl_symbols, host_event_pool, device_event_pool,
+        timepoint_pool, host_allocator, out_device);
   }
+
   if (!iree_status_is_ok(status)) {
-    if (stream) cuda_symbols->cuStreamDestroy(stream);
+    if (timepoint_pool) iree_hal_cuda2_timepoint_pool_free(timepoint_pool);
+    if (device_event_pool) iree_hal_cuda2_event_pool_free(device_event_pool);
+    if (host_event_pool) iree_event_pool_free(host_event_pool);
+    if (callback_stream) cuda_symbols->cuStreamDestroy(callback_stream);
+    if (dispatch_stream) cuda_symbols->cuStreamDestroy(dispatch_stream);
     if (context) cuda_symbols->cuDevicePrimaryCtxRelease(device);
   }
 
@@ -237,8 +299,13 @@
 static void iree_hal_cuda2_device_destroy(iree_hal_device_t* base_device) {
   iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
   iree_allocator_t host_allocator = iree_hal_device_host_allocator(base_device);
+  const iree_hal_cuda2_dynamic_symbols_t* symbols = device->cuda_symbols;
   IREE_TRACE_ZONE_BEGIN(z0);
 
+  // Destroy the pending workload queue.
+  iree_hal_cuda2_pending_queue_actions_destroy(
+      (iree_hal_resource_t*)device->pending_queue_actions);
+
   // There should be no more buffers live that use the allocator.
   iree_hal_allocator_release(device->device_allocator);
 
@@ -248,13 +315,17 @@
   // Destroy memory pools that hold on to reserved memory.
   iree_hal_cuda2_memory_pools_deinitialize(&device->memory_pools);
 
-  // TODO: support multiple streams.
   iree_hal_cuda2_tracing_context_free(device->tracing_context);
-  IREE_CUDA_IGNORE_ERROR(device->cuda_symbols,
-                         cuStreamDestroy(device->cu_stream));
 
-  IREE_CUDA_IGNORE_ERROR(device->cuda_symbols,
-                         cuDevicePrimaryCtxRelease(device->cu_device));
+  // Destroy various pools for synchronization.
+  iree_hal_cuda2_timepoint_pool_free(device->timepoint_pool);
+  iree_hal_cuda2_event_pool_free(device->device_event_pool);
+  iree_event_pool_free(device->host_event_pool);
+
+  IREE_CUDA_IGNORE_ERROR(symbols, cuStreamDestroy(device->dispatch_cu_stream));
+  IREE_CUDA_IGNORE_ERROR(symbols, cuStreamDestroy(device->callback_cu_stream));
+
+  IREE_CUDA_IGNORE_ERROR(symbols, cuDevicePrimaryCtxRelease(device->cu_device));
 
   iree_arena_block_pool_deinitialize(&device->block_pool);
 
@@ -490,8 +561,9 @@
     iree_hal_device_t* base_device, uint64_t initial_value,
     iree_hal_semaphore_t** out_semaphore) {
   iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
-  return iree_hal_cuda2_semaphore_create(initial_value, device->host_allocator,
-                                         out_semaphore);
+  return iree_hal_cuda2_event_semaphore_create(
+      initial_value, device->cuda_symbols, device->timepoint_pool,
+      device->pending_queue_actions, device->host_allocator, out_semaphore);
 }
 
 static iree_hal_semaphore_compatibility_t
@@ -526,9 +598,9 @@
   // allocator is set on the device.
   iree_status_t status = iree_ok_status();
   if (device->supports_memory_pools) {
-    status = iree_hal_cuda2_memory_pools_alloca(&device->memory_pools,
-                                                device->cu_stream, pool, params,
-                                                allocation_size, out_buffer);
+    status = iree_hal_cuda2_memory_pools_alloca(
+        &device->memory_pools, device->dispatch_cu_stream, pool, params,
+        allocation_size, out_buffer);
   } else {
     status = iree_hal_allocator_allocate_buffer(
         iree_hal_device_allocator(base_device), params, allocation_size,
@@ -565,8 +637,8 @@
   // drop it on the floor and let it be freed when the buffer is released.
   iree_status_t status = iree_ok_status();
   if (device->supports_memory_pools) {
-    status = iree_hal_cuda2_memory_pools_dealloca(&device->memory_pools,
-                                                  device->cu_stream, buffer);
+    status = iree_hal_cuda2_memory_pools_dealloca(
+        &device->memory_pools, device->dispatch_cu_stream, buffer);
   }
 
   // Only signal if not returning a synchronous error - synchronous failure
@@ -585,40 +657,39 @@
     iree_host_size_t command_buffer_count,
     iree_hal_command_buffer_t* const* command_buffers) {
   iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
+  IREE_TRACE_ZONE_BEGIN(z0);
 
-  // TODO(benvanik): trace around the entire submission.
-
-  for (iree_host_size_t i = 0; i < command_buffer_count; i++) {
-    CUgraphExec exec =
-        iree_hal_cuda2_graph_command_buffer_handle(command_buffers[i]);
-    IREE_CUDA_RETURN_IF_ERROR(device->cuda_symbols,
-                              cuGraphLaunch(exec, device->cu_stream),
-                              "cuGraphLaunch");
+  iree_status_t status = iree_hal_cuda2_pending_queue_actions_enqueue_execution(
+      device->dispatch_cu_stream, device->callback_cu_stream,
+      device->pending_queue_actions, wait_semaphore_list, signal_semaphore_list,
+      command_buffer_count, command_buffers);
+  if (iree_status_is_ok(status)) {
+    // Try to advance the pending workload queue.
+    status = iree_hal_cuda2_pending_queue_actions_issue(
+        device->pending_queue_actions);
   }
 
-  // TODO(antiagainst): implement semaphores - for now this conservatively
-  // synchronizes after every submit.
-  IREE_TRACE_ZONE_BEGIN_NAMED(z0, "cuStreamSynchronize");
-  IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR(z0, device->cuda_symbols,
-                                         cuStreamSynchronize(device->cu_stream),
-                                         "cuStreamSynchronize");
   iree_hal_cuda2_tracing_context_collect(device->tracing_context);
   IREE_TRACE_ZONE_END(z0);
-
-  return iree_ok_status();
+  return status;
 }
 
 static iree_status_t iree_hal_cuda2_device_queue_flush(
     iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity) {
-  // Currently unused; we flush as submissions are made.
-  return iree_ok_status();
+  iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
+  IREE_TRACE_ZONE_BEGIN(z0);
+  // Try to advance the pending workload queue.
+  iree_status_t status =
+      iree_hal_cuda2_pending_queue_actions_issue(device->pending_queue_actions);
+  IREE_TRACE_ZONE_END(z0);
+  return status;
 }
 
 static iree_status_t iree_hal_cuda2_device_wait_semaphores(
     iree_hal_device_t* base_device, iree_hal_wait_mode_t wait_mode,
     const iree_hal_semaphore_list_t semaphore_list, iree_timeout_t timeout) {
   return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
-                          "semaphore not yet implemented");
+                          "waiting multiple semaphores not yet implemented");
 }
 
 static iree_status_t iree_hal_cuda2_device_profiling_begin(