| // Copyright 2020 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "bindings/tflite/interpreter.h" |
| |
| #include "bindings/tflite/model.h" |
| #include "bindings/tflite/shim.h" |
| #include "bindings/tflite/tensor.h" |
| #include "iree/base/internal/call_once.h" |
| #include "iree/base/tracing.h" |
| #include "iree/hal/drivers/init.h" |
| #include "iree/modules/hal/module.h" |
| |
| //===----------------------------------------------------------------------===// |
| // HAL / driver support |
| //===----------------------------------------------------------------------===// |
| |
| static iree_once_flag _TfLiteInterpreterRegisterDriverFlag = |
| IREE_ONCE_FLAG_INIT; |
| static void _TfLiteInterpreterRegisterDrivers(void) { |
| IREE_IGNORE_ERROR(iree_hal_register_all_available_drivers( |
| iree_hal_driver_registry_default())); |
| } |
| |
| // TODO(#3977): if already provided a HAL device in the options use that. |
| static iree_status_t _TfLiteInterpreterPrepareHAL( |
| TfLiteInterpreter* interpreter) { |
| iree_call_once(&_TfLiteInterpreterRegisterDriverFlag, |
| _TfLiteInterpreterRegisterDrivers); |
| |
| iree_hal_driver_registry_t* driver_registry = |
| iree_hal_driver_registry_default(); |
| |
| iree_hal_driver_info_t* driver_infos = NULL; |
| iree_host_size_t driver_info_count = 0; |
| IREE_RETURN_IF_ERROR(iree_hal_driver_registry_enumerate( |
| driver_registry, interpreter->allocator, &driver_infos, |
| &driver_info_count)); |
| |
| // TODO(benvanik): figure out how we want to emulate device selection; may |
| // just say "whatever is first" on a query. |
| // iree_string_view_t driver_name = driver_infos[0].driver_name; |
| // NOTE: currently the sample file is compiled only with vmvx. |
| iree_string_view_t driver_name = iree_make_cstring_view("vmvx"); |
| |
| // TODO(benvanik): switch to iree_hal_driver_registry_try_create when |
| // implemented. |
| iree_status_t status = iree_hal_driver_registry_try_create_by_name( |
| driver_registry, driver_name, interpreter->allocator, |
| &interpreter->driver); |
| iree_allocator_free(interpreter->allocator, driver_infos); |
| IREE_RETURN_IF_ERROR(status, "failed to create driver '%.*s'", |
| (int)driver_name.size, driver_name.data); |
| |
| IREE_RETURN_IF_ERROR( |
| iree_hal_driver_create_default_device( |
| interpreter->driver, interpreter->allocator, &interpreter->device), |
| "failed creating the default device for driver '%.*s'", |
| (int)driver_name.size, driver_name.data); |
| |
| IREE_RETURN_IF_ERROR(iree_hal_module_create( |
| interpreter->device, interpreter->allocator, &interpreter->hal_module)); |
| |
| return iree_ok_status(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Model shape function query/mutation utilities |
| //===----------------------------------------------------------------------===// |
| |
| // On-stack storage for shape function invocations. |
| // Avoids all allocations and allows for reuse when running down lists of |
| // inputs and outputs calling shape functions. |
| typedef struct { |
| // Inlined list for the !vm.list in the shape function arguments. |
| uint8_t |
| shape_list_storage[128 + sizeof(int32_t) * IREE_BINDINGS_TFLITE_MAX_RANK]; |
| iree_vm_list_t* shape_list; |
| |
| // Inlined list for the shape function arguments. |
| uint8_t arg_list_storage[128 + sizeof(uintptr_t) * 2]; |
| iree_vm_list_t* arg_list; |
| } _TfLiteInterpreterShapeFrame; |
| |
| // Initializes an on-stack shape frame. Existing contents are discarded. |
| static iree_status_t _TfLiteInterpreterShapeFrameInitialize( |
| _TfLiteInterpreterShapeFrame* frame) { |
| // [int32...] storage for the shape dimension inputs/outputs. |
| iree_vm_type_def_t dim_type = |
| iree_vm_type_def_make_value_type(IREE_VM_VALUE_TYPE_I32); |
| IREE_RETURN_IF_ERROR(iree_vm_list_initialize( |
| iree_make_byte_span(frame->shape_list_storage, |
| IREE_ARRAYSIZE(frame->shape_list_storage)), |
| &dim_type, IREE_BINDINGS_TFLITE_MAX_RANK, &frame->shape_list)); |
| |
| // (%index : i32, %shape : !vm.list<i32>) |
| IREE_RETURN_IF_ERROR(iree_vm_list_initialize( |
| iree_make_byte_span(frame->arg_list_storage, |
| IREE_ARRAYSIZE(frame->arg_list_storage)), |
| /*element_type=*/NULL, /*index*/ 1 + /*shape*/ 1, &frame->arg_list)); |
| IREE_RETURN_IF_ERROR(iree_vm_list_resize(frame->arg_list, 2)); |
| |
| // Arg 1 is always the shape list for all I/O, so do that once here. |
| iree_vm_ref_t shape_list_ref = {0}; |
| IREE_RETURN_IF_ERROR(iree_vm_ref_wrap_assign( |
| frame->shape_list, iree_vm_list_type_id(), &shape_list_ref)); |
| IREE_RETURN_IF_ERROR( |
| iree_vm_list_set_ref_retain(frame->arg_list, 1, &shape_list_ref)); |
| |
| return iree_ok_status(); |
| } |
| |
| // Deinitializes an on-stack shape frame. |
| // Though this does not free the frame memory (it's on the stack, afterall) it |
| // will release any resources that may be retained and is required. |
| static void _TfLiteInterpreterShapeFrameDeinitialize( |
| _TfLiteInterpreterShapeFrame* frame) { |
| iree_vm_list_deinitialize(frame->arg_list); |
| iree_vm_list_deinitialize(frame->shape_list); |
| } |
| |
| // Reads the shape value in the frame storage from the prior application. |
| static iree_status_t _TfLiteInterpreterShapeFrameReadValue( |
| _TfLiteInterpreterShapeFrame* frame, int32_t* out_shape_rank, |
| int32_t* out_shape_dims) { |
| *out_shape_rank = (int32_t)iree_vm_list_size(frame->shape_list); |
| for (int32_t i = 0; i < *out_shape_rank; ++i) { |
| iree_vm_value_t dim; |
| IREE_RETURN_IF_ERROR(iree_vm_list_get_value_as( |
| frame->shape_list, i, IREE_VM_VALUE_TYPE_I32, &dim)); |
| out_shape_dims[i] = dim.i32; |
| } |
| return iree_ok_status(); |
| } |
| |
| // Writes the shape value to the current frame storage for future applications. |
| static iree_status_t _TfLiteInterpreterShapeFrameWriteValue( |
| _TfLiteInterpreterShapeFrame* frame, int32_t shape_rank, |
| const int32_t* shape_dims) { |
| IREE_RETURN_IF_ERROR(iree_vm_list_resize(frame->shape_list, shape_rank)); |
| for (int32_t i = 0; i < shape_rank; ++i) { |
| iree_vm_value_t dim = iree_vm_value_make_i32(shape_dims[i]); |
| IREE_RETURN_IF_ERROR(iree_vm_list_set_value(frame->shape_list, i, &dim)); |
| } |
| return iree_ok_status(); |
| } |
| |
| // Calls the |apply_fn| with the current shape frame state. |
| static iree_status_t _TfLiteInterpreterShapeFrameApply( |
| _TfLiteInterpreterShapeFrame* frame, TfLiteInterpreter* interpreter, |
| iree_vm_function_t apply_fn, int32_t index) { |
| // Populate shape_list with the shape dimensions for this particular output. |
| iree_vm_value_t index_value = iree_vm_value_make_i32(index); |
| IREE_IGNORE_ERROR(iree_vm_list_set_value(frame->arg_list, 0, &index_value)); |
| return iree_vm_invoke(interpreter->context, apply_fn, |
| /*policy=*/NULL, frame->arg_list, /*outputs=*/NULL, |
| interpreter->allocator); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Shape I/O queries |
| //===----------------------------------------------------------------------===// |
| |
| // Queries all input shapes from the module; some may still be dynamic (-1). |
| static iree_status_t _TfLiteInterpreterRefreshInputShapes( |
| TfLiteInterpreter* interpreter, _TfLiteInterpreterShapeFrame* frame) { |
| // NOTE: we could optimize this more by using iree_vm_invoke_within, but that |
| // shouldn't be needed (it's just stack pointer manipulation). |
| for (int32_t i = 0; i < interpreter->model->input_count; ++i) { |
| TfLiteTensor* tensor = &interpreter->input_tensors[i]; |
| IREE_RETURN_IF_ERROR(_TfLiteInterpreterShapeFrameApply( |
| frame, interpreter, interpreter->model->exports._query_input_shape, i)); |
| IREE_RETURN_IF_ERROR(_TfLiteInterpreterShapeFrameReadValue( |
| frame, &tensor->shape_rank, tensor->shape_dims)); |
| } |
| return iree_ok_status(); |
| } |
| |
| // Queries all output shapes from the module allowing it use the current input |
| // shapes to compute the possibly dynamic values. |
| static iree_status_t _TfLiteInterpreterRefreshOutputShapes( |
| TfLiteInterpreter* interpreter, _TfLiteInterpreterShapeFrame* frame) { |
| // NOTE: we could optimize this more by using iree_vm_invoke_within, but that |
| // shouldn't be needed (it's just stack pointer manipulation). |
| for (int32_t i = 0; i < interpreter->model->output_count; ++i) { |
| TfLiteTensor* tensor = &interpreter->output_tensors[i]; |
| IREE_RETURN_IF_ERROR(_TfLiteInterpreterShapeFrameApply( |
| frame, interpreter, interpreter->model->exports._query_output_shape, |
| i)); |
| IREE_RETURN_IF_ERROR(_TfLiteInterpreterShapeFrameReadValue( |
| frame, &tensor->shape_rank, tensor->shape_dims)); |
| } |
| return iree_ok_status(); |
| } |
| |
| // Refreshes both input and output tensor shapes by querying the module. |
| // This should be called after each shape change so that we can let the module |
| // run "shape propagation" and compute the new output shapes. |
| static iree_status_t _TfLiteInterpreterRefreshIOShapes( |
| TfLiteInterpreter* interpreter) { |
| IREE_TRACE_ZONE_BEGIN(z0); |
| _TfLiteInterpreterShapeFrame frame; |
| IREE_RETURN_AND_END_ZONE_IF_ERROR( |
| z0, _TfLiteInterpreterShapeFrameInitialize(&frame)); |
| |
| // Query all shapes. |
| iree_status_t status = iree_ok_status(); |
| if (iree_status_is_ok(status)) { |
| status = _TfLiteInterpreterRefreshInputShapes(interpreter, &frame); |
| } |
| if (iree_status_is_ok(status)) { |
| status = _TfLiteInterpreterRefreshOutputShapes(interpreter, &frame); |
| } |
| |
| _TfLiteInterpreterShapeFrameDeinitialize(&frame); |
| IREE_TRACE_ZONE_END(z0); |
| return status; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Creation and static initialization |
| //===----------------------------------------------------------------------===// |
| |
| // Computes the storage requirement for the TfLiteInterpreter struct. |
| static iree_host_size_t _TfLiteInterpreterCalculateSize( |
| const TfLiteModel* model) { |
| iree_host_size_t total_size = |
| iree_host_align(sizeof(TfLiteInterpreter), iree_max_align_t); |
| |
| iree_vm_type_def_t buffer_view_type_def = |
| iree_vm_type_def_make_ref_type(iree_hal_buffer_type_id()); |
| total_size += |
| iree_vm_list_storage_size(&buffer_view_type_def, model->input_count); |
| total_size += |
| iree_vm_list_storage_size(&buffer_view_type_def, model->output_count); |
| total_size += sizeof(TfLiteTensor) * model->input_count; |
| total_size += sizeof(TfLiteTensor) * model->output_count; |
| |
| return total_size; |
| } |
| |
| // Allocates the interpreter slab and populates all internal pointers to the |
| // appropriate offsets. |
| static iree_status_t _TfLiteInterpreterAllocate( |
| const TfLiteModel* model, TfLiteInterpreter** out_interpreter) { |
| iree_host_size_t interpreter_size = _TfLiteInterpreterCalculateSize(model); |
| TfLiteInterpreter* interpreter = NULL; |
| IREE_RETURN_IF_ERROR(iree_allocator_malloc(model->allocator, interpreter_size, |
| (void**)&interpreter)); |
| memset(interpreter, 0, interpreter_size); |
| interpreter->allocator = model->allocator; |
| _TfLiteInterpreterOptionsSetDefaults(&interpreter->options); |
| *out_interpreter = interpreter; |
| |
| interpreter->model = (TfLiteModel*)model; |
| _TfLiteModelRetain(interpreter->model); |
| |
| uint8_t* p = (uint8_t*)interpreter + |
| iree_host_align(sizeof(*interpreter), iree_max_align_t); |
| |
| iree_vm_type_def_t buffer_view_type_def = |
| iree_vm_type_def_make_ref_type(iree_hal_buffer_type_id()); |
| |
| iree_byte_span_t input_list_storage = iree_make_byte_span( |
| p, iree_vm_list_storage_size(&buffer_view_type_def, model->input_count)); |
| IREE_RETURN_IF_ERROR( |
| iree_vm_list_initialize(input_list_storage, &buffer_view_type_def, |
| model->input_count, &interpreter->input_list)); |
| p += input_list_storage.data_length; |
| |
| iree_byte_span_t output_list_storage = iree_make_byte_span( |
| p, iree_vm_list_storage_size(&buffer_view_type_def, model->output_count)); |
| IREE_RETURN_IF_ERROR( |
| iree_vm_list_initialize(output_list_storage, &buffer_view_type_def, |
| model->output_count, &interpreter->output_list)); |
| p += output_list_storage.data_length; |
| |
| interpreter->input_tensors = (TfLiteTensor*)p; |
| p += sizeof(TfLiteTensor) * model->input_count; |
| interpreter->output_tensors = (TfLiteTensor*)p; |
| // p += sizeof(TfLiteTensor) * model->output_count; |
| |
| return iree_ok_status(); |
| } |
| |
| // Populates the input and output tensor lists with static metadata from the |
| // model and prepares for allocation/invocation. |
| static iree_status_t _TfLiteInterpreterPopulateIO( |
| TfLiteInterpreter* interpreter) { |
| iree_vm_function_t main_fn = interpreter->model->exports._main; |
| iree_string_view_t io_names_attr = iree_vm_function_reflection_attr( |
| &main_fn, iree_make_cstring_view("tfl.io.names")); |
| iree_string_view_t io_types_attr = iree_vm_function_reflection_attr( |
| &main_fn, iree_make_cstring_view("tfl.io.types")); |
| iree_string_view_t io_quant_attr = iree_vm_function_reflection_attr( |
| &main_fn, iree_make_cstring_view("tfl.io.quant")); |
| |
| // Setup static tensor metadata. |
| for (iree_host_size_t i = 0; i < interpreter->model->input_count; ++i) { |
| TfLiteTensor* tensor = &interpreter->input_tensors[i]; |
| memset(tensor, 0, sizeof(*tensor)); |
| iree_string_view_t io_name_part = iree_string_view_empty(); |
| iree_string_view_split(io_names_attr, ';', &io_name_part, &io_names_attr); |
| iree_string_view_t io_type_part = iree_string_view_empty(); |
| iree_string_view_split(io_types_attr, ';', &io_type_part, &io_types_attr); |
| iree_string_view_t io_quant_part = iree_string_view_empty(); |
| iree_string_view_split(io_quant_attr, ';', &io_quant_part, &io_quant_attr); |
| IREE_RETURN_IF_ERROR(_TfLiteTensorParseNameAttr(tensor, io_name_part, |
| interpreter->allocator)); |
| IREE_RETURN_IF_ERROR(_TfLiteTensorParseTypeAttr(tensor, io_type_part)); |
| IREE_RETURN_IF_ERROR(_TfLiteTensorParseQuantAttr(tensor, io_quant_part)); |
| } |
| for (iree_host_size_t i = 0; i < interpreter->model->output_count; ++i) { |
| TfLiteTensor* tensor = &interpreter->output_tensors[i]; |
| memset(tensor, 0, sizeof(*tensor)); |
| iree_string_view_t io_name_part = iree_string_view_empty(); |
| iree_string_view_split(io_names_attr, ';', &io_name_part, &io_names_attr); |
| iree_string_view_t io_type_part = iree_string_view_empty(); |
| iree_string_view_split(io_types_attr, ';', &io_type_part, &io_types_attr); |
| iree_string_view_t io_quant_part = iree_string_view_empty(); |
| iree_string_view_split(io_quant_attr, ';', &io_quant_part, &io_quant_attr); |
| IREE_RETURN_IF_ERROR(_TfLiteTensorParseNameAttr(tensor, io_name_part, |
| interpreter->allocator)); |
| IREE_RETURN_IF_ERROR(_TfLiteTensorParseTypeAttr(tensor, io_type_part)); |
| IREE_RETURN_IF_ERROR(_TfLiteTensorParseQuantAttr(tensor, io_quant_part)); |
| } |
| |
| // Prepare the IO lists we use when calling into the model. |
| // The actual contents of these cannot be set until |
| // TfLiteInterpreterAllocateTensors has been called. |
| IREE_RETURN_IF_ERROR(iree_vm_list_reserve(interpreter->input_list, |
| interpreter->model->input_count)); |
| IREE_RETURN_IF_ERROR(iree_vm_list_reserve(interpreter->output_list, |
| interpreter->model->output_count)); |
| |
| return iree_ok_status(); |
| } |
| |
| static iree_status_t _TfLiteInterpreterCreate( |
| const TfLiteModel* model, const TfLiteInterpreterOptions* optional_options, |
| TfLiteInterpreter** out_interpreter) { |
| *out_interpreter = NULL; |
| |
| // We allocate a large majority of the interpreter structures as a single |
| // slab. There's still some allocations that we could prevent (like internal |
| // VM stuff) but this at least covers half of it. |
| IREE_RETURN_IF_ERROR(_TfLiteInterpreterAllocate(model, out_interpreter)); |
| TfLiteInterpreter* interpreter = *out_interpreter; |
| |
| if (optional_options) { |
| memcpy(&interpreter->options, optional_options, |
| sizeof(interpreter->options)); |
| } |
| |
| interpreter->user_module = model->module; |
| iree_vm_module_retain(interpreter->user_module); |
| |
| // External contexts could possibly used to emulate sharing this, but really |
| // if a user is running with multiple models the tflite API is insufficient. |
| IREE_RETURN_IF_ERROR( |
| iree_vm_instance_create(interpreter->allocator, &interpreter->instance)); |
| IREE_RETURN_IF_ERROR(_TfLiteInterpreterPrepareHAL(interpreter)); |
| |
| // Context will contain both the user-provided bytecode and the HAL module. |
| // If we were to support custom ops we would also have a |
| // tflite_resolver_module that we would register to resolve tflite ops into |
| // IREE functions that will call custom ops through TfLiteRegistrations. |
| IREE_RETURN_IF_ERROR(iree_vm_context_create_with_modules( |
| interpreter->instance, interpreter->all_modules, |
| IREE_ARRAYSIZE(interpreter->all_modules), interpreter->allocator, |
| &interpreter->context)); |
| |
| // Setup all I/O tensors and buffer views. |
| IREE_RETURN_IF_ERROR(_TfLiteInterpreterPopulateIO(interpreter)); |
| |
| return iree_ok_status(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Core API |
| //===----------------------------------------------------------------------===// |
| |
| TFL_CAPI_EXPORT extern TfLiteInterpreter* TfLiteInterpreterCreate( |
| const TfLiteModel* model, |
| const TfLiteInterpreterOptions* optional_options) { |
| IREE_TRACE_ZONE_BEGIN(z0); |
| TfLiteInterpreter* interpreter = NULL; |
| iree_status_t status = |
| _TfLiteInterpreterCreate(model, optional_options, &interpreter); |
| if (iree_status_is_ok(iree_status_consume_code(status))) { |
| IREE_TRACE_ZONE_APPEND_TEXT(z0, "num_threads=", strlen("num_threads=")); |
| IREE_TRACE_ZONE_APPEND_VALUE(z0, interpreter->options.num_threads); |
| } else { |
| IREE_TRACE_MESSAGE(ERROR, "failed interpreter creation"); |
| TfLiteInterpreterDelete(interpreter); |
| interpreter = NULL; |
| } |
| IREE_TRACE_ZONE_END(z0); |
| return interpreter; |
| } |
| |
| TFL_CAPI_EXPORT extern TfLiteInterpreter* |
| TfLiteInterpreterCreateWithSelectedOps( |
| const TfLiteModel* model, const TfLiteInterpreterOptions* options) { |
| // No different from TfLiteInterpreterCreate: we don't have "ops" :) |
| return TfLiteInterpreterCreate(model, options); |
| } |
| |
| TFL_CAPI_EXPORT extern void TfLiteInterpreterDelete( |
| TfLiteInterpreter* interpreter) { |
| IREE_TRACE_ZONE_BEGIN(z0); |
| |
| for (iree_host_size_t i = 0; i < interpreter->model->input_count; ++i) { |
| _TfLiteTensorReset(&interpreter->input_tensors[i], interpreter->allocator); |
| } |
| for (iree_host_size_t i = 0; i < interpreter->model->output_count; ++i) { |
| _TfLiteTensorReset(&interpreter->output_tensors[i], interpreter->allocator); |
| } |
| iree_vm_list_deinitialize(interpreter->input_list); |
| iree_vm_list_deinitialize(interpreter->output_list); |
| |
| iree_vm_context_release(interpreter->context); |
| iree_vm_module_release(interpreter->hal_module); |
| iree_vm_module_release(interpreter->user_module); |
| iree_vm_instance_release(interpreter->instance); |
| |
| _TfLiteModelRelease(interpreter->model); |
| iree_allocator_free(interpreter->allocator, interpreter); |
| |
| IREE_TRACE_ZONE_END(z0); |
| } |
| |
| TFL_CAPI_EXPORT extern TfLiteStatus TfLiteInterpreterResetVariableTensors( |
| TfLiteInterpreter* interpreter) { |
| IREE_TRACE_ZONE_BEGIN(z0); |
| |
| // The compiler emits a special function we can use to reset just variables. |
| // NOTE: the function is optional if the model had no variables. |
| iree_status_t status = iree_ok_status(); |
| iree_vm_function_t reset_variables_fn = |
| interpreter->model->exports._reset_variables; |
| if (!iree_vm_function_is_null(reset_variables_fn)) { |
| status = iree_vm_invoke(interpreter->context, reset_variables_fn, |
| /*policy=*/NULL, /*inputs=*/NULL, /*outputs=*/NULL, |
| interpreter->allocator); |
| } |
| |
| IREE_TRACE_ZONE_END(z0); |
| return _TfLiteStatusFromIREEStatus(status); |
| } |
| |
| TFL_CAPI_EXPORT extern int32_t TfLiteInterpreterGetInputTensorCount( |
| const TfLiteInterpreter* interpreter) { |
| return interpreter->model->input_count; |
| } |
| |
| TFL_CAPI_EXPORT extern TfLiteTensor* TfLiteInterpreterGetInputTensor( |
| const TfLiteInterpreter* interpreter, int32_t input_index) { |
| if (input_index < 0 || input_index >= interpreter->model->input_count) { |
| return NULL; |
| } |
| return &interpreter->input_tensors[input_index]; |
| } |
| |
| static iree_status_t _TfLiteInterpreterResizeInputTensor( |
| TfLiteInterpreter* interpreter, int32_t input_index, const int* input_dims, |
| int32_t input_dims_size) { |
| if (input_index < 0 || input_index >= interpreter->model->input_count) { |
| return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, |
| "input_index out of range (0 <= %d < %d)", |
| input_index, interpreter->model->input_count); |
| } |
| if (iree_vm_function_is_null( |
| interpreter->model->exports._resize_input_shape)) { |
| // TODO(#3975): check if this is a no-op success in tflite. |
| return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, |
| "model has no dynamic shapes"); |
| } |
| |
| _TfLiteInterpreterShapeFrame frame; |
| IREE_RETURN_IF_ERROR(_TfLiteInterpreterShapeFrameInitialize(&frame)); |
| |
| // Poke the model and let it update its internal shape. |
| // TODO(#3975): return bool to allow model to say it failed. |
| TfLiteTensor* tensor = &interpreter->input_tensors[input_index]; |
| iree_status_t status = _TfLiteInterpreterShapeFrameWriteValue( |
| &frame, tensor->shape_rank, tensor->shape_dims); |
| if (iree_status_is_ok(status)) { |
| status = _TfLiteInterpreterShapeFrameApply( |
| &frame, interpreter, interpreter->model->exports._resize_input_shape, |
| input_index); |
| } |
| |
| // NOTE: the allocation may now not match the requested shape. This is just |
| // how the tflite API works unfortunately; until |
| // TfLiteInterpreterAllocateTensors it will remain in an indeterminate state. |
| |
| _TfLiteInterpreterShapeFrameDeinitialize(&frame); |
| return status; |
| } |
| |
| TFL_CAPI_EXPORT extern TfLiteStatus TfLiteInterpreterResizeInputTensor( |
| TfLiteInterpreter* interpreter, int32_t input_index, const int* input_dims, |
| int32_t input_dims_size) { |
| IREE_TRACE_ZONE_BEGIN(z0); |
| iree_status_t status = _TfLiteInterpreterResizeInputTensor( |
| interpreter, input_index, input_dims, input_dims_size); |
| IREE_TRACE_ZONE_END(z0); |
| return _TfLiteStatusFromIREEStatus(status); |
| } |
| |
| static iree_status_t _TfLiteInterpreterAllocateTensors( |
| TfLiteInterpreter* interpreter) { |
| // NOTE: we could slab allocate like tflite does, but then if any single |
| // tensor has any single dimension that is resized the whole thing gets |
| // reallocated upon resize. That's no good. Instead, we realloc each tensor |
| // if their size has changed. |
| |
| // Refresh all shapes from the model. It should have all of the |
| // non-data-dependent output shapes. |
| IREE_RETURN_IF_ERROR(_TfLiteInterpreterRefreshIOShapes(interpreter)); |
| |
| // Drop all input tensors we hang on to in the input list. This way we aren't |
| // double-allocating during the resize. |
| IREE_RETURN_IF_ERROR(iree_vm_list_resize(interpreter->input_list, 0)); |
| |
| // Reallocate input tensors (if needed). |
| for (iree_host_size_t i = 0; i < interpreter->model->input_count; ++i) { |
| TfLiteTensor* tensor = &interpreter->input_tensors[i]; |
| IREE_RETURN_IF_ERROR(_TfLiteTensorReallocateIfNeeded( |
| tensor, iree_hal_device_allocator(interpreter->device), |
| interpreter->allocator)); |
| iree_vm_ref_t buffer_ref = iree_hal_buffer_retain_ref(tensor->buffer); |
| IREE_RETURN_IF_ERROR( |
| iree_vm_list_push_ref_move(interpreter->input_list, &buffer_ref)); |
| } |
| |
| // TODO(benvanik): preallocate outputs when we support using them. |
| // We could stash the buffer views in interpreter->output_list. |
| // For now we just drop them all. |
| for (iree_host_size_t i = 0; i < interpreter->model->output_count; ++i) { |
| _TfLiteTensorDiscardBuffer(&interpreter->output_tensors[i]); |
| } |
| |
| return iree_ok_status(); |
| } |
| |
| TFL_CAPI_EXPORT extern TfLiteStatus TfLiteInterpreterAllocateTensors( |
| TfLiteInterpreter* interpreter) { |
| IREE_TRACE_ZONE_BEGIN(z0); |
| |
| iree_status_t status = _TfLiteInterpreterAllocateTensors(interpreter); |
| |
| #if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION |
| iree_device_size_t total_input_size = 0; |
| for (iree_host_size_t i = 0; i < interpreter->model->input_count; ++i) { |
| total_input_size += |
| iree_hal_buffer_byte_length(interpreter->input_tensors[i].buffer); |
| } |
| IREE_TRACE_ZONE_APPEND_VALUE(z0, total_input_size); |
| #endif // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION |
| |
| IREE_TRACE_ZONE_END(z0); |
| return _TfLiteStatusFromIREEStatus(status); |
| } |
| |
| static iree_status_t _TfLiteInterpreterInvoke(TfLiteInterpreter* interpreter) { |
| // tflite models only have a single entry point and the IREE converter |
| // emits it as '_main'. |
| IREE_RETURN_IF_ERROR( |
| iree_vm_invoke(interpreter->context, interpreter->model->exports._main, |
| /*policy=*/NULL, interpreter->input_list, |
| interpreter->output_list, interpreter->allocator)); |
| |
| // Refresh output shapes. |
| // TODO(#3975): just use buffer view results or at least just refresh outputs. |
| IREE_RETURN_IF_ERROR(_TfLiteInterpreterRefreshIOShapes(interpreter)); |
| |
| // Map the output buffers. |
| // NOTE: we could defer the mapping unless requested and ensure state buffers |
| // remain where they currently are for the next invocation. |
| for (iree_host_size_t i = 0; i < interpreter->model->output_count; ++i) { |
| iree_hal_buffer_t* buffer = (iree_hal_buffer_t*)iree_vm_list_get_ref_deref( |
| interpreter->output_list, i, iree_hal_buffer_get_descriptor()); |
| TfLiteTensor* tensor = &interpreter->output_tensors[i]; |
| IREE_RETURN_IF_ERROR(_TfLiteTensorBind(tensor, buffer)); |
| } |
| |
| return iree_ok_status(); |
| } |
| |
| TFL_CAPI_EXPORT extern TfLiteStatus TfLiteInterpreterInvoke( |
| TfLiteInterpreter* interpreter) { |
| IREE_TRACE_ZONE_BEGIN(z0); |
| |
| iree_status_t status = _TfLiteInterpreterInvoke(interpreter); |
| |
| #if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION |
| iree_device_size_t total_output_size = 0; |
| for (iree_host_size_t i = 0; i < interpreter->model->output_count; ++i) { |
| total_output_size += |
| iree_hal_buffer_byte_length(interpreter->output_tensors[i].buffer); |
| } |
| IREE_TRACE_ZONE_APPEND_VALUE(z0, total_output_size); |
| #endif // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION |
| |
| IREE_TRACE_ZONE_END(z0); |
| return _TfLiteStatusFromIREEStatus(status); |
| } |
| |
| TFL_CAPI_EXPORT extern int32_t TfLiteInterpreterGetOutputTensorCount( |
| const TfLiteInterpreter* interpreter) { |
| return interpreter->model->output_count; |
| } |
| |
| TFL_CAPI_EXPORT extern const TfLiteTensor* TfLiteInterpreterGetOutputTensor( |
| const TfLiteInterpreter* interpreter, int32_t output_index) { |
| if (output_index < 0 || output_index >= interpreter->model->output_count) { |
| return NULL; |
| } |
| return &interpreter->output_tensors[output_index]; |
| } |