NFC: Rename cuda to cuda2
diff --git a/experimental/cuda2/stream_command_buffer.c b/experimental/cuda2/stream_command_buffer.c index 3b3d490..55588a1 100644 --- a/experimental/cuda2/stream_command_buffer.c +++ b/experimental/cuda2/stream_command_buffer.c
@@ -21,8 +21,8 @@ typedef struct { iree_hal_command_buffer_t base; - iree_hal_cuda_context_wrapper_t* context; - iree_hal_cuda_tracing_context_t* tracing_context; + iree_hal_cuda2_context_wrapper_t* context; + iree_hal_cuda2_tracing_context_t* tracing_context; CUstream stream; // Maintains a reference to all resources used within the command buffer. @@ -42,21 +42,22 @@ // Keep track of the current set of kernel arguments. void* current_descriptor[IREE_HAL_CUDA_MAX_KERNEL_ARG]; CUdeviceptr* device_ptrs[IREE_HAL_CUDA_MAX_KERNEL_ARG]; -} iree_hal_cuda_stream_command_buffer_t; +} iree_hal_cuda2_stream_command_buffer_t; static const iree_hal_command_buffer_vtable_t - iree_hal_cuda_stream_command_buffer_vtable; + iree_hal_cuda2_stream_command_buffer_vtable; -static iree_hal_cuda_stream_command_buffer_t* -iree_hal_cuda_stream_command_buffer_cast( +static iree_hal_cuda2_stream_command_buffer_t* +iree_hal_cuda2_stream_command_buffer_cast( iree_hal_command_buffer_t* base_value) { - IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_cuda_stream_command_buffer_vtable); - return (iree_hal_cuda_stream_command_buffer_t*)base_value; + IREE_HAL_ASSERT_TYPE(base_value, + &iree_hal_cuda2_stream_command_buffer_vtable); + return (iree_hal_cuda2_stream_command_buffer_t*)base_value; } -iree_status_t iree_hal_cuda_stream_command_buffer_create( - iree_hal_device_t* device, iree_hal_cuda_context_wrapper_t* context, - iree_hal_cuda_tracing_context_t* tracing_context, +iree_status_t iree_hal_cuda2_stream_command_buffer_create( + iree_hal_device_t* device, iree_hal_cuda2_context_wrapper_t* context, + iree_hal_cuda2_tracing_context_t* tracing_context, iree_hal_command_buffer_mode_t mode, iree_hal_command_category_t command_categories, iree_host_size_t binding_capacity, CUstream stream, @@ -75,14 +76,14 @@ IREE_TRACE_ZONE_BEGIN(z0); - iree_hal_cuda_stream_command_buffer_t* command_buffer = NULL; + iree_hal_cuda2_stream_command_buffer_t* command_buffer = NULL; iree_status_t status = iree_allocator_malloc(context->host_allocator, sizeof(*command_buffer), (void**)&command_buffer); if (iree_status_is_ok(status)) { iree_hal_command_buffer_initialize( device, mode, command_categories, IREE_HAL_QUEUE_AFFINITY_ANY, - binding_capacity, &iree_hal_cuda_stream_command_buffer_vtable, + binding_capacity, &iree_hal_cuda2_stream_command_buffer_vtable, &command_buffer->base); command_buffer->context = context; command_buffer->tracing_context = tracing_context; @@ -106,10 +107,10 @@ return status; } -static void iree_hal_cuda_stream_command_buffer_destroy( +static void iree_hal_cuda2_stream_command_buffer_destroy( iree_hal_command_buffer_t* base_command_buffer) { - iree_hal_cuda_stream_command_buffer_t* command_buffer = - iree_hal_cuda_stream_command_buffer_cast(base_command_buffer); + iree_hal_cuda2_stream_command_buffer_t* command_buffer = + iree_hal_cuda2_stream_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); iree_hal_collective_batch_deinitialize(&command_buffer->collective_batch); @@ -120,17 +121,17 @@ IREE_TRACE_ZONE_END(z0); } -bool iree_hal_cuda_stream_command_buffer_isa( +bool iree_hal_cuda2_stream_command_buffer_isa( iree_hal_command_buffer_t* command_buffer) { return iree_hal_resource_is(&command_buffer->resource, - &iree_hal_cuda_stream_command_buffer_vtable); + &iree_hal_cuda2_stream_command_buffer_vtable); } // Flushes any pending batched collective operations. // Must be called before any other non-collective nodes are added to the graph // or a barrier is encountered. -static iree_status_t iree_hal_cuda_stream_command_buffer_flush_collectives( - iree_hal_cuda_stream_command_buffer_t* command_buffer) { +static iree_status_t iree_hal_cuda2_stream_command_buffer_flush_collectives( + iree_hal_cuda2_stream_command_buffer_t* command_buffer) { // NOTE: we could move this out into callers by way of an always-inline shim - // that would make this a single compare against the command buffer state we // are likely to access immediately after anyway and keep overheads minimal. @@ -139,7 +140,7 @@ return iree_ok_status(); } IREE_TRACE_ZONE_BEGIN(z0); - iree_status_t status = iree_hal_cuda_nccl_submit_batch( + iree_status_t status = iree_hal_cuda2_nccl_submit_batch( command_buffer->context, command_buffer->tracing_context, &command_buffer->collective_batch, command_buffer->stream); iree_hal_collective_batch_clear(&command_buffer->collective_batch); @@ -147,28 +148,28 @@ return status; } -static iree_status_t iree_hal_cuda_stream_command_buffer_begin( +static iree_status_t iree_hal_cuda2_stream_command_buffer_begin( iree_hal_command_buffer_t* base_command_buffer) { - iree_hal_cuda_stream_command_buffer_t* command_buffer = - iree_hal_cuda_stream_command_buffer_cast(base_command_buffer); + iree_hal_cuda2_stream_command_buffer_t* command_buffer = + iree_hal_cuda2_stream_command_buffer_cast(base_command_buffer); (void)command_buffer; IREE_CUDA_TRACE_ZONE_BEGIN_EXTERNAL( command_buffer->tracing_context, command_buffer->stream, /*file_name=*/NULL, 0, - /*line=*/0, /*func_name=*/NULL, 0, "iree_hal_cuda_stream_command_buffer", - strlen("iree_hal_cuda_stream_command_buffer")); + /*line=*/0, /*func_name=*/NULL, 0, "iree_hal_cuda2_stream_command_buffer", + strlen("iree_hal_cuda2_stream_command_buffer")); return iree_ok_status(); } -static iree_status_t iree_hal_cuda_stream_command_buffer_end( +static iree_status_t iree_hal_cuda2_stream_command_buffer_end( iree_hal_command_buffer_t* base_command_buffer) { - iree_hal_cuda_stream_command_buffer_t* command_buffer = - iree_hal_cuda_stream_command_buffer_cast(base_command_buffer); + iree_hal_cuda2_stream_command_buffer_t* command_buffer = + iree_hal_cuda2_stream_command_buffer_cast(base_command_buffer); IREE_RETURN_IF_ERROR( - iree_hal_cuda_stream_command_buffer_flush_collectives(command_buffer)); + iree_hal_cuda2_stream_command_buffer_flush_collectives(command_buffer)); // Reset the arena as there should be nothing using it now that we've // dispatched all our operations inline. @@ -193,12 +194,12 @@ return iree_ok_status(); } -static void iree_hal_cuda_stream_command_buffer_begin_debug_group( +static void iree_hal_cuda2_stream_command_buffer_begin_debug_group( iree_hal_command_buffer_t* base_command_buffer, iree_string_view_t label, iree_hal_label_color_t label_color, const iree_hal_label_location_t* location) { - iree_hal_cuda_stream_command_buffer_t* command_buffer = - iree_hal_cuda_stream_command_buffer_cast(base_command_buffer); + iree_hal_cuda2_stream_command_buffer_t* command_buffer = + iree_hal_cuda2_stream_command_buffer_cast(base_command_buffer); (void)command_buffer; IREE_CUDA_TRACE_ZONE_BEGIN_EXTERNAL( @@ -210,10 +211,10 @@ // TODO: pass along to CUPTI if available. } -static void iree_hal_cuda_stream_command_buffer_end_debug_group( +static void iree_hal_cuda2_stream_command_buffer_end_debug_group( iree_hal_command_buffer_t* base_command_buffer) { - iree_hal_cuda_stream_command_buffer_t* command_buffer = - iree_hal_cuda_stream_command_buffer_cast(base_command_buffer); + iree_hal_cuda2_stream_command_buffer_t* command_buffer = + iree_hal_cuda2_stream_command_buffer_cast(base_command_buffer); (void)command_buffer; // TODO: pass along to CUPTI if available. @@ -222,7 +223,7 @@ command_buffer->stream); } -static iree_status_t iree_hal_cuda_stream_command_buffer_execution_barrier( +static iree_status_t iree_hal_cuda2_stream_command_buffer_execution_barrier( iree_hal_command_buffer_t* base_command_buffer, iree_hal_execution_stage_t source_stage_mask, iree_hal_execution_stage_t target_stage_mask, @@ -231,37 +232,37 @@ const iree_hal_memory_barrier_t* memory_barriers, iree_host_size_t buffer_barrier_count, const iree_hal_buffer_barrier_t* buffer_barriers) { - iree_hal_cuda_stream_command_buffer_t* command_buffer = - iree_hal_cuda_stream_command_buffer_cast(base_command_buffer); + iree_hal_cuda2_stream_command_buffer_t* command_buffer = + iree_hal_cuda2_stream_command_buffer_cast(base_command_buffer); IREE_RETURN_IF_ERROR( - iree_hal_cuda_stream_command_buffer_flush_collectives(command_buffer)); + iree_hal_cuda2_stream_command_buffer_flush_collectives(command_buffer)); // TODO(jinchen62): implement CUDA barrier return iree_ok_status(); } -static iree_status_t iree_hal_cuda_stream_command_buffer_signal_event( +static iree_status_t iree_hal_cuda2_stream_command_buffer_signal_event( iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event, iree_hal_execution_stage_t source_stage_mask) { - iree_hal_cuda_stream_command_buffer_t* command_buffer = - iree_hal_cuda_stream_command_buffer_cast(base_command_buffer); + iree_hal_cuda2_stream_command_buffer_t* command_buffer = + iree_hal_cuda2_stream_command_buffer_cast(base_command_buffer); IREE_RETURN_IF_ERROR( - iree_hal_cuda_stream_command_buffer_flush_collectives(command_buffer)); + iree_hal_cuda2_stream_command_buffer_flush_collectives(command_buffer)); // TODO(jinchen62): implement CUDA barrier return iree_ok_status(); } -static iree_status_t iree_hal_cuda_stream_command_buffer_reset_event( +static iree_status_t iree_hal_cuda2_stream_command_buffer_reset_event( iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event, iree_hal_execution_stage_t source_stage_mask) { - iree_hal_cuda_stream_command_buffer_t* command_buffer = - iree_hal_cuda_stream_command_buffer_cast(base_command_buffer); + iree_hal_cuda2_stream_command_buffer_t* command_buffer = + iree_hal_cuda2_stream_command_buffer_cast(base_command_buffer); IREE_RETURN_IF_ERROR( - iree_hal_cuda_stream_command_buffer_flush_collectives(command_buffer)); + iree_hal_cuda2_stream_command_buffer_flush_collectives(command_buffer)); // TODO(jinchen62): implement CUDA barrier return iree_ok_status(); } -static iree_status_t iree_hal_cuda_stream_command_buffer_wait_events( +static iree_status_t iree_hal_cuda2_stream_command_buffer_wait_events( iree_hal_command_buffer_t* base_command_buffer, iree_host_size_t event_count, const iree_hal_event_t** events, iree_hal_execution_stage_t source_stage_mask, @@ -270,32 +271,32 @@ const iree_hal_memory_barrier_t* memory_barriers, iree_host_size_t buffer_barrier_count, const iree_hal_buffer_barrier_t* buffer_barriers) { - iree_hal_cuda_stream_command_buffer_t* command_buffer = - iree_hal_cuda_stream_command_buffer_cast(base_command_buffer); + iree_hal_cuda2_stream_command_buffer_t* command_buffer = + iree_hal_cuda2_stream_command_buffer_cast(base_command_buffer); IREE_RETURN_IF_ERROR( - iree_hal_cuda_stream_command_buffer_flush_collectives(command_buffer)); + iree_hal_cuda2_stream_command_buffer_flush_collectives(command_buffer)); // TODO(jinchen62): implement CUDA barrier return iree_ok_status(); } -static iree_status_t iree_hal_cuda_stream_command_buffer_discard_buffer( +static iree_status_t iree_hal_cuda2_stream_command_buffer_discard_buffer( iree_hal_command_buffer_t* base_command_buffer, iree_hal_buffer_t* buffer) { // We could mark the memory as invalidated so that if managed CUDA does not // try to copy it back to the host. return iree_ok_status(); } -static iree_status_t iree_hal_cuda_stream_command_buffer_fill_buffer( +static iree_status_t iree_hal_cuda2_stream_command_buffer_fill_buffer( iree_hal_command_buffer_t* base_command_buffer, iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, iree_device_size_t length, const void* pattern, iree_host_size_t pattern_length) { - iree_hal_cuda_stream_command_buffer_t* command_buffer = - iree_hal_cuda_stream_command_buffer_cast(base_command_buffer); + iree_hal_cuda2_stream_command_buffer_t* command_buffer = + iree_hal_cuda2_stream_command_buffer_cast(base_command_buffer); IREE_RETURN_IF_ERROR( - iree_hal_cuda_stream_command_buffer_flush_collectives(command_buffer)); + iree_hal_cuda2_stream_command_buffer_flush_collectives(command_buffer)); - CUdeviceptr target_device_buffer = iree_hal_cuda_buffer_device_pointer( + CUdeviceptr target_device_buffer = iree_hal_cuda2_buffer_device_pointer( iree_hal_buffer_allocated_buffer(target_buffer)); target_offset += iree_hal_buffer_byte_offset(target_buffer); CUdeviceptr dst = target_device_buffer + target_offset; @@ -333,14 +334,14 @@ return iree_ok_status(); } -static iree_status_t iree_hal_cuda_stream_command_buffer_update_buffer( +static iree_status_t iree_hal_cuda2_stream_command_buffer_update_buffer( iree_hal_command_buffer_t* base_command_buffer, const void* source_buffer, iree_host_size_t source_offset, iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, iree_device_size_t length) { - iree_hal_cuda_stream_command_buffer_t* command_buffer = - iree_hal_cuda_stream_command_buffer_cast(base_command_buffer); + iree_hal_cuda2_stream_command_buffer_t* command_buffer = + iree_hal_cuda2_stream_command_buffer_cast(base_command_buffer); IREE_RETURN_IF_ERROR( - iree_hal_cuda_stream_command_buffer_flush_collectives(command_buffer)); + iree_hal_cuda2_stream_command_buffer_flush_collectives(command_buffer)); // Allocate scratch space in the arena for the data and copy it in. // The update buffer API requires that the command buffer capture the host @@ -358,7 +359,7 @@ } // Issue the copy using the scratch memory as the source. - CUdeviceptr target_device_buffer = iree_hal_cuda_buffer_device_pointer( + CUdeviceptr target_device_buffer = iree_hal_cuda2_buffer_device_pointer( iree_hal_buffer_allocated_buffer(target_buffer)); CUdeviceptr dst = target_device_buffer + iree_hal_buffer_byte_offset(target_buffer) + target_offset; @@ -370,20 +371,20 @@ return iree_ok_status(); } -static iree_status_t iree_hal_cuda_stream_command_buffer_copy_buffer( +static iree_status_t iree_hal_cuda2_stream_command_buffer_copy_buffer( iree_hal_command_buffer_t* base_command_buffer, iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, iree_device_size_t length) { - iree_hal_cuda_stream_command_buffer_t* command_buffer = - iree_hal_cuda_stream_command_buffer_cast(base_command_buffer); + iree_hal_cuda2_stream_command_buffer_t* command_buffer = + iree_hal_cuda2_stream_command_buffer_cast(base_command_buffer); IREE_RETURN_IF_ERROR( - iree_hal_cuda_stream_command_buffer_flush_collectives(command_buffer)); + iree_hal_cuda2_stream_command_buffer_flush_collectives(command_buffer)); - CUdeviceptr target_device_buffer = iree_hal_cuda_buffer_device_pointer( + CUdeviceptr target_device_buffer = iree_hal_cuda2_buffer_device_pointer( iree_hal_buffer_allocated_buffer(target_buffer)); target_offset += iree_hal_buffer_byte_offset(target_buffer); - CUdeviceptr source_device_buffer = iree_hal_cuda_buffer_device_pointer( + CUdeviceptr source_device_buffer = iree_hal_cuda2_buffer_device_pointer( iree_hal_buffer_allocated_buffer(source_buffer)); source_offset += iree_hal_buffer_byte_offset(source_buffer); CUdeviceptr dst = target_device_buffer + target_offset; @@ -395,24 +396,24 @@ return iree_ok_status(); } -static iree_status_t iree_hal_cuda_stream_command_buffer_collective( +static iree_status_t iree_hal_cuda2_stream_command_buffer_collective( iree_hal_command_buffer_t* base_command_buffer, iree_hal_channel_t* channel, iree_hal_collective_op_t op, uint32_t param, iree_hal_buffer_binding_t send_binding, iree_hal_buffer_binding_t recv_binding, iree_device_size_t element_count) { - iree_hal_cuda_stream_command_buffer_t* command_buffer = - iree_hal_cuda_stream_command_buffer_cast(base_command_buffer); + iree_hal_cuda2_stream_command_buffer_t* command_buffer = + iree_hal_cuda2_stream_command_buffer_cast(base_command_buffer); return iree_hal_collective_batch_append(&command_buffer->collective_batch, channel, op, param, send_binding, recv_binding, element_count); } -static iree_status_t iree_hal_cuda_stream_command_buffer_push_constants( +static iree_status_t iree_hal_cuda2_stream_command_buffer_push_constants( iree_hal_command_buffer_t* base_command_buffer, iree_hal_pipeline_layout_t* pipeline_layout, iree_host_size_t offset, const void* values, iree_host_size_t values_length) { - iree_hal_cuda_stream_command_buffer_t* command_buffer = - iree_hal_cuda_stream_command_buffer_cast(base_command_buffer); + iree_hal_cuda2_stream_command_buffer_t* command_buffer = + iree_hal_cuda2_stream_command_buffer_cast(base_command_buffer); iree_host_size_t constant_base_index = offset / sizeof(int32_t); for (iree_host_size_t i = 0; i < values_length / sizeof(int32_t); i++) { @@ -427,41 +428,42 @@ typedef struct { uint32_t index; uint32_t binding; -} iree_hal_cuda_binding_mapping_t; +} iree_hal_cuda2_binding_mapping_t; // Helper to sort the binding based on their binding index. static int compare_binding_index(const void* a, const void* b) { - const iree_hal_cuda_binding_mapping_t buffer_a = - *(const iree_hal_cuda_binding_mapping_t*)a; - const iree_hal_cuda_binding_mapping_t buffer_b = - *(const iree_hal_cuda_binding_mapping_t*)b; + const iree_hal_cuda2_binding_mapping_t buffer_a = + *(const iree_hal_cuda2_binding_mapping_t*)a; + const iree_hal_cuda2_binding_mapping_t buffer_b = + *(const iree_hal_cuda2_binding_mapping_t*)b; return buffer_a.binding < buffer_b.binding ? -1 : 1; } -static iree_status_t iree_hal_cuda_stream_command_buffer_push_descriptor_set( +static iree_status_t iree_hal_cuda2_stream_command_buffer_push_descriptor_set( iree_hal_command_buffer_t* base_command_buffer, iree_hal_pipeline_layout_t* pipeline_layout, uint32_t set, iree_host_size_t binding_count, const iree_hal_descriptor_set_binding_t* bindings) { - iree_hal_cuda_stream_command_buffer_t* command_buffer = - iree_hal_cuda_stream_command_buffer_cast(base_command_buffer); + iree_hal_cuda2_stream_command_buffer_t* command_buffer = + iree_hal_cuda2_stream_command_buffer_cast(base_command_buffer); iree_host_size_t base_binding = - iree_hal_cuda_base_binding_index(pipeline_layout, set); + iree_hal_cuda2_base_binding_index(pipeline_layout, set); // Convention with the compiler side. We map bindings to kernel argument. // We compact the bindings to get a dense set of arguments and keep them order // based on the binding index. // Sort the binding based on the binding index and map the array index to the // argument index. - iree_hal_cuda_binding_mapping_t binding_used[IREE_HAL_CUDA_MAX_BINDING_COUNT]; + iree_hal_cuda2_binding_mapping_t + binding_used[IREE_HAL_CUDA_MAX_BINDING_COUNT]; for (iree_host_size_t i = 0; i < binding_count; i++) { - iree_hal_cuda_binding_mapping_t buffer = {i, bindings[i].binding}; + iree_hal_cuda2_binding_mapping_t buffer = {i, bindings[i].binding}; binding_used[i] = buffer; } // TODO: remove this sort - it's thankfully small (1-8 on average) but we // should be able to avoid it like we do on the CPU side with a bitmap. - qsort(binding_used, binding_count, sizeof(iree_hal_cuda_binding_mapping_t), + qsort(binding_used, binding_count, sizeof(iree_hal_cuda2_binding_mapping_t), compare_binding_index); assert(binding_count < IREE_HAL_CUDA_MAX_BINDING_COUNT && "binding count larger than the max expected."); @@ -470,7 +472,7 @@ iree_hal_descriptor_set_binding_t binding = bindings[binding_used[i].index]; CUdeviceptr device_ptr = binding.buffer - ? (iree_hal_cuda_buffer_device_pointer( + ? (iree_hal_cuda2_buffer_device_pointer( iree_hal_buffer_allocated_buffer(binding.buffer)) + iree_hal_buffer_byte_offset(binding.buffer) + binding.offset) : 0; @@ -481,20 +483,20 @@ return iree_ok_status(); } -static iree_status_t iree_hal_cuda_stream_command_buffer_dispatch( +static iree_status_t iree_hal_cuda2_stream_command_buffer_dispatch( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { - iree_hal_cuda_stream_command_buffer_t* command_buffer = - iree_hal_cuda_stream_command_buffer_cast(base_command_buffer); + iree_hal_cuda2_stream_command_buffer_t* command_buffer = + iree_hal_cuda2_stream_command_buffer_cast(base_command_buffer); IREE_RETURN_IF_ERROR( - iree_hal_cuda_stream_command_buffer_flush_collectives(command_buffer)); + iree_hal_cuda2_stream_command_buffer_flush_collectives(command_buffer)); // Lookup kernel parameters used for side-channeling additional launch // information from the compiler. - iree_hal_cuda_kernel_params_t kernel_params; + iree_hal_cuda2_kernel_params_t kernel_params; IREE_RETURN_IF_ERROR( - iree_hal_cuda_native_executable_entry_point_kernel_params( + iree_hal_cuda2_native_executable_entry_point_kernel_params( executable, entry_point, &kernel_params)); IREE_CUDA_TRACE_ZONE_BEGIN_EXTERNAL( @@ -505,9 +507,9 @@ // Patch the push constants in the kernel arguments. iree_host_size_t num_constants = - iree_hal_cuda_pipeline_layout_num_constants(kernel_params.layout); + iree_hal_cuda2_pipeline_layout_num_constants(kernel_params.layout); iree_host_size_t constant_base_index = - iree_hal_cuda_push_constant_index(kernel_params.layout); + iree_hal_cuda2_push_constant_index(kernel_params.layout); for (iree_host_size_t i = 0; i < num_constants; i++) { *((uint32_t*)command_buffer->current_descriptor[i + constant_base_index]) = command_buffer->push_constant[i]; @@ -528,7 +530,7 @@ return iree_ok_status(); } -static iree_status_t iree_hal_cuda_stream_command_buffer_dispatch_indirect( +static iree_status_t iree_hal_cuda2_stream_command_buffer_dispatch_indirect( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, iree_hal_buffer_t* workgroups_buffer, @@ -537,7 +539,7 @@ "need cuda implementation of dispatch indirect"); } -static iree_status_t iree_hal_cuda_stream_command_buffer_execute_commands( +static iree_status_t iree_hal_cuda2_stream_command_buffer_execute_commands( iree_hal_command_buffer_t* base_command_buffer, iree_hal_command_buffer_t* base_commands, iree_hal_buffer_binding_table_t binding_table) { @@ -548,29 +550,29 @@ } static const iree_hal_command_buffer_vtable_t - iree_hal_cuda_stream_command_buffer_vtable = { - .destroy = iree_hal_cuda_stream_command_buffer_destroy, - .begin = iree_hal_cuda_stream_command_buffer_begin, - .end = iree_hal_cuda_stream_command_buffer_end, + iree_hal_cuda2_stream_command_buffer_vtable = { + .destroy = iree_hal_cuda2_stream_command_buffer_destroy, + .begin = iree_hal_cuda2_stream_command_buffer_begin, + .end = iree_hal_cuda2_stream_command_buffer_end, .begin_debug_group = - iree_hal_cuda_stream_command_buffer_begin_debug_group, - .end_debug_group = iree_hal_cuda_stream_command_buffer_end_debug_group, + iree_hal_cuda2_stream_command_buffer_begin_debug_group, + .end_debug_group = iree_hal_cuda2_stream_command_buffer_end_debug_group, .execution_barrier = - iree_hal_cuda_stream_command_buffer_execution_barrier, - .signal_event = iree_hal_cuda_stream_command_buffer_signal_event, - .reset_event = iree_hal_cuda_stream_command_buffer_reset_event, - .wait_events = iree_hal_cuda_stream_command_buffer_wait_events, - .discard_buffer = iree_hal_cuda_stream_command_buffer_discard_buffer, - .fill_buffer = iree_hal_cuda_stream_command_buffer_fill_buffer, - .update_buffer = iree_hal_cuda_stream_command_buffer_update_buffer, - .copy_buffer = iree_hal_cuda_stream_command_buffer_copy_buffer, - .collective = iree_hal_cuda_stream_command_buffer_collective, - .push_constants = iree_hal_cuda_stream_command_buffer_push_constants, + iree_hal_cuda2_stream_command_buffer_execution_barrier, + .signal_event = iree_hal_cuda2_stream_command_buffer_signal_event, + .reset_event = iree_hal_cuda2_stream_command_buffer_reset_event, + .wait_events = iree_hal_cuda2_stream_command_buffer_wait_events, + .discard_buffer = iree_hal_cuda2_stream_command_buffer_discard_buffer, + .fill_buffer = iree_hal_cuda2_stream_command_buffer_fill_buffer, + .update_buffer = iree_hal_cuda2_stream_command_buffer_update_buffer, + .copy_buffer = iree_hal_cuda2_stream_command_buffer_copy_buffer, + .collective = iree_hal_cuda2_stream_command_buffer_collective, + .push_constants = iree_hal_cuda2_stream_command_buffer_push_constants, .push_descriptor_set = - iree_hal_cuda_stream_command_buffer_push_descriptor_set, - .dispatch = iree_hal_cuda_stream_command_buffer_dispatch, + iree_hal_cuda2_stream_command_buffer_push_descriptor_set, + .dispatch = iree_hal_cuda2_stream_command_buffer_dispatch, .dispatch_indirect = - iree_hal_cuda_stream_command_buffer_dispatch_indirect, + iree_hal_cuda2_stream_command_buffer_dispatch_indirect, .execute_commands = - iree_hal_cuda_stream_command_buffer_execute_commands, + iree_hal_cuda2_stream_command_buffer_execute_commands, };
diff --git a/experimental/cuda2/stream_command_buffer.h b/experimental/cuda2/stream_command_buffer.h index 2922c16..dc38cdf 100644 --- a/experimental/cuda2/stream_command_buffer.h +++ b/experimental/cuda2/stream_command_buffer.h
@@ -1,11 +1,11 @@ -// Copyright 2021 The IREE Authors +// Copyright 2023 The IREE Authors // // Licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#ifndef IREE_HAL_DRIVERS_CUDA_STREAM_COMMAND_BUFFER_H_ -#define IREE_HAL_DRIVERS_CUDA_STREAM_COMMAND_BUFFER_H_ +#ifndef EXPERIMENTAL_CUDA2_STREAM_COMMAND_BUFFER_H_ +#define EXPERIMENTAL_CUDA2_STREAM_COMMAND_BUFFER_H_ #include "iree/base/internal/arena.h" #include "iree/hal/api.h" @@ -29,9 +29,9 @@ // perform inline execution. When replaying the scratch data required for things // like buffer updates is retained by the source deferred command buffer and as // such the |block_pool| and can be NULL to avoid a double copy. -iree_status_t iree_hal_cuda_stream_command_buffer_create( - iree_hal_device_t* device, iree_hal_cuda_context_wrapper_t* context, - iree_hal_cuda_tracing_context_t* tracing_context, +iree_status_t iree_hal_cuda2_stream_command_buffer_create( + iree_hal_device_t* device, iree_hal_cuda2_context_wrapper_t* context, + iree_hal_cuda2_tracing_context_t* tracing_context, iree_hal_command_buffer_mode_t mode, iree_hal_command_category_t command_categories, iree_host_size_t binding_capacity, CUstream stream, @@ -39,11 +39,11 @@ iree_hal_command_buffer_t** out_command_buffer); // Returns true if |command_buffer| is a CUDA stream-based command buffer. -bool iree_hal_cuda_stream_command_buffer_isa( +bool iree_hal_cuda2_stream_command_buffer_isa( iree_hal_command_buffer_t* command_buffer); #ifdef __cplusplus } // extern "C" #endif // __cplusplus -#endif // IREE_HAL_DRIVERS_CUDA_STREAM_COMMAND_BUFFER_H_ +#endif // EXPERIMENTAL_CUDA2_STREAM_COMMAND_BUFFER_H_