Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 1 | // Copyright 2023 The IREE Authors |
| 2 | // |
| 3 | // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | |
| 7 | #include "experimental/cuda2/cuda_device.h" |
| 8 | |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 9 | #include <stddef.h> |
| 10 | #include <stdint.h> |
| 11 | #include <string.h> |
| 12 | |
| 13 | #include "experimental/cuda2/cuda_allocator.h" |
| 14 | #include "experimental/cuda2/cuda_buffer.h" |
| 15 | #include "experimental/cuda2/cuda_dynamic_symbols.h" |
| 16 | #include "experimental/cuda2/cuda_status_util.h" |
Lei Zhang | 1bb26cb | 2023-07-24 19:28:27 -0400 | [diff] [blame] | 17 | #include "experimental/cuda2/event_pool.h" |
| 18 | #include "experimental/cuda2/event_semaphore.h" |
Lei Zhang | 69a5481 | 2023-07-10 23:29:29 -0400 | [diff] [blame] | 19 | #include "experimental/cuda2/graph_command_buffer.h" |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 20 | #include "experimental/cuda2/memory_pools.h" |
Lei Zhang | 330771e | 2023-06-13 19:39:23 -0400 | [diff] [blame] | 21 | #include "experimental/cuda2/nccl_channel.h" |
| 22 | #include "experimental/cuda2/nccl_dynamic_symbols.h" |
Lei Zhang | 85a1a56 | 2023-06-13 19:51:42 -0400 | [diff] [blame] | 23 | #include "experimental/cuda2/nop_executable_cache.h" |
Lei Zhang | 1bb26cb | 2023-07-24 19:28:27 -0400 | [diff] [blame] | 24 | #include "experimental/cuda2/pending_queue_actions.h" |
Lei Zhang | 45a3eb4 | 2023-06-12 14:16:00 -0400 | [diff] [blame] | 25 | #include "experimental/cuda2/pipeline_layout.h" |
Lei Zhang | 1bb26cb | 2023-07-24 19:28:27 -0400 | [diff] [blame] | 26 | #include "experimental/cuda2/timepoint_pool.h" |
Lei Zhang | 85eb21b | 2023-06-13 19:55:42 -0400 | [diff] [blame] | 27 | #include "experimental/cuda2/tracing.h" |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 28 | #include "iree/base/internal/arena.h" |
Lei Zhang | 1bb26cb | 2023-07-24 19:28:27 -0400 | [diff] [blame] | 29 | #include "iree/base/internal/event_pool.h" |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 30 | #include "iree/base/internal/math.h" |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 31 | #include "iree/hal/utils/buffer_transfer.h" |
| 32 | #include "iree/hal/utils/deferred_command_buffer.h" |
Ben Vanik | f022d29 | 2023-08-15 18:51:46 -0700 | [diff] [blame] | 33 | #include "iree/hal/utils/file_transfer.h" |
| 34 | #include "iree/hal/utils/memory_file.h" |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 35 | |
| 36 | //===----------------------------------------------------------------------===// |
| 37 | // iree_hal_cuda2_device_t |
| 38 | //===----------------------------------------------------------------------===// |
| 39 | |
| 40 | typedef struct iree_hal_cuda2_device_t { |
| 41 | // Abstract resource used for injecting reference counting and vtable; |
| 42 | // must be at offset 0. |
| 43 | iree_hal_resource_t resource; |
| 44 | iree_string_view_t identifier; |
| 45 | |
| 46 | // Block pool used for command buffers with a larger block size (as command |
| 47 | // buffers can contain inlined data uploads). |
| 48 | iree_arena_block_pool_t block_pool; |
| 49 | |
| 50 | // Optional driver that owns the CUDA symbols. We retain it for our lifetime |
| 51 | // to ensure the symbols remains valid. |
| 52 | iree_hal_driver_t* driver; |
| 53 | |
| 54 | const iree_hal_cuda2_dynamic_symbols_t* cuda_symbols; |
Lei Zhang | 330771e | 2023-06-13 19:39:23 -0400 | [diff] [blame] | 55 | const iree_hal_cuda2_nccl_dynamic_symbols_t* nccl_symbols; |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 56 | |
| 57 | // Parameters used to control device behavior. |
| 58 | iree_hal_cuda2_device_params_t params; |
| 59 | |
| 60 | CUcontext cu_context; |
| 61 | CUdevice cu_device; |
Lei Zhang | 1bb26cb | 2023-07-24 19:28:27 -0400 | [diff] [blame] | 62 | // TODO: Support multiple device streams. |
| 63 | // The CUstream used to issue device kernels and allocations. |
| 64 | CUstream dispatch_cu_stream; |
| 65 | // The CUstream used to issue host callback functions. |
| 66 | CUstream callback_cu_stream; |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 67 | |
Lei Zhang | 85eb21b | 2023-06-13 19:55:42 -0400 | [diff] [blame] | 68 | iree_hal_cuda2_tracing_context_t* tracing_context; |
| 69 | |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 70 | iree_allocator_t host_allocator; |
| 71 | |
Lei Zhang | 1bb26cb | 2023-07-24 19:28:27 -0400 | [diff] [blame] | 72 | // Host/device event pools, used for backing semaphore timepoints. |
| 73 | iree_event_pool_t* host_event_pool; |
| 74 | iree_hal_cuda2_event_pool_t* device_event_pool; |
| 75 | // Timepoint pools, shared by various semaphores. |
| 76 | iree_hal_cuda2_timepoint_pool_t* timepoint_pool; |
| 77 | |
| 78 | // A queue to order device workloads and relase to the GPU when constraints |
| 79 | // are met. It buffers submissions and allocations internally before they |
| 80 | // are ready. This queue couples with HAL semaphores backed by iree_event_t |
| 81 | // and CUevent objects. |
| 82 | iree_hal_cuda2_pending_queue_actions_t* pending_queue_actions; |
| 83 | |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 84 | // Device memory pools and allocators. |
| 85 | bool supports_memory_pools; |
| 86 | iree_hal_cuda2_memory_pools_t memory_pools; |
| 87 | iree_hal_allocator_t* device_allocator; |
Lei Zhang | 330771e | 2023-06-13 19:39:23 -0400 | [diff] [blame] | 88 | |
| 89 | // Optional provider used for creating/configuring collective channels. |
| 90 | iree_hal_channel_provider_t* channel_provider; |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 91 | } iree_hal_cuda2_device_t; |
| 92 | |
| 93 | static const iree_hal_device_vtable_t iree_hal_cuda2_device_vtable; |
| 94 | |
| 95 | static iree_hal_cuda2_device_t* iree_hal_cuda2_device_cast( |
| 96 | iree_hal_device_t* base_value) { |
| 97 | IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_cuda2_device_vtable); |
| 98 | return (iree_hal_cuda2_device_t*)base_value; |
| 99 | } |
| 100 | |
| 101 | static iree_hal_cuda2_device_t* iree_hal_cuda2_device_cast_unsafe( |
| 102 | iree_hal_device_t* base_value) { |
| 103 | return (iree_hal_cuda2_device_t*)base_value; |
| 104 | } |
| 105 | |
| 106 | IREE_API_EXPORT void iree_hal_cuda2_device_params_initialize( |
| 107 | iree_hal_cuda2_device_params_t* out_params) { |
| 108 | memset(out_params, 0, sizeof(*out_params)); |
| 109 | out_params->arena_block_size = 32 * 1024; |
Lei Zhang | 1bb26cb | 2023-07-24 19:28:27 -0400 | [diff] [blame] | 110 | out_params->event_pool_capacity = 32; |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 111 | out_params->queue_count = 1; |
Lei Zhang | 85eb21b | 2023-06-13 19:55:42 -0400 | [diff] [blame] | 112 | out_params->stream_tracing = false; |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 113 | out_params->async_allocations = true; |
| 114 | } |
| 115 | |
| 116 | static iree_status_t iree_hal_cuda2_device_check_params( |
| 117 | const iree_hal_cuda2_device_params_t* params) { |
| 118 | if (params->arena_block_size < 4096) { |
| 119 | return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, |
| 120 | "arena block size too small (< 4096 bytes)"); |
| 121 | } |
| 122 | if (params->queue_count == 0) { |
| 123 | return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, |
| 124 | "at least one queue is required"); |
| 125 | } |
| 126 | return iree_ok_status(); |
| 127 | } |
| 128 | |
| 129 | static iree_status_t iree_hal_cuda2_device_create_internal( |
| 130 | iree_hal_driver_t* driver, iree_string_view_t identifier, |
| 131 | const iree_hal_cuda2_device_params_t* params, CUdevice cu_device, |
Lei Zhang | 1bb26cb | 2023-07-24 19:28:27 -0400 | [diff] [blame] | 132 | CUstream dispatch_stream, CUstream callback_stream, CUcontext context, |
Lei Zhang | 330771e | 2023-06-13 19:39:23 -0400 | [diff] [blame] | 133 | const iree_hal_cuda2_dynamic_symbols_t* cuda_symbols, |
| 134 | const iree_hal_cuda2_nccl_dynamic_symbols_t* nccl_symbols, |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 135 | iree_allocator_t host_allocator, iree_hal_device_t** out_device) { |
| 136 | iree_hal_cuda2_device_t* device = NULL; |
| 137 | iree_host_size_t total_size = iree_sizeof_struct(*device) + identifier.size; |
| 138 | IREE_RETURN_IF_ERROR( |
| 139 | iree_allocator_malloc(host_allocator, total_size, (void**)&device)); |
| 140 | memset(device, 0, total_size); |
| 141 | |
| 142 | iree_hal_resource_initialize(&iree_hal_cuda2_device_vtable, |
| 143 | &device->resource); |
| 144 | iree_string_view_append_to_buffer( |
| 145 | identifier, &device->identifier, |
| 146 | (char*)device + iree_sizeof_struct(*device)); |
| 147 | iree_arena_block_pool_initialize(params->arena_block_size, host_allocator, |
| 148 | &device->block_pool); |
| 149 | device->driver = driver; |
| 150 | iree_hal_driver_retain(device->driver); |
Lei Zhang | 330771e | 2023-06-13 19:39:23 -0400 | [diff] [blame] | 151 | device->cuda_symbols = cuda_symbols; |
| 152 | device->nccl_symbols = nccl_symbols; |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 153 | device->params = *params; |
| 154 | device->cu_context = context; |
| 155 | device->cu_device = cu_device; |
Lei Zhang | 1bb26cb | 2023-07-24 19:28:27 -0400 | [diff] [blame] | 156 | device->dispatch_cu_stream = dispatch_stream; |
| 157 | device->callback_cu_stream = callback_stream; |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 158 | device->host_allocator = host_allocator; |
Lei Zhang | 1bb26cb | 2023-07-24 19:28:27 -0400 | [diff] [blame] | 159 | |
| 160 | iree_status_t status = iree_hal_cuda2_pending_queue_actions_create( |
| 161 | cuda_symbols, &device->block_pool, host_allocator, |
| 162 | &device->pending_queue_actions); |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 163 | |
Lei Zhang | 85eb21b | 2023-06-13 19:55:42 -0400 | [diff] [blame] | 164 | // Enable tracing for the (currently only) stream - no-op if disabled. |
Lei Zhang | 1bb26cb | 2023-07-24 19:28:27 -0400 | [diff] [blame] | 165 | if (iree_status_is_ok(status) && device->params.stream_tracing) { |
Lei Zhang | 85eb21b | 2023-06-13 19:55:42 -0400 | [diff] [blame] | 166 | status = iree_hal_cuda2_tracing_context_allocate( |
Lei Zhang | 1bb26cb | 2023-07-24 19:28:27 -0400 | [diff] [blame] | 167 | device->cuda_symbols, device->identifier, dispatch_stream, |
| 168 | &device->block_pool, host_allocator, &device->tracing_context); |
Lei Zhang | 85eb21b | 2023-06-13 19:55:42 -0400 | [diff] [blame] | 169 | } |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 170 | |
| 171 | // Memory pool support is conditional. |
| 172 | if (iree_status_is_ok(status) && params->async_allocations) { |
| 173 | int supports_memory_pools = 0; |
| 174 | status = IREE_CURESULT_TO_STATUS( |
Lei Zhang | 330771e | 2023-06-13 19:39:23 -0400 | [diff] [blame] | 175 | cuda_symbols, |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 176 | cuDeviceGetAttribute(&supports_memory_pools, |
| 177 | CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, |
| 178 | cu_device), |
| 179 | "cuDeviceGetAttribute"); |
| 180 | device->supports_memory_pools = supports_memory_pools != 0; |
| 181 | } |
| 182 | |
| 183 | // Create memory pools first so that we can share them with the allocator. |
| 184 | if (iree_status_is_ok(status) && device->supports_memory_pools) { |
| 185 | status = iree_hal_cuda2_memory_pools_initialize( |
Lei Zhang | 330771e | 2023-06-13 19:39:23 -0400 | [diff] [blame] | 186 | cuda_symbols, cu_device, ¶ms->memory_pools, host_allocator, |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 187 | &device->memory_pools); |
| 188 | } |
| 189 | |
| 190 | if (iree_status_is_ok(status)) { |
| 191 | status = iree_hal_cuda2_allocator_create( |
Ben Vanik | 42b983c | 2023-08-15 21:31:09 -0700 | [diff] [blame] | 192 | cuda_symbols, cu_device, dispatch_stream, |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 193 | device->supports_memory_pools ? &device->memory_pools : NULL, |
| 194 | host_allocator, &device->device_allocator); |
| 195 | } |
| 196 | |
| 197 | if (iree_status_is_ok(status)) { |
| 198 | *out_device = (iree_hal_device_t*)device; |
| 199 | } else { |
| 200 | iree_hal_device_release((iree_hal_device_t*)device); |
| 201 | } |
| 202 | return status; |
| 203 | } |
| 204 | |
| 205 | iree_status_t iree_hal_cuda2_device_create( |
| 206 | iree_hal_driver_t* driver, iree_string_view_t identifier, |
| 207 | const iree_hal_cuda2_device_params_t* params, |
Lei Zhang | 330771e | 2023-06-13 19:39:23 -0400 | [diff] [blame] | 208 | const iree_hal_cuda2_dynamic_symbols_t* cuda_symbols, |
| 209 | const iree_hal_cuda2_nccl_dynamic_symbols_t* nccl_symbols, CUdevice device, |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 210 | iree_allocator_t host_allocator, iree_hal_device_t** out_device) { |
| 211 | IREE_ASSERT_ARGUMENT(driver); |
| 212 | IREE_ASSERT_ARGUMENT(params); |
Lei Zhang | 330771e | 2023-06-13 19:39:23 -0400 | [diff] [blame] | 213 | IREE_ASSERT_ARGUMENT(cuda_symbols); |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 214 | IREE_ASSERT_ARGUMENT(out_device); |
| 215 | IREE_TRACE_ZONE_BEGIN(z0); |
| 216 | |
| 217 | iree_status_t status = iree_hal_cuda2_device_check_params(params); |
| 218 | |
| 219 | // Get the main context for the device. |
| 220 | CUcontext context = NULL; |
| 221 | if (iree_status_is_ok(status)) { |
| 222 | status = IREE_CURESULT_TO_STATUS( |
Lei Zhang | 330771e | 2023-06-13 19:39:23 -0400 | [diff] [blame] | 223 | cuda_symbols, cuDevicePrimaryCtxRetain(&context, device)); |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 224 | } |
| 225 | if (iree_status_is_ok(status)) { |
Lei Zhang | 330771e | 2023-06-13 19:39:23 -0400 | [diff] [blame] | 226 | status = IREE_CURESULT_TO_STATUS(cuda_symbols, cuCtxSetCurrent(context)); |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 227 | } |
| 228 | |
Lei Zhang | 1bb26cb | 2023-07-24 19:28:27 -0400 | [diff] [blame] | 229 | // Create the default dispatch stream for the device. |
| 230 | CUstream dispatch_stream = NULL; |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 231 | if (iree_status_is_ok(status)) { |
| 232 | status = IREE_CURESULT_TO_STATUS( |
Lei Zhang | 1bb26cb | 2023-07-24 19:28:27 -0400 | [diff] [blame] | 233 | cuda_symbols, cuStreamCreate(&dispatch_stream, CU_STREAM_NON_BLOCKING)); |
| 234 | } |
| 235 | // Create the default callback stream for the device. |
| 236 | CUstream callback_stream = NULL; |
| 237 | if (iree_status_is_ok(status)) { |
| 238 | status = IREE_CURESULT_TO_STATUS( |
| 239 | cuda_symbols, cuStreamCreate(&callback_stream, CU_STREAM_NON_BLOCKING)); |
| 240 | } |
| 241 | |
Lei Zhang | 6e04816 | 2023-08-29 21:54:32 -0400 | [diff] [blame] | 242 | if (iree_status_is_ok(status)) { |
| 243 | status = iree_hal_cuda2_device_create_internal( |
| 244 | driver, identifier, params, device, dispatch_stream, callback_stream, |
| 245 | context, cuda_symbols, nccl_symbols, host_allocator, out_device); |
| 246 | } else { |
| 247 | // Release resources we have accquired thus far. |
| 248 | if (callback_stream) cuda_symbols->cuStreamDestroy(callback_stream); |
| 249 | if (dispatch_stream) cuda_symbols->cuStreamDestroy(dispatch_stream); |
| 250 | if (context) cuda_symbols->cuDevicePrimaryCtxRelease(device); |
| 251 | } |
| 252 | |
Lei Zhang | 1bb26cb | 2023-07-24 19:28:27 -0400 | [diff] [blame] | 253 | iree_event_pool_t* host_event_pool = NULL; |
| 254 | if (iree_status_is_ok(status)) { |
| 255 | status = iree_event_pool_allocate(params->event_pool_capacity, |
| 256 | host_allocator, &host_event_pool); |
| 257 | } |
| 258 | |
| 259 | iree_hal_cuda2_event_pool_t* device_event_pool = NULL; |
| 260 | if (iree_status_is_ok(status)) { |
| 261 | status = iree_hal_cuda2_event_pool_allocate( |
Lei Zhang | 6e04816 | 2023-08-29 21:54:32 -0400 | [diff] [blame] | 262 | *out_device, cuda_symbols, params->event_pool_capacity, host_allocator, |
Lei Zhang | 1bb26cb | 2023-07-24 19:28:27 -0400 | [diff] [blame] | 263 | &device_event_pool); |
| 264 | } |
| 265 | |
| 266 | iree_hal_cuda2_timepoint_pool_t* timepoint_pool = NULL; |
| 267 | if (iree_status_is_ok(status)) { |
| 268 | status = iree_hal_cuda2_timepoint_pool_allocate( |
| 269 | host_event_pool, device_event_pool, params->event_pool_capacity, |
| 270 | host_allocator, &timepoint_pool); |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 271 | } |
| 272 | |
| 273 | if (iree_status_is_ok(status)) { |
Lei Zhang | 6e04816 | 2023-08-29 21:54:32 -0400 | [diff] [blame] | 274 | iree_hal_cuda2_device_t* cuda_device = |
| 275 | iree_hal_cuda2_device_cast(*out_device); |
| 276 | cuda_device->host_event_pool = host_event_pool; |
| 277 | cuda_device->device_event_pool = device_event_pool; |
| 278 | cuda_device->timepoint_pool = timepoint_pool; |
| 279 | } else { |
| 280 | // Release resources we have accquired after HAL device creation. |
Lei Zhang | 1bb26cb | 2023-07-24 19:28:27 -0400 | [diff] [blame] | 281 | if (timepoint_pool) iree_hal_cuda2_timepoint_pool_free(timepoint_pool); |
Lei Zhang | 6e04816 | 2023-08-29 21:54:32 -0400 | [diff] [blame] | 282 | if (device_event_pool) iree_hal_cuda2_event_pool_release(device_event_pool); |
Lei Zhang | 1bb26cb | 2023-07-24 19:28:27 -0400 | [diff] [blame] | 283 | if (host_event_pool) iree_event_pool_free(host_event_pool); |
Lei Zhang | 6e04816 | 2023-08-29 21:54:32 -0400 | [diff] [blame] | 284 | // Release other resources via the HAL device. |
| 285 | iree_hal_device_release(*out_device); |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 286 | } |
| 287 | |
| 288 | IREE_TRACE_ZONE_END(z0); |
| 289 | return status; |
| 290 | } |
| 291 | |
| 292 | CUcontext iree_hal_cuda2_device_context(iree_hal_device_t* base_device) { |
| 293 | iree_hal_cuda2_device_t* device = |
| 294 | iree_hal_cuda2_device_cast_unsafe(base_device); |
| 295 | return device->cu_context; |
| 296 | } |
| 297 | |
| 298 | const iree_hal_cuda2_dynamic_symbols_t* iree_hal_cuda2_device_dynamic_symbols( |
| 299 | iree_hal_device_t* base_device) { |
| 300 | iree_hal_cuda2_device_t* device = |
| 301 | iree_hal_cuda2_device_cast_unsafe(base_device); |
| 302 | return device->cuda_symbols; |
| 303 | } |
| 304 | |
| 305 | static void iree_hal_cuda2_device_destroy(iree_hal_device_t* base_device) { |
| 306 | iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); |
| 307 | iree_allocator_t host_allocator = iree_hal_device_host_allocator(base_device); |
Lei Zhang | 1bb26cb | 2023-07-24 19:28:27 -0400 | [diff] [blame] | 308 | const iree_hal_cuda2_dynamic_symbols_t* symbols = device->cuda_symbols; |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 309 | IREE_TRACE_ZONE_BEGIN(z0); |
| 310 | |
Lei Zhang | 1bb26cb | 2023-07-24 19:28:27 -0400 | [diff] [blame] | 311 | // Destroy the pending workload queue. |
| 312 | iree_hal_cuda2_pending_queue_actions_destroy( |
| 313 | (iree_hal_resource_t*)device->pending_queue_actions); |
| 314 | |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 315 | // There should be no more buffers live that use the allocator. |
| 316 | iree_hal_allocator_release(device->device_allocator); |
| 317 | |
Lei Zhang | 330771e | 2023-06-13 19:39:23 -0400 | [diff] [blame] | 318 | // Buffers may have been retaining collective resources. |
| 319 | iree_hal_channel_provider_release(device->channel_provider); |
| 320 | |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 321 | // Destroy memory pools that hold on to reserved memory. |
| 322 | iree_hal_cuda2_memory_pools_deinitialize(&device->memory_pools); |
| 323 | |
Lei Zhang | 85eb21b | 2023-06-13 19:55:42 -0400 | [diff] [blame] | 324 | iree_hal_cuda2_tracing_context_free(device->tracing_context); |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 325 | |
Lei Zhang | 1bb26cb | 2023-07-24 19:28:27 -0400 | [diff] [blame] | 326 | // Destroy various pools for synchronization. |
Lei Zhang | 6e04816 | 2023-08-29 21:54:32 -0400 | [diff] [blame] | 327 | if (device->timepoint_pool) { |
| 328 | iree_hal_cuda2_timepoint_pool_free(device->timepoint_pool); |
| 329 | } |
| 330 | if (device->device_event_pool) { |
| 331 | iree_hal_cuda2_event_pool_release(device->device_event_pool); |
| 332 | } |
| 333 | if (device->host_event_pool) iree_event_pool_free(device->host_event_pool); |
Lei Zhang | 1bb26cb | 2023-07-24 19:28:27 -0400 | [diff] [blame] | 334 | |
| 335 | IREE_CUDA_IGNORE_ERROR(symbols, cuStreamDestroy(device->dispatch_cu_stream)); |
| 336 | IREE_CUDA_IGNORE_ERROR(symbols, cuStreamDestroy(device->callback_cu_stream)); |
| 337 | |
| 338 | IREE_CUDA_IGNORE_ERROR(symbols, cuDevicePrimaryCtxRelease(device->cu_device)); |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 339 | |
| 340 | iree_arena_block_pool_deinitialize(&device->block_pool); |
| 341 | |
| 342 | // Finally, destroy the device. |
| 343 | iree_hal_driver_release(device->driver); |
| 344 | |
| 345 | iree_allocator_free(host_allocator, device); |
| 346 | |
| 347 | IREE_TRACE_ZONE_END(z0); |
| 348 | } |
| 349 | |
| 350 | static iree_string_view_t iree_hal_cuda2_device_id( |
| 351 | iree_hal_device_t* base_device) { |
| 352 | iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); |
| 353 | return device->identifier; |
| 354 | } |
| 355 | |
| 356 | static iree_allocator_t iree_hal_cuda2_device_host_allocator( |
| 357 | iree_hal_device_t* base_device) { |
| 358 | iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); |
| 359 | return device->host_allocator; |
| 360 | } |
| 361 | |
| 362 | static iree_hal_allocator_t* iree_hal_cuda2_device_allocator( |
| 363 | iree_hal_device_t* base_device) { |
| 364 | iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); |
| 365 | return device->device_allocator; |
| 366 | } |
| 367 | |
| 368 | static void iree_hal_cuda2_replace_device_allocator( |
| 369 | iree_hal_device_t* base_device, iree_hal_allocator_t* new_allocator) { |
| 370 | iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); |
| 371 | iree_hal_allocator_retain(new_allocator); |
| 372 | iree_hal_allocator_release(device->device_allocator); |
| 373 | device->device_allocator = new_allocator; |
| 374 | } |
| 375 | |
| 376 | static void iree_hal_cuda2_replace_channel_provider( |
| 377 | iree_hal_device_t* base_device, iree_hal_channel_provider_t* new_provider) { |
Lei Zhang | 330771e | 2023-06-13 19:39:23 -0400 | [diff] [blame] | 378 | iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); |
| 379 | iree_hal_channel_provider_retain(new_provider); |
| 380 | iree_hal_channel_provider_release(device->channel_provider); |
| 381 | device->channel_provider = new_provider; |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 382 | } |
| 383 | |
| 384 | static iree_status_t iree_hal_cuda2_device_trim( |
| 385 | iree_hal_device_t* base_device) { |
| 386 | iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); |
| 387 | iree_arena_block_pool_trim(&device->block_pool); |
| 388 | IREE_RETURN_IF_ERROR(iree_hal_allocator_trim(device->device_allocator)); |
| 389 | if (device->supports_memory_pools) { |
| 390 | IREE_RETURN_IF_ERROR(iree_hal_cuda2_memory_pools_trim( |
| 391 | &device->memory_pools, &device->params.memory_pools)); |
| 392 | } |
| 393 | return iree_ok_status(); |
| 394 | } |
| 395 | |
| 396 | static iree_status_t iree_hal_cuda2_device_query_attribute( |
| 397 | iree_hal_cuda2_device_t* device, CUdevice_attribute attribute, |
| 398 | int64_t* out_value) { |
| 399 | int value = 0; |
| 400 | IREE_CUDA_RETURN_IF_ERROR( |
| 401 | device->cuda_symbols, |
| 402 | cuDeviceGetAttribute(&value, attribute, device->cu_device), |
| 403 | "cuDeviceGetAttribute"); |
| 404 | *out_value = value; |
| 405 | return iree_ok_status(); |
| 406 | } |
| 407 | |
| 408 | static iree_status_t iree_hal_cuda2_device_query_i64( |
| 409 | iree_hal_device_t* base_device, iree_string_view_t category, |
| 410 | iree_string_view_t key, int64_t* out_value) { |
| 411 | iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); |
| 412 | *out_value = 0; |
| 413 | |
| 414 | if (iree_string_view_equal(category, IREE_SV("hal.executable.format"))) { |
| 415 | *out_value = iree_string_view_equal(key, IREE_SV("cuda-nvptx-fb")) ? 1 : 0; |
| 416 | return iree_ok_status(); |
| 417 | } |
| 418 | |
| 419 | if (iree_string_view_equal(category, IREE_SV("cuda.device"))) { |
| 420 | if (iree_string_view_equal(key, IREE_SV("compute_capability_major"))) { |
| 421 | return iree_hal_cuda2_device_query_attribute( |
| 422 | device, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, out_value); |
| 423 | } else if (iree_string_view_equal(key, |
| 424 | IREE_SV("compute_capability_minor"))) { |
| 425 | return iree_hal_cuda2_device_query_attribute( |
| 426 | device, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, out_value); |
| 427 | } |
| 428 | } |
| 429 | |
| 430 | return iree_make_status( |
| 431 | IREE_STATUS_NOT_FOUND, |
| 432 | "unknown device configuration key value '%.*s :: %.*s'", |
| 433 | (int)category.size, category.data, (int)key.size, key.data); |
| 434 | } |
| 435 | |
| 436 | static iree_status_t iree_hal_cuda2_device_create_channel( |
| 437 | iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, |
| 438 | iree_hal_channel_params_t params, iree_hal_channel_t** out_channel) { |
Lei Zhang | 330771e | 2023-06-13 19:39:23 -0400 | [diff] [blame] | 439 | iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); |
| 440 | if (!device->nccl_symbols || !device->nccl_symbols->dylib) { |
| 441 | return iree_make_status( |
| 442 | IREE_STATUS_UNAVAILABLE, |
| 443 | "NCCL runtime library (%d.%d.%d) not available; ensure installed and " |
| 444 | "the shared library is on your PATH/LD_LIBRARY_PATH " |
| 445 | "(nccl.dll/libnccl.so)", |
| 446 | NCCL_MAJOR, NCCL_MINOR, NCCL_PATCH); |
| 447 | } |
| 448 | |
| 449 | // Today we only allow a single logical device per channel. |
| 450 | // We could multiplex channels but it'd be better to surface that to the |
| 451 | // compiler so that it can emit the right rank math. |
| 452 | int requested_count = iree_math_count_ones_u64(queue_affinity); |
| 453 | // TODO(#12206): properly assign affinity in the compiler. |
| 454 | if (requested_count != 64 && requested_count != 1) { |
| 455 | return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, |
| 456 | "exactly one participant is allowed in a " |
| 457 | "channel but %d were specified", |
| 458 | requested_count); |
| 459 | } |
| 460 | |
| 461 | // Ask the channel provider (if configured) for the default rank and count |
| 462 | // if the user did not set them. |
| 463 | if (device->channel_provider && |
| 464 | (params.rank == IREE_HAL_CHANNEL_RANK_DEFAULT || |
| 465 | params.count == IREE_HAL_CHANNEL_COUNT_DEFAULT)) { |
| 466 | IREE_RETURN_IF_ERROR( |
| 467 | iree_hal_channel_provider_query_default_rank_and_count( |
| 468 | device->channel_provider, ¶ms.rank, ¶ms.count), |
| 469 | "querying default collective group rank and count"); |
| 470 | } |
| 471 | |
| 472 | // An ID is required to initialize NCCL. On the root it'll be the local ID and |
| 473 | // on all other participants it'll be the root ID. |
| 474 | iree_hal_cuda2_nccl_id_t id; |
| 475 | memset(&id, 0, sizeof(id)); |
| 476 | if (iree_const_byte_span_is_empty(params.id)) { |
| 477 | // User wants the default ID. |
| 478 | if (!device->channel_provider) { |
| 479 | return iree_make_status( |
| 480 | IREE_STATUS_INVALID_ARGUMENT, |
| 481 | "default collective channel ID requested but no channel provider has " |
| 482 | "been set on the device to provide it"); |
| 483 | } |
| 484 | if (params.rank == 0) { |
| 485 | // Bootstrap NCCL to get the root ID. |
| 486 | IREE_RETURN_IF_ERROR( |
| 487 | iree_hal_cuda2_nccl_get_unique_id(device->nccl_symbols, &id), |
| 488 | "bootstrapping NCCL root"); |
| 489 | } |
| 490 | // Exchange NCCL ID with all participants. |
| 491 | IREE_RETURN_IF_ERROR(iree_hal_channel_provider_exchange_default_id( |
| 492 | device->channel_provider, |
| 493 | iree_make_byte_span((void*)&id, sizeof(id))), |
| 494 | "exchanging NCCL ID with other participants"); |
| 495 | } else if (params.id.data_length != IREE_ARRAYSIZE(id.data)) { |
| 496 | // User provided something but it's not what we expect. |
Scott Todd | 60b0764 | 2023-06-15 09:41:01 -0700 | [diff] [blame] | 497 | return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, |
| 498 | "NCCL ID must be %zu bytes matching the " |
| 499 | "ncclUniqueId struct but caller provided %zu bytes", |
| 500 | IREE_ARRAYSIZE(id.data), sizeof(id)); |
Lei Zhang | 330771e | 2023-06-13 19:39:23 -0400 | [diff] [blame] | 501 | } else { |
| 502 | // User provided the ID - we treat it as opaque here and let NCCL validate. |
| 503 | memcpy(id.data, params.id.data, IREE_ARRAYSIZE(id.data)); |
| 504 | } |
| 505 | |
| 506 | if (iree_hal_cuda2_nccl_id_is_empty(&id)) { |
| 507 | // TODO: maybe this is ok? a localhost alias or something? |
| 508 | return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, |
| 509 | "no default NCCL ID specified (all zeros)"); |
| 510 | } |
| 511 | |
| 512 | // TODO: when we support multiple logical devices we'll want to pass in the |
| 513 | // context of the device mapped to the queue_affinity. For now since this |
| 514 | // implementation only supports one device we pass in the only one we have. |
| 515 | return iree_hal_cuda2_nccl_channel_create( |
| 516 | device->cuda_symbols, device->nccl_symbols, &id, params.rank, |
| 517 | params.count, device->host_allocator, out_channel); |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 518 | } |
| 519 | |
| 520 | static iree_status_t iree_hal_cuda2_device_create_command_buffer( |
| 521 | iree_hal_device_t* base_device, iree_hal_command_buffer_mode_t mode, |
| 522 | iree_hal_command_category_t command_categories, |
| 523 | iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity, |
| 524 | iree_hal_command_buffer_t** out_command_buffer) { |
Lei Zhang | 69a5481 | 2023-07-10 23:29:29 -0400 | [diff] [blame] | 525 | iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); |
| 526 | return iree_hal_cuda2_graph_command_buffer_create( |
| 527 | base_device, device->cuda_symbols, device->cu_context, mode, |
| 528 | command_categories, queue_affinity, binding_capacity, &device->block_pool, |
| 529 | device->host_allocator, out_command_buffer); |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 530 | } |
| 531 | |
| 532 | static iree_status_t iree_hal_cuda2_device_create_descriptor_set_layout( |
| 533 | iree_hal_device_t* base_device, |
| 534 | iree_hal_descriptor_set_layout_flags_t flags, |
| 535 | iree_host_size_t binding_count, |
| 536 | const iree_hal_descriptor_set_layout_binding_t* bindings, |
| 537 | iree_hal_descriptor_set_layout_t** out_descriptor_set_layout) { |
Lei Zhang | 45a3eb4 | 2023-06-12 14:16:00 -0400 | [diff] [blame] | 538 | iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); |
| 539 | return iree_hal_cuda2_descriptor_set_layout_create( |
| 540 | flags, binding_count, bindings, device->host_allocator, |
| 541 | out_descriptor_set_layout); |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 542 | } |
| 543 | |
| 544 | static iree_status_t iree_hal_cuda2_device_create_event( |
| 545 | iree_hal_device_t* base_device, iree_hal_event_t** out_event) { |
| 546 | return iree_make_status(IREE_STATUS_UNIMPLEMENTED, |
| 547 | "event not yet implmeneted"); |
| 548 | } |
| 549 | |
| 550 | static iree_status_t iree_hal_cuda2_device_create_executable_cache( |
| 551 | iree_hal_device_t* base_device, iree_string_view_t identifier, |
| 552 | iree_loop_t loop, iree_hal_executable_cache_t** out_executable_cache) { |
Lei Zhang | 85a1a56 | 2023-06-13 19:51:42 -0400 | [diff] [blame] | 553 | iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); |
| 554 | return iree_hal_cuda2_nop_executable_cache_create( |
| 555 | identifier, device->cuda_symbols, device->cu_device, |
| 556 | device->host_allocator, out_executable_cache); |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 557 | } |
| 558 | |
Ben Vanik | f022d29 | 2023-08-15 18:51:46 -0700 | [diff] [blame] | 559 | static iree_status_t iree_hal_cuda2_device_import_file( |
| 560 | iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, |
Ben Vanik | 3add457 | 2023-10-05 09:47:10 -0700 | [diff] [blame] | 561 | iree_hal_memory_access_t access, iree_io_file_handle_t* handle, |
| 562 | iree_hal_external_file_flags_t flags, iree_hal_file_t** out_file) { |
| 563 | if (iree_io_file_handle_type(handle) != |
| 564 | IREE_IO_FILE_HANDLE_TYPE_HOST_ALLOCATION) { |
Ben Vanik | f022d29 | 2023-08-15 18:51:46 -0700 | [diff] [blame] | 565 | return iree_make_status( |
| 566 | IREE_STATUS_UNAVAILABLE, |
| 567 | "implementation does not support the external file type"); |
| 568 | } |
| 569 | return iree_hal_memory_file_wrap( |
Ben Vanik | 3add457 | 2023-10-05 09:47:10 -0700 | [diff] [blame] | 570 | queue_affinity, access, handle, iree_hal_device_allocator(base_device), |
Ben Vanik | f022d29 | 2023-08-15 18:51:46 -0700 | [diff] [blame] | 571 | iree_hal_device_host_allocator(base_device), out_file); |
| 572 | } |
| 573 | |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 574 | static iree_status_t iree_hal_cuda2_device_create_pipeline_layout( |
| 575 | iree_hal_device_t* base_device, iree_host_size_t push_constants, |
| 576 | iree_host_size_t set_layout_count, |
| 577 | iree_hal_descriptor_set_layout_t* const* set_layouts, |
| 578 | iree_hal_pipeline_layout_t** out_pipeline_layout) { |
Lei Zhang | 45a3eb4 | 2023-06-12 14:16:00 -0400 | [diff] [blame] | 579 | iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); |
| 580 | return iree_hal_cuda2_pipeline_layout_create( |
| 581 | set_layout_count, set_layouts, push_constants, device->host_allocator, |
| 582 | out_pipeline_layout); |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 583 | } |
| 584 | |
| 585 | static iree_status_t iree_hal_cuda2_device_create_semaphore( |
| 586 | iree_hal_device_t* base_device, uint64_t initial_value, |
| 587 | iree_hal_semaphore_t** out_semaphore) { |
Lei Zhang | 04beef2 | 2023-07-10 17:07:36 -0400 | [diff] [blame] | 588 | iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); |
Lei Zhang | 1bb26cb | 2023-07-24 19:28:27 -0400 | [diff] [blame] | 589 | return iree_hal_cuda2_event_semaphore_create( |
| 590 | initial_value, device->cuda_symbols, device->timepoint_pool, |
| 591 | device->pending_queue_actions, device->host_allocator, out_semaphore); |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 592 | } |
| 593 | |
| 594 | static iree_hal_semaphore_compatibility_t |
| 595 | iree_hal_cuda2_device_query_semaphore_compatibility( |
| 596 | iree_hal_device_t* base_device, iree_hal_semaphore_t* semaphore) { |
| 597 | // TODO: implement CUDA semaphores. |
Lei Zhang | 04beef2 | 2023-07-10 17:07:36 -0400 | [diff] [blame] | 598 | return IREE_HAL_SEMAPHORE_COMPATIBILITY_HOST_ONLY; |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 599 | } |
| 600 | |
| 601 | // TODO: implement multiple streams; today we only have one and queue_affinity |
| 602 | // is ignored. |
| 603 | // TODO: implement proper semaphores in CUDA to ensure ordering and avoid |
| 604 | // the barrier here. |
| 605 | static iree_status_t iree_hal_cuda2_device_queue_alloca( |
| 606 | iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, |
| 607 | const iree_hal_semaphore_list_t wait_semaphore_list, |
| 608 | const iree_hal_semaphore_list_t signal_semaphore_list, |
| 609 | iree_hal_allocator_pool_t pool, iree_hal_buffer_params_t params, |
| 610 | iree_device_size_t allocation_size, |
| 611 | iree_hal_buffer_t** IREE_RESTRICT out_buffer) { |
Lei Zhang | 69a5481 | 2023-07-10 23:29:29 -0400 | [diff] [blame] | 612 | iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); |
| 613 | |
| 614 | // NOTE: block on the semaphores here; we could avoid this by properly |
| 615 | // sequencing device work with semaphores. The CUDA HAL is not currently |
| 616 | // asynchronous. |
| 617 | IREE_RETURN_IF_ERROR(iree_hal_semaphore_list_wait(wait_semaphore_list, |
| 618 | iree_infinite_timeout())); |
| 619 | |
| 620 | // Allocate from the pool; likely to fail in cases of virtual memory |
| 621 | // exhaustion but the error may be deferred until a later synchronization. |
| 622 | // If pools are not supported we allocate a buffer as normal from whatever |
| 623 | // allocator is set on the device. |
| 624 | iree_status_t status = iree_ok_status(); |
Ben Vanik | f022d29 | 2023-08-15 18:51:46 -0700 | [diff] [blame] | 625 | if (device->supports_memory_pools && |
Ben Vanik | 63381a8 | 2023-10-19 11:09:49 -0700 | [diff] [blame^] | 626 | !iree_all_bits_set(params.type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) { |
Lei Zhang | 1bb26cb | 2023-07-24 19:28:27 -0400 | [diff] [blame] | 627 | status = iree_hal_cuda2_memory_pools_alloca( |
| 628 | &device->memory_pools, device->dispatch_cu_stream, pool, params, |
| 629 | allocation_size, out_buffer); |
Lei Zhang | 69a5481 | 2023-07-10 23:29:29 -0400 | [diff] [blame] | 630 | } else { |
| 631 | status = iree_hal_allocator_allocate_buffer( |
| 632 | iree_hal_device_allocator(base_device), params, allocation_size, |
Ben Vanik | 42b983c | 2023-08-15 21:31:09 -0700 | [diff] [blame] | 633 | out_buffer); |
Lei Zhang | 69a5481 | 2023-07-10 23:29:29 -0400 | [diff] [blame] | 634 | } |
| 635 | |
| 636 | // Only signal if not returning a synchronous error - synchronous failure |
| 637 | // indicates that the stream is unchanged (it's not really since we waited |
| 638 | // above, but we at least won't deadlock like this). |
| 639 | if (iree_status_is_ok(status)) { |
| 640 | status = iree_hal_semaphore_list_signal(signal_semaphore_list); |
| 641 | } |
| 642 | return status; |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 643 | } |
| 644 | |
| 645 | // TODO: implement multiple streams; today we only have one and queue_affinity |
| 646 | // is ignored. |
| 647 | // TODO: implement proper semaphores in CUDA to ensure ordering and avoid |
| 648 | // the barrier here. |
| 649 | static iree_status_t iree_hal_cuda2_device_queue_dealloca( |
| 650 | iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, |
| 651 | const iree_hal_semaphore_list_t wait_semaphore_list, |
| 652 | const iree_hal_semaphore_list_t signal_semaphore_list, |
| 653 | iree_hal_buffer_t* buffer) { |
Lei Zhang | 69a5481 | 2023-07-10 23:29:29 -0400 | [diff] [blame] | 654 | iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); |
| 655 | |
| 656 | // NOTE: block on the semaphores here; we could avoid this by properly |
| 657 | // sequencing device work with semaphores. The CUDA HAL is not currently |
| 658 | // asynchronous. |
| 659 | IREE_RETURN_IF_ERROR(iree_hal_semaphore_list_wait(wait_semaphore_list, |
| 660 | iree_infinite_timeout())); |
| 661 | |
| 662 | // Schedule the buffer deallocation if we got it from a pool and otherwise |
| 663 | // drop it on the floor and let it be freed when the buffer is released. |
| 664 | iree_status_t status = iree_ok_status(); |
| 665 | if (device->supports_memory_pools) { |
Lei Zhang | 1bb26cb | 2023-07-24 19:28:27 -0400 | [diff] [blame] | 666 | status = iree_hal_cuda2_memory_pools_dealloca( |
| 667 | &device->memory_pools, device->dispatch_cu_stream, buffer); |
Lei Zhang | 69a5481 | 2023-07-10 23:29:29 -0400 | [diff] [blame] | 668 | } |
| 669 | |
| 670 | // Only signal if not returning a synchronous error - synchronous failure |
| 671 | // indicates that the stream is unchanged (it's not really since we waited |
| 672 | // above, but we at least won't deadlock like this). |
| 673 | if (iree_status_is_ok(status)) { |
| 674 | status = iree_hal_semaphore_list_signal(signal_semaphore_list); |
| 675 | } |
| 676 | return status; |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 677 | } |
| 678 | |
Ben Vanik | f022d29 | 2023-08-15 18:51:46 -0700 | [diff] [blame] | 679 | static iree_status_t iree_hal_cuda2_device_queue_read( |
| 680 | iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, |
| 681 | const iree_hal_semaphore_list_t wait_semaphore_list, |
| 682 | const iree_hal_semaphore_list_t signal_semaphore_list, |
| 683 | iree_hal_file_t* source_file, uint64_t source_offset, |
| 684 | iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, |
| 685 | iree_device_size_t length, uint32_t flags) { |
| 686 | // TODO: expose streaming chunk count/size options. |
| 687 | iree_status_t loop_status = iree_ok_status(); |
| 688 | iree_hal_file_transfer_options_t options = { |
| 689 | .loop = iree_loop_inline(&loop_status), |
| 690 | .chunk_count = IREE_HAL_FILE_TRANSFER_CHUNK_COUNT_DEFAULT, |
| 691 | .chunk_size = IREE_HAL_FILE_TRANSFER_CHUNK_SIZE_DEFAULT, |
| 692 | }; |
| 693 | IREE_RETURN_IF_ERROR(iree_hal_device_queue_read_streaming( |
| 694 | base_device, queue_affinity, wait_semaphore_list, signal_semaphore_list, |
| 695 | source_file, source_offset, target_buffer, target_offset, length, flags, |
| 696 | options)); |
| 697 | return loop_status; |
| 698 | } |
| 699 | |
| 700 | static iree_status_t iree_hal_cuda2_device_queue_write( |
| 701 | iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, |
| 702 | const iree_hal_semaphore_list_t wait_semaphore_list, |
| 703 | const iree_hal_semaphore_list_t signal_semaphore_list, |
| 704 | iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, |
| 705 | iree_hal_file_t* target_file, uint64_t target_offset, |
| 706 | iree_device_size_t length, uint32_t flags) { |
| 707 | // TODO: expose streaming chunk count/size options. |
| 708 | iree_status_t loop_status = iree_ok_status(); |
| 709 | iree_hal_file_transfer_options_t options = { |
| 710 | .loop = iree_loop_inline(&loop_status), |
| 711 | .chunk_count = IREE_HAL_FILE_TRANSFER_CHUNK_COUNT_DEFAULT, |
| 712 | .chunk_size = IREE_HAL_FILE_TRANSFER_CHUNK_SIZE_DEFAULT, |
| 713 | }; |
| 714 | IREE_RETURN_IF_ERROR(iree_hal_device_queue_write_streaming( |
| 715 | base_device, queue_affinity, wait_semaphore_list, signal_semaphore_list, |
| 716 | source_buffer, source_offset, target_file, target_offset, length, flags, |
| 717 | options)); |
| 718 | return loop_status; |
| 719 | } |
| 720 | |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 721 | static iree_status_t iree_hal_cuda2_device_queue_execute( |
| 722 | iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, |
| 723 | const iree_hal_semaphore_list_t wait_semaphore_list, |
| 724 | const iree_hal_semaphore_list_t signal_semaphore_list, |
| 725 | iree_host_size_t command_buffer_count, |
| 726 | iree_hal_command_buffer_t* const* command_buffers) { |
Lei Zhang | 69a5481 | 2023-07-10 23:29:29 -0400 | [diff] [blame] | 727 | iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); |
Lei Zhang | 1bb26cb | 2023-07-24 19:28:27 -0400 | [diff] [blame] | 728 | IREE_TRACE_ZONE_BEGIN(z0); |
Lei Zhang | 69a5481 | 2023-07-10 23:29:29 -0400 | [diff] [blame] | 729 | |
Lei Zhang | 1bb26cb | 2023-07-24 19:28:27 -0400 | [diff] [blame] | 730 | iree_status_t status = iree_hal_cuda2_pending_queue_actions_enqueue_execution( |
| 731 | device->dispatch_cu_stream, device->callback_cu_stream, |
| 732 | device->pending_queue_actions, wait_semaphore_list, signal_semaphore_list, |
| 733 | command_buffer_count, command_buffers); |
| 734 | if (iree_status_is_ok(status)) { |
| 735 | // Try to advance the pending workload queue. |
| 736 | status = iree_hal_cuda2_pending_queue_actions_issue( |
| 737 | device->pending_queue_actions); |
Lei Zhang | 69a5481 | 2023-07-10 23:29:29 -0400 | [diff] [blame] | 738 | } |
| 739 | |
Lei Zhang | 69a5481 | 2023-07-10 23:29:29 -0400 | [diff] [blame] | 740 | iree_hal_cuda2_tracing_context_collect(device->tracing_context); |
| 741 | IREE_TRACE_ZONE_END(z0); |
Lei Zhang | 1bb26cb | 2023-07-24 19:28:27 -0400 | [diff] [blame] | 742 | return status; |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 743 | } |
| 744 | |
| 745 | static iree_status_t iree_hal_cuda2_device_queue_flush( |
| 746 | iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity) { |
Lei Zhang | 1bb26cb | 2023-07-24 19:28:27 -0400 | [diff] [blame] | 747 | iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); |
| 748 | IREE_TRACE_ZONE_BEGIN(z0); |
| 749 | // Try to advance the pending workload queue. |
| 750 | iree_status_t status = |
| 751 | iree_hal_cuda2_pending_queue_actions_issue(device->pending_queue_actions); |
| 752 | IREE_TRACE_ZONE_END(z0); |
| 753 | return status; |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 754 | } |
| 755 | |
| 756 | static iree_status_t iree_hal_cuda2_device_wait_semaphores( |
| 757 | iree_hal_device_t* base_device, iree_hal_wait_mode_t wait_mode, |
| 758 | const iree_hal_semaphore_list_t semaphore_list, iree_timeout_t timeout) { |
| 759 | return iree_make_status(IREE_STATUS_UNIMPLEMENTED, |
Lei Zhang | 1bb26cb | 2023-07-24 19:28:27 -0400 | [diff] [blame] | 760 | "waiting multiple semaphores not yet implemented"); |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 761 | } |
| 762 | |
| 763 | static iree_status_t iree_hal_cuda2_device_profiling_begin( |
| 764 | iree_hal_device_t* base_device, |
| 765 | const iree_hal_device_profiling_options_t* options) { |
| 766 | // Unimplemented (and that's ok). |
| 767 | // We could hook in to CUPTI here or use the much simpler cuProfilerStart API. |
| 768 | return iree_ok_status(); |
| 769 | } |
| 770 | |
Ben Vanik | 82be925 | 2023-08-25 11:12:18 -0700 | [diff] [blame] | 771 | static iree_status_t iree_hal_cuda2_device_profiling_flush( |
| 772 | iree_hal_device_t* base_device) { |
| 773 | // Unimplemented (and that's ok). |
| 774 | return iree_ok_status(); |
| 775 | } |
| 776 | |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 777 | static iree_status_t iree_hal_cuda2_device_profiling_end( |
| 778 | iree_hal_device_t* base_device) { |
| 779 | // Unimplemented (and that's ok). |
| 780 | return iree_ok_status(); |
| 781 | } |
| 782 | |
| 783 | static const iree_hal_device_vtable_t iree_hal_cuda2_device_vtable = { |
| 784 | .destroy = iree_hal_cuda2_device_destroy, |
| 785 | .id = iree_hal_cuda2_device_id, |
| 786 | .host_allocator = iree_hal_cuda2_device_host_allocator, |
| 787 | .device_allocator = iree_hal_cuda2_device_allocator, |
| 788 | .replace_device_allocator = iree_hal_cuda2_replace_device_allocator, |
| 789 | .replace_channel_provider = iree_hal_cuda2_replace_channel_provider, |
| 790 | .trim = iree_hal_cuda2_device_trim, |
| 791 | .query_i64 = iree_hal_cuda2_device_query_i64, |
| 792 | .create_channel = iree_hal_cuda2_device_create_channel, |
| 793 | .create_command_buffer = iree_hal_cuda2_device_create_command_buffer, |
| 794 | .create_descriptor_set_layout = |
| 795 | iree_hal_cuda2_device_create_descriptor_set_layout, |
| 796 | .create_event = iree_hal_cuda2_device_create_event, |
| 797 | .create_executable_cache = iree_hal_cuda2_device_create_executable_cache, |
Ben Vanik | f022d29 | 2023-08-15 18:51:46 -0700 | [diff] [blame] | 798 | .import_file = iree_hal_cuda2_device_import_file, |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 799 | .create_pipeline_layout = iree_hal_cuda2_device_create_pipeline_layout, |
| 800 | .create_semaphore = iree_hal_cuda2_device_create_semaphore, |
| 801 | .query_semaphore_compatibility = |
| 802 | iree_hal_cuda2_device_query_semaphore_compatibility, |
| 803 | .transfer_range = iree_hal_device_submit_transfer_range_and_wait, |
| 804 | .queue_alloca = iree_hal_cuda2_device_queue_alloca, |
| 805 | .queue_dealloca = iree_hal_cuda2_device_queue_dealloca, |
Ben Vanik | f022d29 | 2023-08-15 18:51:46 -0700 | [diff] [blame] | 806 | .queue_read = iree_hal_cuda2_device_queue_read, |
| 807 | .queue_write = iree_hal_cuda2_device_queue_write, |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 808 | .queue_execute = iree_hal_cuda2_device_queue_execute, |
| 809 | .queue_flush = iree_hal_cuda2_device_queue_flush, |
| 810 | .wait_semaphores = iree_hal_cuda2_device_wait_semaphores, |
| 811 | .profiling_begin = iree_hal_cuda2_device_profiling_begin, |
Ben Vanik | 82be925 | 2023-08-25 11:12:18 -0700 | [diff] [blame] | 812 | .profiling_flush = iree_hal_cuda2_device_profiling_flush, |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 813 | .profiling_end = iree_hal_cuda2_device_profiling_end, |
| 814 | }; |