Initial Adding ROCM HAL Backend to Experimental (#5943)
Initial pass to integrate ROCm in to IREE so that we can Codegen and run on AMDGPUs. Following steps similar to thomasraoux's CUDA backend. Since ROCm do not have graph or CommandBuffer by default, we implement ROCm's command buffer using stream API to default stream. Tested out and pass most CTS tests except:
semaphore_submission_test + semaphore_test-> some functionalities not implemented for rocm backend yet
command_buffer_test -> CommandBufferTest.CopySubBuffer
In the next patch:
-Complete semaphore functionality
-Squash CommandBuffer bugs
diff --git a/experimental/rocm/CMakeLists.txt b/experimental/rocm/CMakeLists.txt
new file mode 100644
index 0000000..0e3f824
--- /dev/null
+++ b/experimental/rocm/CMakeLists.txt
@@ -0,0 +1,93 @@
+# Copyright 2021 Google LLC
+
+if(NOT ${IREE_BUILD_EXPERIMENTAL_ROCM})
+ return()
+endif()
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ rocm
+ HDRS
+ "api.h"
+ SRCS
+ "api.h"
+ "context_wrapper.h"
+ "rocm_allocator.c"
+ "rocm_allocator.h"
+ "rocm_buffer.c"
+ "rocm_buffer.h"
+ "rocm_device.c"
+ "rocm_device.h"
+ "rocm_driver.c"
+ "rocm_event.c"
+ "rocm_event.h"
+ "descriptor_set_layout.c"
+ "descriptor_set_layout.h"
+ "event_semaphore.c"
+ "event_semaphore.h"
+ "executable_layout.c"
+ "executable_layout.h"
+ "direct_command_buffer.c"
+ "direct_command_buffer.h"
+ "native_executable.c"
+ "native_executable.h"
+ "nop_executable_cache.c"
+ "nop_executable_cache.h"
+ "status_util.c"
+ "status_util.h"
+ INCLUDES
+ "${CMAKE_CURRENT_LIST_DIR}/../.."
+ "${PROJECT_BINARY_DIR}"
+ DEPS
+ ::dynamic_symbols
+ iree::base
+ iree::base::core_headers
+ iree::base::internal
+ iree::base::internal::flatcc
+ iree::base::internal::synchronization
+ iree::base::logging
+ iree::base::status
+ iree::base::tracing
+ iree::hal
+ iree::schemas::rocm_executable_def_c_fbs
+ PUBLIC
+)
+
+add_definitions(-D__HIP_PLATFORM_HCC__)
+
+iree_cc_library(
+ NAME
+ dynamic_symbols
+ HDRS
+ "dynamic_symbols.h"
+ TEXTUAL_HDRS
+ "dynamic_symbol_tables.h"
+ SRCS
+ "rocm_headers.h"
+ "dynamic_symbols.c"
+ INCLUDES
+ "${CMAKE_CURRENT_LIST_DIR}/../.."
+ DEPS
+ rocm_headers
+ iree::base::core_headers
+ iree::base::internal::dynamic_library
+ iree::base::tracing
+ PUBLIC
+)
+
+iree_cc_test(
+ NAME
+ dynamic_symbols_test
+ SRCS
+ "dynamic_symbols_test.cc"
+ DEPS
+ ::dynamic_symbols
+ iree::testing::gtest
+ iree::testing::gtest_main
+ LABELS
+ "driver=rocm"
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/experimental/rocm/api.h b/experimental/rocm/api.h
new file mode 100644
index 0000000..ec38b00
--- /dev/null
+++ b/experimental/rocm/api.h
@@ -0,0 +1,55 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// See iree/base/api.h for documentation on the API conventions used.
+
+#ifndef IREE_HAL_ROCM_API_H_
+#define IREE_HAL_ROCM_API_H_
+
+#include "iree/base/api.h"
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+//===----------------------------------------------------------------------===//
+// iree_hal_rocm_driver_t
+//===----------------------------------------------------------------------===//
+
+// ROCM driver creation options.
+typedef struct {
+ // Index of the default ROCM device to use within the list of available
+ // devices.
+ int default_device_index;
+} iree_hal_rocm_driver_options_t;
+
+IREE_API_EXPORT void iree_hal_rocm_driver_options_initialize(
+ iree_hal_rocm_driver_options_t *out_options);
+
+// Creates a ROCM HAL driver that manage its own hipcontext.
+//
+// |out_driver| must be released by the caller (see |iree_hal_driver_release|).
+IREE_API_EXPORT iree_status_t iree_hal_rocm_driver_create(
+ iree_string_view_t identifier,
+ const iree_hal_rocm_driver_options_t *options,
+ iree_allocator_t host_allocator, iree_hal_driver_t **out_driver);
+
+// TODO(thomasraoux): Support importing a CUcontext from app.
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_ROCM_API_H_
diff --git a/experimental/rocm/context_wrapper.h b/experimental/rocm/context_wrapper.h
new file mode 100644
index 0000000..6637923
--- /dev/null
+++ b/experimental/rocm/context_wrapper.h
@@ -0,0 +1,30 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_ROCM_CONTEXT_WRAPPER_H_
+#define IREE_HAL_ROCM_CONTEXT_WRAPPER_H_
+
+#include "experimental/rocm/dynamic_symbols.h"
+#include "experimental/rocm/rocm_headers.h"
+#include "iree/hal/api.h"
+
+// Structure to wrap all objects constant within a context. This makes it
+// simpler to pass it to the different objects and saves memory.
+typedef struct {
+ hipCtx_t rocm_context;
+ iree_allocator_t host_allocator;
+ iree_hal_rocm_dynamic_symbols_t *syms;
+} iree_hal_rocm_context_wrapper_t;
+
+#endif // IREE_HAL_ROCM_CONTEXT_WRAPPER_H_
diff --git a/experimental/rocm/descriptor_set_layout.c b/experimental/rocm/descriptor_set_layout.c
new file mode 100644
index 0000000..59cd8cc
--- /dev/null
+++ b/experimental/rocm/descriptor_set_layout.c
@@ -0,0 +1,78 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "experimental/rocm/descriptor_set_layout.h"
+
+#include "experimental/rocm/status_util.h"
+#include "iree/base/tracing.h"
+
+typedef struct {
+ iree_hal_resource_t resource;
+ iree_hal_rocm_context_wrapper_t *context;
+} iree_hal_rocm_descriptor_set_layout_t;
+
+extern const iree_hal_descriptor_set_layout_vtable_t
+ iree_hal_rocm_descriptor_set_layout_vtable;
+
+static iree_hal_rocm_descriptor_set_layout_t *
+iree_hal_rocm_descriptor_set_layout_cast(
+ iree_hal_descriptor_set_layout_t *base_value) {
+ IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_descriptor_set_layout_vtable);
+ return (iree_hal_rocm_descriptor_set_layout_t *)base_value;
+}
+
+iree_status_t iree_hal_rocm_descriptor_set_layout_create(
+ iree_hal_rocm_context_wrapper_t *context,
+ iree_hal_descriptor_set_layout_usage_type_t usage_type,
+ iree_host_size_t binding_count,
+ const iree_hal_descriptor_set_layout_binding_t *bindings,
+ iree_hal_descriptor_set_layout_t **out_descriptor_set_layout) {
+ IREE_ASSERT_ARGUMENT(context);
+ IREE_ASSERT_ARGUMENT(!binding_count || bindings);
+ IREE_ASSERT_ARGUMENT(out_descriptor_set_layout);
+ *out_descriptor_set_layout = NULL;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_hal_rocm_descriptor_set_layout_t *descriptor_set_layout = NULL;
+ iree_status_t status = iree_allocator_malloc(context->host_allocator,
+ sizeof(*descriptor_set_layout),
+ (void **)&descriptor_set_layout);
+ if (iree_status_is_ok(status)) {
+ iree_hal_resource_initialize(&iree_hal_rocm_descriptor_set_layout_vtable,
+ &descriptor_set_layout->resource);
+ descriptor_set_layout->context = context;
+ *out_descriptor_set_layout =
+ (iree_hal_descriptor_set_layout_t *)descriptor_set_layout;
+ }
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+static void iree_hal_rocm_descriptor_set_layout_destroy(
+ iree_hal_descriptor_set_layout_t *base_descriptor_set_layout) {
+ iree_hal_rocm_descriptor_set_layout_t *descriptor_set_layout =
+ iree_hal_rocm_descriptor_set_layout_cast(base_descriptor_set_layout);
+ iree_allocator_t host_allocator =
+ descriptor_set_layout->context->host_allocator;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_allocator_free(host_allocator, descriptor_set_layout);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+const iree_hal_descriptor_set_layout_vtable_t
+ iree_hal_rocm_descriptor_set_layout_vtable = {
+ .destroy = iree_hal_rocm_descriptor_set_layout_destroy,
+};
diff --git a/experimental/rocm/descriptor_set_layout.h b/experimental/rocm/descriptor_set_layout.h
new file mode 100644
index 0000000..c658d3e
--- /dev/null
+++ b/experimental/rocm/descriptor_set_layout.h
@@ -0,0 +1,36 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_ROCM_DESCRIPTOR_SET_LAYOUT_H_
+#define IREE_HAL_ROCM_DESCRIPTOR_SET_LAYOUT_H_
+
+#include "experimental/rocm/context_wrapper.h"
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+iree_status_t iree_hal_rocm_descriptor_set_layout_create(
+ iree_hal_rocm_context_wrapper_t *context,
+ iree_hal_descriptor_set_layout_usage_type_t usage_type,
+ iree_host_size_t binding_count,
+ const iree_hal_descriptor_set_layout_binding_t *bindings,
+ iree_hal_descriptor_set_layout_t **out_descriptor_set_layout);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_ROCM_DESCRIPTOR_SET_LAYOUT_H_
diff --git a/experimental/rocm/direct_command_buffer.c b/experimental/rocm/direct_command_buffer.c
new file mode 100644
index 0000000..24f03c2
--- /dev/null
+++ b/experimental/rocm/direct_command_buffer.c
@@ -0,0 +1,345 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "experimental/rocm/direct_command_buffer.h"
+
+#include "experimental/rocm/native_executable.h"
+#include "experimental/rocm/rocm_buffer.h"
+#include "experimental/rocm/rocm_event.h"
+#include "experimental/rocm/status_util.h"
+#include "iree/base/tracing.h"
+
+// Command buffer implementation that directly maps to rocm direct.
+// This records the commands on the calling thread without additional threading
+// indirection.
+
+typedef struct {
+ iree_hal_resource_t resource;
+ iree_hal_rocm_context_wrapper_t *context;
+ iree_hal_command_buffer_mode_t mode;
+ iree_hal_command_category_t allowed_categories;
+ iree_hal_queue_affinity_t queue_affinity;
+ size_t total_size;
+ // Keep track of the current set of kernel arguments.
+ void *current_descriptor[];
+} iree_hal_rocm_direct_command_buffer_t;
+
+static const size_t max_binding_count = 64;
+
+extern const iree_hal_command_buffer_vtable_t
+ iree_hal_rocm_direct_command_buffer_vtable;
+
+static iree_hal_rocm_direct_command_buffer_t *
+iree_hal_rocm_direct_command_buffer_cast(
+ iree_hal_command_buffer_t *base_value) {
+ IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_direct_command_buffer_vtable);
+ return (iree_hal_rocm_direct_command_buffer_t *)base_value;
+}
+
+iree_status_t iree_hal_rocm_direct_command_buffer_allocate(
+ iree_hal_rocm_context_wrapper_t *context,
+ iree_hal_command_buffer_mode_t mode,
+ iree_hal_command_category_t command_categories,
+ iree_hal_queue_affinity_t queue_affinity,
+ iree_hal_command_buffer_t **out_command_buffer) {
+ IREE_ASSERT_ARGUMENT(context);
+ IREE_ASSERT_ARGUMENT(out_command_buffer);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_hal_rocm_direct_command_buffer_t *command_buffer = NULL;
+ size_t total_size = sizeof(*command_buffer) +
+ max_binding_count * sizeof(void *) +
+ max_binding_count * sizeof(hipDeviceptr_t);
+ iree_status_t status = iree_allocator_malloc(
+ context->host_allocator, total_size, (void **)&command_buffer);
+ if (iree_status_is_ok(status)) {
+ iree_hal_resource_initialize(&iree_hal_rocm_direct_command_buffer_vtable,
+ &command_buffer->resource);
+ command_buffer->context = context;
+ command_buffer->mode = mode;
+ command_buffer->allowed_categories = command_categories;
+ command_buffer->queue_affinity = queue_affinity;
+ hipDeviceptr_t *device_ptrs =
+ (hipDeviceptr_t *)(command_buffer->current_descriptor +
+ max_binding_count);
+ for (size_t i = 0; i < max_binding_count; i++) {
+ command_buffer->current_descriptor[i] = &device_ptrs[i];
+ }
+ command_buffer->total_size = total_size;
+
+ *out_command_buffer = (iree_hal_command_buffer_t *)command_buffer;
+ }
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+static void iree_hal_rocm_direct_command_buffer_destroy(
+ iree_hal_command_buffer_t *base_command_buffer) {
+ iree_hal_rocm_direct_command_buffer_t *command_buffer =
+ iree_hal_rocm_direct_command_buffer_cast(base_command_buffer);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_allocator_free(command_buffer->context->host_allocator, command_buffer);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+static iree_hal_command_buffer_mode_t iree_hal_rocm_direct_command_buffer_mode(
+ const iree_hal_command_buffer_t *base_command_buffer) {
+ const iree_hal_rocm_direct_command_buffer_t *command_buffer =
+ (const iree_hal_rocm_direct_command_buffer_t *)(base_command_buffer);
+ return command_buffer->mode;
+}
+
+static iree_hal_command_category_t
+iree_hal_rocm_direct_command_buffer_allowed_categories(
+ const iree_hal_command_buffer_t *base_command_buffer) {
+ const iree_hal_rocm_direct_command_buffer_t *command_buffer =
+ (const iree_hal_rocm_direct_command_buffer_t *)(base_command_buffer);
+ return command_buffer->allowed_categories;
+}
+
+static iree_status_t iree_hal_rocm_direct_command_buffer_begin(
+ iree_hal_command_buffer_t *base_command_buffer) {
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_rocm_direct_command_buffer_end(
+ iree_hal_command_buffer_t *base_command_buffer) {
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_rocm_direct_command_buffer_execution_barrier(
+ iree_hal_command_buffer_t *base_command_buffer,
+ iree_hal_execution_stage_t source_stage_mask,
+ iree_hal_execution_stage_t target_stage_mask,
+ iree_hal_execution_barrier_flags_t flags,
+ iree_host_size_t memory_barrier_count,
+ const iree_hal_memory_barrier_t *memory_barriers,
+ iree_host_size_t buffer_barrier_count,
+ const iree_hal_buffer_barrier_t *buffer_barriers) {
+ // TODO: Implement barrier
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_rocm_direct_command_buffer_signal_event(
+ iree_hal_command_buffer_t *base_command_buffer, iree_hal_event_t *event,
+ iree_hal_execution_stage_t source_stage_mask) {
+ // TODO: Implement barrier
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_rocm_direct_command_buffer_reset_event(
+ iree_hal_command_buffer_t *base_command_buffer, iree_hal_event_t *event,
+ iree_hal_execution_stage_t source_stage_mask) {
+ // TODO: Implement barrier
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_rocm_direct_command_buffer_wait_events(
+ iree_hal_command_buffer_t *base_command_buffer,
+ iree_host_size_t event_count, const iree_hal_event_t **events,
+ iree_hal_execution_stage_t source_stage_mask,
+ iree_hal_execution_stage_t target_stage_mask,
+ iree_host_size_t memory_barrier_count,
+ const iree_hal_memory_barrier_t *memory_barriers,
+ iree_host_size_t buffer_barrier_count,
+ const iree_hal_buffer_barrier_t *buffer_barriers) {
+ // TODO: Implement barrier
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_rocm_direct_command_buffer_discard_buffer(
+ iree_hal_command_buffer_t *base_command_buffer, iree_hal_buffer_t *buffer) {
+ // nothing to do.
+ return iree_ok_status();
+}
+
+// Splats a pattern value of 1, 2, or 4 bytes out to a 4 byte value.
+static uint32_t iree_hal_rocm_splat_pattern(const void *pattern,
+ size_t pattern_length) {
+ switch (pattern_length) {
+ case 1: {
+ uint32_t pattern_value = *(const uint8_t *)(pattern);
+ return (pattern_value << 24) | (pattern_value << 16) |
+ (pattern_value << 8) | pattern_value;
+ }
+ case 2: {
+ uint32_t pattern_value = *(const uint16_t *)(pattern);
+ return (pattern_value << 16) | pattern_value;
+ }
+ case 4: {
+ uint32_t pattern_value = *(const uint32_t *)(pattern);
+ return pattern_value;
+ }
+ default:
+ return 0; // Already verified that this should not be possible.
+ }
+}
+
+static iree_status_t iree_hal_rocm_direct_command_buffer_fill_buffer(
+ iree_hal_command_buffer_t *base_command_buffer,
+ iree_hal_buffer_t *target_buffer, iree_device_size_t target_offset,
+ iree_device_size_t length, const void *pattern,
+ iree_host_size_t pattern_length) {
+ iree_hal_rocm_direct_command_buffer_t *command_buffer =
+ iree_hal_rocm_direct_command_buffer_cast(base_command_buffer);
+
+ hipDeviceptr_t target_device_buffer = iree_hal_rocm_buffer_device_pointer(
+ iree_hal_buffer_allocated_buffer(target_buffer));
+ target_offset += iree_hal_buffer_byte_offset(target_buffer);
+ uint32_t dword_pattern = iree_hal_rocm_splat_pattern(pattern, pattern_length);
+ hipDeviceptr_t dst = target_device_buffer + target_offset;
+ int value = dword_pattern;
+ size_t sizeBytes = length;
+ // TODO(raikonenfnu): Currently using NULL stream, need to figure out way to
+ // access proper stream from command buffer
+ ROCM_RETURN_IF_ERROR(command_buffer->context->syms,
+ hipMemsetAsync(dst, value, sizeBytes, 0),
+ "hipMemsetAsync");
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_rocm_direct_command_buffer_update_buffer(
+ iree_hal_command_buffer_t *base_command_buffer, const void *source_buffer,
+ iree_host_size_t source_offset, iree_hal_buffer_t *target_buffer,
+ iree_device_size_t target_offset, iree_device_size_t length) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "need rocm implementation");
+}
+
+static iree_status_t iree_hal_rocm_direct_command_buffer_copy_buffer(
+ iree_hal_command_buffer_t *base_command_buffer,
+ iree_hal_buffer_t *source_buffer, iree_device_size_t source_offset,
+ iree_hal_buffer_t *target_buffer, iree_device_size_t target_offset,
+ iree_device_size_t length) {
+ iree_hal_rocm_direct_command_buffer_t *command_buffer =
+ iree_hal_rocm_direct_command_buffer_cast(base_command_buffer);
+
+ hipDeviceptr_t target_device_buffer = iree_hal_rocm_buffer_device_pointer(
+ iree_hal_buffer_allocated_buffer(target_buffer));
+ target_offset += iree_hal_buffer_byte_offset(target_buffer);
+ hipDeviceptr_t source_device_buffer = iree_hal_rocm_buffer_device_pointer(
+ iree_hal_buffer_allocated_buffer(source_buffer));
+ source_offset += iree_hal_buffer_byte_offset(source_buffer);
+ // TODO(raikonenfnu): Currently using NULL stream, need to figure out way to
+ // access proper stream from command buffer
+ ROCM_RETURN_IF_ERROR(
+ command_buffer->context->syms,
+ hipMemcpyAsync(target_device_buffer, source_device_buffer, length,
+ hipMemcpyDeviceToDevice, 0),
+ "hipMemcpyAsync");
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_rocm_direct_command_buffer_push_constants(
+ iree_hal_command_buffer_t *base_command_buffer,
+ iree_hal_executable_layout_t *executable_layout, iree_host_size_t offset,
+ const void *values, iree_host_size_t values_length) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "need rocm implementation");
+}
+
+static iree_status_t iree_hal_rocm_direct_command_buffer_push_descriptor_set(
+ iree_hal_command_buffer_t *base_command_buffer,
+ iree_hal_executable_layout_t *executable_layout, uint32_t set,
+ iree_host_size_t binding_count,
+ const iree_hal_descriptor_set_binding_t *bindings) {
+ iree_hal_rocm_direct_command_buffer_t *command_buffer =
+ iree_hal_rocm_direct_command_buffer_cast(base_command_buffer);
+ for (iree_host_size_t i = 0; i < binding_count; i++) {
+ uint32_t arg_index = bindings[i].binding;
+ assert(arg_index < max_binding_count &&
+ "binding index larger than the max expected.");
+ hipDeviceptr_t device_ptr =
+ iree_hal_rocm_buffer_device_pointer(
+ iree_hal_buffer_allocated_buffer(bindings[i].buffer)) +
+ iree_hal_buffer_byte_offset(bindings[i].buffer) + bindings[i].offset;
+ *((hipDeviceptr_t *)command_buffer->current_descriptor[arg_index]) =
+ device_ptr;
+ }
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_rocm_direct_command_buffer_bind_descriptor_set(
+ iree_hal_command_buffer_t *base_command_buffer,
+ iree_hal_executable_layout_t *executable_layout, uint32_t set,
+ iree_hal_descriptor_set_t *descriptor_set,
+ iree_host_size_t dynamic_offset_count,
+ const iree_device_size_t *dynamic_offsets) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "need rocm implementation");
+}
+
+static iree_status_t iree_hal_rocm_direct_command_buffer_dispatch(
+ iree_hal_command_buffer_t *base_command_buffer,
+ iree_hal_executable_t *executable, int32_t entry_point,
+ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
+ iree_hal_rocm_direct_command_buffer_t *command_buffer =
+ iree_hal_rocm_direct_command_buffer_cast(base_command_buffer);
+ iree_hal_rocm_direct_command_buffer_cast(base_command_buffer);
+
+ int32_t block_size_x, block_size_y, block_size_z;
+ IREE_RETURN_IF_ERROR(iree_hal_rocm_native_executable_block_size(
+ executable, entry_point, &block_size_x, &block_size_y, &block_size_z));
+ int size = command_buffer->total_size;
+ hipFunction_t func =
+ iree_hal_rocm_native_executable_for_entry_point(executable, entry_point);
+ // TODO(raikonenfnu): Currently using NULL stream, need to figure out way to
+ // access proper stream from command buffer
+ ROCM_RETURN_IF_ERROR(
+ command_buffer->context->syms,
+ hipModuleLaunchKernel(func, workgroup_x, workgroup_y, workgroup_z,
+ block_size_x, block_size_y, block_size_z, 0, 0,
+ command_buffer->current_descriptor, NULL),
+ "hipModuleLaunchKernel");
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_rocm_direct_command_buffer_dispatch_indirect(
+ iree_hal_command_buffer_t *base_command_buffer,
+ iree_hal_executable_t *executable, int32_t entry_point,
+ iree_hal_buffer_t *workgroups_buffer,
+ iree_device_size_t workgroups_offset) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "need rocm implementation");
+}
+
+const iree_hal_command_buffer_vtable_t
+ iree_hal_rocm_direct_command_buffer_vtable = {
+ .destroy = iree_hal_rocm_direct_command_buffer_destroy,
+ .mode = iree_hal_rocm_direct_command_buffer_mode,
+ .allowed_categories =
+ iree_hal_rocm_direct_command_buffer_allowed_categories,
+ .begin = iree_hal_rocm_direct_command_buffer_begin,
+ .end = iree_hal_rocm_direct_command_buffer_end,
+ .execution_barrier =
+ iree_hal_rocm_direct_command_buffer_execution_barrier,
+ .signal_event = iree_hal_rocm_direct_command_buffer_signal_event,
+ .reset_event = iree_hal_rocm_direct_command_buffer_reset_event,
+ .wait_events = iree_hal_rocm_direct_command_buffer_wait_events,
+ .discard_buffer = iree_hal_rocm_direct_command_buffer_discard_buffer,
+ .fill_buffer = iree_hal_rocm_direct_command_buffer_fill_buffer,
+ .update_buffer = iree_hal_rocm_direct_command_buffer_update_buffer,
+ .copy_buffer = iree_hal_rocm_direct_command_buffer_copy_buffer,
+ .push_constants = iree_hal_rocm_direct_command_buffer_push_constants,
+ .push_descriptor_set =
+ iree_hal_rocm_direct_command_buffer_push_descriptor_set,
+ .bind_descriptor_set =
+ iree_hal_rocm_direct_command_buffer_bind_descriptor_set,
+ .dispatch = iree_hal_rocm_direct_command_buffer_dispatch,
+ .dispatch_indirect =
+ iree_hal_rocm_direct_command_buffer_dispatch_indirect,
+};
diff --git a/experimental/rocm/direct_command_buffer.h b/experimental/rocm/direct_command_buffer.h
new file mode 100644
index 0000000..41ffb9c
--- /dev/null
+++ b/experimental/rocm/direct_command_buffer.h
@@ -0,0 +1,51 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_ROCM_DIRECT_COMMAND_BUFFER_H_
+#define IREE_HAL_ROCM_DIRECT_COMMAND_BUFFER_H_
+
+#include "experimental/rocm/context_wrapper.h"
+#include "experimental/rocm/dynamic_symbols.h"
+#include "experimental/rocm/rocm_headers.h"
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// ROCM Kernel Information Structure
+typedef struct {
+ hipFunction_t func;
+ unsigned int gridDimX;
+ unsigned int gridDimY;
+ unsigned int gridDimZ;
+ unsigned int blockDimX;
+ unsigned int blockDimY;
+ unsigned int blockDimZ;
+ void **kernelParams;
+} hip_launch_params;
+
+// Creates a rocm direct command buffer.
+iree_status_t iree_hal_rocm_direct_command_buffer_allocate(
+ iree_hal_rocm_context_wrapper_t *context,
+ iree_hal_command_buffer_mode_t mode,
+ iree_hal_command_category_t command_categories,
+ iree_hal_queue_affinity_t queue_affinity,
+ iree_hal_command_buffer_t **out_command_buffer);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_ROCM_DIRECT_COMMAND_BUFFER_H_
diff --git a/experimental/rocm/dynamic_symbol_tables.h b/experimental/rocm/dynamic_symbol_tables.h
new file mode 100644
index 0000000..f67aa30
--- /dev/null
+++ b/experimental/rocm/dynamic_symbol_tables.h
@@ -0,0 +1,53 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+RC_PFN_DECL(hipCtxCreate, hipCtx_t *, unsigned int, hipDevice_t)
+RC_PFN_DECL(hipCtxDestroy, hipCtx_t)
+RC_PFN_DECL(hipDeviceGet, hipDevice_t *, int) // No direct, need to modify
+RC_PFN_DECL(hipGetDeviceCount, int *)
+RC_PFN_DECL(hipDeviceGetName, char *, int,
+ hipDevice_t) // No direct, need to modify
+RC_PFN_STR_DECL(
+ hipGetErrorName,
+ hipError_t) // Unlike other functions hipGetErrorName(hipError_t) return
+ // const char* instead of hipError_t so it uses a different
+ // macro
+RC_PFN_STR_DECL(
+ hipGetErrorString,
+ hipError_t) // Unlike other functions hipGetErrorName(hipError_t) return
+ // const char* instead of hipError_t so it uses a different
+ // macro
+RC_PFN_DECL(hipInit, unsigned int)
+RC_PFN_DECL(hipModuleLaunchKernel, hipFunction_t, unsigned int, unsigned int,
+ unsigned int, unsigned int, unsigned int, unsigned int,
+ unsigned int, hipStream_t, void **, void **)
+RC_PFN_DECL(hipMemset, void *, int, size_t)
+RC_PFN_DECL(hipMemsetAsync, void *, int, size_t, hipStream_t)
+RC_PFN_DECL(hipMemcpy, void *, const void *, size_t, hipMemcpyKind)
+RC_PFN_DECL(hipMemcpyAsync, void *, const void *, size_t, hipMemcpyKind,
+ hipStream_t)
+RC_PFN_DECL(hipMalloc, void **, size_t)
+RC_PFN_DECL(hipFree, void *)
+RC_PFN_DECL(hipHostFree, void *)
+RC_PFN_DECL(hipMemAllocHost, void **, size_t, unsigned int)
+RC_PFN_DECL(hipHostGetDevicePointer, void **, void *, unsigned int)
+RC_PFN_DECL(hipModuleGetFunction, hipFunction_t *, hipModule_t, const char *)
+RC_PFN_DECL(hipModuleLoadDataEx, hipModule_t *, const void *, unsigned int,
+ hipJitOption *, void **)
+RC_PFN_DECL(hipModuleLoadData, hipModule_t *, const void *)
+RC_PFN_DECL(hipModuleUnload, hipModule_t)
+RC_PFN_DECL(hipStreamCreateWithFlags, hipStream_t *, unsigned int)
+RC_PFN_DECL(hipStreamDestroy, hipStream_t)
+RC_PFN_DECL(hipStreamSynchronize, hipStream_t)
+RC_PFN_DECL(hipStreamWaitEvent, hipStream_t, hipEvent_t, unsigned int)
diff --git a/experimental/rocm/dynamic_symbols.c b/experimental/rocm/dynamic_symbols.c
new file mode 100644
index 0000000..b198c87
--- /dev/null
+++ b/experimental/rocm/dynamic_symbols.c
@@ -0,0 +1,75 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "experimental/rocm/dynamic_symbols.h"
+
+#include <stddef.h>
+
+#include "iree/base/internal/dynamic_library.h"
+#include "iree/base/target_platform.h"
+#include "iree/base/tracing.h"
+
+static const char *kROCMLoaderSearchNames[] = {
+#if defined(IREE_PLATFORM_WINDOWS)
+ "amdhip64.dll",
+#else
+ "libamdhip64.so",
+#endif
+};
+
+static iree_status_t iree_hal_rocm_dynamic_symbols_resolve_all(
+ iree_hal_rocm_dynamic_symbols_t *syms) {
+#define RC_PFN_DECL(rocmSymbolName, ...) \
+ { \
+ static const char *kName = #rocmSymbolName; \
+ IREE_RETURN_IF_ERROR(iree_dynamic_library_lookup_symbol( \
+ syms->loader_library, kName, (void **)&syms->rocmSymbolName)); \
+ }
+#define RC_PFN_STR_DECL(rocmSymbolName, ...) RC_PFN_DECL(rocmSymbolName, ...)
+#include "experimental/rocm/dynamic_symbol_tables.h"
+#undef RC_PFN_DECL
+#undef RC_PFN_STR_DECL
+ return iree_ok_status();
+}
+
+iree_status_t iree_hal_rocm_dynamic_symbols_initialize(
+ iree_allocator_t allocator, iree_hal_rocm_dynamic_symbols_t *out_syms) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+ memset(out_syms, 0, sizeof(*out_syms));
+ iree_status_t status = iree_dynamic_library_load_from_files(
+ IREE_ARRAYSIZE(kROCMLoaderSearchNames), kROCMLoaderSearchNames,
+ IREE_DYNAMIC_LIBRARY_FLAG_NONE, allocator, &out_syms->loader_library);
+ if (iree_status_is_not_found(status)) {
+ iree_status_ignore(status);
+ return iree_make_status(
+ IREE_STATUS_UNAVAILABLE,
+ "ROCM/HIP runtime library not available; ensure installed and on path");
+ }
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_rocm_dynamic_symbols_resolve_all(out_syms);
+ }
+ if (!iree_status_is_ok(status)) {
+ iree_hal_rocm_dynamic_symbols_deinitialize(out_syms);
+ }
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+void iree_hal_rocm_dynamic_symbols_deinitialize(
+ iree_hal_rocm_dynamic_symbols_t *syms) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+ iree_dynamic_library_release(syms->loader_library);
+ memset(syms, 0, sizeof(*syms));
+ IREE_TRACE_ZONE_END(z0);
+}
diff --git a/experimental/rocm/dynamic_symbols.h b/experimental/rocm/dynamic_symbols.h
new file mode 100644
index 0000000..e5f5891
--- /dev/null
+++ b/experimental/rocm/dynamic_symbols.h
@@ -0,0 +1,58 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_ROCM_DYNAMIC_SYMBOLS_H_
+#define IREE_HAL_ROCM_DYNAMIC_SYMBOLS_H_
+
+#include "experimental/rocm/rocm_headers.h"
+#include "iree/base/api.h"
+#include "iree/base/internal/dynamic_library.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// DynamicSymbols allow loading dynamically a subset of ROCM driver API. It
+// loads all the function declared in `dynamic_symbol_tables.def` and fail if
+// any of the symbol is not available. The functions signatures are matching
+// the declarations in `hipruntime.h`.
+typedef struct {
+ iree_dynamic_library_t *loader_library;
+
+#define RC_PFN_DECL(rocmSymbolName, ...) \
+ hipError_t (*rocmSymbolName)(__VA_ARGS__);
+#define RC_PFN_STR_DECL(rocmSymbolName, ...) \
+ const char *(*rocmSymbolName)(__VA_ARGS__);
+#include "experimental/rocm/dynamic_symbol_tables.h"
+#undef RC_PFN_DECL
+#undef RC_PFN_STR_DECL
+} iree_hal_rocm_dynamic_symbols_t;
+
+// Initializes |out_syms| in-place with dynamically loaded ROCM symbols.
+// iree_hal_rocm_dynamic_symbols_deinitialize must be used to release the
+// library resources.
+iree_status_t iree_hal_rocm_dynamic_symbols_initialize(
+ iree_allocator_t allocator, iree_hal_rocm_dynamic_symbols_t *out_syms);
+
+// Deinitializes |syms| by unloading the backing library. All function pointers
+// will be invalidated. They _may_ still work if there are other reasons the
+// library remains loaded so be careful.
+void iree_hal_rocm_dynamic_symbols_deinitialize(
+ iree_hal_rocm_dynamic_symbols_t *syms);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_ROCM_DYNAMIC_SYMBOLS_H_
diff --git a/experimental/rocm/dynamic_symbols_test.cc b/experimental/rocm/dynamic_symbols_test.cc
new file mode 100644
index 0000000..541d04c
--- /dev/null
+++ b/experimental/rocm/dynamic_symbols_test.cc
@@ -0,0 +1,54 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "experimental/rocm/dynamic_symbols.h"
+
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+namespace iree {
+namespace hal {
+namespace rocm {
+namespace {
+
+#define ROCM_CHECK_ERRORS(expr) \
+ { \
+ hipError_t status = expr; \
+ ASSERT_EQ(hipSuccess, status); \
+ }
+
+TEST(DynamicSymbolsTest, CreateFromSystemLoader) {
+ iree_hal_rocm_dynamic_symbols_t symbols;
+ iree_status_t status = iree_hal_rocm_dynamic_symbols_initialize(
+ iree_allocator_system(), &symbols);
+ if (!iree_status_is_ok(status)) {
+ IREE_LOG(WARNING) << "Symbols cannot be loaded, skipping test.";
+ GTEST_SKIP();
+ }
+
+ int device_count = 0;
+ ROCM_CHECK_ERRORS(symbols.hipInit(0));
+ ROCM_CHECK_ERRORS(symbols.hipGetDeviceCount(&device_count));
+ if (device_count > 0) {
+ hipDevice_t device;
+ ROCM_CHECK_ERRORS(symbols.hipDeviceGet(&device, /*ordinal=*/0));
+ }
+
+ iree_hal_rocm_dynamic_symbols_deinitialize(&symbols);
+}
+
+} // namespace
+} // namespace rocm
+} // namespace hal
+} // namespace iree
diff --git a/experimental/rocm/event_semaphore.c b/experimental/rocm/event_semaphore.c
new file mode 100644
index 0000000..0cdfa02
--- /dev/null
+++ b/experimental/rocm/event_semaphore.c
@@ -0,0 +1,99 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations ufnder the License.
+
+#include "experimental/rocm/event_semaphore.h"
+
+#include "experimental/rocm/status_util.h"
+#include "iree/base/tracing.h"
+
+typedef struct {
+ iree_hal_resource_t resource;
+ iree_hal_rocm_context_wrapper_t *context;
+ uint64_t initial_value;
+} iree_hal_rocm_semaphore_t;
+
+extern const iree_hal_semaphore_vtable_t iree_hal_rocm_semaphore_vtable;
+
+static iree_hal_rocm_semaphore_t *iree_hal_rocm_semaphore_cast(
+ iree_hal_semaphore_t *base_value) {
+ IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_semaphore_vtable);
+ return (iree_hal_rocm_semaphore_t *)base_value;
+}
+
+iree_status_t iree_hal_rocm_semaphore_create(
+ iree_hal_rocm_context_wrapper_t *context, uint64_t initial_value,
+ iree_hal_semaphore_t **out_semaphore) {
+ IREE_ASSERT_ARGUMENT(context);
+ IREE_ASSERT_ARGUMENT(out_semaphore);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_hal_rocm_semaphore_t *semaphore = NULL;
+ iree_status_t status = iree_allocator_malloc(
+ context->host_allocator, sizeof(*semaphore), (void **)&semaphore);
+ if (iree_status_is_ok(status)) {
+ iree_hal_resource_initialize(&iree_hal_rocm_semaphore_vtable,
+ &semaphore->resource);
+ semaphore->context = context;
+ semaphore->initial_value = initial_value;
+ *out_semaphore = (iree_hal_semaphore_t *)semaphore;
+ }
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+static void iree_hal_rocm_semaphore_destroy(
+ iree_hal_semaphore_t *base_semaphore) {
+ iree_hal_rocm_semaphore_t *semaphore =
+ iree_hal_rocm_semaphore_cast(base_semaphore);
+ iree_allocator_t host_allocator = semaphore->context->host_allocator;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_allocator_free(host_allocator, semaphore);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+static iree_status_t iree_hal_rocm_semaphore_query(
+ iree_hal_semaphore_t *base_semaphore, uint64_t *out_value) {
+ // TODO: Support semaphores completely.
+ *out_value = 0;
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "Not impemented on rocm");
+}
+
+static iree_status_t iree_hal_rocm_semaphore_signal(
+ iree_hal_semaphore_t *base_semaphore, uint64_t new_value) {
+ // TODO: Support semaphores completely. Return OK currently as everything is
+ // synchronized for each submit to allow things to run.
+ return iree_ok_status();
+}
+
+static void iree_hal_rocm_semaphore_fail(iree_hal_semaphore_t *base_semaphore,
+ iree_status_t status) {}
+
+static iree_status_t iree_hal_rocm_semaphore_wait(
+ iree_hal_semaphore_t *base_semaphore, uint64_t value,
+ iree_timeout_t timeout) {
+ // TODO: Support semaphores completely. Return OK currently as everything is
+ // synchronized for each submit to allow things to run.
+ return iree_ok_status();
+}
+
+const iree_hal_semaphore_vtable_t iree_hal_rocm_semaphore_vtable = {
+ .destroy = iree_hal_rocm_semaphore_destroy,
+ .query = iree_hal_rocm_semaphore_query,
+ .signal = iree_hal_rocm_semaphore_signal,
+ .fail = iree_hal_rocm_semaphore_fail,
+ .wait = iree_hal_rocm_semaphore_wait,
+};
diff --git a/experimental/rocm/event_semaphore.h b/experimental/rocm/event_semaphore.h
new file mode 100644
index 0000000..952e3e5
--- /dev/null
+++ b/experimental/rocm/event_semaphore.h
@@ -0,0 +1,35 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_ROCM_SEMAPHORE_H_
+#define IREE_HAL_ROCM_SEMAPHORE_H_
+
+#include "experimental/rocm/context_wrapper.h"
+#include "experimental/rocm/status_util.h"
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// Create a rocm allocator.
+iree_status_t iree_hal_rocm_semaphore_create(
+ iree_hal_rocm_context_wrapper_t *context, uint64_t initial_value,
+ iree_hal_semaphore_t **out_semaphore);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_ROCM_SEMAPHORE_H_
diff --git a/experimental/rocm/executable_layout.c b/experimental/rocm/executable_layout.c
new file mode 100644
index 0000000..e7c666f
--- /dev/null
+++ b/experimental/rocm/executable_layout.c
@@ -0,0 +1,88 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "experimental/rocm/executable_layout.h"
+
+#include "experimental/rocm/status_util.h"
+#include "iree/base/tracing.h"
+
+typedef struct {
+ iree_hal_resource_t resource;
+ iree_hal_rocm_context_wrapper_t *context;
+ iree_host_size_t set_layout_count;
+ iree_hal_descriptor_set_layout_t *set_layouts[];
+} iree_hal_rocm_executable_layout_t;
+
+extern const iree_hal_executable_layout_vtable_t
+ iree_hal_rocm_executable_layout_vtable;
+
+static iree_hal_rocm_executable_layout_t *iree_hal_rocm_executable_layout_cast(
+ iree_hal_executable_layout_t *base_value) {
+ IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_executable_layout_vtable);
+ return (iree_hal_rocm_executable_layout_t *)base_value;
+}
+
+iree_status_t iree_hal_rocm_executable_layout_create(
+ iree_hal_rocm_context_wrapper_t *context, iree_host_size_t set_layout_count,
+ iree_hal_descriptor_set_layout_t **set_layouts,
+ iree_host_size_t push_constant_count,
+ iree_hal_executable_layout_t **out_executable_layout) {
+ IREE_ASSERT_ARGUMENT(context);
+ IREE_ASSERT_ARGUMENT(!set_layout_count || set_layouts);
+ IREE_ASSERT_ARGUMENT(out_executable_layout);
+ *out_executable_layout = NULL;
+ IREE_TRACE_ZONE_BEGIN(z0);
+ // Currently the executable layout doesn't do anything.
+ // TODO: Handle creating the argument layout at that time hadling both push
+ // constant and buffers.
+ iree_hal_rocm_executable_layout_t *executable_layout = NULL;
+ iree_host_size_t total_size =
+ sizeof(*executable_layout) +
+ set_layout_count * sizeof(*executable_layout->set_layouts);
+ iree_status_t status = iree_allocator_malloc(
+ context->host_allocator, total_size, (void **)&executable_layout);
+ if (iree_status_is_ok(status)) {
+ iree_hal_resource_initialize(&iree_hal_rocm_executable_layout_vtable,
+ &executable_layout->resource);
+ executable_layout->context = context;
+ executable_layout->set_layout_count = set_layout_count;
+ for (iree_host_size_t i = 0; i < set_layout_count; ++i) {
+ executable_layout->set_layouts[i] = set_layouts[i];
+ iree_hal_descriptor_set_layout_retain(set_layouts[i]);
+ }
+ *out_executable_layout = (iree_hal_executable_layout_t *)executable_layout;
+ }
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+static void iree_hal_rocm_executable_layout_destroy(
+ iree_hal_executable_layout_t *base_executable_layout) {
+ iree_hal_rocm_executable_layout_t *executable_layout =
+ iree_hal_rocm_executable_layout_cast(base_executable_layout);
+ iree_allocator_t host_allocator = executable_layout->context->host_allocator;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ for (iree_host_size_t i = 0; i < executable_layout->set_layout_count; ++i) {
+ iree_hal_descriptor_set_layout_release(executable_layout->set_layouts[i]);
+ }
+ iree_allocator_free(host_allocator, executable_layout);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+const iree_hal_executable_layout_vtable_t
+ iree_hal_rocm_executable_layout_vtable = {
+ .destroy = iree_hal_rocm_executable_layout_destroy,
+};
diff --git a/experimental/rocm/executable_layout.h b/experimental/rocm/executable_layout.h
new file mode 100644
index 0000000..8c36713
--- /dev/null
+++ b/experimental/rocm/executable_layout.h
@@ -0,0 +1,36 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_ROCM_EXECUTABLE_LAYOUT_H_
+#define IREE_HAL_ROCM_EXECUTABLE_LAYOUT_H_
+
+#include "experimental/rocm/context_wrapper.h"
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// Creates the kernel arguments.
+iree_status_t iree_hal_rocm_executable_layout_create(
+ iree_hal_rocm_context_wrapper_t *context, iree_host_size_t set_layout_count,
+ iree_hal_descriptor_set_layout_t **set_layouts,
+ iree_host_size_t push_constant_count,
+ iree_hal_executable_layout_t **out_executable_layout);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_ROCM_EXECUTABLE_LAYOUT_H_
diff --git a/experimental/rocm/native_executable.c b/experimental/rocm/native_executable.c
new file mode 100644
index 0000000..3448228
--- /dev/null
+++ b/experimental/rocm/native_executable.c
@@ -0,0 +1,136 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "experimental/rocm/native_executable.h"
+
+#include "experimental/rocm/status_util.h"
+#include "iree/base/tracing.h"
+
+// flatcc schemas:
+#include "iree/base/internal/flatcc.h"
+#include "iree/schemas/rocm_executable_def_reader.h"
+#include "iree/schemas/rocm_executable_def_verifier.h"
+
+typedef struct {
+ hipFunction_t rocm_function;
+ uint32_t block_size_x;
+ uint32_t block_size_y;
+ uint32_t block_size_z;
+} iree_hal_rocm_native_executable_function_t;
+
+typedef struct {
+ iree_hal_resource_t resource;
+ iree_hal_rocm_context_wrapper_t *context;
+ iree_host_size_t entry_count;
+ hipModule_t module;
+ iree_hal_rocm_native_executable_function_t entry_functions[];
+} iree_hal_rocm_native_executable_t;
+
+extern const iree_hal_executable_vtable_t
+ iree_hal_rocm_native_executable_vtable;
+
+static iree_hal_rocm_native_executable_t *iree_hal_rocm_native_executable_cast(
+ iree_hal_executable_t *base_value) {
+ IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_native_executable_vtable);
+ return (iree_hal_rocm_native_executable_t *)base_value;
+}
+
+iree_status_t iree_hal_rocm_native_executable_create(
+ iree_hal_rocm_context_wrapper_t *context,
+ const iree_hal_executable_spec_t *executable_spec,
+ iree_hal_executable_t **out_executable) {
+ IREE_ASSERT_ARGUMENT(context);
+ IREE_ASSERT_ARGUMENT(executable_spec);
+ IREE_ASSERT_ARGUMENT(out_executable);
+ *out_executable = NULL;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_hal_rocm_native_executable_t *executable = NULL;
+
+ // TODO: Verify the flat buffer.
+ iree_ROCMExecutableDef_table_t executable_def =
+ iree_ROCMExecutableDef_as_root(executable_spec->executable_data.data);
+
+ // Create the kernel module.
+ flatbuffers_string_t hsaco_image =
+ iree_ROCMExecutableDef_hsaco_image_get(executable_def);
+ flatbuffers_string_vec_t entry_points_vec =
+ iree_ROCMExecutableDef_entry_points_get(executable_def);
+ iree_ROCMBlockSizeDef_vec_t block_sizes_vec =
+ iree_ROCMExecutableDef_block_sizes_get(executable_def);
+ iree_host_size_t entry_count = flatbuffers_string_vec_len(entry_points_vec);
+ iree_host_size_t total_size =
+ sizeof(*executable) +
+ entry_count * sizeof(iree_hal_rocm_native_executable_function_t);
+ iree_status_t status = iree_allocator_malloc(
+ context->host_allocator, total_size, (void **)&executable);
+ hipModule_t module = NULL;
+ ROCM_RETURN_IF_ERROR(context->syms,
+ hipModuleLoadDataEx(&module, hsaco_image, 0, NULL, NULL),
+ "hipModuleLoadDataEx");
+
+ for (iree_host_size_t i = 0; i < entry_count; i++) {
+ hipFunction_t function = NULL;
+ const char *entry_name = flatbuffers_string_vec_at(entry_points_vec, i);
+ ROCM_RETURN_IF_ERROR(context->syms,
+ hipModuleGetFunction(&function, module, entry_name),
+ "hipModuleGetFunction");
+ executable->entry_functions[i].rocm_function = function;
+ executable->entry_functions[i].block_size_x = block_sizes_vec[i].x;
+ executable->entry_functions[i].block_size_y = block_sizes_vec[i].y;
+ executable->entry_functions[i].block_size_z = block_sizes_vec[i].z;
+ }
+
+ iree_hal_resource_initialize(&iree_hal_rocm_native_executable_vtable,
+ &executable->resource);
+ executable->module = module;
+ executable->context = context;
+ *out_executable = (iree_hal_executable_t *)executable;
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+hipFunction_t iree_hal_rocm_native_executable_for_entry_point(
+ iree_hal_executable_t *base_executable, int32_t entry_point) {
+ iree_hal_rocm_native_executable_t *executable =
+ iree_hal_rocm_native_executable_cast(base_executable);
+ return executable->entry_functions[entry_point].rocm_function;
+}
+
+iree_status_t iree_hal_rocm_native_executable_block_size(
+ iree_hal_executable_t *base_executable, int32_t entry_point, uint32_t *x,
+ uint32_t *y, uint32_t *z) {
+ iree_hal_rocm_native_executable_t *executable =
+ iree_hal_rocm_native_executable_cast(base_executable);
+ *x = executable->entry_functions[entry_point].block_size_x;
+ *y = executable->entry_functions[entry_point].block_size_y;
+ *z = executable->entry_functions[entry_point].block_size_z;
+ return iree_ok_status();
+}
+
+static void iree_hal_rocm_native_executable_destroy(
+ iree_hal_executable_t *base_executable) {
+ iree_hal_rocm_native_executable_t *executable =
+ iree_hal_rocm_native_executable_cast(base_executable);
+ iree_allocator_t host_allocator = executable->context->host_allocator;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_allocator_free(host_allocator, executable);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+const iree_hal_executable_vtable_t iree_hal_rocm_native_executable_vtable = {
+ .destroy = iree_hal_rocm_native_executable_destroy,
+};
diff --git a/experimental/rocm/native_executable.h b/experimental/rocm/native_executable.h
new file mode 100644
index 0000000..d1ff352
--- /dev/null
+++ b/experimental/rocm/native_executable.h
@@ -0,0 +1,45 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_ROCM_NATIVE_EXECUTABLE_H_
+#define IREE_HAL_ROCM_NATIVE_EXECUTABLE_H_
+
+#include "experimental/rocm/context_wrapper.h"
+#include "experimental/rocm/rocm_headers.h"
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// Creates an executable from a HSACO module. The module may contain several
+// kernels that can be extracted along with the associated block size.
+iree_status_t iree_hal_rocm_native_executable_create(
+ iree_hal_rocm_context_wrapper_t *context,
+ const iree_hal_executable_spec_t *executable_spec,
+ iree_hal_executable_t **out_executable);
+
+hipFunction_t iree_hal_rocm_native_executable_for_entry_point(
+ iree_hal_executable_t *executable, int32_t entry_point);
+
+// Return the block size of the given |entry_point| within the executable.
+iree_status_t iree_hal_rocm_native_executable_block_size(
+ iree_hal_executable_t *executable, int32_t entry_point, uint32_t *x,
+ uint32_t *y, uint32_t *z);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_ROCM_NATIVE_EXECUTABLE_H_
diff --git a/experimental/rocm/nop_executable_cache.c b/experimental/rocm/nop_executable_cache.c
new file mode 100644
index 0000000..e225bab
--- /dev/null
+++ b/experimental/rocm/nop_executable_cache.c
@@ -0,0 +1,94 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "experimental/rocm/nop_executable_cache.h"
+
+#include "experimental/rocm/native_executable.h"
+#include "iree/base/tracing.h"
+
+typedef struct {
+ iree_hal_resource_t resource;
+ iree_hal_rocm_context_wrapper_t *context;
+} iree_hal_rocm_nop_executable_cache_t;
+
+extern const iree_hal_executable_cache_vtable_t
+ iree_hal_rocm_nop_executable_cache_vtable;
+
+static iree_hal_rocm_nop_executable_cache_t *
+iree_hal_rocm_nop_executable_cache_cast(
+ iree_hal_executable_cache_t *base_value) {
+ IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_nop_executable_cache_vtable);
+ return (iree_hal_rocm_nop_executable_cache_t *)base_value;
+}
+
+iree_status_t iree_hal_rocm_nop_executable_cache_create(
+ iree_hal_rocm_context_wrapper_t *context, iree_string_view_t identifier,
+ iree_hal_executable_cache_t **out_executable_cache) {
+ IREE_ASSERT_ARGUMENT(out_executable_cache);
+ *out_executable_cache = NULL;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_hal_rocm_nop_executable_cache_t *executable_cache = NULL;
+ iree_status_t status =
+ iree_allocator_malloc(context->host_allocator, sizeof(*executable_cache),
+ (void **)&executable_cache);
+ if (iree_status_is_ok(status)) {
+ iree_hal_resource_initialize(&iree_hal_rocm_nop_executable_cache_vtable,
+ &executable_cache->resource);
+ executable_cache->context = context;
+
+ *out_executable_cache = (iree_hal_executable_cache_t *)executable_cache;
+ }
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+static void iree_hal_rocm_nop_executable_cache_destroy(
+ iree_hal_executable_cache_t *base_executable_cache) {
+ iree_hal_rocm_nop_executable_cache_t *executable_cache =
+ iree_hal_rocm_nop_executable_cache_cast(base_executable_cache);
+ iree_allocator_t host_allocator = executable_cache->context->host_allocator;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_allocator_free(host_allocator, executable_cache);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+static bool iree_hal_rocm_nop_executable_cache_can_prepare_format(
+ iree_hal_executable_cache_t *base_executable_cache,
+ iree_hal_executable_caching_mode_t caching_mode,
+ iree_string_view_t executable_format) {
+ return iree_string_view_equal(executable_format,
+ iree_make_cstring_view("PTXE"));
+}
+
+static iree_status_t iree_hal_rocm_nop_executable_cache_prepare_executable(
+ iree_hal_executable_cache_t *base_executable_cache,
+ const iree_hal_executable_spec_t *executable_spec,
+ iree_hal_executable_t **out_executable) {
+ iree_hal_rocm_nop_executable_cache_t *executable_cache =
+ iree_hal_rocm_nop_executable_cache_cast(base_executable_cache);
+ return iree_hal_rocm_native_executable_create(
+ executable_cache->context, executable_spec, out_executable);
+}
+
+const iree_hal_executable_cache_vtable_t
+ iree_hal_rocm_nop_executable_cache_vtable = {
+ .destroy = iree_hal_rocm_nop_executable_cache_destroy,
+ .can_prepare_format =
+ iree_hal_rocm_nop_executable_cache_can_prepare_format,
+ .prepare_executable =
+ iree_hal_rocm_nop_executable_cache_prepare_executable,
+};
diff --git a/experimental/rocm/nop_executable_cache.h b/experimental/rocm/nop_executable_cache.h
new file mode 100644
index 0000000..72af2ca
--- /dev/null
+++ b/experimental/rocm/nop_executable_cache.h
@@ -0,0 +1,36 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_ROCM_NOP_EXECUTABLE_CACHE_H_
+#define IREE_HAL_ROCM_NOP_EXECUTABLE_CACHE_H_
+
+#include "experimental/rocm/context_wrapper.h"
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// Creates a no-op executable cache that does not cache at all.
+// This is useful to isolate pipeline caching behavior and verify compilation
+// behavior.
+iree_status_t iree_hal_rocm_nop_executable_cache_create(
+ iree_hal_rocm_context_wrapper_t *context, iree_string_view_t identifier,
+ iree_hal_executable_cache_t **out_executable_cache);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_ROCM_NOP_EXECUTABLE_CACHE_H_
diff --git a/experimental/rocm/registration/CMakeLists.txt b/experimental/rocm/registration/CMakeLists.txt
new file mode 100644
index 0000000..fd6e66a
--- /dev/null
+++ b/experimental/rocm/registration/CMakeLists.txt
@@ -0,0 +1,29 @@
+# Copyright 2021 Google LLC
+
+iree_add_all_subdirs()
+
+if(${IREE_BUILD_EXPERIMENTAL_ROCM})
+
+iree_cc_library(
+ NAME
+ registration
+ HDRS
+ "driver_module.h"
+ SRCS
+ "driver_module.c"
+ DEPS
+ iree::base::core_headers
+ iree::base::status
+ iree::base::tracing
+ iree::hal
+ experimental::rocm
+ INCLUDES
+ "${CMAKE_CURRENT_LIST_DIR}/../../.."
+ DEFINES
+ "IREE_BUILD_EXPERIMENTAL_ROCM=1"
+ PUBLIC
+)
+
+endif()
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/experimental/rocm/registration/driver_module.c b/experimental/rocm/registration/driver_module.c
new file mode 100644
index 0000000..0ff0bb9
--- /dev/null
+++ b/experimental/rocm/registration/driver_module.c
@@ -0,0 +1,71 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "experimental/rocm/registration/driver_module.h"
+
+#include <inttypes.h>
+
+#include "experimental/rocm/api.h"
+#include "iree/base/target_platform.h"
+#include "iree/base/tracing.h"
+
+#define IREE_HAL_ROCM_DRIVER_ID 0x524f434d0au // ROCM
+
+static iree_status_t iree_hal_rocm_driver_factory_enumerate(
+ void *self, const iree_hal_driver_info_t **out_driver_infos,
+ iree_host_size_t *out_driver_info_count) {
+ // NOTE: we could query supported ROCM versions or featuresets here.
+ static const iree_hal_driver_info_t driver_infos[1] = {{
+ .driver_id = IREE_HAL_ROCM_DRIVER_ID,
+ .driver_name = iree_string_view_literal("rocm"),
+ .full_name = iree_string_view_literal("ROCM (dynamic)"),
+ }};
+ *out_driver_info_count = IREE_ARRAYSIZE(driver_infos);
+ *out_driver_infos = driver_infos;
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_rocm_driver_factory_try_create(
+ void *self, iree_hal_driver_id_t driver_id, iree_allocator_t allocator,
+ iree_hal_driver_t **out_driver) {
+ IREE_ASSERT_ARGUMENT(out_driver);
+ *out_driver = NULL;
+ if (driver_id != IREE_HAL_ROCM_DRIVER_ID) {
+ return iree_make_status(IREE_STATUS_UNAVAILABLE,
+ "no driver with ID %016" PRIu64
+ " is provided by this factory",
+ driver_id);
+ }
+ IREE_TRACE_ZONE_BEGIN(z0);
+ // When we expose more than one driver (different rocm versions, etc) we
+ // can name them here:
+ iree_string_view_t identifier = iree_make_cstring_view("rocm");
+
+ iree_hal_rocm_driver_options_t driver_options;
+ iree_hal_rocm_driver_options_initialize(&driver_options);
+ iree_status_t status = iree_hal_rocm_driver_create(
+ identifier, &driver_options, allocator, out_driver);
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+IREE_API_EXPORT iree_status_t
+iree_hal_rocm_driver_module_register(iree_hal_driver_registry_t *registry) {
+ static const iree_hal_driver_factory_t factory = {
+ .self = NULL,
+ .enumerate = iree_hal_rocm_driver_factory_enumerate,
+ .try_create = iree_hal_rocm_driver_factory_try_create,
+ };
+ return iree_hal_driver_registry_register_factory(registry, &factory);
+}
diff --git a/experimental/rocm/registration/driver_module.h b/experimental/rocm/registration/driver_module.h
new file mode 100644
index 0000000..376a21d
--- /dev/null
+++ b/experimental/rocm/registration/driver_module.h
@@ -0,0 +1,31 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_ROCM_REGISTRATION_DRIVER_MODULE_H_
+#define IREE_HAL_ROCM_REGISTRATION_DRIVER_MODULE_H_
+
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+IREE_API_EXPORT iree_status_t
+iree_hal_rocm_driver_module_register(iree_hal_driver_registry_t *registry);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_ROCM_REGISTRATION_DRIVER_MODULE_H_
diff --git a/experimental/rocm/rocm_allocator.c b/experimental/rocm/rocm_allocator.c
new file mode 100644
index 0000000..9b24628
--- /dev/null
+++ b/experimental/rocm/rocm_allocator.c
@@ -0,0 +1,175 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations ufnder the License.
+
+#include "experimental/rocm/rocm_allocator.h"
+
+#include "experimental/rocm/rocm_buffer.h"
+#include "experimental/rocm/status_util.h"
+#include "iree/base/tracing.h"
+
+typedef struct iree_hal_rocm_allocator_s {
+ iree_hal_resource_t resource;
+ iree_hal_rocm_context_wrapper_t *context;
+} iree_hal_rocm_allocator_t;
+
+extern const iree_hal_allocator_vtable_t iree_hal_rocm_allocator_vtable;
+
+static iree_hal_rocm_allocator_t *iree_hal_rocm_allocator_cast(
+ iree_hal_allocator_t *base_value) {
+ IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_allocator_vtable);
+ return (iree_hal_rocm_allocator_t *)base_value;
+}
+
+iree_status_t iree_hal_rocm_allocator_create(
+ iree_hal_rocm_context_wrapper_t *context,
+ iree_hal_allocator_t **out_allocator) {
+ IREE_ASSERT_ARGUMENT(context);
+ IREE_TRACE_ZONE_BEGIN(z0);
+ iree_hal_rocm_allocator_t *allocator = NULL;
+ iree_status_t status = iree_allocator_malloc(
+ context->host_allocator, sizeof(*allocator), (void **)&allocator);
+ if (iree_status_is_ok(status)) {
+ iree_hal_resource_initialize(&iree_hal_rocm_allocator_vtable,
+ &allocator->resource);
+ allocator->context = context;
+ *out_allocator = (iree_hal_allocator_t *)allocator;
+ }
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+static void iree_hal_rocm_allocator_destroy(
+ iree_hal_allocator_t *base_allocator) {
+ iree_hal_rocm_allocator_t *allocator =
+ iree_hal_rocm_allocator_cast(base_allocator);
+ iree_allocator_t host_allocator = allocator->context->host_allocator;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_allocator_free(host_allocator, allocator);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+static iree_allocator_t iree_hal_rocm_allocator_host_allocator(
+ const iree_hal_allocator_t *base_allocator) {
+ iree_hal_rocm_allocator_t *allocator =
+ (iree_hal_rocm_allocator_t *)base_allocator;
+ return allocator->context->host_allocator;
+}
+
+static iree_hal_buffer_compatibility_t
+iree_hal_rocm_allocator_query_buffer_compatibility(
+ iree_hal_allocator_t *base_allocator, iree_hal_memory_type_t memory_type,
+ iree_hal_buffer_usage_t allowed_usage,
+ iree_hal_buffer_usage_t intended_usage,
+ iree_device_size_t allocation_size) {
+ // TODO(benvanik): check to ensure the allocator can serve the memory type.
+
+ // Disallow usage not permitted by the buffer itself. Since we then use this
+ // to determine compatibility below we'll naturally set the right compat flags
+ // based on what's both allowed and intended.
+ intended_usage &= allowed_usage;
+
+ // All buffers can be allocated on the heap.
+ iree_hal_buffer_compatibility_t compatibility =
+ IREE_HAL_BUFFER_COMPATIBILITY_ALLOCATABLE;
+
+ // Buffers can only be used on the queue if they are device visible.
+ if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)) {
+ if (iree_all_bits_set(intended_usage, IREE_HAL_BUFFER_USAGE_TRANSFER)) {
+ compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER;
+ }
+ if (iree_all_bits_set(intended_usage, IREE_HAL_BUFFER_USAGE_DISPATCH)) {
+ compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH;
+ }
+ }
+
+ return compatibility;
+}
+
+static iree_status_t iree_hal_rocm_allocator_allocate_buffer(
+ iree_hal_allocator_t *base_allocator, iree_hal_memory_type_t memory_type,
+ iree_hal_buffer_usage_t allowed_usage, iree_host_size_t allocation_size,
+ iree_hal_buffer_t **out_buffer) {
+ iree_hal_rocm_allocator_t *allocator =
+ iree_hal_rocm_allocator_cast(base_allocator);
+ // Guard against the corner case where the requested buffer size is 0. The
+ // application is unlikely to do anything when requesting a 0-byte buffer; but
+ // it can happen in real world use cases. So we should at least not crash.
+ if (allocation_size == 0) allocation_size = 4;
+ iree_status_t status;
+ void *host_ptr = NULL;
+ hipDeviceptr_t device_ptr = 0;
+ if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) {
+ unsigned int flags = hipHostMallocMapped;
+ if (!iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_HOST_CACHED)) {
+ flags |= hipHostMallocWriteCombined;
+ }
+ status = ROCM_RESULT_TO_STATUS(
+ allocator->context->syms,
+ hipMemAllocHost(&host_ptr, allocation_size, flags));
+ if (iree_status_is_ok(status)) {
+ status = ROCM_RESULT_TO_STATUS(
+ allocator->context->syms,
+ hipHostGetDevicePointer(&device_ptr, host_ptr, /*flags=*/0));
+ }
+ } else {
+ status = ROCM_RESULT_TO_STATUS(allocator->context->syms,
+ hipMalloc(&device_ptr, allocation_size));
+ }
+
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_rocm_buffer_wrap(
+ (iree_hal_allocator_t *)allocator, memory_type,
+ IREE_HAL_MEMORY_ACCESS_ALL, allowed_usage, allocation_size,
+ /*byte_offset=*/0,
+ /*byte_length=*/allocation_size, device_ptr, host_ptr, out_buffer);
+ }
+ if (!iree_status_is_ok(status)) {
+ iree_hal_rocm_allocator_free(base_allocator, device_ptr, host_ptr,
+ memory_type);
+ }
+ return status;
+}
+
+void iree_hal_rocm_allocator_free(iree_hal_allocator_t *base_allocator,
+ hipDeviceptr_t device_ptr, void *host_ptr,
+ iree_hal_memory_type_t memory_type) {
+ iree_hal_rocm_allocator_t *allocator =
+ iree_hal_rocm_allocator_cast(base_allocator);
+ if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) {
+ ROCM_IGNORE_ERROR(allocator->context->syms, hipHostFree(host_ptr));
+ } else {
+ ROCM_IGNORE_ERROR(allocator->context->syms, hipFree(device_ptr));
+ }
+}
+
+static iree_status_t iree_hal_rocm_allocator_wrap_buffer(
+ iree_hal_allocator_t *base_allocator, iree_hal_memory_type_t memory_type,
+ iree_hal_memory_access_t allowed_access,
+ iree_hal_buffer_usage_t allowed_usage, iree_byte_span_t data,
+ iree_allocator_t data_allocator, iree_hal_buffer_t **out_buffer) {
+ return iree_make_status(IREE_STATUS_UNAVAILABLE,
+ "wrapping of external buffers not supported");
+}
+
+const iree_hal_allocator_vtable_t iree_hal_rocm_allocator_vtable = {
+ .destroy = iree_hal_rocm_allocator_destroy,
+ .host_allocator = iree_hal_rocm_allocator_host_allocator,
+ .query_buffer_compatibility =
+ iree_hal_rocm_allocator_query_buffer_compatibility,
+ .allocate_buffer = iree_hal_rocm_allocator_allocate_buffer,
+ .wrap_buffer = iree_hal_rocm_allocator_wrap_buffer,
+};
diff --git a/experimental/rocm/rocm_allocator.h b/experimental/rocm/rocm_allocator.h
new file mode 100644
index 0000000..e802684
--- /dev/null
+++ b/experimental/rocm/rocm_allocator.h
@@ -0,0 +1,40 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_ROCM_ALLOCATOR_H_
+#define IREE_HAL_ROCM_ALLOCATOR_H_
+
+#include "experimental/rocm/context_wrapper.h"
+#include "experimental/rocm/status_util.h"
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// Create a ROCM allocator.
+iree_status_t iree_hal_rocm_allocator_create(
+ iree_hal_rocm_context_wrapper_t *context,
+ iree_hal_allocator_t **out_allocator);
+
+// Free an allocation represent by the given device or host pointer.
+void iree_hal_rocm_allocator_free(iree_hal_allocator_t *allocator,
+ hipDeviceptr_t device_ptr, void *host_ptr,
+ iree_hal_memory_type_t memory_type);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_ROCM_ALLOCATOR_H_
diff --git a/experimental/rocm/rocm_buffer.c b/experimental/rocm/rocm_buffer.c
new file mode 100644
index 0000000..afb7810
--- /dev/null
+++ b/experimental/rocm/rocm_buffer.c
@@ -0,0 +1,140 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "experimental/rocm/rocm_buffer.h"
+
+#include "experimental/rocm/rocm_allocator.h"
+#include "experimental/rocm/status_util.h"
+#include "iree/base/tracing.h"
+
+typedef struct iree_hal_rocm_buffer_s {
+ iree_hal_buffer_t base;
+ void *host_ptr;
+ hipDeviceptr_t device_ptr;
+} iree_hal_rocm_buffer_t;
+
+extern const iree_hal_buffer_vtable_t iree_hal_rocm_buffer_vtable;
+
+static iree_hal_rocm_buffer_t *iree_hal_rocm_buffer_cast(
+ iree_hal_buffer_t *base_value) {
+ IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_buffer_vtable);
+ return (iree_hal_rocm_buffer_t *)base_value;
+}
+
+iree_status_t iree_hal_rocm_buffer_wrap(
+ iree_hal_allocator_t *allocator, iree_hal_memory_type_t memory_type,
+ iree_hal_memory_access_t allowed_access,
+ iree_hal_buffer_usage_t allowed_usage, iree_device_size_t allocation_size,
+ iree_device_size_t byte_offset, iree_device_size_t byte_length,
+ hipDeviceptr_t device_ptr, void *host_ptr, iree_hal_buffer_t **out_buffer) {
+ IREE_ASSERT_ARGUMENT(allocator);
+ IREE_ASSERT_ARGUMENT(out_buffer);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_hal_rocm_buffer_t *buffer = NULL;
+ iree_status_t status =
+ iree_allocator_malloc(iree_hal_allocator_host_allocator(allocator),
+ sizeof(*buffer), (void **)&buffer);
+ if (iree_status_is_ok(status)) {
+ iree_hal_resource_initialize(&iree_hal_rocm_buffer_vtable,
+ &buffer->base.resource);
+ buffer->base.allocator = allocator;
+ buffer->base.allocated_buffer = &buffer->base;
+ buffer->base.allocation_size = allocation_size;
+ buffer->base.byte_offset = byte_offset;
+ buffer->base.byte_length = byte_length;
+ buffer->base.memory_type = memory_type;
+ buffer->base.allowed_access = allowed_access;
+ buffer->base.allowed_usage = allowed_usage;
+ buffer->host_ptr = host_ptr;
+ buffer->device_ptr = device_ptr;
+ *out_buffer = &buffer->base;
+ }
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+static void iree_hal_rocm_buffer_destroy(iree_hal_buffer_t *base_buffer) {
+ iree_hal_rocm_buffer_t *buffer = iree_hal_rocm_buffer_cast(base_buffer);
+ iree_allocator_t host_allocator =
+ iree_hal_allocator_host_allocator(iree_hal_buffer_allocator(base_buffer));
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_hal_rocm_allocator_free(buffer->base.allocator, buffer->device_ptr,
+ buffer->host_ptr, buffer->base.memory_type);
+ iree_allocator_free(host_allocator, buffer);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+static iree_status_t iree_hal_rocm_buffer_map_range(
+ iree_hal_buffer_t *base_buffer, iree_hal_mapping_mode_t mapping_mode,
+ iree_hal_memory_access_t memory_access,
+ iree_device_size_t local_byte_offset, iree_device_size_t local_byte_length,
+ void **out_data_ptr) {
+ iree_hal_rocm_buffer_t *buffer = iree_hal_rocm_buffer_cast(base_buffer);
+
+ if (!iree_all_bits_set(buffer->base.memory_type,
+ IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) {
+ return iree_make_status(IREE_STATUS_INTERNAL,
+ "trying to map memory not host visible");
+ }
+
+ uint8_t *data_ptr = (uint8_t *)(buffer->host_ptr) + local_byte_offset;
+ // If we mapped for discard scribble over the bytes. This is not a mandated
+ // behavior but it will make debugging issues easier. Alternatively for
+ // heap buffers we could reallocate them such that ASAN yells, but that
+ // would only work if the entire buffer was discarded.
+#ifndef NDEBUG
+ if (iree_any_bit_set(memory_access, IREE_HAL_MEMORY_ACCESS_DISCARD)) {
+ memset(data_ptr + local_byte_offset, 0xCD, local_byte_length);
+ }
+#endif // !NDEBUG
+ *out_data_ptr = data_ptr;
+ return iree_ok_status();
+}
+
+static void iree_hal_rocm_buffer_unmap_range(
+ iree_hal_buffer_t *base_buffer, iree_device_size_t local_byte_offset,
+ iree_device_size_t local_byte_length, void *data_ptr) {
+ // nothing to do.
+}
+
+static iree_status_t iree_hal_rocm_buffer_invalidate_range(
+ iree_hal_buffer_t *base_buffer, iree_device_size_t local_byte_offset,
+ iree_device_size_t local_byte_length) {
+ // Nothing to do.
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_rocm_buffer_flush_range(
+ iree_hal_buffer_t *base_buffer, iree_device_size_t local_byte_offset,
+ iree_device_size_t local_byte_length) {
+ // Nothing to do.
+ return iree_ok_status();
+}
+
+void **iree_hal_rocm_buffer_device_pointer(iree_hal_buffer_t *base_buffer) {
+ iree_hal_rocm_buffer_t *buffer = iree_hal_rocm_buffer_cast(base_buffer);
+ return buffer->device_ptr;
+}
+
+const iree_hal_buffer_vtable_t iree_hal_rocm_buffer_vtable = {
+ .destroy = iree_hal_rocm_buffer_destroy,
+ .map_range = iree_hal_rocm_buffer_map_range,
+ .unmap_range = iree_hal_rocm_buffer_unmap_range,
+ .invalidate_range = iree_hal_rocm_buffer_invalidate_range,
+ .flush_range = iree_hal_rocm_buffer_flush_range,
+};
diff --git a/experimental/rocm/rocm_buffer.h b/experimental/rocm/rocm_buffer.h
new file mode 100644
index 0000000..f8a2a1c
--- /dev/null
+++ b/experimental/rocm/rocm_buffer.h
@@ -0,0 +1,42 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_ROCM_BUFFER_H_
+#define IREE_HAL_ROCM_BUFFER_H_
+
+#include "experimental/rocm/rocm_headers.h"
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// Wraps a rocm allocation in an iree_hal_buffer_t.
+iree_status_t iree_hal_rocm_buffer_wrap(
+ iree_hal_allocator_t *allocator, iree_hal_memory_type_t memory_type,
+ iree_hal_memory_access_t allowed_access,
+ iree_hal_buffer_usage_t allowed_usage, iree_device_size_t allocation_size,
+ iree_device_size_t byte_offset, iree_device_size_t byte_length,
+ hipDeviceptr_t device_ptr, void *host_ptr, iree_hal_buffer_t **out_buffer);
+
+// Returns the rocm base pointer for the given |buffer|.
+// This is the entire allocated_buffer and must be offset by the buffer
+// byte_offset and byte_length when used.
+void **iree_hal_rocm_buffer_device_pointer(iree_hal_buffer_t *buffer);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_ROCM_BUFFER_H_
diff --git a/experimental/rocm/rocm_device.c b/experimental/rocm/rocm_device.c
new file mode 100644
index 0000000..a110fd3
--- /dev/null
+++ b/experimental/rocm/rocm_device.c
@@ -0,0 +1,298 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "experimental/rocm/rocm_device.h"
+
+#include "experimental/rocm/api.h"
+#include "experimental/rocm/descriptor_set_layout.h"
+#include "experimental/rocm/direct_command_buffer.h"
+#include "experimental/rocm/dynamic_symbols.h"
+#include "experimental/rocm/event_semaphore.h"
+#include "experimental/rocm/executable_layout.h"
+#include "experimental/rocm/nop_executable_cache.h"
+#include "experimental/rocm/rocm_allocator.h"
+#include "experimental/rocm/rocm_event.h"
+#include "experimental/rocm/status_util.h"
+#include "iree/base/tracing.h"
+
+//===----------------------------------------------------------------------===//
+// iree_hal_rocm_device_t
+//===----------------------------------------------------------------------===//
+
+typedef struct {
+ iree_hal_resource_t resource;
+ iree_string_view_t identifier;
+
+ // Optional driver that owns the ROCM symbols. We retain it for our lifetime
+ // to ensure the symbols remains valid.
+ iree_hal_driver_t *driver;
+
+ hipDevice_t device;
+
+ // TODO: support multiple streams.
+ hipStream_t stream;
+ iree_hal_rocm_context_wrapper_t context_wrapper;
+ iree_hal_allocator_t *device_allocator;
+
+} iree_hal_rocm_device_t;
+
+extern const iree_hal_device_vtable_t iree_hal_rocm_device_vtable;
+
+static iree_hal_rocm_device_t *iree_hal_rocm_device_cast(
+ iree_hal_device_t *base_value) {
+ IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_device_vtable);
+ return (iree_hal_rocm_device_t *)base_value;
+}
+
+static void iree_hal_rocm_device_destroy(iree_hal_device_t *base_device) {
+ iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device);
+ iree_allocator_t host_allocator = iree_hal_device_host_allocator(base_device);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // There should be no more buffers live that use the allocator.
+ iree_hal_allocator_release(device->device_allocator);
+ ROCM_IGNORE_ERROR(device->context_wrapper.syms,
+ hipStreamDestroy(device->stream));
+
+ // Finally, destroy the device.
+ iree_hal_driver_release(device->driver);
+
+ iree_allocator_free(host_allocator, device);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+static iree_status_t iree_hal_rocm_device_create_internal(
+ iree_hal_driver_t *driver, iree_string_view_t identifier,
+ hipDevice_t rocm_device, hipStream_t stream, hipCtx_t context,
+ iree_hal_rocm_dynamic_symbols_t *syms, iree_allocator_t host_allocator,
+ iree_hal_device_t **out_device) {
+ iree_hal_rocm_device_t *device = NULL;
+ iree_host_size_t total_size = sizeof(*device) + identifier.size;
+ IREE_RETURN_IF_ERROR(
+ iree_allocator_malloc(host_allocator, total_size, (void **)&device));
+ memset(device, 0, total_size);
+ iree_hal_resource_initialize(&iree_hal_rocm_device_vtable, &device->resource);
+ device->driver = driver;
+ iree_hal_driver_retain(device->driver);
+ uint8_t *buffer_ptr = (uint8_t *)device + sizeof(*device);
+ buffer_ptr += iree_string_view_append_to_buffer(
+ identifier, &device->identifier, (char *)buffer_ptr);
+ device->device = rocm_device;
+ device->stream = stream;
+ device->context_wrapper.rocm_context = context;
+ device->context_wrapper.host_allocator = host_allocator;
+ device->context_wrapper.syms = syms;
+ iree_status_t status = iree_hal_rocm_allocator_create(
+ &device->context_wrapper, &device->device_allocator);
+ if (iree_status_is_ok(status)) {
+ *out_device = (iree_hal_device_t *)device;
+ } else {
+ iree_hal_device_release((iree_hal_device_t *)device);
+ }
+ return status;
+}
+
+iree_status_t iree_hal_rocm_device_create(iree_hal_driver_t *driver,
+ iree_string_view_t identifier,
+ iree_hal_rocm_dynamic_symbols_t *syms,
+ hipDevice_t device,
+ iree_allocator_t host_allocator,
+ iree_hal_device_t **out_device) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+ hipCtx_t context;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, ROCM_RESULT_TO_STATUS(syms, hipCtxCreate(&context, 0, device)));
+ hipStream_t stream;
+ iree_status_t status = ROCM_RESULT_TO_STATUS(
+ syms, hipStreamCreateWithFlags(&stream, hipStreamNonBlocking));
+
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_rocm_device_create_internal(driver, identifier, device,
+ stream, context, syms,
+ host_allocator, out_device);
+ }
+ if (!iree_status_is_ok(status)) {
+ if (stream) {
+ syms->hipStreamDestroy(stream);
+ }
+ syms->hipCtxDestroy(context);
+ }
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+static iree_string_view_t iree_hal_rocm_device_id(
+ iree_hal_device_t *base_device) {
+ iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device);
+ return device->identifier;
+}
+
+static iree_allocator_t iree_hal_rocm_device_host_allocator(
+ iree_hal_device_t *base_device) {
+ iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device);
+ return device->context_wrapper.host_allocator;
+}
+
+static iree_hal_allocator_t *iree_hal_rocm_device_allocator(
+ iree_hal_device_t *base_device) {
+ iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device);
+ return device->device_allocator;
+}
+
+static iree_status_t iree_hal_rocm_device_query_i32(
+ iree_hal_device_t *base_device, iree_string_view_t key,
+ int32_t *out_value) {
+ // iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device);
+ *out_value = 0;
+ return iree_make_status(IREE_STATUS_NOT_FOUND,
+ "unknown device configuration key value '%*.s'",
+ (int)key.size, key.data);
+}
+
+static iree_status_t iree_hal_rocm_device_create_command_buffer(
+ iree_hal_device_t *base_device, iree_hal_command_buffer_mode_t mode,
+ iree_hal_command_category_t command_categories,
+ iree_hal_queue_affinity_t queue_affinity,
+ iree_hal_command_buffer_t **out_command_buffer) {
+ iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device);
+ return iree_hal_rocm_direct_command_buffer_allocate(
+ &device->context_wrapper, mode, command_categories, queue_affinity,
+ out_command_buffer);
+}
+
+static iree_status_t iree_hal_rocm_device_create_descriptor_set(
+ iree_hal_device_t *base_device,
+ iree_hal_descriptor_set_layout_t *set_layout,
+ iree_host_size_t binding_count,
+ const iree_hal_descriptor_set_binding_t *bindings,
+ iree_hal_descriptor_set_t **out_descriptor_set) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "non-push descriptor sets still need work");
+}
+
+static iree_status_t iree_hal_rocm_device_create_descriptor_set_layout(
+ iree_hal_device_t *base_device,
+ iree_hal_descriptor_set_layout_usage_type_t usage_type,
+ iree_host_size_t binding_count,
+ const iree_hal_descriptor_set_layout_binding_t *bindings,
+ iree_hal_descriptor_set_layout_t **out_descriptor_set_layout) {
+ iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device);
+ return iree_hal_rocm_descriptor_set_layout_create(
+ &device->context_wrapper, usage_type, binding_count, bindings,
+ out_descriptor_set_layout);
+}
+
+static iree_status_t iree_hal_rocm_device_create_event(
+ iree_hal_device_t *base_device, iree_hal_event_t **out_event) {
+ iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device);
+ return iree_hal_rocm_event_create(&device->context_wrapper, out_event);
+}
+
+static iree_status_t iree_hal_rocm_device_create_executable_cache(
+ iree_hal_device_t *base_device, iree_string_view_t identifier,
+ iree_hal_executable_cache_t **out_executable_cache) {
+ iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device);
+ return iree_hal_rocm_nop_executable_cache_create(
+ &device->context_wrapper, identifier, out_executable_cache);
+}
+
+static iree_status_t iree_hal_rocm_device_create_executable_layout(
+ iree_hal_device_t *base_device, iree_host_size_t push_constants,
+ iree_host_size_t set_layout_count,
+ iree_hal_descriptor_set_layout_t **set_layouts,
+ iree_hal_executable_layout_t **out_executable_layout) {
+ iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device);
+ return iree_hal_rocm_executable_layout_create(
+ &device->context_wrapper, set_layout_count, set_layouts, push_constants,
+ out_executable_layout);
+}
+
+static iree_status_t iree_hal_rocm_device_create_semaphore(
+ iree_hal_device_t *base_device, uint64_t initial_value,
+ iree_hal_semaphore_t **out_semaphore) {
+ iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device);
+ return iree_hal_rocm_semaphore_create(&device->context_wrapper, initial_value,
+ out_semaphore);
+}
+
+static iree_status_t iree_hal_rocm_device_queue_submit(
+ iree_hal_device_t *base_device,
+ iree_hal_command_category_t command_categories,
+ iree_hal_queue_affinity_t queue_affinity, iree_host_size_t batch_count,
+ const iree_hal_submission_batch_t *batches) {
+ iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device);
+ // TODO(raikonenfnu): Once semaphore is implemented wait for semaphores
+ // TODO(thomasraoux): Conservatively syncronize after every submit until we
+ // support semaphores.
+ // TODO(raikonenfnu): currently run on default/null stream, when cmd buffer
+ // stream work with device->stream, we'll change
+ ROCM_RETURN_IF_ERROR(device->context_wrapper.syms, hipStreamSynchronize(0),
+ "hipStreamSynchronize");
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_rocm_device_submit_and_wait(
+ iree_hal_device_t *base_device,
+ iree_hal_command_category_t command_categories,
+ iree_hal_queue_affinity_t queue_affinity, iree_host_size_t batch_count,
+ const iree_hal_submission_batch_t *batches,
+ iree_hal_semaphore_t *wait_semaphore, uint64_t wait_value,
+ iree_timeout_t timeout) {
+ // Submit...
+ IREE_RETURN_IF_ERROR(iree_hal_rocm_device_queue_submit(
+ base_device, command_categories, queue_affinity, batch_count, batches));
+
+ // ...and wait.
+ return iree_hal_semaphore_wait(wait_semaphore, wait_value, timeout);
+}
+
+static iree_status_t iree_hal_rocm_device_wait_semaphores(
+ iree_hal_device_t *base_device, iree_hal_wait_mode_t wait_mode,
+ const iree_hal_semaphore_list_t *semaphore_list, iree_timeout_t timeout) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "semaphore not implemented");
+}
+
+static iree_status_t iree_hal_rocm_device_wait_idle(
+ iree_hal_device_t *base_device, iree_timeout_t timeout) {
+ iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device);
+ // Wait until the stream is done.
+ // TODO(thomasraoux): HIP doesn't support a deadline for wait, figure out how
+ // to handle it better.
+ ROCM_RETURN_IF_ERROR(device->context_wrapper.syms,
+ hipStreamSynchronize(device->stream),
+ "hipStreamSynchronize");
+ return iree_ok_status();
+}
+
+const iree_hal_device_vtable_t iree_hal_rocm_device_vtable = {
+ .destroy = iree_hal_rocm_device_destroy,
+ .id = iree_hal_rocm_device_id,
+ .host_allocator = iree_hal_rocm_device_host_allocator,
+ .device_allocator = iree_hal_rocm_device_allocator,
+ .query_i32 = iree_hal_rocm_device_query_i32,
+ .create_command_buffer = iree_hal_rocm_device_create_command_buffer,
+ .create_descriptor_set = iree_hal_rocm_device_create_descriptor_set,
+ .create_descriptor_set_layout =
+ iree_hal_rocm_device_create_descriptor_set_layout,
+ .create_event = iree_hal_rocm_device_create_event,
+ .create_executable_cache = iree_hal_rocm_device_create_executable_cache,
+ .create_executable_layout = iree_hal_rocm_device_create_executable_layout,
+ .create_semaphore = iree_hal_rocm_device_create_semaphore,
+ .queue_submit = iree_hal_rocm_device_queue_submit,
+ .submit_and_wait = iree_hal_rocm_device_submit_and_wait,
+ .wait_semaphores = iree_hal_rocm_device_wait_semaphores,
+ .wait_idle = iree_hal_rocm_device_wait_idle,
+};
diff --git a/experimental/rocm/rocm_device.h b/experimental/rocm/rocm_device.h
new file mode 100644
index 0000000..e3504c6
--- /dev/null
+++ b/experimental/rocm/rocm_device.h
@@ -0,0 +1,38 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_ROCM_ROCM_DEVICE_H_
+#define IREE_HAL_ROCM_ROCM_DEVICE_H_
+
+#include "experimental/rocm/api.h"
+#include "experimental/rocm/dynamic_symbols.h"
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// Creates a device that owns and manages its own hipContext.
+iree_status_t iree_hal_rocm_device_create(iree_hal_driver_t *driver,
+ iree_string_view_t identifier,
+ iree_hal_rocm_dynamic_symbols_t *syms,
+ hipDevice_t device,
+ iree_allocator_t host_allocator,
+ iree_hal_device_t **out_device);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_ROCM_ROCM_DEVICE_H_
diff --git a/experimental/rocm/rocm_driver.c b/experimental/rocm/rocm_driver.c
new file mode 100644
index 0000000..219c6a3
--- /dev/null
+++ b/experimental/rocm/rocm_driver.c
@@ -0,0 +1,211 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "experimental/rocm/api.h"
+#include "experimental/rocm/dynamic_symbols.h"
+#include "experimental/rocm/rocm_device.h"
+#include "experimental/rocm/status_util.h"
+#include "iree/base/tracing.h"
+
+typedef struct {
+ iree_hal_resource_t resource;
+ iree_allocator_t host_allocator;
+ // Identifier used for the driver in the IREE driver registry.
+ // We allow overriding so that multiple ROCM versions can be exposed in the
+ // same process.
+ iree_string_view_t identifier;
+ int default_device_index;
+ // ROCM symbols.
+ iree_hal_rocm_dynamic_symbols_t syms;
+} iree_hal_rocm_driver_t;
+
+// Pick a fixed lenght size for device names.
+#define IREE_MAX_ROCM_DEVICE_NAME_LENGTH 100
+
+extern const iree_hal_driver_vtable_t iree_hal_rocm_driver_vtable;
+
+static iree_hal_rocm_driver_t *iree_hal_rocm_driver_cast(
+ iree_hal_driver_t *base_value) {
+ IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_driver_vtable);
+ return (iree_hal_rocm_driver_t *)base_value;
+}
+
+IREE_API_EXPORT void iree_hal_rocm_driver_options_initialize(
+ iree_hal_rocm_driver_options_t *out_options) {
+ memset(out_options, 0, sizeof(*out_options));
+ out_options->default_device_index = 0;
+}
+
+static iree_status_t iree_hal_rocm_driver_create_internal(
+ iree_string_view_t identifier,
+ const iree_hal_rocm_driver_options_t *options,
+ iree_allocator_t host_allocator, iree_hal_driver_t **out_driver) {
+ iree_hal_rocm_driver_t *driver = NULL;
+ iree_host_size_t total_size = sizeof(*driver) + identifier.size;
+ IREE_RETURN_IF_ERROR(
+ iree_allocator_malloc(host_allocator, total_size, (void **)&driver));
+ iree_hal_resource_initialize(&iree_hal_rocm_driver_vtable, &driver->resource);
+ driver->host_allocator = host_allocator;
+ iree_string_view_append_to_buffer(
+ identifier, &driver->identifier,
+ (char *)driver + total_size - identifier.size);
+ driver->default_device_index = options->default_device_index;
+ iree_status_t status =
+ iree_hal_rocm_dynamic_symbols_initialize(host_allocator, &driver->syms);
+ if (iree_status_is_ok(status)) {
+ *out_driver = (iree_hal_driver_t *)driver;
+ } else {
+ iree_hal_driver_release((iree_hal_driver_t *)driver);
+ }
+ return status;
+}
+
+static void iree_hal_rocm_driver_destroy(iree_hal_driver_t *base_driver) {
+ iree_hal_rocm_driver_t *driver = iree_hal_rocm_driver_cast(base_driver);
+ iree_allocator_t host_allocator = driver->host_allocator;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_hal_rocm_dynamic_symbols_deinitialize(&driver->syms);
+ iree_allocator_free(host_allocator, driver);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+IREE_API_EXPORT iree_status_t iree_hal_rocm_driver_create(
+ iree_string_view_t identifier,
+ const iree_hal_rocm_driver_options_t *options,
+ iree_allocator_t host_allocator, iree_hal_driver_t **out_driver) {
+ IREE_ASSERT_ARGUMENT(options);
+ IREE_ASSERT_ARGUMENT(out_driver);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_status_t status = iree_hal_rocm_driver_create_internal(
+ identifier, options, host_allocator, out_driver);
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+// Populates device information from the given ROCM physical device handle.
+// |out_device_info| must point to valid memory and additional data will be
+// appended to |buffer_ptr| and the new pointer is returned.
+static uint8_t *iree_hal_rocm_populate_device_info(
+ hipDevice_t device, iree_hal_rocm_dynamic_symbols_t *syms,
+ uint8_t *buffer_ptr, iree_hal_device_info_t *out_device_info) {
+ char device_name[IREE_MAX_ROCM_DEVICE_NAME_LENGTH];
+ ROCM_IGNORE_ERROR(syms,
+ hipDeviceGetName(device_name, sizeof(device_name), device));
+ memset(out_device_info, 0, sizeof(*out_device_info));
+ out_device_info->device_id = (iree_hal_device_id_t)device;
+
+ iree_string_view_t device_name_string =
+ iree_make_string_view(device_name, strlen(device_name));
+ buffer_ptr += iree_string_view_append_to_buffer(
+ device_name_string, &out_device_info->name, (char *)buffer_ptr);
+ return buffer_ptr;
+}
+
+static iree_status_t iree_hal_rocm_driver_query_available_devices(
+ iree_hal_driver_t *base_driver, iree_allocator_t host_allocator,
+ iree_hal_device_info_t **out_device_infos,
+ iree_host_size_t *out_device_info_count) {
+ iree_hal_rocm_driver_t *driver = iree_hal_rocm_driver_cast(base_driver);
+ // Query the number of available ROCM devices.
+ int device_count = 0;
+ ROCM_RETURN_IF_ERROR(&driver->syms, hipGetDeviceCount(&device_count),
+ "hipGetDeviceCount");
+
+ // Allocate the return infos and populate with the devices.
+ iree_hal_device_info_t *device_infos = NULL;
+ iree_host_size_t total_size = device_count * sizeof(iree_hal_device_info_t);
+ for (iree_host_size_t i = 0; i < device_count; ++i) {
+ total_size += IREE_MAX_ROCM_DEVICE_NAME_LENGTH * sizeof(char);
+ }
+ iree_status_t status =
+ iree_allocator_malloc(host_allocator, total_size, (void **)&device_infos);
+ if (iree_status_is_ok(status)) {
+ uint8_t *buffer_ptr =
+ (uint8_t *)device_infos + device_count * sizeof(iree_hal_device_info_t);
+ for (iree_host_size_t i = 0; i < device_count; ++i) {
+ hipDevice_t device;
+ iree_status_t status = ROCM_RESULT_TO_STATUS(
+ &driver->syms, hipDeviceGet(&device, i), "hipDeviceGet");
+ if (!iree_status_is_ok(status)) break;
+ buffer_ptr = iree_hal_rocm_populate_device_info(
+ device, &driver->syms, buffer_ptr, &device_infos[i]);
+ }
+ }
+ if (iree_status_is_ok(status)) {
+ *out_device_info_count = device_count;
+ *out_device_infos = device_infos;
+ } else {
+ iree_allocator_free(host_allocator, device_infos);
+ }
+ return status;
+}
+
+static iree_status_t iree_hal_rocm_driver_select_default_device(
+ iree_hal_rocm_dynamic_symbols_t *syms, int default_device_index,
+ iree_allocator_t host_allocator, hipDevice_t *out_device) {
+ int device_count = 0;
+ ROCM_RETURN_IF_ERROR(syms, hipGetDeviceCount(&device_count),
+ "hipGetDeviceCount");
+ iree_status_t status = iree_ok_status();
+ if (device_count == 0 || default_device_index >= device_count) {
+ status = iree_make_status(IREE_STATUS_NOT_FOUND,
+ "default device %d not found (of %d enumerated)",
+ default_device_index, device_count);
+ } else {
+ hipDevice_t device;
+ ROCM_RETURN_IF_ERROR(syms, hipDeviceGet(&device, default_device_index),
+ "hipDeviceGet");
+ *out_device = device;
+ }
+ return status;
+}
+
+static iree_status_t iree_hal_rocm_driver_create_device(
+ iree_hal_driver_t *base_driver, iree_hal_device_id_t device_id,
+ iree_allocator_t host_allocator, iree_hal_device_t **out_device) {
+ iree_hal_rocm_driver_t *driver = iree_hal_rocm_driver_cast(base_driver);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, ROCM_RESULT_TO_STATUS(&driver->syms, hipInit(0), "hipInit"));
+ // Use either the specified device (enumerated earlier) or whatever default
+ // one was specified when the driver was created.
+ hipDevice_t device = (hipDevice_t)device_id;
+ if (device == 0) {
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_rocm_driver_select_default_device(
+ &driver->syms, driver->default_device_index, host_allocator,
+ &device));
+ }
+
+ iree_string_view_t device_name = iree_make_cstring_view("rocm");
+
+ // Attempt to create the device.
+ iree_status_t status =
+ iree_hal_rocm_device_create(base_driver, device_name, &driver->syms,
+ device, host_allocator, out_device);
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+const iree_hal_driver_vtable_t iree_hal_rocm_driver_vtable = {
+ .destroy = iree_hal_rocm_driver_destroy,
+ .query_available_devices = iree_hal_rocm_driver_query_available_devices,
+ .create_device = iree_hal_rocm_driver_create_device,
+};
diff --git a/experimental/rocm/rocm_event.c b/experimental/rocm/rocm_event.c
new file mode 100644
index 0000000..a496ba4
--- /dev/null
+++ b/experimental/rocm/rocm_event.c
@@ -0,0 +1,67 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "experimental/rocm/rocm_event.h"
+
+#include "experimental/rocm/status_util.h"
+#include "iree/base/tracing.h"
+
+// Dummy events for now, don't do anything.
+typedef struct {
+ iree_hal_resource_t resource;
+ iree_hal_rocm_context_wrapper_t *context_wrapper;
+} iree_hal_rocm_event_t;
+
+extern const iree_hal_event_vtable_t iree_hal_rocm_event_vtable;
+
+static iree_hal_rocm_event_t *iree_hal_rocm_event_cast(
+ iree_hal_event_t *base_value) {
+ IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_event_vtable);
+ return (iree_hal_rocm_event_t *)base_value;
+}
+
+iree_status_t iree_hal_rocm_event_create(
+ iree_hal_rocm_context_wrapper_t *context_wrapper,
+ iree_hal_event_t **out_event) {
+ IREE_ASSERT_ARGUMENT(context_wrapper);
+ IREE_ASSERT_ARGUMENT(out_event);
+ *out_event = NULL;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_hal_rocm_event_t *event = NULL;
+ iree_status_t status = iree_allocator_malloc(context_wrapper->host_allocator,
+ sizeof(*event), (void **)&event);
+ if (iree_status_is_ok(status)) {
+ iree_hal_resource_initialize(&iree_hal_rocm_event_vtable, &event->resource);
+ event->context_wrapper = context_wrapper;
+ *out_event = (iree_hal_event_t *)event;
+ }
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+static void iree_hal_rocm_event_destroy(iree_hal_event_t *base_event) {
+ iree_hal_rocm_event_t *event = iree_hal_rocm_event_cast(base_event);
+ iree_allocator_t host_allocator = event->context_wrapper->host_allocator;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_allocator_free(host_allocator, event);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+const iree_hal_event_vtable_t iree_hal_rocm_event_vtable = {
+ .destroy = iree_hal_rocm_event_destroy,
+};
diff --git a/experimental/rocm/rocm_event.h b/experimental/rocm/rocm_event.h
new file mode 100644
index 0000000..d97d023
--- /dev/null
+++ b/experimental/rocm/rocm_event.h
@@ -0,0 +1,38 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_ROCM_EVENT_H_
+#define IREE_HAL_ROCM_EVENT_H_
+
+#include "experimental/rocm/context_wrapper.h"
+#include "experimental/rocm/rocm_headers.h"
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// Creates a dummy event object. Object will be represented by rocm Graph edges
+// so nothing is created at creation time. When an event is signaled in the
+// command buffer we will add the appropriate edges to enforce the right
+// synchronization.
+iree_status_t iree_hal_rocm_event_create(
+ iree_hal_rocm_context_wrapper_t *context_wrapper,
+ iree_hal_event_t **out_event);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_ROCM_EVENT_H_
diff --git a/experimental/rocm/rocm_headers.h b/experimental/rocm/rocm_headers.h
new file mode 100644
index 0000000..866d0ac
--- /dev/null
+++ b/experimental/rocm/rocm_headers.h
@@ -0,0 +1,20 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_ROCM_ROCM_HEADERS_H_
+#define IREE_HAL_ROCM_ROCM_HEADERS_H_
+
+#include "hip/hip_runtime.h"
+
+#endif // IREE_HAL_ROCM_ROCM_HEADERS_H_
diff --git a/experimental/rocm/status_util.c b/experimental/rocm/status_util.c
new file mode 100644
index 0000000..31304e5
--- /dev/null
+++ b/experimental/rocm/status_util.c
@@ -0,0 +1,38 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "experimental/rocm/status_util.h"
+
+#include "experimental/rocm/dynamic_symbols.h"
+
+iree_status_t iree_hal_rocm_result_to_status(
+ iree_hal_rocm_dynamic_symbols_t *syms, hipError_t result, const char *file,
+ uint32_t line) {
+ if (IREE_LIKELY(result == hipSuccess)) {
+ return iree_ok_status();
+ }
+
+ const char *error_name = syms->hipGetErrorName(result);
+ if (result == hipErrorUnknown) {
+ error_name = "UNKNOWN";
+ }
+
+ const char *error_string = syms->hipGetErrorString(result);
+ if (result == hipErrorUnknown) {
+ error_string = "Unknown error.";
+ }
+ return iree_make_status(IREE_STATUS_INTERNAL,
+ "rocm driver error '%s' (%d): %s", error_name, result,
+ error_string);
+}
diff --git a/experimental/rocm/status_util.h b/experimental/rocm/status_util.h
new file mode 100644
index 0000000..6a46462
--- /dev/null
+++ b/experimental/rocm/status_util.h
@@ -0,0 +1,60 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_ROCM_STATUS_UTIL_H_
+#define IREE_HAL_ROCM_STATUS_UTIL_H_
+
+#include "experimental/rocm/dynamic_symbols.h"
+#include "iree/base/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// Converts a hipError_t to an iree_status_t.
+//
+// Usage:
+// iree_status_t status = ROCM_RESULT_TO_STATUS(rocmDoThing(...));
+#define ROCM_RESULT_TO_STATUS(syms, expr, ...) \
+ iree_hal_rocm_result_to_status((syms), ((syms)->expr), __FILE__, __LINE__)
+
+// IREE_RETURN_IF_ERROR but implicitly converts the hipError_t return value to
+// a Status.
+//
+// Usage:
+// ROCM_RETURN_IF_ERROR(rocmDoThing(...), "message");
+#define ROCM_RETURN_IF_ERROR(syms, expr, ...) \
+ IREE_RETURN_IF_ERROR(iree_hal_rocm_result_to_status((syms), ((syms)->expr), \
+ __FILE__, __LINE__), \
+ __VA_ARGS__)
+
+// IREE_IGNORE_ERROR but implicitly converts the hipError_t return value to a
+// Status.
+//
+// Usage:
+// ROCM_IGNORE_ERROR(rocmDoThing(...));
+#define ROCM_IGNORE_ERROR(syms, expr) \
+ IREE_IGNORE_ERROR(iree_hal_rocm_result_to_status((syms), ((syms)->expr), \
+ __FILE__, __LINE__))
+
+// Converts a hipError_t to a Status object.
+iree_status_t iree_hal_rocm_result_to_status(
+ iree_hal_rocm_dynamic_symbols_t *syms, hipError_t result, const char *file,
+ uint32_t line);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_ROCM_STATUS_UTIL_H_