Cuda HAL backend step 1 (#4874)
register the backend and add stub for the device. Add support for
buffers and allocator.
diff --git a/iree/hal/cts/command_buffer_test.cc b/iree/hal/cts/command_buffer_test.cc
index b4ee8f5..1c4ea00 100644
--- a/iree/hal/cts/command_buffer_test.cc
+++ b/iree/hal/cts/command_buffer_test.cc
@@ -33,6 +33,9 @@
using ::testing::ContainerEq;
class CommandBufferTest : public CtsTestBase {
+ public:
+ CommandBufferTest() { declareUnimplementedDriver("cuda"); }
+
protected:
static constexpr iree_device_size_t kBufferSize = 4096;
};
@@ -96,26 +99,28 @@
// Fill the device buffer with segments of different values so that we can
// test both fill and offset/size.
uint8_t val1 = 0x07;
- iree_hal_command_buffer_fill_buffer(
+ IREE_ASSERT_OK(iree_hal_command_buffer_fill_buffer(
command_buffer, device_buffer,
/*target_offset=*/0, /*length=*/kBufferSize / 4, /*pattern=*/&val1,
- /*pattern_length=*/sizeof(val1));
+ /*pattern_length=*/sizeof(val1)));
std::memset(reference_buffer.data(), val1, kBufferSize / 4);
uint8_t val2 = 0xbe;
- iree_hal_command_buffer_fill_buffer(command_buffer, device_buffer,
- /*target_offset=*/kBufferSize / 4,
- /*length=*/kBufferSize / 4,
- /*pattern=*/&val2,
- /*pattern_length=*/sizeof(val2));
+ IREE_ASSERT_OK(
+ iree_hal_command_buffer_fill_buffer(command_buffer, device_buffer,
+ /*target_offset=*/kBufferSize / 4,
+ /*length=*/kBufferSize / 4,
+ /*pattern=*/&val2,
+ /*pattern_length=*/sizeof(val2)));
std::memset(reference_buffer.data() + kBufferSize / 4, val2, kBufferSize / 4);
uint8_t val3 = 0x54;
- iree_hal_command_buffer_fill_buffer(command_buffer, device_buffer,
- /*target_offset=*/kBufferSize / 2,
- /*length=*/kBufferSize / 2,
- /*pattern=*/&val3,
- /*pattern_length=*/sizeof(val3));
+ IREE_ASSERT_OK(
+ iree_hal_command_buffer_fill_buffer(command_buffer, device_buffer,
+ /*target_offset=*/kBufferSize / 2,
+ /*length=*/kBufferSize / 2,
+ /*pattern=*/&val3,
+ /*pattern_length=*/sizeof(val3)));
std::memset(reference_buffer.data() + kBufferSize / 2, val3, kBufferSize / 2);
IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer));
diff --git a/iree/hal/cts/cts_test_base.h b/iree/hal/cts/cts_test_base.h
index 09875ad..cf7d58c 100644
--- a/iree/hal/cts/cts_test_base.h
+++ b/iree/hal/cts/cts_test_base.h
@@ -33,6 +33,12 @@
protected:
virtual void SetUp() {
const std::string& driver_name = GetParam();
+ if (driver_block_list_.find(driver_name) != driver_block_list_.end()) {
+ IREE_LOG(WARNING)
+ << "Skipping test as driver is explicitly disabled for this test";
+ GTEST_SKIP();
+ return;
+ }
// Get driver with the given name and create its default device.
// Skip drivers that are (gracefully) unavailable, fail if creation fails.
@@ -122,6 +128,14 @@
iree_hal_driver_t* driver_ = nullptr;
iree_hal_device_t* device_ = nullptr;
iree_hal_allocator_t* device_allocator_ = nullptr;
+ // Allow skipping tests for driver under development.
+ void declareUnimplementedDriver(const std::string& driver_name) {
+ driver_block_list_.insert(driver_name);
+ }
+ // Allow skipping tests for unsupported features.
+ void declareUnavailableDriver(const std::string& driver_name) {
+ driver_block_list_.insert(driver_name);
+ }
private:
// Gets a HAL driver with the provided name, if available.
@@ -149,6 +163,7 @@
}
return status;
}
+ std::set<std::string> driver_block_list_;
};
struct GenerateTestName {
diff --git a/iree/hal/cts/descriptor_set_layout_test.cc b/iree/hal/cts/descriptor_set_layout_test.cc
index 14dd2b5..3b7df35 100644
--- a/iree/hal/cts/descriptor_set_layout_test.cc
+++ b/iree/hal/cts/descriptor_set_layout_test.cc
@@ -21,7 +21,10 @@
namespace hal {
namespace cts {
-class DescriptorSetLayoutTest : public CtsTestBase {};
+class DescriptorSetLayoutTest : public CtsTestBase {
+ public:
+ DescriptorSetLayoutTest() { declareUnimplementedDriver("cuda"); }
+};
// Note: bindingCount == 0 is valid in VkDescriptorSetLayoutCreateInfo:
// https://www.khronos.org/registry/vulkan/specs/1.2-extensions/man/html/VkDescriptorSetLayoutCreateInfo.html
diff --git a/iree/hal/cts/event_test.cc b/iree/hal/cts/event_test.cc
index b728277..083be3a 100644
--- a/iree/hal/cts/event_test.cc
+++ b/iree/hal/cts/event_test.cc
@@ -21,7 +21,10 @@
namespace hal {
namespace cts {
-class EventTest : public CtsTestBase {};
+class EventTest : public CtsTestBase {
+ public:
+ EventTest() { declareUnimplementedDriver("cuda"); }
+};
TEST_P(EventTest, Create) {
iree_hal_event_t* event;
diff --git a/iree/hal/cts/executable_layout_test.cc b/iree/hal/cts/executable_layout_test.cc
index 1df73bd..f4a3374 100644
--- a/iree/hal/cts/executable_layout_test.cc
+++ b/iree/hal/cts/executable_layout_test.cc
@@ -21,7 +21,10 @@
namespace hal {
namespace cts {
-class ExecutableLayoutTest : public CtsTestBase {};
+class ExecutableLayoutTest : public CtsTestBase {
+ public:
+ ExecutableLayoutTest() { declareUnimplementedDriver("cuda"); }
+};
TEST_P(ExecutableLayoutTest, CreateWithNoLayouts) {
iree_hal_executable_layout_t* executable_layout;
diff --git a/iree/hal/cts/semaphore_submission_test.cc b/iree/hal/cts/semaphore_submission_test.cc
index 5f9c529..5cec41f 100644
--- a/iree/hal/cts/semaphore_submission_test.cc
+++ b/iree/hal/cts/semaphore_submission_test.cc
@@ -20,7 +20,11 @@
namespace hal {
namespace cts {
-class SemaphoreSubmissionTest : public CtsTestBase {};
+class SemaphoreSubmissionTest : public CtsTestBase {
+ public:
+ // Disable cuda backend for this test as semaphores are not implemented yet.
+ SemaphoreSubmissionTest() { declareUnavailableDriver("cuda"); }
+};
TEST_P(SemaphoreSubmissionTest, SubmitWithNoCommandBuffers) {
// No waits, one signal which we immediately wait on after submit.
diff --git a/iree/hal/cts/semaphore_test.cc b/iree/hal/cts/semaphore_test.cc
index c774241..ac02fca 100644
--- a/iree/hal/cts/semaphore_test.cc
+++ b/iree/hal/cts/semaphore_test.cc
@@ -22,7 +22,11 @@
namespace hal {
namespace cts {
-class SemaphoreTest : public CtsTestBase {};
+class SemaphoreTest : public CtsTestBase {
+ public:
+ // Disable cuda backend for this test as semaphores are not implemented yet.
+ SemaphoreTest() { declareUnavailableDriver("cuda"); }
+};
// Tests that a semaphore that is unused properly cleans itself up.
TEST_P(SemaphoreTest, NoOp) {
diff --git a/iree/hal/cuda/BUILD b/iree/hal/cuda/BUILD
index 18fc69b..831fbae 100644
--- a/iree/hal/cuda/BUILD
+++ b/iree/hal/cuda/BUILD
@@ -29,6 +29,39 @@
)
cc_library(
+ name = "cuda",
+ srcs = [
+ "api.h",
+ "context_wrapper.h",
+ "cuda_allocator.c",
+ "cuda_allocator.h",
+ "cuda_buffer.c",
+ "cuda_buffer.h",
+ "cuda_device.c",
+ "cuda_device.h",
+ "cuda_driver.c",
+ "status_util.c",
+ "status_util.h",
+ ],
+ hdrs = [
+ "api.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":dynamic_symbols",
+ "//iree/base:api",
+ "//iree/base:core_headers",
+ "//iree/base:flatcc",
+ "//iree/base:logging",
+ "//iree/base:status",
+ "//iree/base:synchronization",
+ "//iree/base:tracing",
+ "//iree/base/internal",
+ "//iree/hal:api",
+ ],
+)
+
+cc_library(
name = "dynamic_symbols",
srcs = [
"cuda_headers.h",
diff --git a/iree/hal/cuda/CMakeLists.txt b/iree/hal/cuda/CMakeLists.txt
index f9e7d76..6df72d1 100644
--- a/iree/hal/cuda/CMakeLists.txt
+++ b/iree/hal/cuda/CMakeLists.txt
@@ -8,6 +8,37 @@
iree_cc_library(
NAME
+ cuda
+ HDRS
+ "api.h"
+ SRCS
+ "api.h"
+ "context_wrapper.h"
+ "cuda_allocator.c"
+ "cuda_allocator.h"
+ "cuda_buffer.c"
+ "cuda_buffer.h"
+ "cuda_device.c"
+ "cuda_device.h"
+ "cuda_driver.c"
+ "status_util.c"
+ "status_util.h"
+ DEPS
+ ::dynamic_symbols
+ iree::base::api
+ iree::base::core_headers
+ iree::base::flatcc
+ iree::base::internal
+ iree::base::logging
+ iree::base::status
+ iree::base::synchronization
+ iree::base::tracing
+ iree::hal::api
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
dynamic_symbols
HDRS
"dynamic_symbols.h"
diff --git a/iree/hal/cuda/api.h b/iree/hal/cuda/api.h
new file mode 100644
index 0000000..3c219b7
--- /dev/null
+++ b/iree/hal/cuda/api.h
@@ -0,0 +1,55 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// See iree/base/api.h for documentation on the API conventions used.
+
+#ifndef IREE_HAL_CUDA_API_H_
+#define IREE_HAL_CUDA_API_H_
+
+#include "iree/base/api.h"
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+//===----------------------------------------------------------------------===//
+// iree_hal_cuda_driver_t
+//===----------------------------------------------------------------------===//
+
+// CUDA driver creation options.
+typedef struct {
+ // Index of the default CUDA device to use within the list of available
+ // devices.
+ int default_device_index;
+} iree_hal_cuda_driver_options_t;
+
+IREE_API_EXPORT void IREE_API_CALL iree_hal_cuda_driver_options_initialize(
+ iree_hal_cuda_driver_options_t* out_options);
+
+// Creates a CUDA HAL driver that manage its own CUcontext.
+//
+// |out_driver| must be released by the caller (see |iree_hal_driver_release|).
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_cuda_driver_create(
+ iree_string_view_t identifier,
+ const iree_hal_cuda_driver_options_t* options,
+ iree_allocator_t host_allocator, iree_hal_driver_t** out_driver);
+
+// TODO(thomasraoux): Support importing a CUcontext from app.
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_CUDA_API_H_
diff --git a/iree/hal/cuda/context_wrapper.h b/iree/hal/cuda/context_wrapper.h
new file mode 100644
index 0000000..304657d
--- /dev/null
+++ b/iree/hal/cuda/context_wrapper.h
@@ -0,0 +1,30 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_CUDA_CONTEXT_WRAPPER_H_
+#define IREE_HAL_CUDA_CONTEXT_WRAPPER_H_
+
+#include "iree/hal/api.h"
+#include "iree/hal/cuda/cuda_headers.h"
+#include "iree/hal/cuda/dynamic_symbols.h"
+
+// Structure to wrap all objects constant within a context. This makes it
+// simpler to pass it to the different objects and saves memory.
+typedef struct {
+ CUcontext cu_context;
+ iree_allocator_t host_allocator;
+ iree_hal_cuda_dynamic_symbols_t* syms;
+} iree_hal_cuda_context_wrapper_t;
+
+#endif // IREE_HAL_CUDA_CONTEXT_WRAPPER_H_
diff --git a/iree/hal/cuda/cuda_allocator.c b/iree/hal/cuda/cuda_allocator.c
new file mode 100644
index 0000000..9995f2b
--- /dev/null
+++ b/iree/hal/cuda/cuda_allocator.c
@@ -0,0 +1,175 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations ufnder the License.
+
+#include "iree/hal/cuda/cuda_allocator.h"
+
+#include "iree/base/tracing.h"
+#include "iree/hal/cuda/cuda_buffer.h"
+#include "iree/hal/cuda/status_util.h"
+
+typedef struct iree_hal_cuda_allocator_s {
+ iree_hal_resource_t resource;
+ iree_hal_cuda_context_wrapper_t* context;
+} iree_hal_cuda_allocator_t;
+
+extern const iree_hal_allocator_vtable_t iree_hal_cuda_allocator_vtable;
+
+static iree_hal_cuda_allocator_t* iree_hal_cuda_allocator_cast(
+ iree_hal_allocator_t* base_value) {
+ IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_cuda_allocator_vtable);
+ return (iree_hal_cuda_allocator_t*)base_value;
+}
+
+iree_status_t iree_hal_cuda_allocator_create(
+ iree_hal_cuda_context_wrapper_t* context,
+ iree_hal_allocator_t** out_allocator) {
+ IREE_ASSERT_ARGUMENT(context);
+ IREE_TRACE_ZONE_BEGIN(z0);
+ iree_hal_cuda_allocator_t* allocator = NULL;
+ iree_status_t status = iree_allocator_malloc(
+ context->host_allocator, sizeof(*allocator), (void**)&allocator);
+ if (iree_status_is_ok(status)) {
+ iree_hal_resource_initialize(&iree_hal_cuda_allocator_vtable,
+ &allocator->resource);
+ allocator->context = context;
+ *out_allocator = (iree_hal_allocator_t*)allocator;
+ }
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+static void iree_hal_cuda_allocator_destroy(
+ iree_hal_allocator_t* base_allocator) {
+ iree_hal_cuda_allocator_t* allocator =
+ iree_hal_cuda_allocator_cast(base_allocator);
+ iree_allocator_t host_allocator = allocator->context->host_allocator;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_allocator_free(host_allocator, allocator);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+static iree_allocator_t iree_hal_cuda_allocator_host_allocator(
+ const iree_hal_allocator_t* base_allocator) {
+ iree_hal_cuda_allocator_t* allocator =
+ (iree_hal_cuda_allocator_t*)base_allocator;
+ return allocator->context->host_allocator;
+}
+
+static iree_hal_buffer_compatibility_t
+iree_hal_cuda_allocator_query_buffer_compatibility(
+ iree_hal_allocator_t* base_allocator, iree_hal_memory_type_t memory_type,
+ iree_hal_buffer_usage_t allowed_usage,
+ iree_hal_buffer_usage_t intended_usage,
+ iree_device_size_t allocation_size) {
+ // TODO(benvanik): check to ensure the allocator can serve the memory type.
+
+ // Disallow usage not permitted by the buffer itself. Since we then use this
+ // to determine compatibility below we'll naturally set the right compat flags
+ // based on what's both allowed and intended.
+ intended_usage &= allowed_usage;
+
+ // All buffers can be allocated on the heap.
+ iree_hal_buffer_compatibility_t compatibility =
+ IREE_HAL_BUFFER_COMPATIBILITY_ALLOCATABLE;
+
+ // Buffers can only be used on the queue if they are device visible.
+ if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)) {
+ if (iree_all_bits_set(intended_usage, IREE_HAL_BUFFER_USAGE_TRANSFER)) {
+ compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER;
+ }
+ if (iree_all_bits_set(intended_usage, IREE_HAL_BUFFER_USAGE_DISPATCH)) {
+ compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH;
+ }
+ }
+
+ return compatibility;
+}
+
+static iree_status_t iree_hal_cuda_allocator_allocate_buffer(
+ iree_hal_allocator_t* base_allocator, iree_hal_memory_type_t memory_type,
+ iree_hal_buffer_usage_t allowed_usage, iree_host_size_t allocation_size,
+ iree_hal_buffer_t** out_buffer) {
+ iree_hal_cuda_allocator_t* allocator =
+ iree_hal_cuda_allocator_cast(base_allocator);
+ // Guard against the corner case where the requested buffer size is 0. The
+ // application is unlikely to do anything when requesting a 0-byte buffer; but
+ // it can happen in real world use cases. So we should at least not crash.
+ if (allocation_size == 0) allocation_size = 4;
+ iree_status_t status;
+ void* host_ptr = NULL;
+ CUdeviceptr device_ptr = 0;
+ if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) {
+ unsigned int flags = CU_MEMHOSTALLOC_DEVICEMAP;
+ if (!iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_HOST_CACHED)) {
+ flags |= CU_MEMHOSTALLOC_WRITECOMBINED;
+ }
+ status =
+ CU_RESULT_TO_STATUS(allocator->context->syms,
+ cuMemHostAlloc(&host_ptr, allocation_size, flags));
+ if (iree_status_is_ok(status)) {
+ status = CU_RESULT_TO_STATUS(
+ allocator->context->syms,
+ cuMemHostGetDevicePointer(&device_ptr, host_ptr, /*flags=*/0));
+ }
+ } else {
+ status = CU_RESULT_TO_STATUS(allocator->context->syms,
+ cuMemAlloc(&device_ptr, allocation_size));
+ }
+
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_cuda_buffer_wrap(
+ (iree_hal_allocator_t*)allocator, memory_type,
+ IREE_HAL_MEMORY_ACCESS_ALL, allowed_usage, allocation_size,
+ /*byte_offset=*/0,
+ /*byte_length=*/allocation_size, device_ptr, host_ptr, out_buffer);
+ }
+ if (!iree_status_is_ok(status)) {
+ iree_hal_cuda_allocator_free(base_allocator, device_ptr, host_ptr,
+ memory_type);
+ }
+ return status;
+}
+
+void iree_hal_cuda_allocator_free(iree_hal_allocator_t* base_allocator,
+ CUdeviceptr device_ptr, void* host_ptr,
+ iree_hal_memory_type_t memory_type) {
+ iree_hal_cuda_allocator_t* allocator =
+ iree_hal_cuda_allocator_cast(base_allocator);
+ if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) {
+ CUDA_IGNORE_ERROR(allocator->context->syms, cuMemFreeHost(host_ptr));
+ } else {
+ CUDA_IGNORE_ERROR(allocator->context->syms, cuMemFree(device_ptr));
+ }
+}
+
+static iree_status_t iree_hal_cuda_allocator_wrap_buffer(
+ iree_hal_allocator_t* base_allocator, iree_hal_memory_type_t memory_type,
+ iree_hal_memory_access_t allowed_access,
+ iree_hal_buffer_usage_t allowed_usage, iree_byte_span_t data,
+ iree_allocator_t data_allocator, iree_hal_buffer_t** out_buffer) {
+ return iree_make_status(IREE_STATUS_UNAVAILABLE,
+ "wrapping of external buffers not supported");
+}
+
+const iree_hal_allocator_vtable_t iree_hal_cuda_allocator_vtable = {
+ .destroy = iree_hal_cuda_allocator_destroy,
+ .host_allocator = iree_hal_cuda_allocator_host_allocator,
+ .query_buffer_compatibility =
+ iree_hal_cuda_allocator_query_buffer_compatibility,
+ .allocate_buffer = iree_hal_cuda_allocator_allocate_buffer,
+ .wrap_buffer = iree_hal_cuda_allocator_wrap_buffer,
+};
diff --git a/iree/hal/cuda/cuda_allocator.h b/iree/hal/cuda/cuda_allocator.h
new file mode 100644
index 0000000..fcc015b
--- /dev/null
+++ b/iree/hal/cuda/cuda_allocator.h
@@ -0,0 +1,40 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_CUDA_ALLOCATOR_H_
+#define IREE_HAL_CUDA_ALLOCATOR_H_
+
+#include "iree/hal/api.h"
+#include "iree/hal/cuda/context_wrapper.h"
+#include "iree/hal/cuda/status_util.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// Create a cuda allocator.
+iree_status_t iree_hal_cuda_allocator_create(
+ iree_hal_cuda_context_wrapper_t* context,
+ iree_hal_allocator_t** out_allocator);
+
+// Free an allocation represent by the given device or host pointer.
+void iree_hal_cuda_allocator_free(iree_hal_allocator_t* allocator,
+ CUdeviceptr device_ptr, void* host_ptr,
+ iree_hal_memory_type_t memory_type);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_CUDA_ALLOCATOR_H_
diff --git a/iree/hal/cuda/cuda_buffer.c b/iree/hal/cuda/cuda_buffer.c
new file mode 100644
index 0000000..55a2115
--- /dev/null
+++ b/iree/hal/cuda/cuda_buffer.c
@@ -0,0 +1,141 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/hal/cuda/cuda_buffer.h"
+
+#include "iree/base/tracing.h"
+#include "iree/hal/cuda/cuda_allocator.h"
+#include "iree/hal/cuda/status_util.h"
+
+typedef struct iree_hal_cuda_buffer_s {
+ iree_hal_buffer_t base;
+ void* host_ptr;
+ CUdeviceptr device_ptr;
+} iree_hal_cuda_buffer_t;
+
+extern const iree_hal_buffer_vtable_t iree_hal_cuda_buffer_vtable;
+
+static iree_hal_cuda_buffer_t* iree_hal_cuda_buffer_cast(
+ iree_hal_buffer_t* base_value) {
+ IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_cuda_buffer_vtable);
+ return (iree_hal_cuda_buffer_t*)base_value;
+}
+
+iree_status_t iree_hal_cuda_buffer_wrap(
+ iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type,
+ iree_hal_memory_access_t allowed_access,
+ iree_hal_buffer_usage_t allowed_usage, iree_device_size_t allocation_size,
+ iree_device_size_t byte_offset, iree_device_size_t byte_length,
+ CUdeviceptr device_ptr, void* host_ptr, iree_hal_buffer_t** out_buffer) {
+ IREE_ASSERT_ARGUMENT(allocator);
+ IREE_ASSERT_ARGUMENT(out_buffer);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_hal_cuda_buffer_t* buffer = NULL;
+ iree_status_t status =
+ iree_allocator_malloc(iree_hal_allocator_host_allocator(allocator),
+ sizeof(*buffer), (void**)&buffer);
+ if (iree_status_is_ok(status)) {
+ iree_hal_resource_initialize(&iree_hal_cuda_buffer_vtable,
+ &buffer->base.resource);
+ buffer->base.allocator = allocator;
+ buffer->base.allocated_buffer = &buffer->base;
+ buffer->base.allocation_size = allocation_size;
+ buffer->base.byte_offset = byte_offset;
+ buffer->base.byte_length = byte_length;
+ buffer->base.memory_type = memory_type;
+ buffer->base.allowed_access = allowed_access;
+ buffer->base.allowed_usage = allowed_usage;
+ buffer->host_ptr = host_ptr;
+ buffer->device_ptr = device_ptr;
+ *out_buffer = &buffer->base;
+ }
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+static void iree_hal_cuda_buffer_destroy(iree_hal_buffer_t* base_buffer) {
+ iree_hal_cuda_buffer_t* buffer = iree_hal_cuda_buffer_cast(base_buffer);
+ iree_allocator_t host_allocator =
+ iree_hal_allocator_host_allocator(iree_hal_buffer_allocator(base_buffer));
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_hal_cuda_allocator_free(buffer->base.allocator, buffer->device_ptr,
+ buffer->host_ptr, buffer->base.memory_type);
+ iree_allocator_free(host_allocator, buffer);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+static iree_status_t iree_hal_cuda_buffer_map_range(
+ iree_hal_buffer_t* base_buffer, iree_hal_mapping_mode_t mapping_mode,
+ iree_hal_memory_access_t memory_access,
+ iree_device_size_t local_byte_offset, iree_device_size_t local_byte_length,
+ void** out_data_ptr) {
+ iree_hal_cuda_buffer_t* buffer = iree_hal_cuda_buffer_cast(base_buffer);
+
+ if (!iree_all_bits_set(buffer->base.memory_type,
+ IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) {
+ return iree_make_status(IREE_STATUS_INTERNAL,
+ "trying to map memory not host visible");
+ }
+
+ uint8_t* data_ptr = (uint8_t*)(buffer->host_ptr) + local_byte_offset;
+ // If we mapped for discard scribble over the bytes. This is not a mandated
+ // behavior but it will make debugging issues easier. Alternatively for
+ // heap buffers we could reallocate them such that ASAN yells, but that
+ // would only work if the entire buffer was discarded.
+#ifndef NDEBUG
+ if (iree_any_bit_set(memory_access, IREE_HAL_MEMORY_ACCESS_DISCARD)) {
+ memset(data_ptr + local_byte_offset, 0xCD, local_byte_length);
+ }
+#endif // !NDEBUG
+ *out_data_ptr = data_ptr;
+ return iree_ok_status();
+}
+
+static void iree_hal_cuda_buffer_unmap_range(
+ iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset,
+ iree_device_size_t local_byte_length, void* data_ptr) {
+ // nothing to do.
+}
+
+static iree_status_t iree_hal_cuda_buffer_invalidate_range(
+ iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset,
+ iree_device_size_t local_byte_length) {
+ // Nothing to do.
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_cuda_buffer_flush_range(
+ iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset,
+ iree_device_size_t local_byte_length) {
+ // Nothing to do.
+ return iree_ok_status();
+}
+
+CUdeviceptr iree_hal_cuda_buffer_device_pointer(
+ iree_hal_buffer_t* base_buffer) {
+ iree_hal_cuda_buffer_t* buffer = iree_hal_cuda_buffer_cast(base_buffer);
+ return buffer->device_ptr;
+}
+
+const iree_hal_buffer_vtable_t iree_hal_cuda_buffer_vtable = {
+ .destroy = iree_hal_cuda_buffer_destroy,
+ .map_range = iree_hal_cuda_buffer_map_range,
+ .unmap_range = iree_hal_cuda_buffer_unmap_range,
+ .invalidate_range = iree_hal_cuda_buffer_invalidate_range,
+ .flush_range = iree_hal_cuda_buffer_flush_range,
+};
diff --git a/iree/hal/cuda/cuda_buffer.h b/iree/hal/cuda/cuda_buffer.h
new file mode 100644
index 0000000..b9dcd9f
--- /dev/null
+++ b/iree/hal/cuda/cuda_buffer.h
@@ -0,0 +1,42 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_CUDA_BUFFER_H_
+#define IREE_HAL_CUDA_BUFFER_H_
+
+#include "iree/hal/api.h"
+#include "iree/hal/cuda/cuda_headers.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// Wraps a cuda allocation in an iree_hal_buffer_t.
+iree_status_t iree_hal_cuda_buffer_wrap(
+ iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type,
+ iree_hal_memory_access_t allowed_access,
+ iree_hal_buffer_usage_t allowed_usage, iree_device_size_t allocation_size,
+ iree_device_size_t byte_offset, iree_device_size_t byte_length,
+ CUdeviceptr device_ptr, void* host_ptr, iree_hal_buffer_t** out_buffer);
+
+// Returns the cuda base pointer for the given |buffer|.
+// This is the entire allocated_buffer and must be offset by the buffer
+// byte_offset and byte_length when used.
+CUdeviceptr iree_hal_cuda_buffer_device_pointer(iree_hal_buffer_t* buffer);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_CUDA_BUFFER_H_
diff --git a/iree/hal/cuda/cuda_device.c b/iree/hal/cuda/cuda_device.c
new file mode 100644
index 0000000..0430fe8
--- /dev/null
+++ b/iree/hal/cuda/cuda_device.c
@@ -0,0 +1,260 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/hal/cuda/cuda_device.h"
+
+#include "iree/base/status.h"
+#include "iree/base/tracing.h"
+#include "iree/hal/cuda/api.h"
+#include "iree/hal/cuda/cuda_allocator.h"
+#include "iree/hal/cuda/dynamic_symbols.h"
+#include "iree/hal/cuda/status_util.h"
+
+//===----------------------------------------------------------------------===//
+// iree_hal_cuda_device_t
+//===----------------------------------------------------------------------===//
+
+typedef struct {
+ iree_hal_resource_t resource;
+ iree_string_view_t identifier;
+
+ // Optional driver that owns the CUDA symbols. We retain it for our lifetime
+ // to ensure the symbols remains valid.
+ iree_hal_driver_t* driver;
+
+ CUdevice device;
+
+ // TODO: support multiple streams.
+ CUstream stream;
+ iree_hal_cuda_context_wrapper_t context_wrapper;
+ iree_hal_allocator_t* device_allocator;
+
+} iree_hal_cuda_device_t;
+
+extern const iree_hal_device_vtable_t iree_hal_cuda_device_vtable;
+
+static iree_hal_cuda_device_t* iree_hal_cuda_device_cast(
+ iree_hal_device_t* base_value) {
+ IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_cuda_device_vtable);
+ return (iree_hal_cuda_device_t*)base_value;
+}
+
+static void iree_hal_cuda_device_destroy(iree_hal_device_t* base_device) {
+ iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device);
+ iree_allocator_t host_allocator = iree_hal_device_host_allocator(base_device);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // There should be no more buffers live that use the allocator.
+ iree_hal_allocator_release(device->device_allocator);
+ CUDA_IGNORE_ERROR(device->context_wrapper.syms,
+ cuStreamDestroy(device->stream));
+
+ // Finally, destroy the device.
+ iree_hal_driver_release(device->driver);
+
+ iree_allocator_free(host_allocator, device);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+static iree_status_t iree_hal_cuda_device_create_internal(
+ iree_hal_driver_t* driver, iree_string_view_t identifier,
+ CUdevice cu_device, CUstream stream, CUcontext context,
+ iree_hal_cuda_dynamic_symbols_t* syms, iree_allocator_t host_allocator,
+ iree_hal_device_t** out_device) {
+ iree_hal_cuda_device_t* device = NULL;
+ iree_host_size_t total_size = sizeof(*device) + identifier.size;
+ IREE_RETURN_IF_ERROR(
+ iree_allocator_malloc(host_allocator, total_size, (void**)&device));
+ memset(device, 0, total_size);
+ iree_hal_resource_initialize(&iree_hal_cuda_device_vtable, &device->resource);
+ device->driver = driver;
+ iree_hal_driver_retain(device->driver);
+ uint8_t* buffer_ptr = (uint8_t*)device + sizeof(*device);
+ buffer_ptr += iree_string_view_append_to_buffer(
+ identifier, &device->identifier, (char*)buffer_ptr);
+ device->device = cu_device;
+ device->stream = stream;
+ device->context_wrapper.cu_context = context;
+ device->context_wrapper.host_allocator = host_allocator;
+ device->context_wrapper.syms = syms;
+ iree_status_t status = iree_hal_cuda_allocator_create(
+ &device->context_wrapper, &device->device_allocator);
+ if (iree_status_is_ok(status)) {
+ *out_device = (iree_hal_device_t*)device;
+ } else {
+ iree_hal_device_release((iree_hal_device_t*)device);
+ }
+ return status;
+}
+
+iree_status_t iree_hal_cuda_device_create(iree_hal_driver_t* driver,
+ iree_string_view_t identifier,
+ iree_hal_cuda_dynamic_symbols_t* syms,
+ CUdevice device,
+ iree_allocator_t host_allocator,
+ iree_hal_device_t** out_device) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+ CUcontext context;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, CU_RESULT_TO_STATUS(syms, cuCtxCreate(&context, 0, device)));
+ CUstream stream;
+ iree_status_t status = CU_RESULT_TO_STATUS(
+ syms, cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING));
+
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_cuda_device_create_internal(driver, identifier, device,
+ stream, context, syms,
+ host_allocator, out_device);
+ }
+ if (!iree_status_is_ok(status)) {
+ if (stream) {
+ syms->cuStreamDestroy(stream);
+ }
+ syms->cuCtxDestroy(context);
+ }
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+static iree_string_view_t iree_hal_cuda_device_id(
+ iree_hal_device_t* base_device) {
+ iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device);
+ return device->identifier;
+}
+
+static iree_allocator_t iree_hal_cuda_device_host_allocator(
+ iree_hal_device_t* base_device) {
+ iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device);
+ return device->context_wrapper.host_allocator;
+}
+
+static iree_hal_allocator_t* iree_hal_cuda_device_allocator(
+ iree_hal_device_t* base_device) {
+ iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device);
+ return device->device_allocator;
+}
+
+static iree_status_t iree_hal_cuda_device_create_command_buffer(
+ iree_hal_device_t* base_device, iree_hal_command_buffer_mode_t mode,
+ iree_hal_command_category_t command_categories,
+ iree_hal_command_buffer_t** out_command_buffer) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "Not impemented on CUDA");
+}
+
+static iree_status_t iree_hal_cuda_device_create_descriptor_set(
+ iree_hal_device_t* base_device,
+ iree_hal_descriptor_set_layout_t* set_layout,
+ iree_host_size_t binding_count,
+ const iree_hal_descriptor_set_binding_t* bindings,
+ iree_hal_descriptor_set_t** out_descriptor_set) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "non-push descriptor sets still need work");
+}
+
+static iree_status_t iree_hal_cuda_device_create_descriptor_set_layout(
+ iree_hal_device_t* base_device,
+ iree_hal_descriptor_set_layout_usage_type_t usage_type,
+ iree_host_size_t binding_count,
+ const iree_hal_descriptor_set_layout_binding_t* bindings,
+ iree_hal_descriptor_set_layout_t** out_descriptor_set_layout) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "Not impemented on CUDA");
+}
+
+static iree_status_t iree_hal_cuda_device_create_event(
+ iree_hal_device_t* base_device, iree_hal_event_t** out_event) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "Not impemented on CUDA");
+}
+
+static iree_status_t iree_hal_cuda_device_create_executable_cache(
+ iree_hal_device_t* base_device, iree_string_view_t identifier,
+ iree_hal_executable_cache_t** out_executable_cache) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "Not impemented on CUDA");
+}
+
+static iree_status_t iree_hal_cuda_device_create_executable_layout(
+ iree_hal_device_t* base_device, iree_host_size_t push_constants,
+ iree_host_size_t set_layout_count,
+ iree_hal_descriptor_set_layout_t** set_layouts,
+ iree_hal_executable_layout_t** out_executable_layout) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "Not impemented on CUDA");
+}
+
+static iree_status_t iree_hal_cuda_device_create_semaphore(
+ iree_hal_device_t* base_device, uint64_t initial_value,
+ iree_hal_semaphore_t** out_semaphore) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "Not impemented on CUDA");
+}
+
+static iree_status_t iree_hal_cuda_device_queue_submit(
+ iree_hal_device_t* base_device,
+ iree_hal_command_category_t command_categories, uint64_t queue_affinity,
+ iree_host_size_t batch_count, const iree_hal_submission_batch_t* batches) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "Not impemented on CUDA");
+}
+
+static iree_status_t iree_hal_cuda_device_wait_semaphores_with_timeout(
+ iree_hal_device_t* base_device, iree_hal_wait_mode_t wait_mode,
+ const iree_hal_semaphore_list_t* semaphore_list,
+ iree_duration_t timeout_ns) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "semaphore not implemented");
+}
+
+static iree_status_t iree_hal_cuda_device_wait_semaphores_with_deadline(
+ iree_hal_device_t* base_device, iree_hal_wait_mode_t wait_mode,
+ const iree_hal_semaphore_list_t* semaphore_list, iree_time_t deadline_ns) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "semaphore not implemented");
+}
+
+static iree_status_t iree_hal_cuda_device_wait_idle_with_deadline(
+ iree_hal_device_t* base_device, iree_time_t deadline_ns) {
+ iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device);
+ // Wait until the stream is done.
+ // TODO(thomasraoux): CUDA doesn't support a deadline for wait, figure out how
+ // to handle it better.
+ CUDA_RETURN_IF_ERROR(device->context_wrapper.syms,
+ cuStreamSynchronize(device->stream),
+ "cuStreamSynchronize");
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_cuda_device_wait_idle_with_timeout(
+ iree_hal_device_t* base_device, iree_duration_t timeout_ns) {
+ return iree_hal_cuda_device_wait_idle_with_deadline(
+ base_device, iree_relative_timeout_to_deadline_ns(timeout_ns));
+}
+
+const iree_hal_device_vtable_t iree_hal_cuda_device_vtable = {
+ .destroy = iree_hal_cuda_device_destroy,
+ .id = iree_hal_cuda_device_id,
+ .host_allocator = iree_hal_cuda_device_host_allocator,
+ .device_allocator = iree_hal_cuda_device_allocator,
+ .create_command_buffer = iree_hal_cuda_device_create_command_buffer,
+ .create_descriptor_set = iree_hal_cuda_device_create_descriptor_set,
+ .create_descriptor_set_layout =
+ iree_hal_cuda_device_create_descriptor_set_layout,
+ .create_event = iree_hal_cuda_device_create_event,
+ .create_executable_cache = iree_hal_cuda_device_create_executable_cache,
+ .create_executable_layout = iree_hal_cuda_device_create_executable_layout,
+ .create_semaphore = iree_hal_cuda_device_create_semaphore,
+ .queue_submit = iree_hal_cuda_device_queue_submit,
+ .wait_semaphores_with_deadline =
+ iree_hal_cuda_device_wait_semaphores_with_deadline,
+ .wait_semaphores_with_timeout =
+ iree_hal_cuda_device_wait_semaphores_with_timeout,
+ .wait_idle_with_deadline = iree_hal_cuda_device_wait_idle_with_deadline,
+ .wait_idle_with_timeout = iree_hal_cuda_device_wait_idle_with_timeout,
+};
diff --git a/iree/hal/cuda/cuda_device.h b/iree/hal/cuda/cuda_device.h
new file mode 100644
index 0000000..5d7d4ab
--- /dev/null
+++ b/iree/hal/cuda/cuda_device.h
@@ -0,0 +1,38 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_CUDA_CUDA_DEVICE_H_
+#define IREE_HAL_CUDA_CUDA_DEVICE_H_
+
+#include "iree/hal/api.h"
+#include "iree/hal/cuda/api.h"
+#include "iree/hal/cuda/dynamic_symbols.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// Creates a device that owns and manages its own CUcontext.
+iree_status_t iree_hal_cuda_device_create(iree_hal_driver_t* driver,
+ iree_string_view_t identifier,
+ iree_hal_cuda_dynamic_symbols_t* syms,
+ CUdevice device,
+ iree_allocator_t host_allocator,
+ iree_hal_device_t** out_device);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_CUDA_CUDA_DEVICE_H_
diff --git a/iree/hal/cuda/cuda_driver.c b/iree/hal/cuda/cuda_driver.c
new file mode 100644
index 0000000..79f70b2
--- /dev/null
+++ b/iree/hal/cuda/cuda_driver.c
@@ -0,0 +1,210 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/base/tracing.h"
+#include "iree/hal/cuda/api.h"
+#include "iree/hal/cuda/cuda_device.h"
+#include "iree/hal/cuda/dynamic_symbols.h"
+#include "iree/hal/cuda/status_util.h"
+
+typedef struct {
+ iree_hal_resource_t resource;
+ iree_allocator_t host_allocator;
+ // Identifier used for the driver in the IREE driver registry.
+ // We allow overriding so that multiple CUDA versions can be exposed in the
+ // same process.
+ iree_string_view_t identifier;
+ int default_device_index;
+ // CUDA symbols.
+ iree_hal_cuda_dynamic_symbols_t syms;
+} iree_hal_cuda_driver_t;
+
+// Pick a fixed lenght size for device names.
+static const size_t IREE_MAX_CUDA_DEVICE_NAME_LENGTH = 100;
+
+extern const iree_hal_driver_vtable_t iree_hal_cuda_driver_vtable;
+
+static iree_hal_cuda_driver_t* iree_hal_cuda_driver_cast(
+ iree_hal_driver_t* base_value) {
+ IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_cuda_driver_vtable);
+ return (iree_hal_cuda_driver_t*)base_value;
+}
+
+IREE_API_EXPORT void IREE_API_CALL iree_hal_cuda_driver_options_initialize(
+ iree_hal_cuda_driver_options_t* out_options) {
+ memset(out_options, 0, sizeof(*out_options));
+ out_options->default_device_index = 0;
+}
+
+static iree_status_t iree_hal_cuda_driver_create_internal(
+ iree_string_view_t identifier,
+ const iree_hal_cuda_driver_options_t* options,
+ iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) {
+ iree_hal_cuda_driver_t* driver = NULL;
+ iree_host_size_t total_size = sizeof(*driver) + identifier.size;
+ IREE_RETURN_IF_ERROR(
+ iree_allocator_malloc(host_allocator, total_size, (void**)&driver));
+ iree_hal_resource_initialize(&iree_hal_cuda_driver_vtable, &driver->resource);
+ driver->host_allocator = host_allocator;
+ iree_string_view_append_to_buffer(
+ identifier, &driver->identifier,
+ (char*)driver + total_size - identifier.size);
+ driver->default_device_index = options->default_device_index;
+ iree_status_t status = load_symbols(&driver->syms);
+ if (iree_status_is_ok(status)) {
+ *out_driver = (iree_hal_driver_t*)driver;
+ } else {
+ iree_hal_driver_release((iree_hal_driver_t*)driver);
+ }
+ return status;
+}
+
+static void iree_hal_cuda_driver_destroy(iree_hal_driver_t* base_driver) {
+ iree_hal_cuda_driver_t* driver = iree_hal_cuda_driver_cast(base_driver);
+ iree_allocator_t host_allocator = driver->host_allocator;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ unload_symbols(&driver->syms);
+ iree_allocator_free(host_allocator, driver);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_cuda_driver_create(
+ iree_string_view_t identifier,
+ const iree_hal_cuda_driver_options_t* options,
+ iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) {
+ IREE_ASSERT_ARGUMENT(options);
+ IREE_ASSERT_ARGUMENT(out_driver);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_status_t status = iree_hal_cuda_driver_create_internal(
+ identifier, options, host_allocator, out_driver);
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+// Populates device information from the given CUDA physical device handle.
+// |out_device_info| must point to valid memory and additional data will be
+// appended to |buffer_ptr| and the new pointer is returned.
+static uint8_t* iree_hal_cuda_populate_device_info(
+ CUdevice device, iree_hal_cuda_dynamic_symbols_t* syms, uint8_t* buffer_ptr,
+ iree_hal_device_info_t* out_device_info) {
+ char device_name[IREE_MAX_CUDA_DEVICE_NAME_LENGTH];
+ CUDA_IGNORE_ERROR(syms,
+ cuDeviceGetName(device_name, sizeof(device_name), device));
+ memset(out_device_info, 0, sizeof(*out_device_info));
+ out_device_info->device_id = (iree_hal_device_id_t)device;
+
+ iree_string_view_t device_name_string =
+ iree_make_string_view(device_name, strlen(device_name));
+ buffer_ptr += iree_string_view_append_to_buffer(
+ device_name_string, &out_device_info->name, (char*)buffer_ptr);
+ return buffer_ptr;
+}
+
+static iree_status_t iree_hal_cuda_driver_query_available_devices(
+ iree_hal_driver_t* base_driver, iree_allocator_t host_allocator,
+ iree_hal_device_info_t** out_device_infos,
+ iree_host_size_t* out_device_info_count) {
+ iree_hal_cuda_driver_t* driver = iree_hal_cuda_driver_cast(base_driver);
+ // Query the number of available CUDA devices.
+ int device_count = 0;
+ CUDA_RETURN_IF_ERROR(&driver->syms, cuDeviceGetCount(&device_count),
+ "cuDeviceGetCount");
+
+ // Allocate the return infos and populate with the devices.
+ iree_hal_device_info_t* device_infos = NULL;
+ iree_host_size_t total_size = device_count * sizeof(iree_hal_device_info_t);
+ for (iree_host_size_t i = 0; i < device_count; ++i) {
+ total_size += IREE_MAX_CUDA_DEVICE_NAME_LENGTH * sizeof(char);
+ }
+ iree_status_t status =
+ iree_allocator_malloc(host_allocator, total_size, (void**)&device_infos);
+ if (iree_status_is_ok(status)) {
+ uint8_t* buffer_ptr =
+ (uint8_t*)device_infos + device_count * sizeof(iree_hal_device_info_t);
+ for (iree_host_size_t i = 0; i < device_count; ++i) {
+ CUdevice device;
+ iree_status_t status = CU_RESULT_TO_STATUS(
+ &driver->syms, cuDeviceGet(&device, i), "cuDeviceGet");
+ if (!iree_status_is_ok(status)) break;
+ buffer_ptr = iree_hal_cuda_populate_device_info(
+ device, &driver->syms, buffer_ptr, &device_infos[i]);
+ }
+ }
+ if (iree_status_is_ok(status)) {
+ *out_device_info_count = device_count;
+ *out_device_infos = device_infos;
+ } else {
+ iree_allocator_free(host_allocator, device_infos);
+ }
+ return status;
+}
+
+static iree_status_t iree_hal_cuda_driver_select_default_device(
+ iree_hal_cuda_dynamic_symbols_t* syms, int default_device_index,
+ iree_allocator_t host_allocator, CUdevice* out_device) {
+ int device_count = 0;
+ CUDA_RETURN_IF_ERROR(syms, cuDeviceGetCount(&device_count),
+ "cuDeviceGetCount");
+ iree_status_t status = iree_ok_status();
+ if (device_count == 0 || default_device_index >= device_count) {
+ status = iree_make_status(IREE_STATUS_NOT_FOUND,
+ "default device %d not found (of %d enumerated)",
+ default_device_index, device_count);
+ } else {
+ CUdevice device;
+ CUDA_RETURN_IF_ERROR(syms, cuDeviceGet(&device, default_device_index),
+ "cuDeviceGet");
+ *out_device = device;
+ }
+ return status;
+}
+
+static iree_status_t iree_hal_cuda_driver_create_device(
+ iree_hal_driver_t* base_driver, iree_hal_device_id_t device_id,
+ iree_allocator_t host_allocator, iree_hal_device_t** out_device) {
+ iree_hal_cuda_driver_t* driver = iree_hal_cuda_driver_cast(base_driver);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, CU_RESULT_TO_STATUS(&driver->syms, cuInit(0), "cuInit"));
+ // Use either the specified device (enumerated earlier) or whatever default
+ // one was specified when the driver was created.
+ CUdevice device = (CUdevice)device_id;
+ if (device == 0) {
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_cuda_driver_select_default_device(
+ &driver->syms, driver->default_device_index, host_allocator,
+ &device));
+ }
+
+ iree_string_view_t device_name = iree_make_cstring_view("cuda");
+
+ // Attempt to create the device.
+ iree_status_t status =
+ iree_hal_cuda_device_create(base_driver, device_name, &driver->syms,
+ device, host_allocator, out_device);
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+const iree_hal_driver_vtable_t iree_hal_cuda_driver_vtable = {
+ .destroy = iree_hal_cuda_driver_destroy,
+ .query_available_devices = iree_hal_cuda_driver_query_available_devices,
+ .create_device = iree_hal_cuda_driver_create_device,
+};
diff --git a/iree/hal/cuda/dynamic_symbols.cc b/iree/hal/cuda/dynamic_symbols.cc
index 0927116..df13a4d 100644
--- a/iree/hal/cuda/dynamic_symbols.cc
+++ b/iree/hal/cuda/dynamic_symbols.cc
@@ -17,15 +17,12 @@
#include <cstddef>
#include "absl/types/span.h"
+#include "iree/base/dynamic_library.h"
#include "iree/base/status.h"
#include "iree/base/target_platform.h"
#include "iree/base/tracing.h"
-namespace iree {
-namespace hal {
-namespace cuda {
-
-static const char* kCudaLoaderSearchNames[] = {
+static const char* kCUDALoaderSearchNames[] = {
#if defined(IREE_PLATFORM_WINDOWS)
"nvcuda.dll",
#else
@@ -33,28 +30,31 @@
#endif
};
-Status DynamicSymbols::LoadSymbols() {
- IREE_TRACE_SCOPE();
+extern "C" {
- IREE_RETURN_IF_ERROR(DynamicLibrary::Load(
- absl::MakeSpan(kCudaLoaderSearchNames), &loader_library_));
+iree_status_t load_symbols(iree_hal_cuda_dynamic_symbols_t* syms) {
+ std::unique_ptr<iree::DynamicLibrary> loader_library;
+ IREE_RETURN_IF_ERROR(iree::DynamicLibrary::Load(
+ absl::MakeSpan(kCUDALoaderSearchNames), &loader_library));
-#define CU_PFN_DECL(cudaSymbolName) \
+#define CU_PFN_DECL(cudaSymbolName, ...) \
{ \
- using FuncPtrT = std::add_pointer<decltype(::cudaSymbolName)>::type; \
+ using FuncPtrT = decltype(syms->cudaSymbolName); \
static const char* kName = #cudaSymbolName; \
- cudaSymbolName = loader_library_->GetSymbol<FuncPtrT>(kName); \
- if (!cudaSymbolName) { \
+ syms->cudaSymbolName = loader_library->GetSymbol<FuncPtrT>(kName); \
+ if (!syms->cudaSymbolName) { \
return iree_make_status(IREE_STATUS_UNAVAILABLE, "symbol not found"); \
} \
}
#include "dynamic_symbols_tables.h"
#undef CU_PFN_DECL
-
- return OkStatus();
+ syms->opaque_loader_library_ = (void*)loader_library.release();
+ return iree_ok_status();
}
-} // namespace cuda
-} // namespace hal
-} // namespace iree
+void unload_symbols(iree_hal_cuda_dynamic_symbols_t* syms) {
+ delete (iree::DynamicLibrary*)syms->opaque_loader_library_;
+}
+
+} // extern "C"
\ No newline at end of file
diff --git a/iree/hal/cuda/dynamic_symbols.h b/iree/hal/cuda/dynamic_symbols.h
index 9d2c40e..436b32d 100644
--- a/iree/hal/cuda/dynamic_symbols.h
+++ b/iree/hal/cuda/dynamic_symbols.h
@@ -15,38 +15,29 @@
#ifndef IREE_HAL_CUDA_DYNAMIC_SYMBOLS_H_
#define IREE_HAL_CUDA_DYNAMIC_SYMBOLS_H_
-#include <cstdint>
-#include <functional>
-#include <memory>
-
-#include "iree/base/dynamic_library.h"
-#include "iree/base/status.h"
+#include "iree/base/api.h"
#include "iree/hal/cuda/cuda_headers.h"
-namespace iree {
-namespace hal {
-namespace cuda {
-
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
/// DyanmicSymbols allow loading dynamically a subset of CUDA driver API. It
/// loads all the function declared in `dynamic_symbol_tables.def` and fail if
/// any of the symbol is not available. The functions signatures are matching
/// the declarations in `cuda.h`.
-struct DynamicSymbols {
- Status LoadSymbols();
-
-#define CU_PFN_DECL(cudaSymbolName) \
- std::add_pointer<decltype(::cudaSymbolName)>::type cudaSymbolName;
-
+typedef struct {
+#define CU_PFN_DECL(cudaSymbolName, ...) \
+ CUresult (*cudaSymbolName)(__VA_ARGS__);
#include "dynamic_symbols_tables.h"
#undef CU_PFN_DECL
+ void* opaque_loader_library_;
+} iree_hal_cuda_dynamic_symbols_t;
- private:
- // Cuda Loader dynamic library.
- std::unique_ptr<DynamicLibrary> loader_library_;
-};
+iree_status_t load_symbols(iree_hal_cuda_dynamic_symbols_t* syms);
+void unload_symbols(iree_hal_cuda_dynamic_symbols_t* syms);
-} // namespace cuda
-} // namespace hal
-} // namespace iree
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
#endif // IREE_HAL_CUDA_DYNAMIC_SYMBOLS_H_
diff --git a/iree/hal/cuda/dynamic_symbols_tables.h b/iree/hal/cuda/dynamic_symbols_tables.h
index 5adece6..1a9fbaa 100644
--- a/iree/hal/cuda/dynamic_symbols_tables.h
+++ b/iree/hal/cuda/dynamic_symbols_tables.h
@@ -12,79 +12,37 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-CU_PFN_DECL(cuCtxCreate)
-CU_PFN_DECL(cuCtxDestroy)
-CU_PFN_DECL(cuCtxEnablePeerAccess)
-CU_PFN_DECL(cuCtxGetCurrent)
-CU_PFN_DECL(cuCtxGetDevice)
-CU_PFN_DECL(cuCtxGetSharedMemConfig)
-CU_PFN_DECL(cuCtxSetCurrent)
-CU_PFN_DECL(cuCtxSetSharedMemConfig)
-CU_PFN_DECL(cuCtxSynchronize)
-CU_PFN_DECL(cuDeviceCanAccessPeer)
-CU_PFN_DECL(cuDeviceGet)
-CU_PFN_DECL(cuDeviceGetAttribute)
-CU_PFN_DECL(cuDeviceGetCount)
-CU_PFN_DECL(cuDeviceGetName)
-CU_PFN_DECL(cuDeviceGetPCIBusId)
-CU_PFN_DECL(cuDevicePrimaryCtxGetState)
-CU_PFN_DECL(cuDevicePrimaryCtxRelease)
-CU_PFN_DECL(cuDevicePrimaryCtxRetain)
-CU_PFN_DECL(cuDevicePrimaryCtxSetFlags)
-CU_PFN_DECL(cuDeviceTotalMem)
-CU_PFN_DECL(cuDriverGetVersion)
-CU_PFN_DECL(cuEventCreate)
-CU_PFN_DECL(cuEventDestroy)
-CU_PFN_DECL(cuEventElapsedTime)
-CU_PFN_DECL(cuEventQuery)
-CU_PFN_DECL(cuEventRecord)
-CU_PFN_DECL(cuEventSynchronize)
-CU_PFN_DECL(cuFuncGetAttribute)
-CU_PFN_DECL(cuFuncSetCacheConfig)
-CU_PFN_DECL(cuGetErrorName)
-CU_PFN_DECL(cuGetErrorString)
-CU_PFN_DECL(cuGraphAddMemcpyNode)
-CU_PFN_DECL(cuGraphAddMemsetNode)
-CU_PFN_DECL(cuGraphAddKernelNode)
-CU_PFN_DECL(cuGraphCreate)
-CU_PFN_DECL(cuGraphDestroy)
-CU_PFN_DECL(cuGraphExecDestroy)
-CU_PFN_DECL(cuGraphGetNodes)
-CU_PFN_DECL(cuGraphInstantiate)
-CU_PFN_DECL(cuGraphLaunch)
-CU_PFN_DECL(cuInit)
-CU_PFN_DECL(cuLaunchKernel)
-CU_PFN_DECL(cuMemAlloc)
-CU_PFN_DECL(cuMemAllocManaged)
-CU_PFN_DECL(cuMemFree)
-CU_PFN_DECL(cuMemFreeHost)
-CU_PFN_DECL(cuMemGetAddressRange)
-CU_PFN_DECL(cuMemGetInfo)
-CU_PFN_DECL(cuMemHostAlloc)
-CU_PFN_DECL(cuMemHostGetDevicePointer)
-CU_PFN_DECL(cuMemHostRegister)
-CU_PFN_DECL(cuMemHostUnregister)
-CU_PFN_DECL(cuMemcpyDtoD)
-CU_PFN_DECL(cuMemcpyDtoDAsync)
-CU_PFN_DECL(cuMemcpyDtoH)
-CU_PFN_DECL(cuMemcpyDtoHAsync)
-CU_PFN_DECL(cuMemcpyHtoD)
-CU_PFN_DECL(cuMemcpyHtoDAsync)
-CU_PFN_DECL(cuMemsetD32)
-CU_PFN_DECL(cuMemsetD32Async)
-CU_PFN_DECL(cuMemsetD8)
-CU_PFN_DECL(cuMemsetD8Async)
-CU_PFN_DECL(cuModuleGetFunction)
-CU_PFN_DECL(cuModuleGetGlobal)
-CU_PFN_DECL(cuModuleLoadDataEx)
-CU_PFN_DECL(cuModuleLoadFatBinary)
-CU_PFN_DECL(cuModuleUnload)
-CU_PFN_DECL(cuOccupancyMaxActiveBlocksPerMultiprocessor)
-CU_PFN_DECL(cuOccupancyMaxPotentialBlockSize)
-CU_PFN_DECL(cuPointerGetAttribute)
-CU_PFN_DECL(cuStreamAddCallback)
-CU_PFN_DECL(cuStreamCreate)
-CU_PFN_DECL(cuStreamDestroy)
-CU_PFN_DECL(cuStreamQuery)
-CU_PFN_DECL(cuStreamSynchronize)
-CU_PFN_DECL(cuStreamWaitEvent)
+CU_PFN_DECL(cuCtxCreate, CUcontext*, unsigned int, CUdevice)
+CU_PFN_DECL(cuCtxDestroy, CUcontext)
+CU_PFN_DECL(cuDeviceGet, CUdevice*, int)
+CU_PFN_DECL(cuDeviceGetCount, int*)
+CU_PFN_DECL(cuDeviceGetName, char*, int, CUdevice)
+CU_PFN_DECL(cuGetErrorName, CUresult, const char**)
+CU_PFN_DECL(cuGetErrorString, CUresult, const char**)
+CU_PFN_DECL(cuGraphAddMemcpyNode, CUgraphNode*, CUgraph, const CUgraphNode*,
+ size_t, const CUDA_MEMCPY3D*, CUcontext)
+CU_PFN_DECL(cuGraphAddMemsetNode, CUgraphNode*, CUgraph, const CUgraphNode*,
+ size_t, const CUDA_MEMSET_NODE_PARAMS*, CUcontext)
+CU_PFN_DECL(cuGraphAddKernelNode, CUgraphNode*, CUgraph, const CUgraphNode*,
+ size_t, const CUDA_KERNEL_NODE_PARAMS*)
+CU_PFN_DECL(cuGraphCreate, CUgraph*, unsigned int)
+CU_PFN_DECL(cuGraphDestroy, CUgraph)
+CU_PFN_DECL(cuGraphExecDestroy, CUgraphExec)
+CU_PFN_DECL(cuGraphGetNodes, CUgraph, CUgraphNode*, size_t*)
+CU_PFN_DECL(cuGraphInstantiate, CUgraphExec*, CUgraph, CUgraphNode*, char*,
+ size_t)
+CU_PFN_DECL(cuGraphLaunch, CUgraphExec, CUstream)
+CU_PFN_DECL(cuInit, unsigned int)
+CU_PFN_DECL(cuMemAlloc, CUdeviceptr*, size_t)
+CU_PFN_DECL(cuMemFree, CUdeviceptr)
+CU_PFN_DECL(cuMemFreeHost, void*)
+CU_PFN_DECL(cuMemHostAlloc, void**, size_t, unsigned int)
+CU_PFN_DECL(cuMemHostGetDevicePointer, CUdeviceptr*, void*, unsigned int)
+CU_PFN_DECL(cuModuleGetFunction, CUfunction*, CUmodule, const char*)
+CU_PFN_DECL(cuModuleLoadDataEx, CUmodule*, const void*, unsigned int,
+ CUjit_option*, void**)
+CU_PFN_DECL(cuModuleUnload, CUmodule)
+CU_PFN_DECL(cuStreamCreate, CUstream*, unsigned int)
+CU_PFN_DECL(cuStreamDestroy, CUstream)
+CU_PFN_DECL(cuStreamSynchronize, CUstream)
+CU_PFN_DECL(cuStreamWaitEvent, CUstream, CUevent, unsigned int)
diff --git a/iree/hal/cuda/dynamic_symbols_test.cc b/iree/hal/cuda/dynamic_symbols_test.cc
index 6a7967c..5b8681f 100644
--- a/iree/hal/cuda/dynamic_symbols_test.cc
+++ b/iree/hal/cuda/dynamic_symbols_test.cc
@@ -29,9 +29,9 @@
}
TEST(DynamicSymbolsTest, CreateFromSystemLoader) {
- DynamicSymbols symbols;
- Status status = symbols.LoadSymbols();
- if (!status.ok()) {
+ iree_hal_cuda_dynamic_symbols_t symbols;
+ iree_status_t status = load_symbols(&symbols);
+ if (!iree_status_is_ok(status)) {
IREE_LOG(WARNING) << "Symbols cannot be loaded, skipping test.";
GTEST_SKIP();
}
@@ -43,6 +43,7 @@
CUdevice device;
CUDE_CHECK_ERRORS(symbols.cuDeviceGet(&device, /*ordinal=*/0));
}
+ unload_symbols(&symbols);
}
} // namespace
diff --git a/iree/hal/cuda/registration/BUILD b/iree/hal/cuda/registration/BUILD
new file mode 100644
index 0000000..58157f4
--- /dev/null
+++ b/iree/hal/cuda/registration/BUILD
@@ -0,0 +1,51 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_cmake_extra_content(
+ content = """
+if(${IREE_HAL_DRIVER_CUDA})
+""",
+ inline = True,
+)
+
+cc_library(
+ name = "registration",
+ srcs = ["driver_module.c"],
+ hdrs = ["driver_module.h"],
+ defines = [
+ "IREE_HAL_HAVE_CUDA_DRIVER_MODULE=1",
+ ],
+ deps = [
+ "//iree/base:core_headers",
+ "//iree/base:status",
+ "//iree/base:tracing",
+ "//iree/hal:api",
+ "//iree/hal/cuda",
+ ],
+)
+
+iree_cmake_extra_content(
+ content = """
+endif()
+""",
+ inline = True,
+)
diff --git a/iree/hal/cuda/registration/CMakeLists.txt b/iree/hal/cuda/registration/CMakeLists.txt
new file mode 100644
index 0000000..93d8dac
--- /dev/null
+++ b/iree/hal/cuda/registration/CMakeLists.txt
@@ -0,0 +1,25 @@
+# Autogenerated from iree/hal/cuda/registration/BUILD by
+# build_tools/bazel_to_cmake/bazel_to_cmake.py
+iree_add_all_subdirs()
+
+if(${IREE_HAL_DRIVER_CUDA})
+
+iree_cc_library(
+ NAME
+ registration
+ HDRS
+ "driver_module.h"
+ SRCS
+ "driver_module.c"
+ DEPS
+ iree::base::core_headers
+ iree::base::status
+ iree::base::tracing
+ iree::hal::api
+ iree::hal::cuda
+ DEFINES
+ "IREE_HAL_HAVE_CUDA_DRIVER_MODULE=1"
+ PUBLIC
+)
+
+endif()
diff --git a/iree/hal/cuda/registration/driver_module.c b/iree/hal/cuda/registration/driver_module.c
new file mode 100644
index 0000000..5677cd1
--- /dev/null
+++ b/iree/hal/cuda/registration/driver_module.c
@@ -0,0 +1,72 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/hal/cuda/registration/driver_module.h"
+
+#include <inttypes.h>
+
+#include "iree/base/status.h"
+#include "iree/base/target_platform.h"
+#include "iree/base/tracing.h"
+#include "iree/hal/cuda/api.h"
+
+#define IREE_HAL_CUDA_DRIVER_ID 0x43554441u // CUDA
+
+static iree_status_t iree_hal_cuda_driver_factory_enumerate(
+ void* self, const iree_hal_driver_info_t** out_driver_infos,
+ iree_host_size_t* out_driver_info_count) {
+ // NOTE: we could query supported cuda versions or featuresets here.
+ static const iree_hal_driver_info_t driver_infos[1] = {{
+ .driver_id = IREE_HAL_CUDA_DRIVER_ID,
+ .driver_name = iree_string_view_literal("cuda"),
+ .full_name = iree_string_view_literal("CUDA (dynamic)"),
+ }};
+ *out_driver_info_count = IREE_ARRAYSIZE(driver_infos);
+ *out_driver_infos = driver_infos;
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_cuda_driver_factory_try_create(
+ void* self, iree_hal_driver_id_t driver_id, iree_allocator_t allocator,
+ iree_hal_driver_t** out_driver) {
+ IREE_ASSERT_ARGUMENT(out_driver);
+ *out_driver = NULL;
+ if (driver_id != IREE_HAL_CUDA_DRIVER_ID) {
+ return iree_make_status(IREE_STATUS_UNAVAILABLE,
+ "no driver with ID %016" PRIu64
+ " is provided by this factory",
+ driver_id);
+ }
+ IREE_TRACE_ZONE_BEGIN(z0);
+ // When we expose more than one driver (different cuda versions, etc) we
+ // can name them here:
+ iree_string_view_t identifier = iree_make_cstring_view("cuda");
+
+ iree_hal_cuda_driver_options_t driver_options;
+ iree_hal_cuda_driver_options_initialize(&driver_options);
+ iree_status_t status = iree_hal_cuda_driver_create(
+ identifier, &driver_options, allocator, out_driver);
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_cuda_driver_module_register(iree_hal_driver_registry_t* registry) {
+ static const iree_hal_driver_factory_t factory = {
+ .self = NULL,
+ .enumerate = iree_hal_cuda_driver_factory_enumerate,
+ .try_create = iree_hal_cuda_driver_factory_try_create,
+ };
+ return iree_hal_driver_registry_register_factory(registry, &factory);
+}
diff --git a/iree/hal/cuda/registration/driver_module.h b/iree/hal/cuda/registration/driver_module.h
new file mode 100644
index 0000000..897594c
--- /dev/null
+++ b/iree/hal/cuda/registration/driver_module.h
@@ -0,0 +1,31 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_CUDA_REGISTRATION_DRIVER_MODULE_H_
+#define IREE_HAL_CUDA_REGISTRATION_DRIVER_MODULE_H_
+
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_cuda_driver_module_register(iree_hal_driver_registry_t* registry);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_CUDA_REGISTRATION_DRIVER_MODULE_H_
diff --git a/iree/hal/cuda/status_util.c b/iree/hal/cuda/status_util.c
new file mode 100644
index 0000000..9a084a8
--- /dev/null
+++ b/iree/hal/cuda/status_util.c
@@ -0,0 +1,38 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/hal/cuda/status_util.h"
+
+#include "iree/hal/cuda/dynamic_symbols.h"
+
+iree_status_t iree_hal_cuda_result_to_status(
+ iree_hal_cuda_dynamic_symbols_t* syms, CUresult result, const char* file,
+ uint32_t line) {
+ if (IREE_LIKELY(result == CUDA_SUCCESS)) {
+ return iree_ok_status();
+ }
+
+ const char* error_name = NULL;
+ if (syms->cuGetErrorName(result, &error_name) != CUDA_SUCCESS) {
+ error_name = "UNKNOWN";
+ }
+
+ const char* error_string = NULL;
+ if (syms->cuGetErrorString(result, &error_string) != CUDA_SUCCESS) {
+ error_string = "Unknown error.";
+ }
+ return iree_make_status(IREE_STATUS_INTERNAL,
+ "CUDA driver error '%s' (%d): %s", error_name, result,
+ error_string);
+}
diff --git a/iree/hal/cuda/status_util.h b/iree/hal/cuda/status_util.h
new file mode 100644
index 0000000..526b56f
--- /dev/null
+++ b/iree/hal/cuda/status_util.h
@@ -0,0 +1,60 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_CUDA_STATUS_UTIL_H_
+#define IREE_HAL_CUDA_STATUS_UTIL_H_
+
+#include "iree/base/api.h"
+#include "iree/hal/cuda/dynamic_symbols.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// Converts a CUresult to an iree_status_t.
+//
+// Usage:
+// iree_status_t status = CU_RESULT_TO_STATUS(cuDoThing(...));
+#define CU_RESULT_TO_STATUS(syms, expr, ...) \
+ iree_hal_cuda_result_to_status((syms), ((syms)->expr), __FILE__, __LINE__)
+
+// IREE_RETURN_IF_ERROR but implicitly converts the CUresult return value to
+// a Status.
+//
+// Usage:
+// CUDA_RETURN_IF_ERROR(cuDoThing(...), "message");
+#define CUDA_RETURN_IF_ERROR(syms, expr, ...) \
+ IREE_RETURN_IF_ERROR(iree_hal_cuda_result_to_status((syms), ((syms)->expr), \
+ __FILE__, __LINE__), \
+ __VA_ARGS__)
+
+// IREE_IGNORE_ERROR but implicitly converts the CUresult return value to a
+// ::util::Status and checks that it is OkStatus.
+//
+// Usage:
+// CUDA_IGNORE_ERROR(cuDoThing(...));
+#define CUDA_IGNORE_ERROR(syms, expr) \
+ IREE_IGNORE_ERROR(iree_hal_cuda_result_to_status((syms), ((syms)->expr), \
+ __FILE__, __LINE__))
+
+// Converts a CUresult to a Status object.
+iree_status_t iree_hal_cuda_result_to_status(
+ iree_hal_cuda_dynamic_symbols_t* syms, CUresult result, const char* file,
+ uint32_t line);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_CUDA_STATUS_UTIL_H_
diff --git a/iree/hal/drivers/CMakeLists.txt b/iree/hal/drivers/CMakeLists.txt
index 305bcb5..1994aef 100644
--- a/iree/hal/drivers/CMakeLists.txt
+++ b/iree/hal/drivers/CMakeLists.txt
@@ -15,6 +15,9 @@
# bazel_to_cmake: DO NOT EDIT (custom configuration vars)
set(IREE_HAL_DRIVER_MODULES)
+if(${IREE_HAL_DRIVER_CUDA})
+ list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::cuda::registration)
+endif()
if(${IREE_HAL_DRIVER_DYLIB})
list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::dylib::registration)
endif()
diff --git a/iree/hal/drivers/init.c b/iree/hal/drivers/init.c
index bb36616..bbcd9b9 100644
--- a/iree/hal/drivers/init.c
+++ b/iree/hal/drivers/init.c
@@ -16,6 +16,10 @@
#include "iree/base/tracing.h"
+#if defined(IREE_HAL_HAVE_CUDA_DRIVER_MODULE)
+#include "iree/hal/cuda/registration/driver_module.h"
+#endif // IREE_HAL_HAVE_CUDA_DRIVER_MODULE
+
#if defined(IREE_HAL_HAVE_DYLIB_DRIVER_MODULE)
#include "iree/hal/dylib/registration/driver_module.h"
#endif // IREE_HAL_HAVE_DYLIB_DRIVER_MODULE
@@ -32,6 +36,11 @@
iree_hal_register_all_available_drivers(iree_hal_driver_registry_t* registry) {
IREE_TRACE_ZONE_BEGIN(z0);
+#if defined(IREE_HAL_HAVE_CUDA_DRIVER_MODULE)
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_cuda_driver_module_register(registry));
+#endif // IREE_HAL_HAVE_CUDA_DRIVER_MODULE
+
#if defined(IREE_HAL_HAVE_DYLIB_DRIVER_MODULE)
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_dylib_driver_module_register(registry));