blob: 4c43060cbd27ed8b86acecd068a4e9e815ae74e2 [file] [log] [blame]
// Copyright 2023 Google LLC
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0
#include "examples/tflm/soundstream/best_of_times_s16_decoded.h"
#include "examples/tflm/soundstream/best_of_times_s16_encoded.h"
#include "examples/tflm/soundstream/best_of_times_s16_wav.h"
#include "examples/tflm/soundstream/decoder.h"
#include "examples/tflm/soundstream/encoder.h"
namespace {
uint8_t
encoder_tensor_arena[kelvin::soundstream::encoder::kTensorArenaSizeBytes]
__attribute__((aligned(64)));
uint8_t
decoder_tensor_arena[kelvin::soundstream::decoder::kTensorArenaSizeBytes]
__attribute__((aligned(64)));
} // namespace
int main(int argc, char **argv) {
auto encoder = kelvin::soundstream::encoder::Setup(encoder_tensor_arena);
if (!encoder) {
MicroPrintf("Unable to construct encoder");
return -1;
}
auto decoder = kelvin::soundstream::decoder::Setup(decoder_tensor_arena);
if (!decoder) {
MicroPrintf("Unable to construct decoder");
return -1;
}
TfLiteTensor *encoder_input = encoder->interpreter->input(0);
TfLiteTensor *encoder_output = encoder->interpreter->output(0);
TfLiteTensor *decoder_input = decoder->interpreter->input(0);
TfLiteTensor *decoder_output = decoder->interpreter->output(0);
int invocation_count =
(g_best_of_times_s16_audio_data_size * sizeof(int16_t)) /
encoder_input->bytes;
for (int i = 0; i < invocation_count; ++i) {
MicroPrintf("Invocation %d of %d", i, invocation_count);
memcpy(encoder_input->data.uint8,
g_best_of_times_s16_audio_data +
((i * encoder_input->bytes) / sizeof(int16_t)),
encoder_input->bytes);
TfLiteStatus invoke_status = encoder->interpreter->Invoke();
if (invoke_status != kTfLiteOk) {
MicroPrintf("Failed to invoke encoder");
return -1;
}
if (memcmp(encoder_output->data.uint8,
g_best_of_times_s16_encoded + (i * encoder_output->bytes),
encoder_output->bytes)) {
MicroPrintf("Encoder output mismatches reference");
return -1;
}
memcpy(decoder_input->data.uint8, encoder_output->data.uint8,
decoder_input->bytes);
invoke_status = decoder->interpreter->Invoke();
if (invoke_status != kTfLiteOk) {
MicroPrintf("Failed to invoke decoder");
return -1;
}
if (memcmp(decoder_output->data.uint8,
g_best_of_times_s16_decoded + (i * decoder_output->bytes),
decoder_output->bytes)) {
MicroPrintf("Decoder output mismatches reference");
return -1;
}
}
return 0;
}