blob: 35d33f5b6d8008b9dfb32752ff6250fddc95b80b [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "experimental/rocm/native_executable.h"
#include <stddef.h>
#include "experimental/rocm/dynamic_symbols.h"
#include "experimental/rocm/pipeline_layout.h"
#include "experimental/rocm/status_util.h"
#include "iree/base/api.h"
// flatcc schemas:
#include "iree/base/internal/flatcc/parsing.h"
#include "iree/schemas/rocm_executable_def_reader.h"
#include "iree/schemas/rocm_executable_def_verifier.h"
typedef struct iree_hal_rocm_entry_point_t {
iree_hal_rocm_kernel_params_t kernel_params;
iree_string_view_t name;
IREE_TRACE(iree_hal_rocm_FileLineLocDef_table_t source_location;)
IREE_TRACE(iree_hal_rocm_StageLocationDef_vec_t stage_locations;)
} iree_hal_rocm_entry_point_t;
typedef struct iree_hal_rocm_native_executable_t {
iree_hal_resource_t resource;
iree_hal_rocm_context_wrapper_t* context;
iree_hal_pipeline_layout_t** pipeline_layouts;
iree_host_size_t entry_count;
hipModule_t module;
iree_host_size_t entry_point_count;
iree_hal_rocm_entry_point_t entry_points[];
} iree_hal_rocm_native_executable_t;
static const iree_hal_executable_vtable_t
iree_hal_rocm_native_executable_vtable;
static iree_hal_rocm_native_executable_t* iree_hal_rocm_native_executable_cast(
iree_hal_executable_t* base_value) {
IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_native_executable_vtable);
return (iree_hal_rocm_native_executable_t*)base_value;
}
iree_status_t iree_hal_rocm_native_executable_create(
iree_hal_rocm_context_wrapper_t* context,
const iree_hal_executable_params_t* executable_params,
iree_hal_executable_t** out_executable) {
IREE_ASSERT_ARGUMENT(context);
IREE_ASSERT_ARGUMENT(executable_params);
IREE_ASSERT_ARGUMENT(out_executable);
*out_executable = NULL;
IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_rocm_native_executable_t* executable = NULL;
// TODO: Verify the flat buffer.
iree_hal_rocm_ExecutableDef_table_t executable_def =
iree_hal_rocm_ExecutableDef_as_root(
executable_params->executable_data.data);
// Create the kernel module.
flatbuffers_string_t hsaco_image =
iree_hal_rocm_ExecutableDef_hsaco_image_get(executable_def);
flatbuffers_string_vec_t entry_points_vec =
iree_hal_rocm_ExecutableDef_entry_points_get(executable_def);
iree_hal_rocm_BlockSizeDef_vec_t block_sizes_vec =
iree_hal_rocm_ExecutableDef_block_sizes_get(executable_def);
flatbuffers_uint32_vec_t shared_memory_sizes =
iree_hal_rocm_ExecutableDef_shared_memory_sizes_get(executable_def);
iree_host_size_t entry_count = flatbuffers_string_vec_len(entry_points_vec);
iree_host_size_t total_size =
sizeof(*executable) + entry_count * sizeof(executable->entry_points[0]);
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_allocator_malloc(context->host_allocator, total_size,
(void**)&executable));
iree_hal_resource_initialize(&iree_hal_rocm_native_executable_vtable,
&executable->resource);
executable->context = context;
executable->entry_point_count = entry_count;
iree_status_t status = ROCM_RESULT_TO_STATUS(
context->syms,
hipModuleLoadDataEx(&executable->module, hsaco_image, 0, NULL, NULL),
"hipModuleLoadDataEx");
if (!iree_status_is_ok(status)) {
status = iree_status_annotate(
status,
IREE_SV("mismatched target chip? missing/wrong bitcode directory?"));
}
// Query allowed max shared memory.
int32_t max_shared_mem = 0;
if (iree_status_is_ok(status)) {
status = ROCM_RESULT_TO_STATUS(
context->syms,
hipDeviceGetAttribute(&max_shared_mem,
hipDeviceAttributeMaxSharedMemoryPerBlock,
context->rocm_device),
"hipDeviceGetAttribute");
}
if (iree_status_is_ok(status)) {
executable->entry_count = entry_count;
for (iree_host_size_t i = 0; i < entry_count; i++) {
if (iree_status_is_ok(status)) {
hipFunction_t function = NULL;
flatbuffers_string_t entry_name =
flatbuffers_string_vec_at(entry_points_vec, i);
status = ROCM_RESULT_TO_STATUS(
context->syms,
hipModuleGetFunction(&function, executable->module, entry_name),
"hipModuleGetFunction");
if (!iree_status_is_ok(status)) break;
if (!function) {
status = iree_make_status(IREE_STATUS_NOT_FOUND,
"exported module function %s not found",
entry_name);
break;
}
if (shared_memory_sizes[i] > max_shared_mem) {
status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"function '%s' requested shared memory "
"size of %d larger than allowed size of %d",
entry_name, shared_memory_sizes[i],
max_shared_mem);
} else if (shared_memory_sizes[i] != 0) {
status = ROCM_RESULT_TO_STATUS(
context->syms,
hipFuncSetAttribute(
function,
(hipFuncAttribute)
HIP_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_memory_sizes[i]),
"hipFuncSetAttribute");
}
// Package required parameters for kernel launches for each entry point.
iree_hal_rocm_entry_point_t* entry_point_info =
&executable->entry_points[i];
iree_hal_rocm_kernel_params_t* params =
&entry_point_info->kernel_params;
params->layout = executable_params->pipeline_layouts[i];
iree_hal_pipeline_layout_retain(params->layout);
params->function = function;
params->block_size[0] = block_sizes_vec[i].x;
params->block_size[1] = block_sizes_vec[i].y;
params->block_size[2] = block_sizes_vec[i].z;
params->shared_memory_size = shared_memory_sizes[i];
entry_point_info->name = iree_make_string_view(
entry_name, flatbuffers_string_len(entry_name));
}
}
#if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION
if (iree_status_is_ok(status)) {
if (iree_hal_rocm_ExecutableDef_source_locations_is_present(
executable_def)) {
iree_hal_rocm_FileLineLocDef_vec_t source_locations_vec =
iree_hal_rocm_ExecutableDef_source_locations_get(executable_def);
for (iree_host_size_t i = 0; i < entry_count; ++i) {
executable->entry_points[i].source_location =
iree_hal_rocm_FileLineLocDef_vec_at(source_locations_vec, i);
}
}
if (iree_hal_rocm_ExecutableDef_stage_locations_is_present(
executable_def)) {
iree_hal_rocm_StageLocationsDef_vec_t stage_locations_vec =
iree_hal_rocm_ExecutableDef_stage_locations_get(executable_def);
for (iree_host_size_t i = 0; i < entry_count; ++i) {
iree_hal_rocm_StageLocationsDef_table_t stage_locations =
iree_hal_rocm_StageLocationsDef_vec_at(stage_locations_vec, i);
executable->entry_points[i].stage_locations =
iree_hal_rocm_StageLocationsDef_locations_get(stage_locations);
}
}
// Publish any embedded source files to the tracing infrastructure.
if (iree_hal_rocm_ExecutableDef_source_files_is_present(executable_def)) {
iree_hal_rocm_SourceFileDef_vec_t source_files_vec =
iree_hal_rocm_ExecutableDef_source_files_get(executable_def);
for (iree_host_size_t i = 0;
i < iree_hal_rocm_SourceFileDef_vec_len(source_files_vec); ++i) {
iree_hal_rocm_SourceFileDef_table_t source_file =
iree_hal_rocm_SourceFileDef_vec_at(source_files_vec, i);
flatbuffers_string_t path =
iree_hal_rocm_SourceFileDef_path_get(source_file);
flatbuffers_uint8_vec_t content =
iree_hal_rocm_SourceFileDef_content_get(source_file);
IREE_TRACE_PUBLISH_SOURCE_FILE(path, flatbuffers_string_len(path),
content,
flatbuffers_uint8_vec_len(content));
}
}
}
#endif // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION
}
if (iree_status_is_ok(status)) {
*out_executable = (iree_hal_executable_t*)executable;
} else {
if (executable) {
iree_hal_executable_destroy((iree_hal_executable_t*)executable);
}
}
IREE_TRACE_ZONE_END(z0);
return status;
}
hipFunction_t iree_hal_rocm_native_executable_for_entry_point(
iree_hal_executable_t* base_executable, int32_t entry_point) {
iree_hal_rocm_native_executable_t* executable =
iree_hal_rocm_native_executable_cast(base_executable);
return executable->entry_points[entry_point].kernel_params.function;
}
static void iree_hal_rocm_native_executable_destroy(
iree_hal_executable_t* base_executable) {
iree_hal_rocm_native_executable_t* executable =
iree_hal_rocm_native_executable_cast(base_executable);
iree_allocator_t host_allocator = executable->context->host_allocator;
IREE_TRACE_ZONE_BEGIN(z0);
if (executable->module) {
iree_status_t status = ROCM_RESULT_TO_STATUS(
executable->context->syms, hipModuleUnload(executable->module),
"hipModuleUnload");
if (!iree_status_is_ok(status)) {
fprintf(stderr, "Failed unloading ROCm module: ");
iree_status_fprint(stderr, status);
iree_status_free(status);
}
}
if (executable->pipeline_layouts) {
for (iree_host_size_t i = 0; i < executable->entry_count; ++i) {
if (executable->pipeline_layouts[i]) {
iree_hal_pipeline_layout_release(executable->pipeline_layouts[i]);
}
}
}
iree_allocator_free(host_allocator, executable);
IREE_TRACE_ZONE_END(z0);
}
iree_status_t iree_hal_rocm_native_executable_entry_point_kernel_params(
iree_hal_executable_t* base_executable, int32_t entry_point,
iree_hal_rocm_kernel_params_t* out_params) {
iree_hal_rocm_native_executable_t* executable =
iree_hal_rocm_native_executable_cast(base_executable);
if (entry_point >= executable->entry_count) {
return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
"invalid entry point ordinal %d", entry_point);
}
memcpy(out_params, &executable->entry_points[entry_point],
sizeof(*out_params));
return iree_ok_status();
}
void iree_hal_rocm_native_executable_entry_point_source_location(
iree_hal_executable_t* base_executable, iree_host_size_t entry_ordinal,
iree_hal_rocm_source_location_t* out_source_location) {
iree_hal_rocm_native_executable_t* executable =
iree_hal_rocm_native_executable_cast(base_executable);
memset(out_source_location, 0, sizeof(*out_source_location));
if (entry_ordinal >= executable->entry_point_count) {
return;
}
const iree_hal_rocm_entry_point_t* entry_point =
&executable->entry_points[entry_ordinal];
out_source_location->func_name = entry_point->name;
#if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION
iree_hal_rocm_FileLineLocDef_table_t source_location =
entry_point->source_location;
if (entry_point->stage_locations) {
for (size_t i = 0; i < iree_hal_rocm_StageLocationDef_vec_len(
entry_point->stage_locations);
++i) {
iree_hal_rocm_StageLocationDef_table_t stage_location =
iree_hal_rocm_StageLocationDef_vec_at(entry_point->stage_locations,
i);
// TODO(benvanik): a way to select what location is chosen. For now we
// just pick the first one.
source_location =
iree_hal_rocm_StageLocationDef_location_get(stage_location);
break;
}
}
if (source_location) {
flatbuffers_string_t filename =
iree_hal_rocm_FileLineLocDef_filename_get(source_location);
out_source_location->file_name =
iree_make_string_view(filename, flatbuffers_string_len(filename));
out_source_location->line =
iree_hal_rocm_FileLineLocDef_line_get(source_location);
} else {
out_source_location->file_name = out_source_location->func_name;
out_source_location->line = 0;
}
#else
out_source_location->file_name = out_source_location->func_name;
out_source_location->line = 0;
#endif // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION
}
static const iree_hal_executable_vtable_t
iree_hal_rocm_native_executable_vtable = {
.destroy = iree_hal_rocm_native_executable_destroy,
};