[cuda] Collect tracing events after command buffer completion (#16158)
Now we have proper async execution in the cuda HAL driver, command
buffers may not execute immediately after enqueuing, so we should not
collect the tracing events there. Instead, we should collect when we
know the command buffers have completed in a deferred and async manner.
diff --git a/runtime/src/iree/hal/drivers/cuda2/cuda_device.c b/runtime/src/iree/hal/drivers/cuda2/cuda_device.c
index ac519a2..1d0a74a 100644
--- a/runtime/src/iree/hal/drivers/cuda2/cuda_device.c
+++ b/runtime/src/iree/hal/drivers/cuda2/cuda_device.c
@@ -742,6 +742,11 @@
return loop_status;
}
+static void iree_hal_cuda2_device_collect_tracing_context(void* user_data) {
+ iree_hal_cuda2_tracing_context_collect(
+ (iree_hal_cuda2_tracing_context_t*)user_data);
+}
+
static iree_status_t iree_hal_cuda2_device_queue_execute(
iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
const iree_hal_semaphore_list_t wait_semaphore_list,
@@ -753,15 +758,16 @@
iree_status_t status = iree_hal_cuda2_pending_queue_actions_enqueue_execution(
base_device, device->dispatch_cu_stream, device->callback_cu_stream,
- device->pending_queue_actions, wait_semaphore_list, signal_semaphore_list,
- command_buffer_count, command_buffers);
+ device->pending_queue_actions,
+ iree_hal_cuda2_device_collect_tracing_context, device->tracing_context,
+ wait_semaphore_list, signal_semaphore_list, command_buffer_count,
+ command_buffers);
if (iree_status_is_ok(status)) {
// Try to advance the pending workload queue.
status = iree_hal_cuda2_pending_queue_actions_issue(
device->pending_queue_actions);
}
- iree_hal_cuda2_tracing_context_collect(device->tracing_context);
IREE_TRACE_ZONE_END(z0);
return status;
}
diff --git a/runtime/src/iree/hal/drivers/cuda2/pending_queue_actions.c b/runtime/src/iree/hal/drivers/cuda2/pending_queue_actions.c
index 4886ebc..ab76728 100644
--- a/runtime/src/iree/hal/drivers/cuda2/pending_queue_actions.c
+++ b/runtime/src/iree/hal/drivers/cuda2/pending_queue_actions.c
@@ -49,6 +49,12 @@
// Retained to make sure it outlives the current action.
iree_hal_cuda2_pending_queue_actions_t* owning_actions;
+ // The callback to run after completing this action and before freeing
+ // all resources.
+ iree_hal_cuda2_pending_action_cleanup_callback_t cleanup_callback;
+ // User data to pass into the callback.
+ void* callback_user_data;
+
iree_hal_cuda2_queue_action_kind_t kind;
union {
struct {
@@ -403,6 +409,8 @@
iree_status_t iree_hal_cuda2_pending_queue_actions_enqueue_execution(
iree_hal_device_t* device, CUstream dispatch_stream,
CUstream callback_stream, iree_hal_cuda2_pending_queue_actions_t* actions,
+ iree_hal_cuda2_pending_action_cleanup_callback_t cleanup_callback,
+ void* callback_user_data,
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
iree_host_size_t command_buffer_count,
@@ -417,6 +425,8 @@
(void**)&action));
action->kind = IREE_HAL_CUDA2_QUEUE_ACTION_TYPE_EXECUTION;
+ action->cleanup_callback = cleanup_callback;
+ action->callback_user_data = callback_user_data;
action->device = device;
action->dispatch_cu_stream = dispatch_stream;
action->callback_cu_stream = callback_stream;
@@ -604,6 +614,8 @@
iree_allocator_t host_allocator = actions->host_allocator;
IREE_TRACE_ZONE_BEGIN(z0);
+ action->cleanup_callback(action->callback_user_data);
+
iree_hal_resource_set_free(action->resource_set);
iree_hal_cuda2_free_semaphore_list(host_allocator,
&action->wait_semaphore_list);
diff --git a/runtime/src/iree/hal/drivers/cuda2/pending_queue_actions.h b/runtime/src/iree/hal/drivers/cuda2/pending_queue_actions.h
index 1484c2b..574d4c3 100644
--- a/runtime/src/iree/hal/drivers/cuda2/pending_queue_actions.h
+++ b/runtime/src/iree/hal/drivers/cuda2/pending_queue_actions.h
@@ -45,11 +45,20 @@
// Destroys the pending |actions| queue.
void iree_hal_cuda2_pending_queue_actions_destroy(iree_hal_resource_t* actions);
+// Callback to execute user code after action completion but before resource
+// releasing.
+//
+// Data behind |user_data| must remain alive before the action is released.
+typedef void(IREE_API_PTR* iree_hal_cuda2_pending_action_cleanup_callback_t)(
+ void* user_data);
+
// Enqueues the given list of |command_buffers| that waits on
// |wait_semaphore_list| and signals |signal_semaphore_lsit|.
iree_status_t iree_hal_cuda2_pending_queue_actions_enqueue_execution(
iree_hal_device_t* device, CUstream dispatch_stream,
CUstream callback_stream, iree_hal_cuda2_pending_queue_actions_t* actions,
+ iree_hal_cuda2_pending_action_cleanup_callback_t cleanup_callback,
+ void* callback_user_data,
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
iree_host_size_t command_buffer_count,