blob: 9631b4c12a19ad3b2b9fb865e0cf9287aabac442 [file] [log] [blame]
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/fully_connected.h"
#include "tensorflow/lite/micro/kernels/lstm_eval.h"
namespace tflite {
// Deduce the size information (Batch (B), Time Steps (T), Input dimension (I),
// State dimension (S)) that defines the LSTM using the input and hidden state
// tensor
LstmSizeInfo CreateLstmSizeInfo(
const bool time_major, const TfLiteIntArray* input_tensor_shape,
const TfLiteIntArray* hidden_state_tensor_shape) {
LstmSizeInfo size_info;
size_info.time_major = time_major;
size_info.batch_size =
time_major ? input_tensor_shape->data[1] : input_tensor_shape->data[0];
size_info.time_steps =
time_major ? input_tensor_shape->data[0] : input_tensor_shape->data[1];
size_info.input_dimension = input_tensor_shape->data[2];
size_info.state_dimension = hidden_state_tensor_shape->data[1];
return size_info;
}
TfLiteStatus ValidateWeightTensorSize(TfLiteContext* context,
const TfLiteTensor* tensor, int dim1_size,
int dim2_size) {
TF_LITE_ENSURE_EQ(context, tensor->dims->size, 2);
TF_LITE_ENSURE_EQ(context, tensor->dims->data[0], dim1_size);
TF_LITE_ENSURE_EQ(context, tensor->dims->data[1], dim2_size);
return kTfLiteOk;
}
TfLiteStatus ValidateBiasTensorSize(TfLiteContext* context,
const TfLiteTensor* tensor, int size) {
TF_LITE_ENSURE_EQ(context, tensor->dims->size, 1);
TF_LITE_ENSURE_EQ(context, tensor->dims->data[0], size);
return kTfLiteOk;
}
// Go through every tensors and make sure their shape match the kernel
// configuration
TfLiteStatus ValidateTensorSize(TfLiteContext* context,
const LstmTensors& tensors,
const LstmSizeInfo& size_info) {
// Input FC weights
for (size_t i = 1; i < 5; i++) {
TF_LITE_ENSURE_OK(
context, ValidateWeightTensorSize(context, tensors.GetInternalTensor(i),
size_info.state_dimension,
size_info.input_dimension));
}
// Recurrent FC weights
for (size_t i = 5; i < 9; i++) {
TF_LITE_ENSURE_OK(
context, ValidateWeightTensorSize(context, tensors.GetInternalTensor(i),
size_info.state_dimension,
size_info.state_dimension));
}
// Biases
for (size_t i = 12; i < 16; i++) {
TF_LITE_ENSURE_OK(
context, ValidateBiasTensorSize(context, tensors.GetInternalTensor(i),
size_info.state_dimension));
}
// Check the shape of input state tensors.
// These tensor may be 1D or 2D. It's fine as long as the total size is
// correct.
TF_LITE_ENSURE_EQ(context, NumElements(tensors.HiddenStateTensor()),
size_info.batch_size * size_info.state_dimension);
TF_LITE_ENSURE_EQ(context, NumElements(tensors.CellStateTensor()),
size_info.batch_size * size_info.state_dimension);
// Check the shape of output tensor against that of input tensor
TF_LITE_ENSURE_EQ(context, tensors.OutputTensor()->dims->size, 3);
TF_LITE_ENSURE_EQ(context,
tensors.GetInternalTensor(kLstmInputTensor)->dims->data[0],
tensors.OutputTensor()->dims->data[0]);
TF_LITE_ENSURE_EQ(context,
tensors.GetInternalTensor(kLstmInputTensor)->dims->data[1],
tensors.OutputTensor()->dims->data[1]);
TF_LITE_ENSURE_EQ(context, tensors.OutputTensor()->dims->data[2],
size_info.state_dimension);
return kTfLiteOk;
}
// Wrapper function to create gate parameters for the four internal LSTM gates
TfLiteStatus CreateGateParams(
TfLiteContext* context,
/*Input tensors*/
const TfLiteTensor* input, const TfLiteTensor* input_weight,
const TfLiteTensor* input_bias,
/*Hidden state tensors*/
const TfLiteTensor* hidden_state, const TfLiteTensor* hidden_state_weight,
const TfLiteTensor* hidden_state_bias,
/*Scale of the fc output (input to non-linear activation)*/
const float nonlinear_activation_input_scale, const TfLiteType cell_type,
tflite::GateParameters& gate_params) {
// A temp tflite tensor to represent the output of fc operation. Only the data
// type and quantization parameters are set since it is only used for
// parameter calculations
TfLiteTensor fc_output_temp;
fc_output_temp.type = cell_type;
fc_output_temp.params.scale = nonlinear_activation_input_scale;
fc_output_temp.params.zero_point = 0; // symmetrical quantized
// A temp fc opdata to reuse the helper function on creating fc parameters
tflite::OpDataFullyConnected fc_data_temp;
// TODO(b/265853320): due to the lack of precision for the float scale,
// scale_diff / output_scale <= 0.02 (potentially requires 1e-8 precision) can
// not be satisified for the bias. Here we rely on the correctiveness of the
// conversion process (set input_bias=nullptr to avoid checking) for
// tensor scales
TF_LITE_ENSURE_STATUS(CalculateOpDataFullyConnected(
context, kTfLiteActNone, input->type, input, input_weight,
/*input_bias=*/nullptr, &fc_output_temp, &fc_data_temp));
gate_params.input_fc_params = FullyConnectedParamsQuantized(fc_data_temp);
double real_multiplier = 0.0;
GetQuantizedConvolutionMultipler(context, input, input_weight, nullptr,
&fc_output_temp, &real_multiplier);
TF_LITE_ENSURE_STATUS(CalculateOpDataFullyConnected(
context, kTfLiteActNone, hidden_state->type, hidden_state,
hidden_state_weight, hidden_state_bias, &fc_output_temp, &fc_data_temp));
gate_params.recurrent_fc_params = FullyConnectedParamsQuantized(fc_data_temp);
return kTfLiteOk;
}
// Create parameters for element wise multiplication that happens in a) cell
// state update ; b) hidden state update
// Note that all the output of gates are symmetrically quantized so only scales
// are required for input. However, during the hidden state update phase, the
// output is the updated hidden state, which is asymmetrically quantized. Thus
// output may require zero point
tflite::ArithmeticParams CreateInterGateMulParams(const float input1_scale,
const float input2_scale,
const float output_scale,
const TfLiteType output_type,
const int output_zp) {
tflite::ArithmeticParams op_params = {};
if (output_type == kTfLiteInt16) {
op_params.quantized_activation_min = std::numeric_limits<int16_t>::min();
op_params.quantized_activation_max = std::numeric_limits<int16_t>::max();
} else if (output_type == kTfLiteInt8) {
op_params.quantized_activation_min = std::numeric_limits<int8_t>::min();
op_params.quantized_activation_max = std::numeric_limits<int8_t>::max();
}
op_params.input1_offset = 0; // symmetric
op_params.input2_offset = 0; // symmetric
op_params.output_offset = output_zp;
const double input_product_scale =
static_cast<double>(input1_scale) * static_cast<double>(input2_scale);
double effective_scale =
input_product_scale / static_cast<double>(output_scale);
QuantizeMultiplier(effective_scale, &op_params.output_multiplier,
&op_params.output_shift);
return op_params;
}
// Create the additional information about the cell state, which include:
// cell_state_scale_power: used in integer nonlinear function (e.g., tanh)
// quantized_cell_clip: quantized cell clip range
CellStateInfo CreateLstmCellStateInfo(const float cell_state_scale,
const float cell_clip) {
CellStateInfo cell_state_info;
// cell_state_scale_power: 2^-cell_state_scale_power = cell state scale
int buffer;
tflite::CheckedLog2(cell_state_scale, &buffer);
cell_state_info.cell_state_scale_power = buffer;
// Cell state specifics
cell_state_info.cell_clip = cell_clip;
cell_state_info.quantized_cell_clip = static_cast<int16_t>(
std::min(std::max(static_cast<double>(cell_clip) /
static_cast<double>(cell_state_scale),
-32768.0),
32767.0));
return cell_state_info;
}
CellStateInfo CreateLstmCellStateInfoFloat(const float cell_clip) {
CellStateInfo cell_state_info;
cell_state_info.cell_clip = cell_clip;
cell_state_info.cell_state_scale_power = 0; // no quantization
cell_state_info.quantized_cell_clip = 0; // no quantization
return cell_state_info;
}
tflite::FullyConnectedParams CreateFCParamsFloat() {
FullyConnectedParams op_params;
CalculateActivationRange(kTfLiteActNone, &op_params.float_activation_min,
&op_params.float_activation_max);
return op_params;
}
tflite::GateParameters CreateGateParamsFloat() {
tflite::GateParameters gate_params = {};
gate_params.input_fc_params = CreateFCParamsFloat();
gate_params.recurrent_fc_params = CreateFCParamsFloat();
return gate_params;
}
tflite::ArithmeticParams CreateInterGateMulParamsFloat() {
tflite::ArithmeticParams op_params = {};
CalculateActivationRange(kTfLiteActNone, &op_params.float_activation_min,
&op_params.float_activation_max);
return op_params;
}
TfLiteStatus PrepareGateParametersFloat(TfLiteContext* context,
const LstmTensors& lstm_tensors,
OpDataLSTM* op_data_lstm) {
// Gate Parameters
op_data_lstm->forget_gate_parameters = CreateGateParamsFloat();
op_data_lstm->input_gate_parameters = CreateGateParamsFloat();
op_data_lstm->cell_gate_parameters = CreateGateParamsFloat();
op_data_lstm->output_gate_parameters = CreateGateParamsFloat();
// Inter gate multiplication parameters
op_data_lstm->inter_gate_parameters.forget_cell_mul_params =
CreateInterGateMulParamsFloat();
op_data_lstm->inter_gate_parameters.input_mul_params =
CreateInterGateMulParamsFloat();
op_data_lstm->inter_gate_parameters.output_mul_params =
CreateInterGateMulParamsFloat();
return kTfLiteOk;
}
TfLiteStatus PrepareGateParametersInteger(TfLiteContext* context,
const LstmTensors& lstm_tensors,
OpDataLSTM* op_data_lstm) {
float nonlinear_input_scale = 0.000244140625; // 2^-12 Q3.12 -> Q0.15
TF_LITE_ENSURE_OK(
context,
CreateGateParams(
context, lstm_tensors.GetInternalTensor(kLstmInputTensor),
lstm_tensors.GetInternalTensor(kLstmInputToForgetWeightsTensor),
lstm_tensors.GetInternalTensor(kLstmForgetGateBiasTensor),
lstm_tensors.GetInternalTensor(kLstmOutputStateTensor),
lstm_tensors.GetInternalTensor(kLstmRecurrentToForgetWeightsTensor),
/*hidden_state_bias=*/nullptr, nonlinear_input_scale, kTfLiteInt16,
op_data_lstm->forget_gate_parameters));
TF_LITE_ENSURE_OK(
context,
CreateGateParams(
context, lstm_tensors.GetInternalTensor(kLstmInputTensor),
lstm_tensors.GetInternalTensor(kLstmInputToInputWeightsTensor),
lstm_tensors.GetInternalTensor(kLstmInputGateBiasTensor),
lstm_tensors.GetInternalTensor(kLstmOutputStateTensor),
lstm_tensors.GetInternalTensor(kLstmRecurrentToInputWeightsTensor),
/*hidden_state_bias=*/nullptr, nonlinear_input_scale, kTfLiteInt16,
op_data_lstm->input_gate_parameters));
TF_LITE_ENSURE_OK(
context,
CreateGateParams(
context, lstm_tensors.GetInternalTensor(kLstmInputTensor),
lstm_tensors.GetInternalTensor(kLstmInputToCellWeightsTensor),
lstm_tensors.GetInternalTensor(kLstmCellGateBiasTensor),
lstm_tensors.GetInternalTensor(kLstmOutputStateTensor),
lstm_tensors.GetInternalTensor(kLstmRecurrentToCellWeightsTensor),
/*hidden_state_bias=*/nullptr, nonlinear_input_scale, kTfLiteInt16,
op_data_lstm->cell_gate_parameters));
TF_LITE_ENSURE_OK(
context,
CreateGateParams(
context, lstm_tensors.GetInternalTensor(kLstmInputTensor),
lstm_tensors.GetInternalTensor(kLstmInputToOutputWeightsTensor),
lstm_tensors.GetInternalTensor(kLstmOutputGateBiasTensor),
lstm_tensors.GetInternalTensor(kLstmOutputStateTensor),
lstm_tensors.GetInternalTensor(kLstmRecurrentToOutputWeightsTensor),
/*hidden_state_bias=*/nullptr, nonlinear_input_scale, kTfLiteInt16,
op_data_lstm->output_gate_parameters));
// Inter gate multiplication parameters
float nonlinear_output_scale = 0.000030517578125; // 2^-15 Q3.12 -> Q0.15
float cell_state_scale = lstm_tensors.CellStateTensor()->params.scale;
// forget gate output (nonlinear output) x cell state -> cell state
op_data_lstm->inter_gate_parameters.forget_cell_mul_params =
CreateInterGateMulParams(nonlinear_output_scale, cell_state_scale,
cell_state_scale, kTfLiteInt16);
// input gate output x cell gate output -> cell state
op_data_lstm->inter_gate_parameters.input_mul_params =
CreateInterGateMulParams(nonlinear_output_scale, nonlinear_output_scale,
cell_state_scale, kTfLiteInt16);
// tanh output x output gate output -> hidden state (potentially asymmetric)
op_data_lstm->inter_gate_parameters.output_mul_params =
CreateInterGateMulParams(
nonlinear_output_scale, nonlinear_output_scale,
lstm_tensors.HiddenStateTensor()->params.scale,
lstm_tensors.HiddenStateTensor()->type,
lstm_tensors.HiddenStateTensor()->params.zero_point);
return kTfLiteOk;
}
LSTMKernelContents CreateLSTMKernelContent(TfLiteContext* context,
TfLiteNode* node) {
LSTMKernelContents kernel_content;
// Point to correct tensors
for (size_t i = 0; i < 24; i++) {
kernel_content.internal_tensors[i] =
tflite::micro::GetMutableEvalInput(context, node, i);
}
// Output tensor
kernel_content.output_tensor = tflite::micro::GetEvalOutput(context, node, 0);
return kernel_content;
}
} // namespace tflite