blob: 328bbb15090a2841bc0c60e95057380d28410e27 [file] [log] [blame]
/*
* Copyright 2023 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Audio preprocessing: MLCC feature extraction
#include "audio_prep/mfcc.h"
#include <math.h>
#include <stdlib.h>
#include <string.h>
#include "audio_prep/util.h"
#ifdef MFCC_WITH_RVV
#include <riscv_vector.h>
static const uint8_t kWeightsFracBits = 8;
static const uint8_t kSpectraFracBits = 7;
// Calculate the dot product of two int vectors using RVV
static uint32_t dot_product_rvv(uint32_t* u, uint32_t* w, int n) {
size_t vl;
// auxiliary variables
vuint32m8_t vx;
vuint32m8_t vu, vw;
vuint32m1_t v_sum;
uint32_t sum = 0;
for (size_t i = 0; i < n; i += vl) {
vl = __riscv_vsetvl_e32m8(n - i);
vu = __riscv_vle32_v_u32m8(u + i, vl); // load
vw = __riscv_vle32_v_u32m8(w + i, vl); // load
vx = __riscv_vmul(vu, vw, vl); // multiply
v_sum = __riscv_vmv_v_x_u32m1(0, vl); // init
v_sum = __riscv_vredsum(vx, v_sum, vl); // sum
sum += __riscv_vmv_x(v_sum);
}
return sum;
}
#endif
// config struct
typedef struct {
MfccParams params;
int win_len;
int hop_len;
int fft_order;
int fft_len;
int num_spectra_bins;
} MfccConfig;
static MfccConfig config = {.params.num_frames = 96,
.params.num_mel_bins = 64,
.params.audio_samp_rate = 16000,
.params.low_edge_hz = 125,
.params.upper_edge_hz = 7500,
.params.win_len_sec = 0.025,
.params.hop_len_sec = 0.010,
.params.log_floor = 0.01,
.params.log_scaler = 20,
.win_len = 400,
.hop_len = 160,
.fft_order = 9,
.fft_len = 512,
.num_spectra_bins = 257};
// set mfcc parameters
void set_mfcc_params(MfccParams* in_params) {
config.params = *in_params;
config.win_len =
(int)(config.params.audio_samp_rate * config.params.win_len_sec + 0.5);
config.hop_len =
(int)(config.params.audio_samp_rate * config.params.hop_len_sec + 0.5);
config.fft_order = ceilf(log2f(config.win_len));
config.fft_len = 1 << config.fft_order; // 512
config.num_spectra_bins = config.fft_len / 2 + 1;
}
// Convert frequencies to mel scale using HTK formula
static float hz_to_mel(float freq_hz) {
const float kMelBreakFreqHz = 700.0;
const float kMelHighFreqQ = 1127.0;
return kMelHighFreqQ * logf(1.0 + (freq_hz / kMelBreakFreqHz));
}
// Compute Hanning window coefficients
static void hanning(float* window) {
for (int j = 0; j < config.win_len; j++) {
window[j] = 0.5 - 0.5 * cosf(2 * M_PI * j / config.win_len);
}
}
// Calculate short-time Fourier transform magnitude for one frame
// output shape: num_spectra_bins
#ifdef MFCC_WITH_RVV
static void stft_magnitude(float* in, float* window, uint32_t* out) {
#else
static void stft_magnitude(float* in, float* window, float* out) {
#endif
float* frame = (float*)malloc(config.fft_len * sizeof(float));
memset(frame, 0, config.fft_len * sizeof(float));
memcpy(frame, in, config.win_len * sizeof(float));
// apply hanning window
for (int j = 0; j < config.win_len; j++) {
frame[j] *= window[j];
}
// real-valued FFT
rfft(frame, config.fft_order);
// compute STFT magnitude
float temp = 0.0;
for (int j = 0; j <= config.fft_len / 2; j++) {
if (j == 0 || j == config.fft_len / 2) {
temp = frame[j] > 0 ? frame[j] : -frame[j];
} else {
temp = sqrtf(frame[j] * frame[j] +
frame[config.fft_len - j] * frame[config.fft_len - j]);
}
#ifdef MFCC_WITH_RVV
out[j] = (uint32_t)(temp * (1 << kSpectraFracBits));
#else
out[j] = temp;
#endif
}
free(frame);
}
// Return a matrix that can post-multiply spectrogram rows to make mel
// output shape: params.num_mel_bins * num_spectra_bins
#ifdef MFCC_WITH_RVV
static void spectra_to_mel_matrix(uint32_t* weights) {
#else
static void spectra_to_mel_matrix(float* weights) {
#endif
MfccParams* params = &config.params;
float nyquist_hz = params->audio_samp_rate / 2;
float* spectra_bins = (float*)malloc(config.num_spectra_bins * sizeof(float));
linspace(spectra_bins, 0.0, nyquist_hz, config.num_spectra_bins);
for (int i = 0; i < config.num_spectra_bins; i++) {
spectra_bins[i] = hz_to_mel(spectra_bins[i]);
}
float* band_edges =
(float*)malloc((params->num_mel_bins + 2) * sizeof(float));
linspace(band_edges, hz_to_mel(params->low_edge_hz),
hz_to_mel(params->upper_edge_hz), params->num_mel_bins + 2);
float lower = 0.0, center = 0.0, upper = 0.0;
float lower_slope = 0.0, upper_slope = 0.0;
for (int i = 0; i < params->num_mel_bins; i++) {
// spectrogram DC bin
weights[i * config.num_spectra_bins] = 0;
lower = band_edges[i];
center = band_edges[i + 1];
upper = band_edges[i + 2];
for (int j = 1; j < config.num_spectra_bins; j++) {
lower_slope = (spectra_bins[j] - lower) / (center - lower);
upper_slope = (upper - spectra_bins[j]) / (upper - center);
float clamp = (lower_slope < upper_slope) ? lower_slope : upper_slope;
clamp = (clamp < 0) ? 0 : clamp;
#ifdef MFCC_WITH_RVV
weights[i * config.num_spectra_bins + j] =
(uint32_t)(clamp * (1 << kWeightsFracBits));
#else
weights[i * config.num_spectra_bins + j] = clamp;
#endif
}
}
free(band_edges);
free(spectra_bins);
}
// Convert waveform to a log magnitude mel-frequency spectrogram
// input: audio samples (int16) with params.num_frames * hop_len samples
// zero pre-padding win_len - hop_len samples
// output shape: params.num_frames * params.num_mel_bins (uint8)
void extract_mfcc(int16_t* in, uint8_t* out, int in_len) {
MfccParams* params = &config.params;
// Calculate a "periodic" Hann window
float* window = (float*)malloc(config.win_len * sizeof(float));
hanning(window);
#ifdef MFCC_WITH_RVV
uint32_t* weights = (uint32_t*)malloc(
params->num_mel_bins * config.num_spectra_bins * sizeof(uint32_t));
uint32_t* spectra =
(uint32_t*)malloc(config.num_spectra_bins * sizeof(uint32_t));
#else
float* weights = (float*)malloc(params->num_mel_bins *
config.num_spectra_bins * sizeof(float));
float* spectra = (float*)malloc(config.num_spectra_bins * sizeof(float));
#endif
// Compute weights
spectra_to_mel_matrix(weights);
float* frame = (float*)malloc(config.win_len * sizeof(float));
memset(frame, 0, config.win_len * sizeof(float));
for (int i = 0; i < params->num_frames; i++) {
// update buffer
memmove(frame, frame + config.hop_len,
(config.win_len - config.hop_len) * sizeof(float));
// feed in new samples
for (int j = 0; j < config.hop_len; j++) {
int idx = i * config.hop_len + j;
frame[config.win_len - config.hop_len + j] =
idx < in_len ? (float)in[idx] : 0.0;
}
// compute STFT magnitude
stft_magnitude(frame, window, spectra);
// compute MFCC
for (int j = 0; j < params->num_mel_bins; j++) {
#ifdef MFCC_WITH_RVV
uint32_t temp =
dot_product_rvv(spectra, weights + j * config.num_spectra_bins,
config.num_spectra_bins);
float tempf = (float)temp / (1 << (kSpectraFracBits + kWeightsFracBits));
#else
float tempf = dot_product(spectra, weights + j * config.num_spectra_bins,
config.num_spectra_bins);
#endif
if (tempf < params->log_floor) tempf = params->log_floor;
tempf = params->log_scaler * logf(tempf);
tempf = tempf < 0.0 ? 0.0 : (tempf > 255.0 ? 255.0 : tempf);
out[i * params->num_mel_bins + j] = (uint8_t)tempf;
}
}
free(window);
free(weights);
free(spectra);
free(frame);
}