| /* Copyright 2023 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/lite/c/builtin_op_data.h" |
| #include "tensorflow/lite/c/common.h" |
| #include "tensorflow/lite/micro/kernels/kernel_runner.h" |
| #include "tensorflow/lite/micro/test_helpers.h" |
| #include "tensorflow/lite/micro/testing/micro_test.h" |
| |
| namespace tflite { |
| namespace testing { |
| namespace { |
| |
| // naming as follows: <tensor name>_<input size>x<batch size>x<batch count> |
| |
| // 10 inputs each with shape {2, 2}. |
| const float input_data_2x2x10[] = { |
| 0.12609188, -0.46347019, 0.35867718, 0.36897406, |
| |
| 0.14278367, -1.64410412, -0.57290924, 0.12729003, |
| |
| 0.49837467, 0.19278903, 0.17660543, 0.52949083, |
| |
| -0.11186574, 0.13164264, -0.72674477, -0.5683046, |
| |
| -0.68892461, 0.37783599, -0.63690937, 0.44483393, |
| |
| -0.81299269, -0.86831826, -0.95760226, 1.82078898, |
| |
| -1.45006323, -0.82251364, -1.65087092, -1.89238167, |
| |
| 0.03966608, -0.24936394, 2.06740379, -1.51439476, |
| |
| 0.11771342, -0.23761693, 0.31088525, -1.55601168, |
| |
| -0.89477462, 1.67204106, -0.6230064, 0.29819036, |
| }; |
| |
| // Feature filter of shape {8, 2}. |
| const float feature_weights_data_2x2x10[] = { |
| -0.31930989, 0.0079667, 0.39296314, 0.37613347, 0.12416199, 0.15785322, |
| 0.27901134, 0.3905206, 0.21931258, -0.36137494, -0.10640851, 0.31053296, |
| -0.36118156, -0.0976817, -0.36916667, 0.22197971}; |
| |
| // Time filter of shape {8, 10}. |
| const float time_weights_data_2x2x10[] = { |
| -0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156, |
| 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199, |
| |
| 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518, |
| -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296, |
| |
| -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236, |
| 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846, |
| |
| -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166, |
| -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657, |
| |
| -0.14884081, 0.19931212, -0.36002168, 0.34663299, -0.11405486, |
| 0.12672701, 0.39463779, -0.07886535, -0.06384811, 0.08249187, |
| |
| -0.26816407, -0.19905911, 0.29211238, 0.31264046, -0.28664589, |
| 0.05698794, 0.11613581, 0.14078894, 0.02187902, -0.21781836, |
| |
| -0.15567942, 0.08693647, -0.38256618, 0.36580828, -0.22922277, |
| -0.0226903, 0.12878349, -0.28122205, -0.10850525, -0.11955214, |
| |
| 0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326, |
| 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763}; |
| |
| // Activation state with shape {2, 80}. These initial values must be copied into |
| // a mutable activation state tensor. |
| |
| const float initial_activation_state_data_2x2x10[] = { |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; |
| |
| // Bias with shape {8} |
| const float bias_data_2x2x10[] = {0, 0, 0, 0, 0, 0, 0, 0}; |
| |
| // 10 outputs each of shape {2, 4} |
| const float golden_output_2x2x10[] = { |
| -0.044205, -0.013757, 0.050369, -0.018447, |
| 0.073010, 0.025142, -0.021154, 0.013551, |
| |
| -0.209613, -0.062421, 0.150209, -0.108334, |
| 0.028256, -0.006950, -0.030885, 0.009603, |
| |
| -0.076800, -0.037075, -0.087198, -0.155183, |
| 0.091069, 0.098446, -0.016083, 0.106475, |
| |
| -0.082123, -0.162238, -0.084434, -0.141074, |
| -0.029340, -0.090685, 0.053302, -0.030604, |
| |
| -0.201440, 0.088424, 0.139877, 0.012416, |
| -0.113212, 0.103893, -0.100842, 0.122780, |
| |
| -0.166632, -0.116705, 0.175298, -0.047163, |
| 0.313077, -0.166485, -0.285860, 0.129069, |
| |
| -0.625911, 0.046134, 0.138081, -0.129581, |
| -0.521455, -0.061579, 0.230289, 0.114963, |
| |
| -0.216693, -0.161643, -0.179177, -0.052599, |
| -0.213239, 0.029502, 0.260858, 0.275045, |
| |
| -0.213689, -0.323608, -0.285635, -0.317687, |
| -0.324092, -0.317972, -0.208450, -0.462504, |
| |
| -0.255126, -0.218576, -0.041528, 0.179421, |
| -0.440583, 0.072127, -0.284136, 0.241570}; |
| |
| // Simulated real-world inputs, weights and expected outputs. |
| |
| // Input of shape {1x16} |
| const float input_data_16x1x1[] = { |
| -0.488494, 2.023762, -2.233117, -0.488494, 3.559030, 9.490748, |
| -3.210106, -1.953977, -0.279140, 0.907204, 1.674838, 0.000000, |
| -0.279140, -0.628064, -0.069785, -0.628064, |
| }; |
| |
| // Feature filter of shape {64, 16}. |
| const float feature_weights_data_16x1x1[] = { |
| 0.173588, 0.173588, -0.024798, 0.193426, -0.099193, 0.044637, 0.183507, |
| 0.183507, 0.044637, 0.198386, -0.069435, 0.084314, 0.312458, 0.024798, |
| 0.173588, -0.049596, -0.352135, -0.550521, -0.009919, -0.099193, -0.074395, |
| -0.128951, 0.193426, 0.357095, -0.317418, -0.119032, -0.218225, -0.004960, |
| -0.386853, -0.133911, 0.252942, -0.019839, -0.024798, -0.054556, -0.069435, |
| -0.128951, 0.029758, -0.099193, -0.312458, -0.029758, 0.064475, 0.183507, |
| 0.114072, -0.178547, -0.247982, -0.119032, 0.243023, -0.119032, -0.034718, |
| -0.178547, 0.019839, 0.128951, -0.223184, -0.009919, -0.213265, 0.168628, |
| -0.143830, -0.322377, -0.218225, -0.193426, -0.252942, -0.049596, 0.064475, |
| -0.267821, -0.580279, -0.099193, 0.213265, 0.119032, -0.119032, -0.178547, |
| 0.610037, 0.109112, 0.049596, -0.014879, -0.049596, -0.193426, 0.039677, |
| -0.148789, -0.114072, -0.158709, -0.158709, 0.094233, 0.099193, -0.114072, |
| 0.104153, -0.123991, 0.198386, -0.173588, 0.089274, -0.247982, -0.054556, |
| 0.123991, 0.183507, 0.114072, 0.188467, 0.302539, 0.044637, 0.039677, |
| -0.099193, 0.168628, -0.024798, -0.054556, -0.109112, 0.014879, -0.009919, |
| 0.069435, -0.396772, -0.287660, -0.079354, -0.104153, 0.054556, 0.089274, |
| -0.099193, 0.114072, 0.034718, 0.119032, 0.282700, -0.119032, -0.505884, |
| -0.233104, -0.114072, -0.257902, -0.233104, -0.178547, 0.153749, 0.128951, |
| 0.143830, -0.188467, -0.183507, 0.104153, -0.024798, 0.193426, -0.287660, |
| 0.168628, -0.009919, 0.119032, -0.024798, -0.099193, -0.203346, 0.099193, |
| 0.084314, -0.168628, 0.123991, -0.148789, 0.114072, -0.029758, 0.228144, |
| -0.238063, 0.089274, -0.064475, 0.307498, -0.188467, -0.004960, -0.252942, |
| -0.173588, -0.158709, -0.044637, -0.009919, 0.312458, -0.262861, 0.059516, |
| 0.158709, 0.069435, -0.282700, 0.074395, -0.322377, -0.183507, -0.123991, |
| -0.233104, 0.009919, 0.252942, -0.243023, 0.555481, -0.099193, -0.119032, |
| -0.441409, 0.148789, 0.084314, -0.168628, -0.183507, 0.188467, 0.024798, |
| -0.302539, 0.223184, 0.143830, -0.193426, -0.054556, -0.218225, -0.297579, |
| 0.104153, 0.272781, -0.034718, 0.114072, -0.059516, 0.044637, 0.342216, |
| 0.421570, 0.138870, -0.024798, -0.039677, -0.163668, -0.034718, 0.396772, |
| -0.128951, -0.044637, -0.173588, 0.302539, 0.079354, 0.049596, 0.133911, |
| -0.029758, -0.312458, -0.029758, 0.079354, 0.128951, 0.252942, 0.213265, |
| 0.014879, 0.287660, 0.178547, 0.297579, 0.352135, 0.401732, 0.024798, |
| -0.277740, -0.411651, -0.069435, 0.342216, -0.158709, -0.104153, -0.009919, |
| 0.223184, 0.228144, -0.019839, 0.059516, -0.104153, -0.510844, 0.029758, |
| -0.406691, 0.089274, 0.421570, 0.163668, -0.143830, -0.019839, -0.039677, |
| 0.104153, -0.044637, -0.128951, 0.203346, 0.079354, -0.069435, 0.094233, |
| -0.138870, 0.466207, -0.163668, 0.049596, 0.029758, 0.267821, 0.029758, |
| -0.049596, 0.009919, 0.004960, -0.099193, 0.094233, -0.262861, 0.089274, |
| -0.302539, 0.332297, -0.307498, -0.014879, 0.168628, -0.094233, -0.272781, |
| 0.034718, -0.133911, -0.228144, 0.094233, 0.257902, -0.228144, 0.153749, |
| -0.054556, -0.252942, 0.054556, 0.218225, -0.054556, 0.302539, 0.282700, |
| 0.054556, -0.044637, -0.133911, 0.233104, -0.049596, 0.411651, 0.044637, |
| -0.297579, -0.029758, -0.114072, 0.114072, -0.580279, 0.079354, -0.024798, |
| -0.347175, -0.128951, -0.099193, 0.238063, -0.104153, -0.009919, 0.158709, |
| -0.034718, 0.123991, -0.163668, 0.059516, 0.342216, 0.009919, 0.064475, |
| -0.307498, -0.520763, -0.238063, 0.163668, 0.362054, 0.034718, -0.178547, |
| -0.104153, -0.257902, 0.322377, 0.054556, 0.148789, -0.178547, 0.084314, |
| 0.004960, 0.257902, 0.029758, 0.079354, -0.223184, -0.193426, 0.282700, |
| 0.000000, -0.019839, -0.114072, 0.491005, -0.193426, -0.029758, -0.243023, |
| 0.009919, 0.089274, -0.277740, -0.089274, 0.104153, 0.337256, 0.138870, |
| -0.307498, -0.054556, 0.352135, 0.133911, -0.044637, 0.133911, -0.089274, |
| -0.357095, -0.272781, 0.069435, 0.059516, -0.109112, 0.148789, -0.044637, |
| -0.019839, -0.153749, 0.123991, -0.223184, 0.322377, 0.074395, -0.312458, |
| 0.024798, -0.223184, 0.109112, -0.138870, 0.218225, -0.074395, -0.406691, |
| 0.009919, -0.198386, -0.009919, 0.416611, 0.178547, 0.148789, 0.133911, |
| -0.004960, 0.069435, -0.054556, -0.044637, 0.297579, 0.059516, -0.456288, |
| -0.148789, -0.004960, 0.054556, 0.094233, -0.104153, 0.198386, -0.302539, |
| 0.133911, 0.411651, 0.054556, 0.525723, -0.089274, 0.079354, 0.238063, |
| 0.079354, -0.039677, 0.039677, 0.029758, 0.332297, -0.014879, -0.367014, |
| -0.143830, -0.123991, -0.064475, 0.014879, 0.173588, -0.168628, 0.386853, |
| 0.009919, 0.173588, 0.163668, 0.123991, 0.163668, 0.198386, 0.203346, |
| -0.401732, -0.009919, 0.272781, -0.173588, 0.044637, 0.238063, 0.133911, |
| 0.049596, 0.208305, -0.024798, 0.049596, -0.049596, 0.034718, -0.446368, |
| 0.466207, -0.089274, -0.099193, -0.128951, -0.228144, 0.014879, -0.252942, |
| 0.074395, -0.223184, -0.168628, -0.292619, 0.178547, 0.153749, -0.014879, |
| 0.054556, 0.000000, 0.193426, 0.158709, 0.178547, -0.327337, -0.138870, |
| -0.114072, 0.168628, 0.297579, -0.109112, -0.029758, -0.029758, -0.416611, |
| 0.059516, 0.000000, -0.168628, -0.322377, 0.238063, -0.128951, -0.029758, |
| 0.500925, 0.292619, 0.123991, -0.099193, 0.074395, 0.317418, -0.148789, |
| 0.064475, -0.104153, -0.044637, -0.094233, 0.188467, -0.044637, 0.213265, |
| -0.233104, -0.049596, 0.004960, -0.198386, 0.287660, -0.148789, -0.257902, |
| 0.004960, -0.218225, -0.044637, -0.386853, -0.243023, -0.163668, 0.094233, |
| 0.029758, -0.019839, -0.009919, -0.143830, -0.158709, 0.158709, -0.243023, |
| -0.039677, -0.297579, 0.069435, 0.049596, 0.302539, 0.059516, 0.074395, |
| -0.019839, 0.352135, -0.019839, -0.138870, -0.178547, -0.243023, 0.233104, |
| 0.252942, -0.228144, -0.049596, 0.173588, 0.173588, -0.074395, -0.034718, |
| -0.292619, 0.362054, 0.183507, 0.243023, -0.203346, -0.044637, 0.054556, |
| 0.059516, -0.158709, -0.158709, 0.000000, 0.327337, 0.119032, 0.034718, |
| -0.044637, -0.089274, 0.089274, -0.233104, 0.000000, -0.317418, 0.371974, |
| 0.213265, 0.307498, -0.178547, -0.367014, 0.039677, -0.059516, 0.168628, |
| -0.014879, 0.143830, 0.123991, -0.084314, -0.332297, -0.416611, 0.183507, |
| 0.109112, -0.039677, 0.014879, 0.292619, -0.213265, -0.054556, 0.004960, |
| 0.123991, 0.119032, 0.000000, -0.332297, -0.312458, -0.198386, -0.213265, |
| 0.119032, 0.322377, 0.168628, 0.104153, -0.262861, 0.327337, -0.049596, |
| -0.228144, -0.074395, 0.168628, 0.123991, 0.396772, 0.044637, 0.322377, |
| 0.193426, 0.267821, -0.178547, 0.297579, 0.148789, -0.218225, -0.138870, |
| 0.044637, 0.049596, 0.133911, 0.064475, 0.069435, 0.064475, -0.158709, |
| -0.044637, -0.173588, 0.267821, 0.327337, 0.079354, -0.228144, 0.029758, |
| 0.014879, 0.198386, -0.109112, -0.133911, 0.431490, 0.099193, 0.421570, |
| 0.233104, -0.054556, 0.054556, -0.317418, -0.133911, -0.123991, -0.287660, |
| 0.342216, -0.049596, -0.153749, 0.228144, -0.213265, 0.262861, 0.406691, |
| -0.084314, -0.004960, 0.193426, 0.188467, -0.099193, -0.223184, 0.163668, |
| -0.257902, -0.153749, 0.441409, 0.099193, 0.128951, -0.089274, -0.208305, |
| -0.009919, -0.004960, -0.109112, 0.024798, -0.119032, 0.019839, 0.391812, |
| -0.024798, 0.198386, 0.327337, -0.505884, -0.099193, 0.510844, -0.148789, |
| 0.094233, -0.153749, -0.039677, 0.352135, 0.272781, -0.228144, -0.287660, |
| -0.272781, 0.148789, 0.277740, 0.074395, 0.109112, -0.064475, 0.044637, |
| 0.074395, -0.292619, 0.153749, -0.064475, -0.114072, 0.198386, -0.039677, |
| -0.128951, -0.004960, 0.257902, -0.228144, -0.094233, 0.064475, 0.014879, |
| 0.188467, -0.416611, 0.099193, 0.362054, -0.208305, 0.198386, -0.079354, |
| 0.009919, 0.119032, 0.332297, 0.243023, -0.168628, 0.158709, 0.039677, |
| 0.143830, 0.277740, -0.168628, 0.009919, 0.099193, -0.004960, -0.257902, |
| -0.297579, 0.208305, -0.104153, 0.119032, 0.247982, 0.381893, -0.223184, |
| -0.367014, -0.327337, -0.168628, -0.094233, 0.208305, -0.019839, 0.183507, |
| 0.084314, 0.133911, 0.109112, -0.148789, -0.183507, -0.411651, -0.024798, |
| -0.114072, -0.029758, -0.009919, 0.173588, -0.059516, -0.049596, 0.039677, |
| 0.317418, 0.138870, -0.247982, -0.084314, 0.158709, 0.054556, -0.084314, |
| -0.049596, 0.074395, 0.019839, -0.282700, -0.119032, -0.262861, 0.163668, |
| -0.069435, -0.064475, -0.059516, 0.094233, 0.123991, -0.079354, -0.272781, |
| -0.267821, 0.233104, 0.114072, -0.218225, 0.540602, 0.089274, 0.262861, |
| 0.079354, 0.267821, -0.119032, -0.109112, -0.128951, 0.128951, -0.044637, |
| -0.272781, 0.277740, 0.297579, -0.054556, -0.084314, -0.049596, 0.123991, |
| 0.059516, 0.238063, -0.168628, -0.009919, 0.163668, -0.307498, 0.109112, |
| -0.064475, 0.218225, -0.168628, -0.004960, -0.168628, 0.119032, 0.094233, |
| -0.183507, -0.089274, -0.292619, -0.094233, 0.064475, -0.183507, -0.168628, |
| 0.089274, 0.074395, -0.367014, -0.024798, -0.069435, 0.119032, -0.302539, |
| -0.376933, -0.123991, -0.009919, -0.069435, -0.208305, -0.119032, 0.014879, |
| -0.183507, -0.238063, 0.163668, -0.332297, -0.148789, -0.391812, -0.024798, |
| -0.133911, -0.059516, -0.123991, 0.123991, -0.292619, -0.044637, 0.059516, |
| -0.069435, 0.049596, -0.069435, 0.034718, 0.158709, -0.347175, -0.044637, |
| 0.352135, -0.347175, -0.282700, -0.054556, 0.307498, 0.029758, 0.357095, |
| -0.148789, 0.208305, -0.317418, 0.009919, 0.004960, -0.243023, 0.049596, |
| -0.099193, 0.213265, -0.342216, 0.158709, 0.123991, -0.332297, 0.386853, |
| -0.262861, -0.208305, 0.123991, -0.044637, 0.148789, 0.084314, -0.297579, |
| -0.307498, -0.163668, 0.337256, -0.014879, 0.074395, 0.178547, -0.004960, |
| -0.257902, -0.019839, -0.228144, -0.034718, -0.277740, -0.158709, -0.119032, |
| -0.153749, 0.629876, 0.277740, 0.178547, -0.267821, -0.004960, 0.247982, |
| 0.084314, -0.094233, 0.000000, -0.039677, 0.332297, 0.178547, 0.009919, |
| -0.213265, -0.208305, -0.044637, 0.019839, 0.218225, -0.297579, 0.014879, |
| -0.247982, -0.004960, -0.128951, 0.421570, -0.059516, 0.362054, -0.203346, |
| -0.143830, -0.099193, -0.024798, 0.094233, -0.123991, 0.163668, 0.109112, |
| -0.104153, -0.233104, 0.009919, -0.218225, 0.376933, 0.104153, -0.059516, |
| 0.049596, -0.054556, 0.019839, -0.044637, -0.019839, 0.371974, -0.019839, |
| 0.104153, 0.168628, -0.024798, -0.272781, -0.158709, 0.223184, 0.044637, |
| 0.039677, -0.168628, -0.287660, -0.109112, 0.094233, -0.089274, -0.148789, |
| 0.178547, -0.039677, -0.089274, -0.049596, -0.024798, 0.064475, -0.158709, |
| 0.089274, 0.029758, -0.247982, 0.362054, 0.024798, -0.004960, -0.099193, |
| 0.173588, -0.059516, 0.188467, -0.629876, 0.094233, 0.371974, 0.069435, |
| 0.252942, -0.357095, -0.272781, -0.367014, 0.014879, -0.049596, -0.262861, |
| 0.009919, -0.094233, -0.094233, 0.059516, 0.223184, 0.133911, 0.411651, |
| -0.044637, -0.044637, 0.109112, 0.228144, 0.386853, -0.233104, 0.069435, |
| 0.228144, -0.302539, 0.029758, 0.089274, 0.044637, -0.238063, -0.138870, |
| -0.158709, -0.019839, 0.049596, 0.039677, 0.000000, -0.069435, 0.109112, |
| -0.213265, -0.188467, -0.262861, -0.267821, -0.094233, 0.133911, 0.391812, |
| 0.123991, -0.317418, 0.233104, -0.029758, -0.099193, -0.193426, 0.074395, |
| -0.009919, 0.252942, 0.322377, -0.530683, 0.208305, 0.252942, 0.203346, |
| -0.069435, -0.262861}; |
| |
| // Time filter of shape {64, 8}. |
| const float time_weights_data_16x1x1[] = { |
| -0.052026, 0.043107, 0.053512, 0.013378, 0.011892, -0.182834, -0.108511, |
| 0.153105, 0.050539, -0.173915, 0.145672, 0.208103, -0.221481, 0.108511, |
| -0.496475, 0.181347, -0.016351, -0.132294, -0.234859, -0.243778, 0.028243, |
| -0.228914, -0.130808, -0.167969, -0.041621, -0.306209, -0.193239, -0.028243, |
| -0.057972, -0.057972, -0.497962, 0.054999, 0.181347, 0.047566, -0.099592, |
| -0.111484, -0.130808, -0.071350, 0.380532, 0.010405, 0.041621, 0.052026, |
| 0.022297, 0.081755, 0.098106, 0.099592, -0.584176, -0.023783, 0.062431, |
| -0.090674, -0.279453, -0.486070, -0.273507, 0.004459, -0.062431, 0.095133, |
| 0.056485, 0.022297, -0.105538, -0.184320, 0.358235, 0.254183, 0.049053, |
| 0.084728, 0.218508, 0.078782, -0.136754, -0.017837, -0.124862, -0.118916, |
| -0.001486, 0.043107, 0.254183, 0.087701, 0.261616, 0.309182, -0.404315, |
| -0.040134, -0.046080, -0.052026, -0.034188, -0.475665, -0.025270, -0.049053, |
| -0.046080, -0.062431, 0.020810, 0.040134, -0.135267, -0.169456, -0.050539, |
| -0.576743, 0.034188, 0.075809, 0.101079, 0.136754, 0.083241, 0.077296, |
| -0.050539, 0.761064, -0.335938, -0.080268, 0.025270, 0.257156, 0.227427, |
| 0.252697, 0.065404, 0.115943, 0.222968, -0.026756, -0.054999, 0.107025, |
| -0.093646, 0.041621, -0.092160, -0.474178, -0.016351, 0.004459, 0.049053, |
| 0.019324, 0.019324, 0.074323, 0.038648, -0.613905, 0.182834, 0.075809, |
| 0.028243, 0.019324, 0.010405, -0.011892, 0.001486, -0.492016, -0.224454, |
| -0.474178, -0.147159, 0.002973, 0.102565, 0.136754, -0.267561, -0.001486, |
| -0.095133, -0.040134, 0.066890, 0.074323, 0.104052, 0.532150, 0.090674, |
| 0.072836, -0.053512, -0.004459, 0.020810, 0.046080, 0.062431, 0.477151, |
| 0.133781, -0.029729, -0.026756, 0.031215, 0.156077, 0.096619, 0.251210, |
| 0.352289, 0.657012, 0.047566, -0.014865, -0.072836, -0.016351, 0.008919, |
| -0.053512, 0.016351, 0.300263, 0.047566, 0.020810, 0.169456, 0.001486, |
| 0.007432, 0.111484, 0.044594, -0.188779, -0.096619, 0.074323, -0.040134, |
| 0.160537, 0.138240, 0.184320, 0.377559, -0.092160, -0.049053, 0.056485, |
| -0.032702, 0.001486, -0.083241, -0.472692, -0.114457, -0.117430, -0.075809, |
| 0.026756, 0.163510, 0.172428, 0.127835, -0.199185, -0.218508, -0.057972, |
| -0.132294, -0.162023, -0.019324, -0.245265, -0.395396, -0.254183, 0.084728, |
| 0.248238, 0.191752, 0.221481, 0.173915, 0.173915, -0.208103, -0.077296, |
| 0.384991, -0.313641, -0.313641, -0.147159, -0.090674, 0.035675, 0.059458, |
| -0.010405, 0.019324, 0.087701, 0.016351, 0.037161, 0.469719, -0.074323, |
| 0.092160, 0.026756, 0.090674, 0.098106, 0.004459, -0.034188, 0.492016, |
| -0.367154, -0.093646, -0.063917, 0.041621, 0.017837, 0.026756, -0.062431, |
| -0.350803, 0.425125, 0.002973, 0.083241, 0.075809, 0.016351, 0.047566, |
| -0.185807, -0.107025, -0.098106, -0.144186, 0.255670, 0.020810, 0.105538, |
| 0.029729, 0.129321, 0.156077, 0.141213, 0.334452, 0.147159, -0.066890, |
| 0.035675, 0.115943, 0.240805, 0.328506, 0.162023, -0.237832, 0.218508, |
| 0.233373, 0.214049, 0.099592, 0.026756, -0.322560, -0.236346, -0.166483, |
| 0.225941, 0.109997, -0.147159, 0.147159, -0.266075, 0.111484, 0.078782, |
| -0.120403, 0.022297, -0.075809, -0.148645, -0.251210, -0.176888, -0.044594, |
| -0.023783, 0.016351, 0.026756, -0.013378, -0.069863, -0.112970, 0.013378, |
| 0.086214, 0.014865, 0.352289, -0.240805, -0.135267, -0.114457, -0.472692, |
| 0.334452, 0.095133, 0.047566, 0.130808, -0.068377, -0.007432, -0.130808, |
| -0.121889, -0.053512, -0.245265, -0.371613, -0.083241, 0.000000, -0.028243, |
| 0.029729, -0.093646, -0.004459, -0.038648, -0.108511, -0.475665, -0.169456, |
| -0.047566, -0.010405, -0.114457, -0.353776, -0.034188, -0.044594, 0.041621, |
| -0.047566, -0.107025, 0.004459, 0.053512, 0.047566, -0.358235, -0.193239, |
| 0.040134, -0.096619, -0.054999, 0.099592, 0.032702, 0.205130, -0.170942, |
| -0.237832, -0.405801, -0.126348, -0.072836, -0.203644, -0.169456, -0.093646, |
| -0.074323, 0.078782, 0.607959, -0.437017, -0.164996, -0.166483, 0.043107, |
| -0.016351, 0.258643, 0.065404, -0.057972, 0.017837, 0.080268, 0.050539, |
| -0.013378, -0.215536, -0.524718, 0.260129, 0.040134, -0.002973, -0.046080, |
| 0.020810, 0.025270, 0.145672, 0.515799, 0.233373, 0.011892, 0.139727, |
| 0.126348, 0.065404, -0.007432, -0.008919, 0.035675, 0.083241, 0.040134, |
| -0.005946, 0.503907, -0.490529, -0.181347, -0.092160, -0.038648, 0.019324, |
| 0.133781, -0.011892, 0.041621, 0.062431, -0.062431, -0.040134, -0.092160, |
| -0.111484, -0.133781, -0.130808, -0.484583, -0.248238, 0.037161, -0.092160, |
| -0.056485, -0.041621, 0.112970, 0.248238, 0.438503, 0.258643, -0.013378, |
| 0.004459, 0.043107, 0.040134, 0.017837, 0.101079, 0.264589, 0.212563, |
| 0.014865, 0.285399, 0.153105, 0.170942, 0.358235, 0.334452, 0.086214, |
| 0.132294, 0.098106, -0.001486, 0.107025, 0.200671, -0.026756, 0.344857, |
| 0.227427, -0.041621, 0.098106, 0.063917, -0.093646, 0.130808, 0.285399, |
| -0.319587, 0.035675, -0.017837, -0.319587, 0.016351, -0.098106, -0.017837, |
| 0.083241, 0.074323, -0.054999, 0.276480, 0.316614, -0.099592, -0.059458, |
| 0.156077, -0.043107, 0.035675, 0.056485, -0.022297, 0.017837, -0.001486, |
| 0.340398, 0.492016, 0.004459, 0.057972, -0.150132, -0.206617, -0.257156, |
| -0.248238, -0.080268, -0.164996, 0.352289, -0.054999, -0.056485, 0.010405, |
| -0.049053, -0.041621, -0.099592, 0.013378, -0.089187, 0.057972, -0.413234, |
| 0.217022, 0.013378, -0.080268, -0.035675, 0.035675, 0.007432, 0.002973, |
| -0.469719, 0.141213, 0.136754, 0.153105, 0.130808, -0.104052, -0.508367, |
| -0.291345, -0.072836, -0.019324, -0.252697, -0.214049, -0.214049, 0.130808, |
| 0.484583}; |
| |
| // Bias of shape {64} |
| const float bias_data_16x1x1[] = { |
| -0.245395, -0.083545, -0.262522, -0.407912, -0.560898, -0.364789, -0.037964, |
| -0.378594, 0.178152, 0.400380, -0.301349, -0.240913, -0.159454, -0.158757, |
| -0.073665, 0.455906, -0.061232, 0.318907, -0.226993, -0.344644, 0.140316, |
| 0.559608, 0.109774, 0.437391, 0.113849, -0.162068, 0.039572, 0.569472, |
| 0.460205, 0.113459, 0.370469, 0.176811, 0.203063, -0.296975, -0.271655, |
| 0.059862, -0.159912, -0.077310, -0.338314, -0.195477, -0.256762, 0.233834, |
| 0.083172, 0.029040, -0.236288, -0.267054, -0.166627, 0.188319, -0.271391, |
| -0.222920, 0.106463, 0.263614, 0.384986, -0.125957, -0.095890, 0.363686, |
| -0.036990, -0.358884, -0.178254, 0.305596, 0.390088, -0.189437, 0.613409, |
| 0.399639}; |
| |
| // Activation state with shape {64, 8}. These initial values must be copied into |
| // a mutable activation state tensor. |
| const float initial_activation_state_data_16x1x1[] = { |
| -0.582275, -0.586623, -1.262373, -1.277279, -1.542175, -1.271999, -1.429757, |
| -1.184425, -0.462094, -1.443421, 0.230736, -0.494701, -0.354955, -2.534061, |
| -4.277471, -4.218467, 0.403711, -0.248748, -0.330111, -0.467683, 0.549047, |
| 0.733511, -0.230115, 0.793136, -1.126353, -0.984123, -0.081984, -0.222351, |
| 0.692830, 0.517060, 1.367958, 2.118860, -0.116766, -0.826365, -2.402700, |
| -2.313884, -2.898954, -2.076005, -2.405185, -2.755481, 0.329490, 0.085400, |
| -1.485966, -2.034702, -2.161405, -1.269515, -1.151818, -1.823841, 0.561469, |
| 1.109273, 1.693411, -0.082605, -0.069252, -1.225107, -1.330693, -1.411435, |
| 0.253406, -0.357439, -1.593415, -0.879779, -1.111136, 1.821357, 2.471952, |
| 1.236908, -4.014127, -2.810448, -2.944604, -1.930980, -1.566398, -0.838166, |
| -0.319242, 0.749349, 1.156476, 0.658670, 1.997437, 2.080663, 2.912618, |
| 2.677224, 2.642442, 2.796163, -0.272349, -0.473273, 3.120063, 2.747097, |
| 3.595510, 1.874150, 2.049919, 2.093396, -1.049959, 0.277939, -1.255541, |
| -1.052443, -1.810177, -0.883505, -0.538178, 0.524203, -1.017662, -0.269244, |
| 0.039129, -0.227941, -0.114592, -2.018243, -2.548968, -0.706804, 0.890959, |
| 0.102480, 0.349986, 0.405885, 1.287216, 0.756181, 0.319242, -0.641590, |
| -3.841774, -2.716042, -4.342065, -3.826557, -2.924729, -1.643724, -1.237839, |
| -0.597492, -1.954892, -1.215169, -1.528201, -1.018904, -0.863941, -0.293467, |
| 0.039439, 0.672023, 1.408019, 1.362679, 1.467644, 1.006171, 0.310236, |
| -0.249990, -1.048406, -0.752144, -1.831605, -1.058033, -1.096541, -0.293467, |
| 0.051551, 0.232600, 0.088816, 2.570395, 0.704009, 2.465120, 3.010751, |
| 2.139357, 0.630410, 1.006171, 1.545281, 1.486898, -1.162998, -2.344317, |
| -4.593918, -3.522842, -2.872247, -1.416714, -0.642521, -0.230115, 0.315205, |
| -0.368930, -0.162726, 0.396879, 0.505570, 0.534451, 0.554947, 1.270447, |
| 0.388805, 0.531967, -1.243119, -0.671713, -1.214859, -0.238189, 0.016459, |
| -1.164550, 0.609603, 3.293348, 2.600208, 1.454290, -1.034121, -1.760179, |
| -1.192500, -0.613951, 3.449553, 2.912618, 1.917937, 1.435968, 0.879158, |
| 1.118279, 0.102791, -0.502465, -0.239121, -0.092853, 1.786265, 1.943091, |
| 2.547104, 2.630641, 2.585302, 2.965411, -0.945615, -2.538720, -2.474126, |
| -1.088156, 0.056209, 0.864873, 0.170490, 0.457435, 0.545941, 0.752765, |
| 1.569503, 1.129459, 0.662086, -0.527929, -0.810838, -1.662978, 1.285042, |
| 1.653040, 4.130893, 2.961995, 4.147041, 3.256393, 3.881524, 2.522571, |
| -0.875431, -1.112378, 2.105817, 2.180970, 3.121926, 1.577577, 1.639376, |
| 2.906407, -0.142230, 0.421101, 2.212335, 2.311399, 3.993321, 3.651719, |
| 4.206666, 4.678387, -1.304917, -1.130701, -2.543067, -2.500212, -2.197118, |
| -1.197158, -0.949652, -0.282908, 0.320795, -1.543728, 1.290322, 1.788128, |
| 3.957297, 3.205774, 2.892432, 2.297114, 0.138814, -0.139435, 0.936920, |
| 0.344707, 0.723263, -1.772290, -3.138385, -2.287177, -2.405806, -1.859864, |
| -4.572801, -3.410424, -3.855748, -2.239663, -2.269786, -1.582857, 4.238342, |
| 3.858543, 2.499901, 1.087535, 0.290051, -0.026086, -0.880400, -2.602692, |
| -1.404292, 0.253096, -0.665502, -1.443421, -0.925119, -0.096580, 1.115484, |
| 1.846200, -1.604284, -1.244671, -0.464888, 0.326385, 0.168006, -0.262723, |
| -0.744691, 0.953379, -0.407127, -0.349986, -1.154302, 0.831023, 1.590931, |
| 2.538720, 2.063583, 3.697680, -0.752455, -1.293117, -1.330693, -1.869802, |
| -0.592523, 0.631652, 1.198089, -0.481347, 3.738983, 4.153252, 2.782499, |
| 2.244321, 0.709289, 1.650245, 1.700865, 0.385078, 2.192460, 2.610456, |
| 4.009780, 3.492719, 2.574743, 2.116687, 1.856138, 1.205853, 2.722563, |
| 4.075305, 5.415935, 3.009198, 2.715421, 1.571056, 0.897170, -2.430339, |
| 0.749970, 0.425760, -0.302783, 0.817359, 1.031636, 1.913589, 2.686229, |
| 1.631923, -1.459259, -1.793097, -1.187531, -1.553355, -0.844998, -1.296843, |
| -1.805519, -0.486627, 0.909591, 2.082837, -1.473855, -2.456735, -3.851401, |
| -2.760139, -3.060438, -2.605487, -2.138735, -2.441519, -1.333177, -1.353984, |
| -0.245642, -0.588486, 0.033850, 2.084700, 0.076084, 0.690035, 0.747797, |
| 0.594697, -1.016109, -1.348083, -1.201195, -1.088466, 2.045571, 2.460772, |
| 0.717984, 0.041613, -0.721711, 1.134738, 2.322269, 1.112378, -0.307441, |
| -0.581033, -0.868599, -0.018633, 0.856488, 0.919839, 0.303094, -0.433213, |
| 0.811148, -0.508986, -1.060828, -1.227591, -1.566087, -1.117968, -1.385038, |
| -2.011101, -0.490353, -1.849616, -0.594697, -1.055859, 1.110205, 0.622646, |
| 0.145957, 0.359303, 1.012072, 0.774814, -0.400295, -1.484103, -2.007374, |
| -1.441247, -0.997787, -0.581033, -0.545941, -0.306510, 0.693451, 0.087264, |
| -0.227320, -1.211753, -1.532859, -1.688753, 0.065215, 0.134777, 0.608051, |
| -0.393152, -0.214588, -0.635689, -1.499320, 0.069562, -1.555839, -2.633126, |
| -2.966032, -1.550870, -0.101549, 0.874189, 0.436318, 0.299367, 2.289972, |
| 2.339659, 2.602071, 1.564535, 0.019254, -0.583207, -1.295912, -2.424749, |
| -1.221070, -1.175109, -0.577306, -0.102791, 1.877876, 2.568222, 2.173827, |
| 3.131243, 2.637784, 2.088737, 3.679047, 3.218506, 2.483442, 1.650556, |
| 1.363611, -0.027328, 1.486898, -0.721711, -3.684327, -3.006093, -3.777491, |
| -2.327548, -2.737470, -4.549510, -0.060867, 0.127635, 0.680408, 0.581344, |
| 0.320174, -0.403090, -0.838166, 0.293777, -0.995613, -0.165521, -0.419859, |
| 1.110515, 1.203679, 1.749931, 2.467294, 4.276539, 0.031055, -0.967664, |
| 1.167035, 1.865144, 3.221923, 3.248630, 4.121266, 4.187723, 0.749039, |
| -1.571056, 0.785994, 1.568572, 3.759479, 3.588678, 4.116608, 3.864444, |
| -0.290051, -0.271107, 0.375140, 0.537556, 0.536314, 0.095959, 0.054656, |
| 0.088816}; |
| |
| // One output with shape {1, 64} |
| const float golden_output_16x1x1[] = { |
| -0.087914, 1.145864, -0.418088, -1.556392, -0.925298, 0.205252, 0.289119, |
| 1.331180, -0.218010, 0.963057, -2.225886, 1.248478, 1.448983, 0.355467, |
| 1.682174, 0.803739, 0.449738, 0.543566, 1.916269, -2.975136, 0.222774, |
| 0.241589, -0.104216, 1.561748, 0.936818, -0.089907, -0.520117, -0.870353, |
| 1.606074, 0.895770, 0.521297, -0.369994, -0.889351, -2.809309, 2.404628, |
| 1.069754, -0.195456, -1.105652, 1.272715, -1.233177, 1.271416, -1.691805, |
| -1.058125, -0.716227, 0.052540, 1.262483, 0.540555, 1.735760, -0.539197, |
| -0.014367, -0.243002, 1.072254, 0.528985, -0.731151, -1.262649, 2.338702, |
| -0.603093, 0.970736, -3.567897, 0.035085, -0.201711, -0.550400, 1.545573, |
| -1.805005}; |
| |
| // One output with shape {1, 64} |
| const float golden_output_relu_16x1x1[] = { |
| 0.000000, 1.145864, 0.000000, 0.000000, 0.000000, 0.205252, 0.289119, |
| 1.331180, 0.000000, 0.963057, 0.000000, 1.248478, 1.448983, 0.355467, |
| 1.682174, 0.803739, 0.449738, 0.543566, 1.916269, 0.000000, 0.222774, |
| 0.241589, 0.000000, 1.561748, 0.936818, 0.000000, 0.000000, 0.000000, |
| 1.606074, 0.895770, 0.521297, 0.000000, 0.000000, 0.000000, 2.404628, |
| 1.069754, 0.000000, 0.000000, 1.272715, 0.000000, 1.271416, 0.000000, |
| 0.000000, 0.000000, 0.052540, 1.262483, 0.540555, 1.735760, 0.000000, |
| 0.000000, 0.000000, 1.072254, 0.528985, 0.000000, 0.000000, 2.338702, |
| 0.000000, 0.970736, 0.000000, 0.035085, 0.000000, 0.000000, 1.545573, |
| 0.000000}; |
| |
| template <typename T> |
| void ValidateSVDFGoldens(const int batch_size, const int num_units, |
| const int input_size, const int rank, |
| TfLiteTensor* tensors, const int tensor_count, |
| TfLiteFusedActivation activation, |
| const T* input_sequences_data, |
| const int input_sequences_len, T* output_data, |
| const T* expected_output, float tolerance = 1e-5f) { |
| TfLiteSVDFParams params; |
| params.rank = rank; |
| params.activation = activation; |
| |
| int inputs_array_data[] = {5, 0, 1, 2, 3, 4}; |
| TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); |
| |
| int outputs_array_data[] = {1, 5}; |
| TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); |
| |
| const TFLMRegistration registration = Register_SVDF(); |
| micro::KernelRunner runner(registration, tensors, tensor_count, inputs_array, |
| outputs_array, ¶ms); |
| |
| TfLiteStatus init_and_prepare_status = runner.InitAndPrepare(); |
| TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, init_and_prepare_status); |
| TF_LITE_MICRO_EXPECT(runner.ValidateTempBufferDeallocated()); |
| |
| // Abort early to make it clear init and prepare failed. |
| if (init_and_prepare_status != kTfLiteOk) { |
| return; |
| } |
| |
| int num_inputs = input_sequences_len / (input_size * batch_size); |
| |
| for (int i = 0; i < num_inputs; ++i) { |
| const T* input_batch_start = |
| input_sequences_data + i * input_size * batch_size; |
| |
| memcpy(tensors[0].data.raw, input_batch_start, tensors[0].bytes); |
| TfLiteStatus status = runner.Invoke(); |
| TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, status); |
| |
| // Only validate outputs when invoke has succeeded. |
| if (status == kTfLiteOk) { |
| int output_idx = 0; |
| int golden_idx = i * batch_size * num_units; |
| for (int j = golden_idx; j < golden_idx + batch_size * num_units; ++j) { |
| TF_LITE_MICRO_EXPECT_NEAR(expected_output[j], output_data[output_idx], |
| tolerance); |
| output_idx++; |
| } |
| } |
| } |
| TF_LITE_MICRO_EXPECT(runner.ValidateTempBufferDeallocated()); |
| } |
| |
| void TestSVDF(const int batch_size, const int num_units, const int input_size, |
| const int memory_size, const int rank, |
| TfLiteFusedActivation activation, float* input_data, |
| const float* feature_weights_data, const float* time_weights_data, |
| float* activation_state_data, const float* bias_data, |
| float* scratch_data, float* output_data, |
| const float* input_sequences_data, int input_sequences_len, |
| const float* expected_output, float tolerance = 1e-5f) { |
| const int num_filters = num_units * rank; |
| |
| int input_dims_arg[] = {2, batch_size, input_size}; |
| TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_arg); |
| |
| int feature_weights_dims_args[] = {2, num_filters, input_size}; |
| TfLiteIntArray* feature_weights_dims = |
| IntArrayFromInts(feature_weights_dims_args); |
| |
| int time_weights_dims_args[] = {2, num_filters, memory_size}; |
| TfLiteIntArray* time_weights_dims = IntArrayFromInts(time_weights_dims_args); |
| |
| int activation_state_dims_args[] = {2, batch_size, memory_size * num_filters}; |
| TfLiteIntArray* activation_state_dims = |
| IntArrayFromInts(activation_state_dims_args); |
| |
| int bias_dims_args[] = {1, num_units}; |
| TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_args); |
| |
| int output_dims_args[] = {2, batch_size, num_units}; |
| TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_args); |
| |
| const int tensor_count = 6; // 5 inputs, 1 output |
| TfLiteTensor tensors[] = { |
| CreateTensor(input_data, input_dims), |
| CreateTensor(feature_weights_data, feature_weights_dims), |
| CreateTensor(time_weights_data, time_weights_dims), |
| CreateTensor(bias_data, bias_dims), |
| CreateTensor(activation_state_data, activation_state_dims, |
| /*is_variable=*/true), |
| CreateTensor(output_data, output_dims), |
| }; |
| |
| ValidateSVDFGoldens(batch_size, num_units, input_size, rank, tensors, |
| tensor_count, activation, input_sequences_data, |
| input_sequences_len, output_data, expected_output, |
| tolerance); |
| } |
| |
| // The pattern to this method's arguments is: |
| // <kernel metadata> |
| // for each tensor in |
| // {input, feature weights, time weights, bias, activation state, output}: |
| // <tensor float values> <tensor quantized buffer> <tensor quantization data> |
| // |
| // Template parameter sets type of both time_weights and activation_state. |
| template <typename T> |
| inline void TestIntegerSVDF( |
| const int batch_size, const int num_units, const int input_size, |
| const int memory_size, const int rank, TfLiteFusedActivation activation, |
| int8_t* input_quantized, float input_scale, int input_zero_point, |
| const float* feature_weights_data, int8_t* feature_weights_quantized, |
| const float feature_weights_scale, const float* time_weights_data, |
| T* time_weights_quantized, float time_weights_scale, const float* bias_data, |
| int32_t* bias_quantized, const float* initial_activation_state_data, |
| T* activation_state_quantized, float activation_state_scale, |
| int activation_state_zero_point, int8_t* output_data, float output_scale, |
| int output_zero_point, const float* input_sequences_data, |
| int8_t* input_sequences_quantized, const int input_sequences_len, |
| const float* golden_output, int8_t* golden_output_quantized, |
| int golden_output_len) { |
| const int num_filters = num_units * rank; |
| |
| int input_dims_arg[] = {2, batch_size, input_size}; |
| TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_arg); |
| |
| int feature_weights_dims_args[] = {2, num_filters, input_size}; |
| TfLiteIntArray* feature_weights_dims = |
| IntArrayFromInts(feature_weights_dims_args); |
| |
| int time_weights_dims_args[] = {2, num_filters, memory_size}; |
| TfLiteIntArray* time_weights_dims = IntArrayFromInts(time_weights_dims_args); |
| |
| int bias_dims_data[] = {1, num_units}; |
| TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data); |
| |
| int activation_state_dims_args[] = {2, batch_size, memory_size * num_filters}; |
| TfLiteIntArray* activation_state_dims = |
| IntArrayFromInts(activation_state_dims_args); |
| |
| int output_dims_args[] = {2, batch_size, num_units}; |
| TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_args); |
| |
| const int tensor_count = 6; // 5 inputs, 1 output |
| |
| TfLiteTensor tensors[] = { |
| CreateQuantizedTensor(input_quantized, input_dims, input_scale, |
| input_zero_point), |
| CreateQuantizedTensor(feature_weights_data, feature_weights_quantized, |
| feature_weights_dims, feature_weights_scale, 0), |
| CreateQuantizedTensor(time_weights_data, time_weights_quantized, |
| time_weights_dims, time_weights_scale, 0), |
| CreateQuantizedBiasTensor(bias_data, bias_quantized, bias_dims, |
| time_weights_scale, activation_state_scale), |
| CreateQuantizedTensor(initial_activation_state_data, |
| activation_state_quantized, activation_state_dims, |
| activation_state_scale, 0, |
| /*is_variable=*/true), |
| CreateQuantizedTensor(output_data, output_dims, output_scale, |
| output_zero_point)}; |
| |
| tflite::Quantize(golden_output, golden_output_quantized, golden_output_len, |
| output_scale, output_zero_point); |
| tflite::Quantize(input_sequences_data, input_sequences_quantized, |
| input_sequences_len, input_scale, input_zero_point); |
| |
| ValidateSVDFGoldens(batch_size, num_units, input_size, rank, tensors, |
| tensor_count, activation, input_sequences_quantized, |
| input_sequences_len, output_data, golden_output_quantized, |
| /*tolerance*/ 1); |
| } |
| |
| // Template parameter sets type of both time_weights and activation_state. |
| template <typename T> |
| void SvdfQuantized2x2Input2x4OutputShouldMatchGolden() { |
| constexpr int batch_size = 2; |
| constexpr int num_units = 4; |
| constexpr int input_size = 2; |
| constexpr int memory_size = 10; |
| constexpr int rank = 2; |
| constexpr int num_filters = num_units * rank; |
| |
| const int input_size_dims_count = batch_size * input_size; |
| |
| const int activation_state_dims_count = |
| batch_size * memory_size * num_filters; |
| |
| const int output_dims_count = batch_size * num_units; |
| int8_t output_data[output_dims_count]; |
| |
| float input_scale = 2.5 / std::numeric_limits<int8_t>::max(); |
| float feature_weights_scale = 1.0 / std::numeric_limits<int8_t>::max(); |
| float time_weights_scale = 1.0 / std::numeric_limits<T>::max(); |
| float activation_state_scale = 1.49 / std::numeric_limits<T>::max(); |
| float output_scale = 1.0 / std::numeric_limits<int8_t>::max(); |
| |
| int input_zero_point = 0; |
| int output_zero_point = 0; |
| |
| int8_t input_quantized[input_size_dims_count]; |
| int8_t input_sequences_quantized[sizeof(tflite::testing::input_data_2x2x10) / |
| sizeof(float)]; |
| int8_t feature_weights_quantized |
| [sizeof(tflite::testing::feature_weights_data_2x2x10) / sizeof(float)]; |
| T time_weights_quantized[sizeof(tflite::testing::time_weights_data_2x2x10) / |
| sizeof(float)]; |
| T activation_state_quantized[activation_state_dims_count]; |
| int32_t |
| bias_quantized[sizeof(tflite::testing::bias_data_2x2x10) / sizeof(float)]; |
| int8_t golden_quantized[sizeof(tflite::testing::golden_output_2x2x10) / |
| sizeof(float)]; |
| |
| tflite::testing::TestIntegerSVDF( |
| batch_size, num_units, input_size, memory_size, rank, kTfLiteActRelu, |
| input_quantized, input_scale, input_zero_point, |
| tflite::testing::feature_weights_data_2x2x10, feature_weights_quantized, |
| feature_weights_scale, tflite::testing::time_weights_data_2x2x10, |
| time_weights_quantized, time_weights_scale, |
| tflite::testing::bias_data_2x2x10, bias_quantized, |
| tflite::testing::initial_activation_state_data_2x2x10, |
| activation_state_quantized, activation_state_scale, 0, output_data, |
| output_scale, output_zero_point, tflite::testing::input_data_2x2x10, |
| input_sequences_quantized, |
| sizeof(tflite::testing::input_data_2x2x10) / sizeof(float), |
| tflite::testing::golden_output_2x2x10, golden_quantized, |
| sizeof(tflite::testing::golden_output_2x2x10) / sizeof(float)); |
| } |
| |
| // Template parameter sets type of both time_weights and activation_state. |
| template <typename T> |
| void SvdfQuantized1x16Input64x1OutputShouldMatchGolden() { |
| constexpr int batch_size = 1; |
| constexpr int num_units = 64; |
| constexpr int input_size = 16; |
| constexpr int memory_size = 8; |
| constexpr int rank = 1; |
| constexpr int num_filters = num_units * rank; |
| constexpr int activation_state_dims_count = |
| batch_size * memory_size * num_filters; |
| constexpr int output_dims_count = batch_size * num_units; |
| constexpr int input_dims_count = batch_size * input_size; |
| |
| int8_t output_data[output_dims_count]; |
| |
| float input_scale = 0.10075444; |
| float feature_weights_scale = 0.00649388; |
| float time_weights_scale = tflite::testing::ScaleFromMinMax<T>(-.81, .81); |
| float activation_state_scale = |
| tflite::testing::ScaleFromMinMax<T>(-17.73, 17.73); |
| int activation_state_zero_point = |
| tflite::testing::ZeroPointFromMinMax<T>(-17.73, 17.73); |
| float output_scale = 0.051445257; |
| |
| int input_zero_point = 2; |
| int output_zero_point = 0; |
| |
| int8_t input_quantized[input_dims_count]; |
| int8_t input_sequences_quantized[sizeof(tflite::testing::input_data_16x1x1) / |
| sizeof(float)]; |
| int8_t feature_weights_quantized |
| [sizeof(tflite::testing::feature_weights_data_16x1x1) / sizeof(float)]; |
| T time_weights_quantized[sizeof(tflite::testing::time_weights_data_16x1x1) / |
| sizeof(float)]; |
| T activation_state_quantized[activation_state_dims_count]; |
| int32_t |
| bias_quantized[sizeof(tflite::testing::bias_data_16x1x1) / sizeof(float)]; |
| int8_t golden_quantized[sizeof(tflite::testing::golden_output_16x1x1) / |
| sizeof(float)]; |
| |
| tflite::testing::TestIntegerSVDF( |
| batch_size, num_units, input_size, memory_size, rank, kTfLiteActNone, |
| input_quantized, input_scale, input_zero_point, |
| tflite::testing::feature_weights_data_16x1x1, feature_weights_quantized, |
| feature_weights_scale, tflite::testing::time_weights_data_16x1x1, |
| time_weights_quantized, time_weights_scale, |
| tflite::testing::bias_data_16x1x1, bias_quantized, |
| tflite::testing::initial_activation_state_data_16x1x1, |
| activation_state_quantized, activation_state_scale, |
| activation_state_zero_point, output_data, output_scale, output_zero_point, |
| tflite::testing::input_data_16x1x1, input_sequences_quantized, |
| sizeof(tflite::testing::input_data_16x1x1) / sizeof(float), |
| tflite::testing::golden_output_16x1x1, golden_quantized, |
| sizeof(tflite::testing::golden_output_16x1x1) / sizeof(float)); |
| } |
| |
| template <typename T> |
| void SvdfQuantized1x16Input64x1OutputReluShouldMatchGolden() { |
| constexpr int batch_size = 1; |
| constexpr int num_units = 64; |
| constexpr int input_size = 16; |
| constexpr int memory_size = 8; |
| constexpr int rank = 1; |
| constexpr int num_filters = num_units * rank; |
| constexpr int activation_state_dims_count = |
| batch_size * memory_size * num_filters; |
| constexpr int output_dims_count = batch_size * num_units; |
| constexpr int input_dims_count = batch_size * input_size; |
| |
| int8_t output_data[output_dims_count]; |
| |
| float input_scale = 0.10075444; |
| float feature_weights_scale = 0.00649388; |
| float time_weights_scale = tflite::testing::ScaleFromMinMax<T>(-.81, .81); |
| float activation_state_scale = |
| tflite::testing::ScaleFromMinMax<T>(-17.73, 17.73); |
| int activation_state_zero_point = |
| tflite::testing::ZeroPointFromMinMax<T>(-17.73, 17.73); |
| float output_scale = 0.051445257; |
| |
| int input_zero_point = 2; |
| int output_zero_point = -128; |
| |
| int8_t input_quantized[input_dims_count]; |
| int8_t input_sequences_quantized[sizeof(tflite::testing::input_data_16x1x1) / |
| sizeof(float)]; |
| int8_t feature_weights_quantized |
| [sizeof(tflite::testing::feature_weights_data_16x1x1) / sizeof(float)]; |
| T time_weights_quantized[sizeof(tflite::testing::time_weights_data_16x1x1) / |
| sizeof(float)]; |
| T activation_state_quantized[activation_state_dims_count]; |
| int32_t |
| bias_quantized[sizeof(tflite::testing::bias_data_16x1x1) / sizeof(float)]; |
| int8_t golden_quantized[sizeof(tflite::testing::golden_output_relu_16x1x1) / |
| sizeof(float)]; |
| |
| tflite::testing::TestIntegerSVDF( |
| batch_size, num_units, input_size, memory_size, rank, kTfLiteActRelu, |
| input_quantized, input_scale, input_zero_point, |
| tflite::testing::feature_weights_data_16x1x1, feature_weights_quantized, |
| feature_weights_scale, tflite::testing::time_weights_data_16x1x1, |
| time_weights_quantized, time_weights_scale, |
| tflite::testing::bias_data_16x1x1, bias_quantized, |
| tflite::testing::initial_activation_state_data_16x1x1, |
| activation_state_quantized, activation_state_scale, |
| activation_state_zero_point, output_data, output_scale, output_zero_point, |
| tflite::testing::input_data_16x1x1, input_sequences_quantized, |
| sizeof(tflite::testing::input_data_16x1x1) / sizeof(float), |
| tflite::testing::golden_output_relu_16x1x1, golden_quantized, |
| sizeof(tflite::testing::golden_output_relu_16x1x1) / sizeof(float)); |
| } |
| |
| } // namespace |
| } // namespace testing |
| } // namespace tflite |
| |
| TF_LITE_MICRO_TESTS_BEGIN |
| |
| TF_LITE_MICRO_TEST(SvdfFloat2x2Input2x4OutputShouldMatchGolden) { |
| constexpr int batch_size = 2; |
| constexpr int num_units = 4; |
| constexpr int input_size = 2; |
| constexpr int memory_size = 10; |
| constexpr int rank = 2; |
| constexpr int num_filters = num_units * rank; |
| |
| const int input_size_dims_count = batch_size * input_size; |
| float input_data[input_size_dims_count]; |
| |
| const int activation_state_dims_count = |
| batch_size * memory_size * num_filters; |
| float activation_state_data[activation_state_dims_count]; |
| |
| memcpy(activation_state_data, |
| tflite::testing::initial_activation_state_data_2x2x10, |
| sizeof(tflite::testing::initial_activation_state_data_2x2x10)); |
| |
| const int scratch_dims_count = batch_size * num_filters; |
| float scratch_data[scratch_dims_count]; |
| |
| const int output_dims_count = batch_size * num_units; |
| float output_data[output_dims_count]; |
| |
| tflite::testing::TestSVDF( |
| batch_size, num_units, input_size, memory_size, rank, kTfLiteActNone, |
| input_data, tflite::testing::feature_weights_data_2x2x10, |
| tflite::testing::time_weights_data_2x2x10, activation_state_data, |
| tflite::testing::bias_data_2x2x10, scratch_data, output_data, |
| tflite::testing::input_data_2x2x10, |
| sizeof(tflite::testing::input_data_2x2x10) / sizeof(float), |
| tflite::testing::golden_output_2x2x10); |
| } |
| |
| // Only reference kernels support full int8 svdf currently. |
| #if !defined(HEXAGON) |
| TF_LITE_MICRO_TEST(SvdfQuantized2x2Input2x4OutputShouldMatchGoldenInt8) { |
| tflite::testing::SvdfQuantized2x2Input2x4OutputShouldMatchGolden<int8_t>(); |
| } |
| #endif |
| |
| TF_LITE_MICRO_TEST(SvdfQuantized2x2Input2x4OutputShouldMatchGoldenInt16) { |
| tflite::testing::SvdfQuantized2x2Input2x4OutputShouldMatchGolden<int16_t>(); |
| } |
| |
| TF_LITE_MICRO_TEST(SvdfFloat1x16Input64x1OutputShouldMatchGolden) { |
| constexpr int batch_size = 1; |
| constexpr int num_units = 64; |
| constexpr int input_size = 16; |
| constexpr int memory_size = 8; |
| constexpr int rank = 1; |
| constexpr int num_filters = num_units * rank; |
| constexpr int activation_state_dims_count = |
| batch_size * memory_size * num_filters; |
| constexpr int output_dims_count = batch_size * num_units; |
| constexpr int input_dims_count = batch_size * input_size; |
| |
| float input_data[input_dims_count]; |
| float output_data[output_dims_count]; |
| float scratch_buffer[batch_size * num_filters]; |
| float activation_state_data_mutable[activation_state_dims_count]; |
| |
| // Initialize activation state to starting values. |
| memcpy(activation_state_data_mutable, |
| tflite::testing::initial_activation_state_data_16x1x1, |
| sizeof(tflite::testing::initial_activation_state_data_16x1x1)); |
| |
| tflite::testing::TestSVDF( |
| batch_size, num_units, input_size, memory_size, rank, kTfLiteActNone, |
| input_data, tflite::testing::feature_weights_data_16x1x1, |
| tflite::testing::time_weights_data_16x1x1, activation_state_data_mutable, |
| tflite::testing::bias_data_16x1x1, scratch_buffer, output_data, |
| tflite::testing::input_data_16x1x1, input_size, |
| tflite::testing::golden_output_16x1x1); |
| } |
| |
| TF_LITE_MICRO_TEST(SvdfFloat1x16Input64x1OutputReluShouldMatchGolden) { |
| constexpr int batch_size = 1; |
| constexpr int num_units = 64; |
| constexpr int input_size = 16; |
| constexpr int memory_size = 8; |
| constexpr int rank = 1; |
| constexpr int num_filters = num_units * rank; |
| constexpr int activation_state_dims_count = |
| batch_size * memory_size * num_filters; |
| constexpr int output_dims_count = batch_size * num_units; |
| constexpr int input_dims_count = batch_size * input_size; |
| |
| float input_data[input_dims_count]; |
| float output_data[output_dims_count]; |
| float scratch_buffer[batch_size * num_filters]; |
| float activation_state_data_mutable[activation_state_dims_count]; |
| |
| // Initialize activation state to starting values. |
| memcpy(activation_state_data_mutable, |
| tflite::testing::initial_activation_state_data_16x1x1, |
| sizeof(tflite::testing::initial_activation_state_data_16x1x1)); |
| |
| tflite::testing::TestSVDF( |
| batch_size, num_units, input_size, memory_size, rank, kTfLiteActRelu, |
| input_data, tflite::testing::feature_weights_data_16x1x1, |
| tflite::testing::time_weights_data_16x1x1, activation_state_data_mutable, |
| tflite::testing::bias_data_16x1x1, scratch_buffer, output_data, |
| tflite::testing::input_data_16x1x1, input_size, |
| tflite::testing::golden_output_relu_16x1x1); |
| } |
| |
| // Only reference kernels support full int8 svdf currently. |
| #if !defined(HEXAGON) |
| TF_LITE_MICRO_TEST(SvdfQuantized1x16Input64x1OutputShouldMatchGoldenInt8) { |
| tflite::testing::SvdfQuantized1x16Input64x1OutputShouldMatchGolden<int8_t>(); |
| } |
| #endif |
| |
| TF_LITE_MICRO_TEST(SvdfQuantized1x16Input64x1OutputShouldMatchGoldenInt16) { |
| tflite::testing::SvdfQuantized1x16Input64x1OutputShouldMatchGolden<int16_t>(); |
| } |
| |
| // Only reference kernels support full int8 svdf currently. |
| #if !defined(HEXAGON) |
| TF_LITE_MICRO_TEST(SvdfQuantized1x16Input64x1OutputReluShouldMatchGoldenInt8) { |
| tflite::testing::SvdfQuantized1x16Input64x1OutputReluShouldMatchGolden< |
| int8_t>(); |
| } |
| #endif |
| |
| TF_LITE_MICRO_TEST(SvdfQuantized1x16Input64x1OutputReluShouldMatchGoldenInt16) { |
| tflite::testing::SvdfQuantized1x16Input64x1OutputReluShouldMatchGolden< |
| int16_t>(); |
| } |
| |
| TF_LITE_MICRO_TESTS_END |