blob: ae8fadc11398d5bd4a0ea2b6952215064bf25ad8 [file] [log] [blame]
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
#include "tensorflow/lite/micro/test_helpers.h"
#include "tensorflow/lite/micro/testing/micro_test.h"
namespace tflite {
namespace testing {
namespace {
// naming as follows: <tensor name>_<input size>x<batch size>x<batch count>
// 10 inputs each with shape {2, 2}.
const float input_data_2x2x10[] = {
0.12609188, -0.46347019, 0.35867718, 0.36897406,
0.14278367, -1.64410412, -0.57290924, 0.12729003,
0.49837467, 0.19278903, 0.17660543, 0.52949083,
-0.11186574, 0.13164264, -0.72674477, -0.5683046,
-0.68892461, 0.37783599, -0.63690937, 0.44483393,
-0.81299269, -0.86831826, -0.95760226, 1.82078898,
-1.45006323, -0.82251364, -1.65087092, -1.89238167,
0.03966608, -0.24936394, 2.06740379, -1.51439476,
0.11771342, -0.23761693, 0.31088525, -1.55601168,
-0.89477462, 1.67204106, -0.6230064, 0.29819036,
};
// Feature filter of shape {8, 2}.
const float feature_weights_data_2x2x10[] = {
-0.31930989, 0.0079667, 0.39296314, 0.37613347, 0.12416199, 0.15785322,
0.27901134, 0.3905206, 0.21931258, -0.36137494, -0.10640851, 0.31053296,
-0.36118156, -0.0976817, -0.36916667, 0.22197971};
// Time filter of shape {8, 10}.
const float time_weights_data_2x2x10[] = {
-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156,
0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199,
0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518,
-0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296,
-0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236,
0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846,
-0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
-0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657,
-0.14884081, 0.19931212, -0.36002168, 0.34663299, -0.11405486,
0.12672701, 0.39463779, -0.07886535, -0.06384811, 0.08249187,
-0.26816407, -0.19905911, 0.29211238, 0.31264046, -0.28664589,
0.05698794, 0.11613581, 0.14078894, 0.02187902, -0.21781836,
-0.15567942, 0.08693647, -0.38256618, 0.36580828, -0.22922277,
-0.0226903, 0.12878349, -0.28122205, -0.10850525, -0.11955214,
0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326,
0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763};
// Activation state with shape {2, 80}. These initial values must be copied into
// a mutable activation state tensor.
const float initial_activation_state_data_2x2x10[] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
// Bias with shape {8}
const float bias_data_2x2x10[] = {0, 0, 0, 0, 0, 0, 0, 0};
// 10 outputs each of shape {2, 4}
const float golden_output_2x2x10[] = {
-0.044205, -0.013757, 0.050369, -0.018447,
0.073010, 0.025142, -0.021154, 0.013551,
-0.209613, -0.062421, 0.150209, -0.108334,
0.028256, -0.006950, -0.030885, 0.009603,
-0.076800, -0.037075, -0.087198, -0.155183,
0.091069, 0.098446, -0.016083, 0.106475,
-0.082123, -0.162238, -0.084434, -0.141074,
-0.029340, -0.090685, 0.053302, -0.030604,
-0.201440, 0.088424, 0.139877, 0.012416,
-0.113212, 0.103893, -0.100842, 0.122780,
-0.166632, -0.116705, 0.175298, -0.047163,
0.313077, -0.166485, -0.285860, 0.129069,
-0.625911, 0.046134, 0.138081, -0.129581,
-0.521455, -0.061579, 0.230289, 0.114963,
-0.216693, -0.161643, -0.179177, -0.052599,
-0.213239, 0.029502, 0.260858, 0.275045,
-0.213689, -0.323608, -0.285635, -0.317687,
-0.324092, -0.317972, -0.208450, -0.462504,
-0.255126, -0.218576, -0.041528, 0.179421,
-0.440583, 0.072127, -0.284136, 0.241570};
// Simulated real-world inputs, weights and expected outputs.
// Input of shape {1x16}
const float input_data_16x1x1[] = {
-0.488494, 2.023762, -2.233117, -0.488494, 3.559030, 9.490748,
-3.210106, -1.953977, -0.279140, 0.907204, 1.674838, 0.000000,
-0.279140, -0.628064, -0.069785, -0.628064,
};
// Feature filter of shape {64, 16}.
const float feature_weights_data_16x1x1[] = {
0.173588, 0.173588, -0.024798, 0.193426, -0.099193, 0.044637, 0.183507,
0.183507, 0.044637, 0.198386, -0.069435, 0.084314, 0.312458, 0.024798,
0.173588, -0.049596, -0.352135, -0.550521, -0.009919, -0.099193, -0.074395,
-0.128951, 0.193426, 0.357095, -0.317418, -0.119032, -0.218225, -0.004960,
-0.386853, -0.133911, 0.252942, -0.019839, -0.024798, -0.054556, -0.069435,
-0.128951, 0.029758, -0.099193, -0.312458, -0.029758, 0.064475, 0.183507,
0.114072, -0.178547, -0.247982, -0.119032, 0.243023, -0.119032, -0.034718,
-0.178547, 0.019839, 0.128951, -0.223184, -0.009919, -0.213265, 0.168628,
-0.143830, -0.322377, -0.218225, -0.193426, -0.252942, -0.049596, 0.064475,
-0.267821, -0.580279, -0.099193, 0.213265, 0.119032, -0.119032, -0.178547,
0.610037, 0.109112, 0.049596, -0.014879, -0.049596, -0.193426, 0.039677,
-0.148789, -0.114072, -0.158709, -0.158709, 0.094233, 0.099193, -0.114072,
0.104153, -0.123991, 0.198386, -0.173588, 0.089274, -0.247982, -0.054556,
0.123991, 0.183507, 0.114072, 0.188467, 0.302539, 0.044637, 0.039677,
-0.099193, 0.168628, -0.024798, -0.054556, -0.109112, 0.014879, -0.009919,
0.069435, -0.396772, -0.287660, -0.079354, -0.104153, 0.054556, 0.089274,
-0.099193, 0.114072, 0.034718, 0.119032, 0.282700, -0.119032, -0.505884,
-0.233104, -0.114072, -0.257902, -0.233104, -0.178547, 0.153749, 0.128951,
0.143830, -0.188467, -0.183507, 0.104153, -0.024798, 0.193426, -0.287660,
0.168628, -0.009919, 0.119032, -0.024798, -0.099193, -0.203346, 0.099193,
0.084314, -0.168628, 0.123991, -0.148789, 0.114072, -0.029758, 0.228144,
-0.238063, 0.089274, -0.064475, 0.307498, -0.188467, -0.004960, -0.252942,
-0.173588, -0.158709, -0.044637, -0.009919, 0.312458, -0.262861, 0.059516,
0.158709, 0.069435, -0.282700, 0.074395, -0.322377, -0.183507, -0.123991,
-0.233104, 0.009919, 0.252942, -0.243023, 0.555481, -0.099193, -0.119032,
-0.441409, 0.148789, 0.084314, -0.168628, -0.183507, 0.188467, 0.024798,
-0.302539, 0.223184, 0.143830, -0.193426, -0.054556, -0.218225, -0.297579,
0.104153, 0.272781, -0.034718, 0.114072, -0.059516, 0.044637, 0.342216,
0.421570, 0.138870, -0.024798, -0.039677, -0.163668, -0.034718, 0.396772,
-0.128951, -0.044637, -0.173588, 0.302539, 0.079354, 0.049596, 0.133911,
-0.029758, -0.312458, -0.029758, 0.079354, 0.128951, 0.252942, 0.213265,
0.014879, 0.287660, 0.178547, 0.297579, 0.352135, 0.401732, 0.024798,
-0.277740, -0.411651, -0.069435, 0.342216, -0.158709, -0.104153, -0.009919,
0.223184, 0.228144, -0.019839, 0.059516, -0.104153, -0.510844, 0.029758,
-0.406691, 0.089274, 0.421570, 0.163668, -0.143830, -0.019839, -0.039677,
0.104153, -0.044637, -0.128951, 0.203346, 0.079354, -0.069435, 0.094233,
-0.138870, 0.466207, -0.163668, 0.049596, 0.029758, 0.267821, 0.029758,
-0.049596, 0.009919, 0.004960, -0.099193, 0.094233, -0.262861, 0.089274,
-0.302539, 0.332297, -0.307498, -0.014879, 0.168628, -0.094233, -0.272781,
0.034718, -0.133911, -0.228144, 0.094233, 0.257902, -0.228144, 0.153749,
-0.054556, -0.252942, 0.054556, 0.218225, -0.054556, 0.302539, 0.282700,
0.054556, -0.044637, -0.133911, 0.233104, -0.049596, 0.411651, 0.044637,
-0.297579, -0.029758, -0.114072, 0.114072, -0.580279, 0.079354, -0.024798,
-0.347175, -0.128951, -0.099193, 0.238063, -0.104153, -0.009919, 0.158709,
-0.034718, 0.123991, -0.163668, 0.059516, 0.342216, 0.009919, 0.064475,
-0.307498, -0.520763, -0.238063, 0.163668, 0.362054, 0.034718, -0.178547,
-0.104153, -0.257902, 0.322377, 0.054556, 0.148789, -0.178547, 0.084314,
0.004960, 0.257902, 0.029758, 0.079354, -0.223184, -0.193426, 0.282700,
0.000000, -0.019839, -0.114072, 0.491005, -0.193426, -0.029758, -0.243023,
0.009919, 0.089274, -0.277740, -0.089274, 0.104153, 0.337256, 0.138870,
-0.307498, -0.054556, 0.352135, 0.133911, -0.044637, 0.133911, -0.089274,
-0.357095, -0.272781, 0.069435, 0.059516, -0.109112, 0.148789, -0.044637,
-0.019839, -0.153749, 0.123991, -0.223184, 0.322377, 0.074395, -0.312458,
0.024798, -0.223184, 0.109112, -0.138870, 0.218225, -0.074395, -0.406691,
0.009919, -0.198386, -0.009919, 0.416611, 0.178547, 0.148789, 0.133911,
-0.004960, 0.069435, -0.054556, -0.044637, 0.297579, 0.059516, -0.456288,
-0.148789, -0.004960, 0.054556, 0.094233, -0.104153, 0.198386, -0.302539,
0.133911, 0.411651, 0.054556, 0.525723, -0.089274, 0.079354, 0.238063,
0.079354, -0.039677, 0.039677, 0.029758, 0.332297, -0.014879, -0.367014,
-0.143830, -0.123991, -0.064475, 0.014879, 0.173588, -0.168628, 0.386853,
0.009919, 0.173588, 0.163668, 0.123991, 0.163668, 0.198386, 0.203346,
-0.401732, -0.009919, 0.272781, -0.173588, 0.044637, 0.238063, 0.133911,
0.049596, 0.208305, -0.024798, 0.049596, -0.049596, 0.034718, -0.446368,
0.466207, -0.089274, -0.099193, -0.128951, -0.228144, 0.014879, -0.252942,
0.074395, -0.223184, -0.168628, -0.292619, 0.178547, 0.153749, -0.014879,
0.054556, 0.000000, 0.193426, 0.158709, 0.178547, -0.327337, -0.138870,
-0.114072, 0.168628, 0.297579, -0.109112, -0.029758, -0.029758, -0.416611,
0.059516, 0.000000, -0.168628, -0.322377, 0.238063, -0.128951, -0.029758,
0.500925, 0.292619, 0.123991, -0.099193, 0.074395, 0.317418, -0.148789,
0.064475, -0.104153, -0.044637, -0.094233, 0.188467, -0.044637, 0.213265,
-0.233104, -0.049596, 0.004960, -0.198386, 0.287660, -0.148789, -0.257902,
0.004960, -0.218225, -0.044637, -0.386853, -0.243023, -0.163668, 0.094233,
0.029758, -0.019839, -0.009919, -0.143830, -0.158709, 0.158709, -0.243023,
-0.039677, -0.297579, 0.069435, 0.049596, 0.302539, 0.059516, 0.074395,
-0.019839, 0.352135, -0.019839, -0.138870, -0.178547, -0.243023, 0.233104,
0.252942, -0.228144, -0.049596, 0.173588, 0.173588, -0.074395, -0.034718,
-0.292619, 0.362054, 0.183507, 0.243023, -0.203346, -0.044637, 0.054556,
0.059516, -0.158709, -0.158709, 0.000000, 0.327337, 0.119032, 0.034718,
-0.044637, -0.089274, 0.089274, -0.233104, 0.000000, -0.317418, 0.371974,
0.213265, 0.307498, -0.178547, -0.367014, 0.039677, -0.059516, 0.168628,
-0.014879, 0.143830, 0.123991, -0.084314, -0.332297, -0.416611, 0.183507,
0.109112, -0.039677, 0.014879, 0.292619, -0.213265, -0.054556, 0.004960,
0.123991, 0.119032, 0.000000, -0.332297, -0.312458, -0.198386, -0.213265,
0.119032, 0.322377, 0.168628, 0.104153, -0.262861, 0.327337, -0.049596,
-0.228144, -0.074395, 0.168628, 0.123991, 0.396772, 0.044637, 0.322377,
0.193426, 0.267821, -0.178547, 0.297579, 0.148789, -0.218225, -0.138870,
0.044637, 0.049596, 0.133911, 0.064475, 0.069435, 0.064475, -0.158709,
-0.044637, -0.173588, 0.267821, 0.327337, 0.079354, -0.228144, 0.029758,
0.014879, 0.198386, -0.109112, -0.133911, 0.431490, 0.099193, 0.421570,
0.233104, -0.054556, 0.054556, -0.317418, -0.133911, -0.123991, -0.287660,
0.342216, -0.049596, -0.153749, 0.228144, -0.213265, 0.262861, 0.406691,
-0.084314, -0.004960, 0.193426, 0.188467, -0.099193, -0.223184, 0.163668,
-0.257902, -0.153749, 0.441409, 0.099193, 0.128951, -0.089274, -0.208305,
-0.009919, -0.004960, -0.109112, 0.024798, -0.119032, 0.019839, 0.391812,
-0.024798, 0.198386, 0.327337, -0.505884, -0.099193, 0.510844, -0.148789,
0.094233, -0.153749, -0.039677, 0.352135, 0.272781, -0.228144, -0.287660,
-0.272781, 0.148789, 0.277740, 0.074395, 0.109112, -0.064475, 0.044637,
0.074395, -0.292619, 0.153749, -0.064475, -0.114072, 0.198386, -0.039677,
-0.128951, -0.004960, 0.257902, -0.228144, -0.094233, 0.064475, 0.014879,
0.188467, -0.416611, 0.099193, 0.362054, -0.208305, 0.198386, -0.079354,
0.009919, 0.119032, 0.332297, 0.243023, -0.168628, 0.158709, 0.039677,
0.143830, 0.277740, -0.168628, 0.009919, 0.099193, -0.004960, -0.257902,
-0.297579, 0.208305, -0.104153, 0.119032, 0.247982, 0.381893, -0.223184,
-0.367014, -0.327337, -0.168628, -0.094233, 0.208305, -0.019839, 0.183507,
0.084314, 0.133911, 0.109112, -0.148789, -0.183507, -0.411651, -0.024798,
-0.114072, -0.029758, -0.009919, 0.173588, -0.059516, -0.049596, 0.039677,
0.317418, 0.138870, -0.247982, -0.084314, 0.158709, 0.054556, -0.084314,
-0.049596, 0.074395, 0.019839, -0.282700, -0.119032, -0.262861, 0.163668,
-0.069435, -0.064475, -0.059516, 0.094233, 0.123991, -0.079354, -0.272781,
-0.267821, 0.233104, 0.114072, -0.218225, 0.540602, 0.089274, 0.262861,
0.079354, 0.267821, -0.119032, -0.109112, -0.128951, 0.128951, -0.044637,
-0.272781, 0.277740, 0.297579, -0.054556, -0.084314, -0.049596, 0.123991,
0.059516, 0.238063, -0.168628, -0.009919, 0.163668, -0.307498, 0.109112,
-0.064475, 0.218225, -0.168628, -0.004960, -0.168628, 0.119032, 0.094233,
-0.183507, -0.089274, -0.292619, -0.094233, 0.064475, -0.183507, -0.168628,
0.089274, 0.074395, -0.367014, -0.024798, -0.069435, 0.119032, -0.302539,
-0.376933, -0.123991, -0.009919, -0.069435, -0.208305, -0.119032, 0.014879,
-0.183507, -0.238063, 0.163668, -0.332297, -0.148789, -0.391812, -0.024798,
-0.133911, -0.059516, -0.123991, 0.123991, -0.292619, -0.044637, 0.059516,
-0.069435, 0.049596, -0.069435, 0.034718, 0.158709, -0.347175, -0.044637,
0.352135, -0.347175, -0.282700, -0.054556, 0.307498, 0.029758, 0.357095,
-0.148789, 0.208305, -0.317418, 0.009919, 0.004960, -0.243023, 0.049596,
-0.099193, 0.213265, -0.342216, 0.158709, 0.123991, -0.332297, 0.386853,
-0.262861, -0.208305, 0.123991, -0.044637, 0.148789, 0.084314, -0.297579,
-0.307498, -0.163668, 0.337256, -0.014879, 0.074395, 0.178547, -0.004960,
-0.257902, -0.019839, -0.228144, -0.034718, -0.277740, -0.158709, -0.119032,
-0.153749, 0.629876, 0.277740, 0.178547, -0.267821, -0.004960, 0.247982,
0.084314, -0.094233, 0.000000, -0.039677, 0.332297, 0.178547, 0.009919,
-0.213265, -0.208305, -0.044637, 0.019839, 0.218225, -0.297579, 0.014879,
-0.247982, -0.004960, -0.128951, 0.421570, -0.059516, 0.362054, -0.203346,
-0.143830, -0.099193, -0.024798, 0.094233, -0.123991, 0.163668, 0.109112,
-0.104153, -0.233104, 0.009919, -0.218225, 0.376933, 0.104153, -0.059516,
0.049596, -0.054556, 0.019839, -0.044637, -0.019839, 0.371974, -0.019839,
0.104153, 0.168628, -0.024798, -0.272781, -0.158709, 0.223184, 0.044637,
0.039677, -0.168628, -0.287660, -0.109112, 0.094233, -0.089274, -0.148789,
0.178547, -0.039677, -0.089274, -0.049596, -0.024798, 0.064475, -0.158709,
0.089274, 0.029758, -0.247982, 0.362054, 0.024798, -0.004960, -0.099193,
0.173588, -0.059516, 0.188467, -0.629876, 0.094233, 0.371974, 0.069435,
0.252942, -0.357095, -0.272781, -0.367014, 0.014879, -0.049596, -0.262861,
0.009919, -0.094233, -0.094233, 0.059516, 0.223184, 0.133911, 0.411651,
-0.044637, -0.044637, 0.109112, 0.228144, 0.386853, -0.233104, 0.069435,
0.228144, -0.302539, 0.029758, 0.089274, 0.044637, -0.238063, -0.138870,
-0.158709, -0.019839, 0.049596, 0.039677, 0.000000, -0.069435, 0.109112,
-0.213265, -0.188467, -0.262861, -0.267821, -0.094233, 0.133911, 0.391812,
0.123991, -0.317418, 0.233104, -0.029758, -0.099193, -0.193426, 0.074395,
-0.009919, 0.252942, 0.322377, -0.530683, 0.208305, 0.252942, 0.203346,
-0.069435, -0.262861};
// Time filter of shape {64, 8}.
const float time_weights_data_16x1x1[] = {
-0.052026, 0.043107, 0.053512, 0.013378, 0.011892, -0.182834, -0.108511,
0.153105, 0.050539, -0.173915, 0.145672, 0.208103, -0.221481, 0.108511,
-0.496475, 0.181347, -0.016351, -0.132294, -0.234859, -0.243778, 0.028243,
-0.228914, -0.130808, -0.167969, -0.041621, -0.306209, -0.193239, -0.028243,
-0.057972, -0.057972, -0.497962, 0.054999, 0.181347, 0.047566, -0.099592,
-0.111484, -0.130808, -0.071350, 0.380532, 0.010405, 0.041621, 0.052026,
0.022297, 0.081755, 0.098106, 0.099592, -0.584176, -0.023783, 0.062431,
-0.090674, -0.279453, -0.486070, -0.273507, 0.004459, -0.062431, 0.095133,
0.056485, 0.022297, -0.105538, -0.184320, 0.358235, 0.254183, 0.049053,
0.084728, 0.218508, 0.078782, -0.136754, -0.017837, -0.124862, -0.118916,
-0.001486, 0.043107, 0.254183, 0.087701, 0.261616, 0.309182, -0.404315,
-0.040134, -0.046080, -0.052026, -0.034188, -0.475665, -0.025270, -0.049053,
-0.046080, -0.062431, 0.020810, 0.040134, -0.135267, -0.169456, -0.050539,
-0.576743, 0.034188, 0.075809, 0.101079, 0.136754, 0.083241, 0.077296,
-0.050539, 0.761064, -0.335938, -0.080268, 0.025270, 0.257156, 0.227427,
0.252697, 0.065404, 0.115943, 0.222968, -0.026756, -0.054999, 0.107025,
-0.093646, 0.041621, -0.092160, -0.474178, -0.016351, 0.004459, 0.049053,
0.019324, 0.019324, 0.074323, 0.038648, -0.613905, 0.182834, 0.075809,
0.028243, 0.019324, 0.010405, -0.011892, 0.001486, -0.492016, -0.224454,
-0.474178, -0.147159, 0.002973, 0.102565, 0.136754, -0.267561, -0.001486,
-0.095133, -0.040134, 0.066890, 0.074323, 0.104052, 0.532150, 0.090674,
0.072836, -0.053512, -0.004459, 0.020810, 0.046080, 0.062431, 0.477151,
0.133781, -0.029729, -0.026756, 0.031215, 0.156077, 0.096619, 0.251210,
0.352289, 0.657012, 0.047566, -0.014865, -0.072836, -0.016351, 0.008919,
-0.053512, 0.016351, 0.300263, 0.047566, 0.020810, 0.169456, 0.001486,
0.007432, 0.111484, 0.044594, -0.188779, -0.096619, 0.074323, -0.040134,
0.160537, 0.138240, 0.184320, 0.377559, -0.092160, -0.049053, 0.056485,
-0.032702, 0.001486, -0.083241, -0.472692, -0.114457, -0.117430, -0.075809,
0.026756, 0.163510, 0.172428, 0.127835, -0.199185, -0.218508, -0.057972,
-0.132294, -0.162023, -0.019324, -0.245265, -0.395396, -0.254183, 0.084728,
0.248238, 0.191752, 0.221481, 0.173915, 0.173915, -0.208103, -0.077296,
0.384991, -0.313641, -0.313641, -0.147159, -0.090674, 0.035675, 0.059458,
-0.010405, 0.019324, 0.087701, 0.016351, 0.037161, 0.469719, -0.074323,
0.092160, 0.026756, 0.090674, 0.098106, 0.004459, -0.034188, 0.492016,
-0.367154, -0.093646, -0.063917, 0.041621, 0.017837, 0.026756, -0.062431,
-0.350803, 0.425125, 0.002973, 0.083241, 0.075809, 0.016351, 0.047566,
-0.185807, -0.107025, -0.098106, -0.144186, 0.255670, 0.020810, 0.105538,
0.029729, 0.129321, 0.156077, 0.141213, 0.334452, 0.147159, -0.066890,
0.035675, 0.115943, 0.240805, 0.328506, 0.162023, -0.237832, 0.218508,
0.233373, 0.214049, 0.099592, 0.026756, -0.322560, -0.236346, -0.166483,
0.225941, 0.109997, -0.147159, 0.147159, -0.266075, 0.111484, 0.078782,
-0.120403, 0.022297, -0.075809, -0.148645, -0.251210, -0.176888, -0.044594,
-0.023783, 0.016351, 0.026756, -0.013378, -0.069863, -0.112970, 0.013378,
0.086214, 0.014865, 0.352289, -0.240805, -0.135267, -0.114457, -0.472692,
0.334452, 0.095133, 0.047566, 0.130808, -0.068377, -0.007432, -0.130808,
-0.121889, -0.053512, -0.245265, -0.371613, -0.083241, 0.000000, -0.028243,
0.029729, -0.093646, -0.004459, -0.038648, -0.108511, -0.475665, -0.169456,
-0.047566, -0.010405, -0.114457, -0.353776, -0.034188, -0.044594, 0.041621,
-0.047566, -0.107025, 0.004459, 0.053512, 0.047566, -0.358235, -0.193239,
0.040134, -0.096619, -0.054999, 0.099592, 0.032702, 0.205130, -0.170942,
-0.237832, -0.405801, -0.126348, -0.072836, -0.203644, -0.169456, -0.093646,
-0.074323, 0.078782, 0.607959, -0.437017, -0.164996, -0.166483, 0.043107,
-0.016351, 0.258643, 0.065404, -0.057972, 0.017837, 0.080268, 0.050539,
-0.013378, -0.215536, -0.524718, 0.260129, 0.040134, -0.002973, -0.046080,
0.020810, 0.025270, 0.145672, 0.515799, 0.233373, 0.011892, 0.139727,
0.126348, 0.065404, -0.007432, -0.008919, 0.035675, 0.083241, 0.040134,
-0.005946, 0.503907, -0.490529, -0.181347, -0.092160, -0.038648, 0.019324,
0.133781, -0.011892, 0.041621, 0.062431, -0.062431, -0.040134, -0.092160,
-0.111484, -0.133781, -0.130808, -0.484583, -0.248238, 0.037161, -0.092160,
-0.056485, -0.041621, 0.112970, 0.248238, 0.438503, 0.258643, -0.013378,
0.004459, 0.043107, 0.040134, 0.017837, 0.101079, 0.264589, 0.212563,
0.014865, 0.285399, 0.153105, 0.170942, 0.358235, 0.334452, 0.086214,
0.132294, 0.098106, -0.001486, 0.107025, 0.200671, -0.026756, 0.344857,
0.227427, -0.041621, 0.098106, 0.063917, -0.093646, 0.130808, 0.285399,
-0.319587, 0.035675, -0.017837, -0.319587, 0.016351, -0.098106, -0.017837,
0.083241, 0.074323, -0.054999, 0.276480, 0.316614, -0.099592, -0.059458,
0.156077, -0.043107, 0.035675, 0.056485, -0.022297, 0.017837, -0.001486,
0.340398, 0.492016, 0.004459, 0.057972, -0.150132, -0.206617, -0.257156,
-0.248238, -0.080268, -0.164996, 0.352289, -0.054999, -0.056485, 0.010405,
-0.049053, -0.041621, -0.099592, 0.013378, -0.089187, 0.057972, -0.413234,
0.217022, 0.013378, -0.080268, -0.035675, 0.035675, 0.007432, 0.002973,
-0.469719, 0.141213, 0.136754, 0.153105, 0.130808, -0.104052, -0.508367,
-0.291345, -0.072836, -0.019324, -0.252697, -0.214049, -0.214049, 0.130808,
0.484583};
// Bias of shape {64}
const float bias_data_16x1x1[] = {
-0.245395, -0.083545, -0.262522, -0.407912, -0.560898, -0.364789, -0.037964,
-0.378594, 0.178152, 0.400380, -0.301349, -0.240913, -0.159454, -0.158757,
-0.073665, 0.455906, -0.061232, 0.318907, -0.226993, -0.344644, 0.140316,
0.559608, 0.109774, 0.437391, 0.113849, -0.162068, 0.039572, 0.569472,
0.460205, 0.113459, 0.370469, 0.176811, 0.203063, -0.296975, -0.271655,
0.059862, -0.159912, -0.077310, -0.338314, -0.195477, -0.256762, 0.233834,
0.083172, 0.029040, -0.236288, -0.267054, -0.166627, 0.188319, -0.271391,
-0.222920, 0.106463, 0.263614, 0.384986, -0.125957, -0.095890, 0.363686,
-0.036990, -0.358884, -0.178254, 0.305596, 0.390088, -0.189437, 0.613409,
0.399639};
// Activation state with shape {64, 8}. These initial values must be copied into
// a mutable activation state tensor.
const float initial_activation_state_data_16x1x1[] = {
-0.582275, -0.586623, -1.262373, -1.277279, -1.542175, -1.271999, -1.429757,
-1.184425, -0.462094, -1.443421, 0.230736, -0.494701, -0.354955, -2.534061,
-4.277471, -4.218467, 0.403711, -0.248748, -0.330111, -0.467683, 0.549047,
0.733511, -0.230115, 0.793136, -1.126353, -0.984123, -0.081984, -0.222351,
0.692830, 0.517060, 1.367958, 2.118860, -0.116766, -0.826365, -2.402700,
-2.313884, -2.898954, -2.076005, -2.405185, -2.755481, 0.329490, 0.085400,
-1.485966, -2.034702, -2.161405, -1.269515, -1.151818, -1.823841, 0.561469,
1.109273, 1.693411, -0.082605, -0.069252, -1.225107, -1.330693, -1.411435,
0.253406, -0.357439, -1.593415, -0.879779, -1.111136, 1.821357, 2.471952,
1.236908, -4.014127, -2.810448, -2.944604, -1.930980, -1.566398, -0.838166,
-0.319242, 0.749349, 1.156476, 0.658670, 1.997437, 2.080663, 2.912618,
2.677224, 2.642442, 2.796163, -0.272349, -0.473273, 3.120063, 2.747097,
3.595510, 1.874150, 2.049919, 2.093396, -1.049959, 0.277939, -1.255541,
-1.052443, -1.810177, -0.883505, -0.538178, 0.524203, -1.017662, -0.269244,
0.039129, -0.227941, -0.114592, -2.018243, -2.548968, -0.706804, 0.890959,
0.102480, 0.349986, 0.405885, 1.287216, 0.756181, 0.319242, -0.641590,
-3.841774, -2.716042, -4.342065, -3.826557, -2.924729, -1.643724, -1.237839,
-0.597492, -1.954892, -1.215169, -1.528201, -1.018904, -0.863941, -0.293467,
0.039439, 0.672023, 1.408019, 1.362679, 1.467644, 1.006171, 0.310236,
-0.249990, -1.048406, -0.752144, -1.831605, -1.058033, -1.096541, -0.293467,
0.051551, 0.232600, 0.088816, 2.570395, 0.704009, 2.465120, 3.010751,
2.139357, 0.630410, 1.006171, 1.545281, 1.486898, -1.162998, -2.344317,
-4.593918, -3.522842, -2.872247, -1.416714, -0.642521, -0.230115, 0.315205,
-0.368930, -0.162726, 0.396879, 0.505570, 0.534451, 0.554947, 1.270447,
0.388805, 0.531967, -1.243119, -0.671713, -1.214859, -0.238189, 0.016459,
-1.164550, 0.609603, 3.293348, 2.600208, 1.454290, -1.034121, -1.760179,
-1.192500, -0.613951, 3.449553, 2.912618, 1.917937, 1.435968, 0.879158,
1.118279, 0.102791, -0.502465, -0.239121, -0.092853, 1.786265, 1.943091,
2.547104, 2.630641, 2.585302, 2.965411, -0.945615, -2.538720, -2.474126,
-1.088156, 0.056209, 0.864873, 0.170490, 0.457435, 0.545941, 0.752765,
1.569503, 1.129459, 0.662086, -0.527929, -0.810838, -1.662978, 1.285042,
1.653040, 4.130893, 2.961995, 4.147041, 3.256393, 3.881524, 2.522571,
-0.875431, -1.112378, 2.105817, 2.180970, 3.121926, 1.577577, 1.639376,
2.906407, -0.142230, 0.421101, 2.212335, 2.311399, 3.993321, 3.651719,
4.206666, 4.678387, -1.304917, -1.130701, -2.543067, -2.500212, -2.197118,
-1.197158, -0.949652, -0.282908, 0.320795, -1.543728, 1.290322, 1.788128,
3.957297, 3.205774, 2.892432, 2.297114, 0.138814, -0.139435, 0.936920,
0.344707, 0.723263, -1.772290, -3.138385, -2.287177, -2.405806, -1.859864,
-4.572801, -3.410424, -3.855748, -2.239663, -2.269786, -1.582857, 4.238342,
3.858543, 2.499901, 1.087535, 0.290051, -0.026086, -0.880400, -2.602692,
-1.404292, 0.253096, -0.665502, -1.443421, -0.925119, -0.096580, 1.115484,
1.846200, -1.604284, -1.244671, -0.464888, 0.326385, 0.168006, -0.262723,
-0.744691, 0.953379, -0.407127, -0.349986, -1.154302, 0.831023, 1.590931,
2.538720, 2.063583, 3.697680, -0.752455, -1.293117, -1.330693, -1.869802,
-0.592523, 0.631652, 1.198089, -0.481347, 3.738983, 4.153252, 2.782499,
2.244321, 0.709289, 1.650245, 1.700865, 0.385078, 2.192460, 2.610456,
4.009780, 3.492719, 2.574743, 2.116687, 1.856138, 1.205853, 2.722563,
4.075305, 5.415935, 3.009198, 2.715421, 1.571056, 0.897170, -2.430339,
0.749970, 0.425760, -0.302783, 0.817359, 1.031636, 1.913589, 2.686229,
1.631923, -1.459259, -1.793097, -1.187531, -1.553355, -0.844998, -1.296843,
-1.805519, -0.486627, 0.909591, 2.082837, -1.473855, -2.456735, -3.851401,
-2.760139, -3.060438, -2.605487, -2.138735, -2.441519, -1.333177, -1.353984,
-0.245642, -0.588486, 0.033850, 2.084700, 0.076084, 0.690035, 0.747797,
0.594697, -1.016109, -1.348083, -1.201195, -1.088466, 2.045571, 2.460772,
0.717984, 0.041613, -0.721711, 1.134738, 2.322269, 1.112378, -0.307441,
-0.581033, -0.868599, -0.018633, 0.856488, 0.919839, 0.303094, -0.433213,
0.811148, -0.508986, -1.060828, -1.227591, -1.566087, -1.117968, -1.385038,
-2.011101, -0.490353, -1.849616, -0.594697, -1.055859, 1.110205, 0.622646,
0.145957, 0.359303, 1.012072, 0.774814, -0.400295, -1.484103, -2.007374,
-1.441247, -0.997787, -0.581033, -0.545941, -0.306510, 0.693451, 0.087264,
-0.227320, -1.211753, -1.532859, -1.688753, 0.065215, 0.134777, 0.608051,
-0.393152, -0.214588, -0.635689, -1.499320, 0.069562, -1.555839, -2.633126,
-2.966032, -1.550870, -0.101549, 0.874189, 0.436318, 0.299367, 2.289972,
2.339659, 2.602071, 1.564535, 0.019254, -0.583207, -1.295912, -2.424749,
-1.221070, -1.175109, -0.577306, -0.102791, 1.877876, 2.568222, 2.173827,
3.131243, 2.637784, 2.088737, 3.679047, 3.218506, 2.483442, 1.650556,
1.363611, -0.027328, 1.486898, -0.721711, -3.684327, -3.006093, -3.777491,
-2.327548, -2.737470, -4.549510, -0.060867, 0.127635, 0.680408, 0.581344,
0.320174, -0.403090, -0.838166, 0.293777, -0.995613, -0.165521, -0.419859,
1.110515, 1.203679, 1.749931, 2.467294, 4.276539, 0.031055, -0.967664,
1.167035, 1.865144, 3.221923, 3.248630, 4.121266, 4.187723, 0.749039,
-1.571056, 0.785994, 1.568572, 3.759479, 3.588678, 4.116608, 3.864444,
-0.290051, -0.271107, 0.375140, 0.537556, 0.536314, 0.095959, 0.054656,
0.088816};
// One output with shape {1, 64}
const float golden_output_16x1x1[] = {
-0.087914, 1.145864, -0.418088, -1.556392, -0.925298, 0.205252, 0.289119,
1.331180, -0.218010, 0.963057, -2.225886, 1.248478, 1.448983, 0.355467,
1.682174, 0.803739, 0.449738, 0.543566, 1.916269, -2.975136, 0.222774,
0.241589, -0.104216, 1.561748, 0.936818, -0.089907, -0.520117, -0.870353,
1.606074, 0.895770, 0.521297, -0.369994, -0.889351, -2.809309, 2.404628,
1.069754, -0.195456, -1.105652, 1.272715, -1.233177, 1.271416, -1.691805,
-1.058125, -0.716227, 0.052540, 1.262483, 0.540555, 1.735760, -0.539197,
-0.014367, -0.243002, 1.072254, 0.528985, -0.731151, -1.262649, 2.338702,
-0.603093, 0.970736, -3.567897, 0.035085, -0.201711, -0.550400, 1.545573,
-1.805005};
// One output with shape {1, 64}
const float golden_output_relu_16x1x1[] = {
0.000000, 1.145864, 0.000000, 0.000000, 0.000000, 0.205252, 0.289119,
1.331180, 0.000000, 0.963057, 0.000000, 1.248478, 1.448983, 0.355467,
1.682174, 0.803739, 0.449738, 0.543566, 1.916269, 0.000000, 0.222774,
0.241589, 0.000000, 1.561748, 0.936818, 0.000000, 0.000000, 0.000000,
1.606074, 0.895770, 0.521297, 0.000000, 0.000000, 0.000000, 2.404628,
1.069754, 0.000000, 0.000000, 1.272715, 0.000000, 1.271416, 0.000000,
0.000000, 0.000000, 0.052540, 1.262483, 0.540555, 1.735760, 0.000000,
0.000000, 0.000000, 1.072254, 0.528985, 0.000000, 0.000000, 2.338702,
0.000000, 0.970736, 0.000000, 0.035085, 0.000000, 0.000000, 1.545573,
0.000000};
template <typename T>
void ValidateSVDFGoldens(const int batch_size, const int num_units,
const int input_size, const int rank,
TfLiteTensor* tensors, const int tensor_count,
TfLiteFusedActivation activation,
const T* input_sequences_data,
const int input_sequences_len, T* output_data,
const T* expected_output, float tolerance = 1e-5f) {
TfLiteSVDFParams params;
params.rank = rank;
params.activation = activation;
int inputs_array_data[] = {5, 0, 1, 2, 3, 4};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 5};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
const TFLMRegistration registration = Register_SVDF();
micro::KernelRunner runner(registration, tensors, tensor_count, inputs_array,
outputs_array, &params);
TfLiteStatus init_and_prepare_status = runner.InitAndPrepare();
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, init_and_prepare_status);
TF_LITE_MICRO_EXPECT(runner.ValidateTempBufferDeallocated());
// Abort early to make it clear init and prepare failed.
if (init_and_prepare_status != kTfLiteOk) {
return;
}
int num_inputs = input_sequences_len / (input_size * batch_size);
for (int i = 0; i < num_inputs; ++i) {
const T* input_batch_start =
input_sequences_data + i * input_size * batch_size;
memcpy(tensors[0].data.raw, input_batch_start, tensors[0].bytes);
TfLiteStatus status = runner.Invoke();
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, status);
// Only validate outputs when invoke has succeeded.
if (status == kTfLiteOk) {
int output_idx = 0;
int golden_idx = i * batch_size * num_units;
for (int j = golden_idx; j < golden_idx + batch_size * num_units; ++j) {
TF_LITE_MICRO_EXPECT_NEAR(expected_output[j], output_data[output_idx],
tolerance);
output_idx++;
}
}
}
TF_LITE_MICRO_EXPECT(runner.ValidateTempBufferDeallocated());
}
void TestSVDF(const int batch_size, const int num_units, const int input_size,
const int memory_size, const int rank,
TfLiteFusedActivation activation, float* input_data,
const float* feature_weights_data, const float* time_weights_data,
float* activation_state_data, const float* bias_data,
float* scratch_data, float* output_data,
const float* input_sequences_data, int input_sequences_len,
const float* expected_output, float tolerance = 1e-5f) {
const int num_filters = num_units * rank;
int input_dims_arg[] = {2, batch_size, input_size};
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_arg);
int feature_weights_dims_args[] = {2, num_filters, input_size};
TfLiteIntArray* feature_weights_dims =
IntArrayFromInts(feature_weights_dims_args);
int time_weights_dims_args[] = {2, num_filters, memory_size};
TfLiteIntArray* time_weights_dims = IntArrayFromInts(time_weights_dims_args);
int activation_state_dims_args[] = {2, batch_size, memory_size * num_filters};
TfLiteIntArray* activation_state_dims =
IntArrayFromInts(activation_state_dims_args);
int bias_dims_args[] = {1, num_units};
TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_args);
int output_dims_args[] = {2, batch_size, num_units};
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_args);
const int tensor_count = 6; // 5 inputs, 1 output
TfLiteTensor tensors[] = {
CreateTensor(input_data, input_dims),
CreateTensor(feature_weights_data, feature_weights_dims),
CreateTensor(time_weights_data, time_weights_dims),
CreateTensor(bias_data, bias_dims),
CreateTensor(activation_state_data, activation_state_dims,
/*is_variable=*/true),
CreateTensor(output_data, output_dims),
};
ValidateSVDFGoldens(batch_size, num_units, input_size, rank, tensors,
tensor_count, activation, input_sequences_data,
input_sequences_len, output_data, expected_output,
tolerance);
}
// The pattern to this method's arguments is:
// <kernel metadata>
// for each tensor in
// {input, feature weights, time weights, bias, activation state, output}:
// <tensor float values> <tensor quantized buffer> <tensor quantization data>
//
// Template parameter sets type of both time_weights and activation_state.
template <typename T>
inline void TestIntegerSVDF(
const int batch_size, const int num_units, const int input_size,
const int memory_size, const int rank, TfLiteFusedActivation activation,
int8_t* input_quantized, float input_scale, int input_zero_point,
const float* feature_weights_data, int8_t* feature_weights_quantized,
const float feature_weights_scale, const float* time_weights_data,
T* time_weights_quantized, float time_weights_scale, const float* bias_data,
int32_t* bias_quantized, const float* initial_activation_state_data,
T* activation_state_quantized, float activation_state_scale,
int activation_state_zero_point, int8_t* output_data, float output_scale,
int output_zero_point, const float* input_sequences_data,
int8_t* input_sequences_quantized, const int input_sequences_len,
const float* golden_output, int8_t* golden_output_quantized,
int golden_output_len) {
const int num_filters = num_units * rank;
int input_dims_arg[] = {2, batch_size, input_size};
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_arg);
int feature_weights_dims_args[] = {2, num_filters, input_size};
TfLiteIntArray* feature_weights_dims =
IntArrayFromInts(feature_weights_dims_args);
int time_weights_dims_args[] = {2, num_filters, memory_size};
TfLiteIntArray* time_weights_dims = IntArrayFromInts(time_weights_dims_args);
int bias_dims_data[] = {1, num_units};
TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data);
int activation_state_dims_args[] = {2, batch_size, memory_size * num_filters};
TfLiteIntArray* activation_state_dims =
IntArrayFromInts(activation_state_dims_args);
int output_dims_args[] = {2, batch_size, num_units};
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_args);
const int tensor_count = 6; // 5 inputs, 1 output
TfLiteTensor tensors[] = {
CreateQuantizedTensor(input_quantized, input_dims, input_scale,
input_zero_point),
CreateQuantizedTensor(feature_weights_data, feature_weights_quantized,
feature_weights_dims, feature_weights_scale, 0),
CreateQuantizedTensor(time_weights_data, time_weights_quantized,
time_weights_dims, time_weights_scale, 0),
CreateQuantizedBiasTensor(bias_data, bias_quantized, bias_dims,
time_weights_scale, activation_state_scale),
CreateQuantizedTensor(initial_activation_state_data,
activation_state_quantized, activation_state_dims,
activation_state_scale, 0,
/*is_variable=*/true),
CreateQuantizedTensor(output_data, output_dims, output_scale,
output_zero_point)};
tflite::Quantize(golden_output, golden_output_quantized, golden_output_len,
output_scale, output_zero_point);
tflite::Quantize(input_sequences_data, input_sequences_quantized,
input_sequences_len, input_scale, input_zero_point);
ValidateSVDFGoldens(batch_size, num_units, input_size, rank, tensors,
tensor_count, activation, input_sequences_quantized,
input_sequences_len, output_data, golden_output_quantized,
/*tolerance*/ 1);
}
// Template parameter sets type of both time_weights and activation_state.
template <typename T>
void SvdfQuantized2x2Input2x4OutputShouldMatchGolden() {
constexpr int batch_size = 2;
constexpr int num_units = 4;
constexpr int input_size = 2;
constexpr int memory_size = 10;
constexpr int rank = 2;
constexpr int num_filters = num_units * rank;
const int input_size_dims_count = batch_size * input_size;
const int activation_state_dims_count =
batch_size * memory_size * num_filters;
const int output_dims_count = batch_size * num_units;
int8_t output_data[output_dims_count];
float input_scale = 2.5 / std::numeric_limits<int8_t>::max();
float feature_weights_scale = 1.0 / std::numeric_limits<int8_t>::max();
float time_weights_scale = 1.0 / std::numeric_limits<T>::max();
float activation_state_scale = 1.49 / std::numeric_limits<T>::max();
float output_scale = 1.0 / std::numeric_limits<int8_t>::max();
int input_zero_point = 0;
int output_zero_point = 0;
int8_t input_quantized[input_size_dims_count];
int8_t input_sequences_quantized[sizeof(tflite::testing::input_data_2x2x10) /
sizeof(float)];
int8_t feature_weights_quantized
[sizeof(tflite::testing::feature_weights_data_2x2x10) / sizeof(float)];
T time_weights_quantized[sizeof(tflite::testing::time_weights_data_2x2x10) /
sizeof(float)];
T activation_state_quantized[activation_state_dims_count];
int32_t
bias_quantized[sizeof(tflite::testing::bias_data_2x2x10) / sizeof(float)];
int8_t golden_quantized[sizeof(tflite::testing::golden_output_2x2x10) /
sizeof(float)];
tflite::testing::TestIntegerSVDF(
batch_size, num_units, input_size, memory_size, rank, kTfLiteActRelu,
input_quantized, input_scale, input_zero_point,
tflite::testing::feature_weights_data_2x2x10, feature_weights_quantized,
feature_weights_scale, tflite::testing::time_weights_data_2x2x10,
time_weights_quantized, time_weights_scale,
tflite::testing::bias_data_2x2x10, bias_quantized,
tflite::testing::initial_activation_state_data_2x2x10,
activation_state_quantized, activation_state_scale, 0, output_data,
output_scale, output_zero_point, tflite::testing::input_data_2x2x10,
input_sequences_quantized,
sizeof(tflite::testing::input_data_2x2x10) / sizeof(float),
tflite::testing::golden_output_2x2x10, golden_quantized,
sizeof(tflite::testing::golden_output_2x2x10) / sizeof(float));
}
// Template parameter sets type of both time_weights and activation_state.
template <typename T>
void SvdfQuantized1x16Input64x1OutputShouldMatchGolden() {
constexpr int batch_size = 1;
constexpr int num_units = 64;
constexpr int input_size = 16;
constexpr int memory_size = 8;
constexpr int rank = 1;
constexpr int num_filters = num_units * rank;
constexpr int activation_state_dims_count =
batch_size * memory_size * num_filters;
constexpr int output_dims_count = batch_size * num_units;
constexpr int input_dims_count = batch_size * input_size;
int8_t output_data[output_dims_count];
float input_scale = 0.10075444;
float feature_weights_scale = 0.00649388;
float time_weights_scale = tflite::testing::ScaleFromMinMax<T>(-.81, .81);
float activation_state_scale =
tflite::testing::ScaleFromMinMax<T>(-17.73, 17.73);
int activation_state_zero_point =
tflite::testing::ZeroPointFromMinMax<T>(-17.73, 17.73);
float output_scale = 0.051445257;
int input_zero_point = 2;
int output_zero_point = 0;
int8_t input_quantized[input_dims_count];
int8_t input_sequences_quantized[sizeof(tflite::testing::input_data_16x1x1) /
sizeof(float)];
int8_t feature_weights_quantized
[sizeof(tflite::testing::feature_weights_data_16x1x1) / sizeof(float)];
T time_weights_quantized[sizeof(tflite::testing::time_weights_data_16x1x1) /
sizeof(float)];
T activation_state_quantized[activation_state_dims_count];
int32_t
bias_quantized[sizeof(tflite::testing::bias_data_16x1x1) / sizeof(float)];
int8_t golden_quantized[sizeof(tflite::testing::golden_output_16x1x1) /
sizeof(float)];
tflite::testing::TestIntegerSVDF(
batch_size, num_units, input_size, memory_size, rank, kTfLiteActNone,
input_quantized, input_scale, input_zero_point,
tflite::testing::feature_weights_data_16x1x1, feature_weights_quantized,
feature_weights_scale, tflite::testing::time_weights_data_16x1x1,
time_weights_quantized, time_weights_scale,
tflite::testing::bias_data_16x1x1, bias_quantized,
tflite::testing::initial_activation_state_data_16x1x1,
activation_state_quantized, activation_state_scale,
activation_state_zero_point, output_data, output_scale, output_zero_point,
tflite::testing::input_data_16x1x1, input_sequences_quantized,
sizeof(tflite::testing::input_data_16x1x1) / sizeof(float),
tflite::testing::golden_output_16x1x1, golden_quantized,
sizeof(tflite::testing::golden_output_16x1x1) / sizeof(float));
}
template <typename T>
void SvdfQuantized1x16Input64x1OutputReluShouldMatchGolden() {
constexpr int batch_size = 1;
constexpr int num_units = 64;
constexpr int input_size = 16;
constexpr int memory_size = 8;
constexpr int rank = 1;
constexpr int num_filters = num_units * rank;
constexpr int activation_state_dims_count =
batch_size * memory_size * num_filters;
constexpr int output_dims_count = batch_size * num_units;
constexpr int input_dims_count = batch_size * input_size;
int8_t output_data[output_dims_count];
float input_scale = 0.10075444;
float feature_weights_scale = 0.00649388;
float time_weights_scale = tflite::testing::ScaleFromMinMax<T>(-.81, .81);
float activation_state_scale =
tflite::testing::ScaleFromMinMax<T>(-17.73, 17.73);
int activation_state_zero_point =
tflite::testing::ZeroPointFromMinMax<T>(-17.73, 17.73);
float output_scale = 0.051445257;
int input_zero_point = 2;
int output_zero_point = -128;
int8_t input_quantized[input_dims_count];
int8_t input_sequences_quantized[sizeof(tflite::testing::input_data_16x1x1) /
sizeof(float)];
int8_t feature_weights_quantized
[sizeof(tflite::testing::feature_weights_data_16x1x1) / sizeof(float)];
T time_weights_quantized[sizeof(tflite::testing::time_weights_data_16x1x1) /
sizeof(float)];
T activation_state_quantized[activation_state_dims_count];
int32_t
bias_quantized[sizeof(tflite::testing::bias_data_16x1x1) / sizeof(float)];
int8_t golden_quantized[sizeof(tflite::testing::golden_output_relu_16x1x1) /
sizeof(float)];
tflite::testing::TestIntegerSVDF(
batch_size, num_units, input_size, memory_size, rank, kTfLiteActRelu,
input_quantized, input_scale, input_zero_point,
tflite::testing::feature_weights_data_16x1x1, feature_weights_quantized,
feature_weights_scale, tflite::testing::time_weights_data_16x1x1,
time_weights_quantized, time_weights_scale,
tflite::testing::bias_data_16x1x1, bias_quantized,
tflite::testing::initial_activation_state_data_16x1x1,
activation_state_quantized, activation_state_scale,
activation_state_zero_point, output_data, output_scale, output_zero_point,
tflite::testing::input_data_16x1x1, input_sequences_quantized,
sizeof(tflite::testing::input_data_16x1x1) / sizeof(float),
tflite::testing::golden_output_relu_16x1x1, golden_quantized,
sizeof(tflite::testing::golden_output_relu_16x1x1) / sizeof(float));
}
} // namespace
} // namespace testing
} // namespace tflite
TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST(SvdfFloat2x2Input2x4OutputShouldMatchGolden) {
constexpr int batch_size = 2;
constexpr int num_units = 4;
constexpr int input_size = 2;
constexpr int memory_size = 10;
constexpr int rank = 2;
constexpr int num_filters = num_units * rank;
const int input_size_dims_count = batch_size * input_size;
float input_data[input_size_dims_count];
const int activation_state_dims_count =
batch_size * memory_size * num_filters;
float activation_state_data[activation_state_dims_count];
memcpy(activation_state_data,
tflite::testing::initial_activation_state_data_2x2x10,
sizeof(tflite::testing::initial_activation_state_data_2x2x10));
const int scratch_dims_count = batch_size * num_filters;
float scratch_data[scratch_dims_count];
const int output_dims_count = batch_size * num_units;
float output_data[output_dims_count];
tflite::testing::TestSVDF(
batch_size, num_units, input_size, memory_size, rank, kTfLiteActNone,
input_data, tflite::testing::feature_weights_data_2x2x10,
tflite::testing::time_weights_data_2x2x10, activation_state_data,
tflite::testing::bias_data_2x2x10, scratch_data, output_data,
tflite::testing::input_data_2x2x10,
sizeof(tflite::testing::input_data_2x2x10) / sizeof(float),
tflite::testing::golden_output_2x2x10);
}
// Only reference kernels support full int8 svdf currently.
#if !defined(HEXAGON)
TF_LITE_MICRO_TEST(SvdfQuantized2x2Input2x4OutputShouldMatchGoldenInt8) {
tflite::testing::SvdfQuantized2x2Input2x4OutputShouldMatchGolden<int8_t>();
}
#endif
TF_LITE_MICRO_TEST(SvdfQuantized2x2Input2x4OutputShouldMatchGoldenInt16) {
tflite::testing::SvdfQuantized2x2Input2x4OutputShouldMatchGolden<int16_t>();
}
TF_LITE_MICRO_TEST(SvdfFloat1x16Input64x1OutputShouldMatchGolden) {
constexpr int batch_size = 1;
constexpr int num_units = 64;
constexpr int input_size = 16;
constexpr int memory_size = 8;
constexpr int rank = 1;
constexpr int num_filters = num_units * rank;
constexpr int activation_state_dims_count =
batch_size * memory_size * num_filters;
constexpr int output_dims_count = batch_size * num_units;
constexpr int input_dims_count = batch_size * input_size;
float input_data[input_dims_count];
float output_data[output_dims_count];
float scratch_buffer[batch_size * num_filters];
float activation_state_data_mutable[activation_state_dims_count];
// Initialize activation state to starting values.
memcpy(activation_state_data_mutable,
tflite::testing::initial_activation_state_data_16x1x1,
sizeof(tflite::testing::initial_activation_state_data_16x1x1));
tflite::testing::TestSVDF(
batch_size, num_units, input_size, memory_size, rank, kTfLiteActNone,
input_data, tflite::testing::feature_weights_data_16x1x1,
tflite::testing::time_weights_data_16x1x1, activation_state_data_mutable,
tflite::testing::bias_data_16x1x1, scratch_buffer, output_data,
tflite::testing::input_data_16x1x1, input_size,
tflite::testing::golden_output_16x1x1);
}
TF_LITE_MICRO_TEST(SvdfFloat1x16Input64x1OutputReluShouldMatchGolden) {
constexpr int batch_size = 1;
constexpr int num_units = 64;
constexpr int input_size = 16;
constexpr int memory_size = 8;
constexpr int rank = 1;
constexpr int num_filters = num_units * rank;
constexpr int activation_state_dims_count =
batch_size * memory_size * num_filters;
constexpr int output_dims_count = batch_size * num_units;
constexpr int input_dims_count = batch_size * input_size;
float input_data[input_dims_count];
float output_data[output_dims_count];
float scratch_buffer[batch_size * num_filters];
float activation_state_data_mutable[activation_state_dims_count];
// Initialize activation state to starting values.
memcpy(activation_state_data_mutable,
tflite::testing::initial_activation_state_data_16x1x1,
sizeof(tflite::testing::initial_activation_state_data_16x1x1));
tflite::testing::TestSVDF(
batch_size, num_units, input_size, memory_size, rank, kTfLiteActRelu,
input_data, tflite::testing::feature_weights_data_16x1x1,
tflite::testing::time_weights_data_16x1x1, activation_state_data_mutable,
tflite::testing::bias_data_16x1x1, scratch_buffer, output_data,
tflite::testing::input_data_16x1x1, input_size,
tflite::testing::golden_output_relu_16x1x1);
}
// Only reference kernels support full int8 svdf currently.
#if !defined(HEXAGON)
TF_LITE_MICRO_TEST(SvdfQuantized1x16Input64x1OutputShouldMatchGoldenInt8) {
tflite::testing::SvdfQuantized1x16Input64x1OutputShouldMatchGolden<int8_t>();
}
#endif
TF_LITE_MICRO_TEST(SvdfQuantized1x16Input64x1OutputShouldMatchGoldenInt16) {
tflite::testing::SvdfQuantized1x16Input64x1OutputShouldMatchGolden<int16_t>();
}
// Only reference kernels support full int8 svdf currently.
#if !defined(HEXAGON)
TF_LITE_MICRO_TEST(SvdfQuantized1x16Input64x1OutputReluShouldMatchGoldenInt8) {
tflite::testing::SvdfQuantized1x16Input64x1OutputReluShouldMatchGolden<
int8_t>();
}
#endif
TF_LITE_MICRO_TEST(SvdfQuantized1x16Input64x1OutputReluShouldMatchGoldenInt16) {
tflite::testing::SvdfQuantized1x16Input64x1OutputReluShouldMatchGolden<
int16_t>();
}
TF_LITE_MICRO_TESTS_END