Add support for audio pre-processing (MFCC extraction)

This is the complete code change for adding audio pre-processing block (MFCC feature extraction).

The code was re-written in C based on http://google3/audio/dsp/mfcc/mfcc_mel.py, with necessary simplifications.

We also integrated audio pre-processing into daredevil, to allow for end-to-end simulations with .wav audio input.

Numerical correctness has been validated: extracted MFCC matches with
the reference (golden_whistle_spectrogram) and returned output is
correct (Whistling). Unit tests passed.

Profiling result: https://docs.google.com/document/d/1vAgRgZeVyOwIniiIp8RrmEvg3CB_p0PFgB2X9NLN4Fo
The pre-processing block adds a relatively constant number of extra instructions. So the percentage depends on the computational complexity of audio models (for daredevil: 65%)

Change-Id: I662914e68170e5e28384960ad345c3366f9fb627
diff --git a/build_tools/gen_mlmodel_input.py b/build_tools/gen_mlmodel_input.py
index cc9a352..213d508 100755
--- a/build_tools/gen_mlmodel_input.py
+++ b/build_tools/gen_mlmodel_input.py
@@ -21,6 +21,7 @@
 
 import numpy as np
 from PIL import Image
+from scipy.io import wavfile
 
 
 parser = argparse.ArgumentParser(
@@ -38,10 +39,12 @@
 args = parser.parse_args()
 
 
-def write_binary_file(file_path, input, is_quant):
+def write_binary_file(file_path, input, is_quant, is_audio):
     with open(file_path, "wb+") as file:
         for d in input:
-            if is_quant:
+            if is_audio:
+                file.write(struct.pack("<h", d))
+            elif is_quant:
                 file.write(struct.pack("<B", d))
             else:
                 file.write(struct.pack("<f", d))
@@ -53,10 +56,15 @@
     if len(input_shape) < 3:
         raise ValueError("Input shape < 3 dimensions")
     input_ext = os.path.splitext(input_name)[1]
+    is_audio = False
     if (not input_ext) or (input_ext == '.bin'):
         with open(input_name, mode='rb') as f:
             input = np.fromfile(f, dtype=np.uint8 if is_quant else np.float32)
             input = input[:np.prod(input_shape)].reshape(np.prod(input_shape))
+    elif (input_ext == '.wav'):
+        is_audio = True
+        _, input = wavfile.read(input_name)
+        input = input[:np.prod(input_shape)].reshape(np.prod(input_shape))
     else:
         resized_img = Image.open(input_name).resize(
             (input_shape[1], input_shape[2]))
@@ -65,7 +73,7 @@
             low = np.min(float_input_range)
             high = np.max(float_input_range)
             input = (high - low) * input / 255.0 + low
-    write_binary_file(output_file, input, is_quant)
+    write_binary_file(output_file, input, is_quant, is_audio)
 
 
 if __name__ == '__main__':
diff --git a/samples/audio_prep/CMakeLists.txt b/samples/audio_prep/CMakeLists.txt
new file mode 100644
index 0000000..1991ad3
--- /dev/null
+++ b/samples/audio_prep/CMakeLists.txt
@@ -0,0 +1,33 @@
+iree_cc_library(
+  NAME
+    util
+  HDRS
+    "util.h"
+  SRCS
+    "util.c"
+  DEPS
+    "m"
+)
+
+iree_cc_library(
+  NAME
+    mfcc
+  HDRS
+    "mfcc.h"
+  SRCS
+    "mfcc.c"
+  DEPS
+    ::util
+)
+
+iree_cc_binary(
+  NAME
+    mfcc_test
+  SRCS
+    "mfcc_test.cc"
+  DEPS
+    ::mfcc
+    pw_unit_test
+    pw_unit_test.main
+    pw_assert_basic
+)
diff --git a/samples/audio_prep/mfcc.c b/samples/audio_prep/mfcc.c
new file mode 100644
index 0000000..bc01d8b
--- /dev/null
+++ b/samples/audio_prep/mfcc.c
@@ -0,0 +1,193 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Audio preprocessing: MLCC feature extraction
+
+#include <math.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+
+#include "samples/audio_prep/mfcc.h"
+#include "samples/audio_prep/util.h"
+
+// config struct
+typedef struct {
+  MfccParams params;
+  int win_len;
+  int hop_len;
+  int fft_order;
+  int fft_len;
+  int num_spectra_bins;
+} MfccConfig;
+
+static MfccConfig config = {.params.num_frames = 96,
+                            .params.num_mel_bins = 64,
+                            .params.audio_samp_rate = 16000,
+                            .params.low_edge_hz = 125,
+                            .params.upper_edge_hz = 7500,
+                            .params.win_len_sec = 0.025,
+                            .params.hop_len_sec = 0.010,
+                            .params.log_floor = 0.01,
+                            .params.log_scaler = 20,
+                            .win_len = 400,
+                            .hop_len = 160,
+                            .fft_order = 9,
+                            .fft_len = 512,
+                            .num_spectra_bins = 257};
+
+// set mfcc parameters
+void set_mfcc_params(MfccParams* in_params) {
+  config.params = *in_params;
+  config.win_len =
+      (int)(config.params.audio_samp_rate * config.params.win_len_sec + 0.5);
+  config.hop_len =
+      (int)(config.params.audio_samp_rate * config.params.hop_len_sec + 0.5);
+  config.fft_order = ceilf(log2f(config.win_len));
+  config.fft_len = 1 << config.fft_order;  // 512
+  config.num_spectra_bins = config.fft_len / 2 + 1;
+}
+
+// Convert frequencies to mel scale using HTK formula
+static float hz_to_mel(float freq_hz) {
+  const float kMelBreakFreqHz = 700.0;
+  const float kMelHighFreqQ = 1127.0;
+  return kMelHighFreqQ * logf(1.0 + (freq_hz / kMelBreakFreqHz));
+}
+
+// Compute Hanning window coefficients
+static void hanning(float* window) {
+  for (int j = 0; j < config.win_len; j++) {
+    window[j] = 0.5 - 0.5 * cosf(2 * M_PI * j / config.win_len);
+  }
+}
+
+// Calculate short-time Fourier transform magnitude for one frame
+// output shape: num_spectra_bins
+static void stft_magnitude(float* in, float* window, float* out) {
+  float* frame = (float*)malloc(config.fft_len * sizeof(float));
+  memset(frame, 0, config.fft_len * sizeof(float));
+  memcpy(frame, in, config.win_len * sizeof(float));
+
+  // apply hanning window
+  for (int j = 0; j < config.win_len; j++) {
+    frame[j] *= window[j];
+  }
+
+  // real-valued FFT
+  rfft(frame, config.fft_order);
+
+  // compute STFT magnitude
+  out[0] = frame[0] > 0 ? frame[0] : -frame[0];
+  out[config.fft_len / 2] = frame[config.fft_len / 2] > 0
+                                ? frame[config.fft_len / 2]
+                                : -frame[config.fft_len / 2];
+  for (int j = 1; j < config.fft_len / 2; j++) {
+    out[j] = sqrtf(frame[j] * frame[j] +
+                   frame[config.fft_len - j] * frame[config.fft_len - j]);
+  }
+
+  free(frame);
+}
+
+// Return a matrix that can post-multiply spectrogram rows to make mel
+// output shape: params.num_mel_bins * num_spectra_bins
+static void spectra_to_mel_matrix(float* weights) {
+  MfccParams* params = &config.params;
+  float nyquist_hz = params->audio_samp_rate / 2;
+  float* spectra_bins = (float*)malloc(config.num_spectra_bins * sizeof(float));
+  linspace(spectra_bins, 0.0, nyquist_hz, config.num_spectra_bins);
+  for (int i = 0; i < config.num_spectra_bins; i++) {
+    spectra_bins[i] = hz_to_mel(spectra_bins[i]);
+  }
+
+  float* band_edges =
+      (float*)malloc((params->num_mel_bins + 2) * sizeof(float));
+  linspace(band_edges, hz_to_mel(params->low_edge_hz),
+           hz_to_mel(params->upper_edge_hz), params->num_mel_bins + 2);
+
+  float lower = 0.0, center = 0.0, upper = 0.0;
+  float lower_slope = 0.0, upper_slope = 0.0;
+  for (int i = 0; i < params->num_mel_bins; i++) {
+    // spectrogram DC bin
+    weights[i * config.num_spectra_bins] = 0.0;
+
+    lower = band_edges[i];
+    center = band_edges[i + 1];
+    upper = band_edges[i + 2];
+    for (int j = 1; j < config.num_spectra_bins; j++) {
+      lower_slope = (spectra_bins[j] - lower) / (center - lower);
+      upper_slope = (upper - spectra_bins[j]) / (upper - center);
+      float clamp = (lower_slope < upper_slope) ? lower_slope : upper_slope;
+      clamp = (clamp < 0) ? 0 : clamp;
+      weights[i * config.num_spectra_bins + j] = clamp;
+    }
+  }
+
+  free(band_edges);
+  free(spectra_bins);
+}
+
+// Convert waveform to a log magnitude mel-frequency spectrogram
+// input: audio samples (int16) with params.num_frames * hop_len samples
+// zero pre-padding win_len - hop_len samples
+// output shape: params.num_frames * params.num_mel_bins (uint8)
+void extract_mfcc(int16_t* in, uint8_t* out, int in_len) {
+  MfccParams* params = &config.params;
+  // Calculate a "periodic" Hann window
+  float* window = (float*)malloc(config.win_len * sizeof(float));
+  hanning(window);
+
+  // Compute weights
+  float* weights = (float*)malloc(params->num_mel_bins *
+                                  config.num_spectra_bins * sizeof(float));
+  spectra_to_mel_matrix(weights);
+
+  float* frame = (float*)malloc(config.win_len * sizeof(float));
+  memset(frame, 0, config.win_len * sizeof(float));
+  float* spectra = (float*)malloc(config.num_spectra_bins * sizeof(float));
+
+  for (int i = 0; i < params->num_frames; i++) {
+    // update buffer
+    for (int j = 0; j < config.win_len - config.hop_len; j++) {
+      frame[j] = frame[j + config.hop_len];
+    }
+    // feed in new samples
+    for (int j = 0; j < config.hop_len; j++) {
+      int idx = i * config.hop_len + j;
+      frame[config.win_len - config.hop_len + j] =
+          idx < in_len ? (float)in[idx] : 0.0;
+    }
+
+    // compute STFT magnitude
+    stft_magnitude(frame, window, spectra);
+
+    // compute MFCC
+    for (int j = 0; j < params->num_mel_bins; j++) {
+      float temp = dot_product(spectra, weights + j * config.num_spectra_bins,
+                               config.num_spectra_bins);
+      if (temp < params->log_floor) temp = params->log_floor;
+      temp = params->log_scaler * logf(temp);
+      temp = temp < 0.0 ? 0.0 : (temp > 255.0 ? 255.0 : temp);
+      out[i * params->num_mel_bins + j] = (uint8_t)temp;
+    }
+  }
+
+  free(window);
+  free(weights);
+  free(spectra);
+  free(frame);
+}
diff --git a/samples/audio_prep/mfcc.h b/samples/audio_prep/mfcc.h
new file mode 100644
index 0000000..97cbab6
--- /dev/null
+++ b/samples/audio_prep/mfcc.h
@@ -0,0 +1,44 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SAMPLES_AUDIO_PREP_MFCC_H_
+#define SAMPLES_AUDIO_PREP_MFCC_H_
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+typedef struct {
+  int num_frames;
+  int num_mel_bins;
+  float audio_samp_rate;
+  float low_edge_hz;
+  float upper_edge_hz;
+  float win_len_sec;
+  float hop_len_sec;
+  float log_floor;
+  float log_scaler;
+} MfccParams;
+
+void set_mfcc_params(MfccParams* params);
+
+void extract_mfcc(int16_t* in, uint8_t* out, int in_len);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // SAMPLES_AUDIO_PREP_MFCC_H_
diff --git a/samples/audio_prep/mfcc_test.cc b/samples/audio_prep/mfcc_test.cc
new file mode 100644
index 0000000..659e90b
--- /dev/null
+++ b/samples/audio_prep/mfcc_test.cc
@@ -0,0 +1,120 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "pw_unit_test/framework.h"
+#include "samples/audio_prep/mfcc.h"
+
+static MfccParams golden_params = {.num_frames = 3,
+                                   .num_mel_bins = 10,
+                                   .audio_samp_rate = 16000,
+                                   .low_edge_hz = 125,
+                                   .upper_edge_hz = 7600,
+                                   .win_len_sec = 0.000625,
+                                   .hop_len_sec = 0.0003125,
+                                   .log_floor = 0.01,
+                                   .log_scaler = 20};
+static int16_t golden_input[] = {7954,   10085, 8733,   10844,  29949,
+                                 -549,   20833, 30345,  18086,  11375,
+                                 -27309, 12323, -22891, -23360, 11958};
+static uint8_t golden_output[] = {0, 0, 191, 187, 174, 186, 179, 175, 179, 173,
+                                  0, 0, 209, 205, 181, 193, 182, 199, 209, 216,
+                                  0, 0, 198, 194, 192, 204, 196, 201, 212, 223};
+
+// Golden value is derived from
+// http://google3/knowledge/cerebra/parrot/port/android/platform/
+// inference/daredevil/audio_prep/spectrogram_processor_test.cc
+TEST(MfccTest, AgreesWithGoldenValues) {
+  // set parameters
+  set_mfcc_params(&golden_params);
+  int out_len = golden_params.num_frames * golden_params.num_mel_bins;
+  uint8_t* out = reinterpret_cast<uint8_t*>(malloc(out_len * sizeof(uint8_t)));
+  // extract MFCC
+  extract_mfcc(golden_input, out, sizeof(golden_input) / sizeof(int16_t));
+
+  for (int i = 0; i < out_len; i++) {
+    ASSERT_EQ(out[i], golden_output[i]);
+  }
+  free(out);
+}
+
+TEST(MfccTest, DcInputSaneResult) {
+  MfccParams params = {.num_frames = 10,
+                       .num_mel_bins = 64,
+                       .audio_samp_rate = 16000,
+                       .low_edge_hz = 125,
+                       .upper_edge_hz = 7500,
+                       .win_len_sec = 0.032,
+                       .hop_len_sec = 0.016,
+                       .log_floor = 0.01,
+                       .log_scaler = 20};
+
+  // set parameters
+  set_mfcc_params(&params);
+
+  int hop_len = static_cast<int>(params.audio_samp_rate * params.hop_len_sec);
+  int in_len = hop_len * params.num_frames;
+  int16_t* in = reinterpret_cast<int16_t*>(
+      malloc(params.num_frames * hop_len * sizeof(int16_t)));
+  int out_len = params.num_frames * params.num_mel_bins;
+  uint8_t* out = reinterpret_cast<uint8_t*>(malloc(out_len * sizeof(uint8_t)));
+
+  // DC Input
+  memset(in, 255, in_len * sizeof(int16_t));
+  // extract MFCC
+  extract_mfcc(in, out, in_len);
+  // ignore the 1st frame due to pre zero-padding
+  // expect zero outputs
+  for (int i = params.num_mel_bins; i < out_len; i++) {
+    ASSERT_EQ(out[i], 0);
+  }
+
+  free(in);
+  free(out);
+}
+
+TEST(MfccTest, NyquistFreqInputSaneResult) {
+  MfccParams params = {.num_frames = 15,
+                       .num_mel_bins = 32,
+                       .audio_samp_rate = 16000,
+                       .low_edge_hz = 125,
+                       .upper_edge_hz = 7500,
+                       .win_len_sec = 0.016,
+                       .hop_len_sec = 0.008,
+                       .log_floor = 0.01,
+                       .log_scaler = 20};
+
+  // set parameters
+  set_mfcc_params(&params);
+
+  int hop_len = static_cast<int>(params.audio_samp_rate * params.hop_len_sec);
+  int in_len = hop_len * params.num_frames;
+  int16_t* in = reinterpret_cast<int16_t*>(
+      malloc(params.num_frames * hop_len * sizeof(int16_t)));
+  int out_len = params.num_frames * params.num_mel_bins;
+  uint8_t* out = reinterpret_cast<uint8_t*>(malloc(out_len * sizeof(uint8_t)));
+
+  // High (Nyquest) frequency Input
+  memset(in, 255, in_len * sizeof(int16_t));
+  for (int i = 1; i < in_len; i += 2) in[i] *= -1;
+  // extract MFCC
+  extract_mfcc(in, out, in_len);
+  // ignore the 1st frame due to pre zero-padding
+  // expect zero outputs
+  for (int i = params.num_mel_bins; i < out_len; i++) {
+    ASSERT_EQ(out[i], 0);
+  }
+
+  free(in);
+  free(out);
+}
diff --git a/samples/audio_prep/mfcc_test.txt b/samples/audio_prep/mfcc_test.txt
new file mode 100644
index 0000000..6b595fd
--- /dev/null
+++ b/samples/audio_prep/mfcc_test.txt
@@ -0,0 +1,3 @@
+// RUN: ${TEST_RUNNER_CMD} ${OUT}/springbok_iree/samples/audio_prep/mfcc_test 2>&1 | tee %t
+// RUN: cat %t | FileCheck %s
+// CHECK: {{PASSED}}
diff --git a/samples/audio_prep/util.c b/samples/audio_prep/util.c
new file mode 100644
index 0000000..0b3ffad
--- /dev/null
+++ b/samples/audio_prep/util.c
@@ -0,0 +1,200 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <math.h>
+
+#include "samples/audio_prep/util.h"
+
+#ifndef M_SQRT2
+#define M_SQRT2 1.41421356237309504880
+#endif
+
+// Evenly linear spaced array
+void linspace(float* x, float start, float end, int n) {
+  float step = (end - start) / (n - 1);
+  for (int i = 0; i < n; i++) {
+    x[i] = start + i * step;
+  }
+}
+
+// Calculate the dot product of two vectors
+float dot_product(float* v, float* u, int n) {
+  float result = 0.0;
+  for (int i = 0; i < n; i++) {
+    result += v[i] * u[i];
+  }
+  return result;
+}
+
+/*---------------------------------------------------------------------------
+ * FUNCTION NAME: rfft
+ *
+ * PURPOSE:       Real valued, in-place split-radix FFT
+ *
+ * INPUT:
+ *   x            Pointer to input and output array
+ *   m            2^m = n is the Length of FFT
+ *
+ * OUTPUT         Output order
+ *                  Re(0), Re(1), ..., Re(n/2), Im(N/2-1), ..., Im(1)
+ *
+ * RETURN VALUE
+ *   none
+ *
+ * DESIGN REFERENCE:
+ *                IEEE Transactions on Acoustic, Speech, and Signal Processing,
+ *                Vol. ASSP-35. No. 6, June 1987, pp. 849-863.
+ *
+ *                Subroutine adapted from fortran routine pp. 858-859.
+ *                Note corrected printing errors on page 859:
+ *                    SS1 = SIN(A3) -> should be SS1 = SIN(A);
+ *                    CC3 = COS(3)  -> should be CC3 = COS(A3)
+ *
+ *---------------------------------------------------------------------------*/
+
+void rfft(float* x, int m) {
+  int n = 1 << m;
+  int j, i, k, is, id;
+  int i0, i1, i2, i3, i4, i5, i6, i7, i8;
+  int n2, n4, n8;
+  float xt, a0, e, a, a3;
+  float t1, t2, t3, t4, t5, t6;
+  float cc1, ss1, cc3, ss3;
+  float* r0;
+
+  /* Digit reverse counter */
+
+  j = 0;
+  r0 = x;
+
+  for (i = 0; i < n - 1; i++) {
+    if (i < j) {
+      xt = x[j];
+      x[j] = *r0;
+      *r0 = xt;
+    }
+    r0++;
+
+    k = n >> 1;
+
+    while (k <= j) {
+      j = j - k;
+      k >>= 1;
+    }
+    j += k;
+  }
+
+  /* Length two butterflies */
+  is = 0;
+  id = 4;
+
+  while (is < n - 1) {
+    for (i0 = is; i0 < n; i0 += id) {
+      i1 = i0 + 1;
+      a0 = x[i0];
+      x[i0] += x[i1];
+      x[i1] = a0 - x[i1];
+    }
+
+    is = (id << 1) - 2;
+    id <<= 2;
+  }
+
+  /* L shaped butterflies */
+  n2 = 2;
+  for (k = 1; k < m; k++) {
+    n2 <<= 1;
+    n4 = n2 >> 2;
+    n8 = n2 >> 3;
+    e = (M_PI * 2) / n2;
+    is = 0;
+    id = n2 << 1;
+    while (is < n) {
+      for (i = is; i <= n - 1; i += id) {
+        i1 = i;
+        i2 = i1 + n4;
+        i3 = i2 + n4;
+        i4 = i3 + n4;
+        t1 = x[i4] + x[i3];
+        x[i4] -= x[i3];
+        x[i3] = x[i1] - t1;
+        x[i1] += t1;
+
+        if (n4 != 1) {
+          i1 += n8;
+          i2 += n8;
+          i3 += n8;
+          i4 += n8;
+          t1 = (x[i3] + x[i4]) / M_SQRT2;
+          t2 = (x[i3] - x[i4]) / M_SQRT2;
+          x[i4] = x[i2] - t1;
+          x[i3] = -x[i2] - t1;
+          x[i2] = x[i1] - t2;
+          x[i1] = x[i1] + t2;
+        }
+      }
+      is = (id << 1) - n2;
+      id <<= 2;
+    }
+
+    for (j = 1; j < n8; j++) {
+      a = j * e;
+      a3 = 3 * a;
+      cc1 = cosf(a);
+      ss1 = sinf(a);
+      cc3 = cosf(a3);
+      ss3 = sinf(a3);
+
+      is = 0;
+      id = n2 << 1;
+
+      while (is < n) {
+        for (i = is; i <= n - 1; i += id) {
+          i1 = i + j;
+          i2 = i1 + n4;
+          i3 = i2 + n4;
+          i4 = i3 + n4;
+          i5 = i + n4 - j;
+          i6 = i5 + n4;
+          i7 = i6 + n4;
+          i8 = i7 + n4;
+          t1 = x[i3] * cc1 + x[i7] * ss1;
+          t2 = x[i7] * cc1 - x[i3] * ss1;
+          t3 = x[i4] * cc3 + x[i8] * ss3;
+          t4 = x[i8] * cc3 - x[i4] * ss3;
+          t5 = t1 + t3;
+          t6 = t2 + t4;
+          t3 = t1 - t3;
+          t4 = t2 - t4;
+          t2 = x[i6] + t6;
+          x[i3] = t6 - x[i6];
+          x[i8] = t2;
+          t2 = x[i2] - t3;
+          x[i7] = -x[i2] - t3;
+          x[i4] = t2;
+          t1 = x[i1] + t5;
+          x[i6] = x[i1] - t5;
+          x[i1] = t1;
+          t1 = x[i5] + t4;
+          x[i5] = x[i5] - t4;
+          x[i2] = t1;
+        }
+        is = (id << 1) - n2;
+        id <<= 2;
+      }
+    }
+  }
+}
diff --git a/samples/audio_prep/util.h b/samples/audio_prep/util.h
new file mode 100644
index 0000000..ade1610
--- /dev/null
+++ b/samples/audio_prep/util.h
@@ -0,0 +1,48 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SAMPLES_AUDIO_PREP_UTIL_H_
+#define SAMPLES_AUDIO_PREP_UTIL_H_
+
+#ifndef M_PI
+#define M_PI 3.14159265358979323846
+#endif
+
+// Evenly linear spaced array
+void linspace(float* x, float start, float end, int n);
+
+// Calculate the dot product of two vectors
+float dot_product(float* v, float* u, int n);
+
+/*---------------------------------------------------------------------------
+ * FUNCTION NAME: rfft
+ *
+ * PURPOSE:       Real valued, in-place split-radix FFT
+ *
+ * INPUT:
+ *   x            Pointer to input and output array
+ *   m            2^m = n is the Length of FFT
+ *
+ * OUTPUT         Output order
+ *                  Re(0), Re(1), ..., Re(n/2), Im(N/2-1), ..., Im(1)
+ *
+ * RETURN VALUE
+ *   none
+ *
+ *---------------------------------------------------------------------------*/
+void rfft(float* x, int m);
+
+#endif  // SAMPLES_AUDIO_PREP_UTIL_H_
diff --git a/samples/quant_model/CMakeLists.txt b/samples/quant_model/CMakeLists.txt
index 6fd6855..56d3b1e 100644
--- a/samples/quant_model/CMakeLists.txt
+++ b/samples/quant_model/CMakeLists.txt
@@ -165,9 +165,9 @@
   NAME
     daredevil_quant_input
   SHAPE
-    "1, 96, 64"
+    "1, 15360, 1"
   SRC
-    "$ENV{ROOTDIR}/ml/ml-models/test_data/golden_whistle_spectrogram"
+    "$ENV{ROOTDIR}/ml/ml-models/test_data/golden_whistle.wav"
   QUANT
 )
 
