/*
 * Copyright 2023 Google LLC
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

// Audio preprocessing: MLCC feature extraction

#include "audio_prep/mfcc.h"

#include <math.h>
#include <stdlib.h>
#include <string.h>

#include "audio_prep/util.h"

#ifdef MFCC_WITH_RVV
#include <riscv_vector.h>

static const uint8_t kWeightsFracBits = 8;
static const uint8_t kSpectraFracBits = 7;

// Calculate the dot product of two int vectors using RVV
static uint32_t dot_product_rvv(uint32_t* u, uint32_t* w, int n) {
  size_t vl;
  // auxiliary variables
  vuint32m8_t vx;
  vuint32m8_t vu, vw;
  vuint32m1_t v_sum;
  uint32_t sum = 0;
  for (size_t i = 0; i < n; i += vl) {
    vl = __riscv_vsetvl_e32m8(n - i);
    vu = __riscv_vle32_v_u32m8(u + i, vl);   // load
    vw = __riscv_vle32_v_u32m8(w + i, vl);   // load
    vx = __riscv_vmul(vu, vw, vl);           // multiply
    v_sum = __riscv_vmv_v_x_u32m1(0, vl);    // init
    v_sum = __riscv_vredsum(vx, v_sum, vl);  // sum
    sum += __riscv_vmv_x(v_sum);
  }
  return sum;
}
#endif

// config struct
typedef struct {
  MfccParams params;
  int win_len;
  int hop_len;
  int fft_order;
  int fft_len;
  int num_spectra_bins;
} MfccConfig;

static MfccConfig config = {.params.num_frames = 96,
                            .params.num_mel_bins = 64,
                            .params.audio_samp_rate = 16000,
                            .params.low_edge_hz = 125,
                            .params.upper_edge_hz = 7500,
                            .params.win_len_sec = 0.025,
                            .params.hop_len_sec = 0.010,
                            .params.log_floor = 0.01,
                            .params.log_scaler = 20,
                            .win_len = 400,
                            .hop_len = 160,
                            .fft_order = 9,
                            .fft_len = 512,
                            .num_spectra_bins = 257};

// set mfcc parameters
void set_mfcc_params(MfccParams* in_params) {
  config.params = *in_params;
  config.win_len =
      (int)(config.params.audio_samp_rate * config.params.win_len_sec + 0.5);
  config.hop_len =
      (int)(config.params.audio_samp_rate * config.params.hop_len_sec + 0.5);
  config.fft_order = ceilf(log2f(config.win_len));
  config.fft_len = 1 << config.fft_order;  // 512
  config.num_spectra_bins = config.fft_len / 2 + 1;
}

// Convert frequencies to mel scale using HTK formula
static float hz_to_mel(float freq_hz) {
  const float kMelBreakFreqHz = 700.0;
  const float kMelHighFreqQ = 1127.0;
  return kMelHighFreqQ * logf(1.0 + (freq_hz / kMelBreakFreqHz));
}

// Compute Hanning window coefficients
static void hanning(float* window) {
  for (int j = 0; j < config.win_len; j++) {
    window[j] = 0.5 - 0.5 * cosf(2 * M_PI * j / config.win_len);
  }
}

// Calculate short-time Fourier transform magnitude for one frame
// output shape: num_spectra_bins
#ifdef MFCC_WITH_RVV
static void stft_magnitude(float* in, float* window, uint32_t* out) {
#else
static void stft_magnitude(float* in, float* window, float* out) {
#endif
  float* frame = (float*)malloc(config.fft_len * sizeof(float));
  memset(frame, 0, config.fft_len * sizeof(float));
  memcpy(frame, in, config.win_len * sizeof(float));

  // apply hanning window
  for (int j = 0; j < config.win_len; j++) {
    frame[j] *= window[j];
  }

  // real-valued FFT
  rfft(frame, config.fft_order);

  // compute STFT magnitude
  float temp = 0.0;
  for (int j = 0; j <= config.fft_len / 2; j++) {
    if (j == 0 || j == config.fft_len / 2) {
      temp = frame[j] > 0 ? frame[j] : -frame[j];
    } else {
      temp = sqrtf(frame[j] * frame[j] +
                   frame[config.fft_len - j] * frame[config.fft_len - j]);
    }
#ifdef MFCC_WITH_RVV
    out[j] = (uint32_t)(temp * (1 << kSpectraFracBits));
#else
    out[j] = temp;
#endif
  }

  free(frame);
}

// Return a matrix that can post-multiply spectrogram rows to make mel
// output shape: params.num_mel_bins * num_spectra_bins
#ifdef MFCC_WITH_RVV
static void spectra_to_mel_matrix(uint32_t* weights) {
#else
static void spectra_to_mel_matrix(float* weights) {
#endif
  MfccParams* params = &config.params;
  float nyquist_hz = params->audio_samp_rate / 2;
  float* spectra_bins = (float*)malloc(config.num_spectra_bins * sizeof(float));
  linspace(spectra_bins, 0.0, nyquist_hz, config.num_spectra_bins);
  for (int i = 0; i < config.num_spectra_bins; i++) {
    spectra_bins[i] = hz_to_mel(spectra_bins[i]);
  }

  float* band_edges =
      (float*)malloc((params->num_mel_bins + 2) * sizeof(float));
  linspace(band_edges, hz_to_mel(params->low_edge_hz),
           hz_to_mel(params->upper_edge_hz), params->num_mel_bins + 2);

  float lower = 0.0, center = 0.0, upper = 0.0;
  float lower_slope = 0.0, upper_slope = 0.0;
  for (int i = 0; i < params->num_mel_bins; i++) {
    // spectrogram DC bin
    weights[i * config.num_spectra_bins] = 0;

    lower = band_edges[i];
    center = band_edges[i + 1];
    upper = band_edges[i + 2];
    for (int j = 1; j < config.num_spectra_bins; j++) {
      lower_slope = (spectra_bins[j] - lower) / (center - lower);
      upper_slope = (upper - spectra_bins[j]) / (upper - center);
      float clamp = (lower_slope < upper_slope) ? lower_slope : upper_slope;
      clamp = (clamp < 0) ? 0 : clamp;
#ifdef MFCC_WITH_RVV
      weights[i * config.num_spectra_bins + j] =
          (uint32_t)(clamp * (1 << kWeightsFracBits));
#else
      weights[i * config.num_spectra_bins + j] = clamp;
#endif
    }
  }

  free(band_edges);
  free(spectra_bins);
}

// Convert waveform to a log magnitude mel-frequency spectrogram
// input: audio samples (int16) with params.num_frames * hop_len samples
// zero pre-padding win_len - hop_len samples
// output shape: params.num_frames * params.num_mel_bins (uint8)
void extract_mfcc(int16_t* in, uint8_t* out, int in_len) {
  MfccParams* params = &config.params;
  // Calculate a "periodic" Hann window
  float* window = (float*)malloc(config.win_len * sizeof(float));
  hanning(window);

#ifdef MFCC_WITH_RVV
  uint32_t* weights = (uint32_t*)malloc(
      params->num_mel_bins * config.num_spectra_bins * sizeof(uint32_t));
  uint32_t* spectra =
      (uint32_t*)malloc(config.num_spectra_bins * sizeof(uint32_t));
#else
  float* weights = (float*)malloc(params->num_mel_bins *
                                  config.num_spectra_bins * sizeof(float));
  float* spectra = (float*)malloc(config.num_spectra_bins * sizeof(float));
#endif

  // Compute weights
  spectra_to_mel_matrix(weights);

  float* frame = (float*)malloc(config.win_len * sizeof(float));
  memset(frame, 0, config.win_len * sizeof(float));

  for (int i = 0; i < params->num_frames; i++) {
    // update buffer
    memmove(frame, frame + config.hop_len,
            (config.win_len - config.hop_len) * sizeof(float));

    // feed in new samples
    for (int j = 0; j < config.hop_len; j++) {
      int idx = i * config.hop_len + j;
      frame[config.win_len - config.hop_len + j] =
          idx < in_len ? (float)in[idx] : 0.0;
    }

    // compute STFT magnitude
    stft_magnitude(frame, window, spectra);

    // compute MFCC
    for (int j = 0; j < params->num_mel_bins; j++) {
#ifdef MFCC_WITH_RVV
      uint32_t temp =
          dot_product_rvv(spectra, weights + j * config.num_spectra_bins,
                          config.num_spectra_bins);
      float tempf = (float)temp / (1 << (kSpectraFracBits + kWeightsFracBits));
#else
      float tempf = dot_product(spectra, weights + j * config.num_spectra_bins,
                                config.num_spectra_bins);
#endif
      if (tempf < params->log_floor) tempf = params->log_floor;
      tempf = params->log_scaler * logf(tempf);
      tempf = tempf < 0.0 ? 0.0 : (tempf > 255.0 ? 255.0 : tempf);
      out[i * params->num_mel_bins + j] = (uint8_t)tempf;
    }
  }

  free(window);
  free(weights);
  free(spectra);
  free(frame);
}
