blob: 4d7d9d9edcb99e98d829b73de7d376780d4aebcb [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/micro/kernels/testdata/lstm_test_data.h"
#include <cstring>
namespace tflite {
namespace testing {
namespace {
// LSTM internal setting (e.g., nonlinear activation type)
// Only UnidirectionalLSTM is supported now
constexpr TfLiteUnidirectionalSequenceLSTMParams kDefaultBuiltinData = {
/*.activation=*/kTfLiteActTanh,
/*.cell_clip=*/6,
/*.proj_clip=*/3,
/*.time_major=*/false,
/*.asymmetric_quantize_inputs=*/true,
/*diagonal_recurrent_tensors=*/false};
} // namespace
GateOutputCheckData<4, 4> Get2X2GateOutputCheckData() {
GateOutputCheckData<4, 4> gate_data;
const float input_data[4] = {
0.2, 0.3, // batch1
-0.98, 0.62 // batch2
};
std::memcpy(gate_data.input_data, input_data, 4 * sizeof(float));
const float hidden_state[4] = {
-0.1, 0.2, // batch1
-0.3, 0.5 // batch2
};
std::memcpy(gate_data.hidden_state, hidden_state, 4 * sizeof(float));
const float cell_state[4] = {
-1.3, 6.2, // batch1
-7.3, 3.5 // batch2
};
std::memcpy(gate_data.cell_state, cell_state, 4 * sizeof(float));
// Use the forget gate parameters to test small gate outputs
// output = sigmoid(W_i*i+W_h*h+b) = sigmoid([[-10,-10],[-20,-20]][0.2,
// +[[-10,-10],[-20,-20]][-0.1, 0.2]+[1,2]) = sigmoid([-5,-10]) =
// [6.69285092e-03, 4.53978687e-05] (Batch1)
// Similarly, we have [0.93086158 0.9945137 ] for batch 2
const float expected_forget_gate_output[4] = {6.69285092e-3f, 4.53978687e-5f,
0.93086158, 0.9945137};
std::memcpy(gate_data.expected_forget_gate_output,
expected_forget_gate_output, 4 * sizeof(float));
// Use the input gate parameters to test small gate outputs
// output = sigmoid(W_i*i+W_h*h+b) = sigmoid([[10,10],[20,20]][0.2, 0.3]
// +[[10,10],[20,20]][-0.1, 0.2]+[-1,-2]) = sigmoid([5,10]) =
// [0.99330715, 0.9999546]
// Similarly, we have [0.06913842 0.0054863 ] for batch 2
const float expected_input_gate_output[4] = {0.99330715, 0.9999546,
0.06913842, 0.0054863};
std::memcpy(gate_data.expected_input_gate_output, expected_input_gate_output,
4 * sizeof(float));
// Use the output gate parameters to test normnal gate outputs
// output = sigmoid(W_i*i+W_h*h+b) = sigmoid([[1,1],[1,1]][0.2, 0.3]
// +[[1,1],[1,1]][-0.1, 0.2]+[0,0]) = sigmoid([0.6,0.6]) =
// [0.6456563062257954, 0.6456563062257954]
// Similarly, we have [[0.46008512 0.46008512]] for batch 2
const float expected_output_gate_output[4] = {
0.6456563062257954, 0.6456563062257954, 0.46008512, 0.46008512};
std::memcpy(gate_data.expected_output_gate_output,
expected_output_gate_output, 4 * sizeof(float));
// Use the cell(modulation) gate parameters to tanh output
// output = tanh(W_i*i+W_h*h+b) = tanh([[1,1],[1,1]][0.2, 0.3]
// +[[1,1],[1,1]][-0.1, 0.2]+[0,0]) = tanh([0.6,0.6]) =
// [0.6456563062257954, 0.6456563062257954]
// Similarly, we have [-0.1586485 -0.1586485] for batch 2
const float expected_cell_gate_output[4] = {
0.5370495669980353, 0.5370495669980353, -0.1586485, -0.1586485};
std::memcpy(gate_data.expected_cell_gate_output, expected_cell_gate_output,
4 * sizeof(float));
// Cell = forget_gate*cell + input_gate*cell_gate
// Note -6.80625824 is clipped to -6
const float expected_updated_cell[4] = {0.52475447, 0.53730665, -6,
3.47992756};
std::memcpy(gate_data.expected_updated_cell, expected_updated_cell,
4 * sizeof(float));
// Use the updated cell state to update the hidden state
// tanh(expected_updated_cell) * expected_output_gate_output
const float expected_updated_hidden[4] = {0.31079388, 0.3169827, -0.46007947,
0.45921249};
std::memcpy(gate_data.expected_updated_hidden, expected_updated_hidden,
4 * sizeof(float));
return gate_data;
}
// TODO(b/253466487): document how the golden values are arrived at
LstmEvalCheckData<12, 4, 12> Get2X2LstmEvalCheckData() {
LstmEvalCheckData<12, 4, 12> eval_data;
const float input_data[12] = {
0.2, 0.3, 0.2, 0.3, 0.2, 0.3, // batch one
-0.98, 0.62, 0.01, 0.99, 0.49, -0.32 // batch two
};
std::memcpy(eval_data.input_data, input_data, 12 * sizeof(float));
// Initialize hidden state as zeros
const float hidden_state[4] = {};
std::memcpy(eval_data.hidden_state, hidden_state, 4 * sizeof(float));
// The expected model output after 3 time steps using the fixed input and
// parameters
const float expected_output[12] = {
0.26455893, 0.26870455, 0.47935803,
0.47937014, 0.58013272, 0.58013278, // batch1
-1.41184672e-3f, -1.43329117e-5f, 0.46887168,
0.46891281, 0.50054074, 0.50054148 // batch2
};
std::memcpy(eval_data.expected_output, expected_output, 12 * sizeof(float));
const float expected_hidden_state[4] = {
0.58013272, 0.58013278, // batch1
0.50054074, 0.50054148 // batch2
};
std::memcpy(eval_data.expected_hidden_state, expected_hidden_state,
4 * sizeof(float));
const float expected_cell_state[4] = {
0.89740515, 0.8974053, // batch1
0.80327607, 0.80327785 // batch2
};
std::memcpy(eval_data.expected_cell_state, expected_cell_state,
4 * sizeof(float));
return eval_data;
}
LstmNodeContent<float, float, float, float, 2, 3, 2, 2>
Create2x3x2X2FloatNodeContents(const float* input_data,
const float* hidden_state_data,
const float* cell_state_data) {
// Parameters for different gates
// negative large weights for forget gate to make it really forget
const GateData<float, float, 2, 2> forget_gate_data = {
/*.activation_weight=*/{-10, -10, -20, -20},
/*.recurrent_weight=*/{-10, -10, -20, -20},
/*.fused_bias=*/{1, 2},
/*activation_zp_folded_bias=*/{0, 0},
/*recurrent_zp_folded_bias=*/{0, 0}};
// positive large weights for input gate to make it really remember
const GateData<float, float, 2, 2> input_gate_data = {
/*.activation_weight=*/{10, 10, 20, 20},
/*.recurrent_weight=*/{10, 10, 20, 20},
/*.fused_bias=*/{-1, -2},
/*activation_zp_folded_bias=*/{0, 0},
/*recurrent_zp_folded_bias=*/{0, 0}};
// all ones to test the behavior of tanh at normal range (-1,1)
const GateData<float, float, 2, 2> cell_gate_data = {
/*.activation_weight=*/{1, 1, 1, 1},
/*.recurrent_weight=*/{1, 1, 1, 1},
/*.fused_bias=*/{0, 0},
/*activation_zp_folded_bias=*/{0, 0},
/*recurrent_zp_folded_bias=*/{0, 0}};
// all ones to test the behavior of sigmoid at normal range (-1. 1)
const GateData<float, float, 2, 2> output_gate_data = {
/*.activation_weight=*/{1, 1, 1, 1},
/*.recurrent_weight=*/{1, 1, 1, 1},
/*.fused_bias=*/{0, 0},
/*activation_zp_folded_bias=*/{0, 0},
/*recurrent_zp_folded_bias=*/{0, 0}};
LstmNodeContent<float, float, float, float, 2, 3, 2, 2> float_node_contents(
kDefaultBuiltinData, forget_gate_data, input_gate_data, cell_gate_data,
output_gate_data);
if (input_data != nullptr) {
float_node_contents.SetInputData(input_data);
}
if (hidden_state_data != nullptr) {
float_node_contents.SetHiddenStateData(hidden_state_data);
}
if (cell_state_data != nullptr) {
float_node_contents.SetCellStateData(cell_state_data);
}
return float_node_contents;
}
NodeQuantizationParameters Get2X2Int8LstmQuantizationSettings() {
NodeQuantizationParameters quantization_settings;
quantization_settings.activation_type = kTfLiteInt8;
quantization_settings.weight_type = kTfLiteInt8;
quantization_settings.cell_type = kTfLiteInt16;
quantization_settings.bias_type = kTfLiteInt32;
quantization_settings.nonlinear_activation_input_scale =
0.00024414062; // std::pow(2.0f, -12.0f)
quantization_settings.nonlinear_activation_output_scale =
0.00003051757; // std::pow(2.0f, -15.0f)
// state quantization parameters
quantization_settings.input = {/*scale=*/0.00784313725490196, /*zp=*/0,
/*symmetry=*/false};
quantization_settings.output = {/*scale=*/0.004705882165580988, /*zp=*/-21,
/*symmetry=*/false};
quantization_settings.hidden_state = {/*scale=*/0.004705882165580988,
/*zp=*/-21, /*symmetry=*/false};
quantization_settings.cell_state = {/*scale=*/0.00024414062, /*zp=*/0,
/*symmetry=*/true};
// gate quantization parameters
quantization_settings.forget_gate = {
{/*scale=*/0.15748031496062992, /*zp=*/0, /*symmetry=*/true},
{/*scale=*/0.15748031496062992, /*zp=*/0, /*symmetry=*/true},
{/*scale=*/0.0012351397251814111, /*zp=*/0, /*symmetry=*/true}};
quantization_settings.input_gate = {
{/*scale=*/0.15748031496062992, /*zp=*/0, /*symmetry=*/true},
{/*scale=*/0.15748031496062992, /*zp=*/0, /*symmetry=*/true},
{/*scale=*/0.0012351397251814111, /*zp=*/0, /*symmetry=*/true}};
quantization_settings.cell_gate = {
{/*scale=*/0.007874015748031496, /*zp=*/0, /*symmetry=*/true},
{/*scale=*/0.007874015748031496, /*zp=*/0, /*symmetry=*/true},
{/*scale=*/6.175698625907056e-5, /*zp=*/0, /*symmetry=*/true}};
quantization_settings.output_gate = {
{/*scale=*/0.007874015748031496, /*zp=*/0, /*symmetry=*/true},
{/*scale=*/0.007874015748031496, /*zp=*/0, /*symmetry=*/true},
{/*scale=*/6.175698625907056e-5, /*zp=*/0, /*symmetry=*/true}};
return quantization_settings;
}
NodeQuantizationParameters Get2X2Int16LstmQuantizationSettings() {
NodeQuantizationParameters quantization_settings;
quantization_settings.activation_type = kTfLiteInt16;
quantization_settings.weight_type = kTfLiteInt8;
quantization_settings.cell_type = kTfLiteInt16;
quantization_settings.bias_type = kTfLiteInt64;
quantization_settings.nonlinear_activation_input_scale =
0.00024414062; // std::pow(2.0f, -12.0f)
quantization_settings.nonlinear_activation_output_scale =
0.00003051757; // std::pow(2.0f, -15.0f)
// state quantization parameters
quantization_settings.input = {/*scale=*/3.0518044e-5, /*zp=*/0,
/*symmetry=*/false};
quantization_settings.output = {/*scale=*/1.8310826e-5, /*zp=*/-5461,
/*symmetry=*/false};
quantization_settings.hidden_state = {/*scale=*/1.8310826e-5, /*zp=*/-5461,
/*symmetry=*/false};
quantization_settings.cell_state = {/*scale=*/0.00024414062, /*zp=*/0,
/*symmetry=*/true};
// gate quantization parameters
quantization_settings.forget_gate = {
{/*scale=*/0.15748031496062992, /*zp=*/0, /*symmetry=*/true},
{/*scale=*/0.15748031496062992, /*zp=*/0, /*symmetry=*/true},
{/*scale=*/4.8059911474468205e-06, /*zp=*/0, /*symmetry=*/true}};
quantization_settings.input_gate = {
{/*scale=*/0.15748031496062992, /*zp=*/0, /*symmetry=*/true},
{/*scale=*/0.15748031496062992, /*zp=*/0, /*symmetry=*/true},
{/*scale=*/4.8059911474468205e-06, /*zp=*/0, /*symmetry=*/true}};
quantization_settings.cell_gate = {
{/*scale=*/0.007874015748031496, /*zp=*/0, /*symmetry=*/true},
{/*scale=*/0.007874015748031496, /*zp=*/0, /*symmetry=*/true},
{/*scale=*/2.40299557372341e-07, /*zp=*/0, /*symmetry=*/true}};
quantization_settings.output_gate = {
{/*scale=*/0.007874015748031496, /*zp=*/0, /*symmetry=*/true},
{/*scale=*/0.007874015748031496, /*zp=*/0, /*symmetry=*/true},
{/*scale=*/2.40299557372341e-07, /*zp=*/0, /*symmetry=*/true}};
return quantization_settings;
}
LstmNodeContent<int8_t, int8_t, int32_t, int16_t, 2, 3, 2, 2>
Create2x3x2X2Int8NodeContents(const float* input_data,
const float* hidden_state,
const float* cell_state) {
auto float_node_content =
Create2x3x2X2FloatNodeContents(input_data, hidden_state, cell_state);
const auto quantization_settings = Get2X2Int8LstmQuantizationSettings();
return CreateIntegerNodeContents<int8_t, int8_t, int32_t, int16_t, 2, 3, 2,
2>(quantization_settings,
/*fold_zero_point=*/true,
float_node_content);
}
LstmNodeContent<int16_t, int8_t, int64_t, int16_t, 2, 3, 2, 2>
Create2x3x2X2Int16NodeContents(const float* input_data,
const float* hidden_state,
const float* cell_state) {
auto float_node_content =
Create2x3x2X2FloatNodeContents(input_data, hidden_state, cell_state);
const auto quantization_settings = Get2X2Int16LstmQuantizationSettings();
return CreateIntegerNodeContents<int16_t, int8_t, int64_t, int16_t, 2, 3, 2,
2>(quantization_settings,
/*fold_zero_point=*/false,
float_node_content);
}
} // namespace testing
} // namespace tflite