blob: 18b3d378ed9a325d445cc51e64c2c397561dd291 [file] [log] [blame]
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/hal/drivers/hip/native_executable.h"
#include <stddef.h>
#include "iree/base/api.h"
#include "iree/hal/drivers/hip/dynamic_symbols.h"
#include "iree/hal/drivers/hip/status_util.h"
#include "iree/hal/utils/executable_debug_info.h"
// flatcc schemas:
#include "iree/base/internal/flatcc/parsing.h"
#include "iree/schemas/executable_debug_info_reader.h"
#include "iree/schemas/executable_debug_info_verifier.h"
#include "iree/schemas/hip_executable_def_reader.h"
#include "iree/schemas/hip_executable_def_verifier.h"
typedef struct iree_hal_hip_native_executable_t {
// Abstract resource used for injecting reference counting and vtable;
// must be at offset 0.
iree_hal_resource_t resource;
iree_allocator_t host_allocator;
const iree_hal_hip_dynamic_symbols_t* symbols;
// Loaded HIP modules.
iree_host_size_t module_count;
hipModule_t* modules;
// Exported kernels referencing the loaded modules.
iree_host_size_t export_count;
iree_hal_hip_kernel_params_t exports[];
} iree_hal_hip_native_executable_t;
static const iree_hal_executable_vtable_t iree_hal_hip_native_executable_vtable;
static iree_hal_hip_native_executable_t* iree_hal_hip_native_executable_cast(
iree_hal_executable_t* base_value) {
IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_hip_native_executable_vtable);
return (iree_hal_hip_native_executable_t*)base_value;
}
typedef struct iree_hal_hip_limits_t {
uint32_t max_block_dims[3];
uint32_t max_block_shared_memory_size;
} iree_hal_hip_limits_t;
static iree_status_t iree_hal_hip_query_limits(
const iree_hal_hip_dynamic_symbols_t* symbols, hipDevice_t device,
iree_hal_hip_limits_t* out_limits) {
memset(out_limits, 0, sizeof(*out_limits));
IREE_HIP_RETURN_IF_ERROR(
symbols,
hipDeviceGetAttribute(&out_limits->max_block_dims[0],
hipDeviceAttributeMaxBlockDimX, device),
"hipDeviceGetAttribute");
IREE_HIP_RETURN_IF_ERROR(
symbols,
hipDeviceGetAttribute(&out_limits->max_block_dims[1],
hipDeviceAttributeMaxBlockDimY, device),
"hipDeviceGetAttribute");
IREE_HIP_RETURN_IF_ERROR(
symbols,
hipDeviceGetAttribute(&out_limits->max_block_dims[2],
hipDeviceAttributeMaxBlockDimZ, device),
"hipDeviceGetAttribute");
IREE_HIP_RETURN_IF_ERROR(
symbols,
hipDeviceGetAttribute(&out_limits->max_block_shared_memory_size,
hipDeviceAttributeMaxSharedMemoryPerBlock, device),
"hipDeviceGetAttribute");
return iree_ok_status();
}
// Verifies the structure of the flatbuffer so that we can avoid doing so during
// runtime.
//
// There are still some conditions we must be aware of (such as omitted names on
// functions with internal linkage), however we shouldn't need to bounds check
// anything within the flatbuffer after this succeeds.
static iree_status_t iree_hal_hip_native_executable_flatbuffer_verify(
iree_const_byte_span_t flatbuffer_data,
const iree_hal_hip_limits_t* limits) {
if (!flatbuffer_data.data) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"flatbuffer data is not present");
}
// Run flatcc generated verification. This ensures all pointers are in-bounds
// and that we can safely walk the file, but not that the actual contents of
// the flatbuffer meet our expectations.
int verify_ret = iree_hal_hip_ExecutableDef_verify_as_root(
flatbuffer_data.data, flatbuffer_data.data_length);
if (verify_ret != flatcc_verify_ok) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"flatbuffer verification failed: %s",
flatcc_verify_error_string(verify_ret));
}
iree_hal_hip_ExecutableDef_table_t executable_def =
iree_hal_hip_ExecutableDef_as_root(flatbuffer_data.data);
iree_hal_hip_ModuleDef_vec_t modules_vec =
iree_hal_hip_ExecutableDef_modules_get(executable_def);
iree_host_size_t module_count = iree_hal_hip_ModuleDef_vec_len(modules_vec);
for (iree_host_size_t i = 0; i < module_count; ++i) {
iree_hal_hip_ModuleDef_table_t module_def =
iree_hal_hip_ModuleDef_vec_at(modules_vec, i);
if (!module_def) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"modules[%" PRIhsz "] is NULL", i);
}
if (flatbuffers_string_len(
iree_hal_hip_ModuleDef_hsaco_image_get(module_def)) == 0) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"modules[%" PRIhsz "] contents are empty", i);
}
}
iree_hal_hip_ExportDef_vec_t exports_vec =
iree_hal_hip_ExecutableDef_exports_get(executable_def);
for (iree_host_size_t i = 0; i < iree_hal_hip_ExportDef_vec_len(exports_vec);
++i) {
iree_hal_hip_ExportDef_table_t export_def =
iree_hal_hip_ExportDef_vec_at(exports_vec, i);
if (!export_def) continue;
uint32_t module_ordinal =
iree_hal_hip_ExportDef_module_ordinal_get(export_def);
if (module_ordinal >= module_count) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"exports[%" PRIhsz
"] module_ordinal %u is out of bounds %" PRIhsz,
i, module_ordinal, module_count);
}
if (flatbuffers_string_len(
iree_hal_hip_ExportDef_kernel_name_get(export_def)) == 0) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"exports[%" PRIhsz "] name is empty", i);
}
if (iree_hal_hip_ExportDef_block_dims_is_present(export_def)) {
const iree_hal_hip_BlockDims_t* block_dims =
iree_hal_hip_ExportDef_block_dims_get(export_def);
if (block_dims->x > limits->max_block_dims[0] ||
block_dims->y > limits->max_block_dims[1] ||
block_dims->z > limits->max_block_dims[2]) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"exports[%" PRIhsz
"] block dims %ux%ux%u exceeds device maximum %ux%ux%u",
i, block_dims->x, block_dims->y, block_dims->z,
limits->max_block_dims[0], limits->max_block_dims[1],
limits->max_block_dims[2]);
}
} else {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"exports[%" PRIhsz "] blocks dims are missing",
i);
}
uint32_t block_shared_memory_size =
iree_hal_hip_ExportDef_block_shared_memory_size_get(export_def);
if (block_shared_memory_size > limits->max_block_shared_memory_size) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"exports[%" PRIhsz
"] requires %uB of shared memory and "
"exceeds the device maximum of %uB per block",
i, block_shared_memory_size,
limits->max_block_shared_memory_size);
}
uint32_t constant_count =
iree_hal_hip_ExportDef_constant_count_get(export_def);
if (constant_count > IREE_HAL_HIP_MAX_DISPATCH_CONSTANT_COUNT) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"exports[%" PRIhsz "] constant_count %u exceeds maximum of %u", i,
constant_count, IREE_HAL_HIP_MAX_DISPATCH_CONSTANT_COUNT);
}
iree_hal_hip_BindingBits_vec_t binding_flags_vec =
iree_hal_hip_ExportDef_binding_flags_get(export_def);
if (iree_hal_hip_BindingBits_vec_len(binding_flags_vec) >
IREE_HAL_HIP_MAX_DISPATCH_BINDING_COUNT) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"exports[%" PRIhsz "] binding_flags count %zu exceeds maximum of %u",
i, iree_hal_hip_BindingBits_vec_len(binding_flags_vec),
IREE_HAL_HIP_MAX_DISPATCH_BINDING_COUNT);
}
IREE_RETURN_IF_ERROR(iree_hal_debug_verify_export_def(
iree_hal_hip_ExportDef_debug_info_get(export_def)));
}
return iree_ok_status();
}
iree_status_t iree_hal_hip_native_executable_create(
const iree_hal_hip_dynamic_symbols_t* symbols, hipDevice_t device,
const iree_hal_executable_params_t* executable_params,
iree_allocator_t host_allocator, iree_hal_executable_t** out_executable) {
IREE_ASSERT_ARGUMENT(symbols);
IREE_ASSERT_ARGUMENT(executable_params);
IREE_ASSERT_ARGUMENT(out_executable);
IREE_TRACE_ZONE_BEGIN(z0);
*out_executable = NULL;
// TODO: move to the executable cache to avoid repeated queries.
iree_hal_hip_limits_t limits = {0};
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_hip_query_limits(symbols, device, &limits));
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_hip_native_executable_flatbuffer_verify(
executable_params->executable_data, &limits));
iree_hal_hip_ExecutableDef_table_t executable_def =
iree_hal_hip_ExecutableDef_as_root(
executable_params->executable_data.data);
iree_hal_hip_ModuleDef_vec_t modules_vec =
iree_hal_hip_ExecutableDef_modules_get(executable_def);
iree_host_size_t module_count = iree_hal_hip_ModuleDef_vec_len(modules_vec);
iree_hal_hip_ExportDef_vec_t exports_vec =
iree_hal_hip_ExecutableDef_exports_get(executable_def);
iree_host_size_t export_count = iree_hal_hip_ExportDef_vec_len(exports_vec);
// Calculate the total number of characters across all entry point names. This
// is only required when tracing so that we can store copies of the names as
// the flatbuffer storing the strings may be released while the executable is
// still live.
iree_host_size_t total_export_info_length = 0;
IREE_TRACE({
for (iree_host_size_t i = 0; i < export_count; ++i) {
iree_hal_hip_ExportDef_table_t export_def =
iree_hal_hip_ExportDef_vec_at(exports_vec, i);
total_export_info_length += iree_hal_debug_calculate_export_info_size(
iree_hal_hip_ExportDef_debug_info_get(export_def));
}
});
// Allocate storage for the executable and its associated data structures.
iree_hal_hip_native_executable_t* executable = NULL;
const iree_host_size_t total_size =
sizeof(*executable) + module_count * sizeof(executable->modules[0]) +
export_count * sizeof(executable->exports[0]) + total_export_info_length;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0,
iree_allocator_malloc(host_allocator, total_size, (void**)&executable));
iree_hal_resource_initialize(&iree_hal_hip_native_executable_vtable,
&executable->resource);
executable->host_allocator = host_allocator;
executable->symbols = symbols;
executable->module_count = module_count;
executable->modules =
(hipModule_t*)((uint8_t*)executable + sizeof(*executable) +
export_count * sizeof(executable->exports[0]));
executable->export_count = export_count;
IREE_TRACE(uint8_t* export_info_ptr =
((uint8_t*)executable->modules +
module_count * sizeof(executable->modules[0])));
// Publish any embedded source files to the tracing infrastructure.
iree_hal_debug_publish_source_files(
iree_hal_hip_ExecutableDef_source_files_get(executable_def));
// Load each module first so that exports can reference them.
iree_status_t status = iree_ok_status();
for (iree_host_size_t i = 0; i < module_count; ++i) {
iree_hal_hip_ModuleDef_table_t module_def =
iree_hal_hip_ModuleDef_vec_at(modules_vec, i);
// WARNING: HIP doesn't take an expected length here so we can't bound it.
// It's likely that users could craft inputs that read beyond the extents of
// the embedded binary.
flatbuffers_string_t hsaco_image =
iree_hal_hip_ModuleDef_hsaco_image_get(module_def);
// TODO: pass hipJitOption values to get log info and other info back.
// We pass the error buffer today but could use the info log to diagnose
// performance warnings.
char error_log[8192] = {0};
hipJitOption jit_options[] = {
hipJitOptionErrorLogBuffer,
hipJitOptionErrorLogBufferSizeBytes,
};
void* jit_option_values[] = {
(void*)error_log,
(void*)(uint32_t)sizeof(error_log),
};
hipModule_t module = NULL;
status = IREE_HIP_RESULT_TO_STATUS(
symbols,
hipModuleLoadDataEx(&module, hsaco_image, IREE_ARRAYSIZE(jit_options),
jit_options, jit_option_values),
"hipModuleLoadDataEx");
if (!iree_status_is_ok(status)) {
status = iree_status_annotate(
status,
IREE_SV("mismatched target chip? missing/wrong bitcode directory?"));
if (strlen(error_log) > 0) {
status =
iree_status_annotate(status, iree_make_cstring_view(error_log));
}
break;
}
executable->modules[i] = module;
}
if (iree_status_is_ok(status)) {
for (iree_host_size_t i = 0; i < export_count; ++i) {
iree_hal_hip_ExportDef_table_t export_def =
iree_hal_hip_ExportDef_vec_at(exports_vec, i);
// Lookup the function in the module; this should always succeed but
// we cannot trust that the input was generated by our compiler.
uint32_t module_ordinal =
iree_hal_hip_ExportDef_module_ordinal_get(export_def);
hipModule_t module = executable->modules[module_ordinal];
flatbuffers_string_t kernel_name =
iree_hal_hip_ExportDef_kernel_name_get(export_def);
hipFunction_t function = NULL;
status = IREE_HIP_RESULT_TO_STATUS(
symbols, hipModuleGetFunction(&function, module, kernel_name),
"hipModuleGetFunction");
if (!iree_status_is_ok(status)) break;
if (!function) {
status = iree_make_status(IREE_STATUS_NOT_FOUND,
"exports[%" PRIhsz
"] kernel `%s` not found in modules[%u]",
i, kernel_name, module_ordinal);
break;
}
uint32_t block_shared_memory_size =
iree_hal_hip_ExportDef_block_shared_memory_size_get(export_def);
status = IREE_HIP_RESULT_TO_STATUS(
symbols,
hipFuncSetAttribute(
function,
(hipFuncAttribute)
HIP_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
block_shared_memory_size),
"hipFuncSetAttribute");
if (!iree_status_is_ok(status)) break;
// Package required parameters for kernel launches for each entry point.
iree_hal_hip_kernel_params_t* kernel_info = &executable->exports[i];
kernel_info->function = function;
const iree_hal_hip_BlockDims_t* block_dims =
iree_hal_hip_ExportDef_block_dims_get(export_def);
kernel_info->block_dims[0] = block_dims->x;
kernel_info->block_dims[1] = block_dims->y;
kernel_info->block_dims[2] = block_dims->z;
kernel_info->block_shared_memory_size =
iree_hal_hip_ExportDef_block_shared_memory_size_get(export_def);
kernel_info->constant_count =
iree_hal_hip_ExportDef_constant_count_get(export_def);
iree_hal_hip_BindingBits_vec_t binding_flags_vec =
iree_hal_hip_ExportDef_binding_flags_get(export_def);
kernel_info->binding_count =
iree_hal_hip_BindingBits_vec_len(binding_flags_vec);
IREE_TRACE({
iree_hal_debug_export_info_t* export_info =
(iree_hal_debug_export_info_t*)export_info_ptr;
export_info_ptr += iree_hal_debug_copy_export_info(
iree_hal_hip_ExportDef_debug_info_get(export_def), export_info);
kernel_info->debug_info.function_name = export_info->function_name;
kernel_info->debug_info.source_filename = export_info->source_filename;
kernel_info->debug_info.source_line = export_info->source_line;
});
}
}
if (iree_status_is_ok(status)) {
*out_executable = (iree_hal_executable_t*)executable;
} else {
iree_hal_executable_destroy((iree_hal_executable_t*)executable);
}
IREE_TRACE_ZONE_END(z0);
return status;
}
static void iree_hal_hip_native_executable_destroy(
iree_hal_executable_t* base_executable) {
iree_hal_hip_native_executable_t* executable =
iree_hal_hip_native_executable_cast(base_executable);
iree_allocator_t host_allocator = executable->host_allocator;
IREE_TRACE_ZONE_BEGIN(z0);
for (iree_host_size_t i = 0; i < executable->module_count; ++i) {
if (executable->modules[i]) {
IREE_HIP_IGNORE_ERROR(executable->symbols,
hipModuleUnload(executable->modules[i]));
}
}
iree_allocator_free(host_allocator, executable);
IREE_TRACE_ZONE_END(z0);
}
iree_status_t iree_hal_hip_native_executable_lookup_kernel_params(
iree_hal_executable_t* base_executable, int32_t ordinal,
const iree_hal_hip_kernel_params_t** out_params) {
iree_hal_hip_native_executable_t* executable =
iree_hal_hip_native_executable_cast(base_executable);
if (ordinal >= executable->export_count) {
return iree_make_status(
IREE_STATUS_OUT_OF_RANGE,
"export ordinal %d out of range; executable contains %" PRIhsz
" exports",
ordinal, executable->export_count);
}
*out_params = &executable->exports[ordinal];
return iree_ok_status();
}
static const iree_hal_executable_vtable_t
iree_hal_hip_native_executable_vtable = {
.destroy = iree_hal_hip_native_executable_destroy,
};