Adding dyn_cast support to iree_hal_command_buffer_t. This allows for shimming the command buffers at runtime. Future changes will enable this for iree_hal_allocator_t as well and then we can decide if we want to do it for everything (could be used for statistics/reporting on executables, buffers, etc).
diff --git a/experimental/rocm/direct_command_buffer.c b/experimental/rocm/direct_command_buffer.c index db28d91..5f51803 100644 --- a/experimental/rocm/direct_command_buffer.c +++ b/experimental/rocm/direct_command_buffer.c
@@ -94,6 +94,21 @@ IREE_TRACE_ZONE_END(z0); } +bool iree_hal_rocm_direct_command_buffer_isa( + iree_hal_command_buffer_t* command_buffer) { + return iree_hal_command_buffer_dyn_cast( + command_buffer, &iree_hal_rocm_direct_command_buffer_vtable); +} + +static void* iree_hal_rocm_direct_command_buffer_dyn_cast( + iree_hal_command_buffer_t* command_buffer, const void* vtable) { + if (vtable == &iree_hal_rocm_direct_command_buffer_vtable) { + IREE_HAL_ASSERT_TYPE(command_buffer, vtable); + return command_buffer; + } + return NULL; +} + static iree_hal_command_buffer_mode_t iree_hal_rocm_direct_command_buffer_mode( const iree_hal_command_buffer_t* base_command_buffer) { const iree_hal_rocm_direct_command_buffer_t* command_buffer = @@ -359,6 +374,7 @@ const iree_hal_command_buffer_vtable_t iree_hal_rocm_direct_command_buffer_vtable = { .destroy = iree_hal_rocm_direct_command_buffer_destroy, + .dyn_cast = iree_hal_rocm_direct_command_buffer_dyn_cast, .mode = iree_hal_rocm_direct_command_buffer_mode, .allowed_categories = iree_hal_rocm_direct_command_buffer_allowed_categories,
diff --git a/experimental/rocm/direct_command_buffer.h b/experimental/rocm/direct_command_buffer.h index bd665bf..7ce4947 100644 --- a/experimental/rocm/direct_command_buffer.h +++ b/experimental/rocm/direct_command_buffer.h
@@ -37,6 +37,10 @@ iree_hal_queue_affinity_t queue_affinity, iree_hal_command_buffer_t** out_command_buffer); +// Returns true if |command_buffer| is a ROCM command buffer. +bool iree_hal_rocm_direct_command_buffer_isa( + iree_hal_command_buffer_t* command_buffer); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus
diff --git a/iree/hal/command_buffer.c b/iree/hal/command_buffer.c index 90296c6..d137b83 100644 --- a/iree/hal/command_buffer.c +++ b/iree/hal/command_buffer.c
@@ -57,6 +57,13 @@ return status; } +IREE_API_EXPORT void* iree_hal_command_buffer_dyn_cast( + iree_hal_command_buffer_t* command_buffer, const void* vtable) { + IREE_ASSERT_ARGUMENT(command_buffer); + if (iree_hal_resource_is(command_buffer, vtable)) return command_buffer; + return _VTABLE_DISPATCH(command_buffer, dyn_cast)(command_buffer, vtable); +} + IREE_API_EXPORT iree_hal_command_buffer_mode_t iree_hal_command_buffer_mode(const iree_hal_command_buffer_t* command_buffer) { IREE_ASSERT_ARGUMENT(command_buffer);
diff --git a/iree/hal/command_buffer.h b/iree/hal/command_buffer.h index 82b3deb..79716b6 100644 --- a/iree/hal/command_buffer.h +++ b/iree/hal/command_buffer.h
@@ -261,6 +261,9 @@ IREE_API_EXPORT void iree_hal_command_buffer_release( iree_hal_command_buffer_t* command_buffer); +IREE_API_EXPORT void* iree_hal_command_buffer_dyn_cast( + iree_hal_command_buffer_t* command_buffer, const void* vtable); + // Returns a bitmask indicating the behavior of the command buffer. IREE_API_EXPORT iree_hal_command_buffer_mode_t iree_hal_command_buffer_mode(const iree_hal_command_buffer_t* command_buffer); @@ -499,6 +502,9 @@ void(IREE_API_PTR* destroy)(iree_hal_command_buffer_t* command_buffer); + void*(IREE_API_PTR* dyn_cast)(iree_hal_command_buffer_t* command_buffer, + const void* vtable); + iree_hal_command_buffer_mode_t(IREE_API_PTR* mode)( const iree_hal_command_buffer_t* command_buffer); iree_hal_command_category_t(IREE_API_PTR* allowed_categories)(
diff --git a/iree/hal/command_buffer_validation.c b/iree/hal/command_buffer_validation.c index 143a7fc..2db1100 100644 --- a/iree/hal/command_buffer_validation.c +++ b/iree/hal/command_buffer_validation.c
@@ -25,9 +25,11 @@ iree_hal_resource_t resource; iree_hal_device_t* device; iree_hal_command_buffer_t* target_command_buffer; + iree_hal_command_buffer_mode_t mode; iree_hal_command_category_t allowed_categories; bool is_recording; + int32_t debug_group_depth; // TODO(benvanik): current executable layout/descriptor set layout info. // TODO(benvanik): valid push constant bit ranges. } iree_hal_validating_command_buffer_t; @@ -35,6 +37,19 @@ static const iree_hal_command_buffer_vtable_t iree_hal_validating_command_buffer_vtable; +static iree_hal_validating_command_buffer_t* +iree_hal_validating_command_buffer_cast(iree_hal_command_buffer_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_validating_command_buffer_vtable); + return (iree_hal_validating_command_buffer_t*)base_value; +} + +static const iree_hal_validating_command_buffer_t* +iree_hal_validating_command_buffer_const_cast( + const iree_hal_command_buffer_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_validating_command_buffer_vtable); + return (const iree_hal_validating_command_buffer_t*)base_value; +} + // Returns success iff the queue supports the given command categories. static iree_status_t iree_hal_command_buffer_validate_categories( const iree_hal_validating_command_buffer_t* command_buffer, @@ -104,6 +119,8 @@ iree_hal_device_retain(command_buffer->device); command_buffer->target_command_buffer = target_command_buffer; iree_hal_command_buffer_retain(command_buffer->target_command_buffer); + command_buffer->mode = + iree_hal_command_buffer_mode(command_buffer->target_command_buffer); command_buffer->allowed_categories = iree_hal_command_buffer_allowed_categories( command_buffer->target_command_buffer); @@ -120,7 +137,7 @@ iree_hal_command_buffer_t* base_command_buffer) { IREE_TRACE_ZONE_BEGIN(z0); iree_hal_validating_command_buffer_t* command_buffer = - (iree_hal_validating_command_buffer_t*)base_command_buffer; + iree_hal_validating_command_buffer_cast(base_command_buffer); iree_allocator_t host_allocator = iree_hal_device_host_allocator(command_buffer->device); iree_hal_command_buffer_release(command_buffer->target_command_buffer); @@ -129,18 +146,37 @@ IREE_TRACE_ZONE_END(z0); } +static void* iree_hal_validating_command_buffer_dyn_cast( + iree_hal_command_buffer_t* base_command_buffer, const void* vtable) { + if (vtable == &iree_hal_validating_command_buffer_vtable) { + IREE_HAL_ASSERT_TYPE(base_command_buffer, vtable); + return base_command_buffer; + } + iree_hal_validating_command_buffer_t* command_buffer = + iree_hal_validating_command_buffer_cast(base_command_buffer); + return iree_hal_command_buffer_dyn_cast(command_buffer->target_command_buffer, + vtable); +} + +static iree_hal_command_buffer_mode_t iree_hal_validating_command_buffer_mode( + const iree_hal_command_buffer_t* base_command_buffer) { + const iree_hal_validating_command_buffer_t* command_buffer = + iree_hal_validating_command_buffer_const_cast(base_command_buffer); + return command_buffer->mode; +} + static iree_hal_command_category_t iree_hal_validating_command_buffer_allowed_categories( const iree_hal_command_buffer_t* base_command_buffer) { - iree_hal_validating_command_buffer_t* command_buffer = - (iree_hal_validating_command_buffer_t*)base_command_buffer; + const iree_hal_validating_command_buffer_t* command_buffer = + iree_hal_validating_command_buffer_const_cast(base_command_buffer); return command_buffer->allowed_categories; } static iree_status_t iree_hal_validating_command_buffer_begin( iree_hal_command_buffer_t* base_command_buffer) { iree_hal_validating_command_buffer_t* command_buffer = - (iree_hal_validating_command_buffer_t*)base_command_buffer; + iree_hal_validating_command_buffer_cast(base_command_buffer); if (command_buffer->is_recording) { return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, @@ -154,8 +190,13 @@ static iree_status_t iree_hal_validating_command_buffer_end( iree_hal_command_buffer_t* base_command_buffer) { iree_hal_validating_command_buffer_t* command_buffer = - (iree_hal_validating_command_buffer_t*)base_command_buffer; + iree_hal_validating_command_buffer_cast(base_command_buffer); + if (command_buffer->debug_group_depth != 0) { + return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, + "unbalanced debug group depth (expected 0, is %d)", + command_buffer->debug_group_depth); + } if (!command_buffer->is_recording) { return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, "command buffer is not in a recording state"); @@ -165,6 +206,25 @@ return iree_hal_command_buffer_end(command_buffer->target_command_buffer); } +static void iree_hal_validating_command_buffer_begin_debug_group( + iree_hal_command_buffer_t* base_command_buffer, iree_string_view_t label, + iree_hal_label_color_t label_color, + const iree_hal_label_location_t* location) { + iree_hal_validating_command_buffer_t* command_buffer = + iree_hal_validating_command_buffer_cast(base_command_buffer); + iree_hal_command_buffer_begin_debug_group( + command_buffer->target_command_buffer, label, label_color, location); +} + +static void iree_hal_validating_command_buffer_end_debug_group( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_validating_command_buffer_t* command_buffer = + iree_hal_validating_command_buffer_cast(base_command_buffer); + --command_buffer->debug_group_depth; + iree_hal_command_buffer_end_debug_group( + command_buffer->target_command_buffer); +} + static iree_status_t iree_hal_validating_command_buffer_execution_barrier( iree_hal_command_buffer_t* base_command_buffer, iree_hal_execution_stage_t source_stage_mask, @@ -175,7 +235,7 @@ iree_host_size_t buffer_barrier_count, const iree_hal_buffer_barrier_t* buffer_barriers) { iree_hal_validating_command_buffer_t* command_buffer = - (iree_hal_validating_command_buffer_t*)base_command_buffer; + iree_hal_validating_command_buffer_cast(base_command_buffer); IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories( command_buffer, IREE_HAL_COMMAND_CATEGORY_ANY)); @@ -192,7 +252,7 @@ iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event, iree_hal_execution_stage_t source_stage_mask) { iree_hal_validating_command_buffer_t* command_buffer = - (iree_hal_validating_command_buffer_t*)base_command_buffer; + iree_hal_validating_command_buffer_cast(base_command_buffer); IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories( command_buffer, IREE_HAL_COMMAND_CATEGORY_DISPATCH)); @@ -207,7 +267,7 @@ iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event, iree_hal_execution_stage_t source_stage_mask) { iree_hal_validating_command_buffer_t* command_buffer = - (iree_hal_validating_command_buffer_t*)base_command_buffer; + iree_hal_validating_command_buffer_cast(base_command_buffer); IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories( command_buffer, IREE_HAL_COMMAND_CATEGORY_DISPATCH)); @@ -228,7 +288,7 @@ iree_host_size_t buffer_barrier_count, const iree_hal_buffer_barrier_t* buffer_barriers) { iree_hal_validating_command_buffer_t* command_buffer = - (iree_hal_validating_command_buffer_t*)base_command_buffer; + iree_hal_validating_command_buffer_cast(base_command_buffer); IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories( command_buffer, IREE_HAL_COMMAND_CATEGORY_DISPATCH)); @@ -244,7 +304,7 @@ static iree_status_t iree_hal_validating_command_buffer_discard_buffer( iree_hal_command_buffer_t* base_command_buffer, iree_hal_buffer_t* buffer) { iree_hal_validating_command_buffer_t* command_buffer = - (iree_hal_validating_command_buffer_t*)base_command_buffer; + iree_hal_validating_command_buffer_cast(base_command_buffer); IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories( command_buffer, IREE_HAL_COMMAND_CATEGORY_TRANSFER)); @@ -263,7 +323,7 @@ iree_device_size_t length, const void* pattern, iree_host_size_t pattern_length) { iree_hal_validating_command_buffer_t* command_buffer = - (iree_hal_validating_command_buffer_t*)base_command_buffer; + iree_hal_validating_command_buffer_cast(base_command_buffer); IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories( command_buffer, IREE_HAL_COMMAND_CATEGORY_TRANSFER)); @@ -312,7 +372,7 @@ iree_host_size_t source_offset, iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, iree_device_size_t length) { iree_hal_validating_command_buffer_t* command_buffer = - (iree_hal_validating_command_buffer_t*)base_command_buffer; + iree_hal_validating_command_buffer_cast(base_command_buffer); IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories( command_buffer, IREE_HAL_COMMAND_CATEGORY_TRANSFER)); @@ -344,7 +404,7 @@ iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, iree_device_size_t length) { iree_hal_validating_command_buffer_t* command_buffer = - (iree_hal_validating_command_buffer_t*)base_command_buffer; + iree_hal_validating_command_buffer_cast(base_command_buffer); IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories( command_buffer, IREE_HAL_COMMAND_CATEGORY_TRANSFER)); @@ -410,7 +470,7 @@ iree_hal_executable_layout_t* executable_layout, iree_host_size_t offset, const void* values, iree_host_size_t values_length) { iree_hal_validating_command_buffer_t* command_buffer = - (iree_hal_validating_command_buffer_t*)base_command_buffer; + iree_hal_validating_command_buffer_cast(base_command_buffer); IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories( command_buffer, IREE_HAL_COMMAND_CATEGORY_DISPATCH)); @@ -434,7 +494,7 @@ iree_host_size_t binding_count, const iree_hal_descriptor_set_binding_t* bindings) { iree_hal_validating_command_buffer_t* command_buffer = - (iree_hal_validating_command_buffer_t*)base_command_buffer; + iree_hal_validating_command_buffer_cast(base_command_buffer); IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories( command_buffer, IREE_HAL_COMMAND_CATEGORY_DISPATCH)); @@ -455,7 +515,7 @@ iree_host_size_t dynamic_offset_count, const iree_device_size_t* dynamic_offsets) { iree_hal_validating_command_buffer_t* command_buffer = - (iree_hal_validating_command_buffer_t*)base_command_buffer; + iree_hal_validating_command_buffer_cast(base_command_buffer); IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories( command_buffer, IREE_HAL_COMMAND_CATEGORY_DISPATCH)); @@ -473,7 +533,7 @@ iree_hal_executable_t* executable, int32_t entry_point, uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { iree_hal_validating_command_buffer_t* command_buffer = - (iree_hal_validating_command_buffer_t*)base_command_buffer; + iree_hal_validating_command_buffer_cast(base_command_buffer); IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories( command_buffer, IREE_HAL_COMMAND_CATEGORY_DISPATCH)); @@ -491,7 +551,7 @@ iree_hal_buffer_t* workgroups_buffer, iree_device_size_t workgroups_offset) { iree_hal_validating_command_buffer_t* command_buffer = - (iree_hal_validating_command_buffer_t*)base_command_buffer; + iree_hal_validating_command_buffer_cast(base_command_buffer); IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories( command_buffer, IREE_HAL_COMMAND_CATEGORY_DISPATCH)); @@ -523,10 +583,15 @@ static const iree_hal_command_buffer_vtable_t iree_hal_validating_command_buffer_vtable = { .destroy = iree_hal_validating_command_buffer_destroy, + .dyn_cast = iree_hal_validating_command_buffer_dyn_cast, + .mode = iree_hal_validating_command_buffer_mode, .allowed_categories = iree_hal_validating_command_buffer_allowed_categories, .begin = iree_hal_validating_command_buffer_begin, .end = iree_hal_validating_command_buffer_end, + .begin_debug_group = + iree_hal_validating_command_buffer_begin_debug_group, + .end_debug_group = iree_hal_validating_command_buffer_end_debug_group, .execution_barrier = iree_hal_validating_command_buffer_execution_barrier, .signal_event = iree_hal_validating_command_buffer_signal_event,
diff --git a/iree/hal/cuda/api.h b/iree/hal/cuda/api.h index 0a7ec62..50def91 100644 --- a/iree/hal/cuda/api.h +++ b/iree/hal/cuda/api.h
@@ -16,6 +16,14 @@ extern "C" { #endif // __cplusplus +// Defines how command buffers are recorded and executed. +typedef enum iree_hal_cuda_command_buffer_mode_e { + // Command buffers are recorded into CUDA graphs. + IREE_HAL_CUDA_COMMAND_BUFFER_MODE_GRAPH = 0, + // Command buffers are directly issued against a CUDA stream. + IREE_HAL_CUDA_COMMAND_BUFFER_MODE_STREAM = 1, +} iree_hal_cuda_command_buffer_mode_t; + // Parameters configuring an iree_hal_cuda_device_t. // Must be initialized with iree_hal_cuda_device_params_initialize prior to use. typedef struct iree_hal_cuda_device_params_t { @@ -29,8 +37,8 @@ // transient allocations while also increasing memory consumption. iree_host_size_t arena_block_size; - // Switch for using deferred command buffer or default graph command buffer - bool use_deferred_submission; + // Specifies how command buffers are recorded and executed. + iree_hal_cuda_command_buffer_mode_t command_buffer_mode; } iree_hal_cuda_device_params_t; // Initializes |out_params| to default values.
diff --git a/iree/hal/cuda/cuda_device.c b/iree/hal/cuda/cuda_device.c index ca38db7..338af6c 100644 --- a/iree/hal/cuda/cuda_device.c +++ b/iree/hal/cuda/cuda_device.c
@@ -48,8 +48,11 @@ iree_hal_cuda_context_wrapper_t context_wrapper; iree_hal_allocator_t* device_allocator; - // Switch for using deferred command buffer or default graph command buffer - bool use_deferred_submission; + // The command buffer type that should be used when recording commands. + iree_hal_cuda_command_buffer_mode_t command_buffer_mode; + // Cache of the direct stream command buffer initialized when in stream mode. + // TODO: have one cached per stream once there are multiple streams. + iree_hal_command_buffer_t* stream_command_buffer; } iree_hal_cuda_device_t; extern const iree_hal_device_vtable_t iree_hal_cuda_device_vtable; @@ -64,7 +67,7 @@ iree_hal_cuda_device_params_t* out_params) { out_params->arena_block_size = 32 * 1024; out_params->queue_count = 8; - out_params->use_deferred_submission = false; + out_params->command_buffer_mode = IREE_HAL_CUDA_COMMAND_BUFFER_MODE_GRAPH; } static iree_status_t iree_hal_cuda_device_check_params( @@ -80,25 +83,6 @@ return iree_ok_status(); } -static void iree_hal_cuda_device_destroy(iree_hal_device_t* base_device) { - iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device); - iree_allocator_t host_allocator = iree_hal_device_host_allocator(base_device); - IREE_TRACE_ZONE_BEGIN(z0); - - // There should be no more buffers live that use the allocator. - iree_hal_allocator_release(device->device_allocator); - CUDA_IGNORE_ERROR(device->context_wrapper.syms, - cuStreamDestroy(device->stream)); - - iree_arena_block_pool_deinitialize(&device->block_pool); - // Finally, destroy the device. - iree_hal_driver_release(device->driver); - - iree_allocator_free(host_allocator, device); - - IREE_TRACE_ZONE_END(z0); -} - static iree_status_t iree_hal_cuda_device_create_internal( iree_hal_driver_t* driver, iree_string_view_t identifier, const iree_hal_cuda_device_params_t* params, CUdevice cu_device, @@ -122,9 +106,20 @@ iree_arena_block_pool_initialize(params->arena_block_size, host_allocator, &device->block_pool); device->context_wrapper.syms = syms; - device->use_deferred_submission = params->use_deferred_submission; + iree_status_t status = iree_hal_cuda_allocator_create( &device->context_wrapper, cu_device, stream, &device->device_allocator); + + device->command_buffer_mode = params->command_buffer_mode; + if (iree_status_is_ok(status) && + device->command_buffer_mode == IREE_HAL_CUDA_COMMAND_BUFFER_MODE_STREAM) { + status = iree_hal_cuda_stream_command_buffer_create( + &device->context_wrapper, + IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION, + IREE_HAL_COMMAND_CATEGORY_ANY, device->stream, + &device->stream_command_buffer); + } + if (iree_status_is_ok(status)) { *out_device = (iree_hal_device_t*)device; } else { @@ -164,6 +159,27 @@ return status; } +static void iree_hal_cuda_device_destroy(iree_hal_device_t* base_device) { + iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device); + iree_allocator_t host_allocator = iree_hal_device_host_allocator(base_device); + IREE_TRACE_ZONE_BEGIN(z0); + + // There should be no more buffers live that use the allocator. + iree_hal_command_buffer_release(device->stream_command_buffer); + iree_hal_allocator_release(device->device_allocator); + CUDA_IGNORE_ERROR(device->context_wrapper.syms, + cuStreamDestroy(device->stream)); + + iree_arena_block_pool_deinitialize(&device->block_pool); + + // Finally, destroy the device. + iree_hal_driver_release(device->driver); + + iree_allocator_free(host_allocator, device); + + IREE_TRACE_ZONE_END(z0); +} + static iree_string_view_t iree_hal_cuda_device_id( iree_hal_device_t* base_device) { iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device); @@ -209,14 +225,19 @@ iree_hal_queue_affinity_t queue_affinity, iree_hal_command_buffer_t** out_command_buffer) { iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device); - if (device->use_deferred_submission) { - return iree_hal_deferred_command_buffer_create( - mode, command_categories, &device->block_pool, - iree_hal_device_host_allocator(base_device), out_command_buffer); + switch (device->command_buffer_mode) { + case IREE_HAL_CUDA_COMMAND_BUFFER_MODE_GRAPH: + return iree_hal_cuda_graph_command_buffer_create( + &device->context_wrapper, mode, command_categories, queue_affinity, + out_command_buffer); + case IREE_HAL_CUDA_COMMAND_BUFFER_MODE_STREAM: + return iree_hal_deferred_command_buffer_create( + mode, command_categories, &device->block_pool, + iree_hal_device_host_allocator(base_device), out_command_buffer); + default: + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "invalid command buffer mode"); } - return iree_hal_cuda_graph_command_buffer_create( - &device->context_wrapper, mode, command_categories, queue_affinity, - out_command_buffer); } static iree_status_t iree_hal_cuda_device_create_descriptor_set( @@ -280,29 +301,18 @@ iree_hal_queue_affinity_t queue_affinity, iree_host_size_t batch_count, const iree_hal_submission_batch_t* batches) { iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device); - if (device->use_deferred_submission) { - iree_hal_command_buffer_t* stream_command_buffer; - iree_status_t status = iree_hal_cuda_stream_command_buffer_create( - &device->context_wrapper, - IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION, command_categories, - device->stream, &stream_command_buffer); - if (iree_status_is_ok(status)) { - for (int i = 0; i < batch_count; i++) { - for (int j = 0; j < batches[i].command_buffer_count; j++) { - iree_hal_deferred_command_buffer_apply(batches[i].command_buffers[j], - stream_command_buffer); - } - } - } - iree_hal_command_buffer_release(stream_command_buffer); - } else { - for (int i = 0; i < batch_count; i++) { - for (int j = 0; j < batches[i].command_buffer_count; j++) { + for (int i = 0; i < batch_count; i++) { + for (int j = 0; j < batches[i].command_buffer_count; j++) { + iree_hal_command_buffer_t* command_buffer = batches[i].command_buffers[j]; + if (iree_hal_cuda_graph_command_buffer_isa(command_buffer)) { CUgraphExec exec = iree_hal_cuda_graph_command_buffer_exec( batches[i].command_buffers[j]); CUDA_RETURN_IF_ERROR(device->context_wrapper.syms, cuGraphLaunch(exec, device->stream), "cuGraphLaunch"); + } else { + IREE_RETURN_IF_ERROR(iree_hal_deferred_command_buffer_apply( + batches[i].command_buffers[j], device->stream_command_buffer)); } } }
diff --git a/iree/hal/cuda/graph_command_buffer.c b/iree/hal/cuda/graph_command_buffer.c index 19ddb12..b42dfef 100644 --- a/iree/hal/cuda/graph_command_buffer.c +++ b/iree/hal/cuda/graph_command_buffer.c
@@ -121,6 +121,21 @@ return command_buffer->exec; } +bool iree_hal_cuda_graph_command_buffer_isa( + iree_hal_command_buffer_t* command_buffer) { + return iree_hal_command_buffer_dyn_cast( + command_buffer, &iree_hal_cuda_graph_command_buffer_vtable); +} + +static void* iree_hal_cuda_graph_command_buffer_dyn_cast( + iree_hal_command_buffer_t* command_buffer, const void* vtable) { + if (vtable == &iree_hal_cuda_graph_command_buffer_vtable) { + IREE_HAL_ASSERT_TYPE(command_buffer, vtable); + return command_buffer; + } + return NULL; +} + static iree_hal_command_buffer_mode_t iree_hal_cuda_graph_command_buffer_mode( const iree_hal_command_buffer_t* base_command_buffer) { const iree_hal_cuda_graph_command_buffer_t* command_buffer = @@ -410,7 +425,6 @@ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { iree_hal_cuda_graph_command_buffer_t* command_buffer = iree_hal_cuda_graph_command_buffer_cast(base_command_buffer); - iree_hal_cuda_graph_command_buffer_cast(base_command_buffer); int32_t block_size_x, block_size_y, block_size_z; IREE_RETURN_IF_ERROR(iree_hal_cuda_native_executable_block_size( @@ -447,15 +461,18 @@ } CUgraphExec iree_hal_cuda_graph_command_buffer_exec( - const iree_hal_command_buffer_t* base_command_buffer) { - const iree_hal_cuda_graph_command_buffer_t* command_buffer = - (const iree_hal_cuda_graph_command_buffer_t*)(base_command_buffer); + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_cuda_graph_command_buffer_t* command_buffer = + (iree_hal_cuda_graph_command_buffer_t*)iree_hal_command_buffer_dyn_cast( + base_command_buffer, &iree_hal_cuda_graph_command_buffer_vtable); + IREE_ASSERT_TRUE(command_buffer); return command_buffer->exec; } const iree_hal_command_buffer_vtable_t iree_hal_cuda_graph_command_buffer_vtable = { .destroy = iree_hal_cuda_graph_command_buffer_destroy, + .dyn_cast = iree_hal_cuda_graph_command_buffer_dyn_cast, .mode = iree_hal_cuda_graph_command_buffer_mode, .allowed_categories = iree_hal_cuda_graph_command_buffer_allowed_categories,
diff --git a/iree/hal/cuda/graph_command_buffer.h b/iree/hal/cuda/graph_command_buffer.h index c50ccf9..eb2ed8d 100644 --- a/iree/hal/cuda/graph_command_buffer.h +++ b/iree/hal/cuda/graph_command_buffer.h
@@ -25,9 +25,13 @@ iree_hal_queue_affinity_t queue_affinity, iree_hal_command_buffer_t** out_command_buffer); +// Returns true if |command_buffer| is a CUDA graph-based command buffer. +bool iree_hal_cuda_graph_command_buffer_isa( + iree_hal_command_buffer_t* command_buffer); + // Returns the native cuda graph associated to the command buffer. CUgraphExec iree_hal_cuda_graph_command_buffer_exec( - const iree_hal_command_buffer_t* command_buffer); + iree_hal_command_buffer_t* command_buffer); #ifdef __cplusplus } // extern "C"
diff --git a/iree/hal/cuda/registration/driver_module.c b/iree/hal/cuda/registration/driver_module.c index 0a51ed9..fb489ca 100644 --- a/iree/hal/cuda/registration/driver_module.c +++ b/iree/hal/cuda/registration/driver_module.c
@@ -41,15 +41,14 @@ driver_id); } IREE_TRACE_ZONE_BEGIN(z0); + iree_hal_cuda_device_params_t default_params; iree_hal_cuda_device_params_initialize(&default_params); - // TODO(jinchen62): set up default_params.use_deferred_submission by flag - // When we expose more than one driver (different cuda versions, etc) we - // can name them here: - iree_string_view_t identifier = iree_make_cstring_view("cuda"); iree_hal_cuda_driver_options_t driver_options; iree_hal_cuda_driver_options_initialize(&driver_options); + + iree_string_view_t identifier = iree_make_cstring_view("cuda"); iree_status_t status = iree_hal_cuda_driver_create( identifier, &default_params, &driver_options, allocator, out_driver); IREE_TRACE_ZONE_END(z0);
diff --git a/iree/hal/cuda/stream_command_buffer.c b/iree/hal/cuda/stream_command_buffer.c index 0927320..995eadd 100644 --- a/iree/hal/cuda/stream_command_buffer.c +++ b/iree/hal/cuda/stream_command_buffer.c
@@ -83,6 +83,21 @@ IREE_TRACE_ZONE_END(z0); } +bool iree_hal_cuda_stream_command_buffer_isa( + iree_hal_command_buffer_t* command_buffer) { + return iree_hal_command_buffer_dyn_cast( + command_buffer, &iree_hal_cuda_stream_command_buffer_vtable); +} + +static void* iree_hal_cuda_stream_command_buffer_dyn_cast( + iree_hal_command_buffer_t* command_buffer, const void* vtable) { + if (vtable == &iree_hal_cuda_stream_command_buffer_vtable) { + IREE_HAL_ASSERT_TYPE(command_buffer, vtable); + return command_buffer; + } + return NULL; +} + static iree_hal_command_buffer_mode_t iree_hal_cuda_stream_command_buffer_mode( const iree_hal_command_buffer_t* base_command_buffer) { const iree_hal_cuda_stream_command_buffer_t* command_buffer = @@ -328,6 +343,7 @@ const iree_hal_command_buffer_vtable_t iree_hal_cuda_stream_command_buffer_vtable = { .destroy = iree_hal_cuda_stream_command_buffer_destroy, + .dyn_cast = iree_hal_cuda_stream_command_buffer_dyn_cast, .mode = iree_hal_cuda_stream_command_buffer_mode, .allowed_categories = iree_hal_cuda_stream_command_buffer_allowed_categories,
diff --git a/iree/hal/cuda/stream_command_buffer.h b/iree/hal/cuda/stream_command_buffer.h index b4b901a..7e7b8b5 100644 --- a/iree/hal/cuda/stream_command_buffer.h +++ b/iree/hal/cuda/stream_command_buffer.h
@@ -28,6 +28,10 @@ iree_hal_command_category_t command_categories, CUstream stream, iree_hal_command_buffer_t **out_command_buffer); +// Returns true if |command_buffer| is a CUDA stream-based command buffer. +bool iree_hal_cuda_stream_command_buffer_isa( + iree_hal_command_buffer_t *command_buffer); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus
diff --git a/iree/hal/local/inline_command_buffer.c b/iree/hal/local/inline_command_buffer.c index 1dc123e..5e0061e 100644 --- a/iree/hal/local/inline_command_buffer.c +++ b/iree/hal/local/inline_command_buffer.c
@@ -137,6 +137,21 @@ IREE_TRACE_ZONE_END(z0); } +bool iree_hal_inline_command_buffer_isa( + iree_hal_command_buffer_t* command_buffer) { + return iree_hal_command_buffer_dyn_cast( + command_buffer, &iree_hal_inline_command_buffer_vtable); +} + +static void* iree_hal_inline_command_buffer_dyn_cast( + iree_hal_command_buffer_t* command_buffer, const void* vtable) { + if (vtable == &iree_hal_inline_command_buffer_vtable) { + IREE_HAL_ASSERT_TYPE(command_buffer, vtable); + return command_buffer; + } + return NULL; +} + static iree_hal_command_buffer_mode_t iree_hal_inline_command_buffer_mode( const iree_hal_command_buffer_t* base_command_buffer) { return ((const iree_hal_inline_command_buffer_t*)base_command_buffer)->mode; @@ -502,6 +517,7 @@ static const iree_hal_command_buffer_vtable_t iree_hal_inline_command_buffer_vtable = { .destroy = iree_hal_inline_command_buffer_destroy, + .dyn_cast = iree_hal_inline_command_buffer_dyn_cast, .mode = iree_hal_inline_command_buffer_mode, .allowed_categories = iree_hal_inline_command_buffer_allowed_categories, .begin = iree_hal_inline_command_buffer_begin,
diff --git a/iree/hal/local/inline_command_buffer.h b/iree/hal/local/inline_command_buffer.h index 56d77f6..750b674 100644 --- a/iree/hal/local/inline_command_buffer.h +++ b/iree/hal/local/inline_command_buffer.h
@@ -29,6 +29,10 @@ iree_hal_queue_affinity_t queue_affinity, iree_allocator_t host_allocator, iree_hal_command_buffer_t** out_command_buffer); +// Returns true if |command_buffer| is an inline command buffer. +bool iree_hal_inline_command_buffer_isa( + iree_hal_command_buffer_t* command_buffer); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus
diff --git a/iree/hal/local/task_command_buffer.c b/iree/hal/local/task_command_buffer.c index 0eed66f..5645799 100644 --- a/iree/hal/local/task_command_buffer.c +++ b/iree/hal/local/task_command_buffer.c
@@ -172,6 +172,21 @@ IREE_TRACE_ZONE_END(z0); } +bool iree_hal_task_command_buffer_isa( + iree_hal_command_buffer_t* command_buffer) { + return iree_hal_command_buffer_dyn_cast(command_buffer, + &iree_hal_task_command_buffer_vtable); +} + +static void* iree_hal_task_command_buffer_dyn_cast( + iree_hal_command_buffer_t* command_buffer, const void* vtable) { + if (vtable == &iree_hal_task_command_buffer_vtable) { + IREE_HAL_ASSERT_TYPE(command_buffer, vtable); + return command_buffer; + } + return NULL; +} + static iree_hal_command_buffer_mode_t iree_hal_task_command_buffer_mode( const iree_hal_command_buffer_t* base_command_buffer) { return ((const iree_hal_task_command_buffer_t*)base_command_buffer)->mode; @@ -343,7 +358,9 @@ iree_hal_task_queue_state_t* queue_state, iree_task_t* retire_task, iree_arena_allocator_t* arena, iree_task_submission_t* pending_submission) { iree_hal_task_command_buffer_t* command_buffer = - iree_hal_task_command_buffer_cast(base_command_buffer); + iree_hal_command_buffer_dyn_cast(base_command_buffer, + &iree_hal_task_command_buffer_vtable); + IREE_ASSERT_TRUE(command_buffer); // If the command buffer is empty (valid!) then we are a no-op. bool has_root_tasks = !iree_task_list_is_empty(&command_buffer->root_tasks); @@ -958,6 +975,7 @@ static const iree_hal_command_buffer_vtable_t iree_hal_task_command_buffer_vtable = { .destroy = iree_hal_task_command_buffer_destroy, + .dyn_cast = iree_hal_task_command_buffer_dyn_cast, .mode = iree_hal_task_command_buffer_mode, .allowed_categories = iree_hal_task_command_buffer_allowed_categories, .begin = iree_hal_task_command_buffer_begin,
diff --git a/iree/hal/local/task_command_buffer.h b/iree/hal/local/task_command_buffer.h index d3b1a4a..71cc62b 100644 --- a/iree/hal/local/task_command_buffer.h +++ b/iree/hal/local/task_command_buffer.h
@@ -25,6 +25,10 @@ iree_arena_block_pool_t* block_pool, iree_allocator_t host_allocator, iree_hal_command_buffer_t** out_command_buffer); +// Returns true if |command_buffer| is a task system command buffer. +bool iree_hal_task_command_buffer_isa( + iree_hal_command_buffer_t* command_buffer); + // Issues a recorded command buffer using the serial |queue_state|. // |queue_state| is used to track the synchronization scope of the queue from // prior commands such as signaled events and will be mutated as events are
diff --git a/iree/hal/local/task_queue.c b/iree/hal/local/task_queue.c index e7eb23d..5090e56 100644 --- a/iree/hal/local/task_queue.c +++ b/iree/hal/local/task_queue.c
@@ -200,9 +200,15 @@ // submission was purely for synchronization. if (cmd->command_buffer_count > 0) { for (iree_host_size_t i = 0; i < cmd->command_buffer_count; ++i) { - status = iree_hal_task_command_buffer_issue( - cmd->command_buffers[i], &cmd->queue->state, - cmd->task.header.completion_task, cmd->arena, pending_submission); + if (iree_hal_task_command_buffer_isa(cmd->command_buffers[i])) { + status = iree_hal_task_command_buffer_issue( + cmd->command_buffers[i], &cmd->queue->state, + cmd->task.header.completion_task, cmd->arena, pending_submission); + } else { + status = iree_make_status( + IREE_STATUS_UNIMPLEMENTED, + "unsupported command buffer type for task queue submission"); + } if (IREE_UNLIKELY(!iree_status_is_ok(status))) break; } }
diff --git a/iree/hal/vulkan/direct_command_buffer.cc b/iree/hal/vulkan/direct_command_buffer.cc index 734fac1..cdce08a 100644 --- a/iree/hal/vulkan/direct_command_buffer.cc +++ b/iree/hal/vulkan/direct_command_buffer.cc
@@ -137,6 +137,21 @@ IREE_IGNORE_ERROR(command_buffer->descriptor_set_group.Reset()); } +bool iree_hal_vulkan_direct_command_buffer_isa( + iree_hal_command_buffer_t* command_buffer) { + return iree_hal_command_buffer_dyn_cast( + command_buffer, &iree_hal_vulkan_direct_command_buffer_vtable); +} + +static void* iree_hal_vulkan_direct_command_buffer_dyn_cast( + iree_hal_command_buffer_t* command_buffer, const void* vtable) { + if (vtable == &iree_hal_vulkan_direct_command_buffer_vtable) { + IREE_HAL_ASSERT_TYPE(command_buffer, vtable); + return command_buffer; + } + return NULL; +} + static void iree_hal_vulkan_direct_command_buffer_destroy( iree_hal_command_buffer_t* base_command_buffer) { iree_hal_vulkan_direct_command_buffer_t* command_buffer = @@ -159,7 +174,10 @@ VkCommandBuffer iree_hal_vulkan_direct_command_buffer_handle( iree_hal_command_buffer_t* base_command_buffer) { iree_hal_vulkan_direct_command_buffer_t* command_buffer = - iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); + (iree_hal_vulkan_direct_command_buffer_t*) + iree_hal_command_buffer_dyn_cast( + base_command_buffer, + &iree_hal_vulkan_direct_command_buffer_vtable); return command_buffer->handle; } @@ -770,6 +788,7 @@ const iree_hal_command_buffer_vtable_t iree_hal_vulkan_direct_command_buffer_vtable = { /*.destroy=*/iree_hal_vulkan_direct_command_buffer_destroy, + /*.dyn_cast=*/iree_hal_vulkan_direct_command_buffer_dyn_cast, /*.mode=*/ iree_hal_vulkan_direct_command_buffer_mode, /*.allowed_categories=*/
diff --git a/iree/hal/vulkan/direct_command_buffer.h b/iree/hal/vulkan/direct_command_buffer.h index 606e859..3bd729b 100644 --- a/iree/hal/vulkan/direct_command_buffer.h +++ b/iree/hal/vulkan/direct_command_buffer.h
@@ -34,6 +34,10 @@ VkCommandBuffer iree_hal_vulkan_direct_command_buffer_handle( iree_hal_command_buffer_t* command_buffer); +// Returns true if |command_buffer| is a Vulkan command buffer. +bool iree_hal_vulkan_direct_command_buffer_isa( + iree_hal_command_buffer_t* command_buffer); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus
diff --git a/iree/task/worker.c b/iree/task/worker.c index 6d5bc27..cca9142 100644 --- a/iree/task/worker.c +++ b/iree/task/worker.c
@@ -264,6 +264,9 @@ // in one doesn't bring down the whole system we pretend we executed // something here by falling through. IREE_ASSERT_TRUE(iree_status_is_ok(status)); + if (!iree_status_is_ok(status)) { + iree_status_fprint(stderr, status); + } iree_status_ignore(status); IREE_TRACE_ZONE_END(z0);