blob: 53c0d7c4dac74a023c2b55892e99983e1e2f06ff [file] [log] [blame]
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/micro/kernels/lstm_eval_test.h"
#include <cstdint>
#include <cstdlib>
#include <memory>
#include <utility>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/kernels/lstm_eval.h"
#include "tensorflow/lite/micro/kernels/lstm_shared.h"
#include "tensorflow/lite/micro/kernels/testdata/lstm_test_data.h"
#include "tensorflow/lite/micro/test_helpers.h"
#include "tensorflow/lite/micro/testing/micro_test.h"
// TODO(b/230666079) enable below tests for xtensa when the xtensa
// kernel is reconciled with reference kernel
#if !defined(XTENSA)
namespace {
// Test Settings
constexpr float kTestFloatTolerance = 1e-6f;
} // namespace
#endif // !defined(XTENSA)
TF_LITE_MICRO_TESTS_BEGIN
// TODO(b/230666079) enable below tests for xtensa when the xtensa
// kernel is reconciled with reference kernel
#if !defined(XTENSA)
TF_LITE_MICRO_TEST(CheckGateOutputFloat) {
const tflite::testing::GateOutputCheckData<4, 4> gate_output_data =
tflite::testing::Get2X2GateOutputCheckData();
tflite::testing::LstmNodeContent<float, float, float, float, 2, 3, 2, 2>
float_node_contents = tflite::testing::Create2x3x2X2FloatNodeContents(
gate_output_data.input_data, gate_output_data.hidden_state,
gate_output_data.cell_state);
// Forget gate
tflite::testing::TestCalculateLstmGateFloat<2, 2>(
float_node_contents.GetEvalTensor(tflite::kLstmInputTensor),
float_node_contents.GetEvalTensor(
tflite::kLstmInputToForgetWeightsTensor),
float_node_contents.GetEvalTensor(tflite::kLstmForgetGateBiasTensor),
// Recurrent FC
float_node_contents.HiddenStateEvalTensor(),
float_node_contents.GetEvalTensor(
tflite::kLstmRecurrentToForgetWeightsTensor),
nullptr, // bias fused to activation FC,
// Result comparison
kTfLiteActSigmoid, gate_output_data.expected_forget_gate_output,
kTestFloatTolerance);
// Input gate
tflite::testing::TestCalculateLstmGateFloat<2, 2>(
float_node_contents.GetEvalTensor(tflite::kLstmInputTensor),
float_node_contents.GetEvalTensor(tflite::kLstmInputToInputWeightsTensor),
float_node_contents.GetEvalTensor(tflite::kLstmInputGateBiasTensor),
// Recurrent FC
float_node_contents.HiddenStateEvalTensor(),
float_node_contents.GetEvalTensor(
tflite::kLstmRecurrentToInputWeightsTensor),
nullptr, // bias fused to activation FC,
// Result comparison
kTfLiteActSigmoid, gate_output_data.expected_input_gate_output,
kTestFloatTolerance);
// Output gate
tflite::testing::TestCalculateLstmGateFloat<2, 2>(
float_node_contents.GetEvalTensor(tflite::kLstmInputTensor),
float_node_contents.GetEvalTensor(
tflite::kLstmInputToOutputWeightsTensor),
float_node_contents.GetEvalTensor(tflite::kLstmOutputGateBiasTensor),
// Recurrent FC
float_node_contents.HiddenStateEvalTensor(),
float_node_contents.GetEvalTensor(
tflite::kLstmRecurrentToOutputWeightsTensor),
nullptr, // bias fused to activation FC,
// Result comparison
kTfLiteActSigmoid, gate_output_data.expected_output_gate_output,
kTestFloatTolerance);
// Cell gate
tflite::testing::TestCalculateLstmGateFloat<2, 2>(
float_node_contents.GetEvalTensor(tflite::kLstmInputTensor),
float_node_contents.GetEvalTensor(tflite::kLstmInputToCellWeightsTensor),
float_node_contents.GetEvalTensor(tflite::kLstmCellGateBiasTensor),
// Recurrent FC
float_node_contents.HiddenStateEvalTensor(),
float_node_contents.GetEvalTensor(
tflite::kLstmRecurrentToCellWeightsTensor),
nullptr, // bias fused to activation FC,
// Result comparison
float_node_contents.BuiltinData().activation,
gate_output_data.expected_cell_gate_output, kTestFloatTolerance);
}
TF_LITE_MICRO_TEST(CheckGateOutputInt8) {
const tflite::testing::GateOutputCheckData<4, 4> gate_output_data =
tflite::testing::Get2X2GateOutputCheckData();
tflite::testing::LstmNodeContent<int8_t, int8_t, int32_t, int16_t, 2, 3, 2, 2>
int8_node_contents = tflite::testing::Create2x3x2X2Int8NodeContents(
gate_output_data.input_data, gate_output_data.hidden_state,
gate_output_data.cell_state);
// Forget gate
// Quantization performs badly here due to integer overflow!!!
float tolerance = 1e-1f;
tflite::testing::TestCalculateLstmGateInteger<int8_t, int8_t, int32_t,
int16_t, 2, 2>(
int8_node_contents.GetEvalTensor(tflite::kLstmInputTensor),
int8_node_contents.GetEvalTensor(tflite::kLstmInputToForgetWeightsTensor),
int8_node_contents.GetEvalTensor(tflite::kLstmForgetGateBiasTensor),
// Recurrent FC
int8_node_contents.HiddenStateEvalTensor(),
int8_node_contents.GetEvalTensor(
tflite::kLstmRecurrentToForgetWeightsTensor),
nullptr, // bias fused to activation FC,
// Quantization settings
int8_node_contents.QuantizationSettings(),
int8_node_contents.QuantizationSettings().forget_gate,
// Result comparison
kTfLiteActSigmoid, gate_output_data.expected_forget_gate_output,
tolerance);
// Input gate
// Quantization performs badly here due to integer overflow!!!
tolerance = 1e-1f;
tflite::testing::TestCalculateLstmGateInteger<int8_t, int8_t, int32_t,
int16_t, 2, 2>(
int8_node_contents.GetEvalTensor(tflite::kLstmInputTensor),
int8_node_contents.GetEvalTensor(tflite::kLstmInputToInputWeightsTensor),
int8_node_contents.GetEvalTensor(tflite::kLstmInputGateBiasTensor),
// Recurrent FC
int8_node_contents.HiddenStateEvalTensor(),
int8_node_contents.GetEvalTensor(
tflite::kLstmRecurrentToInputWeightsTensor),
nullptr, // bias fused to activation FC,
// Quantization settings
int8_node_contents.QuantizationSettings(),
int8_node_contents.QuantizationSettings().input_gate,
// Result comparison
kTfLiteActSigmoid, gate_output_data.expected_input_gate_output,
tolerance);
// Output gate
tolerance = 1e-2f;
tflite::testing::TestCalculateLstmGateInteger<int8_t, int8_t, int32_t,
int16_t, 2, 2>(
int8_node_contents.GetEvalTensor(tflite::kLstmInputTensor),
int8_node_contents.GetEvalTensor(tflite::kLstmInputToOutputWeightsTensor),
int8_node_contents.GetEvalTensor(tflite::kLstmOutputGateBiasTensor),
// Recurrent FC
int8_node_contents.HiddenStateEvalTensor(),
int8_node_contents.GetEvalTensor(
tflite::kLstmRecurrentToOutputWeightsTensor),
nullptr, // bias fused to activation FC,
// Quantization settings
int8_node_contents.QuantizationSettings(),
int8_node_contents.QuantizationSettings().output_gate,
// Result comparison
kTfLiteActSigmoid, gate_output_data.expected_output_gate_output,
tolerance);
// Cell gate
tolerance = 1e-2f;
tflite::testing::TestCalculateLstmGateInteger<int8_t, int8_t, int32_t,
int16_t, 2, 2>(
int8_node_contents.GetEvalTensor(tflite::kLstmInputTensor),
int8_node_contents.GetEvalTensor(tflite::kLstmInputToCellWeightsTensor),
int8_node_contents.GetEvalTensor(tflite::kLstmCellGateBiasTensor),
// Recurrent FC
int8_node_contents.HiddenStateEvalTensor(),
int8_node_contents.GetEvalTensor(
tflite::kLstmRecurrentToCellWeightsTensor),
nullptr, // bias fused to activation FC,
// Quantization settings
int8_node_contents.QuantizationSettings(),
int8_node_contents.QuantizationSettings().cell_gate,
// Result comparison
int8_node_contents.BuiltinData().activation,
gate_output_data.expected_cell_gate_output, tolerance);
}
TF_LITE_MICRO_TEST(CheckGateOutputInt16) {
const tflite::testing::GateOutputCheckData<4, 4> gate_output_data =
tflite::testing::Get2X2GateOutputCheckData();
tflite::testing::LstmNodeContent<int16_t, int8_t, int64_t, int16_t, 2, 3, 2,
2>
int16_node_contents = tflite::testing::Create2x3x2X2Int16NodeContents(
gate_output_data.input_data, gate_output_data.hidden_state,
gate_output_data.cell_state);
// Forget gate
// Quantization performs badly here due to integer overflow (from batch2)!!!
float tolerance = 1e-1f;
tflite::testing::TestCalculateLstmGateInteger<int16_t, int8_t, int64_t,
int16_t, 2, 2>(
int16_node_contents.GetEvalTensor(tflite::kLstmInputTensor),
int16_node_contents.GetEvalTensor(
tflite::kLstmInputToForgetWeightsTensor),
int16_node_contents.GetEvalTensor(tflite::kLstmForgetGateBiasTensor),
// Recurrent FC
int16_node_contents.HiddenStateEvalTensor(),
int16_node_contents.GetEvalTensor(
tflite::kLstmRecurrentToForgetWeightsTensor),
nullptr, // bias fused to activation FC,
// Quantization settings
int16_node_contents.QuantizationSettings(),
int16_node_contents.QuantizationSettings().forget_gate,
// Result comparison
kTfLiteActSigmoid, gate_output_data.expected_forget_gate_output,
tolerance);
// Input gate
// Quantization performs badly here due to integer overflow (from batch2)!!!
tolerance = 1e-1f;
tflite::testing::TestCalculateLstmGateInteger<int16_t, int8_t, int64_t,
int16_t, 2, 2>(
int16_node_contents.GetEvalTensor(tflite::kLstmInputTensor),
int16_node_contents.GetEvalTensor(tflite::kLstmInputToInputWeightsTensor),
int16_node_contents.GetEvalTensor(tflite::kLstmInputGateBiasTensor),
// Recurrent FC
int16_node_contents.HiddenStateEvalTensor(),
int16_node_contents.GetEvalTensor(
tflite::kLstmRecurrentToInputWeightsTensor),
nullptr, // bias fused to activation FC,
// Quantization settings
int16_node_contents.QuantizationSettings(),
int16_node_contents.QuantizationSettings().input_gate,
// Result comparison
kTfLiteActSigmoid, gate_output_data.expected_input_gate_output,
tolerance);
// Output gate
// Quantization scale (theoritical lowest range) is at range 1e-5
tolerance = 1e-4f;
tflite::testing::TestCalculateLstmGateInteger<int16_t, int8_t, int64_t,
int16_t, 2, 2>(
int16_node_contents.GetEvalTensor(tflite::kLstmInputTensor),
int16_node_contents.GetEvalTensor(
tflite::kLstmInputToOutputWeightsTensor),
int16_node_contents.GetEvalTensor(tflite::kLstmOutputGateBiasTensor),
// Recurrent FC
int16_node_contents.HiddenStateEvalTensor(),
int16_node_contents.GetEvalTensor(
tflite::kLstmRecurrentToOutputWeightsTensor),
nullptr, // bias fused to activation FC,
// Quantization settings
int16_node_contents.QuantizationSettings(),
int16_node_contents.QuantizationSettings().output_gate,
// Result comparison
kTfLiteActSigmoid, gate_output_data.expected_output_gate_output,
tolerance);
// Cell gate
tolerance = 1e-4f;
tflite::testing::TestCalculateLstmGateInteger<int16_t, int8_t, int64_t,
int16_t, 2, 2>(
int16_node_contents.GetEvalTensor(tflite::kLstmInputTensor),
int16_node_contents.GetEvalTensor(tflite::kLstmInputToCellWeightsTensor),
int16_node_contents.GetEvalTensor(tflite::kLstmCellGateBiasTensor),
// Recurrent FC
int16_node_contents.HiddenStateEvalTensor(),
int16_node_contents.GetEvalTensor(
tflite::kLstmRecurrentToCellWeightsTensor),
nullptr, // bias fused to activation FC,
// Quantization settings
int16_node_contents.QuantizationSettings(),
int16_node_contents.QuantizationSettings().cell_gate,
// Result comparison
int16_node_contents.BuiltinData().activation,
gate_output_data.expected_cell_gate_output, tolerance);
}
TF_LITE_MICRO_TEST(CheckCellStateUpdateFloat) {
const tflite::testing::GateOutputCheckData<4, 4> gate_output_data =
tflite::testing::Get2X2GateOutputCheckData();
tflite::testing::LstmNodeContent<float, float, float, float, 2, 3, 2, 2>
float_node_contents = tflite::testing::Create2x3x2X2FloatNodeContents(
gate_output_data.input_data, gate_output_data.hidden_state,
gate_output_data.cell_state);
tflite::testing::TestUpdateLstmCellFloat(
gate_output_data, float_node_contents, kTestFloatTolerance);
}
TF_LITE_MICRO_TEST(CheckCellStateUpdateInt8) {
const tflite::testing::GateOutputCheckData<4, 4> gate_output_data =
tflite::testing::Get2X2GateOutputCheckData();
tflite::testing::LstmNodeContent<int8_t, int8_t, int32_t, int16_t, 2, 3, 2, 2>
int8_node_contents = tflite::testing::Create2x3x2X2Int8NodeContents(
gate_output_data.input_data, gate_output_data.hidden_state,
gate_output_data.cell_state);
// Very high precision. The error is introduced by the
// quantization error of the clip value (~1e-5), but cannot actually reach
// the precision due to integer overflow of the elements
const float tolerance = 1e-3f;
tflite::testing::TestUpdateLstmCellInteger(gate_output_data,
int8_node_contents, tolerance);
}
TF_LITE_MICRO_TEST(CheckCellStateUpdateInt16) {
const tflite::testing::GateOutputCheckData<4, 4> gate_output_data =
tflite::testing::Get2X2GateOutputCheckData();
tflite::testing::LstmNodeContent<int16_t, int8_t, int64_t, int16_t, 2, 3, 2,
2>
int16_node_contents = tflite::testing::Create2x3x2X2Int16NodeContents(
gate_output_data.input_data, gate_output_data.hidden_state,
gate_output_data.cell_state);
// Very high precision. The error is introduced by the
// quantization error of the clip value (~1e-5), but cannot actually reach
// the precision due to integer overflow of the elements
const float tolerance = 1e-3f;
tflite::testing::TestUpdateLstmCellInteger(gate_output_data,
int16_node_contents, tolerance);
}
TF_LITE_MICRO_TEST(CheckHiddenStateUpdateFloat) {
const tflite::testing::GateOutputCheckData<4, 4> gate_output_data =
tflite::testing::Get2X2GateOutputCheckData();
tflite::testing::LstmNodeContent<float, float, float, float, 2, 3, 2, 2>
float_node_contents = tflite::testing::Create2x3x2X2FloatNodeContents(
gate_output_data.input_data, gate_output_data.hidden_state,
gate_output_data.expected_updated_cell);
tflite::testing::TestUpdateLstmHiddenFloat(
gate_output_data, float_node_contents, kTestFloatTolerance);
}
TF_LITE_MICRO_TEST(CheckHiddenStateUpdateInt8) {
const tflite::testing::GateOutputCheckData<4, 4> gate_output_data =
tflite::testing::Get2X2GateOutputCheckData();
tflite::testing::LstmNodeContent<int8_t, int8_t, int32_t, int16_t, 2, 3, 2, 2>
int8_node_contents = tflite::testing::Create2x3x2X2Int8NodeContents(
gate_output_data.input_data, gate_output_data.hidden_state,
gate_output_data.expected_updated_cell);
// Theoritical error floor = quantization scale = 0.004705882165580988
const float tolerance = 1e-2;
tflite::testing::TestUpdateLstmHiddenInteger(gate_output_data,
int8_node_contents, tolerance);
}
TF_LITE_MICRO_TEST(CheckHiddenStateUpdateInt16) {
const tflite::testing::GateOutputCheckData<4, 4> gate_output_data =
tflite::testing::Get2X2GateOutputCheckData();
tflite::testing::LstmNodeContent<int16_t, int8_t, int64_t, int16_t, 2, 3, 2,
2>
int16_node_contents = tflite::testing::Create2x3x2X2Int16NodeContents(
gate_output_data.input_data, gate_output_data.hidden_state,
gate_output_data.expected_updated_cell);
const float tolerance = 1e-4;
tflite::testing::TestUpdateLstmHiddenInteger(gate_output_data,
int16_node_contents, tolerance);
}
TF_LITE_MICRO_TEST(CheckOneStepLSTMFloat) {
const tflite::testing::GateOutputCheckData<4, 4> gate_output_data =
tflite::testing::Get2X2GateOutputCheckData();
tflite::testing::LstmNodeContent<float, float, float, float, 2, 3, 2, 2>
float_node_contents = tflite::testing::Create2x3x2X2FloatNodeContents(
gate_output_data.input_data, gate_output_data.hidden_state,
gate_output_data.cell_state);
tflite::testing::TestLstmStepFloat(gate_output_data, kTestFloatTolerance,
kTestFloatTolerance, float_node_contents);
}
TF_LITE_MICRO_TEST(CheckOneStepLSTMInt8) {
const tflite::testing::GateOutputCheckData<4, 4> gate_output_data =
tflite::testing::Get2X2GateOutputCheckData();
tflite::testing::LstmNodeContent<int8_t, int8_t, int32_t, int16_t, 2, 3, 2, 2>
int8_node_contents = tflite::testing::Create2x3x2X2Int8NodeContents(
gate_output_data.input_data, gate_output_data.hidden_state,
gate_output_data.cell_state);
const float hidden_state_tolerance = 1e-2;
// cell state degrade due to integer overflow
const float cell_state_tolerance = 1e-1;
tflite::testing::TestLstmStepInteger(gate_output_data, hidden_state_tolerance,
cell_state_tolerance,
int8_node_contents);
}
TF_LITE_MICRO_TEST(CheckOneStepLSTMInt16) {
const tflite::testing::GateOutputCheckData<4, 4> gate_output_data =
tflite::testing::Get2X2GateOutputCheckData();
tflite::testing::LstmNodeContent<int16_t, int8_t, int64_t, int16_t, 2, 3, 2,
2>
int16_node_contents = tflite::testing::Create2x3x2X2Int16NodeContents(
gate_output_data.input_data, gate_output_data.hidden_state,
gate_output_data.cell_state);
const float hidden_state_tolerance = 1e-3; // actually very close to 1e-4
// cell state degrade due to integer overflow
const float cell_state_tolerance = 1e-1;
tflite::testing::TestLstmStepInteger<int16_t, int8_t, int64_t, int16_t, 2, 3,
2, 2>(
gate_output_data, hidden_state_tolerance, cell_state_tolerance,
int16_node_contents);
}
TF_LITE_MICRO_TEST(TestLSTMEvalFloat) {
const tflite::testing::LstmEvalCheckData<12, 4, 12> kernel_eval_data =
tflite::testing::Get2X2LstmEvalCheckData();
tflite::testing::LstmNodeContent<float, float, float, float, 2, 3, 2, 2>
float_node_contents = tflite::testing::Create2x3x2X2FloatNodeContents(
kernel_eval_data.input_data, kernel_eval_data.hidden_state);
tflite::testing::TestEvalLstmFloat(kernel_eval_data, kTestFloatTolerance,
kTestFloatTolerance, float_node_contents);
}
TF_LITE_MICRO_TEST(TestLSTMEvalInt8) {
const tflite::testing::LstmEvalCheckData<12, 4, 12> kernel_eval_data =
tflite::testing::Get2X2LstmEvalCheckData();
tflite::testing::LstmNodeContent<int8_t, int8_t, int32_t, int16_t, 2, 3, 2, 2>
int8_node_contents = tflite::testing::Create2x3x2X2Int8NodeContents(
kernel_eval_data.input_data, kernel_eval_data.hidden_state);
const float hidden_state_tolerance = 1e-2;
// cell state degrade due to integer overflow
const float cell_state_tolerance = 1e-2;
tflite::testing::TestEvalLstmInteger(kernel_eval_data, hidden_state_tolerance,
cell_state_tolerance,
int8_node_contents);
}
TF_LITE_MICRO_TEST(TestLSTMEvalInt16) {
const tflite::testing::LstmEvalCheckData<12, 4, 12> kernel_eval_data =
tflite::testing::Get2X2LstmEvalCheckData();
tflite::testing::LstmNodeContent<int16_t, int8_t, int64_t, int16_t, 2, 3, 2,
2>
int16_node_contents = tflite::testing::Create2x3x2X2Int16NodeContents(
kernel_eval_data.input_data, kernel_eval_data.hidden_state);
const float hidden_state_tolerance = 1e-3; // actually very close to 1e-4
// cell state degrade due to integer overflow
const float cell_state_tolerance = 1e-2;
tflite::testing::TestEvalLstmInteger(kernel_eval_data, hidden_state_tolerance,
cell_state_tolerance,
int16_node_contents);
}
#endif // !defined(XTENSA)
TF_LITE_MICRO_TESTS_END