Add helper to allow dynamic loading of CUDA driver library. (#4823)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index d586279..243e3e6 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -114,6 +114,7 @@
# List of all HAL drivers to be built by default:
set(IREE_ALL_HAL_DRIVERS
+ Cuda
DyLib
VMLA
Vulkan
@@ -126,6 +127,10 @@
if(APPLE)
list(REMOVE_ITEM IREE_HAL_DRIVERS_TO_BUILD Vulkan)
endif()
+ # Remove Cuda from Android and Apple platforms.
+ if(ANDROID OR APPLE)
+ list(REMOVE_ITEM IREE_HAL_DRIVERS_TO_BUILD Cuda)
+ endif()
endif()
message(STATUS "Building HAL drivers: ${IREE_HAL_DRIVERS_TO_BUILD}")
@@ -382,6 +387,7 @@
include(flatbuffer_c_library)
add_subdirectory(third_party/benchmark EXCLUDE_FROM_ALL)
+add_subdirectory(build_tools/third_party/cuda_headers EXCLUDE_FROM_ALL)
add_subdirectory(build_tools/third_party/flatcc EXCLUDE_FROM_ALL)
add_subdirectory(build_tools/third_party/half EXCLUDE_FROM_ALL)
add_subdirectory(build_tools/third_party/pffft EXCLUDE_FROM_ALL)
diff --git a/build_tools/bazel/workspace.bzl b/build_tools/bazel/workspace.bzl
index ed68102..7e1afb4 100644
--- a/build_tools/bazel/workspace.bzl
+++ b/build_tools/bazel/workspace.bzl
@@ -123,3 +123,10 @@
build_file = iree_repo_alias + "//:build_tools/third_party/spirv_cross/BUILD.overlay",
path = paths.join(iree_path, "third_party/spirv_cross"),
)
+
+ maybe(
+ native.new_local_repository,
+ name = "cuda_headers",
+ build_file = iree_repo_alias + "//:build_tools/third_party/cuda_headers/BUILD.overlay",
+ path = paths.join(iree_path, "third_party/cuda_headers"),
+ )
diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
index 496e3ed..ad35390 100644
--- a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
+++ b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
@@ -51,6 +51,8 @@
"@llvm-project//mlir:TensorDialect": ["MLIRTensor"],
# Vulkan
"@iree_vulkan_headers//:vulkan_headers": ["Vulkan::Headers"],
+ # Cuda
+ "@cuda_headers//:cuda_headers": ["cuda_headers"],
# The Bazel target maps to the IMPORTED target defined by FindVulkan().
"@vulkan_sdk//:sdk": ["Vulkan::Vulkan"],
# Misc single targets
diff --git a/build_tools/third_party/cuda_headers/BUILD.overlay b/build_tools/third_party/cuda_headers/BUILD.overlay
new file mode 100644
index 0000000..ba40179
--- /dev/null
+++ b/build_tools/third_party/cuda_headers/BUILD.overlay
@@ -0,0 +1,21 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(default_visibility = ["//visibility:public"])
+
+cc_library(
+ name = "cuda_headers",
+ hdrs = ["cuda.h"],
+)
+
diff --git a/build_tools/third_party/cuda_headers/CMakeLists.txt b/build_tools/third_party/cuda_headers/CMakeLists.txt
new file mode 100644
index 0000000..8c3992e
--- /dev/null
+++ b/build_tools/third_party/cuda_headers/CMakeLists.txt
@@ -0,0 +1,29 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set(CUDA_HEADERS_API_ROOT "${IREE_ROOT_DIR}/third_party/cuda_headers/")
+
+external_cc_library(
+ PACKAGE
+ cuda_headers
+ NAME
+ cuda_headers
+ ROOT
+ ${CUDA_HEADERS_API_ROOT}
+ HDRS
+ "cuda.h"
+ INCLUDES
+ ${CUDA_HEADERS_API_ROOT}
+)
+
diff --git a/iree/hal/cuda/BUILD b/iree/hal/cuda/BUILD
new file mode 100644
index 0000000..18fc69b
--- /dev/null
+++ b/iree/hal/cuda/BUILD
@@ -0,0 +1,60 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_cmake_extra_content(
+ content = """
+if(NOT ${IREE_HAL_DRIVER_CUDA})
+ return()
+endif()
+""",
+)
+
+cc_library(
+ name = "dynamic_symbols",
+ srcs = [
+ "cuda_headers.h",
+ "dynamic_symbols.cc",
+ "dynamic_symbols_tables.h",
+ ],
+ hdrs = [
+ "dynamic_symbols.h",
+ ],
+ deps = [
+ "//iree/base:core_headers",
+ "//iree/base:dynamic_library",
+ "//iree/base:status",
+ "//iree/base:tracing",
+ "@com_google_absl//absl/types:span",
+ "@cuda_headers",
+ ],
+)
+
+cc_test(
+ name = "dynamic_symbols_test",
+ srcs = ["dynamic_symbols_test.cc"],
+ tags = ["driver=cuda"],
+ deps = [
+ ":dynamic_symbols",
+ "//iree/testing:gtest",
+ "//iree/testing:gtest_main",
+ ],
+)
diff --git a/iree/hal/cuda/CMakeLists.txt b/iree/hal/cuda/CMakeLists.txt
new file mode 100644
index 0000000..cc7667f
--- /dev/null
+++ b/iree/hal/cuda/CMakeLists.txt
@@ -0,0 +1,51 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if(NOT ${IREE_HAL_DRIVER_CUDA})
+ return()
+endif()
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ dynamic_symbols
+ HDRS
+ "dynamic_symbols.h"
+ SRCS
+ "cuda_headers.h"
+ "dynamic_symbols.cc"
+ "dynamic_symbols_tables.h"
+ DEPS
+ absl::span
+ cuda_headers
+ iree::base::core_headers
+ iree::base::dynamic_library
+ iree::base::status
+ iree::base::tracing
+ PUBLIC
+)
+
+iree_cc_test(
+ NAME
+ dynamic_symbols_test
+ SRCS
+ "dynamic_symbols_test.cc"
+ DEPS
+ ::dynamic_symbols
+ iree::testing::gtest
+ iree::testing::gtest_main
+ LABELS
+ "driver=cuda"
+)
diff --git a/iree/hal/cuda/cuda_headers.h b/iree/hal/cuda/cuda_headers.h
new file mode 100644
index 0000000..f5fd736
--- /dev/null
+++ b/iree/hal/cuda/cuda_headers.h
@@ -0,0 +1,20 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_CUDA_CUDA_HEADERS_H_
+#define IREE_HAL_CUDA_CUDA_HEADERS_H_
+
+#include "cuda.h"
+
+#endif // IREE_HAL_CUDA_CUDA_HEADERS_H_
diff --git a/iree/hal/cuda/dynamic_symbols.cc b/iree/hal/cuda/dynamic_symbols.cc
new file mode 100644
index 0000000..0927116
--- /dev/null
+++ b/iree/hal/cuda/dynamic_symbols.cc
@@ -0,0 +1,60 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/hal/cuda/dynamic_symbols.h"
+
+#include <cstddef>
+
+#include "absl/types/span.h"
+#include "iree/base/status.h"
+#include "iree/base/target_platform.h"
+#include "iree/base/tracing.h"
+
+namespace iree {
+namespace hal {
+namespace cuda {
+
+static const char* kCudaLoaderSearchNames[] = {
+#if defined(IREE_PLATFORM_WINDOWS)
+ "nvcuda.dll",
+#else
+ "libcuda.so",
+#endif
+};
+
+Status DynamicSymbols::LoadSymbols() {
+ IREE_TRACE_SCOPE();
+
+ IREE_RETURN_IF_ERROR(DynamicLibrary::Load(
+ absl::MakeSpan(kCudaLoaderSearchNames), &loader_library_));
+
+#define CU_PFN_DECL(cudaSymbolName) \
+ { \
+ using FuncPtrT = std::add_pointer<decltype(::cudaSymbolName)>::type; \
+ static const char* kName = #cudaSymbolName; \
+ cudaSymbolName = loader_library_->GetSymbol<FuncPtrT>(kName); \
+ if (!cudaSymbolName) { \
+ return iree_make_status(IREE_STATUS_UNAVAILABLE, "symbol not found"); \
+ } \
+ }
+
+#include "dynamic_symbols_tables.h"
+#undef CU_PFN_DECL
+
+ return OkStatus();
+}
+
+} // namespace cuda
+} // namespace hal
+} // namespace iree
diff --git a/iree/hal/cuda/dynamic_symbols.h b/iree/hal/cuda/dynamic_symbols.h
new file mode 100644
index 0000000..9d2c40e
--- /dev/null
+++ b/iree/hal/cuda/dynamic_symbols.h
@@ -0,0 +1,52 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_CUDA_DYNAMIC_SYMBOLS_H_
+#define IREE_HAL_CUDA_DYNAMIC_SYMBOLS_H_
+
+#include <cstdint>
+#include <functional>
+#include <memory>
+
+#include "iree/base/dynamic_library.h"
+#include "iree/base/status.h"
+#include "iree/hal/cuda/cuda_headers.h"
+
+namespace iree {
+namespace hal {
+namespace cuda {
+
+/// DyanmicSymbols allow loading dynamically a subset of CUDA driver API. It
+/// loads all the function declared in `dynamic_symbol_tables.def` and fail if
+/// any of the symbol is not available. The functions signatures are matching
+/// the declarations in `cuda.h`.
+struct DynamicSymbols {
+ Status LoadSymbols();
+
+#define CU_PFN_DECL(cudaSymbolName) \
+ std::add_pointer<decltype(::cudaSymbolName)>::type cudaSymbolName;
+
+#include "dynamic_symbols_tables.h"
+#undef CU_PFN_DECL
+
+ private:
+ // Cuda Loader dynamic library.
+ std::unique_ptr<DynamicLibrary> loader_library_;
+};
+
+} // namespace cuda
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_CUDA_DYNAMIC_SYMBOLS_H_
diff --git a/iree/hal/cuda/dynamic_symbols_tables.h b/iree/hal/cuda/dynamic_symbols_tables.h
new file mode 100644
index 0000000..5adece6
--- /dev/null
+++ b/iree/hal/cuda/dynamic_symbols_tables.h
@@ -0,0 +1,90 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+CU_PFN_DECL(cuCtxCreate)
+CU_PFN_DECL(cuCtxDestroy)
+CU_PFN_DECL(cuCtxEnablePeerAccess)
+CU_PFN_DECL(cuCtxGetCurrent)
+CU_PFN_DECL(cuCtxGetDevice)
+CU_PFN_DECL(cuCtxGetSharedMemConfig)
+CU_PFN_DECL(cuCtxSetCurrent)
+CU_PFN_DECL(cuCtxSetSharedMemConfig)
+CU_PFN_DECL(cuCtxSynchronize)
+CU_PFN_DECL(cuDeviceCanAccessPeer)
+CU_PFN_DECL(cuDeviceGet)
+CU_PFN_DECL(cuDeviceGetAttribute)
+CU_PFN_DECL(cuDeviceGetCount)
+CU_PFN_DECL(cuDeviceGetName)
+CU_PFN_DECL(cuDeviceGetPCIBusId)
+CU_PFN_DECL(cuDevicePrimaryCtxGetState)
+CU_PFN_DECL(cuDevicePrimaryCtxRelease)
+CU_PFN_DECL(cuDevicePrimaryCtxRetain)
+CU_PFN_DECL(cuDevicePrimaryCtxSetFlags)
+CU_PFN_DECL(cuDeviceTotalMem)
+CU_PFN_DECL(cuDriverGetVersion)
+CU_PFN_DECL(cuEventCreate)
+CU_PFN_DECL(cuEventDestroy)
+CU_PFN_DECL(cuEventElapsedTime)
+CU_PFN_DECL(cuEventQuery)
+CU_PFN_DECL(cuEventRecord)
+CU_PFN_DECL(cuEventSynchronize)
+CU_PFN_DECL(cuFuncGetAttribute)
+CU_PFN_DECL(cuFuncSetCacheConfig)
+CU_PFN_DECL(cuGetErrorName)
+CU_PFN_DECL(cuGetErrorString)
+CU_PFN_DECL(cuGraphAddMemcpyNode)
+CU_PFN_DECL(cuGraphAddMemsetNode)
+CU_PFN_DECL(cuGraphAddKernelNode)
+CU_PFN_DECL(cuGraphCreate)
+CU_PFN_DECL(cuGraphDestroy)
+CU_PFN_DECL(cuGraphExecDestroy)
+CU_PFN_DECL(cuGraphGetNodes)
+CU_PFN_DECL(cuGraphInstantiate)
+CU_PFN_DECL(cuGraphLaunch)
+CU_PFN_DECL(cuInit)
+CU_PFN_DECL(cuLaunchKernel)
+CU_PFN_DECL(cuMemAlloc)
+CU_PFN_DECL(cuMemAllocManaged)
+CU_PFN_DECL(cuMemFree)
+CU_PFN_DECL(cuMemFreeHost)
+CU_PFN_DECL(cuMemGetAddressRange)
+CU_PFN_DECL(cuMemGetInfo)
+CU_PFN_DECL(cuMemHostAlloc)
+CU_PFN_DECL(cuMemHostGetDevicePointer)
+CU_PFN_DECL(cuMemHostRegister)
+CU_PFN_DECL(cuMemHostUnregister)
+CU_PFN_DECL(cuMemcpyDtoD)
+CU_PFN_DECL(cuMemcpyDtoDAsync)
+CU_PFN_DECL(cuMemcpyDtoH)
+CU_PFN_DECL(cuMemcpyDtoHAsync)
+CU_PFN_DECL(cuMemcpyHtoD)
+CU_PFN_DECL(cuMemcpyHtoDAsync)
+CU_PFN_DECL(cuMemsetD32)
+CU_PFN_DECL(cuMemsetD32Async)
+CU_PFN_DECL(cuMemsetD8)
+CU_PFN_DECL(cuMemsetD8Async)
+CU_PFN_DECL(cuModuleGetFunction)
+CU_PFN_DECL(cuModuleGetGlobal)
+CU_PFN_DECL(cuModuleLoadDataEx)
+CU_PFN_DECL(cuModuleLoadFatBinary)
+CU_PFN_DECL(cuModuleUnload)
+CU_PFN_DECL(cuOccupancyMaxActiveBlocksPerMultiprocessor)
+CU_PFN_DECL(cuOccupancyMaxPotentialBlockSize)
+CU_PFN_DECL(cuPointerGetAttribute)
+CU_PFN_DECL(cuStreamAddCallback)
+CU_PFN_DECL(cuStreamCreate)
+CU_PFN_DECL(cuStreamDestroy)
+CU_PFN_DECL(cuStreamQuery)
+CU_PFN_DECL(cuStreamSynchronize)
+CU_PFN_DECL(cuStreamWaitEvent)
diff --git a/iree/hal/cuda/dynamic_symbols_test.cc b/iree/hal/cuda/dynamic_symbols_test.cc
new file mode 100644
index 0000000..6a7967c
--- /dev/null
+++ b/iree/hal/cuda/dynamic_symbols_test.cc
@@ -0,0 +1,51 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/hal/cuda/dynamic_symbols.h"
+
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+namespace iree {
+namespace hal {
+namespace cuda {
+namespace {
+
+#define CUDE_CHECK_ERRORS(expr) \
+ { \
+ CUresult status = expr; \
+ ASSERT_EQ(CUDA_SUCCESS, status); \
+ }
+
+TEST(DynamicSymbolsTest, CreateFromSystemLoader) {
+ DynamicSymbols symbols;
+ Status status = symbols.LoadSymbols();
+ if (!status.ok()) {
+ IREE_LOG(WARNING) << "Symbols cannot be loaded, skipping test.";
+ GTEST_SKIP();
+ }
+
+ int device_count = 0;
+ CUDE_CHECK_ERRORS(symbols.cuInit(0));
+ CUDE_CHECK_ERRORS(symbols.cuDeviceGetCount(&device_count));
+ if (device_count > 0) {
+ CUdevice device;
+ CUDE_CHECK_ERRORS(symbols.cuDeviceGet(&device, /*ordinal=*/0));
+ }
+}
+
+} // namespace
+} // namespace cuda
+} // namespace hal
+} // namespace iree