blob: 162c947ea28f083dc572e6d41b384a4fd7969911 [file] [log] [blame]
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "experimental/cuda2/native_executable.h"
#include <stddef.h>
#include "experimental/cuda2/cuda_dynamic_symbols.h"
#include "experimental/cuda2/cuda_status_util.h"
#include "experimental/cuda2/pipeline_layout.h"
#include "iree/base/api.h"
// flatcc schemas:
#include "iree/base/internal/flatcc/parsing.h"
#include "iree/schemas/cuda_executable_def_reader.h"
#include "iree/schemas/cuda_executable_def_verifier.h"
typedef struct iree_hal_cuda2_native_executable_t {
// Abstract resource used for injecting reference counting and vtable;
// must be at offset 0.
iree_hal_resource_t resource;
iree_allocator_t host_allocator;
const iree_hal_cuda2_dynamic_symbols_t* symbols;
// The loaded CUDA module.
CUmodule cu_module;
iree_host_size_t entry_point_count;
// The list of entry point data pointers, pointing to trailing inline
// allocation after the end of this struct.
iree_hal_cuda2_kernel_params_t entry_points[];
} iree_hal_cuda2_native_executable_t;
// + Additional inline allocation for holding entry point information.
static const iree_hal_executable_vtable_t
iree_hal_cuda2_native_executable_vtable;
static iree_hal_cuda2_native_executable_t*
iree_hal_cuda2_native_executable_cast(iree_hal_executable_t* base_value) {
IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_cuda2_native_executable_vtable);
return (iree_hal_cuda2_native_executable_t*)base_value;
}
// Verifies the structure of the flatbuffer so that we can avoid doing so during
// runtime.
//
// There are still some conditions we must be aware of (such as omitted names on
// functions with internal linkage), however we shouldn't need to bounds check
// anything within the flatbuffer after this succeeds.
static iree_status_t iree_hal_cuda2_native_executable_flatbuffer_verify(
iree_const_byte_span_t flatbuffer_data) {
if (!flatbuffer_data.data) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"flatbuffer data is not present");
}
// Run flatcc generated verification. This ensures all pointers are in-bounds
// and that we can safely walk the file, but not that the actual contents of
// the flatbuffer meet our expectations.
int verify_ret = iree_hal_cuda_ExecutableDef_verify_as_root(
flatbuffer_data.data, flatbuffer_data.data_length);
if (verify_ret != flatcc_verify_ok) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"flatbuffer verification failed: %s",
flatcc_verify_error_string(verify_ret));
}
iree_hal_cuda_ExecutableDef_table_t executable_def =
iree_hal_cuda_ExecutableDef_as_root(flatbuffer_data.data);
flatbuffers_string_vec_t entry_points_vec =
iree_hal_cuda_ExecutableDef_entry_points_get(executable_def);
size_t entry_point_count = flatbuffers_string_vec_len(entry_points_vec);
for (size_t i = 0; i < entry_point_count; ++i) {
if (flatbuffers_string_len(
flatbuffers_string_vec_at(entry_points_vec, i)) == 0) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"executable entry point %zu has no name", i);
}
}
iree_hal_cuda_BlockSizeDef_vec_t block_sizes_vec =
iree_hal_cuda_ExecutableDef_block_sizes_get(executable_def);
size_t block_size_count = iree_hal_cuda_BlockSizeDef_vec_len(block_sizes_vec);
if (block_size_count == 0) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"no block sizes present");
}
if (entry_point_count != block_size_count) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"entry points (%zu) and block sizes (%zu) count mismatch",
entry_point_count, block_size_count);
}
flatbuffers_string_t ptx_image =
iree_hal_cuda_ExecutableDef_ptx_image_get(executable_def);
if (flatbuffers_string_len(ptx_image) == 0) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"no PTX image present");
}
return iree_ok_status();
}
iree_status_t iree_hal_cuda2_native_executable_create(
const iree_hal_cuda2_dynamic_symbols_t* symbols, CUdevice device,
const iree_hal_executable_params_t* executable_params,
iree_allocator_t host_allocator, iree_hal_executable_t** out_executable) {
IREE_ASSERT_ARGUMENT(executable_params);
IREE_ASSERT_ARGUMENT(out_executable);
IREE_TRACE_ZONE_BEGIN(z0);
*out_executable = NULL;
iree_hal_cuda2_native_executable_t* executable = NULL;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_cuda2_native_executable_flatbuffer_verify(
executable_params->executable_data));
iree_hal_cuda_ExecutableDef_table_t executable_def =
iree_hal_cuda_ExecutableDef_as_root(
executable_params->executable_data.data);
flatbuffers_string_t ptx_image =
iree_hal_cuda_ExecutableDef_ptx_image_get(executable_def);
flatbuffers_uint32_vec_t shared_memory_sizes =
iree_hal_cuda_ExecutableDef_shared_memory_size_get(executable_def);
flatbuffers_string_vec_t entry_points_vec =
iree_hal_cuda_ExecutableDef_entry_points_get(executable_def);
iree_hal_cuda_BlockSizeDef_vec_t block_sizes_vec =
iree_hal_cuda_ExecutableDef_block_sizes_get(executable_def);
iree_host_size_t entry_point_count =
flatbuffers_string_vec_len(entry_points_vec);
// Calculate the total number of characters across all entry point names. This
// is only required when tracing so that we can store copies of the names as
// the flatbuffer storing the strings may be released while the executable is
// still live.
iree_host_size_t total_entry_point_name_chars = 0;
IREE_TRACE({
for (iree_host_size_t i = 0; i < entry_point_count; i++) {
const char* entry_name = flatbuffers_string_vec_at(entry_points_vec, i);
total_entry_point_name_chars += flatbuffers_string_len(entry_name);
}
});
// Allocate storage for the kernel module.
iree_host_size_t total_size =
sizeof(*executable) +
entry_point_count * sizeof(executable->entry_points[0]) +
total_entry_point_name_chars;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0,
iree_allocator_malloc(host_allocator, total_size, (void**)&executable));
IREE_TRACE(
char* string_table_buffer =
(char*)((char*)executable + sizeof(*executable) +
entry_point_count * sizeof(executable->entry_points[0])));
// Load the PTX image - this will fail if the device cannot handle the
// contents. We could check this prior to creating
CUmodule module = NULL;
iree_status_t status = IREE_CURESULT_TO_STATUS(
symbols, cuModuleLoadDataEx(&module, ptx_image, 0, NULL, NULL),
"cuModuleLoadDataEx");
// Query max optin shared memory per block - we'll use it to compare with
// kernel usages.
int32_t max_shared_memory = 0;
if (iree_status_is_ok(status)) {
status = IREE_CURESULT_TO_STATUS(
symbols,
cuDeviceGetAttribute(
&max_shared_memory,
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device),
"cuDeviceGetAttribute");
}
if (iree_status_is_ok(status)) {
iree_hal_resource_initialize(&iree_hal_cuda2_native_executable_vtable,
&executable->resource);
executable->host_allocator = host_allocator;
executable->symbols = symbols;
executable->cu_module = module;
executable->entry_point_count = entry_point_count;
for (iree_host_size_t i = 0; i < entry_point_count; i++) {
// Lookup the function in the module; this should always succeed but we
// cannot trust that the input was generated by our compiler.
CUfunction function = NULL;
const char* entry_name = flatbuffers_string_vec_at(entry_points_vec, i);
status = IREE_CURESULT_TO_STATUS(
symbols,
cuModuleGetFunction(&function, executable->cu_module, entry_name),
"cuModuleGetFunction");
if (!iree_status_is_ok(status)) break;
if (!function) {
status = iree_make_status(IREE_STATUS_NOT_FOUND,
"exported module function '%s' not found",
entry_name);
break;
}
if (shared_memory_sizes[i] > max_shared_memory) {
status = iree_make_status(IREE_STATUS_OUT_OF_RANGE,
"requested shared memory size of %d bytes "
"larger than allowed size of %d bytes",
shared_memory_sizes[i], max_shared_memory);
} else {
status = IREE_CURESULT_TO_STATUS(
symbols,
cuFuncSetAttribute(function,
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_memory_sizes[i]),
"cuFuncSetAttribute");
}
if (!iree_status_is_ok(status)) break;
// Package required parameters for kernel launches for each entry point.
iree_hal_cuda2_kernel_params_t* params = &executable->entry_points[i];
params->layout = executable_params->pipeline_layouts[i];
iree_hal_pipeline_layout_retain(params->layout);
params->function = function;
params->block_size[0] = block_sizes_vec[i].x;
params->block_size[1] = block_sizes_vec[i].y;
params->block_size[2] = block_sizes_vec[i].z;
params->shared_memory_size = shared_memory_sizes[i];
// Stash the entry point name in the string table for use when tracing.
IREE_TRACE({
iree_host_size_t entry_name_length = flatbuffers_string_len(entry_name);
memcpy(string_table_buffer, entry_name, entry_name_length);
params->function_name =
iree_make_string_view(string_table_buffer, entry_name_length);
string_table_buffer += entry_name_length;
});
IREE_TRACE({
if (iree_hal_cuda_ExecutableDef_source_locations_is_present(
executable_def)) {
iree_hal_cuda_FileLineLocDef_vec_t source_locs_vec =
iree_hal_cuda_ExecutableDef_source_locations_get(executable_def);
iree_hal_cuda_FileLineLocDef_table_t source_loc =
iree_hal_cuda_FileLineLocDef_vec_at(source_locs_vec, i);
flatbuffers_string_t filename =
iree_hal_cuda_FileLineLocDef_filename_get(source_loc);
uint32_t line = iree_hal_cuda_FileLineLocDef_line_get(source_loc);
params->source_filename =
iree_make_string_view(filename, flatbuffers_string_len(filename));
params->source_line = line;
}
});
}
}
if (iree_status_is_ok(status)) {
*out_executable = (iree_hal_executable_t*)executable;
} else {
iree_hal_executable_destroy((iree_hal_executable_t*)executable);
}
IREE_TRACE_ZONE_END(z0);
return status;
}
static void iree_hal_cuda2_native_executable_destroy(
iree_hal_executable_t* base_executable) {
iree_hal_cuda2_native_executable_t* executable =
iree_hal_cuda2_native_executable_cast(base_executable);
iree_allocator_t host_allocator = executable->host_allocator;
IREE_TRACE_ZONE_BEGIN(z0);
for (iree_host_size_t i = 0; i < executable->entry_point_count; ++i) {
iree_hal_pipeline_layout_release(executable->entry_points[i].layout);
}
if (executable->cu_module) {
IREE_CUDA_IGNORE_ERROR(executable->symbols,
cuModuleUnload(executable->cu_module));
}
iree_allocator_free(host_allocator, executable);
IREE_TRACE_ZONE_END(z0);
}
iree_status_t iree_hal_cuda2_native_executable_entry_point_kernel_params(
iree_hal_executable_t* base_executable, int32_t entry_point,
iree_hal_cuda2_kernel_params_t* out_params) {
iree_hal_cuda2_native_executable_t* executable =
iree_hal_cuda2_native_executable_cast(base_executable);
if (entry_point >= executable->entry_point_count) {
return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
"entry point ordinal %d out of range; executable "
"only contains %" PRIhsz " entry points",
entry_point, executable->entry_point_count);
}
memcpy(out_params, &executable->entry_points[entry_point],
sizeof(*out_params));
return iree_ok_status();
}
static const iree_hal_executable_vtable_t
iree_hal_cuda2_native_executable_vtable = {
.destroy = iree_hal_cuda2_native_executable_destroy,
};