blob: 0d18fe7399fdb9cc0b8eacdc96de7bf854d983f5 [file]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/micro/micro_interpreter_graph.h"
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/micro/flatbuffer_utils.h"
#include "tensorflow/lite/micro/memory_helpers.h"
#include "tensorflow/lite/micro/micro_log.h"
#include "tensorflow/lite/micro/micro_profiler.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace tflite {
namespace {
const char* OpNameFromRegistration(const TFLMRegistration* registration) {
if (registration->builtin_code == BuiltinOperator_CUSTOM) {
return registration->custom_name;
} else {
return EnumNameBuiltinOperator(BuiltinOperator(registration->builtin_code));
}
}
} // namespace
MicroInterpreterGraph::MicroInterpreterGraph(
TfLiteContext* context, const Model* model, MicroAllocator* allocator,
MicroResourceVariables* resource_variables)
: context_(context),
model_(model),
allocator_(allocator),
current_subgraph_index_(0),
resource_variables_(resource_variables) {
if (model != nullptr) {
subgraphs_ = model->subgraphs();
}
}
MicroInterpreterGraph::~MicroInterpreterGraph() {}
TfLiteStatus MicroInterpreterGraph::InitSubgraphs() {
int previous_subgraph_idx = current_subgraph_index_;
for (size_t subgraph_idx = 0; subgraph_idx < subgraphs_->size();
subgraph_idx++) {
current_subgraph_index_ = subgraph_idx;
uint32_t operators_size = NumSubgraphOperators(model_, subgraph_idx);
for (size_t i = 0; i < operators_size; ++i) {
TfLiteNode* node =
&(subgraph_allocations_[subgraph_idx].node_and_registrations[i].node);
const TFLMRegistration* registration = subgraph_allocations_[subgraph_idx]
.node_and_registrations[i]
.registration;
size_t init_data_size;
const char* init_data;
if (registration->builtin_code == BuiltinOperator_CUSTOM) {
init_data = reinterpret_cast<const char*>(node->custom_initial_data);
init_data_size = node->custom_initial_data_size;
} else {
init_data = reinterpret_cast<const char*>(node->builtin_data);
init_data_size = 0;
}
if (registration->init) {
node->user_data =
registration->init(context_, init_data, init_data_size);
}
}
}
current_subgraph_index_ = previous_subgraph_idx;
return kTfLiteOk;
}
TfLiteStatus MicroInterpreterGraph::PrepareSubgraphs() {
int previous_subgraph_idx = current_subgraph_index_;
for (size_t subgraph_idx = 0; subgraph_idx < subgraphs_->size();
subgraph_idx++) {
current_subgraph_index_ = subgraph_idx;
uint32_t operators_size = NumSubgraphOperators(model_, subgraph_idx);
for (size_t i = 0; i < operators_size; ++i) {
TfLiteNode* node =
&(subgraph_allocations_[subgraph_idx].node_and_registrations[i].node);
const TFLMRegistration* registration = subgraph_allocations_[subgraph_idx]
.node_and_registrations[i]
.registration;
if (registration->prepare != nullptr) {
TfLiteStatus prepare_status = registration->prepare(context_, node);
if (prepare_status != kTfLiteOk) {
MicroPrintf("Node %s (number %df) failed to prepare with status %d",
OpNameFromRegistration(registration), i, prepare_status);
return kTfLiteError;
}
}
allocator_->FinishPrepareNodeAllocations(/*node_id=*/i);
}
}
current_subgraph_index_ = previous_subgraph_idx;
return kTfLiteOk;
}
TfLiteStatus MicroInterpreterGraph::ResetSubgraphs() {
int previous_subgraph_idx = current_subgraph_index_;
for (size_t subgraph_idx = 0; subgraph_idx < subgraphs_->size();
subgraph_idx++) {
current_subgraph_index_ = subgraph_idx;
uint32_t operators_size = NumSubgraphOperators(model_, subgraph_idx);
for (size_t i = 0; i < operators_size; ++i) {
TfLiteNode* node =
&(subgraph_allocations_[subgraph_idx].node_and_registrations[i].node);
const TFLMRegistration* registration = subgraph_allocations_[subgraph_idx]
.node_and_registrations[i]
.registration;
// registration is allocated outside the interpreter, so double check to
// make sure it's not nullptr;
if (registration != nullptr && registration->reset != nullptr) {
registration->reset(context_, node->user_data);
}
}
}
current_subgraph_index_ = previous_subgraph_idx;
return kTfLiteOk;
}
TfLiteStatus MicroInterpreterGraph::FreeSubgraphs() {
int previous_subgraph_idx = current_subgraph_index_;
for (size_t subgraph_idx = 0; subgraph_idx < subgraphs_->size();
subgraph_idx++) {
current_subgraph_index_ = subgraph_idx;
uint32_t operators_size = NumSubgraphOperators(model_, subgraph_idx);
for (size_t i = 0; i < operators_size; ++i) {
TfLiteNode* node =
&(subgraph_allocations_[subgraph_idx].node_and_registrations[i].node);
const TFLMRegistration* registration = subgraph_allocations_[subgraph_idx]
.node_and_registrations[i]
.registration;
// registration is allocated outside the interpreter, so double check to
// make sure it's not nullptr;
if (registration != nullptr && registration->free != nullptr) {
registration->free(context_, node->user_data);
}
}
}
current_subgraph_index_ = previous_subgraph_idx;
return kTfLiteOk;
}
TfLiteStatus MicroInterpreterGraph::InvokeSubgraph(int subgraph_idx) {
int previous_subgraph_idx = current_subgraph_index_;
current_subgraph_index_ = subgraph_idx;
if (static_cast<size_t>(subgraph_idx) >= subgraphs_->size()) {
MicroPrintf("Accessing subgraph %d but only %d subgraphs found",
subgraph_idx, subgraphs_->size());
return kTfLiteError;
}
uint32_t operators_size = NumSubgraphOperators(model_, subgraph_idx);
for (size_t i = 0; i < operators_size; ++i) {
TfLiteNode* node =
&(subgraph_allocations_[subgraph_idx].node_and_registrations[i].node);
const TFLMRegistration* registration = subgraph_allocations_[subgraph_idx]
.node_and_registrations[i]
.registration;
// This ifdef is needed (even though ScopedMicroProfiler itself is a no-op with
// -DTF_LITE_STRIP_ERROR_STRINGS) because the function OpNameFromRegistration is
// only defined for builds with the error strings.
#if !defined(TF_LITE_STRIP_ERROR_STRINGS)
ScopedMicroProfiler scoped_profiler(
OpNameFromRegistration(registration),
reinterpret_cast<MicroProfilerInterface*>(context_->profiler));
#endif
TFLITE_DCHECK(registration->invoke);
TfLiteStatus invoke_status = registration->invoke(context_, node);
// All TfLiteTensor structs used in the kernel are allocated from temp
// memory in the allocator. This creates a chain of allocations in the
// temp section. The call below resets the chain of allocations to
// prepare for the next call.
allocator_->ResetTempAllocations();
if (invoke_status == kTfLiteError) {
MicroPrintf("Node %s (number %d) failed to invoke with status %d",
OpNameFromRegistration(registration), i, invoke_status);
return kTfLiteError;
} else if (invoke_status != kTfLiteOk) {
return invoke_status;
}
}
current_subgraph_index_ = previous_subgraph_idx;
return kTfLiteOk;
}
TfLiteStatus MicroInterpreterGraph::ResetVariableTensors() {
for (size_t subgraph_idx = 0; subgraph_idx < subgraphs_->size();
subgraph_idx++) {
const SubGraph* subgraph = (*subgraphs_)[subgraph_idx];
for (size_t i = 0; i < subgraph->tensors()->size(); ++i) {
auto* tensor = subgraph->tensors()->Get(i);
if (tensor->is_variable()) {
size_t buffer_size;
TF_LITE_ENSURE_STATUS(TfLiteEvalTensorByteLength(
&subgraph_allocations_[subgraph_idx].tensors[i], &buffer_size));
int value = 0;
if (tensor->type() == tflite::TensorType_INT8) {
value = tensor->quantization()->zero_point()->Get(0);
}
memset(subgraph_allocations_[subgraph_idx].tensors[i].data.raw, value,
buffer_size);
}
}
}
if (resource_variables_ != nullptr) {
resource_variables_->ResetAll();
}
return kTfLiteOk;
}
int MicroInterpreterGraph::NumSubgraphs() {
return model_->subgraphs()->size();
}
void MicroInterpreterGraph::SetSubgraphAllocations(
SubgraphAllocations* subgraph_allocations) {
subgraph_allocations_ = subgraph_allocations;
}
size_t MicroInterpreterGraph::NumSubgraphInputs(int subgraph_idx) {
return model_->subgraphs()->Get(subgraph_idx)->inputs()->size();
}
TfLiteEvalTensor* MicroInterpreterGraph::GetSubgraphInput(int subgraph_idx,
int input_idx) {
int tensor_idx =
model_->subgraphs()->Get(subgraph_idx)->inputs()->Get(input_idx);
return &subgraph_allocations_[subgraph_idx].tensors[tensor_idx];
}
size_t MicroInterpreterGraph::NumSubgraphOutputs(int subgraph_idx) {
return model_->subgraphs()->Get(subgraph_idx)->outputs() == nullptr
? 0
: model_->subgraphs()->Get(subgraph_idx)->outputs()->size();
}
TfLiteEvalTensor* MicroInterpreterGraph::GetSubgraphOutput(int subgraph_idx,
int output_idx) {
int tensor_idx =
model_->subgraphs()->Get(subgraph_idx)->outputs()->Get(output_idx);
return &subgraph_allocations_[subgraph_idx].tensors[tensor_idx];
}
} // namespace tflite