[cuda] Port over CUDA stream-based command buffer impl This commit uses existing CUDA HAL driver's stream command buffer implementation. Improvements include removing context wrapper and adding various Tracy tracking markers. Various places are cleaned up, e.g., barrier/event impl. Inline execution support is dropped due to its incompatibility with semaphores and recorded command buffers. With it now we can enable CUDA stream-based tests.
diff --git a/experimental/cuda2/CMakeLists.txt b/experimental/cuda2/CMakeLists.txt index b1e8a9c..d3033f0 100644 --- a/experimental/cuda2/CMakeLists.txt +++ b/experimental/cuda2/CMakeLists.txt
@@ -42,6 +42,8 @@ "pending_queue_actions.h" "pipeline_layout.c" "pipeline_layout.h" + "stream_command_buffer.c" + "stream_command_buffer.h" "timepoint_pool.c" "timepoint_pool.h" "tracing.c" @@ -56,6 +58,7 @@ iree::hal iree::hal::utils::buffer_transfer iree::hal::utils::collective_batch + iree::hal::utils::deferred_command_buffer iree::hal::utils::file_transfer iree::hal::utils::memory_file iree::hal::utils::resource_set
diff --git a/experimental/cuda2/api.h b/experimental/cuda2/api.h index 62130cb..403b3dc 100644 --- a/experimental/cuda2/api.h +++ b/experimental/cuda2/api.h
@@ -20,6 +20,14 @@ // iree_hal_cuda2_device_t //===----------------------------------------------------------------------===// +// How command buffers are recorded and executed. +typedef enum iree_hal_cuda_command_buffer_mode_e { + // Command buffers are recorded into CUDA graphs. + IREE_HAL_CUDA_COMMAND_BUFFER_MODE_GRAPH = 0, + // Command buffers are directly issued against a CUDA stream. + IREE_HAL_CUDA_COMMAND_BUFFER_MODE_STREAM = 1, +} iree_hal_cuda2_command_buffer_mode_t; + // ncclUniqueId exposed without exporting the NCCL headers. typedef struct { char data[128]; @@ -66,6 +74,9 @@ // consumption. iree_host_size_t event_pool_capacity; + // Specifies how command buffers are recorded and executed. + iree_hal_cuda2_command_buffer_mode_t command_buffer_mode; + // Enables tracing of command buffers when IREE tracing is enabled. // May take advantage of additional extensions for more accurate timing or // hardware-specific performance counters.
diff --git a/experimental/cuda2/cts/CMakeLists.txt b/experimental/cuda2/cts/CMakeLists.txt index c6a07fa..e48f470 100644 --- a/experimental/cuda2/cts/CMakeLists.txt +++ b/experimental/cuda2/cts/CMakeLists.txt
@@ -7,6 +7,8 @@ iree_hal_cts_test_suite( DRIVER_NAME cuda2 + VARIANT_SUFFIX + graph DRIVER_REGISTRATION_HDR "experimental/cuda2/registration/driver_module.h" DRIVER_REGISTRATION_FN @@ -15,6 +17,33 @@ "cuda" EXECUTABLE_FORMAT "\"PTXE\"" + ARGS + "--cuda2_use_streams=false" + DEPS + iree::experimental::cuda2::registration + EXCLUDED_TESTS + # HAL event is unimplemented for now. + "event" + LABELS + driver=cuda2 + requires-gpu-nvidia +) + +iree_hal_cts_test_suite( + DRIVER_NAME + cuda2 + VARIANT_SUFFIX + stream + DRIVER_REGISTRATION_HDR + "experimental/cuda2/registration/driver_module.h" + DRIVER_REGISTRATION_FN + "iree_hal_cuda2_driver_module_register" + COMPILER_TARGET_BACKEND + "cuda" + EXECUTABLE_FORMAT + "\"PTXE\"" + ARGS + "--cuda2_use_streams=true" DEPS iree::experimental::cuda2::registration EXCLUDED_TESTS
diff --git a/experimental/cuda2/cuda_device.c b/experimental/cuda2/cuda_device.c index a5e8788..5ef14f3 100644 --- a/experimental/cuda2/cuda_device.c +++ b/experimental/cuda2/cuda_device.c
@@ -11,7 +11,6 @@ #include <string.h> #include "experimental/cuda2/cuda_allocator.h" -#include "experimental/cuda2/cuda_buffer.h" #include "experimental/cuda2/cuda_dynamic_symbols.h" #include "experimental/cuda2/cuda_status_util.h" #include "experimental/cuda2/event_pool.h" @@ -23,6 +22,7 @@ #include "experimental/cuda2/nop_executable_cache.h" #include "experimental/cuda2/pending_queue_actions.h" #include "experimental/cuda2/pipeline_layout.h" +#include "experimental/cuda2/stream_command_buffer.h" #include "experimental/cuda2/timepoint_pool.h" #include "experimental/cuda2/tracing.h" #include "iree/base/internal/arena.h" @@ -88,6 +88,10 @@ // Optional provider used for creating/configuring collective channels. iree_hal_channel_provider_t* channel_provider; + + // A CUDA stream-based command buffer used to apply deferred command buffers. + // TODO: have one cached per stream once there are multiple streams. + iree_hal_command_buffer_t* deferred_command_buffer; } iree_hal_cuda2_device_t; static const iree_hal_device_vtable_t iree_hal_cuda2_device_vtable; @@ -109,6 +113,7 @@ out_params->arena_block_size = 32 * 1024; out_params->event_pool_capacity = 32; out_params->queue_count = 1; + out_params->command_buffer_mode = IREE_HAL_CUDA_COMMAND_BUFFER_MODE_GRAPH; out_params->stream_tracing = false; out_params->async_allocations = true; } @@ -194,6 +199,18 @@ host_allocator, &device->device_allocator); } + if (iree_status_is_ok(status) && + params->command_buffer_mode == IREE_HAL_CUDA_COMMAND_BUFFER_MODE_STREAM) { + status = iree_hal_cuda2_stream_command_buffer_create( + (iree_hal_device_t*)device, device->cuda_symbols, device->nccl_symbols, + device->tracing_context, + IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION | + IREE_HAL_COMMAND_BUFFER_MODE_UNVALIDATED, + IREE_HAL_COMMAND_CATEGORY_ANY, /*binding_capacity=*/0, + device->dispatch_cu_stream, &device->block_pool, device->host_allocator, + &device->deferred_command_buffer); + } + if (iree_status_is_ok(status)) { *out_device = (iree_hal_device_t*)device; } else { @@ -312,6 +329,8 @@ iree_hal_cuda2_pending_queue_actions_destroy( (iree_hal_resource_t*)device->pending_queue_actions); + iree_hal_command_buffer_release(device->deferred_command_buffer); + // There should be no more buffers live that use the allocator. iree_hal_allocator_release(device->device_allocator); @@ -523,10 +542,22 @@ iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity, iree_hal_command_buffer_t** out_command_buffer) { iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); - return iree_hal_cuda2_graph_command_buffer_create( - base_device, device->cuda_symbols, device->cu_context, mode, - command_categories, queue_affinity, binding_capacity, &device->block_pool, - device->host_allocator, out_command_buffer); + + switch (device->params.command_buffer_mode) { + case IREE_HAL_CUDA_COMMAND_BUFFER_MODE_GRAPH: + return iree_hal_cuda2_graph_command_buffer_create( + base_device, device->cuda_symbols, device->cu_context, mode, + command_categories, queue_affinity, binding_capacity, + &device->block_pool, device->host_allocator, out_command_buffer); + case IREE_HAL_CUDA_COMMAND_BUFFER_MODE_STREAM: + return iree_hal_deferred_command_buffer_create( + base_device, mode, command_categories, binding_capacity, + &device->block_pool, iree_hal_device_host_allocator(base_device), + out_command_buffer); + default: + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "invalid command buffer mode"); + } } static iree_status_t iree_hal_cuda2_device_create_descriptor_set_layout( @@ -729,8 +760,9 @@ iree_status_t status = iree_hal_cuda2_pending_queue_actions_enqueue_execution( device->dispatch_cu_stream, device->callback_cu_stream, - device->pending_queue_actions, wait_semaphore_list, signal_semaphore_list, - command_buffer_count, command_buffers); + device->deferred_command_buffer, device->pending_queue_actions, + wait_semaphore_list, signal_semaphore_list, command_buffer_count, + command_buffers); if (iree_status_is_ok(status)) { // Try to advance the pending workload queue. status = iree_hal_cuda2_pending_queue_actions_issue(
diff --git a/experimental/cuda2/graph_command_buffer.h b/experimental/cuda2/graph_command_buffer.h index ad2d4b9..413e46a 100644 --- a/experimental/cuda2/graph_command_buffer.h +++ b/experimental/cuda2/graph_command_buffer.h
@@ -20,7 +20,8 @@ // Creates a command buffer that records into a CUDA graph. // -// NOTE: the |block_pool| must remain live for the lifetime of the command +// |block_pool| will be used by the graph command buffer to retain copies of +// input data until reset. It must remain live for the lifetime of the command // buffers that use it. iree_status_t iree_hal_cuda2_graph_command_buffer_create( iree_hal_device_t* device,
diff --git a/experimental/cuda2/pending_queue_actions.c b/experimental/cuda2/pending_queue_actions.c index b096c78..e21b70c 100644 --- a/experimental/cuda2/pending_queue_actions.c +++ b/experimental/cuda2/pending_queue_actions.c
@@ -16,6 +16,7 @@ #include "iree/base/internal/arena.h" #include "iree/base/internal/synchronization.h" #include "iree/hal/api.h" +#include "iree/hal/utils/deferred_command_buffer.h" #include "iree/hal/utils/resource_set.h" //===----------------------------------------------------------------------===// @@ -54,6 +55,11 @@ // The stream to launch CUDA host function callbacks. CUstream callback_cu_stream; + // The CUDA stream-based command buffer used to apply deferred in-memory + // command buffers. + // Owned by the device; must be issuing to dispatch_cu_stream in the above. + iree_hal_command_buffer_t* deferred_command_buffer; + // Resource set to retain all associated resources by the payload. iree_hal_resource_set_t* resource_set; @@ -241,6 +247,7 @@ iree_status_t iree_hal_cuda2_pending_queue_actions_enqueue_execution( CUstream dispatch_stream, CUstream callback_stream, + iree_hal_command_buffer_t* deferred_command_buffer, iree_hal_cuda2_pending_queue_actions_t* actions, const iree_hal_semaphore_list_t wait_semaphore_list, const iree_hal_semaphore_list_t signal_semaphore_list, @@ -260,6 +267,7 @@ action->payload.command_buffers.ptr = command_buffers; action->dispatch_cu_stream = dispatch_stream; action->callback_cu_stream = callback_stream; + action->deferred_command_buffer = deferred_command_buffer; action->events = NULL; action->event_count = 0; action->is_pending = true; @@ -364,11 +372,20 @@ // Then launch all command buffers to the dispatch stream. for (iree_host_size_t i = 0; i < action->payload.command_buffers.count; ++i) { - CUgraphExec exec = iree_hal_cuda2_graph_command_buffer_handle( - action->payload.command_buffers.ptr[i]); - IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR( - z0, symbols, cuGraphLaunch(exec, action->dispatch_cu_stream), - "cuGraphLaunch"); + iree_hal_command_buffer_t* command_buffer = + action->payload.command_buffers.ptr[i]; + if (iree_hal_cuda2_graph_command_buffer_isa(command_buffer)) { + CUgraphExec exec = iree_hal_cuda2_graph_command_buffer_handle( + action->payload.command_buffers.ptr[i]); + IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR( + z0, symbols, cuGraphLaunch(exec, action->dispatch_cu_stream), + "cuGraphLaunch"); + } else { + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_deferred_command_buffer_apply( + command_buffer, action->deferred_command_buffer, + iree_hal_buffer_binding_table_empty())); + } } // Last record CUevent signals in the dispatch stream.
diff --git a/experimental/cuda2/pending_queue_actions.h b/experimental/cuda2/pending_queue_actions.h index efcf7ec..590c8e4 100644 --- a/experimental/cuda2/pending_queue_actions.h +++ b/experimental/cuda2/pending_queue_actions.h
@@ -49,6 +49,7 @@ // |wait_semaphore_list| and signals |signal_semaphore_lsit|. iree_status_t iree_hal_cuda2_pending_queue_actions_enqueue_execution( CUstream dispatch_stream, CUstream callback_stream, + iree_hal_command_buffer_t* deferred_command_buffer, iree_hal_cuda2_pending_queue_actions_t* actions, const iree_hal_semaphore_list_t wait_semaphore_list, const iree_hal_semaphore_list_t signal_semaphore_list,
diff --git a/experimental/cuda2/registration/driver_module.c b/experimental/cuda2/registration/driver_module.c index 9ebf116..166286f 100644 --- a/experimental/cuda2/registration/driver_module.c +++ b/experimental/cuda2/registration/driver_module.c
@@ -14,6 +14,14 @@ #include "iree/base/internal/flags.h" IREE_FLAG( + bool, cuda2_use_streams, false, + "Use CUDA streams (instead of graphs) for executing command buffers."); + +IREE_FLAG(bool, cuda2_allow_inline_execution, false, + "Allow command buffers to execute inline against CUDA streams when\n" + "possible."); + +IREE_FLAG( bool, cuda2_async_allocations, true, "Enables CUDA asynchronous stream-ordered allocations when supported."); @@ -90,6 +98,10 @@ iree_hal_cuda2_device_params_t device_params; iree_hal_cuda2_device_params_initialize(&device_params); + if (FLAG_cuda2_use_streams) { + device_params.command_buffer_mode = + IREE_HAL_CUDA_COMMAND_BUFFER_MODE_STREAM; + } device_params.stream_tracing = FLAG_cuda2_tracing; device_params.async_allocations = FLAG_cuda2_async_allocations;
diff --git a/experimental/cuda2/stream_command_buffer.c b/experimental/cuda2/stream_command_buffer.c index 55588a1..1dec100 100644 --- a/experimental/cuda2/stream_command_buffer.c +++ b/experimental/cuda2/stream_command_buffer.c
@@ -4,14 +4,13 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "iree/hal/drivers/cuda/stream_command_buffer.h" +#include "experimental/cuda2/stream_command_buffer.h" -#include "iree/hal/drivers/cuda/cuda_buffer.h" -#include "iree/hal/drivers/cuda/cuda_event.h" -#include "iree/hal/drivers/cuda/native_executable.h" -#include "iree/hal/drivers/cuda/nccl_channel.h" -#include "iree/hal/drivers/cuda/pipeline_layout.h" -#include "iree/hal/drivers/cuda/status_util.h" +#include "experimental/cuda2/cuda_buffer.h" +#include "experimental/cuda2/cuda_status_util.h" +#include "experimental/cuda2/native_executable.h" +#include "experimental/cuda2/nccl_channel.h" +#include "experimental/cuda2/pipeline_layout.h" #include "iree/hal/utils/collective_batch.h" #include "iree/hal/utils/resource_set.h" @@ -19,14 +18,20 @@ // Kernel arguments contains binding and push constants. #define IREE_HAL_CUDA_MAX_KERNEL_ARG 128 -typedef struct { +typedef struct iree_hal_cuda2_stream_command_buffer_t { iree_hal_command_buffer_t base; - iree_hal_cuda2_context_wrapper_t* context; - iree_hal_cuda2_tracing_context_t* tracing_context; - CUstream stream; + iree_allocator_t host_allocator; - // Maintains a reference to all resources used within the command buffer. - // Reset on each begin. + const iree_hal_cuda2_dynamic_symbols_t* cuda_symbols; + const iree_hal_cuda2_nccl_dynamic_symbols_t* nccl_symbols; + + // Per-stream CUDA tracing context. + iree_hal_cuda2_tracing_context_t* tracing_context; + + CUstream cu_stream; + + // A resource set to maintain references to all resources used within the + // command buffer. Reset on each begin. iree_hal_resource_set_t* resource_set; // Staging arena used for host->device transfers. @@ -37,10 +42,10 @@ // Iteratively constructed batch of collective operations. iree_hal_collective_batch_t collective_batch; - int32_t push_constant[IREE_HAL_CUDA_MAX_PUSH_CONSTANT_COUNT]; + int32_t push_constants[IREE_HAL_CUDA_MAX_PUSH_CONSTANT_COUNT]; - // Keep track of the current set of kernel arguments. - void* current_descriptor[IREE_HAL_CUDA_MAX_KERNEL_ARG]; + // The current set of kernel arguments. + void* current_descriptors[IREE_HAL_CUDA_MAX_KERNEL_ARG]; CUdeviceptr* device_ptrs[IREE_HAL_CUDA_MAX_KERNEL_ARG]; } iree_hal_cuda2_stream_command_buffer_t; @@ -56,15 +61,18 @@ } iree_status_t iree_hal_cuda2_stream_command_buffer_create( - iree_hal_device_t* device, iree_hal_cuda2_context_wrapper_t* context, + iree_hal_device_t* device, + const iree_hal_cuda2_dynamic_symbols_t* cuda_symbols, + const iree_hal_cuda2_nccl_dynamic_symbols_t* nccl_symbols, iree_hal_cuda2_tracing_context_t* tracing_context, iree_hal_command_buffer_mode_t mode, iree_hal_command_category_t command_categories, iree_host_size_t binding_capacity, CUstream stream, - iree_arena_block_pool_t* block_pool, + iree_arena_block_pool_t* block_pool, iree_allocator_t host_allocator, iree_hal_command_buffer_t** out_command_buffer) { IREE_ASSERT_ARGUMENT(device); - IREE_ASSERT_ARGUMENT(context); + IREE_ASSERT_ARGUMENT(cuda_symbols); + IREE_ASSERT_ARGUMENT(nccl_symbols); IREE_ASSERT_ARGUMENT(out_command_buffer); *out_command_buffer = NULL; @@ -77,25 +85,28 @@ IREE_TRACE_ZONE_BEGIN(z0); iree_hal_cuda2_stream_command_buffer_t* command_buffer = NULL; - iree_status_t status = - iree_allocator_malloc(context->host_allocator, sizeof(*command_buffer), - (void**)&command_buffer); - if (iree_status_is_ok(status)) { - iree_hal_command_buffer_initialize( - device, mode, command_categories, IREE_HAL_QUEUE_AFFINITY_ANY, - binding_capacity, &iree_hal_cuda2_stream_command_buffer_vtable, - &command_buffer->base); - command_buffer->context = context; - command_buffer->tracing_context = tracing_context; - command_buffer->stream = stream; - iree_arena_initialize(block_pool, &command_buffer->arena); - for (size_t i = 0; i < IREE_HAL_CUDA_MAX_KERNEL_ARG; i++) { - command_buffer->current_descriptor[i] = &command_buffer->device_ptrs[i]; - } + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_allocator_malloc(host_allocator, sizeof(*command_buffer), + (void**)&command_buffer)); - status = iree_hal_resource_set_allocate(block_pool, - &command_buffer->resource_set); + iree_hal_command_buffer_initialize( + device, mode, command_categories, IREE_HAL_QUEUE_AFFINITY_ANY, + binding_capacity, &iree_hal_cuda2_stream_command_buffer_vtable, + &command_buffer->base); + command_buffer->host_allocator = host_allocator; + command_buffer->cuda_symbols = cuda_symbols; + command_buffer->nccl_symbols = nccl_symbols; + command_buffer->tracing_context = tracing_context; + command_buffer->cu_stream = stream; + iree_arena_initialize(block_pool, &command_buffer->arena); + + for (size_t i = 0; i < IREE_HAL_CUDA_MAX_KERNEL_ARG; i++) { + command_buffer->current_descriptors[i] = &command_buffer->device_ptrs[i]; } + + iree_status_t status = + iree_hal_resource_set_allocate(block_pool, &command_buffer->resource_set); + if (iree_status_is_ok(status)) { iree_hal_collective_batch_initialize(&command_buffer->arena, command_buffer->resource_set, @@ -111,12 +122,13 @@ iree_hal_command_buffer_t* base_command_buffer) { iree_hal_cuda2_stream_command_buffer_t* command_buffer = iree_hal_cuda2_stream_command_buffer_cast(base_command_buffer); + iree_allocator_t host_allocator = command_buffer->host_allocator; IREE_TRACE_ZONE_BEGIN(z0); iree_hal_collective_batch_deinitialize(&command_buffer->collective_batch); iree_hal_resource_set_free(command_buffer->resource_set); iree_arena_deinitialize(&command_buffer->arena); - iree_allocator_free(command_buffer->context->host_allocator, command_buffer); + iree_allocator_free(host_allocator, command_buffer); IREE_TRACE_ZONE_END(z0); } @@ -141,8 +153,8 @@ } IREE_TRACE_ZONE_BEGIN(z0); iree_status_t status = iree_hal_cuda2_nccl_submit_batch( - command_buffer->context, command_buffer->tracing_context, - &command_buffer->collective_batch, command_buffer->stream); + command_buffer->nccl_symbols, command_buffer->tracing_context, + &command_buffer->collective_batch, command_buffer->cu_stream); iree_hal_collective_batch_clear(&command_buffer->collective_batch); IREE_TRACE_ZONE_END(z0); return status; @@ -155,9 +167,9 @@ (void)command_buffer; IREE_CUDA_TRACE_ZONE_BEGIN_EXTERNAL( - command_buffer->tracing_context, command_buffer->stream, - /*file_name=*/NULL, 0, - /*line=*/0, /*func_name=*/NULL, 0, "iree_hal_cuda2_stream_command_buffer", + command_buffer->tracing_context, command_buffer->cu_stream, + /*file_name=*/NULL, 0, /*line=*/0, /*func_name=*/NULL, 0, + "iree_hal_cuda2_stream_command_buffer", strlen("iree_hal_cuda2_stream_command_buffer")); return iree_ok_status(); @@ -167,8 +179,10 @@ iree_hal_command_buffer_t* base_command_buffer) { iree_hal_cuda2_stream_command_buffer_t* command_buffer = iree_hal_cuda2_stream_command_buffer_cast(base_command_buffer); + IREE_TRACE_ZONE_BEGIN(z0); - IREE_RETURN_IF_ERROR( + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_cuda2_stream_command_buffer_flush_collectives(command_buffer)); // Reset the arena as there should be nothing using it now that we've @@ -182,15 +196,17 @@ iree_arena_reset(&command_buffer->arena); iree_hal_collective_batch_deinitialize(&command_buffer->collective_batch); iree_hal_resource_set_free(command_buffer->resource_set); - IREE_RETURN_IF_ERROR(iree_hal_resource_set_allocate( - command_buffer->arena.block_pool, &command_buffer->resource_set)); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_resource_set_allocate(command_buffer->arena.block_pool, + &command_buffer->resource_set)); iree_hal_collective_batch_initialize(&command_buffer->arena, command_buffer->resource_set, &command_buffer->collective_batch); IREE_CUDA_TRACE_ZONE_END(command_buffer->tracing_context, - command_buffer->stream); + command_buffer->cu_stream); + IREE_TRACE_ZONE_END(z0); return iree_ok_status(); } @@ -203,7 +219,7 @@ (void)command_buffer; IREE_CUDA_TRACE_ZONE_BEGIN_EXTERNAL( - command_buffer->tracing_context, command_buffer->stream, + command_buffer->tracing_context, command_buffer->cu_stream, location ? location->file.data : NULL, location ? location->file.size : 0, location ? location->line : 0, /*func_name=*/NULL, 0, label.data, label.size); @@ -220,7 +236,7 @@ // TODO: pass along to CUPTI if available. IREE_CUDA_TRACE_ZONE_END(command_buffer->tracing_context, - command_buffer->stream); + command_buffer->cu_stream); } static iree_status_t iree_hal_cuda2_stream_command_buffer_execution_barrier( @@ -234,32 +250,40 @@ const iree_hal_buffer_barrier_t* buffer_barriers) { iree_hal_cuda2_stream_command_buffer_t* command_buffer = iree_hal_cuda2_stream_command_buffer_cast(base_command_buffer); + + if (iree_any_bit_set(source_stage_mask, IREE_HAL_EXECUTION_STAGE_HOST) || + iree_any_bit_set(target_stage_mask, IREE_HAL_EXECUTION_STAGE_HOST)) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "barrier involving host not yet supported"); + } + + if (flags != IREE_HAL_EXECUTION_BARRIER_FLAG_NONE) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "non-zero barrier flag not yet supported"); + } + IREE_TRACE_ZONE_BEGIN(z0); + IREE_RETURN_IF_ERROR( iree_hal_cuda2_stream_command_buffer_flush_collectives(command_buffer)); - // TODO(jinchen62): implement CUDA barrier + + // Nothing to do for barriers between memory operations or dispatches--CUDA + // stream semantics guarantees execution and memory visibility in program + // order. + + IREE_TRACE_ZONE_END(z0); return iree_ok_status(); } static iree_status_t iree_hal_cuda2_stream_command_buffer_signal_event( iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event, iree_hal_execution_stage_t source_stage_mask) { - iree_hal_cuda2_stream_command_buffer_t* command_buffer = - iree_hal_cuda2_stream_command_buffer_cast(base_command_buffer); - IREE_RETURN_IF_ERROR( - iree_hal_cuda2_stream_command_buffer_flush_collectives(command_buffer)); - // TODO(jinchen62): implement CUDA barrier - return iree_ok_status(); + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "event not yet supported"); } static iree_status_t iree_hal_cuda2_stream_command_buffer_reset_event( iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event, iree_hal_execution_stage_t source_stage_mask) { - iree_hal_cuda2_stream_command_buffer_t* command_buffer = - iree_hal_cuda2_stream_command_buffer_cast(base_command_buffer); - IREE_RETURN_IF_ERROR( - iree_hal_cuda2_stream_command_buffer_flush_collectives(command_buffer)); - // TODO(jinchen62): implement CUDA barrier - return iree_ok_status(); + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "event not yet supported"); } static iree_status_t iree_hal_cuda2_stream_command_buffer_wait_events( @@ -271,12 +295,7 @@ const iree_hal_memory_barrier_t* memory_barriers, iree_host_size_t buffer_barrier_count, const iree_hal_buffer_barrier_t* buffer_barriers) { - iree_hal_cuda2_stream_command_buffer_t* command_buffer = - iree_hal_cuda2_stream_command_buffer_cast(base_command_buffer); - IREE_RETURN_IF_ERROR( - iree_hal_cuda2_stream_command_buffer_flush_collectives(command_buffer)); - // TODO(jinchen62): implement CUDA barrier - return iree_ok_status(); + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "event not yet supported"); } static iree_status_t iree_hal_cuda2_stream_command_buffer_discard_buffer( @@ -293,7 +312,10 @@ iree_host_size_t pattern_length) { iree_hal_cuda2_stream_command_buffer_t* command_buffer = iree_hal_cuda2_stream_command_buffer_cast(base_command_buffer); - IREE_RETURN_IF_ERROR( + IREE_TRACE_ZONE_BEGIN(z0); + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_cuda2_stream_command_buffer_flush_collectives(command_buffer)); CUdeviceptr target_device_buffer = iree_hal_cuda2_buffer_device_pointer( @@ -301,36 +323,39 @@ target_offset += iree_hal_buffer_byte_offset(target_buffer); CUdeviceptr dst = target_device_buffer + target_offset; size_t num_elements = length / pattern_length; + switch (pattern_length) { case 4: { - CUDA_RETURN_IF_ERROR( - command_buffer->context->syms, + IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR( + z0, command_buffer->cuda_symbols, cuMemsetD32Async(dst, *(const uint32_t*)(pattern), num_elements, - command_buffer->stream), + command_buffer->cu_stream), "cuMemsetD32Async"); break; } case 2: { - CUDA_RETURN_IF_ERROR( - command_buffer->context->syms, + IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR( + z0, command_buffer->cuda_symbols, cuMemsetD16Async(dst, *(const uint16_t*)(pattern), num_elements, - command_buffer->stream), + command_buffer->cu_stream), "cuMemsetD16Async"); break; } case 1: { - CUDA_RETURN_IF_ERROR( - command_buffer->context->syms, + IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR( + z0, command_buffer->cuda_symbols, cuMemsetD8Async(dst, *(const uint8_t*)(pattern), num_elements, - command_buffer->stream), + command_buffer->cu_stream), "cuMemsetD8Async"); break; } default: + IREE_TRACE_ZONE_END(z0); return iree_make_status(IREE_STATUS_INTERNAL, "unsupported fill pattern length"); } + IREE_TRACE_ZONE_END(z0); return iree_ok_status(); } @@ -340,7 +365,10 @@ iree_device_size_t target_offset, iree_device_size_t length) { iree_hal_cuda2_stream_command_buffer_t* command_buffer = iree_hal_cuda2_stream_command_buffer_cast(base_command_buffer); - IREE_RETURN_IF_ERROR( + IREE_TRACE_ZONE_BEGIN(z0); + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_cuda2_stream_command_buffer_flush_collectives(command_buffer)); // Allocate scratch space in the arena for the data and copy it in. @@ -352,7 +380,8 @@ const uint8_t* src = (const uint8_t*)source_buffer + source_offset; if (command_buffer->arena.block_pool) { uint8_t* storage = NULL; - IREE_RETURN_IF_ERROR( + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_arena_allocate(&command_buffer->arena, length, (void**)&storage)); memcpy(storage, src, length); src = storage; @@ -363,11 +392,12 @@ iree_hal_buffer_allocated_buffer(target_buffer)); CUdeviceptr dst = target_device_buffer + iree_hal_buffer_byte_offset(target_buffer) + target_offset; - CUDA_RETURN_IF_ERROR( - command_buffer->context->syms, - cuMemcpyHtoDAsync(dst, src, length, command_buffer->stream), + IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR( + z0, command_buffer->cuda_symbols, + cuMemcpyHtoDAsync(dst, src, length, command_buffer->cu_stream), "cuMemcpyHtoDAsync"); + IREE_TRACE_ZONE_END(z0); return iree_ok_status(); } @@ -378,7 +408,10 @@ iree_device_size_t length) { iree_hal_cuda2_stream_command_buffer_t* command_buffer = iree_hal_cuda2_stream_command_buffer_cast(base_command_buffer); - IREE_RETURN_IF_ERROR( + IREE_TRACE_ZONE_BEGIN(z0); + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_cuda2_stream_command_buffer_flush_collectives(command_buffer)); CUdeviceptr target_device_buffer = iree_hal_cuda2_buffer_device_pointer( @@ -389,10 +422,13 @@ source_offset += iree_hal_buffer_byte_offset(source_buffer); CUdeviceptr dst = target_device_buffer + target_offset; CUdeviceptr src = source_device_buffer + source_offset; - CUDA_RETURN_IF_ERROR(command_buffer->context->syms, - cuMemcpyAsync(dst, src, length, command_buffer->stream), - "cuMemcpyAsync"); + IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR( + z0, command_buffer->cuda_symbols, + cuMemcpyAsync(dst, src, length, command_buffer->cu_stream), + "cuMemcpyAsync"); + + IREE_TRACE_ZONE_END(z0); return iree_ok_status(); } @@ -403,9 +439,14 @@ iree_hal_buffer_binding_t recv_binding, iree_device_size_t element_count) { iree_hal_cuda2_stream_command_buffer_t* command_buffer = iree_hal_cuda2_stream_command_buffer_cast(base_command_buffer); - return iree_hal_collective_batch_append(&command_buffer->collective_batch, - channel, op, param, send_binding, - recv_binding, element_count); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_status_t status = iree_hal_collective_batch_append( + &command_buffer->collective_batch, channel, op, param, send_binding, + recv_binding, element_count); + + IREE_TRACE_ZONE_END(z0); + return status; } static iree_status_t iree_hal_cuda2_stream_command_buffer_push_constants( @@ -414,13 +455,15 @@ const void* values, iree_host_size_t values_length) { iree_hal_cuda2_stream_command_buffer_t* command_buffer = iree_hal_cuda2_stream_command_buffer_cast(base_command_buffer); + IREE_TRACE_ZONE_BEGIN(z0); iree_host_size_t constant_base_index = offset / sizeof(int32_t); for (iree_host_size_t i = 0; i < values_length / sizeof(int32_t); i++) { - command_buffer->push_constant[i + constant_base_index] = + command_buffer->push_constants[i + constant_base_index] = ((uint32_t*)values)[i]; } + IREE_TRACE_ZONE_END(z0); return iree_ok_status(); } @@ -446,9 +489,10 @@ const iree_hal_descriptor_set_binding_t* bindings) { iree_hal_cuda2_stream_command_buffer_t* command_buffer = iree_hal_cuda2_stream_command_buffer_cast(base_command_buffer); + IREE_TRACE_ZONE_BEGIN(z0); iree_host_size_t base_binding = - iree_hal_cuda2_base_binding_index(pipeline_layout, set); + iree_hal_cuda2_pipeline_layout_base_binding_index(pipeline_layout, set); // Convention with the compiler side. We map bindings to kernel argument. // We compact the bindings to get a dense set of arguments and keep them order @@ -476,10 +520,11 @@ iree_hal_buffer_allocated_buffer(binding.buffer)) + iree_hal_buffer_byte_offset(binding.buffer) + binding.offset) : 0; - *((CUdeviceptr*)command_buffer->current_descriptor[i + base_binding]) = + *((CUdeviceptr*)command_buffer->current_descriptors[i + base_binding]) = device_ptr; } + IREE_TRACE_ZONE_END(z0); return iree_ok_status(); } @@ -489,44 +534,48 @@ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { iree_hal_cuda2_stream_command_buffer_t* command_buffer = iree_hal_cuda2_stream_command_buffer_cast(base_command_buffer); - IREE_RETURN_IF_ERROR( + IREE_TRACE_ZONE_BEGIN(z0); + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_cuda2_stream_command_buffer_flush_collectives(command_buffer)); // Lookup kernel parameters used for side-channeling additional launch // information from the compiler. - iree_hal_cuda2_kernel_params_t kernel_params; - IREE_RETURN_IF_ERROR( - iree_hal_cuda2_native_executable_entry_point_kernel_params( - executable, entry_point, &kernel_params)); + iree_hal_cuda2_kernel_info_t kernel_params; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_cuda2_native_executable_entry_point_kernel_info( + executable, entry_point, &kernel_params)); IREE_CUDA_TRACE_ZONE_BEGIN_EXTERNAL( - command_buffer->tracing_context, command_buffer->stream, + command_buffer->tracing_context, command_buffer->cu_stream, kernel_params.source_filename.data, kernel_params.source_filename.size, kernel_params.source_line, /*func_name=*/NULL, 0, kernel_params.function_name.data, kernel_params.function_name.size); // Patch the push constants in the kernel arguments. iree_host_size_t num_constants = - iree_hal_cuda2_pipeline_layout_num_constants(kernel_params.layout); + iree_hal_cuda2_pipeline_layout_push_constant_count(kernel_params.layout); iree_host_size_t constant_base_index = - iree_hal_cuda2_push_constant_index(kernel_params.layout); + iree_hal_cuda2_pipeline_layout_push_constant_index(kernel_params.layout); for (iree_host_size_t i = 0; i < num_constants; i++) { - *((uint32_t*)command_buffer->current_descriptor[i + constant_base_index]) = - command_buffer->push_constant[i]; + *((uint32_t*)command_buffer->current_descriptors[i + constant_base_index]) = + command_buffer->push_constants[i]; } - CUDA_RETURN_IF_ERROR( - command_buffer->context->syms, - cuLaunchKernel(kernel_params.function, workgroup_x, workgroup_y, - workgroup_z, kernel_params.block_size[0], - kernel_params.block_size[1], kernel_params.block_size[2], - kernel_params.shared_memory_size, command_buffer->stream, - command_buffer->current_descriptor, NULL), + IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR( + z0, command_buffer->cuda_symbols, + cuLaunchKernel( + kernel_params.function, workgroup_x, workgroup_y, workgroup_z, + kernel_params.block_size[0], kernel_params.block_size[1], + kernel_params.block_size[2], kernel_params.shared_memory_size, + command_buffer->cu_stream, command_buffer->current_descriptors, NULL), "cuLaunchKernel"); IREE_CUDA_TRACE_ZONE_END(command_buffer->tracing_context, - command_buffer->stream); + command_buffer->cu_stream); + IREE_TRACE_ZONE_END(z0); return iree_ok_status(); }
diff --git a/experimental/cuda2/stream_command_buffer.h b/experimental/cuda2/stream_command_buffer.h index dc38cdf..6544856 100644 --- a/experimental/cuda2/stream_command_buffer.h +++ b/experimental/cuda2/stream_command_buffer.h
@@ -7,35 +7,37 @@ #ifndef EXPERIMENTAL_CUDA2_STREAM_COMMAND_BUFFER_H_ #define EXPERIMENTAL_CUDA2_STREAM_COMMAND_BUFFER_H_ +#include "experimental/cuda2/cuda_dynamic_symbols.h" +#include "experimental/cuda2/cuda_headers.h" +#include "experimental/cuda2/nccl_dynamic_symbols.h" +#include "experimental/cuda2/tracing.h" #include "iree/base/internal/arena.h" #include "iree/hal/api.h" -#include "iree/hal/drivers/cuda/context_wrapper.h" -#include "iree/hal/drivers/cuda/cuda_headers.h" -#include "iree/hal/drivers/cuda/dynamic_symbols.h" -#include "iree/hal/drivers/cuda/tracing.h" #ifdef __cplusplus extern "C" { #endif // __cplusplus -// Creates a cuda stream command buffer that immediately issues commands against -// the given |stream|. Access to |stream| must be synchronized by the user. +// Creates command buffer that immediately issues commands against the given +// CUDA |stream|. Access to |stream| must be synchronized by the user. // // If |block_pool| is non-NULL then the stream command buffer will retain copies // of input data until reset. If NULL then the caller must ensure the lifetime // of input data outlives the command buffer. // -// This command buffer is used to both replay deferred command buffers and -// perform inline execution. When replaying the scratch data required for things -// like buffer updates is retained by the source deferred command buffer and as -// such the |block_pool| and can be NULL to avoid a double copy. +// This command buffer is used to replay deferred command buffers. When +// replaying the scratch data required for things like buffer updates is +// retained by the source deferred command buffer and as such the |block_pool| +// and can be NULL to avoid a double copy. iree_status_t iree_hal_cuda2_stream_command_buffer_create( - iree_hal_device_t* device, iree_hal_cuda2_context_wrapper_t* context, + iree_hal_device_t* device, + const iree_hal_cuda2_dynamic_symbols_t* cuda_symbols, + const iree_hal_cuda2_nccl_dynamic_symbols_t* nccl_symbols, iree_hal_cuda2_tracing_context_t* tracing_context, iree_hal_command_buffer_mode_t mode, iree_hal_command_category_t command_categories, iree_host_size_t binding_capacity, CUstream stream, - iree_arena_block_pool_t* block_pool, + iree_arena_block_pool_t* block_pool, iree_allocator_t host_allocator, iree_hal_command_buffer_t** out_command_buffer); // Returns true if |command_buffer| is a CUDA stream-based command buffer.
diff --git a/experimental/cuda2/tests/stablehlo_ops/CMakeLists.txt b/experimental/cuda2/tests/stablehlo_ops/CMakeLists.txt index a7bb135..482f154 100644 --- a/experimental/cuda2/tests/stablehlo_ops/CMakeLists.txt +++ b/experimental/cuda2/tests/stablehlo_ops/CMakeLists.txt
@@ -76,6 +76,88 @@ "--iree-input-type=stablehlo" # TODO(#13984): We need memset emulation to workaround CUDA graph issues for now. "--iree-stream-emulate-memset" + RUNNER_ARGS + "--cuda2_use_streams=false" + LABELS + "noasan" + "nomsan" + "notsan" + "noubsan" + "requires-gpu-nvidia" +) + +iree_check_single_backend_test_suite( + NAME + check_cuda2_stream + SRCS + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/abs.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/add.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/batch_norm_inference.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/bitcast_convert.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/broadcast.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/broadcast_add.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/broadcast_in_dim.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/clamp.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/compare.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/complex.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/concatenate.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/constant.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/convert.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/convolution.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/cosine.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/divide.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/dot.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/dot_bf16.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/dot_general.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/dynamic_slice.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/dynamic_update_slice.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/exponential.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/exponential_fp16.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/exponential_minus_one.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/fft.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/finite.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/floor.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/gather.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/iota.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/log.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/log_plus_one.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/maximum.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/minimum.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/multiply.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/negate.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/pad.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/philox.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/pow.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/reduce.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/reduce_window.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/remainder.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/reshape.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/reverse.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/rng_normal.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/rng_uniform.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/round.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/rsqrt.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/scatter.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/scatter_dynamic.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/select.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/sine.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/slice.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/sort.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/sqrt.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/subtract.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/tanh.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/three_fry.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/torch_index_select.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/transpose.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/stablehlo_ops/while.mlir" + TARGET_BACKEND + "cuda" + DRIVER + "cuda2" + COMPILER_FLAGS + "--iree-input-type=stablehlo" + RUNNER_ARGS + "--cuda2_use_streams=true" LABELS "noasan" "nomsan"
diff --git a/experimental/cuda2/tests/tosa_ops/CMakeLists.txt b/experimental/cuda2/tests/tosa_ops/CMakeLists.txt index 7b56ffa..88752fe 100644 --- a/experimental/cuda2/tests/tosa_ops/CMakeLists.txt +++ b/experimental/cuda2/tests/tosa_ops/CMakeLists.txt
@@ -57,6 +57,69 @@ "--iree-input-type=tosa" # TODO(#13984): We need memset emulation to workaround CUDA graph issues for now. "--iree-stream-emulate-memset" + RUNNER_ARGS + "--cuda2_use_streams=false" + LABELS + "noasan" + "nomsan" + "notsan" + "noubsan" + "requires-gpu-nvidia" +) + +iree_check_single_backend_test_suite( + NAME + check_cuda2_stream + SRCS + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/abs.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/add.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/arithmetic_right_shift.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/bitwise_and.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/bitwise_or.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/bitwise_xor.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/ceil.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/clamp.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/clz.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/const.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/equal.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/exp.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/floor.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/fully_connected.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/gather.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/greater.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/greater_equal.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/if.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/log.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/logical_left_shift.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/logical_right_shift.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/logical_right_shift_16.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/matmul.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/max_pool.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/maximum.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/minimum.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/mul.mlir" + # "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/mul_shift.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/negate.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/pad.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/reciprocal.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/reduce.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/reshape.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/rsqrt.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/select.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/sigmoid.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/sub.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/table.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/tanh.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/transpose.mlir" + "${IREE_SOURCE_DIR}/tests/e2e/tosa_ops/while.mlir" + TARGET_BACKEND + "cuda" + DRIVER + "cuda2" + COMPILER_FLAGS + "--iree-input-type=tosa" + RUNNER_ARGS + "--cuda2_use_streams=true" LABELS "noasan" "nomsan"
diff --git a/runtime/src/iree/hal/drivers/cuda/registration/driver_module.c b/runtime/src/iree/hal/drivers/cuda/registration/driver_module.c index 808d6cc..e21326a 100644 --- a/runtime/src/iree/hal/drivers/cuda/registration/driver_module.c +++ b/runtime/src/iree/hal/drivers/cuda/registration/driver_module.c
@@ -6,7 +6,6 @@ #include "iree/hal/drivers/cuda/registration/driver_module.h" -#include <inttypes.h> #include <stddef.h> #include <stdlib.h>