Use scratch memory instead of perm buffer for the RFFT kernel's work area. (#2542)
BUG=333876877
diff --git a/signal/micro/kernels/rfft.cc b/signal/micro/kernels/rfft.cc
index 2ee8e7a..c9472b0 100644
--- a/signal/micro/kernels/rfft.cc
+++ b/signal/micro/kernels/rfft.cc
@@ -48,6 +48,7 @@
int32_t output_length;
TfLiteType fft_type;
T* work_area;
+ int scratch_buffer_index;
int8_t* state;
};
@@ -65,9 +66,6 @@
params->fft_length = fbw.ElementAsInt32(kFftLengthIndex);
params->fft_type = typeToTfLiteType<T>();
- params->work_area = static_cast<T*>(context->AllocatePersistentBuffer(
- context, params->fft_length * sizeof(T)));
-
size_t state_size = (*get_needed_memory_func)(params->fft_length);
params->state = static_cast<int8_t*>(
context->AllocatePersistentBuffer(context, state_size * sizeof(int8_t)));
@@ -103,6 +101,8 @@
params->output_length =
output_shape.Dims(output_shape.DimensionsCount() - 1) / 2;
+ context->RequestScratchBufferInArena(context, params->fft_length * sizeof(T),
+ ¶ms->scratch_buffer_index);
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
@@ -122,15 +122,17 @@
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
Complex<T>* output_data = tflite::micro::GetTensorData<Complex<T>>(output);
+ T* work_area = static_cast<T*>(
+ context->GetScratchBuffer(context, params->scratch_buffer_index));
+
for (int input_idx = 0, output_idx = 0; input_idx < params->input_size;
input_idx += params->input_length, output_idx += params->output_length) {
- memcpy(params->work_area, &input_data[input_idx],
- sizeof(T) * params->input_length);
+ memcpy(work_area, &input_data[input_idx], sizeof(T) * params->input_length);
// Zero pad input to FFT length
- memset(¶ms->work_area[params->input_length], 0,
+ memset(&work_area[params->input_length], 0,
sizeof(T) * (params->fft_length - params->input_length));
- (*apply_func)(params->state, params->work_area, &output_data[output_idx]);
+ (*apply_func)(params->state, work_area, &output_data[output_idx]);
}
return kTfLiteOk;
}