@@ -340,6 +340,7 @@
     ::daredevil_bytecode_module_static_c
     ::daredevil_quant_input_c
     iree::vm::bytecode_module
+    samples::audio_prep::mfcc
     samples::util::util
   LINKOPTS
     "LINKER:--defsym=__itcm_length__=1M"
@@ -355,6 +356,7 @@
     ::daredevil_c_module_static_c
     ::daredevil_c_module_static_emitc
     ::daredevil_quant_input_c
+    samples::audio_prep::mfcc
     samples::util::util
   LINKOPTS
     "LINKER:--defsym=__itcm_length__=1M"
diff --git a/samples/quant_model/daredevil.c b/samples/quant_model/daredevil.c
index df78077..5601ebe 100644
--- a/samples/quant_model/daredevil.c
+++ b/samples/quant_model/daredevil.c
@@ -23,6 +23,7 @@
 #include "daredevil.h"
 #include "iree/base/api.h"
 #include "iree/hal/api.h"
+#include "samples/audio_prep/mfcc.h"
 #include "samples/util/util.h"
 
 // Compiled module embedded here to avoid file IO:
@@ -73,12 +74,16 @@
 
 iree_status_t load_input_data(const MlModel *model, void **buffer,
                               iree_const_byte_span_t **byte_span) {
+  iree_status_t result = alloc_input_buffer(model, buffer);
+  int16_t *input = (int16_t *)daredevil_quant_input;
+  uint8_t *output = (uint8_t *)buffer[0];
+  extract_mfcc(input, output, sizeof(daredevil_quant_input) / sizeof(int16_t));
+
   byte_span[0] = malloc(sizeof(iree_const_byte_span_t));
   *byte_span[0] = iree_make_const_byte_span(
-      daredevil_quant_input,
-      model->input_size_bytes[0] * model->input_length[0]);
+      buffer[0], model->input_size_bytes[0] * model->input_length[0]);
 
-  return iree_ok_status();
+  return result;
 }
 
 iree_status_t process_output(const MlModel *model,