Xtensa LSTM: (#2150)
Enabled LSTM kernel support for XTENSA target.
Updated xtensa_downloads script to use the latest HiFi NN Libraries.
The 8x16 unit test cases has non-zero zero_point for 16 bit output.
[https://github.com/tensorflow/tflite-micro/blob/main/tensorflow/lite/micro/kernels/testdata/lstm_test_data.cc#L255C1-L258C61](url)
Default run for all the 8x16 unit test cases result: FAIL. This is due to non-zero output offset value.
BUG=#1867
diff --git a/tensorflow/lite/micro/kernels/lstm_eval_test.cc b/tensorflow/lite/micro/kernels/lstm_eval_test.cc
index 53c0d7c..eaba2c4 100644
--- a/tensorflow/lite/micro/kernels/lstm_eval_test.cc
+++ b/tensorflow/lite/micro/kernels/lstm_eval_test.cc
@@ -454,6 +454,6 @@
cell_state_tolerance,
int16_node_contents);
}
-
#endif // !defined(XTENSA)
+
TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/lite/micro/kernels/unidirectional_sequence_lstm_test.cc b/tensorflow/lite/micro/kernels/unidirectional_sequence_lstm_test.cc
index c85e56f..ea11afc 100644
--- a/tensorflow/lite/micro/kernels/unidirectional_sequence_lstm_test.cc
+++ b/tensorflow/lite/micro/kernels/unidirectional_sequence_lstm_test.cc
@@ -150,7 +150,6 @@
TF_LITE_MICRO_TESTS_BEGIN
// TODO(b/230666079) enable below tests for xtensa when the xtensa
// kernel is reconciled with reference kernel
-#if !defined(XTENSA)
TF_LITE_MICRO_TEST(TestUnidirectionalLSTMFloat) {
const tflite::testing::LstmEvalCheckData<12, 4, 12> kernel_eval_data =
tflite::testing::Get2X2LstmEvalCheckData();
@@ -193,5 +192,4 @@
kernel_eval_data, hidden_state_tolerance, cell_state_tolerance,
int16_node_contents);
}
-#endif // !defined(XTENSA)
TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/lite/micro/kernels/xtensa/lstm_eval.cc b/tensorflow/lite/micro/kernels/xtensa/lstm_eval.cc
index 9065388..af5bad7 100644
--- a/tensorflow/lite/micro/kernels/xtensa/lstm_eval.cc
+++ b/tensorflow/lite/micro/kernels/xtensa/lstm_eval.cc
@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -14,1204 +14,478 @@
==============================================================================*/
#include "tensorflow/lite/micro/kernels/xtensa/lstm_eval.h"
-#include <math.h>
-#include <string.h>
+#include <limits>
-#include <algorithm>
-#include <cstdint>
-#include <memory>
-#include <vector>
-
-#include "tensorflow/lite/c/builtin_op_data.h"
-#include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/kernels/internal/compatibility.h"
-#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h"
-#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
-#include "tensorflow/lite/kernels/op_macros.h"
+#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
+#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
+#include "tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h"
+#include "tensorflow/lite/kernels/internal/reference/integer_ops/mul.h"
+#include "tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h"
+#include "tensorflow/lite/kernels/internal/reference/logistic.h"
+#include "tensorflow/lite/kernels/internal/reference/mul.h"
+#include "tensorflow/lite/kernels/internal/reference/tanh.h"
+#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/micro/kernels/xtensa/xtensa.h"
namespace tflite {
-namespace ops {
-namespace micro {
-namespace lstm_eval {
-namespace {
-// Calculates a single LSTM gate, int8x8_16 version.
-// Implements the same functionality as CalculateLstmGateFloat.
-void CalculateLstmGateInteger8x8_16(
- // Input and weights
- const int8_t* input, const int8_t* input_to_gate_weights,
- const int32_t* input_to_gate_bias, const int32_t input_to_gate_scale_a,
- const int32_t input_to_gate_scale_b,
- // Output state and weights
- const int8_t* output_state, const int8_t* recurrent_to_gate_weights,
- const int32_t* recurrent_to_gate_bias,
- const int32_t recurrent_to_gate_scale_a,
- const int32_t recurrent_to_gate_scale_b,
- // Cell state and weights
- const int16_t* cell_state, const int16_t* cell_to_gate_weights,
- const int32_t cell_to_gate_scale_a, const int32_t cell_to_gate_scale_b,
- // Layer normalization parameters (layer norm LSTM)
- const int16_t* layer_norm_coefficients, const int32_t* layer_norm_bias,
- const int32_t layer_norm_input_scale_a,
- const int32_t layer_norm_input_scale_b,
- const int32_t layer_norm_variance_guard,
- // Array sizes
- const int n_batch, const int n_input, const int n_output, const int n_cell,
- const TfLiteFusedActivation activation,
- // Output
- int16_t* gate,
- // Parameters for performance optimizations
- // CpuBackendContext* context,
- // Scratch arrays
- int32_t* scratch5) {
- const bool use_peephole = (cell_to_gate_weights != nullptr);
- const bool use_layer_norm = (layer_norm_coefficients != nullptr);
-
- // Initialize scratch buffers with zeros. Note that unlike float and hybrid
- // versions, bias is only used in layer normalization.
- std::fill_n(gate, n_batch * n_cell, 0);
-#if !defined(HIFI5)
- // For each batch and cell: compute input_weight * input.
- tensor_utils::PortableMatrixBatchVectorMultiplyAccumulate(
- input, input_to_gate_bias, input_to_gate_weights, input_to_gate_scale_a,
- input_to_gate_scale_b, n_batch, n_input, n_cell, 0, scratch5, gate, NULL);
-#else
- {
- xa_nn_matXvec_acc_batch_sym8sx8_asym16s(
- gate, input_to_gate_weights, input, input_to_gate_bias, n_cell, n_input,
- n_input, input_to_gate_scale_a, input_to_gate_scale_b, 0, n_batch);
+LstmTensors::LstmTensors(TfLiteContext* context, TfLiteNode* node) {
+ micro_context_ = GetMicroContext(context);
+ // 24 internal tensors. see lstm_shared.h for tensor names
+ for (size_t i = 0; i < 24; i++) {
+ internal_tensors_[i] = micro_context_->AllocateTempInputTensor(node, i);
}
-#endif // !defined(HIFI5)
-// Note: no aux_input.
-
-// For each batch and cell: compute recurrent_weight * output_state.
-#if !defined(HIFI5)
- tensor_utils::PortableMatrixBatchVectorMultiplyAccumulate(
- output_state, recurrent_to_gate_bias, recurrent_to_gate_weights,
- recurrent_to_gate_scale_a, recurrent_to_gate_scale_b, n_batch, n_output,
- n_cell, 0, scratch5, gate, NULL);
-#else
- {
- xa_nn_matXvec_acc_batch_sym8sx8_asym16s(
- gate, recurrent_to_gate_weights, output_state, recurrent_to_gate_bias,
- n_cell, n_output, n_output, recurrent_to_gate_scale_a,
- recurrent_to_gate_scale_b, 0, n_batch);
- }
-#endif // !defined(HIFI5)
- // For each batch and cell: compute cell_weight * cell_state (peephole LSTM)
- if (use_peephole) {
- tensor_utils::PortableVectorBatchVectorCwiseProductAccumulate(
- cell_to_gate_weights, n_output, cell_state, n_batch,
- cell_to_gate_scale_a, cell_to_gate_scale_b, gate);
- }
- // Do layer normalization (if layer norm LSTM)
- if (use_layer_norm) {
- tensor_utils::PortableApplyLayerNorm(
- gate, layer_norm_coefficients, layer_norm_bias,
- layer_norm_input_scale_a, layer_norm_input_scale_b,
- layer_norm_variance_guard, n_batch, n_cell, gate);
- }
- // Apply activation
- switch (activation) {
- case kTfLiteActSigmoid:
-#if !defined(HIFI5)
- tensor_utils::PortableApplySigmoid(gate, n_batch, n_cell, gate);
-#else
- xa_nn_vec_sigmoid_16_16(gate, gate, n_batch * n_cell);
-#endif // !defined(HIFI5)
- break;
- case kTfLiteActTanh:
-#if !defined(HIFI5)
- tensor_utils::PortableApplyTanh(3, gate, n_batch, n_cell, gate);
-#else
- xa_nn_vec_tanh_16_16(gate, gate, 3, n_batch * n_cell);
-#endif // !defined(HIFI5)
- break;
- default:
- // Only Sigmoid or Tanh is used.
- TFLITE_ASSERT_FALSE;
- }
+ output_tensor_ =
+ micro_context_->AllocateTempOutputTensor(node, kLstmOutputTensor);
}
-// Updates the LSTM cell state, used by both integer LSTM versions.
-// Also see UpdateLstmCellFloat.
-//
-// Parameters:
-// - n_batch, n_cell: sizes of vectors
-// - cell_state: input/output vector, size n_batch*n_cell
-// - cell_state_scale: scaling factor of cell state.
-// - input_gate: input vector, size n_batch*n_cell.
-// - forget_gate: input/scratch vector, size n_batch*n_cell, always modified.
-// - cell_gate: input vector, size n_batch*n_cell.
-// - use_cifg: use 1-forget_gate instead of input_gate.
-// - clip: if > 0, clip the resulting cell state to [-clip, +clip].
-void UpdateLstmCellInteger(int n_batch, int n_cell, int16_t* cell_state,
- int32_t cell_state_scale, const int16_t* input_gate,
- int16_t* forget_gate, const int16_t* cell_gate,
- bool use_cifg, int16_t clip) {
-#if !defined(HIFI5)
- // Use the forget_gate array as scratch, as input_gate array is not allocated
- // in CIFG case. (Be careful not to write to the scratch before reading the
- // forget gate data.)
- int16_t* scratch = forget_gate;
-
- tensor_utils::PortableCwiseMul(forget_gate, cell_state, n_batch, n_cell, 15,
- cell_state);
- if (use_cifg) {
- tensor_utils::PortableSub1Vector(forget_gate, n_batch * n_cell, scratch);
- tensor_utils::PortableCwiseMul(scratch, cell_gate, n_batch, n_cell,
- 30 + cell_state_scale, scratch);
- } else {
- tensor_utils::PortableCwiseMul(input_gate, cell_gate, n_batch, n_cell,
- 30 + cell_state_scale, scratch);
- }
- tensor_utils::PortableCwiseAdd(cell_state, scratch, n_batch, n_cell,
- cell_state);
-
- if (clip > 0) {
- tensor_utils::PortableCwiseClipping(cell_state, n_batch * n_cell, clip);
- }
-#else
- if (use_cifg) {
- calc_cell_state_with_cifg(cell_state, forget_gate, cell_gate, 15,
- 30 + cell_state_scale, clip, n_batch * n_cell);
- } else {
- calc_cell_state_without_cifg(cell_state, forget_gate, cell_gate, input_gate,
- 15, 30 + cell_state_scale, clip,
- n_batch * n_cell);
- }
-
-#endif // !defined(HIFI5)
-}
-
-// Calculates the output state tensor of an LSTM step. See Float and hybrid
-// versions as well.
-//
-// Parameters:
-// - n_batch: batches: the number of distinct vectors in each array.
-// - n_cell, n_output: sizes of vectors.
-// - cell_state, output_gate: input vectors, size n_batch*n_cell.
-// - cell_state_scale: scaling of cell_state.
-// - hidden_scale_[a|b]: effective scale of cell_state.*output_gate
-// - hidden_zp: zero_point for cell_state.*output_gate
-// - projection_weights, proj_scale_[a|b], projection_bias:
-// constant inputs, describing projection matrix and bias.
-// - output_state_zp: zero point of output_state. (Input, calibrated value.)
-// - quantized_proj_clip: if > 0, clip the output of the projection.
-// - output_state: output vector, size n_batch*n_output. Must be contiguous.
-// - context: data for optimized MatrixBatchVectorMultiplyAccumulate.
-// - scratch0: scratch area of size n_batch*n_cell
-// - scratch1: scratch area of size n_batch*n_cell
-// - scratch2: scratch area used by MatrixBatchVectorMultiplyAccumulate
-void CalculateLstmOutputInteger8x8_16(
- int n_batch, int n_cell, int n_output, const int16_t* cell_state,
- int32_t cell_state_scale, const int16_t* output_gate,
- int32_t hidden_scale_a, int32_t hidden_scale_b, int32_t hidden_zp,
- const int8_t* projection_weights, int32_t proj_scale_a,
- int32_t proj_scale_b, const int32_t* projection_bias,
- int32_t output_state_zp, int8_t quantized_proj_clip, int8_t* output_state,
- int16_t* scratch0, int8_t* scratch1, int32_t* scratch2) {
-// Note: unlike float/hybrid, the activation is always Tanh.
-#if !defined(HIFI5)
- tensor_utils::PortableApplyTanh(15 + cell_state_scale, cell_state, n_batch,
- n_cell, scratch0);
-#else
- xa_nn_vec_tanh_16_16(scratch0, cell_state, (15 + cell_state_scale),
- n_batch * n_cell);
-#endif // !defined(HIFI5)
-
-#if !defined(HIFI5)
- tensor_utils::PortableCwiseMul(output_gate, scratch0, hidden_scale_a,
- hidden_scale_b, n_batch, n_cell, hidden_zp,
- scratch1);
-#else
- xa_nn_elm_mul_16x16_asym8s(scratch1, output_gate, scratch0, hidden_scale_a,
- hidden_scale_b, hidden_zp, n_batch * n_cell);
-#endif // !defined(HIFI5)
-
- const bool use_projection = (projection_weights != nullptr);
-
- if (use_projection) {
- // Note: no bias like in float/hybrid
- std::fill_n(output_state, n_batch * n_output, 0);
- tensor_utils::PortableMatrixBatchVectorMultiplyAccumulate(
- scratch1, projection_bias, projection_weights, proj_scale_a,
- proj_scale_b, n_batch, n_cell, n_output, output_state_zp, scratch2,
- output_state, NULL);
- if (quantized_proj_clip > 0) {
- tensor_utils::PortableCwiseClipping(output_state, n_batch * n_output,
- quantized_proj_clip);
- }
- } else {
- std::copy_n(scratch1, n_batch * n_output, output_state);
- }
-}
-
-// Calculates a single LSTM gate, int8x8_8 version.
-// Implements the same functionality as CalculateLstmGateFloat.
-void CalculateLstmGateInteger8x8_8(
- // Inputs and weights
- const int8_t* input, int32_t input_zp, const int8_t* input_to_gate_weight,
- const int32_t input_to_gate_scale_a, const int32_t input_to_gate_scale_b,
- const int32_t input_times_weights_scale_a,
- const int32_t input_times_weights_scale_b,
- const int32_t input_times_weights_zp,
- // Output state and weights
- const int8_t* output_state, const int32_t output_state_zp,
- const int8_t* recurrent_to_gate_weight,
- const int32_t recurrent_to_gate_scale_a,
- const int32_t recurrent_to_gate_scale_b,
- const int32_t output_state_times_weights_scale_a,
- const int32_t output_state_times_weights_scale_b,
- const int32_t output_state_times_weights_zp,
- // Layer normalization parameters (layer norm LSTM)
- const int16_t* layer_norm_gate_weight,
- const int32_t layer_norm_gate_scale_a,
- const int32_t layer_norm_gate_scale_b, const int32_t* gate_bias,
- // Array sizes
- const int n_batch, const int n_input, const int n_output, const int n_cell,
- const TfLiteFusedActivation activation,
- // Output
- int16_t* gate,
- // Scratch arrays, both sized n_batch*n_cell
- int8_t* scratch0, int8_t* scratch1) {
- // Multiply input * input_weights => scratch0
- tensor_utils::PortableMatrixBatchVectorMultiply(
- input, input_zp, input_to_gate_weight, input_to_gate_scale_a,
- input_to_gate_scale_b, n_batch, n_input, n_cell, scratch0,
- input_times_weights_zp);
- // Multiply output_state * recurrent_weights => scratch1
- tensor_utils::PortableMatrixBatchVectorMultiply(
- output_state, output_state_zp, recurrent_to_gate_weight,
- recurrent_to_gate_scale_a, recurrent_to_gate_scale_b, n_batch, n_output,
- n_cell, scratch1, output_state_times_weights_zp);
- // Add scratch0 + scratch1 => gate
- tensor_utils::PortableTwoGateSaturatingAdd(
- scratch0, input_times_weights_zp, scratch1, output_state_times_weights_zp,
- input_times_weights_scale_a, input_times_weights_scale_b,
- output_state_times_weights_scale_a, output_state_times_weights_scale_b,
- n_batch, n_cell, gate);
- // Apply layer normalization.
- tensor_utils::PortableApplyLayerNormFloat(
- gate, layer_norm_gate_weight, layer_norm_gate_scale_a,
- layer_norm_gate_scale_b, gate_bias, n_batch, n_cell, gate);
- // Apply activation.
- switch (activation) {
- case kTfLiteActSigmoid:
- tensor_utils::PortableApplySigmoidFloat(gate, n_batch, n_cell, gate);
- break;
- case kTfLiteActTanh:
- tensor_utils::PortableApplyTanhFloat(gate, n_batch, n_cell, -12, gate);
- break;
- default:
- // Only Sigmoid or Tanh is used.
- TFLITE_ASSERT_FALSE;
- }
-}
-
-// Calculates the output state tensor of an LSTM step. See Float and hybrid
-// versions as well.
-//
-// Parameters:
-// - n_batch: batches: the number of distinct vectors in each array.
-// - n_cell, n_output: sizes of vectors.
-// - cell_state, output_gate: input vectors, size n_batch*n_cell.
-// - projection_weights, proj_scale_[a|b], projection_bias:
-// constant inputs, describing projection matrix and bias.
-// - output_state_zp: zero point of the output state.
-// - quantized_proj_clip: if > 0, clip the output of the projection.
-// - output_state: output vector, size n_batch*n_output. Must be contiguous.
-// - scratch: scratch area of size n_batch*n_cell
-void CalculateLstmOutputInteger8x8_8(
- int n_batch, int n_cell, int n_output, const int16_t* cell_state,
- const int16_t* output_gate, const int8_t* projection_weights,
- int32_t proj_scale_a, int32_t proj_scale_b, const int32_t* projection_bias,
- int32_t output_state_zp, int32_t quantized_proj_clip, int8_t* output_state,
- int16_t* scratch) {
- // Note: unlike float/hybrid, the activation is always Tanh.
- tensor_utils::PortableApplyTanhFloat(cell_state, n_batch, n_cell, -15,
- scratch);
- tensor_utils::PortableCwiseMul(output_gate, scratch, n_batch, n_cell,
- 15 + 15 - 15, scratch);
- // Note: no bias like in float/hybrid
- tensor_utils::PortableMatrixBatchVectorMultiply(
- scratch, projection_weights, proj_scale_a, proj_scale_b, projection_bias,
- n_batch, n_cell, n_output, output_state_zp, output_state);
- if (quantized_proj_clip > 0) {
- tensor_utils::PortableCwiseClipping(output_state, n_batch * n_output,
- (int8_t)quantized_proj_clip);
- }
-}
-
-// Fully quantized lstm kernel for 16 bit gate matmul output.
-//
-// Input tensor of size n_batch * n_input:
-// input_ptr
-//
-// LSTM weights:
-// Quantized input weights of size 'n_cell * n_input':
-// input_to_input_weight_ptr - optional
-// input_to_forget_weight_ptr - optional
-// input_to_cell_weight_ptr - optional
-// input_to_output_weight_ptr - optional
-//
-// Quantized recurrent weights of size 'n_cell * n_output':
-// recurrent_to_input_weight_ptr - optional
-// recurrent_to_forget_weights_ptr
-// recurrent_to_cell_weights_ptr
-// recurrent_to_input_weights_ptr
-//
-// Quantized peephole weights of size 'n_cell', representing diagonal matrices.
-// cell_to_input_weights - optional
-// cell_to_cell_weights - optional
-// cell_to_output_weights - optional
-//
-// Quantized projection weights of size 'n_output * n_cell'
-// projection_weight_ptr - optional
-//
-// Weight scales (scalars) for each of the weights above.
-// effective_input_to_input_scale_a - optional
-// effective_input_to_input_scale_b - optional
-// effective_input_to_forget_scale_a
-// effective_input_to_forget_scale_b
-// effective_input_to_cell_scale_a
-// effective_input_to_cell_scale_b
-// effective_input_to_output_scale_a
-// effective_input_to_output_scale_b
-// effective_recurrent_to_input_scale_a - optional
-// effective_recurrent_to_input_scale_b - optional
-// effective_recurrent_to_forget_scale_a
-// effective_recurrent_to_forget_scale_b
-// effective_recurrent_to_cell_scale_a
-// effective_recurrent_to_cell_scale_b
-// effective_recurrent_to_output_scale_a
-// effective_recurrent_to_output_scale_b
-// effective_proj_scale_a - optional
-// effective_proj_scale_b - optional
-//
-// Gate biases of size 'n_cell':
-// input_gate_bias_ptr - optional
-// forget_gate_bias_ptr
-// cell_gate_bias_ptr
-// output_gate_bias_ptr
-//
-// Layer norm coefficients of size 'n_cell', representing diagonal matrices.
-// layer_norm_input_weight_ptr - optional
-// layer_norm_forget_weight_ptr - optional
-// layer_norm_cell_weight_ptr - optional
-// layer_norm_output_weight_ptr - optional
-//
-// Layer norm scales of size 'n_cell'.
-// layer_norm_input_scale_a - optional
-// layer_norm_input_scale_b - optional
-// layer_norm_forget_scale_a - optional
-// layer_norm_forget_scale_b - optional
-// layer_norm_cell_scale_a - optional
-// layer_norm_cell_scale_b - optional
-// layer_norm_output_scale_a - optional
-// layer_norm_output_scale_b - optional
-//
-// Scalar values:
-// quantized_cell_clip: quantized clip value for cell.
-// quantized_proj_clip: quantized clip value for projection.
-// cell_state_scale: the power of two scale for cell state.
-//
-// Zero points:
-// output_state_zp: zero point of output state
-// hidden_zp: zero point for hidden state.
-//
-// Temporary pre-allocated storage for the calculation. Each is of size n_cell *
-// n_batch.
-// scratch0
-// scratch1
-// scratch2
-// scratch3
-// scratch4
-// scratch5: this scratch buffer is created purely for optimizing the
-// MatrixBatchVectorMultiplyAccumulate.
-//
-// Outputs:
-// output_state_ptr - size 'n_batch * n_output'
-// cell_state_ptr - size 'n_batch * n_cell'
-// output_ptr - size 'n_batch * n_output'
-// TODO(b/159947023): scratch0 is not used if (!cifg). Don't allocate then.
-inline void LstmStepInteger8x8_16(
- const int8_t* input_ptr, const int8_t* input_to_input_weight_ptr,
- int32_t effective_input_to_input_scale_a,
- int32_t effective_input_to_input_scale_b,
- const int8_t* input_to_forget_weight_ptr,
- int32_t effective_input_to_forget_scale_a,
- int32_t effective_input_to_forget_scale_b,
- const int8_t* input_to_cell_weight_ptr,
- int32_t effective_input_to_cell_scale_a,
- int32_t effective_input_to_cell_scale_b,
- const int8_t* input_to_output_weight_ptr,
- int32_t effective_input_to_output_scale_a,
- int32_t effective_input_to_output_scale_b,
- const int8_t* recurrent_to_input_weight_ptr,
- int32_t effective_recurrent_to_input_scale_a,
- int32_t effective_recurrent_to_input_scale_b,
- const int8_t* recurrent_to_forget_weight_ptr,
- int32_t effective_recurrent_to_forget_scale_a,
- int32_t effective_recurrent_to_forget_scale_b,
- const int8_t* recurrent_to_cell_weight_ptr,
- int32_t effective_recurrent_to_cell_scale_a,
- int32_t effective_recurrent_to_cell_scale_b,
- const int8_t* recurrent_to_output_weight_ptr,
- int32_t effective_recurrent_to_output_scale_a,
- int32_t effective_recurrent_to_output_scale_b,
- const int16_t* cell_to_input_weight_ptr,
- int32_t effective_cell_to_input_scale_a,
- int32_t effective_cell_to_input_scale_b,
- const int16_t* cell_to_forget_weight_ptr,
- int32_t effective_cell_to_forget_scale_a,
- int32_t effective_cell_to_forget_scale_b,
- const int16_t* cell_to_output_weight_ptr,
- int32_t effective_cell_to_output_scale_a,
- int32_t effective_cell_to_output_scale_b,
- const int8_t* projection_weight_ptr, int32_t effective_proj_scale_a,
- int32_t effective_proj_scale_b, int32_t hidden_zp,
- int32_t effective_hidden_scale_a, int32_t effective_hidden_scale_b,
- const int16_t* layer_norm_input_weight_ptr,
- int32_t layer_norm_input_scale_a, int32_t layer_norm_input_scale_b,
- const int16_t* layer_norm_forget_weight_ptr,
- int32_t layer_norm_forget_scale_a, int32_t layer_norm_forget_scale_b,
- const int16_t* layer_norm_cell_weight_ptr, int32_t layer_norm_cell_scale_a,
- int32_t layer_norm_cell_scale_b,
- const int16_t* layer_norm_output_weight_ptr,
- int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
- const int32_t* input_gate_bias_ptr, const int32_t* forget_gate_bias_ptr,
- const int32_t* cell_gate_bias_ptr, const int32_t* output_gate_bias_ptr,
- int16_t quantized_cell_clip, int8_t quantized_proj_clip,
- int32_t cell_state_scale, int32_t input_variance_guard,
- int32_t forget_variance_guard, int32_t cell_variance_guard,
- int32_t output_variance_guard,
- const int32_t* input_to_forget_effective_bias,
- const int32_t* recurrent_to_forget_effective_bias,
- const int32_t* input_to_cell_effective_bias,
- const int32_t* recurrent_to_cell_effective_bias,
- const int32_t* input_to_output_effective_bias,
- const int32_t* recurrent_to_output_effective_bias,
- const int32_t* input_to_input_effective_bias,
- const int32_t* recurrent_to_input_effective_bias,
- const int32_t* projection_effective_bias, int n_batch, int n_cell,
- int n_input, int n_output, int8_t* output_state_ptr,
- int32_t output_state_zp, int16_t* cell_state_ptr, int8_t* output_ptr,
- int16_t* scratch0, int16_t* scratch1, int16_t* scratch2, int16_t* scratch3,
- int8_t* scratch4, int32_t* scratch5) {
- // ruy::profiler::ScopeLabel label("LstmStepInteger8x8_16");
- // Make named scratch buffers for the different gates.
- int16_t* input_gate_scratch = scratch0;
- int16_t* forget_gate_scratch = scratch1;
- int16_t* cell_gate_scratch = scratch2;
- int16_t* output_gate_scratch = scratch3;
-
- // Since we have already checked that weights are all there or none, we
- // can check the existence of only one to the get the condition.
- const bool use_cifg = (input_to_input_weight_ptr == nullptr);
-
- // Check for nullptrs.
- TFLITE_DCHECK(input_to_forget_effective_bias);
- TFLITE_DCHECK(recurrent_to_forget_effective_bias);
- TFLITE_DCHECK(input_to_cell_effective_bias);
- TFLITE_DCHECK(recurrent_to_cell_effective_bias);
- TFLITE_DCHECK(input_to_output_effective_bias);
- TFLITE_DCHECK(recurrent_to_output_effective_bias);
- if (!use_cifg) {
- TFLITE_DCHECK(input_to_input_effective_bias);
- TFLITE_DCHECK(recurrent_to_input_effective_bias);
- }
- const bool use_projection = (projection_weight_ptr != nullptr);
- if (use_projection) {
- TFLITE_DCHECK(projection_effective_bias);
- }
- if (!use_cifg) {
- // Calculate the input gate. (If not CIFG.)
- CalculateLstmGateInteger8x8_16(
- input_ptr, input_to_input_weight_ptr, input_to_input_effective_bias,
- effective_input_to_input_scale_a, effective_input_to_input_scale_b,
- output_state_ptr, recurrent_to_input_weight_ptr,
- recurrent_to_input_effective_bias, effective_recurrent_to_input_scale_a,
- effective_recurrent_to_input_scale_b, cell_state_ptr,
- cell_to_input_weight_ptr, effective_cell_to_input_scale_a,
- effective_cell_to_input_scale_b, layer_norm_input_weight_ptr,
- input_gate_bias_ptr, layer_norm_input_scale_a, layer_norm_input_scale_b,
- input_variance_guard, n_batch, n_input, n_output, n_cell,
- kTfLiteActSigmoid, input_gate_scratch, scratch5);
- }
- // Calculate the forget gate.
- CalculateLstmGateInteger8x8_16(
- input_ptr, input_to_forget_weight_ptr, input_to_forget_effective_bias,
- effective_input_to_forget_scale_a, effective_input_to_forget_scale_b,
- output_state_ptr, recurrent_to_forget_weight_ptr,
- recurrent_to_forget_effective_bias, effective_recurrent_to_forget_scale_a,
- effective_recurrent_to_forget_scale_b, cell_state_ptr,
- cell_to_forget_weight_ptr, effective_cell_to_forget_scale_a,
- effective_cell_to_forget_scale_b, layer_norm_forget_weight_ptr,
- forget_gate_bias_ptr, layer_norm_forget_scale_a,
- layer_norm_forget_scale_b, forget_variance_guard, n_batch, n_input,
- n_output, n_cell, kTfLiteActSigmoid, forget_gate_scratch, scratch5);
- // Calculate the cell update gate.
- CalculateLstmGateInteger8x8_16(
- input_ptr, input_to_cell_weight_ptr, input_to_cell_effective_bias,
- effective_input_to_cell_scale_a, effective_input_to_cell_scale_b,
- output_state_ptr, recurrent_to_cell_weight_ptr,
- recurrent_to_cell_effective_bias, effective_recurrent_to_cell_scale_a,
- effective_recurrent_to_cell_scale_b, cell_state_ptr,
- /*cell_to_gate_weights=*/nullptr, /*cell_to_gate_scale_a=*/0,
- /*cell_to_gate_scale_b=*/0, layer_norm_cell_weight_ptr,
- cell_gate_bias_ptr, layer_norm_cell_scale_a, layer_norm_cell_scale_b,
- cell_variance_guard, n_batch, n_input, n_output, n_cell, kTfLiteActTanh,
- cell_gate_scratch, scratch5);
- // Update the cell state.
- UpdateLstmCellInteger(n_batch, n_cell, cell_state_ptr, cell_state_scale,
- input_gate_scratch, forget_gate_scratch,
- cell_gate_scratch, use_cifg, quantized_cell_clip);
- // Calculate the output gate.
- CalculateLstmGateInteger8x8_16(
- input_ptr, input_to_output_weight_ptr, input_to_output_effective_bias,
- effective_input_to_output_scale_a, effective_input_to_output_scale_b,
- output_state_ptr, recurrent_to_output_weight_ptr,
- recurrent_to_output_effective_bias, effective_recurrent_to_output_scale_a,
- effective_recurrent_to_output_scale_b, cell_state_ptr,
- cell_to_output_weight_ptr, effective_cell_to_output_scale_a,
- effective_cell_to_output_scale_b, layer_norm_output_weight_ptr,
- output_gate_bias_ptr, layer_norm_output_scale_a,
- layer_norm_output_scale_b, output_variance_guard, n_batch, n_input,
- n_output, n_cell, kTfLiteActSigmoid, output_gate_scratch, scratch5);
- // Update the output state.
- CalculateLstmOutputInteger8x8_16(
- n_batch, n_cell, n_output, cell_state_ptr, cell_state_scale,
- output_gate_scratch, effective_hidden_scale_a, effective_hidden_scale_b,
- hidden_zp, projection_weight_ptr, effective_proj_scale_a,
- effective_proj_scale_b, projection_effective_bias, output_state_zp,
- quantized_proj_clip, output_state_ptr, scratch0, scratch4, scratch5);
- // Copy output state to the output. Note that unlike float or hybrid, output
- // is always contiguous.
- std::copy_n(output_state_ptr, n_batch * n_output, output_ptr);
-}
-
-// Fully quantized lstm kernel for 8 bit gate matmul output.
-//
-// Input tensor of size n_batch * n_input:
-// input_ptr
-//
-// LSTM weights:
-// Quantized input weights of size 'n_cell * n_input':
-// input_to_input_weight_ptr - optional
-// input_to_forget_weight_ptr - optional
-// input_to_cell_weight_ptr - optional
-// input_to_output_weight_ptr - optional
-//
-// Quantized recurrent weights of size 'n_cell * n_output':
-// recurrent_to_input_weight_ptr - optional
-// recurrent_to_forget_weights_ptr
-// recurrent_to_cell_weights_ptr
-// recurrent_to_input_weights_ptr
-//
-// Quantized peephole weights of size 'n_cell', representing diagonal matrices.
-// cell_to_input_weights - optional
-// cell_to_cell_weights - optional
-// cell_to_output_weights - optional
-//
-// Quantized projection weights of size 'n_output * n_cell'
-// projection_weight_ptr - optional
-//
-// Weight scales (scalars) for each of the weights above.
-// effective_input_to_input_scale_a - optional
-// effective_input_to_input_scale_b - optional
-// effective_input_to_forget_scale_a
-// effective_input_to_forget_scale_b
-// effective_input_to_cell_scale_a
-// effective_input_to_cell_scale_b
-// effective_input_to_output_scale_a
-// effective_input_to_output_scale_b
-// effective_recurrent_to_input_scale_a - optional
-// effective_recurrent_to_input_scale_b - optional
-// effective_recurrent_to_forget_scale_a
-// effective_recurrent_to_forget_scale_b
-// effective_recurrent_to_cell_scale_a
-// effective_recurrent_to_cell_scale_b
-// effective_recurrent_to_output_scale_a
-// effective_recurrent_to_output_scale_b
-// effective_proj_scale_a - optional
-// effective_proj_scale_b - optional
-//
-// Gate biases of size 'n_cell':
-// input_gate_bias_ptr - optional
-// forget_gate_bias_ptr
-// cell_gate_bias_ptr
-// output_gate_bias_ptr
-//
-// Layer norm coefficients of size 'n_cell', representing diagonal matrices.
-// layer_norm_input_weight_ptr - optional
-// layer_norm_forget_weight_ptr - optional
-// layer_norm_cell_weight_ptr - optional
-// layer_norm_output_weight_ptr - optional
-//
-// Layer norm scales of size 'n_cell'.
-// layer_norm_input_scale_a - optional
-// layer_norm_input_scale_b - optional
-// layer_norm_forget_scale_a - optional
-// layer_norm_forget_scale_b - optional
-// layer_norm_cell_scale_a - optional
-// layer_norm_cell_scale_b - optional
-// layer_norm_output_scale_a - optional
-// layer_norm_output_scale_b - optional
-//
-// Scalar values:
-// quantized_cell_clip: quantized clip value for cell.
-// quantized_proj_clip: quantized clip value for projection.
-// cell_state_scale: the power of two scale for cell state.
-//
-// Zero points:
-// output_state_zp: zero point of output state.
-// hidden_zp: zero point for hidden state.
-//
-// Temporary pre-allocated storage for the calculation. Each is of size n_cell *
-// n_batch.
-// scratch0
-// scratch1
-// scratch2
-// scratch3
-// scratch4
-// scratch5
-// scratch6
-// scratch7
-//
-// Outputs:
-// output_state_ptr - size 'n_batch * n_output'
-// cell_state_ptr - size 'n_batch * n_cell'
-// output_ptr - size 'n_batch * n_output'
-// TODO(b/148688698): Move zero point calculation into Prepare().
-// TODO(b/159947023): scratch5 is unused, remove.
-inline void LstmStepInteger8x8_8(
- const int8_t* input_ptr, int32_t input_zp,
- const int8_t* input_to_input_weight_ptr,
- int32_t effective_input_to_input_scale_a,
- int32_t effective_input_to_input_scale_b,
- const int8_t* input_to_forget_weight_ptr,
- int32_t effective_input_to_forget_scale_a,
- int32_t effective_input_to_forget_scale_b,
- const int8_t* input_to_cell_weight_ptr,
- int32_t effective_input_to_cell_scale_a,
- int32_t effective_input_to_cell_scale_b,
- const int8_t* input_to_output_weight_ptr,
- int32_t effective_input_to_output_scale_a,
- int32_t effective_input_to_output_scale_b,
- const int8_t* recurrent_to_input_weight_ptr,
- int32_t effective_recurrent_to_input_scale_a,
- int32_t effective_recurrent_to_input_scale_b,
- const int8_t* recurrent_to_forget_weight_ptr,
- int32_t effective_recurrent_to_forget_scale_a,
- int32_t effective_recurrent_to_forget_scale_b,
- const int8_t* recurrent_to_cell_weight_ptr,
- int32_t effective_recurrent_to_cell_scale_a,
- int32_t effective_recurrent_to_cell_scale_b,
- const int8_t* recurrent_to_output_weight_ptr,
- int32_t effective_recurrent_to_output_scale_a,
- int32_t effective_recurrent_to_output_scale_b,
- const int8_t* cell_to_input_weight_ptr,
- int32_t effective_cell_to_input_scale_a,
- int32_t effective_cell_to_input_scale_b,
- const int8_t* cell_to_forget_weight_ptr,
- int32_t effective_cell_to_forget_scale_a,
- int32_t effective_cell_to_forget_scale_b,
- const int8_t* cell_to_output_weight_ptr,
- int32_t effective_cell_to_output_scale_a,
- int32_t effective_cell_to_output_scale_b,
- const int8_t* projection_weight_ptr, int32_t effective_proj_scale_a,
- int32_t effective_proj_scale_b, const int16_t* layer_norm_input_weight_ptr,
- int32_t layer_norm_input_scale_a, int32_t layer_norm_input_scale_b,
- const int16_t* layer_norm_forget_weight_ptr,
- int32_t layer_norm_forget_scale_a, int32_t layer_norm_forget_scale_b,
- const int16_t* layer_norm_cell_weight_ptr, int32_t layer_norm_cell_scale_a,
- int32_t layer_norm_cell_scale_b,
- const int16_t* layer_norm_output_weight_ptr,
- int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
- const int32_t* input_gate_bias_ptr, const int32_t* forget_gate_bias_ptr,
- const int32_t* cell_gate_bias_ptr, const int32_t* output_gate_bias_ptr,
- const int32_t* projection_bias_ptr, const TfLiteLSTMParams* params,
- const int32_t* intermediate_scale_a, const int32_t* intermediate_scale_b,
- const int32_t* intermediate_zp, int16_t quantized_cell_clip,
- int8_t quantized_proj_clip, int n_batch, int n_cell, int n_input,
- int n_output, int output_batch_leading_dim, int8_t* output_state_ptr,
- int32_t output_state_zp, int16_t* cell_state_ptr, int8_t* output_ptr,
- int8_t* scratch0, int8_t* scratch1, int16_t* scratch2, int16_t* scratch3,
- int16_t* scratch4, int16_t* scratch5, int16_t* scratch6,
- int16_t* scratch7) {
- // TODO(b/159066113): scratch5 is unused, remove.
-
- // ruy::profiler::ScopeLabel label("LstmStepInteger8x8_8");
- // Make named scratch buffers for the different gates.
- int16_t* forget_gate_scratch = scratch2;
- int16_t* cell_gate_scratch = scratch3;
- int16_t* output_gate_scratch = scratch4;
- // no-CIFG is not supported here
-
- // Calculate the forget gate.
- CalculateLstmGateInteger8x8_8(
- input_ptr, input_zp, input_to_forget_weight_ptr,
- effective_input_to_forget_scale_a, effective_input_to_forget_scale_b,
- intermediate_scale_a[2], intermediate_scale_b[2], intermediate_zp[4],
- output_state_ptr, output_state_zp, recurrent_to_forget_weight_ptr,
- effective_recurrent_to_forget_scale_a,
- effective_recurrent_to_forget_scale_b, intermediate_scale_a[3],
- intermediate_scale_b[3], intermediate_zp[5], layer_norm_forget_weight_ptr,
- layer_norm_forget_scale_a, layer_norm_forget_scale_b,
- forget_gate_bias_ptr, n_batch, n_input, n_output, n_cell,
- kTfLiteActSigmoid, forget_gate_scratch, scratch0, scratch1);
- // Calculate the cell update gate.
- CalculateLstmGateInteger8x8_8(
- input_ptr, input_zp, input_to_cell_weight_ptr,
- effective_input_to_cell_scale_a, effective_input_to_cell_scale_b,
- intermediate_scale_a[4], intermediate_scale_b[4], intermediate_zp[7],
- output_state_ptr, output_state_zp, recurrent_to_cell_weight_ptr,
- effective_recurrent_to_cell_scale_a, effective_recurrent_to_cell_scale_b,
- intermediate_scale_a[5], intermediate_scale_b[5], intermediate_zp[8],
- layer_norm_cell_weight_ptr, layer_norm_cell_scale_a,
- layer_norm_cell_scale_b, cell_gate_bias_ptr, n_batch, n_input, n_output,
- n_cell, kTfLiteActTanh, cell_gate_scratch, scratch0, scratch1);
- // Update the cell state.
- UpdateLstmCellInteger(n_batch, n_cell, cell_state_ptr,
- /*cell_state_scale=*/-15, /*input_gate=*/nullptr,
- forget_gate_scratch, cell_gate_scratch,
- /*use_cifg=*/true, quantized_cell_clip);
- // Calculate the output gate.
- CalculateLstmGateInteger8x8_8(
- input_ptr, input_zp, input_to_output_weight_ptr,
- effective_input_to_output_scale_a, effective_input_to_output_scale_b,
- intermediate_scale_a[6], intermediate_scale_b[6], intermediate_zp[10],
- output_state_ptr, output_state_zp, recurrent_to_output_weight_ptr,
- effective_recurrent_to_output_scale_a,
- effective_recurrent_to_output_scale_b, intermediate_scale_a[11],
- intermediate_scale_b[7], intermediate_zp[7], layer_norm_output_weight_ptr,
- layer_norm_output_scale_a, layer_norm_output_scale_b,
- output_gate_bias_ptr, n_batch, n_input, n_output, n_cell,
- kTfLiteActSigmoid, output_gate_scratch, scratch0, scratch1);
- // Update the output state.
- CalculateLstmOutputInteger8x8_8(
- n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
- projection_weight_ptr, effective_proj_scale_a, effective_proj_scale_b,
- projection_bias_ptr, output_state_zp, quantized_proj_clip,
- output_state_ptr, scratch2);
- // Copy output state to the output. Note that unlike float or hybrid, output
- // is always contiguous.
- std::copy_n(output_state_ptr, n_batch * n_output, output_ptr);
-}
-
-} // namespace
-
-// LINT.ThenChange(//tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc)
-TfLiteStatus EvalInteger8x8_16(
- TfLiteContext* context, TfLiteNode* node, const TfLiteEvalTensor* input,
- const TfLiteEvalTensor* input_to_input_weights,
- const TfLiteEvalTensor* input_to_forget_weights,
- const TfLiteEvalTensor* input_to_cell_weights,
- const TfLiteEvalTensor* input_to_output_weights,
- const TfLiteEvalTensor* recurrent_to_input_weights,
- const TfLiteEvalTensor* recurrent_to_forget_weights,
- const TfLiteEvalTensor* recurrent_to_cell_weights,
- const TfLiteEvalTensor* recurrent_to_output_weights,
- const TfLiteEvalTensor* cell_to_input_weights,
- const TfLiteEvalTensor* cell_to_forget_weights,
- const TfLiteEvalTensor* cell_to_output_weights,
- const TfLiteEvalTensor* input_layer_norm_coefficients,
- const TfLiteEvalTensor* forget_layer_norm_coefficients,
- const TfLiteEvalTensor* cell_layer_norm_coefficients,
- const TfLiteEvalTensor* output_layer_norm_coefficients,
- const TfLiteEvalTensor* input_gate_bias,
- const TfLiteEvalTensor* forget_gate_bias,
- const TfLiteEvalTensor* cell_gate_bias,
- const TfLiteEvalTensor* output_gate_bias,
- const TfLiteEvalTensor* projection_weights,
- const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
- bool forward_sequence, bool time_major,
- const lstm_eval::IntegerLstmParameter* integer_lstm_param,
- TfLiteEvalTensor* output_state, TfLiteEvalTensor* cell_state,
- TfLiteEvalTensor* output, TfLiteEvalTensor* scratch0,
- TfLiteEvalTensor* scratch1, TfLiteEvalTensor* scratch2,
- TfLiteEvalTensor* scratch3, TfLiteEvalTensor* scratch4,
- TfLiteEvalTensor* scratch5) {
- TFLITE_DCHECK(input->dims->size >= 2 && input->dims->size <= 3);
- const int n_input = input->dims->data[input->dims->size - 1];
- int max_time, n_batch;
- if (input->dims->size == 2) {
- max_time = 1;
- n_batch = input->dims->data[0];
- } else {
- max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
- n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
- }
-
- // n_cell and n_output will be the same size when there is no projection.
- const int n_cell = input_to_output_weights->dims->data[0];
- const int n_output = recurrent_to_output_weights->dims->data[1];
-
- // Activation zero point
- // TODO@is data.output_zero_point equal to output_state->params.zero_point
- // int output_state_zp = output_state->params.zero_point;
- int output_state_zp = 0;
-
- // Get params for time/batch/sequence.
- const int output_batch_leading_dim =
- output->dims->data[output->dims->size - 1];
-
- if (time_major) {
- const int input_step = n_batch * n_input;
- const int output_step = n_batch * output_batch_leading_dim;
- for (int t = 0; t < max_time; t++) {
- const int t_rel = t;
- int8_t* output_ptr =
- tflite::micro::GetTensorData<int8_t>(output) + t_rel * output_step;
- const int8_t* input_ptr =
- tflite::micro::GetTensorData<int8_t>(input) + t_rel * input_step;
- LstmStepInteger8x8_16(
- input_ptr,
- tflite::micro::GetTensorData<int8_t>(input_to_input_weights),
- integer_lstm_param->effective_input_to_input_scale_a,
- integer_lstm_param->effective_input_to_input_scale_b,
- tflite::micro::GetTensorData<int8_t>(input_to_forget_weights),
- integer_lstm_param->effective_input_to_forget_scale_a,
- integer_lstm_param->effective_input_to_forget_scale_b,
- tflite::micro::GetTensorData<int8_t>(input_to_cell_weights),
- integer_lstm_param->effective_input_to_cell_scale_a,
- integer_lstm_param->effective_input_to_cell_scale_b,
- tflite::micro::GetTensorData<int8_t>(input_to_output_weights),
- integer_lstm_param->effective_input_to_output_scale_a,
- integer_lstm_param->effective_input_to_output_scale_b,
- tflite::micro::GetTensorData<int8_t>(recurrent_to_input_weights),
- integer_lstm_param->effective_recurrent_to_input_scale_a,
- integer_lstm_param->effective_recurrent_to_input_scale_b,
- tflite::micro::GetTensorData<int8_t>(recurrent_to_forget_weights),
- integer_lstm_param->effective_recurrent_to_forget_scale_a,
- integer_lstm_param->effective_recurrent_to_forget_scale_b,
- tflite::micro::GetTensorData<int8_t>(recurrent_to_cell_weights),
- integer_lstm_param->effective_recurrent_to_cell_scale_a,
- integer_lstm_param->effective_recurrent_to_cell_scale_b,
- tflite::micro::GetTensorData<int8_t>(recurrent_to_output_weights),
- integer_lstm_param->effective_recurrent_to_output_scale_a,
- integer_lstm_param->effective_recurrent_to_output_scale_b,
- tflite::micro::GetTensorData<int16_t>(cell_to_input_weights),
- integer_lstm_param->effective_cell_to_input_scale_a,
- integer_lstm_param->effective_cell_to_input_scale_b,
- tflite::micro::GetTensorData<int16_t>(cell_to_forget_weights),
- integer_lstm_param->effective_cell_to_forget_scale_a,
- integer_lstm_param->effective_cell_to_forget_scale_b,
- tflite::micro::GetTensorData<int16_t>(cell_to_output_weights),
- integer_lstm_param->effective_cell_to_output_scale_a,
- integer_lstm_param->effective_cell_to_output_scale_b,
- tflite::micro::GetTensorData<int8_t>(projection_weights),
- integer_lstm_param->effective_proj_scale_a,
- integer_lstm_param->effective_proj_scale_b,
- integer_lstm_param->hidden_zp,
- integer_lstm_param->effective_hidden_scale_a,
- integer_lstm_param->effective_hidden_scale_b,
- tflite::micro::GetTensorData<int16_t>(input_layer_norm_coefficients),
- integer_lstm_param->layer_norm_input_scale_a,
- integer_lstm_param->layer_norm_input_scale_b,
- tflite::micro::GetTensorData<int16_t>(forget_layer_norm_coefficients),
- integer_lstm_param->layer_norm_forget_scale_a,
- integer_lstm_param->layer_norm_forget_scale_b,
- tflite::micro::GetTensorData<int16_t>(cell_layer_norm_coefficients),
- integer_lstm_param->layer_norm_cell_scale_a,
- integer_lstm_param->layer_norm_cell_scale_b,
- tflite::micro::GetTensorData<int16_t>(output_layer_norm_coefficients),
- integer_lstm_param->layer_norm_output_scale_a,
- integer_lstm_param->layer_norm_output_scale_b,
- tflite::micro::GetTensorData<int32_t>(input_gate_bias),
- tflite::micro::GetTensorData<int32_t>(forget_gate_bias),
- tflite::micro::GetTensorData<int32_t>(cell_gate_bias),
- tflite::micro::GetTensorData<int32_t>(output_gate_bias),
- integer_lstm_param->quantized_cell_clip,
- integer_lstm_param->quantized_proj_clip,
- integer_lstm_param->cell_scale,
- integer_lstm_param->input_variance_guard,
- integer_lstm_param->forget_variance_guard,
- integer_lstm_param->cell_variance_guard,
- integer_lstm_param->output_variance_guard,
- integer_lstm_param->input_to_forget_effective_bias.get(),
- integer_lstm_param->recurrent_to_forget_effective_bias.get(),
- integer_lstm_param->input_to_cell_effective_bias.get(),
- integer_lstm_param->recurrent_to_cell_effective_bias.get(),
- integer_lstm_param->input_to_output_effective_bias.get(),
- integer_lstm_param->recurrent_to_output_effective_bias.get(),
- integer_lstm_param->input_to_input_effective_bias.get(),
- integer_lstm_param->recurrent_to_input_effective_bias.get(),
- integer_lstm_param->projection_effective_bias.get(), n_batch, n_cell,
- n_input, n_output, tflite::micro::GetTensorData<int8_t>(output_state),
- output_state_zp, tflite::micro::GetTensorData<int16_t>(cell_state),
- output_ptr, (int16_t*)(scratch0), (int16_t*)(scratch1),
- (int16_t*)(scratch2), (int16_t*)(scratch3), (int8_t*)(scratch4),
- (int32_t*)(scratch5));
- }
- } else {
- for (int b = 0; b < n_batch; b++) {
- const int input_step = n_input;
- const int output_step = output_batch_leading_dim;
- for (int t = 0; t < max_time; t++) {
- // If this is the forward_sequence, step forward, otherwise step
- // backwards.
- const int t_rel = forward_sequence ? t : max_time - t - 1;
- const int time_offset = b * max_time + t_rel;
- const int8_t* input_ptr = tflite::micro::GetTensorData<int8_t>(input) +
- time_offset * input_step;
- int8_t* output_ptr = tflite::micro::GetTensorData<int8_t>(output) +
- time_offset * output_step;
-
- // Offset the {output,cell}_state pointers to the right batch.
- int8_t* output_state_ptr =
- tflite::micro::GetTensorData<int8_t>(output_state) +
- b * output_batch_leading_dim;
- int16_t* cell_state_ptr =
- tflite::micro::GetTensorData<int16_t>(cell_state) + b * n_cell;
-
- LstmStepInteger8x8_16(
- input_ptr,
- tflite::micro::GetTensorData<int8_t>(input_to_input_weights),
- integer_lstm_param->effective_input_to_input_scale_a,
- integer_lstm_param->effective_input_to_input_scale_b,
- tflite::micro::GetTensorData<int8_t>(input_to_forget_weights),
- integer_lstm_param->effective_input_to_forget_scale_a,
- integer_lstm_param->effective_input_to_forget_scale_b,
- tflite::micro::GetTensorData<int8_t>(input_to_cell_weights),
- integer_lstm_param->effective_input_to_cell_scale_a,
- integer_lstm_param->effective_input_to_cell_scale_b,
- tflite::micro::GetTensorData<int8_t>(input_to_output_weights),
- integer_lstm_param->effective_input_to_output_scale_a,
- integer_lstm_param->effective_input_to_output_scale_b,
- tflite::micro::GetTensorData<int8_t>(recurrent_to_input_weights),
- integer_lstm_param->effective_recurrent_to_input_scale_a,
- integer_lstm_param->effective_recurrent_to_input_scale_b,
- tflite::micro::GetTensorData<int8_t>(recurrent_to_forget_weights),
- integer_lstm_param->effective_recurrent_to_forget_scale_a,
- integer_lstm_param->effective_recurrent_to_forget_scale_b,
- tflite::micro::GetTensorData<int8_t>(recurrent_to_cell_weights),
- integer_lstm_param->effective_recurrent_to_cell_scale_a,
- integer_lstm_param->effective_recurrent_to_cell_scale_b,
- tflite::micro::GetTensorData<int8_t>(recurrent_to_output_weights),
- integer_lstm_param->effective_recurrent_to_output_scale_a,
- integer_lstm_param->effective_recurrent_to_output_scale_b,
- tflite::micro::GetTensorData<int16_t>(cell_to_input_weights),
- integer_lstm_param->effective_cell_to_input_scale_a,
- integer_lstm_param->effective_cell_to_input_scale_b,
- tflite::micro::GetTensorData<int16_t>(cell_to_forget_weights),
- integer_lstm_param->effective_cell_to_forget_scale_a,
- integer_lstm_param->effective_cell_to_forget_scale_b,
- tflite::micro::GetTensorData<int16_t>(cell_to_output_weights),
- integer_lstm_param->effective_cell_to_output_scale_a,
- integer_lstm_param->effective_cell_to_output_scale_b,
- tflite::micro::GetTensorData<int8_t>(projection_weights),
- integer_lstm_param->effective_proj_scale_a,
- integer_lstm_param->effective_proj_scale_b,
- integer_lstm_param->hidden_zp,
- integer_lstm_param->effective_hidden_scale_a,
- integer_lstm_param->effective_hidden_scale_b,
- tflite::micro::GetTensorData<int16_t>(
- input_layer_norm_coefficients),
- integer_lstm_param->layer_norm_input_scale_a,
- integer_lstm_param->layer_norm_input_scale_b,
- tflite::micro::GetTensorData<int16_t>(
- forget_layer_norm_coefficients),
- integer_lstm_param->layer_norm_forget_scale_a,
- integer_lstm_param->layer_norm_forget_scale_b,
- tflite::micro::GetTensorData<int16_t>(cell_layer_norm_coefficients),
- integer_lstm_param->layer_norm_cell_scale_a,
- integer_lstm_param->layer_norm_cell_scale_b,
- tflite::micro::GetTensorData<int16_t>(
- output_layer_norm_coefficients),
- integer_lstm_param->layer_norm_output_scale_a,
- integer_lstm_param->layer_norm_output_scale_b,
- tflite::micro::GetTensorData<int32_t>(input_gate_bias),
- tflite::micro::GetTensorData<int32_t>(forget_gate_bias),
- tflite::micro::GetTensorData<int32_t>(cell_gate_bias),
- tflite::micro::GetTensorData<int32_t>(output_gate_bias),
- integer_lstm_param->quantized_cell_clip,
- integer_lstm_param->quantized_proj_clip,
- integer_lstm_param->cell_scale,
- integer_lstm_param->input_variance_guard,
- integer_lstm_param->forget_variance_guard,
- integer_lstm_param->cell_variance_guard,
- integer_lstm_param->output_variance_guard,
- integer_lstm_param->input_to_forget_effective_bias.get(),
- integer_lstm_param->recurrent_to_forget_effective_bias.get(),
- integer_lstm_param->input_to_cell_effective_bias.get(),
- integer_lstm_param->recurrent_to_cell_effective_bias.get(),
- integer_lstm_param->input_to_output_effective_bias.get(),
- integer_lstm_param->recurrent_to_output_effective_bias.get(),
- integer_lstm_param->input_to_input_effective_bias.get(),
- integer_lstm_param->recurrent_to_input_effective_bias.get(),
- integer_lstm_param->projection_effective_bias.get(), /*n_batch=*/1,
- n_cell, n_input, n_output, output_state_ptr, output_state_zp,
- cell_state_ptr, output_ptr, (int16_t*)(scratch0),
- (int16_t*)(scratch1), (int16_t*)(scratch2), (int16_t*)(scratch3),
- (int8_t*)(scratch4), (int32_t*)(scratch5));
- }
+LstmTensors::~LstmTensors() {
+ for (size_t i = 0; i < 24; i++) {
+ if (internal_tensors_[i] != nullptr) {
+ micro_context_->DeallocateTempTfLiteTensor(internal_tensors_[i]);
}
}
+ micro_context_->DeallocateTempTfLiteTensor(output_tensor_);
+}
+// Verify the LSTM internal tensor properties (e.g., type checks)
+// Input/output/states/fc weights tensors are required for kernel evaulation.
+// The state tensors should be variables. Variants of the standard LSTM
+// are not supported here, therefore their corresponding tensors should be
+// invalid
+TfLiteStatus LstmTensors::ValidateTensorStatus(TfLiteContext* context) const {
+ // Verify certain tensor properties
+ // input tensor
+ TF_LITE_ENSURE(context, internal_tensors_[kLstmInputTensor] != nullptr);
+ // hidden state
+ TF_LITE_ENSURE(context, internal_tensors_[kLstmOutputStateTensor] != nullptr);
+ TF_LITE_ENSURE(context,
+ internal_tensors_[kLstmOutputStateTensor]->is_variable);
+ // hidden state becomes input so they must have the same type
+ TF_LITE_ENSURE_EQ(context, internal_tensors_[kLstmOutputStateTensor]->type,
+ internal_tensors_[kLstmInputTensor]->type);
+ // cell state
+ TF_LITE_ENSURE(context, internal_tensors_[kLstmCellStateTensor] != nullptr);
+ TF_LITE_ENSURE(context, internal_tensors_[kLstmCellStateTensor]->is_variable);
+ // output
+ TF_LITE_ENSURE(context, output_tensor_ != nullptr);
+ // output type is the same as the input type (activations)
+ TF_LITE_ENSURE_EQ(context, output_tensor_->type,
+ internal_tensors_[kLstmInputTensor]->type);
+
+ // weight tensors (1-9, see lstm_shared for index definition)
+ const auto weight_type =
+ internal_tensors_[kLstmInputToForgetWeightsTensor]->type;
+ for (size_t i = 1; i < 9; i++) {
+ TF_LITE_ENSURE(context, internal_tensors_[i] != nullptr);
+ TF_LITE_ENSURE_EQ(context, internal_tensors_[i]->type, weight_type);
+ }
+
+ // bias tensors (12-15, see lstm_shared for index definition)
+ const auto bias_type = internal_tensors_[kLstmForgetGateBiasTensor]->type;
+ for (size_t i = 12; i < 16; i++) {
+ TF_LITE_ENSURE(context, internal_tensors_[i] != nullptr);
+ TF_LITE_ENSURE_EQ(context, internal_tensors_[i]->type, bias_type);
+ }
+ // Tensors from LSTM variants are invalid
+ // No peephole
+ for (size_t i = 9; i < 12; i++) {
+ TF_LITE_ENSURE(context, internal_tensors_[i] == nullptr);
+ }
+ // No projection
+ for (size_t i = 16; i < 18; i++) {
+ TF_LITE_ENSURE(context, internal_tensors_[i] == nullptr);
+ }
+ // No internal layer norm
+ for (size_t i = 20; i < 24; i++) {
+ TF_LITE_ENSURE(context, internal_tensors_[i] == nullptr);
+ }
return kTfLiteOk;
}
-TfLiteStatus EvalInteger8x8_8(
- const TfLiteEvalTensor* input,
- const TfLiteEvalTensor* input_to_input_weights,
- const TfLiteEvalTensor* input_to_forget_weights,
- const TfLiteEvalTensor* input_to_cell_weights,
- const TfLiteEvalTensor* input_to_output_weights,
- const TfLiteEvalTensor* recurrent_to_input_weights,
- const TfLiteEvalTensor* recurrent_to_forget_weights,
- const TfLiteEvalTensor* recurrent_to_cell_weights,
- const TfLiteEvalTensor* recurrent_to_output_weights,
- const TfLiteEvalTensor* cell_to_input_weights,
- const TfLiteEvalTensor* cell_to_forget_weights,
- const TfLiteEvalTensor* cell_to_output_weights,
- const TfLiteEvalTensor* input_layer_norm_coefficients,
- const TfLiteEvalTensor* forget_layer_norm_coefficients,
- const TfLiteEvalTensor* cell_layer_norm_coefficients,
- const TfLiteEvalTensor* output_layer_norm_coefficients,
- const TfLiteEvalTensor* input_gate_bias,
- const TfLiteEvalTensor* forget_gate_bias,
- const TfLiteEvalTensor* cell_gate_bias,
- const TfLiteEvalTensor* output_gate_bias,
- const TfLiteEvalTensor* projection_weights,
- const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
- TfLiteEvalTensor* output_state, TfLiteEvalTensor* cell_state,
- TfLiteEvalTensor* output,
- const lstm_eval::IntegerLstmParameter* integer_lstm_param,
- TfLiteEvalTensor* scratch0, TfLiteEvalTensor* scratch1,
- TfLiteEvalTensor* scratch2, TfLiteEvalTensor* scratch3,
- TfLiteEvalTensor* scratch4, TfLiteEvalTensor* scratch5,
- TfLiteEvalTensor* scratch6, TfLiteEvalTensor* scratch7) {
- TFLITE_DCHECK(input->dims->size >= 2 && input->dims->size <= 3);
- const int n_input = input->dims->data[input->dims->size - 1];
- int max_time, n_batch;
- if (input->dims->size == 2) {
- max_time = 1;
- n_batch = input->dims->data[0];
- } else {
- max_time = input->dims->data[0];
- n_batch = input->dims->data[1];
+namespace lstm_internal {
+
+#if !(defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
+const int32_t kInt16Max = std::numeric_limits<int16_t>::max();
+const int32_t kInt16Min = std::numeric_limits<int16_t>::min();
+#endif
+
+void AddElementWise(const int16_t* input_1, const int16_t* input_2, int n_batch,
+ int n_input, int16_t* output) {
+#if !(defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
+ for (int batch = 0; batch < n_batch; ++batch) {
+ for (int i = 0; i < n_input; ++i) {
+ const int index = batch * n_input + i;
+ int32_t sum = input_1[index] + input_2[index];
+ const int32_t sum_clamped = std::min(kInt16Max, std::max(kInt16Min, sum));
+ output[index] = static_cast<int16_t>(sum_clamped);
+ }
}
-
- // n_cell and n_output will be the same size when there is no projection.
- const int n_cell = input_to_output_weights->dims->data[0];
- const int n_output = recurrent_to_output_weights->dims->data[1];
- //@TODO input zero point and output zeropoint
- // const int32_t input_zp = input->params.zero_point;
- /// const int32_t output_state_zp = output_state->params.zero_point;
- const int32_t input_zp = 0;
- const int32_t output_state_zp = 0;
-
- // Get params for time/batch/sequence.
- const int output_batch_leading_dim =
- output->dims->data[output->dims->size - 1];
- const int input_step = n_batch * n_input;
- const int output_step = n_batch * output_batch_leading_dim;
-
- for (int t = 0; t < max_time; t++) {
- const int t_rel = t;
- int8_t* output_ptr =
- tflite::micro::GetTensorData<int8_t>(output) + t_rel * output_step;
- // Input can be int8 asymmetric or int16 symmetric.
- const int8_t* input_ptr =
- tflite::micro::GetTensorData<int8_t>(input) + t_rel * input_step;
- lstm_eval::LstmStepInteger8x8_8(
- input_ptr, input_zp,
-
- tflite::micro::GetTensorData<int8_t>(input_to_input_weights),
- integer_lstm_param->effective_input_to_input_scale_a,
- integer_lstm_param->effective_input_to_input_scale_b,
-
- tflite::micro::GetTensorData<int8_t>(input_to_forget_weights),
- integer_lstm_param->effective_input_to_forget_scale_a,
- integer_lstm_param->effective_input_to_forget_scale_b,
-
- tflite::micro::GetTensorData<int8_t>(input_to_cell_weights),
- integer_lstm_param->effective_input_to_cell_scale_a,
- integer_lstm_param->effective_input_to_cell_scale_b,
-
- tflite::micro::GetTensorData<int8_t>(input_to_output_weights),
- integer_lstm_param->effective_input_to_output_scale_a,
- integer_lstm_param->effective_input_to_output_scale_b,
-
- tflite::micro::GetTensorData<int8_t>(recurrent_to_input_weights),
- integer_lstm_param->effective_recurrent_to_input_scale_a,
- integer_lstm_param->effective_recurrent_to_input_scale_b,
-
- tflite::micro::GetTensorData<int8_t>(recurrent_to_forget_weights),
- integer_lstm_param->effective_recurrent_to_forget_scale_a,
- integer_lstm_param->effective_recurrent_to_forget_scale_b,
-
- tflite::micro::GetTensorData<int8_t>(recurrent_to_cell_weights),
- integer_lstm_param->effective_recurrent_to_cell_scale_a,
- integer_lstm_param->effective_recurrent_to_cell_scale_b,
-
- tflite::micro::GetTensorData<int8_t>(recurrent_to_output_weights),
- integer_lstm_param->effective_recurrent_to_output_scale_a,
- integer_lstm_param->effective_recurrent_to_output_scale_b,
-
- tflite::micro::GetTensorData<int8_t>(cell_to_input_weights),
- integer_lstm_param->effective_cell_to_input_scale_a,
- integer_lstm_param->effective_cell_to_input_scale_b,
-
- tflite::micro::GetTensorData<int8_t>(cell_to_forget_weights),
- integer_lstm_param->effective_cell_to_forget_scale_a,
- integer_lstm_param->effective_cell_to_forget_scale_b,
-
- tflite::micro::GetTensorData<int8_t>(cell_to_output_weights),
- integer_lstm_param->effective_cell_to_output_scale_a,
- integer_lstm_param->effective_cell_to_output_scale_b,
-
- tflite::micro::GetTensorData<int8_t>(projection_weights),
- integer_lstm_param->effective_proj_scale_a,
- integer_lstm_param->effective_proj_scale_b,
-
- tflite::micro::GetTensorData<int16_t>(input_layer_norm_coefficients),
- integer_lstm_param->layer_norm_input_scale_a,
- integer_lstm_param->layer_norm_input_scale_b,
-
- tflite::micro::GetTensorData<int16_t>(forget_layer_norm_coefficients),
- integer_lstm_param->layer_norm_forget_scale_a,
- integer_lstm_param->layer_norm_forget_scale_b,
-
- tflite::micro::GetTensorData<int16_t>(cell_layer_norm_coefficients),
- integer_lstm_param->layer_norm_cell_scale_a,
- integer_lstm_param->layer_norm_cell_scale_b,
-
- tflite::micro::GetTensorData<int16_t>(output_layer_norm_coefficients),
- integer_lstm_param->layer_norm_output_scale_a,
- integer_lstm_param->layer_norm_output_scale_b,
-
- tflite::micro::GetTensorData<int32_t>(input_gate_bias),
- tflite::micro::GetTensorData<int32_t>(forget_gate_bias),
- tflite::micro::GetTensorData<int32_t>(cell_gate_bias),
- tflite::micro::GetTensorData<int32_t>(output_gate_bias),
- tflite::micro::GetTensorData<int32_t>(projection_bias),
-
- params, integer_lstm_param->intermediate_scale_a,
- integer_lstm_param->intermediate_scale_b,
- integer_lstm_param->intermediate_zp,
- integer_lstm_param->quantized_cell_clip,
- integer_lstm_param->quantized_proj_clip, n_batch, n_cell, n_input,
- n_output, output_batch_leading_dim,
- tflite::micro::GetTensorData<int8_t>(output_state), output_state_zp,
- tflite::micro::GetTensorData<int16_t>(cell_state), output_ptr,
- tflite::micro::GetTensorData<int8_t>(scratch0),
- tflite::micro::GetTensorData<int8_t>(scratch1),
- tflite::micro::GetTensorData<int16_t>(scratch2),
- tflite::micro::GetTensorData<int16_t>(scratch3),
- tflite::micro::GetTensorData<int16_t>(scratch4),
- tflite::micro::GetTensorData<int16_t>(scratch5),
- tflite::micro::GetTensorData<int16_t>(scratch6),
- tflite::micro::GetTensorData<int16_t>(scratch7));
- }
-
- return kTfLiteOk;
+#else
+ WORD32 err;
+ err = xa_nn_elm_add_16x16_16(output, input_1, input_2, n_batch * n_input);
+#endif
}
-} // namespace lstm_eval
-} // namespace micro
-} // namespace ops
+void AddElementWise(const float* input_1, const float* input_2, int n_batch,
+ int n_input, float* output) {
+ for (int batch = 0; batch < n_batch; ++batch) {
+ for (int i = 0; i < n_input; ++i) {
+ const int index = batch * n_input + i;
+ output[index] = input_1[index] + input_2[index];
+ }
+ }
+}
+
+#if !(defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
+void Sigmoid(const RuntimeShape& data_shape, int16_t* data) {
+ reference_integer_ops::Logistic(
+ 0 /*data->input_multiplier*/, 0 /*data->input_left_shift */,
+ data_shape.FlatSize() /*NumElements(input->dims)*/,
+ data /* tflite::micro::GetTensorData<int16_t>(input) */,
+ data /*tflite::micro::GetTensorData<int16_t>(output) */);
+}
+
+void Sigmoid(const RuntimeShape& data_shape, float* data) {
+ reference_ops::Logistic(data_shape, data, data_shape, data);
+}
+
+void Tanh(int32_t cell_state_scale_power, const RuntimeShape& input_data_shape,
+ int16_t* input_data, const RuntimeShape& output_data_shape,
+ int16_t* output_data) {
+ int32_t tanh_input_left_shift = (15 + cell_state_scale_power) - 3;
+ int32_t input_multiplier = 0;
+ if (tanh_input_left_shift < 0) /* handling negative shift value */
+ {
+ tanh_input_left_shift = -tanh_input_left_shift;
+ input_multiplier = 3;
+ }
+ reference_integer_ops::Tanh(input_multiplier, tanh_input_left_shift,
+ input_data_shape, input_data, output_data_shape,
+ output_data);
+}
+
+void Tanh(int32_t cell_state_scale_power, const RuntimeShape& input_data_shape,
+ float* input_data, const RuntimeShape& output_data_shape,
+ float* output_data) {
+ reference_ops::Tanh(input_data_shape, input_data, output_data_shape,
+ output_data);
+}
+
+// Input and output have the same shape in LSTM
+void Mul(const RuntimeShape& shape, const ArithmeticParams& params,
+ const int16_t* input1_data, const int16_t* input2_data,
+ int8_t* output_data) {
+ return reference_integer_ops::MulElementwise(
+ shape.FlatSize(), params, input1_data, input2_data, output_data);
+}
+
+// Input and output have the same shape in LSTM
+void Mul(const RuntimeShape& shape, const ArithmeticParams& params,
+ const int16_t* input1_data, const int16_t* input2_data,
+ int16_t* output_data) {
+ return reference_integer_ops::MulElementwise(
+ shape.FlatSize(), params, input1_data, input2_data, output_data);
+}
+
+// Input and output have the same shape in LSTM
+void Mul(const RuntimeShape& shape, const ArithmeticParams& params,
+ const float* input1_data, const float* input2_data,
+ float* output_data) {
+ return reference_ops::Mul(params, shape, input1_data, shape, input2_data,
+ shape, output_data);
+}
+
+void FullyConnected(const FullyConnectedParams& params,
+ const RuntimeShape& input_shape, const int8_t* input_data,
+ const RuntimeShape& filter_shape, const int8_t* filter_data,
+ const RuntimeShape& bias_shape, const int32_t* bias_data,
+ const RuntimeShape& output_shape, int16_t* output_data) {
+ return tflite::reference_integer_ops::FullyConnected(
+ params, input_shape, input_data, filter_shape, filter_data, bias_shape,
+ bias_data, output_shape, output_data);
+}
+
+void FullyConnected(const FullyConnectedParams& params,
+ const RuntimeShape& input_shape, const int16_t* input_data,
+ const RuntimeShape& filter_shape, const int8_t* filter_data,
+ const RuntimeShape& bias_shape, const int64_t* bias_data,
+ const RuntimeShape& output_shape, int16_t* output_data) {
+ return tflite::reference_integer_ops::FullyConnected(
+ params, input_shape, input_data, filter_shape, filter_data, bias_shape,
+ bias_data, output_shape, output_data);
+}
+
+void FullyConnected(const FullyConnectedParams& params,
+ const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& filter_shape, const float* filter_data,
+ const RuntimeShape& bias_shape, const float* bias_data,
+ const RuntimeShape& output_shape, float* output_data) {
+ return tflite::reference_ops::FullyConnected(
+ params, input_shape, input_data, filter_shape, filter_data, bias_shape,
+ bias_data, output_shape, output_data);
+}
+#else // #if !(defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
+void Sigmoid(int16_t* data, int32_t data_size) {
+ WORD32 err;
+ err = xa_nn_vec_sigmoid_sym16s_sym16s(data, data, 0, 0, data_size);
+}
+
+void Sigmoid(float* data, int32_t data_size) {
+ int data_dims[2] = {1, data_size};
+ RuntimeShape data_shape(2, reinterpret_cast<const int32_t*>(data_dims));
+ reference_ops::Logistic(data_shape, data, data_shape, data);
+}
+
+void Tanh(int32_t cell_state_scale_power, int16_t* input_data,
+ int16_t* output_data, int32_t data_size) {
+ int32_t tanh_input_left_shift = (15 + cell_state_scale_power) - 3;
+ int32_t input_multiplier = 0;
+ if (tanh_input_left_shift < 0) /* handling negative shift value */
+ {
+ tanh_input_left_shift = -tanh_input_left_shift;
+#if (defined(USE_HIFI_ACT_TIE) && \
+ (defined(AE_TANH16X4X2) || defined(AE_TANH16X4)))
+ input_multiplier = 1;
+#else
+ input_multiplier = 3;
+#endif
+ }
+ WORD32 err;
+ err = xa_nn_vec_tanh_sym16s_sym16s(output_data, input_data, input_multiplier,
+ tanh_input_left_shift, data_size);
+}
+
+void Tanh(int32_t cell_state_scale_power, float* input_data, float* output_data,
+ int32_t data_size) {
+ int data_dims[2] = {1, data_size};
+ RuntimeShape data_shape(2, reinterpret_cast<const int32_t*>(data_dims));
+ reference_ops::Tanh(data_shape, input_data, data_shape, output_data);
+}
+
+// Input and output have the same shape in LSTM
+void Mul(const ArithmeticParams& params, const int16_t* input1_data,
+ const int16_t* input2_data, int8_t* output_data, int32_t data_size) {
+ WORD32 err;
+ err = xa_nn_elm_mul_sym16sxsym16s_asym8s(
+ output_data, params.output_offset, params.output_shift,
+ params.output_multiplier, params.quantized_activation_min,
+ params.quantized_activation_max, input1_data, input2_data, data_size);
+}
+
+// Input and output have the same shape in LSTM
+void Mul(const ArithmeticParams& params, const int16_t* input1_data,
+ const int16_t* input2_data, int16_t* output_data, int32_t data_size) {
+ int dims_4D[4] = {1, 1, 1, data_size};
+ WORD32 err;
+ err = xa_nn_elm_mul_broadcast_4D_sym16sxsym16s_sym16s(
+ output_data, dims_4D, params.output_shift, params.output_multiplier,
+ params.quantized_activation_min, params.quantized_activation_max,
+ input1_data, dims_4D, input2_data, dims_4D);
+ return;
+}
+
+// Input and output have the same shape in LSTM
+void Mul(const ArithmeticParams& params, const float* input1_data,
+ const float* input2_data, float* output_data, int32_t data_size) {
+ int dims_2D[2] = {1, data_size};
+ RuntimeShape data_shape(2, reinterpret_cast<const int32_t*>(dims_2D));
+ return reference_ops::Mul(params, data_shape, input1_data, data_shape,
+ input2_data, data_shape, output_data);
+}
+
+void FullyConnected(const FullyConnectedParams& params,
+ const int8_t* input_data, const int8_t* filter_data,
+ const int32_t* bias_data, int16_t* output_data,
+ const int num_batches, const int output_depth,
+ const int accum_depth) {
+ WORD32 err;
+#pragma loop_count min = 1
+ for (int b = 0; b < num_batches; b++) {
+ err = xa_nn_matXvec_out_stride_sym8sxasym8s_16(
+ output_data + b * output_depth, filter_data,
+ input_data + b * accum_depth, bias_data, output_depth, accum_depth,
+ accum_depth, 1, params.input_offset, params.output_multiplier,
+ params.output_shift);
+ }
+ return;
+}
+
+void FullyConnected(const FullyConnectedParams& params,
+ const int16_t* input_data, const int8_t* filter_data,
+ const int64_t* bias_data, int16_t* output_data,
+ const int num_batches, const int output_depth,
+ const int accum_depth) {
+ WORD32 err;
+
+ err = xa_nn_matmul_sym8sxsym16s_sym16s(
+ output_data, filter_data, input_data, bias_data, output_depth,
+ accum_depth, accum_depth, num_batches, accum_depth, output_depth, 1,
+ params.input_offset, params.output_multiplier, params.output_shift,
+ params.output_offset);
+ return;
+}
+
+void FullyConnected(const FullyConnectedParams& params, const float* input_data,
+ const float* filter_data, const float* bias_data,
+ float* output_data, const int num_batches,
+ const int output_depth, const int accum_depth) {
+ int input_dims[2] = {num_batches, output_depth};
+ RuntimeShape input_shape(2, reinterpret_cast<const int32_t*>(input_dims));
+ RuntimeShape bias_shape(1, bias_data == NULL ? 0 : output_depth);
+ int filter_dims[2] = {output_depth, accum_depth};
+ RuntimeShape filter_shape(2, reinterpret_cast<const int32_t*>(filter_dims));
+ int output_dims[2] = {num_batches, output_depth};
+ RuntimeShape output_shape(2, reinterpret_cast<const int32_t*>(output_dims));
+ return tflite::reference_ops::FullyConnected(
+ params, input_shape, input_data, filter_shape, filter_data, bias_shape,
+ bias_data, output_shape, output_data);
+}
+#endif // #if !(defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
+
+void Clipping(const int v_size, const CellStateInfo& cell_state_info,
+ int16_t* vector) {
+ for (int i = 0; i < v_size; i++) {
+ vector[i] =
+ std::max(std::min(cell_state_info.quantized_cell_clip, vector[i]),
+ static_cast<int16_t>(-cell_state_info.quantized_cell_clip));
+ }
+}
+
+void Clipping(const int v_size, const CellStateInfo& cell_state_info,
+ float* vector) {
+ for (int i = 0; i < v_size; i++) {
+ vector[i] = std::max(std::min(cell_state_info.cell_clip, vector[i]),
+ -cell_state_info.cell_clip);
+ }
+}
+
+#if defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
+void UpdateLstmCell(const LstmStepManager& step_info,
+ TfLiteEvalTensor* cell_state,
+ // Gate outputs
+ int16_t* forget_gate_output,
+ const int16_t* input_gate_output,
+ const int16_t* cell_gate_output,
+ // Mul parameters
+ const ArithmeticParams& forget_cell_mul_params,
+ const ArithmeticParams& input_mul_params,
+ const CellStateInfo& cell_state_info, int16_t* buffer) {
+ auto cell_state_shape = step_info.StateShape();
+ // Check offset validity to avoid memory overflow
+ TFLITE_DCHECK_LE(step_info.CellStateOffset() + cell_state_shape.FlatSize(),
+ tflite::micro::GetTensorShape(cell_state).FlatSize());
+
+ WORD32 err;
+ // Multiplier is equivalent to 0.5 here so adding 1 to shifts
+ err = xa_nn_lstm_cell_state_update_16(
+ tflite::micro::GetTensorData<int16_t>(cell_state) +
+ step_info.CellStateOffset(),
+ forget_gate_output, cell_gate_output, input_gate_output,
+ forget_cell_mul_params.output_shift - 1,
+ input_mul_params.output_shift - 1, cell_state_info.quantized_cell_clip,
+ cell_state_shape.FlatSize());
+}
+
+void UpdateLstmCell(const LstmStepManager& step_info,
+ TfLiteEvalTensor* cell_state,
+ // Gate outputs
+ float* forget_gate_output, const float* input_gate_output,
+ const float* cell_gate_output,
+ // Mul parameters
+ const ArithmeticParams& forget_cell_mul_params,
+ const ArithmeticParams& input_mul_params,
+ const CellStateInfo& cell_state_info, float* buffer) {
+ // Check offset validity to avoid memory overflow
+ TFLITE_DCHECK_LE(
+ step_info.CellStateOffset() + step_info.StateShape().FlatSize(),
+ tflite::micro::GetTensorShape(cell_state).FlatSize());
+
+ auto cell_state_shape = step_info.StateShape();
+ // Forget Gate x Cell State
+ Mul(forget_cell_mul_params, forget_gate_output,
+ tflite::micro::GetTensorData<float>(cell_state) +
+ step_info.CellStateOffset(),
+ tflite::micro::GetTensorData<float>(cell_state) +
+ step_info.CellStateOffset(),
+ cell_state_shape.FlatSize());
+ // Input Gate x Cell Gate
+ Mul(input_mul_params, input_gate_output, cell_gate_output, buffer,
+ cell_state_shape.FlatSize());
+
+ // Update the cell state
+ AddElementWise(tflite::micro::GetTensorData<float>(cell_state) +
+ step_info.CellStateOffset(),
+ buffer,
+ /*n_batch=*/cell_state_shape.DimsData()[0],
+ /*n_state=*/cell_state_shape.DimsData()[1],
+ tflite::micro::GetTensorData<float>(cell_state) +
+ step_info.CellStateOffset());
+
+ if (cell_state_info.cell_clip > 0) {
+ Clipping(cell_state_shape.FlatSize(), cell_state_info,
+ tflite::micro::GetTensorData<float>(cell_state) +
+ step_info.CellStateOffset());
+ }
+}
+#endif // #if defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
+
+// Increment the data offset so the sigle time step invocation call can access
+// the corresponding input/output tensor data at the time step
+void LstmStepManager::UpdateTime() {
+ current_time_ += 1;
+ TFLITE_DCHECK_LE(current_time_, size_info_.time_steps);
+ // default as one batch per inference
+ int input_step = size_info_.input_dimension;
+ int output_step = size_info_.state_dimension;
+ // time major: batch inference
+ if (size_info_.time_major) {
+ input_step = input_step * size_info_.batch_size;
+ output_step = output_step * size_info_.batch_size;
+ }
+
+ input_offset_ += input_step;
+ output_offset_ += output_step;
+}
+
+// Increment the data offset so the sigle time step invocation call can access
+// the corresponding hidden/cell state tensor data at the time step (for single
+// batch inference only)
+void LstmStepManager::UpdateBatch() {
+ current_batch_ += 1;
+ TFLITE_DCHECK_LE(current_batch_, size_info_.batch_size);
+ // batch inference for time major: no action needed
+ if (size_info_.time_major) {
+ return;
+ }
+ // otherwise: singe batch inference, go to the next batch
+ hidden_state_offset_ += size_info_.state_dimension;
+ cell_state_offset_ += size_info_.state_dimension;
+}
+
+// Input shape for each single time LSTM invocation.
+// Multi-batch for time_major input
+RuntimeShape LstmStepManager::InputShape() const {
+ int batch_size = 1;
+ if (size_info_.time_major) {
+ batch_size = size_info_.batch_size;
+ }
+ const int dims[2] = {batch_size, size_info_.input_dimension};
+ const int32_t* dims_data = reinterpret_cast<const int32_t*>(dims);
+ return RuntimeShape(2, dims_data);
+}
+
+// State shape (both hidden and cell) for each single time LSTM invocation.
+// Multi-batch for time_major input
+RuntimeShape LstmStepManager::StateShape() const {
+ int batch_size = 1;
+ if (size_info_.time_major) {
+ batch_size = size_info_.batch_size;
+ }
+ const int dims[2] = {batch_size, size_info_.state_dimension};
+ const int32_t* dims_data = reinterpret_cast<const int32_t*>(dims);
+ return RuntimeShape(2, dims_data);
+}
+
+} // namespace lstm_internal
} // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/xtensa/lstm_eval.h b/tensorflow/lite/micro/kernels/xtensa/lstm_eval.h
index 5dd746a..0ba5e22 100644
--- a/tensorflow/lite/micro/kernels/xtensa/lstm_eval.h
+++ b/tensorflow/lite/micro/kernels/xtensa/lstm_eval.h
@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -12,205 +12,813 @@
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LITE_KERNELS_LSTM_EVAL_H_
-#define TENSORFLOW_LITE_KERNELS_LSTM_EVAL_H_
+// Functions to perform integer evaulation for standard LSTM (e.g., defined in
+// the keras lstm layer, no peephole etc.). Currently used by the 16 bits
+// activation case only
+
+#ifndef TENSORFLOW_LITE_MICRO_KERNELS_LSTM_EVAL_GENERAL_H_
+#define TENSORFLOW_LITE_MICRO_KERNELS_LSTM_EVAL_GENERAL_H_
+#include <algorithm>
#include <cstdint>
-#include <memory>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h"
-#include "tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
+#include "tensorflow/lite/micro/kernels/lstm_shared.h"
+#include "tensorflow/lite/micro/kernels/xtensa/xtensa.h"
+#include "tensorflow/lite/micro/micro_log.h"
namespace tflite {
-namespace ops {
-namespace micro {
-namespace lstm_eval {
-#if defined(HIFI5)
-void calc_cell_state_without_cifg(int16_t* cell_state,
- const int16_t* forget_gate,
- const int16_t* cell_gate,
- const int16_t* input_gate, int shift1,
- int shift2, int clip, int num_elms);
+// Interface to access all the TempTfLiteTensors of the LSTM kernel during the
+// preparation phase. Can only be constructed through the constructor to avoid
+// memory leakage. All TempTfLiteTensors will be deallocated through the
+// destructor.
+class LstmTensors {
+ public:
+ LstmTensors(const LstmTensors& other) = delete;
+ LstmTensors& operator=(const LstmTensors& other) = delete;
-void calc_cell_state_with_cifg(int16_t* cell_state, const int16_t* forget_gate,
- const int16_t* cell_gate, int shift1, int shift2,
- int clip, int num_elms);
+ LstmTensors(TfLiteContext* context, TfLiteNode* node);
+ ~LstmTensors();
-void xa_nn_elm_mul_16x16_asym8s(int8_t* output, const int16_t* input_1,
- const int16_t* input_2, int32_t multiplier,
- int32_t shift, int32_t zero_point,
- int num_elms);
-#endif // defined(HIFI5)
+ // Verify the LSTM internal tensor properties (e.g., type checks)
+ // Input/output/states/fc weights tensors are required for kernel evaulation.
+ // The state tensors should be variables. Variants of the standard LSTM
+ // are not supported here, therefore their corresponding tensors should be
+ // invalid
+ TfLiteStatus ValidateTensorStatus(TfLiteContext* context) const;
-// Pamameters for integer LSTM.
-// Consider split this into two Integer Parameters if more fields are added.
-struct IntegerLstmParameter {
- int32_t effective_input_to_input_scale_a;
- int effective_input_to_input_scale_b;
- int32_t effective_recurrent_to_input_scale_a;
- int effective_recurrent_to_input_scale_b;
- int32_t effective_cell_to_input_scale_a;
- int effective_cell_to_input_scale_b;
- int32_t effective_input_to_forget_scale_a;
- int effective_input_to_forget_scale_b;
- int32_t effective_recurrent_to_forget_scale_a;
- int effective_recurrent_to_forget_scale_b;
- int32_t effective_cell_to_forget_scale_a;
- int effective_cell_to_forget_scale_b;
- int32_t effective_input_to_cell_scale_a;
- int effective_input_to_cell_scale_b;
- int32_t effective_recurrent_to_cell_scale_a;
- int effective_recurrent_to_cell_scale_b;
- int32_t effective_input_to_output_scale_a;
- int effective_input_to_output_scale_b;
- int32_t effective_recurrent_to_output_scale_a;
- int effective_recurrent_to_output_scale_b;
- int32_t effective_cell_to_output_scale_a;
- int effective_cell_to_output_scale_b;
- int32_t effective_proj_scale_a;
- int effective_proj_scale_b;
- int32_t effective_hidden_scale_a;
- int effective_hidden_scale_b;
- int32_t layer_norm_input_scale_a;
- int layer_norm_input_scale_b;
- int32_t layer_norm_forget_scale_a;
- int layer_norm_forget_scale_b;
- int32_t layer_norm_cell_scale_a;
- int layer_norm_cell_scale_b;
- int32_t layer_norm_output_scale_a;
- int layer_norm_output_scale_b;
- // Quantized clip value for cell and projection. Zero value means no clipping.
- int16_t quantized_cell_clip;
- int8_t quantized_proj_clip;
- int32_t hidden_zp;
- int32_t cell_scale;
+ // Internal tensors. see lstm_shared.h for tensor names
+ const TfLiteTensor* GetInternalTensor(const int tensor_index) const {
+ return internal_tensors_[tensor_index];
+ }
- int32_t input_variance_guard;
- int32_t forget_variance_guard;
- int32_t cell_variance_guard;
- int32_t output_variance_guard;
+ const TfLiteTensor* HiddenStateTensor() const {
+ return internal_tensors_[kLstmOutputStateTensor];
+ }
+ const TfLiteTensor* CellStateTensor() const {
+ return internal_tensors_[kLstmCellStateTensor];
+ }
+ const TfLiteTensor* OutputTensor() const { return output_tensor_; }
- // Pre-calculate bias + zero_point * weight.
- // Unabled to use temporary tensors since those are used in Prepare() and
- // scratch buffer is only allocated after Preapre().
- std::unique_ptr<int32_t[]> input_to_forget_effective_bias;
- std::unique_ptr<int32_t[]> recurrent_to_forget_effective_bias;
- std::unique_ptr<int32_t[]> input_to_cell_effective_bias;
- std::unique_ptr<int32_t[]> recurrent_to_cell_effective_bias;
- std::unique_ptr<int32_t[]> input_to_output_effective_bias;
- std::unique_ptr<int32_t[]> recurrent_to_output_effective_bias;
- std::unique_ptr<int32_t[]> input_to_input_effective_bias;
- std::unique_ptr<int32_t[]> recurrent_to_input_effective_bias;
- std::unique_ptr<int32_t[]> projection_effective_bias;
-
- // Scale and zero point for intermediate tensors.
- // Used only in the 8x8_8 case.
- int32_t intermediate_scale_a[8];
- int32_t intermediate_scale_b[8];
- int32_t intermediate_zp[12];
+ private:
+ // see lstm_shared.h for tensor names
+ MicroContext* micro_context_;
+ TfLiteTensor* internal_tensors_[24];
+ TfLiteTensor* output_tensor_;
};
-TfLiteStatus EvalFloat(const TfLiteEvalTensor* input,
- const TfLiteEvalTensor* input_to_input_weights,
- const TfLiteEvalTensor* input_to_forget_weights,
- const TfLiteEvalTensor* input_to_cell_weights,
- const TfLiteEvalTensor* input_to_output_weights,
- const TfLiteEvalTensor* recurrent_to_input_weights,
- const TfLiteEvalTensor* recurrent_to_forget_weights,
- const TfLiteEvalTensor* recurrent_to_cell_weights,
- const TfLiteEvalTensor* recurrent_to_output_weights,
- const TfLiteEvalTensor* cell_to_input_weights,
- const TfLiteEvalTensor* cell_to_forget_weights,
- const TfLiteEvalTensor* cell_to_output_weights,
- const TfLiteEvalTensor* input_layer_norm_coefficients,
- const TfLiteEvalTensor* forget_layer_norm_coefficients,
- const TfLiteEvalTensor* cell_layer_norm_coefficients,
- const TfLiteEvalTensor* output_layer_norm_coefficients,
- const TfLiteEvalTensor* aux_input,
- const TfLiteEvalTensor* aux_input_to_input_weights,
- const TfLiteEvalTensor* aux_input_to_forget_weights,
- const TfLiteEvalTensor* aux_input_to_cell_weights,
- const TfLiteEvalTensor* aux_input_to_output_weights,
- const TfLiteEvalTensor* input_gate_bias,
- const TfLiteEvalTensor* forget_gate_bias,
- const TfLiteEvalTensor* cell_gate_bias,
- const TfLiteEvalTensor* output_gate_bias,
- const TfLiteEvalTensor* projection_weights,
- const TfLiteEvalTensor* projection_bias,
- const TfLiteLSTMParams* params, bool forward_sequence,
- bool time_major, int output_offset,
- TfLiteEvalTensor* scratch_buffer,
- TfLiteEvalTensor* output_state,
- TfLiteEvalTensor* cell_state, TfLiteEvalTensor* output);
+// Deduce the size information (Batch (B), Time Steps (T), Input dimension (I),
+// State dimension (S)) that defines the LSTM using the input and hidden state
+// tensor
+LstmSizeInfo CreateLstmSizeInfo(
+ const bool time_major, const TfLiteIntArray* input_tensor_shape,
+ const TfLiteIntArray* hidden_state_tensor_shape);
-TfLiteStatus EvalInteger8x8_16(
- TfLiteContext* context, TfLiteNode* node, const TfLiteEvalTensor* input,
- const TfLiteEvalTensor* input_to_input_weights,
- const TfLiteEvalTensor* input_to_forget_weights,
- const TfLiteEvalTensor* input_to_cell_weights,
- const TfLiteEvalTensor* input_to_output_weights,
- const TfLiteEvalTensor* recurrent_to_input_weights,
- const TfLiteEvalTensor* recurrent_to_forget_weights,
- const TfLiteEvalTensor* recurrent_to_cell_weights,
- const TfLiteEvalTensor* recurrent_to_output_weights,
- const TfLiteEvalTensor* cell_to_input_weights,
- const TfLiteEvalTensor* cell_to_forget_weights,
- const TfLiteEvalTensor* cell_to_output_weights,
- const TfLiteEvalTensor* input_layer_norm_coefficients,
- const TfLiteEvalTensor* forget_layer_norm_coefficients,
- const TfLiteEvalTensor* cell_layer_norm_coefficients,
- const TfLiteEvalTensor* output_layer_norm_coefficients,
- const TfLiteEvalTensor* input_gate_bias,
- const TfLiteEvalTensor* forget_gate_bias,
- const TfLiteEvalTensor* cell_gate_bias,
- const TfLiteEvalTensor* output_gate_bias,
- const TfLiteEvalTensor* projection_weights,
- const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
- bool forward_sequence, bool time_major,
- const lstm_eval::IntegerLstmParameter* integer_lstm_param,
- TfLiteEvalTensor* output_state, TfLiteEvalTensor* cell_state,
- TfLiteEvalTensor* output, TfLiteEvalTensor* scratch0,
- TfLiteEvalTensor* scratch1, TfLiteEvalTensor* scratch2,
- TfLiteEvalTensor* scratch3, TfLiteEvalTensor* scratch4,
- TfLiteEvalTensor* scratch5);
+TfLiteStatus ValidateWeightTensorSize(TfLiteContext* context,
+ const TfLiteTensor* tensor, int dim1_size,
+ int dim2_size);
-TfLiteStatus EvalInteger8x8_8(
- const TfLiteEvalTensor* input,
- const TfLiteEvalTensor* input_to_input_weights,
- const TfLiteEvalTensor* input_to_forget_weights,
- const TfLiteEvalTensor* input_to_cell_weights,
- const TfLiteEvalTensor* input_to_output_weights,
- const TfLiteEvalTensor* recurrent_to_input_weights,
- const TfLiteEvalTensor* recurrent_to_forget_weights,
- const TfLiteEvalTensor* recurrent_to_cell_weights,
- const TfLiteEvalTensor* recurrent_to_output_weights,
- const TfLiteEvalTensor* cell_to_input_weights,
- const TfLiteEvalTensor* cell_to_forget_weights,
- const TfLiteEvalTensor* cell_to_output_weights,
- const TfLiteEvalTensor* input_layer_norm_coefficients,
- const TfLiteEvalTensor* forget_layer_norm_coefficients,
- const TfLiteEvalTensor* cell_layer_norm_coefficients,
- const TfLiteEvalTensor* output_layer_norm_coefficients,
- const TfLiteEvalTensor* input_gate_bias,
- const TfLiteEvalTensor* forget_gate_bias,
- const TfLiteEvalTensor* cell_gate_bias,
- const TfLiteEvalTensor* output_gate_bias,
- const TfLiteEvalTensor* projection_weights,
- const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
- TfLiteEvalTensor* output_state, TfLiteEvalTensor* cell_state,
- TfLiteEvalTensor* output,
- const lstm_eval::IntegerLstmParameter* integer_lstm_param,
- TfLiteEvalTensor* scratch0, TfLiteEvalTensor* scratch1,
- TfLiteEvalTensor* scratch2, TfLiteEvalTensor* scratch3,
- TfLiteEvalTensor* scratch4, TfLiteEvalTensor* scratch5,
- TfLiteEvalTensor* scratch6, TfLiteEvalTensor* scratch7);
+TfLiteStatus ValidateBiasTensorSize(TfLiteContext* context,
+ const TfLiteTensor* tensor, int size);
-} // namespace lstm_eval
-} // namespace micro
-} // namespace ops
+// Go through every tensors and make sure their shape match the kernel
+// configuration
+TfLiteStatus ValidateTensorSize(TfLiteContext* context,
+ const LstmTensors& tensors,
+ const LstmSizeInfo& size_info);
+
+// Wrapper function to create gate parameters for the four internal LSTM gates
+TfLiteStatus CreateGateParams(
+ TfLiteContext* context,
+ /*Input tensors*/
+ const TfLiteTensor* input, const TfLiteTensor* input_weight,
+ const TfLiteTensor* input_bias,
+ /*Hidden state tensors*/
+ const TfLiteTensor* hidden_state, const TfLiteTensor* hidden_state_weight,
+ const TfLiteTensor* hidden_state_bias,
+ /*Scale of the fc output (input to non-linear activation)*/
+ const float nonlinear_activation_input_scale, const TfLiteType cell_type,
+ const tflite::GateParameters& gate_params);
+
+// Create parameters for element wise multiplication that happens in a) cell
+// state update ; b) hidden state update
+// Note that all the output of gates are symmetrically quantized so only scales
+// are required for input. However, during the hidden state update phase, the
+// output is the updated hidden state, which is asymmetrically quantized. Thus
+// output may require zero point
+tflite::ArithmeticParams CreateInterGateMulParams(const float input1_scale,
+ const float input2_scale,
+ const float output_scale,
+ const TfLiteType output_type,
+ const int output_zp = 0);
+
+// Create the additional information about the cell state, which include:
+// cell_state_scale_power: used in integer nonlinear function (e.g., tanh)
+// quantized_cell_clip: quantized cell clip range
+CellStateInfo CreateLstmCellStateInfo(const float cell_state_scale,
+ const float cell_clip);
+
+CellStateInfo CreateLstmCellStateInfoFloat(const float cell_clip);
+tflite::FullyConnectedParams CreateFCParamsFloat();
+
+tflite::GateParameters CreateGateParamsFloat();
+
+tflite::ArithmeticParams CreateInterGateMulParamsFloat();
+
+TfLiteStatus PrepareGateParametersFloat(TfLiteContext* context,
+ const LstmTensors& lstm_tensors,
+ OpDataLSTM* op_data_lstm);
+
+TfLiteStatus PrepareGateParametersInteger(TfLiteContext* context,
+ const LstmTensors& lstm_tensors,
+ OpDataLSTM* op_data_lstm);
+
+LSTMKernelContents CreateLSTMKernelContent(TfLiteContext* context,
+ TfLiteNode* node);
+
+template <typename CellType>
+LSTMBuffers<CellType> CreateLSTMBuffers(TfLiteContext* context,
+ const int* buffer_indices) {
+ LSTMBuffers<CellType> buffers;
+ buffers.buffer0 = reinterpret_cast<CellType*>(
+ context->GetScratchBuffer(context, buffer_indices[0]));
+ buffers.buffer1 = reinterpret_cast<CellType*>(
+ context->GetScratchBuffer(context, buffer_indices[1]));
+ buffers.buffer2 = reinterpret_cast<CellType*>(
+ context->GetScratchBuffer(context, buffer_indices[2]));
+ buffers.buffer3 = reinterpret_cast<CellType*>(
+ context->GetScratchBuffer(context, buffer_indices[3]));
+ return buffers;
+}
+
+// Since LSTM includes multiple intermediate stages, introducing the internal
+// namespace to expose them for testing
+namespace lstm_internal {
+
+#if !(defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
+void Sigmoid(const RuntimeShape& data_shape, int16_t* data);
+
+void Sigmoid(const RuntimeShape& data_shape, float* data);
+
+void Tanh(int32_t cell_state_scale_power, const RuntimeShape& input_data_shape,
+ int16_t* input_data, const RuntimeShape& output_data_shape,
+ int16_t* output_data);
+
+void Tanh(int32_t cell_state_scale_power, const RuntimeShape& input_data_shape,
+ float* input_data, const RuntimeShape& output_data_shape,
+ float* output_data);
+
+void Mul(const RuntimeShape& shape, const ArithmeticParams& params,
+ const int16_t* input1_data, const int16_t* input2_data,
+ int8_t* output_data);
+
+void Mul(const RuntimeShape& shape, const ArithmeticParams& params,
+ const int16_t* input1_data, const int16_t* input2_data,
+ int16_t* output_data);
+
+void Mul(const RuntimeShape& shape, const ArithmeticParams& params,
+ const float* input1_data, const float* input2_data,
+ float* output_data);
+
+void FullyConnected(const FullyConnectedParams& params,
+ const RuntimeShape& input_shape, const int8_t* input_data,
+ const RuntimeShape& filter_shape, const int8_t* filter_data,
+ const RuntimeShape& bias_shape, const int32_t* bias_data,
+ const RuntimeShape& output_shape, int16_t* output_data);
+
+void FullyConnected(const FullyConnectedParams& params,
+ const RuntimeShape& input_shape, const int16_t* input_data,
+ const RuntimeShape& filter_shape, const int8_t* filter_data,
+ const RuntimeShape& bias_shape, const int64_t* bias_data,
+ const RuntimeShape& output_shape, int16_t* output_data);
+
+void FullyConnected(const FullyConnectedParams& params,
+ const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& filter_shape, const float* filter_data,
+ const RuntimeShape& bias_shape, const float* bias_data,
+ const RuntimeShape& output_shape, float* output_data);
+#else // #if !(defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
+void Sigmoid(int16_t* data, int32_t data_size);
+
+void Sigmoid(float* data, int32_t data_size);
+
+void Tanh(int32_t cell_state_scale_power, int16_t* input_data,
+ int16_t* output_data, int32_t data_size);
+
+void Tanh(int32_t cell_state_scale_power, float* input_data, float* output_data,
+ int32_t data_size);
+
+void Mul(const ArithmeticParams& params, const int16_t* input1_data,
+ const int16_t* input2_data, int8_t* output_data, int32_t data_size);
+
+void Mul(const ArithmeticParams& params, const int16_t* input1_data,
+ const int16_t* input2_data, int16_t* output_data, int32_t data_size);
+
+void Mul(const ArithmeticParams& params, const float* input1_data,
+ const float* input2_data, float* output_data, int32_t data_size);
+
+void FullyConnected(const FullyConnectedParams& params,
+ const int8_t* input_data, const int8_t* filter_data,
+ const int32_t* bias_data, int16_t* output_data,
+ const int num_batches, const int output_depth,
+ const int accum_depth);
+
+void FullyConnected(const FullyConnectedParams& params,
+ const int16_t* input_data, const int8_t* filter_data,
+ const int64_t* bias_data, int16_t* output_data,
+ const int num_batches, const int output_depth,
+ const int accum_depth);
+
+void FullyConnected(const FullyConnectedParams& params, const float* input_data,
+ const float* filter_data, const float* bias_data,
+ float* output_data, const int num_batches,
+ const int output_depth, const int accum_depth);
+#endif // #if !(defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
+
+void AddElementWise(const int16_t* input_1, const int16_t* input_2, int n_batch,
+ int n_input, int16_t* output);
+
+void AddElementWise(const float* input_1, const float* input_2, int n_batch,
+ int n_input, float* output);
+
+void Clipping(const int v_size, const CellStateInfo& cell_state_info,
+ int16_t* vector);
+
+void Clipping(const int v_size, const CellStateInfo& cell_state_info,
+ float* vector);
+
+// Manages the slice position (offset), slice length (sliced tensor shape),
+// and update rules for input/output/hidden state/cell state tensors at each
+// time step.
+class LstmStepManager {
+ public:
+ LstmStepManager() = delete;
+ // Does not take any ownership, and all pointers must refer to valid objects
+ // that outlive the one constructed.
+ explicit LstmStepManager(const LstmSizeInfo* size_info)
+ : size_info_(*size_info) {}
+
+ void UpdateTime();
+ void UpdateBatch();
+
+ void ResetTime() { current_time_ = 0; }
+ RuntimeShape InputShape() const;
+ RuntimeShape StateShape() const;
+
+ int InputOffset() const { return input_offset_; }
+ int OutputOffset() const { return output_offset_; }
+ int HiddenStateOffset() const { return hidden_state_offset_; }
+ int CellStateOffset() const { return cell_state_offset_; }
+#if defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
+ int time_major() const { return size_info_.time_major; }
+
+ int batch_size() const { return size_info_.batch_size; }
+
+ int input_dimension() const { return size_info_.input_dimension; }
+
+ int state_dimension() const { return size_info_.state_dimension; }
+#endif
+
+ private:
+ int current_time_ = 0;
+ int current_batch_ = 0;
+ int input_offset_ = 0;
+ int output_offset_ = 0;
+ int hidden_state_offset_ = 0;
+ int cell_state_offset_ = 0;
+ // Sizeinfo is from LstmOpData, which reside in the memory arena
+ // (guarante to outlast LSTMStepManager, which reside in stack)
+ const LstmSizeInfo& size_info_;
+};
+
+// Calculates a single LSTM gate.
+// Implements the following formula:
+// gate = activate(FC(input) + FC(recurrent))
+// Activation is sigmoid except for the "cell" gate (configurable, usually tanh)
+#if !(defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
+template <typename ActivationType, typename WeightType, typename CellType,
+ typename BiasType>
+void CalculateLstmGate(
+ const LstmStepManager& step_info, const GateParameters& gate_params,
+ // Input FC
+ const TfLiteEvalTensor* input, const TfLiteEvalTensor* input_weight,
+ const TfLiteEvalTensor* input_bias,
+ // Recurrent FC
+ const TfLiteEvalTensor* recurrent, const TfLiteEvalTensor* recurrent_weight,
+ const TfLiteEvalTensor* recurrent_bias,
+ // Output
+ CellType* gate_output,
+ // Scratch arrays
+ CellType* fc_output_buffer, const TfLiteFusedActivation activation) {
+ const auto gate_output_shape = step_info.StateShape();
+ // Check offset validity to avoid memory overflow
+ TFLITE_DCHECK_LE(step_info.InputOffset() + step_info.InputShape().FlatSize(),
+ tflite::micro::GetTensorShape(input).FlatSize());
+ TFLITE_DCHECK_LE(
+ step_info.HiddenStateOffset() + step_info.StateShape().FlatSize(),
+ tflite::micro::GetTensorShape(recurrent).FlatSize());
+
+ // Input FC
+ FullyConnected(gate_params.input_fc_params, step_info.InputShape(),
+ tflite::micro::GetTensorData<ActivationType>(input) +
+ step_info.InputOffset(),
+ micro::GetTensorShape(input_weight),
+ tflite::micro::GetTensorData<WeightType>(input_weight),
+ tflite::micro::GetTensorShape(input_bias),
+ tflite::micro::GetOptionalTensorData<BiasType>(input_bias),
+ gate_output_shape, gate_output);
+
+ // Recurrent FC
+ FullyConnected(gate_params.recurrent_fc_params, step_info.StateShape(),
+ tflite::micro::GetTensorData<ActivationType>(recurrent) +
+ step_info.HiddenStateOffset(),
+ tflite::micro::GetTensorShape(recurrent_weight),
+ tflite::micro::GetTensorData<WeightType>(recurrent_weight),
+ tflite::micro::GetTensorShape(recurrent_bias),
+ tflite::micro::GetOptionalTensorData<BiasType>(recurrent_bias),
+ gate_output_shape, fc_output_buffer);
+
+ AddElementWise(gate_output, fc_output_buffer,
+ /*n_batch=*/gate_output_shape.DimsData()[0],
+ /*n_state=*/gate_output_shape.DimsData()[1], gate_output);
+ // Apply activation
+ switch (activation) {
+ case kTfLiteActSigmoid:
+ Sigmoid(gate_output_shape, gate_output);
+ break;
+ case kTfLiteActTanh: {
+ // Set the scale power to -12 to avoid shift
+ Tanh(/*cell_state_scale_power=*/-12, gate_output_shape, gate_output,
+ gate_output_shape, gate_output);
+ } break;
+ default:
+ // Only Sigmoid or Tanh is used.
+ TFLITE_ASSERT_FALSE;
+ }
+}
+
+// Update the cell state using the output from the forget gate, input gate, and
+// cell gate Formula: updated_cell_state = forget_gate_output*cell_state +
+// input_gate_output * cell_gate_output, where * denotes element wise
+// multiplication
+template <typename CellType>
+void UpdateLstmCell(const LstmStepManager& step_info,
+ TfLiteEvalTensor* cell_state,
+ // Gate outputs
+ CellType* forget_gate_output,
+ const CellType* input_gate_output,
+ const CellType* cell_gate_output,
+ // Mul parameters
+ const ArithmeticParams& forget_cell_mul_params,
+ const ArithmeticParams& input_mul_params,
+ const CellStateInfo& cell_state_info, CellType* buffer) {
+ // Check offset validity to avoid memory overflow
+ TFLITE_DCHECK_LE(
+ step_info.CellStateOffset() + step_info.StateShape().FlatSize(),
+ tflite::micro::GetTensorShape(cell_state).FlatSize());
+
+ auto cell_state_shape = step_info.StateShape();
+ // Forget Gate x Cell State
+ Mul(cell_state_shape, forget_cell_mul_params, forget_gate_output,
+ tflite::micro::GetTensorData<CellType>(cell_state) +
+ step_info.CellStateOffset(),
+ tflite::micro::GetTensorData<CellType>(cell_state) +
+ step_info.CellStateOffset());
+ // Input Gate x Cell Gate
+ Mul(cell_state_shape, input_mul_params, input_gate_output, cell_gate_output,
+ buffer);
+
+ // Update the cell state
+ AddElementWise(tflite::micro::GetTensorData<CellType>(cell_state) +
+ step_info.CellStateOffset(),
+ buffer,
+ /*n_batch=*/cell_state_shape.DimsData()[0],
+ /*n_state=*/cell_state_shape.DimsData()[1],
+ tflite::micro::GetTensorData<CellType>(cell_state) +
+ step_info.CellStateOffset());
+
+ if (cell_state_info.cell_clip > 0) {
+ Clipping(cell_state_shape.FlatSize(), cell_state_info,
+ tflite::micro::GetTensorData<CellType>(cell_state) +
+ step_info.CellStateOffset());
+ }
+}
+#else // #if !(defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
+template <typename ActivationType, typename WeightType, typename CellType,
+ typename BiasType>
+void CalculateLstmGate(
+ const LstmStepManager& step_info, const GateParameters& gate_params,
+ // Input FC
+ const TfLiteEvalTensor* input, const TfLiteEvalTensor* input_weight,
+ const TfLiteEvalTensor* input_bias,
+ // Recurrent FC
+ const TfLiteEvalTensor* recurrent, const TfLiteEvalTensor* recurrent_weight,
+ const TfLiteEvalTensor* recurrent_bias,
+ // Output
+ CellType* gate_output,
+ // Scratch arrays
+ CellType* fc_output_buffer, const TfLiteFusedActivation activation,
+ const int num_batches, const int input_dimension,
+ const int state_dimension) {
+ // RuntimeShape step_input_shape = step_info.InputShape();
+ // RuntimeShape input_shape = tflite::micro::GetTensorShape(input);
+ // RuntimeShape step_state_shape = step_info.StateShape();
+ // RuntimeShape recurrent_shape = tflite::micro::GetTensorShape(recurrent);
+
+ // Moved these to LstmStep function
+ // Check offset validity to avoid memory overflow
+ // TFLITE_DCHECK_LE(step_info.InputOffset() + step_input_shape.FlatSize(),
+ // input_shape.FlatSize());
+ // TFLITE_DCHECK_LE(
+ // step_info.HiddenStateOffset() + step_state_shape.FlatSize(),
+ // recurrent_shape.FlatSize());
+
+ // Input FC
+ FullyConnected(gate_params.input_fc_params,
+ tflite::micro::GetTensorData<ActivationType>(input) +
+ step_info.InputOffset(),
+ tflite::micro::GetTensorData<WeightType>(input_weight),
+ tflite::micro::GetOptionalTensorData<BiasType>(input_bias),
+ gate_output, num_batches, state_dimension, input_dimension);
+
+ // Recurrent FC
+ FullyConnected(gate_params.recurrent_fc_params,
+ tflite::micro::GetTensorData<ActivationType>(recurrent) +
+ step_info.HiddenStateOffset(),
+ tflite::micro::GetTensorData<WeightType>(recurrent_weight),
+ tflite::micro::GetOptionalTensorData<BiasType>(recurrent_bias),
+ fc_output_buffer, num_batches, state_dimension,
+ state_dimension);
+
+ AddElementWise(gate_output, fc_output_buffer,
+ /*n_batch=*/num_batches,
+ /*n_state=*/state_dimension, gate_output);
+ // Apply activation
+ switch (activation) {
+ case kTfLiteActSigmoid:
+ Sigmoid(gate_output, num_batches * state_dimension);
+ break;
+ case kTfLiteActTanh: {
+ // Set the scale power to -12 to avoid shift
+ Tanh(/*cell_state_scale_power=*/-12, gate_output, gate_output,
+ num_batches * state_dimension);
+ } break;
+ default:
+ // Only Sigmoid or Tanh is used.
+ TFLITE_ASSERT_FALSE;
+ }
+}
+
+// Update the cell state using the output from the forget gate, input gate, and
+// cell gate Formula: updated_cell_state = forget_gate_output*cell_state +
+// input_gate_output * cell_gate_output, where * denotes element wise
+// multiplication
+void UpdateLstmCell(const LstmStepManager& step_info,
+ TfLiteEvalTensor* cell_state,
+ // Gate outputs
+ int16_t* forget_gate_output,
+ const int16_t* input_gate_output,
+ const int16_t* cell_gate_output,
+ // Mul parameters
+ const ArithmeticParams& forget_cell_mul_params,
+ const ArithmeticParams& input_mul_params,
+ const CellStateInfo& cell_state_info, int16_t* buffer);
+
+void UpdateLstmCell(const LstmStepManager& step_info,
+ TfLiteEvalTensor* cell_state,
+ // Gate outputs
+ float* forget_gate_output, const float* input_gate_output,
+ const float* cell_gate_output,
+ // Mul parameters
+ const ArithmeticParams& forget_cell_mul_params,
+ const ArithmeticParams& input_mul_params,
+ const CellStateInfo& cell_state_info, float* buffer);
+#endif // #if !(defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
+
+// Update the hidden state of the LSTM kernel using the following formula:
+// updated_hidden_state = Tanh(updated_cell_state) * output_gate_output, * means
+// element wise multiplication
+template <typename CellType, typename ActivationType>
+void UpdateLstmHidden(const LstmStepManager& step_info,
+ TfLiteEvalTensor* cell_state,
+ TfLiteEvalTensor* hidden_state,
+ const CellType* output_gate_output,
+ const ArithmeticParams& mul_params,
+ int32_t cell_state_scale_power, CellType* buffer) {
+ // Check offset validity to avoid memory overflow
+ TFLITE_DCHECK_LE(
+ step_info.CellStateOffset() + step_info.StateShape().FlatSize(),
+ tflite::micro::GetTensorShape(cell_state).FlatSize());
+ TFLITE_DCHECK_LE(
+ step_info.HiddenStateOffset() + step_info.StateShape().FlatSize(),
+ tflite::micro::GetTensorShape(hidden_state).FlatSize());
+
+ auto cell_state_shape = step_info.StateShape();
+ CellType* cell_state_data =
+ tflite::micro::GetTensorData<CellType>(cell_state) +
+ step_info.CellStateOffset();
+ // Tanh(cell_state)
+#if !(defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
+ Tanh(cell_state_scale_power, cell_state_shape, cell_state_data,
+ cell_state_shape, buffer);
+ // Update the hidden state
+ Mul(cell_state_shape, mul_params, buffer, output_gate_output,
+ tflite::micro::GetTensorData<ActivationType>(hidden_state) +
+ step_info.HiddenStateOffset());
+#else
+ int32_t cell_state_size = cell_state_shape.FlatSize();
+ Tanh(cell_state_scale_power, cell_state_data, buffer, cell_state_size);
+ // Update the hidden state
+ Mul(mul_params, buffer, output_gate_output,
+ tflite::micro::GetTensorData<ActivationType>(hidden_state) +
+ step_info.HiddenStateOffset(),
+ cell_state_size);
+#endif
+}
+
+#if !(defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
+template <typename ActivationType, typename WeightType, typename CellType,
+ typename BiasType>
+void LstmStep(const LstmStepManager& step_info, const OpDataLSTM& op_data,
+ LSTMKernelContents& kernel_content,
+ const LSTMBuffers<CellType>& buffers) {
+ /*Step1: Calculate gate outputs to prepare cell state update*/
+ CellType* gate_internal_buffer = buffers.buffer3;
+ CellType* forget_gate_output = buffers.buffer0;
+ CalculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
+ step_info, op_data.forget_gate_parameters,
+ // Input FC
+ kernel_content.GetInternalTensor(tflite::kLstmInputTensor),
+ kernel_content.GetInternalTensor(tflite::kLstmInputToForgetWeightsTensor),
+ kernel_content.GetInternalTensor(tflite::kLstmForgetGateBiasTensor),
+ // Recurrent FC
+ kernel_content.HiddenStateTensor(),
+ kernel_content.GetInternalTensor(
+ tflite::kLstmRecurrentToForgetWeightsTensor),
+ /*recurrent_bias*/ nullptr,
+ // Output
+ forget_gate_output,
+ // Scratch arrays
+ gate_internal_buffer, kTfLiteActSigmoid);
+
+ // Input Gate calculation;
+ CellType* input_gate_output = buffers.buffer1;
+ CalculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
+ step_info, op_data.input_gate_parameters,
+ // Input FC
+ kernel_content.GetInternalTensor(tflite::kLstmInputTensor),
+ kernel_content.GetInternalTensor(tflite::kLstmInputToInputWeightsTensor),
+ kernel_content.GetInternalTensor(tflite::kLstmInputGateBiasTensor),
+ // Recurrent FC
+ kernel_content.HiddenStateTensor(),
+ kernel_content.GetInternalTensor(
+ tflite::kLstmRecurrentToInputWeightsTensor),
+ /*recurrent_bias*/ nullptr,
+ // Output
+ input_gate_output,
+ // Scratch arrays
+ gate_internal_buffer, kTfLiteActSigmoid);
+
+ // Cell Gate calculation
+ CellType* cell_gate_output = buffers.buffer2;
+ CalculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
+ step_info, op_data.cell_gate_parameters,
+ // Input FC
+ kernel_content.GetInternalTensor(tflite::kLstmInputTensor),
+ kernel_content.GetInternalTensor(tflite::kLstmInputToCellWeightsTensor),
+ kernel_content.GetInternalTensor(tflite::kLstmCellGateBiasTensor),
+ // Recurrent FC
+ kernel_content.HiddenStateTensor(),
+ kernel_content.GetInternalTensor(
+ tflite::kLstmRecurrentToCellWeightsTensor),
+ /*recurrent_bias*/ nullptr,
+ // Output
+ cell_gate_output,
+ // Scratch arrays
+ gate_internal_buffer, op_data.cell_gate_nonlinear_type);
+
+ /*Step2: update the cell state */
+ const InterGateParameters& inter_gate_params = op_data.inter_gate_parameters;
+ CellType* updated_input_buffer = buffers.buffer1; // reuse buffer
+
+ UpdateLstmCell<CellType>(step_info, kernel_content.CellStateTensor(),
+ forget_gate_output, input_gate_output,
+ cell_gate_output,
+ inter_gate_params.forget_cell_mul_params,
+ inter_gate_params.input_mul_params,
+ op_data.cell_state_info, updated_input_buffer);
+
+ /*Step3: update the hidden state */
+ CellType* output_gate_output = buffers.buffer1; // reuse buffer
+ CalculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
+ step_info, op_data.output_gate_parameters,
+ // Input FC
+ kernel_content.GetInternalTensor(tflite::kLstmInputTensor),
+ kernel_content.GetInternalTensor(tflite::kLstmInputToOutputWeightsTensor),
+ kernel_content.GetInternalTensor(tflite::kLstmOutputGateBiasTensor),
+ // Recurrent FC
+ kernel_content.HiddenStateTensor(),
+ kernel_content.GetInternalTensor(
+ tflite::kLstmRecurrentToOutputWeightsTensor),
+ /*recurrent_bias*/ nullptr,
+ // Output
+ output_gate_output,
+ // Scratch arrays
+ gate_internal_buffer, kTfLiteActSigmoid);
+
+ CellType* tanh_activated_cell_buffer = buffers.buffer0; // reuse buffer
+ tflite::lstm_internal::UpdateLstmHidden<CellType, ActivationType>(
+ step_info, kernel_content.CellStateTensor(),
+ kernel_content.HiddenStateTensor(), output_gate_output,
+ inter_gate_params.output_mul_params,
+ op_data.cell_state_info.cell_state_scale_power,
+ tanh_activated_cell_buffer);
+
+ /*Step4: copy the update the hidden state to output*/
+ // Check offset validity to avoid memory overflow
+ TFLITE_DCHECK_LE(
+ step_info.OutputOffset() + step_info.StateShape().FlatSize(),
+ tflite::micro::GetTensorShape(kernel_content.output_tensor).FlatSize());
+ // record the output (from the updated hidden state)
+ ActivationType* output_ptr = tflite::micro::GetTensorData<ActivationType>(
+ kernel_content.output_tensor);
+ const auto* hidden_state = kernel_content.HiddenStateTensor();
+ std::memcpy(output_ptr + step_info.OutputOffset(),
+ tflite::micro::GetTensorData<ActivationType>(hidden_state) +
+ step_info.HiddenStateOffset(),
+ step_info.StateShape().FlatSize() * sizeof(ActivationType));
+}
+#else // #if !(defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
+template <typename ActivationType, typename WeightType, typename CellType,
+ typename BiasType>
+void LstmStep(const LstmStepManager& step_info, const OpDataLSTM& op_data,
+ LSTMKernelContents& kernel_content,
+ const LSTMBuffers<CellType>& buffers) {
+ const TfLiteEvalTensor* input =
+ kernel_content.GetInternalTensor(tflite::kLstmInputTensor);
+ TfLiteEvalTensor* recurrent = kernel_content.HiddenStateTensor();
+
+ int time_major = step_info.time_major();
+ int num_batches = time_major == 0 ? 1 : step_info.batch_size();
+ int input_dimension = step_info.input_dimension();
+ int state_dimension = step_info.state_dimension();
+
+ // Check offset validity to avoid memory overflow
+ TFLITE_DCHECK_LE(step_info.InputOffset() + num_batches * input_dimension,
+ tflite::micro::GetTensorShape(input).FlatSize());
+ TFLITE_DCHECK_LE(
+ step_info.HiddenStateOffset() + num_batches * state_dimension,
+ tflite::micro::GetTensorShape(recurrent).FlatSize());
+
+ /*Step1: Calculate gate outputs to prepare cell state update*/
+ CellType* gate_internal_buffer = buffers.buffer3;
+ CellType* forget_gate_output = buffers.buffer0;
+ CalculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
+ step_info, op_data.forget_gate_parameters,
+ // Input FC
+ input, // kernel_content.GetInternalTensor(tflite::kLstmInputTensor),
+ kernel_content.GetInternalTensor(tflite::kLstmInputToForgetWeightsTensor),
+ kernel_content.GetInternalTensor(tflite::kLstmForgetGateBiasTensor),
+ // Recurrent FC
+ recurrent, // kernel_content.HiddenStateTensor(),
+ kernel_content.GetInternalTensor(
+ tflite::kLstmRecurrentToForgetWeightsTensor),
+ /*recurrent_bias*/ nullptr,
+ // Output
+ forget_gate_output,
+ // Scratch arrays
+ gate_internal_buffer, kTfLiteActSigmoid, num_batches, input_dimension,
+ state_dimension);
+
+ // Input Gate calculation;
+ CellType* input_gate_output = buffers.buffer1;
+ CalculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
+ step_info, op_data.input_gate_parameters,
+ // Input FC
+ input, // kernel_content.GetInternalTensor(tflite::kLstmInputTensor),
+ kernel_content.GetInternalTensor(tflite::kLstmInputToInputWeightsTensor),
+ kernel_content.GetInternalTensor(tflite::kLstmInputGateBiasTensor),
+ // Recurrent FC
+ recurrent, // kernel_content.HiddenStateTensor(),
+ kernel_content.GetInternalTensor(
+ tflite::kLstmRecurrentToInputWeightsTensor),
+ /*recurrent_bias*/ nullptr,
+ // Output
+ input_gate_output,
+ // Scratch arrays
+ gate_internal_buffer, kTfLiteActSigmoid, num_batches, input_dimension,
+ state_dimension);
+
+ // Cell Gate calculation
+ CellType* cell_gate_output = buffers.buffer2;
+ CalculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
+ step_info, op_data.cell_gate_parameters,
+ // Input FC
+ input, // kernel_content.GetInternalTensor(tflite::kLstmInputTensor),
+ kernel_content.GetInternalTensor(tflite::kLstmInputToCellWeightsTensor),
+ kernel_content.GetInternalTensor(tflite::kLstmCellGateBiasTensor),
+ // Recurrent FC
+ recurrent, // kernel_content.HiddenStateTensor(),
+ kernel_content.GetInternalTensor(
+ tflite::kLstmRecurrentToCellWeightsTensor),
+ /*recurrent_bias*/ nullptr,
+ // Output
+ cell_gate_output,
+ // Scratch arrays
+ gate_internal_buffer, op_data.cell_gate_nonlinear_type, num_batches,
+ input_dimension, state_dimension);
+
+ /*Step2: update the cell state */
+ const InterGateParameters& inter_gate_params = op_data.inter_gate_parameters;
+ CellType* updated_input_buffer = buffers.buffer1; // reuse buffer
+
+ UpdateLstmCell(step_info, kernel_content.CellStateTensor(),
+ forget_gate_output, input_gate_output, cell_gate_output,
+ inter_gate_params.forget_cell_mul_params,
+ inter_gate_params.input_mul_params, op_data.cell_state_info,
+ updated_input_buffer);
+
+ /*Step3: update the hidden state */
+ CellType* output_gate_output = buffers.buffer1; // reuse buffer
+ CalculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
+ step_info, op_data.output_gate_parameters,
+ // Input FC
+ input, // kernel_content.GetInternalTensor(tflite::kLstmInputTensor),
+ kernel_content.GetInternalTensor(tflite::kLstmInputToOutputWeightsTensor),
+ kernel_content.GetInternalTensor(tflite::kLstmOutputGateBiasTensor),
+ // Recurrent FC
+ recurrent, // kernel_content.HiddenStateTensor(),
+ kernel_content.GetInternalTensor(
+ tflite::kLstmRecurrentToOutputWeightsTensor),
+ /*recurrent_bias*/ nullptr,
+ // Output
+ output_gate_output,
+ // Scratch arrays
+ gate_internal_buffer, kTfLiteActSigmoid, num_batches, input_dimension,
+ state_dimension);
+
+ CellType* tanh_activated_cell_buffer = buffers.buffer0; // reuse buffer
+ tflite::lstm_internal::UpdateLstmHidden<CellType, ActivationType>(
+ step_info, kernel_content.CellStateTensor(), recurrent,
+ /* kernel_content.HiddenStateTensor(), */ output_gate_output,
+ inter_gate_params.output_mul_params,
+ op_data.cell_state_info.cell_state_scale_power,
+ tanh_activated_cell_buffer);
+
+ /*Step4: copy the update the hidden state to output*/
+ // Check offset validity to avoid memory overflow
+ TFLITE_DCHECK_LE(
+ step_info.OutputOffset() + step_info.StateShape().FlatSize(),
+ tflite::micro::GetTensorShape(kernel_content.output_tensor).FlatSize());
+ // record the output (from the updated hidden state)
+ ActivationType* output_ptr = tflite::micro::GetTensorData<ActivationType>(
+ kernel_content.output_tensor);
+ // const auto* hidden_state = kernel_content.HiddenStateTensor();
+ std::memcpy(output_ptr + step_info.OutputOffset(),
+ tflite::micro::GetTensorData<ActivationType>(recurrent) +
+ step_info.HiddenStateOffset(),
+ step_info.StateShape().FlatSize() * sizeof(ActivationType));
+}
+#endif // #if !(defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
+
+} // namespace lstm_internal
+
+// Evaulate the LSTM kernel with (potential) multi-steps and multi-batch input
+// Since
+template <typename ActivationType, typename WeightType, typename CellType,
+ typename BiasType>
+TfLiteStatus EvalLstm(const OpDataLSTM& op_data,
+ LSTMKernelContents& kernel_content,
+ const LSTMBuffers<CellType>& buffers) {
+ lstm_internal::LstmStepManager step_info(&op_data.size_info);
+ const auto& size_info = op_data.size_info;
+ // time is the first dimention, enable batch computation
+ if (size_info.time_major) {
+ for (int t = 0; t < size_info.time_steps; t++) {
+ lstm_internal::LstmStep<ActivationType, WeightType, CellType, BiasType>(
+ step_info, op_data, kernel_content, buffers);
+ // prepare for the next time step
+ step_info.UpdateTime();
+ }
+ } else {
+ // batch first, unable to size the input data. single batch inference
+ for (int b = 0; b < size_info.batch_size; b++) {
+ for (int t = 0; t < size_info.time_steps; t++) {
+ lstm_internal::LstmStep<ActivationType, WeightType, CellType, BiasType>(
+ step_info, op_data, kernel_content, buffers);
+ // prepare for the next time step
+ step_info.UpdateTime();
+ }
+ // prepare for the next batch
+ step_info.UpdateBatch();
+ step_info.ResetTime();
+ }
+ }
+ return kTfLiteOk;
+}
} // namespace tflite
-#endif // TENSORFLOW_LITE_KERNELS_LSTM_EVAL_H_
+
+#endif // TENSORFLOW_LITE_MICRO_KERNELS_LSTM_EVAL_16ACT_H_
diff --git a/tensorflow/lite/micro/kernels/xtensa/lstm_eval_hifi.cc b/tensorflow/lite/micro/kernels/xtensa/lstm_eval_hifi.cc
index 2b49f26..a2d04ee 100644
--- a/tensorflow/lite/micro/kernels/xtensa/lstm_eval_hifi.cc
+++ b/tensorflow/lite/micro/kernels/xtensa/lstm_eval_hifi.cc
@@ -12,17 +12,65 @@
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#if defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
+
+#include <xtensa/tie/xt_hifi2.h>
+
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/kernels/xtensa/lstm_eval.h"
#include "tensorflow/lite/micro/kernels/xtensa/xtensa.h"
namespace tflite {
-namespace ops {
-namespace micro {
-namespace lstm_eval {
#if defined(HIFI5)
+#if TFLITE_SINGLE_ROUNDING
+#define MPY_BY_QUANT_MULT_X2_OUT32(out, inp, multiplier, left_shift, \
+ right_shift) \
+ { \
+ ae_int64 out64_0, out64_1; \
+ ae_int64 INT64_ONE = AE_MOVINT64_FROMINT32X2(AE_MOVDA32X2(0, 1)); \
+ ae_int64 round_val = AE_SLAA64S(INT64_ONE, 30 - left_shift); \
+ AE_MUL32X2S_HH_LL(out64_0, out64_1, inp, AE_MOVDA32(multiplier)); \
+ out64_0 = AE_ADD64S(out64_0, round_val); \
+ out64_1 = AE_ADD64S(out64_1, round_val); \
+ out = AE_TRUNCA32X2F64S(out64_0, out64_1, 1 + left_shift); \
+ }
+
+#define MPY_BY_QUANT_MULT_X2X2_OUT32(out1, out2, inp1, inp2, multiplier, \
+ left_shift, right_shift) \
+ { \
+ ae_int64 out64_0, out64_1, out64_2, out64_3; \
+ ae_int64 INT64_ONE = AE_MOVINT64_FROMINT32X2(AE_MOVDA32X2(0, 1)); \
+ ae_int64 round_val = AE_SLAA64S(INT64_ONE, 30 - left_shift); \
+ AE_MUL32X2S_HH_LL(out64_0, out64_1, inp1, AE_MOVDA32(multiplier)); \
+ AE_MUL32X2S_HH_LL(out64_2, out64_3, inp2, AE_MOVDA32(multiplier)); \
+ out64_0 = AE_ADD64S(out64_0, round_val); \
+ out64_1 = AE_ADD64S(out64_1, round_val); \
+ out64_2 = AE_ADD64S(out64_2, round_val); \
+ out64_3 = AE_ADD64S(out64_3, round_val); \
+ out1 = AE_TRUNCA32X2F64S(out64_0, out64_1, 1 + left_shift); \
+ out2 = AE_TRUNCA32X2F64S(out64_2, out64_3, 1 + left_shift); \
+ }
+#else /* #if TFLITE_SINGLE_ROUNDING */
+#define MPY_BY_QUANT_MULT_X2_OUT32(out, inp, multiplier, left_shift, \
+ right_shift) \
+ out = AE_SLAA32(inp, left_shift); \
+ out = AE_MULFP32X2RAS(out, AE_MOVDA32(multiplier)); \
+ out = AE_SRAA32SYMS(out, right_shift);
+
+#define MPY_BY_QUANT_MULT_X2X2_OUT32(out1, out2, inp1, inp2, multiplier, \
+ left_shift, right_shift) \
+ { \
+ ae_int32x2 d_ls = AE_MOVDA32(1 << left_shift); \
+ AE_MUL2P32X4(out1, out2, inp1, inp2, d_ls, d_ls); \
+ AE_MULF2P32X4RAS(out1, out2, out1, out2, AE_MOVDA32(multiplier), \
+ AE_MOVDA32(multiplier)); \
+ out1 = AE_SRAA32SYMS(out1, right_shift); \
+ out2 = AE_SRAA32SYMS(out2, right_shift); \
+ }
+#endif /* #if TFLITE_SINGLE_ROUNDING */
+
void calc_cell_state_without_cifg(int16_t* cell_state,
const int16_t* forget_gate,
const int16_t* cell_gate,
@@ -124,7 +172,7 @@
AE_MUL16X4(d_mul_0, d_mul_1, d_cg_0, d_ig_0);
d_mul_0 = AE_SRAA32SYMS(d_mul_0, shift2);
- d_cg_0 = AE_SAT16X4(d_mul_0, d_mul_1);
+ d_cg_0 = AE_SAT16X4(d_mul_0, d_mul_0);
d_cs_w_0 = AE_ADD16S(d_cs_w_0, d_cg_0);
AE_MINMAX16(d_cs_w_0, d_min, d_max);
@@ -187,11 +235,11 @@
AE_MUL16X4(d_mul_0, d_mul_1, d_cs_r_0, d_fg_0);
d_mul_0 = AE_SRAA32SYMS(d_mul_0, shift1);
- d_cs_w_0 = AE_SAT16X4(d_mul_0, d_mul_1);
+ d_cs_w_0 = AE_SAT16X4(d_mul_0, d_mul_0);
AE_MUL16X4(d_mul_0, d_mul_1, d_cg_0, d_ig_0);
d_mul_0 = AE_SRAA32SYMS(d_mul_0, shift2);
- d_cg_0 = AE_SAT16X4(d_mul_0, d_mul_1);
+ d_cg_0 = AE_SAT16X4(d_mul_0, d_mul_0);
d_cs_w_0 = AE_ADD16S(d_cs_w_0, d_cg_0);
AE_MINMAX16(d_cs_w_0, d_min, d_max);
@@ -298,7 +346,7 @@
d_1mfg_0 = AE_SUB16S(d_one, d_fg_0);
AE_MUL16X4(d_mul_0, d_mul_1, d_cg_0, d_1mfg_0);
d_mul_0 = AE_SRAA32SYMS(d_mul_0, shift2);
- d_cg_0 = AE_SAT16X4(d_mul_0, d_mul_1);
+ d_cg_0 = AE_SAT16X4(d_mul_0, d_mul_0);
d_cs_w_0 = AE_ADD16S(d_cs_w_0, d_cg_0);
AE_MINMAX16(d_cs_w_0, d_min, d_max);
@@ -360,12 +408,12 @@
AE_MUL16X4(d_mul_0, d_mul_1, d_cs_r_0, d_fg_0);
d_mul_0 = AE_SRAA32SYMS(d_mul_0, shift1);
- d_cs_w_0 = AE_SAT16X4(d_mul_0, d_mul_1);
+ d_cs_w_0 = AE_SAT16X4(d_mul_0, d_mul_0);
d_1mfg_0 = AE_SUB16S(d_one, d_fg_0);
AE_MUL16X4(d_mul_0, d_mul_1, d_cg_0, d_1mfg_0);
d_mul_0 = AE_SRAA32SYMS(d_mul_0, shift2);
- d_cg_0 = AE_SAT16X4(d_mul_0, d_mul_1);
+ d_cg_0 = AE_SAT16X4(d_mul_0, d_mul_0);
d_cs_w_0 = AE_ADD16S(d_cs_w_0, d_cg_0);
AE_MINMAX16(d_cs_w_0, d_min, d_max);
@@ -404,8 +452,13 @@
d_multiplier = AE_MOVDA32(multiplier);
d_zp = AE_MOVDA16(zero_point);
+#if TFLITE_SINGLE_ROUNDING
+ left_shift = shift;
+ (void)right_shift;
+#else /* #if TFLITE_SINGLE_ROUNDING */
left_shift = shift < 0 ? 0 : shift;
right_shift = shift > 0 ? 0 : -shift;
+#endif /* #if TFLITE_SINGLE_ROUNDING */
d_left_shift = AE_MOVDA32(1 << left_shift);
#pragma concurrent
@@ -415,18 +468,10 @@
AE_MUL16X4(data_ab_0, data_ab_1, data_a_0, data_b_0);
AE_MUL16X4(data_ab_2, data_ab_3, data_a_1, data_b_1);
- AE_MUL2P32X4(data_ab_0, data_ab_1, data_ab_0, data_ab_1, d_left_shift,
- d_left_shift);
- AE_MUL2P32X4(data_ab_2, data_ab_3, data_ab_2, data_ab_3, d_left_shift,
- d_left_shift);
- AE_MULF2P32X4RAS(data_ab_0, data_ab_1, data_ab_0, data_ab_1, d_multiplier,
- d_multiplier);
- AE_MULF2P32X4RAS(data_ab_2, data_ab_3, data_ab_2, data_ab_3, d_multiplier,
- d_multiplier);
- data_ab_0 = AE_SRAA32SYMS(data_ab_0, right_shift);
- data_ab_1 = AE_SRAA32SYMS(data_ab_1, right_shift);
- data_ab_2 = AE_SRAA32SYMS(data_ab_2, right_shift);
- data_ab_3 = AE_SRAA32SYMS(data_ab_3, right_shift);
+ MPY_BY_QUANT_MULT_X2X2_OUT32(data_ab_0, data_ab_1, data_ab_0, data_ab_1,
+ multiplier, left_shift, right_shift);
+ MPY_BY_QUANT_MULT_X2X2_OUT32(data_ab_2, data_ab_3, data_ab_2, data_ab_3,
+ multiplier, left_shift, right_shift);
data_c_0 = AE_SAT16X4(data_ab_0, data_ab_1);
data_c_1 = AE_SAT16X4(data_ab_2, data_ab_3);
data_c_0 = AE_SUB16S(data_c_0, d_zp);
@@ -445,18 +490,532 @@
AE_L16_IP(data_b_0, (ae_int16*)tmp_input_2, 2);
AE_MUL16X4(data_ab_0, data_ab_1, data_a_0, data_b_0);
- data_ab_0 = AE_MULP32X2(data_ab_0, d_left_shift);
- data_ab_0 = AE_MULFP32X2RAS(data_ab_0, d_multiplier);
- data_ab_0 = AE_SRAA32SYMS(data_ab_0, right_shift);
- data_c_0 = AE_SAT16X4(data_ab_0, data_ab_1);
+ MPY_BY_QUANT_MULT_X2_OUT32(data_ab_0, data_ab_0, multiplier, left_shift,
+ right_shift);
+ data_c_0 = AE_SAT16X4(data_ab_0, data_ab_0);
data_c_0 = AE_SUB16S(data_c_0, d_zp);
data_c = AE_SAT8X8X16(data_c_0, data_c_0);
AE_S8_0_IP(data_c, (ae_int8*)output, 1);
}
}
+#elif defined(HIFI3) || defined(HIFI4)
+#if TFLITE_SINGLE_ROUNDING
+#define MPY_BY_QUANT_MULT_X2_OUT32(out, inp, multiplier, l_shift, r_shift) \
+ { \
+ ae_int64 out64_0, out64_1; \
+ out64_0 = AE_MUL32_HH(inp, AE_MOVDA32(multiplier)); \
+ out64_1 = AE_MUL32_LL(inp, AE_MOVDA32(multiplier)); \
+ out64_0 = AE_SLAA64S(out64_0, 1 + l_shift); \
+ out64_1 = AE_SLAA64S(out64_1, 1 + l_shift); \
+ out = AE_ROUND32X2F64SASYM(out64_0, out64_1); \
+ }
+
+#define MPY_BY_QUANT_MULT_X2X2_OUT32(out1, out2, inp1, inp2, multiplier, \
+ l_shift, r_shift) \
+ { \
+ ae_int64 out64_0, out64_1, out64_2, out64_3; \
+ out64_0 = AE_MUL32_HH(inp1, AE_MOVDA32(multiplier)); \
+ out64_1 = AE_MUL32_LL(inp1, AE_MOVDA32(multiplier)); \
+ out64_2 = AE_MUL32_HH(inp2, AE_MOVDA32(multiplier)); \
+ out64_3 = AE_MUL32_LL(inp2, AE_MOVDA32(multiplier)); \
+ out64_0 = AE_SLAA64S(out64_0, 1 + l_shift); \
+ out64_1 = AE_SLAA64S(out64_1, 1 + l_shift); \
+ out64_2 = AE_SLAA64S(out64_2, 1 + l_shift); \
+ out64_3 = AE_SLAA64S(out64_3, 1 + l_shift); \
+ out1 = AE_ROUND32X2F64SASYM(out64_0, out64_1); \
+ out2 = AE_ROUND32X2F64SASYM(out64_2, out64_3); \
+ }
+#else /* #if TFLITE_SINGLE_ROUNDING */
+#define MPY_BY_QUANT_MULT_X2_OUT32(out, inp, multiplier, l_shift, r_shift) \
+ out = AE_SLAA32(inp, l_shift); \
+ out = AE_MULFP32X2RAS(out, AE_MOVDA32(multiplier)); \
+ out = AE_ROUND32X2F64SSYM(AE_SRAA64(AE_CVT64F32_H(out), r_shift), \
+ AE_SRAA64(AE_CVT64F32_L(out), r_shift));
+
+#define MPY_BY_QUANT_MULT_X2X2_OUT32(out1, out2, inp1, inp2, multiplier, \
+ l_shift, r_shift) \
+ { \
+ ae_int32x2 d_ls = AE_MOVDA32(1 << l_shift); \
+ out1 = AE_MULP32X2(inp1, d_ls); \
+ out2 = AE_MULP32X2(inp2, d_ls); \
+ out1 = AE_MULFP32X2RAS(out1, AE_MOVDA32(multiplier)); \
+ out2 = AE_MULFP32X2RAS(out2, AE_MOVDA32(multiplier)); \
+ out1 = AE_ROUND32X2F64SSYM(AE_SRAA64(AE_CVT64F32_H(out1), r_shift), \
+ AE_SRAA64(AE_CVT64F32_L(out1), r_shift)); \
+ out2 = AE_ROUND32X2F64SSYM(AE_SRAA64(AE_CVT64F32_H(out2), r_shift), \
+ AE_SRAA64(AE_CVT64F32_L(out2), r_shift)); \
+ }
+#endif /* #if TFLITE_SINGLE_ROUNDING */
+
+#ifndef AE_MULFP16X4RS
+static inline ae_f16x4 AE_MULFP16X4RS(ae_f16x4 d0, ae_f16x4 d1) {
+ ae_f16x4 output;
+ ae_f32x2 d0_32_0, d0_32_1, out32_0, out32_1;
+ ae_f16x4 one_d = AE_MOVDA16(1);
+ AE_MUL16X4(d0_32_0, d0_32_1, d0, one_d);
+ out32_0 = AE_MULFP32X16X2RS_H(d0_32_0, d1);
+ out32_1 = AE_MULFP32X16X2RS_L(d0_32_1, d1);
+ output = AE_SEL16_6420(AE_MOVF16X4_FROMF32X2(out32_0),
+ AE_MOVF16X4_FROMF32X2(out32_1));
+ return output;
+}
+#endif
+
+#ifndef AE_MINMAX16
+#define AE_MINMAX16(dinout, d_min, d_max) \
+ { \
+ xtbool4 b0 = AE_LT16(dinout, d_min); \
+ AE_MOVT16X4(dinout, d_min, b0); \
+ b0 = AE_LT16(d_max, dinout); \
+ AE_MOVT16X4(dinout, d_max, b0); \
+ }
+#endif
+
+#ifndef AE_SRAA32SYMS
+#define AE_SRAA32SYMS(inp, right_shift) \
+ AE_ROUND32X2F64SSYM(AE_SRAA64(AE_CVT64F32_H(inp), right_shift), \
+ AE_SRAA64(AE_CVT64F32_L(inp), right_shift))
+#endif
+
+void calc_cell_state_without_cifg(int16_t* cell_state,
+ const int16_t* forget_gate,
+ const int16_t* cell_gate,
+ const int16_t* input_gate, int shift1,
+ int shift2, int clip, int num_elms) {
+ const ae_int16x4 *p16x4_cs_r, *p16x4_fg_r;
+ const ae_int16x4 *p16x4_cg_r, *p16x4_ig_r;
+
+ ae_int16x4* p16x4_cs_w;
+
+ ae_valign align_cs_r, align_fg_r;
+ ae_valign align_cg_r, align_ig_r;
+ ae_valign align_cs_w;
+
+ ae_int16x4 d_cs_r_0, d_cs_r_1;
+ ae_int16x4 d_fg_0, d_fg_1;
+ ae_int16x4 d_cg_0, d_cg_1;
+ ae_int16x4 d_ig_0, d_ig_1;
+ ae_int16x4 d_cs_w_0, d_cs_w_1;
+ ae_int32x2 d_mul_0, d_mul_1, d_mul_2, d_mul_3;
+ ae_int32x2 d_mul_4, d_mul_5, d_mul_6, d_mul_7;
+
+ ae_int16x4 d_min, d_max;
+
+ int i = 0;
+ p16x4_cs_r = (const ae_int16x4*)cell_state;
+ p16x4_fg_r = (const ae_int16x4*)forget_gate;
+ p16x4_cg_r = (const ae_int16x4*)cell_gate;
+ p16x4_ig_r = (const ae_int16x4*)input_gate;
+
+ p16x4_cs_w = (ae_int16x4*)cell_state;
+
+ align_cs_r = AE_LA64_PP(p16x4_cs_r);
+ align_fg_r = AE_LA64_PP(p16x4_fg_r);
+ align_cg_r = AE_LA64_PP(p16x4_cg_r);
+ align_ig_r = AE_LA64_PP(p16x4_ig_r);
+
+ align_cs_w = AE_ZALIGN64();
+
+ if (clip > 0) {
+ d_min = AE_MOVDA16(-clip);
+ d_max = AE_MOVDA16(clip);
+ } else {
+ d_min = AE_MOVDA16(-32768);
+ d_max = AE_MOVDA16(32767);
+ }
+
+#pragma concurrent
+ if (shift1 == 15) {
+ for (i = 0; i < (num_elms >> 3); i++) {
+ AE_LA16X4_IP(d_cs_r_0, align_cs_r, p16x4_cs_r);
+ AE_LA16X4_IP(d_cs_r_1, align_cs_r, p16x4_cs_r);
+ AE_LA16X4_IP(d_fg_0, align_fg_r, p16x4_fg_r);
+ AE_LA16X4_IP(d_fg_1, align_fg_r, p16x4_fg_r);
+ AE_LA16X4_IP(d_cg_0, align_cg_r, p16x4_cg_r);
+ AE_LA16X4_IP(d_cg_1, align_cg_r, p16x4_cg_r);
+ AE_LA16X4_IP(d_ig_0, align_ig_r, p16x4_ig_r);
+ AE_LA16X4_IP(d_ig_1, align_ig_r, p16x4_ig_r);
+
+ d_cs_w_0 = AE_MULFP16X4RS(d_cs_r_0, d_fg_0);
+ d_cs_w_1 = AE_MULFP16X4RS(d_cs_r_1, d_fg_1);
+
+ AE_MUL16X4(d_mul_4, d_mul_5, d_cg_0, d_ig_0);
+ AE_MUL16X4(d_mul_6, d_mul_7, d_cg_1, d_ig_1);
+ d_mul_4 = AE_SRAA32SYMS(d_mul_4, shift2);
+ d_mul_5 = AE_SRAA32SYMS(d_mul_5, shift2);
+ d_mul_6 = AE_SRAA32SYMS(d_mul_6, shift2);
+ d_mul_7 = AE_SRAA32SYMS(d_mul_7, shift2);
+
+ d_cg_0 = AE_SAT16X4(d_mul_4, d_mul_5);
+ d_cg_1 = AE_SAT16X4(d_mul_6, d_mul_7);
+ d_cs_w_0 = AE_ADD16S(d_cs_w_0, d_cg_0);
+ d_cs_w_1 = AE_ADD16S(d_cs_w_1, d_cg_1);
+ AE_MINMAX16(d_cs_w_0, d_min, d_max);
+ AE_MINMAX16(d_cs_w_1, d_min, d_max);
+
+ AE_SA16X4_IP(d_cs_w_0, align_cs_w, p16x4_cs_w);
+ AE_SA16X4_IP(d_cs_w_1, align_cs_w, p16x4_cs_w);
+ }
+ AE_SA64POS_FP(align_cs_w, p16x4_cs_w); // finalize the stream
+
+ const ae_int16 *p16_cs_r, *p16_fg_r;
+ const ae_int16 *p16_cg_r, *p16_ig_r;
+
+ ae_int16* p16_cs_w;
+
+ p16_cs_r = (const ae_int16*)p16x4_cs_r;
+ p16_fg_r = (const ae_int16*)p16x4_fg_r;
+ p16_cg_r = (const ae_int16*)p16x4_cg_r;
+ p16_ig_r = (const ae_int16*)p16x4_ig_r;
+
+ p16_cs_w = (ae_int16*)p16x4_cs_w;
+ // residue iterations
+#pragma concurrent
+#pragma loop_count max = 7
+ for (i = 0; i < ((num_elms)&7); i++) {
+ d_cs_r_0 = p16_cs_r[i];
+ d_fg_0 = p16_fg_r[i];
+ d_cg_0 = p16_cg_r[i];
+ d_ig_0 = p16_ig_r[i];
+
+ d_cs_w_0 = AE_MULFP16X4RS(d_cs_r_0, d_fg_0);
+
+ AE_MUL16X4(d_mul_0, d_mul_1, d_cg_0, d_ig_0);
+ d_mul_0 = AE_SRAA32SYMS(d_mul_0, shift2);
+ d_cg_0 = AE_SAT16X4(d_mul_0, d_mul_0);
+
+ d_cs_w_0 = AE_ADD16S(d_cs_w_0, d_cg_0);
+ AE_MINMAX16(d_cs_w_0, d_min, d_max);
+ p16_cs_w[i] = d_cs_w_0;
+ }
+ } else {
+ for (i = 0; i < (num_elms >> 3); i++) {
+ AE_LA16X4_IP(d_cs_r_0, align_cs_r, p16x4_cs_r);
+ AE_LA16X4_IP(d_cs_r_1, align_cs_r, p16x4_cs_r);
+ AE_LA16X4_IP(d_fg_0, align_fg_r, p16x4_fg_r);
+ AE_LA16X4_IP(d_fg_1, align_fg_r, p16x4_fg_r);
+ AE_LA16X4_IP(d_cg_0, align_cg_r, p16x4_cg_r);
+ AE_LA16X4_IP(d_cg_1, align_cg_r, p16x4_cg_r);
+ AE_LA16X4_IP(d_ig_0, align_ig_r, p16x4_ig_r);
+ AE_LA16X4_IP(d_ig_1, align_ig_r, p16x4_ig_r);
+
+ AE_MUL16X4(d_mul_0, d_mul_1, d_cs_r_0, d_fg_0);
+ AE_MUL16X4(d_mul_2, d_mul_3, d_cs_r_1, d_fg_1);
+ d_mul_0 = AE_SRAA32SYMS(d_mul_0, shift1);
+ d_mul_1 = AE_SRAA32SYMS(d_mul_1, shift1);
+ d_mul_2 = AE_SRAA32SYMS(d_mul_2, shift1);
+ d_mul_3 = AE_SRAA32SYMS(d_mul_3, shift1);
+
+ d_cs_w_0 = AE_SAT16X4(d_mul_0, d_mul_1);
+ d_cs_w_1 = AE_SAT16X4(d_mul_2, d_mul_3);
+
+ AE_MUL16X4(d_mul_4, d_mul_5, d_cg_0, d_ig_0);
+ AE_MUL16X4(d_mul_6, d_mul_7, d_cg_1, d_ig_1);
+ d_mul_4 = AE_SRAA32SYMS(d_mul_4, shift2);
+ d_mul_5 = AE_SRAA32SYMS(d_mul_5, shift2);
+ d_mul_6 = AE_SRAA32SYMS(d_mul_6, shift2);
+ d_mul_7 = AE_SRAA32SYMS(d_mul_7, shift2);
+
+ d_cg_0 = AE_SAT16X4(d_mul_4, d_mul_5);
+ d_cg_1 = AE_SAT16X4(d_mul_6, d_mul_7);
+ d_cs_w_0 = AE_ADD16S(d_cs_w_0, d_cg_0);
+ d_cs_w_1 = AE_ADD16S(d_cs_w_1, d_cg_1);
+ AE_MINMAX16(d_cs_w_0, d_min, d_max);
+ AE_MINMAX16(d_cs_w_1, d_min, d_max);
+
+ AE_SA16X4_IP(d_cs_w_0, align_cs_w, p16x4_cs_w);
+ AE_SA16X4_IP(d_cs_w_1, align_cs_w, p16x4_cs_w);
+ }
+ AE_SA64POS_FP(align_cs_w, p16x4_cs_w); // finalize the stream
+
+ const ae_int16 *p16_cs_r, *p16_fg_r;
+ const ae_int16 *p16_cg_r, *p16_ig_r;
+
+ ae_int16* p16_cs_w;
+
+ p16_cs_r = (const ae_int16*)p16x4_cs_r;
+ p16_fg_r = (const ae_int16*)p16x4_fg_r;
+ p16_cg_r = (const ae_int16*)p16x4_cg_r;
+ p16_ig_r = (const ae_int16*)p16x4_ig_r;
+
+ p16_cs_w = (ae_int16*)p16x4_cs_w;
+ // residue iterations
+#pragma concurrent
+#pragma loop_count max = 7
+ for (i = 0; i < ((num_elms)&7); i++) {
+ d_cs_r_0 = p16_cs_r[i];
+ d_fg_0 = p16_fg_r[i];
+ d_cg_0 = p16_cg_r[i];
+ d_ig_0 = p16_ig_r[i];
+
+ AE_MUL16X4(d_mul_0, d_mul_1, d_cs_r_0, d_fg_0);
+ d_mul_0 = AE_SRAA32SYMS(d_mul_0, shift1);
+ d_cs_w_0 = AE_SAT16X4(d_mul_0, d_mul_0);
+
+ AE_MUL16X4(d_mul_0, d_mul_1, d_cg_0, d_ig_0);
+ d_mul_0 = AE_SRAA32SYMS(d_mul_0, shift2);
+ d_cg_0 = AE_SAT16X4(d_mul_0, d_mul_0);
+
+ d_cs_w_0 = AE_ADD16S(d_cs_w_0, d_cg_0);
+ AE_MINMAX16(d_cs_w_0, d_min, d_max);
+ p16_cs_w[i] = d_cs_w_0;
+ }
+ }
+}
+
+void calc_cell_state_with_cifg(int16_t* cell_state, const int16_t* forget_gate,
+ const int16_t* cell_gate, int shift1, int shift2,
+ int clip, int num_elms) {
+ const ae_int16x4 *p16x4_cs_r, *p16x4_fg_r;
+ const ae_int16x4* p16x4_cg_r;
+
+ ae_int16x4* p16x4_cs_w;
+
+ ae_valign align_cs_r, align_fg_r;
+ ae_valign align_cg_r;
+ ae_valign align_cs_w;
+
+ ae_int16x4 d_cs_r_0, d_cs_r_1;
+ ae_int16x4 d_fg_0, d_fg_1;
+ ae_int16x4 d_cg_0, d_cg_1;
+ ae_int16x4 d_1mfg_0, d_1mfg_1;
+ ae_int16x4 d_cs_w_0, d_cs_w_1;
+ ae_int32x2 d_mul_0, d_mul_1, d_mul_2, d_mul_3;
+ ae_int32x2 d_mul_4, d_mul_5, d_mul_6, d_mul_7;
+
+ ae_int16x4 d_min, d_max, d_one;
+
+ int i = 0;
+ p16x4_cs_r = (const ae_int16x4*)cell_state;
+ p16x4_fg_r = (const ae_int16x4*)forget_gate;
+ p16x4_cg_r = (const ae_int16x4*)cell_gate;
+
+ p16x4_cs_w = (ae_int16x4*)cell_state;
+
+ align_cs_r = AE_LA64_PP(p16x4_cs_r);
+ align_fg_r = AE_LA64_PP(p16x4_fg_r);
+ align_cg_r = AE_LA64_PP(p16x4_cg_r);
+
+ align_cs_w = AE_ZALIGN64();
+
+ if (clip > 0) {
+ d_min = AE_MOVDA16(-clip);
+ d_max = AE_MOVDA16(clip);
+ } else {
+ d_min = AE_MOVDA16(-32768);
+ d_max = AE_MOVDA16(32767);
+ }
+ d_one = AE_MOVDA16(32767);
+
+#pragma concurrent
+ if (shift1 == 15) {
+ for (i = 0; i < (num_elms >> 3); i++) {
+ AE_LA16X4_IP(d_cs_r_0, align_cs_r, p16x4_cs_r);
+ AE_LA16X4_IP(d_cs_r_1, align_cs_r, p16x4_cs_r);
+ AE_LA16X4_IP(d_fg_0, align_fg_r, p16x4_fg_r);
+ AE_LA16X4_IP(d_fg_1, align_fg_r, p16x4_fg_r);
+ AE_LA16X4_IP(d_cg_0, align_cg_r, p16x4_cg_r);
+ AE_LA16X4_IP(d_cg_1, align_cg_r, p16x4_cg_r);
+
+ d_cs_w_0 = AE_MULFP16X4RS(d_cs_r_0, d_fg_0);
+ d_cs_w_1 = AE_MULFP16X4RS(d_cs_r_1, d_fg_1);
+
+ d_1mfg_0 = AE_SUB16S(d_one, d_fg_0);
+ d_1mfg_1 = AE_SUB16S(d_one, d_fg_1);
+ AE_MUL16X4(d_mul_4, d_mul_5, d_cg_0, d_1mfg_0);
+ AE_MUL16X4(d_mul_6, d_mul_7, d_cg_1, d_1mfg_1);
+ d_mul_4 = AE_SRAA32SYMS(d_mul_4, shift2);
+ d_mul_5 = AE_SRAA32SYMS(d_mul_5, shift2);
+ d_mul_6 = AE_SRAA32SYMS(d_mul_6, shift2);
+ d_mul_7 = AE_SRAA32SYMS(d_mul_7, shift2);
+ d_cg_0 = AE_SAT16X4(d_mul_4, d_mul_5);
+ d_cg_1 = AE_SAT16X4(d_mul_6, d_mul_7);
+
+ d_cs_w_0 = AE_ADD16S(d_cs_w_0, d_cg_0);
+ d_cs_w_1 = AE_ADD16S(d_cs_w_1, d_cg_1);
+
+ AE_MINMAX16(d_cs_w_0, d_min, d_max);
+ AE_MINMAX16(d_cs_w_1, d_min, d_max);
+
+ AE_SA16X4_IP(d_cs_w_0, align_cs_w, p16x4_cs_w);
+ AE_SA16X4_IP(d_cs_w_1, align_cs_w, p16x4_cs_w);
+ }
+ AE_SA64POS_FP(align_cs_w, p16x4_cs_w); // finalize the stream
+
+ const ae_int16 *p16_cs_r, *p16_fg_r;
+ const ae_int16* p16_cg_r;
+
+ ae_int16* p16_cs_w;
+
+ p16_cs_r = (const ae_int16*)p16x4_cs_r;
+ p16_fg_r = (const ae_int16*)p16x4_fg_r;
+ p16_cg_r = (const ae_int16*)p16x4_cg_r;
+
+ p16_cs_w = (ae_int16*)p16x4_cs_w;
+ // residue iterations
+#pragma concurrent
+#pragma loop_count max = 7
+ for (i = 0; i < ((num_elms)&7); i++) {
+ d_cs_r_0 = p16_cs_r[i];
+ d_fg_0 = p16_fg_r[i];
+ d_cg_0 = p16_cg_r[i];
+
+ d_cs_w_0 = AE_MULFP16X4RS(d_cs_r_0, d_fg_0);
+
+ d_1mfg_0 = AE_SUB16S(d_one, d_fg_0);
+ AE_MUL16X4(d_mul_0, d_mul_1, d_cg_0, d_1mfg_0);
+ d_mul_0 = AE_SRAA32SYMS(d_mul_0, shift2);
+ d_cg_0 = AE_SAT16X4(d_mul_0, d_mul_0);
+
+ d_cs_w_0 = AE_ADD16S(d_cs_w_0, d_cg_0);
+ AE_MINMAX16(d_cs_w_0, d_min, d_max);
+ p16_cs_w[i] = d_cs_w_0;
+ }
+ } else {
+ for (i = 0; i < (num_elms >> 3); i++) {
+ AE_LA16X4_IP(d_cs_r_0, align_cs_r, p16x4_cs_r);
+ AE_LA16X4_IP(d_cs_r_1, align_cs_r, p16x4_cs_r);
+ AE_LA16X4_IP(d_fg_0, align_fg_r, p16x4_fg_r);
+ AE_LA16X4_IP(d_fg_1, align_fg_r, p16x4_fg_r);
+ AE_LA16X4_IP(d_cg_0, align_cg_r, p16x4_cg_r);
+ AE_LA16X4_IP(d_cg_1, align_cg_r, p16x4_cg_r);
+
+ AE_MUL16X4(d_mul_0, d_mul_1, d_cs_r_0, d_fg_0);
+ AE_MUL16X4(d_mul_2, d_mul_3, d_cs_r_1, d_fg_1);
+ d_mul_0 = AE_SRAA32SYMS(d_mul_0, shift1);
+ d_mul_1 = AE_SRAA32SYMS(d_mul_1, shift1);
+ d_mul_2 = AE_SRAA32SYMS(d_mul_2, shift1);
+ d_mul_3 = AE_SRAA32SYMS(d_mul_3, shift1);
+ d_cs_w_0 = AE_SAT16X4(d_mul_0, d_mul_1);
+ d_cs_w_1 = AE_SAT16X4(d_mul_2, d_mul_3);
+
+ d_1mfg_0 = AE_SUB16S(d_one, d_fg_0);
+ d_1mfg_1 = AE_SUB16S(d_one, d_fg_1);
+ AE_MUL16X4(d_mul_4, d_mul_5, d_cg_0, d_1mfg_0);
+ AE_MUL16X4(d_mul_6, d_mul_7, d_cg_1, d_1mfg_1);
+ d_mul_4 = AE_SRAA32SYMS(d_mul_4, shift2);
+ d_mul_5 = AE_SRAA32SYMS(d_mul_5, shift2);
+ d_mul_6 = AE_SRAA32SYMS(d_mul_6, shift2);
+ d_mul_7 = AE_SRAA32SYMS(d_mul_7, shift2);
+ d_cg_0 = AE_SAT16X4(d_mul_4, d_mul_5);
+ d_cg_1 = AE_SAT16X4(d_mul_6, d_mul_7);
+
+ d_cs_w_0 = AE_ADD16S(d_cs_w_0, d_cg_0);
+ d_cs_w_1 = AE_ADD16S(d_cs_w_1, d_cg_1);
+
+ AE_MINMAX16(d_cs_w_0, d_min, d_max);
+ AE_MINMAX16(d_cs_w_1, d_min, d_max);
+
+ AE_SA16X4_IP(d_cs_w_0, align_cs_w, p16x4_cs_w);
+ AE_SA16X4_IP(d_cs_w_1, align_cs_w, p16x4_cs_w);
+ }
+ AE_SA64POS_FP(align_cs_w, p16x4_cs_w); // finalize the stream
+
+ const ae_int16 *p16_cs_r, *p16_fg_r;
+ const ae_int16* p16_cg_r;
+
+ ae_int16* p16_cs_w;
+
+ p16_cs_r = (const ae_int16*)p16x4_cs_r;
+ p16_fg_r = (const ae_int16*)p16x4_fg_r;
+ p16_cg_r = (const ae_int16*)p16x4_cg_r;
+
+ p16_cs_w = (ae_int16*)p16x4_cs_w;
+ // residue iterations
+#pragma concurrent
+#pragma loop_count max = 7
+ for (i = 0; i < ((num_elms)&7); i++) {
+ d_cs_r_0 = p16_cs_r[i];
+ d_fg_0 = p16_fg_r[i];
+ d_cg_0 = p16_cg_r[i];
+
+ AE_MUL16X4(d_mul_0, d_mul_1, d_cs_r_0, d_fg_0);
+ d_mul_0 = AE_SRAA32SYMS(d_mul_0, shift1);
+ d_cs_w_0 = AE_SAT16X4(d_mul_0, d_mul_0);
+
+ d_1mfg_0 = AE_SUB16S(d_one, d_fg_0);
+ AE_MUL16X4(d_mul_0, d_mul_1, d_cg_0, d_1mfg_0);
+ d_mul_0 = AE_SRAA32SYMS(d_mul_0, shift2);
+ d_cg_0 = AE_SAT16X4(d_mul_0, d_mul_0);
+
+ d_cs_w_0 = AE_ADD16S(d_cs_w_0, d_cg_0);
+ AE_MINMAX16(d_cs_w_0, d_min, d_max);
+ p16_cs_w[i] = d_cs_w_0;
+ }
+ }
+}
+
+void xa_nn_elm_mul_16x16_asym8s(int8_t* output, const int16_t* input_1,
+ const int16_t* input_2, int32_t multiplier,
+ int32_t shift, int32_t zero_point,
+ int num_elms) {
+ ae_int16x4* tmp_input_1;
+ ae_int16x4* tmp_input_2;
+
+ ae_valign align_src_input_1, align_src_input_2;
+
+ ae_int16x4 data_a_0, data_b_0;
+ ae_int32x2 data_ab_0, data_ab_1;
+ ae_int16x4 d_zp;
+ ae_int16x4 data_c_0;
+ ae_int16x4 d_min8 = AE_MOVDA16(-128);
+ ae_int16x4 d_max8 = AE_MOVDA16(127);
+
+ int i = 0;
+ int left_shift, right_shift;
+ tmp_input_1 = (ae_int16x4*)(input_1);
+ tmp_input_2 = (ae_int16x4*)(input_2);
+
+ align_src_input_1 = AE_LA64_PP((ae_int16x4*)tmp_input_1);
+ align_src_input_2 = AE_LA64_PP((ae_int16x4*)tmp_input_2);
+
+ d_zp = AE_MOVDA16(zero_point);
+
+#if TFLITE_SINGLE_ROUNDING
+ left_shift = shift;
+ (void)right_shift;
+#else /* #if TFLITE_SINGLE_ROUNDING */
+ left_shift = shift < 0 ? 0 : shift;
+ right_shift = shift > 0 ? 0 : -shift;
+#endif /* #if TFLITE_SINGLE_ROUNDING */
+
+#pragma concurrent
+ for (i = 0; i < (num_elms >> 2); i++) {
+ AE_LA16X4_IP(data_a_0, align_src_input_1, tmp_input_1);
+ AE_LA16X4_IP(data_b_0, align_src_input_2, tmp_input_2);
+
+ AE_MUL16X4(data_ab_0, data_ab_1, data_a_0, data_b_0);
+ MPY_BY_QUANT_MULT_X2X2_OUT32(data_ab_0, data_ab_1, data_ab_0, data_ab_1,
+ multiplier, left_shift, right_shift);
+ data_c_0 = AE_SAT16X4(data_ab_0, data_ab_1);
+ data_c_0 = AE_SUB16S(data_c_0, d_zp);
+ AE_MINMAX16(data_c_0, d_min8, d_max8);
+
+ *output++ = AE_MOVAD16_3(data_c_0);
+ *output++ = AE_MOVAD16_2(data_c_0);
+ *output++ = AE_MOVAD16_1(data_c_0);
+ *output++ = AE_MOVAD16_0(data_c_0);
+ }
+
+ // residue iterations
+#pragma concurrent
+#pragma loop_count max = 3
+ for (int j = 0; j < ((num_elms)&3); j++) {
+ AE_L16_IP(data_a_0, (ae_int16*)tmp_input_1, 2);
+ AE_L16_IP(data_b_0, (ae_int16*)tmp_input_2, 2);
+
+ AE_MUL16X4(data_ab_0, data_ab_1, data_a_0, data_b_0);
+ MPY_BY_QUANT_MULT_X2_OUT32(data_ab_0, data_ab_0, multiplier, left_shift,
+ right_shift);
+ data_c_0 = AE_SAT16X4(data_ab_0, data_ab_0);
+ data_c_0 = AE_SUB16S(data_c_0, d_zp);
+ AE_MINMAX16(data_c_0, d_min8, d_max8);
+
+ *output++ = AE_MOVAD16_0(data_c_0);
+ }
+}
#endif // defined(HIFI5)
-} // namespace lstm_eval
-} // namespace micro
-} // namespace ops
} // namespace tflite
+
+#endif // defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
diff --git a/tensorflow/lite/micro/kernels/xtensa/lstm_shared.h b/tensorflow/lite/micro/kernels/xtensa/lstm_shared.h
deleted file mode 100644
index 4bcff1a..0000000
--- a/tensorflow/lite/micro/kernels/xtensa/lstm_shared.h
+++ /dev/null
@@ -1,78 +0,0 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef TENSORFLOW_LITE_KERNELS_LSTM_SHARED_H_
-#define TENSORFLOW_LITE_KERNELS_LSTM_SHARED_H_
-
-namespace tflite {
-namespace ops {
-namespace micro {
-namespace lstm {
-// For full inputs kernel (24-inputs).
-// Please note the 20-input full kernel is deprecated and only kept
-// here for backward compatibility.
-namespace full {
-
-// Input Tensors of size {n_batch, n_input}
-constexpr int kInputTensor = 0;
-
-// Input weight tensors of size: {n_cell, n_input}
-constexpr int kInputToInputWeightsTensor = 1; // Optional
-constexpr int kInputToForgetWeightsTensor = 2;
-constexpr int kInputToCellWeightsTensor = 3;
-constexpr int kInputToOutputWeightsTensor = 4;
-
-// Recurrent weight tensors of size {n_cell, n_output}
-constexpr int kRecurrentToInputWeightsTensor = 5; // Optional
-constexpr int kRecurrentToForgetWeightsTensor = 6;
-constexpr int kRecurrentToCellWeightsTensor = 7;
-constexpr int kRecurrentToOutputWeightsTensor = 8;
-
-// Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
-constexpr int kCellToInputWeightsTensor = 9; // Optional
-constexpr int kCellToForgetWeightsTensor = 10; // Optional
-constexpr int kCellToOutputWeightsTensor = 11; // Optional
-
-// Gates bias tensors of size {n_cell}
-constexpr int kInputGateBiasTensor = 12; // Optional
-constexpr int kForgetGateBiasTensor = 13;
-constexpr int kCellGateBiasTensor = 14;
-constexpr int kOutputGateBiasTensor = 15;
-
-// Projection weight tensor of size {n_output, n_cell}
-constexpr int kProjectionWeightsTensor = 16; // Optional
-// Projection bias tensor of size {n_output}
-constexpr int kProjectionBiasTensor = 17; // Optional
-
-// These state tensors are defined as variable tensors, and will be modified by
-// this op.
-constexpr int kOutputStateTensor = 18;
-constexpr int kCellStateTensor = 19;
-
-// Layer norm coefficient tensors of size {n_cell}, representing a diagonal
-// matrix.
-constexpr int kInputLayerNormCoefficientsTensor = 20; // Optional
-constexpr int kForgetLayerNormCoefficientsTensor = 21; // Optional
-constexpr int kCellLayerNormCoefficientsTensor = 22; // Optional
-constexpr int kOutputLayerNormCoefficientsTensor = 23; // Optional
-
-// Output tensors.
-constexpr int kOutputTensor = 0;
-} // namespace full
-
-} // namespace lstm
-} // namespace micro
-} // namespace ops
-} // namespace tflite
-#endif // TENSORFLOW_LITE_KERNELS_LSTM_SHARED_H_
diff --git a/tensorflow/lite/micro/kernels/xtensa/svdf.cc b/tensorflow/lite/micro/kernels/xtensa/svdf.cc
index 3a71f94..da34e09 100644
--- a/tensorflow/lite/micro/kernels/xtensa/svdf.cc
+++ b/tensorflow/lite/micro/kernels/xtensa/svdf.cc
@@ -63,7 +63,7 @@
#if defined(HIFI5)
memcpy(state_ptr, state_ptr + 1, num_bytes);
#else
- xa_nn_memmove_16(state_ptr, state_ptr + 1, num_bytes);
+ xa_nn_memmove_16(state_ptr, state_ptr + 1, (num_bytes >> 1));
#endif // defined(HIFI5)
// Note: no need to clear the latest activation, matmul is not accumulative.
diff --git a/tensorflow/lite/micro/kernels/xtensa/transpose_conv.cc b/tensorflow/lite/micro/kernels/xtensa/transpose_conv.cc
index 20b0ab8..44a9f86 100644
--- a/tensorflow/lite/micro/kernels/xtensa/transpose_conv.cc
+++ b/tensorflow/lite/micro/kernels/xtensa/transpose_conv.cc
@@ -183,19 +183,57 @@
// Quantized kernels use an int32 scratch buffer.
if (input->type == kTfLiteInt8) {
TFLITE_DCHECK(context->RequestScratchBufferInArena != nullptr);
+#if defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
+ const int stride_width = params->stride_width;
+ const int stride_height = params->stride_height;
+
+ const int input_height = SizeOfDimension(input, 1);
+ const int input_width = SizeOfDimension(input, 2);
+ const int input_depth = SizeOfDimension(input, 3);
+ const int output_height = height;
+ const int output_width = width;
+ int32_t scratch_buffer_size = 0;
+ scratch_buffer_size = xa_nn_transpose_conv_getsize(
+ input_height, input_width, input_depth, filter_height, filter_width,
+ stride_width, stride_height, output_height, output_width, num_channels,
+ PREC_SYM8S, PREC_ASYM8S);
+ TFLITE_DCHECK(context->RequestScratchBufferInArena(
+ context, scratch_buffer_size,
+ &(data->scratch_buffer_index)) == kTfLiteOk);
+#else // #if defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
TFLITE_DCHECK(context->RequestScratchBufferInArena(
context,
GetTensorShape(output).FlatSize() * sizeof(int32_t),
&(data->scratch_buffer_index)) == kTfLiteOk);
+#endif
}
// Quantized 16x8 kernels use an int64 scratch buffer.
if (input->type == kTfLiteInt16) {
TFLITE_DCHECK(context->RequestScratchBufferInArena != nullptr);
+#if defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
+ const int stride_width = params->stride_width;
+ const int stride_height = params->stride_height;
+
+ const int input_height = SizeOfDimension(input, 1);
+ const int input_width = SizeOfDimension(input, 2);
+ const int input_depth = SizeOfDimension(input, 3);
+ const int output_height = height;
+ const int output_width = width;
+ int32_t scratch_buffer_size = 0;
+ scratch_buffer_size = xa_nn_transpose_conv_getsize(
+ input_height, input_width, input_depth, filter_height, filter_width,
+ stride_width, stride_height, output_height, output_width, num_channels,
+ PREC_SYM8S, PREC_SYM16S);
+ TFLITE_DCHECK(context->RequestScratchBufferInArena(
+ context, scratch_buffer_size,
+ &(data->scratch_buffer_index)) == kTfLiteOk);
+#else // #if defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
TFLITE_DCHECK(context->RequestScratchBufferInArena(
context,
GetTensorShape(output).FlatSize() * sizeof(std::int64_t),
&(data->scratch_buffer_index)) == kTfLiteOk);
+#endif // #if defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
}
// All per-channel quantized tensors need valid zero point and scale arrays.
@@ -282,6 +320,63 @@
case kTfLiteInt8: {
int32_t* scratch_buffer = static_cast<int32_t*>(
context->GetScratchBuffer(context, data.scratch_buffer_index));
+#if defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
+ if (bias->type == kTfLiteInt32) {
+ const RuntimeShape& input_shape = tflite::micro::GetTensorShape(input);
+ const RuntimeShape& filter_shape =
+ tflite::micro::GetTensorShape(filter);
+ const RuntimeShape& output_shape =
+ tflite::micro::GetTensorShape(output);
+ const int stride_width = data.params.stride_width;
+ const int stride_height = data.params.stride_height;
+ const int pad_width = data.params.padding_values.width;
+ const int pad_height = data.params.padding_values.height;
+
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
+ const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
+
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ const int8_t* input_data = tflite::micro::GetTensorData<int8_t>(input);
+ const int8_t* filter_data =
+ tflite::micro::GetTensorData<int8_t>(filter);
+ const int32_t* bias_data = tflite::micro::GetTensorData<int32_t>(bias);
+ int8_t* output_data = tflite::micro::GetTensorData<int8_t>(output);
+
+ const int num_elements = output_shape.FlatSize();
+
+ for (int b = 0; b < batches; b++) {
+ xa_nn_transpose_conv_sym8sxasym8s(
+ &output_data[b * output_height * output_width * output_depth],
+ const_cast<WORD8*>(
+ &input_data[b * input_height * input_width * input_depth]),
+ const_cast<WORD8*>(filter_data), const_cast<WORD32*>(bias_data),
+ stride_width, stride_height, pad_width, pad_height, input_depth,
+ output_depth, input_height, input_width, filter_height,
+ filter_width, output_height, output_width, num_elements / batches,
+ data.params.input_offset, data.params.output_offset,
+ data.per_channel_output_shift, data.per_channel_output_multiplier,
+ scratch_buffer);
+ }
+ } else {
+ reference_integer_ops::TransposeConv(
+ data.params, data.per_channel_output_multiplier,
+ data.per_channel_output_shift, tflite::micro::GetTensorShape(input),
+ tflite::micro::GetTensorData<int8_t>(input),
+ tflite::micro::GetTensorShape(filter),
+ tflite::micro::GetTensorData<int8_t>(filter),
+ tflite::micro::GetTensorShape(bias),
+ tflite::micro::GetTensorData<int32_t>(bias),
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData<int8_t>(output),
+ tflite::micro::GetTensorShape(nullptr), nullptr, scratch_buffer);
+ }
+#else
reference_integer_ops::TransposeConv(
data.params, data.per_channel_output_multiplier,
data.per_channel_output_shift, tflite::micro::GetTensorShape(input),
@@ -293,6 +388,7 @@
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output),
tflite::micro::GetTensorShape(nullptr), nullptr, scratch_buffer);
+#endif
break;
}
case kTfLiteInt16: {
@@ -319,7 +415,7 @@
tflite::micro::GetTensorData<int16_t>(output),
tflite::micro::GetTensorShape(nullptr), nullptr, scratch_buffer);
} else {
-#if defined(HIFI3) || defined(HIFI4)
+#if defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
const RuntimeShape& input_shape = tflite::micro::GetTensorShape(input);
const RuntimeShape& filter_shape =
tflite::micro::GetTensorShape(filter);
@@ -359,9 +455,9 @@
output_depth, input_height, input_width, filter_height,
filter_width, output_height, output_width, num_elements / batches,
data.per_channel_output_shift, data.per_channel_output_multiplier,
- &scratch_buffer[b * output_height * output_width * output_depth]);
+ scratch_buffer);
}
-#else
+#else // #if defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
reference_integer_ops::TransposeConv(
data.params, data.per_channel_output_multiplier,
data.per_channel_output_shift, tflite::micro::GetTensorShape(input),
@@ -373,7 +469,7 @@
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output),
tflite::micro::GetTensorShape(nullptr), nullptr, scratch_buffer);
-#endif // defined(HIFI3) || defined(HIFI4)
+#endif // #if defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
}
break;
}
diff --git a/tensorflow/lite/micro/kernels/xtensa/unidirectional_sequence_lstm.cc b/tensorflow/lite/micro/kernels/xtensa/unidirectional_sequence_lstm.cc
index cbce1e1..0f6a02e 100644
--- a/tensorflow/lite/micro/kernels/xtensa/unidirectional_sequence_lstm.cc
+++ b/tensorflow/lite/micro/kernels/xtensa/unidirectional_sequence_lstm.cc
@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -13,1109 +13,156 @@
limitations under the License.
==============================================================================*/
-#include <math.h>
-#include <stdio.h>
+// Integer version of unidirectional sequence lstm. Only the standard LSTM
+// (defined in the keras LSTM layer, e.g., no peephole etc.) is supported here.
+// Currently used by the 16 bits activation case only
-#include <cstddef>
+#include <algorithm>
+#include <limits>
-#include "tensorflow/lite/c/builtin_op_data.h"
-#include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
-#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/micro/kernels/fully_connected.h"
+#include "tensorflow/lite/micro/kernels/kernel_util.h"
+#include "tensorflow/lite/micro/kernels/lstm_shared.h"
#include "tensorflow/lite/micro/kernels/xtensa/lstm_eval.h"
-#include "tensorflow/lite/micro/kernels/xtensa/lstm_shared.h"
-#include "tensorflow/lite/micro/micro_log.h"
-// TODO(b/230666079): Flatten the namespace to match the builtin kernel
-// implementation
namespace tflite {
-namespace ops {
-namespace micro {
-// namespace unidirectional_sequence_lstm {
+
namespace {
+/*Helper Functions*/
-struct OpData {
- // If the lstm is layer norm.
- bool use_layer_norm;
- // The scratch tensor index.
- int scratch_tensor_index;
- bool compute_row_sums = false;
+/*Kernel functions*/
- lstm_eval::IntegerLstmParameter integer_lstm_param;
-};
+void* UnidirectionalSequenceLstmInit(TfLiteContext* context, const char* buffer,
+ size_t length) {
+ TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
+ return context->AllocatePersistentBuffer(context, sizeof(OpDataLSTM));
+}
-TfLiteStatus PopulateQuantizedLstmParams8x8_16(
- TfLiteContext* context, TfLiteNode* node,
- lstm_eval::IntegerLstmParameter* integer_lstm_param) {
- // Calculate quantized clip for projection and cell.
- const auto* params =
+TfLiteStatus UnidirectionalSequenceLstmPrepare(TfLiteContext* context,
+ TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 24);
+
+ TFLITE_DCHECK(node->builtin_data != nullptr);
+ TFLITE_DCHECK(node->user_data != nullptr);
+
+ OpDataLSTM* op_data = reinterpret_cast<OpDataLSTM*>(node->user_data);
+ const auto* builtin_data =
static_cast<TfLiteUnidirectionalSequenceLSTMParams*>(node->builtin_data);
- const float cell_clip = static_cast<float>(params->cell_clip);
- const float proj_clip = static_cast<float>(params->proj_clip);
+ // All TempTfLiteTensors will be deallocated through the destructor.
+ LstmTensors lstm_tensors(context, node);
+ TF_LITE_ENSURE_OK(context, lstm_tensors.ValidateTensorStatus(context));
- const TfLiteTensor* cell_state =
- GetVariableInput(context, node, micro::lstm::full::kCellStateTensor);
- TF_LITE_ENSURE(context, cell_state != nullptr);
- TfLiteTensor* output_tensor;
+ op_data->cell_gate_nonlinear_type = builtin_data->activation;
+ op_data->size_info =
+ CreateLstmSizeInfo(builtin_data->time_major,
+ lstm_tensors.GetInternalTensor(kLstmInputTensor)->dims,
+ lstm_tensors.HiddenStateTensor()->dims);
TF_LITE_ENSURE_OK(
- context, GetOutputSafe(context, node, micro::lstm::full::kOutputTensor,
- &output_tensor));
+ context, ValidateTensorSize(context, lstm_tensors, op_data->size_info));
- auto* cell_state_params =
- static_cast<TfLiteAffineQuantization*>(cell_state->quantization.params);
- auto* proj_params = static_cast<TfLiteAffineQuantization*>(
- output_tensor->quantization.params);
- if (cell_clip > static_cast<float>(0.0)) {
- integer_lstm_param->quantized_cell_clip = static_cast<int16_t>(std::min(
- std::max(cell_clip / cell_state_params->scale->data[0], -32768.0f),
- 32767.0f));
+ // Create cell state information and gate parameters (Fully Connected and Mul)
+ auto cell_state_type =
+ lstm_tensors.GetInternalTensor(kLstmCellStateTensor)->type;
+ if (cell_state_type == kTfLiteFloat32) {
+ op_data->cell_state_info =
+ CreateLstmCellStateInfoFloat(builtin_data->cell_clip);
+ TF_LITE_ENSURE_OK(
+ context, PrepareGateParametersFloat(context, lstm_tensors, op_data));
+ } else if (cell_state_type == kTfLiteInt16) {
+ op_data->cell_state_info = CreateLstmCellStateInfo(
+ lstm_tensors.CellStateTensor()->params.scale, builtin_data->cell_clip);
+ TF_LITE_ENSURE_OK(
+ context, PrepareGateParametersInteger(context, lstm_tensors, op_data));
} else {
- integer_lstm_param->quantized_cell_clip = 0;
+ MicroPrintf(
+ "Cell state type %s (%d) not supported. The quantized Unidirectional "
+ "Sequence LSTM Op only support int16 cell state",
+ TfLiteTypeGetName(cell_state_type), cell_state_type);
+ return kTfLiteError;
}
- if (proj_clip > static_cast<float>(0.0)) {
- integer_lstm_param->quantized_proj_clip = static_cast<int8_t>(std::min(
- std::max(proj_clip / proj_params->scale->data[0], -128.0f), 127.0f));
- } else {
- integer_lstm_param->quantized_proj_clip = 0;
+ // request buffers (four buffers)
+ for (size_t i = 0; i < 4; i++) {
+ TF_LITE_ENSURE_OK(context, context->RequestScratchBufferInArena(
+ context,
+ op_data->size_info.batch_size *
+ op_data->size_info.state_dimension *
+ TfLiteTypeGetSize(cell_state_type),
+ &(op_data->buffer_indices[i])));
}
+ return kTfLiteOk;
+}
- // Calculate effective scales.
- OpData* op_data = static_cast<OpData*>(node->user_data);
- const bool use_layer_norm = op_data->use_layer_norm;
+TfLiteStatus UnidirectionalSequenceLstmEval(TfLiteContext* context,
+ TfLiteNode* node) {
+ TFLITE_DCHECK(node->user_data != nullptr);
+ const OpDataLSTM& op_data = *reinterpret_cast<OpDataLSTM*>(node->user_data);
+ auto kernel_content = CreateLSTMKernelContent(context, node);
- const TfLiteTensor* input;
- TF_LITE_ENSURE_OK(
- context,
- GetInputSafe(context, node, micro::lstm::full::kInputTensor, &input));
+ const auto activation_type =
+ kernel_content.internal_tensors[kLstmInputTensor]->type;
+ const auto weight_type =
+ kernel_content.internal_tensors[kLstmInputToInputWeightsTensor]->type;
- const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
- context, node, micro::lstm::full::kInputToInputWeightsTensor);
- const TfLiteTensor* input_to_forget_weights;
- TF_LITE_ENSURE_OK(context,
- GetInputSafe(context, node,
- micro::lstm::full::kInputToForgetWeightsTensor,
- &input_to_forget_weights));
- const TfLiteTensor* input_to_cell_weights;
- TF_LITE_ENSURE_OK(
- context,
- GetInputSafe(context, node, micro::lstm::full::kInputToCellWeightsTensor,
- &input_to_cell_weights));
- const TfLiteTensor* input_to_output_weights;
- TF_LITE_ENSURE_OK(context,
- GetInputSafe(context, node,
- micro::lstm::full::kInputToOutputWeightsTensor,
- &input_to_output_weights));
-
- const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
- context, node, micro::lstm::full::kRecurrentToInputWeightsTensor);
- const TfLiteTensor* recurrent_to_forget_weights;
- TF_LITE_ENSURE_OK(
- context, GetInputSafe(context, node,
- micro::lstm::full::kRecurrentToForgetWeightsTensor,
- &recurrent_to_forget_weights));
- const TfLiteTensor* recurrent_to_cell_weights;
- TF_LITE_ENSURE_OK(
- context, GetInputSafe(context, node,
- micro::lstm::full::kRecurrentToCellWeightsTensor,
- &recurrent_to_cell_weights));
- const TfLiteTensor* recurrent_to_output_weights;
- TF_LITE_ENSURE_OK(
- context, GetInputSafe(context, node,
- micro::lstm::full::kRecurrentToOutputWeightsTensor,
- &recurrent_to_output_weights));
-
- const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
- context, node, micro::lstm::full::kCellToInputWeightsTensor);
- const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(
- context, node, micro::lstm::full::kCellToForgetWeightsTensor);
- const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(
- context, node, micro::lstm::full::kCellToOutputWeightsTensor);
-
- const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor(
- context, node, micro::lstm::full::kInputLayerNormCoefficientsTensor);
- const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor(
- context, node, micro::lstm::full::kForgetLayerNormCoefficientsTensor);
- const TfLiteTensor* cell_layer_norm_coefficients = GetOptionalInputTensor(
- context, node, micro::lstm::full::kCellLayerNormCoefficientsTensor);
- const TfLiteTensor* output_layer_norm_coefficients = GetOptionalInputTensor(
- context, node, micro::lstm::full::kOutputLayerNormCoefficientsTensor);
-
- const TfLiteTensor* projection_weights = GetOptionalInputTensor(
- context, node, micro::lstm::full::kProjectionWeightsTensor);
-
- TfLiteTensor* output_state =
- GetVariableInput(context, node, micro::lstm::full::kOutputStateTensor);
- TF_LITE_ENSURE(context, output_state != nullptr);
-
- // Since we have already checked that weights are all there or none, we can
- // check the existence of only one to get the condition.
- const bool use_cifg = (input_to_input_weights == nullptr);
- const bool use_peephole = (cell_to_output_weights != nullptr);
- const bool use_projection = (projection_weights != nullptr);
-
- // Get intermediate scales and zero points.
- constexpr size_t kIntermediateCount = 5;
- float intermediate_scale[kIntermediateCount];
- int32_t intermediate_zp[kIntermediateCount];
- for (int i = 0; i < 4; ++i) {
- if (use_layer_norm) {
- TfLiteTensor* intermediate =
- context->GetTensor(context, node->intermediates->data[i]);
- auto* tmp_params = static_cast<TfLiteAffineQuantization*>(
- intermediate->quantization.params);
- intermediate_scale[i] = tmp_params->scale->data[0];
- intermediate_zp[i] = tmp_params->zero_point->data[0];
- } else {
- // Q3.12 for activation functions.
- intermediate_scale[i] = std::pow(2, -12);
- intermediate_zp[i] = 0;
+ switch (activation_type) {
+ case kTfLiteFloat32: {
+ LSTMBuffers<float> buffers =
+ CreateLSTMBuffers<float>(context, op_data.buffer_indices);
+ EvalLstm<float, float, float, float>(op_data, kernel_content, buffers);
+ break;
+ }
+ case kTfLiteInt8: {
+ switch (weight_type) {
+ case kTfLiteInt8: {
+ // 8(activation)x8(weight)->16(cell) LSTM with 32 bits bias
+ LSTMBuffers<int16_t> buffers =
+ CreateLSTMBuffers<int16_t>(context, op_data.buffer_indices);
+ EvalLstm<int8_t, int8_t, int16_t, int32_t>(op_data, kernel_content,
+ buffers);
+ break;
+ }
+ default: {
+ MicroPrintf("Filter type %s (%d) not supported.",
+ TfLiteTypeGetName(weight_type), activation_type);
+ return kTfLiteError;
+ }
+ }
+ break;
+ }
+ case kTfLiteInt16: {
+ switch (weight_type) {
+ case kTfLiteInt8: {
+ // 16(activation)x8(weight)->16(cell) LSTM with 64 bits bias
+ LSTMBuffers<int16_t> buffers =
+ CreateLSTMBuffers<int16_t>(context, op_data.buffer_indices);
+ EvalLstm<int16_t, int8_t, int16_t, int64_t>(op_data, kernel_content,
+ buffers);
+ break;
+ }
+ default: {
+ MicroPrintf("Filter type %s (%d) not supported.",
+ TfLiteTypeGetName(weight_type), weight_type);
+ return kTfLiteError;
+ }
+ }
+ break;
+ }
+ default: {
+ MicroPrintf("Input type %s (%d) not supported.",
+ TfLiteTypeGetName(activation_type), activation_type);
+ return kTfLiteError;
}
}
- // In the absence of projection, hidden becomes otuput and this intermediate
- // is ignored.
- TfLiteTensor* hidden =
- context->GetTensor(context, node->intermediates->data[4]);
- auto* hidden_params =
- static_cast<TfLiteAffineQuantization*>(hidden->quantization.params);
- intermediate_scale[4] = hidden_params->scale->data[0];
- intermediate_zp[4] = hidden_params->zero_point->data[0];
-
- // Scales.
- const float default_scale = 1.0;
- float input_scale = default_scale;
- float input_to_input_weight_scale = default_scale;
- float recurrent_to_input_weight_scale = default_scale;
- float cell_to_input_weight_scale = default_scale;
- float input_to_forget_weight_scale = default_scale;
- float recurrent_to_forget_weight_scale = default_scale;
- float cell_to_forget_weight_scale = default_scale;
- float input_to_cell_weight_scale = default_scale;
- float recurrent_to_cell_weight_scale = default_scale;
- float input_to_output_weight_scale = default_scale;
- float recurrent_to_output_weight_scale = default_scale;
- float cell_to_output_weight_scale = default_scale;
- float projection_weight_scale = default_scale;
- float layer_norm_input_scale = default_scale;
- float layer_norm_forget_scale = default_scale;
- float layer_norm_cell_scale = default_scale;
- float layer_norm_output_scale = default_scale;
- float output_state_scale = default_scale;
- int cell_scale = 1;
-
- // Effective scales.
- float effective_input_to_input_scale = default_scale;
- float effective_recurrent_to_input_scale = default_scale;
- float effective_cell_to_input_scale = default_scale;
- float effective_input_to_forget_scale = default_scale;
- float effective_recurrent_to_forget_scale = default_scale;
- float effective_cell_to_forget_scale = default_scale;
- float effective_input_to_cell_scale = default_scale;
- float effective_recurrent_to_cell_scale = default_scale;
- float effective_input_to_output_scale = default_scale;
- float effective_recurrent_to_output_scale = default_scale;
- float effective_cell_to_output_scale = default_scale;
- float effective_proj_scale = default_scale;
- float effective_hidden_scale = default_scale;
-
- // Populate scales.
- if (!use_cifg) {
- input_to_input_weight_scale = input_to_input_weights->params.scale;
- recurrent_to_input_weight_scale = recurrent_to_input_weights->params.scale;
- }
-
- if (use_peephole) {
- if (!use_cifg) {
- cell_to_input_weight_scale = cell_to_input_weights->params.scale;
- }
- cell_to_forget_weight_scale = cell_to_forget_weights->params.scale;
- cell_to_output_weight_scale = cell_to_output_weights->params.scale;
- }
-
- if (use_layer_norm) {
- if (!use_cifg) {
- layer_norm_input_scale = input_layer_norm_coefficients->params.scale;
- }
- layer_norm_forget_scale = forget_layer_norm_coefficients->params.scale;
- layer_norm_cell_scale = cell_layer_norm_coefficients->params.scale;
- layer_norm_output_scale = output_layer_norm_coefficients->params.scale;
- }
-
- if (use_projection) {
- projection_weight_scale = projection_weights->params.scale;
- }
- output_state_scale = output_state->params.scale;
-
- input_to_forget_weight_scale = input_to_forget_weights->params.scale;
- input_to_cell_weight_scale = input_to_cell_weights->params.scale;
- input_to_output_weight_scale = input_to_output_weights->params.scale;
- recurrent_to_forget_weight_scale = recurrent_to_forget_weights->params.scale;
- recurrent_to_cell_weight_scale = recurrent_to_cell_weights->params.scale;
- recurrent_to_output_weight_scale = recurrent_to_output_weights->params.scale;
-
- // Check cell state (already used above)
- TF_LITE_ENSURE(context, CheckedLog2(cell_state->params.scale, &cell_scale));
- // TF_LITE_ENSURE(context, cell_scale <= -9);
- integer_lstm_param->cell_scale = cell_scale;
- input_scale = input->params.scale;
-
- // Calculate effective scales.
- if (!use_cifg) {
- effective_input_to_input_scale =
- input_to_input_weight_scale * input_scale / intermediate_scale[0];
- effective_recurrent_to_input_scale = recurrent_to_input_weight_scale *
- output_state_scale /
- intermediate_scale[0];
- }
- effective_input_to_forget_scale =
- input_to_forget_weight_scale * input_scale / intermediate_scale[1];
- effective_recurrent_to_forget_scale = recurrent_to_forget_weight_scale *
- output_state_scale /
- intermediate_scale[1];
-
- effective_input_to_cell_scale =
- input_to_cell_weight_scale * input_scale / intermediate_scale[2];
- effective_recurrent_to_cell_scale = recurrent_to_cell_weight_scale *
- output_state_scale /
- intermediate_scale[2];
-
- effective_input_to_output_scale =
- input_to_output_weight_scale * input_scale / intermediate_scale[3];
- effective_recurrent_to_output_scale = recurrent_to_output_weight_scale *
- output_state_scale /
- intermediate_scale[3];
-
- effective_hidden_scale = std::pow((float)2, (float)-15) /
- intermediate_scale[4] *
- std::pow((float)2, (float)-15);
-
- effective_proj_scale =
- projection_weight_scale * intermediate_scale[4] / output_state_scale;
-
- if (use_peephole) {
- if (!use_cifg) {
- effective_cell_to_input_scale =
- std::pow((float)(2), (float)cell_scale) * // NOLINT
- (float)(cell_to_input_weight_scale) / intermediate_scale[0];
- }
- effective_cell_to_forget_scale =
- std::pow((float)2, (float)cell_scale) * // NOLINT
- (float)cell_to_forget_weight_scale / intermediate_scale[1];
- effective_cell_to_output_scale =
- std::pow((float)2, (float)cell_scale) * // NOLINT
- (float)cell_to_output_weight_scale / intermediate_scale[3];
- }
-
- // Decompose scales.
- QuantizeMultiplier(static_cast<double>(effective_input_to_input_scale),
- &integer_lstm_param->effective_input_to_input_scale_a,
- &integer_lstm_param->effective_input_to_input_scale_b);
- QuantizeMultiplier(static_cast<double>(effective_recurrent_to_input_scale),
- &integer_lstm_param->effective_recurrent_to_input_scale_a,
- &integer_lstm_param->effective_recurrent_to_input_scale_b);
- QuantizeMultiplier(static_cast<double>(effective_cell_to_input_scale),
- &integer_lstm_param->effective_cell_to_input_scale_a,
- &integer_lstm_param->effective_cell_to_input_scale_b);
- QuantizeMultiplier(static_cast<double>(effective_input_to_forget_scale),
- &integer_lstm_param->effective_input_to_forget_scale_a,
- &integer_lstm_param->effective_input_to_forget_scale_b);
- QuantizeMultiplier(
- static_cast<double>(effective_recurrent_to_forget_scale),
- &integer_lstm_param->effective_recurrent_to_forget_scale_a,
- &integer_lstm_param->effective_recurrent_to_forget_scale_b);
- QuantizeMultiplier(static_cast<double>(effective_cell_to_forget_scale),
- &integer_lstm_param->effective_cell_to_forget_scale_a,
- &integer_lstm_param->effective_cell_to_forget_scale_b);
- QuantizeMultiplier(static_cast<double>(effective_input_to_cell_scale),
- &integer_lstm_param->effective_input_to_cell_scale_a,
- &integer_lstm_param->effective_input_to_cell_scale_b);
- QuantizeMultiplier(static_cast<double>(effective_recurrent_to_cell_scale),
- &integer_lstm_param->effective_recurrent_to_cell_scale_a,
- &integer_lstm_param->effective_recurrent_to_cell_scale_b);
- QuantizeMultiplier(static_cast<double>(effective_input_to_output_scale),
- &integer_lstm_param->effective_input_to_output_scale_a,
- &integer_lstm_param->effective_input_to_output_scale_b);
- QuantizeMultiplier(
- static_cast<double>(effective_recurrent_to_output_scale),
- &integer_lstm_param->effective_recurrent_to_output_scale_a,
- &integer_lstm_param->effective_recurrent_to_output_scale_b);
- QuantizeMultiplier(static_cast<double>(effective_cell_to_output_scale),
- &integer_lstm_param->effective_cell_to_output_scale_a,
- &integer_lstm_param->effective_cell_to_output_scale_b);
- QuantizeMultiplier(static_cast<double>(effective_proj_scale),
- &integer_lstm_param->effective_proj_scale_a,
- &integer_lstm_param->effective_proj_scale_b);
- QuantizeMultiplier(static_cast<double>(effective_hidden_scale),
- &integer_lstm_param->effective_hidden_scale_a,
- &integer_lstm_param->effective_hidden_scale_b);
- QuantizeMultiplier(static_cast<double>(layer_norm_input_scale),
- &integer_lstm_param->layer_norm_input_scale_a,
- &integer_lstm_param->layer_norm_input_scale_b);
- QuantizeMultiplier(static_cast<double>(layer_norm_forget_scale),
- &integer_lstm_param->layer_norm_forget_scale_a,
- &integer_lstm_param->layer_norm_forget_scale_b);
- QuantizeMultiplier(static_cast<double>(layer_norm_cell_scale),
- &integer_lstm_param->layer_norm_cell_scale_a,
- &integer_lstm_param->layer_norm_cell_scale_b);
- QuantizeMultiplier(static_cast<double>(layer_norm_output_scale),
- &integer_lstm_param->layer_norm_output_scale_a,
- &integer_lstm_param->layer_norm_output_scale_b);
-
- integer_lstm_param->hidden_zp = intermediate_zp[4];
-
- // 10000 is used to make sure the kernel logic does not overflow.
- if (!use_cifg) {
- integer_lstm_param->input_variance_guard =
- std::max(static_cast<int32_t>(1),
- static_cast<int32_t>(10000 * layer_norm_input_scale));
- }
- integer_lstm_param->forget_variance_guard =
- std::max(static_cast<int32_t>(1),
- static_cast<int32_t>(10000 * layer_norm_forget_scale));
- integer_lstm_param->cell_variance_guard =
- std::max(static_cast<int32_t>(1),
- static_cast<int32_t>(10000 * layer_norm_cell_scale));
- integer_lstm_param->output_variance_guard =
- std::max(static_cast<int32_t>(1),
- static_cast<int32_t>(10000 * layer_norm_output_scale));
-
return kTfLiteOk;
}
} // namespace
-// Temporary tensors
-enum TemporaryTensor {
- kScratchBuffer = 0,
- kInputQuantized = 1,
- kOutputStateQuantized = 2,
- kCellStateQuantized = 3,
- kInputScalingFactors = 4,
- kOutputStateScalingFactors = 5,
- kProductScalingFactors = 6,
- kRecoveredCellWeights = 7,
- kAccumScratch = 8,
- kInputZeroPoints = 9,
- kOutputStateZeroPoints = 10,
- kRowSums = 11,
- kNumTemporaryTensors = 12,
-};
-
-void* Init(TfLiteContext* context, const char* buffer, size_t length) {
- OpData* op_data = reinterpret_cast<OpData*>(
- context->AllocatePersistentBuffer(context, sizeof(OpData)));
-
- return op_data;
-}
-
-// Check that input tensor dimensions matches with each other.
-TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
- TfLiteNode* node, int n_input,
- int n_output, int n_cell,
- bool use_layer_norm, bool is_integer) {
- const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
-
- // Making sure clipping parameters have valid values.
- // == 0 means no clipping
- // > 0 means clipping
- TF_LITE_ENSURE(context, params->cell_clip >= 0);
- TF_LITE_ENSURE(context, params->proj_clip >= 0);
- const TfLiteEvalTensor* input_to_input_weights = tflite::micro::GetEvalInput(
- context, node, micro::lstm::full::kInputToInputWeightsTensor);
- if (input_to_input_weights != nullptr) {
- TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
- TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
- TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
- }
- const TfLiteEvalTensor* input_to_forget_weights = tflite::micro::GetEvalInput(
- context, node, micro::lstm::full::kInputToForgetWeightsTensor);
-
- TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
- TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
- TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
- const TfLiteEvalTensor* input_to_cell_weights = tflite::micro::GetEvalInput(
- context, node, micro::lstm::full::kInputToCellWeightsTensor);
-
- TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
- TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
- TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
- const TfLiteEvalTensor* recurrent_to_input_weights =
- tflite::micro::GetEvalInput(
- context, node, micro::lstm::full::kRecurrentToInputWeightsTensor);
- if (recurrent_to_input_weights != nullptr) {
- TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
- TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
- n_cell);
- TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1],
- n_output);
- }
- const TfLiteEvalTensor* recurrent_to_forget_weights =
- tflite::micro::GetEvalInput(
- context, node, micro::lstm::full::kRecurrentToForgetWeightsTensor);
-
- TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
- TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
- n_cell);
- TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
- n_output);
- const TfLiteEvalTensor* recurrent_to_cell_weights =
- tflite::micro::GetEvalInput(
- context, node, micro::lstm::full::kRecurrentToCellWeightsTensor);
-
- TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
- TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
- TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
- n_output);
-
- // We make sure the input-gate's parameters are either both present (regular
- // LSTM) or not at all (CIFG-LSTM).
- const bool cifg_weights_all_or_none =
- ((input_to_input_weights != nullptr) &&
- (recurrent_to_input_weights != nullptr)) ||
- ((input_to_input_weights == nullptr) &&
- (recurrent_to_input_weights == nullptr));
- TF_LITE_ENSURE(context, cifg_weights_all_or_none == true);
-
- const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
- context, node, micro::lstm::full::kCellToInputWeightsTensor);
- if (cell_to_input_weights != nullptr) {
- TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
- TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
- TF_LITE_ENSURE_TYPES_EQ(
- context, cell_to_input_weights->type,
- is_integer ? kTfLiteInt16 : input_to_forget_weights->type);
- }
-
- const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(
- context, node, lstm::full::kCellToForgetWeightsTensor);
- if (cell_to_forget_weights != nullptr) {
- TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
- TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
- TF_LITE_ENSURE_TYPES_EQ(
- context, cell_to_forget_weights->type,
- is_integer ? kTfLiteInt16 : input_to_forget_weights->type);
- }
-
- const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(
- context, node, micro::lstm::full::kCellToOutputWeightsTensor);
- if (cell_to_output_weights != nullptr) {
- TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
- TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
- TF_LITE_ENSURE_TYPES_EQ(
- context, cell_to_output_weights->type,
- is_integer ? kTfLiteInt16 : input_to_forget_weights->type);
- }
-
- // Making sure the peephole weights are there all or none.
- const bool use_cifg = (input_to_input_weights == nullptr);
- const bool peephole_weights_all_or_none =
- ((cell_to_input_weights != nullptr || use_cifg) &&
- (cell_to_forget_weights != nullptr) &&
- (cell_to_output_weights != nullptr)) ||
- ((cell_to_input_weights == nullptr) &&
- (cell_to_forget_weights == nullptr) &&
- (cell_to_output_weights == nullptr));
- TF_LITE_ENSURE(context, peephole_weights_all_or_none == true);
- const TfLiteEvalTensor* input_gate_bias = tflite::micro::GetEvalInput(
- context, node, micro::lstm::full::kInputGateBiasTensor);
-
- if (use_cifg) {
- TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr);
- } else {
- TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1);
- TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
- if (is_integer) {
- TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteInt32);
- } else {
- TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteFloat32);
- }
- }
- const TfLiteEvalTensor* forget_gate_bias = tflite::micro::GetEvalInput(
- context, node, micro::lstm::full::kForgetGateBiasTensor);
-
- TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
- TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
- if (is_integer) {
- TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteInt32);
- } else {
- TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteFloat32);
- }
- const TfLiteEvalTensor* cell_gate_bias = tflite::micro::GetEvalInput(
- context, node, micro::lstm::full::kCellGateBiasTensor);
-
- TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->size, 1);
- TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->data[0], n_cell);
- if (is_integer) {
- TF_LITE_ENSURE_TYPES_EQ(context, cell_gate_bias->type, kTfLiteInt32);
- } else {
- TF_LITE_ENSURE_TYPES_EQ(context, cell_gate_bias->type, kTfLiteFloat32);
- }
- const TfLiteEvalTensor* output_gate_bias = tflite::micro::GetEvalInput(
- context, node, micro::lstm::full::kOutputGateBiasTensor);
- TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
- TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
- if (is_integer) {
- TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteInt32);
- } else {
- TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteFloat32);
- }
-
- const TfLiteTensor* projection_weights = GetOptionalInputTensor(
- context, node, micro::lstm::full::kProjectionWeightsTensor);
- if (projection_weights != nullptr) {
- TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
- TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
- TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell);
- }
-
- const TfLiteTensor* projection_bias = GetOptionalInputTensor(
- context, node, micro::lstm::full::kProjectionBiasTensor);
- if (projection_bias != nullptr) {
- TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
- TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
- if (is_integer) {
- TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteInt32);
- } else {
- TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteFloat32);
- }
- }
-
- // Making sure the projection tensors are consistent:
- // 1) If projection weight is not present, then projection bias should not be
- // present.
- // 2) If projection weight is present, then projection bias is optional.
- const bool projecton_tensors_consistent =
- ((projection_weights != nullptr) || (projection_bias == nullptr));
- TF_LITE_ENSURE(context, projecton_tensors_consistent == true);
-
- if (use_layer_norm) {
- const TfLiteEvalTensor* input_layer_norm_coefficients =
- tflite::micro::GetEvalInput(
- context, node,
- micro::lstm::full::kInputLayerNormCoefficientsTensor);
- if (use_cifg) {
- TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients, nullptr);
- } else {
- TF_LITE_ENSURE(context, input_layer_norm_coefficients != nullptr);
- TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->size, 1);
- TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->data[0],
- n_cell);
- if (is_integer) {
- TF_LITE_ENSURE_TYPES_EQ(context, input_layer_norm_coefficients->type,
- kTfLiteInt16);
- } else {
- TF_LITE_ENSURE_TYPES_EQ(context, input_layer_norm_coefficients->type,
- kTfLiteFloat32);
- }
- }
- const TfLiteEvalTensor* forget_layer_norm_coefficients =
- tflite::micro::GetEvalInput(
- context, node,
- micro::lstm::full::kForgetLayerNormCoefficientsTensor);
- TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->size, 1);
- TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->data[0],
- n_cell);
- if (is_integer) {
- TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type,
- kTfLiteInt16);
- } else {
- TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type,
- kTfLiteFloat32);
- }
- const TfLiteEvalTensor* cell_layer_norm_coefficients =
- tflite::micro::GetEvalInput(
- context, node, micro::lstm::full::kCellLayerNormCoefficientsTensor);
- TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->size, 1);
- TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->data[0],
- n_cell);
- if (is_integer) {
- TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type,
- kTfLiteInt16);
- } else {
- TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type,
- kTfLiteFloat32);
- }
- const TfLiteEvalTensor* output_layer_norm_coefficients =
- tflite::micro::GetEvalInput(
- context, node,
- micro::lstm::full::kOutputLayerNormCoefficientsTensor);
-
- TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->size, 1);
- TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->data[0],
- n_cell);
- if (is_integer) {
- TF_LITE_ENSURE_TYPES_EQ(context, output_layer_norm_coefficients->type,
- kTfLiteInt16);
- } else {
- TF_LITE_ENSURE_TYPES_EQ(context, output_layer_norm_coefficients->type,
- kTfLiteFloat32);
- }
- }
-
- return kTfLiteOk;
-}
-
-TfLiteStatus PrecomputeZeroPointTimesWeightWithBias(
- TfLiteContext* context, int32_t zero_point,
- const TfLiteTensor* weight_tensor, const TfLiteTensor* bias_tensor,
- std::unique_ptr<int32_t[]>* output) {
- if (weight_tensor == nullptr) {
- return kTfLiteOk;
- }
-
- const RuntimeShape& weight_shape = GetTensorShape(weight_tensor);
- TF_LITE_ENSURE_EQ(context, weight_shape.DimensionsCount(), 2);
- const int row = weight_shape.Dims(0);
- const int col = weight_shape.Dims(1);
- output->reset(new int32_t[row]);
- if (bias_tensor == nullptr) {
- memset(output->get(), 0, row * sizeof(int32_t));
- } else {
- const int32_t* bias = GetTensorData<int32_t>(bias_tensor);
- memcpy(output->get(), bias, row * sizeof(int32_t));
- }
- if (zero_point != 0) {
- const int8_t* weight = GetTensorData<int8_t>(weight_tensor);
- tensor_utils::PortableMatrixScalarMultiplyAccumulate(
- weight, zero_point, row, col, output->get());
- }
- return kTfLiteOk;
-}
-
-TfLiteStatus PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext* context,
- OpData* op_data,
- TfLiteNode* node) {
- const TfLiteTensor* input;
- TF_LITE_ENSURE_OK(
- context,
- GetInputSafe(context, node, micro::lstm::full::kInputTensor, &input));
- const TfLiteTensor* output_state =
- GetVariableInput(context, node, micro::lstm::full::kOutputStateTensor);
- TF_LITE_ENSURE(context, output_state != nullptr);
-
- const int32_t input_zero_point = -input->params.zero_point;
- const int32_t output_state_zero_point = -output_state->params.zero_point;
-
- const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
- context, node, micro::lstm::full::kInputToInputWeightsTensor);
- const TfLiteTensor* input_to_forget_weights;
- TF_LITE_ENSURE_OK(context,
- GetInputSafe(context, node,
- micro::lstm::full::kInputToForgetWeightsTensor,
- &input_to_forget_weights));
- const TfLiteTensor* input_to_cell_weights;
- TF_LITE_ENSURE_OK(
- context,
- GetInputSafe(context, node, micro::lstm::full::kInputToCellWeightsTensor,
- &input_to_cell_weights));
- const TfLiteTensor* input_to_output_weights;
- TF_LITE_ENSURE_OK(context,
- GetInputSafe(context, node,
- micro::lstm::full::kInputToOutputWeightsTensor,
- &input_to_output_weights));
-
- const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
- context, node, micro::lstm::full::kRecurrentToInputWeightsTensor);
- const TfLiteTensor* recurrent_to_forget_weights;
- TF_LITE_ENSURE_OK(
- context, GetInputSafe(context, node,
- micro::lstm::full::kRecurrentToForgetWeightsTensor,
- &recurrent_to_forget_weights));
- const TfLiteTensor* recurrent_to_cell_weights;
- TF_LITE_ENSURE_OK(
- context, GetInputSafe(context, node,
- micro::lstm::full::kRecurrentToCellWeightsTensor,
- &recurrent_to_cell_weights));
- const TfLiteTensor* recurrent_to_output_weights;
- TF_LITE_ENSURE_OK(
- context, GetInputSafe(context, node,
- micro::lstm::full::kRecurrentToOutputWeightsTensor,
- &recurrent_to_output_weights));
-
- const TfLiteTensor* projection_weights = GetOptionalInputTensor(
- context, node, lstm::full::kProjectionWeightsTensor);
- const TfLiteTensor* projection_bias = GetOptionalInputTensor(
- context, node, micro::lstm::full::kProjectionBiasTensor);
-
- lstm_eval::IntegerLstmParameter* integer_lstm_params =
- &op_data->integer_lstm_param;
-
- TfLiteTensor* intermediate =
- context->GetTensor(context, node->intermediates->data[4]);
- const auto* params =
- static_cast<TfLiteAffineQuantization*>(intermediate->quantization.params);
- const int32_t hidden_zp = params->zero_point->data[0];
-
- // Get bias and perform zero point calculation.
- // When there is layer normalization, the gate bias does not apply to matmul
- // directly:
- // y = ln(w * x + w * r + w * c) + b.
- const bool is_layer_norm = op_data->use_layer_norm;
-
- // Forget gate.
- const TfLiteTensor* forget_gate_bias =
- is_layer_norm
- ? nullptr
- : GetInput(context, node, micro::lstm::full::kForgetGateBiasTensor);
- TF_LITE_ENSURE_OK(
- context,
- PrecomputeZeroPointTimesWeightWithBias(
- context, input_zero_point, input_to_forget_weights, forget_gate_bias,
- &(integer_lstm_params->input_to_forget_effective_bias)));
-
- TF_LITE_ENSURE_OK(
- context,
- PrecomputeZeroPointTimesWeightWithBias(
- context, output_state_zero_point, recurrent_to_forget_weights,
- nullptr, &(integer_lstm_params->recurrent_to_forget_effective_bias)));
-
- // Modulation gate.
- const TfLiteTensor* cell_gate_bias =
- is_layer_norm
- ? nullptr
- : GetInput(context, node, micro::lstm::full::kCellGateBiasTensor);
- TF_LITE_ENSURE_OK(
- context,
- PrecomputeZeroPointTimesWeightWithBias(
- context, input_zero_point, input_to_cell_weights, cell_gate_bias,
- &(integer_lstm_params->input_to_cell_effective_bias)));
- TF_LITE_ENSURE_OK(
- context,
- PrecomputeZeroPointTimesWeightWithBias(
- context, output_state_zero_point, recurrent_to_cell_weights, nullptr,
- &(integer_lstm_params->recurrent_to_cell_effective_bias)));
-
- // Output gate.
- const TfLiteTensor* output_gate_bias =
- is_layer_norm
- ? nullptr
- : GetInput(context, node, micro::lstm::full::kOutputGateBiasTensor);
- TF_LITE_ENSURE_OK(
- context,
- PrecomputeZeroPointTimesWeightWithBias(
- context, input_zero_point, input_to_output_weights, output_gate_bias,
- &(integer_lstm_params->input_to_output_effective_bias)));
-
- TF_LITE_ENSURE_OK(
- context,
- PrecomputeZeroPointTimesWeightWithBias(
- context, output_state_zero_point, recurrent_to_output_weights,
- nullptr, &(integer_lstm_params->recurrent_to_output_effective_bias)));
-
- // Input gate. The calculation is only meaningful for non-cifg case.
- const TfLiteTensor* input_gate_bias =
- is_layer_norm
- ? nullptr
- : GetInput(context, node, micro::lstm::full::kInputGateBiasTensor);
- TF_LITE_ENSURE_OK(
- context,
- PrecomputeZeroPointTimesWeightWithBias(
- context, input_zero_point, input_to_input_weights, input_gate_bias,
- &(integer_lstm_params->input_to_input_effective_bias)));
- TF_LITE_ENSURE_OK(
- context,
- PrecomputeZeroPointTimesWeightWithBias(
- context, output_state_zero_point, recurrent_to_input_weights, nullptr,
- &(integer_lstm_params->recurrent_to_input_effective_bias)));
-
- // Projection bias. The calculation is only meaningful for with projection.
- TF_LITE_ENSURE_OK(context,
- PrecomputeZeroPointTimesWeightWithBias(
- context, hidden_zp, projection_weights, projection_bias,
- &(integer_lstm_params->projection_effective_bias)));
- return kTfLiteOk;
-}
-
-// Resize the output and state tensors based on the sizes of the input tensors.
-// Allocate a temporary scratch tensor. Also check that the sizes of the input
-// tensors match each other.
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
- // const int scratch_tensor_index = op_data->scratch_tensor_index;
-
- // Check we have all the inputs and outputs we need.
- bool use_layer_norm = false;
- if (node->inputs->size == 24) {
- const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor(
- context, node, micro::lstm::full::kForgetLayerNormCoefficientsTensor);
- if (forget_layer_norm_coefficients == nullptr) {
- use_layer_norm = false;
- } else {
- use_layer_norm = true;
- }
- } else if (node->inputs->size == 20) {
- // This is deprecated and is only kept here for backward compatibility.
- use_layer_norm = false;
- } else {
- MicroPrintf("The LSTM Full kernel expects 20 or 24 inputs. Got %d inputs",
- node->inputs->size);
- return kTfLiteError;
- }
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
- op_data->use_layer_norm = use_layer_norm;
-
- // Inferring batch size, number of outputs and sequence length and
- // number of cells from the input tensors.
- const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(
- context, node, micro::lstm::full::kInputTensor);
- const bool is_integer = input->type == kTfLiteInt8;
- TF_LITE_ENSURE(context, input->dims->size > 1);
- const auto* params =
- reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
- node->builtin_data);
- const bool time_major = params->time_major;
- const int n_batch = time_major ? input->dims->data[1] : input->dims->data[0];
- const int n_input = input->dims->data[2];
- const TfLiteEvalTensor* input_to_output_weights = tflite::micro::GetEvalInput(
- context, node, micro::lstm::full::kInputToOutputWeightsTensor);
- const int n_cell = input_to_output_weights->dims->data[0];
- TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
- TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
- const TfLiteEvalTensor* recurrent_to_output_weights =
- tflite::micro::GetEvalInput(
- context, node, micro::lstm::full::kRecurrentToOutputWeightsTensor);
-
- TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2);
- TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0],
- n_cell);
- const int n_output = recurrent_to_output_weights->dims->data[1];
-
- // Check that input tensor dimensions matches with each other.
- TF_LITE_ENSURE_OK(
- context, CheckInputTensorDimensions(context, node, n_input, n_output,
- n_cell, use_layer_norm, is_integer));
- // Get the pointer to output, output_state and cell_state buffer tensors.
- // TfLiteEvalTensor* output =
- // tflite::micro::GetEvalOutput(context, node,
- // micro::lstm::full::kOutputTensor);
- TfLiteEvalTensor* output_state = tflite::micro::GetMutableEvalInput(
- context, node, micro::lstm::full::kOutputStateTensor);
- TFLITE_DCHECK(output_state != nullptr);
- TfLiteEvalTensor* cell_state = tflite::micro::GetMutableEvalInput(
- context, node, micro::lstm::full::kCellStateTensor);
- TFLITE_DCHECK(cell_state != nullptr);
- // Check the shape of input state tensors.
- // These tensor may be 1D or 2D. It's fine as long as the total size is
- // correct.
- TF_LITE_ENSURE_EQ(context, NumElements(output_state->dims),
- n_batch * n_output);
- TF_LITE_ENSURE_EQ(context, NumElements(cell_state->dims), n_batch * n_cell);
-
- if (is_integer) {
- const int num_intermediate_tensors = node->intermediates->size;
- TF_LITE_ENSURE(context, num_intermediate_tensors == 5);
- }
-
- if (is_integer) {
- // Integer UnidirectionalSequenceLSTM prepare function for 8x8->16.
- // This code path needs 5 intermediate tensors per Op.
- // Populate quantization parameters.
- PopulateQuantizedLstmParams8x8_16(context, node,
- &op_data->integer_lstm_param);
- // Allocate scratch buffer. Need 6 16bit buffer with size n_batch * n_cell
- // and 1 8bit buffer with size n_batch * n_cell. We also need 1 32 bit
- // buffer with size n_batch * n_cell.
- //
- // Handle cifg case as well, which might save one buffer.
-
- int scratch_idx = 0;
-
- context->RequestScratchBufferInArena(
- context, n_batch * n_cell * sizeof(int32_t), &(scratch_idx));
- op_data->scratch_tensor_index = scratch_idx;
-
- for (int scratch_index = 1; scratch_index < 6; ++scratch_index) {
- // node->temporaries->data[scratch_index] = op_data->scratch_tensor_index
- // + scratch_index;
- context->RequestScratchBufferInArena(
- context, n_batch * n_cell * sizeof(int32_t), &(scratch_idx));
- TFLITE_DCHECK(scratch_idx ==
- (op_data->scratch_tensor_index + scratch_index));
- }
-
- // Populate precomputed zp * weight.
- TF_LITE_ENSURE_OK(context, PopulatePrecomputedZPTimesWeightsWithBias(
- context, op_data, node));
- }
-
- return kTfLiteOk;
-}
-
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- const auto* params =
- reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
- node->builtin_data);
- const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
- // const bool use_layer_norm = op_data->use_layer_norm;
- // const bool time_major = params->time_major;
-
- const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(
- context, node, micro::lstm::full::kInputTensor);
- const TfLiteEvalTensor* input_to_input_weights = tflite::micro::GetEvalInput(
- context, node, micro::lstm::full::kInputToInputWeightsTensor);
- const TfLiteEvalTensor* input_to_forget_weights = tflite::micro::GetEvalInput(
- context, node, micro::lstm::full::kInputToForgetWeightsTensor);
- const TfLiteEvalTensor* input_to_cell_weights = tflite::micro::GetEvalInput(
- context, node, micro::lstm::full::kInputToCellWeightsTensor);
- const TfLiteEvalTensor* input_to_output_weights = tflite::micro::GetEvalInput(
- context, node, micro::lstm::full::kInputToOutputWeightsTensor);
- const TfLiteEvalTensor* recurrent_to_input_weights =
- tflite::micro::GetEvalInput(
- context, node, micro::lstm::full::kRecurrentToInputWeightsTensor);
- const TfLiteEvalTensor* recurrent_to_forget_weights =
- tflite::micro::GetEvalInput(
- context, node, micro::lstm::full::kRecurrentToForgetWeightsTensor);
- const TfLiteEvalTensor* recurrent_to_cell_weights =
- tflite::micro::GetEvalInput(
- context, node, micro::lstm::full::kRecurrentToCellWeightsTensor);
- const TfLiteEvalTensor* recurrent_to_output_weights =
- tflite::micro::GetEvalInput(
- context, node, micro::lstm::full::kRecurrentToOutputWeightsTensor);
- const TfLiteEvalTensor* cell_to_input_weights = context->GetEvalTensor(
- context,
- node->inputs->data[micro::lstm::full::kCellToInputWeightsTensor]);
- const TfLiteEvalTensor* cell_to_forget_weights = context->GetEvalTensor(
- context,
- node->inputs->data[micro::lstm::full::kCellToForgetWeightsTensor]);
- const TfLiteEvalTensor* cell_to_output_weights = context->GetEvalTensor(
- context,
- node->inputs->data[micro::lstm::full::kCellToOutputWeightsTensor]);
- const TfLiteEvalTensor* input_gate_bias = context->GetEvalTensor(
- context, node->inputs->data[micro::lstm::full::kInputGateBiasTensor]);
-
- const TfLiteEvalTensor* forget_gate_bias = context->GetEvalTensor(
- context, node->inputs->data[micro::lstm::full::kForgetGateBiasTensor]);
- const TfLiteEvalTensor* cell_gate_bias = context->GetEvalTensor(
- context, node->inputs->data[micro::lstm::full::kCellGateBiasTensor]);
- const TfLiteEvalTensor* output_gate_bias = context->GetEvalTensor(
- context, node->inputs->data[micro::lstm::full::kOutputGateBiasTensor]);
-
- const TfLiteEvalTensor* projection_weights = context->GetEvalTensor(
- context, node->inputs->data[micro::lstm::full::kProjectionWeightsTensor]);
- const TfLiteEvalTensor* projection_bias = context->GetEvalTensor(
- context, node->inputs->data[micro::lstm::full::kProjectionBiasTensor]);
-
- TfLiteEvalTensor* output_state = context->GetEvalTensor(
- context, node->inputs->data[micro::lstm::full::kOutputStateTensor]);
- TFLITE_DCHECK(output_state != nullptr);
- TfLiteEvalTensor* cell_state = context->GetEvalTensor(
- context, node->inputs->data[micro::lstm::full::kCellStateTensor]);
- TFLITE_DCHECK(cell_state != nullptr);
- const TfLiteEvalTensor* input_layer_norm_coefficients =
- context->GetEvalTensor(
- context,
- node->inputs
- ->data[micro::lstm::full::kInputLayerNormCoefficientsTensor]);
-
- const TfLiteEvalTensor* forget_layer_norm_coefficients =
- context->GetEvalTensor(
- context,
- node->inputs
- ->data[micro::lstm::full::kForgetLayerNormCoefficientsTensor]);
- const TfLiteEvalTensor* cell_layer_norm_coefficients = context->GetEvalTensor(
- context,
- node->inputs->data[micro::lstm::full::kCellLayerNormCoefficientsTensor]);
-
- const TfLiteEvalTensor* output_layer_norm_coefficients =
- context->GetEvalTensor(
- context,
- node->inputs
- ->data[micro::lstm::full::kOutputLayerNormCoefficientsTensor]);
-
- TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(
- context, node, micro::lstm::full::kOutputTensor);
-
- // Copy out the LSTM specific params so they can be passed in the function.
- TfLiteLSTMParams lstm_params;
- lstm_params.activation = params->activation;
- lstm_params.cell_clip = params->cell_clip;
- lstm_params.proj_clip = params->proj_clip;
- lstm_params.asymmetric_quantize_inputs = params->asymmetric_quantize_inputs;
- switch (input_to_output_weights->type) {
- case kTfLiteInt8: {
- const bool is_hybrid = input->type == kTfLiteFloat32;
- if (is_hybrid) {
- MicroPrintf(" hybrid type is not supported.");
- return kTfLiteError;
-
- } else {
- TfLiteEvalTensor* scratch[6];
- // Allocate scratch buffer. Need 6 16bit buffer with size n_batch *
- // n_cell
- // and 1 8bit buffer with size n_batch * n_cell. We also need 1 32 bit
- // buffer with size n_batch * n_cell.
- //
- // Handle cifg case as well, which might save one buffer.
-
- const auto* tmp_params =
- reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
- node->builtin_data);
- const bool time_major = tmp_params->time_major;
- for (int scratch_index = 0; scratch_index < 6; ++scratch_index) {
- TFLITE_DCHECK(context != nullptr);
- TFLITE_DCHECK(context->GetScratchBuffer != nullptr);
- int32_t* scratch_tensor =
- static_cast<int32_t*>(context->GetScratchBuffer(
- context, op_data->scratch_tensor_index + scratch_index));
- scratch[scratch_index] = (TfLiteEvalTensor*)scratch_tensor;
- }
- /*
- TF_LITE_ENSURE_OK(context,
- GetScratchSafe(context, node, 0,
- &scratch0));
-
- TF_LITE_ENSURE_OK(context,
- GetScratchSafe(context, node, 1,
- &scratch1));
-
- TF_LITE_ENSURE_OK(context,
- GetScratchSafe(context, node, 2,
- &scratch2));
-
- TF_LITE_ENSURE_OK(context,
- GetScratchSafe(context, node, 3,
- &scratch3));
-
- TF_LITE_ENSURE_OK(context,
- GetScratchSafe(context, node, 4,
- &scratch4));
-
- TF_LITE_ENSURE_OK(context,
- GetScratchSafe(context, node, 5,
- &scratch5));
- */
- return lstm_eval::EvalInteger8x8_16(
- context, node, input, input_to_input_weights,
- input_to_forget_weights, input_to_cell_weights,
- input_to_output_weights, recurrent_to_input_weights,
- recurrent_to_forget_weights, recurrent_to_cell_weights,
- recurrent_to_output_weights, cell_to_input_weights,
- cell_to_forget_weights, cell_to_output_weights,
- input_layer_norm_coefficients, forget_layer_norm_coefficients,
- cell_layer_norm_coefficients, output_layer_norm_coefficients,
- input_gate_bias, forget_gate_bias, cell_gate_bias, output_gate_bias,
- projection_weights, projection_bias, &lstm_params,
- /*forward_sequence=*/true, time_major, &op_data->integer_lstm_param,
- output_state, cell_state, output, scratch[0], scratch[1],
- scratch[2], scratch[3], scratch[4], scratch[5]);
- }
- }
-
- default:
- MicroPrintf("Type %s is not currently supported.",
- TfLiteTypeGetName(input_to_output_weights->type));
- return kTfLiteError;
- }
- return kTfLiteOk;
-}
-//} // namespace unidirectional_sequence_lstm
-
-} // namespace micro
-} // namespace ops
-
TFLMRegistration Register_UNIDIRECTIONAL_SEQUENCE_LSTM() {
- return tflite::micro::RegisterOp(ops::micro::Init, ops::micro::Prepare,
- ops::micro::Eval);
+ return tflite::micro::RegisterOp(UnidirectionalSequenceLstmInit,
+ UnidirectionalSequenceLstmPrepare,
+ UnidirectionalSequenceLstmEval);
}
} // namespace tflite
diff --git a/tensorflow/lite/micro/tools/make/ext_libs/ndsplib-hifi5.patch b/tensorflow/lite/micro/tools/make/ext_libs/ndsplib-hifi5.patch
index 125e3f7..118845c 100644
--- a/tensorflow/lite/micro/tools/make/ext_libs/ndsplib-hifi5.patch
+++ b/tensorflow/lite/micro/tools/make/ext_libs/ndsplib-hifi5.patch
@@ -12,3 +12,20 @@
#define onchip
#define NASSERT(x) {(void)__builtin_expect((x)!=0,1);}
#else
+diff --git a/library/include_private/common.h b/library/include_private/common.h
+index 2eaf70f..9df811c 100644
+--- a/library/include_private/common.h
++++ b/library/include_private/common.h
+@@ -172,6 +172,12 @@ __pragma (warning(pop))
+ #if defined(COMPILER_XTENSA) || defined(COMPILER_GNU)
+ #define DISCARD_FUN(retval_type,name,arglist) \
+ __asm__(".type "#name", @object\n\t.global "#name"\n\t.align 4\n\t"#name":\n\t.long 0x49438B96,0x4D73F192\n\t");
++
++#define DISCARD_FUN_FOR_NONVOID_RETURN(retval_type,name,arglist) \
++__attribute__ ((section ("/DISCARD/"))) \
++retval_type name arglist \
++{ return (retval_type) 0; }
++
+ #endif
+
+ /*------ LIST OF DEFINES DEPENDING ON ISA OPTIONS ------*/
diff --git a/tensorflow/lite/micro/tools/make/ext_libs/xa_nnlib_hifi4.patch b/tensorflow/lite/micro/tools/make/ext_libs/xa_nnlib_hifi4.patch
index 227ee92..1bb15aa 100644
--- a/tensorflow/lite/micro/tools/make/ext_libs/xa_nnlib_hifi4.patch
+++ b/tensorflow/lite/micro/tools/make/ext_libs/xa_nnlib_hifi4.patch
@@ -1,207 +1,5 @@
-From 0a68f2ffa640d1b52314278cec838384722eb1d0 Mon Sep 17 00:00:00 2001
-From: William Huang <yushh@google.com>
-Date: Tue, 16 May 2023 09:18:55 +0000
-Subject: [PATCH] Optimize Xtensa transpose convolution for more kernel sizes
- and input channels.
-
-Previously, there were three code paths, in decreasing performance:
-
-1. Kernel size (H*W) multiple of 4, input channels multiple of 16
-2. Kernel size (H*W) multiple of 4, input channels multiple of 4
-3. Others (unoptimized case)
-
-This patch reduces them to the follow two cases:
-
-1. Input channels multiple of 4
-2. Others (unoptimized case)
-
-Original CL=cl/516144094
-
-BUG=227374718
-
-Signed-off-by: William Huang <yushh@google.com>
-
-Optimize Xtensa CONV2D circular buffer copy.
-
-In Xtensa's CONV2D kernel, data is shuffled around and padded so the 2D
-convolution turns into sequential vector products. Unfortunately, this
-process is somewhat slow, and the overhead is especially high for small
-vector lengths.
-
-This patch introduces the following:
-
-- Faster code path for no padding (since our models use VALID padding,
- i.e., no padding at all)
-- Manual loop if array is small and memcpy if array is large
-- Skip memset on padded channels as the corresponding kernels are
- already zero
-
-BUG=249796929
-
-Signed-off-by: William Huang <yushh@google.com>
-
-Add implementation for zero-copy CONV2D kernels.
-
-The previous `xa_nn_conv2d_std_sym8sxsym16s` implementation shuffles the
-input tensor into a circular buffer, flattening the dimensions, so that
-the 2D convolution turns into sequential vector products. However, this
-created significant overhead for layers where the resulting vector
-lengths are small.
-
-This patch implements an alternative zero-copy method that takes
-advantage of two facts:
-
-1. If `x_padding == 0`, the width dimension is automatically flattened
- with the channel dimension, and we need only `kernel_height`
- sequential vector products, even without the data shuffling
-2. Similar to the loop tiling done in
- `xa_nn_matXvec_sym8sxsym16s_sym16s_circ`, we can tile the `out_width`
- and `out_channels` dimensions, achieving the throughput of
- `_xa_nn_dot_product_2row_4vec_mat_vecs_4bytes_aligned` (i.e., 1.6
- MULAAAAQs/cycle), even when `out_height < 2`
-
-As a result, the patch significantly benefits layers where the kernel
-and output heights are small, leading to 25%+ cycle reductions in some
-use cases.
-
-Signed-off-by: William Huang <yushh@google.com>
----
- .../cnn/hifi4/xa_nn_conv2d_std_circ_buf.c | 84 +++++++-
- .../cnn/hifi4/xa_nn_conv2d_std_state.h | 15 ++
- .../cnn/hifi4/xa_nn_conv2d_std_sym8sxsym16s.c | 203 +++++++++++++++---
- .../hifi4/xa_nn_transpose_conv_sym8sxsym16s.c | 36 +---
- 4 files changed, 275 insertions(+), 63 deletions(-)
-
-diff --git a/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_circ_buf.c b/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_circ_buf.c
-index f8adba2..1a5f186 100644
---- a/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_circ_buf.c
-+++ b/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_circ_buf.c
-@@ -642,7 +642,8 @@ VOID conv2d_std_init_cir_buf(
- }
-
- // Add x_stride (but not more than kernel_width) x (input_height x input_channels) new planes to circular buffer
--VOID conv2d_std_update_cir_buf(
-+// Slow version of conv2d_std_update_cir_buf with fewer requirements
-+VOID conv2d_std_update_cir_buf_slow(
- WORD32 input_channels,
- WORD32 input_channels_pad,
- WORD32 input_bytewidth,
-@@ -742,6 +743,87 @@ VOID conv2d_std_update_cir_buf(
- *pp_inp = (VOID *)p_inp;
- }
-
-+// Add x_stride (but not more than kernel_width) x (input_height x input_channels) new planes to circular buffer
-+VOID conv2d_std_update_cir_buf(
-+ WORD32 input_channels,
-+ WORD32 input_channels_pad,
-+ WORD32 input_bytewidth,
-+ WORD32 input_width,
-+ WORD32 input_height,
-+ WORD32 y_padding,
-+ WORD32 y_b_pad,
-+ WORD32 x_padding,
-+ WORD32 kernel_width,
-+ WORD32 x_stride,
-+ VOID **pp_inp,
-+ WORD32 idx_beg_inp_width_pad,
-+ xa_nn_conv_state_t *p_state)
-+{
-+ if (y_padding != 0 || y_b_pad != 0 || x_padding != 0) {
-+ conv2d_std_update_cir_buf_slow(
-+ input_channels,
-+ input_channels_pad,
-+ input_bytewidth,
-+ input_width,
-+ input_height,
-+ y_padding,
-+ y_b_pad,
-+ x_padding,
-+ kernel_width,
-+ x_stride,
-+ pp_inp,
-+ idx_beg_inp_width_pad,
-+ p_state
-+ );
-+ return;
-+ }
-+
-+ WORD32 i,k;
-+ WORD8 *p_inp = (WORD8 *)*pp_inp;
-+ WORD32 planes_to_add = x_stride > kernel_width ? kernel_width : x_stride;
-+ WORD32 planes_to_keep = kernel_width - planes_to_add;
-+
-+ // Copy 'planes_to_add' planes of data to circular buffer
-+ AE_ADDCIRC16X4_XC((ae_int16x4 *)p_state->cir_buf.p_curr, planes_to_add * input_channels_pad * input_bytewidth);
-+ WORD8 *p_dst = (WORD8 *)p_state->cir_buf.p_curr;
-+ AE_ADDCIRC16X4_XC((ae_int16x4 *)p_dst, planes_to_keep * input_channels_pad * input_bytewidth);
-+
-+ WORD32 copy_inp_width = planes_to_add;
-+ WORD32 to_skip_inp_width = x_stride - planes_to_add; // Non-zero for x_stride > kernel_width
-+
-+ int size = input_channels * input_bytewidth;
-+ if (size <= 32) {
-+ for(i=0;i<input_height;i++)
-+ {
-+ for(k=0;k<copy_inp_width;k++)
-+ {
-+ for (int j = 0; j < size; ++j) {
-+ p_dst[j] = p_inp[j];
-+ }
-+ AE_ADDCIRC16X4_XC((ae_int16x4 *)p_dst, input_channels_pad * input_bytewidth);
-+ p_inp += input_channels * input_bytewidth;
-+ }
-+ AE_ADDCIRC16X4_XC((ae_int16x4 *)p_dst, planes_to_keep * input_channels_pad * input_bytewidth);
-+ p_inp += (input_width - copy_inp_width) * input_channels * input_bytewidth;
-+ }
-+ } else {
-+ for(i=0;i<input_height;i++)
-+ {
-+ for(k=0;k<copy_inp_width;k++)
-+ {
-+ memcpy(p_dst, p_inp, input_channels * input_bytewidth);
-+ AE_ADDCIRC16X4_XC((ae_int16x4 *)p_dst, input_channels_pad * input_bytewidth);
-+ p_inp += input_channels * input_bytewidth;
-+ }
-+ AE_ADDCIRC16X4_XC((ae_int16x4 *)p_dst, planes_to_keep * input_channels_pad * input_bytewidth);
-+ p_inp += (input_width - copy_inp_width) * input_channels * input_bytewidth;
-+ }
-+ }
-+ p_inp += (-input_height * input_width + copy_inp_width + to_skip_inp_width) * input_channels * input_bytewidth;
-+
-+ *pp_inp = (VOID *)p_inp;
-+}
-+
- VOID xa_nn_dilated_conv2d_std_load_cir_buf_asym8(
- WORD32 input_channels,
- WORD32 input_channels_pad,
-diff --git a/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_state.h b/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_state.h
-index a2ba355..8d33bad 100644
---- a/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_state.h
-+++ b/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_state.h
-@@ -214,6 +214,21 @@ VOID conv2d_std_init_cir_buf(
- VOID **pp_inp,
- xa_nn_conv_state_t *p_state);
-
-+VOID conv2d_std_update_cir_buf_slow(
-+ WORD32 input_channels,
-+ WORD32 input_channels_pad,
-+ WORD32 input_bytewidth,
-+ WORD32 input_width,
-+ WORD32 input_height,
-+ WORD32 y_padding,
-+ WORD32 y_b_pad,
-+ WORD32 x_padding,
-+ WORD32 kernel_width,
-+ WORD32 x_stride,
-+ VOID **pp_inp,
-+ WORD32 idx_beg_inp_width_pad,
-+ xa_nn_conv_state_t *p_state);
-+
- VOID conv2d_std_update_cir_buf(
- WORD32 input_channels,
- WORD32 input_channels_pad,
diff --git a/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_sym8sxsym16s.c b/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_sym8sxsym16s.c
-index 92721bc..6f868be 100644
+index b9905e9..990b713 100644
--- a/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_sym8sxsym16s.c
+++ b/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_sym8sxsym16s.c
@@ -49,6 +49,24 @@ static inline ae_int32x2 MultiplyByQuantizedMultiplier_ref(ae_int64 d_x,
@@ -229,78 +27,35 @@
static WORD32 conv_x_left_pad(
WORD32 x_padding,
WORD32 kernel_width,
-@@ -238,41 +256,166 @@ WORD32 xa_nn_conv2d_std_per_chan_sym8sxsym16s(
- WORD32 y_b_pad = kernel_height + (out_height - 1) * y_stride - (y_padding + input_height);
- y_b_pad = y_b_pad < 0 ? 0 : y_b_pad;
+@@ -129,6 +147,160 @@ static WORD32 conv_x_right_pad(
+ return out_width_over_x_r_pad;
+ }
-- conv2d_std_init_cir_buf(input_channels, input_channels_pad, input_bytewidth, input_width, input_height, y_padding, y_b_pad, x_padding_var, kernel_width, x_stride, (VOID**)&pp_inp, p_state);
-+ if (x_padding || (input_channels & 0x3) || (out_channels & 0x3) || (out_width & 0x1)) {
-+ conv2d_std_init_cir_buf(input_channels, input_channels_pad, input_bytewidth, input_width, input_height, y_padding, y_b_pad, x_padding_var, kernel_width, x_stride, (VOID**)&pp_inp, p_state);
-
-- // Index to padded input width
-- WORD32 idx_beg_inp_width_pad = kernel_width - x_stride;
-- idx_beg_inp_width_pad = idx_beg_inp_width_pad < 0 ? 0 : idx_beg_inp_width_pad;
-+ // Index to padded input width
-+ WORD32 idx_beg_inp_width_pad = kernel_width - x_stride;
-+ idx_beg_inp_width_pad = idx_beg_inp_width_pad < 0 ? 0 : idx_beg_inp_width_pad;
-
-
-- // Process Loop to compute one output plane [out_height x out_channels] per iteration
-- for(j=0;j<out_width-out_width_over_x_pad-out_width_over_x_r_pad;j++)
-- {
-- // Add x_stride x (input_height x input_channels) new planes to circular buffer
-- conv2d_std_update_cir_buf(input_channels, input_channels_pad, input_bytewidth, input_width, input_height, y_padding, y_b_pad, x_padding_var, kernel_width, x_stride, (VOID**)&pp_inp, idx_beg_inp_width_pad, p_state);
-+ // Process Loop to compute one output plane [out_height x out_channels] per iteration
-+ for(j=0;j<out_width-out_width_over_x_pad-out_width_over_x_r_pad;j++)
++static WORD32 xa_nn_conv2d_std_per_chan_sym8sxsym16s_no_circ_buf(
++ WORD16* __restrict__ p_out,
++ const WORD16* __restrict__ p_inp,
++ const WORD8* __restrict__ p_kernel,
++ const WORD64* __restrict__ p_bias,
++ WORD32 input_height,
++ WORD32 input_width,
++ WORD32 input_channels,
++ WORD32 kernel_height,
++ WORD32 kernel_width,
++ WORD32 out_channels,
++ WORD32 x_stride,
++ WORD32 y_stride,
++ WORD32 x_padding,
++ WORD32 y_padding,
++ WORD32 out_height,
++ WORD32 out_width,
++ WORD32 input_zero_bias,
++ WORD32 * p_out_multiplier,
++ WORD32 * p_out_shift,
++ WORD32 out_zero_bias,
++ WORD32 out_data_format
++ )
+ {
-+ // Add x_stride x (input_height x input_channels) new planes to circular buffer
-+ conv2d_std_update_cir_buf(input_channels, input_channels_pad, input_bytewidth, input_width, input_height, y_padding, y_b_pad, x_padding_var, kernel_width, x_stride, (VOID**)&pp_inp, idx_beg_inp_width_pad, p_state);
-
-- // Update index to input width padded
-- idx_beg_inp_width_pad += x_stride;
-+ // Update index to input width padded
-+ idx_beg_inp_width_pad += x_stride;
-
-- // Convolution using matXvec with matrix as circular buffer
-- xa_nn_matXvec_sym8sxsym16s_sym16s_circ
-- (p_out /* output */
-- ,p_state->cir_buf.p_curr/* matrix: rows x cols */
-- ,p_state->p_kernel_padded /* vec: cols */
-- ,p_bias /* bias */
-- ,out_height /* rows */
-- ,input_channels_pad * kernel_width * kernel_height /* cols */
-- ,input_channels_pad * kernel_width * y_stride/* row_offset */
-- ,out_channels /* vec_count */
-- ,input_channels_pad * kernel_width * kernel_height /* vec_stride */
-- ,out_channels_offset /* out_col_offset */
-- ,out_height_offset /* out_row_offset */
-- ,input_zero_bias
-- ,p_out_multiplier
-- ,p_out_shift
-- ,out_zero_bias
-- );
-- p_out += out_width_offset;
-+ // Convolution using matXvec with matrix as circular buffer
-+ xa_nn_matXvec_sym8sxsym16s_sym16s_circ
-+ (p_out /* output */
-+ ,p_state->cir_buf.p_curr/* matrix: rows x cols */
-+ ,p_state->p_kernel_padded /* vec: cols */
-+ ,p_bias /* bias */
-+ ,out_height /* rows */
-+ ,input_channels_pad * kernel_width * kernel_height /* cols */
-+ ,input_channels_pad * kernel_width * y_stride/* row_offset */
-+ ,out_channels /* vec_count */
-+ ,input_channels_pad * kernel_width * kernel_height /* vec_stride */
-+ ,out_channels_offset /* out_col_offset */
-+ ,out_height_offset /* out_row_offset */
-+ ,input_zero_bias
-+ ,p_out_multiplier
-+ ,p_out_shift
-+ ,out_zero_bias
-+ );
-+ p_out += out_width_offset;
-+ }
-+ } else {
++
+ const WORD16 *p_dst0_0 = p_out + 0;
+ const WORD16 *p_dst0_1 = p_out + 1;
+ const WORD16 *p_dst0_2 = p_out + 2;
@@ -310,8 +65,8 @@
+ const WORD16 *p_dst1_2 = p_out + out_channels + 2;
+ const WORD16 *p_dst1_3 = p_out + out_channels + 3;
+ int kernel_out_ch_offset = kernel_height * kernel_width * input_channels;
-+ int input_x_offset = input_channels * x_stride / 4;
-+ int p_inp_vec_stride = input_width * input_channels / 4;
++ int input_x_offset = (input_channels * x_stride) / 4;
++ int p_inp_vec_stride = (input_width * input_channels) / 4;
+ int p_kern_vec_stride = kernel_width * input_channels;
+ int vec_len = kernel_width * input_channels;
+ for (int out_y = 0; out_y < out_height; ++out_y) {
@@ -325,6 +80,7 @@
+ ae_int64 out1_1 = p_bias[out_ch + 1];
+ ae_int64 out1_2 = p_bias[out_ch + 2];
+ ae_int64 out1_3 = p_bias[out_ch + 3];
++
+ out0_0 = AE_SLAI64(out0_0, 8);
+ out0_1 = AE_SLAI64(out0_1, 8);
+ out0_2 = AE_SLAI64(out0_2, 8);
@@ -333,10 +89,11 @@
+ out1_1 = AE_SLAI64(out1_1, 8);
+ out1_2 = AE_SLAI64(out1_2, 8);
+ out1_3 = AE_SLAI64(out1_3, 8);
++
+ int in_x_o = out_x * x_stride;
+ int in_y_o = out_y * y_stride - y_padding;
+ int k_y_min = -in_y_o;
-+ int k_y_max = input_width - in_y_o;
++ int k_y_max = input_height - in_y_o;
+ k_y_min = (k_y_min < 0) ? 0 : k_y_min;
+ k_y_min = (k_y_min < kernel_height) ? k_y_min : kernel_height;
+ k_y_max = (k_y_max < 0) ? 0 : k_y_max;
@@ -382,6 +139,7 @@
+ AE_MULAAAAQ16(out1_3, d_inp1, d_kern3);
+ }
+ }
++
+ out0_0 = AE_SRAI64(out0_0, 8);
+ out0_1 = AE_SRAI64(out0_1, 8);
+ out0_2 = AE_SRAI64(out0_2, 8);
@@ -390,6 +148,7 @@
+ out1_1 = AE_SRAI64(out1_1, 8);
+ out1_2 = AE_SRAI64(out1_2, 8);
+ out1_3 = AE_SRAI64(out1_3, 8);
++
+ ae_int32x2 acc_vec0 = MultiplyByQuantizedMultiplier_x2_opt(
+ out0_0, out1_0, p_out_multiplier[out_ch + 0],
+ p_out_shift[out_ch + 0]);
@@ -423,70 +182,45 @@
+ p_dst1_3 += out_channels;
+ }
+ }
++ return 0;
++}
++
+ WORD32 xa_nn_conv2d_std_per_chan_sym8sxsym16s(
+ WORD16* __restrict__ p_out,
+ const WORD16* __restrict__ p_inp,
+@@ -180,6 +352,35 @@ WORD32 xa_nn_conv2d_std_per_chan_sym8sxsym16s(
+ XA_NNLIB_ARG_CHK_COND((p_out_shift[itr] < -31 || p_out_shift[itr] > 31), -1);
}
- return 0;
-diff --git a/algo/kernels/cnn/hifi4/xa_nn_transpose_conv_sym8sxsym16s.c b/algo/kernels/cnn/hifi4/xa_nn_transpose_conv_sym8sxsym16s.c
-index 7f31b75..a010d45 100644
---- a/algo/kernels/cnn/hifi4/xa_nn_transpose_conv_sym8sxsym16s.c
-+++ b/algo/kernels/cnn/hifi4/xa_nn_transpose_conv_sym8sxsym16s.c
-@@ -157,7 +157,7 @@ int xa_nn_transpose_conv_sym8sxsym16s(WORD16* output_data,
- */
- if(input_data && filter_data && output_data && scratch_buffer &&
- (((unsigned int)input_data&0x7)==0) && (((unsigned int)filter_data&0x3)==0) && (((unsigned int)output_data&0x7) == 0) &&
-- (((unsigned int)scratch_buffer&0x7) == 0) && ((input_depth&0xF)==0) && ((filter_height*filter_width&0x3)==0))
-+ (((unsigned int)scratch_buffer&0x7) == 0) && ((input_depth&0x3)==0))
- {
- {
- //tbd : batch = 1, need to handle other values and in_x_min/max= 0 .. need toc heck for other values
-@@ -180,7 +180,8 @@ int xa_nn_transpose_conv_sym8sxsym16s(WORD16* output_data,
- filt_y_max = (filt_y_max < filter_height) ? filt_y_max : filter_height;
- filt_y_max = (filt_y_max < 0) ? 0 : filt_y_max;
- pinp = (WORD16*)&input_data[in_y*input_width*input_depth+in_x*input_depth];
-- for (int in_channel = 0; in_channel < input_depth; in_channel+=16)
-+ int in_channel = 0;
-+ for (; in_channel + 15 < input_depth; in_channel+=16)
- {
- ae_int16x4 d_inp, d_inp1, d_inp2, d_inp3;
- AE_L16X4_IP(d_inp, (ae_int16x4*)pinp, sizeof(WORD64));
-@@ -235,36 +236,7 @@ int xa_nn_transpose_conv_sym8sxsym16s(WORD16* output_data,
- }
- }
- }
-- }
-- }
-- }
-- }
-- else if(input_data && filter_data && output_data && scratch_buffer &&
-- (((unsigned int)input_data&0x7)==0) && (((unsigned int)filter_data&0x3)==0) && (((unsigned int)output_data&0x7) == 0) &&
-- (((unsigned int)scratch_buffer&0x7) == 0) && ((input_depth&0x3)==0) && ((filter_height*filter_width&0x3)==0))
-- {
-- {
-- //tbd : batch = 1, need to handle other values and in_x_min/max= 0 .. need toc heck for other values
-- for (int in_y = 0; in_y < input_height; ++in_y)
-- {
-- for (int in_x = 0; in_x < input_width; ++in_x)
-- {
-- const int out_x_orig = in_x*stride_width - pad_width;
-- const int out_y_orig = in_y*stride_height - pad_height;
-- int filt_x_min = -out_x_orig;
-- int filt_x_max = output_width - out_x_orig;
-- int filt_y_min = -out_y_orig;
-- int filt_y_max = output_height - out_y_orig;
-- filt_x_min = (filt_x_min < filter_width) ? filt_x_min : filter_width;
-- filt_x_min = (filt_x_min < 0) ? 0 : filt_x_min;
-- filt_x_max = (filt_x_max < filter_width) ? filt_x_max : filter_width;
-- filt_x_max = (filt_x_max < 0) ? 0 : filt_x_max;
-- filt_y_min = (filt_y_min < filter_height) ? filt_y_min : filter_height;
-- filt_y_min = (filt_y_min < 0) ? 0 : filt_y_min;
-- filt_y_max = (filt_y_max < filter_height) ? filt_y_max : filter_height;
-- filt_y_max = (filt_y_max < 0) ? 0 : filt_y_max;
-- pinp = (WORD16*)&input_data[in_y*input_width*input_depth+in_x*input_depth];
-- for (int in_channel = 0; in_channel < input_depth; in_channel+=4)
-+ for (; in_channel + 3 < input_depth; in_channel+=4)
- {
- ae_int16x4 d_inp;
- AE_L16X4_IP(d_inp, (ae_int16x4*)pinp, sizeof(WORD64));
---
-2.41.0.162.gfafddb0af9-goog
-
++ if ( !(x_padding) && !(input_channels & 0x3) && !(out_channels & 0x3) && !(out_width & 0x1) && (out_data_format == 0) && ((out_width-1)*x_stride <=(input_width-kernel_width) ) )
++ {
++ int ret_val=0;
++ ret_val=xa_nn_conv2d_std_per_chan_sym8sxsym16s_no_circ_buf(p_out,
++ p_inp,
++ p_kernel,
++ p_bias,
++ input_height,
++ input_width,
++ input_channels,
++ kernel_height,
++ kernel_width,
++ out_channels,
++ x_stride,
++ y_stride,
++ x_padding,
++ y_padding,
++ out_height,
++ out_width,
++ input_zero_bias,
++ p_out_multiplier,
++ p_out_shift,
++ out_zero_bias,
++ out_data_format
++ );
++
++ return ret_val;
++ }
++
+ WORD32 j;
+ WORD32 input_bytewidth = 2;
+ VOID *pp_inp = (VOID *)p_inp;
diff --git a/tensorflow/lite/micro/tools/make/ext_libs/xa_nnlib_hifi5.patch b/tensorflow/lite/micro/tools/make/ext_libs/xa_nnlib_hifi5.patch
deleted file mode 100644
index 9d95c63..0000000
--- a/tensorflow/lite/micro/tools/make/ext_libs/xa_nnlib_hifi5.patch
+++ /dev/null
@@ -1,36 +0,0 @@
-diff --git a/algo/kernels/fc/hifi4/xa_nn_fully_connected.c b/algo/kernels/fc/hifi4/xa_nn_fully_connected.c
-index 26a2b73..61f0a64 100644
---- a/algo/kernels/fc/hifi4/xa_nn_fully_connected.c
-+++ b/algo/kernels/fc/hifi4/xa_nn_fully_connected.c
-@@ -298,7 +298,6 @@ WORD32 xa_nn_fully_connected_sym8sxasym8s_asym8s
- XA_NNLIB_ARG_CHK_PTR(p_out, -1);
- XA_NNLIB_ARG_CHK_PTR(p_weight, -1);
- XA_NNLIB_ARG_CHK_PTR(p_inp, -1);
-- XA_NNLIB_ARG_CHK_PTR(p_bias, -1);
- /* Pointer alignment checks */
- #if 0
- XA_NNLIB_ARG_CHK_ALIGN(p_out, ALIGNMENT, -1);
-@@ -310,7 +309,8 @@ WORD32 xa_nn_fully_connected_sym8sxasym8s_asym8s
- XA_NNLIB_ARG_CHK_ALIGN(p_out, sizeof(WORD8), -1);
- XA_NNLIB_ARG_CHK_ALIGN(p_weight, sizeof(WORD8), -1);
- XA_NNLIB_ARG_CHK_ALIGN(p_inp, sizeof(WORD8), -1);
-- XA_NNLIB_ARG_CHK_ALIGN(p_bias, sizeof(WORD32), -1);
-+ if (p_bias != NULL)
-+ XA_NNLIB_ARG_CHK_ALIGN(p_bias, sizeof(WORD32), -1);
- #endif
- /* Basic Parameter checks */
- XA_NNLIB_ARG_CHK_COND((out_depth <= 0), -1);
-diff --git a/algo/kernels/matXvec/hifi5/xa_nn_matXvec_sym8sxasym8s.c b/algo/kernels/matXvec/hifi5/xa_nn_matXvec_sym8sxasym8s.c
-index 5350cbe..a91e043 100644
---- a/algo/kernels/matXvec/hifi5/xa_nn_matXvec_sym8sxasym8s.c
-+++ b/algo/kernels/matXvec/hifi5/xa_nn_matXvec_sym8sxasym8s.c
-@@ -704,7 +704,8 @@ WORD32 xa_nn_matXvec_sym8sxasym8s_asym8s(
- XA_NNLIB_ARG_CHK_PTR(p_mat1, -1);
- XA_NNLIB_ARG_CHK_PTR(p_vec1, -1);
- /* Pointer alignment checks */
-- XA_NNLIB_ARG_CHK_ALIGN(p_bias, sizeof(WORD32), -1);
-+ if (p_bias != NULL)
-+ XA_NNLIB_ARG_CHK_ALIGN(p_bias, sizeof(WORD32), -1);
- /* Basic Parameter checks */
- XA_NNLIB_ARG_CHK_COND((rows <= 0), -1);
- XA_NNLIB_ARG_CHK_COND((cols1 <= 0), -1);
diff --git a/tensorflow/lite/micro/tools/make/ext_libs/xtensa_download.sh b/tensorflow/lite/micro/tools/make/ext_libs/xtensa_download.sh
index fb45123..d80855f 100755
--- a/tensorflow/lite/micro/tools/make/ext_libs/xtensa_download.sh
+++ b/tensorflow/lite/micro/tools/make/ext_libs/xtensa_download.sh
@@ -44,13 +44,13 @@
fi
if [[ ${2} == "hifi4" ]]; then
- LIBRARY_URL="http://github.com/foss-xtensa/nnlib-hifi4/raw/master/archive/xa_nnlib_hifi4_10_14_2022.zip"
+ LIBRARY_URL="http://github.com/foss-xtensa/nnlib-hifi4/raw/master/archive/xa_nnlib_hifi4_09_05_2023.zip"
LIBRARY_DIRNAME="xa_nnlib_hifi4"
- LIBRARY_MD5="2bf3c1c7fd5a23f157babc8e24fd2c55"
+ LIBRARY_MD5="2a54e056aef73a4fcffde4643998501a"
elif [[ ${2} == "hifi5" ]]; then
- LIBRARY_URL="http://github.com/foss-xtensa/nnlib-hifi5/raw/master/archive/xa_nnlib_hifi5_12_19_2022.zip"
+ LIBRARY_URL="http://github.com/foss-xtensa/nnlib-hifi5/raw/master/archive/xa_nnlib_hifi5_09_05_2023.zip"
LIBRARY_DIRNAME="xa_nnlib_hifi5"
- LIBRARY_MD5="83306809191f42a064bde688b94e1eb1"
+ LIBRARY_MD5="1deb55ef200bf5dbedc70b99b02140c0"
elif [[ ${2} == "vision_p6" ]]; then
LIBRARY_URL="https://github.com/foss-xtensa/tflmlib_vision/raw/main/archive/xi_tflmlib_vision_p6_22_06_29.zip"
LIBRARY_DIRNAME="xi_tflmlib_vision_p6"