Create oneshot stream command buffer in pending_queue_actions
diff --git a/experimental/cuda2/cuda_device.c b/experimental/cuda2/cuda_device.c
index 5ef14f3..407b91a 100644
--- a/experimental/cuda2/cuda_device.c
+++ b/experimental/cuda2/cuda_device.c
@@ -88,10 +88,6 @@
// Optional provider used for creating/configuring collective channels.
iree_hal_channel_provider_t* channel_provider;
-
- // A CUDA stream-based command buffer used to apply deferred command buffers.
- // TODO: have one cached per stream once there are multiple streams.
- iree_hal_command_buffer_t* deferred_command_buffer;
} iree_hal_cuda2_device_t;
static const iree_hal_device_vtable_t iree_hal_cuda2_device_vtable;
@@ -199,18 +195,6 @@
host_allocator, &device->device_allocator);
}
- if (iree_status_is_ok(status) &&
- params->command_buffer_mode == IREE_HAL_CUDA_COMMAND_BUFFER_MODE_STREAM) {
- status = iree_hal_cuda2_stream_command_buffer_create(
- (iree_hal_device_t*)device, device->cuda_symbols, device->nccl_symbols,
- device->tracing_context,
- IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION |
- IREE_HAL_COMMAND_BUFFER_MODE_UNVALIDATED,
- IREE_HAL_COMMAND_CATEGORY_ANY, /*binding_capacity=*/0,
- device->dispatch_cu_stream, &device->block_pool, device->host_allocator,
- &device->deferred_command_buffer);
- }
-
if (iree_status_is_ok(status)) {
*out_device = (iree_hal_device_t*)device;
} else {
@@ -329,8 +313,6 @@
iree_hal_cuda2_pending_queue_actions_destroy(
(iree_hal_resource_t*)device->pending_queue_actions);
- iree_hal_command_buffer_release(device->deferred_command_buffer);
-
// There should be no more buffers live that use the allocator.
iree_hal_allocator_release(device->device_allocator);
@@ -536,6 +518,19 @@
params.count, device->host_allocator, out_channel);
}
+iree_status_t iree_hal_cuda2_device_create_stream_command_buffer(
+ iree_hal_device_t* base_device, iree_hal_command_buffer_mode_t mode,
+ iree_hal_command_category_t command_categories,
+ iree_host_size_t binding_capacity,
+ iree_hal_command_buffer_t** out_command_buffer) {
+ iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
+ return iree_hal_cuda2_stream_command_buffer_create(
+ base_device, device->cuda_symbols, device->nccl_symbols,
+ device->tracing_context, mode, command_categories, binding_capacity,
+ device->dispatch_cu_stream, &device->block_pool, device->host_allocator,
+ out_command_buffer);
+}
+
static iree_status_t iree_hal_cuda2_device_create_command_buffer(
iree_hal_device_t* base_device, iree_hal_command_buffer_mode_t mode,
iree_hal_command_category_t command_categories,
@@ -759,10 +754,9 @@
IREE_TRACE_ZONE_BEGIN(z0);
iree_status_t status = iree_hal_cuda2_pending_queue_actions_enqueue_execution(
- device->dispatch_cu_stream, device->callback_cu_stream,
- device->deferred_command_buffer, device->pending_queue_actions,
- wait_semaphore_list, signal_semaphore_list, command_buffer_count,
- command_buffers);
+ base_device, device->dispatch_cu_stream, device->callback_cu_stream,
+ device->pending_queue_actions, wait_semaphore_list, signal_semaphore_list,
+ command_buffer_count, command_buffers);
if (iree_status_is_ok(status)) {
// Try to advance the pending workload queue.
status = iree_hal_cuda2_pending_queue_actions_issue(
diff --git a/experimental/cuda2/cuda_device.h b/experimental/cuda2/cuda_device.h
index 5def0de..39d86e8 100644
--- a/experimental/cuda2/cuda_device.h
+++ b/experimental/cuda2/cuda_device.h
@@ -25,6 +25,14 @@
const iree_hal_cuda2_nccl_dynamic_symbols_t* nccl_symbols, CUdevice device,
iree_allocator_t host_allocator, iree_hal_device_t** out_device);
+// Creates a CUDA stream-backed command buffer using resources from the the
+// given |base_device|.
+iree_status_t iree_hal_cuda2_device_create_stream_command_buffer(
+ iree_hal_device_t* base_device, iree_hal_command_buffer_mode_t mode,
+ iree_hal_command_category_t command_categories,
+ iree_host_size_t binding_capacity,
+ iree_hal_command_buffer_t** out_command_buffer);
+
// Returns the CUDA context bound to the given |device| if it is a CUDA device
// and otherwise returns NULL.
//
diff --git a/experimental/cuda2/pending_queue_actions.c b/experimental/cuda2/pending_queue_actions.c
index e21b70c..bb80d6f 100644
--- a/experimental/cuda2/pending_queue_actions.c
+++ b/experimental/cuda2/pending_queue_actions.c
@@ -8,6 +8,7 @@
#include <stdbool.h>
+#include "experimental/cuda2/cuda_device.h"
#include "experimental/cuda2/cuda_dynamic_symbols.h"
#include "experimental/cuda2/cuda_status_util.h"
#include "experimental/cuda2/event_semaphore.h"
@@ -50,16 +51,15 @@
} command_buffers;
} payload;
+ // The device from which to allocate CUDA stream-based command buffers for
+ // applying deferred command buffers.
+ iree_hal_device_t* device;
+
// The stream to launch main GPU workload.
CUstream dispatch_cu_stream;
// The stream to launch CUDA host function callbacks.
CUstream callback_cu_stream;
- // The CUDA stream-based command buffer used to apply deferred in-memory
- // command buffers.
- // Owned by the device; must be issuing to dispatch_cu_stream in the above.
- iree_hal_command_buffer_t* deferred_command_buffer;
-
// Resource set to retain all associated resources by the payload.
iree_hal_resource_set_t* resource_set;
@@ -246,9 +246,8 @@
}
iree_status_t iree_hal_cuda2_pending_queue_actions_enqueue_execution(
- CUstream dispatch_stream, CUstream callback_stream,
- iree_hal_command_buffer_t* deferred_command_buffer,
- iree_hal_cuda2_pending_queue_actions_t* actions,
+ iree_hal_device_t* device, CUstream dispatch_stream,
+ CUstream callback_stream, iree_hal_cuda2_pending_queue_actions_t* actions,
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
iree_host_size_t command_buffer_count,
@@ -265,9 +264,9 @@
action->kind = IREE_HAL_CUDA2_QUEUE_ACTION_TYPE_EXECUTION;
action->payload.command_buffers.count = command_buffer_count;
action->payload.command_buffers.ptr = command_buffers;
+ action->device = device;
action->dispatch_cu_stream = dispatch_stream;
action->callback_cu_stream = callback_stream;
- action->deferred_command_buffer = deferred_command_buffer;
action->events = NULL;
action->event_count = 0;
action->is_pending = true;
@@ -381,9 +380,21 @@
z0, symbols, cuGraphLaunch(exec, action->dispatch_cu_stream),
"cuGraphLaunch");
} else {
+ iree_hal_command_buffer_t* stream_command_buffer = NULL;
+ iree_hal_command_buffer_mode_t mode =
+ IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT |
+ IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION |
+ IREE_HAL_COMMAND_BUFFER_MODE_UNVALIDATED;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_cuda2_device_create_stream_command_buffer(
+ action->device, mode, IREE_HAL_COMMAND_CATEGORY_ANY,
+ /*binding_capacity=*/0, &stream_command_buffer));
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_resource_set_insert(action->resource_set, 1,
+ &stream_command_buffer));
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_deferred_command_buffer_apply(
- command_buffer, action->deferred_command_buffer,
+ command_buffer, stream_command_buffer,
iree_hal_buffer_binding_table_empty()));
}
}
diff --git a/experimental/cuda2/pending_queue_actions.h b/experimental/cuda2/pending_queue_actions.h
index 590c8e4..036c063 100644
--- a/experimental/cuda2/pending_queue_actions.h
+++ b/experimental/cuda2/pending_queue_actions.h
@@ -48,9 +48,8 @@
// Enqueues the given list of |command_buffers| that waits on
// |wait_semaphore_list| and signals |signal_semaphore_lsit|.
iree_status_t iree_hal_cuda2_pending_queue_actions_enqueue_execution(
- CUstream dispatch_stream, CUstream callback_stream,
- iree_hal_command_buffer_t* deferred_command_buffer,
- iree_hal_cuda2_pending_queue_actions_t* actions,
+ iree_hal_device_t* device, CUstream dispatch_stream,
+ CUstream callback_stream, iree_hal_cuda2_pending_queue_actions_t* actions,
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
iree_host_size_t command_buffer_count,