blob: 36bf8ab8fdf65f1f4c2454dc4e0f43b6f05a516a [file] [log] [blame]
Lei Zhangc4e01e92023-06-09 17:38:05 -04001// Copyright 2023 The IREE Authors
2//
3// Licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
7#include "experimental/cuda2/cuda_device.h"
8
Lei Zhangc4e01e92023-06-09 17:38:05 -04009#include <stddef.h>
10#include <stdint.h>
11#include <string.h>
12
13#include "experimental/cuda2/cuda_allocator.h"
14#include "experimental/cuda2/cuda_buffer.h"
15#include "experimental/cuda2/cuda_dynamic_symbols.h"
16#include "experimental/cuda2/cuda_status_util.h"
17#include "experimental/cuda2/memory_pools.h"
Lei Zhang330771e2023-06-13 19:39:23 -040018#include "experimental/cuda2/nccl_channel.h"
19#include "experimental/cuda2/nccl_dynamic_symbols.h"
Lei Zhang85a1a562023-06-13 19:51:42 -040020#include "experimental/cuda2/nop_executable_cache.h"
Lei Zhang45a3eb42023-06-12 14:16:00 -040021#include "experimental/cuda2/pipeline_layout.h"
Lei Zhangc4e01e92023-06-09 17:38:05 -040022#include "iree/base/internal/arena.h"
23#include "iree/base/internal/math.h"
Lei Zhangc4e01e92023-06-09 17:38:05 -040024#include "iree/hal/utils/buffer_transfer.h"
25#include "iree/hal/utils/deferred_command_buffer.h"
26
27//===----------------------------------------------------------------------===//
28// iree_hal_cuda2_device_t
29//===----------------------------------------------------------------------===//
30
31typedef struct iree_hal_cuda2_device_t {
32 // Abstract resource used for injecting reference counting and vtable;
33 // must be at offset 0.
34 iree_hal_resource_t resource;
35 iree_string_view_t identifier;
36
37 // Block pool used for command buffers with a larger block size (as command
38 // buffers can contain inlined data uploads).
39 iree_arena_block_pool_t block_pool;
40
41 // Optional driver that owns the CUDA symbols. We retain it for our lifetime
42 // to ensure the symbols remains valid.
43 iree_hal_driver_t* driver;
44
45 const iree_hal_cuda2_dynamic_symbols_t* cuda_symbols;
Lei Zhang330771e2023-06-13 19:39:23 -040046 const iree_hal_cuda2_nccl_dynamic_symbols_t* nccl_symbols;
Lei Zhangc4e01e92023-06-09 17:38:05 -040047
48 // Parameters used to control device behavior.
49 iree_hal_cuda2_device_params_t params;
50
51 CUcontext cu_context;
52 CUdevice cu_device;
53 // TODO: support multiple streams.
54 CUstream cu_stream;
55
56 iree_allocator_t host_allocator;
57
58 // Device memory pools and allocators.
59 bool supports_memory_pools;
60 iree_hal_cuda2_memory_pools_t memory_pools;
61 iree_hal_allocator_t* device_allocator;
Lei Zhang330771e2023-06-13 19:39:23 -040062
63 // Optional provider used for creating/configuring collective channels.
64 iree_hal_channel_provider_t* channel_provider;
Lei Zhangc4e01e92023-06-09 17:38:05 -040065} iree_hal_cuda2_device_t;
66
67static const iree_hal_device_vtable_t iree_hal_cuda2_device_vtable;
68
69static iree_hal_cuda2_device_t* iree_hal_cuda2_device_cast(
70 iree_hal_device_t* base_value) {
71 IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_cuda2_device_vtable);
72 return (iree_hal_cuda2_device_t*)base_value;
73}
74
75static iree_hal_cuda2_device_t* iree_hal_cuda2_device_cast_unsafe(
76 iree_hal_device_t* base_value) {
77 return (iree_hal_cuda2_device_t*)base_value;
78}
79
80IREE_API_EXPORT void iree_hal_cuda2_device_params_initialize(
81 iree_hal_cuda2_device_params_t* out_params) {
82 memset(out_params, 0, sizeof(*out_params));
83 out_params->arena_block_size = 32 * 1024;
84 out_params->queue_count = 1;
85 out_params->async_allocations = true;
86}
87
88static iree_status_t iree_hal_cuda2_device_check_params(
89 const iree_hal_cuda2_device_params_t* params) {
90 if (params->arena_block_size < 4096) {
91 return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
92 "arena block size too small (< 4096 bytes)");
93 }
94 if (params->queue_count == 0) {
95 return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
96 "at least one queue is required");
97 }
98 return iree_ok_status();
99}
100
101static iree_status_t iree_hal_cuda2_device_create_internal(
102 iree_hal_driver_t* driver, iree_string_view_t identifier,
103 const iree_hal_cuda2_device_params_t* params, CUdevice cu_device,
104 CUstream stream, CUcontext context,
Lei Zhang330771e2023-06-13 19:39:23 -0400105 const iree_hal_cuda2_dynamic_symbols_t* cuda_symbols,
106 const iree_hal_cuda2_nccl_dynamic_symbols_t* nccl_symbols,
Lei Zhangc4e01e92023-06-09 17:38:05 -0400107 iree_allocator_t host_allocator, iree_hal_device_t** out_device) {
108 iree_hal_cuda2_device_t* device = NULL;
109 iree_host_size_t total_size = iree_sizeof_struct(*device) + identifier.size;
110 IREE_RETURN_IF_ERROR(
111 iree_allocator_malloc(host_allocator, total_size, (void**)&device));
112 memset(device, 0, total_size);
113
114 iree_hal_resource_initialize(&iree_hal_cuda2_device_vtable,
115 &device->resource);
116 iree_string_view_append_to_buffer(
117 identifier, &device->identifier,
118 (char*)device + iree_sizeof_struct(*device));
119 iree_arena_block_pool_initialize(params->arena_block_size, host_allocator,
120 &device->block_pool);
121 device->driver = driver;
122 iree_hal_driver_retain(device->driver);
Lei Zhang330771e2023-06-13 19:39:23 -0400123 device->cuda_symbols = cuda_symbols;
124 device->nccl_symbols = nccl_symbols;
Lei Zhangc4e01e92023-06-09 17:38:05 -0400125 device->params = *params;
126 device->cu_context = context;
127 device->cu_device = cu_device;
128 device->cu_stream = stream;
129 device->host_allocator = host_allocator;
130
131 iree_status_t status = iree_ok_status();
132
133 // Memory pool support is conditional.
134 if (iree_status_is_ok(status) && params->async_allocations) {
135 int supports_memory_pools = 0;
136 status = IREE_CURESULT_TO_STATUS(
Lei Zhang330771e2023-06-13 19:39:23 -0400137 cuda_symbols,
Lei Zhangc4e01e92023-06-09 17:38:05 -0400138 cuDeviceGetAttribute(&supports_memory_pools,
139 CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED,
140 cu_device),
141 "cuDeviceGetAttribute");
142 device->supports_memory_pools = supports_memory_pools != 0;
143 }
144
145 // Create memory pools first so that we can share them with the allocator.
146 if (iree_status_is_ok(status) && device->supports_memory_pools) {
147 status = iree_hal_cuda2_memory_pools_initialize(
Lei Zhang330771e2023-06-13 19:39:23 -0400148 cuda_symbols, cu_device, &params->memory_pools, host_allocator,
Lei Zhangc4e01e92023-06-09 17:38:05 -0400149 &device->memory_pools);
150 }
151
152 if (iree_status_is_ok(status)) {
153 status = iree_hal_cuda2_allocator_create(
Lei Zhang330771e2023-06-13 19:39:23 -0400154 (iree_hal_device_t*)device, cuda_symbols, cu_device, stream,
Lei Zhangc4e01e92023-06-09 17:38:05 -0400155 device->supports_memory_pools ? &device->memory_pools : NULL,
156 host_allocator, &device->device_allocator);
157 }
158
159 if (iree_status_is_ok(status)) {
160 *out_device = (iree_hal_device_t*)device;
161 } else {
162 iree_hal_device_release((iree_hal_device_t*)device);
163 }
164 return status;
165}
166
167iree_status_t iree_hal_cuda2_device_create(
168 iree_hal_driver_t* driver, iree_string_view_t identifier,
169 const iree_hal_cuda2_device_params_t* params,
Lei Zhang330771e2023-06-13 19:39:23 -0400170 const iree_hal_cuda2_dynamic_symbols_t* cuda_symbols,
171 const iree_hal_cuda2_nccl_dynamic_symbols_t* nccl_symbols, CUdevice device,
Lei Zhangc4e01e92023-06-09 17:38:05 -0400172 iree_allocator_t host_allocator, iree_hal_device_t** out_device) {
173 IREE_ASSERT_ARGUMENT(driver);
174 IREE_ASSERT_ARGUMENT(params);
Lei Zhang330771e2023-06-13 19:39:23 -0400175 IREE_ASSERT_ARGUMENT(cuda_symbols);
Lei Zhangc4e01e92023-06-09 17:38:05 -0400176 IREE_ASSERT_ARGUMENT(out_device);
177 IREE_TRACE_ZONE_BEGIN(z0);
178
179 iree_status_t status = iree_hal_cuda2_device_check_params(params);
180
181 // Get the main context for the device.
182 CUcontext context = NULL;
183 if (iree_status_is_ok(status)) {
184 status = IREE_CURESULT_TO_STATUS(
Lei Zhang330771e2023-06-13 19:39:23 -0400185 cuda_symbols, cuDevicePrimaryCtxRetain(&context, device));
Lei Zhangc4e01e92023-06-09 17:38:05 -0400186 }
187 if (iree_status_is_ok(status)) {
Lei Zhang330771e2023-06-13 19:39:23 -0400188 status = IREE_CURESULT_TO_STATUS(cuda_symbols, cuCtxSetCurrent(context));
Lei Zhangc4e01e92023-06-09 17:38:05 -0400189 }
190
191 // Create the default stream for the device.
192 CUstream stream = NULL;
193 if (iree_status_is_ok(status)) {
194 status = IREE_CURESULT_TO_STATUS(
Lei Zhang330771e2023-06-13 19:39:23 -0400195 cuda_symbols, cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING));
Lei Zhangc4e01e92023-06-09 17:38:05 -0400196 }
197
198 if (iree_status_is_ok(status)) {
199 status = iree_hal_cuda2_device_create_internal(
Lei Zhang330771e2023-06-13 19:39:23 -0400200 driver, identifier, params, device, stream, context, cuda_symbols,
201 nccl_symbols, host_allocator, out_device);
Lei Zhangc4e01e92023-06-09 17:38:05 -0400202 }
203 if (!iree_status_is_ok(status)) {
Lei Zhang330771e2023-06-13 19:39:23 -0400204 if (stream) cuda_symbols->cuStreamDestroy(stream);
205 if (context) cuda_symbols->cuDevicePrimaryCtxRelease(device);
Lei Zhangc4e01e92023-06-09 17:38:05 -0400206 }
207
208 IREE_TRACE_ZONE_END(z0);
209 return status;
210}
211
212CUcontext iree_hal_cuda2_device_context(iree_hal_device_t* base_device) {
213 iree_hal_cuda2_device_t* device =
214 iree_hal_cuda2_device_cast_unsafe(base_device);
215 return device->cu_context;
216}
217
218const iree_hal_cuda2_dynamic_symbols_t* iree_hal_cuda2_device_dynamic_symbols(
219 iree_hal_device_t* base_device) {
220 iree_hal_cuda2_device_t* device =
221 iree_hal_cuda2_device_cast_unsafe(base_device);
222 return device->cuda_symbols;
223}
224
225static void iree_hal_cuda2_device_destroy(iree_hal_device_t* base_device) {
226 iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
227 iree_allocator_t host_allocator = iree_hal_device_host_allocator(base_device);
228 IREE_TRACE_ZONE_BEGIN(z0);
229
230 // There should be no more buffers live that use the allocator.
231 iree_hal_allocator_release(device->device_allocator);
232
Lei Zhang330771e2023-06-13 19:39:23 -0400233 // Buffers may have been retaining collective resources.
234 iree_hal_channel_provider_release(device->channel_provider);
235
Lei Zhangc4e01e92023-06-09 17:38:05 -0400236 // Destroy memory pools that hold on to reserved memory.
237 iree_hal_cuda2_memory_pools_deinitialize(&device->memory_pools);
238
239 // TODO: support multiple streams.
240 IREE_CUDA_IGNORE_ERROR(device->cuda_symbols,
241 cuStreamDestroy(device->cu_stream));
242
243 IREE_CUDA_IGNORE_ERROR(device->cuda_symbols,
244 cuDevicePrimaryCtxRelease(device->cu_device));
245
246 iree_arena_block_pool_deinitialize(&device->block_pool);
247
248 // Finally, destroy the device.
249 iree_hal_driver_release(device->driver);
250
251 iree_allocator_free(host_allocator, device);
252
253 IREE_TRACE_ZONE_END(z0);
254}
255
256static iree_string_view_t iree_hal_cuda2_device_id(
257 iree_hal_device_t* base_device) {
258 iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
259 return device->identifier;
260}
261
262static iree_allocator_t iree_hal_cuda2_device_host_allocator(
263 iree_hal_device_t* base_device) {
264 iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
265 return device->host_allocator;
266}
267
268static iree_hal_allocator_t* iree_hal_cuda2_device_allocator(
269 iree_hal_device_t* base_device) {
270 iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
271 return device->device_allocator;
272}
273
274static void iree_hal_cuda2_replace_device_allocator(
275 iree_hal_device_t* base_device, iree_hal_allocator_t* new_allocator) {
276 iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
277 iree_hal_allocator_retain(new_allocator);
278 iree_hal_allocator_release(device->device_allocator);
279 device->device_allocator = new_allocator;
280}
281
282static void iree_hal_cuda2_replace_channel_provider(
283 iree_hal_device_t* base_device, iree_hal_channel_provider_t* new_provider) {
Lei Zhang330771e2023-06-13 19:39:23 -0400284 iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
285 iree_hal_channel_provider_retain(new_provider);
286 iree_hal_channel_provider_release(device->channel_provider);
287 device->channel_provider = new_provider;
Lei Zhangc4e01e92023-06-09 17:38:05 -0400288}
289
290static iree_status_t iree_hal_cuda2_device_trim(
291 iree_hal_device_t* base_device) {
292 iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
293 iree_arena_block_pool_trim(&device->block_pool);
294 IREE_RETURN_IF_ERROR(iree_hal_allocator_trim(device->device_allocator));
295 if (device->supports_memory_pools) {
296 IREE_RETURN_IF_ERROR(iree_hal_cuda2_memory_pools_trim(
297 &device->memory_pools, &device->params.memory_pools));
298 }
299 return iree_ok_status();
300}
301
302static iree_status_t iree_hal_cuda2_device_query_attribute(
303 iree_hal_cuda2_device_t* device, CUdevice_attribute attribute,
304 int64_t* out_value) {
305 int value = 0;
306 IREE_CUDA_RETURN_IF_ERROR(
307 device->cuda_symbols,
308 cuDeviceGetAttribute(&value, attribute, device->cu_device),
309 "cuDeviceGetAttribute");
310 *out_value = value;
311 return iree_ok_status();
312}
313
314static iree_status_t iree_hal_cuda2_device_query_i64(
315 iree_hal_device_t* base_device, iree_string_view_t category,
316 iree_string_view_t key, int64_t* out_value) {
317 iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
318 *out_value = 0;
319
320 if (iree_string_view_equal(category, IREE_SV("hal.executable.format"))) {
321 *out_value = iree_string_view_equal(key, IREE_SV("cuda-nvptx-fb")) ? 1 : 0;
322 return iree_ok_status();
323 }
324
325 if (iree_string_view_equal(category, IREE_SV("cuda.device"))) {
326 if (iree_string_view_equal(key, IREE_SV("compute_capability_major"))) {
327 return iree_hal_cuda2_device_query_attribute(
328 device, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, out_value);
329 } else if (iree_string_view_equal(key,
330 IREE_SV("compute_capability_minor"))) {
331 return iree_hal_cuda2_device_query_attribute(
332 device, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, out_value);
333 }
334 }
335
336 return iree_make_status(
337 IREE_STATUS_NOT_FOUND,
338 "unknown device configuration key value '%.*s :: %.*s'",
339 (int)category.size, category.data, (int)key.size, key.data);
340}
341
342static iree_status_t iree_hal_cuda2_device_create_channel(
343 iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
344 iree_hal_channel_params_t params, iree_hal_channel_t** out_channel) {
Lei Zhang330771e2023-06-13 19:39:23 -0400345 iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
346 if (!device->nccl_symbols || !device->nccl_symbols->dylib) {
347 return iree_make_status(
348 IREE_STATUS_UNAVAILABLE,
349 "NCCL runtime library (%d.%d.%d) not available; ensure installed and "
350 "the shared library is on your PATH/LD_LIBRARY_PATH "
351 "(nccl.dll/libnccl.so)",
352 NCCL_MAJOR, NCCL_MINOR, NCCL_PATCH);
353 }
354
355 // Today we only allow a single logical device per channel.
356 // We could multiplex channels but it'd be better to surface that to the
357 // compiler so that it can emit the right rank math.
358 int requested_count = iree_math_count_ones_u64(queue_affinity);
359 // TODO(#12206): properly assign affinity in the compiler.
360 if (requested_count != 64 && requested_count != 1) {
361 return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
362 "exactly one participant is allowed in a "
363 "channel but %d were specified",
364 requested_count);
365 }
366
367 // Ask the channel provider (if configured) for the default rank and count
368 // if the user did not set them.
369 if (device->channel_provider &&
370 (params.rank == IREE_HAL_CHANNEL_RANK_DEFAULT ||
371 params.count == IREE_HAL_CHANNEL_COUNT_DEFAULT)) {
372 IREE_RETURN_IF_ERROR(
373 iree_hal_channel_provider_query_default_rank_and_count(
374 device->channel_provider, &params.rank, &params.count),
375 "querying default collective group rank and count");
376 }
377
378 // An ID is required to initialize NCCL. On the root it'll be the local ID and
379 // on all other participants it'll be the root ID.
380 iree_hal_cuda2_nccl_id_t id;
381 memset(&id, 0, sizeof(id));
382 if (iree_const_byte_span_is_empty(params.id)) {
383 // User wants the default ID.
384 if (!device->channel_provider) {
385 return iree_make_status(
386 IREE_STATUS_INVALID_ARGUMENT,
387 "default collective channel ID requested but no channel provider has "
388 "been set on the device to provide it");
389 }
390 if (params.rank == 0) {
391 // Bootstrap NCCL to get the root ID.
392 IREE_RETURN_IF_ERROR(
393 iree_hal_cuda2_nccl_get_unique_id(device->nccl_symbols, &id),
394 "bootstrapping NCCL root");
395 }
396 // Exchange NCCL ID with all participants.
397 IREE_RETURN_IF_ERROR(iree_hal_channel_provider_exchange_default_id(
398 device->channel_provider,
399 iree_make_byte_span((void*)&id, sizeof(id))),
400 "exchanging NCCL ID with other participants");
401 } else if (params.id.data_length != IREE_ARRAYSIZE(id.data)) {
402 // User provided something but it's not what we expect.
403 return iree_make_status(
404 IREE_STATUS_INVALID_ARGUMENT,
405 "NCCL ID must be %" PRIhsz
406 " bytes matching the ncclUniqueId struct but caller provided %" PRIhsz
407 " bytes",
408 IREE_ARRAYSIZE(id.data), sizeof(id));
409 } else {
410 // User provided the ID - we treat it as opaque here and let NCCL validate.
411 memcpy(id.data, params.id.data, IREE_ARRAYSIZE(id.data));
412 }
413
414 if (iree_hal_cuda2_nccl_id_is_empty(&id)) {
415 // TODO: maybe this is ok? a localhost alias or something?
416 return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
417 "no default NCCL ID specified (all zeros)");
418 }
419
420 // TODO: when we support multiple logical devices we'll want to pass in the
421 // context of the device mapped to the queue_affinity. For now since this
422 // implementation only supports one device we pass in the only one we have.
423 return iree_hal_cuda2_nccl_channel_create(
424 device->cuda_symbols, device->nccl_symbols, &id, params.rank,
425 params.count, device->host_allocator, out_channel);
Lei Zhangc4e01e92023-06-09 17:38:05 -0400426}
427
428static iree_status_t iree_hal_cuda2_device_create_command_buffer(
429 iree_hal_device_t* base_device, iree_hal_command_buffer_mode_t mode,
430 iree_hal_command_category_t command_categories,
431 iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity,
432 iree_hal_command_buffer_t** out_command_buffer) {
433 return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
434 "command buffer not yet implmeneted");
435}
436
437static iree_status_t iree_hal_cuda2_device_create_descriptor_set_layout(
438 iree_hal_device_t* base_device,
439 iree_hal_descriptor_set_layout_flags_t flags,
440 iree_host_size_t binding_count,
441 const iree_hal_descriptor_set_layout_binding_t* bindings,
442 iree_hal_descriptor_set_layout_t** out_descriptor_set_layout) {
Lei Zhang45a3eb42023-06-12 14:16:00 -0400443 iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
444 return iree_hal_cuda2_descriptor_set_layout_create(
445 flags, binding_count, bindings, device->host_allocator,
446 out_descriptor_set_layout);
Lei Zhangc4e01e92023-06-09 17:38:05 -0400447}
448
449static iree_status_t iree_hal_cuda2_device_create_event(
450 iree_hal_device_t* base_device, iree_hal_event_t** out_event) {
451 return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
452 "event not yet implmeneted");
453}
454
455static iree_status_t iree_hal_cuda2_device_create_executable_cache(
456 iree_hal_device_t* base_device, iree_string_view_t identifier,
457 iree_loop_t loop, iree_hal_executable_cache_t** out_executable_cache) {
Lei Zhang85a1a562023-06-13 19:51:42 -0400458 iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
459 return iree_hal_cuda2_nop_executable_cache_create(
460 identifier, device->cuda_symbols, device->cu_device,
461 device->host_allocator, out_executable_cache);
Lei Zhangc4e01e92023-06-09 17:38:05 -0400462}
463
464static iree_status_t iree_hal_cuda2_device_create_pipeline_layout(
465 iree_hal_device_t* base_device, iree_host_size_t push_constants,
466 iree_host_size_t set_layout_count,
467 iree_hal_descriptor_set_layout_t* const* set_layouts,
468 iree_hal_pipeline_layout_t** out_pipeline_layout) {
Lei Zhang45a3eb42023-06-12 14:16:00 -0400469 iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
470 return iree_hal_cuda2_pipeline_layout_create(
471 set_layout_count, set_layouts, push_constants, device->host_allocator,
472 out_pipeline_layout);
Lei Zhangc4e01e92023-06-09 17:38:05 -0400473}
474
475static iree_status_t iree_hal_cuda2_device_create_semaphore(
476 iree_hal_device_t* base_device, uint64_t initial_value,
477 iree_hal_semaphore_t** out_semaphore) {
478 return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
479 "semaphore not yet implmeneted");
480}
481
482static iree_hal_semaphore_compatibility_t
483iree_hal_cuda2_device_query_semaphore_compatibility(
484 iree_hal_device_t* base_device, iree_hal_semaphore_t* semaphore) {
485 // TODO: implement CUDA semaphores.
486 return IREE_HAL_SEMAPHORE_COMPATIBILITY_NONE;
487}
488
489// TODO: implement multiple streams; today we only have one and queue_affinity
490// is ignored.
491// TODO: implement proper semaphores in CUDA to ensure ordering and avoid
492// the barrier here.
493static iree_status_t iree_hal_cuda2_device_queue_alloca(
494 iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
495 const iree_hal_semaphore_list_t wait_semaphore_list,
496 const iree_hal_semaphore_list_t signal_semaphore_list,
497 iree_hal_allocator_pool_t pool, iree_hal_buffer_params_t params,
498 iree_device_size_t allocation_size,
499 iree_hal_buffer_t** IREE_RESTRICT out_buffer) {
500 return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
501 "queue alloca not yet implmeneted");
502}
503
504// TODO: implement multiple streams; today we only have one and queue_affinity
505// is ignored.
506// TODO: implement proper semaphores in CUDA to ensure ordering and avoid
507// the barrier here.
508static iree_status_t iree_hal_cuda2_device_queue_dealloca(
509 iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
510 const iree_hal_semaphore_list_t wait_semaphore_list,
511 const iree_hal_semaphore_list_t signal_semaphore_list,
512 iree_hal_buffer_t* buffer) {
513 return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
514 "queue dealloca not yet implmeneted");
515}
516
517static iree_status_t iree_hal_cuda2_device_queue_execute(
518 iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
519 const iree_hal_semaphore_list_t wait_semaphore_list,
520 const iree_hal_semaphore_list_t signal_semaphore_list,
521 iree_host_size_t command_buffer_count,
522 iree_hal_command_buffer_t* const* command_buffers) {
523 return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
524 "queue execution not yet implmeneted");
525}
526
527static iree_status_t iree_hal_cuda2_device_queue_flush(
528 iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity) {
529 // Currently unused; we flush as submissions are made.
530 return iree_ok_status();
531}
532
533static iree_status_t iree_hal_cuda2_device_wait_semaphores(
534 iree_hal_device_t* base_device, iree_hal_wait_mode_t wait_mode,
535 const iree_hal_semaphore_list_t semaphore_list, iree_timeout_t timeout) {
536 return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
537 "semaphore not yet implemented");
538}
539
540static iree_status_t iree_hal_cuda2_device_profiling_begin(
541 iree_hal_device_t* base_device,
542 const iree_hal_device_profiling_options_t* options) {
543 // Unimplemented (and that's ok).
544 // We could hook in to CUPTI here or use the much simpler cuProfilerStart API.
545 return iree_ok_status();
546}
547
548static iree_status_t iree_hal_cuda2_device_profiling_end(
549 iree_hal_device_t* base_device) {
550 // Unimplemented (and that's ok).
551 return iree_ok_status();
552}
553
554static const iree_hal_device_vtable_t iree_hal_cuda2_device_vtable = {
555 .destroy = iree_hal_cuda2_device_destroy,
556 .id = iree_hal_cuda2_device_id,
557 .host_allocator = iree_hal_cuda2_device_host_allocator,
558 .device_allocator = iree_hal_cuda2_device_allocator,
559 .replace_device_allocator = iree_hal_cuda2_replace_device_allocator,
560 .replace_channel_provider = iree_hal_cuda2_replace_channel_provider,
561 .trim = iree_hal_cuda2_device_trim,
562 .query_i64 = iree_hal_cuda2_device_query_i64,
563 .create_channel = iree_hal_cuda2_device_create_channel,
564 .create_command_buffer = iree_hal_cuda2_device_create_command_buffer,
565 .create_descriptor_set_layout =
566 iree_hal_cuda2_device_create_descriptor_set_layout,
567 .create_event = iree_hal_cuda2_device_create_event,
568 .create_executable_cache = iree_hal_cuda2_device_create_executable_cache,
569 .create_pipeline_layout = iree_hal_cuda2_device_create_pipeline_layout,
570 .create_semaphore = iree_hal_cuda2_device_create_semaphore,
571 .query_semaphore_compatibility =
572 iree_hal_cuda2_device_query_semaphore_compatibility,
573 .transfer_range = iree_hal_device_submit_transfer_range_and_wait,
574 .queue_alloca = iree_hal_cuda2_device_queue_alloca,
575 .queue_dealloca = iree_hal_cuda2_device_queue_dealloca,
576 .queue_execute = iree_hal_cuda2_device_queue_execute,
577 .queue_flush = iree_hal_cuda2_device_queue_flush,
578 .wait_semaphores = iree_hal_cuda2_device_wait_semaphores,
579 .profiling_begin = iree_hal_cuda2_device_profiling_begin,
580 .profiling_end = iree_hal_cuda2_device_profiling_end,
581};