blob: a5e878857e804353cbf3dac40669551fb969e223 [file] [log] [blame]
Lei Zhangc4e01e92023-06-09 17:38:05 -04001// Copyright 2023 The IREE Authors
2//
3// Licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
7#include "experimental/cuda2/cuda_device.h"
8
Lei Zhangc4e01e92023-06-09 17:38:05 -04009#include <stddef.h>
10#include <stdint.h>
11#include <string.h>
12
13#include "experimental/cuda2/cuda_allocator.h"
14#include "experimental/cuda2/cuda_buffer.h"
15#include "experimental/cuda2/cuda_dynamic_symbols.h"
16#include "experimental/cuda2/cuda_status_util.h"
Lei Zhang1bb26cb2023-07-24 19:28:27 -040017#include "experimental/cuda2/event_pool.h"
18#include "experimental/cuda2/event_semaphore.h"
Lei Zhang69a54812023-07-10 23:29:29 -040019#include "experimental/cuda2/graph_command_buffer.h"
Lei Zhangc4e01e92023-06-09 17:38:05 -040020#include "experimental/cuda2/memory_pools.h"
Lei Zhang330771e2023-06-13 19:39:23 -040021#include "experimental/cuda2/nccl_channel.h"
22#include "experimental/cuda2/nccl_dynamic_symbols.h"
Lei Zhang85a1a562023-06-13 19:51:42 -040023#include "experimental/cuda2/nop_executable_cache.h"
Lei Zhang1bb26cb2023-07-24 19:28:27 -040024#include "experimental/cuda2/pending_queue_actions.h"
Lei Zhang45a3eb42023-06-12 14:16:00 -040025#include "experimental/cuda2/pipeline_layout.h"
Lei Zhang1bb26cb2023-07-24 19:28:27 -040026#include "experimental/cuda2/timepoint_pool.h"
Lei Zhang85eb21b2023-06-13 19:55:42 -040027#include "experimental/cuda2/tracing.h"
Lei Zhangc4e01e92023-06-09 17:38:05 -040028#include "iree/base/internal/arena.h"
Lei Zhang1bb26cb2023-07-24 19:28:27 -040029#include "iree/base/internal/event_pool.h"
Lei Zhangc4e01e92023-06-09 17:38:05 -040030#include "iree/base/internal/math.h"
Lei Zhangc4e01e92023-06-09 17:38:05 -040031#include "iree/hal/utils/buffer_transfer.h"
32#include "iree/hal/utils/deferred_command_buffer.h"
Ben Vanikf022d292023-08-15 18:51:46 -070033#include "iree/hal/utils/file_transfer.h"
34#include "iree/hal/utils/memory_file.h"
Lei Zhangc4e01e92023-06-09 17:38:05 -040035
36//===----------------------------------------------------------------------===//
37// iree_hal_cuda2_device_t
38//===----------------------------------------------------------------------===//
39
40typedef struct iree_hal_cuda2_device_t {
41 // Abstract resource used for injecting reference counting and vtable;
42 // must be at offset 0.
43 iree_hal_resource_t resource;
44 iree_string_view_t identifier;
45
46 // Block pool used for command buffers with a larger block size (as command
47 // buffers can contain inlined data uploads).
48 iree_arena_block_pool_t block_pool;
49
50 // Optional driver that owns the CUDA symbols. We retain it for our lifetime
51 // to ensure the symbols remains valid.
52 iree_hal_driver_t* driver;
53
54 const iree_hal_cuda2_dynamic_symbols_t* cuda_symbols;
Lei Zhang330771e2023-06-13 19:39:23 -040055 const iree_hal_cuda2_nccl_dynamic_symbols_t* nccl_symbols;
Lei Zhangc4e01e92023-06-09 17:38:05 -040056
57 // Parameters used to control device behavior.
58 iree_hal_cuda2_device_params_t params;
59
60 CUcontext cu_context;
61 CUdevice cu_device;
Lei Zhang1bb26cb2023-07-24 19:28:27 -040062 // TODO: Support multiple device streams.
63 // The CUstream used to issue device kernels and allocations.
64 CUstream dispatch_cu_stream;
65 // The CUstream used to issue host callback functions.
66 CUstream callback_cu_stream;
Lei Zhangc4e01e92023-06-09 17:38:05 -040067
Lei Zhang85eb21b2023-06-13 19:55:42 -040068 iree_hal_cuda2_tracing_context_t* tracing_context;
69
Lei Zhangc4e01e92023-06-09 17:38:05 -040070 iree_allocator_t host_allocator;
71
Lei Zhang1bb26cb2023-07-24 19:28:27 -040072 // Host/device event pools, used for backing semaphore timepoints.
73 iree_event_pool_t* host_event_pool;
74 iree_hal_cuda2_event_pool_t* device_event_pool;
75 // Timepoint pools, shared by various semaphores.
76 iree_hal_cuda2_timepoint_pool_t* timepoint_pool;
77
78 // A queue to order device workloads and relase to the GPU when constraints
79 // are met. It buffers submissions and allocations internally before they
80 // are ready. This queue couples with HAL semaphores backed by iree_event_t
81 // and CUevent objects.
82 iree_hal_cuda2_pending_queue_actions_t* pending_queue_actions;
83
Lei Zhangc4e01e92023-06-09 17:38:05 -040084 // Device memory pools and allocators.
85 bool supports_memory_pools;
86 iree_hal_cuda2_memory_pools_t memory_pools;
87 iree_hal_allocator_t* device_allocator;
Lei Zhang330771e2023-06-13 19:39:23 -040088
89 // Optional provider used for creating/configuring collective channels.
90 iree_hal_channel_provider_t* channel_provider;
Lei Zhangc4e01e92023-06-09 17:38:05 -040091} iree_hal_cuda2_device_t;
92
93static const iree_hal_device_vtable_t iree_hal_cuda2_device_vtable;
94
95static iree_hal_cuda2_device_t* iree_hal_cuda2_device_cast(
96 iree_hal_device_t* base_value) {
97 IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_cuda2_device_vtable);
98 return (iree_hal_cuda2_device_t*)base_value;
99}
100
101static iree_hal_cuda2_device_t* iree_hal_cuda2_device_cast_unsafe(
102 iree_hal_device_t* base_value) {
103 return (iree_hal_cuda2_device_t*)base_value;
104}
105
106IREE_API_EXPORT void iree_hal_cuda2_device_params_initialize(
107 iree_hal_cuda2_device_params_t* out_params) {
108 memset(out_params, 0, sizeof(*out_params));
109 out_params->arena_block_size = 32 * 1024;
Lei Zhang1bb26cb2023-07-24 19:28:27 -0400110 out_params->event_pool_capacity = 32;
Lei Zhangc4e01e92023-06-09 17:38:05 -0400111 out_params->queue_count = 1;
Lei Zhang85eb21b2023-06-13 19:55:42 -0400112 out_params->stream_tracing = false;
Lei Zhangc4e01e92023-06-09 17:38:05 -0400113 out_params->async_allocations = true;
114}
115
116static iree_status_t iree_hal_cuda2_device_check_params(
117 const iree_hal_cuda2_device_params_t* params) {
118 if (params->arena_block_size < 4096) {
119 return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
120 "arena block size too small (< 4096 bytes)");
121 }
122 if (params->queue_count == 0) {
123 return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
124 "at least one queue is required");
125 }
126 return iree_ok_status();
127}
128
129static iree_status_t iree_hal_cuda2_device_create_internal(
130 iree_hal_driver_t* driver, iree_string_view_t identifier,
131 const iree_hal_cuda2_device_params_t* params, CUdevice cu_device,
Lei Zhang1bb26cb2023-07-24 19:28:27 -0400132 CUstream dispatch_stream, CUstream callback_stream, CUcontext context,
Lei Zhang330771e2023-06-13 19:39:23 -0400133 const iree_hal_cuda2_dynamic_symbols_t* cuda_symbols,
134 const iree_hal_cuda2_nccl_dynamic_symbols_t* nccl_symbols,
Lei Zhangc4e01e92023-06-09 17:38:05 -0400135 iree_allocator_t host_allocator, iree_hal_device_t** out_device) {
136 iree_hal_cuda2_device_t* device = NULL;
137 iree_host_size_t total_size = iree_sizeof_struct(*device) + identifier.size;
138 IREE_RETURN_IF_ERROR(
139 iree_allocator_malloc(host_allocator, total_size, (void**)&device));
140 memset(device, 0, total_size);
141
142 iree_hal_resource_initialize(&iree_hal_cuda2_device_vtable,
143 &device->resource);
144 iree_string_view_append_to_buffer(
145 identifier, &device->identifier,
146 (char*)device + iree_sizeof_struct(*device));
147 iree_arena_block_pool_initialize(params->arena_block_size, host_allocator,
148 &device->block_pool);
149 device->driver = driver;
150 iree_hal_driver_retain(device->driver);
Lei Zhang330771e2023-06-13 19:39:23 -0400151 device->cuda_symbols = cuda_symbols;
152 device->nccl_symbols = nccl_symbols;
Lei Zhangc4e01e92023-06-09 17:38:05 -0400153 device->params = *params;
154 device->cu_context = context;
155 device->cu_device = cu_device;
Lei Zhang1bb26cb2023-07-24 19:28:27 -0400156 device->dispatch_cu_stream = dispatch_stream;
157 device->callback_cu_stream = callback_stream;
Lei Zhangc4e01e92023-06-09 17:38:05 -0400158 device->host_allocator = host_allocator;
Lei Zhang1bb26cb2023-07-24 19:28:27 -0400159
160 iree_status_t status = iree_hal_cuda2_pending_queue_actions_create(
161 cuda_symbols, &device->block_pool, host_allocator,
162 &device->pending_queue_actions);
Lei Zhangc4e01e92023-06-09 17:38:05 -0400163
Lei Zhang85eb21b2023-06-13 19:55:42 -0400164 // Enable tracing for the (currently only) stream - no-op if disabled.
Lei Zhang1bb26cb2023-07-24 19:28:27 -0400165 if (iree_status_is_ok(status) && device->params.stream_tracing) {
Lei Zhang85eb21b2023-06-13 19:55:42 -0400166 status = iree_hal_cuda2_tracing_context_allocate(
Lei Zhang1bb26cb2023-07-24 19:28:27 -0400167 device->cuda_symbols, device->identifier, dispatch_stream,
168 &device->block_pool, host_allocator, &device->tracing_context);
Lei Zhang85eb21b2023-06-13 19:55:42 -0400169 }
Lei Zhangc4e01e92023-06-09 17:38:05 -0400170
171 // Memory pool support is conditional.
172 if (iree_status_is_ok(status) && params->async_allocations) {
173 int supports_memory_pools = 0;
174 status = IREE_CURESULT_TO_STATUS(
Lei Zhang330771e2023-06-13 19:39:23 -0400175 cuda_symbols,
Lei Zhangc4e01e92023-06-09 17:38:05 -0400176 cuDeviceGetAttribute(&supports_memory_pools,
177 CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED,
178 cu_device),
179 "cuDeviceGetAttribute");
180 device->supports_memory_pools = supports_memory_pools != 0;
181 }
182
183 // Create memory pools first so that we can share them with the allocator.
184 if (iree_status_is_ok(status) && device->supports_memory_pools) {
185 status = iree_hal_cuda2_memory_pools_initialize(
Lei Zhang330771e2023-06-13 19:39:23 -0400186 cuda_symbols, cu_device, &params->memory_pools, host_allocator,
Lei Zhangc4e01e92023-06-09 17:38:05 -0400187 &device->memory_pools);
188 }
189
190 if (iree_status_is_ok(status)) {
191 status = iree_hal_cuda2_allocator_create(
Ben Vanik42b983c2023-08-15 21:31:09 -0700192 cuda_symbols, cu_device, dispatch_stream,
Lei Zhangc4e01e92023-06-09 17:38:05 -0400193 device->supports_memory_pools ? &device->memory_pools : NULL,
194 host_allocator, &device->device_allocator);
195 }
196
197 if (iree_status_is_ok(status)) {
198 *out_device = (iree_hal_device_t*)device;
199 } else {
200 iree_hal_device_release((iree_hal_device_t*)device);
201 }
202 return status;
203}
204
205iree_status_t iree_hal_cuda2_device_create(
206 iree_hal_driver_t* driver, iree_string_view_t identifier,
207 const iree_hal_cuda2_device_params_t* params,
Lei Zhang330771e2023-06-13 19:39:23 -0400208 const iree_hal_cuda2_dynamic_symbols_t* cuda_symbols,
209 const iree_hal_cuda2_nccl_dynamic_symbols_t* nccl_symbols, CUdevice device,
Lei Zhangc4e01e92023-06-09 17:38:05 -0400210 iree_allocator_t host_allocator, iree_hal_device_t** out_device) {
211 IREE_ASSERT_ARGUMENT(driver);
212 IREE_ASSERT_ARGUMENT(params);
Lei Zhang330771e2023-06-13 19:39:23 -0400213 IREE_ASSERT_ARGUMENT(cuda_symbols);
Lei Zhangc4e01e92023-06-09 17:38:05 -0400214 IREE_ASSERT_ARGUMENT(out_device);
215 IREE_TRACE_ZONE_BEGIN(z0);
216
217 iree_status_t status = iree_hal_cuda2_device_check_params(params);
218
219 // Get the main context for the device.
220 CUcontext context = NULL;
221 if (iree_status_is_ok(status)) {
222 status = IREE_CURESULT_TO_STATUS(
Lei Zhang330771e2023-06-13 19:39:23 -0400223 cuda_symbols, cuDevicePrimaryCtxRetain(&context, device));
Lei Zhangc4e01e92023-06-09 17:38:05 -0400224 }
225 if (iree_status_is_ok(status)) {
Lei Zhang330771e2023-06-13 19:39:23 -0400226 status = IREE_CURESULT_TO_STATUS(cuda_symbols, cuCtxSetCurrent(context));
Lei Zhangc4e01e92023-06-09 17:38:05 -0400227 }
228
Lei Zhang1bb26cb2023-07-24 19:28:27 -0400229 // Create the default dispatch stream for the device.
230 CUstream dispatch_stream = NULL;
Lei Zhangc4e01e92023-06-09 17:38:05 -0400231 if (iree_status_is_ok(status)) {
232 status = IREE_CURESULT_TO_STATUS(
Lei Zhang1bb26cb2023-07-24 19:28:27 -0400233 cuda_symbols, cuStreamCreate(&dispatch_stream, CU_STREAM_NON_BLOCKING));
234 }
235 // Create the default callback stream for the device.
236 CUstream callback_stream = NULL;
237 if (iree_status_is_ok(status)) {
238 status = IREE_CURESULT_TO_STATUS(
239 cuda_symbols, cuStreamCreate(&callback_stream, CU_STREAM_NON_BLOCKING));
240 }
241
Lei Zhang6e048162023-08-29 21:54:32 -0400242 if (iree_status_is_ok(status)) {
243 status = iree_hal_cuda2_device_create_internal(
244 driver, identifier, params, device, dispatch_stream, callback_stream,
245 context, cuda_symbols, nccl_symbols, host_allocator, out_device);
246 } else {
247 // Release resources we have accquired thus far.
248 if (callback_stream) cuda_symbols->cuStreamDestroy(callback_stream);
249 if (dispatch_stream) cuda_symbols->cuStreamDestroy(dispatch_stream);
250 if (context) cuda_symbols->cuDevicePrimaryCtxRelease(device);
251 }
252
Lei Zhang1bb26cb2023-07-24 19:28:27 -0400253 iree_event_pool_t* host_event_pool = NULL;
254 if (iree_status_is_ok(status)) {
255 status = iree_event_pool_allocate(params->event_pool_capacity,
256 host_allocator, &host_event_pool);
257 }
258
259 iree_hal_cuda2_event_pool_t* device_event_pool = NULL;
260 if (iree_status_is_ok(status)) {
261 status = iree_hal_cuda2_event_pool_allocate(
Lei Zhang6e048162023-08-29 21:54:32 -0400262 *out_device, cuda_symbols, params->event_pool_capacity, host_allocator,
Lei Zhang1bb26cb2023-07-24 19:28:27 -0400263 &device_event_pool);
264 }
265
266 iree_hal_cuda2_timepoint_pool_t* timepoint_pool = NULL;
267 if (iree_status_is_ok(status)) {
268 status = iree_hal_cuda2_timepoint_pool_allocate(
269 host_event_pool, device_event_pool, params->event_pool_capacity,
270 host_allocator, &timepoint_pool);
Lei Zhangc4e01e92023-06-09 17:38:05 -0400271 }
272
273 if (iree_status_is_ok(status)) {
Lei Zhang6e048162023-08-29 21:54:32 -0400274 iree_hal_cuda2_device_t* cuda_device =
275 iree_hal_cuda2_device_cast(*out_device);
276 cuda_device->host_event_pool = host_event_pool;
277 cuda_device->device_event_pool = device_event_pool;
278 cuda_device->timepoint_pool = timepoint_pool;
279 } else {
280 // Release resources we have accquired after HAL device creation.
Lei Zhang1bb26cb2023-07-24 19:28:27 -0400281 if (timepoint_pool) iree_hal_cuda2_timepoint_pool_free(timepoint_pool);
Lei Zhang6e048162023-08-29 21:54:32 -0400282 if (device_event_pool) iree_hal_cuda2_event_pool_release(device_event_pool);
Lei Zhang1bb26cb2023-07-24 19:28:27 -0400283 if (host_event_pool) iree_event_pool_free(host_event_pool);
Lei Zhang6e048162023-08-29 21:54:32 -0400284 // Release other resources via the HAL device.
285 iree_hal_device_release(*out_device);
Lei Zhangc4e01e92023-06-09 17:38:05 -0400286 }
287
288 IREE_TRACE_ZONE_END(z0);
289 return status;
290}
291
292CUcontext iree_hal_cuda2_device_context(iree_hal_device_t* base_device) {
293 iree_hal_cuda2_device_t* device =
294 iree_hal_cuda2_device_cast_unsafe(base_device);
295 return device->cu_context;
296}
297
298const iree_hal_cuda2_dynamic_symbols_t* iree_hal_cuda2_device_dynamic_symbols(
299 iree_hal_device_t* base_device) {
300 iree_hal_cuda2_device_t* device =
301 iree_hal_cuda2_device_cast_unsafe(base_device);
302 return device->cuda_symbols;
303}
304
305static void iree_hal_cuda2_device_destroy(iree_hal_device_t* base_device) {
306 iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
307 iree_allocator_t host_allocator = iree_hal_device_host_allocator(base_device);
Lei Zhang1bb26cb2023-07-24 19:28:27 -0400308 const iree_hal_cuda2_dynamic_symbols_t* symbols = device->cuda_symbols;
Lei Zhangc4e01e92023-06-09 17:38:05 -0400309 IREE_TRACE_ZONE_BEGIN(z0);
310
Lei Zhang1bb26cb2023-07-24 19:28:27 -0400311 // Destroy the pending workload queue.
312 iree_hal_cuda2_pending_queue_actions_destroy(
313 (iree_hal_resource_t*)device->pending_queue_actions);
314
Lei Zhangc4e01e92023-06-09 17:38:05 -0400315 // There should be no more buffers live that use the allocator.
316 iree_hal_allocator_release(device->device_allocator);
317
Lei Zhang330771e2023-06-13 19:39:23 -0400318 // Buffers may have been retaining collective resources.
319 iree_hal_channel_provider_release(device->channel_provider);
320
Lei Zhangc4e01e92023-06-09 17:38:05 -0400321 // Destroy memory pools that hold on to reserved memory.
322 iree_hal_cuda2_memory_pools_deinitialize(&device->memory_pools);
323
Lei Zhang85eb21b2023-06-13 19:55:42 -0400324 iree_hal_cuda2_tracing_context_free(device->tracing_context);
Lei Zhangc4e01e92023-06-09 17:38:05 -0400325
Lei Zhang1bb26cb2023-07-24 19:28:27 -0400326 // Destroy various pools for synchronization.
Lei Zhang6e048162023-08-29 21:54:32 -0400327 if (device->timepoint_pool) {
328 iree_hal_cuda2_timepoint_pool_free(device->timepoint_pool);
329 }
330 if (device->device_event_pool) {
331 iree_hal_cuda2_event_pool_release(device->device_event_pool);
332 }
333 if (device->host_event_pool) iree_event_pool_free(device->host_event_pool);
Lei Zhang1bb26cb2023-07-24 19:28:27 -0400334
335 IREE_CUDA_IGNORE_ERROR(symbols, cuStreamDestroy(device->dispatch_cu_stream));
336 IREE_CUDA_IGNORE_ERROR(symbols, cuStreamDestroy(device->callback_cu_stream));
337
338 IREE_CUDA_IGNORE_ERROR(symbols, cuDevicePrimaryCtxRelease(device->cu_device));
Lei Zhangc4e01e92023-06-09 17:38:05 -0400339
340 iree_arena_block_pool_deinitialize(&device->block_pool);
341
342 // Finally, destroy the device.
343 iree_hal_driver_release(device->driver);
344
345 iree_allocator_free(host_allocator, device);
346
347 IREE_TRACE_ZONE_END(z0);
348}
349
350static iree_string_view_t iree_hal_cuda2_device_id(
351 iree_hal_device_t* base_device) {
352 iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
353 return device->identifier;
354}
355
356static iree_allocator_t iree_hal_cuda2_device_host_allocator(
357 iree_hal_device_t* base_device) {
358 iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
359 return device->host_allocator;
360}
361
362static iree_hal_allocator_t* iree_hal_cuda2_device_allocator(
363 iree_hal_device_t* base_device) {
364 iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
365 return device->device_allocator;
366}
367
368static void iree_hal_cuda2_replace_device_allocator(
369 iree_hal_device_t* base_device, iree_hal_allocator_t* new_allocator) {
370 iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
371 iree_hal_allocator_retain(new_allocator);
372 iree_hal_allocator_release(device->device_allocator);
373 device->device_allocator = new_allocator;
374}
375
376static void iree_hal_cuda2_replace_channel_provider(
377 iree_hal_device_t* base_device, iree_hal_channel_provider_t* new_provider) {
Lei Zhang330771e2023-06-13 19:39:23 -0400378 iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
379 iree_hal_channel_provider_retain(new_provider);
380 iree_hal_channel_provider_release(device->channel_provider);
381 device->channel_provider = new_provider;
Lei Zhangc4e01e92023-06-09 17:38:05 -0400382}
383
384static iree_status_t iree_hal_cuda2_device_trim(
385 iree_hal_device_t* base_device) {
386 iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
387 iree_arena_block_pool_trim(&device->block_pool);
388 IREE_RETURN_IF_ERROR(iree_hal_allocator_trim(device->device_allocator));
389 if (device->supports_memory_pools) {
390 IREE_RETURN_IF_ERROR(iree_hal_cuda2_memory_pools_trim(
391 &device->memory_pools, &device->params.memory_pools));
392 }
393 return iree_ok_status();
394}
395
396static iree_status_t iree_hal_cuda2_device_query_attribute(
397 iree_hal_cuda2_device_t* device, CUdevice_attribute attribute,
398 int64_t* out_value) {
399 int value = 0;
400 IREE_CUDA_RETURN_IF_ERROR(
401 device->cuda_symbols,
402 cuDeviceGetAttribute(&value, attribute, device->cu_device),
403 "cuDeviceGetAttribute");
404 *out_value = value;
405 return iree_ok_status();
406}
407
408static iree_status_t iree_hal_cuda2_device_query_i64(
409 iree_hal_device_t* base_device, iree_string_view_t category,
410 iree_string_view_t key, int64_t* out_value) {
411 iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
412 *out_value = 0;
413
414 if (iree_string_view_equal(category, IREE_SV("hal.executable.format"))) {
415 *out_value = iree_string_view_equal(key, IREE_SV("cuda-nvptx-fb")) ? 1 : 0;
416 return iree_ok_status();
417 }
418
419 if (iree_string_view_equal(category, IREE_SV("cuda.device"))) {
420 if (iree_string_view_equal(key, IREE_SV("compute_capability_major"))) {
421 return iree_hal_cuda2_device_query_attribute(
422 device, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, out_value);
423 } else if (iree_string_view_equal(key,
424 IREE_SV("compute_capability_minor"))) {
425 return iree_hal_cuda2_device_query_attribute(
426 device, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, out_value);
427 }
428 }
429
430 return iree_make_status(
431 IREE_STATUS_NOT_FOUND,
432 "unknown device configuration key value '%.*s :: %.*s'",
433 (int)category.size, category.data, (int)key.size, key.data);
434}
435
436static iree_status_t iree_hal_cuda2_device_create_channel(
437 iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
438 iree_hal_channel_params_t params, iree_hal_channel_t** out_channel) {
Lei Zhang330771e2023-06-13 19:39:23 -0400439 iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
440 if (!device->nccl_symbols || !device->nccl_symbols->dylib) {
441 return iree_make_status(
442 IREE_STATUS_UNAVAILABLE,
443 "NCCL runtime library (%d.%d.%d) not available; ensure installed and "
444 "the shared library is on your PATH/LD_LIBRARY_PATH "
445 "(nccl.dll/libnccl.so)",
446 NCCL_MAJOR, NCCL_MINOR, NCCL_PATCH);
447 }
448
449 // Today we only allow a single logical device per channel.
450 // We could multiplex channels but it'd be better to surface that to the
451 // compiler so that it can emit the right rank math.
452 int requested_count = iree_math_count_ones_u64(queue_affinity);
453 // TODO(#12206): properly assign affinity in the compiler.
454 if (requested_count != 64 && requested_count != 1) {
455 return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
456 "exactly one participant is allowed in a "
457 "channel but %d were specified",
458 requested_count);
459 }
460
461 // Ask the channel provider (if configured) for the default rank and count
462 // if the user did not set them.
463 if (device->channel_provider &&
464 (params.rank == IREE_HAL_CHANNEL_RANK_DEFAULT ||
465 params.count == IREE_HAL_CHANNEL_COUNT_DEFAULT)) {
466 IREE_RETURN_IF_ERROR(
467 iree_hal_channel_provider_query_default_rank_and_count(
468 device->channel_provider, &params.rank, &params.count),
469 "querying default collective group rank and count");
470 }
471
472 // An ID is required to initialize NCCL. On the root it'll be the local ID and
473 // on all other participants it'll be the root ID.
474 iree_hal_cuda2_nccl_id_t id;
475 memset(&id, 0, sizeof(id));
476 if (iree_const_byte_span_is_empty(params.id)) {
477 // User wants the default ID.
478 if (!device->channel_provider) {
479 return iree_make_status(
480 IREE_STATUS_INVALID_ARGUMENT,
481 "default collective channel ID requested but no channel provider has "
482 "been set on the device to provide it");
483 }
484 if (params.rank == 0) {
485 // Bootstrap NCCL to get the root ID.
486 IREE_RETURN_IF_ERROR(
487 iree_hal_cuda2_nccl_get_unique_id(device->nccl_symbols, &id),
488 "bootstrapping NCCL root");
489 }
490 // Exchange NCCL ID with all participants.
491 IREE_RETURN_IF_ERROR(iree_hal_channel_provider_exchange_default_id(
492 device->channel_provider,
493 iree_make_byte_span((void*)&id, sizeof(id))),
494 "exchanging NCCL ID with other participants");
495 } else if (params.id.data_length != IREE_ARRAYSIZE(id.data)) {
496 // User provided something but it's not what we expect.
Scott Todd60b07642023-06-15 09:41:01 -0700497 return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
498 "NCCL ID must be %zu bytes matching the "
499 "ncclUniqueId struct but caller provided %zu bytes",
500 IREE_ARRAYSIZE(id.data), sizeof(id));
Lei Zhang330771e2023-06-13 19:39:23 -0400501 } else {
502 // User provided the ID - we treat it as opaque here and let NCCL validate.
503 memcpy(id.data, params.id.data, IREE_ARRAYSIZE(id.data));
504 }
505
506 if (iree_hal_cuda2_nccl_id_is_empty(&id)) {
507 // TODO: maybe this is ok? a localhost alias or something?
508 return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
509 "no default NCCL ID specified (all zeros)");
510 }
511
512 // TODO: when we support multiple logical devices we'll want to pass in the
513 // context of the device mapped to the queue_affinity. For now since this
514 // implementation only supports one device we pass in the only one we have.
515 return iree_hal_cuda2_nccl_channel_create(
516 device->cuda_symbols, device->nccl_symbols, &id, params.rank,
517 params.count, device->host_allocator, out_channel);
Lei Zhangc4e01e92023-06-09 17:38:05 -0400518}
519
520static iree_status_t iree_hal_cuda2_device_create_command_buffer(
521 iree_hal_device_t* base_device, iree_hal_command_buffer_mode_t mode,
522 iree_hal_command_category_t command_categories,
523 iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity,
524 iree_hal_command_buffer_t** out_command_buffer) {
Lei Zhang69a54812023-07-10 23:29:29 -0400525 iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
526 return iree_hal_cuda2_graph_command_buffer_create(
527 base_device, device->cuda_symbols, device->cu_context, mode,
528 command_categories, queue_affinity, binding_capacity, &device->block_pool,
529 device->host_allocator, out_command_buffer);
Lei Zhangc4e01e92023-06-09 17:38:05 -0400530}
531
532static iree_status_t iree_hal_cuda2_device_create_descriptor_set_layout(
533 iree_hal_device_t* base_device,
534 iree_hal_descriptor_set_layout_flags_t flags,
535 iree_host_size_t binding_count,
536 const iree_hal_descriptor_set_layout_binding_t* bindings,
537 iree_hal_descriptor_set_layout_t** out_descriptor_set_layout) {
Lei Zhang45a3eb42023-06-12 14:16:00 -0400538 iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
539 return iree_hal_cuda2_descriptor_set_layout_create(
540 flags, binding_count, bindings, device->host_allocator,
541 out_descriptor_set_layout);
Lei Zhangc4e01e92023-06-09 17:38:05 -0400542}
543
544static iree_status_t iree_hal_cuda2_device_create_event(
545 iree_hal_device_t* base_device, iree_hal_event_t** out_event) {
546 return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
547 "event not yet implmeneted");
548}
549
550static iree_status_t iree_hal_cuda2_device_create_executable_cache(
551 iree_hal_device_t* base_device, iree_string_view_t identifier,
552 iree_loop_t loop, iree_hal_executable_cache_t** out_executable_cache) {
Lei Zhang85a1a562023-06-13 19:51:42 -0400553 iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
554 return iree_hal_cuda2_nop_executable_cache_create(
555 identifier, device->cuda_symbols, device->cu_device,
556 device->host_allocator, out_executable_cache);
Lei Zhangc4e01e92023-06-09 17:38:05 -0400557}
558
Ben Vanikf022d292023-08-15 18:51:46 -0700559static iree_status_t iree_hal_cuda2_device_import_file(
560 iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
Ben Vanik3add4572023-10-05 09:47:10 -0700561 iree_hal_memory_access_t access, iree_io_file_handle_t* handle,
562 iree_hal_external_file_flags_t flags, iree_hal_file_t** out_file) {
563 if (iree_io_file_handle_type(handle) !=
564 IREE_IO_FILE_HANDLE_TYPE_HOST_ALLOCATION) {
Ben Vanikf022d292023-08-15 18:51:46 -0700565 return iree_make_status(
566 IREE_STATUS_UNAVAILABLE,
567 "implementation does not support the external file type");
568 }
569 return iree_hal_memory_file_wrap(
Ben Vanik3add4572023-10-05 09:47:10 -0700570 queue_affinity, access, handle, iree_hal_device_allocator(base_device),
Ben Vanikf022d292023-08-15 18:51:46 -0700571 iree_hal_device_host_allocator(base_device), out_file);
572}
573
Lei Zhangc4e01e92023-06-09 17:38:05 -0400574static iree_status_t iree_hal_cuda2_device_create_pipeline_layout(
575 iree_hal_device_t* base_device, iree_host_size_t push_constants,
576 iree_host_size_t set_layout_count,
577 iree_hal_descriptor_set_layout_t* const* set_layouts,
578 iree_hal_pipeline_layout_t** out_pipeline_layout) {
Lei Zhang45a3eb42023-06-12 14:16:00 -0400579 iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
580 return iree_hal_cuda2_pipeline_layout_create(
581 set_layout_count, set_layouts, push_constants, device->host_allocator,
582 out_pipeline_layout);
Lei Zhangc4e01e92023-06-09 17:38:05 -0400583}
584
585static iree_status_t iree_hal_cuda2_device_create_semaphore(
586 iree_hal_device_t* base_device, uint64_t initial_value,
587 iree_hal_semaphore_t** out_semaphore) {
Lei Zhang04beef22023-07-10 17:07:36 -0400588 iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
Lei Zhang1bb26cb2023-07-24 19:28:27 -0400589 return iree_hal_cuda2_event_semaphore_create(
590 initial_value, device->cuda_symbols, device->timepoint_pool,
591 device->pending_queue_actions, device->host_allocator, out_semaphore);
Lei Zhangc4e01e92023-06-09 17:38:05 -0400592}
593
594static iree_hal_semaphore_compatibility_t
595iree_hal_cuda2_device_query_semaphore_compatibility(
596 iree_hal_device_t* base_device, iree_hal_semaphore_t* semaphore) {
597 // TODO: implement CUDA semaphores.
Lei Zhang04beef22023-07-10 17:07:36 -0400598 return IREE_HAL_SEMAPHORE_COMPATIBILITY_HOST_ONLY;
Lei Zhangc4e01e92023-06-09 17:38:05 -0400599}
600
601// TODO: implement multiple streams; today we only have one and queue_affinity
602// is ignored.
603// TODO: implement proper semaphores in CUDA to ensure ordering and avoid
604// the barrier here.
605static iree_status_t iree_hal_cuda2_device_queue_alloca(
606 iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
607 const iree_hal_semaphore_list_t wait_semaphore_list,
608 const iree_hal_semaphore_list_t signal_semaphore_list,
609 iree_hal_allocator_pool_t pool, iree_hal_buffer_params_t params,
610 iree_device_size_t allocation_size,
611 iree_hal_buffer_t** IREE_RESTRICT out_buffer) {
Lei Zhang69a54812023-07-10 23:29:29 -0400612 iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
613
614 // NOTE: block on the semaphores here; we could avoid this by properly
615 // sequencing device work with semaphores. The CUDA HAL is not currently
616 // asynchronous.
617 IREE_RETURN_IF_ERROR(iree_hal_semaphore_list_wait(wait_semaphore_list,
618 iree_infinite_timeout()));
619
620 // Allocate from the pool; likely to fail in cases of virtual memory
621 // exhaustion but the error may be deferred until a later synchronization.
622 // If pools are not supported we allocate a buffer as normal from whatever
623 // allocator is set on the device.
624 iree_status_t status = iree_ok_status();
Ben Vanikf022d292023-08-15 18:51:46 -0700625 if (device->supports_memory_pools &&
Ben Vanik63381a82023-10-19 11:09:49 -0700626 !iree_all_bits_set(params.type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) {
Lei Zhang1bb26cb2023-07-24 19:28:27 -0400627 status = iree_hal_cuda2_memory_pools_alloca(
628 &device->memory_pools, device->dispatch_cu_stream, pool, params,
629 allocation_size, out_buffer);
Lei Zhang69a54812023-07-10 23:29:29 -0400630 } else {
631 status = iree_hal_allocator_allocate_buffer(
632 iree_hal_device_allocator(base_device), params, allocation_size,
Ben Vanik42b983c2023-08-15 21:31:09 -0700633 out_buffer);
Lei Zhang69a54812023-07-10 23:29:29 -0400634 }
635
636 // Only signal if not returning a synchronous error - synchronous failure
637 // indicates that the stream is unchanged (it's not really since we waited
638 // above, but we at least won't deadlock like this).
639 if (iree_status_is_ok(status)) {
640 status = iree_hal_semaphore_list_signal(signal_semaphore_list);
641 }
642 return status;
Lei Zhangc4e01e92023-06-09 17:38:05 -0400643}
644
645// TODO: implement multiple streams; today we only have one and queue_affinity
646// is ignored.
647// TODO: implement proper semaphores in CUDA to ensure ordering and avoid
648// the barrier here.
649static iree_status_t iree_hal_cuda2_device_queue_dealloca(
650 iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
651 const iree_hal_semaphore_list_t wait_semaphore_list,
652 const iree_hal_semaphore_list_t signal_semaphore_list,
653 iree_hal_buffer_t* buffer) {
Lei Zhang69a54812023-07-10 23:29:29 -0400654 iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
655
656 // NOTE: block on the semaphores here; we could avoid this by properly
657 // sequencing device work with semaphores. The CUDA HAL is not currently
658 // asynchronous.
659 IREE_RETURN_IF_ERROR(iree_hal_semaphore_list_wait(wait_semaphore_list,
660 iree_infinite_timeout()));
661
662 // Schedule the buffer deallocation if we got it from a pool and otherwise
663 // drop it on the floor and let it be freed when the buffer is released.
664 iree_status_t status = iree_ok_status();
665 if (device->supports_memory_pools) {
Lei Zhang1bb26cb2023-07-24 19:28:27 -0400666 status = iree_hal_cuda2_memory_pools_dealloca(
667 &device->memory_pools, device->dispatch_cu_stream, buffer);
Lei Zhang69a54812023-07-10 23:29:29 -0400668 }
669
670 // Only signal if not returning a synchronous error - synchronous failure
671 // indicates that the stream is unchanged (it's not really since we waited
672 // above, but we at least won't deadlock like this).
673 if (iree_status_is_ok(status)) {
674 status = iree_hal_semaphore_list_signal(signal_semaphore_list);
675 }
676 return status;
Lei Zhangc4e01e92023-06-09 17:38:05 -0400677}
678
Ben Vanikf022d292023-08-15 18:51:46 -0700679static iree_status_t iree_hal_cuda2_device_queue_read(
680 iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
681 const iree_hal_semaphore_list_t wait_semaphore_list,
682 const iree_hal_semaphore_list_t signal_semaphore_list,
683 iree_hal_file_t* source_file, uint64_t source_offset,
684 iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
685 iree_device_size_t length, uint32_t flags) {
686 // TODO: expose streaming chunk count/size options.
687 iree_status_t loop_status = iree_ok_status();
688 iree_hal_file_transfer_options_t options = {
689 .loop = iree_loop_inline(&loop_status),
690 .chunk_count = IREE_HAL_FILE_TRANSFER_CHUNK_COUNT_DEFAULT,
691 .chunk_size = IREE_HAL_FILE_TRANSFER_CHUNK_SIZE_DEFAULT,
692 };
693 IREE_RETURN_IF_ERROR(iree_hal_device_queue_read_streaming(
694 base_device, queue_affinity, wait_semaphore_list, signal_semaphore_list,
695 source_file, source_offset, target_buffer, target_offset, length, flags,
696 options));
697 return loop_status;
698}
699
700static iree_status_t iree_hal_cuda2_device_queue_write(
701 iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
702 const iree_hal_semaphore_list_t wait_semaphore_list,
703 const iree_hal_semaphore_list_t signal_semaphore_list,
704 iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
705 iree_hal_file_t* target_file, uint64_t target_offset,
706 iree_device_size_t length, uint32_t flags) {
707 // TODO: expose streaming chunk count/size options.
708 iree_status_t loop_status = iree_ok_status();
709 iree_hal_file_transfer_options_t options = {
710 .loop = iree_loop_inline(&loop_status),
711 .chunk_count = IREE_HAL_FILE_TRANSFER_CHUNK_COUNT_DEFAULT,
712 .chunk_size = IREE_HAL_FILE_TRANSFER_CHUNK_SIZE_DEFAULT,
713 };
714 IREE_RETURN_IF_ERROR(iree_hal_device_queue_write_streaming(
715 base_device, queue_affinity, wait_semaphore_list, signal_semaphore_list,
716 source_buffer, source_offset, target_file, target_offset, length, flags,
717 options));
718 return loop_status;
719}
720
Lei Zhangc4e01e92023-06-09 17:38:05 -0400721static iree_status_t iree_hal_cuda2_device_queue_execute(
722 iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
723 const iree_hal_semaphore_list_t wait_semaphore_list,
724 const iree_hal_semaphore_list_t signal_semaphore_list,
725 iree_host_size_t command_buffer_count,
726 iree_hal_command_buffer_t* const* command_buffers) {
Lei Zhang69a54812023-07-10 23:29:29 -0400727 iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
Lei Zhang1bb26cb2023-07-24 19:28:27 -0400728 IREE_TRACE_ZONE_BEGIN(z0);
Lei Zhang69a54812023-07-10 23:29:29 -0400729
Lei Zhang1bb26cb2023-07-24 19:28:27 -0400730 iree_status_t status = iree_hal_cuda2_pending_queue_actions_enqueue_execution(
731 device->dispatch_cu_stream, device->callback_cu_stream,
732 device->pending_queue_actions, wait_semaphore_list, signal_semaphore_list,
733 command_buffer_count, command_buffers);
734 if (iree_status_is_ok(status)) {
735 // Try to advance the pending workload queue.
736 status = iree_hal_cuda2_pending_queue_actions_issue(
737 device->pending_queue_actions);
Lei Zhang69a54812023-07-10 23:29:29 -0400738 }
739
Lei Zhang69a54812023-07-10 23:29:29 -0400740 iree_hal_cuda2_tracing_context_collect(device->tracing_context);
741 IREE_TRACE_ZONE_END(z0);
Lei Zhang1bb26cb2023-07-24 19:28:27 -0400742 return status;
Lei Zhangc4e01e92023-06-09 17:38:05 -0400743}
744
745static iree_status_t iree_hal_cuda2_device_queue_flush(
746 iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity) {
Lei Zhang1bb26cb2023-07-24 19:28:27 -0400747 iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
748 IREE_TRACE_ZONE_BEGIN(z0);
749 // Try to advance the pending workload queue.
750 iree_status_t status =
751 iree_hal_cuda2_pending_queue_actions_issue(device->pending_queue_actions);
752 IREE_TRACE_ZONE_END(z0);
753 return status;
Lei Zhangc4e01e92023-06-09 17:38:05 -0400754}
755
756static iree_status_t iree_hal_cuda2_device_wait_semaphores(
757 iree_hal_device_t* base_device, iree_hal_wait_mode_t wait_mode,
758 const iree_hal_semaphore_list_t semaphore_list, iree_timeout_t timeout) {
759 return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
Lei Zhang1bb26cb2023-07-24 19:28:27 -0400760 "waiting multiple semaphores not yet implemented");
Lei Zhangc4e01e92023-06-09 17:38:05 -0400761}
762
763static iree_status_t iree_hal_cuda2_device_profiling_begin(
764 iree_hal_device_t* base_device,
765 const iree_hal_device_profiling_options_t* options) {
766 // Unimplemented (and that's ok).
767 // We could hook in to CUPTI here or use the much simpler cuProfilerStart API.
768 return iree_ok_status();
769}
770
Ben Vanik82be9252023-08-25 11:12:18 -0700771static iree_status_t iree_hal_cuda2_device_profiling_flush(
772 iree_hal_device_t* base_device) {
773 // Unimplemented (and that's ok).
774 return iree_ok_status();
775}
776
Lei Zhangc4e01e92023-06-09 17:38:05 -0400777static iree_status_t iree_hal_cuda2_device_profiling_end(
778 iree_hal_device_t* base_device) {
779 // Unimplemented (and that's ok).
780 return iree_ok_status();
781}
782
783static const iree_hal_device_vtable_t iree_hal_cuda2_device_vtable = {
784 .destroy = iree_hal_cuda2_device_destroy,
785 .id = iree_hal_cuda2_device_id,
786 .host_allocator = iree_hal_cuda2_device_host_allocator,
787 .device_allocator = iree_hal_cuda2_device_allocator,
788 .replace_device_allocator = iree_hal_cuda2_replace_device_allocator,
789 .replace_channel_provider = iree_hal_cuda2_replace_channel_provider,
790 .trim = iree_hal_cuda2_device_trim,
791 .query_i64 = iree_hal_cuda2_device_query_i64,
792 .create_channel = iree_hal_cuda2_device_create_channel,
793 .create_command_buffer = iree_hal_cuda2_device_create_command_buffer,
794 .create_descriptor_set_layout =
795 iree_hal_cuda2_device_create_descriptor_set_layout,
796 .create_event = iree_hal_cuda2_device_create_event,
797 .create_executable_cache = iree_hal_cuda2_device_create_executable_cache,
Ben Vanikf022d292023-08-15 18:51:46 -0700798 .import_file = iree_hal_cuda2_device_import_file,
Lei Zhangc4e01e92023-06-09 17:38:05 -0400799 .create_pipeline_layout = iree_hal_cuda2_device_create_pipeline_layout,
800 .create_semaphore = iree_hal_cuda2_device_create_semaphore,
801 .query_semaphore_compatibility =
802 iree_hal_cuda2_device_query_semaphore_compatibility,
803 .transfer_range = iree_hal_device_submit_transfer_range_and_wait,
804 .queue_alloca = iree_hal_cuda2_device_queue_alloca,
805 .queue_dealloca = iree_hal_cuda2_device_queue_dealloca,
Ben Vanikf022d292023-08-15 18:51:46 -0700806 .queue_read = iree_hal_cuda2_device_queue_read,
807 .queue_write = iree_hal_cuda2_device_queue_write,
Lei Zhangc4e01e92023-06-09 17:38:05 -0400808 .queue_execute = iree_hal_cuda2_device_queue_execute,
809 .queue_flush = iree_hal_cuda2_device_queue_flush,
810 .wait_semaphores = iree_hal_cuda2_device_wait_semaphores,
811 .profiling_begin = iree_hal_cuda2_device_profiling_begin,
Ben Vanik82be9252023-08-25 11:12:18 -0700812 .profiling_flush = iree_hal_cuda2_device_profiling_flush,
Lei Zhangc4e01e92023-06-09 17:38:05 -0400813 .profiling_end = iree_hal_cuda2_device_profiling_end,
814};