[HIP] Adds graph command buffer & descriptor set and pipeline layout (#15910)
Progress towards #15790
diff --git a/experimental/hip/CMakeLists.txt b/experimental/hip/CMakeLists.txt
index b94f73a..5efb844 100644
--- a/experimental/hip/CMakeLists.txt
+++ b/experimental/hip/CMakeLists.txt
@@ -25,6 +25,8 @@
"api.h"
SRCS
"api.h"
+ "graph_command_buffer.c"
+ "graph_command_buffer.h"
"hip_allocator.c"
"hip_allocator.h"
"hip_buffer.c"
@@ -38,6 +40,8 @@
"native_executable.h"
"nop_executable_cache.c"
"nop_executable_cache.h"
+ "pipeline_layout.c"
+ "pipeline_layout.h"
INCLUDES
"${HIP_API_HEADERS_ROOT}"
DEPS
@@ -48,6 +52,8 @@
iree::base::internal::arena
iree::base::internal::flatcc::parsing
iree::hal
+ iree::hal::utils::collective_batch
+ iree::hal::utils::resource_set
iree::schemas::rocm_executable_def_c_fbs
COPTS
"-D__HIP_PLATFORM_HCC__=1"
diff --git a/experimental/hip/dynamic_symbol_tables.h b/experimental/hip/dynamic_symbol_tables.h
index 1563769..be36f53 100644
--- a/experimental/hip/dynamic_symbol_tables.h
+++ b/experimental/hip/dynamic_symbol_tables.h
@@ -15,6 +15,9 @@
IREE_HIP_PFN_DECL(hipDeviceGetUuid, hipUUID *, hipDevice_t)
IREE_HIP_PFN_DECL(hipDevicePrimaryCtxRelease, hipDevice_t)
IREE_HIP_PFN_DECL(hipDevicePrimaryCtxRetain, hipCtx_t *, hipDevice_t)
+IREE_HIP_PFN_DECL(hipDrvGraphAddMemcpyNode, hipGraphNode_t *, hipGraph_t,
+ const hipGraphNode_t *, size_t, const HIP_MEMCPY3D *,
+ hipCtx_t)
IREE_HIP_PFN_DECL(hipEventCreate, hipEvent_t *)
IREE_HIP_PFN_DECL(hipEventDestroy, hipEvent_t)
IREE_HIP_PFN_DECL(hipEventElapsedTime, float *, hipEvent_t, hipEvent_t)
@@ -30,6 +33,18 @@
// const char* instead of hipError_t so it uses a different macro.
IREE_HIP_PFN_STR_DECL(hipGetErrorName, hipError_t)
IREE_HIP_PFN_STR_DECL(hipGetErrorString, hipError_t)
+IREE_HIP_PFN_DECL(hipGraphAddEmptyNode, hipGraphNode_t *, hipGraph_t,
+ const hipGraphNode_t *, size_t)
+IREE_HIP_PFN_DECL(hipGraphAddKernelNode, hipGraphNode_t *, hipGraph_t,
+ const hipGraphNode_t *, size_t, const hipKernelNodeParams *)
+IREE_HIP_PFN_DECL(hipGraphAddMemsetNode, hipGraphNode_t *, hipGraph_t,
+ const hipGraphNode_t *, size_t, const hipMemsetParams *)
+IREE_HIP_PFN_DECL(hipGraphCreate, hipGraph_t *, unsigned int)
+IREE_HIP_PFN_DECL(hipGraphDestroy, hipGraph_t)
+IREE_HIP_PFN_DECL(hipGraphExecDestroy, hipGraphExec_t)
+IREE_HIP_PFN_DECL(hipGraphInstantiate, hipGraphExec_t *, hipGraph_t,
+ hipGraphNode_t *, char *, size_t)
+IREE_HIP_PFN_DECL(hipGraphLaunch, hipGraphExec_t, hipStream_t)
IREE_HIP_PFN_DECL(hipHostFree, void *)
IREE_HIP_PFN_DECL(hipHostGetDevicePointer, void **, void *, unsigned int)
IREE_HIP_PFN_DECL(hipHostMalloc, void **, size_t, unsigned int)
@@ -63,7 +78,6 @@
IREE_HIP_PFN_DECL(hipModuleLoadDataEx, hipModule_t *, const void *,
unsigned int, hipJitOption *, void **)
IREE_HIP_PFN_DECL(hipModuleUnload, hipModule_t)
-IREE_HIP_PFN_DECL(hipSetDevice, int)
IREE_HIP_PFN_DECL(hipStreamCreateWithFlags, hipStream_t *, unsigned int)
IREE_HIP_PFN_DECL(hipStreamDestroy, hipStream_t)
IREE_HIP_PFN_DECL(hipStreamSynchronize, hipStream_t)
diff --git a/experimental/hip/graph_command_buffer.c b/experimental/hip/graph_command_buffer.c
new file mode 100644
index 0000000..b3cb73c
--- /dev/null
+++ b/experimental/hip/graph_command_buffer.c
@@ -0,0 +1,682 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "experimental/hip/graph_command_buffer.h"
+
+#include <assert.h>
+#include <stddef.h>
+#include <stdint.h>
+
+#include "experimental/hip/dynamic_symbols.h"
+#include "experimental/hip/hip_buffer.h"
+#include "experimental/hip/native_executable.h"
+#include "experimental/hip/pipeline_layout.h"
+#include "experimental/hip/status_util.h"
+#include "iree/base/api.h"
+#include "iree/hal/utils/resource_set.h"
+
+// The maximal number of HIP graph nodes that can run concurrently between
+// barriers.
+#define IREE_HAL_HIP_MAX_CONCURRENT_GRAPH_NODE_COUNT 32
+
+// Command buffer implementation that directly records into HIP graphs.
+// The command buffer records the commands on the calling thread without
+// additional threading indirection.
+typedef struct iree_hal_hip_graph_command_buffer_t {
+ iree_hal_command_buffer_t base;
+ iree_allocator_t host_allocator;
+ const iree_hal_hip_dynamic_symbols_t* symbols;
+
+ // A resource set to maintain references to all resources used within the
+ // command buffer.
+ iree_hal_resource_set_t* resource_set;
+
+ // Staging arena used for host->device transfers.
+ // This is used for when we need HIP to be able to reference memory as it
+ // performs asynchronous operations.
+ iree_arena_allocator_t arena;
+
+ hipCtx_t hip_context;
+ // The HIP graph under construction.
+ hipGraph_t hip_graph;
+ hipGraphExec_t hip_exec;
+
+ // A node acting as a barrier for all commands added to the command buffer.
+ hipGraphNode_t hip_barrier_node;
+
+ // Nodes added to the command buffer after the last barrier.
+ hipGraphNode_t hip_graph_nodes[IREE_HAL_HIP_MAX_CONCURRENT_GRAPH_NODE_COUNT];
+ iree_host_size_t graph_node_count;
+
+ int32_t push_constants[IREE_HAL_HIP_MAX_PUSH_CONSTANT_COUNT];
+
+ // The current bound descriptor sets.
+ struct {
+ hipDeviceptr_t bindings[IREE_HAL_HIP_MAX_DESCRIPTOR_SET_BINDING_COUNT];
+ } descriptor_sets[IREE_HAL_HIP_MAX_DESCRIPTOR_SET_COUNT];
+} iree_hal_hip_graph_command_buffer_t;
+
+static const iree_hal_command_buffer_vtable_t
+ iree_hal_hip_graph_command_buffer_vtable;
+
+static iree_hal_hip_graph_command_buffer_t*
+iree_hal_hip_graph_command_buffer_cast(iree_hal_command_buffer_t* base_value) {
+ IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_hip_graph_command_buffer_vtable);
+ return (iree_hal_hip_graph_command_buffer_t*)base_value;
+}
+
+iree_status_t iree_hal_hip_graph_command_buffer_create(
+ iree_hal_device_t* device,
+ const iree_hal_hip_dynamic_symbols_t* hip_symbols, hipCtx_t context,
+ iree_hal_command_buffer_mode_t mode,
+ iree_hal_command_category_t command_categories,
+ iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity,
+ iree_arena_block_pool_t* block_pool, iree_allocator_t host_allocator,
+ iree_hal_command_buffer_t** out_command_buffer) {
+ IREE_ASSERT_ARGUMENT(device);
+ IREE_ASSERT_ARGUMENT(hip_symbols);
+ IREE_ASSERT_ARGUMENT(block_pool);
+ IREE_ASSERT_ARGUMENT(out_command_buffer);
+ *out_command_buffer = NULL;
+
+ if (binding_capacity > 0) {
+ // TODO(#10144): support indirect command buffers with binding tables.
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "indirect command buffers not yet implemented");
+ }
+
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_hal_hip_graph_command_buffer_t* command_buffer = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_allocator_malloc(host_allocator, sizeof(*command_buffer),
+ (void**)&command_buffer));
+
+ iree_hal_command_buffer_initialize(
+ device, mode, command_categories, queue_affinity, binding_capacity,
+ &iree_hal_hip_graph_command_buffer_vtable, &command_buffer->base);
+ command_buffer->host_allocator = host_allocator;
+ command_buffer->symbols = hip_symbols;
+ iree_arena_initialize(block_pool, &command_buffer->arena);
+ command_buffer->hip_context = context;
+ command_buffer->hip_graph = NULL;
+ command_buffer->hip_exec = NULL;
+ command_buffer->hip_barrier_node = NULL;
+ command_buffer->graph_node_count = 0;
+
+ iree_status_t status =
+ iree_hal_resource_set_allocate(block_pool, &command_buffer->resource_set);
+
+ if (iree_status_is_ok(status)) {
+ *out_command_buffer = &command_buffer->base;
+ } else {
+ iree_hal_command_buffer_release(&command_buffer->base);
+ }
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+static void iree_hal_hip_graph_command_buffer_destroy(
+ iree_hal_command_buffer_t* base_command_buffer) {
+ iree_hal_hip_graph_command_buffer_t* command_buffer =
+ iree_hal_hip_graph_command_buffer_cast(base_command_buffer);
+ iree_allocator_t host_allocator = command_buffer->host_allocator;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ if (command_buffer->hip_graph != NULL) {
+ IREE_HIP_IGNORE_ERROR(command_buffer->symbols,
+ hipGraphDestroy(command_buffer->hip_graph));
+ command_buffer->hip_graph = NULL;
+ }
+ if (command_buffer->hip_exec != NULL) {
+ IREE_HIP_IGNORE_ERROR(command_buffer->symbols,
+ hipGraphExecDestroy(command_buffer->hip_exec));
+ command_buffer->hip_exec = NULL;
+ }
+ command_buffer->hip_barrier_node = NULL;
+ command_buffer->graph_node_count = 0;
+
+ iree_hal_resource_set_free(command_buffer->resource_set);
+ iree_arena_deinitialize(&command_buffer->arena);
+ iree_allocator_free(host_allocator, command_buffer);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+bool iree_hal_hip_graph_command_buffer_isa(
+ iree_hal_command_buffer_t* command_buffer) {
+ return iree_hal_resource_is(&command_buffer->resource,
+ &iree_hal_hip_graph_command_buffer_vtable);
+}
+
+hipGraphExec_t iree_hal_hip_graph_command_buffer_handle(
+ iree_hal_command_buffer_t* base_command_buffer) {
+ iree_hal_hip_graph_command_buffer_t* command_buffer =
+ iree_hal_hip_graph_command_buffer_cast(base_command_buffer);
+ return command_buffer->hip_exec;
+}
+
+static iree_status_t iree_hal_hip_graph_command_buffer_begin(
+ iree_hal_command_buffer_t* base_command_buffer) {
+ iree_hal_hip_graph_command_buffer_t* command_buffer =
+ iree_hal_hip_graph_command_buffer_cast(base_command_buffer);
+
+ if (command_buffer->hip_graph != NULL) {
+ return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+ "command buffer cannot be re-recorded");
+ }
+
+ // Create a new empty graph to record into.
+ IREE_HIP_RETURN_IF_ERROR(
+ command_buffer->symbols,
+ hipGraphCreate(&command_buffer->hip_graph, /*flags=*/0),
+ "hipGraphCreate");
+
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_hip_graph_command_buffer_end(
+ iree_hal_command_buffer_t* base_command_buffer) {
+ iree_hal_hip_graph_command_buffer_t* command_buffer =
+ iree_hal_hip_graph_command_buffer_cast(base_command_buffer);
+
+ // Reset state used during recording.
+ command_buffer->hip_barrier_node = NULL;
+ command_buffer->graph_node_count = 0;
+
+ // Compile the graph.
+ hipGraphNode_t error_node = NULL;
+ iree_status_t status = IREE_HIP_RESULT_TO_STATUS(
+ command_buffer->symbols,
+ hipGraphInstantiate(&command_buffer->hip_exec, command_buffer->hip_graph,
+ &error_node,
+ /*logBuffer=*/NULL,
+ /*bufferSize=*/0));
+ if (iree_status_is_ok(status)) {
+ // No longer need the source graph used for construction.
+ IREE_HIP_IGNORE_ERROR(command_buffer->symbols,
+ hipGraphDestroy(command_buffer->hip_graph));
+ command_buffer->hip_graph = NULL;
+ }
+
+ iree_hal_resource_set_freeze(command_buffer->resource_set);
+
+ return iree_ok_status();
+}
+
+static void iree_hal_hip_graph_command_buffer_begin_debug_group(
+ iree_hal_command_buffer_t* base_command_buffer, iree_string_view_t label,
+ iree_hal_label_color_t label_color,
+ const iree_hal_label_location_t* location) {
+ // TODO: tracy event stack.
+}
+
+static void iree_hal_hip_graph_command_buffer_end_debug_group(
+ iree_hal_command_buffer_t* base_command_buffer) {
+ // TODO: tracy event stack.
+}
+
+static iree_status_t iree_hal_hip_graph_command_buffer_execution_barrier(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_execution_stage_t source_stage_mask,
+ iree_hal_execution_stage_t target_stage_mask,
+ iree_hal_execution_barrier_flags_t flags,
+ iree_host_size_t memory_barrier_count,
+ const iree_hal_memory_barrier_t* memory_barriers,
+ iree_host_size_t buffer_barrier_count,
+ const iree_hal_buffer_barrier_t* buffer_barriers) {
+ iree_hal_hip_graph_command_buffer_t* command_buffer =
+ iree_hal_hip_graph_command_buffer_cast(base_command_buffer);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ IREE_ASSERT_GT(command_buffer->graph_node_count, 0,
+ "expected at least one node before a barrier");
+
+ // Use the last node as a barrier to avoid creating redundant empty nodes.
+ if (IREE_LIKELY(command_buffer->graph_node_count == 1)) {
+ command_buffer->hip_barrier_node = command_buffer->hip_graph_nodes[0];
+ command_buffer->graph_node_count = 0;
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+ }
+
+ IREE_HIP_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, command_buffer->symbols,
+ hipGraphAddEmptyNode(
+ &command_buffer->hip_barrier_node, command_buffer->hip_graph,
+ command_buffer->hip_graph_nodes, command_buffer->graph_node_count),
+ "hipGraphAddEmptyNode");
+
+ command_buffer->graph_node_count = 0;
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_hip_graph_command_buffer_signal_event(
+ iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event,
+ iree_hal_execution_stage_t source_stage_mask) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "event not yet supported");
+}
+
+static iree_status_t iree_hal_hip_graph_command_buffer_reset_event(
+ iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event,
+ iree_hal_execution_stage_t source_stage_mask) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "event not yet supported");
+}
+
+static iree_status_t iree_hal_hip_graph_command_buffer_wait_events(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_host_size_t event_count, const iree_hal_event_t** events,
+ iree_hal_execution_stage_t source_stage_mask,
+ iree_hal_execution_stage_t target_stage_mask,
+ iree_host_size_t memory_barrier_count,
+ const iree_hal_memory_barrier_t* memory_barriers,
+ iree_host_size_t buffer_barrier_count,
+ const iree_hal_buffer_barrier_t* buffer_barriers) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "event not yet supported");
+}
+
+static iree_status_t iree_hal_hip_graph_command_buffer_discard_buffer(
+ iree_hal_command_buffer_t* base_command_buffer, iree_hal_buffer_t* buffer) {
+ // We could mark the memory as invalidated so that if this is a managed buffer
+ // HIP does not try to copy it back to the host.
+ return iree_ok_status();
+}
+
+// Splats a pattern value of 1, 2, or 4 bytes out to a 4 byte value.
+static uint32_t iree_hal_hip_splat_pattern(const void* pattern,
+ size_t pattern_length) {
+ switch (pattern_length) {
+ case 1: {
+ uint32_t pattern_value = *(const uint8_t*)(pattern);
+ return (pattern_value << 24) | (pattern_value << 16) |
+ (pattern_value << 8) | pattern_value;
+ }
+ case 2: {
+ uint32_t pattern_value = *(const uint16_t*)(pattern);
+ return (pattern_value << 16) | pattern_value;
+ }
+ case 4: {
+ uint32_t pattern_value = *(const uint32_t*)(pattern);
+ return pattern_value;
+ }
+ default:
+ return 0; // Already verified that this should not be possible.
+ }
+}
+
+static iree_status_t iree_hal_hip_graph_command_buffer_fill_buffer(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+ iree_device_size_t length, const void* pattern,
+ iree_host_size_t pattern_length) {
+ iree_hal_hip_graph_command_buffer_t* command_buffer =
+ iree_hal_hip_graph_command_buffer_cast(base_command_buffer);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1,
+ &target_buffer));
+
+ hipDeviceptr_t target_device_buffer = iree_hal_hip_buffer_device_pointer(
+ iree_hal_buffer_allocated_buffer(target_buffer));
+ target_offset += iree_hal_buffer_byte_offset(target_buffer);
+ uint32_t pattern_4byte = iree_hal_hip_splat_pattern(pattern, pattern_length);
+ hipMemsetParams params = {
+ .dst = (uint8_t*)target_device_buffer + target_offset,
+ .elementSize = pattern_length,
+ .pitch = 0, // unused if height == 1
+ .width = length / pattern_length, // element count
+ .height = 1,
+ .value = pattern_4byte,
+ };
+
+ if (command_buffer->graph_node_count >=
+ IREE_HAL_HIP_MAX_CONCURRENT_GRAPH_NODE_COUNT) {
+ return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+ "exceeded max concurrent node limit");
+ }
+
+ size_t dependency_count = command_buffer->hip_barrier_node ? 1 : 0;
+ IREE_HIP_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, command_buffer->symbols,
+ hipGraphAddMemsetNode(
+ &command_buffer->hip_graph_nodes[command_buffer->graph_node_count++],
+ command_buffer->hip_graph, &command_buffer->hip_barrier_node,
+ dependency_count, ¶ms),
+ "hipGraphAddMemsetNode");
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_hip_graph_command_buffer_update_buffer(
+ iree_hal_command_buffer_t* base_command_buffer, const void* source_buffer,
+ iree_host_size_t source_offset, iree_hal_buffer_t* target_buffer,
+ iree_device_size_t target_offset, iree_device_size_t length) {
+ iree_hal_hip_graph_command_buffer_t* command_buffer =
+ iree_hal_hip_graph_command_buffer_cast(base_command_buffer);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // Allocate scratch space in the arena for the data and copy it in.
+ // The update buffer API requires that the command buffer capture the host
+ // memory at the time the method is called in case the caller wants to reuse
+ // the memory. Because HIP memcpys are async if we didn't copy it's possible
+ // for the reused memory to change before the stream reaches the copy
+ // operation and get the wrong data.
+ uint8_t* storage = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0,
+ iree_arena_allocate(&command_buffer->arena, length, (void**)&storage));
+ memcpy(storage, (const uint8_t*)source_buffer + source_offset, length);
+
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1,
+ &target_buffer));
+
+ hipDeviceptr_t target_device_buffer = iree_hal_hip_buffer_device_pointer(
+ iree_hal_buffer_allocated_buffer(target_buffer));
+
+ HIP_MEMCPY3D params = {
+ .srcMemoryType = hipMemoryTypeHost,
+ .srcHost = storage,
+ .dstMemoryType = hipMemoryTypeDevice,
+ .dstDevice = target_device_buffer,
+ .dstXInBytes = iree_hal_buffer_byte_offset(target_buffer) + target_offset,
+ .WidthInBytes = length,
+ .Height = 1,
+ .Depth = 1,
+ };
+
+ if (command_buffer->graph_node_count >=
+ IREE_HAL_HIP_MAX_CONCURRENT_GRAPH_NODE_COUNT) {
+ return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+ "exceeded max concurrent node limit");
+ }
+
+ size_t dependency_count = command_buffer->hip_barrier_node ? 1 : 0;
+ IREE_HIP_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, command_buffer->symbols,
+ hipDrvGraphAddMemcpyNode(
+ &command_buffer->hip_graph_nodes[command_buffer->graph_node_count++],
+ command_buffer->hip_graph, &command_buffer->hip_barrier_node,
+ dependency_count, ¶ms, command_buffer->hip_context),
+ "hipDrvGraphAddMemcpyNode");
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_hip_graph_command_buffer_copy_buffer(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
+ iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+ iree_device_size_t length) {
+ iree_hal_hip_graph_command_buffer_t* command_buffer =
+ iree_hal_hip_graph_command_buffer_cast(base_command_buffer);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ const iree_hal_buffer_t* buffers[2] = {source_buffer, target_buffer};
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0,
+ iree_hal_resource_set_insert(command_buffer->resource_set, 2, buffers));
+
+ hipDeviceptr_t target_device_buffer = iree_hal_hip_buffer_device_pointer(
+ iree_hal_buffer_allocated_buffer(target_buffer));
+ target_offset += iree_hal_buffer_byte_offset(target_buffer);
+ hipDeviceptr_t source_device_buffer = iree_hal_hip_buffer_device_pointer(
+ iree_hal_buffer_allocated_buffer(source_buffer));
+ source_offset += iree_hal_buffer_byte_offset(source_buffer);
+
+ HIP_MEMCPY3D params = {
+ .srcMemoryType = hipMemoryTypeDevice,
+ .srcDevice = source_device_buffer,
+ .srcXInBytes = source_offset,
+ .dstMemoryType = hipMemoryTypeDevice,
+ .dstDevice = target_device_buffer,
+ .dstXInBytes = target_offset,
+ .WidthInBytes = length,
+ .Height = 1,
+ .Depth = 1,
+ };
+
+ if (command_buffer->graph_node_count >=
+ IREE_HAL_HIP_MAX_CONCURRENT_GRAPH_NODE_COUNT) {
+ return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+ "exceeded max concurrent node limit");
+ }
+
+ size_t dependency_count = command_buffer->hip_barrier_node ? 1 : 0;
+ IREE_HIP_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, command_buffer->symbols,
+ hipDrvGraphAddMemcpyNode(
+ &command_buffer->hip_graph_nodes[command_buffer->graph_node_count++],
+ command_buffer->hip_graph, &command_buffer->hip_barrier_node,
+ dependency_count, ¶ms, command_buffer->hip_context),
+ "hipDrvGraphAddMemcpyNode");
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_hip_graph_command_buffer_collective(
+ iree_hal_command_buffer_t* base_command_buffer, iree_hal_channel_t* channel,
+ iree_hal_collective_op_t op, uint32_t param,
+ iree_hal_buffer_binding_t send_binding,
+ iree_hal_buffer_binding_t recv_binding, iree_device_size_t element_count) {
+ return iree_status_from_code(IREE_STATUS_UNIMPLEMENTED);
+}
+
+static iree_status_t iree_hal_hip_graph_command_buffer_push_constants(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_pipeline_layout_t* pipeline_layout, iree_host_size_t offset,
+ const void* values, iree_host_size_t values_length) {
+ iree_hal_hip_graph_command_buffer_t* command_buffer =
+ iree_hal_hip_graph_command_buffer_cast(base_command_buffer);
+
+ if (IREE_UNLIKELY(offset + values_length >=
+ sizeof(command_buffer->push_constants))) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "push constant range [%zu, %zu) out of range",
+ offset, offset + values_length);
+ }
+
+ memcpy((uint8_t*)&command_buffer->push_constants + offset, values,
+ values_length);
+
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_hip_graph_command_buffer_push_descriptor_set(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_pipeline_layout_t* pipeline_layout, uint32_t set,
+ iree_host_size_t binding_count,
+ const iree_hal_descriptor_set_binding_t* bindings) {
+ if (binding_count > IREE_HAL_HIP_MAX_DESCRIPTOR_SET_BINDING_COUNT) {
+ return iree_make_status(
+ IREE_STATUS_RESOURCE_EXHAUSTED,
+ "exceeded available binding slots for push "
+ "descriptor set #%" PRIu32 "; requested %" PRIhsz " vs. maximal %d",
+ set, binding_count, IREE_HAL_HIP_MAX_DESCRIPTOR_SET_BINDING_COUNT);
+ }
+
+ iree_hal_hip_graph_command_buffer_t* command_buffer =
+ iree_hal_hip_graph_command_buffer_cast(base_command_buffer);
+ IREE_TRACE_ZONE_BEGIN(z0);
+ hipDeviceptr_t* current_bindings =
+ command_buffer->descriptor_sets[set].bindings;
+ for (iree_host_size_t i = 0; i < binding_count; i++) {
+ const iree_hal_descriptor_set_binding_t* binding = &bindings[i];
+ hipDeviceptr_t device_ptr = NULL;
+ if (binding->buffer) {
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1,
+ &binding->buffer));
+
+ hipDeviceptr_t device_buffer = iree_hal_hip_buffer_device_pointer(
+ iree_hal_buffer_allocated_buffer(binding->buffer));
+ iree_device_size_t offset = iree_hal_buffer_byte_offset(binding->buffer);
+ device_ptr = (uint8_t*)device_buffer + offset + binding->offset;
+ }
+
+ current_bindings[binding->binding] = device_ptr;
+ }
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_hip_graph_command_buffer_dispatch(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
+ iree_hal_hip_graph_command_buffer_t* command_buffer =
+ iree_hal_hip_graph_command_buffer_cast(base_command_buffer);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // Lookup kernel parameters used for side-channeling additional launch
+ // information from the compiler.
+ iree_hal_hip_kernel_info_t kernel_info;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_hip_native_executable_entry_point_kernel_info(
+ executable, entry_point, &kernel_info));
+
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1,
+ &executable));
+ iree_hal_hip_dispatch_layout_t dispatch_params =
+ iree_hal_hip_pipeline_layout_dispatch_layout(kernel_info.layout);
+ // The total number of descriptors across all descriptor sets.
+ iree_host_size_t descriptor_count = dispatch_params.total_binding_count;
+ // The total number of push constants.
+ iree_host_size_t push_constant_count = dispatch_params.push_constant_count;
+ // We append push constants to the end of descriptors to form a linear chain
+ // of kernel arguments.
+ iree_host_size_t kernel_params_count = descriptor_count + push_constant_count;
+ iree_host_size_t kernel_params_length = kernel_params_count * sizeof(void*);
+
+ iree_host_size_t total_size = kernel_params_length * 2;
+ uint8_t* storage_base = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_arena_allocate(&command_buffer->arena, total_size,
+ (void**)&storage_base));
+ void** params_ptr = (void**)storage_base;
+
+ // Set up kernel arguments to point to the payload slots.
+ hipDeviceptr_t* payload_ptr =
+ (hipDeviceptr_t*)((uint8_t*)params_ptr + kernel_params_length);
+ for (size_t i = 0; i < kernel_params_count; i++) {
+ params_ptr[i] = &payload_ptr[i];
+ }
+
+ // Copy descriptors from all sets to the end of the current segment for later
+ // access.
+ iree_host_size_t set_count = dispatch_params.set_layout_count;
+ for (iree_host_size_t i = 0; i < set_count; ++i) {
+ // TODO: cache this information in the kernel info to avoid recomputation.
+ iree_host_size_t binding_count =
+ iree_hal_hip_descriptor_set_layout_binding_count(
+ iree_hal_hip_pipeline_layout_descriptor_set_layout(
+ kernel_info.layout, i));
+ iree_host_size_t index =
+ iree_hal_hip_pipeline_layout_base_binding_index(kernel_info.layout, i);
+ memcpy(payload_ptr + index, command_buffer->descriptor_sets[i].bindings,
+ binding_count * sizeof(hipDeviceptr_t));
+ }
+
+ // Append the push constants to the kernel arguments.
+ iree_host_size_t base_index = dispatch_params.push_constant_base_index;
+
+ // Each kernel parameter points to is a hipDeviceptr_t, which as the size of a
+ // pointer on the target machine. we are just storing a 32-bit value for the
+ // push constant here instead. So we must process one element each type, for
+ // 64-bit machines.
+ for (iree_host_size_t i = 0; i < push_constant_count; i++) {
+ *((uint32_t*)params_ptr[base_index + i]) =
+ command_buffer->push_constants[i];
+ }
+
+ hipKernelNodeParams params = {
+ .blockDim.x = kernel_info.block_size[0],
+ .blockDim.y = kernel_info.block_size[1],
+ .blockDim.z = kernel_info.block_size[2],
+ .gridDim.x = workgroup_x,
+ .gridDim.y = workgroup_y,
+ .gridDim.z = workgroup_z,
+ .func = kernel_info.function,
+ .kernelParams = params_ptr,
+ .sharedMemBytes = kernel_info.shared_memory_size,
+ };
+
+ if (command_buffer->graph_node_count >=
+ IREE_HAL_HIP_MAX_CONCURRENT_GRAPH_NODE_COUNT) {
+ return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+ "exceeded max concurrent node limit");
+ }
+
+ size_t dependency_count = command_buffer->hip_barrier_node ? 1 : 0;
+ IREE_HIP_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, command_buffer->symbols,
+ hipGraphAddKernelNode(
+ &command_buffer->hip_graph_nodes[command_buffer->graph_node_count++],
+ command_buffer->hip_graph, &command_buffer->hip_barrier_node,
+ dependency_count, ¶ms),
+ "hipGraphAddKernelNode");
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_hip_graph_command_buffer_dispatch_indirect(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ iree_hal_buffer_t* workgroups_buffer,
+ iree_device_size_t workgroups_offset) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "indirect dispatch not yet implemented");
+}
+
+static iree_status_t iree_hal_hip_graph_command_buffer_execute_commands(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_command_buffer_t* base_commands,
+ iree_hal_buffer_binding_table_t binding_table) {
+ // TODO(#10144): support indirect command buffers by adding subgraph nodes and
+ // tracking the binding table for future hipGraphExecKernelNodeSetParams
+ // usage. Need to look into how to update the params of the subgraph nodes -
+ // is the graph exec the outer one and if so will it allow node handles from
+ // the subgraphs?
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "indirect command buffers not yet implemented");
+}
+
+static const iree_hal_command_buffer_vtable_t
+ iree_hal_hip_graph_command_buffer_vtable = {
+ .destroy = iree_hal_hip_graph_command_buffer_destroy,
+ .begin = iree_hal_hip_graph_command_buffer_begin,
+ .end = iree_hal_hip_graph_command_buffer_end,
+ .begin_debug_group =
+ iree_hal_hip_graph_command_buffer_begin_debug_group,
+ .end_debug_group = iree_hal_hip_graph_command_buffer_end_debug_group,
+ .execution_barrier =
+ iree_hal_hip_graph_command_buffer_execution_barrier,
+ .signal_event = iree_hal_hip_graph_command_buffer_signal_event,
+ .reset_event = iree_hal_hip_graph_command_buffer_reset_event,
+ .wait_events = iree_hal_hip_graph_command_buffer_wait_events,
+ .discard_buffer = iree_hal_hip_graph_command_buffer_discard_buffer,
+ .fill_buffer = iree_hal_hip_graph_command_buffer_fill_buffer,
+ .update_buffer = iree_hal_hip_graph_command_buffer_update_buffer,
+ .copy_buffer = iree_hal_hip_graph_command_buffer_copy_buffer,
+ .collective = iree_hal_hip_graph_command_buffer_collective,
+ .push_constants = iree_hal_hip_graph_command_buffer_push_constants,
+ .push_descriptor_set =
+ iree_hal_hip_graph_command_buffer_push_descriptor_set,
+ .dispatch = iree_hal_hip_graph_command_buffer_dispatch,
+ .dispatch_indirect =
+ iree_hal_hip_graph_command_buffer_dispatch_indirect,
+ .execute_commands = iree_hal_hip_graph_command_buffer_execute_commands,
+};
diff --git a/experimental/hip/graph_command_buffer.h b/experimental/hip/graph_command_buffer.h
new file mode 100644
index 0000000..35253ce
--- /dev/null
+++ b/experimental/hip/graph_command_buffer.h
@@ -0,0 +1,50 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_EXPERIMENTAL_HIP_GRAPH_COMMAND_BUFFER_H_
+#define IREE_EXPERIMENTAL_HIP_GRAPH_COMMAND_BUFFER_H_
+
+#include "experimental/hip/dynamic_symbols.h"
+#include "experimental/hip/hip_headers.h"
+#include "iree/base/api.h"
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// NOTE: hipGraph API used in this module is marked as beta in the HIP
+// documentation, meaning, while this is feature complete it is still open to
+// changes and may have outstanding issues.
+
+typedef struct iree_arena_block_pool_t iree_arena_block_pool_t;
+
+// Creates a command buffer that records into a HIP graph.
+//
+// NOTE: the |block_pool| must remain live for the lifetime of the command
+// buffers that use it.
+iree_status_t iree_hal_hip_graph_command_buffer_create(
+ iree_hal_device_t* device,
+ const iree_hal_hip_dynamic_symbols_t* hip_symbols, hipCtx_t context,
+ iree_hal_command_buffer_mode_t mode,
+ iree_hal_command_category_t command_categories,
+ iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity,
+ iree_arena_block_pool_t* block_pool, iree_allocator_t host_allocator,
+ iree_hal_command_buffer_t** out_command_buffer);
+
+// Returns true if |command_buffer| is a HIP graph-based command buffer.
+bool iree_hal_hip_graph_command_buffer_isa(
+ iree_hal_command_buffer_t* command_buffer);
+
+// Returns the native HIP graph associated to the command buffer.
+hipGraphExec_t iree_hal_hip_graph_command_buffer_handle(
+ iree_hal_command_buffer_t* command_buffer);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_EXPERIMENTAL_HIP_GRAPH_COMMAND_BUFFER_H_
diff --git a/experimental/hip/hip_device.c b/experimental/hip/hip_device.c
index 552cac7..5cc7f12 100644
--- a/experimental/hip/hip_device.c
+++ b/experimental/hip/hip_device.c
@@ -11,15 +11,16 @@
#include <string.h>
#include "experimental/hip/dynamic_symbols.h"
+#include "experimental/hip/graph_command_buffer.h"
#include "experimental/hip/hip_allocator.h"
#include "experimental/hip/hip_buffer.h"
#include "experimental/hip/memory_pools.h"
#include "experimental/hip/nop_executable_cache.h"
+#include "experimental/hip/pipeline_layout.h"
#include "experimental/hip/status_util.h"
#include "iree/base/internal/arena.h"
#include "iree/base/internal/math.h"
#include "iree/base/tracing.h"
-#include "iree/hal/utils/deferred_command_buffer.h"
//===----------------------------------------------------------------------===//
// iree_hal_hip_device_t
@@ -322,8 +323,11 @@
iree_hal_command_category_t command_categories,
iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity,
iree_hal_command_buffer_t** out_command_buffer) {
- return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
- "command buffer not yet implmeneted");
+ iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device);
+ return iree_hal_hip_graph_command_buffer_create(
+ base_device, device->hip_symbols, device->hip_context, mode,
+ command_categories, queue_affinity, binding_capacity, &device->block_pool,
+ device->host_allocator, out_command_buffer);
}
static iree_status_t iree_hal_hip_device_create_descriptor_set_layout(
@@ -332,8 +336,10 @@
iree_host_size_t binding_count,
const iree_hal_descriptor_set_layout_binding_t* bindings,
iree_hal_descriptor_set_layout_t** out_descriptor_set_layout) {
- return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
- "descriptor set layout not yet implmeneted");
+ iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device);
+ return iree_hal_hip_descriptor_set_layout_create(
+ flags, binding_count, bindings, device->host_allocator,
+ out_descriptor_set_layout);
}
static iree_status_t iree_hal_hip_device_create_event(
@@ -364,8 +370,10 @@
iree_host_size_t set_layout_count,
iree_hal_descriptor_set_layout_t* const* set_layouts,
iree_hal_pipeline_layout_t** out_pipeline_layout) {
- return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
- "pipeline layout not yet implmeneted");
+ iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device);
+ return iree_hal_hip_pipeline_layout_create(
+ set_layout_count, set_layouts, push_constants, device->host_allocator,
+ out_pipeline_layout);
}
static iree_status_t iree_hal_hip_device_create_semaphore(
@@ -393,8 +401,36 @@
iree_hal_allocator_pool_t pool, iree_hal_buffer_params_t params,
iree_device_size_t allocation_size,
iree_hal_buffer_t** IREE_RESTRICT out_buffer) {
- return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
- "queue alloca not yet implmeneted");
+ iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device);
+
+ // NOTE: block on the semaphores here; we could avoid this by properly
+ // sequencing device work with semaphores. The HIP HAL is not currently
+ // asynchronous.
+ IREE_RETURN_IF_ERROR(iree_hal_semaphore_list_wait(wait_semaphore_list,
+ iree_infinite_timeout()));
+
+ // Allocate from the pool; likely to fail in cases of virtual memory
+ // exhaustion but the error may be deferred until a later synchronization.
+ // If pools are not supported we allocate a buffer as normal from whatever
+ // allocator is set on the device.
+ iree_status_t status = iree_ok_status();
+ if (device->supports_memory_pools) {
+ status = iree_hal_hip_memory_pools_allocate(
+ &device->memory_pools, device->hip_stream, pool, params,
+ allocation_size, out_buffer);
+ } else {
+ status = iree_hal_allocator_allocate_buffer(
+ iree_hal_device_allocator(base_device), params, allocation_size,
+ out_buffer);
+ }
+
+ // Only signal if not returning a synchronous error - synchronous failure
+ // indicates that the stream is unchanged (it's not really since we waited
+ // above, but we at least won't deadlock like this).
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_semaphore_list_signal(signal_semaphore_list);
+ }
+ return status;
}
// TODO: implement multiple streams; today we only have one and queue_affinity
@@ -406,8 +442,29 @@
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
iree_hal_buffer_t* buffer) {
- return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
- "queue dealloca not yet implmeneted");
+ iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device);
+
+ // NOTE: block on the semaphores here; we could avoid this by properly
+ // sequencing device work with semaphores. The HIP HAL is not currently
+ // asynchronous.
+ IREE_RETURN_IF_ERROR(iree_hal_semaphore_list_wait(wait_semaphore_list,
+ iree_infinite_timeout()));
+
+ // Schedule the buffer deallocation if we got it from a pool and otherwise
+ // drop it on the floor and let it be freed when the buffer is released.
+ iree_status_t status = iree_ok_status();
+ if (device->supports_memory_pools) {
+ status = iree_hal_hip_memory_pools_deallocate(&device->memory_pools,
+ device->hip_stream, buffer);
+ }
+
+ // Only signal if not returning a synchronous error - synchronous failure
+ // indicates that the stream is unchanged (it's not really since we waited
+ // above, but we at least won't deadlock like this).
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_semaphore_list_signal(signal_semaphore_list);
+ }
+ return status;
}
static iree_status_t iree_hal_hip_device_queue_read(
@@ -438,8 +495,25 @@
const iree_hal_semaphore_list_t signal_semaphore_list,
iree_host_size_t command_buffer_count,
iree_hal_command_buffer_t* const* command_buffers) {
- return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
- "queue execution not yet implmeneted");
+ iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device);
+
+ for (iree_host_size_t i = 0; i < command_buffer_count; i++) {
+ hipGraphExec_t exec =
+ iree_hal_hip_graph_command_buffer_handle(command_buffers[i]);
+ IREE_HIP_RETURN_IF_ERROR(device->hip_symbols,
+ hipGraphLaunch(exec, device->hip_stream),
+ "hipGraphLaunch");
+ }
+
+ // TODO(nithinsubbiah): implement semaphores - for now this conservatively
+ // synchronizes after every submit.
+ IREE_TRACE_ZONE_BEGIN_NAMED(z0, "hipStreamSynchronize");
+ IREE_HIP_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, device->hip_symbols, hipStreamSynchronize(device->hip_stream),
+ "hipStreamSynchronize");
+ IREE_TRACE_ZONE_END(z0);
+
+ return iree_ok_status();
}
static iree_status_t iree_hal_hip_device_queue_flush(
diff --git a/experimental/hip/pipeline_layout.c b/experimental/hip/pipeline_layout.c
new file mode 100644
index 0000000..12cefb3
--- /dev/null
+++ b/experimental/hip/pipeline_layout.c
@@ -0,0 +1,248 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "experimental/hip/pipeline_layout.h"
+
+#include <stddef.h>
+
+#include "iree/base/api.h"
+#include "iree/base/tracing.h"
+
+//===----------------------------------------------------------------------===//
+// iree_hal_hip_descriptor_set_layout_t
+//===----------------------------------------------------------------------===//
+
+typedef struct iree_hal_hip_descriptor_set_layout_t {
+ // Abstract resource used for injecting reference counting and vtable;
+ // must be at offset 0.
+ iree_hal_resource_t resource;
+
+ // The host allocator used for creating this descriptor set layout struct.
+ iree_allocator_t host_allocator;
+
+ // The total number of bindings in this descriptor set.
+ iree_host_size_t binding_count;
+} iree_hal_hip_descriptor_set_layout_t;
+
+static const iree_hal_descriptor_set_layout_vtable_t
+ iree_hal_hip_descriptor_set_layout_vtable;
+
+static iree_hal_hip_descriptor_set_layout_t*
+iree_hal_hip_descriptor_set_layout_cast(
+ iree_hal_descriptor_set_layout_t* base_value) {
+ IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_hip_descriptor_set_layout_vtable);
+ return (iree_hal_hip_descriptor_set_layout_t*)base_value;
+}
+
+static const iree_hal_hip_descriptor_set_layout_t*
+iree_hal_hip_descriptor_set_layout_const_cast(
+ const iree_hal_descriptor_set_layout_t* base_value) {
+ IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_hip_descriptor_set_layout_vtable);
+ return (const iree_hal_hip_descriptor_set_layout_t*)base_value;
+}
+
+iree_status_t iree_hal_hip_descriptor_set_layout_create(
+ iree_hal_descriptor_set_layout_flags_t flags,
+ iree_host_size_t binding_count,
+ const iree_hal_descriptor_set_layout_binding_t* bindings,
+ iree_allocator_t host_allocator,
+ iree_hal_descriptor_set_layout_t** out_descriptor_set_layout) {
+ IREE_ASSERT_ARGUMENT(!binding_count || bindings);
+ IREE_ASSERT_ARGUMENT(out_descriptor_set_layout);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ *out_descriptor_set_layout = NULL;
+
+ iree_hal_hip_descriptor_set_layout_t* descriptor_set_layout = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_allocator_malloc(host_allocator, sizeof(*descriptor_set_layout),
+ (void**)&descriptor_set_layout));
+
+ iree_hal_resource_initialize(&iree_hal_hip_descriptor_set_layout_vtable,
+ &descriptor_set_layout->resource);
+ descriptor_set_layout->host_allocator = host_allocator;
+ descriptor_set_layout->binding_count = binding_count;
+ *out_descriptor_set_layout =
+ (iree_hal_descriptor_set_layout_t*)descriptor_set_layout;
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+iree_host_size_t iree_hal_hip_descriptor_set_layout_binding_count(
+ const iree_hal_descriptor_set_layout_t* base_descriptor_set_layout) {
+ const iree_hal_hip_descriptor_set_layout_t* descriptor_set_layout =
+ iree_hal_hip_descriptor_set_layout_const_cast(base_descriptor_set_layout);
+ return descriptor_set_layout->binding_count;
+}
+
+static void iree_hal_hip_descriptor_set_layout_destroy(
+ iree_hal_descriptor_set_layout_t* base_descriptor_set_layout) {
+ iree_hal_hip_descriptor_set_layout_t* descriptor_set_layout =
+ iree_hal_hip_descriptor_set_layout_cast(base_descriptor_set_layout);
+ iree_allocator_t host_allocator = descriptor_set_layout->host_allocator;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_allocator_free(host_allocator, descriptor_set_layout);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+static const iree_hal_descriptor_set_layout_vtable_t
+ iree_hal_hip_descriptor_set_layout_vtable = {
+ .destroy = iree_hal_hip_descriptor_set_layout_destroy,
+};
+
+//===----------------------------------------------------------------------===//
+// iree_hal_hip_pipeline_layout_t
+//===----------------------------------------------------------------------===//
+
+typedef struct iree_hal_hip_pipeline_layout_t {
+ // Abstract resource used for injecting reference counting and vtable;
+ // must be at offset 0.
+ iree_hal_resource_t resource;
+
+ // The host allocator used for creating this pipeline layout struct.
+ iree_allocator_t host_allocator;
+
+ // The kernel argument index for push constants.
+ // Note that push constants are placed after all normal descriptors.
+ iree_host_size_t push_constant_base_index;
+ iree_host_size_t push_constant_count;
+
+ iree_host_size_t set_layout_count;
+ // The list of descriptor set layout pointers, pointing to trailing inline
+ // allocation after the end of this struct.
+ struct {
+ iree_hal_descriptor_set_layout_t* set_layout;
+ // Base kernel argument index for this descriptor set.
+ iree_host_size_t base_index;
+ } set_layouts[];
+} iree_hal_hip_pipeline_layout_t;
+// + Additional inline allocation for holding all descriptor sets.
+
+static const iree_hal_pipeline_layout_vtable_t
+ iree_hal_hip_pipeline_layout_vtable;
+
+static iree_hal_hip_pipeline_layout_t* iree_hal_hip_pipeline_layout_cast(
+ iree_hal_pipeline_layout_t* base_value) {
+ IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_hip_pipeline_layout_vtable);
+ return (iree_hal_hip_pipeline_layout_t*)base_value;
+}
+
+static const iree_hal_hip_pipeline_layout_t*
+iree_hal_hip_pipeline_layout_const_cast(
+ const iree_hal_pipeline_layout_t* base_value) {
+ IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_hip_pipeline_layout_vtable);
+ return (const iree_hal_hip_pipeline_layout_t*)base_value;
+}
+
+iree_status_t iree_hal_hip_pipeline_layout_create(
+ iree_host_size_t set_layout_count,
+ iree_hal_descriptor_set_layout_t* const* set_layouts,
+ iree_host_size_t push_constant_count, iree_allocator_t host_allocator,
+ iree_hal_pipeline_layout_t** out_pipeline_layout) {
+ IREE_ASSERT_ARGUMENT(!set_layout_count || set_layouts);
+ IREE_ASSERT_ARGUMENT(out_pipeline_layout);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ *out_pipeline_layout = NULL;
+ if (push_constant_count > IREE_HAL_HIP_MAX_PUSH_CONSTANT_COUNT) {
+ IREE_TRACE_ZONE_END(z0);
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "push constant count %" PRIhsz " over the limit of %d",
+ push_constant_count, IREE_HAL_HIP_MAX_PUSH_CONSTANT_COUNT);
+ }
+
+ // Currently the pipeline layout doesn't do anything.
+ // TODO: Handle creating the argument layout at that time hadling both push
+ // constant and buffers.
+ iree_hal_hip_pipeline_layout_t* pipeline_layout = NULL;
+ iree_host_size_t total_size =
+ sizeof(*pipeline_layout) +
+ set_layout_count * sizeof(*pipeline_layout->set_layouts);
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_allocator_malloc(host_allocator, total_size,
+ (void**)&pipeline_layout));
+
+ iree_hal_resource_initialize(&iree_hal_hip_pipeline_layout_vtable,
+ &pipeline_layout->resource);
+ pipeline_layout->host_allocator = host_allocator;
+ pipeline_layout->set_layout_count = set_layout_count;
+ iree_host_size_t base_index = 0;
+ for (iree_host_size_t i = 0; i < set_layout_count; ++i) {
+ pipeline_layout->set_layouts[i].set_layout = set_layouts[i];
+ // Copy and retain all descriptor sets so we don't lose them.
+ iree_hal_descriptor_set_layout_retain(set_layouts[i]);
+ pipeline_layout->set_layouts[i].base_index = base_index;
+ base_index +=
+ iree_hal_hip_descriptor_set_layout_binding_count(set_layouts[i]);
+ }
+ pipeline_layout->push_constant_base_index = base_index;
+ pipeline_layout->push_constant_count = push_constant_count;
+ *out_pipeline_layout = (iree_hal_pipeline_layout_t*)pipeline_layout;
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+static void iree_hal_hip_pipeline_layout_destroy(
+ iree_hal_pipeline_layout_t* base_pipeline_layout) {
+ iree_hal_hip_pipeline_layout_t* pipeline_layout =
+ iree_hal_hip_pipeline_layout_cast(base_pipeline_layout);
+ iree_allocator_t host_allocator = pipeline_layout->host_allocator;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ for (iree_host_size_t i = 0; i < pipeline_layout->set_layout_count; ++i) {
+ iree_hal_descriptor_set_layout_release(
+ pipeline_layout->set_layouts[i].set_layout);
+ }
+ iree_allocator_free(host_allocator, pipeline_layout);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+const iree_hal_descriptor_set_layout_t*
+iree_hal_hip_pipeline_layout_descriptor_set_layout(
+ const iree_hal_pipeline_layout_t* base_pipeline_layout, uint32_t set) {
+ const iree_hal_hip_pipeline_layout_t* pipeline_layout =
+ iree_hal_hip_pipeline_layout_const_cast(base_pipeline_layout);
+ if (set < pipeline_layout->set_layout_count) {
+ return pipeline_layout->set_layouts[set].set_layout;
+ }
+ return NULL;
+}
+
+iree_host_size_t iree_hal_hip_pipeline_layout_base_binding_index(
+ const iree_hal_pipeline_layout_t* base_pipeline_layout, uint32_t set) {
+ const iree_hal_hip_pipeline_layout_t* pipeline_layout =
+ iree_hal_hip_pipeline_layout_const_cast(base_pipeline_layout);
+ return pipeline_layout->set_layouts[set].base_index;
+}
+
+static const iree_hal_pipeline_layout_vtable_t
+ iree_hal_hip_pipeline_layout_vtable = {
+ .destroy = iree_hal_hip_pipeline_layout_destroy,
+};
+
+//===----------------------------------------------------------------------===//
+// iree_hal_hip_dispatch_layout_t
+//===----------------------------------------------------------------------===//
+
+iree_hal_hip_dispatch_layout_t iree_hal_hip_pipeline_layout_dispatch_layout(
+ const iree_hal_pipeline_layout_t* base_pipeline_layout) {
+ const iree_hal_hip_pipeline_layout_t* pipeline_layout =
+ iree_hal_hip_pipeline_layout_const_cast(base_pipeline_layout);
+ iree_hal_hip_dispatch_layout_t dispatch_params = {
+ .push_constant_base_index = pipeline_layout->push_constant_base_index,
+ .push_constant_count = pipeline_layout->push_constant_count,
+ .total_binding_count = pipeline_layout->push_constant_base_index,
+ .set_layout_count = pipeline_layout->set_layout_count,
+ };
+
+ return dispatch_params;
+}
diff --git a/experimental/hip/pipeline_layout.h b/experimental/hip/pipeline_layout.h
new file mode 100644
index 0000000..810b467
--- /dev/null
+++ b/experimental/hip/pipeline_layout.h
@@ -0,0 +1,92 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_EXPERIMENTAL_HIP_PIPELINE_LAYOUT_H_
+#define IREE_EXPERIMENTAL_HIP_PIPELINE_LAYOUT_H_
+
+#include "iree/base/api.h"
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// The max number of bindings per descriptor set allowed in the HIP HAL
+// implementation.
+#define IREE_HAL_HIP_MAX_DESCRIPTOR_SET_BINDING_COUNT 16
+
+// The max number of descriptor sets allowed in the HIP HAL implementation.
+//
+// This depends on the general descriptor set planning in IREE and should adjust
+// with it.
+#define IREE_HAL_HIP_MAX_DESCRIPTOR_SET_COUNT 4
+
+// The max number of push constants supported by the HIP HAL implementation.
+#define IREE_HAL_HIP_MAX_PUSH_CONSTANT_COUNT 64
+
+//===----------------------------------------------------------------------===//
+// iree_hal_hip_descriptor_set_layout_t
+//===----------------------------------------------------------------------===//
+
+// Creates a descriptor set layout with the given |bindings|.
+//
+// Bindings in a descriptor set map to a list of consecutive kernel arguments in
+// HIP kernels.
+iree_status_t iree_hal_hip_descriptor_set_layout_create(
+ iree_hal_descriptor_set_layout_flags_t flags,
+ iree_host_size_t binding_count,
+ const iree_hal_descriptor_set_layout_binding_t* bindings,
+ iree_allocator_t host_allocator,
+ iree_hal_descriptor_set_layout_t** out_descriptor_set_layout);
+
+// Returns the binding count for the given descriptor set layout.
+iree_host_size_t iree_hal_hip_descriptor_set_layout_binding_count(
+ const iree_hal_descriptor_set_layout_t* descriptor_set_layout);
+
+//===----------------------------------------------------------------------===//
+// iree_hal_hip_pipeline_layout_t
+//===----------------------------------------------------------------------===//
+
+// Creates the pipeline layout with the given |set_layouts| and
+// |push_constant_count|.
+//
+// Bindings in the pipeline map to kernel arguments in HIP kernels, followed by
+// the kernel argument for the push constant data.
+iree_status_t iree_hal_hip_pipeline_layout_create(
+ iree_host_size_t set_layout_count,
+ iree_hal_descriptor_set_layout_t* const* set_layouts,
+ iree_host_size_t push_constant_count, iree_allocator_t host_allocator,
+ iree_hal_pipeline_layout_t** out_pipeline_layout);
+
+// Returns the total number of sets in the given |pipeline_layout|.
+iree_host_size_t iree_hal_hip_pipeline_layout_descriptor_set_count(
+ const iree_hal_pipeline_layout_t* pipeline_layout);
+
+// Returns the descriptor set layout of the given |set| in |pipeline_layout|.
+const iree_hal_descriptor_set_layout_t*
+iree_hal_hip_pipeline_layout_descriptor_set_layout(
+ const iree_hal_pipeline_layout_t* pipeline_layout, uint32_t set);
+
+// Returns the base kernel argument index for the given set.
+iree_host_size_t iree_hal_hip_pipeline_layout_base_binding_index(
+ const iree_hal_pipeline_layout_t* pipeline_layout, uint32_t set);
+
+typedef struct iree_hal_hip_dispatch_layout_t {
+ iree_host_size_t push_constant_base_index;
+ iree_host_size_t push_constant_count;
+ iree_host_size_t set_layout_count;
+ iree_host_size_t total_binding_count;
+} iree_hal_hip_dispatch_layout_t;
+
+// Returns dispatch layout parameters in a struct form for pipeline layout.
+iree_hal_hip_dispatch_layout_t iree_hal_hip_pipeline_layout_dispatch_layout(
+ const iree_hal_pipeline_layout_t* base_pipeline_layout);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_EXPERIMENTAL_HIP_PIPELINE_LAYOUT_H_