Merge "Add support for SoundStream "streaming" version"
diff --git a/examples/tflm/soundstream/BUILD b/examples/tflm/soundstream/BUILD
index b3a41f3..2cb2d2d 100644
--- a/examples/tflm/soundstream/BUILD
+++ b/examples/tflm/soundstream/BUILD
@@ -6,11 +6,15 @@
name = "soundstream",
srcs = [
"decoder.cc",
- "encoder.cc",
"decoder_non_stream_q16x8_b64_io_int16_tflite.cc",
- "encoder_non_stream_q16x8_b64_io_int16_tflite.cc",
"decoder_non_stream_q16x8_b64_io_int16_tflite.h",
+ "decoder_streaming_q16x8_b64_io_int16_tflite.cc",
+ "decoder_streaming_q16x8_b64_io_int16_tflite.h",
+ "encoder.cc",
+ "encoder_non_stream_q16x8_b64_io_int16_tflite.cc",
"encoder_non_stream_q16x8_b64_io_int16_tflite.h",
+ "encoder_streaming_q16x8_b64_io_int16_tflite.cc",
+ "encoder_streaming_q16x8_b64_io_int16_tflite.h",
],
hdrs = [
"decoder.h",
@@ -19,11 +23,12 @@
tags = ["manual"],
deps = [
"@tflite-micro//tensorflow/lite/micro:micro_framework",
+ "@tflite-micro//tensorflow/lite/micro:recording_allocators",
],
)
kelvin_binary(
- name = "soundstream_decoder",
+ name = "soundstream_decoder_non_streaming",
srcs = [
"soundstream_decoder.cc",
],
@@ -31,13 +36,24 @@
deps = [
":soundstream",
"//crt:crt_header",
- "@tflite-micro//tensorflow/lite/micro:micro_framework",
- "@tflite-micro//tensorflow/lite/micro:system_setup",
],
)
kelvin_binary(
- name = "soundstream_encoder",
+ name = "soundstream_decoder_streaming",
+ srcs = [
+ "soundstream_decoder.cc",
+ ],
+ copts = ["-DSTREAMING"],
+ tags = ["manual"],
+ deps = [
+ ":soundstream",
+ "//crt:crt_header",
+ ],
+)
+
+kelvin_binary(
+ name = "soundstream_encoder_non_streaming",
srcs = [
"soundstream_encoder.cc",
],
@@ -45,13 +61,24 @@
deps = [
":soundstream",
"//crt:crt_header",
- "@tflite-micro//tensorflow/lite/micro:micro_framework",
- "@tflite-micro//tensorflow/lite/micro:system_setup",
],
)
kelvin_binary(
- name = "soundstream_e2e",
+ name = "soundstream_encoder_streaming",
+ srcs = [
+ "soundstream_encoder.cc",
+ ],
+ copts = ["-DSTREAMING"],
+ tags = ["manual"],
+ deps = [
+ ":soundstream",
+ "//crt:crt_header",
+ ],
+)
+
+kelvin_binary(
+ name = "soundstream_e2e_non_streaming",
srcs = [
"best_of_times_s16_decoded.cc",
"best_of_times_s16_encoded.cc",
@@ -67,8 +94,27 @@
deps = [
":soundstream",
"//crt:crt_header",
- "@tflite-micro//tensorflow/lite/micro:micro_framework",
- "@tflite-micro//tensorflow/lite/micro:system_setup",
+ ],
+)
+
+kelvin_binary(
+ name = "soundstream_e2e_streaming",
+ srcs = [
+ "best_of_times_s16_decoded_streaming.cc",
+ "best_of_times_s16_encoded_streaming.cc",
+ "best_of_times_s16_wav.cc",
+ "soundstream_e2e.cc",
+ ],
+ hdrs = [
+ "best_of_times_s16_decoded_streaming.h",
+ "best_of_times_s16_encoded_streaming.h",
+ "best_of_times_s16_wav.h",
+ ],
+ copts = ["-DSTREAMING"],
+ tags = ["manual"],
+ deps = [
+ ":soundstream",
+ "//crt:crt_header",
],
)
@@ -101,6 +147,34 @@
)
generate_cc_arrays(
+ name = "decoder_streaming_q16x8_b64_io_int16_tflite_cc",
+ src = "@ml-models//:quant_models/_decoder_streaming_q16x8_b64_io_int16.tflite",
+ out = "decoder_streaming_q16x8_b64_io_int16_tflite.cc",
+ tags = ["manual"],
+)
+
+generate_cc_arrays(
+ name = "decoder_streaming_q16x8_b64_io_int16_tflite_h",
+ src = "@ml-models//:quant_models/_decoder_streaming_q16x8_b64_io_int16.tflite",
+ out = "decoder_streaming_q16x8_b64_io_int16_tflite.h",
+ tags = ["manual"],
+)
+
+generate_cc_arrays(
+ name = "encoder_streaming_q16x8_b64_io_int16_tflite_cc",
+ src = "@ml-models//:quant_models/_encoder_streaming_q16x8_b64_io_int16.tflite",
+ out = "encoder_streaming_q16x8_b64_io_int16_tflite.cc",
+ tags = ["manual"],
+)
+
+generate_cc_arrays(
+ name = "encoder_streaming_q16x8_b64_io_int16_tflite_h",
+ src = "@ml-models//:quant_models/_encoder_streaming_q16x8_b64_io_int16.tflite",
+ out = "encoder_streaming_q16x8_b64_io_int16_tflite.h",
+ tags = ["manual"],
+)
+
+generate_cc_arrays(
name = "best_of_times_s16_wav_cc",
src = "best_of_times_s16.wav",
out = "best_of_times_s16_wav.cc",
@@ -135,3 +209,27 @@
src = "best_of_times_s16_decoded.raw",
out = "best_of_times_s16_decoded.h",
)
+
+generate_cc_arrays(
+ name = "best_of_times_s16_encoded_streaming_cc",
+ src = "best_of_times_s16_encoded_streaming.raw",
+ out = "best_of_times_s16_encoded_streaming.cc",
+)
+
+generate_cc_arrays(
+ name = "best_of_times_s16_encoded_streaming_h",
+ src = "best_of_times_s16_encoded_streaming.raw",
+ out = "best_of_times_s16_encoded_streaming.h",
+)
+
+generate_cc_arrays(
+ name = "best_of_times_s16_decoded_streaming_cc",
+ src = "best_of_times_s16_decoded_streaming.raw",
+ out = "best_of_times_s16_decoded_streaming.cc",
+)
+
+generate_cc_arrays(
+ name = "best_of_times_s16_decoded_streaming_h",
+ src = "best_of_times_s16_decoded_streaming.raw",
+ out = "best_of_times_s16_decoded_streaming.h",
+)
diff --git a/examples/tflm/soundstream/best_of_times_s16_decoded_streaming.raw b/examples/tflm/soundstream/best_of_times_s16_decoded_streaming.raw
new file mode 100644
index 0000000..91f655e
--- /dev/null
+++ b/examples/tflm/soundstream/best_of_times_s16_decoded_streaming.raw
Binary files differ
diff --git a/examples/tflm/soundstream/best_of_times_s16_encoded_streaming.raw b/examples/tflm/soundstream/best_of_times_s16_encoded_streaming.raw
new file mode 100644
index 0000000..7956e3c
--- /dev/null
+++ b/examples/tflm/soundstream/best_of_times_s16_encoded_streaming.raw
Binary files differ
diff --git a/examples/tflm/soundstream/decoder.cc b/examples/tflm/soundstream/decoder.cc
index 2a0fc4c..6aa939a 100644
--- a/examples/tflm/soundstream/decoder.cc
+++ b/examples/tflm/soundstream/decoder.cc
@@ -1,37 +1,96 @@
+// Copyright 2023 Google LLC
+// Licensed under the Apache License, Version 2.0, see LICENSE for details.
+// SPDX-License-Identifier: Apache-2.0
+
#include "examples/tflm/soundstream/decoder.h"
#include "examples/tflm/soundstream/decoder_non_stream_q16x8_b64_io_int16_tflite.h"
+#include "examples/tflm/soundstream/decoder_streaming_q16x8_b64_io_int16_tflite.h"
namespace kelvin::soundstream::decoder {
-std::optional<Decoder> Setup(uint8_t* tensor_arena) {
- auto* model =
- tflite::GetModel(g__decoder_non_stream_q16x8_b64_io_int16_model_data);
- if (model->version() != TFLITE_SCHEMA_VERSION) {
- return {};
+
+constexpr unsigned int kNonStreamingOpCount = 11;
+constexpr unsigned int kStreamingOpCount = 16;
+// Not sure how to get a good upper bound on this one, so arbitrarily chosen.
+constexpr unsigned int kStreamingVariablesCount = 40;
+
+template <bool kStreaming>
+class DecoderImpl : public Decoder {
+ public:
+ static Decoder* Setup(const uint8_t* model_data, uint8_t* tensor_arena,
+ size_t tensor_arena_size) {
+ auto* model = tflite::GetModel(model_data);
+ if (model->version() != TFLITE_SCHEMA_VERSION) {
+ return nullptr;
+ }
+
+ DecoderImpl* d = new DecoderImpl(model, tensor_arena, tensor_arena_size);
+
+ TfLiteStatus allocate_status = d->interpreter()->AllocateTensors();
+ if (allocate_status != kTfLiteOk) {
+ MicroPrintf("Failed to allocate decoder's tensors");
+ return nullptr;
+ }
+ return d;
}
+ tflite::MicroInterpreter* interpreter() { return &interpreter_; }
- Decoder d;
- d.resolver = std::make_unique<tflite::MicroMutableOpResolver<11>>();
- d.resolver->AddReshape();
- d.resolver->AddPad();
- d.resolver->AddConv2D();
- d.resolver->AddLeakyRelu();
- d.resolver->AddSplit();
- d.resolver->AddTransposeConv();
- d.resolver->AddStridedSlice();
- d.resolver->AddConcatenation();
- d.resolver->AddDepthwiseConv2D();
- d.resolver->AddAdd();
- d.resolver->AddQuantize();
+ private:
+ DecoderImpl(const tflite::Model* model, uint8_t* tensor_arena,
+ size_t tensor_arena_size)
+ : resolver_(CreateResolver()),
+ allocator_(tflite::RecordingMicroAllocator::Create(tensor_arena,
+ tensor_arena_size)),
+ variables_(tflite::MicroResourceVariables::Create(
+ allocator_.get(), kStreamingVariablesCount)),
+ interpreter_(model, resolver_, allocator_.get(), variables_.get()) {}
- d.interpreter = std::make_unique<tflite::MicroInterpreter>(
- model, *d.resolver, tensor_arena, kTensorArenaSizeBytes);
-
- TfLiteStatus allocate_status = d.interpreter->AllocateTensors();
- if (allocate_status != kTfLiteOk) {
- MicroPrintf("Failed to allocate decoder's tensors");
- return {};
+ static constexpr int kOpCount =
+ kStreaming ? kStreamingOpCount : kStreamingOpCount;
+ static inline tflite::MicroMutableOpResolver<kOpCount> CreateResolver() {
+ tflite::MicroMutableOpResolver<kOpCount> resolver;
+ resolver.AddReshape();
+ resolver.AddPad();
+ resolver.AddConv2D();
+ resolver.AddLeakyRelu();
+ resolver.AddSplit();
+ resolver.AddTransposeConv();
+ resolver.AddStridedSlice();
+ resolver.AddConcatenation();
+ resolver.AddDepthwiseConv2D();
+ resolver.AddAdd();
+ resolver.AddQuantize();
+ if (kStreaming) {
+ resolver.AddCallOnce();
+ resolver.AddVarHandle();
+ resolver.AddReadVariable();
+ resolver.AddAssignVariable();
+ resolver.AddSub();
+ }
+ return resolver;
}
- return d;
+ const tflite::MicroMutableOpResolver<kOpCount> resolver_;
+ // Created in the arena
+ std::unique_ptr<tflite::RecordingMicroAllocator> allocator_;
+ // Created in the arena
+ std::unique_ptr<tflite::MicroResourceVariables> variables_;
+ tflite::MicroInterpreter interpreter_;
+};
+
+// Two separate methods to construct streaming vs non-streaming, so that the
+// compiler can eliminate one if it's unused. Perhaps with LTO we could combine
+// them together.
+std::unique_ptr<Decoder> SetupStreaming(uint8_t* tensor_arena,
+ size_t tensor_arena_size) {
+ return std::unique_ptr<Decoder>(DecoderImpl<true>::Setup(
+ g__decoder_streaming_q16x8_b64_io_int16_model_data, tensor_arena,
+ tensor_arena_size));
}
+std::unique_ptr<Decoder> Setup(uint8_t* tensor_arena,
+ size_t tensor_arena_size) {
+ return std::unique_ptr<Decoder>(DecoderImpl<false>::Setup(
+ g__decoder_non_stream_q16x8_b64_io_int16_model_data, tensor_arena,
+ tensor_arena_size));
+}
+
} // namespace kelvin::soundstream::decoder
diff --git a/examples/tflm/soundstream/decoder.h b/examples/tflm/soundstream/decoder.h
index b21c93d..337f402 100644
--- a/examples/tflm/soundstream/decoder.h
+++ b/examples/tflm/soundstream/decoder.h
@@ -1,20 +1,31 @@
+// Copyright 2023 Google LLC
+// Licensed under the Apache License, Version 2.0, see LICENSE for details.
+// SPDX-License-Identifier: Apache-2.0
+
#ifndef EXAMPLES_TFLM_SOUNDSTREAM_DECODER_H_
#define EXAMPLES_TFLM_SOUNDSTREAM_DECODER_H_
#include <cstddef>
#include <memory>
-#include <optional>
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
+#include "tensorflow/lite/micro/recording_micro_allocator.h"
namespace kelvin::soundstream::decoder {
+// RecordingMicroAllocator on desktop recorded 94512 bytes of allocation.
constexpr size_t kTensorArenaSizeBytes = 96 * 1024;
-struct Decoder {
- std::unique_ptr<tflite::MicroInterpreter> interpreter;
- std::unique_ptr<tflite::MicroMutableOpResolver<11>> resolver;
+// RecordingMicroAllocator on desktop recorded 143296 bytes of allocation.
+constexpr size_t kTensorArenaStreamingSizeBytes = 168 * 1024;
+
+class Decoder {
+ public:
+ virtual tflite::MicroInterpreter* interpreter() = 0;
};
-std::optional<Decoder> Setup(uint8_t* tensor_arena);
+
+std::unique_ptr<Decoder> Setup(uint8_t* tensor_arena, size_t tensor_arena_size);
+std::unique_ptr<Decoder> SetupStreaming(uint8_t* tensor_arena,
+ size_t tensor_arena_size);
} // namespace kelvin::soundstream::decoder
#endif // EXAMPLES_TFLM_SOUNDSTREAM_DECODER_H_
diff --git a/examples/tflm/soundstream/encoder.cc b/examples/tflm/soundstream/encoder.cc
index d31d517..6d3c985 100644
--- a/examples/tflm/soundstream/encoder.cc
+++ b/examples/tflm/soundstream/encoder.cc
@@ -1,32 +1,95 @@
+// Copyright 2023 Google LLC
+// Licensed under the Apache License, Version 2.0, see LICENSE for details.
+// SPDX-License-Identifier: Apache-2.0
+
#include "examples/tflm/soundstream/encoder.h"
#include "examples/tflm/soundstream/encoder_non_stream_q16x8_b64_io_int16_tflite.h"
+#include "examples/tflm/soundstream/encoder_streaming_q16x8_b64_io_int16_tflite.h"
+#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
+#include "tensorflow/lite/micro/recording_micro_allocator.h"
namespace kelvin::soundstream::encoder {
-std::optional<Encoder> Setup(uint8_t* tensor_arena) {
- auto* model =
- tflite::GetModel(g__encoder_non_stream_q16x8_b64_io_int16_model_data);
- if (model->version() != TFLITE_SCHEMA_VERSION) {
- return {};
+
+constexpr unsigned int kNonStreamingOpCount = 6;
+constexpr unsigned int kStreamingOpCount = 13;
+// Not sure how to get a good upper bound on this one, so arbitrarily chosen.
+constexpr unsigned int kStreamingVariablesCount = 40;
+
+template <bool kStreaming>
+class EncoderImpl : public Encoder {
+ public:
+ static Encoder* Setup(const uint8_t* model_data, uint8_t* tensor_arena,
+ size_t tensor_arena_size) {
+ auto* model = tflite::GetModel(model_data);
+ if (model->version() != TFLITE_SCHEMA_VERSION) {
+ return nullptr;
+ }
+
+ EncoderImpl* e = new EncoderImpl(model, tensor_arena, tensor_arena_size);
+
+ TfLiteStatus allocate_status = e->interpreter()->AllocateTensors();
+ if (allocate_status != kTfLiteOk) {
+ MicroPrintf("Failed to allocate decoder's tensors");
+ return nullptr;
+ }
+ return e;
}
+ tflite::MicroInterpreter* interpreter() { return &interpreter_; }
- Encoder e;
- e.resolver = std::make_unique<tflite::MicroMutableOpResolver<6>>();
- e.resolver->AddReshape();
- e.resolver->AddPad();
- e.resolver->AddConv2D();
- e.resolver->AddLeakyRelu();
- e.resolver->AddDepthwiseConv2D();
- e.resolver->AddAdd();
+ private:
+ EncoderImpl(const tflite::Model* model, uint8_t* tensor_arena,
+ size_t tensor_arena_size)
+ : resolver_(CreateResolver()),
+ allocator_(tflite::RecordingMicroAllocator::Create(tensor_arena,
+ tensor_arena_size)),
+ variables_(tflite::MicroResourceVariables::Create(
+ allocator_.get(), kStreamingVariablesCount)),
+ interpreter_(model, resolver_, allocator_.get(), variables_.get()) {}
- e.interpreter = std::make_unique<tflite::MicroInterpreter>(
- model, *e.resolver, tensor_arena, kTensorArenaSizeBytes);
-
- TfLiteStatus allocate_status = e.interpreter->AllocateTensors();
- if (allocate_status != kTfLiteOk) {
- MicroPrintf("Failed to allocate encoder's tensors");
- return {};
+ static constexpr int kOpCount =
+ kStreaming ? kStreamingOpCount : kStreamingOpCount;
+ static inline tflite::MicroMutableOpResolver<kOpCount> CreateResolver() {
+ tflite::MicroMutableOpResolver<kOpCount> resolver;
+ resolver.AddReshape();
+ resolver.AddPad();
+ resolver.AddConv2D();
+ resolver.AddLeakyRelu();
+ resolver.AddDepthwiseConv2D();
+ resolver.AddAdd();
+ if (kStreaming) {
+ resolver.AddCallOnce();
+ resolver.AddVarHandle();
+ resolver.AddReadVariable();
+ resolver.AddConcatenation();
+ resolver.AddStridedSlice();
+ resolver.AddAssignVariable();
+ resolver.AddQuantize();
+ }
+ return resolver;
}
- return e;
+ const tflite::MicroMutableOpResolver<kOpCount> resolver_;
+ // Created in the arena
+ std::unique_ptr<tflite::RecordingMicroAllocator> allocator_;
+ // Created in the arena
+ std::unique_ptr<tflite::MicroResourceVariables> variables_;
+ tflite::MicroInterpreter interpreter_;
+};
+
+// Two separate methods to construct streaming vs non-streaming, so that the
+// compiler can eliminate one if it's unused. Perhaps with LTO we could combine
+// them together.
+std::unique_ptr<Encoder> SetupStreaming(uint8_t* tensor_arena,
+ size_t tensor_arena_size) {
+ return std::unique_ptr<Encoder>(EncoderImpl<true>::Setup(
+ g__encoder_streaming_q16x8_b64_io_int16_model_data, tensor_arena,
+ tensor_arena_size));
}
+std::unique_ptr<Encoder> Setup(uint8_t* tensor_arena,
+ size_t tensor_arena_size) {
+ return std::unique_ptr<Encoder>(EncoderImpl<false>::Setup(
+ g__encoder_non_stream_q16x8_b64_io_int16_model_data, tensor_arena,
+ tensor_arena_size));
+}
+
} // namespace kelvin::soundstream::encoder
diff --git a/examples/tflm/soundstream/encoder.h b/examples/tflm/soundstream/encoder.h
index 3cb74f9..267d553 100644
--- a/examples/tflm/soundstream/encoder.h
+++ b/examples/tflm/soundstream/encoder.h
@@ -1,20 +1,28 @@
+// Copyright 2023 Google LLC
+// Licensed under the Apache License, Version 2.0, see LICENSE for details.
+// SPDX-License-Identifier: Apache-2.0
+
#ifndef EXAMPLES_TFLM_SOUNDSTREAM_ENCODER_H_
#define EXAMPLES_TFLM_SOUNDSTREAM_ENCODER_H_
#include <cstddef>
#include <memory>
-#include <optional>
#include "tensorflow/lite/micro/micro_interpreter.h"
-#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
namespace kelvin::soundstream::encoder {
+// RecordingMicroAllocator on desktop recorded 90064 bytes of allocation.
constexpr size_t kTensorArenaSizeBytes = 96 * 1024;
+// RecordingMicroAllocator on desktop recorded 147328 bytes of allocation.
+constexpr size_t kTensorArenaStreamingSizeBytes = 168 * 1024;
+
struct Encoder {
- std::unique_ptr<tflite::MicroInterpreter> interpreter;
- std::unique_ptr<tflite::MicroMutableOpResolver<6>> resolver;
+ virtual tflite::MicroInterpreter* interpreter() = 0;
};
-std::optional<Encoder> Setup(uint8_t* tensor_arena);
+
+std::unique_ptr<Encoder> Setup(uint8_t* tensor_arena, size_t tensor_arena_size);
+std::unique_ptr<Encoder> SetupStreaming(uint8_t* tensor_arena,
+ size_t tensor_arena_size);
} // namespace kelvin::soundstream::encoder
#endif // EXAMPLES_TFLM_SOUNDSTREAM_ENCODER_H_
diff --git a/examples/tflm/soundstream/soundstream_decoder.cc b/examples/tflm/soundstream/soundstream_decoder.cc
index 77f0f94..9a5a739 100644
--- a/examples/tflm/soundstream/soundstream_decoder.cc
+++ b/examples/tflm/soundstream/soundstream_decoder.cc
@@ -16,23 +16,34 @@
};
namespace {
-uint8_t
- decoder_tensor_arena[kelvin::soundstream::decoder::kTensorArenaSizeBytes]
- __attribute__((aligned(64)));
+#if defined(STREAMING)
+constexpr size_t tensor_arena_size =
+ kelvin::soundstream::decoder::kTensorArenaStreamingSizeBytes;
+#else
+constexpr size_t tensor_arena_size =
+ kelvin::soundstream::decoder::kTensorArenaSizeBytes;
+#endif
+uint8_t decoder_tensor_arena[tensor_arena_size] __attribute__((aligned(64)));
} // namespace
int main(int argc, char **argv) {
- auto decoder = kelvin::soundstream::decoder::Setup(decoder_tensor_arena);
+#if defined(STREAMING)
+ auto decoder = kelvin::soundstream::decoder::SetupStreaming(
+ decoder_tensor_arena, tensor_arena_size);
+#else
+ auto decoder = kelvin::soundstream::decoder::Setup(decoder_tensor_arena,
+ tensor_arena_size);
+#endif
if (!decoder) {
MicroPrintf("Unable to construct decoder");
return -1;
}
- TfLiteTensor *decoder_input = decoder->interpreter->input(0);
- TfLiteTensor *decoder_output = decoder->interpreter->output(0);
+ TfLiteTensor *decoder_input = decoder->interpreter()->input(0);
+ TfLiteTensor *decoder_output = decoder->interpreter()->output(0);
memset(decoder_input->data.uint8, 0, decoder_input->bytes);
- TfLiteStatus invoke_status = decoder->interpreter->Invoke();
+ TfLiteStatus invoke_status = decoder->interpreter()->Invoke();
if (invoke_status != kTfLiteOk) {
MicroPrintf("Failed to invoke decoder");
return -1;
diff --git a/examples/tflm/soundstream/soundstream_e2e.cc b/examples/tflm/soundstream/soundstream_e2e.cc
index 4c43060..9ba8913 100644
--- a/examples/tflm/soundstream/soundstream_e2e.cc
+++ b/examples/tflm/soundstream/soundstream_e2e.cc
@@ -2,38 +2,69 @@
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0
-#include "examples/tflm/soundstream/best_of_times_s16_decoded.h"
-#include "examples/tflm/soundstream/best_of_times_s16_encoded.h"
#include "examples/tflm/soundstream/best_of_times_s16_wav.h"
#include "examples/tflm/soundstream/decoder.h"
#include "examples/tflm/soundstream/encoder.h"
+#if defined(STREAMING)
+#include "examples/tflm/soundstream/best_of_times_s16_decoded_streaming.h"
+#include "examples/tflm/soundstream/best_of_times_s16_encoded_streaming.h"
+const unsigned char *reference_decoded = g_best_of_times_s16_decoded_streaming;
+const unsigned char *reference_encoded = g_best_of_times_s16_encoded_streaming;
+#else
+#include "examples/tflm/soundstream/best_of_times_s16_decoded.h"
+#include "examples/tflm/soundstream/best_of_times_s16_encoded.h"
+const unsigned char *reference_decoded = g_best_of_times_s16_decoded;
+const unsigned char *reference_encoded = g_best_of_times_s16_encoded;
+#endif
+
namespace {
-uint8_t
- encoder_tensor_arena[kelvin::soundstream::encoder::kTensorArenaSizeBytes]
+#if defined(STREAMING)
+constexpr size_t decoder_tensor_arena_size =
+ kelvin::soundstream::decoder::kTensorArenaStreamingSizeBytes;
+constexpr size_t encoder_tensor_arena_size =
+ kelvin::soundstream::encoder::kTensorArenaStreamingSizeBytes;
+#else
+constexpr size_t decoder_tensor_arena_size =
+ kelvin::soundstream::decoder::kTensorArenaSizeBytes;
+constexpr size_t encoder_tensor_arena_size =
+ kelvin::soundstream::encoder::kTensorArenaSizeBytes;
+#endif
+uint8_t encoder_tensor_arena[encoder_tensor_arena_size]
__attribute__((aligned(64)));
-uint8_t
- decoder_tensor_arena[kelvin::soundstream::decoder::kTensorArenaSizeBytes]
+uint8_t decoder_tensor_arena[decoder_tensor_arena_size]
__attribute__((aligned(64)));
} // namespace
int main(int argc, char **argv) {
- auto encoder = kelvin::soundstream::encoder::Setup(encoder_tensor_arena);
+#if defined(STREAMING)
+ auto encoder = kelvin::soundstream::encoder::SetupStreaming(
+ encoder_tensor_arena, encoder_tensor_arena_size);
+#else
+ auto encoder = kelvin::soundstream::encoder::Setup(encoder_tensor_arena,
+ encoder_tensor_arena_size);
+#endif
if (!encoder) {
MicroPrintf("Unable to construct encoder");
return -1;
}
- auto decoder = kelvin::soundstream::decoder::Setup(decoder_tensor_arena);
+#if defined(STREAMING)
+ auto decoder = kelvin::soundstream::decoder::SetupStreaming(
+ decoder_tensor_arena, decoder_tensor_arena_size);
+#else
+ auto decoder = kelvin::soundstream::decoder::Setup(decoder_tensor_arena,
+ decoder_tensor_arena_size);
+#endif
if (!decoder) {
MicroPrintf("Unable to construct decoder");
return -1;
}
- TfLiteTensor *encoder_input = encoder->interpreter->input(0);
- TfLiteTensor *encoder_output = encoder->interpreter->output(0);
- TfLiteTensor *decoder_input = decoder->interpreter->input(0);
- TfLiteTensor *decoder_output = decoder->interpreter->output(0);
+ TfLiteTensor *encoder_input = encoder->interpreter()->input(0);
+ TfLiteTensor *encoder_output = encoder->interpreter()->output(0);
+ TfLiteTensor *decoder_input = decoder->interpreter()->input(0);
+ TfLiteTensor *decoder_output = decoder->interpreter()->output(0);
int invocation_count =
(g_best_of_times_s16_audio_data_size * sizeof(int16_t)) /
@@ -44,13 +75,13 @@
g_best_of_times_s16_audio_data +
((i * encoder_input->bytes) / sizeof(int16_t)),
encoder_input->bytes);
- TfLiteStatus invoke_status = encoder->interpreter->Invoke();
+ TfLiteStatus invoke_status = encoder->interpreter()->Invoke();
if (invoke_status != kTfLiteOk) {
MicroPrintf("Failed to invoke encoder");
return -1;
}
if (memcmp(encoder_output->data.uint8,
- g_best_of_times_s16_encoded + (i * encoder_output->bytes),
+ reference_encoded + (i * encoder_output->bytes),
encoder_output->bytes)) {
MicroPrintf("Encoder output mismatches reference");
return -1;
@@ -58,13 +89,13 @@
memcpy(decoder_input->data.uint8, encoder_output->data.uint8,
decoder_input->bytes);
- invoke_status = decoder->interpreter->Invoke();
+ invoke_status = decoder->interpreter()->Invoke();
if (invoke_status != kTfLiteOk) {
MicroPrintf("Failed to invoke decoder");
return -1;
}
if (memcmp(decoder_output->data.uint8,
- g_best_of_times_s16_decoded + (i * decoder_output->bytes),
+ reference_decoded + (i * decoder_output->bytes),
decoder_output->bytes)) {
MicroPrintf("Decoder output mismatches reference");
return -1;
diff --git a/examples/tflm/soundstream/soundstream_encoder.cc b/examples/tflm/soundstream/soundstream_encoder.cc
index bc7fc25..6c60748 100644
--- a/examples/tflm/soundstream/soundstream_encoder.cc
+++ b/examples/tflm/soundstream/soundstream_encoder.cc
@@ -16,23 +16,34 @@
};
namespace {
-uint8_t
- encoder_tensor_arena[kelvin::soundstream::encoder::kTensorArenaSizeBytes]
- __attribute__((aligned(64)));
+#if defined(STREAMING)
+constexpr size_t tensor_arena_size =
+ kelvin::soundstream::encoder::kTensorArenaStreamingSizeBytes;
+#else
+constexpr size_t tensor_arena_size =
+ kelvin::soundstream::encoder::kTensorArenaSizeBytes;
+#endif
+uint8_t tensor_arena[tensor_arena_size] __attribute__((aligned(64)));
} // namespace
int main(int argc, char **argv) {
- auto encoder = kelvin::soundstream::encoder::Setup(encoder_tensor_arena);
+#if defined(STREAMING)
+ auto encoder = kelvin::soundstream::encoder::SetupStreaming(
+ tensor_arena, tensor_arena_size);
+#else
+ auto encoder =
+ kelvin::soundstream::encoder::Setup(tensor_arena, tensor_arena_size);
+#endif
if (!encoder) {
MicroPrintf("Unable to construct encoder");
return -1;
}
- TfLiteTensor *encoder_input = encoder->interpreter->input(0);
- TfLiteTensor *encoder_output = encoder->interpreter->output(0);
+ TfLiteTensor *encoder_input = encoder->interpreter()->input(0);
+ TfLiteTensor *encoder_output = encoder->interpreter()->output(0);
memset(encoder_input->data.uint8, 0, encoder_input->bytes);
- TfLiteStatus invoke_status = encoder->interpreter->Invoke();
+ TfLiteStatus invoke_status = encoder->interpreter()->Invoke();
if (invoke_status != kTfLiteOk) {
MicroPrintf("Failed to invoke encoder");
return -1;