blob: 5fc3b4ca98888f76b44580766045c3ff34655e5e [file]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_MICRO_INTERPRETER_GRAPH_H_
#define TENSORFLOW_LITE_MICRO_MICRO_INTERPRETER_GRAPH_H_
#include "tensorflow/lite/micro/micro_allocator.h"
#include "tensorflow/lite/micro/micro_common.h"
#include "tensorflow/lite/micro/micro_graph.h"
#include "tensorflow/lite/micro/micro_resource_variable.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace tflite {
// Abstracts the details of interacting with the tflite::Model.
//
// Provides methods to access, initialize, prepare, invoke and free any
// subgraph in the tflite::Graph.
class MicroInterpreterGraph : public MicroGraph {
public:
// The lifetime of the context, model, allocator and resource_variables must
// be at least as long as that of the graph object, since the this class may
// need to access them at any time. If resource_variables is a nullptr,
// GetResourceVariables will return a nullptr.
MicroInterpreterGraph(TfLiteContext* context, const Model* model,
MicroAllocator* allocator,
MicroResourceVariables* resource_variables);
virtual ~MicroInterpreterGraph();
// Sets up builtin data and calls TFLMRegistration->Init for every
// operator in every subgraph in the model.
virtual TfLiteStatus InitSubgraphs();
// Calls TFLMRegistration->Prepare for every operator in every subgraph
// in the model.
virtual TfLiteStatus PrepareSubgraphs();
// Calls TFLMRegistration->Reset for every operator in every subgraph in
// the model.
virtual TfLiteStatus ResetSubgraphs();
// Calls TFLMRegistration->Free for every operator in every subgraph in
// the model.
virtual TfLiteStatus FreeSubgraphs();
// Calls TFLMRegistration->Invoke for every operator in a single subgraph
// in the model.
virtual TfLiteStatus InvokeSubgraph(int subgraph_idx);
// Zeros out all variable tensors in all subgraphs in the model.
virtual TfLiteStatus ResetVariableTensors();
// Number of tensor inputs to a specified subgraph in the model.
virtual size_t NumSubgraphInputs(int subgraph_idx);
// Get the specified input tensor of a specified subgraph in the model.
virtual TfLiteEvalTensor* GetSubgraphInput(int subgraph_idx, int input_idx);
// Number of tensor outputs from a specified subgraph in the model.
virtual size_t NumSubgraphOutputs(int subgraph_idx);
// Get the specified output tensor of a specified subgraph in the model.
virtual TfLiteEvalTensor* GetSubgraphOutput(int subgraph_idx, int output_idx);
// Number of subgraphs in the model.
virtual int NumSubgraphs();
// Hook to pass in subgraph allocations tracked within the interpreter,
// allowing MicroInterpreterGraph to init / prepare / invoke subgraphs in the
// model.
void SetSubgraphAllocations(SubgraphAllocations* subgraph_allocations);
// Get the current subgraph index. Within an on operator, this is guaranteed
// to be the subgraph of that operator.
int GetCurrentSubgraphIndex() { return current_subgraph_index_; }
// Get the current operator index inside a subgraph.
// The couple GetCurrentSubgraphIndex GetCurrentSubgraphIndex creates a unique
// identifier of the operator inside the subgraph
int GetCurrentOperatorIndex() { return current_operator_index_; }
// Gets the list of allocations for each subgraph. This is the source of truth
// for all per-subgraph allocation data.
SubgraphAllocations* GetAllocations() { return subgraph_allocations_; }
// Get the resource variables for this TFLM graph.
MicroResourceVariables* GetResourceVariables() { return resource_variables_; }
private:
TfLiteContext* context_;
const Model* model_;
MicroAllocator* allocator_;
SubgraphAllocations* subgraph_allocations_ = nullptr;
int current_subgraph_index_;
uint32_t current_operator_index_;
MicroResourceVariables* resource_variables_;
const flatbuffers::Vector<flatbuffers::Offset<SubGraph>>* subgraphs_ =
nullptr; // Initialized as nullptr to prevent any possible issues
// related to accessing uninitialized memory.
TF_LITE_REMOVE_VIRTUAL_DELETE
};
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_MICRO_INTERPRETER_GRAPH_H_