blob: c775fd8861740e560b36776222571a8ffbf2c0e0 [file] [log] [blame]
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/modules/hal/loader/module.h"
#include "iree/base/api.h"
#include "iree/hal/api.h"
#include "iree/hal/local/local_executable.h"
#include "iree/vm/api.h"
#define IREE_HAL_LOADER_MODULE_VERSION_0_0 0x00000000u
#define IREE_HAL_LOADER_MODULE_VERSION_LATEST IREE_HAL_LOADER_MODULE_VERSION_0_0
//===----------------------------------------------------------------------===//
// Module type definitions
//===----------------------------------------------------------------------===//
typedef struct iree_hal_loader_module_t {
iree_allocator_t host_allocator;
iree_hal_loader_module_flags_t flags;
// TODO(benvanik): types.
iree_host_size_t loader_count;
iree_hal_executable_loader_t* loaders[];
} iree_hal_loader_module_t;
#define IREE_HAL_LOADER_MODULE_CAST(module) \
(iree_hal_loader_module_t*)((uint8_t*)(module) + \
iree_vm_native_module_size());
typedef struct iree_hal_loader_module_state_t {
iree_allocator_t host_allocator;
iree_hal_loader_module_flags_t flags;
} iree_hal_loader_module_state_t;
static void IREE_API_PTR iree_hal_loader_module_destroy(void* base_module) {
iree_hal_loader_module_t* module = IREE_HAL_LOADER_MODULE_CAST(base_module);
for (iree_host_size_t i = 0; i < module->loader_count; ++i) {
iree_hal_executable_loader_release(module->loaders[i]);
}
}
static iree_status_t IREE_API_PTR
iree_hal_loader_module_alloc_state(void* self, iree_allocator_t host_allocator,
iree_vm_module_state_t** out_module_state) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_loader_module_t* module = IREE_HAL_LOADER_MODULE_CAST(self);
iree_hal_loader_module_state_t* state = NULL;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0,
iree_allocator_malloc(host_allocator, sizeof(*state), (void**)&state));
memset(state, 0, sizeof(*state));
state->host_allocator = host_allocator;
state->flags = module->flags;
*out_module_state = (iree_vm_module_state_t*)state;
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
static void IREE_API_PTR iree_hal_loader_module_free_state(
void* self, iree_vm_module_state_t* module_state) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_loader_module_state_t* state =
(iree_hal_loader_module_state_t*)module_state;
iree_allocator_free(state->host_allocator, state);
IREE_TRACE_ZONE_END(z0);
}
static iree_status_t IREE_API_PTR iree_hal_loader_module_notify(
void* self, iree_vm_module_state_t* module_state, iree_vm_signal_t signal) {
switch (signal) {
case IREE_VM_SIGNAL_SUSPEND:
case IREE_VM_SIGNAL_LOW_MEMORY:
default:
return iree_ok_status();
}
}
//===----------------------------------------------------------------------===//
// Utilities
//===----------------------------------------------------------------------===//
// Casts a VM value to a C host size.
static iree_host_size_t iree_hal_cast_host_size(int64_t value) {
// TODO(benvanik): make this return status and check for overflow if host
// size is 32-bits.
return (iree_host_size_t)value;
}
// Casts a VM value to a HAL device size.
static iree_device_size_t iree_hal_cast_device_size(int64_t value) {
// TODO(benvanik): make this return status and check for overflow if device
// size is 32-bits.
return (iree_device_size_t)value;
}
//===----------------------------------------------------------------------===//
// Shared argument shims
//===----------------------------------------------------------------------===//
#define IREE_HAL_ABI_EXPORT(function_name, arg_types, ret_types) \
IREE_VM_ABI_EXPORT(function_name, iree_hal_loader_module_state_t, arg_types, \
ret_types)
#define IREE_HAL_ABI_FIXED_STRUCT(name, types, body) \
IREE_VM_ABI_FIXED_STRUCT(name, body)
#define IREE_HAL_ABI_DEFINE_SHIM(arg_types, ret_types) \
static IREE_VM_ABI_DEFINE_SHIM(arg_types, ret_types)
//===----------------------------------------------------------------------===//
// iree_hal_executable_t
//===----------------------------------------------------------------------===//
IREE_HAL_ABI_EXPORT(iree_hal_loader_module_executable_query_support, //
r, i) {
iree_vm_buffer_t* executable_format = NULL;
IREE_RETURN_IF_ERROR(
iree_vm_buffer_check_deref(args->r0, &executable_format));
iree_string_view_t executable_format_str =
iree_vm_buffer_as_string(executable_format);
bool has_support = false;
iree_hal_loader_module_t* loader_module = IREE_HAL_LOADER_MODULE_CAST(module);
for (iree_host_size_t i = 0; i < loader_module->loader_count; ++i) {
iree_hal_executable_loader_t* loader = loader_module->loaders[i];
if (iree_hal_executable_loader_query_support(loader, 0,
executable_format_str)) {
has_support = true;
break;
}
}
rets->i0 = has_support ? 1 : 0;
return iree_ok_status();
}
static iree_status_t iree_hal_loader_module_try_load(
iree_hal_loader_module_t* loader_module,
const iree_hal_executable_params_t* executable_params,
iree_hal_executable_t** out_executable) {
for (iree_host_size_t i = 0; i < loader_module->loader_count; ++i) {
iree_hal_executable_loader_t* loader = loader_module->loaders[i];
if (!iree_hal_executable_loader_query_support(
loader, executable_params->caching_mode,
executable_params->executable_format)) {
// Loader definitely can't handle the executable; no use trying so skip.
continue;
}
// The loader _may_ handle the executable; if the specific executable is not
// supported then the try will fail with IREE_STATUS_CANCELLED and we should
// continue trying other loaders.
iree_status_t status = iree_hal_executable_loader_try_load(
loader, executable_params, /*worker_capacity=*/1, out_executable);
if (iree_status_is_ok(status)) {
// Executable was successfully loaded.
return status;
} else if (!iree_status_is_cancelled(status)) {
// Error beyond just the try failing due to unsupported formats.
return status;
}
iree_status_ignore(status);
}
return iree_make_status(
IREE_STATUS_NOT_FOUND,
"no executable loader registered for the given executable format '%.*s'",
(int)executable_params->executable_format.size,
executable_params->executable_format.data);
}
IREE_HAL_ABI_EXPORT(iree_hal_loader_module_executable_load, //
rrr, r) {
iree_vm_buffer_t* executable_format = NULL;
IREE_RETURN_IF_ERROR(
iree_vm_buffer_check_deref(args->r0, &executable_format));
iree_string_view_t executable_format_str =
iree_vm_buffer_as_string(executable_format);
iree_vm_buffer_t* executable_data = NULL;
IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r1, &executable_data));
iree_host_size_t constant_count = 0;
const uint32_t* constants = NULL;
if (iree_vm_buffer_isa(args->r2)) {
iree_vm_buffer_t* constant_buffer = NULL;
IREE_RETURN_IF_ERROR(
iree_vm_buffer_check_deref(args->r2, &constant_buffer));
if (constant_buffer->data.data_length % 4 != 0) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"constant buffer data must contain 4-byte "
"elements but data length is %" PRIhsz,
constant_buffer->data.data_length);
}
constant_count = constant_buffer->data.data_length / sizeof(uint32_t);
constants = (const uint32_t*)constant_buffer->data.data;
}
iree_hal_executable_params_t executable_params;
iree_hal_executable_params_initialize(&executable_params);
executable_params.caching_mode |=
executable_data->access == IREE_VM_BUFFER_ACCESS_ORIGIN_MODULE
? IREE_HAL_EXECUTABLE_CACHING_MODE_ALIAS_PROVIDED_DATA
: 0;
executable_params.executable_format = executable_format_str;
executable_params.executable_data = iree_make_const_byte_span(
executable_data->data.data, executable_data->data.data_length);
executable_params.pipeline_layout_count = 0;
executable_params.pipeline_layouts = NULL;
executable_params.constant_count = constant_count;
executable_params.constants = constants;
iree_hal_executable_t* executable = NULL;
iree_hal_loader_module_t* loader_module = IREE_HAL_LOADER_MODULE_CAST(module);
iree_status_t status = iree_hal_loader_module_try_load(
loader_module, &executable_params, &executable);
rets->r0 = iree_hal_executable_move_ref(executable);
return status;
}
typedef struct {
union {
struct {
iree_vm_ref_t executable;
int32_t entry_point;
int32_t workgroup_x;
int32_t workgroup_y;
int32_t workgroup_z;
};
iree_vm_abi_riiii_t params;
};
iree_vm_size_t push_constant_count;
const uint32_t* push_constants;
iree_vm_size_t binding_count;
const iree_vm_abi_rII_t* bindings;
} iree_hal_loader_dispatch_args_t;
static iree_status_t iree_hal_loader_module_executable_dispatch(
iree_vm_stack_t* IREE_RESTRICT stack, void* IREE_RESTRICT module,
iree_hal_loader_module_state_t* IREE_RESTRICT state,
const iree_hal_loader_dispatch_args_t* IREE_RESTRICT args) {
iree_hal_executable_t* executable = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_executable_check_deref(args->executable, &executable));
if (args->binding_count > 32) {
return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
"too many bindings");
}
void** binding_ptrs =
(void**)iree_alloca(args->binding_count * sizeof(void*));
size_t* binding_lengths =
(size_t*)iree_alloca(args->binding_count * sizeof(size_t));
for (iree_host_size_t i = 0; i < args->binding_count; ++i) {
iree_vm_buffer_t* buffer = NULL;
IREE_RETURN_IF_ERROR(
iree_vm_buffer_check_deref(args->bindings[i].r0, &buffer));
// TODO(benvanik): this is a hack around not having the access permissions
// currently modeled. This is only used for verification and early errors
// and not intended to be a last-line defense against writes (you need an
// MMU for that) so it's just subpar reporting.
iree_const_byte_span_t span;
IREE_RETURN_IF_ERROR(iree_vm_buffer_map_ro(
buffer, iree_hal_cast_host_size(args->bindings[i].i1),
iree_hal_cast_host_size(args->bindings[i].i2), /*alignment=*/1, &span));
binding_ptrs[i] = (void*)span.data;
binding_lengths[i] = span.data_length;
}
const iree_hal_executable_dispatch_state_v0_t dispatch_state = {
.workgroup_size_x = 1,
.workgroup_size_y = 1,
.workgroup_size_z = 1,
.push_constant_count = args->push_constant_count,
.workgroup_count_x = args->workgroup_x,
.workgroup_count_y = args->workgroup_y,
.workgroup_count_z = args->workgroup_z,
.max_concurrency = 1,
.binding_count = args->binding_count,
.push_constants = args->push_constants,
.binding_ptrs = binding_ptrs,
.binding_lengths = binding_lengths,
};
// TODO(benvanik): environmental information.
uint32_t processor_id = 0;
iree_byte_span_t local_memory = iree_byte_span_empty();
return iree_hal_local_executable_issue_dispatch_inline(
(iree_hal_local_executable_t*)executable, args->entry_point,
&dispatch_state, processor_id, local_memory);
}
static iree_status_t iree_vm_shim_dispatch_v(
iree_vm_stack_t* IREE_RESTRICT stack, iree_vm_native_function_flags_t flags,
iree_byte_span_t args_storage, iree_byte_span_t rets_storage,
iree_vm_native_function_target2_t target_fn, void* IREE_RESTRICT module,
void* IREE_RESTRICT module_state) {
// TODO(benvanik): support multiple variadic segments in one call.
// For now we inline what it would do in a very painful way.
bool args_ok = true;
if (args_storage.data_length <
(sizeof(iree_vm_abi_riiii_t) + sizeof(iree_vm_size_t) +
sizeof(iree_vm_size_t))) {
// Can't fit even with zero lengths.
args_ok = false;
}
iree_hal_loader_dispatch_args_t args = {
.params = *(const iree_vm_abi_riiii_t*)args_storage.data,
};
if (args_ok) {
const uint8_t* push_constants_ptr = args_storage.data + sizeof(args.params);
args.push_constant_count = *(const iree_vm_size_t*)push_constants_ptr;
args.push_constants =
(const uint32_t*)(push_constants_ptr + sizeof(iree_vm_size_t));
const uint8_t* bindings_ptr =
push_constants_ptr + sizeof(iree_vm_size_t) +
args.push_constant_count * sizeof(args.push_constants[0]);
args.binding_count = *(const iree_vm_size_t*)bindings_ptr;
args.bindings =
(const iree_vm_abi_rII_t*)(bindings_ptr + sizeof(iree_vm_size_t));
const uint8_t* max_ptr = (const uint8_t*)args.bindings +
args.binding_count * sizeof(args.bindings[0]);
const uint8_t* end_ptr = args_storage.data + args_storage.data_length;
if (max_ptr > end_ptr) args_ok = false;
}
if (IREE_UNLIKELY(!args_ok || rets_storage.data_length > 0)) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"argument/result signature mismatch");
}
return iree_hal_loader_module_executable_dispatch(stack, module, module_state,
&args);
}
//===----------------------------------------------------------------------===//
// VM module interface implementation
//===----------------------------------------------------------------------===//
// NOTE: this must match the ordering of the iree_hal_loader_module_exports_
// table.
static const iree_vm_native_function_ptr_t iree_hal_loader_module_funcs_[] = {
#define EXPORT_FN(name, target_fn, shim_arg_type, arg_types, ret_types) \
{ \
.shim = (iree_vm_native_function_shim_t) \
iree_vm_shim_##shim_arg_type##_##ret_types, \
.target = (iree_vm_native_function_target_t)(target_fn), \
},
#include "iree/modules/hal/loader/exports.inl" // IWYU pragma: keep
#undef EXPORT_FN
};
// NOTE: 0 length, but can't express that in C.
static const iree_vm_native_import_descriptor_t
iree_hal_loader_module_imports_[1];
static const iree_vm_native_export_descriptor_t
iree_hal_loader_module_exports_[] = {
#define EXPORT_FN(name, target_fn, shim_arg_type, arg_types, ret_types) \
{ \
.local_name = iree_string_view_literal(name), \
.calling_convention = \
iree_string_view_literal("0" #arg_types "_" #ret_types), \
.attr_count = 0, \
.attrs = NULL, \
},
#include "iree/modules/hal/loader/exports.inl" // IWYU pragma: keep
#undef EXPORT_FN
};
static_assert(IREE_ARRAYSIZE(iree_hal_loader_module_funcs_) ==
IREE_ARRAYSIZE(iree_hal_loader_module_exports_),
"function pointer table must be 1:1 with exports");
static const iree_vm_native_module_descriptor_t
iree_hal_loader_module_descriptor_ = {
.name = iree_string_view_literal("hal_loader"),
.version = IREE_HAL_LOADER_MODULE_VERSION_LATEST,
.attr_count = 0,
.attrs = NULL,
.dependency_count = 0,
.dependencies = NULL,
.import_count = 0, // workaround for 0-length C struct
.imports = iree_hal_loader_module_imports_,
.export_count = IREE_ARRAYSIZE(iree_hal_loader_module_exports_),
.exports = iree_hal_loader_module_exports_,
.function_count = IREE_ARRAYSIZE(iree_hal_loader_module_funcs_),
.functions = iree_hal_loader_module_funcs_,
};
IREE_API_EXPORT iree_status_t iree_hal_loader_module_create(
iree_vm_instance_t* instance, iree_hal_loader_module_flags_t flags,
iree_host_size_t loader_count, iree_hal_executable_loader_t** loaders,
iree_allocator_t host_allocator, iree_vm_module_t** out_module) {
IREE_ASSERT_ARGUMENT(instance);
IREE_ASSERT_ARGUMENT(out_module);
*out_module = NULL;
// Setup the interface with the functions we implement ourselves. Any function
// we omit will be handled by the base native module.
static const iree_vm_module_t interface = {
.destroy = iree_hal_loader_module_destroy,
.alloc_state = iree_hal_loader_module_alloc_state,
.free_state = iree_hal_loader_module_free_state,
.notify = iree_hal_loader_module_notify,
};
// Allocate shared module state.
iree_host_size_t total_size =
iree_vm_native_module_size() + sizeof(iree_hal_loader_module_t) +
loader_count * sizeof(iree_hal_executable_loader_t*);
iree_vm_module_t* base_module = NULL;
IREE_RETURN_IF_ERROR(
iree_allocator_malloc(host_allocator, total_size, (void**)&base_module));
memset(base_module, 0, total_size);
iree_status_t status = iree_vm_native_module_initialize(
&interface, &iree_hal_loader_module_descriptor_, instance, host_allocator,
base_module);
if (!iree_status_is_ok(status)) {
iree_allocator_free(host_allocator, base_module);
return status;
}
iree_hal_loader_module_t* module = IREE_HAL_LOADER_MODULE_CAST(base_module);
module->host_allocator = host_allocator;
module->flags = flags;
module->loader_count = loader_count;
for (iree_host_size_t i = 0; i < loader_count; ++i) {
module->loaders[i] = loaders[i];
iree_hal_executable_loader_retain(loaders[i]);
}
*out_module = base_module;
return iree_ok_status();
}