[cuda] Implement HAL semaphore using CUevent objects (#14426)
This commit adds a HAL semaphore implementation for the CUDA driver
backed by iree_event_t and CUevent objects for different synchronization
directions.
Fixes https://github.com/openxla/iree/issues/4727
Progress towards https://github.com/openxla/iree/issues/13245
diff --git a/experimental/cuda2/CMakeLists.txt b/experimental/cuda2/CMakeLists.txt
index 873c3b9..de91884 100644
--- a/experimental/cuda2/CMakeLists.txt
+++ b/experimental/cuda2/CMakeLists.txt
@@ -24,6 +24,10 @@
"cuda_device.c"
"cuda_device.h"
"cuda_driver.c"
+ "event_pool.c"
+ "event_pool.h"
+ "event_semaphore.c"
+ "event_semaphore.h"
"graph_command_buffer.c"
"graph_command_buffer.h"
"memory_pools.c"
@@ -36,8 +40,12 @@
"nop_semaphore.h"
"nccl_channel.c"
"nccl_channel.h"
+ "pending_queue_actions.c"
+ "pending_queue_actions.h"
"pipeline_layout.c"
"pipeline_layout.h"
+ "timepoint_pool.c"
+ "timepoint_pool.h"
"tracing.c"
"tracing.h"
DEPS
@@ -45,6 +53,7 @@
iree::base
iree::base::internal
iree::base::internal::arena
+ iree::base::internal::event_pool
iree::base::internal::flatcc::parsing
iree::hal
iree::hal::utils::buffer_transfer
diff --git a/experimental/cuda2/README.md b/experimental/cuda2/README.md
new file mode 100644
index 0000000..7825c75
--- /dev/null
+++ b/experimental/cuda2/README.md
@@ -0,0 +1,164 @@
+# IREE CUDA HAL Driver
+
+This document lists technical details regarding the CUDA implemenation of
+IREE's [Hardware Abstraction Layer (HAL)][iree-hal], called a CUDA HAL driver.
+
+Note that there is an existing CUDA HAL driver under the
+[`iree/hal/drivers/cuda/`][iree-cuda] directory; what this directory holds is
+a rewrite for it. Once this rewrite is mature enough, it will replace the
+existing one. For the rewrite rationale, goals, and plans, please see
+[Issue #13245][iree-cuda-rewrite].
+
+## Synchronization
+
+### HAL Semaphore
+
+The IREE HAL uses semaphores to synchronize work between host CPU threads and
+device GPU streams. It's a unified primitive that covers all directions--host
+to host, host to device, device to host, and device to device, and allows
+flexible signal and wait ordering--signal before wait, or wait before signal.
+There is no limit on the number of waits of the same value too.
+
+The core state of a HAL semaphore consists of a monotonically increasing 64-bit
+integer value, which forms a timeline--signaling the semaphore to a larger
+value advances the timeline and unblocks work waiting on some earlier values.
+The semantics closely mirrors
+[Vulkan timeline semaphore][vulkan-timeline-semaphore].
+
+In CUDA, there is no direct equivalent primitives providing all the capabilities
+needed by the HAL semaphore abstraction:
+
+* [Stream memory operations][cu-mem-ops] provides `cuStreamWriteValue64()` and
+ `cuStreamWaitValue64()`, which can implment HAL semaphore 64-bit integer value
+ signal and wait. Though these operations require device pointers and cannot
+ accepts pointers to managed memory buffers, meaning no support for the host.
+ Additionally, per the spec, "synchronization ordering established through
+ these APIs is not visible to CUDA. CUDA tasks that are (even indirectly)
+ ordered by these APIs should also have that order expressed with
+ CUDA-visible dependencies such as events." So it's not suitable for
+ integration with other CUDA components.
+* For [external resource interoperability][cu-external-resource], we have APIs
+ like `cuSignalExternalSemaphoresAsync()` and `cuWaitExternalSemaphoresAsync()`,
+ which can directly map to Vulkan timeline semaphores. Though these APIs are
+ meant to handle exernal resources--there is no way to create
+ `CUexternalSemaphore` objects directly other than `cuImportExternalSemaphore()`.
+
+Therefore, to implement the support, we need to leverage multiple native CPU or
+CUDA primitives under the hood.
+
+#### `CUevent` capabilities
+
+The main synchronization mechanism is [CUDA event--`CUevent`][cu-event].
+As a functionality and integration baseline, we use `CUevent` to implement the
+IREE HAL semaphore abstraction.
+
+`CUevent` natively supports the following capabilities:
+
+* State: binary; either unsignaled or signaled. There can exist multiple
+ waits (e.g., via `cuEventSynchronize()` or `cuGraphAddEventWaitNode()`) for
+ the same `CUevent` signal (e.g., via `cuEventRecord()` or
+ `cuGraphAddEventRecordNode()`).
+* Ordering: must be signal before wait. Waiting before signal would mean
+ waiting an empty set of work, or previously recorded work.
+* Direction: device to device, device to host.
+
+We need to fill the remaining capability gaps. Before going into details,
+the overall approach would be to:
+
+* State: we need a 64-bit integer value timeline. Given the binary state of
+ a `CUevent`, each `CUevent` would just be a "timepoint" on the timeline.
+* Ordering: we need to defer releasing the workload to the GPU until the
+ semaphore waits are reached on the host, or we can have some device
+ `CUevent` to wait on.
+* Direction: host to host and host to device is missing; we can support that
+ with host synchronization mechanisms.
+
+#### Signal to wait analysis
+
+Concretely, for a given HAL semaphore, looking at the four directions:
+
+##### CPU signal
+
+A CPU thread signals the semaphore timeline to a new value.
+
+If there are CPU waits, it is purely on the CPU side. We just need to use common
+CPU notification mechanisms. In IREE we have `iree_event_t` wrapping various
+low-level OS primitives for it. So we can just use that to represent a wait
+timepoint. We need to keep track of all CPU wait timepoints in the timeline.
+After a new signaled value, go through the timeline and notify all those waiting
+on earlier values.
+
+If there are GPU waits, given that there are no way we can signal a `CUevent` on
+CPU, one way to handle this is to cache and defer the submission batches by
+ourselves until CPU signals past the desired value. To support this, we would
+need to implement a deferred/pending actions queue.
+
+##### GPU signal
+
+GPU signals can only be through a `CUevent` object, which has a binary state.
+We need to advance the timeline too. One way is to use `cuLaunchHostFunc()`
+to perform the advance from the CPU side. This additionally would mean we can
+reuse the logic form CPU signaling to unblock CPU waits.
+
+For GPU waits, we can also leverage the same logic--using CPU signaling to
+unblock deferred GPU queue actions. Though this is performant, given that
+the CPU is involved for GPU internal synchronization. We want to use `CUevent`
+instead:
+
+* We keep track of all GPU signals in the timeline. Once we see a GPU wait
+ request, try to scan the timeline to find a GPU signal that advances the
+ timeline past the desired value, and use that for waiting instead.
+* We may not see GPU signal before seeing GPU wait requests, then we can also
+ keep track of all GPU waits in the timeline. Later once see either a CPU
+ signal or GPU signal advancing past the waited value, we can handle them
+ accordingly--submitting immediately or associating the `CUevent`.
+ This would also guarantee the requirement of `CUevent`--recording should
+ happen before waiting.
+* We can use the same `CUevent` to unblock multiple GPU waits. That's allowed,
+ though it would mean we need to be careful regarding `CUevent` lifetime
+ management. Here we can use reference counting to see how many timepoints
+ are using it and automatically return to a pool once done.
+
+Another problem is that per the `cuLaunchHostFunc()` doc, "the function will
+be called after currently enqueued work and will block work added after it."
+We don't want the blocking behavior involving host. So we can use a dedicated
+`CUstream` for launching the host function, waiting on the `CUevent` from the
+original stream too. We can also handle resource deallocation together there.
+
+#### Data structures
+
+To summarize, we need the following data structures to implement HAL semaphore:
+
+* `iree_event_t`: CPU notification mechanism wrapping low-level OS primitives.
+ Used by host wait timepoints.
+* `iree_event_pool_t`: a pool for CPU `iree_event_t` objects to recycle.
+* `iree_hal_cuda2_event_t`: GPU notification mechanism wrapping a `CUevent` and
+ a reference count. Used by device signal and wait timepoints. Associates with
+ a `iree_hal_cuda2_event_pool_t` pool--returns to the pool directly on once
+ reference count goes to 0.
+* `iree_hal_cuda2_event_pool_t`: a pool for GPU `iree_hal_cuda2_event_t` objects
+ to recycle.
+* `iree_hal_cuda2_timepoint_t`: an object that wraps a CPU `iree_event_t` or
+ GPU `iree_hal_cuda2_event_t` to represent wait/signal of a timepoint on a
+ timeline.
+* `iree_hal_cuda2_timepoint_pool_t`: a pool for `iree_hal_cuda2_timepoint_t`
+ objects to recycle. This pool builds upon the CPU and GPU event pool--it
+ acquires CPU/GPU event objects there.
+* `iree_hal_cuda_timeline_semaphore_t`: contains a list of CPU wait and GPU
+ wait/signal timepoints.
+* `iree_hal_cuda2_queue_action_t`: a pending queue action (kernel launch or
+ stream-ordered allocation).
+* `iree_hal_cuda2_pending_queue_actions_t`: a data structure to manage pending
+ queue actions. It provides APIs to enqueue actions, and advance the queue on
+ demand--queue actions are released to the GPU when all their wait semaphores
+ are signaled past the desired value, or we can have a `CUevent` object already
+ recorded to some `CUstream` to wait on.
+
+
+[iree-hal]: https://github.com/openxla/iree/tree/main/runtime/src/iree/hal
+[iree-cuda]: https://github.com/openxla/iree/tree/main/runtime/src/iree/hal/drivers/cuda
+[iree-cuda-rewite]: https://github.com/openxla/iree/issues/13245
+[vulkan-timeline-semaphore]: https://www.khronos.org/blog/vulkan-timeline-semaphores
+[cu-mem-ops]: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEMOP.html
+[cu-external-resource]: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXTRES__INTEROP.html
+[cu-event]: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EVENT.html
diff --git a/experimental/cuda2/api.h b/experimental/cuda2/api.h
index 5df17c5..62130cb 100644
--- a/experimental/cuda2/api.h
+++ b/experimental/cuda2/api.h
@@ -59,6 +59,13 @@
// transient allocations while also increasing memory consumption.
iree_host_size_t arena_block_size;
+ // The host and device event pool capacity.
+ // The CUDA driver implements semaphore with host and device events. This
+ // parameter controls the size of those pools. Larger values would make
+ // creating semaphore values quicker, though with increased memory
+ // consumption.
+ iree_host_size_t event_pool_capacity;
+
// Enables tracing of command buffers when IREE tracing is enabled.
// May take advantage of additional extensions for more accurate timing or
// hardware-specific performance counters.
diff --git a/experimental/cuda2/cts/CMakeLists.txt b/experimental/cuda2/cts/CMakeLists.txt
index c6f15a5..c6a07fa 100644
--- a/experimental/cuda2/cts/CMakeLists.txt
+++ b/experimental/cuda2/cts/CMakeLists.txt
@@ -20,8 +20,6 @@
EXCLUDED_TESTS
# HAL event is unimplemented for now.
"event"
- # HAL semaphore is in the process of being implemented.
- "semaphore"
LABELS
driver=cuda2
requires-gpu-nvidia
diff --git a/experimental/cuda2/cuda_device.c b/experimental/cuda2/cuda_device.c
index 0d72c36..c6bc989 100644
--- a/experimental/cuda2/cuda_device.c
+++ b/experimental/cuda2/cuda_device.c
@@ -14,15 +14,19 @@
#include "experimental/cuda2/cuda_buffer.h"
#include "experimental/cuda2/cuda_dynamic_symbols.h"
#include "experimental/cuda2/cuda_status_util.h"
+#include "experimental/cuda2/event_pool.h"
+#include "experimental/cuda2/event_semaphore.h"
#include "experimental/cuda2/graph_command_buffer.h"
#include "experimental/cuda2/memory_pools.h"
#include "experimental/cuda2/nccl_channel.h"
#include "experimental/cuda2/nccl_dynamic_symbols.h"
#include "experimental/cuda2/nop_executable_cache.h"
-#include "experimental/cuda2/nop_semaphore.h"
+#include "experimental/cuda2/pending_queue_actions.h"
#include "experimental/cuda2/pipeline_layout.h"
+#include "experimental/cuda2/timepoint_pool.h"
#include "experimental/cuda2/tracing.h"
#include "iree/base/internal/arena.h"
+#include "iree/base/internal/event_pool.h"
#include "iree/base/internal/math.h"
#include "iree/hal/utils/buffer_transfer.h"
#include "iree/hal/utils/deferred_command_buffer.h"
@@ -53,13 +57,28 @@
CUcontext cu_context;
CUdevice cu_device;
- // TODO: support multiple streams.
- CUstream cu_stream;
+ // TODO: Support multiple device streams.
+ // The CUstream used to issue device kernels and allocations.
+ CUstream dispatch_cu_stream;
+ // The CUstream used to issue host callback functions.
+ CUstream callback_cu_stream;
iree_hal_cuda2_tracing_context_t* tracing_context;
iree_allocator_t host_allocator;
+ // Host/device event pools, used for backing semaphore timepoints.
+ iree_event_pool_t* host_event_pool;
+ iree_hal_cuda2_event_pool_t* device_event_pool;
+ // Timepoint pools, shared by various semaphores.
+ iree_hal_cuda2_timepoint_pool_t* timepoint_pool;
+
+ // A queue to order device workloads and relase to the GPU when constraints
+ // are met. It buffers submissions and allocations internally before they
+ // are ready. This queue couples with HAL semaphores backed by iree_event_t
+ // and CUevent objects.
+ iree_hal_cuda2_pending_queue_actions_t* pending_queue_actions;
+
// Device memory pools and allocators.
bool supports_memory_pools;
iree_hal_cuda2_memory_pools_t memory_pools;
@@ -86,6 +105,7 @@
iree_hal_cuda2_device_params_t* out_params) {
memset(out_params, 0, sizeof(*out_params));
out_params->arena_block_size = 32 * 1024;
+ out_params->event_pool_capacity = 32;
out_params->queue_count = 1;
out_params->stream_tracing = false;
out_params->async_allocations = true;
@@ -107,9 +127,12 @@
static iree_status_t iree_hal_cuda2_device_create_internal(
iree_hal_driver_t* driver, iree_string_view_t identifier,
const iree_hal_cuda2_device_params_t* params, CUdevice cu_device,
- CUstream stream, CUcontext context,
+ CUstream dispatch_stream, CUstream callback_stream, CUcontext context,
const iree_hal_cuda2_dynamic_symbols_t* cuda_symbols,
const iree_hal_cuda2_nccl_dynamic_symbols_t* nccl_symbols,
+ iree_event_pool_t* host_event_pool,
+ iree_hal_cuda2_event_pool_t* device_event_pool,
+ iree_hal_cuda2_timepoint_pool_t* timepoint_pool,
iree_allocator_t host_allocator, iree_hal_device_t** out_device) {
iree_hal_cuda2_device_t* device = NULL;
iree_host_size_t total_size = iree_sizeof_struct(*device) + identifier.size;
@@ -131,15 +154,22 @@
device->params = *params;
device->cu_context = context;
device->cu_device = cu_device;
- device->cu_stream = stream;
+ device->dispatch_cu_stream = dispatch_stream;
+ device->callback_cu_stream = callback_stream;
device->host_allocator = host_allocator;
+ device->host_event_pool = host_event_pool;
+ device->device_event_pool = device_event_pool;
+ device->timepoint_pool = timepoint_pool;
+
+ iree_status_t status = iree_hal_cuda2_pending_queue_actions_create(
+ cuda_symbols, &device->block_pool, host_allocator,
+ &device->pending_queue_actions);
// Enable tracing for the (currently only) stream - no-op if disabled.
- iree_status_t status = iree_ok_status();
- if (device->params.stream_tracing) {
+ if (iree_status_is_ok(status) && device->params.stream_tracing) {
status = iree_hal_cuda2_tracing_context_allocate(
- device->cuda_symbols, device->identifier, stream, &device->block_pool,
- host_allocator, &device->tracing_context);
+ device->cuda_symbols, device->identifier, dispatch_stream,
+ &device->block_pool, host_allocator, &device->tracing_context);
}
// Memory pool support is conditional.
@@ -163,7 +193,7 @@
if (iree_status_is_ok(status)) {
status = iree_hal_cuda2_allocator_create(
- (iree_hal_device_t*)device, cuda_symbols, cu_device, stream,
+ (iree_hal_device_t*)device, cuda_symbols, cu_device, dispatch_stream,
device->supports_memory_pools ? &device->memory_pools : NULL,
host_allocator, &device->device_allocator);
}
@@ -200,20 +230,52 @@
status = IREE_CURESULT_TO_STATUS(cuda_symbols, cuCtxSetCurrent(context));
}
- // Create the default stream for the device.
- CUstream stream = NULL;
+ // Create the default dispatch stream for the device.
+ CUstream dispatch_stream = NULL;
if (iree_status_is_ok(status)) {
status = IREE_CURESULT_TO_STATUS(
- cuda_symbols, cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING));
+ cuda_symbols, cuStreamCreate(&dispatch_stream, CU_STREAM_NON_BLOCKING));
+ }
+ // Create the default callback stream for the device.
+ CUstream callback_stream = NULL;
+ if (iree_status_is_ok(status)) {
+ status = IREE_CURESULT_TO_STATUS(
+ cuda_symbols, cuStreamCreate(&callback_stream, CU_STREAM_NON_BLOCKING));
+ }
+
+ iree_event_pool_t* host_event_pool = NULL;
+ if (iree_status_is_ok(status)) {
+ status = iree_event_pool_allocate(params->event_pool_capacity,
+ host_allocator, &host_event_pool);
+ }
+
+ iree_hal_cuda2_event_pool_t* device_event_pool = NULL;
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_cuda2_event_pool_allocate(
+ cuda_symbols, params->event_pool_capacity, host_allocator,
+ &device_event_pool);
+ }
+
+ iree_hal_cuda2_timepoint_pool_t* timepoint_pool = NULL;
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_cuda2_timepoint_pool_allocate(
+ host_event_pool, device_event_pool, params->event_pool_capacity,
+ host_allocator, &timepoint_pool);
}
if (iree_status_is_ok(status)) {
status = iree_hal_cuda2_device_create_internal(
- driver, identifier, params, device, stream, context, cuda_symbols,
- nccl_symbols, host_allocator, out_device);
+ driver, identifier, params, device, dispatch_stream, callback_stream,
+ context, cuda_symbols, nccl_symbols, host_event_pool, device_event_pool,
+ timepoint_pool, host_allocator, out_device);
}
+
if (!iree_status_is_ok(status)) {
- if (stream) cuda_symbols->cuStreamDestroy(stream);
+ if (timepoint_pool) iree_hal_cuda2_timepoint_pool_free(timepoint_pool);
+ if (device_event_pool) iree_hal_cuda2_event_pool_free(device_event_pool);
+ if (host_event_pool) iree_event_pool_free(host_event_pool);
+ if (callback_stream) cuda_symbols->cuStreamDestroy(callback_stream);
+ if (dispatch_stream) cuda_symbols->cuStreamDestroy(dispatch_stream);
if (context) cuda_symbols->cuDevicePrimaryCtxRelease(device);
}
@@ -237,8 +299,13 @@
static void iree_hal_cuda2_device_destroy(iree_hal_device_t* base_device) {
iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
iree_allocator_t host_allocator = iree_hal_device_host_allocator(base_device);
+ const iree_hal_cuda2_dynamic_symbols_t* symbols = device->cuda_symbols;
IREE_TRACE_ZONE_BEGIN(z0);
+ // Destroy the pending workload queue.
+ iree_hal_cuda2_pending_queue_actions_destroy(
+ (iree_hal_resource_t*)device->pending_queue_actions);
+
// There should be no more buffers live that use the allocator.
iree_hal_allocator_release(device->device_allocator);
@@ -248,13 +315,17 @@
// Destroy memory pools that hold on to reserved memory.
iree_hal_cuda2_memory_pools_deinitialize(&device->memory_pools);
- // TODO: support multiple streams.
iree_hal_cuda2_tracing_context_free(device->tracing_context);
- IREE_CUDA_IGNORE_ERROR(device->cuda_symbols,
- cuStreamDestroy(device->cu_stream));
- IREE_CUDA_IGNORE_ERROR(device->cuda_symbols,
- cuDevicePrimaryCtxRelease(device->cu_device));
+ // Destroy various pools for synchronization.
+ iree_hal_cuda2_timepoint_pool_free(device->timepoint_pool);
+ iree_hal_cuda2_event_pool_free(device->device_event_pool);
+ iree_event_pool_free(device->host_event_pool);
+
+ IREE_CUDA_IGNORE_ERROR(symbols, cuStreamDestroy(device->dispatch_cu_stream));
+ IREE_CUDA_IGNORE_ERROR(symbols, cuStreamDestroy(device->callback_cu_stream));
+
+ IREE_CUDA_IGNORE_ERROR(symbols, cuDevicePrimaryCtxRelease(device->cu_device));
iree_arena_block_pool_deinitialize(&device->block_pool);
@@ -490,8 +561,9 @@
iree_hal_device_t* base_device, uint64_t initial_value,
iree_hal_semaphore_t** out_semaphore) {
iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
- return iree_hal_cuda2_semaphore_create(initial_value, device->host_allocator,
- out_semaphore);
+ return iree_hal_cuda2_event_semaphore_create(
+ initial_value, device->cuda_symbols, device->timepoint_pool,
+ device->pending_queue_actions, device->host_allocator, out_semaphore);
}
static iree_hal_semaphore_compatibility_t
@@ -526,9 +598,9 @@
// allocator is set on the device.
iree_status_t status = iree_ok_status();
if (device->supports_memory_pools) {
- status = iree_hal_cuda2_memory_pools_alloca(&device->memory_pools,
- device->cu_stream, pool, params,
- allocation_size, out_buffer);
+ status = iree_hal_cuda2_memory_pools_alloca(
+ &device->memory_pools, device->dispatch_cu_stream, pool, params,
+ allocation_size, out_buffer);
} else {
status = iree_hal_allocator_allocate_buffer(
iree_hal_device_allocator(base_device), params, allocation_size,
@@ -565,8 +637,8 @@
// drop it on the floor and let it be freed when the buffer is released.
iree_status_t status = iree_ok_status();
if (device->supports_memory_pools) {
- status = iree_hal_cuda2_memory_pools_dealloca(&device->memory_pools,
- device->cu_stream, buffer);
+ status = iree_hal_cuda2_memory_pools_dealloca(
+ &device->memory_pools, device->dispatch_cu_stream, buffer);
}
// Only signal if not returning a synchronous error - synchronous failure
@@ -585,40 +657,39 @@
iree_host_size_t command_buffer_count,
iree_hal_command_buffer_t* const* command_buffers) {
iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
+ IREE_TRACE_ZONE_BEGIN(z0);
- // TODO(benvanik): trace around the entire submission.
-
- for (iree_host_size_t i = 0; i < command_buffer_count; i++) {
- CUgraphExec exec =
- iree_hal_cuda2_graph_command_buffer_handle(command_buffers[i]);
- IREE_CUDA_RETURN_IF_ERROR(device->cuda_symbols,
- cuGraphLaunch(exec, device->cu_stream),
- "cuGraphLaunch");
+ iree_status_t status = iree_hal_cuda2_pending_queue_actions_enqueue_execution(
+ device->dispatch_cu_stream, device->callback_cu_stream,
+ device->pending_queue_actions, wait_semaphore_list, signal_semaphore_list,
+ command_buffer_count, command_buffers);
+ if (iree_status_is_ok(status)) {
+ // Try to advance the pending workload queue.
+ status = iree_hal_cuda2_pending_queue_actions_issue(
+ device->pending_queue_actions);
}
- // TODO(antiagainst): implement semaphores - for now this conservatively
- // synchronizes after every submit.
- IREE_TRACE_ZONE_BEGIN_NAMED(z0, "cuStreamSynchronize");
- IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR(z0, device->cuda_symbols,
- cuStreamSynchronize(device->cu_stream),
- "cuStreamSynchronize");
iree_hal_cuda2_tracing_context_collect(device->tracing_context);
IREE_TRACE_ZONE_END(z0);
-
- return iree_ok_status();
+ return status;
}
static iree_status_t iree_hal_cuda2_device_queue_flush(
iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity) {
- // Currently unused; we flush as submissions are made.
- return iree_ok_status();
+ iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
+ IREE_TRACE_ZONE_BEGIN(z0);
+ // Try to advance the pending workload queue.
+ iree_status_t status =
+ iree_hal_cuda2_pending_queue_actions_issue(device->pending_queue_actions);
+ IREE_TRACE_ZONE_END(z0);
+ return status;
}
static iree_status_t iree_hal_cuda2_device_wait_semaphores(
iree_hal_device_t* base_device, iree_hal_wait_mode_t wait_mode,
const iree_hal_semaphore_list_t semaphore_list, iree_timeout_t timeout) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
- "semaphore not yet implemented");
+ "waiting multiple semaphores not yet implemented");
}
static iree_status_t iree_hal_cuda2_device_profiling_begin(
diff --git a/experimental/cuda2/cuda_dynamic_symbol_table.h b/experimental/cuda2/cuda_dynamic_symbol_table.h
index aa030c9..f57c65d 100644
--- a/experimental/cuda2/cuda_dynamic_symbol_table.h
+++ b/experimental/cuda2/cuda_dynamic_symbol_table.h
@@ -84,3 +84,4 @@
IREE_CU_PFN_DECL(cuLaunchKernel, CUfunction, unsigned int, unsigned int,
unsigned int, unsigned int, unsigned int, unsigned int,
unsigned int, CUstream, void**, void**)
+IREE_CU_PFN_DECL(cuLaunchHostFunc, CUstream, CUhostFn, void*);
diff --git a/experimental/cuda2/event_pool.c b/experimental/cuda2/event_pool.c
new file mode 100644
index 0000000..d90bf79
--- /dev/null
+++ b/experimental/cuda2/event_pool.c
@@ -0,0 +1,278 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "experimental/cuda2/event_pool.h"
+
+#include <stdbool.h>
+#include <stddef.h>
+#include <string.h>
+
+#include "experimental/cuda2/cuda_dynamic_symbols.h"
+#include "experimental/cuda2/cuda_status_util.h"
+#include "iree/base/api.h"
+#include "iree/base/internal/atomics.h"
+#include "iree/base/internal/synchronization.h"
+#include "iree/hal/api.h"
+
+//===----------------------------------------------------------------------===//
+// iree_hal_cuda2_event_t
+//===----------------------------------------------------------------------===//
+
+struct iree_hal_cuda2_event_t {
+ // The allocator used to create the event.
+ iree_allocator_t host_allocator;
+ // The symbols used to create and destroy CUevent objects.
+ const iree_hal_cuda2_dynamic_symbols_t* symbols;
+
+ // The event pool that owns this event. This cannot be NULL.
+ iree_hal_cuda2_event_pool_t* pool;
+ // The underlying CUevent object.
+ CUevent cu_event;
+
+ // A reference count used to manage resource lifetime. Its value range:
+ // * 1 - when inside the event pool and to be acquired;
+ // * >= 1 - when acquired outside of the event pool;
+ // * 0 - when before releasing back to the pool or destruction.
+ iree_atomic_ref_count_t ref_count;
+};
+
+CUevent iree_hal_cuda2_event_handle(const iree_hal_cuda2_event_t* event) {
+ return event->cu_event;
+}
+
+static inline void iree_hal_cuda2_event_destroy(iree_hal_cuda2_event_t* event) {
+ iree_allocator_t host_allocator = event->host_allocator;
+ const iree_hal_cuda2_dynamic_symbols_t* symbols = event->symbols;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ IREE_ASSERT_REF_COUNT_ZERO(&event->ref_count);
+ IREE_CUDA_IGNORE_ERROR(symbols, cuEventDestroy(event->cu_event));
+ iree_allocator_free(host_allocator, event);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+static inline iree_status_t iree_hal_cuda2_event_create(
+ const iree_hal_cuda2_dynamic_symbols_t* symbols,
+ iree_hal_cuda2_event_pool_t* pool, iree_allocator_t host_allocator,
+ iree_hal_cuda2_event_t** out_event) {
+ IREE_ASSERT_ARGUMENT(symbols);
+ IREE_ASSERT_ARGUMENT(pool);
+ IREE_ASSERT_ARGUMENT(out_event);
+ *out_event = NULL;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_hal_cuda2_event_t* event = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0,
+ iree_allocator_malloc(host_allocator, sizeof(*event), (void**)&event));
+ event->host_allocator = host_allocator;
+ event->symbols = symbols;
+ event->pool = pool;
+ event->cu_event = NULL;
+ iree_atomic_ref_count_init(&event->ref_count); // -> 1
+
+ iree_status_t status = IREE_CURESULT_TO_STATUS(
+ symbols, cuEventCreate(&event->cu_event, CU_EVENT_DISABLE_TIMING),
+ "cuEventCreate");
+ if (iree_status_is_ok(status)) {
+ *out_event = event;
+ } else {
+ iree_atomic_ref_count_dec(&event->ref_count); // -> 0
+ iree_hal_cuda2_event_destroy(event);
+ }
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+void iree_hal_cuda2_event_retain(iree_hal_cuda2_event_t* event) {
+ iree_atomic_ref_count_inc(&event->ref_count);
+}
+
+static void iree_hal_cuda2_event_pool_release(
+ iree_hal_cuda2_event_pool_t* event_pool, iree_host_size_t event_count,
+ iree_hal_cuda2_event_t** events);
+
+void iree_hal_cuda2_event_release(iree_hal_cuda2_event_t* event) {
+ if (iree_atomic_ref_count_dec(&event->ref_count) == 1) {
+ // Release back to the pool if the reference count becomes 0.
+ iree_hal_cuda2_event_pool_release(event->pool, 1, &event);
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hal_cuda2_event_pool_t
+//===----------------------------------------------------------------------===//
+
+struct iree_hal_cuda2_event_pool_t {
+ // The allocator used to create the event pool.
+ iree_allocator_t host_allocator;
+ // The symbols used to create and destroy CUevent objects.
+ const iree_hal_cuda2_dynamic_symbols_t* symbols;
+
+ // Guards event related fields in the pool. We don't expect a performant
+ // program to frequently allocate events for synchronization purposes; the
+ // traffic to this pool should be low. So it should be fine to use mutex to
+ // guard here.
+ iree_slim_mutex_t event_mutex;
+
+ // Maximum number of event objects that will be maintained in the pool.
+ // More events may be allocated at any time, but they will be disposed
+ // directly when they are no longer needed.
+ iree_host_size_t available_capacity IREE_GUARDED_BY(event_mutex);
+ // Total number of currently available event objects.
+ iree_host_size_t available_count IREE_GUARDED_BY(event_mutex);
+ // The list of available_count event objects.
+ iree_hal_cuda2_event_t* available_list[] IREE_GUARDED_BY(event_mutex);
+};
+// + Additional inline allocation for holding events up to the capacity.
+
+iree_status_t iree_hal_cuda2_event_pool_allocate(
+ const iree_hal_cuda2_dynamic_symbols_t* symbols,
+ iree_host_size_t available_capacity, iree_allocator_t host_allocator,
+ iree_hal_cuda2_event_pool_t** out_event_pool) {
+ IREE_ASSERT_ARGUMENT(symbols);
+ IREE_ASSERT_ARGUMENT(out_event_pool);
+ *out_event_pool = NULL;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_hal_cuda2_event_pool_t* event_pool = NULL;
+ iree_host_size_t total_size =
+ sizeof(*event_pool) +
+ available_capacity * sizeof(*event_pool->available_list);
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0,
+ iree_allocator_malloc(host_allocator, total_size, (void**)&event_pool));
+ event_pool->host_allocator = host_allocator;
+ event_pool->symbols = symbols;
+ iree_slim_mutex_initialize(&event_pool->event_mutex);
+ event_pool->available_capacity = available_capacity;
+ event_pool->available_count = 0;
+
+ iree_status_t status = iree_ok_status();
+ for (iree_host_size_t i = 0; i < available_capacity; ++i) {
+ status = iree_hal_cuda2_event_create(
+ symbols, event_pool, host_allocator,
+ &event_pool->available_list[event_pool->available_count++]);
+ if (!iree_status_is_ok(status)) break;
+ }
+
+ if (iree_status_is_ok(status)) {
+ *out_event_pool = event_pool;
+ } else {
+ iree_hal_cuda2_event_pool_free(event_pool);
+ }
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+void iree_hal_cuda2_event_pool_free(iree_hal_cuda2_event_pool_t* event_pool) {
+ iree_allocator_t host_allocator = event_pool->host_allocator;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ for (iree_host_size_t i = 0; i < event_pool->available_count; ++i) {
+ iree_hal_cuda2_event_t* event = event_pool->available_list[i];
+ iree_atomic_ref_count_dec(&event->ref_count); // -> 0
+ iree_hal_cuda2_event_destroy(event);
+ }
+ iree_slim_mutex_deinitialize(&event_pool->event_mutex);
+ iree_allocator_free(host_allocator, event_pool);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+iree_status_t iree_hal_cuda2_event_pool_acquire(
+ iree_hal_cuda2_event_pool_t* event_pool, iree_host_size_t event_count,
+ iree_hal_cuda2_event_t** out_events) {
+ IREE_ASSERT_ARGUMENT(event_pool);
+ if (!event_count) return iree_ok_status();
+ IREE_ASSERT_ARGUMENT(out_events);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // We'll try to get what we can from the pool and fall back to initializing
+ // new iree_hal_cuda2_event_t objects.
+ iree_host_size_t remaining_count = event_count;
+
+ // Try first to grab from the pool.
+ iree_slim_mutex_lock(&event_pool->event_mutex);
+ iree_host_size_t from_pool_count =
+ iree_min(event_pool->available_count, event_count);
+ if (from_pool_count > 0) {
+ iree_host_size_t pool_base_index =
+ event_pool->available_count - from_pool_count;
+ memcpy(out_events, &event_pool->available_list[pool_base_index],
+ from_pool_count * sizeof(*event_pool->available_list));
+ event_pool->available_count -= from_pool_count;
+ remaining_count -= from_pool_count;
+ }
+ iree_slim_mutex_unlock(&event_pool->event_mutex);
+
+ // Allocate the rest of the events.
+ if (remaining_count > 0) {
+ IREE_TRACE_ZONE_BEGIN_NAMED(z1, "event-pool-unpooled-acquire");
+ iree_status_t status = iree_ok_status();
+ for (iree_host_size_t i = 0; i < remaining_count; ++i) {
+ status = iree_hal_cuda2_event_create(event_pool->symbols, event_pool,
+ event_pool->host_allocator,
+ &out_events[from_pool_count + i]);
+ if (!iree_status_is_ok(status)) {
+ // Must release all events we've acquired so far.
+ iree_hal_cuda2_event_pool_release(event_pool, from_pool_count + i,
+ out_events);
+ IREE_TRACE_ZONE_END(z1);
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+ }
+ }
+ IREE_TRACE_ZONE_END(z1);
+ }
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+static void iree_hal_cuda2_event_pool_release(
+ iree_hal_cuda2_event_pool_t* event_pool, iree_host_size_t event_count,
+ iree_hal_cuda2_event_t** events) {
+ IREE_ASSERT_ARGUMENT(event_pool);
+ if (!event_count) return;
+ IREE_ASSERT_ARGUMENT(events);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // We'll try to release all we can back to the pool and then deinitialize
+ // the ones that won't fit.
+ iree_host_size_t remaining_count = event_count;
+
+ // Try first to release to the pool.
+ iree_slim_mutex_lock(&event_pool->event_mutex);
+ iree_host_size_t to_pool_count =
+ iree_min(event_pool->available_capacity - event_pool->available_count,
+ event_count);
+ if (to_pool_count > 0) {
+ for (iree_host_size_t i = 0; i < to_pool_count; ++i) {
+ IREE_ASSERT_REF_COUNT_ZERO(&events[i]->ref_count);
+ iree_hal_cuda2_event_retain(events[i]); // -> 1
+ }
+ iree_host_size_t pool_base_index = event_pool->available_count;
+ memcpy(&event_pool->available_list[pool_base_index], events,
+ to_pool_count * sizeof(*event_pool->available_list));
+ event_pool->available_count += to_pool_count;
+ remaining_count -= to_pool_count;
+ }
+ iree_slim_mutex_unlock(&event_pool->event_mutex);
+
+ // Deallocate the rest of the events. We don't bother resetting them as we are
+ // getting rid of them.
+ if (remaining_count > 0) {
+ IREE_TRACE_ZONE_BEGIN_NAMED(z1, "event-pool-unpooled-release");
+ for (iree_host_size_t i = 0; i < remaining_count; ++i) {
+ iree_hal_cuda2_event_destroy(events[to_pool_count + i]);
+ }
+ IREE_TRACE_ZONE_END(z1);
+ }
+ IREE_TRACE_ZONE_END(z0);
+}
diff --git a/experimental/cuda2/event_pool.h b/experimental/cuda2/event_pool.h
new file mode 100644
index 0000000..50f9624
--- /dev/null
+++ b/experimental/cuda2/event_pool.h
@@ -0,0 +1,78 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef EXPERIMENTAL_CUDA2_EVENT_POOL_H_
+#define EXPERIMENTAL_CUDA2_EVENT_POOL_H_
+
+#include "experimental/cuda2/cuda_dynamic_symbols.h"
+#include "iree/base/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+//===----------------------------------------------------------------------===//
+// iree_hal_cuda2_event_t
+//===----------------------------------------------------------------------===//
+
+// An struct that wraps a CUevent object with a reference count for lifetime
+// management.
+//
+// iree_hal_cuda2_event_t objects cannot be directly created; they should be
+// acquired from the event pool and released back to the pool once done.
+//
+// Thread-safe; multiple threads may retain and release the same event.
+typedef struct iree_hal_cuda2_event_t iree_hal_cuda2_event_t;
+
+// Returns the underlying CUevent handle behind |event|.
+CUevent iree_hal_cuda2_event_handle(const iree_hal_cuda2_event_t* event);
+
+// Retains the given |event| by increasing its reference count.
+void iree_hal_cuda2_event_retain(iree_hal_cuda2_event_t* event);
+
+// Releases the given |event| by decreasing its reference count.
+//
+// |event| will be returned to its owning pool when the reference count is 0.
+void iree_hal_cuda2_event_release(iree_hal_cuda2_event_t* event);
+
+//===----------------------------------------------------------------------===//
+// iree_hal_cuda2_event_pool_t
+//===----------------------------------------------------------------------===//
+
+// A simple pool of iree_hal_event_t objects to recycle.
+//
+// Thread-safe; multiple threads may acquire and release events from the pool.
+typedef struct iree_hal_cuda2_event_pool_t iree_hal_cuda2_event_pool_t;
+
+// Allocates a new event pool with up to |available_capacity| events.
+//
+// Extra events requested beyond the capability are directly created and
+// destroyed without pooling.
+iree_status_t iree_hal_cuda2_event_pool_allocate(
+ const iree_hal_cuda2_dynamic_symbols_t* symbols,
+ iree_host_size_t available_capacity, iree_allocator_t host_allocator,
+ iree_hal_cuda2_event_pool_t** out_event_pool);
+
+// Deallocates an event pool and destroys all events.
+//
+// All events that were acquired from the pool must have already been released
+// back to it prior to deallocation.
+void iree_hal_cuda2_event_pool_free(iree_hal_cuda2_event_pool_t* event_pool);
+
+// Acquires one or more events from the event pool.
+//
+// Each returned event have an initial reference count of 1. The returned
+// CUevent objects may retain captured states of some queues from previous
+// uses; callers should record again to overwrite.
+iree_status_t iree_hal_cuda2_event_pool_acquire(
+ iree_hal_cuda2_event_pool_t* event_pool, iree_host_size_t event_count,
+ iree_hal_cuda2_event_t** out_events);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // EXPERIMENTAL_CUDA2_EVENT_POOL_H_
diff --git a/experimental/cuda2/event_semaphore.c b/experimental/cuda2/event_semaphore.c
new file mode 100644
index 0000000..00ef215
--- /dev/null
+++ b/experimental/cuda2/event_semaphore.c
@@ -0,0 +1,414 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "experimental/cuda2/event_semaphore.h"
+
+#include "experimental/cuda2/cuda_dynamic_symbols.h"
+#include "experimental/cuda2/cuda_headers.h"
+#include "experimental/cuda2/cuda_status_util.h"
+#include "experimental/cuda2/timepoint_pool.h"
+#include "iree/base/internal/synchronization.h"
+#include "iree/hal/utils/semaphore_base.h"
+
+// Sentinel to indicate the semaphore has failed and an error status is set.
+#define IREE_HAL_CUDA_SEMAPHORE_FAILURE_VALUE UINT64_MAX
+
+typedef struct iree_hal_cuda2_semaphore_t {
+ // Abstract resource used for injecting reference counting and vtable;
+ // must be at offset 0.
+ iree_hal_semaphore_t base;
+
+ // The allocator used to create this semaphore.
+ iree_allocator_t host_allocator;
+ // The symbols used to issue CUDA API calls.
+ const iree_hal_cuda2_dynamic_symbols_t* symbols;
+
+ // The timepoint pool to acquire timepoint objects.
+ iree_hal_cuda2_timepoint_pool_t* timepoint_pool;
+
+ // The list of pending queue actions that this semaphore need to advance on
+ // new signaled values.
+ iree_hal_cuda2_pending_queue_actions_t* pending_queue_actions;
+
+ // Guards value and status. We expect low contention on semaphores and since
+ // iree_slim_mutex_t is (effectively) just a CAS this keeps things simpler
+ // than trying to make the entire structure lock-free.
+ iree_slim_mutex_t mutex;
+
+ // Current signaled value. May be IREE_HAL_CUDA_SEMAPHORE_FAILURE_VALUE to
+ // indicate that the semaphore has been signaled for failure and
+ // |failure_status| contains the error.
+ uint64_t current_value IREE_GUARDED_BY(mutex);
+
+ // OK or the status passed to iree_hal_semaphore_fail. Owned by the semaphore.
+ iree_status_t failure_status IREE_GUARDED_BY(mutex);
+} iree_hal_cuda2_semaphore_t;
+
+static const iree_hal_semaphore_vtable_t iree_hal_cuda2_semaphore_vtable;
+
+static iree_hal_cuda2_semaphore_t* iree_hal_cuda2_semaphore_cast(
+ iree_hal_semaphore_t* base_value) {
+ IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_cuda2_semaphore_vtable);
+ return (iree_hal_cuda2_semaphore_t*)base_value;
+}
+
+iree_status_t iree_hal_cuda2_event_semaphore_create(
+ uint64_t initial_value, const iree_hal_cuda2_dynamic_symbols_t* symbols,
+ iree_hal_cuda2_timepoint_pool_t* timepoint_pool,
+ iree_hal_cuda2_pending_queue_actions_t* pending_queue_actions,
+ iree_allocator_t host_allocator, iree_hal_semaphore_t** out_semaphore) {
+ IREE_ASSERT_ARGUMENT(symbols);
+ IREE_ASSERT_ARGUMENT(timepoint_pool);
+ IREE_ASSERT_ARGUMENT(pending_queue_actions);
+ IREE_ASSERT_ARGUMENT(out_semaphore);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_hal_cuda2_semaphore_t* semaphore = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_allocator_malloc(host_allocator, sizeof(*semaphore),
+ (void**)&semaphore));
+
+ iree_hal_semaphore_initialize(&iree_hal_cuda2_semaphore_vtable,
+ &semaphore->base);
+ semaphore->host_allocator = host_allocator;
+ semaphore->symbols = symbols;
+ semaphore->timepoint_pool = timepoint_pool;
+ semaphore->pending_queue_actions = pending_queue_actions;
+ iree_slim_mutex_initialize(&semaphore->mutex);
+ semaphore->current_value = initial_value;
+ semaphore->failure_status = iree_ok_status();
+
+ *out_semaphore = &semaphore->base;
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+static void iree_hal_cuda2_semaphore_destroy(
+ iree_hal_semaphore_t* base_semaphore) {
+ iree_hal_cuda2_semaphore_t* semaphore =
+ iree_hal_cuda2_semaphore_cast(base_semaphore);
+ iree_allocator_t host_allocator = semaphore->host_allocator;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_status_ignore(semaphore->failure_status);
+ iree_slim_mutex_deinitialize(&semaphore->mutex);
+
+ iree_hal_semaphore_deinitialize(&semaphore->base);
+ iree_allocator_free(host_allocator, semaphore);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+static iree_status_t iree_hal_cuda2_semaphore_query(
+ iree_hal_semaphore_t* base_semaphore, uint64_t* out_value) {
+ iree_hal_cuda2_semaphore_t* semaphore =
+ iree_hal_cuda2_semaphore_cast(base_semaphore);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_slim_mutex_lock(&semaphore->mutex);
+
+ *out_value = semaphore->current_value;
+
+ iree_status_t status = iree_ok_status();
+ if (*out_value >= IREE_HAL_CUDA_SEMAPHORE_FAILURE_VALUE) {
+ status = iree_status_clone(semaphore->failure_status);
+ }
+
+ iree_slim_mutex_unlock(&semaphore->mutex);
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+static iree_status_t iree_hal_cuda2_semaphore_signal(
+ iree_hal_semaphore_t* base_semaphore, uint64_t new_value) {
+ iree_hal_cuda2_semaphore_t* semaphore =
+ iree_hal_cuda2_semaphore_cast(base_semaphore);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_slim_mutex_lock(&semaphore->mutex);
+
+ if (new_value <= semaphore->current_value) {
+ uint64_t current_value IREE_ATTRIBUTE_UNUSED = semaphore->current_value;
+ iree_slim_mutex_unlock(&semaphore->mutex);
+ IREE_TRACE_ZONE_END(z0);
+ return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+ "semaphore values must be monotonically "
+ "increasing; current_value=%" PRIu64
+ ", new_value=%" PRIu64,
+ current_value, new_value);
+ }
+
+ semaphore->current_value = new_value;
+
+ iree_slim_mutex_unlock(&semaphore->mutex);
+
+ // Notify timepoints - note that this must happen outside the lock.
+ iree_hal_semaphore_notify(&semaphore->base, new_value, IREE_STATUS_OK);
+
+ // Advance the pending queue actions if possible. This also must happen
+ // outside the lock to avoid nesting.
+ iree_status_t status = iree_hal_cuda2_pending_queue_actions_issue(
+ semaphore->pending_queue_actions);
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+static void iree_hal_cuda2_semaphore_fail(iree_hal_semaphore_t* base_semaphore,
+ iree_status_t status) {
+ iree_hal_cuda2_semaphore_t* semaphore =
+ iree_hal_cuda2_semaphore_cast(base_semaphore);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ const iree_status_code_t status_code = iree_status_code(status);
+
+ iree_slim_mutex_lock(&semaphore->mutex);
+
+ // Try to set our local status - we only preserve the first failure so only
+ // do this if we are going from a valid semaphore to a failed one.
+ if (!iree_status_is_ok(semaphore->failure_status)) {
+ // Previous status was not OK; drop our new status.
+ IREE_IGNORE_ERROR(status);
+ iree_slim_mutex_unlock(&semaphore->mutex);
+ IREE_TRACE_ZONE_END(z0);
+ return;
+ }
+
+ // Signal to our failure sentinel value.
+ semaphore->current_value = IREE_HAL_CUDA_SEMAPHORE_FAILURE_VALUE;
+ semaphore->failure_status = status;
+
+ iree_slim_mutex_unlock(&semaphore->mutex);
+
+ // Notify timepoints - note that this must happen outside the lock.
+ iree_hal_semaphore_notify(&semaphore->base,
+ IREE_HAL_CUDA_SEMAPHORE_FAILURE_VALUE, status_code);
+ IREE_TRACE_ZONE_END(z0);
+}
+
+// Handles host wait timepoints on the host when the |semaphore| timeline
+// advances past the given |value|.
+//
+// Note that this callback is invoked by the a host thread.
+static iree_status_t iree_hal_cuda2_semaphore_timepoint_host_wait_callback(
+ void* user_data, iree_hal_semaphore_t* semaphore, uint64_t value,
+ iree_status_code_t status_code) {
+ iree_hal_cuda2_timepoint_t* timepoint =
+ (iree_hal_cuda2_timepoint_t*)user_data;
+ iree_event_set(&timepoint->timepoint.host_wait);
+ return iree_ok_status();
+}
+
+// Acquires a timepoint to wait the timeline to reach at least the given
+// |min_value| from the host.
+static iree_status_t iree_hal_cuda2_semaphore_acquire_timepoint_host_wait(
+ iree_hal_cuda2_semaphore_t* semaphore, uint64_t min_value,
+ iree_timeout_t timeout, iree_hal_cuda2_timepoint_t** out_timepoint) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_cuda2_timepoint_pool_acquire_host_wait(
+ semaphore->timepoint_pool, 1, out_timepoint));
+ // Initialize the timepoint with the value and callback, and connect it to
+ // this semaphore.
+ iree_hal_semaphore_acquire_timepoint(
+ &semaphore->base, min_value, timeout,
+ (iree_hal_semaphore_callback_t){
+ .fn = iree_hal_cuda2_semaphore_timepoint_host_wait_callback,
+ .user_data = *out_timepoint,
+ },
+ &(*out_timepoint)->base);
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_cuda2_semaphore_wait(
+ iree_hal_semaphore_t* base_semaphore, uint64_t value,
+ iree_timeout_t timeout) {
+ iree_hal_cuda2_semaphore_t* semaphore =
+ iree_hal_cuda2_semaphore_cast(base_semaphore);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_slim_mutex_lock(&semaphore->mutex);
+ if (!iree_status_is_ok(semaphore->failure_status)) {
+ // Fastest path: failed; return an error to tell callers to query for it.
+ iree_slim_mutex_unlock(&semaphore->mutex);
+ IREE_TRACE_ZONE_END(z0);
+ return iree_status_from_code(IREE_STATUS_ABORTED);
+ }
+ if (semaphore->current_value >= value) {
+ // Fast path: already satisfied.
+ iree_slim_mutex_unlock(&semaphore->mutex);
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+ }
+ if (iree_timeout_is_immediate(timeout)) {
+ // Not satisfied but a poll, so can avoid the expensive wait handle work.
+ iree_slim_mutex_unlock(&semaphore->mutex);
+ IREE_TRACE_ZONE_END(z0);
+ return iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED);
+ }
+ iree_slim_mutex_unlock(&semaphore->mutex);
+
+ iree_time_t deadline_ns = iree_timeout_as_deadline_ns(timeout);
+
+ // Slow path: acquire a timepoint. This should happen outside of the lock to
+ // given that acquiring has its own internal locks.
+ iree_hal_cuda2_timepoint_t* timepoint = NULL;
+ iree_status_t status = iree_hal_cuda2_semaphore_acquire_timepoint_host_wait(
+ semaphore, value, timeout, &timepoint);
+ if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+ }
+
+ // Wait until the timepoint resolves.
+ // If satisfied the timepoint is automatically cleaned up and we are done. If
+ // the deadline is reached before satisfied then we have to clean it up.
+ status = iree_wait_one(&timepoint->timepoint.host_wait, deadline_ns);
+ if (!iree_status_is_ok(status)) {
+ iree_hal_semaphore_cancel_timepoint(&semaphore->base, &timepoint->base);
+ }
+ iree_hal_cuda2_timepoint_pool_release(semaphore->timepoint_pool, 1,
+ &timepoint);
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+// Handles device signal timepoints on the host when the |semaphore| timeline
+// advances past the given |value|.
+//
+// Note that this callback is invoked by the a host thread after the CUDA host
+// function callback function is triggered in the CUDA driver.
+static iree_status_t iree_hal_cuda2_semaphore_timepoint_device_signal_callback(
+ void* user_data, iree_hal_semaphore_t* semaphore, uint64_t value,
+ iree_status_code_t status_code) {
+ iree_hal_cuda2_timepoint_t* timepoint =
+ (iree_hal_cuda2_timepoint_t*)user_data;
+ // Just release the timepoint back to the pool. This will decrease the
+ // reference count of the underlying CUDA event internally.
+ iree_hal_cuda2_timepoint_pool_release(timepoint->pool, 1, &timepoint);
+ return iree_ok_status();
+}
+
+// Acquires a timepoint to signal the timeline to the given |to_value| from the
+// device.
+iree_status_t iree_hal_cuda2_event_semaphore_acquire_timepoint_device_signal(
+ iree_hal_semaphore_t* base_semaphore, uint64_t to_value,
+ CUevent* out_event) {
+ iree_hal_cuda2_semaphore_t* semaphore =
+ iree_hal_cuda2_semaphore_cast(base_semaphore);
+ iree_hal_cuda2_timepoint_t* signal_timepoint = NULL;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_cuda2_timepoint_pool_acquire_device_signal(
+ semaphore->timepoint_pool, 1, &signal_timepoint));
+
+ // Initialize the timepoint with the value and callback, and connect it to
+ // this semaphore.
+ iree_hal_semaphore_acquire_timepoint(
+ &semaphore->base, to_value, iree_infinite_timeout(),
+ (iree_hal_semaphore_callback_t){
+ .fn = iree_hal_cuda2_semaphore_timepoint_device_signal_callback,
+ .user_data = signal_timepoint,
+ },
+ &signal_timepoint->base);
+ iree_hal_cuda2_event_t* event = signal_timepoint->timepoint.device_signal;
+
+ // Scan through the timepoint list and update device wait timepoints to wait
+ // for this device signal when possible. We need to lock with the timepoint
+ // list mutex here.
+ iree_slim_mutex_lock(&semaphore->base.timepoint_mutex);
+ for (iree_hal_semaphore_timepoint_t* tp = semaphore->base.timepoint_list.head;
+ tp != NULL; tp = tp->next) {
+ iree_hal_cuda2_timepoint_t* wait_timepoint =
+ (iree_hal_cuda2_timepoint_t*)tp;
+ if (wait_timepoint->kind == IREE_HAL_CUDA_TIMEPOINT_KIND_DEVICE_WAIT &&
+ wait_timepoint->timepoint.device_wait == NULL &&
+ wait_timepoint->base.minimum_value <= to_value) {
+ iree_hal_cuda2_event_retain(event);
+ wait_timepoint->timepoint.device_wait = event;
+ }
+ }
+ iree_slim_mutex_unlock(&semaphore->base.timepoint_mutex);
+
+ *out_event = iree_hal_cuda2_event_handle(event);
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+// Handles device wait timepoints on the host when the |semaphore| timeline
+// advances past the given |value|.
+//
+// Note that this callback is invoked by the a host thread.
+static iree_status_t iree_hal_cuda2_semaphore_timepoint_device_wait_callback(
+ void* user_data, iree_hal_semaphore_t* semaphore, uint64_t value,
+ iree_status_code_t status_code) {
+ iree_hal_cuda2_timepoint_t* timepoint =
+ (iree_hal_cuda2_timepoint_t*)user_data;
+ // Just release the timepoint back to the pool. This will decrease the
+ // reference count of the underlying CUDA event internally.
+ iree_hal_cuda2_timepoint_pool_release(timepoint->pool, 1, &timepoint);
+ return iree_ok_status();
+}
+
+// Acquires a timepoint to wait the timeline to reach at least the given
+// |min_value| on the device.
+iree_status_t iree_hal_cuda2_event_semaphore_acquire_timepoint_device_wait(
+ iree_hal_semaphore_t* base_semaphore, uint64_t min_value,
+ CUevent* out_event) {
+ iree_hal_cuda2_semaphore_t* semaphore =
+ iree_hal_cuda2_semaphore_cast(base_semaphore);
+ iree_hal_cuda2_timepoint_t* wait_timepoint = NULL;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_cuda2_timepoint_pool_acquire_device_wait(
+ semaphore->timepoint_pool, 1, &wait_timepoint));
+
+ // Initialize the timepoint with the value and callback, and connect it to
+ // this semaphore.
+ iree_hal_semaphore_acquire_timepoint(
+ &semaphore->base, min_value, iree_infinite_timeout(),
+ (iree_hal_semaphore_callback_t){
+ .fn = iree_hal_cuda2_semaphore_timepoint_device_wait_callback,
+ .user_data = wait_timepoint,
+ },
+ &wait_timepoint->base);
+
+ // Scan through the timepoint list and try to find a device event signal to
+ // wait on. We need to lock with the timepoint list mutex here.
+ iree_slim_mutex_lock(&semaphore->base.timepoint_mutex);
+ for (iree_hal_semaphore_timepoint_t* tp = semaphore->base.timepoint_list.head;
+ tp != NULL; tp = tp->next) {
+ iree_hal_cuda2_timepoint_t* signal_timepoint =
+ (iree_hal_cuda2_timepoint_t*)tp;
+ if (signal_timepoint->kind == IREE_HAL_CUDA_TIMEPOINT_KIND_DEVICE_SIGNAL &&
+ signal_timepoint->base.minimum_value >= min_value) {
+ iree_hal_cuda2_event_t* event = signal_timepoint->timepoint.device_signal;
+ iree_hal_cuda2_event_retain(event);
+ wait_timepoint->timepoint.device_wait = event;
+ }
+ }
+ iree_slim_mutex_unlock(&semaphore->base.timepoint_mutex);
+
+ *out_event =
+ iree_hal_cuda2_event_handle(wait_timepoint->timepoint.device_wait);
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+static const iree_hal_semaphore_vtable_t iree_hal_cuda2_semaphore_vtable = {
+ .destroy = iree_hal_cuda2_semaphore_destroy,
+ .query = iree_hal_cuda2_semaphore_query,
+ .signal = iree_hal_cuda2_semaphore_signal,
+ .fail = iree_hal_cuda2_semaphore_fail,
+ .wait = iree_hal_cuda2_semaphore_wait,
+};
diff --git a/experimental/cuda2/event_semaphore.h b/experimental/cuda2/event_semaphore.h
new file mode 100644
index 0000000..1ec09ed
--- /dev/null
+++ b/experimental/cuda2/event_semaphore.h
@@ -0,0 +1,56 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef EXPERIMENTAL_CUDA2_EVENT_SEMAPHORE_H_
+#define EXPERIMENTAL_CUDA2_EVENT_SEMAPHORE_H_
+
+#include <stdint.h>
+
+#include "experimental/cuda2/cuda_dynamic_symbols.h"
+#include "experimental/cuda2/pending_queue_actions.h"
+#include "experimental/cuda2/timepoint_pool.h"
+#include "iree/base/api.h"
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// Creates an IREE HAL semaphore with the given |initial_value|.
+//
+// The HAL semaphore are backed by iree_event_t or CUevent objects for different
+// timepoints along the timeline under the hood. Those timepoints will be
+// allocated from the |timepoint_pool|.
+//
+// This semaphore is meant to be used together with a pending queue actions; it
+// may advance the given |pending_queue_actions| if new values are signaled.
+//
+// Thread-safe; multiple threads may signal/wait values on the same semaphore.
+iree_status_t iree_hal_cuda2_event_semaphore_create(
+ uint64_t initial_value, const iree_hal_cuda2_dynamic_symbols_t* symbols,
+ iree_hal_cuda2_timepoint_pool_t* timepoint_pool,
+ iree_hal_cuda2_pending_queue_actions_t* pending_queue_actions,
+ iree_allocator_t host_allocator, iree_hal_semaphore_t** out_semaphore);
+
+// Acquires a timepoint to signal the timeline to the given |to_value| from the
+// device. The underlying CUDA event is written into |out_event| for interacting
+// with CUDA APIs.
+iree_status_t iree_hal_cuda2_event_semaphore_acquire_timepoint_device_signal(
+ iree_hal_semaphore_t* base_semaphore, uint64_t to_value,
+ CUevent* out_event);
+
+// Acquires a timepoint to wait the timeline to reach at least the given
+// |min_value| on the device The underlying CUDA event is written into
+// |out_event| for interacting with CUDA APIs.
+iree_status_t iree_hal_cuda2_event_semaphore_acquire_timepoint_device_wait(
+ iree_hal_semaphore_t* base_semaphore, uint64_t min_value,
+ CUevent* out_event);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // EXPERIMENTAL_CUDA2_EVENT_SEMAPHORE_H_
diff --git a/experimental/cuda2/pending_queue_actions.c b/experimental/cuda2/pending_queue_actions.c
new file mode 100644
index 0000000..b096c78
--- /dev/null
+++ b/experimental/cuda2/pending_queue_actions.c
@@ -0,0 +1,516 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "experimental/cuda2/pending_queue_actions.h"
+
+#include <stdbool.h>
+
+#include "experimental/cuda2/cuda_dynamic_symbols.h"
+#include "experimental/cuda2/cuda_status_util.h"
+#include "experimental/cuda2/event_semaphore.h"
+#include "experimental/cuda2/graph_command_buffer.h"
+#include "iree/base/api.h"
+#include "iree/base/internal/arena.h"
+#include "iree/base/internal/synchronization.h"
+#include "iree/hal/api.h"
+#include "iree/hal/utils/resource_set.h"
+
+//===----------------------------------------------------------------------===//
+// Queue action
+//===----------------------------------------------------------------------===//
+
+typedef enum iree_hal_cuda2_queue_action_kind_e {
+ IREE_HAL_CUDA2_QUEUE_ACTION_TYPE_EXECUTION,
+ // TODO: Add support for queue alloca and dealloca.
+} iree_hal_cuda2_queue_action_kind_t;
+
+// A pending queue action.
+//
+// Note that this struct does not have internal synchronization; it's expected
+// to work together with the pending action queue, which synchronizes accesses.
+typedef struct iree_hal_cuda2_queue_action_t {
+ // Intrusive doubly-linked list next entry pointer.
+ struct iree_hal_cuda2_queue_action_t* next;
+ // Intrusive doubly-linked list previous entry pointer.
+ struct iree_hal_cuda2_queue_action_t* prev;
+
+ // The owning pending actions queue. We use its allocators and pools.
+ // Retained to make sure it outlives the current action.
+ iree_hal_cuda2_pending_queue_actions_t* owning_actions;
+
+ iree_hal_cuda2_queue_action_kind_t kind;
+ union {
+ struct {
+ iree_host_size_t count;
+ iree_hal_command_buffer_t* const* ptr;
+ } command_buffers;
+ } payload;
+
+ // The stream to launch main GPU workload.
+ CUstream dispatch_cu_stream;
+ // The stream to launch CUDA host function callbacks.
+ CUstream callback_cu_stream;
+
+ // Resource set to retain all associated resources by the payload.
+ iree_hal_resource_set_t* resource_set;
+
+ // Semaphore list to wait on for the payload to start on the GPU.
+ iree_hal_semaphore_list_t wait_semaphore_list;
+ // Semaphore list to signal after the payload completes on the GPU.
+ iree_hal_semaphore_list_t signal_semaphore_list;
+
+ // Scratch fields for analyzing whether actions are ready to issue.
+ CUevent* events;
+ iree_host_size_t event_count;
+ bool is_pending;
+} iree_hal_cuda2_queue_action_t;
+
+//===----------------------------------------------------------------------===//
+// Queue action list
+//===----------------------------------------------------------------------===//
+
+typedef struct iree_hal_cuda2_queue_action_list_t {
+ iree_hal_cuda2_queue_action_t* head;
+ iree_hal_cuda2_queue_action_t* tail;
+} iree_hal_cuda2_queue_action_list_t;
+
+// Returns true if the action list is empty.
+static inline bool iree_hal_cuda2_queue_action_list_is_empty(
+ const iree_hal_cuda2_queue_action_list_t* list) {
+ return list->head == NULL;
+}
+
+// Pushes |action| on to the end of the given action |list|.
+static void iree_hal_cuda2_queue_action_list_push_back(
+ iree_hal_cuda2_queue_action_list_t* list,
+ iree_hal_cuda2_queue_action_t* action) {
+ if (list->tail) {
+ list->tail->next = action;
+ } else {
+ list->head = action;
+ }
+ action->next = NULL;
+ action->prev = list->tail;
+ list->tail = action;
+}
+
+// Erases |action| from |list|.
+static void iree_hal_cuda2_queue_action_list_erase(
+ iree_hal_cuda2_queue_action_list_t* list,
+ iree_hal_cuda2_queue_action_t* action) {
+ iree_hal_cuda2_queue_action_t* next = action->next;
+ iree_hal_cuda2_queue_action_t* prev = action->prev;
+ if (prev) {
+ prev->next = next;
+ action->prev = NULL;
+ } else {
+ list->head = next;
+ }
+ if (next) {
+ next->prev = prev;
+ action->next = NULL;
+ } else {
+ list->tail = prev;
+ }
+}
+
+// Takes all actions from |available_list| and moves them into |ready_list|.
+static void iree_hal_cuda2_queue_action_list_take_all(
+ iree_hal_cuda2_queue_action_list_t* available_list,
+ iree_hal_cuda2_queue_action_list_t* ready_list) {
+ IREE_ASSERT(available_list != ready_list);
+ ready_list->head = available_list->head;
+ ready_list->tail = available_list->tail;
+ available_list->head = NULL;
+ available_list->tail = NULL;
+}
+
+//===----------------------------------------------------------------------===//
+// Pending queue actions
+//===----------------------------------------------------------------------===//
+
+struct iree_hal_cuda2_pending_queue_actions_t {
+ // Abstract resource used for injecting reference counting and vtable;
+ // must be at offset 0.
+ iree_hal_resource_t resource;
+
+ // The allocator used to create the timepoint pool.
+ iree_allocator_t host_allocator;
+ // The block pool to allocate resource sets from.
+ iree_arena_block_pool_t* block_pool;
+
+ // The symbols used to create and destroy CUevent objects.
+ const iree_hal_cuda2_dynamic_symbols_t* symbols;
+
+ // Non-recursive mutex guarding access to the action list.
+ iree_slim_mutex_t action_mutex;
+
+ // The double-linked list of pending actions.
+ iree_hal_cuda2_queue_action_list_t action_list IREE_GUARDED_BY(action_mutex);
+};
+
+static const iree_hal_resource_vtable_t
+ iree_hal_cuda2_pending_queue_actions_vtable;
+
+iree_status_t iree_hal_cuda2_pending_queue_actions_create(
+ const iree_hal_cuda2_dynamic_symbols_t* symbols,
+ iree_arena_block_pool_t* block_pool, iree_allocator_t host_allocator,
+ iree_hal_cuda2_pending_queue_actions_t** out_actions) {
+ IREE_ASSERT_ARGUMENT(symbols);
+ IREE_ASSERT_ARGUMENT(block_pool);
+ IREE_ASSERT_ARGUMENT(out_actions);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_hal_cuda2_pending_queue_actions_t* actions = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_allocator_malloc(host_allocator, sizeof(*actions),
+ (void**)&actions));
+ iree_hal_resource_initialize(&iree_hal_cuda2_pending_queue_actions_vtable,
+ &actions->resource);
+ actions->host_allocator = host_allocator;
+ actions->block_pool = block_pool;
+ actions->symbols = symbols;
+ iree_slim_mutex_initialize(&actions->action_mutex);
+ memset(&actions->action_list, 0, sizeof(actions->action_list));
+ *out_actions = actions;
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+static iree_hal_cuda2_pending_queue_actions_t*
+iree_hal_cuda2_pending_queue_actions_cast(iree_hal_resource_t* base_value) {
+ return (iree_hal_cuda2_pending_queue_actions_t*)base_value;
+}
+
+void iree_hal_cuda2_pending_queue_actions_destroy(
+ iree_hal_resource_t* base_actions) {
+ iree_hal_cuda2_pending_queue_actions_t* actions =
+ iree_hal_cuda2_pending_queue_actions_cast(base_actions);
+ iree_allocator_t host_allocator = actions->host_allocator;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ IREE_ASSERT(iree_hal_cuda2_queue_action_list_is_empty(&actions->action_list));
+
+ iree_slim_mutex_deinitialize(&actions->action_mutex);
+ iree_allocator_free(host_allocator, actions);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+static const iree_hal_resource_vtable_t
+ iree_hal_cuda2_pending_queue_actions_vtable = {
+ .destroy = iree_hal_cuda2_pending_queue_actions_destroy,
+};
+
+// Performs copy of the given |in_list| to |out_list| to retain the semaphore
+// and value list.
+static iree_status_t iree_hal_cuda2_copy_semaphore_list(
+ iree_hal_semaphore_list_t in_list, iree_allocator_t host_allocator,
+ iree_hal_semaphore_list_t* out_list) {
+ if (in_list.count == 0) {
+ memset(out_list, 0, sizeof(*out_list));
+ } else {
+ out_list->count = in_list.count;
+
+ iree_host_size_t semaphore_size =
+ in_list.count * sizeof(*in_list.semaphores);
+ IREE_RETURN_IF_ERROR(iree_allocator_malloc(host_allocator, semaphore_size,
+ (void**)&out_list->semaphores));
+ memcpy(out_list->semaphores, in_list.semaphores, semaphore_size);
+
+ iree_host_size_t value_size =
+ in_list.count * sizeof(*in_list.payload_values);
+ IREE_RETURN_IF_ERROR(iree_allocator_malloc(
+ host_allocator, value_size, (void**)&out_list->payload_values));
+ memcpy(out_list->payload_values, in_list.payload_values, value_size);
+ }
+ return iree_ok_status();
+}
+
+// Frees the semaphore and value list inside |semaphore_list|.
+static void iree_hal_cuda2_free_semaphore_list(
+ iree_allocator_t host_allocator,
+ iree_hal_semaphore_list_t* semaphore_list) {
+ iree_allocator_free(host_allocator, semaphore_list->semaphores);
+ iree_allocator_free(host_allocator, semaphore_list->payload_values);
+}
+
+iree_status_t iree_hal_cuda2_pending_queue_actions_enqueue_execution(
+ CUstream dispatch_stream, CUstream callback_stream,
+ iree_hal_cuda2_pending_queue_actions_t* actions,
+ const iree_hal_semaphore_list_t wait_semaphore_list,
+ const iree_hal_semaphore_list_t signal_semaphore_list,
+ iree_host_size_t command_buffer_count,
+ iree_hal_command_buffer_t* const* command_buffers) {
+ IREE_ASSERT_ARGUMENT(actions);
+ IREE_ASSERT_ARGUMENT(command_buffer_count == 0 || command_buffers);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_hal_cuda2_queue_action_t* action = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_allocator_malloc(actions->host_allocator, sizeof(*action),
+ (void**)&action));
+
+ action->kind = IREE_HAL_CUDA2_QUEUE_ACTION_TYPE_EXECUTION;
+ action->payload.command_buffers.count = command_buffer_count;
+ action->payload.command_buffers.ptr = command_buffers;
+ action->dispatch_cu_stream = dispatch_stream;
+ action->callback_cu_stream = callback_stream;
+ action->events = NULL;
+ action->event_count = 0;
+ action->is_pending = true;
+
+ // Retain all command buffers and semaphores.
+ iree_hal_resource_set_t* resource_set = NULL;
+ iree_status_t status =
+ iree_hal_resource_set_allocate(actions->block_pool, &resource_set);
+ if (IREE_LIKELY(iree_status_is_ok(status))) {
+ status = iree_hal_resource_set_insert(resource_set, command_buffer_count,
+ command_buffers);
+ }
+ if (IREE_LIKELY(iree_status_is_ok(status))) {
+ status =
+ iree_hal_resource_set_insert(resource_set, wait_semaphore_list.count,
+ wait_semaphore_list.semaphores);
+ }
+ if (IREE_LIKELY(iree_status_is_ok(status))) {
+ status =
+ iree_hal_resource_set_insert(resource_set, signal_semaphore_list.count,
+ signal_semaphore_list.semaphores);
+ }
+
+ // Copy the semaphore and value list for later access.
+ // TODO: avoid host allocator malloc; use some pool for the allocation.
+ if (IREE_LIKELY(iree_status_is_ok(status))) {
+ status = iree_hal_cuda2_copy_semaphore_list(wait_semaphore_list,
+ actions->host_allocator,
+ &action->wait_semaphore_list);
+ }
+ if (IREE_LIKELY(iree_status_is_ok(status))) {
+ status = iree_hal_cuda2_copy_semaphore_list(signal_semaphore_list,
+ actions->host_allocator,
+ &action->signal_semaphore_list);
+ }
+
+ if (IREE_LIKELY(iree_status_is_ok(status))) {
+ action->owning_actions = actions;
+ iree_hal_resource_retain(actions);
+
+ action->resource_set = resource_set;
+
+ iree_slim_mutex_lock(&actions->action_mutex);
+ iree_hal_cuda2_queue_action_list_push_back(&actions->action_list, action);
+ iree_slim_mutex_unlock(&actions->action_mutex);
+ } else {
+ iree_hal_cuda2_free_semaphore_list(actions->host_allocator,
+ &action->wait_semaphore_list);
+ iree_hal_cuda2_free_semaphore_list(actions->host_allocator,
+ &action->signal_semaphore_list);
+ iree_hal_resource_set_free(resource_set);
+ iree_allocator_free(actions->host_allocator, action);
+ }
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+static void iree_hal_cuda2_pending_queue_actions_cleanup_execution(
+ iree_hal_cuda2_queue_action_t* action);
+
+// Releases resources after action completion on the GPU and advances timeline
+// and pending actions queue.
+//
+// This is the CUDA host function callback to cudaLaunchHostFunc, invoked by a
+// CUDA driver thread.
+static void iree_hal_cuda2_execution_device_signal_host_callback(
+ void* user_data) {
+ iree_hal_cuda2_queue_action_t* action =
+ (iree_hal_cuda2_queue_action_t*)user_data;
+ iree_hal_cuda2_pending_queue_actions_t* actions = action->owning_actions;
+ // Advance semaphore timelines by calling into the host signaling function.
+ IREE_IGNORE_ERROR(
+ iree_hal_semaphore_list_signal(action->signal_semaphore_list));
+ // Destroy the current action given its done now--this also frees all retained
+ // resources.
+ iree_hal_cuda2_pending_queue_actions_cleanup_execution(action);
+ // Try to release more pending actions to the GPU now.
+ IREE_IGNORE_ERROR(iree_hal_cuda2_pending_queue_actions_issue(actions));
+}
+
+// Issues the given kernel dispatch |action| to the GPU.
+static iree_status_t iree_hal_cuda2_pending_queue_actions_issue_execution(
+ iree_hal_cuda2_queue_action_t* action) {
+ IREE_ASSERT(action->events != NULL);
+ IREE_ASSERT(action->is_pending == false);
+ const iree_hal_cuda2_dynamic_symbols_t* symbols =
+ action->owning_actions->symbols;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // No need to lock given that this action is already detched from the pending
+ // actions list; so only this thread is seeing it now.
+
+ // First wait all the device CUevent in the dispatch stream.
+ for (iree_host_size_t i = 0; i < action->event_count; ++i) {
+ IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, symbols,
+ cuStreamWaitEvent(action->dispatch_cu_stream, action->events[i],
+ CU_EVENT_WAIT_DEFAULT),
+ "cuStreamWaitEvent");
+ }
+
+ // Then launch all command buffers to the dispatch stream.
+ for (iree_host_size_t i = 0; i < action->payload.command_buffers.count; ++i) {
+ CUgraphExec exec = iree_hal_cuda2_graph_command_buffer_handle(
+ action->payload.command_buffers.ptr[i]);
+ IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, symbols, cuGraphLaunch(exec, action->dispatch_cu_stream),
+ "cuGraphLaunch");
+ }
+
+ // Last record CUevent signals in the dispatch stream.
+ for (iree_host_size_t i = 0; i < action->signal_semaphore_list.count; ++i) {
+ // Grab a CUevent for this semaphore value signaling.
+ CUevent event = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_cuda2_event_semaphore_acquire_timepoint_device_signal(
+ action->signal_semaphore_list.semaphores[i],
+ action->signal_semaphore_list.payload_values[i], &event));
+
+ // Record the event signaling in the dispatch stream.
+ IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, symbols, cuEventRecord(event, action->dispatch_cu_stream),
+ "cuEventRecord");
+ // Let the callback stream to wait on the CUevent.
+ IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, symbols,
+ cuStreamWaitEvent(action->callback_cu_stream, event,
+ CU_EVENT_WAIT_DEFAULT),
+ "cuStreamWaitEvent");
+ }
+
+ // Now launch a host function on the callback stream to advance the semaphore
+ // timeline.
+ IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, symbols,
+ cuLaunchHostFunc(action->callback_cu_stream,
+ iree_hal_cuda2_execution_device_signal_host_callback,
+ action),
+ "cuStreamWaitEvent");
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+// Releases resources after completing the given kernel dispatch |action|.
+static void iree_hal_cuda2_pending_queue_actions_cleanup_execution(
+ iree_hal_cuda2_queue_action_t* action) {
+ iree_hal_cuda2_pending_queue_actions_t* actions = action->owning_actions;
+ iree_allocator_t host_allocator = actions->host_allocator;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_hal_resource_set_free(action->resource_set);
+ iree_hal_cuda2_free_semaphore_list(host_allocator,
+ &action->wait_semaphore_list);
+ iree_hal_cuda2_free_semaphore_list(host_allocator,
+ &action->signal_semaphore_list);
+ iree_hal_resource_release(actions);
+
+ iree_allocator_free(host_allocator, action);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+iree_status_t iree_hal_cuda2_pending_queue_actions_issue(
+ iree_hal_cuda2_pending_queue_actions_t* actions) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_hal_cuda2_queue_action_list_t pending_list = {NULL, NULL};
+ iree_hal_cuda2_queue_action_list_t ready_list = {NULL, NULL};
+
+ iree_slim_mutex_lock(&actions->action_mutex);
+
+ if (iree_hal_cuda2_queue_action_list_is_empty(&actions->action_list)) {
+ iree_slim_mutex_unlock(&actions->action_mutex);
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+ }
+
+ // Scan through the list and categorize actions into pending and ready lists.
+ for (iree_hal_cuda2_queue_action_t* action = actions->action_list.head;
+ action != NULL;) {
+ iree_hal_cuda2_queue_action_t* next_action = action->next;
+ action->next = NULL;
+
+ iree_host_size_t semaphore_count = action->wait_semaphore_list.count;
+ iree_hal_semaphore_t** semaphores = action->wait_semaphore_list.semaphores;
+ uint64_t* values = action->wait_semaphore_list.payload_values;
+
+ // We are allocating stack space here, assuming that there won't be a lot of
+ // waits and additional references to this field happens in a function call
+ // from this function.
+ action->events = iree_alloca(semaphore_count * sizeof(CUevent));
+ action->event_count = 0;
+ action->is_pending = false;
+
+ // Look at all wait semaphores.
+ for (iree_host_size_t i = 0; i < semaphore_count; ++i) {
+ // If this semaphore has already signaled past the desired value, we can
+ // just ignore it.
+ uint64_t value = 0;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_semaphore_query(semaphores[i], &value));
+ if (value >= values[i]) continue;
+
+ // Try to acquire a CUevent from a device wait timepoint. If so, we can
+ // use that CUevent to wait on the device. Otherwise, this action is still
+ // not ready.
+ CUevent event = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_cuda2_event_semaphore_acquire_timepoint_device_wait(
+ semaphores[i], values[i], &event));
+ if (event) {
+ action->events[action->event_count++] = event;
+ } else {
+ // Clear the scratch fields.
+ action->events = NULL;
+ action->event_count = 0;
+ action->is_pending = true;
+ break;
+ }
+ }
+
+ if (action->is_pending) {
+ iree_hal_cuda2_queue_action_list_push_back(&pending_list, action);
+ } else {
+ iree_hal_cuda2_queue_action_list_push_back(&ready_list, action);
+ }
+
+ action = next_action;
+ }
+
+ // Preserve pending timepoints.
+ actions->action_list = pending_list;
+
+ iree_slim_mutex_unlock(&actions->action_mutex);
+
+ // Now go through the ready list and issue the actions to the GPU.
+ for (iree_hal_cuda2_queue_action_t* action = ready_list.head;
+ action != NULL;) {
+ iree_hal_cuda2_queue_action_t* next_action = action->next;
+ action->next = NULL;
+
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_cuda2_pending_queue_actions_issue_execution(action));
+ action->events = NULL;
+ action->event_count = 0;
+
+ action = next_action;
+ }
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
diff --git a/experimental/cuda2/pending_queue_actions.h b/experimental/cuda2/pending_queue_actions.h
new file mode 100644
index 0000000..efcf7ec
--- /dev/null
+++ b/experimental/cuda2/pending_queue_actions.h
@@ -0,0 +1,66 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef EXPERIMENTAL_CUDA2_PENDING_QUEUE_ACTIONS_H_
+#define EXPERIMENTAL_CUDA2_PENDING_QUEUE_ACTIONS_H_
+
+#include "experimental/cuda2/cuda_dynamic_symbols.h"
+#include "iree/base/api.h"
+#include "iree/base/internal/arena.h"
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// A data structure to manage pending queue actions (kernel launches and async
+// allocations).
+//
+// This is needed in order to satisfy queue action dependencies. IREE uses HAL
+// semaphore as the unified mechanism for synchronization directions including
+// host to host, host to device, devie to device, and device to host. Plus, it
+// allows wait before signal. These flexible capabilities are not all supported
+// by CUevent objects. Therefore, we need supporting data structures to
+// implement them on top of CUevent objects. Thus this pending queue actions.
+//
+// This buffers pending queue actions and their associated resources. It
+// provides an API to advance the wait list on demand--queue actions are
+// released to the GPU when all their wait semaphores are signaled past the
+// desired value, or we can have a CUevent already recorded to some CUDA
+// stream to wait on.
+//
+// Thread-safe; multiple threads may enqueue workloads.
+typedef struct iree_hal_cuda2_pending_queue_actions_t
+ iree_hal_cuda2_pending_queue_actions_t;
+
+// Creates a pending actions queue.
+iree_status_t iree_hal_cuda2_pending_queue_actions_create(
+ const iree_hal_cuda2_dynamic_symbols_t* symbols,
+ iree_arena_block_pool_t* block_pool, iree_allocator_t host_allocator,
+ iree_hal_cuda2_pending_queue_actions_t** out_actions);
+
+// Destroys the pending |actions| queue.
+void iree_hal_cuda2_pending_queue_actions_destroy(iree_hal_resource_t* actions);
+
+// Enqueues the given list of |command_buffers| that waits on
+// |wait_semaphore_list| and signals |signal_semaphore_lsit|.
+iree_status_t iree_hal_cuda2_pending_queue_actions_enqueue_execution(
+ CUstream dispatch_stream, CUstream callback_stream,
+ iree_hal_cuda2_pending_queue_actions_t* actions,
+ const iree_hal_semaphore_list_t wait_semaphore_list,
+ const iree_hal_semaphore_list_t signal_semaphore_list,
+ iree_host_size_t command_buffer_count,
+ iree_hal_command_buffer_t* const* command_buffers);
+
+// Tries to scan the pending actions and release ready ones to the GPU.
+iree_status_t iree_hal_cuda2_pending_queue_actions_issue(
+ iree_hal_cuda2_pending_queue_actions_t* actions);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // EXPERIMENTAL_CUDA2_PENDING_QUEUE_ACTIONS_H_
diff --git a/experimental/cuda2/timepoint_pool.c b/experimental/cuda2/timepoint_pool.c
new file mode 100644
index 0000000..a8ffa69
--- /dev/null
+++ b/experimental/cuda2/timepoint_pool.c
@@ -0,0 +1,354 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "experimental/cuda2/timepoint_pool.h"
+
+#include <stdbool.h>
+#include <stddef.h>
+#include <string.h>
+
+#include "experimental/cuda2/cuda_dynamic_symbols.h"
+#include "experimental/cuda2/cuda_status_util.h"
+#include "experimental/cuda2/event_pool.h"
+#include "iree/base/api.h"
+#include "iree/base/internal/atomics.h"
+#include "iree/base/internal/event_pool.h"
+#include "iree/base/internal/synchronization.h"
+#include "iree/hal/api.h"
+#include "iree/hal/utils/semaphore_base.c"
+
+//===----------------------------------------------------------------------===//
+// iree_hal_cuda2_timepoint_t
+//===----------------------------------------------------------------------===//
+
+static iree_status_t iree_hal_cuda2_timepoint_allocate(
+ iree_hal_cuda2_timepoint_pool_t* pool, iree_allocator_t host_allocator,
+ iree_hal_cuda2_timepoint_t** out_timepoint) {
+ IREE_ASSERT_ARGUMENT(pool);
+ IREE_ASSERT_ARGUMENT(out_timepoint);
+ *out_timepoint = NULL;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_hal_cuda2_timepoint_t* timepoint = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_allocator_malloc(host_allocator, sizeof(*timepoint),
+ (void**)&timepoint));
+ // iree_allocator_malloc zeros out the whole struct.
+ timepoint->host_allocator = host_allocator;
+ timepoint->pool = pool;
+
+ *out_timepoint = timepoint;
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+// Clears all data fields in the given |timepoint| except the original host
+// allocator and owning pool.
+static void iree_hal_cuda2_timepoint_clear(
+ iree_hal_cuda2_timepoint_t* timepoint) {
+ iree_allocator_t host_allocator = timepoint->host_allocator;
+ iree_hal_cuda2_timepoint_pool_t* pool = timepoint->pool;
+ memset(timepoint, 0, sizeof(*timepoint));
+ timepoint->host_allocator = host_allocator;
+ timepoint->pool = pool;
+}
+
+static void iree_hal_cuda2_timepoint_free(
+ iree_hal_cuda2_timepoint_t* timepoint) {
+ iree_allocator_t host_allocator = timepoint->host_allocator;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ IREE_ASSERT(timepoint->kind == IREE_HAL_CUDA_TIMEPOINT_KIND_NONE);
+ iree_allocator_free(host_allocator, timepoint);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hal_cuda2_timepoint_pool_t
+//===----------------------------------------------------------------------===//
+
+struct iree_hal_cuda2_timepoint_pool_t {
+ // The allocator used to create the timepoint pool.
+ iree_allocator_t host_allocator;
+
+ // The pool to acquire host events.
+ iree_event_pool_t* host_event_pool;
+ // The pool to acquire device events. Internally synchronized.
+ iree_hal_cuda2_event_pool_t* device_event_pool;
+
+ // Note that the above pools are internally synchronized; so we don't and
+ // shouldn't use the following mutex to guard access to them.
+
+ // Guards timepoint related fields this pool. We don't expect a performant
+ // program to frequently allocate timepoints for synchronization purposes; the
+ // traffic to this pool should be low. So it should be fine to use mutex to
+ // guard here.
+ iree_slim_mutex_t timepoint_mutex;
+
+ // Maximum number of timepoint objects that will be maintained in the pool.
+ // More timepoints may be allocated at any time, but they will be disposed
+ // directly when they are no longer needed.
+ iree_host_size_t available_capacity IREE_GUARDED_BY(timepoint_mutex);
+ // Total number of currently available timepoint objects.
+ iree_host_size_t available_count IREE_GUARDED_BY(timepoint_mutex);
+ // The list of available_count timepoint objects.
+ iree_hal_cuda2_timepoint_t* available_list[] IREE_GUARDED_BY(timepoint_mutex);
+};
+// + Additional inline allocation for holding timepoints up to the capacity.
+
+iree_status_t iree_hal_cuda2_timepoint_pool_allocate(
+ iree_event_pool_t* host_event_pool,
+ iree_hal_cuda2_event_pool_t* device_event_pool,
+ iree_host_size_t available_capacity, iree_allocator_t host_allocator,
+ iree_hal_cuda2_timepoint_pool_t** out_timepoint_pool) {
+ IREE_ASSERT_ARGUMENT(host_event_pool);
+ IREE_ASSERT_ARGUMENT(device_event_pool);
+ IREE_ASSERT_ARGUMENT(out_timepoint_pool);
+ *out_timepoint_pool = NULL;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_hal_cuda2_timepoint_pool_t* timepoint_pool = NULL;
+ iree_host_size_t total_size =
+ sizeof(*timepoint_pool) +
+ available_capacity * sizeof(*timepoint_pool->available_list);
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_allocator_malloc(host_allocator, total_size,
+ (void**)&timepoint_pool));
+ timepoint_pool->host_allocator = host_allocator;
+ timepoint_pool->host_event_pool = host_event_pool;
+ timepoint_pool->device_event_pool = device_event_pool;
+
+ iree_slim_mutex_initialize(&timepoint_pool->timepoint_mutex);
+ timepoint_pool->available_capacity = available_capacity;
+ timepoint_pool->available_count = 0;
+
+ iree_status_t status = iree_ok_status();
+ for (iree_host_size_t i = 0; i < available_capacity; ++i) {
+ status = iree_hal_cuda2_timepoint_allocate(
+ timepoint_pool, host_allocator,
+ &timepoint_pool->available_list[timepoint_pool->available_count++]);
+ if (!iree_status_is_ok(status)) break;
+ }
+
+ if (iree_status_is_ok(status)) {
+ *out_timepoint_pool = timepoint_pool;
+ } else {
+ iree_hal_cuda2_timepoint_pool_free(timepoint_pool);
+ }
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+void iree_hal_cuda2_timepoint_pool_free(
+ iree_hal_cuda2_timepoint_pool_t* timepoint_pool) {
+ iree_allocator_t host_allocator = timepoint_pool->host_allocator;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ for (iree_host_size_t i = 0; i < timepoint_pool->available_count; ++i) {
+ iree_hal_cuda2_timepoint_free(timepoint_pool->available_list[i]);
+ }
+ iree_slim_mutex_deinitialize(&timepoint_pool->timepoint_mutex);
+ iree_allocator_free(host_allocator, timepoint_pool);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+// Acquires |timepoint_count| timepoints from the given |timepoint_pool|.
+// The |out_timepoints| needs to be further initialized with proper kind and
+// payload values.
+static iree_status_t iree_hal_cuda2_timepoint_pool_acquire_internal(
+ iree_hal_cuda2_timepoint_pool_t* timepoint_pool,
+ iree_host_size_t timepoint_count,
+ iree_hal_cuda2_timepoint_t** out_timepoints) {
+ IREE_ASSERT_ARGUMENT(timepoint_pool);
+ if (!timepoint_count) return iree_ok_status();
+ IREE_ASSERT_ARGUMENT(out_timepoints);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // We'll try to get what we can from the pool and fall back to initializing
+ // new iree_hal_cuda2_timepoint_t objects.
+ iree_host_size_t remaining_count = timepoint_count;
+
+ // Try first to grab from the pool.
+ iree_slim_mutex_lock(&timepoint_pool->timepoint_mutex);
+ iree_host_size_t from_pool_count =
+ iree_min(timepoint_pool->available_count, timepoint_count);
+ if (from_pool_count > 0) {
+ iree_host_size_t pool_base_index =
+ timepoint_pool->available_count - from_pool_count;
+ memcpy(out_timepoints, &timepoint_pool->available_list[pool_base_index],
+ from_pool_count * sizeof(*timepoint_pool->available_list));
+ timepoint_pool->available_count -= from_pool_count;
+ remaining_count -= from_pool_count;
+ }
+ iree_slim_mutex_unlock(&timepoint_pool->timepoint_mutex);
+
+ // Allocate the rest of the timepoints.
+ if (remaining_count > 0) {
+ IREE_TRACE_ZONE_BEGIN_NAMED(z1, "timepoint-pool-unpooled-acquire");
+ iree_status_t status = iree_ok_status();
+ for (iree_host_size_t i = 0; i < remaining_count; ++i) {
+ status = iree_hal_cuda2_timepoint_allocate(
+ timepoint_pool, timepoint_pool->host_allocator,
+ &out_timepoints[from_pool_count + i]);
+ if (!iree_status_is_ok(status)) {
+ // Must release all timepoints we've acquired so far.
+ iree_hal_cuda2_timepoint_pool_release(
+ timepoint_pool, from_pool_count + i, out_timepoints);
+ IREE_TRACE_ZONE_END(z1);
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+ }
+ }
+ IREE_TRACE_ZONE_END(z1);
+ }
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+iree_status_t iree_hal_cuda2_timepoint_pool_acquire_host_wait(
+ iree_hal_cuda2_timepoint_pool_t* timepoint_pool,
+ iree_host_size_t timepoint_count,
+ iree_hal_cuda2_timepoint_t** out_timepoints) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // Acquire host events to wrap up. This should happen before acquiring the
+ // timepoints to avoid nested locks.
+ iree_event_t* host_events = iree_alloca(
+ timepoint_count * sizeof((*out_timepoints)->timepoint.host_wait));
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_event_pool_acquire(timepoint_pool->host_event_pool,
+ timepoint_count, host_events));
+
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_cuda2_timepoint_pool_acquire_internal(
+ timepoint_pool, timepoint_count, out_timepoints));
+ for (iree_host_size_t i = 0; i < timepoint_count; ++i) {
+ out_timepoints[i]->kind = IREE_HAL_CUDA_TIMEPOINT_KIND_HOST_WAIT;
+ out_timepoints[i]->timepoint.host_wait = host_events[i];
+ }
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+iree_status_t iree_hal_cuda2_timepoint_pool_acquire_device_signal(
+ iree_hal_cuda2_timepoint_pool_t* timepoint_pool,
+ iree_host_size_t timepoint_count,
+ iree_hal_cuda2_timepoint_t** out_timepoints) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // Acquire device events to wrap up. This should happen before acquiring the
+ // timepoints to avoid nested locks.
+ iree_hal_cuda2_event_t** device_events = iree_alloca(
+ timepoint_count * sizeof((*out_timepoints)->timepoint.device_signal));
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_cuda2_event_pool_acquire(timepoint_pool->device_event_pool,
+ timepoint_count, device_events));
+
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_cuda2_timepoint_pool_acquire_internal(
+ timepoint_pool, timepoint_count, out_timepoints));
+ for (iree_host_size_t i = 0; i < timepoint_count; ++i) {
+ out_timepoints[i]->kind = IREE_HAL_CUDA_TIMEPOINT_KIND_DEVICE_SIGNAL;
+ out_timepoints[i]->timepoint.device_signal = device_events[i];
+ }
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+iree_status_t iree_hal_cuda2_timepoint_pool_acquire_device_wait(
+ iree_hal_cuda2_timepoint_pool_t* timepoint_pool,
+ iree_host_size_t timepoint_count,
+ iree_hal_cuda2_timepoint_t** out_timepoints) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // Acquire device events to wrap up. This should happen before acquiring the
+ // timepoints to avoid nested locks.
+ iree_hal_cuda2_event_t** device_events = iree_alloca(
+ timepoint_count * sizeof((*out_timepoints)->timepoint.device_wait));
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_cuda2_event_pool_acquire(timepoint_pool->device_event_pool,
+ timepoint_count, device_events));
+
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_cuda2_timepoint_pool_acquire_internal(
+ timepoint_pool, timepoint_count, out_timepoints));
+ for (iree_host_size_t i = 0; i < timepoint_count; ++i) {
+ out_timepoints[i]->kind = IREE_HAL_CUDA_TIMEPOINT_KIND_DEVICE_WAIT;
+ out_timepoints[i]->timepoint.device_wait = device_events[i];
+ }
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+void iree_hal_cuda2_timepoint_pool_release(
+ iree_hal_cuda2_timepoint_pool_t* timepoint_pool,
+ iree_host_size_t timepoint_count, iree_hal_cuda2_timepoint_t** timepoints) {
+ IREE_ASSERT_ARGUMENT(timepoint_pool);
+ if (!timepoint_count) return;
+ IREE_ASSERT_ARGUMENT(timepoints);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // Release the wrapped host/device events. This should happen before acquiring
+ // the timepoint pool's lock given that the host/device event pool its
+ // internal lock too.
+ // TODO: Release in batch to avoid lock overhead from separate calls.
+ for (iree_host_size_t i = 0; i < timepoint_count; ++i) {
+ switch (timepoints[i]->kind) {
+ case IREE_HAL_CUDA_TIMEPOINT_KIND_HOST_WAIT:
+ iree_event_pool_release(timepoint_pool->host_event_pool, 1,
+ &timepoints[i]->timepoint.host_wait);
+ break;
+ case IREE_HAL_CUDA_TIMEPOINT_KIND_DEVICE_SIGNAL:
+ iree_hal_cuda2_event_release(timepoints[i]->timepoint.device_signal);
+ break;
+ case IREE_HAL_CUDA_TIMEPOINT_KIND_DEVICE_WAIT:
+ iree_hal_cuda2_event_release(timepoints[i]->timepoint.device_wait);
+ break;
+ default:
+ break;
+ }
+ }
+
+ // We'll try to release all we can back to the pool and then deinitialize
+ // the ones that won't fit.
+ iree_host_size_t remaining_count = timepoint_count;
+
+ // Try first to release to the pool.
+ iree_slim_mutex_lock(&timepoint_pool->timepoint_mutex);
+ iree_host_size_t to_pool_count = iree_min(
+ timepoint_pool->available_capacity - timepoint_pool->available_count,
+ timepoint_count);
+ if (to_pool_count > 0) {
+ for (iree_host_size_t i = 0; i < to_pool_count; ++i) {
+ iree_hal_cuda2_timepoint_clear(timepoints[i]);
+ }
+ iree_host_size_t pool_base_index = timepoint_pool->available_count;
+ memcpy(&timepoint_pool->available_list[pool_base_index], timepoints,
+ to_pool_count * sizeof(*timepoint_pool->available_list));
+ timepoint_pool->available_count += to_pool_count;
+ remaining_count -= to_pool_count;
+ }
+ iree_slim_mutex_unlock(&timepoint_pool->timepoint_mutex);
+
+ // Deallocate the rest of the timepoints. We don't bother resetting them as we
+ // are getting rid of them.
+ if (remaining_count > 0) {
+ IREE_TRACE_ZONE_BEGIN_NAMED(z1, "timepoint-pool-unpooled-release");
+ for (iree_host_size_t i = 0; i < remaining_count; ++i) {
+ iree_hal_cuda2_timepoint_clear(timepoints[to_pool_count + i]);
+ iree_hal_cuda2_timepoint_free(timepoints[to_pool_count + i]);
+ }
+ IREE_TRACE_ZONE_END(z1);
+ }
+ IREE_TRACE_ZONE_END(z0);
+}
diff --git a/experimental/cuda2/timepoint_pool.h b/experimental/cuda2/timepoint_pool.h
new file mode 100644
index 0000000..b0d71f5
--- /dev/null
+++ b/experimental/cuda2/timepoint_pool.h
@@ -0,0 +1,119 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef EXPERIMENTAL_CUDA2_TIMEPOINT_POOL_H_
+#define EXPERIMENTAL_CUDA2_TIMEPOINT_POOL_H_
+
+#include "experimental/cuda2/event_pool.h"
+#include "iree/base/api.h"
+#include "iree/base/internal/event_pool.h"
+#include "iree/hal/utils/semaphore_base.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+//===----------------------------------------------------------------------===//
+// iree_hal_cuda2_timepoint_t
+//===----------------------------------------------------------------------===//
+
+// Forward declaration of the timepoint pool.
+typedef struct iree_hal_cuda2_timepoint_pool_t iree_hal_cuda2_timepoint_pool_t;
+
+// An enum to identify the timepoint kind in iree_hal_cuda_timepoint_t objects.
+typedef enum iree_hal_cuda2_timepoint_kind_e {
+ // None; for uninitialized timepoint objects.
+ IREE_HAL_CUDA_TIMEPOINT_KIND_NONE = 0,
+ // A timepoint waited by the host.
+ IREE_HAL_CUDA_TIMEPOINT_KIND_HOST_WAIT,
+ // A timepoint signaled by the device.
+ IREE_HAL_CUDA_TIMEPOINT_KIND_DEVICE_SIGNAL,
+ // A timepoint waited by the device.
+ IREE_HAL_CUDA_TIMEPOINT_KIND_DEVICE_WAIT,
+} iree_hal_cuda2_timepoint_kind_t;
+
+// An object that wraps a host iree_event_t or device iree_hal_cuda2_event_t to
+// represent wait/signal of a timepoint on a timeline.
+//
+// iree_hal_cuda2_timepoint_t objects cannot be directly created; it should be
+// acquired from the timeline pool and released back to the pool once done.
+//
+// Thread-compatible; a timepoint is typically only accessed by one thread.
+typedef struct iree_hal_cuda2_timepoint_t {
+ // Base timepoint structure providing intrusive linked list pointers and
+ // timepoint callback mechanisms.
+ iree_hal_semaphore_timepoint_t base;
+
+ // The allocator used to create the timepoint.
+ iree_allocator_t host_allocator;
+
+ // The timepoint pool that owns this timepoint.
+ iree_hal_cuda2_timepoint_pool_t* pool;
+
+ iree_hal_cuda2_timepoint_kind_t kind;
+ union {
+ iree_event_t host_wait;
+ iree_hal_cuda2_event_t* device_signal;
+ // The device event to wait. NULL means no device event available to wait
+ // for this timepoint at the moment.
+ iree_hal_cuda2_event_t* device_wait;
+ } timepoint;
+} iree_hal_cuda2_timepoint_t;
+
+//===----------------------------------------------------------------------===//
+// iree_hal_cuda2_timepoint_pool_t
+//===----------------------------------------------------------------------===//
+
+// A simple pool of iree_hal_cuda2_timepoint_t objects to recycle.
+//
+// Thread-safe; multiple threads may acquire and release timepoints from the
+// pool.
+typedef struct iree_hal_cuda2_timepoint_pool_t iree_hal_cuda2_timepoint_pool_t;
+
+// Allocates a new timepoint pool with up to |available_capacity| timepoints.
+//
+// Extra timepoint requests beyond the capability are directly created and
+// destroyed without pooling.
+iree_status_t iree_hal_cuda2_timepoint_pool_allocate(
+ iree_event_pool_t* host_event_pool,
+ iree_hal_cuda2_event_pool_t* device_event_pool,
+ iree_host_size_t available_capacity, iree_allocator_t host_allocator,
+ iree_hal_cuda2_timepoint_pool_t** out_timepoint_pool);
+
+// Deallocates a timepoint pool and destroys all timepoints.
+//
+// All timepoints that were acquired from the pool must have already been
+// released back to it prior to deallocation.
+void iree_hal_cuda2_timepoint_pool_free(
+ iree_hal_cuda2_timepoint_pool_t* timepoint_pool);
+
+// Acquires one or more timepoints from the timepoint pool.
+//
+// |out_timepoints| are owned by the caller and must be kept live until the
+// timepoints have been reached, or cancelled by the caller.
+iree_status_t iree_hal_cuda2_timepoint_pool_acquire_host_wait(
+ iree_hal_cuda2_timepoint_pool_t* timepoint_pool,
+ iree_host_size_t timepoint_count,
+ iree_hal_cuda2_timepoint_t** out_timepoints);
+iree_status_t iree_hal_cuda2_timepoint_pool_acquire_device_signal(
+ iree_hal_cuda2_timepoint_pool_t* timepoint_pool,
+ iree_host_size_t timepoint_count,
+ iree_hal_cuda2_timepoint_t** out_timepoints);
+iree_status_t iree_hal_cuda2_timepoint_pool_acquire_device_wait(
+ iree_hal_cuda2_timepoint_pool_t* timepoint_pool,
+ iree_host_size_t timepoint_count,
+ iree_hal_cuda2_timepoint_t** out_timepoints);
+
+// Releases one or more timepoints back to the timepoint pool.
+void iree_hal_cuda2_timepoint_pool_release(
+ iree_hal_cuda2_timepoint_pool_t* timepoint_pool,
+ iree_host_size_t timepoint_count, iree_hal_cuda2_timepoint_t** timepoints);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // EXPERIMENTAL_CUDA2_TIMEPOINT_POOL_H_
diff --git a/runtime/src/iree/hal/cts/semaphore_submission_test.h b/runtime/src/iree/hal/cts/semaphore_submission_test.h
index d6de947..9342957 100644
--- a/runtime/src/iree/hal/cts/semaphore_submission_test.h
+++ b/runtime/src/iree/hal/cts/semaphore_submission_test.h
@@ -183,6 +183,25 @@
iree_hal_semaphore_release(signal_semaphore_2);
}
+// TODO: test device -> device synchronization: submit two batches with a
+// semaphore singal -> wait dependency.
+//
+// TODO: test device -> device synchronization: submit multiple batches with
+// multiple later batches waiting on the same signaling from a former batch.
+//
+// TODO: test device -> device synchronization: submit multiple batches with
+// a former batch signaling a value greater than all other batches' (different)
+// wait values.
+
+// TODO: test host + device -> device synchronization: submit two batches
+// with a later batch waiting on both a host and device singal to proceed.
+
+// TODO: test device -> host + device synchronization: submit two batches
+// with a former batch signaling to enable both host and device to proceed.
+
+// TODO: test signaling a larger value before/after enqueuing waiting a smaller
+// value to the device.
+
} // namespace cts
} // namespace hal
} // namespace iree
diff --git a/runtime/src/iree/hal/cts/semaphore_test.h b/runtime/src/iree/hal/cts/semaphore_test.h
index 5e70639..154491c 100644
--- a/runtime/src/iree/hal/cts/semaphore_test.h
+++ b/runtime/src/iree/hal/cts/semaphore_test.h
@@ -230,6 +230,9 @@
iree_hal_semaphore_release(b2a);
}
+// TODO: test waiting the same value multiple times.
+// TODO: test waiting for a finite amount of time.
+
} // namespace cts
} // namespace hal
} // namespace iree