blob: 4fab5596185fa3e1ad937825d3cd3a678ad879d6 [file] [log] [blame]
Lei Zhang5c38bcc2023-06-05 17:29:11 -04001// Copyright 2023 The IREE Authors
2//
3// Licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
7#include <stdint.h>
8#include <string.h>
9
10#include "experimental/cuda2/api.h"
Lei Zhangc4e01e92023-06-09 17:38:05 -040011#include "experimental/cuda2/cuda_device.h"
Lei Zhang6ab45702023-06-05 17:41:45 -040012#include "experimental/cuda2/cuda_dynamic_symbols.h"
13#include "experimental/cuda2/cuda_status_util.h"
14#include "experimental/cuda2/nccl_dynamic_symbols.h"
15#include "experimental/cuda2/nccl_status_util.h"
Lei Zhang5c38bcc2023-06-05 17:29:11 -040016#include "iree/base/api.h"
Lei Zhang5c38bcc2023-06-05 17:29:11 -040017#include "iree/hal/api.h"
18
19// Maximum device name length supported by the CUDA HAL driver.
20#define IREE_HAL_CUDA_MAX_DEVICE_NAME_LENGTH 128
21
22// Utility macros to convert between CUDevice and iree_hal_device_id_t.
23#define IREE_CUDEVICE_TO_DEVICE_ID(device) (iree_hal_device_id_t)((device) + 1)
24#define IREE_DEVICE_ID_TO_CUDEVICE(device_id) (CUdevice)((device_id)-1)
25
26typedef struct iree_hal_cuda2_driver_t {
27 // Abstract resource used for injecting reference counting and vtable;
28 // must be at offset 0.
29 iree_hal_resource_t resource;
30
31 iree_allocator_t host_allocator;
32
33 // Identifier used for registering the driver in the IREE driver registry.
34 iree_string_view_t identifier;
35 // CUDA driver API dynamic symbols to interact with the CUDA system.
36 iree_hal_cuda2_dynamic_symbols_t cuda_symbols;
37 // NCCL API dynamic symbols to interact with the CUDA system.
38 iree_hal_cuda2_nccl_dynamic_symbols_t nccl_symbols;
39
Lei Zhangc4e01e92023-06-09 17:38:05 -040040 // The default parameters for creating devices using this driver.
41 iree_hal_cuda2_device_params_t device_params;
42
Lei Zhang5c38bcc2023-06-05 17:29:11 -040043 // The index of the default CUDA device to use if multiple ones are available.
44 int default_device_index;
45} iree_hal_cuda2_driver_t;
46
47static const iree_hal_driver_vtable_t iree_hal_cuda2_driver_vtable;
48
49static iree_hal_cuda2_driver_t* iree_hal_cuda2_driver_cast(
50 iree_hal_driver_t* base_value) {
51 IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_cuda2_driver_vtable);
52 return (iree_hal_cuda2_driver_t*)base_value;
53}
54
55IREE_API_EXPORT void iree_hal_cuda2_driver_options_initialize(
56 iree_hal_cuda2_driver_options_t* out_options) {
57 IREE_ASSERT_ARGUMENT(out_options);
58 memset(out_options, 0, sizeof(*out_options));
59 out_options->default_device_index = 0;
60}
61
62static iree_status_t iree_hal_cuda2_driver_create_internal(
63 iree_string_view_t identifier,
64 const iree_hal_cuda2_driver_options_t* options,
Lei Zhangc4e01e92023-06-09 17:38:05 -040065 const iree_hal_cuda2_device_params_t* device_params,
Lei Zhang5c38bcc2023-06-05 17:29:11 -040066 iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) {
67 iree_hal_cuda2_driver_t* driver = NULL;
68 iree_host_size_t total_size = iree_sizeof_struct(*driver) + identifier.size;
69 IREE_RETURN_IF_ERROR(
70 iree_allocator_malloc(host_allocator, total_size, (void**)&driver));
71
72 iree_hal_resource_initialize(&iree_hal_cuda2_driver_vtable,
73 &driver->resource);
74 driver->host_allocator = host_allocator;
75 iree_string_view_append_to_buffer(
76 identifier, &driver->identifier,
77 (char*)driver + iree_sizeof_struct(*driver));
78 driver->default_device_index = options->default_device_index;
79
80 iree_status_t status = iree_hal_cuda2_dynamic_symbols_initialize(
81 host_allocator, &driver->cuda_symbols);
82
83 if (iree_status_is_ok(status)) {
84 // Try to dynamically load NCCL. This will fail if NCCL is unavailable or
85 // incompatible. We only fail on unavailability when the user tries to
86 // create a channel and otherwise defer reporting.
87 status = iree_hal_cuda2_nccl_dynamic_symbols_initialize(
88 host_allocator, &driver->cuda_symbols, &driver->nccl_symbols);
89 if (iree_status_is_unavailable(status)) status = iree_status_ignore(status);
90 }
91
Lei Zhangc4e01e92023-06-09 17:38:05 -040092 memcpy(&driver->device_params, device_params, sizeof(driver->device_params));
93
Lei Zhang5c38bcc2023-06-05 17:29:11 -040094 if (iree_status_is_ok(status)) {
95 *out_driver = (iree_hal_driver_t*)driver;
96 } else {
97 iree_hal_driver_release((iree_hal_driver_t*)driver);
98 }
99 return status;
100}
101
102IREE_API_EXPORT iree_status_t iree_hal_cuda2_driver_create(
103 iree_string_view_t identifier,
104 const iree_hal_cuda2_driver_options_t* options,
Lei Zhangc4e01e92023-06-09 17:38:05 -0400105 const iree_hal_cuda2_device_params_t* device_params,
Lei Zhang5c38bcc2023-06-05 17:29:11 -0400106 iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) {
107 IREE_ASSERT_ARGUMENT(options);
Lei Zhangc4e01e92023-06-09 17:38:05 -0400108 IREE_ASSERT_ARGUMENT(device_params);
Lei Zhang5c38bcc2023-06-05 17:29:11 -0400109 IREE_ASSERT_ARGUMENT(out_driver);
110 IREE_TRACE_ZONE_BEGIN(z0);
111
112 iree_status_t status = iree_hal_cuda2_driver_create_internal(
Lei Zhangc4e01e92023-06-09 17:38:05 -0400113 identifier, options, device_params, host_allocator, out_driver);
Lei Zhang5c38bcc2023-06-05 17:29:11 -0400114
115 IREE_TRACE_ZONE_END(z0);
116 return status;
117}
118
119static void iree_hal_cuda2_driver_destroy(iree_hal_driver_t* base_driver) {
120 IREE_ASSERT_ARGUMENT(base_driver);
121
122 iree_hal_cuda2_driver_t* driver = iree_hal_cuda2_driver_cast(base_driver);
123 iree_allocator_t host_allocator = driver->host_allocator;
124 IREE_TRACE_ZONE_BEGIN(z0);
125
126 iree_hal_cuda2_nccl_dynamic_symbols_deinitialize(&driver->nccl_symbols);
127 iree_hal_cuda2_dynamic_symbols_deinitialize(&driver->cuda_symbols);
128 iree_allocator_free(host_allocator, driver);
129
130 IREE_TRACE_ZONE_END(z0);
131}
132
133// Initializes the CUDA system.
134static iree_status_t iree_hal_cuda2_init(iree_hal_cuda2_driver_t* driver) {
135 IREE_TRACE_ZONE_BEGIN(z0);
136 iree_status_t status =
137 IREE_CURESULT_TO_STATUS(&driver->cuda_symbols, cuInit(0), "cuInit");
138 IREE_TRACE_ZONE_END(z0);
139 return status;
140}
141
142// Populates device information from the given CUDA physical device handle.
143// |out_device_info| must point to valid memory and additional data will be
144// appended to |buffer_ptr| and the new pointer is returned.
145static iree_status_t iree_hal_cuda2_populate_device_info(
146 CUdevice device, iree_hal_cuda2_dynamic_symbols_t* syms,
147 uint8_t* buffer_ptr, uint8_t** out_buffer_ptr,
148 iree_hal_device_info_t* out_device_info) {
149 *out_buffer_ptr = buffer_ptr;
150
151 char device_name[IREE_HAL_CUDA_MAX_DEVICE_NAME_LENGTH];
152 IREE_CUDA_RETURN_IF_ERROR(
153 syms, cuDeviceGetName(device_name, sizeof(device_name), device),
154 "cuDeviceGetName");
155 memset(out_device_info, 0, sizeof(*out_device_info));
156 out_device_info->device_id = IREE_CUDEVICE_TO_DEVICE_ID(device);
157
158 // This matches the output of `nvidia-smi -L`.
159 CUuuid device_uuid;
160 IREE_CUDA_RETURN_IF_ERROR(syms, cuDeviceGetUuid(&device_uuid, device),
161 "cuDeviceGetUuid");
162 char device_path_str[4 + 36 + 1] = {0};
163 snprintf(device_path_str, sizeof(device_path_str),
164 "GPU-"
165 "%02x%02x%02x%02x-"
166 "%02x%02x-"
167 "%02x%02x-"
168 "%02x%02x-"
169 "%02x%02x%02x%02x%02x%02x",
170 (uint8_t)device_uuid.bytes[0], (uint8_t)device_uuid.bytes[1],
171 (uint8_t)device_uuid.bytes[2], (uint8_t)device_uuid.bytes[3],
172 (uint8_t)device_uuid.bytes[4], (uint8_t)device_uuid.bytes[5],
173 (uint8_t)device_uuid.bytes[6], (uint8_t)device_uuid.bytes[7],
174 (uint8_t)device_uuid.bytes[8], (uint8_t)device_uuid.bytes[9],
175 (uint8_t)device_uuid.bytes[10], (uint8_t)device_uuid.bytes[11],
176 (uint8_t)device_uuid.bytes[12], (uint8_t)device_uuid.bytes[13],
177 (uint8_t)device_uuid.bytes[14], (uint8_t)device_uuid.bytes[15]);
178 buffer_ptr += iree_string_view_append_to_buffer(
179 iree_make_string_view(device_path_str,
180 IREE_ARRAYSIZE(device_path_str) - 1),
181 &out_device_info->path, (char*)buffer_ptr);
182
183 iree_string_view_t device_name_str =
184 iree_make_string_view(device_name, strlen(device_name));
185 buffer_ptr += iree_string_view_append_to_buffer(
186 device_name_str, &out_device_info->name, (char*)buffer_ptr);
187
188 *out_buffer_ptr = buffer_ptr;
189 return iree_ok_status();
190}
191
192// Returns true if the device meets all the required capabilities.
193static bool iree_hal_cuda2_is_valid_device(iree_hal_cuda2_driver_t* driver,
194 CUdevice device) {
195 return true;
196}
197
198static iree_status_t iree_hal_cuda2_driver_query_available_devices(
199 iree_hal_driver_t* base_driver, iree_allocator_t host_allocator,
200 iree_host_size_t* out_device_info_count,
201 iree_hal_device_info_t** out_device_infos) {
202 IREE_ASSERT_ARGUMENT(base_driver);
203 IREE_ASSERT_ARGUMENT(out_device_info_count);
204 IREE_ASSERT_ARGUMENT(out_device_infos);
205 iree_hal_cuda2_driver_t* driver = iree_hal_cuda2_driver_cast(base_driver);
206 IREE_TRACE_ZONE_BEGIN(z0);
207
208 // Ensure CUDA is initialized before querying it.
209 IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_cuda2_init(driver));
210
211 // Query the number of available CUDA devices.
212 int device_count = 0;
213 IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR(z0, &driver->cuda_symbols,
214 cuDeviceGetCount(&device_count),
215 "cuDeviceGetCount");
216
217 // Allocate the return infos and populate with the devices.
218 iree_hal_device_info_t* device_infos = NULL;
219 iree_host_size_t total_size =
220 device_count * (sizeof(iree_hal_device_info_t) +
221 IREE_HAL_CUDA_MAX_DEVICE_NAME_LENGTH * sizeof(char));
222 iree_status_t status =
223 iree_allocator_malloc(host_allocator, total_size, (void**)&device_infos);
224
225 int valid_device_count = 0;
226 if (iree_status_is_ok(status)) {
227 uint8_t* buffer_ptr =
228 (uint8_t*)device_infos + device_count * sizeof(iree_hal_device_info_t);
229 for (iree_host_size_t i = 0; i < device_count; ++i) {
230 CUdevice device = 0;
231 status = IREE_CURESULT_TO_STATUS(&driver->cuda_symbols,
232 cuDeviceGet(&device, i), "cuDeviceGet");
233 if (!iree_status_is_ok(status)) break;
234 if (!iree_hal_cuda2_is_valid_device(driver, device)) continue;
235 status = iree_hal_cuda2_populate_device_info(
236 device, &driver->cuda_symbols, buffer_ptr, &buffer_ptr,
237 &device_infos[valid_device_count]);
238 if (!iree_status_is_ok(status)) break;
239 valid_device_count++;
240 }
241 }
242 if (iree_status_is_ok(status)) {
243 *out_device_info_count = valid_device_count;
244 *out_device_infos = device_infos;
245 } else {
246 iree_allocator_free(host_allocator, device_infos);
247 }
248
249 IREE_TRACE_ZONE_END(z0);
250 return status;
251}
252
253static iree_status_t iree_hal_cuda2_driver_dump_device_info(
254 iree_hal_driver_t* base_driver, iree_hal_device_id_t device_id,
255 iree_string_builder_t* builder) {
Lei Zhang686860c2023-06-05 17:44:47 -0400256 IREE_ASSERT_ARGUMENT(base_driver);
257 IREE_ASSERT_ARGUMENT(builder);
Lei Zhang5c38bcc2023-06-05 17:29:11 -0400258 iree_hal_cuda2_driver_t* driver = iree_hal_cuda2_driver_cast(base_driver);
Lei Zhang686860c2023-06-05 17:44:47 -0400259 CUdevice device = IREE_DEVICE_ID_TO_CUDEVICE(device_id);
260
261#define IREE_CUDA_QUERY_ATTRIBUTE(attribute, value) \
262 IREE_CUDA_RETURN_IF_ERROR( \
263 &driver->cuda_symbols, \
264 cuDeviceGetAttribute(&value, CU_DEVICE_ATTRIBUTE_##attribute, device), \
265 "cuDeviceGetAttribute");
266
267 int compute_capability_major = 0, compute_capability_minor = 0;
268 IREE_CUDA_QUERY_ATTRIBUTE(COMPUTE_CAPABILITY_MAJOR, compute_capability_major);
269 IREE_CUDA_QUERY_ATTRIBUTE(COMPUTE_CAPABILITY_MINOR, compute_capability_minor);
270 IREE_RETURN_IF_ERROR(iree_string_builder_append_format(
271 builder, "\n- gpu-compute-capability: %d.%d", compute_capability_major,
272 compute_capability_minor));
273
274 int driver_version = 0;
275 IREE_CUDA_RETURN_IF_ERROR(&driver->cuda_symbols,
276 cuDriverGetVersion(&driver_version),
277 "cuDriverGetVersion");
278 IREE_RETURN_IF_ERROR(iree_string_builder_append_format(
279 builder, "\n- driver-max-cuda-version: %d.%d", driver_version / 1000,
280 (driver_version % 1000) / 10));
281
282 // Launch configuration limits.
283 int max_block_dim_x = 0, max_block_dim_y = 0, max_block_dim_z = 0;
284 int max_grid_dim_x = 0, max_grid_dim_y = 0, max_grid_dim_z = 0;
285 IREE_CUDA_QUERY_ATTRIBUTE(MAX_BLOCK_DIM_X, max_block_dim_x);
286 IREE_CUDA_QUERY_ATTRIBUTE(MAX_BLOCK_DIM_Y, max_block_dim_y);
287 IREE_CUDA_QUERY_ATTRIBUTE(MAX_BLOCK_DIM_Z, max_block_dim_z);
288 IREE_CUDA_QUERY_ATTRIBUTE(MAX_GRID_DIM_X, max_grid_dim_x);
289 IREE_CUDA_QUERY_ATTRIBUTE(MAX_GRID_DIM_Y, max_grid_dim_y);
290 IREE_CUDA_QUERY_ATTRIBUTE(MAX_GRID_DIM_Z, max_grid_dim_z);
291
292 IREE_RETURN_IF_ERROR(iree_string_builder_append_cstring(builder, "\n"));
293 IREE_RETURN_IF_ERROR(iree_string_builder_append_format(
294 builder, "\n- launch-max-block-dims: (%d, %d, %d)", max_block_dim_x,
295 max_block_dim_y, max_block_dim_z));
296 IREE_RETURN_IF_ERROR(iree_string_builder_append_format(
297 builder, "\n- launch-max-grid-dims: (%d, %d, %d)", max_grid_dim_x,
298 max_grid_dim_y, max_grid_dim_z));
299
300 // Per block resource limits.
301 int max_threads_per_block = 0;
302 int max_registers_per_block = 0;
303 int max_shared_memory_per_block = 0;
304 IREE_CUDA_QUERY_ATTRIBUTE(MAX_THREADS_PER_BLOCK, max_threads_per_block);
305 IREE_CUDA_QUERY_ATTRIBUTE(MAX_REGISTERS_PER_BLOCK, max_registers_per_block);
306 IREE_CUDA_QUERY_ATTRIBUTE(MAX_SHARED_MEMORY_PER_BLOCK,
307 max_shared_memory_per_block);
308
309 IREE_RETURN_IF_ERROR(iree_string_builder_append_cstring(builder, "\n"));
310 IREE_RETURN_IF_ERROR(iree_string_builder_append_format(
311 builder, "\n- block-max-thread-count: %d", max_threads_per_block));
312 IREE_RETURN_IF_ERROR(iree_string_builder_append_format(
313 builder, "\n- block-max-32-bit-register-count: %d",
314 max_registers_per_block));
315 IREE_RETURN_IF_ERROR(iree_string_builder_append_format(
316 builder, "\n- block-max-shared-memory: %d bytes",
317 max_shared_memory_per_block));
318
319 // Per multiprocessor resource limits.
320 int max_threads_per_multiprocessor = 0;
321 int max_blocks_per_multiprocessor = 0;
322 int max_registers_per_multiprocessor = 0;
323 int max_shared_memory_per_multiprocessor = 0;
324 IREE_CUDA_QUERY_ATTRIBUTE(MAX_THREADS_PER_MULTIPROCESSOR,
325 max_threads_per_multiprocessor);
326 IREE_CUDA_QUERY_ATTRIBUTE(MAX_BLOCKS_PER_MULTIPROCESSOR,
327 max_blocks_per_multiprocessor);
328 IREE_CUDA_QUERY_ATTRIBUTE(MAX_REGISTERS_PER_MULTIPROCESSOR,
329 max_registers_per_multiprocessor);
330 IREE_CUDA_QUERY_ATTRIBUTE(MAX_SHARED_MEMORY_PER_MULTIPROCESSOR,
331 max_shared_memory_per_multiprocessor);
332
333 IREE_RETURN_IF_ERROR(iree_string_builder_append_cstring(builder, "\n"));
334 IREE_RETURN_IF_ERROR(iree_string_builder_append_format(
335 builder, "\n- multiprocessor-max-thread-count: %d",
336 max_threads_per_multiprocessor));
337 IREE_RETURN_IF_ERROR(iree_string_builder_append_format(
338 builder, "\n- multiprocessor-max-block-count: %d",
339 max_blocks_per_multiprocessor));
340 IREE_RETURN_IF_ERROR(iree_string_builder_append_format(
341 builder, "\n- multiprocessor-max-32-bit-register-count: %d",
342 max_registers_per_multiprocessor));
343 IREE_RETURN_IF_ERROR(iree_string_builder_append_format(
344 builder, "\n- multiprocessor-max-shared-memory: %d bytes",
345 max_shared_memory_per_multiprocessor));
346
347 // Memory characteristics.
Lei Zhang6a61d9f2023-06-08 12:27:20 -0400348 int is_integrated_memory = 0;
Lei Zhang686860c2023-06-05 17:44:47 -0400349 int has_unified_address_space = 0;
350 int supports_managed_memory = 0;
351 int can_map_host_memory = 0;
352 int supports_pageable_memory_access = 0;
353 int supports_concurrent_managed_access = 0;
354 int supports_memory_pools = 0;
355 int l2_cache_size = 0;
Lei Zhang6a61d9f2023-06-08 12:27:20 -0400356 IREE_CUDA_QUERY_ATTRIBUTE(INTEGRATED, is_integrated_memory);
Lei Zhang686860c2023-06-05 17:44:47 -0400357 IREE_CUDA_QUERY_ATTRIBUTE(UNIFIED_ADDRESSING, has_unified_address_space);
358 IREE_CUDA_QUERY_ATTRIBUTE(MANAGED_MEMORY, supports_managed_memory);
359 IREE_CUDA_QUERY_ATTRIBUTE(CAN_MAP_HOST_MEMORY, can_map_host_memory);
360 IREE_CUDA_QUERY_ATTRIBUTE(PAGEABLE_MEMORY_ACCESS,
361 supports_pageable_memory_access);
362 IREE_CUDA_QUERY_ATTRIBUTE(CONCURRENT_MANAGED_ACCESS,
363 supports_concurrent_managed_access);
364 IREE_CUDA_QUERY_ATTRIBUTE(MEMORY_POOLS_SUPPORTED, supports_memory_pools);
365 IREE_CUDA_QUERY_ATTRIBUTE(L2_CACHE_SIZE, l2_cache_size);
366
367 IREE_RETURN_IF_ERROR(iree_string_builder_append_cstring(builder, "\n"));
368 IREE_RETURN_IF_ERROR(iree_string_builder_append_format(
Lei Zhang6a61d9f2023-06-08 12:27:20 -0400369 builder, "\n- memory-is-integrated-memory: %d", is_integrated_memory));
370 IREE_RETURN_IF_ERROR(iree_string_builder_append_format(
Lei Zhang686860c2023-06-05 17:44:47 -0400371 builder, "\n- memory-has-unified-address-space: %d",
372 has_unified_address_space));
373 IREE_RETURN_IF_ERROR(iree_string_builder_append_format(
374 builder, "\n- memory-supports-managed-memory: %d",
375 supports_managed_memory));
376 IREE_RETURN_IF_ERROR(iree_string_builder_append_format(
377 builder, "\n- memory-can-map-host-memory-to-device: %d",
378 can_map_host_memory));
379 IREE_RETURN_IF_ERROR(iree_string_builder_append_format(
380 builder, "\n- memory-supports-pageable-memory-access-from-device: %d",
381 supports_pageable_memory_access));
382 IREE_RETURN_IF_ERROR(iree_string_builder_append_format(
383 builder, "\n- memory-supports-concurrent-managed-access: %d",
384 supports_concurrent_managed_access));
385 IREE_RETURN_IF_ERROR(iree_string_builder_append_format(
386 builder, "\n- memory-supports-memory-pools: %d", supports_memory_pools));
387 IREE_RETURN_IF_ERROR(iree_string_builder_append_format(
388 builder, "\n- memory-l2-cache-size: %d bytes", l2_cache_size));
389
Lei Zhang31c2d242023-06-12 19:48:48 -0400390 int supports_64bit_memops = 0;
391 IREE_CUDA_QUERY_ATTRIBUTE(CAN_USE_64_BIT_STREAM_MEM_OPS,
392 supports_64bit_memops);
393 int supports_timeline_semaphore_interop = 0;
394 IREE_CUDA_QUERY_ATTRIBUTE(TIMELINE_SEMAPHORE_INTEROP_SUPPORTED,
395 supports_timeline_semaphore_interop);
396 int mem_sync_domain_count = 0;
397 IREE_CUDA_QUERY_ATTRIBUTE(MEM_SYNC_DOMAIN_COUNT, mem_sync_domain_count);
398
399 IREE_RETURN_IF_ERROR(iree_string_builder_append_cstring(builder, "\n"));
400 IREE_RETURN_IF_ERROR(iree_string_builder_append_format(
401 builder, "\n- sync-supports-64-bit-stream-mem-ops: %d",
402 supports_64bit_memops));
403 IREE_RETURN_IF_ERROR(iree_string_builder_append_format(
404 builder, "\n- sync-supports-timeline-semaphore-interop: %d",
405 supports_timeline_semaphore_interop));
406 IREE_RETURN_IF_ERROR(iree_string_builder_append_format(
407 builder, "\n- sync-mem-domain-count: %d", mem_sync_domain_count));
408
Lei Zhang686860c2023-06-05 17:44:47 -0400409 // Other GPU characteristics.
410 int multiprocessor_count = 0;
411 IREE_CUDA_QUERY_ATTRIBUTE(MULTIPROCESSOR_COUNT, multiprocessor_count);
412 int clock_rate = 0;
413 IREE_CUDA_QUERY_ATTRIBUTE(CLOCK_RATE, clock_rate);
414 int warp_size = 0;
415 IREE_CUDA_QUERY_ATTRIBUTE(WARP_SIZE, warp_size);
416 int execution_timeout = 0;
417 IREE_CUDA_QUERY_ATTRIBUTE(KERNEL_EXEC_TIMEOUT, execution_timeout);
418
419 IREE_RETURN_IF_ERROR(iree_string_builder_append_cstring(builder, "\n"));
420 IREE_RETURN_IF_ERROR(iree_string_builder_append_format(
421 builder, "\n- gpu-multiprocessor-count: %d", multiprocessor_count));
422 IREE_RETURN_IF_ERROR(iree_string_builder_append_format(
423 builder, "\n- gpu-clock-rate: %d kHz", clock_rate));
424 IREE_RETURN_IF_ERROR(iree_string_builder_append_format(
425 builder, "\n- gpu-warp-size: %d", warp_size));
426 IREE_RETURN_IF_ERROR(iree_string_builder_append_format(
427 builder, "\n- kernel-has-execution-timeout: %d", execution_timeout));
428
429 IREE_RETURN_IF_ERROR(iree_string_builder_append_cstring(builder, "\n"));
430
431#undef IREE_CUDA_QUERY_ATTRIBUTE
432
Lei Zhang5c38bcc2023-06-05 17:29:11 -0400433 return iree_ok_status();
434}
435
436static iree_status_t iree_hal_cuda2_driver_select_default_device(
437 iree_hal_driver_t* base_driver, iree_hal_cuda2_dynamic_symbols_t* syms,
438 int default_device_index, iree_allocator_t host_allocator,
439 CUdevice* out_device) {
440 iree_hal_device_info_t* device_infos = NULL;
441 iree_host_size_t device_count = 0;
442 IREE_RETURN_IF_ERROR(iree_hal_cuda2_driver_query_available_devices(
443 base_driver, host_allocator, &device_count, &device_infos));
444
445 iree_status_t status = iree_ok_status();
446 if (device_count == 0) {
447 status = iree_make_status(IREE_STATUS_UNAVAILABLE,
448 "no compatible CUDA devices were found");
449 } else if (default_device_index >= device_count) {
450 status = iree_make_status(IREE_STATUS_NOT_FOUND,
Scott Todd60b07642023-06-15 09:41:01 -0700451 "default device %d not found (of %" PRIhsz
452 " enumerated)",
Lei Zhang5c38bcc2023-06-05 17:29:11 -0400453 default_device_index, device_count);
454 } else {
455 *out_device = IREE_DEVICE_ID_TO_CUDEVICE(
456 device_infos[default_device_index].device_id);
457 }
458 iree_allocator_free(host_allocator, device_infos);
459
460 return status;
461}
462
463static iree_status_t iree_hal_cuda2_driver_create_device_by_id(
464 iree_hal_driver_t* base_driver, iree_hal_device_id_t device_id,
465 iree_host_size_t param_count, const iree_string_pair_t* params,
466 iree_allocator_t host_allocator, iree_hal_device_t** out_device) {
467 IREE_ASSERT_ARGUMENT(base_driver);
Lei Zhang5c38bcc2023-06-05 17:29:11 -0400468 IREE_ASSERT_ARGUMENT(out_device);
469
470 iree_hal_cuda2_driver_t* driver = iree_hal_cuda2_driver_cast(base_driver);
471 IREE_TRACE_ZONE_BEGIN(z0);
472
473 // Ensure CUDA is initialized before querying it.
474 IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_cuda2_init(driver));
475
476 // Use either the specified device (enumerated earlier) or whatever default
477 // one was specified when the driver was created.
478 CUdevice device = 0;
479 if (device_id == IREE_HAL_DEVICE_ID_DEFAULT) {
480 IREE_RETURN_AND_END_ZONE_IF_ERROR(
481 z0, iree_hal_cuda2_driver_select_default_device(
482 base_driver, &driver->cuda_symbols,
483 driver->default_device_index, host_allocator, &device));
484 } else {
485 device = IREE_DEVICE_ID_TO_CUDEVICE(device_id);
486 }
Lei Zhangc4e01e92023-06-09 17:38:05 -0400487
488 iree_string_view_t device_name = iree_make_cstring_view("cuda2");
489
490 // Attempt to create the device now.
491 iree_status_t status = iree_hal_cuda2_device_create(
492 base_driver, device_name, &driver->device_params, &driver->cuda_symbols,
Lei Zhang330771e2023-06-13 19:39:23 -0400493 &driver->nccl_symbols, device, host_allocator, out_device);
Lei Zhang5c38bcc2023-06-05 17:29:11 -0400494
495 IREE_TRACE_ZONE_END(z0);
Lei Zhangc4e01e92023-06-09 17:38:05 -0400496 return status;
Lei Zhang5c38bcc2023-06-05 17:29:11 -0400497}
498
499static iree_status_t iree_hal_cuda2_driver_create_device_by_uuid(
500 iree_hal_driver_t* base_driver, iree_string_view_t driver_name,
501 const CUuuid* device_uuid, iree_host_size_t param_count,
502 const iree_string_pair_t* params, iree_allocator_t host_allocator,
503 iree_hal_device_t** out_device) {
504 iree_hal_cuda2_driver_t* driver = iree_hal_cuda2_driver_cast(base_driver);
505
506 // Ensure CUDA is initialized before querying it.
507 IREE_RETURN_IF_ERROR(iree_hal_cuda2_init(driver));
508
509 // CUDA doesn't have an API to do this so we need to scan all devices to
510 // find the one with the matching UUID.
511 int device_count = 0;
512 IREE_CUDA_RETURN_IF_ERROR(&driver->cuda_symbols,
513 cuDeviceGetCount(&device_count),
514 "cuDeviceGetCount");
515 CUdevice device = 0;
516 bool found_device = false;
517 for (int i = 0; i < device_count; i++) {
518 IREE_CUDA_RETURN_IF_ERROR(&driver->cuda_symbols, cuDeviceGet(&device, i),
519 "cuDeviceGet");
520 CUuuid query_uuid;
521 IREE_CUDA_RETURN_IF_ERROR(&driver->cuda_symbols,
522 cuDeviceGetUuid(&query_uuid, device),
523 "cuDeviceGetUuid");
524 if (memcmp(&device_uuid->bytes[0], &query_uuid.bytes[0],
525 sizeof(device_uuid)) == 0) {
526 found_device = true;
527 break;
528 }
529 }
530 if (!found_device) {
531 return iree_make_status(
532 IREE_STATUS_NOT_FOUND,
533 "CUDA device with UUID GPU-"
534 "%02x%02x%02x%02x-"
535 "%02x%02x-"
536 "%02x%02x-"
537 "%02x%02x-"
538 "%02x%02x%02x%02x%02x%02x"
539 " not found",
540 (uint8_t)device_uuid->bytes[0], (uint8_t)device_uuid->bytes[1],
541 (uint8_t)device_uuid->bytes[2], (uint8_t)device_uuid->bytes[3],
542 (uint8_t)device_uuid->bytes[4], (uint8_t)device_uuid->bytes[5],
543 (uint8_t)device_uuid->bytes[6], (uint8_t)device_uuid->bytes[7],
544 (uint8_t)device_uuid->bytes[8], (uint8_t)device_uuid->bytes[9],
545 (uint8_t)device_uuid->bytes[10], (uint8_t)device_uuid->bytes[11],
546 (uint8_t)device_uuid->bytes[12], (uint8_t)device_uuid->bytes[13],
547 (uint8_t)device_uuid->bytes[14], (uint8_t)device_uuid->bytes[15]);
548 }
549
550 iree_status_t status = iree_hal_cuda2_driver_create_device_by_id(
551 base_driver, IREE_CUDEVICE_TO_DEVICE_ID(device), param_count, params,
552 host_allocator, out_device);
553
554 return status;
555}
556
557static iree_status_t iree_hal_cuda2_driver_create_device_by_index(
558 iree_hal_driver_t* base_driver, iree_string_view_t driver_name,
559 int device_index, iree_host_size_t param_count,
560 const iree_string_pair_t* params, iree_allocator_t host_allocator,
561 iree_hal_device_t** out_device) {
562 iree_hal_cuda2_driver_t* driver = iree_hal_cuda2_driver_cast(base_driver);
563
564 // Ensure CUDA is initialized before querying it.
565 IREE_RETURN_IF_ERROR(iree_hal_cuda2_init(driver));
566
567 // Query the number of available CUDA devices.
568 int device_count = 0;
569 IREE_CUDA_RETURN_IF_ERROR(&driver->cuda_symbols,
570 cuDeviceGetCount(&device_count),
571 "cuDeviceGetCount");
572 if (device_index >= device_count) {
573 return iree_make_status(IREE_STATUS_NOT_FOUND,
574 "device %d not found (of %d enumerated)",
575 device_index, device_count);
576 }
577
578 CUdevice device = 0;
579 IREE_CUDA_RETURN_IF_ERROR(&driver->cuda_symbols,
580 cuDeviceGet(&device, device_index), "cuDeviceGet");
581
582 iree_status_t status = iree_hal_cuda2_driver_create_device_by_id(
583 base_driver, IREE_CUDEVICE_TO_DEVICE_ID(device), param_count, params,
584 host_allocator, out_device);
585
586 return status;
587}
588
589static iree_status_t iree_hal_cuda2_driver_create_device_by_path(
590 iree_hal_driver_t* base_driver, iree_string_view_t driver_name,
591 iree_string_view_t device_path, iree_host_size_t param_count,
592 const iree_string_pair_t* params, iree_allocator_t host_allocator,
593 iree_hal_device_t** out_device) {
594 IREE_ASSERT_ARGUMENT(base_driver);
Lei Zhang5c38bcc2023-06-05 17:29:11 -0400595 IREE_ASSERT_ARGUMENT(out_device);
596
597 if (iree_string_view_is_empty(device_path)) {
598 return iree_hal_cuda2_driver_create_device_by_id(
599 base_driver, IREE_HAL_DEVICE_ID_DEFAULT, param_count, params,
600 host_allocator, out_device);
601 }
602
603 if (iree_string_view_consume_prefix(&device_path, IREE_SV("GPU-"))) {
604 // UUID as returned by cuDeviceGetUuid.
605 CUuuid device_uuid;
606 if (!iree_string_view_parse_hex_bytes(device_path,
607 IREE_ARRAYSIZE(device_uuid.bytes),
608 (uint8_t*)device_uuid.bytes)) {
609 return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
610 "invalid GPU UUID: '%.*s'", (int)device_path.size,
611 device_path.data);
612 }
613 return iree_hal_cuda2_driver_create_device_by_uuid(
614 base_driver, driver_name, &device_uuid, param_count, params,
615 host_allocator, out_device);
616 }
617
618 // Try to parse as a device index.
619 int device_index = 0;
620 if (iree_string_view_atoi_int32(device_path, &device_index)) {
621 return iree_hal_cuda2_driver_create_device_by_index(
622 base_driver, driver_name, device_index, param_count, params,
623 host_allocator, out_device);
624 }
625
626 return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "unsupported device path");
627}
628
629static const iree_hal_driver_vtable_t iree_hal_cuda2_driver_vtable = {
630 .destroy = iree_hal_cuda2_driver_destroy,
631 .query_available_devices = iree_hal_cuda2_driver_query_available_devices,
632 .dump_device_info = iree_hal_cuda2_driver_dump_device_info,
633 .create_device_by_id = iree_hal_cuda2_driver_create_device_by_id,
634 .create_device_by_path = iree_hal_cuda2_driver_create_device_by_path,
635};