blob: 5fd2acac66e9b8a40f0794908b3ad90ac04ccc6e [file] [log] [blame]
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
using tensorflow::shape_inference::InferenceContext;
using tensorflow::shape_inference::ShapeHandle;
namespace tensorflow {
namespace signal {
absl::Status StackerShape(InferenceContext* c) {
int num_channels, stacker_left_context, stacker_right_context;
TF_RETURN_IF_ERROR(c->GetAttr<int>("num_channels", &num_channels));
TF_RETURN_IF_ERROR(
c->GetAttr<int>("stacker_left_context", &stacker_left_context));
TF_RETURN_IF_ERROR(
c->GetAttr<int>("stacker_right_context", &stacker_right_context));
int output_frames = stacker_left_context + 1 + stacker_right_context;
ShapeHandle out;
shape_inference::DimensionHandle dim_in;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &out));
TF_RETURN_IF_ERROR(
c->WithValue(c->Dim(c->input(0), 0), num_channels, &dim_in));
TF_RETURN_IF_ERROR(
c->ReplaceDim(out, 0, c->MakeDim(num_channels * output_frames), &out));
c->set_output(0, out);
c->set_output(1, c->Scalar());
return absl::OkStatus();
}
// TODO(b/286250473): change back name after name clash resolved
REGISTER_OP("SignalStacker")
.Attr("num_channels: int >= 1")
.Attr("stacker_left_context: int >= 0")
.Attr("stacker_right_context: int >= 0")
.Attr("stacker_step: int >= 1")
.Input("input: int16")
.Output("output: int16")
.Output("output_valid: bool")
.SetShapeFn([](InferenceContext* c) { return StackerShape(c); })
.Doc(R"doc(
Stack several input frames into a single stacked frame. On each invocation, it
generates a stacked frame that contains:
(stacker_left_context + 1 + stacker_right_context)
consecutive unstacked frames. The stacked frame which becomes the input to the
neural network. The stacker then moves forward in a step of one or more input
frames.
For example, assuming a squence of 10 input frames, where each input frame
is itself a vector of size num_channels:
[1 2 3 4 5 6 7 8 9 0], and the current frame is 9, the following configuration:
stacker_left_context=0
stacker_right_context=1
stacker_step=2
will produce 5 stacked frames:
[1,2] [3,4] [5,6] [7,8] [9,0].
input: A 1-D input frame of size num_channels
output: A 1-D output frame of size
(stacker_left_context + 1 + stacker_right_context) * num_channels
output_valid: A boolean scalar.
If true, the output is a valid output frame
If false, the output is an invalid output frame of all zeros
Once the stacker produces its first output frame, its output will
be valid every stacker_step input frames.
num_channels: the number of filter bank channels in each stacker input frame
stacker_left_context: The number of input frames to the left of the current
frame to include in the output frame.
stacker_right_context: The number of input frames to the right of the current
frame to include in the output frame.
stacker_step: The number of input frames to increment the stacker's 'current'
frame on each invocation.
)doc");
} // namespace signal
} // namespace tensorflow