blob: 858c823c1f32234696eddd45c988fcdd3899259a [file] [log] [blame]
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "codegen/runtime/micro_codegen_context.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/micro_log.h"
namespace tflite {
MicroCodegenContext::MicroCodegenContext(TfLiteContext* context,
Span<Subgraph> subgraphs)
: context_(context), subgraphs_(subgraphs) {}
void* MicroCodegenContext::GetScratchBuffer(int buffer_idx) {
// TODO(rjascani): Implement scratch buffers
return nullptr;
}
TfLiteEvalTensor* MicroCodegenContext::GetEvalTensor(int tensor_idx) {
TFLITE_DCHECK(static_cast<size_t>(tensor_idx) <
subgraphs_[current_subgraph_idx_].tensors.size());
return &subgraphs_[current_subgraph_idx_].tensors[tensor_idx];
}
TfLiteStatus MicroCodegenContext::set_external_context(
void* external_context_payload) {
if (external_context_payload == nullptr ||
external_context_payload_ != nullptr) {
MicroPrintf(
"Attempting to set external context to %x but it was %x already",
external_context_payload, external_context_payload_);
return kTfLiteError;
}
external_context_payload_ = external_context_payload;
return kTfLiteOk;
}
void* MicroCodegenContext::external_context() {
return external_context_payload_;
}
MicroGraph& MicroCodegenContext::graph() { return *this; }
void* MicroCodegenContext::AllocatePersistentBuffer(size_t) {
// Not allowed at Eval
TFLITE_ABORT;
return nullptr;
}
TfLiteStatus MicroCodegenContext::RequestScratchBufferInArena(size_t, int*) {
// Not allowed at Eval
TFLITE_ABORT;
return kTfLiteError;
}
TfLiteTensor* MicroCodegenContext::AllocateTempTfLiteTensor(int) {
// Not allowed at Eval
TFLITE_ABORT;
return nullptr;
}
void MicroCodegenContext::DeallocateTempTfLiteTensor(TfLiteTensor*) {
// Not allowed at Eval
TFLITE_ABORT;
}
uint8_t* MicroCodegenContext::AllocateTempBuffer(size_t, size_t) {
// Not allowed at Eval
TFLITE_ABORT;
return nullptr;
}
void MicroCodegenContext::DeallocateTempBuffer(uint8_t*) {
// Not allowed at Eval
TFLITE_ABORT;
}
TfLiteStatus MicroCodegenContext::InvokeSubgraph(int subgraph_idx) {
TF_LITE_ENSURE(context_,
static_cast<size_t>(subgraph_idx) < subgraphs_.size());
size_t previous_subgraph_idx = current_subgraph_idx_;
current_subgraph_idx_ = subgraph_idx;
TfLiteStatus status =
subgraphs_[subgraph_idx].invoke(context_, subgraphs_[subgraph_idx].nodes);
current_subgraph_idx_ = previous_subgraph_idx;
return status;
}
size_t MicroCodegenContext::NumSubgraphInputs(int subgraph_idx) {
TFLITE_DCHECK(static_cast<size_t>(subgraph_idx) < subgraphs_.size());
return subgraphs_[subgraph_idx].inputs.size();
}
TfLiteEvalTensor* MicroCodegenContext::GetSubgraphInput(int subgraph_idx,
int input_idx) {
TFLITE_DCHECK(static_cast<size_t>(subgraph_idx) < subgraphs_.size());
TFLITE_DCHECK(static_cast<size_t>(input_idx) <
subgraphs_[subgraph_idx].inputs.size());
const size_t tensor_idx = subgraphs_[subgraph_idx].inputs[input_idx];
return &subgraphs_[subgraph_idx].tensors[tensor_idx];
}
size_t MicroCodegenContext::NumSubgraphOutputs(int subgraph_idx) {
TFLITE_DCHECK(static_cast<size_t>(subgraph_idx) < subgraphs_.size());
return subgraphs_[subgraph_idx].outputs.size();
}
TfLiteEvalTensor* MicroCodegenContext::GetSubgraphOutput(int subgraph_idx,
int output_idx) {
TFLITE_DCHECK(static_cast<size_t>(subgraph_idx) < subgraphs_.size());
TFLITE_DCHECK(static_cast<size_t>(output_idx) <
subgraphs_[subgraph_idx].outputs.size());
const size_t tensor_idx = subgraphs_[subgraph_idx].outputs[output_idx];
return &subgraphs_[subgraph_idx].tensors[tensor_idx];
}
int MicroCodegenContext::NumSubgraphs() { return subgraphs_.size(); }
MicroResourceVariables* MicroCodegenContext::GetResourceVariables() {
return nullptr;
}
} // namespace tflite