blob: fad75d067d66a8d798a9208ff37e3fb581abcb11 [file] [log] [blame]
// Copyright 2019 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/modules/hal/module.h"
#include <inttypes.h>
#include <stdbool.h>
#include <stddef.h>
#include "iree/modules/hal/utils/buffer_diagnostics.h"
//===----------------------------------------------------------------------===//
// Limits imposed by the module (and not the HAL)
//===----------------------------------------------------------------------===//
// Limit the number of bindings we pass down through the HAL. This can be tuned
// in the future but right now guards the stack from blowing up during calls.
#define IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT ((iree_host_size_t)32)
// Limit the number of bindings in a binding table that we allocate on the stack
// while marshaling from the VM. Counts over this amount will result in heap
// allocations to avoid blowing the native stack. In most programs we expect
// at most a dozen buffers but programs with individually stored parameters may
// need hundreds or even thousands. Yuck.
#define IREE_HAL_MODULE_MAX_STACK_COMMAND_BUFFER_BINDING_COUNT \
((iree_host_size_t)64)
//===----------------------------------------------------------------------===//
// Module type definitions
//===----------------------------------------------------------------------===//
#define IREE_HAL_MODULE_VERSION_0_3 0x00000003u
#define IREE_HAL_MODULE_VERSION_LATEST IREE_HAL_MODULE_VERSION_0_3
typedef struct iree_hal_module_t {
iree_allocator_t host_allocator;
iree_hal_module_flags_t flags;
iree_host_size_t device_count;
iree_hal_device_t* devices[];
} iree_hal_module_t;
#define IREE_HAL_MODULE_CAST(module) \
(iree_hal_module_t*)((uint8_t*)(module) + iree_vm_native_module_size());
static void IREE_API_PTR iree_hal_module_destroy(void* base_module) {
iree_hal_module_t* module = IREE_HAL_MODULE_CAST(base_module);
for (iree_host_size_t i = 0; i < module->device_count; ++i) {
iree_hal_device_release(module->devices[i]);
}
}
typedef struct iree_hal_module_state_t {
iree_allocator_t host_allocator;
// Flags controlling HAL module behavior passed in from the hosting
// application. All instantiations of a module share the same flags.
iree_hal_module_flags_t flags;
// Total number of devices available to the module.
iree_host_size_t device_count;
// Devices referencing the storage in the parent module.
// Unretained as the parent module must remain live longer than any module
// state allocated from it and we can rely on it to keep the devices retained.
iree_hal_device_t** devices;
// TODO(benvanik): add iree_loop_t to module constructor.
// Status of the nested loop we run for executable creation today. We should
// instead be taking a loop upon creation and scheduling work against that.
iree_status_t loop_status;
// Shared executable cache for each device used to cache all executables
// created in the context. We could have multiple to allow for modules to
// create distinct sets of executables like ones for training vs inference in
// the same model or allow these to be injected so that multiple loaded
// contexts share the caches.
iree_hal_executable_cache_t* executable_caches[];
} iree_hal_module_state_t;
static iree_status_t IREE_API_PTR
iree_hal_module_alloc_state(void* self, iree_allocator_t host_allocator,
iree_vm_module_state_t** out_module_state) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_module_t* module = IREE_HAL_MODULE_CAST(self);
iree_hal_module_state_t* state = NULL;
iree_host_size_t total_size =
sizeof(*state) +
module->device_count * sizeof(state->executable_caches[0]);
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_allocator_malloc(host_allocator, total_size, (void**)&state));
memset(state, 0, total_size);
state->host_allocator = host_allocator;
state->flags = module->flags;
state->device_count = module->device_count;
state->devices = module->devices;
state->loop_status = iree_ok_status();
iree_status_t status = iree_ok_status();
for (iree_host_size_t i = 0; i < state->device_count; ++i) {
status = iree_hal_executable_cache_create(
state->devices[i], iree_string_view_empty(),
iree_loop_inline(&state->loop_status), &state->executable_caches[i]);
if (!iree_status_is_ok(status)) break;
}
if (iree_status_is_ok(status)) {
*out_module_state = (iree_vm_module_state_t*)state;
} else {
for (iree_host_size_t i = 0; i < state->device_count; ++i) {
iree_hal_executable_cache_release(state->executable_caches[i]);
}
iree_allocator_free(host_allocator, state);
}
IREE_TRACE_ZONE_END(z0);
return status;
}
static void IREE_API_PTR
iree_hal_module_free_state(void* self, iree_vm_module_state_t* module_state) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_module_state_t* state = (iree_hal_module_state_t*)module_state;
for (iree_host_size_t i = 0; i < state->device_count; ++i) {
iree_hal_executable_cache_release(state->executable_caches[i]);
}
iree_status_ignore(state->loop_status);
iree_allocator_free(state->host_allocator, state);
IREE_TRACE_ZONE_END(z0);
}
// Returns an unretained reference to the executable cache for the given device.
// If the same device is registered multiple times the first cache is returned.
static iree_status_t iree_hal_module_state_lookup_executable_cache(
iree_hal_module_state_t* state, iree_hal_device_t* device,
iree_hal_executable_cache_t** out_executable_cache) {
IREE_ASSERT_ARGUMENT(state);
IREE_ASSERT_ARGUMENT(device);
IREE_ASSERT_ARGUMENT(out_executable_cache);
*out_executable_cache = NULL;
for (iree_host_size_t i = 0; i < state->device_count; ++i) {
if (state->devices[i] == device) {
*out_executable_cache = state->executable_caches[i];
return iree_ok_status();
}
}
return iree_make_status(
IREE_STATUS_NOT_FOUND,
"no executable cache for the given device found; possibly a device not "
"registered with the HAL module");
}
static iree_status_t IREE_API_PTR iree_hal_module_notify(
void* self, iree_vm_module_state_t* module_state, iree_vm_signal_t signal) {
iree_hal_module_state_t* state = (iree_hal_module_state_t*)module_state;
switch (signal) {
case IREE_VM_SIGNAL_SUSPEND:
case IREE_VM_SIGNAL_LOW_MEMORY: {
for (iree_host_size_t i = 0; i < state->device_count; ++i) {
IREE_RETURN_IF_ERROR(iree_hal_device_trim(state->devices[i]));
}
return iree_ok_status();
}
default: {
// Ignored today but if we started managing device power down we could
// use this to wake them back up again.
return iree_ok_status();
}
}
}
//===----------------------------------------------------------------------===//
// Utilities
//===----------------------------------------------------------------------===//
// Casts a VM value to a C host size.
static iree_host_size_t iree_hal_cast_host_size(int64_t value) {
// TODO(benvanik): make this return status and check for overflow if host
// size is 32-bits.
return (iree_host_size_t)value;
}
// Casts a VM value to a HAL device size.
static iree_device_size_t iree_hal_cast_device_size(int64_t value) {
// TODO(benvanik): make this return status and check for overflow if device
// size is 32-bits.
return (iree_device_size_t)value;
}
//===----------------------------------------------------------------------===//
// Experimental APIs
//===----------------------------------------------------------------------===//
// NOTE: Ex* APIs are experimental and likely to be removed soon. Modules
// using these APIs are not forward compatible.
static void iree_hal_module_file_buffer_release(
void* user_data, iree_io_file_handle_primitive_t handle_primitive) {
iree_vm_buffer_t* backing_buffer = (iree_vm_buffer_t*)user_data;
iree_vm_buffer_release(backing_buffer);
}
IREE_VM_ABI_EXPORT(iree_hal_module_ex_file_from_memory, //
iree_hal_module_state_t, //
rIirIIi, r) {
iree_hal_device_t* device = NULL;
IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
iree_hal_queue_affinity_t queue_affinity =
(iree_hal_queue_affinity_t)args->i1;
iree_hal_memory_access_t access = (iree_hal_memory_access_t)args->i2;
iree_vm_buffer_t* buffer = NULL;
IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r3, &buffer));
iree_host_size_t offset = iree_hal_cast_host_size(args->i4);
iree_host_size_t length = iree_hal_cast_host_size(args->i5);
uint32_t flags = (uint32_t)args->i6;
// Only allow read-only access right now while experimental.
// The contents here are almost always from mapped file memory today.
if (iree_any_bit_set(access, ~IREE_HAL_MEMORY_ACCESS_READ)) {
return iree_make_status(
IREE_STATUS_PERMISSION_DENIED,
"only read-only memory can be accessed via a file handle (today)");
}
// Verify the provided range and get the host pointer.
iree_const_byte_span_t span = iree_const_byte_span_empty();
IREE_RETURN_IF_ERROR(iree_vm_buffer_map_ro(buffer, offset, length, 1, &span));
// Retain the buffer until the file is destroyed.
iree_io_file_handle_release_callback_t release_callback = {
.fn = iree_hal_module_file_buffer_release,
.user_data = buffer,
};
iree_vm_buffer_retain(buffer);
// Wrap the memory in a file handle.
iree_io_file_handle_t* handle = NULL;
iree_status_t status = iree_io_file_handle_wrap_host_allocation(
IREE_IO_FILE_ACCESS_READ,
iree_make_byte_span((void*)span.data, span.data_length), release_callback,
iree_hal_device_host_allocator(device), &handle);
if (!iree_status_is_ok(status)) {
iree_vm_buffer_release(buffer);
}
// Attempt to import the memory as a file.
// Memory files are always supported (even if via emulation) so this should
// always succeed.
iree_hal_file_t* file = NULL;
if (iree_status_is_ok(status)) {
status = iree_hal_file_import(device, queue_affinity, access, handle, flags,
&file);
}
iree_io_file_handle_release(handle);
rets->r0 = iree_hal_file_move_ref(file);
return status;
}
//===----------------------------------------------------------------------===//
// iree_hal_allocator_t
//===----------------------------------------------------------------------===//
IREE_VM_ABI_EXPORT(iree_hal_module_allocator_allocate, //
iree_hal_module_state_t, //
rIiiI, r) {
iree_hal_allocator_t* allocator = NULL;
IREE_RETURN_IF_ERROR(iree_hal_allocator_check_deref(args->r0, &allocator));
iree_hal_queue_affinity_t queue_affinity =
(iree_hal_queue_affinity_t)args->i1;
iree_hal_memory_type_t memory_types = (iree_hal_memory_type_t)args->i2;
iree_hal_buffer_usage_t buffer_usage = (iree_hal_buffer_usage_t)args->i3;
iree_device_size_t allocation_size = iree_hal_cast_device_size(args->i4);
const iree_hal_buffer_params_t params = {
.type = memory_types,
.usage = buffer_usage,
.queue_affinity = queue_affinity,
};
iree_hal_buffer_t* buffer = NULL;
IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer(
allocator, params, allocation_size, &buffer),
"failed to allocate buffer of length %" PRIdsz,
allocation_size);
rets->r0 = iree_hal_buffer_move_ref(buffer);
return iree_ok_status();
}
static void iree_hal_module_imported_buffer_release(void* user_data,
iree_hal_buffer_t* buffer) {
iree_vm_buffer_t* backing_buffer = (iree_vm_buffer_t*)user_data;
iree_vm_buffer_release(backing_buffer);
}
IREE_VM_ABI_EXPORT(iree_hal_module_allocator_import, //
iree_hal_module_state_t, //
riIiirII, r) {
iree_hal_allocator_t* allocator = NULL;
IREE_RETURN_IF_ERROR(iree_hal_allocator_check_deref(args->r0, &allocator));
bool is_try = args->i1 != 0;
iree_hal_queue_affinity_t queue_affinity =
(iree_hal_queue_affinity_t)args->i2;
iree_hal_memory_type_t memory_types = (iree_hal_memory_type_t)args->i3;
iree_hal_buffer_usage_t buffer_usage = (iree_hal_buffer_usage_t)args->i4;
iree_vm_buffer_t* source = NULL;
IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r5, &source));
iree_device_size_t offset = iree_hal_cast_device_size(args->i6);
iree_device_size_t length = iree_hal_cast_device_size(args->i7);
iree_host_size_t buffer_length = source->data.data_length;
if (length == -1) {
length = buffer_length;
}
if (length < 0 || offset < 0 || offset > buffer_length ||
offset + length > buffer_length) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"byte range out of bounds (requested %" PRIdsz
"-%" PRIdsz " of available %" PRIhsz ")",
offset, (offset + length - 1), buffer_length);
}
iree_hal_memory_access_t allowed_access = IREE_HAL_MEMORY_ACCESS_READ;
if (!iree_all_bits_set(source->access, IREE_VM_BUFFER_ACCESS_MUTABLE)) {
// Source buffer is read-only; require that the access request matches.
if (!iree_all_bits_set(buffer_usage,
IREE_HAL_BUFFER_USAGE_SHARING_IMMUTABLE)) {
return iree_make_status(IREE_STATUS_PERMISSION_DENIED,
"source buffer is immutable and can only be "
"imported for constant usage");
}
// NOTE: if we wanted to lock things down for when there's no MMU to ensure
// that the loaded program doesn't touch the memory then we could just fail
// the request - the program will then perform an alloc+copy and can do
// whatever it wants with the memory.
} else {
// Source buffer is mutable; allow in-place writes.
if (!iree_all_bits_set(buffer_usage,
IREE_HAL_BUFFER_USAGE_SHARING_IMMUTABLE)) {
allowed_access |= IREE_HAL_MEMORY_ACCESS_WRITE;
}
}
// Try mapping - note that this may fail if the target device cannot map the
// memory into the given type (for example, mapping a host buffer into
// device-local memory is only going to work on unified memory systems).
const iree_hal_buffer_params_t params = {
.type = memory_types,
.usage = buffer_usage,
.access = allowed_access,
.queue_affinity = queue_affinity,
};
iree_hal_external_buffer_t external_buffer = {
.type = IREE_HAL_EXTERNAL_BUFFER_TYPE_HOST_ALLOCATION,
.flags = IREE_HAL_EXTERNAL_BUFFER_FLAG_NONE,
.size = length,
.handle.host_allocation.ptr = source->data.data + offset,
};
iree_hal_buffer_release_callback_t release_callback = {
.fn = iree_hal_module_imported_buffer_release,
.user_data = source,
};
iree_hal_buffer_t* buffer = NULL;
iree_status_t status = iree_hal_allocator_import_buffer(
allocator, params, &external_buffer, release_callback, &buffer);
if (iree_status_is_ok(status)) {
// Import succeeded - retain the source buffer that'll be released by
// iree_hal_module_map_data_ctl when the mapping is no longer used.
iree_vm_buffer_retain(source);
rets->r0 = iree_hal_buffer_move_ref(buffer);
return iree_ok_status();
}
// Failed to import - if this was a try then don't fail and just rely on the
// result being nullptr to indicate to the caller that things failed.
memset(&rets->r0, 0, sizeof(rets->r0));
if (is_try) {
IREE_TRACE_MESSAGE(WARNING, "try import failed");
iree_status_ignore(status);
return iree_ok_status();
}
return status;
}
//===----------------------------------------------------------------------===//
// iree_hal_buffer_t
//===----------------------------------------------------------------------===//
IREE_VM_ABI_EXPORT(iree_hal_module_buffer_assert, //
iree_hal_module_state_t, //
rrrIii, v) {
IREE_RETURN_IF_ERROR(iree_hal_modules_buffer_assert(
args->r0, args->r1, iree_hal_cast_device_size(args->i3),
(iree_hal_memory_type_t)args->i4, (iree_hal_buffer_usage_t)args->i5));
// TODO(benvanik): assert that the buffer view is accessible from the
// target device. This needs some iree_hal_allocator_* methods for checking
// whether the external buffer can be used. To start we just compare if the
// allocators are identical.
iree_hal_allocator_t* allocator = NULL;
IREE_RETURN_IF_ERROR(iree_hal_allocator_check_deref(args->r2, &allocator));
(void)allocator;
return iree_ok_status();
}
IREE_VM_ABI_EXPORT(iree_hal_module_buffer_subspan, //
iree_hal_module_state_t, //
rII, r) {
iree_hal_buffer_t* source_buffer = NULL;
IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r0, &source_buffer));
iree_device_size_t source_offset = iree_hal_cast_device_size(args->i1);
iree_device_size_t length = iree_hal_cast_device_size(args->i2);
iree_hal_buffer_t* subspan_buffer = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_buffer_subspan(source_buffer, source_offset, length,
&subspan_buffer),
"invalid subspan of an existing buffer (source_offset=%" PRIdsz
", length=%" PRIdsz ")",
source_offset, length);
rets->r0 = iree_hal_buffer_move_ref(subspan_buffer);
return iree_ok_status();
}
IREE_VM_ABI_EXPORT(iree_hal_module_buffer_length, //
iree_hal_module_state_t, //
r, I) {
iree_hal_buffer_t* buffer = NULL;
IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r0, &buffer));
rets->i0 = (int64_t)iree_hal_buffer_byte_length(buffer);
return iree_ok_status();
}
IREE_VM_ABI_EXPORT(iree_hal_module_buffer_load, //
iree_hal_module_state_t, //
rIi, i) {
iree_hal_buffer_t* source_buffer = NULL;
IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r0, &source_buffer));
iree_device_size_t source_offset = iree_hal_cast_device_size(args->i1);
iree_vm_size_t length = (iree_vm_size_t)args->i2;
uint32_t target_buffer = 0;
if (length > sizeof(target_buffer)) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"load length byte count %d exceeds max", length);
}
IREE_RETURN_IF_ERROR(iree_hal_buffer_map_read(source_buffer, source_offset,
&target_buffer, length));
rets->i0 = target_buffer;
return iree_ok_status();
}
IREE_VM_ABI_EXPORT(iree_hal_module_buffer_store, //
iree_hal_module_state_t, //
irIi, v) {
int32_t value = args->i0;
iree_hal_buffer_t* target_buffer = NULL;
IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r1, &target_buffer));
iree_device_size_t target_offset = iree_hal_cast_device_size(args->i2);
iree_vm_size_t length = (iree_vm_size_t)args->i3;
if (length > sizeof(value)) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"store length byte count %d exceeds max", length);
} else if (target_offset + length >
iree_hal_buffer_byte_length(target_buffer)) {
return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
"store out of bounds (target_offset=%" PRIdsz
", length=%d into max %" PRIdsz ")",
target_offset, length,
iree_hal_buffer_byte_length(target_buffer));
}
return iree_hal_buffer_map_write(target_buffer, target_offset, &value,
length);
}
//===----------------------------------------------------------------------===//
// iree_hal_buffer_view_t
//===----------------------------------------------------------------------===//
IREE_VM_ABI_EXPORT(iree_hal_module_buffer_view_create, //
iree_hal_module_state_t, //
rIIiiCID, r) {
iree_hal_buffer_t* source_buffer = NULL;
IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r0, &source_buffer));
iree_device_size_t source_offset = iree_hal_cast_device_size(args->i1);
iree_device_size_t source_length = iree_hal_cast_device_size(args->i2);
iree_hal_element_type_t element_type = (iree_hal_element_type_t)args->i3;
iree_hal_encoding_type_t encoding_type = (iree_hal_encoding_type_t)args->i4;
iree_host_size_t shape_rank = 0;
iree_hal_dim_t* shape_dims = NULL;
// TODO(benvanik): avoid the cast/alloca if not required.
IREE_VM_ABI_VLA_STACK_CAST(args, a5_count, a5, iree_hal_dim_t, 128,
&shape_rank, &shape_dims);
iree_hal_buffer_t* subspan_buffer = NULL;
if (source_offset != 0 ||
source_length != iree_hal_buffer_byte_length(source_buffer)) {
IREE_RETURN_IF_ERROR(
iree_hal_buffer_subspan(source_buffer, source_offset, source_length,
&subspan_buffer),
"invalid subspan of an existing buffer (source_offset=%" PRIdsz
", length=%" PRIdsz ")",
source_offset, source_length);
}
iree_hal_buffer_view_t* buffer_view = NULL;
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create(
subspan_buffer ? subspan_buffer : source_buffer, shape_rank, shape_dims,
element_type, encoding_type, state->host_allocator, &buffer_view));
iree_hal_buffer_release(subspan_buffer);
rets->r0 = iree_hal_buffer_view_move_ref(buffer_view);
return iree_ok_status();
}
IREE_VM_ABI_EXPORT(iree_hal_module_buffer_view_assert, //
iree_hal_module_state_t, //
rriiCID, v) {
iree_host_size_t expected_shape_rank = 0;
iree_hal_dim_t* expected_shape_dims = NULL;
// TODO(benvanik): avoid the cast/alloca if not required.
IREE_VM_ABI_VLA_STACK_CAST(args, a4_count, a4, iree_hal_dim_t, 128,
&expected_shape_rank, &expected_shape_dims);
return iree_hal_modules_buffer_view_assert(
args->r0, args->r1, (iree_hal_element_type_t)args->i2,
(iree_hal_encoding_type_t)args->i3, expected_shape_rank,
expected_shape_dims);
}
IREE_VM_ABI_EXPORT(iree_hal_module_buffer_view_buffer, //
iree_hal_module_state_t, //
r, r) {
iree_hal_buffer_view_t* buffer_view = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_buffer_view_check_deref(args->r0, &buffer_view));
rets->r0 =
iree_hal_buffer_retain_ref(iree_hal_buffer_view_buffer(buffer_view));
return iree_ok_status();
}
IREE_VM_ABI_EXPORT(iree_hal_module_buffer_view_element_type, //
iree_hal_module_state_t, //
r, i) {
iree_hal_buffer_view_t* buffer_view = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_buffer_view_check_deref(args->r0, &buffer_view));
rets->i0 = (uint32_t)iree_hal_buffer_view_element_type(buffer_view);
return iree_ok_status();
}
IREE_VM_ABI_EXPORT(iree_hal_module_buffer_view_encoding_type, //
iree_hal_module_state_t, //
r, i) {
iree_hal_buffer_view_t* buffer_view = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_buffer_view_check_deref(args->r0, &buffer_view));
rets->i0 = (uint32_t)iree_hal_buffer_view_encoding_type(buffer_view);
return iree_ok_status();
}
IREE_VM_ABI_EXPORT(iree_hal_module_buffer_view_rank, //
iree_hal_module_state_t, //
r, i) {
iree_hal_buffer_view_t* buffer_view = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_buffer_view_check_deref(args->r0, &buffer_view));
rets->i0 = (iree_vm_size_t)iree_hal_buffer_view_shape_rank(buffer_view);
return iree_ok_status();
}
IREE_VM_ABI_EXPORT(iree_hal_module_buffer_view_dim, //
iree_hal_module_state_t, //
ri, I) {
iree_hal_buffer_view_t* buffer_view = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_buffer_view_check_deref(args->r0, &buffer_view));
iree_vm_size_t index = (iree_vm_size_t)args->i1;
rets->i0 = (int64_t)iree_hal_buffer_view_shape_dim(buffer_view, index);
return iree_ok_status();
}
IREE_VM_ABI_EXPORT(iree_hal_module_buffer_view_trace, //
iree_hal_module_state_t, //
rCrD, v) {
return iree_hal_modules_buffer_view_trace(args->r0, args->a1_count, args->a1,
state->host_allocator);
}
//===----------------------------------------------------------------------===//
// iree_hal_channel_t
//===----------------------------------------------------------------------===//
IREE_VM_ABI_EXPORT(iree_hal_module_channel_create, //
iree_hal_module_state_t, //
rIirrii, r) {
iree_hal_device_t* device = NULL;
IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
iree_hal_queue_affinity_t queue_affinity =
(iree_hal_queue_affinity_t)args->i1;
uint32_t flags = args->i2;
iree_vm_buffer_t* id = NULL;
IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref_or_null(args->r3, &id));
iree_vm_buffer_t* group = NULL;
IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref_or_null(args->r4, &group));
iree_string_view_t group_str = iree_vm_buffer_as_string(group);
int32_t rank = args->i5;
int32_t count = args->i6;
iree_hal_channel_params_t params = {
.flags = flags,
.id = iree_vm_buffer_const_contents(id), // may be null
.group = group_str, // may be null
.rank = rank,
.count = count,
};
iree_hal_channel_t* channel = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_channel_create(device, queue_affinity, params, &channel));
rets->r0 = iree_hal_channel_move_ref(channel);
return iree_ok_status();
}
IREE_VM_ABI_EXPORT(iree_hal_module_channel_split, //
iree_hal_module_state_t, //
riii, r) {
iree_hal_channel_t* base_channel = NULL;
IREE_RETURN_IF_ERROR(iree_hal_channel_check_deref(args->r0, &base_channel));
int32_t color = args->i1;
int32_t key = args->i2;
int32_t flags = args->i3;
iree_hal_channel_t* split_channel = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_channel_split(base_channel, color, key, flags, &split_channel));
rets->r0 = iree_hal_channel_move_ref(split_channel);
return iree_ok_status();
}
IREE_VM_ABI_EXPORT(iree_hal_module_channel_rank_and_count, //
iree_hal_module_state_t, //
r, ii) {
iree_hal_channel_t* channel = NULL;
IREE_RETURN_IF_ERROR(iree_hal_channel_check_deref(args->r0, &channel));
int32_t rank = 0;
int32_t count = 0;
iree_hal_channel_query_rank_and_count(channel, &rank, &count);
rets->i0 = rank;
rets->i1 = count;
return iree_ok_status();
}
//===----------------------------------------------------------------------===//
// iree_hal_command_buffer_t
//===----------------------------------------------------------------------===//
IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_create, //
iree_hal_module_state_t, //
riiIi, r) {
iree_hal_device_t* device = NULL;
IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
iree_hal_command_buffer_mode_t modes =
(iree_hal_command_buffer_mode_t)args->i1;
iree_hal_command_category_t command_categories =
(iree_hal_command_category_t)args->i2;
iree_hal_queue_affinity_t queue_affinity =
(iree_hal_queue_affinity_t)args->i3;
iree_host_size_t binding_capacity = (iree_host_size_t)args->i4;
iree_hal_command_buffer_t* command_buffer = NULL;
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_create(
device, modes, command_categories, queue_affinity, binding_capacity,
&command_buffer));
iree_status_t status = iree_hal_command_buffer_begin(command_buffer);
if (iree_status_is_ok(status)) {
rets->r0 = iree_hal_command_buffer_move_ref(command_buffer);
} else {
iree_hal_command_buffer_release(command_buffer);
}
return status;
}
IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_finalize, //
iree_hal_module_state_t, //
r, v) {
iree_hal_command_buffer_t* command_buffer = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
return iree_hal_command_buffer_end(command_buffer);
}
IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_begin_debug_group, //
iree_hal_module_state_t, //
rr, v) {
iree_hal_command_buffer_t* command_buffer = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
iree_vm_buffer_t* label = NULL;
IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r1, &label));
iree_string_view_t label_str = iree_vm_buffer_as_string(label);
// TODO(benvanik): query from VM.
iree_hal_label_location_t location = {
.file = iree_string_view_empty(),
.line = 0,
};
iree_hal_command_buffer_begin_debug_group(
command_buffer, label_str, iree_hal_label_color_unspecified(), &location);
return iree_ok_status();
}
IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_end_debug_group, //
iree_hal_module_state_t, //
r, v) {
iree_hal_command_buffer_t* command_buffer = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
iree_hal_command_buffer_end_debug_group(command_buffer);
return iree_ok_status();
}
IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_execution_barrier, //
iree_hal_module_state_t, //
riii, v) {
iree_hal_command_buffer_t* command_buffer = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
iree_hal_execution_stage_t source_stage_mask =
(iree_hal_execution_stage_t)args->i1;
iree_hal_execution_stage_t target_stage_mask =
(iree_hal_execution_stage_t)args->i2;
iree_hal_execution_barrier_flags_t flags =
(iree_hal_execution_barrier_flags_t)args->i3;
// TODO(benvanik): decode barriers.
iree_hal_memory_barrier_t global_barrier;
global_barrier.source_scope = IREE_HAL_ACCESS_SCOPE_DISPATCH_WRITE;
global_barrier.target_scope = IREE_HAL_ACCESS_SCOPE_DISPATCH_READ;
return iree_hal_command_buffer_execution_barrier(
command_buffer, source_stage_mask, target_stage_mask, flags, 1,
&global_barrier, 0, NULL);
}
IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_fill_buffer, //
iree_hal_module_state_t, //
rrIIiii, v) {
iree_hal_command_buffer_t* command_buffer = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
iree_device_size_t target_offset = iree_hal_cast_device_size(args->i2);
iree_device_size_t length = iree_hal_cast_device_size(args->i3);
uint32_t target_buffer_slot = (uint32_t)args->i4;
iree_hal_buffer_ref_t target_ref = iree_hal_make_indirect_buffer_ref(
target_buffer_slot, target_offset, length);
IREE_RETURN_IF_ERROR(
iree_hal_buffer_check_deref_or_null(args->r1, &target_ref.buffer));
uint32_t pattern = (uint32_t)args->i5;
uint32_t pattern_length = (uint32_t)args->i6;
return iree_hal_command_buffer_fill_buffer(command_buffer, target_ref,
&pattern, pattern_length);
}
IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_update_buffer, //
iree_hal_module_state_t, //
rrIrIIi, v) {
iree_hal_command_buffer_t* command_buffer = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
iree_vm_buffer_t* source_buffer = NULL;
IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r1, &source_buffer));
iree_host_size_t source_offset = iree_hal_cast_host_size(args->i2);
iree_device_size_t target_offset = iree_hal_cast_device_size(args->i4);
iree_device_size_t length = iree_hal_cast_device_size(args->i5);
uint32_t target_buffer_slot = (uint32_t)args->i6;
iree_hal_buffer_ref_t target_ref = iree_hal_make_indirect_buffer_ref(
target_buffer_slot, target_offset, length);
IREE_RETURN_IF_ERROR(
iree_hal_buffer_check_deref_or_null(args->r3, &target_ref.buffer));
iree_const_byte_span_t source_span = iree_const_byte_span_empty();
IREE_RETURN_IF_ERROR(iree_vm_buffer_map_ro(
source_buffer, source_offset, (iree_host_size_t)length, 1, &source_span));
return iree_hal_command_buffer_update_buffer(command_buffer, source_span.data,
/*source_offset=*/0, target_ref);
}
IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_copy_buffer, //
iree_hal_module_state_t, //
riirIrII, v) {
iree_hal_command_buffer_t* command_buffer = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
uint32_t source_buffer_slot = (uint32_t)args->i1;
uint32_t target_buffer_slot = (uint32_t)args->i2;
iree_device_size_t source_offset = iree_hal_cast_device_size(args->i4);
iree_device_size_t target_offset = iree_hal_cast_device_size(args->i6);
iree_device_size_t length = iree_hal_cast_device_size(args->i7);
iree_hal_buffer_ref_t source_ref = iree_hal_make_indirect_buffer_ref(
source_buffer_slot, source_offset, length);
iree_hal_buffer_ref_t target_ref = iree_hal_make_indirect_buffer_ref(
target_buffer_slot, target_offset, length);
IREE_RETURN_IF_ERROR(
iree_hal_buffer_check_deref_or_null(args->r3, &source_ref.buffer));
IREE_RETURN_IF_ERROR(
iree_hal_buffer_check_deref_or_null(args->r5, &target_ref.buffer));
return iree_hal_command_buffer_copy_buffer(command_buffer, source_ref,
target_ref);
}
IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_collective, //
iree_hal_module_state_t, //
rriiiirrIIIII, v) {
iree_hal_command_buffer_t* command_buffer = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
iree_hal_channel_t* channel = NULL;
IREE_RETURN_IF_ERROR(iree_hal_channel_check_deref(args->r1, &channel));
iree_hal_collective_op_t op = {.packed = args->i2};
uint32_t param = args->i3;
uint32_t send_buffer_slot = (uint32_t)args->i4;
uint32_t recv_buffer_slot = (uint32_t)args->i5;
iree_hal_buffer_ref_t send_ref = iree_hal_make_indirect_buffer_ref(
send_buffer_slot, iree_hal_cast_device_size(args->i8),
iree_hal_cast_device_size(args->i9));
IREE_RETURN_IF_ERROR(
iree_hal_buffer_check_deref_or_null(args->r6, &send_ref.buffer));
iree_hal_buffer_ref_t recv_ref = iree_hal_make_indirect_buffer_ref(
recv_buffer_slot, iree_hal_cast_device_size(args->i10),
iree_hal_cast_device_size(args->i11));
IREE_RETURN_IF_ERROR(
iree_hal_buffer_check_deref_or_null(args->r7, &recv_ref.buffer));
iree_device_size_t element_count = iree_hal_cast_device_size(args->i12);
return iree_hal_command_buffer_collective(command_buffer, channel, op, param,
send_ref, recv_ref, element_count);
}
IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_push_constants, //
iree_hal_module_state_t, //
rriCiD, v) {
iree_hal_command_buffer_t* command_buffer = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
iree_hal_pipeline_layout_t* pipeline_layout = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_pipeline_layout_check_deref(args->r1, &pipeline_layout));
iree_vm_size_t offset = (iree_vm_size_t)args->i2;
iree_host_size_t value_count = args->a3_count;
const uint32_t* values = (const uint32_t*)&args->a3[0].i0;
return iree_hal_command_buffer_push_constants(
command_buffer, pipeline_layout, offset * sizeof(uint32_t), values,
value_count * sizeof(uint32_t));
}
IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_push_descriptor_set, //
iree_hal_module_state_t, //
rriCiirIID, v) {
iree_hal_command_buffer_t* command_buffer = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
iree_hal_pipeline_layout_t* pipeline_layout = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_pipeline_layout_check_deref(args->r1, &pipeline_layout));
iree_vm_size_t set = args->i2;
iree_host_size_t binding_count = args->a3_count;
if (IREE_UNLIKELY(binding_count >
IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT)) {
return iree_make_status(
IREE_STATUS_OUT_OF_RANGE, "binding count %" PRIhsz " > %" PRIhsz,
binding_count, IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT);
}
iree_hal_buffer_ref_t* bindings = (iree_hal_buffer_ref_t*)iree_alloca(
binding_count * sizeof(iree_hal_buffer_ref_t));
for (iree_host_size_t i = 0; i < binding_count; ++i) {
bindings[i].ordinal = (uint32_t)args->a3[i].i0;
bindings[i].buffer_slot = (uint32_t)args->a3[i].i1;
IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref_or_null(
args->a3[i].r2, &bindings[i].buffer));
bindings[i].offset = iree_hal_cast_device_size(args->a3[i].i3);
bindings[i].length = iree_hal_cast_device_size(args->a3[i].i4);
}
return iree_hal_command_buffer_push_descriptor_set(
command_buffer, pipeline_layout, set, binding_count, bindings);
}
IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_dispatch, //
iree_hal_module_state_t, //
rriiiiI, v) {
iree_hal_command_buffer_t* command_buffer = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
iree_hal_executable_t* executable = NULL;
IREE_RETURN_IF_ERROR(iree_hal_executable_check_deref(args->r1, &executable));
uint32_t entry_point = (uint32_t)args->i2;
uint32_t workgroup_x = (uint32_t)args->i3;
uint32_t workgroup_y = (uint32_t)args->i4;
uint32_t workgroup_z = (uint32_t)args->i5;
iree_hal_dispatch_flags_t flags = (iree_hal_dispatch_flags_t)args->i6;
return iree_hal_command_buffer_dispatch(command_buffer, executable,
entry_point, workgroup_x, workgroup_y,
workgroup_z, flags);
}
IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_dispatch_indirect, //
iree_hal_module_state_t, //
rriirII, v) {
iree_hal_command_buffer_t* command_buffer = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
iree_hal_executable_t* executable = NULL;
IREE_RETURN_IF_ERROR(iree_hal_executable_check_deref(args->r1, &executable));
uint32_t entry_point = (uint32_t)args->i2;
uint32_t workgroups_buffer_slot = (uint32_t)args->i3;
iree_device_size_t workgroups_offset = iree_hal_cast_device_size(args->i5);
iree_hal_buffer_ref_t workgroups_ref = iree_hal_make_indirect_buffer_ref(
workgroups_buffer_slot, workgroups_offset, 3 * sizeof(uint32_t));
IREE_RETURN_IF_ERROR(
iree_hal_buffer_check_deref_or_null(args->r4, &workgroups_ref.buffer));
iree_hal_dispatch_flags_t flags = (iree_hal_dispatch_flags_t)args->i6;
return iree_hal_command_buffer_dispatch_indirect(
command_buffer, executable, entry_point, workgroups_ref, flags);
}
//===----------------------------------------------------------------------===//
// iree_hal_descriptor_set_layout
//===----------------------------------------------------------------------===//
IREE_VM_ABI_EXPORT(iree_hal_module_descriptor_set_layout_create, //
iree_hal_module_state_t, //
riCiiiD, r) {
iree_hal_device_t* device = NULL;
IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
iree_hal_descriptor_set_layout_flags_t flags =
(iree_hal_descriptor_set_layout_flags_t)args->i1;
iree_host_size_t binding_count = args->a2_count;
if (IREE_UNLIKELY(binding_count >
IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT)) {
return iree_make_status(
IREE_STATUS_OUT_OF_RANGE, "binding count %" PRIhsz " > %" PRIhsz,
binding_count, IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT);
}
iree_hal_descriptor_set_layout_binding_t* bindings =
(iree_hal_descriptor_set_layout_binding_t*)iree_alloca(
binding_count * sizeof(iree_hal_descriptor_set_layout_binding_t));
for (iree_host_size_t i = 0; i < binding_count; ++i) {
bindings[i].binding = (uint32_t)args->a2[i].i0;
bindings[i].type = (iree_hal_descriptor_type_t)args->a2[i].i1;
bindings[i].flags = (iree_hal_descriptor_flags_t)args->a2[i].i2;
}
iree_hal_descriptor_set_layout_t* descriptor_set_layout = NULL;
IREE_RETURN_IF_ERROR(iree_hal_descriptor_set_layout_create(
device, flags, binding_count, bindings, &descriptor_set_layout));
rets->r0 = iree_hal_descriptor_set_layout_move_ref(descriptor_set_layout);
return iree_ok_status();
}
//===----------------------------------------------------------------------===//
// iree_hal_device_t
//===----------------------------------------------------------------------===//
IREE_VM_ABI_EXPORT(iree_hal_module_device_allocator, //
iree_hal_module_state_t, //
r, r) {
iree_hal_device_t* device = NULL;
IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
rets->r0 = iree_hal_allocator_retain_ref(iree_hal_device_allocator(device));
return iree_ok_status();
}
IREE_VM_ABI_EXPORT(iree_hal_module_device_query_i64, //
iree_hal_module_state_t, //
rrr, iI) {
iree_hal_device_t* device = NULL;
IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
iree_vm_buffer_t* category = NULL;
IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r1, &category));
iree_string_view_t category_str = iree_vm_buffer_as_string(category);
iree_vm_buffer_t* key = NULL;
IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r2, &key));
iree_string_view_t key_str = iree_vm_buffer_as_string(key);
int64_t value = 0;
iree_status_t query_status =
iree_hal_device_query_i64(device, category_str, key_str, &value);
rets->i0 = iree_status_consume_code(query_status) == IREE_STATUS_OK ? 1 : 0;
rets->i1 = value;
return iree_ok_status();
}
IREE_VM_ABI_EXPORT(iree_hal_module_device_queue_alloca, //
iree_hal_module_state_t, //
rIrriiiI, r) {
iree_hal_device_t* device = NULL;
IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
iree_hal_queue_affinity_t queue_affinity =
(iree_hal_queue_affinity_t)args->i1;
iree_hal_fence_t* wait_fence = iree_hal_fence_deref(args->r2);
iree_hal_fence_t* signal_fence = iree_hal_fence_deref(args->r3);
iree_hal_allocator_pool_t pool = (iree_hal_allocator_pool_t)args->i4;
iree_hal_memory_type_t memory_types = (iree_hal_memory_type_t)args->i5;
iree_hal_buffer_usage_t buffer_usage = (iree_hal_buffer_usage_t)args->i6;
iree_device_size_t allocation_size = iree_hal_cast_device_size(args->i7);
const iree_hal_buffer_params_t params = {
.type = memory_types,
.usage = buffer_usage,
};
iree_hal_buffer_t* buffer = NULL;
IREE_RETURN_IF_ERROR(iree_hal_device_queue_alloca(
device, queue_affinity, iree_hal_fence_semaphore_list(wait_fence),
iree_hal_fence_semaphore_list(signal_fence), pool, params,
allocation_size, &buffer));
rets->r0 = iree_hal_buffer_move_ref(buffer);
return iree_ok_status();
}
IREE_VM_ABI_EXPORT(iree_hal_module_device_queue_dealloca, //
iree_hal_module_state_t, //
rIrrr, v) {
iree_hal_device_t* device = NULL;
IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
iree_hal_queue_affinity_t queue_affinity =
(iree_hal_queue_affinity_t)args->i1;
iree_hal_fence_t* wait_fence = iree_hal_fence_deref(args->r2);
iree_hal_fence_t* signal_fence = iree_hal_fence_deref(args->r3);
iree_hal_buffer_t* buffer = NULL;
IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r4, &buffer));
return iree_hal_device_queue_dealloca(
device, queue_affinity, iree_hal_fence_semaphore_list(wait_fence),
iree_hal_fence_semaphore_list(signal_fence), buffer);
}
IREE_VM_ABI_EXPORT(iree_hal_module_device_queue_read, //
iree_hal_module_state_t, //
rIrrrIrIIi, v) {
iree_hal_device_t* device = NULL;
IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
iree_hal_queue_affinity_t queue_affinity =
(iree_hal_queue_affinity_t)args->i1;
iree_hal_fence_t* wait_fence = iree_hal_fence_deref(args->r2);
iree_hal_fence_t* signal_fence = iree_hal_fence_deref(args->r3);
iree_hal_file_t* source_file = NULL;
IREE_RETURN_IF_ERROR(iree_hal_file_check_deref(args->r4, &source_file));
uint64_t source_offset = (uint64_t)args->i5;
iree_hal_buffer_t* target_buffer = NULL;
IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r6, &target_buffer));
iree_device_size_t target_offset = iree_hal_cast_device_size(args->i7);
iree_device_size_t length = iree_hal_cast_device_size(args->i8);
uint32_t flags = (uint32_t)args->i9;
return iree_hal_device_queue_read(
device, queue_affinity, iree_hal_fence_semaphore_list(wait_fence),
iree_hal_fence_semaphore_list(signal_fence), source_file, source_offset,
target_buffer, target_offset, length, flags);
}
IREE_VM_ABI_EXPORT(iree_hal_module_device_queue_write, //
iree_hal_module_state_t, //
rIrrrIrIIi, v) {
iree_hal_device_t* device = NULL;
IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
iree_hal_queue_affinity_t queue_affinity =
(iree_hal_queue_affinity_t)args->i1;
iree_hal_fence_t* wait_fence = iree_hal_fence_deref(args->r2);
iree_hal_fence_t* signal_fence = iree_hal_fence_deref(args->r3);
iree_hal_buffer_t* source_buffer = NULL;
IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r4, &source_buffer));
iree_device_size_t source_offset = iree_hal_cast_device_size(args->i5);
iree_hal_file_t* target_file = NULL;
IREE_RETURN_IF_ERROR(iree_hal_file_check_deref(args->r6, &target_file));
uint64_t target_offset = (uint64_t)args->i7;
iree_device_size_t length = iree_hal_cast_device_size(args->i8);
uint32_t flags = (uint32_t)args->i9;
return iree_hal_device_queue_write(
device, queue_affinity, iree_hal_fence_semaphore_list(wait_fence),
iree_hal_fence_semaphore_list(signal_fence), source_buffer, source_offset,
target_file, target_offset, length, flags);
}
IREE_VM_ABI_EXPORT(iree_hal_module_device_queue_execute, //
iree_hal_module_state_t, //
rIrrCrD, v) {
iree_hal_device_t* device = NULL;
IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
iree_hal_queue_affinity_t queue_affinity =
(iree_hal_queue_affinity_t)args->i1;
iree_hal_fence_t* wait_fence = iree_hal_fence_deref(args->r2);
iree_hal_fence_t* signal_fence = iree_hal_fence_deref(args->r3);
iree_host_size_t command_buffer_count = 0;
iree_hal_command_buffer_t** command_buffers = NULL;
IREE_VM_ABI_VLA_STACK_DEREF(args, a4_count, a4, iree_hal_command_buffer, 32,
&command_buffer_count, &command_buffers);
return iree_hal_device_queue_execute(
device, queue_affinity, iree_hal_fence_semaphore_list(wait_fence),
iree_hal_fence_semaphore_list(signal_fence), command_buffer_count,
command_buffers, /*binding_tables=*/NULL);
}
IREE_VM_ABI_EXPORT(iree_hal_module_device_queue_execute_indirect, //
iree_hal_module_state_t, //
rIrrrCrIID, v) {
iree_hal_device_t* device = NULL;
IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
iree_hal_queue_affinity_t queue_affinity =
(iree_hal_queue_affinity_t)args->i1;
iree_hal_fence_t* wait_fence = iree_hal_fence_deref(args->r2);
iree_hal_fence_t* signal_fence = iree_hal_fence_deref(args->r3);
iree_hal_command_buffer_t* command_buffer = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_command_buffer_check_deref(args->r4, &command_buffer));
// Allocate temporary storage for the binding table in order to marshal VM
// refs and 64-bit offsets/lengths into the types required by the HAL C API.
iree_host_size_t binding_count = args->a5_count;
iree_hal_buffer_binding_t* bindings = NULL;
if (binding_count > IREE_HAL_MODULE_MAX_STACK_COMMAND_BUFFER_BINDING_COUNT) {
// Heap allocate when using a large number of bindings to avoid blowing the
// native stack. Note that we have to free it before returning from the
// function.
IREE_RETURN_IF_ERROR(iree_allocator_malloc_uninitialized(
state->host_allocator, binding_count * sizeof(*bindings),
(void**)&bindings));
} else {
// Stack allocate when using a small number of bindings (common).
bindings = (iree_hal_buffer_binding_t*)iree_alloca(binding_count *
sizeof(*bindings));
}
// Ensure all buffers are valid (may be NULL) and build the binding table.
iree_status_t status = iree_ok_status();
for (iree_host_size_t i = 0; i < binding_count; ++i) {
status = iree_hal_buffer_check_deref_or_null(args->a5[i].r0,
&bindings[i].buffer);
if (!iree_status_is_ok(status)) break;
bindings[i].offset = iree_hal_cast_device_size(args->a5[i].i1);
bindings[i].length = iree_hal_cast_device_size(args->a5[i].i2);
}
// Schedule execution with the binding table - it will be copied by the device
// and need not live longer than the call.
if (iree_status_is_ok(status)) {
iree_hal_buffer_binding_table_t binding_table = {
.count = binding_count,
.bindings = bindings,
};
status = iree_hal_device_queue_execute(
device, queue_affinity, iree_hal_fence_semaphore_list(wait_fence),
iree_hal_fence_semaphore_list(signal_fence), 1, &command_buffer,
&binding_table);
}
// If we had to heap-allocate the binding table storage it must be freed
// before returning to the VM.
if (binding_count > IREE_HAL_MODULE_MAX_STACK_COMMAND_BUFFER_BINDING_COUNT) {
iree_allocator_free(state->host_allocator, bindings);
}
return status;
}
IREE_VM_ABI_EXPORT(iree_hal_module_device_queue_flush, //
iree_hal_module_state_t, //
rI, v) {
iree_hal_device_t* device = NULL;
IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
iree_hal_queue_affinity_t queue_affinity =
(iree_hal_queue_affinity_t)args->i1;
return iree_hal_device_queue_flush(device, queue_affinity);
}
//===----------------------------------------------------------------------===//
// iree_hal_device_t management
//===----------------------------------------------------------------------===//
IREE_VM_ABI_EXPORT(iree_hal_module_devices_count, //
iree_hal_module_state_t, //
v, i) {
rets->i0 = (int32_t)state->device_count;
return iree_ok_status();
}
IREE_VM_ABI_EXPORT(iree_hal_module_devices_get, //
iree_hal_module_state_t, //
i, r) {
if (args->i0 < state->device_count) {
rets->r0 = iree_hal_device_retain_ref(state->devices[args->i0]);
} else {
rets->r0 = iree_vm_ref_null();
}
return iree_ok_status();
}
//===--------------------------------------------------------------------===//
// iree_hal_executable_t
//===--------------------------------------------------------------------===//
IREE_VM_ABI_EXPORT(iree_hal_module_executable_create, //
iree_hal_module_state_t, //
rrrrCrD, r) {
iree_hal_device_t* device = NULL;
IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
iree_vm_buffer_t* executable_format = NULL;
IREE_RETURN_IF_ERROR(
iree_vm_buffer_check_deref(args->r1, &executable_format));
iree_string_view_t executable_format_str =
iree_vm_buffer_as_string(executable_format);
iree_vm_buffer_t* executable_data = NULL;
IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r2, &executable_data));
iree_host_size_t constant_count = 0;
const uint32_t* constants = NULL;
if (iree_vm_buffer_isa(args->r3)) {
iree_vm_buffer_t* constant_buffer = NULL;
IREE_RETURN_IF_ERROR(
iree_vm_buffer_check_deref(args->r3, &constant_buffer));
if (constant_buffer->data.data_length % 4 != 0) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"constant buffer data must contain 4-byte "
"elements but data length is %" PRIhsz,
constant_buffer->data.data_length);
}
constant_count = constant_buffer->data.data_length / sizeof(uint32_t);
constants = (const uint32_t*)constant_buffer->data.data;
}
iree_hal_executable_cache_t* executable_cache = NULL;
IREE_RETURN_IF_ERROR(iree_hal_module_state_lookup_executable_cache(
state, device, &executable_cache));
iree_host_size_t pipeline_layout_count = args->a4_count;
iree_hal_pipeline_layout_t** pipeline_layouts = NULL;
IREE_RETURN_IF_ERROR(
iree_allocator_malloc(state->host_allocator,
pipeline_layout_count * sizeof(pipeline_layouts[0]),
(void**)&pipeline_layouts));
iree_status_t status = iree_ok_status();
for (iree_host_size_t i = 0; i < pipeline_layout_count; ++i) {
status = iree_hal_pipeline_layout_check_deref(args->a4[i].r0,
&pipeline_layouts[i]);
if (!iree_status_is_ok(status)) break;
}
iree_hal_executable_t* executable = NULL;
if (iree_status_is_ok(status)) {
iree_hal_executable_params_t executable_params;
iree_hal_executable_params_initialize(&executable_params);
executable_params.caching_mode |=
executable_data->access == IREE_VM_BUFFER_ACCESS_ORIGIN_MODULE
? IREE_HAL_EXECUTABLE_CACHING_MODE_ALIAS_PROVIDED_DATA
: 0;
executable_params.executable_format = executable_format_str;
executable_params.executable_data = iree_make_const_byte_span(
executable_data->data.data, executable_data->data.data_length);
executable_params.pipeline_layout_count = pipeline_layout_count;
executable_params.pipeline_layouts = pipeline_layouts;
executable_params.constant_count = constant_count;
executable_params.constants = constants;
status = iree_hal_executable_cache_prepare_executable(
executable_cache, &executable_params, &executable);
}
iree_allocator_free(state->host_allocator, pipeline_layouts);
rets->r0 = iree_hal_executable_move_ref(executable);
return status;
}
//===----------------------------------------------------------------------===//
// iree_hal_fence_t
//===----------------------------------------------------------------------===//
IREE_VM_ABI_EXPORT(iree_hal_module_fence_create, //
iree_hal_module_state_t, //
ri, r) {
iree_hal_device_t* device = NULL;
IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
uint32_t fence_flags = args->i1;
(void)fence_flags;
// TODO(benvanik): hide semaphores from the API.
// This should be reworked to just create the fence.
iree_hal_semaphore_t* semaphore = NULL;
IREE_RETURN_IF_ERROR(iree_hal_semaphore_create(device, 0ull, &semaphore));
// Create fence with room for our single semaphore.
iree_hal_fence_t* fence = NULL;
iree_status_t status =
iree_hal_fence_create(1, state->host_allocator, &fence);
if (iree_status_is_ok(status)) {
status = iree_hal_fence_insert(fence, semaphore, 1ull);
}
iree_hal_semaphore_release(semaphore);
if (iree_status_is_ok(status)) {
rets->r0 = iree_hal_fence_move_ref(fence);
} else {
iree_hal_fence_release(fence);
}
return status;
}
IREE_VM_ABI_EXPORT(iree_hal_module_fence_join, //
iree_hal_module_state_t, //
CrD, r) {
// NOTE: this is an inlined version of iree_hal_fence_join that avoids the
// need for mapping VM types to HAL types via temporary stack/heap storage.
// This lets us avoid allocations/stack exhaustion in pathological cases of
// hundreds of fences (say, one per input argument in stateless programs with
// hundreds/thousands of inputs).
// Find the maximum required timepoint capacity by scanning the fence list.
// This ensures all fences passed in are actually fences _or_ are NULL so
// the subsequent scan below only needs to check for NULL cases.
iree_host_size_t total_timepoint_capacity = 0;
for (iree_host_size_t i = 0; i < args->a0_count; ++i) {
iree_hal_fence_t* fence = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_fence_check_deref_or_null(args->a0[i].r0, &fence));
if (fence) {
total_timepoint_capacity += iree_hal_fence_timepoint_count(fence);
}
}
// If all fences were empty then we no-op by returning a NULL fence
// (immediately signaled).
if (!total_timepoint_capacity) {
rets->r0 = iree_vm_ref_null();
return iree_ok_status();
}
// Create the fence with the maximum capacity. Hopefully there is some
// deduplication.
iree_hal_fence_t* joined_fence = NULL;
IREE_RETURN_IF_ERROR(iree_hal_fence_create(
total_timepoint_capacity, state->host_allocator, &joined_fence));
// Insert all timepoints from all fences. This is slow in cases where there
// are a lot of unique fences.
iree_status_t status = iree_ok_status();
for (iree_host_size_t i = 0; i < args->a0_count; ++i) {
// NOTE: only possible because we checked above and know this is NULL or an
// iree_hal_fence_t.
iree_hal_fence_t* fence = (iree_hal_fence_t*)args->a0[i].r0.ptr;
if (!fence) continue;
iree_hal_semaphore_list_t source_list =
iree_hal_fence_semaphore_list(fence);
for (iree_host_size_t j = 0; j < source_list.count; ++j) {
status = iree_hal_fence_insert(joined_fence, source_list.semaphores[j],
source_list.payload_values[j]);
if (!iree_status_is_ok(status)) break;
}
if (!iree_status_is_ok(status)) break;
}
if (iree_status_is_ok(status)) {
rets->r0 = iree_hal_fence_move_ref(joined_fence);
} else {
iree_hal_fence_release(joined_fence);
}
return status;
}
IREE_VM_ABI_EXPORT(iree_hal_module_fence_query, //
iree_hal_module_state_t, //
r, i) {
iree_hal_fence_t* fence = NULL;
IREE_RETURN_IF_ERROR(iree_hal_fence_check_deref(args->r0, &fence));
iree_status_t query_status = iree_hal_fence_query(fence);
rets->i0 = iree_status_consume_code(query_status);
iree_status_ignore(query_status);
return iree_ok_status();
}
IREE_VM_ABI_EXPORT(iree_hal_module_fence_signal, //
iree_hal_module_state_t, //
r, v) {
iree_hal_fence_t* fence = NULL;
IREE_RETURN_IF_ERROR(iree_hal_fence_check_deref(args->r0, &fence));
return iree_hal_fence_signal(fence);
}
IREE_VM_ABI_EXPORT(iree_hal_module_fence_fail, //
iree_hal_module_state_t, //
ri, v) {
iree_hal_fence_t* fence = NULL;
IREE_RETURN_IF_ERROR(iree_hal_fence_check_deref(args->r0, &fence));
iree_status_code_t status_code =
(iree_status_code_t)(args->i1 & IREE_STATUS_CODE_MASK);
iree_hal_fence_fail(fence, iree_make_status(status_code));
return iree_ok_status();
}
// Removes entries in |fences| if they have been reached.
// Returns failure if one or more fences have failed.
static iree_status_t iree_hal_module_fence_elide_reached(
iree_host_size_t* fence_count, iree_hal_fence_t** fences) {
iree_host_size_t new_count = *fence_count;
for (iree_host_size_t i = 0; i < new_count;) {
iree_status_t status = iree_hal_fence_query(fences[i]);
if (iree_status_is_ok(status)) {
// Has been reached; shift the list down.
memmove(&fences[i], &fences[i + 1],
(new_count - i - 1) * sizeof(iree_hal_fence_t*));
fences[new_count - 1] = NULL;
--new_count;
} else if (iree_status_is_deferred(status)) {
// Still waiting.
iree_status_ignore(status);
++i; // next
} else {
// Failed; propagate failure.
*fence_count = new_count;
return status;
}
}
*fence_count = new_count;
return iree_ok_status();
}
// Enters a wait frame for all timepoints in all |fences|.
// Returns an |out_wait_status| of OK if all fences have been reached or
// IREE_STATUS_DEFERRED if one or more fences are still pending and a wait
// frame was entered.
static iree_status_t iree_hal_module_fence_await_begin(
iree_vm_stack_t* stack, iree_host_size_t fence_count,
iree_hal_fence_t** fences, iree_timeout_t timeout, iree_zone_id_t zone_id,
iree_status_t* out_wait_status) {
// To avoid additional allocations when waiting on multiple fences we enter
// the wait frame with the maximum required wait source capacity and perform
// a simple deduplication when building the list. Ideally this helps get us on
// fast paths of single semaphore waits. The common case is a single fence in
// which case this is all exceptional.
iree_host_size_t total_timepoint_capacity = 0;
for (iree_host_size_t i = 0; i < fence_count; ++i) {
total_timepoint_capacity += iree_hal_fence_timepoint_count(fences[i]);
}
// Fast-path for no semaphores (empty/immediate fences).
if (total_timepoint_capacity == 0) {
*out_wait_status = iree_ok_status();
IREE_TRACE_ZONE_END(zone_id);
return iree_ok_status();
}
// Reserve storage as if all timepoints from all fences were unique.
iree_vm_wait_frame_t* wait_frame = NULL;
IREE_RETURN_IF_ERROR(iree_vm_stack_wait_enter(stack, IREE_VM_WAIT_ALL,
total_timepoint_capacity,
timeout, zone_id, &wait_frame));
// Insert the first set of timepoints - they're already deduplicated.
iree_host_size_t unique_timepoint_count = 0;
if (fence_count >= 1) {
iree_hal_semaphore_list_t semaphore_list =
iree_hal_fence_semaphore_list(fences[0]);
for (iree_host_size_t i = 0; i < semaphore_list.count; ++i) {
iree_wait_source_t wait_source = iree_hal_semaphore_await(
semaphore_list.semaphores[i], semaphore_list.payload_values[i]);
wait_frame->wait_sources[unique_timepoint_count++] = wait_source;
}
}
// TODO(benvanik): simplify this; it may not be worth the complexity. We'll
// need more real workloads using multi-fence joins to see how useful this is.
// Insert remaining fence timepoints by performing merging as we go.
for (iree_host_size_t i = 1; i < fence_count; ++i) {
iree_hal_semaphore_list_t semaphore_list =
iree_hal_fence_semaphore_list(fences[i]);
for (iree_host_size_t j = 0; j < semaphore_list.count; ++j) {
// O(n^2) set insertion - relying on this being rare and the total count
// being low. The savings of a small linear scan here relative to an
// additional syscall are always worth it but we may want to go further.
iree_wait_source_t wait_source = iree_hal_semaphore_await(
semaphore_list.semaphores[j], semaphore_list.payload_values[j]);
bool found_existing = false;
for (iree_host_size_t k = 0; k < unique_timepoint_count; ++k) {
if (wait_frame->wait_sources[k].ctl == wait_source.ctl &&
wait_frame->wait_sources[k].self == wait_source.self) {
// Found existing; use max of both.
wait_frame->wait_sources[k].data =
iree_max(wait_frame->wait_sources[k].data, wait_source.data);
found_existing = true;
break;
}
}
if (!found_existing) {
wait_frame->wait_sources[unique_timepoint_count++] = wait_source;
}
}
}
// Update frame with the actual number of timepoints in the wait operation.
wait_frame->count = unique_timepoint_count;
*out_wait_status = iree_status_from_code(IREE_STATUS_DEFERRED);
return iree_ok_status();
}
// PC for iree_hal_module_fence_await.
enum iree_hal_module_fence_await_pc_e {
// Initial entry point that will try to either wait inline or yield to the
// scheduler with a wait-all operation.
IREE_HAL_MODULE_FENCE_AWAIT_PC_BEGIN = 0,
// Resume entry point after the scheduler wait has resolved (successfully or
// otherwise).
IREE_HAL_MODULE_FENCE_AWAIT_PC_RESUME,
};
IREE_VM_ABI_EXPORT(iree_hal_module_fence_await, //
iree_hal_module_state_t, //
iCrD, i) {
// On entry we either perform the wait or begin a coroutine yield operation.
// After resuming we check to see if the fence has been reached and propagate
// the result.
iree_vm_stack_frame_t* current_frame = iree_vm_stack_top(stack);
iree_zone_id_t zone_id = 0;
iree_status_t wait_status = iree_ok_status();
if (current_frame->pc == IREE_HAL_MODULE_FENCE_AWAIT_PC_BEGIN) {
uint32_t timeout_millis = (uint32_t)args->i0;
iree_host_size_t fence_count = 0;
iree_hal_fence_t** fences = NULL;
IREE_VM_ABI_VLA_STACK_DEREF_OR_NULL(args, a1_count, a1, iree_hal_fence, 32,
&fence_count, &fences);
IREE_TRACE_ZONE_BEGIN(z0);
zone_id = z0;
// Capture absolute timeout so that regardless of how long it takes us to
// wait the user-perceived wait time remains the same.
iree_timeout_t timeout = timeout_millis == UINT32_MAX
? iree_infinite_timeout()
: iree_make_timeout_ms(timeout_millis);
iree_convert_timeout_to_absolute(&timeout);
// Remove any fences that have been reached and check for failure.
IREE_RETURN_AND_END_ZONE_IF_ERROR(
zone_id, iree_hal_module_fence_elide_reached(&fence_count, fences));
// If all fences have been reached we can exit early as if we waited
// successfully.
if (fence_count > 0) {
if (iree_all_bits_set(state->flags, IREE_HAL_MODULE_FLAG_SYNCHRONOUS)) {
// Block the native thread until the fence is reached or the deadline is
// exceeded.
for (iree_host_size_t i = 0; i < fence_count; ++i) {
wait_status = iree_hal_fence_wait(fences[i], timeout);
if (!iree_status_is_ok(wait_status)) break;
}
} else {
current_frame->pc = IREE_HAL_MODULE_FENCE_AWAIT_PC_RESUME;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
zone_id,
iree_hal_module_fence_await_begin(stack, fence_count, fences,
timeout, zone_id, &wait_status));
if (iree_status_is_deferred(wait_status)) {
zone_id = 0; // ownership transferred to wait frame
}
}
}
} else {
// Resume by leaving the wait frame and storing the result.
iree_vm_wait_result_t wait_result;
IREE_RETURN_IF_ERROR(iree_vm_stack_wait_leave(stack, &wait_result));
wait_status = wait_result.status;
IREE_TRACE(zone_id = wait_result.trace_zone);
}
iree_status_t status = iree_ok_status();
if (iree_status_is_ok(wait_status)) {
// Successful wait.
rets->i0 = 0;
} else if (iree_status_is_deferred(wait_status)) {
// Yielding; resume required.
// NOTE: zone not ended as it's reserved on the stack.
status = wait_status;
} else if (iree_status_is_deadline_exceeded(wait_status)) {
// Propagate deadline exceeded back to the VM.
rets->i0 = (int32_t)iree_status_consume_code(wait_status);
iree_status_ignore(wait_status);
} else {
// Fail the invocation.
status = wait_status;
}
IREE_TRACE({
if (zone_id) IREE_TRACE_ZONE_END(zone_id);
});
return status;
}
//===----------------------------------------------------------------------===//
// iree_hal_pipeline_layout_t
//===----------------------------------------------------------------------===//
IREE_VM_ABI_EXPORT(iree_hal_module_pipeline_layout_create, //
iree_hal_module_state_t, //
riCrD, r) {
iree_hal_device_t* device = NULL;
IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
int32_t push_constants = (int32_t)args->i1;
iree_host_size_t set_layout_count = 0;
iree_hal_descriptor_set_layout_t** set_layouts = NULL;
IREE_VM_ABI_VLA_STACK_DEREF(args, a2_count, a2,
iree_hal_descriptor_set_layout, 32,
&set_layout_count, &set_layouts);
iree_hal_pipeline_layout_t* pipeline_layout = NULL;
IREE_RETURN_IF_ERROR(iree_hal_pipeline_layout_create(
device, push_constants, set_layout_count, set_layouts, &pipeline_layout));
rets->r0 = iree_hal_pipeline_layout_move_ref(pipeline_layout);
return iree_ok_status();
}
//===----------------------------------------------------------------------===//
// VM module interface implementation
//===----------------------------------------------------------------------===//
// NOTE: this must match the ordering of the iree_hal_module_exports_ table.
static const iree_vm_native_function_ptr_t iree_hal_module_funcs_[] = {
#define EXPORT_FN(name, target_fn, arg_types, ret_types) \
{ \
.shim = (iree_vm_native_function_shim_t) \
iree_vm_shim_##arg_types##_##ret_types, \
.target = (iree_vm_native_function_target_t)(target_fn), \
},
#include "iree/modules/hal/exports.inl" // IWYU pragma: keep
#undef EXPORT_FN
};
// NOTE: 0 length, but can't express that in C.
static const iree_vm_native_import_descriptor_t iree_hal_module_imports_[1];
static const iree_vm_native_export_descriptor_t iree_hal_module_exports_[] = {
#define EXPORT_FN(name, target_fn, arg_types, ret_types) \
{ \
.local_name = iree_string_view_literal(name), \
.calling_convention = \
iree_string_view_literal("0" #arg_types "_" #ret_types), \
.attr_count = 0, \
.attrs = NULL, \
},
#include "iree/modules/hal/exports.inl" // IWYU pragma: keep
#undef EXPORT_FN
};
static_assert(IREE_ARRAYSIZE(iree_hal_module_funcs_) ==
IREE_ARRAYSIZE(iree_hal_module_exports_),
"function pointer table must be 1:1 with exports");
static const iree_vm_native_module_descriptor_t iree_hal_module_descriptor_ = {
.name = iree_string_view_literal("hal"),
.version = IREE_HAL_MODULE_VERSION_LATEST,
.attr_count = 0,
.attrs = NULL,
.dependency_count = 0,
.dependencies = NULL,
.import_count = 0, // workaround for 0-length C struct
.imports = iree_hal_module_imports_,
.export_count = IREE_ARRAYSIZE(iree_hal_module_exports_),
.exports = iree_hal_module_exports_,
.function_count = IREE_ARRAYSIZE(iree_hal_module_funcs_),
.functions = iree_hal_module_funcs_,
};
IREE_API_EXPORT iree_status_t iree_hal_module_create(
iree_vm_instance_t* instance, iree_host_size_t device_count,
iree_hal_device_t** devices, iree_hal_module_flags_t flags,
iree_allocator_t host_allocator, iree_vm_module_t** out_module) {
IREE_ASSERT_ARGUMENT(instance);
IREE_ASSERT_ARGUMENT(device_count);
IREE_ASSERT_ARGUMENT(devices);
IREE_ASSERT_ARGUMENT(out_module);
*out_module = NULL;
IREE_TRACE_ZONE_BEGIN(z0);
// Setup the interface with the functions we implement ourselves. Any function
// we omit will be handled by the base native module.
static const iree_vm_module_t interface = {
.destroy = iree_hal_module_destroy,
.alloc_state = iree_hal_module_alloc_state,
.free_state = iree_hal_module_free_state,
.notify = iree_hal_module_notify,
};
// Allocate shared module state.
iree_host_size_t total_size = iree_vm_native_module_size() +
sizeof(iree_hal_module_t) +
device_count * sizeof(iree_hal_device_t*);
iree_vm_module_t* base_module = NULL;
IREE_RETURN_IF_ERROR(
iree_allocator_malloc(host_allocator, total_size, (void**)&base_module));
memset(base_module, 0, total_size);
iree_status_t status =
iree_vm_native_module_initialize(&interface, &iree_hal_module_descriptor_,
instance, host_allocator, base_module);
if (!iree_status_is_ok(status)) {
iree_allocator_free(host_allocator, base_module);
IREE_TRACE_ZONE_END(z0);
return status;
}
iree_hal_module_t* module = IREE_HAL_MODULE_CAST(base_module);
module->host_allocator = host_allocator;
// TODO(benvanik): fix vm yield with result storage.
module->flags = flags | IREE_HAL_MODULE_FLAG_SYNCHRONOUS;
module->device_count = device_count;
for (iree_host_size_t i = 0; i < device_count; ++i) {
module->devices[i] = devices[i];
iree_hal_device_retain(module->devices[i]);
}
*out_module = base_module;
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
IREE_API_EXPORT iree_host_size_t
iree_hal_module_state_device_count(iree_vm_module_state_t* module_state) {
iree_hal_module_state_t* state = (iree_hal_module_state_t*)module_state;
return state->device_count;
}
IREE_API_EXPORT iree_hal_device_t* iree_hal_module_state_device_get(
iree_vm_module_state_t* module_state, iree_host_size_t index) {
iree_hal_module_state_t* state = (iree_hal_module_state_t*)module_state;
return index < state->device_count ? state->devices[index] : NULL;
}