blob: e211adc37057b79ad820e872e705b9abab1646fa [file]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/vm/context.h"
#include <assert.h>
#include <stdbool.h>
#include <stdio.h>
#include "iree/base/atomics.h"
#include "iree/base/tracing.h"
struct iree_vm_context {
iree_atomic_ref_count_t ref_count;
iree_vm_instance_t* instance;
iree_allocator_t allocator;
intptr_t context_id;
bool is_static;
struct {
iree_host_size_t count;
iree_host_size_t capacity;
iree_vm_module_t** modules;
iree_vm_module_state_t** module_states;
} list;
};
static void iree_vm_context_destroy(iree_vm_context_t* context);
// Runs a single `() -> ()` function from the module if it exists.
static iree_status_t iree_vm_context_run_function(
iree_vm_stack_t* stack, iree_vm_module_t* module,
iree_string_view_t function_name) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_vm_function_call_t call;
memset(&call, 0, sizeof(call));
iree_status_t status = iree_vm_module_lookup_function_by_name(
module, IREE_VM_FUNCTION_LINKAGE_EXPORT, function_name, &call.function);
if (iree_status_is_not_found(status)) {
// Function doesn't exist; that's ok as this was an optional call.
iree_status_ignore(status);
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
} else if (!iree_status_is_ok(status)) {
IREE_TRACE_ZONE_END(z0);
return status;
}
iree_vm_execution_result_t result;
status = module->begin_call(module->self, stack, &call, &result);
// TODO(benvanik): ensure completed synchronously.
IREE_TRACE_ZONE_END(z0);
return status;
}
static iree_status_t iree_vm_context_query_module_state(
void* state_resolver, iree_vm_module_t* module,
iree_vm_module_state_t** out_module_state) {
IREE_ASSERT_ARGUMENT(state_resolver);
IREE_ASSERT_ARGUMENT(module);
IREE_ASSERT_ARGUMENT(out_module_state);
iree_vm_context_t* context = (iree_vm_context_t*)state_resolver;
// NOTE: this is a linear scan, but given that the list of modules should be
// N<4 this is faster than just about anything else we could do.
// To future performance profilers: sorry when N>>4 :)
for (int i = 0; i < context->list.count; ++i) {
if (context->list.modules[i] == module) {
*out_module_state = context->list.module_states[i];
return iree_ok_status();
}
}
return iree_make_status(IREE_STATUS_NOT_FOUND);
}
static iree_status_t iree_vm_context_resolve_module_imports(
iree_vm_context_t* context, iree_vm_module_t* module,
iree_vm_module_state_t* module_state) {
IREE_TRACE_ZONE_BEGIN(z0);
// NOTE: this has some bad characteristics, but the number of modules and the
// number of imported functions should be relatively small (even if the number
// of exported functions for particular modules is large).
iree_vm_module_signature_t module_signature = module->signature(module->self);
for (int i = 0; i < module_signature.import_function_count; ++i) {
iree_string_view_t full_name;
iree_vm_function_signature_t expected_signature;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0,
module->get_function(module->self, IREE_VM_FUNCTION_LINKAGE_IMPORT, i,
/*out_function=*/NULL,
/*out_name=*/&full_name,
/*out_signature=*/&expected_signature));
// Resolve the function to the module that contains it and return the
// information.
iree_vm_function_t import_function;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0,
iree_vm_context_resolve_function(context, full_name, &import_function));
// Query the function signature from the module that contains it; we don't
// use the signature from the module requesting the import as we want a
// single source of truth.
iree_vm_function_signature_t import_signature =
iree_vm_function_signature(&import_function);
// Simple check to confirm the signatures match. We still can't trust that
// the module using the import *actually* calls it with the right convention
// (so this is not a safety check!), but this will catch the 99% case of a
// signature changing out from under a module or using a module with a newer
// signature than that provided by the imported module.
//
// We allow modules to not define their cconv expectation as in a lot of
// cases where modules are all compiled into the same binary there's no
// value in performing the verification. Runtime checks during calls will
// fail with less awesome logging but that's the tradeoff.
if (expected_signature.calling_convention.size &&
!iree_string_view_equal(import_signature.calling_convention,
expected_signature.calling_convention)) {
IREE_TRACE_ZONE_END(z0);
return iree_make_status(
IREE_STATUS_INTERNAL,
"import function signature mismatch between %.*s "
"and source %.*s; expected %.*s but got %.*s",
(int)iree_vm_module_name(module).size,
iree_vm_module_name(module).data,
(int)iree_vm_module_name(import_function.module).size,
iree_vm_module_name(import_function.module).data,
(int)expected_signature.calling_convention.size,
expected_signature.calling_convention.data,
(int)import_signature.calling_convention.size,
import_signature.calling_convention.data);
}
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, module->resolve_import(module->self, module_state, i,
&import_function, &import_signature));
}
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
static void iree_vm_context_release_modules(iree_vm_context_t* context,
iree_host_size_t start,
iree_host_size_t end) {
IREE_TRACE_ZONE_BEGIN(z0);
// Run module __deinit functions, if present (in reverse init order).
IREE_VM_INLINE_STACK_INITIALIZE(
stack, iree_vm_context_state_resolver(context), context->allocator);
for (int i = (int)end; i >= (int)start; --i) {
iree_vm_module_t* module = context->list.modules[i];
iree_vm_module_state_t* module_state = context->list.module_states[i];
if (!module_state) {
// Partially initialized; skip.
continue;
}
IREE_IGNORE_ERROR(iree_vm_context_run_function(
stack, module, iree_make_cstring_view("__deinit")));
}
iree_vm_stack_deinitialize(stack);
// Release all module state (in reverse init order).
for (int i = (int)end; i >= (int)start; --i) {
iree_vm_module_t* module = context->list.modules[i];
// It is possible in error states to have partially initialized.
if (context->list.module_states[i]) {
module->free_state(module->self, context->list.module_states[i]);
context->list.module_states[i] = NULL;
}
}
// Release modules now that there are no import tables remaining.
for (int i = (int)end; i >= (int)start; --i) {
if (context->list.modules[i]) {
iree_vm_module_release(context->list.modules[i]);
context->list.modules[i] = NULL;
}
}
IREE_TRACE_ZONE_END(z0);
}
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_vm_context_create(iree_vm_instance_t* instance, iree_allocator_t allocator,
iree_vm_context_t** out_context) {
return iree_vm_context_create_with_modules(instance, NULL, 0, allocator,
out_context);
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_vm_context_create_with_modules(
iree_vm_instance_t* instance, iree_vm_module_t** modules,
iree_host_size_t module_count, iree_allocator_t allocator,
iree_vm_context_t** out_context) {
IREE_TRACE_ZONE_BEGIN(z0);
IREE_ASSERT_ARGUMENT(instance);
IREE_ASSERT_ARGUMENT(out_context);
*out_context = NULL;
iree_host_size_t context_size =
sizeof(iree_vm_context_t) + sizeof(iree_vm_module_t*) * module_count +
sizeof(iree_vm_module_state_t*) * module_count;
iree_vm_context_t* context = NULL;
iree_allocator_malloc(allocator, context_size, (void**)&context);
iree_atomic_ref_count_init(&context->ref_count);
context->instance = instance;
iree_vm_instance_retain(context->instance);
context->allocator = allocator;
static iree_atomic_int32_t next_context_id = IREE_ATOMIC_VAR_INIT(1);
context->context_id = iree_atomic_fetch_add_int32(&next_context_id, 1,
iree_memory_order_seq_cst);
uint8_t* p = (uint8_t*)context + sizeof(iree_vm_context_t);
context->list.modules = (iree_vm_module_t**)p;
p += sizeof(iree_vm_module_t*) * module_count;
context->list.module_states = (iree_vm_module_state_t**)p;
p += sizeof(iree_vm_module_state_t*) * module_count;
context->list.count = 0;
context->list.capacity = module_count;
context->is_static = module_count > 0;
iree_status_t register_status =
iree_vm_context_register_modules(context, modules, module_count);
if (!iree_status_is_ok(register_status)) {
iree_vm_context_destroy(context);
IREE_TRACE_ZONE_END(z0);
return register_status;
}
*out_context = context;
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
static void iree_vm_context_destroy(iree_vm_context_t* context) {
if (!context) return;
IREE_TRACE_ZONE_BEGIN(z0);
if (context->list.count > 0) {
iree_vm_context_release_modules(context, 0, context->list.count - 1);
}
// Note: For non-static module lists, it is only dynamically allocated if
// capacity > 0.
if (!context->is_static && context->list.capacity > 0) {
iree_allocator_free(context->allocator, context->list.modules);
context->list.modules = NULL;
iree_allocator_free(context->allocator, context->list.module_states);
context->list.module_states = NULL;
}
iree_vm_instance_release(context->instance);
context->instance = NULL;
iree_allocator_free(context->allocator, context);
IREE_TRACE_ZONE_END(z0);
}
IREE_API_EXPORT void IREE_API_CALL
iree_vm_context_retain(iree_vm_context_t* context) {
if (context) {
iree_atomic_ref_count_inc(&context->ref_count);
}
}
IREE_API_EXPORT void IREE_API_CALL
iree_vm_context_release(iree_vm_context_t* context) {
if (context && iree_atomic_ref_count_dec(&context->ref_count) == 1) {
iree_vm_context_destroy(context);
}
}
IREE_API_EXPORT intptr_t IREE_API_CALL
iree_vm_context_id(const iree_vm_context_t* context) {
if (!context) {
return -1;
}
return context->context_id;
}
IREE_API_EXPORT iree_vm_state_resolver_t IREE_API_CALL
iree_vm_context_state_resolver(const iree_vm_context_t* context) {
iree_vm_state_resolver_t state_resolver = {0};
state_resolver.self = (void*)context;
state_resolver.query_module_state = iree_vm_context_query_module_state;
return state_resolver;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_vm_context_resolve_module_state(
const iree_vm_context_t* context, iree_vm_module_t* module,
iree_vm_module_state_t** out_module_state) {
return iree_vm_context_query_module_state((void*)context, module,
out_module_state);
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_vm_context_register_modules(
iree_vm_context_t* context, iree_vm_module_t** modules,
iree_host_size_t module_count) {
IREE_ASSERT_ARGUMENT(context);
if (!modules && module_count > 1) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"modules/module_count mismatch");
}
for (iree_host_size_t i = 0; i < module_count; ++i) {
if (!modules[i]) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"modules[%zu] is null", i);
}
}
IREE_TRACE_ZONE_BEGIN(z0);
// Try growing both our storage lists first, if needed.
if (context->list.count + module_count > context->list.capacity) {
if (context->is_static) {
IREE_TRACE_ZONE_END(z0);
return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
"context was allocated as static and cannot "
"register modules after creation");
}
iree_host_size_t new_capacity = context->list.capacity + module_count;
if (new_capacity < context->list.capacity * 2) {
// TODO(benvanik): tune list growth for module count >> 4.
new_capacity = context->list.capacity * 2;
}
iree_vm_module_t** new_module_list;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_allocator_malloc(context->allocator,
sizeof(iree_vm_module_t*) * new_capacity,
(void**)&new_module_list));
iree_vm_module_state_t** new_module_state_list;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0,
iree_allocator_malloc(context->allocator,
sizeof(iree_vm_module_state_t*) * new_capacity,
(void**)&new_module_state_list));
memcpy(new_module_list, context->list.modules,
sizeof(iree_vm_module_t*) * context->list.count);
memcpy(new_module_state_list, context->list.module_states,
sizeof(iree_vm_module_state_t*) * context->list.count);
// The existing memory is only dynamically allocated if it has been
// grown.
if (context->list.capacity > 0) {
iree_allocator_free(context->allocator, context->list.modules);
iree_allocator_free(context->allocator, context->list.module_states);
}
context->list.modules = new_module_list;
context->list.module_states = new_module_state_list;
context->list.capacity = new_capacity;
}
// VM stack used to call into module __init methods.
IREE_VM_INLINE_STACK_INITIALIZE(
stack, iree_vm_context_state_resolver(context), context->allocator);
// Retain all modules and allocate their state.
assert(context->list.capacity >= context->list.count + module_count);
iree_host_size_t original_count = context->list.count;
iree_status_t status = iree_ok_status();
iree_host_size_t i = 0;
for (i = 0; i < module_count; ++i) {
iree_vm_module_t* module = modules[i];
context->list.modules[original_count + i] = module;
context->list.module_states[original_count + i] = NULL;
iree_vm_module_retain(module);
// Allocate module state.
iree_vm_module_state_t* module_state = NULL;
status =
module->alloc_state(module->self, context->allocator, &module_state);
if (!iree_status_is_ok(status)) {
// Cleanup handled below.
break;
}
context->list.module_states[original_count + i] = module_state;
// Resolve imports for the modules.
status =
iree_vm_context_resolve_module_imports(context, module, module_state);
if (!iree_status_is_ok(status)) {
// Cleanup handled below.
break;
}
++context->list.count;
// Run module __init functions, if present.
// As initialization functions may reference imports we need to perform
// all of these after we have resolved the imports above.
status = iree_vm_context_run_function(stack, module,
iree_make_cstring_view("__init"));
if (!iree_status_is_ok(status)) {
// Cleanup handled below.
break;
}
}
iree_vm_stack_deinitialize(stack);
// Cleanup for failure cases during module initialization; we need to
// ensure we release any modules we'd already initialized.
if (!iree_status_is_ok(status)) {
iree_vm_context_release_modules(context, original_count,
original_count + i);
context->list.count = original_count;
}
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_vm_context_resolve_function(
const iree_vm_context_t* context, iree_string_view_t full_name,
iree_vm_function_t* out_function) {
IREE_TRACE_ZONE_BEGIN(z0);
IREE_ASSERT_ARGUMENT(out_function);
memset(out_function, 0, sizeof(iree_vm_function_t));
iree_string_view_t module_name;
iree_string_view_t function_name;
if (iree_string_view_split(full_name, '.', &module_name, &function_name) ==
-1) {
IREE_TRACE_ZONE_END(z0);
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"import name not fully-qualified (module.func): '%.*s'",
(int)full_name.size, full_name.data);
}
for (int i = (int)context->list.count - 1; i >= 0; --i) {
iree_vm_module_t* module = context->list.modules[i];
if (iree_string_view_equal(module_name, iree_vm_module_name(module))) {
iree_status_t status = iree_vm_module_lookup_function_by_name(
module, IREE_VM_FUNCTION_LINKAGE_EXPORT, function_name, out_function);
IREE_TRACE_ZONE_END(z0);
return status;
}
}
IREE_TRACE_ZONE_END(z0);
return iree_make_status(IREE_STATUS_NOT_FOUND,
"module '%.*s' required for import '%.*s' not "
"registered with the context",
(int)module_name.size, module_name.data,
(int)full_name.size, full_name.data);
}