Cuda HAL backend step 1 (#4874)

register the backend and add stub for the device. Add support for
buffers and allocator.
diff --git a/iree/hal/cts/command_buffer_test.cc b/iree/hal/cts/command_buffer_test.cc
index b4ee8f5..1c4ea00 100644
--- a/iree/hal/cts/command_buffer_test.cc
+++ b/iree/hal/cts/command_buffer_test.cc
@@ -33,6 +33,9 @@
 using ::testing::ContainerEq;
 
 class CommandBufferTest : public CtsTestBase {
+ public:
+  CommandBufferTest() { declareUnimplementedDriver("cuda"); }
+
  protected:
   static constexpr iree_device_size_t kBufferSize = 4096;
 };
@@ -96,26 +99,28 @@
   // Fill the device buffer with segments of different values so that we can
   // test both fill and offset/size.
   uint8_t val1 = 0x07;
-  iree_hal_command_buffer_fill_buffer(
+  IREE_ASSERT_OK(iree_hal_command_buffer_fill_buffer(
       command_buffer, device_buffer,
       /*target_offset=*/0, /*length=*/kBufferSize / 4, /*pattern=*/&val1,
-      /*pattern_length=*/sizeof(val1));
+      /*pattern_length=*/sizeof(val1)));
   std::memset(reference_buffer.data(), val1, kBufferSize / 4);
 
   uint8_t val2 = 0xbe;
-  iree_hal_command_buffer_fill_buffer(command_buffer, device_buffer,
-                                      /*target_offset=*/kBufferSize / 4,
-                                      /*length=*/kBufferSize / 4,
-                                      /*pattern=*/&val2,
-                                      /*pattern_length=*/sizeof(val2));
+  IREE_ASSERT_OK(
+      iree_hal_command_buffer_fill_buffer(command_buffer, device_buffer,
+                                          /*target_offset=*/kBufferSize / 4,
+                                          /*length=*/kBufferSize / 4,
+                                          /*pattern=*/&val2,
+                                          /*pattern_length=*/sizeof(val2)));
   std::memset(reference_buffer.data() + kBufferSize / 4, val2, kBufferSize / 4);
 
   uint8_t val3 = 0x54;
-  iree_hal_command_buffer_fill_buffer(command_buffer, device_buffer,
-                                      /*target_offset=*/kBufferSize / 2,
-                                      /*length=*/kBufferSize / 2,
-                                      /*pattern=*/&val3,
-                                      /*pattern_length=*/sizeof(val3));
+  IREE_ASSERT_OK(
+      iree_hal_command_buffer_fill_buffer(command_buffer, device_buffer,
+                                          /*target_offset=*/kBufferSize / 2,
+                                          /*length=*/kBufferSize / 2,
+                                          /*pattern=*/&val3,
+                                          /*pattern_length=*/sizeof(val3)));
   std::memset(reference_buffer.data() + kBufferSize / 2, val3, kBufferSize / 2);
 
   IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer));
diff --git a/iree/hal/cts/cts_test_base.h b/iree/hal/cts/cts_test_base.h
index 09875ad..cf7d58c 100644
--- a/iree/hal/cts/cts_test_base.h
+++ b/iree/hal/cts/cts_test_base.h
@@ -33,6 +33,12 @@
  protected:
   virtual void SetUp() {
     const std::string& driver_name = GetParam();
+    if (driver_block_list_.find(driver_name) != driver_block_list_.end()) {
+      IREE_LOG(WARNING)
+          << "Skipping test as driver is explicitly disabled for this test";
+      GTEST_SKIP();
+      return;
+    }
 
     // Get driver with the given name and create its default device.
     // Skip drivers that are (gracefully) unavailable, fail if creation fails.
@@ -122,6 +128,14 @@
   iree_hal_driver_t* driver_ = nullptr;
   iree_hal_device_t* device_ = nullptr;
   iree_hal_allocator_t* device_allocator_ = nullptr;
+  // Allow skipping tests for driver under development.
+  void declareUnimplementedDriver(const std::string& driver_name) {
+    driver_block_list_.insert(driver_name);
+  }
+  // Allow skipping tests for unsupported features.
+  void declareUnavailableDriver(const std::string& driver_name) {
+    driver_block_list_.insert(driver_name);
+  }
 
  private:
   // Gets a HAL driver with the provided name, if available.
@@ -149,6 +163,7 @@
     }
     return status;
   }
