Adds GetCurrentOperatorIndex to micro_interpreter (#2605)
In this way an AOT can add metadata that could then be accessed with the couple GetCurrentSubgraphIndex, GetCurrentOperatorIndex.
These can be accessed from tflite::GetMicroContext(context)->graph()
Refers to #2593
BUG=#2593
diff --git a/tensorflow/lite/micro/micro_interpreter_graph.cc b/tensorflow/lite/micro/micro_interpreter_graph.cc
index 0d18fe7..7f096ae 100644
--- a/tensorflow/lite/micro/micro_interpreter_graph.cc
+++ b/tensorflow/lite/micro/micro_interpreter_graph.cc
@@ -44,6 +44,7 @@
model_(model),
allocator_(allocator),
current_subgraph_index_(0),
+ current_operator_index_(0),
resource_variables_(resource_variables) {
if (model != nullptr) {
subgraphs_ = model->subgraphs();
@@ -54,17 +55,21 @@
TfLiteStatus MicroInterpreterGraph::InitSubgraphs() {
int previous_subgraph_idx = current_subgraph_index_;
+ uint32_t previous_operator_idx = current_operator_index_;
for (size_t subgraph_idx = 0; subgraph_idx < subgraphs_->size();
subgraph_idx++) {
current_subgraph_index_ = subgraph_idx;
uint32_t operators_size = NumSubgraphOperators(model_, subgraph_idx);
- for (size_t i = 0; i < operators_size; ++i) {
- TfLiteNode* node =
- &(subgraph_allocations_[subgraph_idx].node_and_registrations[i].node);
- const TFLMRegistration* registration = subgraph_allocations_[subgraph_idx]
- .node_and_registrations[i]
- .registration;
+ for (current_operator_index_ = 0; current_operator_index_ < operators_size;
+ ++current_operator_index_) {
+ TfLiteNode* node = &(subgraph_allocations_[subgraph_idx]
+ .node_and_registrations[current_operator_index_]
+ .node);
+ const TFLMRegistration* registration =
+ subgraph_allocations_[subgraph_idx]
+ .node_and_registrations[current_operator_index_]
+ .registration;
size_t init_data_size;
const char* init_data;
if (registration->builtin_code == BuiltinOperator_CUSTOM) {
@@ -81,52 +86,62 @@
}
}
current_subgraph_index_ = previous_subgraph_idx;
+ current_operator_index_ = previous_operator_idx;
return kTfLiteOk;
}
TfLiteStatus MicroInterpreterGraph::PrepareSubgraphs() {
int previous_subgraph_idx = current_subgraph_index_;
-
+ uint32_t previous_operator_idx = current_operator_index_;
for (size_t subgraph_idx = 0; subgraph_idx < subgraphs_->size();
subgraph_idx++) {
current_subgraph_index_ = subgraph_idx;
uint32_t operators_size = NumSubgraphOperators(model_, subgraph_idx);
- for (size_t i = 0; i < operators_size; ++i) {
- TfLiteNode* node =
- &(subgraph_allocations_[subgraph_idx].node_and_registrations[i].node);
- const TFLMRegistration* registration = subgraph_allocations_[subgraph_idx]
- .node_and_registrations[i]
- .registration;
+ for (current_operator_index_ = 0; current_operator_index_ < operators_size;
+ ++current_operator_index_) {
+ TfLiteNode* node = &(subgraph_allocations_[subgraph_idx]
+ .node_and_registrations[current_operator_index_]
+ .node);
+ const TFLMRegistration* registration =
+ subgraph_allocations_[subgraph_idx]
+ .node_and_registrations[current_operator_index_]
+ .registration;
if (registration->prepare != nullptr) {
TfLiteStatus prepare_status = registration->prepare(context_, node);
if (prepare_status != kTfLiteOk) {
MicroPrintf("Node %s (number %df) failed to prepare with status %d",
- OpNameFromRegistration(registration), i, prepare_status);
+ OpNameFromRegistration(registration),
+ current_operator_index_, prepare_status);
return kTfLiteError;
}
}
- allocator_->FinishPrepareNodeAllocations(/*node_id=*/i);
+ allocator_->FinishPrepareNodeAllocations(
+ /*node_id=*/current_operator_index_);
}
}
current_subgraph_index_ = previous_subgraph_idx;
-
+ current_operator_index_ = previous_operator_idx;
return kTfLiteOk;
}
TfLiteStatus MicroInterpreterGraph::ResetSubgraphs() {
int previous_subgraph_idx = current_subgraph_index_;
+ uint32_t previous_operator_idx = current_operator_index_;
for (size_t subgraph_idx = 0; subgraph_idx < subgraphs_->size();
subgraph_idx++) {
current_subgraph_index_ = subgraph_idx;
uint32_t operators_size = NumSubgraphOperators(model_, subgraph_idx);
- for (size_t i = 0; i < operators_size; ++i) {
- TfLiteNode* node =
- &(subgraph_allocations_[subgraph_idx].node_and_registrations[i].node);
- const TFLMRegistration* registration = subgraph_allocations_[subgraph_idx]
- .node_and_registrations[i]
- .registration;
+ for (current_operator_index_ = 0; current_operator_index_ < operators_size;
+ ++current_operator_index_) {
+ TfLiteNode* node = &(subgraph_allocations_[subgraph_idx]
+ .node_and_registrations[current_operator_index_]
+ .node);
+ const TFLMRegistration* registration =
+ subgraph_allocations_[subgraph_idx]
+ .node_and_registrations[current_operator_index_]
+ .registration;
// registration is allocated outside the interpreter, so double check to
// make sure it's not nullptr;
if (registration != nullptr && registration->reset != nullptr) {
@@ -135,23 +150,28 @@
}
}
current_subgraph_index_ = previous_subgraph_idx;
+ current_operator_index_ = previous_operator_idx;
return kTfLiteOk;
}
TfLiteStatus MicroInterpreterGraph::FreeSubgraphs() {
int previous_subgraph_idx = current_subgraph_index_;
+ uint32_t previous_operator_idx = current_operator_index_;
for (size_t subgraph_idx = 0; subgraph_idx < subgraphs_->size();
subgraph_idx++) {
current_subgraph_index_ = subgraph_idx;
uint32_t operators_size = NumSubgraphOperators(model_, subgraph_idx);
- for (size_t i = 0; i < operators_size; ++i) {
- TfLiteNode* node =
- &(subgraph_allocations_[subgraph_idx].node_and_registrations[i].node);
- const TFLMRegistration* registration = subgraph_allocations_[subgraph_idx]
- .node_and_registrations[i]
- .registration;
+ for (current_operator_index_ = 0; current_operator_index_ < operators_size;
+ ++current_operator_index_) {
+ TfLiteNode* node = &(subgraph_allocations_[subgraph_idx]
+ .node_and_registrations[current_operator_index_]
+ .node);
+ const TFLMRegistration* registration =
+ subgraph_allocations_[subgraph_idx]
+ .node_and_registrations[current_operator_index_]
+ .registration;
// registration is allocated outside the interpreter, so double check to
// make sure it's not nullptr;
if (registration != nullptr && registration->free != nullptr) {
@@ -160,12 +180,14 @@
}
}
current_subgraph_index_ = previous_subgraph_idx;
+ current_operator_index_ = previous_operator_idx;
return kTfLiteOk;
}
TfLiteStatus MicroInterpreterGraph::InvokeSubgraph(int subgraph_idx) {
int previous_subgraph_idx = current_subgraph_index_;
+ uint32_t previous_operator_idx = current_operator_index_;
current_subgraph_index_ = subgraph_idx;
if (static_cast<size_t>(subgraph_idx) >= subgraphs_->size()) {
@@ -174,12 +196,15 @@
return kTfLiteError;
}
uint32_t operators_size = NumSubgraphOperators(model_, subgraph_idx);
- for (size_t i = 0; i < operators_size; ++i) {
- TfLiteNode* node =
- &(subgraph_allocations_[subgraph_idx].node_and_registrations[i].node);
- const TFLMRegistration* registration = subgraph_allocations_[subgraph_idx]
- .node_and_registrations[i]
- .registration;
+ for (current_operator_index_ = 0; current_operator_index_ < operators_size;
+ ++current_operator_index_) {
+ TfLiteNode* node = &(subgraph_allocations_[subgraph_idx]
+ .node_and_registrations[current_operator_index_]
+ .node);
+ const TFLMRegistration* registration =
+ subgraph_allocations_[subgraph_idx]
+ .node_and_registrations[current_operator_index_]
+ .registration;
// This ifdef is needed (even though ScopedMicroProfiler itself is a no-op with
// -DTF_LITE_STRIP_ERROR_STRINGS) because the function OpNameFromRegistration is
@@ -201,13 +226,15 @@
if (invoke_status == kTfLiteError) {
MicroPrintf("Node %s (number %d) failed to invoke with status %d",
- OpNameFromRegistration(registration), i, invoke_status);
+ OpNameFromRegistration(registration), current_operator_index_,
+ invoke_status);
return kTfLiteError;
} else if (invoke_status != kTfLiteOk) {
return invoke_status;
}
}
current_subgraph_index_ = previous_subgraph_idx;
+ current_operator_index_ = previous_operator_idx;
return kTfLiteOk;
}
diff --git a/tensorflow/lite/micro/micro_interpreter_graph.h b/tensorflow/lite/micro/micro_interpreter_graph.h
index 5c2121a..7ae0794 100644
--- a/tensorflow/lite/micro/micro_interpreter_graph.h
+++ b/tensorflow/lite/micro/micro_interpreter_graph.h
@@ -86,7 +86,12 @@
// to be the subgraph of that operator.
int GetCurrentSubgraphIndex() { return current_subgraph_index_; }
- // Gets the list of alloctions for each subgraph. This is the source of truth
+ // Get the current operator index inside a subgraph.
+ // The couple GetCurrentSubgraphIndex GetCurrentSubgraphIndex creates a unique
+ // identifier of the operator inside the subgraph
+ int GetCurrentOperatorIndex() { return current_operator_index_; }
+
+ // Gets the list of allocations for each subgraph. This is the source of truth
// for all per-subgraph allocation data.
SubgraphAllocations* GetAllocations() { return subgraph_allocations_; }
@@ -99,6 +104,7 @@
MicroAllocator* allocator_;
SubgraphAllocations* subgraph_allocations_ = nullptr;
int current_subgraph_index_;
+ uint32_t current_operator_index_;
MicroResourceVariables* resource_variables_;
const flatbuffers::Vector<flatbuffers::Offset<SubGraph>>* subgraphs_;