blob: 6d836762c8bebf058dadc132b17fbf96eaa6f33e [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "experimental/rocm/direct_command_buffer.h"
#include <assert.h>
#include <stddef.h>
#include <stdint.h>
#include "experimental/rocm/dynamic_symbols.h"
#include "experimental/rocm/executable_layout.h"
#include "experimental/rocm/native_executable.h"
#include "experimental/rocm/rocm_buffer.h"
#include "experimental/rocm/status_util.h"
#include "iree/base/api.h"
#include "iree/base/tracing.h"
// Command buffer implementation that directly maps to rocm direct.
// This records the commands on the calling thread without additional threading
// indirection.
typedef struct {
iree_hal_command_buffer_t base;
iree_hal_rocm_context_wrapper_t* context;
iree_arena_block_pool_t* block_pool;
// Keep track of the current set of kernel arguments.
int32_t push_constant[IREE_HAL_ROCM_MAX_PUSH_CONSTANT_COUNT];
void* current_descriptor[];
} iree_hal_rocm_direct_command_buffer_t;
#define IREE_HAL_ROCM_MAX_BINDING_COUNT 64
// Kernel arguments contains binding and push constants.
#define IREE_HAL_ROCM_MAX_KERNEL_ARG 128
static const iree_hal_command_buffer_vtable_t
iree_hal_rocm_direct_command_buffer_vtable;
static iree_hal_rocm_direct_command_buffer_t*
iree_hal_rocm_direct_command_buffer_cast(
iree_hal_command_buffer_t* base_value) {
IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_direct_command_buffer_vtable);
return (iree_hal_rocm_direct_command_buffer_t*)base_value;
}
iree_status_t iree_hal_rocm_direct_command_buffer_create(
iree_hal_device_t* device, iree_hal_rocm_context_wrapper_t* context,
iree_hal_command_buffer_mode_t mode,
iree_hal_command_category_t command_categories,
iree_hal_queue_affinity_t queue_affinity,
iree_arena_block_pool_t* block_pool,
iree_hal_command_buffer_t** out_command_buffer) {
IREE_ASSERT_ARGUMENT(context);
IREE_ASSERT_ARGUMENT(block_pool);
IREE_ASSERT_ARGUMENT(out_command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_rocm_direct_command_buffer_t* command_buffer = NULL;
size_t total_size = sizeof(*command_buffer) +
IREE_HAL_ROCM_MAX_KERNEL_ARG * sizeof(void*) +
IREE_HAL_ROCM_MAX_KERNEL_ARG * sizeof(hipDeviceptr_t);
iree_status_t status = iree_allocator_malloc(
context->host_allocator, total_size, (void**)&command_buffer);
if (iree_status_is_ok(status)) {
iree_hal_command_buffer_initialize(
device, mode, command_categories, queue_affinity,
&iree_hal_rocm_direct_command_buffer_vtable, &command_buffer->base);
command_buffer->context = context;
command_buffer->block_pool = block_pool;
hipDeviceptr_t* device_ptrs =
(hipDeviceptr_t*)(command_buffer->current_descriptor +
IREE_HAL_ROCM_MAX_KERNEL_ARG);
for (size_t i = 0; i < IREE_HAL_ROCM_MAX_KERNEL_ARG; i++) {
command_buffer->current_descriptor[i] = &device_ptrs[i];
}
*out_command_buffer = &command_buffer->base;
}
IREE_TRACE_ZONE_END(z0);
return status;
}
static void iree_hal_rocm_direct_command_buffer_destroy(
iree_hal_command_buffer_t* base_command_buffer) {
iree_hal_rocm_direct_command_buffer_t* command_buffer =
iree_hal_rocm_direct_command_buffer_cast(base_command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
iree_allocator_free(command_buffer->context->host_allocator, command_buffer);
IREE_TRACE_ZONE_END(z0);
}
bool iree_hal_rocm_direct_command_buffer_isa(
iree_hal_command_buffer_t* command_buffer) {
return iree_hal_command_buffer_dyn_cast(
command_buffer, &iree_hal_rocm_direct_command_buffer_vtable);
}
static void* iree_hal_rocm_direct_command_buffer_dyn_cast(
iree_hal_command_buffer_t* command_buffer, const void* vtable) {
if (vtable == &iree_hal_rocm_direct_command_buffer_vtable) {
IREE_HAL_ASSERT_TYPE(command_buffer, vtable);
return command_buffer;
}
return NULL;
}
static iree_status_t iree_hal_rocm_direct_command_buffer_begin(
iree_hal_command_buffer_t* base_command_buffer) {
return iree_ok_status();
}
static iree_status_t iree_hal_rocm_direct_command_buffer_end(
iree_hal_command_buffer_t* base_command_buffer) {
return iree_ok_status();
}
static void iree_hal_rocm_direct_command_buffer_begin_debug_group(
iree_hal_command_buffer_t* base_command_buffer, iree_string_view_t label,
iree_hal_label_color_t label_color,
const iree_hal_label_location_t* location) {
// TODO(benvanik): tracy event stack.
}
static void iree_hal_rocm_direct_command_buffer_end_debug_group(
iree_hal_command_buffer_t* base_command_buffer) {
// TODO(benvanik): tracy event stack.
}
static iree_status_t iree_hal_rocm_direct_command_buffer_execution_barrier(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_execution_stage_t source_stage_mask,
iree_hal_execution_stage_t target_stage_mask,
iree_hal_execution_barrier_flags_t flags,
iree_host_size_t memory_barrier_count,
const iree_hal_memory_barrier_t* memory_barriers,
iree_host_size_t buffer_barrier_count,
const iree_hal_buffer_barrier_t* buffer_barriers) {
// TODO: Implement barrier
return iree_ok_status();
}
static iree_status_t iree_hal_rocm_direct_command_buffer_signal_event(
iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event,
iree_hal_execution_stage_t source_stage_mask) {
// TODO: Implement barrier
return iree_ok_status();
}
static iree_status_t iree_hal_rocm_direct_command_buffer_reset_event(
iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event,
iree_hal_execution_stage_t source_stage_mask) {
// TODO: Implement barrier
return iree_ok_status();
}
static iree_status_t iree_hal_rocm_direct_command_buffer_wait_events(
iree_hal_command_buffer_t* base_command_buffer,
iree_host_size_t event_count, const iree_hal_event_t** events,
iree_hal_execution_stage_t source_stage_mask,
iree_hal_execution_stage_t target_stage_mask,
iree_host_size_t memory_barrier_count,
const iree_hal_memory_barrier_t* memory_barriers,
iree_host_size_t buffer_barrier_count,
const iree_hal_buffer_barrier_t* buffer_barriers) {
// TODO: Implement barrier
return iree_ok_status();
}
static iree_status_t iree_hal_rocm_direct_command_buffer_discard_buffer(
iree_hal_command_buffer_t* base_command_buffer, iree_hal_buffer_t* buffer) {
// nothing to do.
return iree_ok_status();
}
// Splats a pattern value of 1, 2, or 4 bytes out to a 4 byte value.
static uint32_t iree_hal_rocm_splat_pattern(const void* pattern,
size_t pattern_length) {
switch (pattern_length) {
case 1: {
uint32_t pattern_value = *(const uint8_t*)(pattern);
return (pattern_value << 24) | (pattern_value << 16) |
(pattern_value << 8) | pattern_value;
}
case 2: {
uint32_t pattern_value = *(const uint16_t*)(pattern);
return (pattern_value << 16) | pattern_value;
}
case 4: {
uint32_t pattern_value = *(const uint32_t*)(pattern);
return pattern_value;
}
default:
return 0; // Already verified that this should not be possible.
}
}
static iree_status_t iree_hal_rocm_direct_command_buffer_fill_buffer(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
iree_device_size_t length, const void* pattern,
iree_host_size_t pattern_length) {
iree_hal_rocm_direct_command_buffer_t* command_buffer =
iree_hal_rocm_direct_command_buffer_cast(base_command_buffer);
hipDeviceptr_t target_device_buffer = iree_hal_rocm_buffer_device_pointer(
iree_hal_buffer_allocated_buffer(target_buffer));
target_offset += iree_hal_buffer_byte_offset(target_buffer);
uint32_t dword_pattern = iree_hal_rocm_splat_pattern(pattern, pattern_length);
hipDeviceptr_t dst = target_device_buffer + target_offset;
int value = dword_pattern;
size_t sizeBytes = length;
// TODO(raikonenfnu): Currently using NULL stream, need to figure out way to
// access proper stream from command buffer
ROCM_RETURN_IF_ERROR(command_buffer->context->syms,
hipMemsetAsync(dst, value, sizeBytes, 0),
"hipMemsetAsync");
return iree_ok_status();
}
static iree_status_t iree_hal_rocm_direct_command_buffer_update_buffer(
iree_hal_command_buffer_t* base_command_buffer, const void* source_buffer,
iree_host_size_t source_offset, iree_hal_buffer_t* target_buffer,
iree_device_size_t target_offset, iree_device_size_t length) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"need rocm implementation");
}
static iree_status_t iree_hal_rocm_direct_command_buffer_copy_buffer(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
iree_device_size_t length) {
iree_hal_rocm_direct_command_buffer_t* command_buffer =
iree_hal_rocm_direct_command_buffer_cast(base_command_buffer);
hipDeviceptr_t target_device_buffer = iree_hal_rocm_buffer_device_pointer(
iree_hal_buffer_allocated_buffer(target_buffer));
target_offset += iree_hal_buffer_byte_offset(target_buffer);
hipDeviceptr_t source_device_buffer = iree_hal_rocm_buffer_device_pointer(
iree_hal_buffer_allocated_buffer(source_buffer));
source_offset += iree_hal_buffer_byte_offset(source_buffer);
// TODO(raikonenfnu): Currently using NULL stream, need to figure out way to
// access proper stream from command buffer
ROCM_RETURN_IF_ERROR(
command_buffer->context->syms,
hipMemcpyAsync(target_device_buffer, source_device_buffer, length,
hipMemcpyDeviceToDevice, 0),
"hipMemcpyAsync");
return iree_ok_status();
}
static iree_status_t iree_hal_rocm_direct_command_buffer_push_constants(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_layout_t* executable_layout, iree_host_size_t offset,
const void* values, iree_host_size_t values_length) {
iree_hal_rocm_direct_command_buffer_t* command_buffer =
iree_hal_rocm_direct_command_buffer_cast(base_command_buffer);
iree_host_size_t constant_base_index = offset / sizeof(int32_t);
for (iree_host_size_t i = 0; i < values_length / sizeof(int32_t); i++) {
command_buffer->push_constant[i + constant_base_index] =
((uint32_t*)values)[i];
}
return iree_ok_status();
}
// Tie together the binding index and its index in |bindings| array.
typedef struct {
uint32_t index;
uint32_t binding;
} iree_hal_rocm_binding_mapping_t;
// Helper to sort the binding based on their binding index.
static int compare_binding_index(const void* a, const void* b) {
const iree_hal_rocm_binding_mapping_t buffer_a =
*(const iree_hal_rocm_binding_mapping_t*)a;
const iree_hal_rocm_binding_mapping_t buffer_b =
*(const iree_hal_rocm_binding_mapping_t*)b;
return buffer_a.binding < buffer_b.binding ? -1 : 1;
}
static iree_status_t iree_hal_rocm_direct_command_buffer_push_descriptor_set(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_layout_t* executable_layout, uint32_t set,
iree_host_size_t binding_count,
const iree_hal_descriptor_set_binding_t* bindings) {
iree_hal_rocm_direct_command_buffer_t* command_buffer =
iree_hal_rocm_direct_command_buffer_cast(base_command_buffer);
iree_host_size_t base_binding =
iree_hal_rocm_base_binding_index(executable_layout, set);
// Convention with the compiler side. We map bindings to kernel argument.
// We compact the bindings to get a dense set of arguments and keep them order
// based on the binding index.
// Sort the binding based on the binding index and map the array index to the
// argument index.
iree_hal_rocm_binding_mapping_t binding_used[IREE_HAL_ROCM_MAX_BINDING_COUNT];
for (iree_host_size_t i = 0; i < binding_count; i++) {
iree_hal_rocm_binding_mapping_t buffer = {i, bindings[i].binding};
binding_used[i] = buffer;
}
qsort(binding_used, binding_count, sizeof(iree_hal_rocm_binding_mapping_t),
compare_binding_index);
assert(binding_count < IREE_HAL_ROCM_MAX_BINDING_COUNT &&
"binding count larger than the max expected.");
for (iree_host_size_t i = 0; i < binding_count; i++) {
iree_hal_descriptor_set_binding_t binding = bindings[binding_used[i].index];
hipDeviceptr_t device_ptr =
iree_hal_rocm_buffer_device_pointer(
iree_hal_buffer_allocated_buffer(binding.buffer)) +
iree_hal_buffer_byte_offset(binding.buffer) + binding.offset;
*((hipDeviceptr_t*)command_buffer->current_descriptor[i + base_binding]) =
device_ptr;
}
return iree_ok_status();
}
static iree_status_t iree_hal_rocm_direct_command_buffer_bind_descriptor_set(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_layout_t* executable_layout, uint32_t set,
iree_hal_descriptor_set_t* descriptor_set,
iree_host_size_t dynamic_offset_count,
const iree_device_size_t* dynamic_offsets) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"need rocm implementation");
}
static iree_status_t iree_hal_rocm_direct_command_buffer_dispatch(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
iree_hal_rocm_direct_command_buffer_t* command_buffer =
iree_hal_rocm_direct_command_buffer_cast(base_command_buffer);
iree_hal_rocm_direct_command_buffer_cast(base_command_buffer);
iree_hal_executable_layout_t* layout =
iree_hal_rocm_executable_get_layout(executable, entry_point);
iree_host_size_t num_constants =
iree_hal_rocm_executable_layout_num_constants(layout);
iree_host_size_t constant_base_index =
iree_hal_rocm_push_constant_index(layout);
// Patch the push constants in the kernel arguments.
for (iree_host_size_t i = 0; i < num_constants; i++) {
*((uint32_t*)command_buffer->current_descriptor[i + constant_base_index]) =
command_buffer->push_constant[i];
}
int32_t block_size_x, block_size_y, block_size_z;
IREE_RETURN_IF_ERROR(iree_hal_rocm_native_executable_block_size(
executable, entry_point, &block_size_x, &block_size_y, &block_size_z));
hipFunction_t func =
iree_hal_rocm_native_executable_for_entry_point(executable, entry_point);
// TODO(raikonenfnu): Currently using NULL stream, need to figure out way to
// access proper stream from command buffer
ROCM_RETURN_IF_ERROR(
command_buffer->context->syms,
hipModuleLaunchKernel(func, workgroup_x, workgroup_y, workgroup_z,
block_size_x, block_size_y, block_size_z, 0, 0,
command_buffer->current_descriptor, NULL),
"hipModuleLaunchKernel");
return iree_ok_status();
}
static iree_status_t iree_hal_rocm_direct_command_buffer_dispatch_indirect(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
iree_hal_buffer_t* workgroups_buffer,
iree_device_size_t workgroups_offset) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"need rocm implementation");
}
static const iree_hal_command_buffer_vtable_t
iree_hal_rocm_direct_command_buffer_vtable = {
.destroy = iree_hal_rocm_direct_command_buffer_destroy,
.dyn_cast = iree_hal_rocm_direct_command_buffer_dyn_cast,
.begin = iree_hal_rocm_direct_command_buffer_begin,
.end = iree_hal_rocm_direct_command_buffer_end,
.begin_debug_group =
iree_hal_rocm_direct_command_buffer_begin_debug_group,
.end_debug_group = iree_hal_rocm_direct_command_buffer_end_debug_group,
.execution_barrier =
iree_hal_rocm_direct_command_buffer_execution_barrier,
.signal_event = iree_hal_rocm_direct_command_buffer_signal_event,
.reset_event = iree_hal_rocm_direct_command_buffer_reset_event,
.wait_events = iree_hal_rocm_direct_command_buffer_wait_events,
.discard_buffer = iree_hal_rocm_direct_command_buffer_discard_buffer,
.fill_buffer = iree_hal_rocm_direct_command_buffer_fill_buffer,
.update_buffer = iree_hal_rocm_direct_command_buffer_update_buffer,
.copy_buffer = iree_hal_rocm_direct_command_buffer_copy_buffer,
.push_constants = iree_hal_rocm_direct_command_buffer_push_constants,
.push_descriptor_set =
iree_hal_rocm_direct_command_buffer_push_descriptor_set,
.bind_descriptor_set =
iree_hal_rocm_direct_command_buffer_bind_descriptor_set,
.dispatch = iree_hal_rocm_direct_command_buffer_dispatch,
.dispatch_indirect =
iree_hal_rocm_direct_command_buffer_dispatch_indirect,
};