blob: 6aa939a4c684fad88ec8d0e70ebde2113f6b7453 [file] [log] [blame]
// Copyright 2023 Google LLC
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0
#include "examples/tflm/soundstream/decoder.h"
#include "examples/tflm/soundstream/decoder_non_stream_q16x8_b64_io_int16_tflite.h"
#include "examples/tflm/soundstream/decoder_streaming_q16x8_b64_io_int16_tflite.h"
namespace kelvin::soundstream::decoder {
constexpr unsigned int kNonStreamingOpCount = 11;
constexpr unsigned int kStreamingOpCount = 16;
// Not sure how to get a good upper bound on this one, so arbitrarily chosen.
constexpr unsigned int kStreamingVariablesCount = 40;
template <bool kStreaming>
class DecoderImpl : public Decoder {
public:
static Decoder* Setup(const uint8_t* model_data, uint8_t* tensor_arena,
size_t tensor_arena_size) {
auto* model = tflite::GetModel(model_data);
if (model->version() != TFLITE_SCHEMA_VERSION) {
return nullptr;
}
DecoderImpl* d = new DecoderImpl(model, tensor_arena, tensor_arena_size);
TfLiteStatus allocate_status = d->interpreter()->AllocateTensors();
if (allocate_status != kTfLiteOk) {
MicroPrintf("Failed to allocate decoder's tensors");
return nullptr;
}
return d;
}
tflite::MicroInterpreter* interpreter() { return &interpreter_; }
private:
DecoderImpl(const tflite::Model* model, uint8_t* tensor_arena,
size_t tensor_arena_size)
: resolver_(CreateResolver()),
allocator_(tflite::RecordingMicroAllocator::Create(tensor_arena,
tensor_arena_size)),
variables_(tflite::MicroResourceVariables::Create(
allocator_.get(), kStreamingVariablesCount)),
interpreter_(model, resolver_, allocator_.get(), variables_.get()) {}
static constexpr int kOpCount =
kStreaming ? kStreamingOpCount : kStreamingOpCount;
static inline tflite::MicroMutableOpResolver<kOpCount> CreateResolver() {
tflite::MicroMutableOpResolver<kOpCount> resolver;
resolver.AddReshape();
resolver.AddPad();
resolver.AddConv2D();
resolver.AddLeakyRelu();
resolver.AddSplit();
resolver.AddTransposeConv();
resolver.AddStridedSlice();
resolver.AddConcatenation();
resolver.AddDepthwiseConv2D();
resolver.AddAdd();
resolver.AddQuantize();
if (kStreaming) {
resolver.AddCallOnce();
resolver.AddVarHandle();
resolver.AddReadVariable();
resolver.AddAssignVariable();
resolver.AddSub();
}
return resolver;
}
const tflite::MicroMutableOpResolver<kOpCount> resolver_;
// Created in the arena
std::unique_ptr<tflite::RecordingMicroAllocator> allocator_;
// Created in the arena
std::unique_ptr<tflite::MicroResourceVariables> variables_;
tflite::MicroInterpreter interpreter_;
};
// Two separate methods to construct streaming vs non-streaming, so that the
// compiler can eliminate one if it's unused. Perhaps with LTO we could combine
// them together.
std::unique_ptr<Decoder> SetupStreaming(uint8_t* tensor_arena,
size_t tensor_arena_size) {
return std::unique_ptr<Decoder>(DecoderImpl<true>::Setup(
g__decoder_streaming_q16x8_b64_io_int16_model_data, tensor_arena,
tensor_arena_size));
}
std::unique_ptr<Decoder> Setup(uint8_t* tensor_arena,
size_t tensor_arena_size) {
return std::unique_ptr<Decoder>(DecoderImpl<false>::Setup(
g__decoder_non_stream_q16x8_b64_io_int16_model_data, tensor_arena,
tensor_arena_size));
}
} // namespace kelvin::soundstream::decoder