Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 1 | // Copyright 2023 The IREE Authors |
| 2 | // |
| 3 | // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | |
| 7 | #include "experimental/cuda2/cuda_device.h" |
| 8 | |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 9 | #include <stddef.h> |
| 10 | #include <stdint.h> |
| 11 | #include <string.h> |
| 12 | |
| 13 | #include "experimental/cuda2/cuda_allocator.h" |
| 14 | #include "experimental/cuda2/cuda_buffer.h" |
| 15 | #include "experimental/cuda2/cuda_dynamic_symbols.h" |
| 16 | #include "experimental/cuda2/cuda_status_util.h" |
| 17 | #include "experimental/cuda2/memory_pools.h" |
Lei Zhang | 330771e | 2023-06-13 19:39:23 -0400 | [diff] [blame] | 18 | #include "experimental/cuda2/nccl_channel.h" |
| 19 | #include "experimental/cuda2/nccl_dynamic_symbols.h" |
Lei Zhang | 85a1a56 | 2023-06-13 19:51:42 -0400 | [diff] [blame^] | 20 | #include "experimental/cuda2/nop_executable_cache.h" |
Lei Zhang | 45a3eb4 | 2023-06-12 14:16:00 -0400 | [diff] [blame] | 21 | #include "experimental/cuda2/pipeline_layout.h" |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 22 | #include "iree/base/internal/arena.h" |
| 23 | #include "iree/base/internal/math.h" |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 24 | #include "iree/hal/utils/buffer_transfer.h" |
| 25 | #include "iree/hal/utils/deferred_command_buffer.h" |
| 26 | |
| 27 | //===----------------------------------------------------------------------===// |
| 28 | // iree_hal_cuda2_device_t |
| 29 | //===----------------------------------------------------------------------===// |
| 30 | |
| 31 | typedef struct iree_hal_cuda2_device_t { |
| 32 | // Abstract resource used for injecting reference counting and vtable; |
| 33 | // must be at offset 0. |
| 34 | iree_hal_resource_t resource; |
| 35 | iree_string_view_t identifier; |
| 36 | |
| 37 | // Block pool used for command buffers with a larger block size (as command |
| 38 | // buffers can contain inlined data uploads). |
| 39 | iree_arena_block_pool_t block_pool; |
| 40 | |
| 41 | // Optional driver that owns the CUDA symbols. We retain it for our lifetime |
| 42 | // to ensure the symbols remains valid. |
| 43 | iree_hal_driver_t* driver; |
| 44 | |
| 45 | const iree_hal_cuda2_dynamic_symbols_t* cuda_symbols; |
Lei Zhang | 330771e | 2023-06-13 19:39:23 -0400 | [diff] [blame] | 46 | const iree_hal_cuda2_nccl_dynamic_symbols_t* nccl_symbols; |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 47 | |
| 48 | // Parameters used to control device behavior. |
| 49 | iree_hal_cuda2_device_params_t params; |
| 50 | |
| 51 | CUcontext cu_context; |
| 52 | CUdevice cu_device; |
| 53 | // TODO: support multiple streams. |
| 54 | CUstream cu_stream; |
| 55 | |
| 56 | iree_allocator_t host_allocator; |
| 57 | |
| 58 | // Device memory pools and allocators. |
| 59 | bool supports_memory_pools; |
| 60 | iree_hal_cuda2_memory_pools_t memory_pools; |
| 61 | iree_hal_allocator_t* device_allocator; |
Lei Zhang | 330771e | 2023-06-13 19:39:23 -0400 | [diff] [blame] | 62 | |
| 63 | // Optional provider used for creating/configuring collective channels. |
| 64 | iree_hal_channel_provider_t* channel_provider; |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 65 | } iree_hal_cuda2_device_t; |
| 66 | |
| 67 | static const iree_hal_device_vtable_t iree_hal_cuda2_device_vtable; |
| 68 | |
| 69 | static iree_hal_cuda2_device_t* iree_hal_cuda2_device_cast( |
| 70 | iree_hal_device_t* base_value) { |
| 71 | IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_cuda2_device_vtable); |
| 72 | return (iree_hal_cuda2_device_t*)base_value; |
| 73 | } |
| 74 | |
| 75 | static iree_hal_cuda2_device_t* iree_hal_cuda2_device_cast_unsafe( |
| 76 | iree_hal_device_t* base_value) { |
| 77 | return (iree_hal_cuda2_device_t*)base_value; |
| 78 | } |
| 79 | |
| 80 | IREE_API_EXPORT void iree_hal_cuda2_device_params_initialize( |
| 81 | iree_hal_cuda2_device_params_t* out_params) { |
| 82 | memset(out_params, 0, sizeof(*out_params)); |
| 83 | out_params->arena_block_size = 32 * 1024; |
| 84 | out_params->queue_count = 1; |
| 85 | out_params->async_allocations = true; |
| 86 | } |
| 87 | |
| 88 | static iree_status_t iree_hal_cuda2_device_check_params( |
| 89 | const iree_hal_cuda2_device_params_t* params) { |
| 90 | if (params->arena_block_size < 4096) { |
| 91 | return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, |
| 92 | "arena block size too small (< 4096 bytes)"); |
| 93 | } |
| 94 | if (params->queue_count == 0) { |
| 95 | return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, |
| 96 | "at least one queue is required"); |
| 97 | } |
| 98 | return iree_ok_status(); |
| 99 | } |
| 100 | |
| 101 | static iree_status_t iree_hal_cuda2_device_create_internal( |
| 102 | iree_hal_driver_t* driver, iree_string_view_t identifier, |
| 103 | const iree_hal_cuda2_device_params_t* params, CUdevice cu_device, |
| 104 | CUstream stream, CUcontext context, |
Lei Zhang | 330771e | 2023-06-13 19:39:23 -0400 | [diff] [blame] | 105 | const iree_hal_cuda2_dynamic_symbols_t* cuda_symbols, |
| 106 | const iree_hal_cuda2_nccl_dynamic_symbols_t* nccl_symbols, |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 107 | iree_allocator_t host_allocator, iree_hal_device_t** out_device) { |
| 108 | iree_hal_cuda2_device_t* device = NULL; |
| 109 | iree_host_size_t total_size = iree_sizeof_struct(*device) + identifier.size; |
| 110 | IREE_RETURN_IF_ERROR( |
| 111 | iree_allocator_malloc(host_allocator, total_size, (void**)&device)); |
| 112 | memset(device, 0, total_size); |
| 113 | |
| 114 | iree_hal_resource_initialize(&iree_hal_cuda2_device_vtable, |
| 115 | &device->resource); |
| 116 | iree_string_view_append_to_buffer( |
| 117 | identifier, &device->identifier, |
| 118 | (char*)device + iree_sizeof_struct(*device)); |
| 119 | iree_arena_block_pool_initialize(params->arena_block_size, host_allocator, |
| 120 | &device->block_pool); |
| 121 | device->driver = driver; |
| 122 | iree_hal_driver_retain(device->driver); |
Lei Zhang | 330771e | 2023-06-13 19:39:23 -0400 | [diff] [blame] | 123 | device->cuda_symbols = cuda_symbols; |
| 124 | device->nccl_symbols = nccl_symbols; |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 125 | device->params = *params; |
| 126 | device->cu_context = context; |
| 127 | device->cu_device = cu_device; |
| 128 | device->cu_stream = stream; |
| 129 | device->host_allocator = host_allocator; |
| 130 | |
| 131 | iree_status_t status = iree_ok_status(); |
| 132 | |
| 133 | // Memory pool support is conditional. |
| 134 | if (iree_status_is_ok(status) && params->async_allocations) { |
| 135 | int supports_memory_pools = 0; |
| 136 | status = IREE_CURESULT_TO_STATUS( |
Lei Zhang | 330771e | 2023-06-13 19:39:23 -0400 | [diff] [blame] | 137 | cuda_symbols, |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 138 | cuDeviceGetAttribute(&supports_memory_pools, |
| 139 | CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, |
| 140 | cu_device), |
| 141 | "cuDeviceGetAttribute"); |
| 142 | device->supports_memory_pools = supports_memory_pools != 0; |
| 143 | } |
| 144 | |
| 145 | // Create memory pools first so that we can share them with the allocator. |
| 146 | if (iree_status_is_ok(status) && device->supports_memory_pools) { |
| 147 | status = iree_hal_cuda2_memory_pools_initialize( |
Lei Zhang | 330771e | 2023-06-13 19:39:23 -0400 | [diff] [blame] | 148 | cuda_symbols, cu_device, ¶ms->memory_pools, host_allocator, |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 149 | &device->memory_pools); |
| 150 | } |
| 151 | |
| 152 | if (iree_status_is_ok(status)) { |
| 153 | status = iree_hal_cuda2_allocator_create( |
Lei Zhang | 330771e | 2023-06-13 19:39:23 -0400 | [diff] [blame] | 154 | (iree_hal_device_t*)device, cuda_symbols, cu_device, stream, |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 155 | device->supports_memory_pools ? &device->memory_pools : NULL, |
| 156 | host_allocator, &device->device_allocator); |
| 157 | } |
| 158 | |
| 159 | if (iree_status_is_ok(status)) { |
| 160 | *out_device = (iree_hal_device_t*)device; |
| 161 | } else { |
| 162 | iree_hal_device_release((iree_hal_device_t*)device); |
| 163 | } |
| 164 | return status; |
| 165 | } |
| 166 | |
| 167 | iree_status_t iree_hal_cuda2_device_create( |
| 168 | iree_hal_driver_t* driver, iree_string_view_t identifier, |
| 169 | const iree_hal_cuda2_device_params_t* params, |
Lei Zhang | 330771e | 2023-06-13 19:39:23 -0400 | [diff] [blame] | 170 | const iree_hal_cuda2_dynamic_symbols_t* cuda_symbols, |
| 171 | const iree_hal_cuda2_nccl_dynamic_symbols_t* nccl_symbols, CUdevice device, |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 172 | iree_allocator_t host_allocator, iree_hal_device_t** out_device) { |
| 173 | IREE_ASSERT_ARGUMENT(driver); |
| 174 | IREE_ASSERT_ARGUMENT(params); |
Lei Zhang | 330771e | 2023-06-13 19:39:23 -0400 | [diff] [blame] | 175 | IREE_ASSERT_ARGUMENT(cuda_symbols); |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 176 | IREE_ASSERT_ARGUMENT(out_device); |
| 177 | IREE_TRACE_ZONE_BEGIN(z0); |
| 178 | |
| 179 | iree_status_t status = iree_hal_cuda2_device_check_params(params); |
| 180 | |
| 181 | // Get the main context for the device. |
| 182 | CUcontext context = NULL; |
| 183 | if (iree_status_is_ok(status)) { |
| 184 | status = IREE_CURESULT_TO_STATUS( |
Lei Zhang | 330771e | 2023-06-13 19:39:23 -0400 | [diff] [blame] | 185 | cuda_symbols, cuDevicePrimaryCtxRetain(&context, device)); |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 186 | } |
| 187 | if (iree_status_is_ok(status)) { |
Lei Zhang | 330771e | 2023-06-13 19:39:23 -0400 | [diff] [blame] | 188 | status = IREE_CURESULT_TO_STATUS(cuda_symbols, cuCtxSetCurrent(context)); |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 189 | } |
| 190 | |
| 191 | // Create the default stream for the device. |
| 192 | CUstream stream = NULL; |
| 193 | if (iree_status_is_ok(status)) { |
| 194 | status = IREE_CURESULT_TO_STATUS( |
Lei Zhang | 330771e | 2023-06-13 19:39:23 -0400 | [diff] [blame] | 195 | cuda_symbols, cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING)); |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 196 | } |
| 197 | |
| 198 | if (iree_status_is_ok(status)) { |
| 199 | status = iree_hal_cuda2_device_create_internal( |
Lei Zhang | 330771e | 2023-06-13 19:39:23 -0400 | [diff] [blame] | 200 | driver, identifier, params, device, stream, context, cuda_symbols, |
| 201 | nccl_symbols, host_allocator, out_device); |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 202 | } |
| 203 | if (!iree_status_is_ok(status)) { |
Lei Zhang | 330771e | 2023-06-13 19:39:23 -0400 | [diff] [blame] | 204 | if (stream) cuda_symbols->cuStreamDestroy(stream); |
| 205 | if (context) cuda_symbols->cuDevicePrimaryCtxRelease(device); |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 206 | } |
| 207 | |
| 208 | IREE_TRACE_ZONE_END(z0); |
| 209 | return status; |
| 210 | } |
| 211 | |
| 212 | CUcontext iree_hal_cuda2_device_context(iree_hal_device_t* base_device) { |
| 213 | iree_hal_cuda2_device_t* device = |
| 214 | iree_hal_cuda2_device_cast_unsafe(base_device); |
| 215 | return device->cu_context; |
| 216 | } |
| 217 | |
| 218 | const iree_hal_cuda2_dynamic_symbols_t* iree_hal_cuda2_device_dynamic_symbols( |
| 219 | iree_hal_device_t* base_device) { |
| 220 | iree_hal_cuda2_device_t* device = |
| 221 | iree_hal_cuda2_device_cast_unsafe(base_device); |
| 222 | return device->cuda_symbols; |
| 223 | } |
| 224 | |
| 225 | static void iree_hal_cuda2_device_destroy(iree_hal_device_t* base_device) { |
| 226 | iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); |
| 227 | iree_allocator_t host_allocator = iree_hal_device_host_allocator(base_device); |
| 228 | IREE_TRACE_ZONE_BEGIN(z0); |
| 229 | |
| 230 | // There should be no more buffers live that use the allocator. |
| 231 | iree_hal_allocator_release(device->device_allocator); |
| 232 | |
Lei Zhang | 330771e | 2023-06-13 19:39:23 -0400 | [diff] [blame] | 233 | // Buffers may have been retaining collective resources. |
| 234 | iree_hal_channel_provider_release(device->channel_provider); |
| 235 | |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 236 | // Destroy memory pools that hold on to reserved memory. |
| 237 | iree_hal_cuda2_memory_pools_deinitialize(&device->memory_pools); |
| 238 | |
| 239 | // TODO: support multiple streams. |
| 240 | IREE_CUDA_IGNORE_ERROR(device->cuda_symbols, |
| 241 | cuStreamDestroy(device->cu_stream)); |
| 242 | |
| 243 | IREE_CUDA_IGNORE_ERROR(device->cuda_symbols, |
| 244 | cuDevicePrimaryCtxRelease(device->cu_device)); |
| 245 | |
| 246 | iree_arena_block_pool_deinitialize(&device->block_pool); |
| 247 | |
| 248 | // Finally, destroy the device. |
| 249 | iree_hal_driver_release(device->driver); |
| 250 | |
| 251 | iree_allocator_free(host_allocator, device); |
| 252 | |
| 253 | IREE_TRACE_ZONE_END(z0); |
| 254 | } |
| 255 | |
| 256 | static iree_string_view_t iree_hal_cuda2_device_id( |
| 257 | iree_hal_device_t* base_device) { |
| 258 | iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); |
| 259 | return device->identifier; |
| 260 | } |
| 261 | |
| 262 | static iree_allocator_t iree_hal_cuda2_device_host_allocator( |
| 263 | iree_hal_device_t* base_device) { |
| 264 | iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); |
| 265 | return device->host_allocator; |
| 266 | } |
| 267 | |
| 268 | static iree_hal_allocator_t* iree_hal_cuda2_device_allocator( |
| 269 | iree_hal_device_t* base_device) { |
| 270 | iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); |
| 271 | return device->device_allocator; |
| 272 | } |
| 273 | |
| 274 | static void iree_hal_cuda2_replace_device_allocator( |
| 275 | iree_hal_device_t* base_device, iree_hal_allocator_t* new_allocator) { |
| 276 | iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); |
| 277 | iree_hal_allocator_retain(new_allocator); |
| 278 | iree_hal_allocator_release(device->device_allocator); |
| 279 | device->device_allocator = new_allocator; |
| 280 | } |
| 281 | |
| 282 | static void iree_hal_cuda2_replace_channel_provider( |
| 283 | iree_hal_device_t* base_device, iree_hal_channel_provider_t* new_provider) { |
Lei Zhang | 330771e | 2023-06-13 19:39:23 -0400 | [diff] [blame] | 284 | iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); |
| 285 | iree_hal_channel_provider_retain(new_provider); |
| 286 | iree_hal_channel_provider_release(device->channel_provider); |
| 287 | device->channel_provider = new_provider; |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 288 | } |
| 289 | |
| 290 | static iree_status_t iree_hal_cuda2_device_trim( |
| 291 | iree_hal_device_t* base_device) { |
| 292 | iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); |
| 293 | iree_arena_block_pool_trim(&device->block_pool); |
| 294 | IREE_RETURN_IF_ERROR(iree_hal_allocator_trim(device->device_allocator)); |
| 295 | if (device->supports_memory_pools) { |
| 296 | IREE_RETURN_IF_ERROR(iree_hal_cuda2_memory_pools_trim( |
| 297 | &device->memory_pools, &device->params.memory_pools)); |
| 298 | } |
| 299 | return iree_ok_status(); |
| 300 | } |
| 301 | |
| 302 | static iree_status_t iree_hal_cuda2_device_query_attribute( |
| 303 | iree_hal_cuda2_device_t* device, CUdevice_attribute attribute, |
| 304 | int64_t* out_value) { |
| 305 | int value = 0; |
| 306 | IREE_CUDA_RETURN_IF_ERROR( |
| 307 | device->cuda_symbols, |
| 308 | cuDeviceGetAttribute(&value, attribute, device->cu_device), |
| 309 | "cuDeviceGetAttribute"); |
| 310 | *out_value = value; |
| 311 | return iree_ok_status(); |
| 312 | } |
| 313 | |
| 314 | static iree_status_t iree_hal_cuda2_device_query_i64( |
| 315 | iree_hal_device_t* base_device, iree_string_view_t category, |
| 316 | iree_string_view_t key, int64_t* out_value) { |
| 317 | iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); |
| 318 | *out_value = 0; |
| 319 | |
| 320 | if (iree_string_view_equal(category, IREE_SV("hal.executable.format"))) { |
| 321 | *out_value = iree_string_view_equal(key, IREE_SV("cuda-nvptx-fb")) ? 1 : 0; |
| 322 | return iree_ok_status(); |
| 323 | } |
| 324 | |
| 325 | if (iree_string_view_equal(category, IREE_SV("cuda.device"))) { |
| 326 | if (iree_string_view_equal(key, IREE_SV("compute_capability_major"))) { |
| 327 | return iree_hal_cuda2_device_query_attribute( |
| 328 | device, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, out_value); |
| 329 | } else if (iree_string_view_equal(key, |
| 330 | IREE_SV("compute_capability_minor"))) { |
| 331 | return iree_hal_cuda2_device_query_attribute( |
| 332 | device, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, out_value); |
| 333 | } |
| 334 | } |
| 335 | |
| 336 | return iree_make_status( |
| 337 | IREE_STATUS_NOT_FOUND, |
| 338 | "unknown device configuration key value '%.*s :: %.*s'", |
| 339 | (int)category.size, category.data, (int)key.size, key.data); |
| 340 | } |
| 341 | |
| 342 | static iree_status_t iree_hal_cuda2_device_create_channel( |
| 343 | iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, |
| 344 | iree_hal_channel_params_t params, iree_hal_channel_t** out_channel) { |
Lei Zhang | 330771e | 2023-06-13 19:39:23 -0400 | [diff] [blame] | 345 | iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); |
| 346 | if (!device->nccl_symbols || !device->nccl_symbols->dylib) { |
| 347 | return iree_make_status( |
| 348 | IREE_STATUS_UNAVAILABLE, |
| 349 | "NCCL runtime library (%d.%d.%d) not available; ensure installed and " |
| 350 | "the shared library is on your PATH/LD_LIBRARY_PATH " |
| 351 | "(nccl.dll/libnccl.so)", |
| 352 | NCCL_MAJOR, NCCL_MINOR, NCCL_PATCH); |
| 353 | } |
| 354 | |
| 355 | // Today we only allow a single logical device per channel. |
| 356 | // We could multiplex channels but it'd be better to surface that to the |
| 357 | // compiler so that it can emit the right rank math. |
| 358 | int requested_count = iree_math_count_ones_u64(queue_affinity); |
| 359 | // TODO(#12206): properly assign affinity in the compiler. |
| 360 | if (requested_count != 64 && requested_count != 1) { |
| 361 | return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, |
| 362 | "exactly one participant is allowed in a " |
| 363 | "channel but %d were specified", |
| 364 | requested_count); |
| 365 | } |
| 366 | |
| 367 | // Ask the channel provider (if configured) for the default rank and count |
| 368 | // if the user did not set them. |
| 369 | if (device->channel_provider && |
| 370 | (params.rank == IREE_HAL_CHANNEL_RANK_DEFAULT || |
| 371 | params.count == IREE_HAL_CHANNEL_COUNT_DEFAULT)) { |
| 372 | IREE_RETURN_IF_ERROR( |
| 373 | iree_hal_channel_provider_query_default_rank_and_count( |
| 374 | device->channel_provider, ¶ms.rank, ¶ms.count), |
| 375 | "querying default collective group rank and count"); |
| 376 | } |
| 377 | |
| 378 | // An ID is required to initialize NCCL. On the root it'll be the local ID and |
| 379 | // on all other participants it'll be the root ID. |
| 380 | iree_hal_cuda2_nccl_id_t id; |
| 381 | memset(&id, 0, sizeof(id)); |
| 382 | if (iree_const_byte_span_is_empty(params.id)) { |
| 383 | // User wants the default ID. |
| 384 | if (!device->channel_provider) { |
| 385 | return iree_make_status( |
| 386 | IREE_STATUS_INVALID_ARGUMENT, |
| 387 | "default collective channel ID requested but no channel provider has " |
| 388 | "been set on the device to provide it"); |
| 389 | } |
| 390 | if (params.rank == 0) { |
| 391 | // Bootstrap NCCL to get the root ID. |
| 392 | IREE_RETURN_IF_ERROR( |
| 393 | iree_hal_cuda2_nccl_get_unique_id(device->nccl_symbols, &id), |
| 394 | "bootstrapping NCCL root"); |
| 395 | } |
| 396 | // Exchange NCCL ID with all participants. |
| 397 | IREE_RETURN_IF_ERROR(iree_hal_channel_provider_exchange_default_id( |
| 398 | device->channel_provider, |
| 399 | iree_make_byte_span((void*)&id, sizeof(id))), |
| 400 | "exchanging NCCL ID with other participants"); |
| 401 | } else if (params.id.data_length != IREE_ARRAYSIZE(id.data)) { |
| 402 | // User provided something but it's not what we expect. |
| 403 | return iree_make_status( |
| 404 | IREE_STATUS_INVALID_ARGUMENT, |
| 405 | "NCCL ID must be %" PRIhsz |
| 406 | " bytes matching the ncclUniqueId struct but caller provided %" PRIhsz |
| 407 | " bytes", |
| 408 | IREE_ARRAYSIZE(id.data), sizeof(id)); |
| 409 | } else { |
| 410 | // User provided the ID - we treat it as opaque here and let NCCL validate. |
| 411 | memcpy(id.data, params.id.data, IREE_ARRAYSIZE(id.data)); |
| 412 | } |
| 413 | |
| 414 | if (iree_hal_cuda2_nccl_id_is_empty(&id)) { |
| 415 | // TODO: maybe this is ok? a localhost alias or something? |
| 416 | return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, |
| 417 | "no default NCCL ID specified (all zeros)"); |
| 418 | } |
| 419 | |
| 420 | // TODO: when we support multiple logical devices we'll want to pass in the |
| 421 | // context of the device mapped to the queue_affinity. For now since this |
| 422 | // implementation only supports one device we pass in the only one we have. |
| 423 | return iree_hal_cuda2_nccl_channel_create( |
| 424 | device->cuda_symbols, device->nccl_symbols, &id, params.rank, |
| 425 | params.count, device->host_allocator, out_channel); |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 426 | } |
| 427 | |
| 428 | static iree_status_t iree_hal_cuda2_device_create_command_buffer( |
| 429 | iree_hal_device_t* base_device, iree_hal_command_buffer_mode_t mode, |
| 430 | iree_hal_command_category_t command_categories, |
| 431 | iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity, |
| 432 | iree_hal_command_buffer_t** out_command_buffer) { |
| 433 | return iree_make_status(IREE_STATUS_UNIMPLEMENTED, |
| 434 | "command buffer not yet implmeneted"); |
| 435 | } |
| 436 | |
| 437 | static iree_status_t iree_hal_cuda2_device_create_descriptor_set_layout( |
| 438 | iree_hal_device_t* base_device, |
| 439 | iree_hal_descriptor_set_layout_flags_t flags, |
| 440 | iree_host_size_t binding_count, |
| 441 | const iree_hal_descriptor_set_layout_binding_t* bindings, |
| 442 | iree_hal_descriptor_set_layout_t** out_descriptor_set_layout) { |
Lei Zhang | 45a3eb4 | 2023-06-12 14:16:00 -0400 | [diff] [blame] | 443 | iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); |
| 444 | return iree_hal_cuda2_descriptor_set_layout_create( |
| 445 | flags, binding_count, bindings, device->host_allocator, |
| 446 | out_descriptor_set_layout); |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 447 | } |
| 448 | |
| 449 | static iree_status_t iree_hal_cuda2_device_create_event( |
| 450 | iree_hal_device_t* base_device, iree_hal_event_t** out_event) { |
| 451 | return iree_make_status(IREE_STATUS_UNIMPLEMENTED, |
| 452 | "event not yet implmeneted"); |
| 453 | } |
| 454 | |
| 455 | static iree_status_t iree_hal_cuda2_device_create_executable_cache( |
| 456 | iree_hal_device_t* base_device, iree_string_view_t identifier, |
| 457 | iree_loop_t loop, iree_hal_executable_cache_t** out_executable_cache) { |
Lei Zhang | 85a1a56 | 2023-06-13 19:51:42 -0400 | [diff] [blame^] | 458 | iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); |
| 459 | return iree_hal_cuda2_nop_executable_cache_create( |
| 460 | identifier, device->cuda_symbols, device->cu_device, |
| 461 | device->host_allocator, out_executable_cache); |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 462 | } |
| 463 | |
| 464 | static iree_status_t iree_hal_cuda2_device_create_pipeline_layout( |
| 465 | iree_hal_device_t* base_device, iree_host_size_t push_constants, |
| 466 | iree_host_size_t set_layout_count, |
| 467 | iree_hal_descriptor_set_layout_t* const* set_layouts, |
| 468 | iree_hal_pipeline_layout_t** out_pipeline_layout) { |
Lei Zhang | 45a3eb4 | 2023-06-12 14:16:00 -0400 | [diff] [blame] | 469 | iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device); |
| 470 | return iree_hal_cuda2_pipeline_layout_create( |
| 471 | set_layout_count, set_layouts, push_constants, device->host_allocator, |
| 472 | out_pipeline_layout); |
Lei Zhang | c4e01e9 | 2023-06-09 17:38:05 -0400 | [diff] [blame] | 473 | } |
| 474 | |
| 475 | static iree_status_t iree_hal_cuda2_device_create_semaphore( |
| 476 | iree_hal_device_t* base_device, uint64_t initial_value, |
| 477 | iree_hal_semaphore_t** out_semaphore) { |
| 478 | return iree_make_status(IREE_STATUS_UNIMPLEMENTED, |
| 479 | "semaphore not yet implmeneted"); |
| 480 | } |
| 481 | |
| 482 | static iree_hal_semaphore_compatibility_t |
| 483 | iree_hal_cuda2_device_query_semaphore_compatibility( |
| 484 | iree_hal_device_t* base_device, iree_hal_semaphore_t* semaphore) { |
| 485 | // TODO: implement CUDA semaphores. |
| 486 | return IREE_HAL_SEMAPHORE_COMPATIBILITY_NONE; |
| 487 | } |
| 488 | |
| 489 | // TODO: implement multiple streams; today we only have one and queue_affinity |
| 490 | // is ignored. |
| 491 | // TODO: implement proper semaphores in CUDA to ensure ordering and avoid |
| 492 | // the barrier here. |
| 493 | static iree_status_t iree_hal_cuda2_device_queue_alloca( |
| 494 | iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, |
| 495 | const iree_hal_semaphore_list_t wait_semaphore_list, |
| 496 | const iree_hal_semaphore_list_t signal_semaphore_list, |
| 497 | iree_hal_allocator_pool_t pool, iree_hal_buffer_params_t params, |
| 498 | iree_device_size_t allocation_size, |
| 499 | iree_hal_buffer_t** IREE_RESTRICT out_buffer) { |
| 500 | return iree_make_status(IREE_STATUS_UNIMPLEMENTED, |
| 501 | "queue alloca not yet implmeneted"); |
| 502 | } |
| 503 | |
| 504 | // TODO: implement multiple streams; today we only have one and queue_affinity |
| 505 | // is ignored. |
| 506 | // TODO: implement proper semaphores in CUDA to ensure ordering and avoid |
| 507 | // the barrier here. |
| 508 | static iree_status_t iree_hal_cuda2_device_queue_dealloca( |
| 509 | iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, |
| 510 | const iree_hal_semaphore_list_t wait_semaphore_list, |
| 511 | const iree_hal_semaphore_list_t signal_semaphore_list, |
| 512 | iree_hal_buffer_t* buffer) { |
| 513 | return iree_make_status(IREE_STATUS_UNIMPLEMENTED, |
| 514 | "queue dealloca not yet implmeneted"); |
| 515 | } |
| 516 | |
| 517 | static iree_status_t iree_hal_cuda2_device_queue_execute( |
| 518 | iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, |
| 519 | const iree_hal_semaphore_list_t wait_semaphore_list, |
| 520 | const iree_hal_semaphore_list_t signal_semaphore_list, |
| 521 | iree_host_size_t command_buffer_count, |
| 522 | iree_hal_command_buffer_t* const* command_buffers) { |
| 523 | return iree_make_status(IREE_STATUS_UNIMPLEMENTED, |
| 524 | "queue execution not yet implmeneted"); |
| 525 | } |
| 526 | |
| 527 | static iree_status_t iree_hal_cuda2_device_queue_flush( |
| 528 | iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity) { |
| 529 | // Currently unused; we flush as submissions are made. |
| 530 | return iree_ok_status(); |
| 531 | } |
| 532 | |
| 533 | static iree_status_t iree_hal_cuda2_device_wait_semaphores( |
| 534 | iree_hal_device_t* base_device, iree_hal_wait_mode_t wait_mode, |
| 535 | const iree_hal_semaphore_list_t semaphore_list, iree_timeout_t timeout) { |
| 536 | return iree_make_status(IREE_STATUS_UNIMPLEMENTED, |
| 537 | "semaphore not yet implemented"); |
| 538 | } |
| 539 | |
| 540 | static iree_status_t iree_hal_cuda2_device_profiling_begin( |
| 541 | iree_hal_device_t* base_device, |
| 542 | const iree_hal_device_profiling_options_t* options) { |
| 543 | // Unimplemented (and that's ok). |
| 544 | // We could hook in to CUPTI here or use the much simpler cuProfilerStart API. |
| 545 | return iree_ok_status(); |
| 546 | } |
| 547 | |
| 548 | static iree_status_t iree_hal_cuda2_device_profiling_end( |
| 549 | iree_hal_device_t* base_device) { |
| 550 | // Unimplemented (and that's ok). |
| 551 | return iree_ok_status(); |
| 552 | } |
| 553 | |
| 554 | static const iree_hal_device_vtable_t iree_hal_cuda2_device_vtable = { |
| 555 | .destroy = iree_hal_cuda2_device_destroy, |
| 556 | .id = iree_hal_cuda2_device_id, |
| 557 | .host_allocator = iree_hal_cuda2_device_host_allocator, |
| 558 | .device_allocator = iree_hal_cuda2_device_allocator, |
| 559 | .replace_device_allocator = iree_hal_cuda2_replace_device_allocator, |
| 560 | .replace_channel_provider = iree_hal_cuda2_replace_channel_provider, |
| 561 | .trim = iree_hal_cuda2_device_trim, |
| 562 | .query_i64 = iree_hal_cuda2_device_query_i64, |
| 563 | .create_channel = iree_hal_cuda2_device_create_channel, |
| 564 | .create_command_buffer = iree_hal_cuda2_device_create_command_buffer, |
| 565 | .create_descriptor_set_layout = |
| 566 | iree_hal_cuda2_device_create_descriptor_set_layout, |
| 567 | .create_event = iree_hal_cuda2_device_create_event, |
| 568 | .create_executable_cache = iree_hal_cuda2_device_create_executable_cache, |
| 569 | .create_pipeline_layout = iree_hal_cuda2_device_create_pipeline_layout, |
| 570 | .create_semaphore = iree_hal_cuda2_device_create_semaphore, |
| 571 | .query_semaphore_compatibility = |
| 572 | iree_hal_cuda2_device_query_semaphore_compatibility, |
| 573 | .transfer_range = iree_hal_device_submit_transfer_range_and_wait, |
| 574 | .queue_alloca = iree_hal_cuda2_device_queue_alloca, |
| 575 | .queue_dealloca = iree_hal_cuda2_device_queue_dealloca, |
| 576 | .queue_execute = iree_hal_cuda2_device_queue_execute, |
| 577 | .queue_flush = iree_hal_cuda2_device_queue_flush, |
| 578 | .wait_semaphores = iree_hal_cuda2_device_wait_semaphores, |
| 579 | .profiling_begin = iree_hal_cuda2_device_profiling_begin, |
| 580 | .profiling_end = iree_hal_cuda2_device_profiling_end, |
| 581 | }; |