Add helper to allow dynamic loading of CUDA driver library. (#4823)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index d586279..243e3e6 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -114,6 +114,7 @@
 
 # List of all HAL drivers to be built by default:
 set(IREE_ALL_HAL_DRIVERS
+  Cuda
   DyLib
   VMLA
   Vulkan
@@ -126,6 +127,10 @@
   if(APPLE)
     list(REMOVE_ITEM IREE_HAL_DRIVERS_TO_BUILD Vulkan)
   endif()
+  # Remove Cuda from Android and Apple platforms.
+  if(ANDROID OR APPLE)
+    list(REMOVE_ITEM IREE_HAL_DRIVERS_TO_BUILD Cuda)
+  endif()
 endif()
 message(STATUS "Building HAL drivers: ${IREE_HAL_DRIVERS_TO_BUILD}")
 
@@ -382,6 +387,7 @@
 include(flatbuffer_c_library)
 
 add_subdirectory(third_party/benchmark EXCLUDE_FROM_ALL)
+add_subdirectory(build_tools/third_party/cuda_headers EXCLUDE_FROM_ALL)
 add_subdirectory(build_tools/third_party/flatcc EXCLUDE_FROM_ALL)
 add_subdirectory(build_tools/third_party/half EXCLUDE_FROM_ALL)
 add_subdirectory(build_tools/third_party/pffft EXCLUDE_FROM_ALL)
diff --git a/build_tools/bazel/workspace.bzl b/build_tools/bazel/workspace.bzl
index ed68102..7e1afb4 100644
--- a/build_tools/bazel/workspace.bzl
+++ b/build_tools/bazel/workspace.bzl
@@ -123,3 +123,10 @@
         build_file = iree_repo_alias + "//:build_tools/third_party/spirv_cross/BUILD.overlay",
         path = paths.join(iree_path, "third_party/spirv_cross"),
     )
+
+    maybe(
+        native.new_local_repository,
+        name = "cuda_headers",
+        build_file = iree_repo_alias + "//:build_tools/third_party/cuda_headers/BUILD.overlay",
+        path = paths.join(iree_path, "third_party/cuda_headers"),
+    )
diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
index 496e3ed..ad35390 100644
--- a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
+++ b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
@@ -51,6 +51,8 @@
     "@llvm-project//mlir:TensorDialect": ["MLIRTensor"],
     # Vulkan
     "@iree_vulkan_headers//:vulkan_headers": ["Vulkan::Headers"],
+    # Cuda
+    "@cuda_headers//:cuda_headers": ["cuda_headers"],
     # The Bazel target maps to the IMPORTED target defined by FindVulkan().
     "@vulkan_sdk//:sdk": ["Vulkan::Vulkan"],
     # Misc single targets
diff --git a/build_tools/third_party/cuda_headers/BUILD.overlay b/build_tools/third_party/cuda_headers/BUILD.overlay
new file mode 100644
index 0000000..ba40179
--- /dev/null
+++ b/build_tools/third_party/cuda_headers/BUILD.overlay
@@ -0,0 +1,21 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(default_visibility = ["//visibility:public"])
+
+cc_library(
+    name = "cuda_headers",
+    hdrs = ["cuda.h"],
+)
+
diff --git a/build_tools/third_party/cuda_headers/CMakeLists.txt b/build_tools/third_party/cuda_headers/CMakeLists.txt
new file mode 100644
index 0000000..8c3992e
--- /dev/null
+++ b/build_tools/third_party/cuda_headers/CMakeLists.txt
@@ -0,0 +1,29 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set(CUDA_HEADERS_API_ROOT "${IREE_ROOT_DIR}/third_party/cuda_headers/")
+
+external_cc_library(
+  PACKAGE
+    cuda_headers
+  NAME
+    cuda_headers
+  ROOT
+    ${CUDA_HEADERS_API_ROOT}
+  HDRS
+    "cuda.h"
+  INCLUDES
+    ${CUDA_HEADERS_API_ROOT}
+)
+
diff --git a/iree/hal/cuda/BUILD b/iree/hal/cuda/BUILD
new file mode 100644
index 0000000..18fc69b
--- /dev/null
+++ b/iree/hal/cuda/BUILD
@@ -0,0 +1,60 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content")
+
+package(
+    default_visibility = ["//visibility:public"],
+    features = ["layering_check"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+iree_cmake_extra_content(
+    content = """
+if(NOT ${IREE_HAL_DRIVER_CUDA})
+  return()
+endif()
+""",
+)
+
+cc_library(
+    name = "dynamic_symbols",
+    srcs = [
+        "cuda_headers.h",
+        "dynamic_symbols.cc",
+        "dynamic_symbols_tables.h",
+    ],
+    hdrs = [
+        "dynamic_symbols.h",
+    ],
+    deps = [
+        "//iree/base:core_headers",
+        "//iree/base:dynamic_library",
+        "//iree/base:status",
+        "//iree/base:tracing",
+        "@com_google_absl//absl/types:span",
+        "@cuda_headers",
+    ],
+)
+
+cc_test(
+    name = "dynamic_symbols_test",
+    srcs = ["dynamic_symbols_test.cc"],
+    tags = ["driver=cuda"],
+    deps = [
+        ":dynamic_symbols",
+        "//iree/testing:gtest",
+        "//iree/testing:gtest_main",
+    ],
+)
diff --git a/iree/hal/cuda/CMakeLists.txt b/iree/hal/cuda/CMakeLists.txt
new file mode 100644
index 0000000..cc7667f
--- /dev/null
+++ b/iree/hal/cuda/CMakeLists.txt
@@ -0,0 +1,51 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if(NOT ${IREE_HAL_DRIVER_CUDA})
+  return()
+endif()
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+  NAME
+    dynamic_symbols
+  HDRS
+    "dynamic_symbols.h"
+  SRCS
+    "cuda_headers.h"
+    "dynamic_symbols.cc"
+    "dynamic_symbols_tables.h"
+  DEPS
+    absl::span
+    cuda_headers
+    iree::base::core_headers
+    iree::base::dynamic_library
+    iree::base::status
+    iree::base::tracing
+  PUBLIC
+)
+
+iree_cc_test(
+  NAME
+    dynamic_symbols_test
+  SRCS
+    "dynamic_symbols_test.cc"
+  DEPS
+    ::dynamic_symbols
+    iree::testing::gtest
+    iree::testing::gtest_main
+  LABELS
+    "driver=cuda"
+)
diff --git a/iree/hal/cuda/cuda_headers.h b/iree/hal/cuda/cuda_headers.h
new file mode 100644
index 0000000..f5fd736
--- /dev/null
+++ b/iree/hal/cuda/cuda_headers.h
@@ -0,0 +1,20 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_CUDA_CUDA_HEADERS_H_
+#define IREE_HAL_CUDA_CUDA_HEADERS_H_
+
+#include "cuda.h"
+
+#endif  // IREE_HAL_CUDA_CUDA_HEADERS_H_
diff --git a/iree/hal/cuda/dynamic_symbols.cc b/iree/hal/cuda/dynamic_symbols.cc
new file mode 100644
index 0000000..0927116
--- /dev/null
+++ b/iree/hal/cuda/dynamic_symbols.cc
@@ -0,0 +1,60 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/hal/cuda/dynamic_symbols.h"
+
+#include <cstddef>
+
+#include "absl/types/span.h"
+#include "iree/base/status.h"
+#include "iree/base/target_platform.h"
+#include "iree/base/tracing.h"
+
+namespace iree {
+namespace hal {
+namespace cuda {
+
+static const char* kCudaLoaderSearchNames[] = {
+#if defined(IREE_PLATFORM_WINDOWS)
+    "nvcuda.dll",
+#else
+    "libcuda.so",
+#endif
+};
+
+Status DynamicSymbols::LoadSymbols() {
+  IREE_TRACE_SCOPE();
+
+  IREE_RETURN_IF_ERROR(DynamicLibrary::Load(
+      absl::MakeSpan(kCudaLoaderSearchNames), &loader_library_));
+
+#define CU_PFN_DECL(cudaSymbolName)                                         \
+  {                                                                         \
+    using FuncPtrT = std::add_pointer<decltype(::cudaSymbolName)>::type;    \
+    static const char* kName = #cudaSymbolName;                             \
+    cudaSymbolName = loader_library_->GetSymbol<FuncPtrT>(kName);           \
+    if (!cudaSymbolName) {                                                  \
+      return iree_make_status(IREE_STATUS_UNAVAILABLE, "symbol not found"); \
+    }                                                                       \
+  }
+
+#include "dynamic_symbols_tables.h"
+#undef CU_PFN_DECL
+
+  return OkStatus();
+}
+
+}  // namespace cuda
+}  // namespace hal
+}  // namespace iree
diff --git a/iree/hal/cuda/dynamic_symbols.h b/iree/hal/cuda/dynamic_symbols.h
new file mode 100644
index 0000000..9d2c40e
--- /dev/null
+++ b/iree/hal/cuda/dynamic_symbols.h
@@ -0,0 +1,52 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_CUDA_DYNAMIC_SYMBOLS_H_
+#define IREE_HAL_CUDA_DYNAMIC_SYMBOLS_H_
+
+#include <cstdint>
+#include <functional>
+#include <memory>
+
+#include "iree/base/dynamic_library.h"
+#include "iree/base/status.h"
+#include "iree/hal/cuda/cuda_headers.h"
+
+namespace iree {
+namespace hal {
+namespace cuda {
+
+/// DyanmicSymbols allow loading dynamically a subset of CUDA driver API. It
+/// loads all the function declared in `dynamic_symbol_tables.def` and fail if
+/// any of the symbol is not available. The functions signatures are matching
+/// the declarations in `cuda.h`.
+struct DynamicSymbols {
+  Status LoadSymbols();
+
+#define CU_PFN_DECL(cudaSymbolName) \
+  std::add_pointer<decltype(::cudaSymbolName)>::type cudaSymbolName;
+
+#include "dynamic_symbols_tables.h"
+#undef CU_PFN_DECL
+
+ private:
+  // Cuda Loader dynamic library.
+  std::unique_ptr<DynamicLibrary> loader_library_;
+};
+
+}  // namespace cuda
+}  // namespace hal
+}  // namespace iree
+
+#endif  // IREE_HAL_CUDA_DYNAMIC_SYMBOLS_H_
diff --git a/iree/hal/cuda/dynamic_symbols_tables.h b/iree/hal/cuda/dynamic_symbols_tables.h
new file mode 100644
index 0000000..5adece6
--- /dev/null
+++ b/iree/hal/cuda/dynamic_symbols_tables.h
@@ -0,0 +1,90 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+CU_PFN_DECL(cuCtxCreate)
+CU_PFN_DECL(cuCtxDestroy)
+CU_PFN_DECL(cuCtxEnablePeerAccess)
+CU_PFN_DECL(cuCtxGetCurrent)
+CU_PFN_DECL(cuCtxGetDevice)
+CU_PFN_DECL(cuCtxGetSharedMemConfig)
+CU_PFN_DECL(cuCtxSetCurrent)
+CU_PFN_DECL(cuCtxSetSharedMemConfig)
+CU_PFN_DECL(cuCtxSynchronize)
+CU_PFN_DECL(cuDeviceCanAccessPeer)
+CU_PFN_DECL(cuDeviceGet)
+CU_PFN_DECL(cuDeviceGetAttribute)
+CU_PFN_DECL(cuDeviceGetCount)
+CU_PFN_DECL(cuDeviceGetName)
+CU_PFN_DECL(cuDeviceGetPCIBusId)
+CU_PFN_DECL(cuDevicePrimaryCtxGetState)
+CU_PFN_DECL(cuDevicePrimaryCtxRelease)
+CU_PFN_DECL(cuDevicePrimaryCtxRetain)
+CU_PFN_DECL(cuDevicePrimaryCtxSetFlags)
+CU_PFN_DECL(cuDeviceTotalMem)
+CU_PFN_DECL(cuDriverGetVersion)
+CU_PFN_DECL(cuEventCreate)
+CU_PFN_DECL(cuEventDestroy)
+CU_PFN_DECL(cuEventElapsedTime)
+CU_PFN_DECL(cuEventQuery)
+CU_PFN_DECL(cuEventRecord)
+CU_PFN_DECL(cuEventSynchronize)
+CU_PFN_DECL(cuFuncGetAttribute)
+CU_PFN_DECL(cuFuncSetCacheConfig)
+CU_PFN_DECL(cuGetErrorName)
+CU_PFN_DECL(cuGetErrorString)
+CU_PFN_DECL(cuGraphAddMemcpyNode)
+CU_PFN_DECL(cuGraphAddMemsetNode)
+CU_PFN_DECL(cuGraphAddKernelNode)
+CU_PFN_DECL(cuGraphCreate)
+CU_PFN_DECL(cuGraphDestroy)
+CU_PFN_DECL(cuGraphExecDestroy)
+CU_PFN_DECL(cuGraphGetNodes)
+CU_PFN_DECL(cuGraphInstantiate)
+CU_PFN_DECL(cuGraphLaunch)
+CU_PFN_DECL(cuInit)
+CU_PFN_DECL(cuLaunchKernel)
+CU_PFN_DECL(cuMemAlloc)
+CU_PFN_DECL(cuMemAllocManaged)
+CU_PFN_DECL(cuMemFree)
+CU_PFN_DECL(cuMemFreeHost)
+CU_PFN_DECL(cuMemGetAddressRange)
+CU_PFN_DECL(cuMemGetInfo)
+CU_PFN_DECL(cuMemHostAlloc)
+CU_PFN_DECL(cuMemHostGetDevicePointer)
+CU_PFN_DECL(cuMemHostRegister)
+CU_PFN_DECL(cuMemHostUnregister)
+CU_PFN_DECL(cuMemcpyDtoD)
+CU_PFN_DECL(cuMemcpyDtoDAsync)
+CU_PFN_DECL(cuMemcpyDtoH)
+CU_PFN_DECL(cuMemcpyDtoHAsync)
+CU_PFN_DECL(cuMemcpyHtoD)
+CU_PFN_DECL(cuMemcpyHtoDAsync)
+CU_PFN_DECL(cuMemsetD32)
+CU_PFN_DECL(cuMemsetD32Async)
+CU_PFN_DECL(cuMemsetD8)
+CU_PFN_DECL(cuMemsetD8Async)
+CU_PFN_DECL(cuModuleGetFunction)
+CU_PFN_DECL(cuModuleGetGlobal)
+CU_PFN_DECL(cuModuleLoadDataEx)
+CU_PFN_DECL(cuModuleLoadFatBinary)
+CU_PFN_DECL(cuModuleUnload)
+CU_PFN_DECL(cuOccupancyMaxActiveBlocksPerMultiprocessor)
+CU_PFN_DECL(cuOccupancyMaxPotentialBlockSize)
+CU_PFN_DECL(cuPointerGetAttribute)
+CU_PFN_DECL(cuStreamAddCallback)
+CU_PFN_DECL(cuStreamCreate)
+CU_PFN_DECL(cuStreamDestroy)
+CU_PFN_DECL(cuStreamQuery)
+CU_PFN_DECL(cuStreamSynchronize)
+CU_PFN_DECL(cuStreamWaitEvent)
diff --git a/iree/hal/cuda/dynamic_symbols_test.cc b/iree/hal/cuda/dynamic_symbols_test.cc
new file mode 100644
index 0000000..6a7967c
--- /dev/null
+++ b/iree/hal/cuda/dynamic_symbols_test.cc
@@ -0,0 +1,51 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/hal/cuda/dynamic_symbols.h"
+
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+namespace iree {
+namespace hal {
+namespace cuda {
+namespace {
+
+#define CUDE_CHECK_ERRORS(expr)      \
+  {                                  \
+    CUresult status = expr;          \
+    ASSERT_EQ(CUDA_SUCCESS, status); \
+  }
+
+TEST(DynamicSymbolsTest, CreateFromSystemLoader) {
+  DynamicSymbols symbols;
+  Status status = symbols.LoadSymbols();
+  if (!status.ok()) {
+    IREE_LOG(WARNING) << "Symbols cannot be loaded, skipping test.";
+    GTEST_SKIP();
+  }
+
+  int device_count = 0;
+  CUDE_CHECK_ERRORS(symbols.cuInit(0));
+  CUDE_CHECK_ERRORS(symbols.cuDeviceGetCount(&device_count));
+  if (device_count > 0) {
+    CUdevice device;
+    CUDE_CHECK_ERRORS(symbols.cuDeviceGet(&device, /*ordinal=*/0));
+  }
+}
+
+}  // namespace
+}  // namespace cuda
+}  // namespace hal
+}  // namespace iree