blob: ca01a63bce47fbbea88ba8f0448000a3f3bc3ece [file] [log] [blame]
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef CODEGEN_RUNTIME_MICRO_CODEGEN_CONTEXT_H_
#define CODEGEN_RUNTIME_MICRO_CODEGEN_CONTEXT_H_
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/micro_context.h"
#include "tensorflow/lite/micro/micro_graph.h"
namespace tflite {
// A poor man's std::span, we should consider using the Pigweed span instead.
template <typename T>
class Span {
public:
constexpr Span(T* data, size_t size) noexcept : data_(data), size_(size) {}
constexpr T& operator[](size_t idx) const noexcept { return *(data_ + idx); }
constexpr T* data() const noexcept { return data_; }
constexpr size_t size() const noexcept { return size_; }
private:
T* data_;
size_t size_;
};
struct Subgraph {
Span<const size_t> inputs;
Span<const size_t> outputs;
Span<TfLiteNode> nodes;
Span<TfLiteEvalTensor> tensors;
TfLiteStatus (*invoke)(TfLiteContext*, Span<TfLiteNode>);
};
class MicroCodegenContext : public MicroContext, MicroGraph {
public:
MicroCodegenContext(TfLiteContext* context, Span<Subgraph> subgraphs);
~MicroCodegenContext() = default;
// MicroContext API
void* AllocatePersistentBuffer(size_t bytes) override;
TfLiteStatus RequestScratchBufferInArena(size_t bytes,
int* buffer_idx) override;
void* GetScratchBuffer(int buffer_idx) override;
TfLiteTensor* AllocateTempTfLiteTensor(int tensor_idx) override;
void DeallocateTempTfLiteTensor(TfLiteTensor* tensor) override;
uint8_t* AllocateTempBuffer(size_t size, size_t alignment) override;
void DeallocateTempBuffer(uint8_t* buffer) override;
TfLiteEvalTensor* GetEvalTensor(int tensor_idx) override;
TfLiteStatus set_external_context(void* external_context_payload) override;
void* external_context() override;
MicroGraph& graph() override;
// MicroGraph API
TfLiteStatus InvokeSubgraph(int subgraph_idx) override;
size_t NumSubgraphInputs(int subgraph_idx) override;
TfLiteEvalTensor* GetSubgraphInput(int subgraph_idx, int input_idx) override;
size_t NumSubgraphOutputs(int subgraph_idx) override;
TfLiteEvalTensor* GetSubgraphOutput(int subgraph_idx,
int output_idx) override;
int NumSubgraphs() override;
MicroResourceVariables* GetResourceVariables() override;
private:
TfLiteContext* context_;
Span<Subgraph> subgraphs_;
size_t current_subgraph_idx_ = 0;
void* external_context_payload_ = nullptr;
TF_LITE_REMOVE_VIRTUAL_DELETE
};
} // namespace tflite
#endif // CODEGEN_RUNTIME_MICRO_CODEGEN_CONTEXT_H_