blob: fa6580969eedf1b378ed7f5955d3a8ec27196727 [file] [log] [blame]
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "experimental/cuda2/cuda_dynamic_symbols.h"
#include <string.h>
#include "experimental/cuda2/cuda_status_util.h"
#include "iree/base/api.h"
#include "iree/base/internal/dynamic_library.h"
static const char* iree_hal_cuda_dylib_names[] = {
#if defined(IREE_PLATFORM_WINDOWS)
"nvcuda.dll",
#else
"libcuda.so",
#endif // IREE_PLATFORM_WINDOWS
};
// CUDA API version for cuGetProcAddress.
// 1000 * major + 10 * minor
#define IREE_CUDA_DRIVER_API_VERSION 11030
// Load CUDA entry points.
static iree_status_t iree_hal_cuda2_dynamic_symbols_resolve_all(
iree_hal_cuda2_dynamic_symbols_t* syms) {
// Since cuGetProcAddress is in the symbol table, it will be loaded again
// through cuGetProcAddress. cuGetProcAddress_v2 is added in CUDA 12.0 and has
// a new function signature. If IREE_CUDA_DRIVER_API_VERSION is increased to
// >=12.0, then make sure we are using the correct signature.
IREE_RETURN_IF_ERROR(iree_dynamic_library_lookup_symbol(
syms->dylib, "cuGetProcAddress", (void**)&syms->cuGetProcAddress));
#define IREE_CU_PFN_DECL(cuda_symbol_name, ...) \
{ \
static const char* name = #cuda_symbol_name; \
IREE_CUDA_RETURN_IF_ERROR( \
syms, \
cuGetProcAddress(name, (void**)&syms->cuda_symbol_name, \
IREE_CUDA_DRIVER_API_VERSION, \
CU_GET_PROC_ADDRESS_DEFAULT), \
"when resolving " #cuda_symbol_name " using cuGetProcAddress"); \
}
#include "experimental/cuda2/cuda_dynamic_symbol_table.h" // IWYU pragma: keep
#undef IREE_CU_PFN_DECL
return iree_ok_status();
}
iree_status_t iree_hal_cuda2_dynamic_symbols_initialize(
iree_allocator_t host_allocator,
iree_hal_cuda2_dynamic_symbols_t* out_syms) {
IREE_ASSERT_ARGUMENT(out_syms);
IREE_TRACE_ZONE_BEGIN(z0);
memset(out_syms, 0, sizeof(*out_syms));
iree_status_t status = iree_dynamic_library_load_from_files(
IREE_ARRAYSIZE(iree_hal_cuda_dylib_names), iree_hal_cuda_dylib_names,
IREE_DYNAMIC_LIBRARY_FLAG_NONE, host_allocator, &out_syms->dylib);
if (iree_status_is_not_found(status)) {
iree_status_ignore(status);
status = iree_make_status(
IREE_STATUS_UNAVAILABLE,
"CUDA driver library 'libcuda.so'/'nvcuda.dll' not available; please "
"ensure installed and in dynamic library search path");
}
if (iree_status_is_ok(status)) {
status = iree_hal_cuda2_dynamic_symbols_resolve_all(out_syms);
}
if (!iree_status_is_ok(status)) {
iree_hal_cuda2_dynamic_symbols_deinitialize(out_syms);
}
IREE_TRACE_ZONE_END(z0);
return status;
}
void iree_hal_cuda2_dynamic_symbols_deinitialize(
iree_hal_cuda2_dynamic_symbols_t* syms) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_dynamic_library_release(syms->dylib);
memset(syms, 0, sizeof(*syms));
IREE_TRACE_ZONE_END(z0);
}