blob: 48043e1d8ec2525053422ba54f1e3f1ed68797d4 [file] [log] [blame]
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Use window op in python."""
import numpy as np
import tensorflow as tf
from tflite_micro.python.tflite_micro.signal.utils import util
gen_window_op = util.load_custom_op('window_op.so')
def hann_window_weights(window_length, shift, dtype=np.int16):
arg = np.pi * 2 / window_length
index = np.arange(window_length)
weights = (0.5 - (0.5 * np.cos(arg * (index + 0.5))))
if dtype == np.int16:
weights = np.round(weights * (2**shift))
return weights.astype(dtype=dtype)
# We can calculate the result of sqrt(0.5-cos(x)/2) without sqrt as sin(x/2).
def square_root_hann_window_weights(window_length, shift, dtype=np.int16):
arg_half = np.pi / window_length
index = np.arange(window_length)
weights = np.sin(arg_half * (index + 0.5))
if dtype == np.int16:
weights = np.round(weights * (2**shift))
return weights.astype(dtype=dtype)
# In the so-called weighted overlap add (WOLA) method, a second window would
# be applied after the inverse FFT and prior to the final overlap-add to
# generate the output signal. This second window is known as a synthesis
# window and it is commonly chosen to be the same as the first window (which
# is applied before the FFT and is known as an analysis window). The pair
# of windows need to be normalized such that they together meet the constant
# WOLA (CWOLA) constraint. So if a signal goes through this procedure, it can
# be reconstructed with little distortion. For the square-root Hann window
# implemented above, the normalizing constant is given by
# sqrt((window_length / (2 * window_step)).
def square_root_hann_cwola_window_weights(window_length,
window_step,
shift,
dtype=np.int16):
arg_half = np.pi / window_length
norm = np.sqrt(window_length / (2.0 * window_step))
index = np.arange(window_length)
weights = np.sin(arg_half * (index + 0.5)) / norm
if dtype == np.int16:
weights = np.round(weights * (2**shift))
return weights.astype(dtype=dtype)
def _window_wrapper(window_fn, default_name):
"""Wrapper around gen_window_op.window*."""
def _window(input_tensor, weight_tensor, shift, name=default_name):
with tf.name_scope(name) as name:
input_tensor = tf.convert_to_tensor(input_tensor, dtype=np.int16)
input_dim_list = input_tensor.shape.as_list()
weight_tensor = tf.convert_to_tensor(weight_tensor)
weight_dim_list = weight_tensor.shape.as_list()
if input_dim_list[-1] != weight_dim_list[0]:
raise ValueError("Innermost input dimension must match weights size")
return window_fn(input_tensor, weight_tensor, shift=shift, name=name)
return _window
# TODO(b/286250473): change back name to "window" after name clash resolved
window = _window_wrapper(gen_window_op.signal_window, "signal_window")
tf.no_gradient("signal_window")