+  std::set<std::string> driver_block_list_;
 };
 
 struct GenerateTestName {
diff --git a/iree/hal/cts/descriptor_set_layout_test.cc b/iree/hal/cts/descriptor_set_layout_test.cc
index 14dd2b5..3b7df35 100644
--- a/iree/hal/cts/descriptor_set_layout_test.cc
+++ b/iree/hal/cts/descriptor_set_layout_test.cc
@@ -21,7 +21,10 @@
 namespace hal {
 namespace cts {
 
-class DescriptorSetLayoutTest : public CtsTestBase {};
+class DescriptorSetLayoutTest : public CtsTestBase {
+ public:
+  DescriptorSetLayoutTest() { declareUnimplementedDriver("cuda"); }
+};
 
 // Note: bindingCount == 0 is valid in VkDescriptorSetLayoutCreateInfo:
 // https://www.khronos.org/registry/vulkan/specs/1.2-extensions/man/html/VkDescriptorSetLayoutCreateInfo.html
diff --git a/iree/hal/cts/event_test.cc b/iree/hal/cts/event_test.cc
index b728277..083be3a 100644
--- a/iree/hal/cts/event_test.cc
+++ b/iree/hal/cts/event_test.cc
@@ -21,7 +21,10 @@
 namespace hal {
 namespace cts {
 
-class EventTest : public CtsTestBase {};
+class EventTest : public CtsTestBase {
+ public:
+  EventTest() { declareUnimplementedDriver("cuda"); }
+};
 
 TEST_P(EventTest, Create) {
   iree_hal_event_t* event;
diff --git a/iree/hal/cts/executable_layout_test.cc b/iree/hal/cts/executable_layout_test.cc
index 1df73bd..f4a3374 100644
--- a/iree/hal/cts/executable_layout_test.cc
+++ b/iree/hal/cts/executable_layout_test.cc
@@ -21,7 +21,10 @@
 namespace hal {
 namespace cts {
 
-class ExecutableLayoutTest : public CtsTestBase {};
+class ExecutableLayoutTest : public CtsTestBase {
+ public:
+  ExecutableLayoutTest() { declareUnimplementedDriver("cuda"); }
+};
 
 TEST_P(ExecutableLayoutTest, CreateWithNoLayouts) {
   iree_hal_executable_layout_t* executable_layout;
diff --git a/iree/hal/cts/semaphore_submission_test.cc b/iree/hal/cts/semaphore_submission_test.cc
index 5f9c529..5cec41f 100644
--- a/iree/hal/cts/semaphore_submission_test.cc
+++ b/iree/hal/cts/semaphore_submission_test.cc
@@ -20,7 +20,11 @@
 namespace hal {
 namespace cts {
 
-class SemaphoreSubmissionTest : public CtsTestBase {};
+class SemaphoreSubmissionTest : public CtsTestBase {
+ public:
+  // Disable cuda backend for this test as semaphores are not implemented yet.
+  SemaphoreSubmissionTest() { declareUnavailableDriver("cuda"); }
+};
 
 TEST_P(SemaphoreSubmissionTest, SubmitWithNoCommandBuffers) {
   // No waits, one signal which we immediately wait on after submit.
diff --git a/iree/hal/cts/semaphore_test.cc b/iree/hal/cts/semaphore_test.cc
index c774241..ac02fca 100644
--- a/iree/hal/cts/semaphore_test.cc
+++ b/iree/hal/cts/semaphore_test.cc
@@ -22,7 +22,11 @@
 namespace hal {
 namespace cts {
 
-class SemaphoreTest : public CtsTestBase {};
+class SemaphoreTest : public CtsTestBase {
+ public:
+  // Disable cuda backend for this test as semaphores are not implemented yet.
+  SemaphoreTest() { declareUnavailableDriver("cuda"); }
+};
 
 // Tests that a semaphore that is unused properly cleans itself up.
 TEST_P(SemaphoreTest, NoOp) {
diff --git a/iree/hal/cuda/BUILD b/iree/hal/cuda/BUILD
index 18fc69b..831fbae 100644
--- a/iree/hal/cuda/BUILD
+++ b/iree/hal/cuda/BUILD
@@ -29,6 +29,39 @@
 )
 
 cc_library(
+    name = "cuda",
+    srcs = [
+        "api.h",
+        "context_wrapper.h",
+        "cuda_allocator.c",
+        "cuda_allocator.h",
+        "cuda_buffer.c",
+        "cuda_buffer.h",
+        "cuda_device.c",
+        "cuda_device.h",
+        "cuda_driver.c",
+        "status_util.c",
+        "status_util.h",
+    ],
+    hdrs = [
+        "api.h",
+    ],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":dynamic_symbols",
+        "//iree/base:api",
+        "//iree/base:core_headers",
+        "//iree/base:flatcc",
+        "//iree/base:logging",
+        "//iree/base:status",
+        "//iree/base:synchronization",
+        "//iree/base:tracing",
+        "//iree/base/internal",
+        "//iree/hal:api",
+    ],
+)
+
+cc_library(
     name = "dynamic_symbols",
     srcs = [
         "cuda_headers.h",
diff --git a/iree/hal/cuda/CMakeLists.txt b/iree/hal/cuda/CMakeLists.txt
index f9e7d76..6df72d1 100644
--- a/iree/hal/cuda/CMakeLists.txt
+++ b/iree/hal/cuda/CMakeLists.txt
@@ -8,6 +8,37 @@
 
 iree_cc_library(
   NAME
+    cuda
+  HDRS
+    "api.h"
+  SRCS
+    "api.h"
+    "context_wrapper.h"
+    "cuda_allocator.c"
+    "cuda_allocator.h"
+    "cuda_buffer.c"
+    "cuda_buffer.h"
+    "cuda_device.c"
+    "cuda_device.h"
+    "cuda_driver.c"
+    "status_util.c"
+    "status_util.h"
+  DEPS
+    ::dynamic_symbols
+    iree::base::api
+    iree::base::core_headers
+    iree::base::flatcc
+    iree::base::internal
+    iree::base::logging
+    iree::base::status
+    iree::base::synchronization
+    iree::base::tracing
+    iree::hal::api
+  PUBLIC
+)
+
+iree_cc_library(
+  NAME
     dynamic_symbols
   HDRS
     "dynamic_symbols.h"
diff --git a/iree/hal/cuda/api.h b/iree/hal/cuda/api.h
new file mode 100644
index 0000000..3c219b7
--- /dev/null
+++ b/iree/hal/cuda/api.h
@@ -0,0 +1,55 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// See iree/base/api.h for documentation on the API conventions used.
+
+#ifndef IREE_HAL_CUDA_API_H_
+#define IREE_HAL_CUDA_API_H_
+
+#include "iree/base/api.h"
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+//===----------------------------------------------------------------------===//
+// iree_hal_cuda_driver_t
+//===----------------------------------------------------------------------===//
+
+// CUDA driver creation options.
+typedef struct {
+  // Index of the default CUDA device to use within the list of available
+  // devices.
+  int default_device_index;
+} iree_hal_cuda_driver_options_t;
+
+IREE_API_EXPORT void IREE_API_CALL iree_hal_cuda_driver_options_initialize(
+    iree_hal_cuda_driver_options_t* out_options);
+
+// Creates a CUDA HAL driver that manage its own CUcontext.
+//
+// |out_driver| must be released by the caller (see |iree_hal_driver_release|).
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_cuda_driver_create(
+    iree_string_view_t identifier,
+    const iree_hal_cuda_driver_options_t* options,
+    iree_allocator_t host_allocator, iree_hal_driver_t** out_driver);
+
+// TODO(thomasraoux): Support importing a CUcontext from app.
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_CUDA_API_H_
diff --git a/iree/hal/cuda/context_wrapper.h b/iree/hal/cuda/context_wrapper.h
new file mode 100644
index 0000000..304657d
--- /dev/null
+++ b/iree/hal/cuda/context_wrapper.h
@@ -0,0 +1,30 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_CUDA_CONTEXT_WRAPPER_H_
+#define IREE_HAL_CUDA_CONTEXT_WRAPPER_H_
+
+#include "iree/hal/api.h"
+#include "iree/hal/cuda/cuda_headers.h"
+#include "iree/hal/cuda/dynamic_symbols.h"
+
+// Structure to wrap all objects constant within a context. This makes it
+// simpler to pass it to the different objects and saves memory.
+typedef struct {
+  CUcontext cu_context;
+  iree_allocator_t host_allocator;
+  iree_hal_cuda_dynamic_symbols_t* syms;
+} iree_hal_cuda_context_wrapper_t;
+
+#endif  // IREE_HAL_CUDA_CONTEXT_WRAPPER_H_
diff --git a/iree/hal/cuda/cuda_allocator.c b/iree/hal/cuda/cuda_allocator.c
new file mode 100644
index 0000000..9995f2b
--- /dev/null
+++ b/iree/hal/cuda/cuda_allocator.c
@@ -0,0 +1,175 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations ufnder the License.
+
+#include "iree/hal/cuda/cuda_allocator.h"
+
+#include "iree/base/tracing.h"
+#include "iree/hal/cuda/cuda_buffer.h"
+#include "iree/hal/cuda/status_util.h"
+
+typedef struct iree_hal_cuda_allocator_s {
+  iree_hal_resource_t resource;
+  iree_hal_cuda_context_wrapper_t* context;
+} iree_hal_cuda_allocator_t;
+
+extern const iree_hal_allocator_vtable_t iree_hal_cuda_allocator_vtable;
+
+static iree_hal_cuda_allocator_t* iree_hal_cuda_allocator_cast(
+    iree_hal_allocator_t* base_value) {
+  IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_cuda_allocator_vtable);
+  return (iree_hal_cuda_allocator_t*)base_value;
+}
+
+iree_status_t iree_hal_cuda_allocator_create(
+    iree_hal_cuda_context_wrapper_t* context,
+    iree_hal_allocator_t** out_allocator) {
+  IREE_ASSERT_ARGUMENT(context);
+  IREE_TRACE_ZONE_BEGIN(z0);
+  iree_hal_cuda_allocator_t* allocator = NULL;
+  iree_status_t status = iree_allocator_malloc(
+      context->host_allocator, sizeof(*allocator), (void**)&allocator);
+  if (iree_status_is_ok(status)) {
+    iree_hal_resource_initialize(&iree_hal_cuda_allocator_vtable,
+                                 &allocator->resource);
+    allocator->context = context;
+    *out_allocator = (iree_hal_allocator_t*)allocator;
+  }
+
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+static void iree_hal_cuda_allocator_destroy(
+    iree_hal_allocator_t* base_allocator) {
+  iree_hal_cuda_allocator_t* allocator =
+      iree_hal_cuda_allocator_cast(base_allocator);
+  iree_allocator_t host_allocator = allocator->context->host_allocator;
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  iree_allocator_free(host_allocator, allocator);
+
+  IREE_TRACE_ZONE_END(z0);
+}
+
+static iree_allocator_t iree_hal_cuda_allocator_host_allocator(
+    const iree_hal_allocator_t* base_allocator) {
+  iree_hal_cuda_allocator_t* allocator =
+      (iree_hal_cuda_allocator_t*)base_allocator;
+  return allocator->context->host_allocator;
+}
+
+static iree_hal_buffer_compatibility_t
+iree_hal_cuda_allocator_query_buffer_compatibility(
+    iree_hal_allocator_t* base_allocator, iree_hal_memory_type_t memory_type,
+    iree_hal_buffer_usage_t allowed_usage,
+    iree_hal_buffer_usage_t intended_usage,
+    iree_device_size_t allocation_size) {
+  // TODO(benvanik): check to ensure the allocator can serve the memory type.
+
+  // Disallow usage not permitted by the buffer itself. Since we then use this
+  // to determine compatibility below we'll naturally set the right compat flags
+  // based on what's both allowed and intended.
+  intended_usage &= allowed_usage;
+
+  // All buffers can be allocated on the heap.
+  iree_hal_buffer_compatibility_t compatibility =
+      IREE_HAL_BUFFER_COMPATIBILITY_ALLOCATABLE;
+
+  // Buffers can only be used on the queue if they are device visible.
+  if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)) {
+    if (iree_all_bits_set(intended_usage, IREE_HAL_BUFFER_USAGE_TRANSFER)) {
+      compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER;
+    }
+    if (iree_all_bits_set(intended_usage, IREE_HAL_BUFFER_USAGE_DISPATCH)) {
+      compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH;
+    }
+  }
+
+  return compatibility;
+}
+
+static iree_status_t iree_hal_cuda_allocator_allocate_buffer(
+    iree_hal_allocator_t* base_allocator, iree_hal_memory_type_t memory_type,
+    iree_hal_buffer_usage_t allowed_usage, iree_host_size_t allocation_size,
+    iree_hal_buffer_t** out_buffer) {
+  iree_hal_cuda_allocator_t* allocator =
+      iree_hal_cuda_allocator_cast(base_allocator);
+  // Guard against the corner case where the requested buffer size is 0. The
+  // application is unlikely to do anything when requesting a 0-byte buffer; but
+  // it can happen in real world use cases. So we should at least not crash.
+  if (allocation_size == 0) allocation_size = 4;
+  iree_status_t status;
+  void* host_ptr = NULL;
+  CUdeviceptr device_ptr = 0;
+  if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) {
+    unsigned int flags = CU_MEMHOSTALLOC_DEVICEMAP;
+    if (!iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_HOST_CACHED)) {
+      flags |= CU_MEMHOSTALLOC_WRITECOMBINED;
+    }
+    status =
+        CU_RESULT_TO_STATUS(allocator->context->syms,
+                            cuMemHostAlloc(&host_ptr, allocation_size, flags));
+    if (iree_status_is_ok(status)) {
+      status = CU_RESULT_TO_STATUS(
+          allocator->context->syms,
+          cuMemHostGetDevicePointer(&device_ptr, host_ptr, /*flags=*/0));
+    }
+  } else {
+    status = CU_RESULT_TO_STATUS(allocator->context->syms,
+                                 cuMemAlloc(&device_ptr, allocation_size));
+  }
+
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_cuda_buffer_wrap(
+        (iree_hal_allocator_t*)allocator, memory_type,
+        IREE_HAL_MEMORY_ACCESS_ALL, allowed_usage, allocation_size,
+        /*byte_offset=*/0,
+        /*byte_length=*/allocation_size, device_ptr, host_ptr, out_buffer);
+  }
+  if (!iree_status_is_ok(status)) {
+    iree_hal_cuda_allocator_free(base_allocator, device_ptr, host_ptr,
+                                 memory_type);
+  }
+  return status;
+}
+
+void iree_hal_cuda_allocator_free(iree_hal_allocator_t* base_allocator,
+                                  CUdeviceptr device_ptr, void* host_ptr,
+                                  iree_hal_memory_type_t memory_type) {
+  iree_hal_cuda_allocator_t* allocator =
+      iree_hal_cuda_allocator_cast(base_allocator);
+  if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) {
+    CUDA_IGNORE_ERROR(allocator->context->syms, cuMemFreeHost(host_ptr));
+  } else {
+    CUDA_IGNORE_ERROR(allocator->context->syms, cuMemFree(device_ptr));
+  }
+}
+
+static iree_status_t iree_hal_cuda_allocator_wrap_buffer(
+    iree_hal_allocator_t* base_allocator, iree_hal_memory_type_t memory_type,
+    iree_hal_memory_access_t allowed_access,
+    iree_hal_buffer_usage_t allowed_usage, iree_byte_span_t data,
+    iree_allocator_t data_allocator, iree_hal_buffer_t** out_buffer) {
+  return iree_make_status(IREE_STATUS_UNAVAILABLE,
+                          "wrapping of external buffers not supported");
+}
+
+const iree_hal_allocator_vtable_t iree_hal_cuda_allocator_vtable = {
+    .destroy = iree_hal_cuda_allocator_destroy,
+    .host_allocator = iree_hal_cuda_allocator_host_allocator,
+    .query_buffer_compatibility =
+        iree_hal_cuda_allocator_query_buffer_compatibility,
+    .allocate_buffer = iree_hal_cuda_allocator_allocate_buffer,
+    .wrap_buffer = iree_hal_cuda_allocator_wrap_buffer,
+};
diff --git a/iree/hal/cuda/cuda_allocator.h b/iree/hal/cuda/cuda_allocator.h
new file mode 100644
index 0000000..fcc015b
--- /dev/null
+++ b/iree/hal/cuda/cuda_allocator.h
@@ -0,0 +1,40 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_CUDA_ALLOCATOR_H_
+#define IREE_HAL_CUDA_ALLOCATOR_H_
+
+#include "iree/hal/api.h"
+#include "iree/hal/cuda/context_wrapper.h"
+#include "iree/hal/cuda/status_util.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Create a cuda allocator.
+iree_status_t iree_hal_cuda_allocator_create(
+    iree_hal_cuda_context_wrapper_t* context,
+    iree_hal_allocator_t** out_allocator);
+
+// Free an allocation represent by the given device or host pointer.
+void iree_hal_cuda_allocator_free(iree_hal_allocator_t* allocator,
+                                  CUdeviceptr device_ptr, void* host_ptr,
+                                  iree_hal_memory_type_t memory_type);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_CUDA_ALLOCATOR_H_
diff --git a/iree/hal/cuda/cuda_buffer.c b/iree/hal/cuda/cuda_buffer.c
new file mode 100644
index 0000000..55a2115
--- /dev/null
+++ b/iree/hal/cuda/cuda_buffer.c
@@ -0,0 +1,141 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/hal/cuda/cuda_buffer.h"
+
+#include "iree/base/tracing.h"
+#include "iree/hal/cuda/cuda_allocator.h"
+#include "iree/hal/cuda/status_util.h"
+
+typedef struct iree_hal_cuda_buffer_s {
+  iree_hal_buffer_t base;
+  void* host_ptr;
+  CUdeviceptr device_ptr;
+} iree_hal_cuda_buffer_t;
+
+extern const iree_hal_buffer_vtable_t iree_hal_cuda_buffer_vtable;
+
+static iree_hal_cuda_buffer_t* iree_hal_cuda_buffer_cast(
+    iree_hal_buffer_t* base_value) {
+  IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_cuda_buffer_vtable);
+  return (iree_hal_cuda_buffer_t*)base_value;
+}
+
+iree_status_t iree_hal_cuda_buffer_wrap(
+    iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type,
+    iree_hal_memory_access_t allowed_access,
+    iree_hal_buffer_usage_t allowed_usage, iree_device_size_t allocation_size,
+    iree_device_size_t byte_offset, iree_device_size_t byte_length,
+    CUdeviceptr device_ptr, void* host_ptr, iree_hal_buffer_t** out_buffer) {
+  IREE_ASSERT_ARGUMENT(allocator);
+  IREE_ASSERT_ARGUMENT(out_buffer);
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  iree_hal_cuda_buffer_t* buffer = NULL;
+  iree_status_t status =
+      iree_allocator_malloc(iree_hal_allocator_host_allocator(allocator),
+                            sizeof(*buffer), (void**)&buffer);
+  if (iree_status_is_ok(status)) {
+    iree_hal_resource_initialize(&iree_hal_cuda_buffer_vtable,
+                                 &buffer->base.resource);
+    buffer->base.allocator = allocator;
+    buffer->base.allocated_buffer = &buffer->base;
+    buffer->base.allocation_size = allocation_size;
+    buffer->base.byte_offset = byte_offset;
+    buffer->base.byte_length = byte_length;
+    buffer->base.memory_type = memory_type;
+    buffer->base.allowed_access = allowed_access;
+    buffer->base.allowed_usage = allowed_usage;
+    buffer->host_ptr = host_ptr;
+    buffer->device_ptr = device_ptr;
+    *out_buffer = &buffer->base;
+  }
+
+  IREE_TRACE_ZONE_END(z0);
+  return iree_ok_status();
+}
+
+static void iree_hal_cuda_buffer_destroy(iree_hal_buffer_t* base_buffer) {
+  iree_hal_cuda_buffer_t* buffer = iree_hal_cuda_buffer_cast(base_buffer);
+  iree_allocator_t host_allocator =
+      iree_hal_allocator_host_allocator(iree_hal_buffer_allocator(base_buffer));
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  iree_hal_cuda_allocator_free(buffer->base.allocator, buffer->device_ptr,
+                               buffer->host_ptr, buffer->base.memory_type);
+  iree_allocator_free(host_allocator, buffer);
+
+  IREE_TRACE_ZONE_END(z0);
+}
+
+static iree_status_t iree_hal_cuda_buffer_map_range(
+    iree_hal_buffer_t* base_buffer, iree_hal_mapping_mode_t mapping_mode,
+    iree_hal_memory_access_t memory_access,
+    iree_device_size_t local_byte_offset, iree_device_size_t local_byte_length,
+    void** out_data_ptr) {
+  iree_hal_cuda_buffer_t* buffer = iree_hal_cuda_buffer_cast(base_buffer);
+
+  if (!iree_all_bits_set(buffer->base.memory_type,
+                         IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) {
+    return iree_make_status(IREE_STATUS_INTERNAL,
+                            "trying to map memory not host visible");
+  }
+
+  uint8_t* data_ptr = (uint8_t*)(buffer->host_ptr) + local_byte_offset;
+  // If we mapped for discard scribble over the bytes. This is not a mandated
+  // behavior but it will make debugging issues easier. Alternatively for
+  // heap buffers we could reallocate them such that ASAN yells, but that
+  // would only work if the entire buffer was discarded.
+#ifndef NDEBUG
+  if (iree_any_bit_set(memory_access, IREE_HAL_MEMORY_ACCESS_DISCARD)) {
+    memset(data_ptr + local_byte_offset, 0xCD, local_byte_length);
+  }
+#endif  // !NDEBUG
+  *out_data_ptr = data_ptr;
+  return iree_ok_status();
+}
+
+static void iree_hal_cuda_buffer_unmap_range(
+    iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset,
+    iree_device_size_t local_byte_length, void* data_ptr) {
+  // nothing to do.
+}
+
+static iree_status_t iree_hal_cuda_buffer_invalidate_range(
+    iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset,
+    iree_device_size_t local_byte_length) {
+  // Nothing to do.
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_cuda_buffer_flush_range(
+    iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset,
+    iree_device_size_t local_byte_length) {
+  // Nothing to do.
+  return iree_ok_status();
+}
+
+CUdeviceptr iree_hal_cuda_buffer_device_pointer(
+    iree_hal_buffer_t* base_buffer) {
+  iree_hal_cuda_buffer_t* buffer = iree_hal_cuda_buffer_cast(base_buffer);
+  return buffer->device_ptr;
+}
+
+const iree_hal_buffer_vtable_t iree_hal_cuda_buffer_vtable = {
+    .destroy = iree_hal_cuda_buffer_destroy,
+    .map_range = iree_hal_cuda_buffer_map_range,
+    .unmap_range = iree_hal_cuda_buffer_unmap_range,
+    .invalidate_range = iree_hal_cuda_buffer_invalidate_range,
+    .flush_range = iree_hal_cuda_buffer_flush_range,
+};
diff --git a/iree/hal/cuda/cuda_buffer.h b/iree/hal/cuda/cuda_buffer.h
new file mode 100644
index 0000000..b9dcd9f
--- /dev/null
+++ b/iree/hal/cuda/cuda_buffer.h
@@ -0,0 +1,42 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_CUDA_BUFFER_H_
+#define IREE_HAL_CUDA_BUFFER_H_
+
+#include "iree/hal/api.h"
+#include "iree/hal/cuda/cuda_headers.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Wraps a cuda allocation in an iree_hal_buffer_t.
+iree_status_t iree_hal_cuda_buffer_wrap(
+    iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type,
+    iree_hal_memory_access_t allowed_access,
+    iree_hal_buffer_usage_t allowed_usage, iree_device_size_t allocation_size,
+    iree_device_size_t byte_offset, iree_device_size_t byte_length,
+    CUdeviceptr device_ptr, void* host_ptr, iree_hal_buffer_t** out_buffer);
+
+// Returns the cuda base pointer for the given |buffer|.
+// This is the entire allocated_buffer and must be offset by the buffer
+// byte_offset and byte_length when used.
+CUdeviceptr iree_hal_cuda_buffer_device_pointer(iree_hal_buffer_t* buffer);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_CUDA_BUFFER_H_
diff --git a/iree/hal/cuda/cuda_device.c b/iree/hal/cuda/cuda_device.c
new file mode 100644
index 0000000..0430fe8
--- /dev/null
+++ b/iree/hal/cuda/cuda_device.c
@@ -0,0 +1,260 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/hal/cuda/cuda_device.h"
+
+#include "iree/base/status.h"
+#include "iree/base/tracing.h"
+#include "iree/hal/cuda/api.h"
+#include "iree/hal/cuda/cuda_allocator.h"
+#include "iree/hal/cuda/dynamic_symbols.h"
+#include "iree/hal/cuda/status_util.h"
+
+//===----------------------------------------------------------------------===//
+// iree_hal_cuda_device_t
+//===----------------------------------------------------------------------===//
+
+typedef struct {
+  iree_hal_resource_t resource;
+  iree_string_view_t identifier;
+
+  // Optional driver that owns the CUDA symbols. We retain it for our lifetime
+  // to ensure the symbols remains valid.
+  iree_hal_driver_t* driver;
+
+  CUdevice device;
+
+  // TODO: support multiple streams.
+  CUstream stream;
+  iree_hal_cuda_context_wrapper_t context_wrapper;
+  iree_hal_allocator_t* device_allocator;
+
+} iree_hal_cuda_device_t;
+
+extern const iree_hal_device_vtable_t iree_hal_cuda_device_vtable;
+
+static iree_hal_cuda_device_t* iree_hal_cuda_device_cast(
+    iree_hal_device_t* base_value) {
+  IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_cuda_device_vtable);
+  return (iree_hal_cuda_device_t*)base_value;
+}
+
+static void iree_hal_cuda_device_destroy(iree_hal_device_t* base_device) {
+  iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device);
+  iree_allocator_t host_allocator = iree_hal_device_host_allocator(base_device);
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  // There should be no more buffers live that use the allocator.
+  iree_hal_allocator_release(device->device_allocator);
+  CUDA_IGNORE_ERROR(device->context_wrapper.syms,
+                    cuStreamDestroy(device->stream));
+
+  // Finally, destroy the device.
+  iree_hal_driver_release(device->driver);
+
+  iree_allocator_free(host_allocator, device);
+
+  IREE_TRACE_ZONE_END(z0);
+}
+
+static iree_status_t iree_hal_cuda_device_create_internal(
+    iree_hal_driver_t* driver, iree_string_view_t identifier,
+    CUdevice cu_device, CUstream stream, CUcontext context,
+    iree_hal_cuda_dynamic_symbols_t* syms, iree_allocator_t host_allocator,
+    iree_hal_device_t** out_device) {
+  iree_hal_cuda_device_t* device = NULL;
+  iree_host_size_t total_size = sizeof(*device) + identifier.size;
+  IREE_RETURN_IF_ERROR(
+      iree_allocator_malloc(host_allocator, total_size, (void**)&device));
+  memset(device, 0, total_size);
+  iree_hal_resource_initialize(&iree_hal_cuda_device_vtable, &device->resource);
+  device->driver = driver;
+  iree_hal_driver_retain(device->driver);
+  uint8_t* buffer_ptr = (uint8_t*)device + sizeof(*device);
+  buffer_ptr += iree_string_view_append_to_buffer(
+      identifier, &device->identifier, (char*)buffer_ptr);
+  device->device = cu_device;
+  device->stream = stream;
+  device->context_wrapper.cu_context = context;
+  device->context_wrapper.host_allocator = host_allocator;
+  device->context_wrapper.syms = syms;
+  iree_status_t status = iree_hal_cuda_allocator_create(
+      &device->context_wrapper, &device->device_allocator);
+  if (iree_status_is_ok(status)) {
+    *out_device = (iree_hal_device_t*)device;
+  } else {
+    iree_hal_device_release((iree_hal_device_t*)device);
+  }
+  return status;
+}
+
+iree_status_t iree_hal_cuda_device_create(iree_hal_driver_t* driver,
+                                          iree_string_view_t identifier,
+                                          iree_hal_cuda_dynamic_symbols_t* syms,
+                                          CUdevice device,
+                                          iree_allocator_t host_allocator,
+                                          iree_hal_device_t** out_device) {
+  IREE_TRACE_ZONE_BEGIN(z0);
+  CUcontext context;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, CU_RESULT_TO_STATUS(syms, cuCtxCreate(&context, 0, device)));
+  CUstream stream;
+  iree_status_t status = CU_RESULT_TO_STATUS(
+      syms, cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING));
+
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_cuda_device_create_internal(driver, identifier, device,
+                                                  stream, context, syms,
+                                                  host_allocator, out_device);
+  }
+  if (!iree_status_is_ok(status)) {
+    if (stream) {
+      syms->cuStreamDestroy(stream);
+    }
+    syms->cuCtxDestroy(context);
+  }
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+static iree_string_view_t iree_hal_cuda_device_id(
+    iree_hal_device_t* base_device) {
+  iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device);
+  return device->identifier;
+}
+
+static iree_allocator_t iree_hal_cuda_device_host_allocator(
+    iree_hal_device_t* base_device) {
+  iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device);
+  return device->context_wrapper.host_allocator;
+}
+
+static iree_hal_allocator_t* iree_hal_cuda_device_allocator(
+    iree_hal_device_t* base_device) {
+  iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device);
+  return device->device_allocator;
+}
+
+static iree_status_t iree_hal_cuda_device_create_command_buffer(
+    iree_hal_device_t* base_device, iree_hal_command_buffer_mode_t mode,
+    iree_hal_command_category_t command_categories,
+    iree_hal_command_buffer_t** out_command_buffer) {
+  return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "Not impemented on CUDA");
+}
+
+static iree_status_t iree_hal_cuda_device_create_descriptor_set(
+    iree_hal_device_t* base_device,
+    iree_hal_descriptor_set_layout_t* set_layout,
+    iree_host_size_t binding_count,
+    const iree_hal_descriptor_set_binding_t* bindings,
+    iree_hal_descriptor_set_t** out_descriptor_set) {
+  return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                          "non-push descriptor sets still need work");
+}
+
+static iree_status_t iree_hal_cuda_device_create_descriptor_set_layout(
+    iree_hal_device_t* base_device,
+    iree_hal_descriptor_set_layout_usage_type_t usage_type,
+    iree_host_size_t binding_count,
+    const iree_hal_descriptor_set_layout_binding_t* bindings,
+    iree_hal_descriptor_set_layout_t** out_descriptor_set_layout) {
+  return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "Not impemented on CUDA");
+}
+
+static iree_status_t iree_hal_cuda_device_create_event(
+    iree_hal_device_t* base_device, iree_hal_event_t** out_event) {
+  return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "Not impemented on CUDA");
+}
+
+static iree_status_t iree_hal_cuda_device_create_executable_cache(
+    iree_hal_device_t* base_device, iree_string_view_t identifier,
+    iree_hal_executable_cache_t** out_executable_cache) {
+  return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "Not impemented on CUDA");
+}
+
+static iree_status_t iree_hal_cuda_device_create_executable_layout(
+    iree_hal_device_t* base_device, iree_host_size_t push_constants,
+    iree_host_size_t set_layout_count,
+    iree_hal_descriptor_set_layout_t** set_layouts,
+    iree_hal_executable_layout_t** out_executable_layout) {
+  return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "Not impemented on CUDA");
+}
+
+static iree_status_t iree_hal_cuda_device_create_semaphore(
+    iree_hal_device_t* base_device, uint64_t initial_value,
+    iree_hal_semaphore_t** out_semaphore) {
+  return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "Not impemented on CUDA");
+}
+
+static iree_status_t iree_hal_cuda_device_queue_submit(
+    iree_hal_device_t* base_device,
+    iree_hal_command_category_t command_categories, uint64_t queue_affinity,
+    iree_host_size_t batch_count, const iree_hal_submission_batch_t* batches) {
+  return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "Not impemented on CUDA");
+}
+
+static iree_status_t iree_hal_cuda_device_wait_semaphores_with_timeout(
+    iree_hal_device_t* base_device, iree_hal_wait_mode_t wait_mode,
+    const iree_hal_semaphore_list_t* semaphore_list,
+    iree_duration_t timeout_ns) {
+  return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                          "semaphore not implemented");
+}
+
+static iree_status_t iree_hal_cuda_device_wait_semaphores_with_deadline(
+    iree_hal_device_t* base_device, iree_hal_wait_mode_t wait_mode,
+    const iree_hal_semaphore_list_t* semaphore_list, iree_time_t deadline_ns) {
+  return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                          "semaphore not implemented");
+}
+
+static iree_status_t iree_hal_cuda_device_wait_idle_with_deadline(
+    iree_hal_device_t* base_device, iree_time_t deadline_ns) {
+  iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device);
+  // Wait until the stream is done.
+  // TODO(thomasraoux): CUDA doesn't support a deadline for wait, figure out how
+  // to handle it better.
+  CUDA_RETURN_IF_ERROR(device->context_wrapper.syms,
+                       cuStreamSynchronize(device->stream),
+                       "cuStreamSynchronize");
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_cuda_device_wait_idle_with_timeout(
+    iree_hal_device_t* base_device, iree_duration_t timeout_ns) {
+  return iree_hal_cuda_device_wait_idle_with_deadline(
+      base_device, iree_relative_timeout_to_deadline_ns(timeout_ns));
+}
+
+const iree_hal_device_vtable_t iree_hal_cuda_device_vtable = {
+    .destroy = iree_hal_cuda_device_destroy,
+    .id = iree_hal_cuda_device_id,
+    .host_allocator = iree_hal_cuda_device_host_allocator,
+    .device_allocator = iree_hal_cuda_device_allocator,
+    .create_command_buffer = iree_hal_cuda_device_create_command_buffer,
+    .create_descriptor_set = iree_hal_cuda_device_create_descriptor_set,
+    .create_descriptor_set_layout =
+        iree_hal_cuda_device_create_descriptor_set_layout,
+    .create_event = iree_hal_cuda_device_create_event,
+    .create_executable_cache = iree_hal_cuda_device_create_executable_cache,
+    .create_executable_layout = iree_hal_cuda_device_create_executable_layout,
+    .create_semaphore = iree_hal_cuda_device_create_semaphore,
+    .queue_submit = iree_hal_cuda_device_queue_submit,
+    .wait_semaphores_with_deadline =
+        iree_hal_cuda_device_wait_semaphores_with_deadline,
+    .wait_semaphores_with_timeout =
+        iree_hal_cuda_device_wait_semaphores_with_timeout,
+    .wait_idle_with_deadline = iree_hal_cuda_device_wait_idle_with_deadline,
+    .wait_idle_with_timeout = iree_hal_cuda_device_wait_idle_with_timeout,
+};
diff --git a/iree/hal/cuda/cuda_device.h b/iree/hal/cuda/cuda_device.h
new file mode 100644
index 0000000..5d7d4ab
--- /dev/null
+++ b/iree/hal/cuda/cuda_device.h
@@ -0,0 +1,38 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_CUDA_CUDA_DEVICE_H_
+#define IREE_HAL_CUDA_CUDA_DEVICE_H_
+
+#include "iree/hal/api.h"
+#include "iree/hal/cuda/api.h"
+#include "iree/hal/cuda/dynamic_symbols.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Creates a device that owns and manages its own CUcontext.
+iree_status_t iree_hal_cuda_device_create(iree_hal_driver_t* driver,
+                                          iree_string_view_t identifier,
+                                          iree_hal_cuda_dynamic_symbols_t* syms,
+                                          CUdevice device,
+                                          iree_allocator_t host_allocator,
+                                          iree_hal_device_t** out_device);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_CUDA_CUDA_DEVICE_H_
diff --git a/iree/hal/cuda/cuda_driver.c b/iree/hal/cuda/cuda_driver.c
new file mode 100644
index 0000000..79f70b2
--- /dev/null
+++ b/iree/hal/cuda/cuda_driver.c
@@ -0,0 +1,210 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/base/tracing.h"
+#include "iree/hal/cuda/api.h"
+#include "iree/hal/cuda/cuda_device.h"
+#include "iree/hal/cuda/dynamic_symbols.h"
+#include "iree/hal/cuda/status_util.h"
+
+typedef struct {
+  iree_hal_resource_t resource;
+  iree_allocator_t host_allocator;
+  // Identifier used for the driver in the IREE driver registry.
+  // We allow overriding so that multiple CUDA versions can be exposed in the
+  // same process.
+  iree_string_view_t identifier;
+  int default_device_index;
+  // CUDA symbols.
+  iree_hal_cuda_dynamic_symbols_t syms;
+} iree_hal_cuda_driver_t;
+
+// Pick a fixed lenght size for device names.
+static const size_t IREE_MAX_CUDA_DEVICE_NAME_LENGTH = 100;
+
+extern const iree_hal_driver_vtable_t iree_hal_cuda_driver_vtable;
+
+static iree_hal_cuda_driver_t* iree_hal_cuda_driver_cast(
+    iree_hal_driver_t* base_value) {
+  IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_cuda_driver_vtable);
+  return (iree_hal_cuda_driver_t*)base_value;
+}
+
+IREE_API_EXPORT void IREE_API_CALL iree_hal_cuda_driver_options_initialize(
+    iree_hal_cuda_driver_options_t* out_options) {
+  memset(out_options, 0, sizeof(*out_options));
+  out_options->default_device_index = 0;
+}
+
+static iree_status_t iree_hal_cuda_driver_create_internal(
+    iree_string_view_t identifier,
+    const iree_hal_cuda_driver_options_t* options,
+    iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) {
+  iree_hal_cuda_driver_t* driver = NULL;
+  iree_host_size_t total_size = sizeof(*driver) + identifier.size;
+  IREE_RETURN_IF_ERROR(
+      iree_allocator_malloc(host_allocator, total_size, (void**)&driver));
+  iree_hal_resource_initialize(&iree_hal_cuda_driver_vtable, &driver->resource);
+  driver->host_allocator = host_allocator;
+  iree_string_view_append_to_buffer(
+      identifier, &driver->identifier,
+      (char*)driver + total_size - identifier.size);
+  driver->default_device_index = options->default_device_index;
+  iree_status_t status = load_symbols(&driver->syms);
+  if (iree_status_is_ok(status)) {
+    *out_driver = (iree_hal_driver_t*)driver;
+  } else {
+    iree_hal_driver_release((iree_hal_driver_t*)driver);
+  }
+  return status;
+}
+
+static void iree_hal_cuda_driver_destroy(iree_hal_driver_t* base_driver) {
+  iree_hal_cuda_driver_t* driver = iree_hal_cuda_driver_cast(base_driver);
+  iree_allocator_t host_allocator = driver->host_allocator;
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  unload_symbols(&driver->syms);
+  iree_allocator_free(host_allocator, driver);
+
+  IREE_TRACE_ZONE_END(z0);
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_cuda_driver_create(
+    iree_string_view_t identifier,
+    const iree_hal_cuda_driver_options_t* options,
+    iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) {
+  IREE_ASSERT_ARGUMENT(options);
+  IREE_ASSERT_ARGUMENT(out_driver);
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  iree_status_t status = iree_hal_cuda_driver_create_internal(
+      identifier, options, host_allocator, out_driver);
+
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+// Populates device information from the given CUDA physical device handle.
+// |out_device_info| must point to valid memory and additional data will be
+// appended to |buffer_ptr| and the new pointer is returned.
+static uint8_t* iree_hal_cuda_populate_device_info(
+    CUdevice device, iree_hal_cuda_dynamic_symbols_t* syms, uint8_t* buffer_ptr,
+    iree_hal_device_info_t* out_device_info) {
+  char device_name[IREE_MAX_CUDA_DEVICE_NAME_LENGTH];
+  CUDA_IGNORE_ERROR(syms,
+                    cuDeviceGetName(device_name, sizeof(device_name), device));
+  memset(out_device_info, 0, sizeof(*out_device_info));
+  out_device_info->device_id = (iree_hal_device_id_t)device;
+
+  iree_string_view_t device_name_string =
+      iree_make_string_view(device_name, strlen(device_name));
+  buffer_ptr += iree_string_view_append_to_buffer(
+      device_name_string, &out_device_info->name, (char*)buffer_ptr);
+  return buffer_ptr;
+}
+
+static iree_status_t iree_hal_cuda_driver_query_available_devices(
+    iree_hal_driver_t* base_driver, iree_allocator_t host_allocator,
+    iree_hal_device_info_t** out_device_infos,
+    iree_host_size_t* out_device_info_count) {
+  iree_hal_cuda_driver_t* driver = iree_hal_cuda_driver_cast(base_driver);
+  // Query the number of available CUDA devices.
+  int device_count = 0;
+  CUDA_RETURN_IF_ERROR(&driver->syms, cuDeviceGetCount(&device_count),
+                       "cuDeviceGetCount");
+
+  // Allocate the return infos and populate with the devices.
+  iree_hal_device_info_t* device_infos = NULL;
+  iree_host_size_t total_size = device_count * sizeof(iree_hal_device_info_t);
+  for (iree_host_size_t i = 0; i < device_count; ++i) {
+    total_size += IREE_MAX_CUDA_DEVICE_NAME_LENGTH * sizeof(char);
+  }
+  iree_status_t status =
+      iree_allocator_malloc(host_allocator, total_size, (void**)&device_infos);
+  if (iree_status_is_ok(status)) {
+    uint8_t* buffer_ptr =
+        (uint8_t*)device_infos + device_count * sizeof(iree_hal_device_info_t);
+    for (iree_host_size_t i = 0; i < device_count; ++i) {
+      CUdevice device;
+      iree_status_t status = CU_RESULT_TO_STATUS(
+          &driver->syms, cuDeviceGet(&device, i), "cuDeviceGet");
+      if (!iree_status_is_ok(status)) break;
+      buffer_ptr = iree_hal_cuda_populate_device_info(
+          device, &driver->syms, buffer_ptr, &device_infos[i]);
+    }
+  }
+  if (iree_status_is_ok(status)) {
+    *out_device_info_count = device_count;
+    *out_device_infos = device_infos;
+  } else {
+    iree_allocator_free(host_allocator, device_infos);
+  }
+  return status;
+}
+
+static iree_status_t iree_hal_cuda_driver_select_default_device(
+    iree_hal_cuda_dynamic_symbols_t* syms, int default_device_index,
+    iree_allocator_t host_allocator, CUdevice* out_device) {
+  int device_count = 0;
+  CUDA_RETURN_IF_ERROR(syms, cuDeviceGetCount(&device_count),
+                       "cuDeviceGetCount");
+  iree_status_t status = iree_ok_status();
+  if (device_count == 0 || default_device_index >= device_count) {
+    status = iree_make_status(IREE_STATUS_NOT_FOUND,
+                              "default device %d not found (of %d enumerated)",
+                              default_device_index, device_count);
+  } else {
+    CUdevice device;
+    CUDA_RETURN_IF_ERROR(syms, cuDeviceGet(&device, default_device_index),
+                         "cuDeviceGet");
+    *out_device = device;
+  }
+  return status;
+}
+
+static iree_status_t iree_hal_cuda_driver_create_device(
+    iree_hal_driver_t* base_driver, iree_hal_device_id_t device_id,
+    iree_allocator_t host_allocator, iree_hal_device_t** out_device) {
+  iree_hal_cuda_driver_t* driver = iree_hal_cuda_driver_cast(base_driver);
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, CU_RESULT_TO_STATUS(&driver->syms, cuInit(0), "cuInit"));
+  // Use either the specified device (enumerated earlier) or whatever default
+  // one was specified when the driver was created.
+  CUdevice device = (CUdevice)device_id;
+  if (device == 0) {
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_hal_cuda_driver_select_default_device(
+                &driver->syms, driver->default_device_index, host_allocator,
+                &device));
+  }
+
+  iree_string_view_t device_name = iree_make_cstring_view("cuda");
+
+  // Attempt to create the device.
+  iree_status_t status =
+      iree_hal_cuda_device_create(base_driver, device_name, &driver->syms,
+                                  device, host_allocator, out_device);
+
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+const iree_hal_driver_vtable_t iree_hal_cuda_driver_vtable = {
+    .destroy = iree_hal_cuda_driver_destroy,
+    .query_available_devices = iree_hal_cuda_driver_query_available_devices,
+    .create_device = iree_hal_cuda_driver_create_device,
+};
diff --git a/iree/hal/cuda/dynamic_symbols.cc b/iree/hal/cuda/dynamic_symbols.cc
index 0927116..df13a4d 100644
--- a/iree/hal/cuda/dynamic_symbols.cc
+++ b/iree/hal/cuda/dynamic_symbols.cc
@@ -17,15 +17,12 @@
 #include <cstddef>
 
 #include "absl/types/span.h"
+#include "iree/base/dynamic_library.h"
 #include "iree/base/status.h"
 #include "iree/base/target_platform.h"
 #include "iree/base/tracing.h"
 
-namespace iree {
-namespace hal {
-namespace cuda {
-
-static const char* kCudaLoaderSearchNames[] = {
+static const char* kCUDALoaderSearchNames[] = {
 #if defined(IREE_PLATFORM_WINDOWS)
     "nvcuda.dll",
 #else
@@ -33,28 +30,31 @@
 #endif
 };
 
-Status DynamicSymbols::LoadSymbols() {
-  IREE_TRACE_SCOPE();
+extern "C" {
 
-  IREE_RETURN_IF_ERROR(DynamicLibrary::Load(
-      absl::MakeSpan(kCudaLoaderSearchNames), &loader_library_));
+iree_status_t load_symbols(iree_hal_cuda_dynamic_symbols_t* syms) {
+  std::unique_ptr<iree::DynamicLibrary> loader_library;
+  IREE_RETURN_IF_ERROR(iree::DynamicLibrary::Load(
+      absl::MakeSpan(kCUDALoaderSearchNames), &loader_library));
 
-#define CU_PFN_DECL(cudaSymbolName)                                         \
+#define CU_PFN_DECL(cudaSymbolName, ...)                                    \
   {                                                                         \
-    using FuncPtrT = std::add_pointer<decltype(::cudaSymbolName)>::type;    \
+    using FuncPtrT = decltype(syms->cudaSymbolName);                        \
     static const char* kName = #cudaSymbolName;                             \
-    cudaSymbolName = loader_library_->GetSymbol<FuncPtrT>(kName);           \
-    if (!cudaSymbolName) {                                                  \
+    syms->cudaSymbolName = loader_library->GetSymbol<FuncPtrT>(kName);      \
+    if (!syms->cudaSymbolName) {                                            \
       return iree_make_status(IREE_STATUS_UNAVAILABLE, "symbol not found"); \
     }                                                                       \
   }
 
 #include "dynamic_symbols_tables.h"
 #undef CU_PFN_DECL
-
-  return OkStatus();
+  syms->opaque_loader_library_ = (void*)loader_library.release();
+  return iree_ok_status();
 }
 
-}  // namespace cuda
-}  // namespace hal
-}  // namespace iree
+void unload_symbols(iree_hal_cuda_dynamic_symbols_t* syms) {
+  delete (iree::DynamicLibrary*)syms->opaque_loader_library_;
+}
+
+}  // extern "C"
\ No newline at end of file
diff --git a/iree/hal/cuda/dynamic_symbols.h b/iree/hal/cuda/dynamic_symbols.h
index 9d2c40e..436b32d 100644
--- a/iree/hal/cuda/dynamic_symbols.h
+++ b/iree/hal/cuda/dynamic_symbols.h
@@ -15,38 +15,29 @@
 #ifndef IREE_HAL_CUDA_DYNAMIC_SYMBOLS_H_
 #define IREE_HAL_CUDA_DYNAMIC_SYMBOLS_H_
 
-#include <cstdint>
-#include <functional>
-#include <memory>
-
-#include "iree/base/dynamic_library.h"
-#include "iree/base/status.h"
+#include "iree/base/api.h"
 #include "iree/hal/cuda/cuda_headers.h"
 
-namespace iree {
-namespace hal {
-namespace cuda {
-
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
 /// DyanmicSymbols allow loading dynamically a subset of CUDA driver API. It
 /// loads all the function declared in `dynamic_symbol_tables.def` and fail if
 /// any of the symbol is not available. The functions signatures are matching
 /// the declarations in `cuda.h`.
-struct DynamicSymbols {
-  Status LoadSymbols();
-
-#define CU_PFN_DECL(cudaSymbolName) \
-  std::add_pointer<decltype(::cudaSymbolName)>::type cudaSymbolName;
-
+typedef struct {
+#define CU_PFN_DECL(cudaSymbolName, ...) \
+  CUresult (*cudaSymbolName)(__VA_ARGS__);
 #include "dynamic_symbols_tables.h"
 #undef CU_PFN_DECL
+  void* opaque_loader_library_;
+} iree_hal_cuda_dynamic_symbols_t;
 
- private:
-  // Cuda Loader dynamic library.
-  std::unique_ptr<DynamicLibrary> loader_library_;
-};
+iree_status_t load_symbols(iree_hal_cuda_dynamic_symbols_t* syms);
+void unload_symbols(iree_hal_cuda_dynamic_symbols_t* syms);
 
-}  // namespace cuda
-}  // namespace hal
-}  // namespace iree
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
 
 #endif  // IREE_HAL_CUDA_DYNAMIC_SYMBOLS_H_
diff --git a/iree/hal/cuda/dynamic_symbols_tables.h b/iree/hal/cuda/dynamic_symbols_tables.h
index 5adece6..1a9fbaa 100644
--- a/iree/hal/cuda/dynamic_symbols_tables.h
+++ b/iree/hal/cuda/dynamic_symbols_tables.h
@@ -12,79 +12,37 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-CU_PFN_DECL(cuCtxCreate)
-CU_PFN_DECL(cuCtxDestroy)
-CU_PFN_DECL(cuCtxEnablePeerAccess)
-CU_PFN_DECL(cuCtxGetCurrent)
-CU_PFN_DECL(cuCtxGetDevice)
-CU_PFN_DECL(cuCtxGetSharedMemConfig)
-CU_PFN_DECL(cuCtxSetCurrent)
-CU_PFN_DECL(cuCtxSetSharedMemConfig)
-CU_PFN_DECL(cuCtxSynchronize)
-CU_PFN_DECL(cuDeviceCanAccessPeer)
-CU_PFN_DECL(cuDeviceGet)
-CU_PFN_DECL(cuDeviceGetAttribute)
-CU_PFN_DECL(cuDeviceGetCount)
-CU_PFN_DECL(cuDeviceGetName)
-CU_PFN_DECL(cuDeviceGetPCIBusId)
-CU_PFN_DECL(cuDevicePrimaryCtxGetState)
-CU_PFN_DECL(cuDevicePrimaryCtxRelease)
-CU_PFN_DECL(cuDevicePrimaryCtxRetain)
-CU_PFN_DECL(cuDevicePrimaryCtxSetFlags)
-CU_PFN_DECL(cuDeviceTotalMem)
-CU_PFN_DECL(cuDriverGetVersion)
-CU_PFN_DECL(cuEventCreate)
-CU_PFN_DECL(cuEventDestroy)
-CU_PFN_DECL(cuEventElapsedTime)
-CU_PFN_DECL(cuEventQuery)
-CU_PFN_DECL(cuEventRecord)
-CU_PFN_DECL(cuEventSynchronize)
-CU_PFN_DECL(cuFuncGetAttribute)
-CU_PFN_DECL(cuFuncSetCacheConfig)
-CU_PFN_DECL(cuGetErrorName)
-CU_PFN_DECL(cuGetErrorString)
-CU_PFN_DECL(cuGraphAddMemcpyNode)
-CU_PFN_DECL(cuGraphAddMemsetNode)
-CU_PFN_DECL(cuGraphAddKernelNode)
-CU_PFN_DECL(cuGraphCreate)
-CU_PFN_DECL(cuGraphDestroy)
-CU_PFN_DECL(cuGraphExecDestroy)
-CU_PFN_DECL(cuGraphGetNodes)
-CU_PFN_DECL(cuGraphInstantiate)
-CU_PFN_DECL(cuGraphLaunch)
-CU_PFN_DECL(cuInit)
-CU_PFN_DECL(cuLaunchKernel)
-CU_PFN_DECL(cuMemAlloc)
-CU_PFN_DECL(cuMemAllocManaged)
-CU_PFN_DECL(cuMemFree)
-CU_PFN_DECL(cuMemFreeHost)
-CU_PFN_DECL(cuMemGetAddressRange)
-CU_PFN_DECL(cuMemGetInfo)
-CU_PFN_DECL(cuMemHostAlloc)
-CU_PFN_DECL(cuMemHostGetDevicePointer)
-CU_PFN_DECL(cuMemHostRegister)
-CU_PFN_DECL(cuMemHostUnregister)
-CU_PFN_DECL(cuMemcpyDtoD)
-CU_PFN_DECL(cuMemcpyDtoDAsync)
-CU_PFN_DECL(cuMemcpyDtoH)
-CU_PFN_DECL(cuMemcpyDtoHAsync)
-CU_PFN_DECL(cuMemcpyHtoD)
-CU_PFN_DECL(cuMemcpyHtoDAsync)
-CU_PFN_DECL(cuMemsetD32)
-CU_PFN_DECL(cuMemsetD32Async)
-CU_PFN_DECL(cuMemsetD8)
-CU_PFN_DECL(cuMemsetD8Async)
-CU_PFN_DECL(cuModuleGetFunction)
-CU_PFN_DECL(cuModuleGetGlobal)
-CU_PFN_DECL(cuModuleLoadDataEx)
-CU_PFN_DECL(cuModuleLoadFatBinary)
-CU_PFN_DECL(cuModuleUnload)
-CU_PFN_DECL(cuOccupancyMaxActiveBlocksPerMultiprocessor)
-CU_PFN_DECL(cuOccupancyMaxPotentialBlockSize)
-CU_PFN_DECL(cuPointerGetAttribute)
-CU_PFN_DECL(cuStreamAddCallback)
-CU_PFN_DECL(cuStreamCreate)
-CU_PFN_DECL(cuStreamDestroy)
-CU_PFN_DECL(cuStreamQuery)
-CU_PFN_DECL(cuStreamSynchronize)
-CU_PFN_DECL(cuStreamWaitEvent)
+CU_PFN_DECL(cuCtxCreate, CUcontext*, unsigned int, CUdevice)
+CU_PFN_DECL(cuCtxDestroy, CUcontext)
+CU_PFN_DECL(cuDeviceGet, CUdevice*, int)
+CU_PFN_DECL(cuDeviceGetCount, int*)
+CU_PFN_DECL(cuDeviceGetName, char*, int, CUdevice)
+CU_PFN_DECL(cuGetErrorName, CUresult, const char**)
+CU_PFN_DECL(cuGetErrorString, CUresult, const char**)
+CU_PFN_DECL(cuGraphAddMemcpyNode, CUgraphNode*, CUgraph, const CUgraphNode*,
+            size_t, const CUDA_MEMCPY3D*, CUcontext)
+CU_PFN_DECL(cuGraphAddMemsetNode, CUgraphNode*, CUgraph, const CUgraphNode*,
+            size_t, const CUDA_MEMSET_NODE_PARAMS*, CUcontext)
+CU_PFN_DECL(cuGraphAddKernelNode, CUgraphNode*, CUgraph, const CUgraphNode*,
+            size_t, const CUDA_KERNEL_NODE_PARAMS*)
+CU_PFN_DECL(cuGraphCreate, CUgraph*, unsigned int)
+CU_PFN_DECL(cuGraphDestroy, CUgraph)
+CU_PFN_DECL(cuGraphExecDestroy, CUgraphExec)
+CU_PFN_DECL(cuGraphGetNodes, CUgraph, CUgraphNode*, size_t*)
+CU_PFN_DECL(cuGraphInstantiate, CUgraphExec*, CUgraph, CUgraphNode*, char*,
+            size_t)
+CU_PFN_DECL(cuGraphLaunch, CUgraphExec, CUstream)
+CU_PFN_DECL(cuInit, unsigned int)
+CU_PFN_DECL(cuMemAlloc, CUdeviceptr*, size_t)
+CU_PFN_DECL(cuMemFree, CUdeviceptr)
+CU_PFN_DECL(cuMemFreeHost, void*)
+CU_PFN_DECL(cuMemHostAlloc, void**, size_t, unsigned int)
+CU_PFN_DECL(cuMemHostGetDevicePointer, CUdeviceptr*, void*, unsigned int)
+CU_PFN_DECL(cuModuleGetFunction, CUfunction*, CUmodule, const char*)
+CU_PFN_DECL(cuModuleLoadDataEx, CUmodule*, const void*, unsigned int,
+            CUjit_option*, void**)
+CU_PFN_DECL(cuModuleUnload, CUmodule)
+CU_PFN_DECL(cuStreamCreate, CUstream*, unsigned int)
+CU_PFN_DECL(cuStreamDestroy, CUstream)
+CU_PFN_DECL(cuStreamSynchronize, CUstream)
+CU_PFN_DECL(cuStreamWaitEvent, CUstream, CUevent, unsigned int)
diff --git a/iree/hal/cuda/dynamic_symbols_test.cc b/iree/hal/cuda/dynamic_symbols_test.cc
index 6a7967c..5b8681f 100644
--- a/iree/hal/cuda/dynamic_symbols_test.cc
+++ b/iree/hal/cuda/dynamic_symbols_test.cc
@@ -29,9 +29,9 @@
   }
 
 TEST(DynamicSymbolsTest, CreateFromSystemLoader) {
-  DynamicSymbols symbols;
-  Status status = symbols.LoadSymbols();
-  if (!status.ok()) {
+  iree_hal_cuda_dynamic_symbols_t symbols;
+  iree_status_t status = load_symbols(&symbols);
+  if (!iree_status_is_ok(status)) {
     IREE_LOG(WARNING) << "Symbols cannot be loaded, skipping test.";
     GTEST_SKIP();
   }
@@ -43,6 +43,7 @@
     CUdevice device;
     CUDE_CHECK_ERRORS(symbols.cuDeviceGet(&device, /*ordinal=*/0));
   }
+  unload_symbols(&symbols);
 }
 
 }  // namespace
diff --git a/iree/hal/cuda/registration/BUILD b/iree/hal/cuda/registration/BUILD
new file mode 100644
index 0000000..58157f4
--- /dev/null
+++ b/iree/hal/cuda/registration/BUILD
@@ -0,0 +1,51 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content")
+
+package(
+    default_visibility = ["//visibility:public"],
+    features = ["layering_check"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+iree_cmake_extra_content(
+    content = """
+if(${IREE_HAL_DRIVER_CUDA})
+""",
+    inline = True,
+)
+
+cc_library(
+    name = "registration",
+    srcs = ["driver_module.c"],
+    hdrs = ["driver_module.h"],
+    defines = [
+        "IREE_HAL_HAVE_CUDA_DRIVER_MODULE=1",
+    ],
+    deps = [
+        "//iree/base:core_headers",
+        "//iree/base:status",
+        "//iree/base:tracing",
+        "//iree/hal:api",
+        "//iree/hal/cuda",
+    ],
+)
+
+iree_cmake_extra_content(
+    content = """
+endif()
+""",
+    inline = True,
+)
diff --git a/iree/hal/cuda/registration/CMakeLists.txt b/iree/hal/cuda/registration/CMakeLists.txt
new file mode 100644
index 0000000..93d8dac
--- /dev/null
+++ b/iree/hal/cuda/registration/CMakeLists.txt
@@ -0,0 +1,25 @@
+# Autogenerated from iree/hal/cuda/registration/BUILD by
+# build_tools/bazel_to_cmake/bazel_to_cmake.py
+iree_add_all_subdirs()
+
+if(${IREE_HAL_DRIVER_CUDA})
+
+iree_cc_library(
+  NAME
+    registration
+  HDRS
+    "driver_module.h"
+  SRCS
+    "driver_module.c"
+  DEPS
+    iree::base::core_headers
+    iree::base::status
+    iree::base::tracing
+    iree::hal::api
+    iree::hal::cuda
+  DEFINES
+    "IREE_HAL_HAVE_CUDA_DRIVER_MODULE=1"
+  PUBLIC
+)
+
+endif()
diff --git a/iree/hal/cuda/registration/driver_module.c b/iree/hal/cuda/registration/driver_module.c
new file mode 100644
index 0000000..5677cd1
--- /dev/null
+++ b/iree/hal/cuda/registration/driver_module.c
@@ -0,0 +1,72 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/hal/cuda/registration/driver_module.h"
+
+#include <inttypes.h>
+
+#include "iree/base/status.h"
+#include "iree/base/target_platform.h"
+#include "iree/base/tracing.h"
+#include "iree/hal/cuda/api.h"
+
+#define IREE_HAL_CUDA_DRIVER_ID 0x43554441u  // CUDA
+
+static iree_status_t iree_hal_cuda_driver_factory_enumerate(
+    void* self, const iree_hal_driver_info_t** out_driver_infos,
+    iree_host_size_t* out_driver_info_count) {
+  // NOTE: we could query supported cuda versions or featuresets here.
+  static const iree_hal_driver_info_t driver_infos[1] = {{
+      .driver_id = IREE_HAL_CUDA_DRIVER_ID,
+      .driver_name = iree_string_view_literal("cuda"),
+      .full_name = iree_string_view_literal("CUDA (dynamic)"),
+  }};
+  *out_driver_info_count = IREE_ARRAYSIZE(driver_infos);
+  *out_driver_infos = driver_infos;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_cuda_driver_factory_try_create(
+    void* self, iree_hal_driver_id_t driver_id, iree_allocator_t allocator,
+    iree_hal_driver_t** out_driver) {
+  IREE_ASSERT_ARGUMENT(out_driver);
+  *out_driver = NULL;
+  if (driver_id != IREE_HAL_CUDA_DRIVER_ID) {
+    return iree_make_status(IREE_STATUS_UNAVAILABLE,
+                            "no driver with ID %016" PRIu64
+                            " is provided by this factory",
+                            driver_id);
+  }
+  IREE_TRACE_ZONE_BEGIN(z0);
+  // When we expose more than one driver (different cuda versions, etc) we
+  // can name them here:
+  iree_string_view_t identifier = iree_make_cstring_view("cuda");
+
+  iree_hal_cuda_driver_options_t driver_options;
+  iree_hal_cuda_driver_options_initialize(&driver_options);
+  iree_status_t status = iree_hal_cuda_driver_create(
+      identifier, &driver_options, allocator, out_driver);
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_cuda_driver_module_register(iree_hal_driver_registry_t* registry) {
+  static const iree_hal_driver_factory_t factory = {
+      .self = NULL,
+      .enumerate = iree_hal_cuda_driver_factory_enumerate,
+      .try_create = iree_hal_cuda_driver_factory_try_create,
+  };
+  return iree_hal_driver_registry_register_factory(registry, &factory);
+}
diff --git a/iree/hal/cuda/registration/driver_module.h b/iree/hal/cuda/registration/driver_module.h
new file mode 100644
index 0000000..897594c
--- /dev/null
+++ b/iree/hal/cuda/registration/driver_module.h
@@ -0,0 +1,31 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_CUDA_REGISTRATION_DRIVER_MODULE_H_
+#define IREE_HAL_CUDA_REGISTRATION_DRIVER_MODULE_H_
+
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_cuda_driver_module_register(iree_hal_driver_registry_t* registry);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_CUDA_REGISTRATION_DRIVER_MODULE_H_
diff --git a/iree/hal/cuda/status_util.c b/iree/hal/cuda/status_util.c
new file mode 100644
index 0000000..9a084a8
--- /dev/null
+++ b/iree/hal/cuda/status_util.c
@@ -0,0 +1,38 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/hal/cuda/status_util.h"
+
+#include "iree/hal/cuda/dynamic_symbols.h"
+
+iree_status_t iree_hal_cuda_result_to_status(
+    iree_hal_cuda_dynamic_symbols_t* syms, CUresult result, const char* file,
+    uint32_t line) {
+  if (IREE_LIKELY(result == CUDA_SUCCESS)) {
+    return iree_ok_status();
+  }
+
+  const char* error_name = NULL;
+  if (syms->cuGetErrorName(result, &error_name) != CUDA_SUCCESS) {
+    error_name = "UNKNOWN";
+  }
+
+  const char* error_string = NULL;
+  if (syms->cuGetErrorString(result, &error_string) != CUDA_SUCCESS) {
+    error_string = "Unknown error.";
+  }
+  return iree_make_status(IREE_STATUS_INTERNAL,
+                          "CUDA driver error '%s' (%d): %s", error_name, result,
+                          error_string);
+}
diff --git a/iree/hal/cuda/status_util.h b/iree/hal/cuda/status_util.h
new file mode 100644
index 0000000..526b56f
--- /dev/null
+++ b/iree/hal/cuda/status_util.h
@@ -0,0 +1,60 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_CUDA_STATUS_UTIL_H_
+#define IREE_HAL_CUDA_STATUS_UTIL_H_
+
+#include "iree/base/api.h"
+#include "iree/hal/cuda/dynamic_symbols.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Converts a CUresult to an iree_status_t.
+//
+// Usage:
+//   iree_status_t status = CU_RESULT_TO_STATUS(cuDoThing(...));
+#define CU_RESULT_TO_STATUS(syms, expr, ...) \
+  iree_hal_cuda_result_to_status((syms), ((syms)->expr), __FILE__, __LINE__)
+
+// IREE_RETURN_IF_ERROR but implicitly converts the CUresult return value to
+// a Status.
+//
+// Usage:
+//   CUDA_RETURN_IF_ERROR(cuDoThing(...), "message");
+#define CUDA_RETURN_IF_ERROR(syms, expr, ...)                                 \
+  IREE_RETURN_IF_ERROR(iree_hal_cuda_result_to_status((syms), ((syms)->expr), \
+                                                      __FILE__, __LINE__),    \
+                       __VA_ARGS__)
+
+// IREE_IGNORE_ERROR but implicitly converts the CUresult return value to a
+// ::util::Status and checks that it is OkStatus.
+//
+// Usage:
+//   CUDA_IGNORE_ERROR(cuDoThing(...));
+#define CUDA_IGNORE_ERROR(syms, expr)                                      \
+  IREE_IGNORE_ERROR(iree_hal_cuda_result_to_status((syms), ((syms)->expr), \
+                                                   __FILE__, __LINE__))
+
+// Converts a CUresult to a Status object.
+iree_status_t iree_hal_cuda_result_to_status(
+    iree_hal_cuda_dynamic_symbols_t* syms, CUresult result, const char* file,
+    uint32_t line);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_CUDA_STATUS_UTIL_H_
diff --git a/iree/hal/drivers/CMakeLists.txt b/iree/hal/drivers/CMakeLists.txt
index 305bcb5..1994aef 100644
--- a/iree/hal/drivers/CMakeLists.txt
+++ b/iree/hal/drivers/CMakeLists.txt
@@ -15,6 +15,9 @@
 # bazel_to_cmake: DO NOT EDIT (custom configuration vars)
 
 set(IREE_HAL_DRIVER_MODULES)
+if(${IREE_HAL_DRIVER_CUDA})
+  list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::cuda::registration)
+endif()
 if(${IREE_HAL_DRIVER_DYLIB})
   list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::dylib::registration)
 endif()
diff --git a/iree/hal/drivers/init.c b/iree/hal/drivers/init.c
index bb36616..bbcd9b9 100644
--- a/iree/hal/drivers/init.c
+++ b/iree/hal/drivers/init.c
@@ -16,6 +16,10 @@
 
 #include "iree/base/tracing.h"
 
+#if defined(IREE_HAL_HAVE_CUDA_DRIVER_MODULE)
+#include "iree/hal/cuda/registration/driver_module.h"
+#endif  // IREE_HAL_HAVE_CUDA_DRIVER_MODULE
+
 #if defined(IREE_HAL_HAVE_DYLIB_DRIVER_MODULE)
 #include "iree/hal/dylib/registration/driver_module.h"
 #endif  // IREE_HAL_HAVE_DYLIB_DRIVER_MODULE
@@ -32,6 +36,11 @@
 iree_hal_register_all_available_drivers(iree_hal_driver_registry_t* registry) {
   IREE_TRACE_ZONE_BEGIN(z0);
 
+#if defined(IREE_HAL_HAVE_CUDA_DRIVER_MODULE)
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_hal_cuda_driver_module_register(registry));
+#endif  // IREE_HAL_HAVE_CUDA_DRIVER_MODULE
+
 #if defined(IREE_HAL_HAVE_DYLIB_DRIVER_MODULE)
   IREE_RETURN_AND_END_ZONE_IF_ERROR(
       z0, iree_hal_dylib_driver_module_register(registry));