blob: 73730737e1ae4d0aca30b4b72a4564ab4ebc52c9 [file] [log] [blame]
// Copyright 2020 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "runtime/bindings/tflite/model.h"
#include <stdio.h>
#include <string.h>
#include "iree/modules/hal/module.h"
#include "iree/vm/bytecode/module.h"
static iree_status_t _TfLiteModelCalculateFunctionIOCounts(
const iree_vm_function_signature_t* signature, int32_t* out_input_count,
int32_t* out_output_count) {
iree_string_view_t arguments, results;
IREE_RETURN_IF_ERROR(iree_vm_function_call_get_cconv_fragments(
signature, &arguments, &results));
// NOTE: today we only pass 1:1 buffer views with what tflite does.
// That means that both these should be one `r` per buffer view and our counts
// are just the number of chars in the cconv.
*out_input_count = (int32_t)arguments.size;
*out_output_count = (int32_t)results.size;
return iree_ok_status();
}
static iree_status_t _TfLiteModelInitializeModule(const void* flatbuffer_data,
size_t flatbuffer_size,
iree_allocator_t allocator,
TfLiteModel* model) {
IREE_TRACE_ZONE_BEGIN(z0);
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT, allocator,
&model->instance));
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_module_register_all_types(model->instance));
iree_const_byte_span_t flatbuffer_span =
iree_make_const_byte_span(flatbuffer_data, flatbuffer_size);
iree_allocator_t flatbuffer_allocator = iree_allocator_null();
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0,
iree_vm_bytecode_module_create(model->instance, flatbuffer_span,
flatbuffer_allocator, allocator,
&model->module),
"error creating bytecode module");
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0,
iree_vm_module_lookup_function_by_name(
model->module, IREE_VM_FUNCTION_LINKAGE_EXPORT,
iree_make_cstring_view("_tflite_main"), &model->exports._main),
"unable to find '_tflite_main' export in module, module must be compiled "
"with tflite bindings support");
// Get the input and output counts of the function; this is useful for being
// able to preallocate storage when creating interpreters.
iree_vm_function_signature_t main_signature =
iree_vm_function_signature(&model->exports._main);
IREE_RETURN_IF_ERROR(_TfLiteModelCalculateFunctionIOCounts(
&main_signature, &model->input_count, &model->output_count));
// NOTE: the input shape query is not required as it's possible (though
// silly) for a model to have no inputs. In testing this can happen a lot
// but in the wild it's rare ... says someone who previously filed bugs
// against tflite because they didn't support models with no inputs when I
// was being silly and needed them ;)
IREE_IGNORE_ERROR(iree_vm_module_lookup_function_by_name(
model->module, IREE_VM_FUNCTION_LINKAGE_EXPORT,
iree_make_cstring_view("_tflite_main_query_input_shape"),
&model->exports._query_input_shape));
// NOTE: the input shape resizing function is only required if the model has
// dynamic shapes.
IREE_IGNORE_ERROR(iree_vm_module_lookup_function_by_name(
model->module, IREE_VM_FUNCTION_LINKAGE_EXPORT,
iree_make_cstring_view("_tflite_main_resize_input_shape"),
&model->exports._resize_input_shape));
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0,
iree_vm_module_lookup_function_by_name(
model->module, IREE_VM_FUNCTION_LINKAGE_EXPORT,
iree_make_cstring_view("_tflite_main_query_output_shape"),
&model->exports._query_output_shape),
"unable to find '_tflite_main_query_output_shape' export in module");
// It's OK for this to fail; the model may not have variables.
IREE_IGNORE_ERROR(iree_vm_module_lookup_function_by_name(
model->module, IREE_VM_FUNCTION_LINKAGE_EXPORT,
iree_make_cstring_view("_tflite_main_reset_variables"),
&model->exports._reset_variables));
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
TFL_CAPI_EXPORT extern TfLiteModel* TfLiteModelCreate(const void* model_data,
size_t model_size) {
iree_allocator_t allocator = iree_allocator_system();
IREE_TRACE_ZONE_BEGIN(z0);
TfLiteModel* model = NULL;
iree_status_t status =
iree_allocator_malloc(allocator, sizeof(*model), (void**)&model);
if (!iree_status_is_ok(iree_status_consume_code(status))) {
IREE_TRACE_MESSAGE(ERROR, "failed model allocation");
IREE_TRACE_ZONE_END(z0);
return NULL;
}
memset(model, 0, sizeof(*model));
iree_atomic_ref_count_init(&model->ref_count);
model->allocator = allocator;
status =
_TfLiteModelInitializeModule(model_data, model_size, allocator, model);
if (!iree_status_is_ok(status)) {
iree_status_fprint(stderr, status);
iree_status_free(status);
TfLiteModelDelete(model);
IREE_TRACE_ZONE_END(z0);
return NULL;
}
IREE_TRACE_ZONE_END(z0);
return model;
}
TFL_CAPI_EXPORT extern TfLiteModel* TfLiteModelCreateFromFile(
const char* model_path) {
iree_allocator_t allocator = iree_allocator_system();
IREE_TRACE_ZONE_BEGIN(z0);
// TODO(#3909): use file mapping C API.
FILE* file = fopen(model_path, "r");
if (!file) {
IREE_TRACE_MESSAGE(ERROR, "failed to open model file");
IREE_TRACE_MESSAGE_DYNAMIC(ERROR, model_path, strlen(model_path));
IREE_TRACE_ZONE_END(z0);
return NULL;
}
fseek(file, 0, SEEK_END);
size_t file_size = ftell(file);
fseek(file, 0, SEEK_SET);
TfLiteModel* model = NULL;
iree_status_t status = iree_allocator_malloc(
allocator, sizeof(TfLiteModel) + file_size, (void**)&model);
if (!iree_status_is_ok(iree_status_consume_code(status))) {
IREE_TRACE_MESSAGE(ERROR, "failed model+data allocation");
IREE_TRACE_ZONE_END(z0);
return NULL;
}
memset(model, 0, sizeof(*model));
iree_atomic_ref_count_init(&model->ref_count);
model->allocator = allocator;
model->owned_model_data = (uint8_t*)model + file_size;
int ret = fread(model->owned_model_data, 1, file_size, file);
fclose(file);
if (ret != file_size) {
TfLiteModelDelete(model);
IREE_TRACE_MESSAGE(ERROR, "failed model+data read");
IREE_TRACE_ZONE_END(z0);
return NULL;
}
status = _TfLiteModelInitializeModule(model->owned_model_data, file_size,
allocator, model);
if (!iree_status_is_ok(iree_status_consume_code(status))) {
TfLiteModelDelete(model);
IREE_TRACE_ZONE_END(z0);
return NULL;
}
IREE_TRACE_ZONE_END(z0);
return model;
}
void _TfLiteModelRetain(TfLiteModel* model) {
if (model) {
iree_atomic_ref_count_inc(&model->ref_count);
}
}
void _TfLiteModelRelease(TfLiteModel* model) {
if (model && iree_atomic_ref_count_dec(&model->ref_count) == 1) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_vm_module_release(model->module);
iree_vm_instance_release(model->instance);
iree_allocator_free(model->allocator, model);
IREE_TRACE_ZONE_END(z0);
}
}
TFL_CAPI_EXPORT extern void TfLiteModelDelete(TfLiteModel* model) {
IREE_TRACE_ZONE_BEGIN(z0);
_TfLiteModelRelease(model);
IREE_TRACE_ZONE_END(z0);
}