Xtensa LSTM: (#2150)

Enabled LSTM kernel support for XTENSA target.
Updated xtensa_downloads script to use the latest HiFi NN Libraries.

The 8x16 unit test cases has non-zero zero_point for 16 bit output.
[https://github.com/tensorflow/tflite-micro/blob/main/tensorflow/lite/micro/kernels/testdata/lstm_test_data.cc#L255C1-L258C61](url)

Default run for all the 8x16 unit test cases result: FAIL. This is due to non-zero output offset value.

BUG=#1867
diff --git a/tensorflow/lite/micro/kernels/lstm_eval_test.cc b/tensorflow/lite/micro/kernels/lstm_eval_test.cc
index 53c0d7c..eaba2c4 100644
--- a/tensorflow/lite/micro/kernels/lstm_eval_test.cc
+++ b/tensorflow/lite/micro/kernels/lstm_eval_test.cc
@@ -454,6 +454,6 @@
                                        cell_state_tolerance,
                                        int16_node_contents);
 }
-
 #endif  // !defined(XTENSA)
+
 TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/lite/micro/kernels/unidirectional_sequence_lstm_test.cc b/tensorflow/lite/micro/kernels/unidirectional_sequence_lstm_test.cc
index c85e56f..ea11afc 100644
--- a/tensorflow/lite/micro/kernels/unidirectional_sequence_lstm_test.cc
+++ b/tensorflow/lite/micro/kernels/unidirectional_sequence_lstm_test.cc
@@ -150,7 +150,6 @@
 TF_LITE_MICRO_TESTS_BEGIN
 // TODO(b/230666079) enable below tests for xtensa when the xtensa
 // kernel is reconciled with reference kernel
-#if !defined(XTENSA)
 TF_LITE_MICRO_TEST(TestUnidirectionalLSTMFloat) {
   const tflite::testing::LstmEvalCheckData<12, 4, 12> kernel_eval_data =
       tflite::testing::Get2X2LstmEvalCheckData();
@@ -193,5 +192,4 @@
       kernel_eval_data, hidden_state_tolerance, cell_state_tolerance,
       int16_node_contents);
 }
-#endif  // !defined(XTENSA)
 TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/lite/micro/kernels/xtensa/lstm_eval.cc b/tensorflow/lite/micro/kernels/xtensa/lstm_eval.cc
index 9065388..af5bad7 100644
--- a/tensorflow/lite/micro/kernels/xtensa/lstm_eval.cc
+++ b/tensorflow/lite/micro/kernels/xtensa/lstm_eval.cc
@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -14,1204 +14,478 @@
 ==============================================================================*/
 #include "tensorflow/lite/micro/kernels/xtensa/lstm_eval.h"
 
-#include <math.h>
-#include <string.h>
+#include <limits>
 
-#include <algorithm>
-#include <cstdint>
-#include <memory>
-#include <vector>
-
-#include "tensorflow/lite/c/builtin_op_data.h"
-#include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/kernels/internal/compatibility.h"
-#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h"
-#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
-#include "tensorflow/lite/kernels/op_macros.h"
+#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
+#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
+#include "tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h"
+#include "tensorflow/lite/kernels/internal/reference/integer_ops/mul.h"
+#include "tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h"
+#include "tensorflow/lite/kernels/internal/reference/logistic.h"
+#include "tensorflow/lite/kernels/internal/reference/mul.h"
+#include "tensorflow/lite/kernels/internal/reference/tanh.h"
+#include "tensorflow/lite/kernels/internal/types.h"
 #include "tensorflow/lite/micro/kernels/xtensa/xtensa.h"
 
 namespace tflite {
-namespace ops {
-namespace micro {
-namespace lstm_eval {
-namespace {
 
-// Calculates a single LSTM gate, int8x8_16 version.
-// Implements the same functionality as CalculateLstmGateFloat.
-void CalculateLstmGateInteger8x8_16(
-    // Input and weights
-    const int8_t* input, const int8_t* input_to_gate_weights,
-    const int32_t* input_to_gate_bias, const int32_t input_to_gate_scale_a,
-    const int32_t input_to_gate_scale_b,
-    // Output state and weights
-    const int8_t* output_state, const int8_t* recurrent_to_gate_weights,
-    const int32_t* recurrent_to_gate_bias,
-    const int32_t recurrent_to_gate_scale_a,
-    const int32_t recurrent_to_gate_scale_b,
-    // Cell state and weights
-    const int16_t* cell_state, const int16_t* cell_to_gate_weights,
-    const int32_t cell_to_gate_scale_a, const int32_t cell_to_gate_scale_b,
-    // Layer normalization parameters (layer norm LSTM)
-    const int16_t* layer_norm_coefficients, const int32_t* layer_norm_bias,
-    const int32_t layer_norm_input_scale_a,
-    const int32_t layer_norm_input_scale_b,
-    const int32_t layer_norm_variance_guard,
-    // Array sizes
-    const int n_batch, const int n_input, const int n_output, const int n_cell,
-    const TfLiteFusedActivation activation,
-    // Output
-    int16_t* gate,
-    // Parameters for performance optimizations
-    // CpuBackendContext* context,
-    // Scratch arrays
-    int32_t* scratch5) {
-  const bool use_peephole = (cell_to_gate_weights != nullptr);
-  const bool use_layer_norm = (layer_norm_coefficients != nullptr);
-
-  // Initialize scratch buffers with zeros. Note that unlike float and hybrid
-  // versions, bias is only used in layer normalization.
-  std::fill_n(gate, n_batch * n_cell, 0);
-#if !defined(HIFI5)
-  // For each batch and cell: compute input_weight * input.
-  tensor_utils::PortableMatrixBatchVectorMultiplyAccumulate(
-      input, input_to_gate_bias, input_to_gate_weights, input_to_gate_scale_a,
-      input_to_gate_scale_b, n_batch, n_input, n_cell, 0, scratch5, gate, NULL);
-#else
-  {
-    xa_nn_matXvec_acc_batch_sym8sx8_asym16s(
-        gate, input_to_gate_weights, input, input_to_gate_bias, n_cell, n_input,
-        n_input, input_to_gate_scale_a, input_to_gate_scale_b, 0, n_batch);
+LstmTensors::LstmTensors(TfLiteContext* context, TfLiteNode* node) {
+  micro_context_ = GetMicroContext(context);
+  // 24 internal tensors. see lstm_shared.h for tensor names
+  for (size_t i = 0; i < 24; i++) {
+    internal_tensors_[i] = micro_context_->AllocateTempInputTensor(node, i);
   }
-#endif  // !defined(HIFI5)
-// Note: no aux_input.
-
-// For each batch and cell: compute recurrent_weight * output_state.
-#if !defined(HIFI5)
-  tensor_utils::PortableMatrixBatchVectorMultiplyAccumulate(
-      output_state, recurrent_to_gate_bias, recurrent_to_gate_weights,
-      recurrent_to_gate_scale_a, recurrent_to_gate_scale_b, n_batch, n_output,
-      n_cell, 0, scratch5, gate, NULL);
-#else
-  {
-    xa_nn_matXvec_acc_batch_sym8sx8_asym16s(
-        gate, recurrent_to_gate_weights, output_state, recurrent_to_gate_bias,
-        n_cell, n_output, n_output, recurrent_to_gate_scale_a,
-        recurrent_to_gate_scale_b, 0, n_batch);
-  }
-#endif  // !defined(HIFI5)
-  // For each batch and cell: compute cell_weight * cell_state (peephole LSTM)
-  if (use_peephole) {
-    tensor_utils::PortableVectorBatchVectorCwiseProductAccumulate(
-        cell_to_gate_weights, n_output, cell_state, n_batch,
-        cell_to_gate_scale_a, cell_to_gate_scale_b, gate);
-  }
-  // Do layer normalization (if layer norm LSTM)
-  if (use_layer_norm) {
-    tensor_utils::PortableApplyLayerNorm(
-        gate, layer_norm_coefficients, layer_norm_bias,
-        layer_norm_input_scale_a, layer_norm_input_scale_b,
-        layer_norm_variance_guard, n_batch, n_cell, gate);
-  }
-  // Apply activation
-  switch (activation) {
-    case kTfLiteActSigmoid:
-#if !defined(HIFI5)
-      tensor_utils::PortableApplySigmoid(gate, n_batch, n_cell, gate);
-#else
-      xa_nn_vec_sigmoid_16_16(gate, gate, n_batch * n_cell);
-#endif  // !defined(HIFI5)
-      break;
-    case kTfLiteActTanh:
-#if !defined(HIFI5)
-      tensor_utils::PortableApplyTanh(3, gate, n_batch, n_cell, gate);
-#else
-      xa_nn_vec_tanh_16_16(gate, gate, 3, n_batch * n_cell);
-#endif  // !defined(HIFI5)
-      break;
-    default:
-      // Only Sigmoid or Tanh is used.
-      TFLITE_ASSERT_FALSE;
-  }
+  output_tensor_ =
+      micro_context_->AllocateTempOutputTensor(node, kLstmOutputTensor);
 }
 
-// Updates the LSTM cell state, used by both integer LSTM versions.
-// Also see UpdateLstmCellFloat.
-//
-// Parameters:
-//  - n_batch, n_cell: sizes of vectors
-//  - cell_state: input/output vector, size n_batch*n_cell
-//  - cell_state_scale: scaling factor of cell state.
-//  - input_gate: input vector, size n_batch*n_cell.
-//  - forget_gate: input/scratch vector, size n_batch*n_cell, always modified.
-//  - cell_gate: input vector, size n_batch*n_cell.
-//  - use_cifg: use 1-forget_gate instead of input_gate.
-//  - clip: if > 0, clip the resulting cell state to [-clip, +clip].
-void UpdateLstmCellInteger(int n_batch, int n_cell, int16_t* cell_state,
-                           int32_t cell_state_scale, const int16_t* input_gate,
-                           int16_t* forget_gate, const int16_t* cell_gate,
-                           bool use_cifg, int16_t clip) {
-#if !defined(HIFI5)
-  // Use the forget_gate array as scratch, as input_gate array is not allocated
-  // in CIFG case. (Be careful not to write to the scratch before reading the
-  // forget gate data.)
-  int16_t* scratch = forget_gate;
-
-  tensor_utils::PortableCwiseMul(forget_gate, cell_state, n_batch, n_cell, 15,
-                                 cell_state);
-  if (use_cifg) {
-    tensor_utils::PortableSub1Vector(forget_gate, n_batch * n_cell, scratch);
-    tensor_utils::PortableCwiseMul(scratch, cell_gate, n_batch, n_cell,
-                                   30 + cell_state_scale, scratch);
-  } else {
-    tensor_utils::PortableCwiseMul(input_gate, cell_gate, n_batch, n_cell,
-                                   30 + cell_state_scale, scratch);
-  }
-  tensor_utils::PortableCwiseAdd(cell_state, scratch, n_batch, n_cell,
-                                 cell_state);
-
-  if (clip > 0) {
-    tensor_utils::PortableCwiseClipping(cell_state, n_batch * n_cell, clip);
-  }
-#else
-  if (use_cifg) {
-    calc_cell_state_with_cifg(cell_state, forget_gate, cell_gate, 15,
-                              30 + cell_state_scale, clip, n_batch * n_cell);
-  } else {
-    calc_cell_state_without_cifg(cell_state, forget_gate, cell_gate, input_gate,
-                                 15, 30 + cell_state_scale, clip,
-                                 n_batch * n_cell);
-  }
-
-#endif  // !defined(HIFI5)
-}
-
-// Calculates the output state tensor of an LSTM step. See Float and hybrid
-// versions as well.
-//
-// Parameters:
-//  - n_batch: batches: the number of distinct vectors in each array.
-//  - n_cell, n_output: sizes of vectors.
-//  - cell_state, output_gate: input vectors, size n_batch*n_cell.
-//  - cell_state_scale: scaling of cell_state.
-//  - hidden_scale_[a|b]: effective scale of cell_state.*output_gate
-//  - hidden_zp: zero_point for cell_state.*output_gate
-//  - projection_weights, proj_scale_[a|b], projection_bias:
-//      constant inputs, describing projection matrix and bias.
-//  - output_state_zp: zero point of output_state. (Input, calibrated value.)
-//  - quantized_proj_clip: if > 0, clip the output of the projection.
-//  - output_state: output vector, size n_batch*n_output. Must be contiguous.
-//  - context: data for optimized MatrixBatchVectorMultiplyAccumulate.
-//  - scratch0: scratch area of size n_batch*n_cell
-//  - scratch1: scratch area of size n_batch*n_cell
-//  - scratch2: scratch area used by MatrixBatchVectorMultiplyAccumulate
-void CalculateLstmOutputInteger8x8_16(
-    int n_batch, int n_cell, int n_output, const int16_t* cell_state,
-    int32_t cell_state_scale, const int16_t* output_gate,
-    int32_t hidden_scale_a, int32_t hidden_scale_b, int32_t hidden_zp,
-    const int8_t* projection_weights, int32_t proj_scale_a,
-    int32_t proj_scale_b, const int32_t* projection_bias,
-    int32_t output_state_zp, int8_t quantized_proj_clip, int8_t* output_state,
-    int16_t* scratch0, int8_t* scratch1, int32_t* scratch2) {
-// Note: unlike float/hybrid, the activation is always Tanh.
-#if !defined(HIFI5)
-  tensor_utils::PortableApplyTanh(15 + cell_state_scale, cell_state, n_batch,
-                                  n_cell, scratch0);
-#else
-  xa_nn_vec_tanh_16_16(scratch0, cell_state, (15 + cell_state_scale),
-                       n_batch * n_cell);
-#endif  // !defined(HIFI5)
-
-#if !defined(HIFI5)
-  tensor_utils::PortableCwiseMul(output_gate, scratch0, hidden_scale_a,
-                                 hidden_scale_b, n_batch, n_cell, hidden_zp,
-                                 scratch1);
-#else
-  xa_nn_elm_mul_16x16_asym8s(scratch1, output_gate, scratch0, hidden_scale_a,
-                             hidden_scale_b, hidden_zp, n_batch * n_cell);
-#endif  // !defined(HIFI5)
-
-  const bool use_projection = (projection_weights != nullptr);
-
-  if (use_projection) {
-    // Note: no bias like in float/hybrid
-    std::fill_n(output_state, n_batch * n_output, 0);
-    tensor_utils::PortableMatrixBatchVectorMultiplyAccumulate(
-        scratch1, projection_bias, projection_weights, proj_scale_a,
-        proj_scale_b, n_batch, n_cell, n_output, output_state_zp, scratch2,
-        output_state, NULL);
-    if (quantized_proj_clip > 0) {
-      tensor_utils::PortableCwiseClipping(output_state, n_batch * n_output,
-                                          quantized_proj_clip);
-    }
-  } else {
-    std::copy_n(scratch1, n_batch * n_output, output_state);
-  }
-}
-
-// Calculates a single LSTM gate, int8x8_8 version.
-// Implements the same functionality as CalculateLstmGateFloat.
-void CalculateLstmGateInteger8x8_8(
-    // Inputs and weights
-    const int8_t* input, int32_t input_zp, const int8_t* input_to_gate_weight,
-    const int32_t input_to_gate_scale_a, const int32_t input_to_gate_scale_b,
-    const int32_t input_times_weights_scale_a,
-    const int32_t input_times_weights_scale_b,
-    const int32_t input_times_weights_zp,
-    // Output state and weights
-    const int8_t* output_state, const int32_t output_state_zp,
-    const int8_t* recurrent_to_gate_weight,
-    const int32_t recurrent_to_gate_scale_a,
-    const int32_t recurrent_to_gate_scale_b,
-    const int32_t output_state_times_weights_scale_a,
-    const int32_t output_state_times_weights_scale_b,
-    const int32_t output_state_times_weights_zp,
-    // Layer normalization parameters (layer norm LSTM)
-    const int16_t* layer_norm_gate_weight,
-    const int32_t layer_norm_gate_scale_a,
-    const int32_t layer_norm_gate_scale_b, const int32_t* gate_bias,
-    // Array sizes
-    const int n_batch, const int n_input, const int n_output, const int n_cell,
-    const TfLiteFusedActivation activation,
-    // Output
-    int16_t* gate,
-    // Scratch arrays, both sized n_batch*n_cell
-    int8_t* scratch0, int8_t* scratch1) {
-  // Multiply input * input_weights => scratch0
-  tensor_utils::PortableMatrixBatchVectorMultiply(
-      input, input_zp, input_to_gate_weight, input_to_gate_scale_a,
-      input_to_gate_scale_b, n_batch, n_input, n_cell, scratch0,
-      input_times_weights_zp);
-  // Multiply output_state * recurrent_weights => scratch1
-  tensor_utils::PortableMatrixBatchVectorMultiply(
-      output_state, output_state_zp, recurrent_to_gate_weight,
-      recurrent_to_gate_scale_a, recurrent_to_gate_scale_b, n_batch, n_output,
-      n_cell, scratch1, output_state_times_weights_zp);
-  // Add scratch0 + scratch1 => gate
-  tensor_utils::PortableTwoGateSaturatingAdd(
-      scratch0, input_times_weights_zp, scratch1, output_state_times_weights_zp,
-      input_times_weights_scale_a, input_times_weights_scale_b,
-      output_state_times_weights_scale_a, output_state_times_weights_scale_b,
-      n_batch, n_cell, gate);
-  // Apply layer normalization.
-  tensor_utils::PortableApplyLayerNormFloat(
-      gate, layer_norm_gate_weight, layer_norm_gate_scale_a,
-      layer_norm_gate_scale_b, gate_bias, n_batch, n_cell, gate);
-  // Apply activation.
-  switch (activation) {
-    case kTfLiteActSigmoid:
-      tensor_utils::PortableApplySigmoidFloat(gate, n_batch, n_cell, gate);
-      break;
-    case kTfLiteActTanh:
-      tensor_utils::PortableApplyTanhFloat(gate, n_batch, n_cell, -12, gate);
-      break;
-    default:
-      // Only Sigmoid or Tanh is used.
-      TFLITE_ASSERT_FALSE;
-  }
-}
-
-// Calculates the output state tensor of an LSTM step. See Float and hybrid
-// versions as well.
-//
-// Parameters:
-//  - n_batch: batches: the number of distinct vectors in each array.
-//  - n_cell, n_output: sizes of vectors.
-//  - cell_state, output_gate: input vectors, size n_batch*n_cell.
-//  - projection_weights, proj_scale_[a|b], projection_bias:
-//      constant inputs, describing projection matrix and bias.
-//  - output_state_zp: zero point of the output state.
-//  - quantized_proj_clip: if > 0, clip the output of the projection.
-//  - output_state: output vector, size n_batch*n_output. Must be contiguous.
-//  - scratch: scratch area of size n_batch*n_cell
-void CalculateLstmOutputInteger8x8_8(
-    int n_batch, int n_cell, int n_output, const int16_t* cell_state,
-    const int16_t* output_gate, const int8_t* projection_weights,
-    int32_t proj_scale_a, int32_t proj_scale_b, const int32_t* projection_bias,
-    int32_t output_state_zp, int32_t quantized_proj_clip, int8_t* output_state,
-    int16_t* scratch) {
-  // Note: unlike float/hybrid, the activation is always Tanh.
-  tensor_utils::PortableApplyTanhFloat(cell_state, n_batch, n_cell, -15,
-                                       scratch);
-  tensor_utils::PortableCwiseMul(output_gate, scratch, n_batch, n_cell,
-                                 15 + 15 - 15, scratch);
-  // Note: no bias like in float/hybrid
-  tensor_utils::PortableMatrixBatchVectorMultiply(
-      scratch, projection_weights, proj_scale_a, proj_scale_b, projection_bias,
-      n_batch, n_cell, n_output, output_state_zp, output_state);
-  if (quantized_proj_clip > 0) {
-    tensor_utils::PortableCwiseClipping(output_state, n_batch * n_output,
-                                        (int8_t)quantized_proj_clip);
-  }
-}
-
-// Fully quantized lstm kernel for 16 bit gate matmul output.
-//
-// Input tensor of size n_batch * n_input:
-//   input_ptr
-//
-// LSTM weights:
-// Quantized input weights of size 'n_cell * n_input':
-//   input_to_input_weight_ptr            - optional
-//   input_to_forget_weight_ptr           - optional
-//   input_to_cell_weight_ptr             - optional
-//   input_to_output_weight_ptr           - optional
-//
-// Quantized recurrent weights of size 'n_cell * n_output':
-//   recurrent_to_input_weight_ptr        - optional
-//   recurrent_to_forget_weights_ptr
-//   recurrent_to_cell_weights_ptr
-//   recurrent_to_input_weights_ptr
-//
-// Quantized peephole weights of size 'n_cell', representing diagonal matrices.
-//   cell_to_input_weights               - optional
-//   cell_to_cell_weights                - optional
-//   cell_to_output_weights              - optional
-//
-// Quantized projection weights of size 'n_output * n_cell'
-//   projection_weight_ptr                     - optional
-//
-// Weight scales (scalars) for each of the weights above.
-//   effective_input_to_input_scale_a    - optional
-//   effective_input_to_input_scale_b    - optional
-//   effective_input_to_forget_scale_a
-//   effective_input_to_forget_scale_b
-//   effective_input_to_cell_scale_a
-//   effective_input_to_cell_scale_b
-//   effective_input_to_output_scale_a
-//   effective_input_to_output_scale_b
-//   effective_recurrent_to_input_scale_a    - optional
-//   effective_recurrent_to_input_scale_b    - optional
-//   effective_recurrent_to_forget_scale_a
-//   effective_recurrent_to_forget_scale_b
-//   effective_recurrent_to_cell_scale_a
-//   effective_recurrent_to_cell_scale_b
-//   effective_recurrent_to_output_scale_a
-//   effective_recurrent_to_output_scale_b
-//   effective_proj_scale_a                  - optional
-//   effective_proj_scale_b                  - optional
-//
-// Gate biases of size 'n_cell':
-//   input_gate_bias_ptr                 - optional
-//   forget_gate_bias_ptr
-//   cell_gate_bias_ptr
-//   output_gate_bias_ptr
-//
-// Layer norm coefficients of size 'n_cell', representing diagonal matrices.
-//   layer_norm_input_weight_ptr    - optional
-//   layer_norm_forget_weight_ptr   - optional
-//   layer_norm_cell_weight_ptr     - optional
-//   layer_norm_output_weight_ptr   - optional
-//
-// Layer norm scales of size 'n_cell'.
-//   layer_norm_input_scale_a     - optional
-//   layer_norm_input_scale_b     - optional
-//   layer_norm_forget_scale_a    - optional
-//   layer_norm_forget_scale_b    - optional
-//   layer_norm_cell_scale_a      - optional
-//   layer_norm_cell_scale_b      - optional
-//   layer_norm_output_scale_a    - optional
-//   layer_norm_output_scale_b    - optional
-//
-// Scalar values:
-//   quantized_cell_clip: quantized clip value for cell.
-//   quantized_proj_clip: quantized clip value for projection.
-//   cell_state_scale: the power of two scale for cell state.
-//
-// Zero points:
-//   output_state_zp: zero point of output state
-//   hidden_zp: zero point for hidden state.
-//
-// Temporary pre-allocated storage for the calculation. Each is of size n_cell *
-// n_batch.
-//   scratch0
-//   scratch1
-//   scratch2
-//   scratch3
-//   scratch4
-//   scratch5: this scratch buffer is created purely for optimizing the
-//              MatrixBatchVectorMultiplyAccumulate.
-//
-// Outputs:
-//   output_state_ptr - size 'n_batch * n_output'
-//   cell_state_ptr   - size 'n_batch * n_cell'
-//   output_ptr       - size 'n_batch * n_output'
-// TODO(b/159947023): scratch0 is not used if (!cifg). Don't allocate then.
-inline void LstmStepInteger8x8_16(
-    const int8_t* input_ptr, const int8_t* input_to_input_weight_ptr,
-    int32_t effective_input_to_input_scale_a,
-    int32_t effective_input_to_input_scale_b,
-    const int8_t* input_to_forget_weight_ptr,
-    int32_t effective_input_to_forget_scale_a,
-    int32_t effective_input_to_forget_scale_b,
-    const int8_t* input_to_cell_weight_ptr,
-    int32_t effective_input_to_cell_scale_a,
-    int32_t effective_input_to_cell_scale_b,
-    const int8_t* input_to_output_weight_ptr,
-    int32_t effective_input_to_output_scale_a,
-    int32_t effective_input_to_output_scale_b,
-    const int8_t* recurrent_to_input_weight_ptr,
-    int32_t effective_recurrent_to_input_scale_a,
-    int32_t effective_recurrent_to_input_scale_b,
-    const int8_t* recurrent_to_forget_weight_ptr,
-    int32_t effective_recurrent_to_forget_scale_a,
-    int32_t effective_recurrent_to_forget_scale_b,
-    const int8_t* recurrent_to_cell_weight_ptr,
-    int32_t effective_recurrent_to_cell_scale_a,
-    int32_t effective_recurrent_to_cell_scale_b,
-    const int8_t* recurrent_to_output_weight_ptr,
-    int32_t effective_recurrent_to_output_scale_a,
-    int32_t effective_recurrent_to_output_scale_b,
-    const int16_t* cell_to_input_weight_ptr,
-    int32_t effective_cell_to_input_scale_a,
-    int32_t effective_cell_to_input_scale_b,
-    const int16_t* cell_to_forget_weight_ptr,
-    int32_t effective_cell_to_forget_scale_a,
-    int32_t effective_cell_to_forget_scale_b,
-    const int16_t* cell_to_output_weight_ptr,
-    int32_t effective_cell_to_output_scale_a,
-    int32_t effective_cell_to_output_scale_b,
-    const int8_t* projection_weight_ptr, int32_t effective_proj_scale_a,
-    int32_t effective_proj_scale_b, int32_t hidden_zp,
-    int32_t effective_hidden_scale_a, int32_t effective_hidden_scale_b,
-    const int16_t* layer_norm_input_weight_ptr,
-    int32_t layer_norm_input_scale_a, int32_t layer_norm_input_scale_b,
-    const int16_t* layer_norm_forget_weight_ptr,
-    int32_t layer_norm_forget_scale_a, int32_t layer_norm_forget_scale_b,
-    const int16_t* layer_norm_cell_weight_ptr, int32_t layer_norm_cell_scale_a,
-    int32_t layer_norm_cell_scale_b,
-    const int16_t* layer_norm_output_weight_ptr,
-    int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
-    const int32_t* input_gate_bias_ptr, const int32_t* forget_gate_bias_ptr,
-    const int32_t* cell_gate_bias_ptr, const int32_t* output_gate_bias_ptr,
-    int16_t quantized_cell_clip, int8_t quantized_proj_clip,
-    int32_t cell_state_scale, int32_t input_variance_guard,
-    int32_t forget_variance_guard, int32_t cell_variance_guard,
-    int32_t output_variance_guard,
-    const int32_t* input_to_forget_effective_bias,
-    const int32_t* recurrent_to_forget_effective_bias,
-    const int32_t* input_to_cell_effective_bias,
-    const int32_t* recurrent_to_cell_effective_bias,
-    const int32_t* input_to_output_effective_bias,
-    const int32_t* recurrent_to_output_effective_bias,
-    const int32_t* input_to_input_effective_bias,
-    const int32_t* recurrent_to_input_effective_bias,
-    const int32_t* projection_effective_bias, int n_batch, int n_cell,
-    int n_input, int n_output, int8_t* output_state_ptr,
-    int32_t output_state_zp, int16_t* cell_state_ptr, int8_t* output_ptr,
-    int16_t* scratch0, int16_t* scratch1, int16_t* scratch2, int16_t* scratch3,
-    int8_t* scratch4, int32_t* scratch5) {
-  // ruy::profiler::ScopeLabel label("LstmStepInteger8x8_16");
-  // Make named scratch buffers for the different gates.
-  int16_t* input_gate_scratch = scratch0;
-  int16_t* forget_gate_scratch = scratch1;
-  int16_t* cell_gate_scratch = scratch2;
-  int16_t* output_gate_scratch = scratch3;
-
-  // Since we have already checked that weights are all there or none, we
-  // can check the existence of only one to the get the condition.
-  const bool use_cifg = (input_to_input_weight_ptr == nullptr);
-
-  // Check for nullptrs.
-  TFLITE_DCHECK(input_to_forget_effective_bias);
-  TFLITE_DCHECK(recurrent_to_forget_effective_bias);
-  TFLITE_DCHECK(input_to_cell_effective_bias);
-  TFLITE_DCHECK(recurrent_to_cell_effective_bias);
-  TFLITE_DCHECK(input_to_output_effective_bias);
-  TFLITE_DCHECK(recurrent_to_output_effective_bias);
-  if (!use_cifg) {
-    TFLITE_DCHECK(input_to_input_effective_bias);
-    TFLITE_DCHECK(recurrent_to_input_effective_bias);
-  }
-  const bool use_projection = (projection_weight_ptr != nullptr);
-  if (use_projection) {
-    TFLITE_DCHECK(projection_effective_bias);
-  }
-  if (!use_cifg) {
-    // Calculate the input gate. (If not CIFG.)
-    CalculateLstmGateInteger8x8_16(
-        input_ptr, input_to_input_weight_ptr, input_to_input_effective_bias,
-        effective_input_to_input_scale_a, effective_input_to_input_scale_b,
-        output_state_ptr, recurrent_to_input_weight_ptr,
-        recurrent_to_input_effective_bias, effective_recurrent_to_input_scale_a,
-        effective_recurrent_to_input_scale_b, cell_state_ptr,
-        cell_to_input_weight_ptr, effective_cell_to_input_scale_a,
-        effective_cell_to_input_scale_b, layer_norm_input_weight_ptr,
-        input_gate_bias_ptr, layer_norm_input_scale_a, layer_norm_input_scale_b,
-        input_variance_guard, n_batch, n_input, n_output, n_cell,
-        kTfLiteActSigmoid, input_gate_scratch, scratch5);
-  }
-  // Calculate the forget gate.
-  CalculateLstmGateInteger8x8_16(
-      input_ptr, input_to_forget_weight_ptr, input_to_forget_effective_bias,
-      effective_input_to_forget_scale_a, effective_input_to_forget_scale_b,
-      output_state_ptr, recurrent_to_forget_weight_ptr,
-      recurrent_to_forget_effective_bias, effective_recurrent_to_forget_scale_a,
-      effective_recurrent_to_forget_scale_b, cell_state_ptr,
-      cell_to_forget_weight_ptr, effective_cell_to_forget_scale_a,
-      effective_cell_to_forget_scale_b, layer_norm_forget_weight_ptr,
-      forget_gate_bias_ptr, layer_norm_forget_scale_a,
-      layer_norm_forget_scale_b, forget_variance_guard, n_batch, n_input,
-      n_output, n_cell, kTfLiteActSigmoid, forget_gate_scratch, scratch5);
-  // Calculate the cell update gate.
-  CalculateLstmGateInteger8x8_16(
-      input_ptr, input_to_cell_weight_ptr, input_to_cell_effective_bias,
-      effective_input_to_cell_scale_a, effective_input_to_cell_scale_b,
-      output_state_ptr, recurrent_to_cell_weight_ptr,
-      recurrent_to_cell_effective_bias, effective_recurrent_to_cell_scale_a,
-      effective_recurrent_to_cell_scale_b, cell_state_ptr,
-      /*cell_to_gate_weights=*/nullptr, /*cell_to_gate_scale_a=*/0,
-      /*cell_to_gate_scale_b=*/0, layer_norm_cell_weight_ptr,
-      cell_gate_bias_ptr, layer_norm_cell_scale_a, layer_norm_cell_scale_b,
-      cell_variance_guard, n_batch, n_input, n_output, n_cell, kTfLiteActTanh,
-      cell_gate_scratch, scratch5);
-  // Update the cell state.
-  UpdateLstmCellInteger(n_batch, n_cell, cell_state_ptr, cell_state_scale,
-                        input_gate_scratch, forget_gate_scratch,
-                        cell_gate_scratch, use_cifg, quantized_cell_clip);
-  // Calculate the output gate.
-  CalculateLstmGateInteger8x8_16(
-      input_ptr, input_to_output_weight_ptr, input_to_output_effective_bias,
-      effective_input_to_output_scale_a, effective_input_to_output_scale_b,
-      output_state_ptr, recurrent_to_output_weight_ptr,
-      recurrent_to_output_effective_bias, effective_recurrent_to_output_scale_a,
-      effective_recurrent_to_output_scale_b, cell_state_ptr,
-      cell_to_output_weight_ptr, effective_cell_to_output_scale_a,
-      effective_cell_to_output_scale_b, layer_norm_output_weight_ptr,
-      output_gate_bias_ptr, layer_norm_output_scale_a,
-      layer_norm_output_scale_b, output_variance_guard, n_batch, n_input,
-      n_output, n_cell, kTfLiteActSigmoid, output_gate_scratch, scratch5);
-  // Update the output state.
-  CalculateLstmOutputInteger8x8_16(
-      n_batch, n_cell, n_output, cell_state_ptr, cell_state_scale,
-      output_gate_scratch, effective_hidden_scale_a, effective_hidden_scale_b,
-      hidden_zp, projection_weight_ptr, effective_proj_scale_a,
-      effective_proj_scale_b, projection_effective_bias, output_state_zp,
-      quantized_proj_clip, output_state_ptr, scratch0, scratch4, scratch5);
-  // Copy output state to the output. Note that unlike float or hybrid, output
-  // is always contiguous.
-  std::copy_n(output_state_ptr, n_batch * n_output, output_ptr);
-}
-
-// Fully quantized lstm kernel for 8 bit gate matmul output.
-//
-// Input tensor of size n_batch * n_input:
-//   input_ptr
-//
-// LSTM weights:
-// Quantized input weights of size 'n_cell * n_input':
-//   input_to_input_weight_ptr            - optional
-//   input_to_forget_weight_ptr           - optional
-//   input_to_cell_weight_ptr             - optional
-//   input_to_output_weight_ptr           - optional
-//
-// Quantized recurrent weights of size 'n_cell * n_output':
-//   recurrent_to_input_weight_ptr        - optional
-//   recurrent_to_forget_weights_ptr
-//   recurrent_to_cell_weights_ptr
-//   recurrent_to_input_weights_ptr
-//
-// Quantized peephole weights of size 'n_cell', representing diagonal matrices.
-//   cell_to_input_weights               - optional
-//   cell_to_cell_weights                - optional
-//   cell_to_output_weights              - optional
-//
-// Quantized projection weights of size 'n_output * n_cell'
-//   projection_weight_ptr                     - optional
-//
-// Weight scales (scalars) for each of the weights above.
-//   effective_input_to_input_scale_a    - optional
-//   effective_input_to_input_scale_b    - optional
-//   effective_input_to_forget_scale_a
-//   effective_input_to_forget_scale_b
-//   effective_input_to_cell_scale_a
-//   effective_input_to_cell_scale_b
-//   effective_input_to_output_scale_a
-//   effective_input_to_output_scale_b
-//   effective_recurrent_to_input_scale_a    - optional
-//   effective_recurrent_to_input_scale_b    - optional
-//   effective_recurrent_to_forget_scale_a
-//   effective_recurrent_to_forget_scale_b
-//   effective_recurrent_to_cell_scale_a
-//   effective_recurrent_to_cell_scale_b
-//   effective_recurrent_to_output_scale_a
-//   effective_recurrent_to_output_scale_b
-//   effective_proj_scale_a                  - optional
-//   effective_proj_scale_b                  - optional
-//
-// Gate biases of size 'n_cell':
-//   input_gate_bias_ptr                 - optional
-//   forget_gate_bias_ptr
-//   cell_gate_bias_ptr
-//   output_gate_bias_ptr
-//
-// Layer norm coefficients of size 'n_cell', representing diagonal matrices.
-//   layer_norm_input_weight_ptr    - optional
-//   layer_norm_forget_weight_ptr   - optional
-//   layer_norm_cell_weight_ptr     - optional
-//   layer_norm_output_weight_ptr   - optional
-//
-// Layer norm scales of size 'n_cell'.
-//   layer_norm_input_scale_a     - optional
-//   layer_norm_input_scale_b     - optional
-//   layer_norm_forget_scale_a    - optional
-//   layer_norm_forget_scale_b    - optional
-//   layer_norm_cell_scale_a      - optional
-//   layer_norm_cell_scale_b      - optional
-//   layer_norm_output_scale_a    - optional
-//   layer_norm_output_scale_b    - optional
-//
-// Scalar values:
-//   quantized_cell_clip: quantized clip value for cell.
-//   quantized_proj_clip: quantized clip value for projection.
-//   cell_state_scale: the power of two scale for cell state.
-//
-// Zero points:
-//   output_state_zp: zero point of output state.
-//   hidden_zp: zero point for hidden state.
-//
-// Temporary pre-allocated storage for the calculation. Each is of size n_cell *
-// n_batch.
-//   scratch0
-//   scratch1
-//   scratch2
-//   scratch3
-//   scratch4
-//   scratch5
-//   scratch6
-//   scratch7
-//
-// Outputs:
-//   output_state_ptr - size 'n_batch * n_output'
-//   cell_state_ptr   - size 'n_batch * n_cell'
-//   output_ptr       - size 'n_batch * n_output'
-// TODO(b/148688698): Move zero point calculation into Prepare().
-// TODO(b/159947023): scratch5 is unused, remove.
-inline void LstmStepInteger8x8_8(
-    const int8_t* input_ptr, int32_t input_zp,
-    const int8_t* input_to_input_weight_ptr,
-    int32_t effective_input_to_input_scale_a,
-    int32_t effective_input_to_input_scale_b,
-    const int8_t* input_to_forget_weight_ptr,
-    int32_t effective_input_to_forget_scale_a,
-    int32_t effective_input_to_forget_scale_b,
-    const int8_t* input_to_cell_weight_ptr,
-    int32_t effective_input_to_cell_scale_a,
-    int32_t effective_input_to_cell_scale_b,
-    const int8_t* input_to_output_weight_ptr,
-    int32_t effective_input_to_output_scale_a,
-    int32_t effective_input_to_output_scale_b,
-    const int8_t* recurrent_to_input_weight_ptr,
-    int32_t effective_recurrent_to_input_scale_a,
-    int32_t effective_recurrent_to_input_scale_b,
-    const int8_t* recurrent_to_forget_weight_ptr,
-    int32_t effective_recurrent_to_forget_scale_a,
-    int32_t effective_recurrent_to_forget_scale_b,
-    const int8_t* recurrent_to_cell_weight_ptr,
-    int32_t effective_recurrent_to_cell_scale_a,
-    int32_t effective_recurrent_to_cell_scale_b,
-    const int8_t* recurrent_to_output_weight_ptr,
-    int32_t effective_recurrent_to_output_scale_a,
-    int32_t effective_recurrent_to_output_scale_b,
-    const int8_t* cell_to_input_weight_ptr,
-    int32_t effective_cell_to_input_scale_a,
-    int32_t effective_cell_to_input_scale_b,
-    const int8_t* cell_to_forget_weight_ptr,
-    int32_t effective_cell_to_forget_scale_a,
-    int32_t effective_cell_to_forget_scale_b,
-    const int8_t* cell_to_output_weight_ptr,
-    int32_t effective_cell_to_output_scale_a,
-    int32_t effective_cell_to_output_scale_b,
-    const int8_t* projection_weight_ptr, int32_t effective_proj_scale_a,
-    int32_t effective_proj_scale_b, const int16_t* layer_norm_input_weight_ptr,
-    int32_t layer_norm_input_scale_a, int32_t layer_norm_input_scale_b,
-    const int16_t* layer_norm_forget_weight_ptr,
-    int32_t layer_norm_forget_scale_a, int32_t layer_norm_forget_scale_b,
-    const int16_t* layer_norm_cell_weight_ptr, int32_t layer_norm_cell_scale_a,
-    int32_t layer_norm_cell_scale_b,
-    const int16_t* layer_norm_output_weight_ptr,
-    int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
-    const int32_t* input_gate_bias_ptr, const int32_t* forget_gate_bias_ptr,
-    const int32_t* cell_gate_bias_ptr, const int32_t* output_gate_bias_ptr,
-    const int32_t* projection_bias_ptr, const TfLiteLSTMParams* params,
-    const int32_t* intermediate_scale_a, const int32_t* intermediate_scale_b,
-    const int32_t* intermediate_zp, int16_t quantized_cell_clip,
-    int8_t quantized_proj_clip, int n_batch, int n_cell, int n_input,
-    int n_output, int output_batch_leading_dim, int8_t* output_state_ptr,
-    int32_t output_state_zp, int16_t* cell_state_ptr, int8_t* output_ptr,
-    int8_t* scratch0, int8_t* scratch1, int16_t* scratch2, int16_t* scratch3,
-    int16_t* scratch4, int16_t* scratch5, int16_t* scratch6,
-    int16_t* scratch7) {
-  // TODO(b/159066113): scratch5 is unused, remove.
-
-  // ruy::profiler::ScopeLabel label("LstmStepInteger8x8_8");
-  // Make named scratch buffers for the different gates.
-  int16_t* forget_gate_scratch = scratch2;
-  int16_t* cell_gate_scratch = scratch3;
-  int16_t* output_gate_scratch = scratch4;
-  // no-CIFG is not supported here
-
-  // Calculate the forget gate.
-  CalculateLstmGateInteger8x8_8(
-      input_ptr, input_zp, input_to_forget_weight_ptr,
-      effective_input_to_forget_scale_a, effective_input_to_forget_scale_b,
-      intermediate_scale_a[2], intermediate_scale_b[2], intermediate_zp[4],
-      output_state_ptr, output_state_zp, recurrent_to_forget_weight_ptr,
-      effective_recurrent_to_forget_scale_a,
-      effective_recurrent_to_forget_scale_b, intermediate_scale_a[3],
-      intermediate_scale_b[3], intermediate_zp[5], layer_norm_forget_weight_ptr,
-      layer_norm_forget_scale_a, layer_norm_forget_scale_b,
-      forget_gate_bias_ptr, n_batch, n_input, n_output, n_cell,
-      kTfLiteActSigmoid, forget_gate_scratch, scratch0, scratch1);
-  // Calculate the cell update gate.
-  CalculateLstmGateInteger8x8_8(
-      input_ptr, input_zp, input_to_cell_weight_ptr,
-      effective_input_to_cell_scale_a, effective_input_to_cell_scale_b,
-      intermediate_scale_a[4], intermediate_scale_b[4], intermediate_zp[7],
-      output_state_ptr, output_state_zp, recurrent_to_cell_weight_ptr,
-      effective_recurrent_to_cell_scale_a, effective_recurrent_to_cell_scale_b,
-      intermediate_scale_a[5], intermediate_scale_b[5], intermediate_zp[8],
-      layer_norm_cell_weight_ptr, layer_norm_cell_scale_a,
-      layer_norm_cell_scale_b, cell_gate_bias_ptr, n_batch, n_input, n_output,
-      n_cell, kTfLiteActTanh, cell_gate_scratch, scratch0, scratch1);
-  // Update the cell state.
-  UpdateLstmCellInteger(n_batch, n_cell, cell_state_ptr,
-                        /*cell_state_scale=*/-15, /*input_gate=*/nullptr,
-                        forget_gate_scratch, cell_gate_scratch,
-                        /*use_cifg=*/true, quantized_cell_clip);
-  // Calculate the output gate.
-  CalculateLstmGateInteger8x8_8(
-      input_ptr, input_zp, input_to_output_weight_ptr,
-      effective_input_to_output_scale_a, effective_input_to_output_scale_b,
-      intermediate_scale_a[6], intermediate_scale_b[6], intermediate_zp[10],
-      output_state_ptr, output_state_zp, recurrent_to_output_weight_ptr,
-      effective_recurrent_to_output_scale_a,
-      effective_recurrent_to_output_scale_b, intermediate_scale_a[11],
-      intermediate_scale_b[7], intermediate_zp[7], layer_norm_output_weight_ptr,
-      layer_norm_output_scale_a, layer_norm_output_scale_b,
-      output_gate_bias_ptr, n_batch, n_input, n_output, n_cell,
-      kTfLiteActSigmoid, output_gate_scratch, scratch0, scratch1);
-  // Update the output state.
-  CalculateLstmOutputInteger8x8_8(
-      n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
-      projection_weight_ptr, effective_proj_scale_a, effective_proj_scale_b,
-      projection_bias_ptr, output_state_zp, quantized_proj_clip,
-      output_state_ptr, scratch2);
-  // Copy output state to the output. Note that unlike float or hybrid, output
-  // is always contiguous.
-  std::copy_n(output_state_ptr, n_batch * n_output, output_ptr);
-}
-
-}  // namespace
-
-// LINT.ThenChange(//tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc)
-TfLiteStatus EvalInteger8x8_16(
-    TfLiteContext* context, TfLiteNode* node, const TfLiteEvalTensor* input,
-    const TfLiteEvalTensor* input_to_input_weights,
-    const TfLiteEvalTensor* input_to_forget_weights,
-    const TfLiteEvalTensor* input_to_cell_weights,
-    const TfLiteEvalTensor* input_to_output_weights,
-    const TfLiteEvalTensor* recurrent_to_input_weights,
-    const TfLiteEvalTensor* recurrent_to_forget_weights,
-    const TfLiteEvalTensor* recurrent_to_cell_weights,
-    const TfLiteEvalTensor* recurrent_to_output_weights,
-    const TfLiteEvalTensor* cell_to_input_weights,
-    const TfLiteEvalTensor* cell_to_forget_weights,
-    const TfLiteEvalTensor* cell_to_output_weights,
-    const TfLiteEvalTensor* input_layer_norm_coefficients,
-    const TfLiteEvalTensor* forget_layer_norm_coefficients,
-    const TfLiteEvalTensor* cell_layer_norm_coefficients,
-    const TfLiteEvalTensor* output_layer_norm_coefficients,
-    const TfLiteEvalTensor* input_gate_bias,
-    const TfLiteEvalTensor* forget_gate_bias,
-    const TfLiteEvalTensor* cell_gate_bias,
-    const TfLiteEvalTensor* output_gate_bias,
-    const TfLiteEvalTensor* projection_weights,
-    const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
-    bool forward_sequence, bool time_major,
-    const lstm_eval::IntegerLstmParameter* integer_lstm_param,
-    TfLiteEvalTensor* output_state, TfLiteEvalTensor* cell_state,
-    TfLiteEvalTensor* output, TfLiteEvalTensor* scratch0,
-    TfLiteEvalTensor* scratch1, TfLiteEvalTensor* scratch2,
-    TfLiteEvalTensor* scratch3, TfLiteEvalTensor* scratch4,
-    TfLiteEvalTensor* scratch5) {
-  TFLITE_DCHECK(input->dims->size >= 2 && input->dims->size <= 3);
-  const int n_input = input->dims->data[input->dims->size - 1];
-  int max_time, n_batch;
-  if (input->dims->size == 2) {
-    max_time = 1;
-    n_batch = input->dims->data[0];
-  } else {
-    max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
-    n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
-  }
-
-  // n_cell and n_output will be the same size when there is no projection.
-  const int n_cell = input_to_output_weights->dims->data[0];
-  const int n_output = recurrent_to_output_weights->dims->data[1];
-
-  // Activation zero point
-  //  TODO@is data.output_zero_point equal to output_state->params.zero_point
-  // int output_state_zp = output_state->params.zero_point;
-  int output_state_zp = 0;
-
-  // Get params for time/batch/sequence.
-  const int output_batch_leading_dim =
-      output->dims->data[output->dims->size - 1];
-
-  if (time_major) {
-    const int input_step = n_batch * n_input;
-    const int output_step = n_batch * output_batch_leading_dim;
-    for (int t = 0; t < max_time; t++) {
-      const int t_rel = t;
-      int8_t* output_ptr =
-          tflite::micro::GetTensorData<int8_t>(output) + t_rel * output_step;
-      const int8_t* input_ptr =
-          tflite::micro::GetTensorData<int8_t>(input) + t_rel * input_step;
-      LstmStepInteger8x8_16(
-          input_ptr,
-          tflite::micro::GetTensorData<int8_t>(input_to_input_weights),
-          integer_lstm_param->effective_input_to_input_scale_a,
-          integer_lstm_param->effective_input_to_input_scale_b,
-          tflite::micro::GetTensorData<int8_t>(input_to_forget_weights),
-          integer_lstm_param->effective_input_to_forget_scale_a,
-          integer_lstm_param->effective_input_to_forget_scale_b,
-          tflite::micro::GetTensorData<int8_t>(input_to_cell_weights),
-          integer_lstm_param->effective_input_to_cell_scale_a,
-          integer_lstm_param->effective_input_to_cell_scale_b,
-          tflite::micro::GetTensorData<int8_t>(input_to_output_weights),
-          integer_lstm_param->effective_input_to_output_scale_a,
-          integer_lstm_param->effective_input_to_output_scale_b,
-          tflite::micro::GetTensorData<int8_t>(recurrent_to_input_weights),
-          integer_lstm_param->effective_recurrent_to_input_scale_a,
-          integer_lstm_param->effective_recurrent_to_input_scale_b,
-          tflite::micro::GetTensorData<int8_t>(recurrent_to_forget_weights),
-          integer_lstm_param->effective_recurrent_to_forget_scale_a,
-          integer_lstm_param->effective_recurrent_to_forget_scale_b,
-          tflite::micro::GetTensorData<int8_t>(recurrent_to_cell_weights),
-          integer_lstm_param->effective_recurrent_to_cell_scale_a,
-          integer_lstm_param->effective_recurrent_to_cell_scale_b,
-          tflite::micro::GetTensorData<int8_t>(recurrent_to_output_weights),
-          integer_lstm_param->effective_recurrent_to_output_scale_a,
-          integer_lstm_param->effective_recurrent_to_output_scale_b,
-          tflite::micro::GetTensorData<int16_t>(cell_to_input_weights),
-          integer_lstm_param->effective_cell_to_input_scale_a,
-          integer_lstm_param->effective_cell_to_input_scale_b,
-          tflite::micro::GetTensorData<int16_t>(cell_to_forget_weights),
-          integer_lstm_param->effective_cell_to_forget_scale_a,
-          integer_lstm_param->effective_cell_to_forget_scale_b,
-          tflite::micro::GetTensorData<int16_t>(cell_to_output_weights),
-          integer_lstm_param->effective_cell_to_output_scale_a,
-          integer_lstm_param->effective_cell_to_output_scale_b,
-          tflite::micro::GetTensorData<int8_t>(projection_weights),
-          integer_lstm_param->effective_proj_scale_a,
-          integer_lstm_param->effective_proj_scale_b,
-          integer_lstm_param->hidden_zp,
-          integer_lstm_param->effective_hidden_scale_a,
-          integer_lstm_param->effective_hidden_scale_b,
-          tflite::micro::GetTensorData<int16_t>(input_layer_norm_coefficients),
-          integer_lstm_param->layer_norm_input_scale_a,
-          integer_lstm_param->layer_norm_input_scale_b,
-          tflite::micro::GetTensorData<int16_t>(forget_layer_norm_coefficients),
-          integer_lstm_param->layer_norm_forget_scale_a,
-          integer_lstm_param->layer_norm_forget_scale_b,
-          tflite::micro::GetTensorData<int16_t>(cell_layer_norm_coefficients),
-          integer_lstm_param->layer_norm_cell_scale_a,
-          integer_lstm_param->layer_norm_cell_scale_b,
-          tflite::micro::GetTensorData<int16_t>(output_layer_norm_coefficients),
-          integer_lstm_param->layer_norm_output_scale_a,
-          integer_lstm_param->layer_norm_output_scale_b,
-          tflite::micro::GetTensorData<int32_t>(input_gate_bias),
-          tflite::micro::GetTensorData<int32_t>(forget_gate_bias),
-          tflite::micro::GetTensorData<int32_t>(cell_gate_bias),
-          tflite::micro::GetTensorData<int32_t>(output_gate_bias),
-          integer_lstm_param->quantized_cell_clip,
-          integer_lstm_param->quantized_proj_clip,
-          integer_lstm_param->cell_scale,
-          integer_lstm_param->input_variance_guard,
-          integer_lstm_param->forget_variance_guard,
-          integer_lstm_param->cell_variance_guard,
-          integer_lstm_param->output_variance_guard,
-          integer_lstm_param->input_to_forget_effective_bias.get(),
-          integer_lstm_param->recurrent_to_forget_effective_bias.get(),
-          integer_lstm_param->input_to_cell_effective_bias.get(),
-          integer_lstm_param->recurrent_to_cell_effective_bias.get(),
-          integer_lstm_param->input_to_output_effective_bias.get(),
-          integer_lstm_param->recurrent_to_output_effective_bias.get(),
-          integer_lstm_param->input_to_input_effective_bias.get(),
-          integer_lstm_param->recurrent_to_input_effective_bias.get(),
-          integer_lstm_param->projection_effective_bias.get(), n_batch, n_cell,
-          n_input, n_output, tflite::micro::GetTensorData<int8_t>(output_state),
-          output_state_zp, tflite::micro::GetTensorData<int16_t>(cell_state),
-          output_ptr, (int16_t*)(scratch0), (int16_t*)(scratch1),
-          (int16_t*)(scratch2), (int16_t*)(scratch3), (int8_t*)(scratch4),
-          (int32_t*)(scratch5));
-    }
-  } else {
-    for (int b = 0; b < n_batch; b++) {
-      const int input_step = n_input;
-      const int output_step = output_batch_leading_dim;
-      for (int t = 0; t < max_time; t++) {
-        // If this is the forward_sequence, step forward, otherwise step
-        // backwards.
-        const int t_rel = forward_sequence ? t : max_time - t - 1;
-        const int time_offset = b * max_time + t_rel;
-        const int8_t* input_ptr = tflite::micro::GetTensorData<int8_t>(input) +
-                                  time_offset * input_step;
-        int8_t* output_ptr = tflite::micro::GetTensorData<int8_t>(output) +
-                             time_offset * output_step;
-
-        // Offset the {output,cell}_state pointers to the right batch.
-        int8_t* output_state_ptr =
-            tflite::micro::GetTensorData<int8_t>(output_state) +
-            b * output_batch_leading_dim;
-        int16_t* cell_state_ptr =
-            tflite::micro::GetTensorData<int16_t>(cell_state) + b * n_cell;
-
-        LstmStepInteger8x8_16(
-            input_ptr,
-            tflite::micro::GetTensorData<int8_t>(input_to_input_weights),
-            integer_lstm_param->effective_input_to_input_scale_a,
-            integer_lstm_param->effective_input_to_input_scale_b,
-            tflite::micro::GetTensorData<int8_t>(input_to_forget_weights),
-            integer_lstm_param->effective_input_to_forget_scale_a,
-            integer_lstm_param->effective_input_to_forget_scale_b,
-            tflite::micro::GetTensorData<int8_t>(input_to_cell_weights),
-            integer_lstm_param->effective_input_to_cell_scale_a,
-            integer_lstm_param->effective_input_to_cell_scale_b,
-            tflite::micro::GetTensorData<int8_t>(input_to_output_weights),
-            integer_lstm_param->effective_input_to_output_scale_a,
-            integer_lstm_param->effective_input_to_output_scale_b,
-            tflite::micro::GetTensorData<int8_t>(recurrent_to_input_weights),
-            integer_lstm_param->effective_recurrent_to_input_scale_a,
-            integer_lstm_param->effective_recurrent_to_input_scale_b,
-            tflite::micro::GetTensorData<int8_t>(recurrent_to_forget_weights),
-            integer_lstm_param->effective_recurrent_to_forget_scale_a,
-            integer_lstm_param->effective_recurrent_to_forget_scale_b,
-            tflite::micro::GetTensorData<int8_t>(recurrent_to_cell_weights),
-            integer_lstm_param->effective_recurrent_to_cell_scale_a,
-            integer_lstm_param->effective_recurrent_to_cell_scale_b,
-            tflite::micro::GetTensorData<int8_t>(recurrent_to_output_weights),
-            integer_lstm_param->effective_recurrent_to_output_scale_a,
-            integer_lstm_param->effective_recurrent_to_output_scale_b,
-            tflite::micro::GetTensorData<int16_t>(cell_to_input_weights),
-            integer_lstm_param->effective_cell_to_input_scale_a,
-            integer_lstm_param->effective_cell_to_input_scale_b,
-            tflite::micro::GetTensorData<int16_t>(cell_to_forget_weights),
-            integer_lstm_param->effective_cell_to_forget_scale_a,
-            integer_lstm_param->effective_cell_to_forget_scale_b,
-            tflite::micro::GetTensorData<int16_t>(cell_to_output_weights),
-            integer_lstm_param->effective_cell_to_output_scale_a,
-            integer_lstm_param->effective_cell_to_output_scale_b,
-            tflite::micro::GetTensorData<int8_t>(projection_weights),
-            integer_lstm_param->effective_proj_scale_a,
-            integer_lstm_param->effective_proj_scale_b,
-            integer_lstm_param->hidden_zp,
-            integer_lstm_param->effective_hidden_scale_a,
-            integer_lstm_param->effective_hidden_scale_b,
-            tflite::micro::GetTensorData<int16_t>(
-                input_layer_norm_coefficients),
-            integer_lstm_param->layer_norm_input_scale_a,
-            integer_lstm_param->layer_norm_input_scale_b,
-            tflite::micro::GetTensorData<int16_t>(
-                forget_layer_norm_coefficients),
-            integer_lstm_param->layer_norm_forget_scale_a,
-            integer_lstm_param->layer_norm_forget_scale_b,
-            tflite::micro::GetTensorData<int16_t>(cell_layer_norm_coefficients),
-            integer_lstm_param->layer_norm_cell_scale_a,
-            integer_lstm_param->layer_norm_cell_scale_b,
-            tflite::micro::GetTensorData<int16_t>(
-                output_layer_norm_coefficients),
-            integer_lstm_param->layer_norm_output_scale_a,
-            integer_lstm_param->layer_norm_output_scale_b,
-            tflite::micro::GetTensorData<int32_t>(input_gate_bias),
-            tflite::micro::GetTensorData<int32_t>(forget_gate_bias),
-            tflite::micro::GetTensorData<int32_t>(cell_gate_bias),
-            tflite::micro::GetTensorData<int32_t>(output_gate_bias),
-            integer_lstm_param->quantized_cell_clip,
-            integer_lstm_param->quantized_proj_clip,
-            integer_lstm_param->cell_scale,
-            integer_lstm_param->input_variance_guard,
-            integer_lstm_param->forget_variance_guard,
-            integer_lstm_param->cell_variance_guard,
-            integer_lstm_param->output_variance_guard,
-            integer_lstm_param->input_to_forget_effective_bias.get(),
-            integer_lstm_param->recurrent_to_forget_effective_bias.get(),
-            integer_lstm_param->input_to_cell_effective_bias.get(),
-            integer_lstm_param->recurrent_to_cell_effective_bias.get(),
-            integer_lstm_param->input_to_output_effective_bias.get(),
-            integer_lstm_param->recurrent_to_output_effective_bias.get(),
-            integer_lstm_param->input_to_input_effective_bias.get(),
-            integer_lstm_param->recurrent_to_input_effective_bias.get(),
-            integer_lstm_param->projection_effective_bias.get(), /*n_batch=*/1,
-            n_cell, n_input, n_output, output_state_ptr, output_state_zp,
-            cell_state_ptr, output_ptr, (int16_t*)(scratch0),
-            (int16_t*)(scratch1), (int16_t*)(scratch2), (int16_t*)(scratch3),
-            (int8_t*)(scratch4), (int32_t*)(scratch5));
-      }
+LstmTensors::~LstmTensors() {
+  for (size_t i = 0; i < 24; i++) {
+    if (internal_tensors_[i] != nullptr) {
+      micro_context_->DeallocateTempTfLiteTensor(internal_tensors_[i]);
     }
   }
+  micro_context_->DeallocateTempTfLiteTensor(output_tensor_);
+}
 
+// Verify the LSTM internal tensor properties (e.g., type checks)
+// Input/output/states/fc weights tensors are required for kernel evaulation.
+// The state tensors should be variables. Variants of the standard LSTM
+// are not supported here, therefore their corresponding tensors should be
+// invalid
+TfLiteStatus LstmTensors::ValidateTensorStatus(TfLiteContext* context) const {
+  // Verify certain tensor properties
+  // input tensor
+  TF_LITE_ENSURE(context, internal_tensors_[kLstmInputTensor] != nullptr);
+  // hidden state
+  TF_LITE_ENSURE(context, internal_tensors_[kLstmOutputStateTensor] != nullptr);
+  TF_LITE_ENSURE(context,
+                 internal_tensors_[kLstmOutputStateTensor]->is_variable);
+  // hidden state becomes input so they must have the same type
+  TF_LITE_ENSURE_EQ(context, internal_tensors_[kLstmOutputStateTensor]->type,
+                    internal_tensors_[kLstmInputTensor]->type);
+  // cell state
+  TF_LITE_ENSURE(context, internal_tensors_[kLstmCellStateTensor] != nullptr);
+  TF_LITE_ENSURE(context, internal_tensors_[kLstmCellStateTensor]->is_variable);
+  // output
+  TF_LITE_ENSURE(context, output_tensor_ != nullptr);
+  // output type is the same as the input type (activations)
+  TF_LITE_ENSURE_EQ(context, output_tensor_->type,
+                    internal_tensors_[kLstmInputTensor]->type);
+
+  // weight tensors (1-9, see lstm_shared for index definition)
+  const auto weight_type =
+      internal_tensors_[kLstmInputToForgetWeightsTensor]->type;
+  for (size_t i = 1; i < 9; i++) {
+    TF_LITE_ENSURE(context, internal_tensors_[i] != nullptr);
+    TF_LITE_ENSURE_EQ(context, internal_tensors_[i]->type, weight_type);
+  }
+
+  // bias tensors (12-15, see lstm_shared for index definition)
+  const auto bias_type = internal_tensors_[kLstmForgetGateBiasTensor]->type;
+  for (size_t i = 12; i < 16; i++) {
+    TF_LITE_ENSURE(context, internal_tensors_[i] != nullptr);
+    TF_LITE_ENSURE_EQ(context, internal_tensors_[i]->type, bias_type);
+  }
+  // Tensors from LSTM variants are invalid
+  // No peephole
+  for (size_t i = 9; i < 12; i++) {
+    TF_LITE_ENSURE(context, internal_tensors_[i] == nullptr);
+  }
+  // No projection
+  for (size_t i = 16; i < 18; i++) {
+    TF_LITE_ENSURE(context, internal_tensors_[i] == nullptr);
+  }
+  // No internal layer norm
+  for (size_t i = 20; i < 24; i++) {
+    TF_LITE_ENSURE(context, internal_tensors_[i] == nullptr);
+  }
   return kTfLiteOk;
 }
 
-TfLiteStatus EvalInteger8x8_8(
-    const TfLiteEvalTensor* input,
-    const TfLiteEvalTensor* input_to_input_weights,
-    const TfLiteEvalTensor* input_to_forget_weights,
-    const TfLiteEvalTensor* input_to_cell_weights,
-    const TfLiteEvalTensor* input_to_output_weights,
-    const TfLiteEvalTensor* recurrent_to_input_weights,
-    const TfLiteEvalTensor* recurrent_to_forget_weights,
-    const TfLiteEvalTensor* recurrent_to_cell_weights,
-    const TfLiteEvalTensor* recurrent_to_output_weights,
-    const TfLiteEvalTensor* cell_to_input_weights,
-    const TfLiteEvalTensor* cell_to_forget_weights,
-    const TfLiteEvalTensor* cell_to_output_weights,
-    const TfLiteEvalTensor* input_layer_norm_coefficients,
-    const TfLiteEvalTensor* forget_layer_norm_coefficients,
-    const TfLiteEvalTensor* cell_layer_norm_coefficients,
-    const TfLiteEvalTensor* output_layer_norm_coefficients,
-    const TfLiteEvalTensor* input_gate_bias,
-    const TfLiteEvalTensor* forget_gate_bias,
-    const TfLiteEvalTensor* cell_gate_bias,
-    const TfLiteEvalTensor* output_gate_bias,
-    const TfLiteEvalTensor* projection_weights,
-    const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
-    TfLiteEvalTensor* output_state, TfLiteEvalTensor* cell_state,
-    TfLiteEvalTensor* output,
-    const lstm_eval::IntegerLstmParameter* integer_lstm_param,
-    TfLiteEvalTensor* scratch0, TfLiteEvalTensor* scratch1,
-    TfLiteEvalTensor* scratch2, TfLiteEvalTensor* scratch3,
-    TfLiteEvalTensor* scratch4, TfLiteEvalTensor* scratch5,
-    TfLiteEvalTensor* scratch6, TfLiteEvalTensor* scratch7) {
-  TFLITE_DCHECK(input->dims->size >= 2 && input->dims->size <= 3);
-  const int n_input = input->dims->data[input->dims->size - 1];
-  int max_time, n_batch;
-  if (input->dims->size == 2) {
-    max_time = 1;
-    n_batch = input->dims->data[0];
-  } else {
-    max_time = input->dims->data[0];
-    n_batch = input->dims->data[1];
+namespace lstm_internal {
+
+#if !(defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
+const int32_t kInt16Max = std::numeric_limits<int16_t>::max();
+const int32_t kInt16Min = std::numeric_limits<int16_t>::min();
+#endif
+
+void AddElementWise(const int16_t* input_1, const int16_t* input_2, int n_batch,
+                    int n_input, int16_t* output) {
+#if !(defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
+  for (int batch = 0; batch < n_batch; ++batch) {
+    for (int i = 0; i < n_input; ++i) {
+      const int index = batch * n_input + i;
+      int32_t sum = input_1[index] + input_2[index];
+      const int32_t sum_clamped = std::min(kInt16Max, std::max(kInt16Min, sum));
+      output[index] = static_cast<int16_t>(sum_clamped);
+    }
   }
-
-  // n_cell and n_output will be the same size when there is no projection.
-  const int n_cell = input_to_output_weights->dims->data[0];
-  const int n_output = recurrent_to_output_weights->dims->data[1];
-  //@TODO input zero point and output zeropoint
-  // const int32_t input_zp = input->params.zero_point;
-  /// const int32_t output_state_zp = output_state->params.zero_point;
-  const int32_t input_zp = 0;
-  const int32_t output_state_zp = 0;
-
-  // Get params for time/batch/sequence.
-  const int output_batch_leading_dim =
-      output->dims->data[output->dims->size - 1];
-  const int input_step = n_batch * n_input;
-  const int output_step = n_batch * output_batch_leading_dim;
-
-  for (int t = 0; t < max_time; t++) {
-    const int t_rel = t;
-    int8_t* output_ptr =
-        tflite::micro::GetTensorData<int8_t>(output) + t_rel * output_step;
-    // Input can be int8 asymmetric or int16 symmetric.
-    const int8_t* input_ptr =
-        tflite::micro::GetTensorData<int8_t>(input) + t_rel * input_step;
-    lstm_eval::LstmStepInteger8x8_8(
-        input_ptr, input_zp,
-
-        tflite::micro::GetTensorData<int8_t>(input_to_input_weights),
-        integer_lstm_param->effective_input_to_input_scale_a,
-        integer_lstm_param->effective_input_to_input_scale_b,
-
-        tflite::micro::GetTensorData<int8_t>(input_to_forget_weights),
-        integer_lstm_param->effective_input_to_forget_scale_a,
-        integer_lstm_param->effective_input_to_forget_scale_b,
-
-        tflite::micro::GetTensorData<int8_t>(input_to_cell_weights),
-        integer_lstm_param->effective_input_to_cell_scale_a,
-        integer_lstm_param->effective_input_to_cell_scale_b,
-
-        tflite::micro::GetTensorData<int8_t>(input_to_output_weights),
-        integer_lstm_param->effective_input_to_output_scale_a,
-        integer_lstm_param->effective_input_to_output_scale_b,
-
-        tflite::micro::GetTensorData<int8_t>(recurrent_to_input_weights),
-        integer_lstm_param->effective_recurrent_to_input_scale_a,
-        integer_lstm_param->effective_recurrent_to_input_scale_b,
-
-        tflite::micro::GetTensorData<int8_t>(recurrent_to_forget_weights),
-        integer_lstm_param->effective_recurrent_to_forget_scale_a,
-        integer_lstm_param->effective_recurrent_to_forget_scale_b,
-
-        tflite::micro::GetTensorData<int8_t>(recurrent_to_cell_weights),
-        integer_lstm_param->effective_recurrent_to_cell_scale_a,
-        integer_lstm_param->effective_recurrent_to_cell_scale_b,
-
-        tflite::micro::GetTensorData<int8_t>(recurrent_to_output_weights),
-        integer_lstm_param->effective_recurrent_to_output_scale_a,
-        integer_lstm_param->effective_recurrent_to_output_scale_b,
-
-        tflite::micro::GetTensorData<int8_t>(cell_to_input_weights),
-        integer_lstm_param->effective_cell_to_input_scale_a,
-        integer_lstm_param->effective_cell_to_input_scale_b,
-
-        tflite::micro::GetTensorData<int8_t>(cell_to_forget_weights),
-        integer_lstm_param->effective_cell_to_forget_scale_a,
-        integer_lstm_param->effective_cell_to_forget_scale_b,
-
-        tflite::micro::GetTensorData<int8_t>(cell_to_output_weights),
-        integer_lstm_param->effective_cell_to_output_scale_a,
-        integer_lstm_param->effective_cell_to_output_scale_b,
-
-        tflite::micro::GetTensorData<int8_t>(projection_weights),
-        integer_lstm_param->effective_proj_scale_a,
-        integer_lstm_param->effective_proj_scale_b,
-
-        tflite::micro::GetTensorData<int16_t>(input_layer_norm_coefficients),
-        integer_lstm_param->layer_norm_input_scale_a,
-        integer_lstm_param->layer_norm_input_scale_b,
-
-        tflite::micro::GetTensorData<int16_t>(forget_layer_norm_coefficients),
-        integer_lstm_param->layer_norm_forget_scale_a,
-        integer_lstm_param->layer_norm_forget_scale_b,
-
-        tflite::micro::GetTensorData<int16_t>(cell_layer_norm_coefficients),
-        integer_lstm_param->layer_norm_cell_scale_a,
-        integer_lstm_param->layer_norm_cell_scale_b,
-
-        tflite::micro::GetTensorData<int16_t>(output_layer_norm_coefficients),
-        integer_lstm_param->layer_norm_output_scale_a,
-        integer_lstm_param->layer_norm_output_scale_b,
-
-        tflite::micro::GetTensorData<int32_t>(input_gate_bias),
-        tflite::micro::GetTensorData<int32_t>(forget_gate_bias),
-        tflite::micro::GetTensorData<int32_t>(cell_gate_bias),
-        tflite::micro::GetTensorData<int32_t>(output_gate_bias),
-        tflite::micro::GetTensorData<int32_t>(projection_bias),
-
-        params, integer_lstm_param->intermediate_scale_a,
-        integer_lstm_param->intermediate_scale_b,
-        integer_lstm_param->intermediate_zp,
-        integer_lstm_param->quantized_cell_clip,
-        integer_lstm_param->quantized_proj_clip, n_batch, n_cell, n_input,
-        n_output, output_batch_leading_dim,
-        tflite::micro::GetTensorData<int8_t>(output_state), output_state_zp,
-        tflite::micro::GetTensorData<int16_t>(cell_state), output_ptr,
-        tflite::micro::GetTensorData<int8_t>(scratch0),
-        tflite::micro::GetTensorData<int8_t>(scratch1),
-        tflite::micro::GetTensorData<int16_t>(scratch2),
-        tflite::micro::GetTensorData<int16_t>(scratch3),
-        tflite::micro::GetTensorData<int16_t>(scratch4),
-        tflite::micro::GetTensorData<int16_t>(scratch5),
-        tflite::micro::GetTensorData<int16_t>(scratch6),
-        tflite::micro::GetTensorData<int16_t>(scratch7));
-  }
-
-  return kTfLiteOk;
+#else
+  WORD32 err;
+  err = xa_nn_elm_add_16x16_16(output, input_1, input_2, n_batch * n_input);
+#endif
 }
 
-}  // namespace lstm_eval
-}  // namespace micro
-}  // namespace ops
+void AddElementWise(const float* input_1, const float* input_2, int n_batch,
+                    int n_input, float* output) {
+  for (int batch = 0; batch < n_batch; ++batch) {
+    for (int i = 0; i < n_input; ++i) {
+      const int index = batch * n_input + i;
+      output[index] = input_1[index] + input_2[index];
+    }
+  }
+}
+
+#if !(defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
+void Sigmoid(const RuntimeShape& data_shape, int16_t* data) {
+  reference_integer_ops::Logistic(
+      0 /*data->input_multiplier*/, 0 /*data->input_left_shift */,
+      data_shape.FlatSize() /*NumElements(input->dims)*/,
+      data /* tflite::micro::GetTensorData<int16_t>(input) */,
+      data /*tflite::micro::GetTensorData<int16_t>(output) */);
+}
+
+void Sigmoid(const RuntimeShape& data_shape, float* data) {
+  reference_ops::Logistic(data_shape, data, data_shape, data);
+}
+
+void Tanh(int32_t cell_state_scale_power, const RuntimeShape& input_data_shape,
+          int16_t* input_data, const RuntimeShape& output_data_shape,
+          int16_t* output_data) {
+  int32_t tanh_input_left_shift = (15 + cell_state_scale_power) - 3;
+  int32_t input_multiplier = 0;
+  if (tanh_input_left_shift < 0) /* handling negative shift value */
+  {
+    tanh_input_left_shift = -tanh_input_left_shift;
+    input_multiplier = 3;
+  }
+  reference_integer_ops::Tanh(input_multiplier, tanh_input_left_shift,
+                              input_data_shape, input_data, output_data_shape,
+                              output_data);
+}
+
+void Tanh(int32_t cell_state_scale_power, const RuntimeShape& input_data_shape,
+          float* input_data, const RuntimeShape& output_data_shape,
+          float* output_data) {
+  reference_ops::Tanh(input_data_shape, input_data, output_data_shape,
+                      output_data);
+}
+
+// Input and output have the same shape in LSTM
+void Mul(const RuntimeShape& shape, const ArithmeticParams& params,
+         const int16_t* input1_data, const int16_t* input2_data,
+         int8_t* output_data) {
+  return reference_integer_ops::MulElementwise(
+      shape.FlatSize(), params, input1_data, input2_data, output_data);
+}
+
+// Input and output have the same shape in LSTM
+void Mul(const RuntimeShape& shape, const ArithmeticParams& params,
+         const int16_t* input1_data, const int16_t* input2_data,
+         int16_t* output_data) {
+  return reference_integer_ops::MulElementwise(
+      shape.FlatSize(), params, input1_data, input2_data, output_data);
+}
+
+// Input and output have the same shape in LSTM
+void Mul(const RuntimeShape& shape, const ArithmeticParams& params,
+         const float* input1_data, const float* input2_data,
+         float* output_data) {
+  return reference_ops::Mul(params, shape, input1_data, shape, input2_data,
+                            shape, output_data);
+}
+
+void FullyConnected(const FullyConnectedParams& params,
+                    const RuntimeShape& input_shape, const int8_t* input_data,
+                    const RuntimeShape& filter_shape, const int8_t* filter_data,
+                    const RuntimeShape& bias_shape, const int32_t* bias_data,
+                    const RuntimeShape& output_shape, int16_t* output_data) {
+  return tflite::reference_integer_ops::FullyConnected(
+      params, input_shape, input_data, filter_shape, filter_data, bias_shape,
+      bias_data, output_shape, output_data);
+}
+
+void FullyConnected(const FullyConnectedParams& params,
+                    const RuntimeShape& input_shape, const int16_t* input_data,
+                    const RuntimeShape& filter_shape, const int8_t* filter_data,
+                    const RuntimeShape& bias_shape, const int64_t* bias_data,
+                    const RuntimeShape& output_shape, int16_t* output_data) {
+  return tflite::reference_integer_ops::FullyConnected(
+      params, input_shape, input_data, filter_shape, filter_data, bias_shape,
+      bias_data, output_shape, output_data);
+}
+
+void FullyConnected(const FullyConnectedParams& params,
+                    const RuntimeShape& input_shape, const float* input_data,
+                    const RuntimeShape& filter_shape, const float* filter_data,
+                    const RuntimeShape& bias_shape, const float* bias_data,
+                    const RuntimeShape& output_shape, float* output_data) {
+  return tflite::reference_ops::FullyConnected(
+      params, input_shape, input_data, filter_shape, filter_data, bias_shape,
+      bias_data, output_shape, output_data);
+}
+#else  // #if !(defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
+void Sigmoid(int16_t* data, int32_t data_size) {
+  WORD32 err;
+  err = xa_nn_vec_sigmoid_sym16s_sym16s(data, data, 0, 0, data_size);
+}
+
+void Sigmoid(float* data, int32_t data_size) {
+  int data_dims[2] = {1, data_size};
+  RuntimeShape data_shape(2, reinterpret_cast<const int32_t*>(data_dims));
+  reference_ops::Logistic(data_shape, data, data_shape, data);
+}
+
+void Tanh(int32_t cell_state_scale_power, int16_t* input_data,
+          int16_t* output_data, int32_t data_size) {
+  int32_t tanh_input_left_shift = (15 + cell_state_scale_power) - 3;
+  int32_t input_multiplier = 0;
+  if (tanh_input_left_shift < 0) /* handling negative shift value */
+  {
+    tanh_input_left_shift = -tanh_input_left_shift;
+#if (defined(USE_HIFI_ACT_TIE) && \
+     (defined(AE_TANH16X4X2) || defined(AE_TANH16X4)))
+    input_multiplier = 1;
+#else
+    input_multiplier = 3;
+#endif
+  }
+  WORD32 err;
+  err = xa_nn_vec_tanh_sym16s_sym16s(output_data, input_data, input_multiplier,
+                                     tanh_input_left_shift, data_size);
+}
+
+void Tanh(int32_t cell_state_scale_power, float* input_data, float* output_data,
+          int32_t data_size) {
+  int data_dims[2] = {1, data_size};
+  RuntimeShape data_shape(2, reinterpret_cast<const int32_t*>(data_dims));
+  reference_ops::Tanh(data_shape, input_data, data_shape, output_data);
+}
+
+// Input and output have the same shape in LSTM
+void Mul(const ArithmeticParams& params, const int16_t* input1_data,
+         const int16_t* input2_data, int8_t* output_data, int32_t data_size) {
+  WORD32 err;
+  err = xa_nn_elm_mul_sym16sxsym16s_asym8s(
+      output_data, params.output_offset, params.output_shift,
+      params.output_multiplier, params.quantized_activation_min,
+      params.quantized_activation_max, input1_data, input2_data, data_size);
+}
+
+// Input and output have the same shape in LSTM
+void Mul(const ArithmeticParams& params, const int16_t* input1_data,
+         const int16_t* input2_data, int16_t* output_data, int32_t data_size) {
+  int dims_4D[4] = {1, 1, 1, data_size};
+  WORD32 err;
+  err = xa_nn_elm_mul_broadcast_4D_sym16sxsym16s_sym16s(
+      output_data, dims_4D, params.output_shift, params.output_multiplier,
+      params.quantized_activation_min, params.quantized_activation_max,
+      input1_data, dims_4D, input2_data, dims_4D);
+  return;
+}
+
+// Input and output have the same shape in LSTM
+void Mul(const ArithmeticParams& params, const float* input1_data,
+         const float* input2_data, float* output_data, int32_t data_size) {
+  int dims_2D[2] = {1, data_size};
+  RuntimeShape data_shape(2, reinterpret_cast<const int32_t*>(dims_2D));
+  return reference_ops::Mul(params, data_shape, input1_data, data_shape,
+                            input2_data, data_shape, output_data);
+}
+
+void FullyConnected(const FullyConnectedParams& params,
+                    const int8_t* input_data, const int8_t* filter_data,
+                    const int32_t* bias_data, int16_t* output_data,
+                    const int num_batches, const int output_depth,
+                    const int accum_depth) {
+  WORD32 err;
+#pragma loop_count min = 1
+  for (int b = 0; b < num_batches; b++) {
+    err = xa_nn_matXvec_out_stride_sym8sxasym8s_16(
+        output_data + b * output_depth, filter_data,
+        input_data + b * accum_depth, bias_data, output_depth, accum_depth,
+        accum_depth, 1, params.input_offset, params.output_multiplier,
+        params.output_shift);
+  }
+  return;
+}
+
+void FullyConnected(const FullyConnectedParams& params,
+                    const int16_t* input_data, const int8_t* filter_data,
+                    const int64_t* bias_data, int16_t* output_data,
+                    const int num_batches, const int output_depth,
+                    const int accum_depth) {
+  WORD32 err;
+
+  err = xa_nn_matmul_sym8sxsym16s_sym16s(
+      output_data, filter_data, input_data, bias_data, output_depth,
+      accum_depth, accum_depth, num_batches, accum_depth, output_depth, 1,
+      params.input_offset, params.output_multiplier, params.output_shift,
+      params.output_offset);
+  return;
+}
+
+void FullyConnected(const FullyConnectedParams& params, const float* input_data,
+                    const float* filter_data, const float* bias_data,
+                    float* output_data, const int num_batches,
+                    const int output_depth, const int accum_depth) {
+  int input_dims[2] = {num_batches, output_depth};
+  RuntimeShape input_shape(2, reinterpret_cast<const int32_t*>(input_dims));
+  RuntimeShape bias_shape(1, bias_data == NULL ? 0 : output_depth);
+  int filter_dims[2] = {output_depth, accum_depth};
+  RuntimeShape filter_shape(2, reinterpret_cast<const int32_t*>(filter_dims));
+  int output_dims[2] = {num_batches, output_depth};
+  RuntimeShape output_shape(2, reinterpret_cast<const int32_t*>(output_dims));
+  return tflite::reference_ops::FullyConnected(
+      params, input_shape, input_data, filter_shape, filter_data, bias_shape,
+      bias_data, output_shape, output_data);
+}
+#endif  // #if !(defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
+
+void Clipping(const int v_size, const CellStateInfo& cell_state_info,
+              int16_t* vector) {
+  for (int i = 0; i < v_size; i++) {
+    vector[i] =
+        std::max(std::min(cell_state_info.quantized_cell_clip, vector[i]),
+                 static_cast<int16_t>(-cell_state_info.quantized_cell_clip));
+  }
+}
+
+void Clipping(const int v_size, const CellStateInfo& cell_state_info,
+              float* vector) {
+  for (int i = 0; i < v_size; i++) {
+    vector[i] = std::max(std::min(cell_state_info.cell_clip, vector[i]),
+                         -cell_state_info.cell_clip);
+  }
+}
+
+#if defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
+void UpdateLstmCell(const LstmStepManager& step_info,
+                    TfLiteEvalTensor* cell_state,
+                    // Gate outputs
+                    int16_t* forget_gate_output,
+                    const int16_t* input_gate_output,
+                    const int16_t* cell_gate_output,
+                    // Mul parameters
+                    const ArithmeticParams& forget_cell_mul_params,
+                    const ArithmeticParams& input_mul_params,
+                    const CellStateInfo& cell_state_info, int16_t* buffer) {
+  auto cell_state_shape = step_info.StateShape();
+  // Check offset validity to avoid memory overflow
+  TFLITE_DCHECK_LE(step_info.CellStateOffset() + cell_state_shape.FlatSize(),
+                   tflite::micro::GetTensorShape(cell_state).FlatSize());
+
+  WORD32 err;
+  // Multiplier is equivalent to 0.5 here so adding 1 to shifts
+  err = xa_nn_lstm_cell_state_update_16(
+      tflite::micro::GetTensorData<int16_t>(cell_state) +
+          step_info.CellStateOffset(),
+      forget_gate_output, cell_gate_output, input_gate_output,
+      forget_cell_mul_params.output_shift - 1,
+      input_mul_params.output_shift - 1, cell_state_info.quantized_cell_clip,
+      cell_state_shape.FlatSize());
+}
+
+void UpdateLstmCell(const LstmStepManager& step_info,
+                    TfLiteEvalTensor* cell_state,
+                    // Gate outputs
+                    float* forget_gate_output, const float* input_gate_output,
+                    const float* cell_gate_output,
+                    // Mul parameters
+                    const ArithmeticParams& forget_cell_mul_params,
+                    const ArithmeticParams& input_mul_params,
+                    const CellStateInfo& cell_state_info, float* buffer) {
+  // Check offset validity to avoid memory overflow
+  TFLITE_DCHECK_LE(
+      step_info.CellStateOffset() + step_info.StateShape().FlatSize(),
+      tflite::micro::GetTensorShape(cell_state).FlatSize());
+
+  auto cell_state_shape = step_info.StateShape();
+  // Forget Gate x Cell State
+  Mul(forget_cell_mul_params, forget_gate_output,
+      tflite::micro::GetTensorData<float>(cell_state) +
+          step_info.CellStateOffset(),
+      tflite::micro::GetTensorData<float>(cell_state) +
+          step_info.CellStateOffset(),
+      cell_state_shape.FlatSize());
+  // Input Gate x Cell Gate
+  Mul(input_mul_params, input_gate_output, cell_gate_output, buffer,
+      cell_state_shape.FlatSize());
+
+  // Update the cell state
+  AddElementWise(tflite::micro::GetTensorData<float>(cell_state) +
+                     step_info.CellStateOffset(),
+                 buffer,
+                 /*n_batch=*/cell_state_shape.DimsData()[0],
+                 /*n_state=*/cell_state_shape.DimsData()[1],
+                 tflite::micro::GetTensorData<float>(cell_state) +
+                     step_info.CellStateOffset());
+
+  if (cell_state_info.cell_clip > 0) {
+    Clipping(cell_state_shape.FlatSize(), cell_state_info,
+             tflite::micro::GetTensorData<float>(cell_state) +
+                 step_info.CellStateOffset());
+  }
+}
+#endif  // #if defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
+
+// Increment the data offset so the sigle time step invocation call can access
+// the corresponding input/output tensor data at the time step
+void LstmStepManager::UpdateTime() {
+  current_time_ += 1;
+  TFLITE_DCHECK_LE(current_time_, size_info_.time_steps);
+  // default as one batch per inference
+  int input_step = size_info_.input_dimension;
+  int output_step = size_info_.state_dimension;
+  // time major: batch inference
+  if (size_info_.time_major) {
+    input_step = input_step * size_info_.batch_size;
+    output_step = output_step * size_info_.batch_size;
+  }
+
+  input_offset_ += input_step;
+  output_offset_ += output_step;
+}
+
+// Increment the data offset so the sigle time step invocation call can access
+// the corresponding hidden/cell state tensor data at the time step (for single
+// batch inference only)
+void LstmStepManager::UpdateBatch() {
+  current_batch_ += 1;
+  TFLITE_DCHECK_LE(current_batch_, size_info_.batch_size);
+  // batch inference for time major: no action needed
+  if (size_info_.time_major) {
+    return;
+  }
+  // otherwise: singe batch inference, go to the next batch
+  hidden_state_offset_ += size_info_.state_dimension;
+  cell_state_offset_ += size_info_.state_dimension;
+}
+
+// Input shape for each single time LSTM invocation.
+// Multi-batch for time_major input
+RuntimeShape LstmStepManager::InputShape() const {
+  int batch_size = 1;
+  if (size_info_.time_major) {
+    batch_size = size_info_.batch_size;
+  }
+  const int dims[2] = {batch_size, size_info_.input_dimension};
+  const int32_t* dims_data = reinterpret_cast<const int32_t*>(dims);
+  return RuntimeShape(2, dims_data);
+}
+
+// State shape (both hidden and cell) for each single time LSTM invocation.
+// Multi-batch for time_major input
+RuntimeShape LstmStepManager::StateShape() const {
+  int batch_size = 1;
+  if (size_info_.time_major) {
+    batch_size = size_info_.batch_size;
+  }
+  const int dims[2] = {batch_size, size_info_.state_dimension};
+  const int32_t* dims_data = reinterpret_cast<const int32_t*>(dims);
+  return RuntimeShape(2, dims_data);
+}
+
+}  // namespace lstm_internal
 }  // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/xtensa/lstm_eval.h b/tensorflow/lite/micro/kernels/xtensa/lstm_eval.h
index 5dd746a..0ba5e22 100644
--- a/tensorflow/lite/micro/kernels/xtensa/lstm_eval.h
+++ b/tensorflow/lite/micro/kernels/xtensa/lstm_eval.h
@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -12,205 +12,813 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#ifndef TENSORFLOW_LITE_KERNELS_LSTM_EVAL_H_
-#define TENSORFLOW_LITE_KERNELS_LSTM_EVAL_H_
 
+// Functions to perform integer evaulation for standard LSTM (e.g., defined in
+// the keras lstm layer, no peephole etc.). Currently used by the 16 bits
+// activation case only
+
+#ifndef TENSORFLOW_LITE_MICRO_KERNELS_LSTM_EVAL_GENERAL_H_
+#define TENSORFLOW_LITE_MICRO_KERNELS_LSTM_EVAL_GENERAL_H_
+#include <algorithm>
 #include <cstdint>
-#include <memory>
 
 #include "tensorflow/lite/c/builtin_op_data.h"
 #include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h"
-#include "tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h"
 #include "tensorflow/lite/micro/kernels/kernel_util.h"
+#include "tensorflow/lite/micro/kernels/lstm_shared.h"
+#include "tensorflow/lite/micro/kernels/xtensa/xtensa.h"
+#include "tensorflow/lite/micro/micro_log.h"
 
 namespace tflite {
-namespace ops {
-namespace micro {
-namespace lstm_eval {
 
-#if defined(HIFI5)
-void calc_cell_state_without_cifg(int16_t* cell_state,
-                                  const int16_t* forget_gate,
-                                  const int16_t* cell_gate,
-                                  const int16_t* input_gate, int shift1,
-                                  int shift2, int clip, int num_elms);
+// Interface to access all the TempTfLiteTensors of the LSTM kernel during the
+// preparation phase. Can only be constructed through the constructor to avoid
+// memory leakage. All TempTfLiteTensors will be deallocated through the
+// destructor.
+class LstmTensors {
+ public:
+  LstmTensors(const LstmTensors& other) = delete;
+  LstmTensors& operator=(const LstmTensors& other) = delete;
 
-void calc_cell_state_with_cifg(int16_t* cell_state, const int16_t* forget_gate,
-                               const int16_t* cell_gate, int shift1, int shift2,
-                               int clip, int num_elms);
+  LstmTensors(TfLiteContext* context, TfLiteNode* node);
+  ~LstmTensors();
 
-void xa_nn_elm_mul_16x16_asym8s(int8_t* output, const int16_t* input_1,
-                                const int16_t* input_2, int32_t multiplier,
-                                int32_t shift, int32_t zero_point,
-                                int num_elms);
-#endif  // defined(HIFI5)
+  // Verify the LSTM internal tensor properties (e.g., type checks)
+  // Input/output/states/fc weights tensors are required for kernel evaulation.
+  // The state tensors should be variables. Variants of the standard LSTM
+  // are not supported here, therefore their corresponding tensors should be
+  // invalid
+  TfLiteStatus ValidateTensorStatus(TfLiteContext* context) const;
 
-// Pamameters for integer LSTM.
-// Consider split this into two Integer Parameters if more fields are added.
-struct IntegerLstmParameter {
-  int32_t effective_input_to_input_scale_a;
-  int effective_input_to_input_scale_b;
-  int32_t effective_recurrent_to_input_scale_a;
-  int effective_recurrent_to_input_scale_b;
-  int32_t effective_cell_to_input_scale_a;
-  int effective_cell_to_input_scale_b;
-  int32_t effective_input_to_forget_scale_a;
-  int effective_input_to_forget_scale_b;
-  int32_t effective_recurrent_to_forget_scale_a;
-  int effective_recurrent_to_forget_scale_b;
-  int32_t effective_cell_to_forget_scale_a;
-  int effective_cell_to_forget_scale_b;
-  int32_t effective_input_to_cell_scale_a;
-  int effective_input_to_cell_scale_b;
-  int32_t effective_recurrent_to_cell_scale_a;
-  int effective_recurrent_to_cell_scale_b;
-  int32_t effective_input_to_output_scale_a;
-  int effective_input_to_output_scale_b;
-  int32_t effective_recurrent_to_output_scale_a;
-  int effective_recurrent_to_output_scale_b;
-  int32_t effective_cell_to_output_scale_a;
-  int effective_cell_to_output_scale_b;
-  int32_t effective_proj_scale_a;
-  int effective_proj_scale_b;
-  int32_t effective_hidden_scale_a;
-  int effective_hidden_scale_b;
-  int32_t layer_norm_input_scale_a;
-  int layer_norm_input_scale_b;
-  int32_t layer_norm_forget_scale_a;
-  int layer_norm_forget_scale_b;
-  int32_t layer_norm_cell_scale_a;
-  int layer_norm_cell_scale_b;
-  int32_t layer_norm_output_scale_a;
-  int layer_norm_output_scale_b;
-  // Quantized clip value for cell and projection. Zero value means no clipping.
-  int16_t quantized_cell_clip;
-  int8_t quantized_proj_clip;
-  int32_t hidden_zp;
-  int32_t cell_scale;
+  // Internal tensors. see lstm_shared.h for tensor names
+  const TfLiteTensor* GetInternalTensor(const int tensor_index) const {
+    return internal_tensors_[tensor_index];
+  }
 
-  int32_t input_variance_guard;
-  int32_t forget_variance_guard;
-  int32_t cell_variance_guard;
-  int32_t output_variance_guard;
+  const TfLiteTensor* HiddenStateTensor() const {
+    return internal_tensors_[kLstmOutputStateTensor];
+  }
+  const TfLiteTensor* CellStateTensor() const {
+    return internal_tensors_[kLstmCellStateTensor];
+  }
+  const TfLiteTensor* OutputTensor() const { return output_tensor_; }
 
-  // Pre-calculate bias + zero_point * weight.
-  // Unabled to use temporary tensors since those are used in Prepare() and
-  // scratch buffer is only allocated after Preapre().
-  std::unique_ptr<int32_t[]> input_to_forget_effective_bias;
-  std::unique_ptr<int32_t[]> recurrent_to_forget_effective_bias;
-  std::unique_ptr<int32_t[]> input_to_cell_effective_bias;
-  std::unique_ptr<int32_t[]> recurrent_to_cell_effective_bias;
-  std::unique_ptr<int32_t[]> input_to_output_effective_bias;
-  std::unique_ptr<int32_t[]> recurrent_to_output_effective_bias;
-  std::unique_ptr<int32_t[]> input_to_input_effective_bias;
-  std::unique_ptr<int32_t[]> recurrent_to_input_effective_bias;
-  std::unique_ptr<int32_t[]> projection_effective_bias;
-
-  // Scale and zero point for intermediate tensors.
-  // Used only in the 8x8_8 case.
-  int32_t intermediate_scale_a[8];
-  int32_t intermediate_scale_b[8];
-  int32_t intermediate_zp[12];
+ private:
+  // see lstm_shared.h for tensor names
+  MicroContext* micro_context_;
+  TfLiteTensor* internal_tensors_[24];
+  TfLiteTensor* output_tensor_;
 };
 
-TfLiteStatus EvalFloat(const TfLiteEvalTensor* input,
-                       const TfLiteEvalTensor* input_to_input_weights,
-                       const TfLiteEvalTensor* input_to_forget_weights,
-                       const TfLiteEvalTensor* input_to_cell_weights,
-                       const TfLiteEvalTensor* input_to_output_weights,
-                       const TfLiteEvalTensor* recurrent_to_input_weights,
-                       const TfLiteEvalTensor* recurrent_to_forget_weights,
-                       const TfLiteEvalTensor* recurrent_to_cell_weights,
-                       const TfLiteEvalTensor* recurrent_to_output_weights,
-                       const TfLiteEvalTensor* cell_to_input_weights,
-                       const TfLiteEvalTensor* cell_to_forget_weights,
-                       const TfLiteEvalTensor* cell_to_output_weights,
-                       const TfLiteEvalTensor* input_layer_norm_coefficients,
-                       const TfLiteEvalTensor* forget_layer_norm_coefficients,
-                       const TfLiteEvalTensor* cell_layer_norm_coefficients,
-                       const TfLiteEvalTensor* output_layer_norm_coefficients,
-                       const TfLiteEvalTensor* aux_input,
-                       const TfLiteEvalTensor* aux_input_to_input_weights,
-                       const TfLiteEvalTensor* aux_input_to_forget_weights,
-                       const TfLiteEvalTensor* aux_input_to_cell_weights,
-                       const TfLiteEvalTensor* aux_input_to_output_weights,
-                       const TfLiteEvalTensor* input_gate_bias,
-                       const TfLiteEvalTensor* forget_gate_bias,
-                       const TfLiteEvalTensor* cell_gate_bias,
-                       const TfLiteEvalTensor* output_gate_bias,
-                       const TfLiteEvalTensor* projection_weights,
-                       const TfLiteEvalTensor* projection_bias,
-                       const TfLiteLSTMParams* params, bool forward_sequence,
-                       bool time_major, int output_offset,
-                       TfLiteEvalTensor* scratch_buffer,
-                       TfLiteEvalTensor* output_state,
-                       TfLiteEvalTensor* cell_state, TfLiteEvalTensor* output);
+// Deduce the size information (Batch (B), Time Steps (T), Input dimension (I),
+// State dimension (S)) that defines the LSTM using the input and hidden state
+// tensor
+LstmSizeInfo CreateLstmSizeInfo(
+    const bool time_major, const TfLiteIntArray* input_tensor_shape,
+    const TfLiteIntArray* hidden_state_tensor_shape);
 
-TfLiteStatus EvalInteger8x8_16(
-    TfLiteContext* context, TfLiteNode* node, const TfLiteEvalTensor* input,
-    const TfLiteEvalTensor* input_to_input_weights,
-    const TfLiteEvalTensor* input_to_forget_weights,
-    const TfLiteEvalTensor* input_to_cell_weights,
-    const TfLiteEvalTensor* input_to_output_weights,
-    const TfLiteEvalTensor* recurrent_to_input_weights,
-    const TfLiteEvalTensor* recurrent_to_forget_weights,
-    const TfLiteEvalTensor* recurrent_to_cell_weights,
-    const TfLiteEvalTensor* recurrent_to_output_weights,
-    const TfLiteEvalTensor* cell_to_input_weights,
-    const TfLiteEvalTensor* cell_to_forget_weights,
-    const TfLiteEvalTensor* cell_to_output_weights,
-    const TfLiteEvalTensor* input_layer_norm_coefficients,
-    const TfLiteEvalTensor* forget_layer_norm_coefficients,
-    const TfLiteEvalTensor* cell_layer_norm_coefficients,
-    const TfLiteEvalTensor* output_layer_norm_coefficients,
-    const TfLiteEvalTensor* input_gate_bias,
-    const TfLiteEvalTensor* forget_gate_bias,
-    const TfLiteEvalTensor* cell_gate_bias,
-    const TfLiteEvalTensor* output_gate_bias,
-    const TfLiteEvalTensor* projection_weights,
-    const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
-    bool forward_sequence, bool time_major,
-    const lstm_eval::IntegerLstmParameter* integer_lstm_param,
-    TfLiteEvalTensor* output_state, TfLiteEvalTensor* cell_state,
-    TfLiteEvalTensor* output, TfLiteEvalTensor* scratch0,
-    TfLiteEvalTensor* scratch1, TfLiteEvalTensor* scratch2,
-    TfLiteEvalTensor* scratch3, TfLiteEvalTensor* scratch4,
-    TfLiteEvalTensor* scratch5);
+TfLiteStatus ValidateWeightTensorSize(TfLiteContext* context,
+                                      const TfLiteTensor* tensor, int dim1_size,
+                                      int dim2_size);
 
-TfLiteStatus EvalInteger8x8_8(
-    const TfLiteEvalTensor* input,
-    const TfLiteEvalTensor* input_to_input_weights,
-    const TfLiteEvalTensor* input_to_forget_weights,
-    const TfLiteEvalTensor* input_to_cell_weights,
-    const TfLiteEvalTensor* input_to_output_weights,
-    const TfLiteEvalTensor* recurrent_to_input_weights,
-    const TfLiteEvalTensor* recurrent_to_forget_weights,
-    const TfLiteEvalTensor* recurrent_to_cell_weights,
-    const TfLiteEvalTensor* recurrent_to_output_weights,
-    const TfLiteEvalTensor* cell_to_input_weights,
-    const TfLiteEvalTensor* cell_to_forget_weights,
-    const TfLiteEvalTensor* cell_to_output_weights,
-    const TfLiteEvalTensor* input_layer_norm_coefficients,
-    const TfLiteEvalTensor* forget_layer_norm_coefficients,
-    const TfLiteEvalTensor* cell_layer_norm_coefficients,
-    const TfLiteEvalTensor* output_layer_norm_coefficients,
-    const TfLiteEvalTensor* input_gate_bias,
-    const TfLiteEvalTensor* forget_gate_bias,
-    const TfLiteEvalTensor* cell_gate_bias,
-    const TfLiteEvalTensor* output_gate_bias,
-    const TfLiteEvalTensor* projection_weights,
-    const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
-    TfLiteEvalTensor* output_state, TfLiteEvalTensor* cell_state,
-    TfLiteEvalTensor* output,
-    const lstm_eval::IntegerLstmParameter* integer_lstm_param,
-    TfLiteEvalTensor* scratch0, TfLiteEvalTensor* scratch1,
-    TfLiteEvalTensor* scratch2, TfLiteEvalTensor* scratch3,
-    TfLiteEvalTensor* scratch4, TfLiteEvalTensor* scratch5,
-    TfLiteEvalTensor* scratch6, TfLiteEvalTensor* scratch7);
+TfLiteStatus ValidateBiasTensorSize(TfLiteContext* context,
+                                    const TfLiteTensor* tensor, int size);
 
-}  // namespace lstm_eval
-}  // namespace micro
-}  // namespace ops
+// Go through every tensors and make sure their shape match the kernel
+// configuration
+TfLiteStatus ValidateTensorSize(TfLiteContext* context,
+                                const LstmTensors& tensors,
+                                const LstmSizeInfo& size_info);
+
+// Wrapper function to create gate parameters for the four internal LSTM gates
+TfLiteStatus CreateGateParams(
+    TfLiteContext* context,
+    /*Input tensors*/
+    const TfLiteTensor* input, const TfLiteTensor* input_weight,
+    const TfLiteTensor* input_bias,
+    /*Hidden state tensors*/
+    const TfLiteTensor* hidden_state, const TfLiteTensor* hidden_state_weight,
+    const TfLiteTensor* hidden_state_bias,
+    /*Scale of the fc output (input to non-linear activation)*/
+    const float nonlinear_activation_input_scale, const TfLiteType cell_type,
+    const tflite::GateParameters& gate_params);
+
+// Create parameters for element wise multiplication that happens in a) cell
+// state update ; b) hidden state update
+// Note that all the output of gates are symmetrically quantized so only scales
+// are required for input. However, during the hidden state update phase, the
+// output is the updated hidden state, which is asymmetrically quantized. Thus
+// output may require zero point
+tflite::ArithmeticParams CreateInterGateMulParams(const float input1_scale,
+                                                  const float input2_scale,
+                                                  const float output_scale,
+                                                  const TfLiteType output_type,
+                                                  const int output_zp = 0);
+
+// Create the additional information about the cell state, which include:
+// cell_state_scale_power: used in integer nonlinear function (e.g., tanh)
+// quantized_cell_clip: quantized cell clip range
+CellStateInfo CreateLstmCellStateInfo(const float cell_state_scale,
+                                      const float cell_clip);
+
+CellStateInfo CreateLstmCellStateInfoFloat(const float cell_clip);
+tflite::FullyConnectedParams CreateFCParamsFloat();
+
+tflite::GateParameters CreateGateParamsFloat();
+
+tflite::ArithmeticParams CreateInterGateMulParamsFloat();
+
+TfLiteStatus PrepareGateParametersFloat(TfLiteContext* context,
+                                        const LstmTensors& lstm_tensors,
+                                        OpDataLSTM* op_data_lstm);
+
+TfLiteStatus PrepareGateParametersInteger(TfLiteContext* context,
+                                          const LstmTensors& lstm_tensors,
+                                          OpDataLSTM* op_data_lstm);
+
+LSTMKernelContents CreateLSTMKernelContent(TfLiteContext* context,
+                                           TfLiteNode* node);
+
+template <typename CellType>
+LSTMBuffers<CellType> CreateLSTMBuffers(TfLiteContext* context,
+                                        const int* buffer_indices) {
+  LSTMBuffers<CellType> buffers;
+  buffers.buffer0 = reinterpret_cast<CellType*>(
+      context->GetScratchBuffer(context, buffer_indices[0]));
+  buffers.buffer1 = reinterpret_cast<CellType*>(
+      context->GetScratchBuffer(context, buffer_indices[1]));
+  buffers.buffer2 = reinterpret_cast<CellType*>(
+      context->GetScratchBuffer(context, buffer_indices[2]));
+  buffers.buffer3 = reinterpret_cast<CellType*>(
+      context->GetScratchBuffer(context, buffer_indices[3]));
+  return buffers;
+}
+
+// Since LSTM includes multiple intermediate stages, introducing the internal
+// namespace to expose them for testing
+namespace lstm_internal {
+
+#if !(defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
+void Sigmoid(const RuntimeShape& data_shape, int16_t* data);
+
+void Sigmoid(const RuntimeShape& data_shape, float* data);
+
+void Tanh(int32_t cell_state_scale_power, const RuntimeShape& input_data_shape,
+          int16_t* input_data, const RuntimeShape& output_data_shape,
+          int16_t* output_data);
+
+void Tanh(int32_t cell_state_scale_power, const RuntimeShape& input_data_shape,
+          float* input_data, const RuntimeShape& output_data_shape,
+          float* output_data);
+
+void Mul(const RuntimeShape& shape, const ArithmeticParams& params,
+         const int16_t* input1_data, const int16_t* input2_data,
+         int8_t* output_data);
+
+void Mul(const RuntimeShape& shape, const ArithmeticParams& params,
+         const int16_t* input1_data, const int16_t* input2_data,
+         int16_t* output_data);
+
+void Mul(const RuntimeShape& shape, const ArithmeticParams& params,
+         const float* input1_data, const float* input2_data,
+         float* output_data);
+
+void FullyConnected(const FullyConnectedParams& params,
+                    const RuntimeShape& input_shape, const int8_t* input_data,
+                    const RuntimeShape& filter_shape, const int8_t* filter_data,
+                    const RuntimeShape& bias_shape, const int32_t* bias_data,
+                    const RuntimeShape& output_shape, int16_t* output_data);
+
+void FullyConnected(const FullyConnectedParams& params,
+                    const RuntimeShape& input_shape, const int16_t* input_data,
+                    const RuntimeShape& filter_shape, const int8_t* filter_data,
+                    const RuntimeShape& bias_shape, const int64_t* bias_data,
+                    const RuntimeShape& output_shape, int16_t* output_data);
+
+void FullyConnected(const FullyConnectedParams& params,
+                    const RuntimeShape& input_shape, const float* input_data,
+                    const RuntimeShape& filter_shape, const float* filter_data,
+                    const RuntimeShape& bias_shape, const float* bias_data,
+                    const RuntimeShape& output_shape, float* output_data);
+#else   // #if !(defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
+void Sigmoid(int16_t* data, int32_t data_size);
+
+void Sigmoid(float* data, int32_t data_size);
+
+void Tanh(int32_t cell_state_scale_power, int16_t* input_data,
+          int16_t* output_data, int32_t data_size);
+
+void Tanh(int32_t cell_state_scale_power, float* input_data, float* output_data,
+          int32_t data_size);
+
+void Mul(const ArithmeticParams& params, const int16_t* input1_data,
+         const int16_t* input2_data, int8_t* output_data, int32_t data_size);
+
+void Mul(const ArithmeticParams& params, const int16_t* input1_data,
+         const int16_t* input2_data, int16_t* output_data, int32_t data_size);
+
+void Mul(const ArithmeticParams& params, const float* input1_data,
+         const float* input2_data, float* output_data, int32_t data_size);
+
+void FullyConnected(const FullyConnectedParams& params,
+                    const int8_t* input_data, const int8_t* filter_data,
+                    const int32_t* bias_data, int16_t* output_data,
+                    const int num_batches, const int output_depth,
+                    const int accum_depth);
+
+void FullyConnected(const FullyConnectedParams& params,
+                    const int16_t* input_data, const int8_t* filter_data,
+                    const int64_t* bias_data, int16_t* output_data,
+                    const int num_batches, const int output_depth,
+                    const int accum_depth);
+
+void FullyConnected(const FullyConnectedParams& params, const float* input_data,
+                    const float* filter_data, const float* bias_data,
+                    float* output_data, const int num_batches,
+                    const int output_depth, const int accum_depth);
+#endif  // #if !(defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
+
+void AddElementWise(const int16_t* input_1, const int16_t* input_2, int n_batch,
+                    int n_input, int16_t* output);
+
+void AddElementWise(const float* input_1, const float* input_2, int n_batch,
+                    int n_input, float* output);
+
+void Clipping(const int v_size, const CellStateInfo& cell_state_info,
+              int16_t* vector);
+
+void Clipping(const int v_size, const CellStateInfo& cell_state_info,
+              float* vector);
+
+// Manages the slice position (offset), slice length (sliced tensor shape),
+// and update rules for input/output/hidden state/cell state tensors at each
+// time step.
+class LstmStepManager {
+ public:
+  LstmStepManager() = delete;
+  // Does not take any ownership, and all pointers must refer to valid objects
+  // that outlive the one constructed.
+  explicit LstmStepManager(const LstmSizeInfo* size_info)
+      : size_info_(*size_info) {}
+
+  void UpdateTime();
+  void UpdateBatch();
+
+  void ResetTime() { current_time_ = 0; }
+  RuntimeShape InputShape() const;
+  RuntimeShape StateShape() const;
+
+  int InputOffset() const { return input_offset_; }
+  int OutputOffset() const { return output_offset_; }
+  int HiddenStateOffset() const { return hidden_state_offset_; }
+  int CellStateOffset() const { return cell_state_offset_; }
+#if defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
+  int time_major() const { return size_info_.time_major; }
+
+  int batch_size() const { return size_info_.batch_size; }
+
+  int input_dimension() const { return size_info_.input_dimension; }
+
+  int state_dimension() const { return size_info_.state_dimension; }
+#endif
+
+ private:
+  int current_time_ = 0;
+  int current_batch_ = 0;
+  int input_offset_ = 0;
+  int output_offset_ = 0;
+  int hidden_state_offset_ = 0;
+  int cell_state_offset_ = 0;
+  // Sizeinfo is from LstmOpData, which reside in the memory arena
+  // (guarante to outlast LSTMStepManager, which reside in stack)
+  const LstmSizeInfo& size_info_;
+};
+
+// Calculates a single LSTM gate.
+// Implements the following formula:
+//   gate = activate(FC(input) + FC(recurrent))
+// Activation is sigmoid except for the "cell" gate (configurable, usually tanh)
+#if !(defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
+template <typename ActivationType, typename WeightType, typename CellType,
+          typename BiasType>
+void CalculateLstmGate(
+    const LstmStepManager& step_info, const GateParameters& gate_params,
+    // Input FC
+    const TfLiteEvalTensor* input, const TfLiteEvalTensor* input_weight,
+    const TfLiteEvalTensor* input_bias,
+    // Recurrent FC
+    const TfLiteEvalTensor* recurrent, const TfLiteEvalTensor* recurrent_weight,
+    const TfLiteEvalTensor* recurrent_bias,
+    // Output
+    CellType* gate_output,
+    // Scratch arrays
+    CellType* fc_output_buffer, const TfLiteFusedActivation activation) {
+  const auto gate_output_shape = step_info.StateShape();
+  // Check offset validity to avoid memory overflow
+  TFLITE_DCHECK_LE(step_info.InputOffset() + step_info.InputShape().FlatSize(),
+                   tflite::micro::GetTensorShape(input).FlatSize());
+  TFLITE_DCHECK_LE(
+      step_info.HiddenStateOffset() + step_info.StateShape().FlatSize(),
+      tflite::micro::GetTensorShape(recurrent).FlatSize());
+
+  // Input FC
+  FullyConnected(gate_params.input_fc_params, step_info.InputShape(),
+                 tflite::micro::GetTensorData<ActivationType>(input) +
+                     step_info.InputOffset(),
+                 micro::GetTensorShape(input_weight),
+                 tflite::micro::GetTensorData<WeightType>(input_weight),
+                 tflite::micro::GetTensorShape(input_bias),
+                 tflite::micro::GetOptionalTensorData<BiasType>(input_bias),
+                 gate_output_shape, gate_output);
+
+  // Recurrent FC
+  FullyConnected(gate_params.recurrent_fc_params, step_info.StateShape(),
+                 tflite::micro::GetTensorData<ActivationType>(recurrent) +
+                     step_info.HiddenStateOffset(),
+                 tflite::micro::GetTensorShape(recurrent_weight),
+                 tflite::micro::GetTensorData<WeightType>(recurrent_weight),
+                 tflite::micro::GetTensorShape(recurrent_bias),
+                 tflite::micro::GetOptionalTensorData<BiasType>(recurrent_bias),
+                 gate_output_shape, fc_output_buffer);
+
+  AddElementWise(gate_output, fc_output_buffer,
+                 /*n_batch=*/gate_output_shape.DimsData()[0],
+                 /*n_state=*/gate_output_shape.DimsData()[1], gate_output);
+  // Apply activation
+  switch (activation) {
+    case kTfLiteActSigmoid:
+      Sigmoid(gate_output_shape, gate_output);
+      break;
+    case kTfLiteActTanh: {
+      // Set the scale power to -12 to avoid shift
+      Tanh(/*cell_state_scale_power=*/-12, gate_output_shape, gate_output,
+           gate_output_shape, gate_output);
+    } break;
+    default:
+      // Only Sigmoid or Tanh is used.
+      TFLITE_ASSERT_FALSE;
+  }
+}
+
+// Update the cell state using the output from the forget gate, input gate, and
+// cell gate Formula: updated_cell_state = forget_gate_output*cell_state +
+// input_gate_output * cell_gate_output, where * denotes element wise
+// multiplication
+template <typename CellType>
+void UpdateLstmCell(const LstmStepManager& step_info,
+                    TfLiteEvalTensor* cell_state,
+                    // Gate outputs
+                    CellType* forget_gate_output,
+                    const CellType* input_gate_output,
+                    const CellType* cell_gate_output,
+                    // Mul parameters
+                    const ArithmeticParams& forget_cell_mul_params,
+                    const ArithmeticParams& input_mul_params,
+                    const CellStateInfo& cell_state_info, CellType* buffer) {
+  // Check offset validity to avoid memory overflow
+  TFLITE_DCHECK_LE(
+      step_info.CellStateOffset() + step_info.StateShape().FlatSize(),
+      tflite::micro::GetTensorShape(cell_state).FlatSize());
+
+  auto cell_state_shape = step_info.StateShape();
+  // Forget Gate x Cell State
+  Mul(cell_state_shape, forget_cell_mul_params, forget_gate_output,
+      tflite::micro::GetTensorData<CellType>(cell_state) +
+          step_info.CellStateOffset(),
+      tflite::micro::GetTensorData<CellType>(cell_state) +
+          step_info.CellStateOffset());
+  // Input Gate x Cell Gate
+  Mul(cell_state_shape, input_mul_params, input_gate_output, cell_gate_output,
+      buffer);
+
+  // Update the cell state
+  AddElementWise(tflite::micro::GetTensorData<CellType>(cell_state) +
+                     step_info.CellStateOffset(),
+                 buffer,
+                 /*n_batch=*/cell_state_shape.DimsData()[0],
+                 /*n_state=*/cell_state_shape.DimsData()[1],
+                 tflite::micro::GetTensorData<CellType>(cell_state) +
+                     step_info.CellStateOffset());
+
+  if (cell_state_info.cell_clip > 0) {
+    Clipping(cell_state_shape.FlatSize(), cell_state_info,
+             tflite::micro::GetTensorData<CellType>(cell_state) +
+                 step_info.CellStateOffset());
+  }
+}
+#else   // #if !(defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
+template <typename ActivationType, typename WeightType, typename CellType,
+          typename BiasType>
+void CalculateLstmGate(
+    const LstmStepManager& step_info, const GateParameters& gate_params,
+    // Input FC
+    const TfLiteEvalTensor* input, const TfLiteEvalTensor* input_weight,
+    const TfLiteEvalTensor* input_bias,
+    // Recurrent FC
+    const TfLiteEvalTensor* recurrent, const TfLiteEvalTensor* recurrent_weight,
+    const TfLiteEvalTensor* recurrent_bias,
+    // Output
+    CellType* gate_output,
+    // Scratch arrays
+    CellType* fc_output_buffer, const TfLiteFusedActivation activation,
+    const int num_batches, const int input_dimension,
+    const int state_dimension) {
+  // RuntimeShape step_input_shape = step_info.InputShape();
+  // RuntimeShape input_shape = tflite::micro::GetTensorShape(input);
+  // RuntimeShape step_state_shape = step_info.StateShape();
+  // RuntimeShape recurrent_shape = tflite::micro::GetTensorShape(recurrent);
+
+  // Moved these to LstmStep function
+  // Check offset validity to avoid memory overflow
+  // TFLITE_DCHECK_LE(step_info.InputOffset() + step_input_shape.FlatSize(),
+  // input_shape.FlatSize());
+  // TFLITE_DCHECK_LE(
+  // step_info.HiddenStateOffset() + step_state_shape.FlatSize(),
+  // recurrent_shape.FlatSize());
+
+  // Input FC
+  FullyConnected(gate_params.input_fc_params,
+                 tflite::micro::GetTensorData<ActivationType>(input) +
+                     step_info.InputOffset(),
+                 tflite::micro::GetTensorData<WeightType>(input_weight),
+                 tflite::micro::GetOptionalTensorData<BiasType>(input_bias),
+                 gate_output, num_batches, state_dimension, input_dimension);
+
+  // Recurrent FC
+  FullyConnected(gate_params.recurrent_fc_params,
+                 tflite::micro::GetTensorData<ActivationType>(recurrent) +
+                     step_info.HiddenStateOffset(),
+                 tflite::micro::GetTensorData<WeightType>(recurrent_weight),
+                 tflite::micro::GetOptionalTensorData<BiasType>(recurrent_bias),
+                 fc_output_buffer, num_batches, state_dimension,
+                 state_dimension);
+
+  AddElementWise(gate_output, fc_output_buffer,
+                 /*n_batch=*/num_batches,
+                 /*n_state=*/state_dimension, gate_output);
+  // Apply activation
+  switch (activation) {
+    case kTfLiteActSigmoid:
+      Sigmoid(gate_output, num_batches * state_dimension);
+      break;
+    case kTfLiteActTanh: {
+      // Set the scale power to -12 to avoid shift
+      Tanh(/*cell_state_scale_power=*/-12, gate_output, gate_output,
+           num_batches * state_dimension);
+    } break;
+    default:
+      // Only Sigmoid or Tanh is used.
+      TFLITE_ASSERT_FALSE;
+  }
+}
+
+// Update the cell state using the output from the forget gate, input gate, and
+// cell gate Formula: updated_cell_state = forget_gate_output*cell_state +
+// input_gate_output * cell_gate_output, where * denotes element wise
+// multiplication
+void UpdateLstmCell(const LstmStepManager& step_info,
+                    TfLiteEvalTensor* cell_state,
+                    // Gate outputs
+                    int16_t* forget_gate_output,
+                    const int16_t* input_gate_output,
+                    const int16_t* cell_gate_output,
+                    // Mul parameters
+                    const ArithmeticParams& forget_cell_mul_params,
+                    const ArithmeticParams& input_mul_params,
+                    const CellStateInfo& cell_state_info, int16_t* buffer);
+
+void UpdateLstmCell(const LstmStepManager& step_info,
+                    TfLiteEvalTensor* cell_state,
+                    // Gate outputs
+                    float* forget_gate_output, const float* input_gate_output,
+                    const float* cell_gate_output,
+                    // Mul parameters
+                    const ArithmeticParams& forget_cell_mul_params,
+                    const ArithmeticParams& input_mul_params,
+                    const CellStateInfo& cell_state_info, float* buffer);
+#endif  // #if !(defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
+
+// Update the hidden state of the LSTM kernel using the following formula:
+// updated_hidden_state = Tanh(updated_cell_state) * output_gate_output, * means
+// element wise multiplication
+template <typename CellType, typename ActivationType>
+void UpdateLstmHidden(const LstmStepManager& step_info,
+                      TfLiteEvalTensor* cell_state,
+                      TfLiteEvalTensor* hidden_state,
+                      const CellType* output_gate_output,
+                      const ArithmeticParams& mul_params,
+                      int32_t cell_state_scale_power, CellType* buffer) {
+  // Check offset validity to avoid memory overflow
+  TFLITE_DCHECK_LE(
+      step_info.CellStateOffset() + step_info.StateShape().FlatSize(),
+      tflite::micro::GetTensorShape(cell_state).FlatSize());
+  TFLITE_DCHECK_LE(
+      step_info.HiddenStateOffset() + step_info.StateShape().FlatSize(),
+      tflite::micro::GetTensorShape(hidden_state).FlatSize());
+
+  auto cell_state_shape = step_info.StateShape();
+  CellType* cell_state_data =
+      tflite::micro::GetTensorData<CellType>(cell_state) +
+      step_info.CellStateOffset();
+  // Tanh(cell_state)
+#if !(defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
+  Tanh(cell_state_scale_power, cell_state_shape, cell_state_data,
+       cell_state_shape, buffer);
+  // Update the hidden state
+  Mul(cell_state_shape, mul_params, buffer, output_gate_output,
+      tflite::micro::GetTensorData<ActivationType>(hidden_state) +
+          step_info.HiddenStateOffset());
+#else
+  int32_t cell_state_size = cell_state_shape.FlatSize();
+  Tanh(cell_state_scale_power, cell_state_data, buffer, cell_state_size);
+  // Update the hidden state
+  Mul(mul_params, buffer, output_gate_output,
+      tflite::micro::GetTensorData<ActivationType>(hidden_state) +
+          step_info.HiddenStateOffset(),
+      cell_state_size);
+#endif
+}
+
+#if !(defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
+template <typename ActivationType, typename WeightType, typename CellType,
+          typename BiasType>
+void LstmStep(const LstmStepManager& step_info, const OpDataLSTM& op_data,
+              LSTMKernelContents& kernel_content,
+              const LSTMBuffers<CellType>& buffers) {
+  /*Step1: Calculate gate outputs to prepare cell state update*/
+  CellType* gate_internal_buffer = buffers.buffer3;
+  CellType* forget_gate_output = buffers.buffer0;
+  CalculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
+      step_info, op_data.forget_gate_parameters,
+      // Input FC
+      kernel_content.GetInternalTensor(tflite::kLstmInputTensor),
+      kernel_content.GetInternalTensor(tflite::kLstmInputToForgetWeightsTensor),
+      kernel_content.GetInternalTensor(tflite::kLstmForgetGateBiasTensor),
+      // Recurrent FC
+      kernel_content.HiddenStateTensor(),
+      kernel_content.GetInternalTensor(
+          tflite::kLstmRecurrentToForgetWeightsTensor),
+      /*recurrent_bias*/ nullptr,
+      // Output
+      forget_gate_output,
+      // Scratch arrays
+      gate_internal_buffer, kTfLiteActSigmoid);
+
+  // Input Gate calculation;
+  CellType* input_gate_output = buffers.buffer1;
+  CalculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
+      step_info, op_data.input_gate_parameters,
+      // Input FC
+      kernel_content.GetInternalTensor(tflite::kLstmInputTensor),
+      kernel_content.GetInternalTensor(tflite::kLstmInputToInputWeightsTensor),
+      kernel_content.GetInternalTensor(tflite::kLstmInputGateBiasTensor),
+      // Recurrent FC
+      kernel_content.HiddenStateTensor(),
+      kernel_content.GetInternalTensor(
+          tflite::kLstmRecurrentToInputWeightsTensor),
+      /*recurrent_bias*/ nullptr,
+      // Output
+      input_gate_output,
+      // Scratch arrays
+      gate_internal_buffer, kTfLiteActSigmoid);
+
+  // Cell Gate calculation
+  CellType* cell_gate_output = buffers.buffer2;
+  CalculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
+      step_info, op_data.cell_gate_parameters,
+      // Input FC
+      kernel_content.GetInternalTensor(tflite::kLstmInputTensor),
+      kernel_content.GetInternalTensor(tflite::kLstmInputToCellWeightsTensor),
+      kernel_content.GetInternalTensor(tflite::kLstmCellGateBiasTensor),
+      // Recurrent FC
+      kernel_content.HiddenStateTensor(),
+      kernel_content.GetInternalTensor(
+          tflite::kLstmRecurrentToCellWeightsTensor),
+      /*recurrent_bias*/ nullptr,
+      // Output
+      cell_gate_output,
+      // Scratch arrays
+      gate_internal_buffer, op_data.cell_gate_nonlinear_type);
+
+  /*Step2: update the cell state */
+  const InterGateParameters& inter_gate_params = op_data.inter_gate_parameters;
+  CellType* updated_input_buffer = buffers.buffer1;  // reuse buffer
+
+  UpdateLstmCell<CellType>(step_info, kernel_content.CellStateTensor(),
+                           forget_gate_output, input_gate_output,
+                           cell_gate_output,
+                           inter_gate_params.forget_cell_mul_params,
+                           inter_gate_params.input_mul_params,
+                           op_data.cell_state_info, updated_input_buffer);
+
+  /*Step3: update the hidden state */
+  CellType* output_gate_output = buffers.buffer1;  // reuse buffer
+  CalculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
+      step_info, op_data.output_gate_parameters,
+      // Input FC
+      kernel_content.GetInternalTensor(tflite::kLstmInputTensor),
+      kernel_content.GetInternalTensor(tflite::kLstmInputToOutputWeightsTensor),
+      kernel_content.GetInternalTensor(tflite::kLstmOutputGateBiasTensor),
+      // Recurrent FC
+      kernel_content.HiddenStateTensor(),
+      kernel_content.GetInternalTensor(
+          tflite::kLstmRecurrentToOutputWeightsTensor),
+      /*recurrent_bias*/ nullptr,
+      // Output
+      output_gate_output,
+      // Scratch arrays
+      gate_internal_buffer, kTfLiteActSigmoid);
+
+  CellType* tanh_activated_cell_buffer = buffers.buffer0;  // reuse buffer
+  tflite::lstm_internal::UpdateLstmHidden<CellType, ActivationType>(
+      step_info, kernel_content.CellStateTensor(),
+      kernel_content.HiddenStateTensor(), output_gate_output,
+      inter_gate_params.output_mul_params,
+      op_data.cell_state_info.cell_state_scale_power,
+      tanh_activated_cell_buffer);
+
+  /*Step4: copy the update the hidden state to output*/
+  // Check offset validity to avoid memory overflow
+  TFLITE_DCHECK_LE(
+      step_info.OutputOffset() + step_info.StateShape().FlatSize(),
+      tflite::micro::GetTensorShape(kernel_content.output_tensor).FlatSize());
+  // record the output (from the updated hidden state)
+  ActivationType* output_ptr = tflite::micro::GetTensorData<ActivationType>(
+      kernel_content.output_tensor);
+  const auto* hidden_state = kernel_content.HiddenStateTensor();
+  std::memcpy(output_ptr + step_info.OutputOffset(),
+              tflite::micro::GetTensorData<ActivationType>(hidden_state) +
+                  step_info.HiddenStateOffset(),
+              step_info.StateShape().FlatSize() * sizeof(ActivationType));
+}
+#else   // #if !(defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
+template <typename ActivationType, typename WeightType, typename CellType,
+          typename BiasType>
+void LstmStep(const LstmStepManager& step_info, const OpDataLSTM& op_data,
+              LSTMKernelContents& kernel_content,
+              const LSTMBuffers<CellType>& buffers) {
+  const TfLiteEvalTensor* input =
+      kernel_content.GetInternalTensor(tflite::kLstmInputTensor);
+  TfLiteEvalTensor* recurrent = kernel_content.HiddenStateTensor();
+
+  int time_major = step_info.time_major();
+  int num_batches = time_major == 0 ? 1 : step_info.batch_size();
+  int input_dimension = step_info.input_dimension();
+  int state_dimension = step_info.state_dimension();
+
+  // Check offset validity to avoid memory overflow
+  TFLITE_DCHECK_LE(step_info.InputOffset() + num_batches * input_dimension,
+                   tflite::micro::GetTensorShape(input).FlatSize());
+  TFLITE_DCHECK_LE(
+      step_info.HiddenStateOffset() + num_batches * state_dimension,
+      tflite::micro::GetTensorShape(recurrent).FlatSize());
+
+  /*Step1: Calculate gate outputs to prepare cell state update*/
+  CellType* gate_internal_buffer = buffers.buffer3;
+  CellType* forget_gate_output = buffers.buffer0;
+  CalculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
+      step_info, op_data.forget_gate_parameters,
+      // Input FC
+      input,  // kernel_content.GetInternalTensor(tflite::kLstmInputTensor),
+      kernel_content.GetInternalTensor(tflite::kLstmInputToForgetWeightsTensor),
+      kernel_content.GetInternalTensor(tflite::kLstmForgetGateBiasTensor),
+      // Recurrent FC
+      recurrent,  // kernel_content.HiddenStateTensor(),
+      kernel_content.GetInternalTensor(
+          tflite::kLstmRecurrentToForgetWeightsTensor),
+      /*recurrent_bias*/ nullptr,
+      // Output
+      forget_gate_output,
+      // Scratch arrays
+      gate_internal_buffer, kTfLiteActSigmoid, num_batches, input_dimension,
+      state_dimension);
+
+  // Input Gate calculation;
+  CellType* input_gate_output = buffers.buffer1;
+  CalculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
+      step_info, op_data.input_gate_parameters,
+      // Input FC
+      input,  // kernel_content.GetInternalTensor(tflite::kLstmInputTensor),
+      kernel_content.GetInternalTensor(tflite::kLstmInputToInputWeightsTensor),
+      kernel_content.GetInternalTensor(tflite::kLstmInputGateBiasTensor),
+      // Recurrent FC
+      recurrent,  // kernel_content.HiddenStateTensor(),
+      kernel_content.GetInternalTensor(
+          tflite::kLstmRecurrentToInputWeightsTensor),
+      /*recurrent_bias*/ nullptr,
+      // Output
+      input_gate_output,
+      // Scratch arrays
+      gate_internal_buffer, kTfLiteActSigmoid, num_batches, input_dimension,
+      state_dimension);
+
+  // Cell Gate calculation
+  CellType* cell_gate_output = buffers.buffer2;
+  CalculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
+      step_info, op_data.cell_gate_parameters,
+      // Input FC
+      input,  // kernel_content.GetInternalTensor(tflite::kLstmInputTensor),
+      kernel_content.GetInternalTensor(tflite::kLstmInputToCellWeightsTensor),
+      kernel_content.GetInternalTensor(tflite::kLstmCellGateBiasTensor),
+      // Recurrent FC
+      recurrent,  // kernel_content.HiddenStateTensor(),
+      kernel_content.GetInternalTensor(
+          tflite::kLstmRecurrentToCellWeightsTensor),
+      /*recurrent_bias*/ nullptr,
+      // Output
+      cell_gate_output,
+      // Scratch arrays
+      gate_internal_buffer, op_data.cell_gate_nonlinear_type, num_batches,
+      input_dimension, state_dimension);
+
+  /*Step2: update the cell state */
+  const InterGateParameters& inter_gate_params = op_data.inter_gate_parameters;
+  CellType* updated_input_buffer = buffers.buffer1;  // reuse buffer
+
+  UpdateLstmCell(step_info, kernel_content.CellStateTensor(),
+                 forget_gate_output, input_gate_output, cell_gate_output,
+                 inter_gate_params.forget_cell_mul_params,
+                 inter_gate_params.input_mul_params, op_data.cell_state_info,
+                 updated_input_buffer);
+
+  /*Step3: update the hidden state */
+  CellType* output_gate_output = buffers.buffer1;  // reuse buffer
+  CalculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
+      step_info, op_data.output_gate_parameters,
+      // Input FC
+      input,  // kernel_content.GetInternalTensor(tflite::kLstmInputTensor),
+      kernel_content.GetInternalTensor(tflite::kLstmInputToOutputWeightsTensor),
+      kernel_content.GetInternalTensor(tflite::kLstmOutputGateBiasTensor),
+      // Recurrent FC
+      recurrent,  // kernel_content.HiddenStateTensor(),
+      kernel_content.GetInternalTensor(
+          tflite::kLstmRecurrentToOutputWeightsTensor),
+      /*recurrent_bias*/ nullptr,
+      // Output
+      output_gate_output,
+      // Scratch arrays
+      gate_internal_buffer, kTfLiteActSigmoid, num_batches, input_dimension,
+      state_dimension);
+
+  CellType* tanh_activated_cell_buffer = buffers.buffer0;  // reuse buffer
+  tflite::lstm_internal::UpdateLstmHidden<CellType, ActivationType>(
+      step_info, kernel_content.CellStateTensor(), recurrent,
+      /* kernel_content.HiddenStateTensor(), */ output_gate_output,
+      inter_gate_params.output_mul_params,
+      op_data.cell_state_info.cell_state_scale_power,
+      tanh_activated_cell_buffer);
+
+  /*Step4: copy the update the hidden state to output*/
+  // Check offset validity to avoid memory overflow
+  TFLITE_DCHECK_LE(
+      step_info.OutputOffset() + step_info.StateShape().FlatSize(),
+      tflite::micro::GetTensorShape(kernel_content.output_tensor).FlatSize());
+  // record the output (from the updated hidden state)
+  ActivationType* output_ptr = tflite::micro::GetTensorData<ActivationType>(
+      kernel_content.output_tensor);
+  // const auto* hidden_state = kernel_content.HiddenStateTensor();
+  std::memcpy(output_ptr + step_info.OutputOffset(),
+              tflite::micro::GetTensorData<ActivationType>(recurrent) +
+                  step_info.HiddenStateOffset(),
+              step_info.StateShape().FlatSize() * sizeof(ActivationType));
+}
+#endif  // #if !(defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
+
+}  // namespace lstm_internal
+
+// Evaulate the LSTM kernel with (potential) multi-steps and multi-batch input
+// Since
+template <typename ActivationType, typename WeightType, typename CellType,
+          typename BiasType>
+TfLiteStatus EvalLstm(const OpDataLSTM& op_data,
+                      LSTMKernelContents& kernel_content,
+                      const LSTMBuffers<CellType>& buffers) {
+  lstm_internal::LstmStepManager step_info(&op_data.size_info);
+  const auto& size_info = op_data.size_info;
+  // time is the first dimention, enable batch computation
+  if (size_info.time_major) {
+    for (int t = 0; t < size_info.time_steps; t++) {
+      lstm_internal::LstmStep<ActivationType, WeightType, CellType, BiasType>(
+          step_info, op_data, kernel_content, buffers);
+      // prepare for the next time step
+      step_info.UpdateTime();
+    }
+  } else {
+    // batch first, unable to size the input data. single batch inference
+    for (int b = 0; b < size_info.batch_size; b++) {
+      for (int t = 0; t < size_info.time_steps; t++) {
+        lstm_internal::LstmStep<ActivationType, WeightType, CellType, BiasType>(
+            step_info, op_data, kernel_content, buffers);
+        // prepare for the next time step
+        step_info.UpdateTime();
+      }
+      // prepare for the next batch
+      step_info.UpdateBatch();
+      step_info.ResetTime();
+    }
+  }
+  return kTfLiteOk;
+}
 }  // namespace tflite
-#endif  // TENSORFLOW_LITE_KERNELS_LSTM_EVAL_H_
+
+#endif  // TENSORFLOW_LITE_MICRO_KERNELS_LSTM_EVAL_16ACT_H_
diff --git a/tensorflow/lite/micro/kernels/xtensa/lstm_eval_hifi.cc b/tensorflow/lite/micro/kernels/xtensa/lstm_eval_hifi.cc
index 2b49f26..a2d04ee 100644
--- a/tensorflow/lite/micro/kernels/xtensa/lstm_eval_hifi.cc
+++ b/tensorflow/lite/micro/kernels/xtensa/lstm_eval_hifi.cc
@@ -12,17 +12,65 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
+#if defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
+
+#include <xtensa/tie/xt_hifi2.h>
+
 #include "tensorflow/lite/c/builtin_op_data.h"
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/micro/kernels/xtensa/lstm_eval.h"
 #include "tensorflow/lite/micro/kernels/xtensa/xtensa.h"
 
 namespace tflite {
-namespace ops {
-namespace micro {
-namespace lstm_eval {
 
 #if defined(HIFI5)
+#if TFLITE_SINGLE_ROUNDING
+#define MPY_BY_QUANT_MULT_X2_OUT32(out, inp, multiplier, left_shift,  \
+                                   right_shift)                       \
+  {                                                                   \
+    ae_int64 out64_0, out64_1;                                        \
+    ae_int64 INT64_ONE = AE_MOVINT64_FROMINT32X2(AE_MOVDA32X2(0, 1)); \
+    ae_int64 round_val = AE_SLAA64S(INT64_ONE, 30 - left_shift);      \
+    AE_MUL32X2S_HH_LL(out64_0, out64_1, inp, AE_MOVDA32(multiplier)); \
+    out64_0 = AE_ADD64S(out64_0, round_val);                          \
+    out64_1 = AE_ADD64S(out64_1, round_val);                          \
+    out = AE_TRUNCA32X2F64S(out64_0, out64_1, 1 + left_shift);        \
+  }
+
+#define MPY_BY_QUANT_MULT_X2X2_OUT32(out1, out2, inp1, inp2, multiplier, \
+                                     left_shift, right_shift)            \
+  {                                                                      \
+    ae_int64 out64_0, out64_1, out64_2, out64_3;                         \
+    ae_int64 INT64_ONE = AE_MOVINT64_FROMINT32X2(AE_MOVDA32X2(0, 1));    \
+    ae_int64 round_val = AE_SLAA64S(INT64_ONE, 30 - left_shift);         \
+    AE_MUL32X2S_HH_LL(out64_0, out64_1, inp1, AE_MOVDA32(multiplier));   \
+    AE_MUL32X2S_HH_LL(out64_2, out64_3, inp2, AE_MOVDA32(multiplier));   \
+    out64_0 = AE_ADD64S(out64_0, round_val);                             \
+    out64_1 = AE_ADD64S(out64_1, round_val);                             \
+    out64_2 = AE_ADD64S(out64_2, round_val);                             \
+    out64_3 = AE_ADD64S(out64_3, round_val);                             \
+    out1 = AE_TRUNCA32X2F64S(out64_0, out64_1, 1 + left_shift);          \
+    out2 = AE_TRUNCA32X2F64S(out64_2, out64_3, 1 + left_shift);          \
+  }
+#else /* #if TFLITE_SINGLE_ROUNDING */
+#define MPY_BY_QUANT_MULT_X2_OUT32(out, inp, multiplier, left_shift, \
+                                   right_shift)                      \
+  out = AE_SLAA32(inp, left_shift);                                  \
+  out = AE_MULFP32X2RAS(out, AE_MOVDA32(multiplier));                \
+  out = AE_SRAA32SYMS(out, right_shift);
+
+#define MPY_BY_QUANT_MULT_X2X2_OUT32(out1, out2, inp1, inp2, multiplier, \
+                                     left_shift, right_shift)            \
+  {                                                                      \
+    ae_int32x2 d_ls = AE_MOVDA32(1 << left_shift);                       \
+    AE_MUL2P32X4(out1, out2, inp1, inp2, d_ls, d_ls);                    \
+    AE_MULF2P32X4RAS(out1, out2, out1, out2, AE_MOVDA32(multiplier),     \
+                     AE_MOVDA32(multiplier));                            \
+    out1 = AE_SRAA32SYMS(out1, right_shift);                             \
+    out2 = AE_SRAA32SYMS(out2, right_shift);                             \
+  }
+#endif /* #if TFLITE_SINGLE_ROUNDING */
+
 void calc_cell_state_without_cifg(int16_t* cell_state,
                                   const int16_t* forget_gate,
                                   const int16_t* cell_gate,
@@ -124,7 +172,7 @@
 
       AE_MUL16X4(d_mul_0, d_mul_1, d_cg_0, d_ig_0);
       d_mul_0 = AE_SRAA32SYMS(d_mul_0, shift2);
-      d_cg_0 = AE_SAT16X4(d_mul_0, d_mul_1);
+      d_cg_0 = AE_SAT16X4(d_mul_0, d_mul_0);
 
       d_cs_w_0 = AE_ADD16S(d_cs_w_0, d_cg_0);
       AE_MINMAX16(d_cs_w_0, d_min, d_max);
@@ -187,11 +235,11 @@
 
       AE_MUL16X4(d_mul_0, d_mul_1, d_cs_r_0, d_fg_0);
       d_mul_0 = AE_SRAA32SYMS(d_mul_0, shift1);
-      d_cs_w_0 = AE_SAT16X4(d_mul_0, d_mul_1);
+      d_cs_w_0 = AE_SAT16X4(d_mul_0, d_mul_0);
 
       AE_MUL16X4(d_mul_0, d_mul_1, d_cg_0, d_ig_0);
       d_mul_0 = AE_SRAA32SYMS(d_mul_0, shift2);
-      d_cg_0 = AE_SAT16X4(d_mul_0, d_mul_1);
+      d_cg_0 = AE_SAT16X4(d_mul_0, d_mul_0);
 
       d_cs_w_0 = AE_ADD16S(d_cs_w_0, d_cg_0);
       AE_MINMAX16(d_cs_w_0, d_min, d_max);
@@ -298,7 +346,7 @@
       d_1mfg_0 = AE_SUB16S(d_one, d_fg_0);
       AE_MUL16X4(d_mul_0, d_mul_1, d_cg_0, d_1mfg_0);
       d_mul_0 = AE_SRAA32SYMS(d_mul_0, shift2);
-      d_cg_0 = AE_SAT16X4(d_mul_0, d_mul_1);
+      d_cg_0 = AE_SAT16X4(d_mul_0, d_mul_0);
 
       d_cs_w_0 = AE_ADD16S(d_cs_w_0, d_cg_0);
       AE_MINMAX16(d_cs_w_0, d_min, d_max);
@@ -360,12 +408,12 @@
 
       AE_MUL16X4(d_mul_0, d_mul_1, d_cs_r_0, d_fg_0);
       d_mul_0 = AE_SRAA32SYMS(d_mul_0, shift1);
-      d_cs_w_0 = AE_SAT16X4(d_mul_0, d_mul_1);
+      d_cs_w_0 = AE_SAT16X4(d_mul_0, d_mul_0);
 
       d_1mfg_0 = AE_SUB16S(d_one, d_fg_0);
       AE_MUL16X4(d_mul_0, d_mul_1, d_cg_0, d_1mfg_0);
       d_mul_0 = AE_SRAA32SYMS(d_mul_0, shift2);
-      d_cg_0 = AE_SAT16X4(d_mul_0, d_mul_1);
+      d_cg_0 = AE_SAT16X4(d_mul_0, d_mul_0);
 
       d_cs_w_0 = AE_ADD16S(d_cs_w_0, d_cg_0);
       AE_MINMAX16(d_cs_w_0, d_min, d_max);
@@ -404,8 +452,13 @@
   d_multiplier = AE_MOVDA32(multiplier);
   d_zp = AE_MOVDA16(zero_point);
 
+#if TFLITE_SINGLE_ROUNDING
+  left_shift = shift;
+  (void)right_shift;
+#else  /* #if TFLITE_SINGLE_ROUNDING */
   left_shift = shift < 0 ? 0 : shift;
   right_shift = shift > 0 ? 0 : -shift;
+#endif /* #if TFLITE_SINGLE_ROUNDING */
 
   d_left_shift = AE_MOVDA32(1 << left_shift);
 #pragma concurrent
@@ -415,18 +468,10 @@
 
     AE_MUL16X4(data_ab_0, data_ab_1, data_a_0, data_b_0);
     AE_MUL16X4(data_ab_2, data_ab_3, data_a_1, data_b_1);
-    AE_MUL2P32X4(data_ab_0, data_ab_1, data_ab_0, data_ab_1, d_left_shift,
-                 d_left_shift);
-    AE_MUL2P32X4(data_ab_2, data_ab_3, data_ab_2, data_ab_3, d_left_shift,
-                 d_left_shift);
-    AE_MULF2P32X4RAS(data_ab_0, data_ab_1, data_ab_0, data_ab_1, d_multiplier,
-                     d_multiplier);
-    AE_MULF2P32X4RAS(data_ab_2, data_ab_3, data_ab_2, data_ab_3, d_multiplier,
-                     d_multiplier);
-    data_ab_0 = AE_SRAA32SYMS(data_ab_0, right_shift);
-    data_ab_1 = AE_SRAA32SYMS(data_ab_1, right_shift);
-    data_ab_2 = AE_SRAA32SYMS(data_ab_2, right_shift);
-    data_ab_3 = AE_SRAA32SYMS(data_ab_3, right_shift);
+    MPY_BY_QUANT_MULT_X2X2_OUT32(data_ab_0, data_ab_1, data_ab_0, data_ab_1,
+                                 multiplier, left_shift, right_shift);
+    MPY_BY_QUANT_MULT_X2X2_OUT32(data_ab_2, data_ab_3, data_ab_2, data_ab_3,
+                                 multiplier, left_shift, right_shift);
     data_c_0 = AE_SAT16X4(data_ab_0, data_ab_1);
     data_c_1 = AE_SAT16X4(data_ab_2, data_ab_3);
     data_c_0 = AE_SUB16S(data_c_0, d_zp);
@@ -445,18 +490,532 @@
     AE_L16_IP(data_b_0, (ae_int16*)tmp_input_2, 2);
 
     AE_MUL16X4(data_ab_0, data_ab_1, data_a_0, data_b_0);
-    data_ab_0 = AE_MULP32X2(data_ab_0, d_left_shift);
-    data_ab_0 = AE_MULFP32X2RAS(data_ab_0, d_multiplier);
-    data_ab_0 = AE_SRAA32SYMS(data_ab_0, right_shift);
-    data_c_0 = AE_SAT16X4(data_ab_0, data_ab_1);
+    MPY_BY_QUANT_MULT_X2_OUT32(data_ab_0, data_ab_0, multiplier, left_shift,
+                               right_shift);
+    data_c_0 = AE_SAT16X4(data_ab_0, data_ab_0);
     data_c_0 = AE_SUB16S(data_c_0, d_zp);
     data_c = AE_SAT8X8X16(data_c_0, data_c_0);
     AE_S8_0_IP(data_c, (ae_int8*)output, 1);
   }
 }
+#elif defined(HIFI3) || defined(HIFI4)
+#if TFLITE_SINGLE_ROUNDING
+#define MPY_BY_QUANT_MULT_X2_OUT32(out, inp, multiplier, l_shift, r_shift) \
+  {                                                                        \
+    ae_int64 out64_0, out64_1;                                             \
+    out64_0 = AE_MUL32_HH(inp, AE_MOVDA32(multiplier));                    \
+    out64_1 = AE_MUL32_LL(inp, AE_MOVDA32(multiplier));                    \
+    out64_0 = AE_SLAA64S(out64_0, 1 + l_shift);                            \
+    out64_1 = AE_SLAA64S(out64_1, 1 + l_shift);                            \
+    out = AE_ROUND32X2F64SASYM(out64_0, out64_1);                          \
+  }
+
+#define MPY_BY_QUANT_MULT_X2X2_OUT32(out1, out2, inp1, inp2, multiplier, \
+                                     l_shift, r_shift)                   \
+  {                                                                      \
+    ae_int64 out64_0, out64_1, out64_2, out64_3;                         \
+    out64_0 = AE_MUL32_HH(inp1, AE_MOVDA32(multiplier));                 \
+    out64_1 = AE_MUL32_LL(inp1, AE_MOVDA32(multiplier));                 \
+    out64_2 = AE_MUL32_HH(inp2, AE_MOVDA32(multiplier));                 \
+    out64_3 = AE_MUL32_LL(inp2, AE_MOVDA32(multiplier));                 \
+    out64_0 = AE_SLAA64S(out64_0, 1 + l_shift);                          \
+    out64_1 = AE_SLAA64S(out64_1, 1 + l_shift);                          \
+    out64_2 = AE_SLAA64S(out64_2, 1 + l_shift);                          \
+    out64_3 = AE_SLAA64S(out64_3, 1 + l_shift);                          \
+    out1 = AE_ROUND32X2F64SASYM(out64_0, out64_1);                       \
+    out2 = AE_ROUND32X2F64SASYM(out64_2, out64_3);                       \
+  }
+#else /* #if TFLITE_SINGLE_ROUNDING */
+#define MPY_BY_QUANT_MULT_X2_OUT32(out, inp, multiplier, l_shift, r_shift) \
+  out = AE_SLAA32(inp, l_shift);                                           \
+  out = AE_MULFP32X2RAS(out, AE_MOVDA32(multiplier));                      \
+  out = AE_ROUND32X2F64SSYM(AE_SRAA64(AE_CVT64F32_H(out), r_shift),        \
+                            AE_SRAA64(AE_CVT64F32_L(out), r_shift));
+
+#define MPY_BY_QUANT_MULT_X2X2_OUT32(out1, out2, inp1, inp2, multiplier, \
+                                     l_shift, r_shift)                   \
+  {                                                                      \
+    ae_int32x2 d_ls = AE_MOVDA32(1 << l_shift);                          \
+    out1 = AE_MULP32X2(inp1, d_ls);                                      \
+    out2 = AE_MULP32X2(inp2, d_ls);                                      \
+    out1 = AE_MULFP32X2RAS(out1, AE_MOVDA32(multiplier));                \
+    out2 = AE_MULFP32X2RAS(out2, AE_MOVDA32(multiplier));                \
+    out1 = AE_ROUND32X2F64SSYM(AE_SRAA64(AE_CVT64F32_H(out1), r_shift),  \
+                               AE_SRAA64(AE_CVT64F32_L(out1), r_shift)); \
+    out2 = AE_ROUND32X2F64SSYM(AE_SRAA64(AE_CVT64F32_H(out2), r_shift),  \
+                               AE_SRAA64(AE_CVT64F32_L(out2), r_shift)); \
+  }
+#endif /* #if TFLITE_SINGLE_ROUNDING */
+
+#ifndef AE_MULFP16X4RS
+static inline ae_f16x4 AE_MULFP16X4RS(ae_f16x4 d0, ae_f16x4 d1) {
+  ae_f16x4 output;
+  ae_f32x2 d0_32_0, d0_32_1, out32_0, out32_1;
+  ae_f16x4 one_d = AE_MOVDA16(1);
+  AE_MUL16X4(d0_32_0, d0_32_1, d0, one_d);
+  out32_0 = AE_MULFP32X16X2RS_H(d0_32_0, d1);
+  out32_1 = AE_MULFP32X16X2RS_L(d0_32_1, d1);
+  output = AE_SEL16_6420(AE_MOVF16X4_FROMF32X2(out32_0),
+                         AE_MOVF16X4_FROMF32X2(out32_1));
+  return output;
+}
+#endif
+
+#ifndef AE_MINMAX16
+#define AE_MINMAX16(dinout, d_min, d_max) \
+  {                                       \
+    xtbool4 b0 = AE_LT16(dinout, d_min);  \
+    AE_MOVT16X4(dinout, d_min, b0);       \
+    b0 = AE_LT16(d_max, dinout);          \
+    AE_MOVT16X4(dinout, d_max, b0);       \
+  }
+#endif
+
+#ifndef AE_SRAA32SYMS
+#define AE_SRAA32SYMS(inp, right_shift)                           \
+  AE_ROUND32X2F64SSYM(AE_SRAA64(AE_CVT64F32_H(inp), right_shift), \
+                      AE_SRAA64(AE_CVT64F32_L(inp), right_shift))
+#endif
+
+void calc_cell_state_without_cifg(int16_t* cell_state,
+                                  const int16_t* forget_gate,
+                                  const int16_t* cell_gate,
+                                  const int16_t* input_gate, int shift1,
+                                  int shift2, int clip, int num_elms) {
+  const ae_int16x4 *p16x4_cs_r, *p16x4_fg_r;
+  const ae_int16x4 *p16x4_cg_r, *p16x4_ig_r;
+
+  ae_int16x4* p16x4_cs_w;
+
+  ae_valign align_cs_r, align_fg_r;
+  ae_valign align_cg_r, align_ig_r;
+  ae_valign align_cs_w;
+
+  ae_int16x4 d_cs_r_0, d_cs_r_1;
+  ae_int16x4 d_fg_0, d_fg_1;
+  ae_int16x4 d_cg_0, d_cg_1;
+  ae_int16x4 d_ig_0, d_ig_1;
+  ae_int16x4 d_cs_w_0, d_cs_w_1;
+  ae_int32x2 d_mul_0, d_mul_1, d_mul_2, d_mul_3;
+  ae_int32x2 d_mul_4, d_mul_5, d_mul_6, d_mul_7;
+
+  ae_int16x4 d_min, d_max;
+
+  int i = 0;
+  p16x4_cs_r = (const ae_int16x4*)cell_state;
+  p16x4_fg_r = (const ae_int16x4*)forget_gate;
+  p16x4_cg_r = (const ae_int16x4*)cell_gate;
+  p16x4_ig_r = (const ae_int16x4*)input_gate;
+
+  p16x4_cs_w = (ae_int16x4*)cell_state;
+
+  align_cs_r = AE_LA64_PP(p16x4_cs_r);
+  align_fg_r = AE_LA64_PP(p16x4_fg_r);
+  align_cg_r = AE_LA64_PP(p16x4_cg_r);
+  align_ig_r = AE_LA64_PP(p16x4_ig_r);
+
+  align_cs_w = AE_ZALIGN64();
+
+  if (clip > 0) {
+    d_min = AE_MOVDA16(-clip);
+    d_max = AE_MOVDA16(clip);
+  } else {
+    d_min = AE_MOVDA16(-32768);
+    d_max = AE_MOVDA16(32767);
+  }
+
+#pragma concurrent
+  if (shift1 == 15) {
+    for (i = 0; i < (num_elms >> 3); i++) {
+      AE_LA16X4_IP(d_cs_r_0, align_cs_r, p16x4_cs_r);
+      AE_LA16X4_IP(d_cs_r_1, align_cs_r, p16x4_cs_r);
+      AE_LA16X4_IP(d_fg_0, align_fg_r, p16x4_fg_r);
+      AE_LA16X4_IP(d_fg_1, align_fg_r, p16x4_fg_r);
+      AE_LA16X4_IP(d_cg_0, align_cg_r, p16x4_cg_r);
+      AE_LA16X4_IP(d_cg_1, align_cg_r, p16x4_cg_r);
+      AE_LA16X4_IP(d_ig_0, align_ig_r, p16x4_ig_r);
+      AE_LA16X4_IP(d_ig_1, align_ig_r, p16x4_ig_r);
+
+      d_cs_w_0 = AE_MULFP16X4RS(d_cs_r_0, d_fg_0);
+      d_cs_w_1 = AE_MULFP16X4RS(d_cs_r_1, d_fg_1);
+
+      AE_MUL16X4(d_mul_4, d_mul_5, d_cg_0, d_ig_0);
+      AE_MUL16X4(d_mul_6, d_mul_7, d_cg_1, d_ig_1);
+      d_mul_4 = AE_SRAA32SYMS(d_mul_4, shift2);
+      d_mul_5 = AE_SRAA32SYMS(d_mul_5, shift2);
+      d_mul_6 = AE_SRAA32SYMS(d_mul_6, shift2);
+      d_mul_7 = AE_SRAA32SYMS(d_mul_7, shift2);
+
+      d_cg_0 = AE_SAT16X4(d_mul_4, d_mul_5);
+      d_cg_1 = AE_SAT16X4(d_mul_6, d_mul_7);
+      d_cs_w_0 = AE_ADD16S(d_cs_w_0, d_cg_0);
+      d_cs_w_1 = AE_ADD16S(d_cs_w_1, d_cg_1);
+      AE_MINMAX16(d_cs_w_0, d_min, d_max);
+      AE_MINMAX16(d_cs_w_1, d_min, d_max);
+
+      AE_SA16X4_IP(d_cs_w_0, align_cs_w, p16x4_cs_w);
+      AE_SA16X4_IP(d_cs_w_1, align_cs_w, p16x4_cs_w);
+    }
+    AE_SA64POS_FP(align_cs_w, p16x4_cs_w);  // finalize the stream
+
+    const ae_int16 *p16_cs_r, *p16_fg_r;
+    const ae_int16 *p16_cg_r, *p16_ig_r;
+
+    ae_int16* p16_cs_w;
+
+    p16_cs_r = (const ae_int16*)p16x4_cs_r;
+    p16_fg_r = (const ae_int16*)p16x4_fg_r;
+    p16_cg_r = (const ae_int16*)p16x4_cg_r;
+    p16_ig_r = (const ae_int16*)p16x4_ig_r;
+
+    p16_cs_w = (ae_int16*)p16x4_cs_w;
+    // residue iterations
+#pragma concurrent
+#pragma loop_count max = 7
+    for (i = 0; i < ((num_elms)&7); i++) {
+      d_cs_r_0 = p16_cs_r[i];
+      d_fg_0 = p16_fg_r[i];
+      d_cg_0 = p16_cg_r[i];
+      d_ig_0 = p16_ig_r[i];
+
+      d_cs_w_0 = AE_MULFP16X4RS(d_cs_r_0, d_fg_0);
+
+      AE_MUL16X4(d_mul_0, d_mul_1, d_cg_0, d_ig_0);
+      d_mul_0 = AE_SRAA32SYMS(d_mul_0, shift2);
+      d_cg_0 = AE_SAT16X4(d_mul_0, d_mul_0);
+
+      d_cs_w_0 = AE_ADD16S(d_cs_w_0, d_cg_0);
+      AE_MINMAX16(d_cs_w_0, d_min, d_max);
+      p16_cs_w[i] = d_cs_w_0;
+    }
+  } else {
+    for (i = 0; i < (num_elms >> 3); i++) {
+      AE_LA16X4_IP(d_cs_r_0, align_cs_r, p16x4_cs_r);
+      AE_LA16X4_IP(d_cs_r_1, align_cs_r, p16x4_cs_r);
+      AE_LA16X4_IP(d_fg_0, align_fg_r, p16x4_fg_r);
+      AE_LA16X4_IP(d_fg_1, align_fg_r, p16x4_fg_r);
+      AE_LA16X4_IP(d_cg_0, align_cg_r, p16x4_cg_r);
+      AE_LA16X4_IP(d_cg_1, align_cg_r, p16x4_cg_r);
+      AE_LA16X4_IP(d_ig_0, align_ig_r, p16x4_ig_r);
+      AE_LA16X4_IP(d_ig_1, align_ig_r, p16x4_ig_r);
+
+      AE_MUL16X4(d_mul_0, d_mul_1, d_cs_r_0, d_fg_0);
+      AE_MUL16X4(d_mul_2, d_mul_3, d_cs_r_1, d_fg_1);
+      d_mul_0 = AE_SRAA32SYMS(d_mul_0, shift1);
+      d_mul_1 = AE_SRAA32SYMS(d_mul_1, shift1);
+      d_mul_2 = AE_SRAA32SYMS(d_mul_2, shift1);
+      d_mul_3 = AE_SRAA32SYMS(d_mul_3, shift1);
+
+      d_cs_w_0 = AE_SAT16X4(d_mul_0, d_mul_1);
+      d_cs_w_1 = AE_SAT16X4(d_mul_2, d_mul_3);
+
+      AE_MUL16X4(d_mul_4, d_mul_5, d_cg_0, d_ig_0);
+      AE_MUL16X4(d_mul_6, d_mul_7, d_cg_1, d_ig_1);
+      d_mul_4 = AE_SRAA32SYMS(d_mul_4, shift2);
+      d_mul_5 = AE_SRAA32SYMS(d_mul_5, shift2);
+      d_mul_6 = AE_SRAA32SYMS(d_mul_6, shift2);
+      d_mul_7 = AE_SRAA32SYMS(d_mul_7, shift2);
+
+      d_cg_0 = AE_SAT16X4(d_mul_4, d_mul_5);
+      d_cg_1 = AE_SAT16X4(d_mul_6, d_mul_7);
+      d_cs_w_0 = AE_ADD16S(d_cs_w_0, d_cg_0);
+      d_cs_w_1 = AE_ADD16S(d_cs_w_1, d_cg_1);
+      AE_MINMAX16(d_cs_w_0, d_min, d_max);
+      AE_MINMAX16(d_cs_w_1, d_min, d_max);
+
+      AE_SA16X4_IP(d_cs_w_0, align_cs_w, p16x4_cs_w);
+      AE_SA16X4_IP(d_cs_w_1, align_cs_w, p16x4_cs_w);
+    }
+    AE_SA64POS_FP(align_cs_w, p16x4_cs_w);  // finalize the stream
+
+    const ae_int16 *p16_cs_r, *p16_fg_r;
+    const ae_int16 *p16_cg_r, *p16_ig_r;
+
+    ae_int16* p16_cs_w;
+
+    p16_cs_r = (const ae_int16*)p16x4_cs_r;
+    p16_fg_r = (const ae_int16*)p16x4_fg_r;
+    p16_cg_r = (const ae_int16*)p16x4_cg_r;
+    p16_ig_r = (const ae_int16*)p16x4_ig_r;
+
+    p16_cs_w = (ae_int16*)p16x4_cs_w;
+    // residue iterations
+#pragma concurrent
+#pragma loop_count max = 7
+    for (i = 0; i < ((num_elms)&7); i++) {
+      d_cs_r_0 = p16_cs_r[i];
+      d_fg_0 = p16_fg_r[i];
+      d_cg_0 = p16_cg_r[i];
+      d_ig_0 = p16_ig_r[i];
+
+      AE_MUL16X4(d_mul_0, d_mul_1, d_cs_r_0, d_fg_0);
+      d_mul_0 = AE_SRAA32SYMS(d_mul_0, shift1);
+      d_cs_w_0 = AE_SAT16X4(d_mul_0, d_mul_0);
+
+      AE_MUL16X4(d_mul_0, d_mul_1, d_cg_0, d_ig_0);
+      d_mul_0 = AE_SRAA32SYMS(d_mul_0, shift2);
+      d_cg_0 = AE_SAT16X4(d_mul_0, d_mul_0);
+
+      d_cs_w_0 = AE_ADD16S(d_cs_w_0, d_cg_0);
+      AE_MINMAX16(d_cs_w_0, d_min, d_max);
+      p16_cs_w[i] = d_cs_w_0;
+    }
+  }
+}
+
+void calc_cell_state_with_cifg(int16_t* cell_state, const int16_t* forget_gate,
+                               const int16_t* cell_gate, int shift1, int shift2,
+                               int clip, int num_elms) {
+  const ae_int16x4 *p16x4_cs_r, *p16x4_fg_r;
+  const ae_int16x4* p16x4_cg_r;
+
+  ae_int16x4* p16x4_cs_w;
+
+  ae_valign align_cs_r, align_fg_r;
+  ae_valign align_cg_r;
+  ae_valign align_cs_w;
+
+  ae_int16x4 d_cs_r_0, d_cs_r_1;
+  ae_int16x4 d_fg_0, d_fg_1;
+  ae_int16x4 d_cg_0, d_cg_1;
+  ae_int16x4 d_1mfg_0, d_1mfg_1;
+  ae_int16x4 d_cs_w_0, d_cs_w_1;
+  ae_int32x2 d_mul_0, d_mul_1, d_mul_2, d_mul_3;
+  ae_int32x2 d_mul_4, d_mul_5, d_mul_6, d_mul_7;
+
+  ae_int16x4 d_min, d_max, d_one;
+
+  int i = 0;
+  p16x4_cs_r = (const ae_int16x4*)cell_state;
+  p16x4_fg_r = (const ae_int16x4*)forget_gate;
+  p16x4_cg_r = (const ae_int16x4*)cell_gate;
+
+  p16x4_cs_w = (ae_int16x4*)cell_state;
+
+  align_cs_r = AE_LA64_PP(p16x4_cs_r);
+  align_fg_r = AE_LA64_PP(p16x4_fg_r);
+  align_cg_r = AE_LA64_PP(p16x4_cg_r);
+
+  align_cs_w = AE_ZALIGN64();
+
+  if (clip > 0) {
+    d_min = AE_MOVDA16(-clip);
+    d_max = AE_MOVDA16(clip);
+  } else {
+    d_min = AE_MOVDA16(-32768);
+    d_max = AE_MOVDA16(32767);
+  }
+  d_one = AE_MOVDA16(32767);
+
+#pragma concurrent
+  if (shift1 == 15) {
+    for (i = 0; i < (num_elms >> 3); i++) {
+      AE_LA16X4_IP(d_cs_r_0, align_cs_r, p16x4_cs_r);
+      AE_LA16X4_IP(d_cs_r_1, align_cs_r, p16x4_cs_r);
+      AE_LA16X4_IP(d_fg_0, align_fg_r, p16x4_fg_r);
+      AE_LA16X4_IP(d_fg_1, align_fg_r, p16x4_fg_r);
+      AE_LA16X4_IP(d_cg_0, align_cg_r, p16x4_cg_r);
+      AE_LA16X4_IP(d_cg_1, align_cg_r, p16x4_cg_r);
+
+      d_cs_w_0 = AE_MULFP16X4RS(d_cs_r_0, d_fg_0);
+      d_cs_w_1 = AE_MULFP16X4RS(d_cs_r_1, d_fg_1);
+
+      d_1mfg_0 = AE_SUB16S(d_one, d_fg_0);
+      d_1mfg_1 = AE_SUB16S(d_one, d_fg_1);
+      AE_MUL16X4(d_mul_4, d_mul_5, d_cg_0, d_1mfg_0);
+      AE_MUL16X4(d_mul_6, d_mul_7, d_cg_1, d_1mfg_1);
+      d_mul_4 = AE_SRAA32SYMS(d_mul_4, shift2);
+      d_mul_5 = AE_SRAA32SYMS(d_mul_5, shift2);
+      d_mul_6 = AE_SRAA32SYMS(d_mul_6, shift2);
+      d_mul_7 = AE_SRAA32SYMS(d_mul_7, shift2);
+      d_cg_0 = AE_SAT16X4(d_mul_4, d_mul_5);
+      d_cg_1 = AE_SAT16X4(d_mul_6, d_mul_7);
+
+      d_cs_w_0 = AE_ADD16S(d_cs_w_0, d_cg_0);
+      d_cs_w_1 = AE_ADD16S(d_cs_w_1, d_cg_1);
+
+      AE_MINMAX16(d_cs_w_0, d_min, d_max);
+      AE_MINMAX16(d_cs_w_1, d_min, d_max);
+
+      AE_SA16X4_IP(d_cs_w_0, align_cs_w, p16x4_cs_w);
+      AE_SA16X4_IP(d_cs_w_1, align_cs_w, p16x4_cs_w);
+    }
+    AE_SA64POS_FP(align_cs_w, p16x4_cs_w);  // finalize the stream
+
+    const ae_int16 *p16_cs_r, *p16_fg_r;
+    const ae_int16* p16_cg_r;
+
+    ae_int16* p16_cs_w;
+
+    p16_cs_r = (const ae_int16*)p16x4_cs_r;
+    p16_fg_r = (const ae_int16*)p16x4_fg_r;
+    p16_cg_r = (const ae_int16*)p16x4_cg_r;
+
+    p16_cs_w = (ae_int16*)p16x4_cs_w;
+    // residue iterations
+#pragma concurrent
+#pragma loop_count max = 7
+    for (i = 0; i < ((num_elms)&7); i++) {
+      d_cs_r_0 = p16_cs_r[i];
+      d_fg_0 = p16_fg_r[i];
+      d_cg_0 = p16_cg_r[i];
+
+      d_cs_w_0 = AE_MULFP16X4RS(d_cs_r_0, d_fg_0);
+
+      d_1mfg_0 = AE_SUB16S(d_one, d_fg_0);
+      AE_MUL16X4(d_mul_0, d_mul_1, d_cg_0, d_1mfg_0);
+      d_mul_0 = AE_SRAA32SYMS(d_mul_0, shift2);
+      d_cg_0 = AE_SAT16X4(d_mul_0, d_mul_0);
+
+      d_cs_w_0 = AE_ADD16S(d_cs_w_0, d_cg_0);
+      AE_MINMAX16(d_cs_w_0, d_min, d_max);
+      p16_cs_w[i] = d_cs_w_0;
+    }
+  } else {
+    for (i = 0; i < (num_elms >> 3); i++) {
+      AE_LA16X4_IP(d_cs_r_0, align_cs_r, p16x4_cs_r);
+      AE_LA16X4_IP(d_cs_r_1, align_cs_r, p16x4_cs_r);
+      AE_LA16X4_IP(d_fg_0, align_fg_r, p16x4_fg_r);
+      AE_LA16X4_IP(d_fg_1, align_fg_r, p16x4_fg_r);
+      AE_LA16X4_IP(d_cg_0, align_cg_r, p16x4_cg_r);
+      AE_LA16X4_IP(d_cg_1, align_cg_r, p16x4_cg_r);
+
+      AE_MUL16X4(d_mul_0, d_mul_1, d_cs_r_0, d_fg_0);
+      AE_MUL16X4(d_mul_2, d_mul_3, d_cs_r_1, d_fg_1);
+      d_mul_0 = AE_SRAA32SYMS(d_mul_0, shift1);
+      d_mul_1 = AE_SRAA32SYMS(d_mul_1, shift1);
+      d_mul_2 = AE_SRAA32SYMS(d_mul_2, shift1);
+      d_mul_3 = AE_SRAA32SYMS(d_mul_3, shift1);
+      d_cs_w_0 = AE_SAT16X4(d_mul_0, d_mul_1);
+      d_cs_w_1 = AE_SAT16X4(d_mul_2, d_mul_3);
+
+      d_1mfg_0 = AE_SUB16S(d_one, d_fg_0);
+      d_1mfg_1 = AE_SUB16S(d_one, d_fg_1);
+      AE_MUL16X4(d_mul_4, d_mul_5, d_cg_0, d_1mfg_0);
+      AE_MUL16X4(d_mul_6, d_mul_7, d_cg_1, d_1mfg_1);
+      d_mul_4 = AE_SRAA32SYMS(d_mul_4, shift2);
+      d_mul_5 = AE_SRAA32SYMS(d_mul_5, shift2);
+      d_mul_6 = AE_SRAA32SYMS(d_mul_6, shift2);
+      d_mul_7 = AE_SRAA32SYMS(d_mul_7, shift2);
+      d_cg_0 = AE_SAT16X4(d_mul_4, d_mul_5);
+      d_cg_1 = AE_SAT16X4(d_mul_6, d_mul_7);
+
+      d_cs_w_0 = AE_ADD16S(d_cs_w_0, d_cg_0);
+      d_cs_w_1 = AE_ADD16S(d_cs_w_1, d_cg_1);
+
+      AE_MINMAX16(d_cs_w_0, d_min, d_max);
+      AE_MINMAX16(d_cs_w_1, d_min, d_max);
+
+      AE_SA16X4_IP(d_cs_w_0, align_cs_w, p16x4_cs_w);
+      AE_SA16X4_IP(d_cs_w_1, align_cs_w, p16x4_cs_w);
+    }
+    AE_SA64POS_FP(align_cs_w, p16x4_cs_w);  // finalize the stream
+
+    const ae_int16 *p16_cs_r, *p16_fg_r;
+    const ae_int16* p16_cg_r;
+
+    ae_int16* p16_cs_w;
+
+    p16_cs_r = (const ae_int16*)p16x4_cs_r;
+    p16_fg_r = (const ae_int16*)p16x4_fg_r;
+    p16_cg_r = (const ae_int16*)p16x4_cg_r;
+
+    p16_cs_w = (ae_int16*)p16x4_cs_w;
+    // residue iterations
+#pragma concurrent
+#pragma loop_count max = 7
+    for (i = 0; i < ((num_elms)&7); i++) {
+      d_cs_r_0 = p16_cs_r[i];
+      d_fg_0 = p16_fg_r[i];
+      d_cg_0 = p16_cg_r[i];
+
+      AE_MUL16X4(d_mul_0, d_mul_1, d_cs_r_0, d_fg_0);
+      d_mul_0 = AE_SRAA32SYMS(d_mul_0, shift1);
+      d_cs_w_0 = AE_SAT16X4(d_mul_0, d_mul_0);
+
+      d_1mfg_0 = AE_SUB16S(d_one, d_fg_0);
+      AE_MUL16X4(d_mul_0, d_mul_1, d_cg_0, d_1mfg_0);
+      d_mul_0 = AE_SRAA32SYMS(d_mul_0, shift2);
+      d_cg_0 = AE_SAT16X4(d_mul_0, d_mul_0);
+
+      d_cs_w_0 = AE_ADD16S(d_cs_w_0, d_cg_0);
+      AE_MINMAX16(d_cs_w_0, d_min, d_max);
+      p16_cs_w[i] = d_cs_w_0;
+    }
+  }
+}
+
+void xa_nn_elm_mul_16x16_asym8s(int8_t* output, const int16_t* input_1,
+                                const int16_t* input_2, int32_t multiplier,
+                                int32_t shift, int32_t zero_point,
+                                int num_elms) {
+  ae_int16x4* tmp_input_1;
+  ae_int16x4* tmp_input_2;
+
+  ae_valign align_src_input_1, align_src_input_2;
+
+  ae_int16x4 data_a_0, data_b_0;
+  ae_int32x2 data_ab_0, data_ab_1;
+  ae_int16x4 d_zp;
+  ae_int16x4 data_c_0;
+  ae_int16x4 d_min8 = AE_MOVDA16(-128);
+  ae_int16x4 d_max8 = AE_MOVDA16(127);
+
+  int i = 0;
+  int left_shift, right_shift;
+  tmp_input_1 = (ae_int16x4*)(input_1);
+  tmp_input_2 = (ae_int16x4*)(input_2);
+
+  align_src_input_1 = AE_LA64_PP((ae_int16x4*)tmp_input_1);
+  align_src_input_2 = AE_LA64_PP((ae_int16x4*)tmp_input_2);
+
+  d_zp = AE_MOVDA16(zero_point);
+
+#if TFLITE_SINGLE_ROUNDING
+  left_shift = shift;
+  (void)right_shift;
+#else  /* #if TFLITE_SINGLE_ROUNDING */
+  left_shift = shift < 0 ? 0 : shift;
+  right_shift = shift > 0 ? 0 : -shift;
+#endif /* #if TFLITE_SINGLE_ROUNDING */
+
+#pragma concurrent
+  for (i = 0; i < (num_elms >> 2); i++) {
+    AE_LA16X4_IP(data_a_0, align_src_input_1, tmp_input_1);
+    AE_LA16X4_IP(data_b_0, align_src_input_2, tmp_input_2);
+
+    AE_MUL16X4(data_ab_0, data_ab_1, data_a_0, data_b_0);
+    MPY_BY_QUANT_MULT_X2X2_OUT32(data_ab_0, data_ab_1, data_ab_0, data_ab_1,
+                                 multiplier, left_shift, right_shift);
+    data_c_0 = AE_SAT16X4(data_ab_0, data_ab_1);
+    data_c_0 = AE_SUB16S(data_c_0, d_zp);
+    AE_MINMAX16(data_c_0, d_min8, d_max8);
+
+    *output++ = AE_MOVAD16_3(data_c_0);
+    *output++ = AE_MOVAD16_2(data_c_0);
+    *output++ = AE_MOVAD16_1(data_c_0);
+    *output++ = AE_MOVAD16_0(data_c_0);
+  }
+
+  // residue iterations
+#pragma concurrent
+#pragma loop_count max = 3
+  for (int j = 0; j < ((num_elms)&3); j++) {
+    AE_L16_IP(data_a_0, (ae_int16*)tmp_input_1, 2);
+    AE_L16_IP(data_b_0, (ae_int16*)tmp_input_2, 2);
+
+    AE_MUL16X4(data_ab_0, data_ab_1, data_a_0, data_b_0);
+    MPY_BY_QUANT_MULT_X2_OUT32(data_ab_0, data_ab_0, multiplier, left_shift,
+                               right_shift);
+    data_c_0 = AE_SAT16X4(data_ab_0, data_ab_0);
+    data_c_0 = AE_SUB16S(data_c_0, d_zp);
+    AE_MINMAX16(data_c_0, d_min8, d_max8);
+
+    *output++ = AE_MOVAD16_0(data_c_0);
+  }
+}
 #endif  // defined(HIFI5)
 
-}  // namespace lstm_eval
-}  // namespace micro
-}  // namespace ops
 }  // namespace tflite
+
+#endif  // defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
diff --git a/tensorflow/lite/micro/kernels/xtensa/lstm_shared.h b/tensorflow/lite/micro/kernels/xtensa/lstm_shared.h
deleted file mode 100644
index 4bcff1a..0000000
--- a/tensorflow/lite/micro/kernels/xtensa/lstm_shared.h
+++ /dev/null
@@ -1,78 +0,0 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef TENSORFLOW_LITE_KERNELS_LSTM_SHARED_H_
-#define TENSORFLOW_LITE_KERNELS_LSTM_SHARED_H_
-
-namespace tflite {
-namespace ops {
-namespace micro {
-namespace lstm {
-// For full inputs kernel (24-inputs).
-// Please note the 20-input full kernel is deprecated and only kept
-// here for backward compatibility.
-namespace full {
-
-// Input Tensors of size {n_batch, n_input}
-constexpr int kInputTensor = 0;
-
-// Input weight tensors of size: {n_cell, n_input}
-constexpr int kInputToInputWeightsTensor = 1;  // Optional
-constexpr int kInputToForgetWeightsTensor = 2;
-constexpr int kInputToCellWeightsTensor = 3;
-constexpr int kInputToOutputWeightsTensor = 4;
-
-// Recurrent weight tensors of size {n_cell, n_output}
-constexpr int kRecurrentToInputWeightsTensor = 5;  // Optional
-constexpr int kRecurrentToForgetWeightsTensor = 6;
-constexpr int kRecurrentToCellWeightsTensor = 7;
-constexpr int kRecurrentToOutputWeightsTensor = 8;
-
-// Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
-constexpr int kCellToInputWeightsTensor = 9;    // Optional
-constexpr int kCellToForgetWeightsTensor = 10;  // Optional
-constexpr int kCellToOutputWeightsTensor = 11;  // Optional
-
-// Gates bias tensors of size {n_cell}
-constexpr int kInputGateBiasTensor = 12;  // Optional
-constexpr int kForgetGateBiasTensor = 13;
-constexpr int kCellGateBiasTensor = 14;
-constexpr int kOutputGateBiasTensor = 15;
-
-// Projection weight tensor of size {n_output, n_cell}
-constexpr int kProjectionWeightsTensor = 16;  // Optional
-// Projection bias tensor of size {n_output}
-constexpr int kProjectionBiasTensor = 17;  // Optional
-
-// These state tensors are defined as variable tensors, and will be modified by
-// this op.
-constexpr int kOutputStateTensor = 18;
-constexpr int kCellStateTensor = 19;
-
-// Layer norm coefficient tensors of size {n_cell}, representing a diagonal
-// matrix.
-constexpr int kInputLayerNormCoefficientsTensor = 20;   // Optional
-constexpr int kForgetLayerNormCoefficientsTensor = 21;  // Optional
-constexpr int kCellLayerNormCoefficientsTensor = 22;    // Optional
-constexpr int kOutputLayerNormCoefficientsTensor = 23;  // Optional
-
-// Output tensors.
-constexpr int kOutputTensor = 0;
-}  // namespace full
-
-}  // namespace lstm
-}  // namespace micro
-}  // namespace ops
-}  // namespace tflite
-#endif  // TENSORFLOW_LITE_KERNELS_LSTM_SHARED_H_
diff --git a/tensorflow/lite/micro/kernels/xtensa/svdf.cc b/tensorflow/lite/micro/kernels/xtensa/svdf.cc
index 3a71f94..da34e09 100644
--- a/tensorflow/lite/micro/kernels/xtensa/svdf.cc
+++ b/tensorflow/lite/micro/kernels/xtensa/svdf.cc
@@ -63,7 +63,7 @@
 #if defined(HIFI5)
   memcpy(state_ptr, state_ptr + 1, num_bytes);
 #else
-  xa_nn_memmove_16(state_ptr, state_ptr + 1, num_bytes);
+  xa_nn_memmove_16(state_ptr, state_ptr + 1, (num_bytes >> 1));
 #endif  // defined(HIFI5)
 
   // Note: no need to clear the latest activation, matmul is not accumulative.
diff --git a/tensorflow/lite/micro/kernels/xtensa/transpose_conv.cc b/tensorflow/lite/micro/kernels/xtensa/transpose_conv.cc
index 20b0ab8..44a9f86 100644
--- a/tensorflow/lite/micro/kernels/xtensa/transpose_conv.cc
+++ b/tensorflow/lite/micro/kernels/xtensa/transpose_conv.cc
@@ -183,19 +183,57 @@
   // Quantized kernels use an int32 scratch buffer.
   if (input->type == kTfLiteInt8) {
     TFLITE_DCHECK(context->RequestScratchBufferInArena != nullptr);
+#if defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
+    const int stride_width = params->stride_width;
+    const int stride_height = params->stride_height;
+
+    const int input_height = SizeOfDimension(input, 1);
+    const int input_width = SizeOfDimension(input, 2);
+    const int input_depth = SizeOfDimension(input, 3);
+    const int output_height = height;
+    const int output_width = width;
+    int32_t scratch_buffer_size = 0;
+    scratch_buffer_size = xa_nn_transpose_conv_getsize(
+        input_height, input_width, input_depth, filter_height, filter_width,
+        stride_width, stride_height, output_height, output_width, num_channels,
+        PREC_SYM8S, PREC_ASYM8S);
+    TFLITE_DCHECK(context->RequestScratchBufferInArena(
+                      context, scratch_buffer_size,
+                      &(data->scratch_buffer_index)) == kTfLiteOk);
+#else  // #if defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
     TFLITE_DCHECK(context->RequestScratchBufferInArena(
                       context,
                       GetTensorShape(output).FlatSize() * sizeof(int32_t),
                       &(data->scratch_buffer_index)) == kTfLiteOk);
+#endif
   }
 
   // Quantized 16x8 kernels use an int64 scratch buffer.
   if (input->type == kTfLiteInt16) {
     TFLITE_DCHECK(context->RequestScratchBufferInArena != nullptr);
+#if defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
+    const int stride_width = params->stride_width;
+    const int stride_height = params->stride_height;
+
+    const int input_height = SizeOfDimension(input, 1);
+    const int input_width = SizeOfDimension(input, 2);
+    const int input_depth = SizeOfDimension(input, 3);
+    const int output_height = height;
+    const int output_width = width;
+    int32_t scratch_buffer_size = 0;
+    scratch_buffer_size = xa_nn_transpose_conv_getsize(
+        input_height, input_width, input_depth, filter_height, filter_width,
+        stride_width, stride_height, output_height, output_width, num_channels,
+        PREC_SYM8S, PREC_SYM16S);
+    TFLITE_DCHECK(context->RequestScratchBufferInArena(
+                      context, scratch_buffer_size,
+                      &(data->scratch_buffer_index)) == kTfLiteOk);
+#else   // #if defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
     TFLITE_DCHECK(context->RequestScratchBufferInArena(
                       context,
                       GetTensorShape(output).FlatSize() * sizeof(std::int64_t),
                       &(data->scratch_buffer_index)) == kTfLiteOk);
+#endif  // #if defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
   }
 
   // All per-channel quantized tensors need valid zero point and scale arrays.
@@ -282,6 +320,63 @@
     case kTfLiteInt8: {
       int32_t* scratch_buffer = static_cast<int32_t*>(
           context->GetScratchBuffer(context, data.scratch_buffer_index));
+#if defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
+      if (bias->type == kTfLiteInt32) {
+        const RuntimeShape& input_shape = tflite::micro::GetTensorShape(input);
+        const RuntimeShape& filter_shape =
+            tflite::micro::GetTensorShape(filter);
+        const RuntimeShape& output_shape =
+            tflite::micro::GetTensorShape(output);
+        const int stride_width = data.params.stride_width;
+        const int stride_height = data.params.stride_height;
+        const int pad_width = data.params.padding_values.width;
+        const int pad_height = data.params.padding_values.height;
+
+        const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+        const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
+        const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
+
+        const int input_height = input_shape.Dims(1);
+        const int input_width = input_shape.Dims(2);
+        const int filter_height = filter_shape.Dims(1);
+        const int filter_width = filter_shape.Dims(2);
+        const int output_height = output_shape.Dims(1);
+        const int output_width = output_shape.Dims(2);
+        const int8_t* input_data = tflite::micro::GetTensorData<int8_t>(input);
+        const int8_t* filter_data =
+            tflite::micro::GetTensorData<int8_t>(filter);
+        const int32_t* bias_data = tflite::micro::GetTensorData<int32_t>(bias);
+        int8_t* output_data = tflite::micro::GetTensorData<int8_t>(output);
+
+        const int num_elements = output_shape.FlatSize();
+
+        for (int b = 0; b < batches; b++) {
+          xa_nn_transpose_conv_sym8sxasym8s(
+              &output_data[b * output_height * output_width * output_depth],
+              const_cast<WORD8*>(
+                  &input_data[b * input_height * input_width * input_depth]),
+              const_cast<WORD8*>(filter_data), const_cast<WORD32*>(bias_data),
+              stride_width, stride_height, pad_width, pad_height, input_depth,
+              output_depth, input_height, input_width, filter_height,
+              filter_width, output_height, output_width, num_elements / batches,
+              data.params.input_offset, data.params.output_offset,
+              data.per_channel_output_shift, data.per_channel_output_multiplier,
+              scratch_buffer);
+        }
+      } else {
+        reference_integer_ops::TransposeConv(
+            data.params, data.per_channel_output_multiplier,
+            data.per_channel_output_shift, tflite::micro::GetTensorShape(input),
+            tflite::micro::GetTensorData<int8_t>(input),
+            tflite::micro::GetTensorShape(filter),
+            tflite::micro::GetTensorData<int8_t>(filter),
+            tflite::micro::GetTensorShape(bias),
+            tflite::micro::GetTensorData<int32_t>(bias),
+            tflite::micro::GetTensorShape(output),
+            tflite::micro::GetTensorData<int8_t>(output),
+            tflite::micro::GetTensorShape(nullptr), nullptr, scratch_buffer);
+      }
+#else
       reference_integer_ops::TransposeConv(
           data.params, data.per_channel_output_multiplier,
           data.per_channel_output_shift, tflite::micro::GetTensorShape(input),
@@ -293,6 +388,7 @@
           tflite::micro::GetTensorShape(output),
           tflite::micro::GetTensorData<int8_t>(output),
           tflite::micro::GetTensorShape(nullptr), nullptr, scratch_buffer);
+#endif
       break;
     }
     case kTfLiteInt16: {
@@ -319,7 +415,7 @@
             tflite::micro::GetTensorData<int16_t>(output),
             tflite::micro::GetTensorShape(nullptr), nullptr, scratch_buffer);
       } else {
-#if defined(HIFI3) || defined(HIFI4)
+#if defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
         const RuntimeShape& input_shape = tflite::micro::GetTensorShape(input);
         const RuntimeShape& filter_shape =
             tflite::micro::GetTensorShape(filter);
@@ -359,9 +455,9 @@
               output_depth, input_height, input_width, filter_height,
               filter_width, output_height, output_width, num_elements / batches,
               data.per_channel_output_shift, data.per_channel_output_multiplier,
-              &scratch_buffer[b * output_height * output_width * output_depth]);
+              scratch_buffer);
         }
-#else
+#else   // #if defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
         reference_integer_ops::TransposeConv(
             data.params, data.per_channel_output_multiplier,
             data.per_channel_output_shift, tflite::micro::GetTensorShape(input),
@@ -373,7 +469,7 @@
             tflite::micro::GetTensorShape(output),
             tflite::micro::GetTensorData<int16_t>(output),
             tflite::micro::GetTensorShape(nullptr), nullptr, scratch_buffer);
-#endif  // defined(HIFI3) || defined(HIFI4)
+#endif  // #if defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
       }
       break;
     }
diff --git a/tensorflow/lite/micro/kernels/xtensa/unidirectional_sequence_lstm.cc b/tensorflow/lite/micro/kernels/xtensa/unidirectional_sequence_lstm.cc
index cbce1e1..0f6a02e 100644
--- a/tensorflow/lite/micro/kernels/xtensa/unidirectional_sequence_lstm.cc
+++ b/tensorflow/lite/micro/kernels/xtensa/unidirectional_sequence_lstm.cc
@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -13,1109 +13,156 @@
 limitations under the License.
 ==============================================================================*/
 
-#include <math.h>
-#include <stdio.h>
+// Integer version of unidirectional sequence lstm. Only the standard LSTM
+// (defined in the keras LSTM layer, e.g., no peephole etc.) is supported here.
+// Currently used by the 16 bits activation case only
 
-#include <cstddef>
+#include <algorithm>
+#include <limits>
 
-#include "tensorflow/lite/c/builtin_op_data.h"
-#include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/kernels/internal/compatibility.h"
 #include "tensorflow/lite/kernels/internal/quantization_util.h"
-#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
 #include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/micro/kernels/fully_connected.h"
+#include "tensorflow/lite/micro/kernels/kernel_util.h"
+#include "tensorflow/lite/micro/kernels/lstm_shared.h"
 #include "tensorflow/lite/micro/kernels/xtensa/lstm_eval.h"
-#include "tensorflow/lite/micro/kernels/xtensa/lstm_shared.h"
-#include "tensorflow/lite/micro/micro_log.h"
 
-// TODO(b/230666079): Flatten the namespace to match the builtin kernel
-// implementation
 namespace tflite {
-namespace ops {
-namespace micro {
-// namespace unidirectional_sequence_lstm {
+
 namespace {
+/*Helper Functions*/
 
-struct OpData {
-  // If the lstm is layer norm.
-  bool use_layer_norm;
-  // The scratch tensor index.
-  int scratch_tensor_index;
-  bool compute_row_sums = false;
+/*Kernel functions*/
 
-  lstm_eval::IntegerLstmParameter integer_lstm_param;
-};
+void* UnidirectionalSequenceLstmInit(TfLiteContext* context, const char* buffer,
+                                     size_t length) {
+  TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
+  return context->AllocatePersistentBuffer(context, sizeof(OpDataLSTM));
+}
 
-TfLiteStatus PopulateQuantizedLstmParams8x8_16(
-    TfLiteContext* context, TfLiteNode* node,
-    lstm_eval::IntegerLstmParameter* integer_lstm_param) {
-  // Calculate quantized clip for projection and cell.
-  const auto* params =
+TfLiteStatus UnidirectionalSequenceLstmPrepare(TfLiteContext* context,
+                                               TfLiteNode* node) {
+  TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
+  TF_LITE_ENSURE_EQ(context, node->inputs->size, 24);
+
+  TFLITE_DCHECK(node->builtin_data != nullptr);
+  TFLITE_DCHECK(node->user_data != nullptr);
+
+  OpDataLSTM* op_data = reinterpret_cast<OpDataLSTM*>(node->user_data);
+  const auto* builtin_data =
       static_cast<TfLiteUnidirectionalSequenceLSTMParams*>(node->builtin_data);
-  const float cell_clip = static_cast<float>(params->cell_clip);
-  const float proj_clip = static_cast<float>(params->proj_clip);
+  // All TempTfLiteTensors will be deallocated through the destructor.
+  LstmTensors lstm_tensors(context, node);
+  TF_LITE_ENSURE_OK(context, lstm_tensors.ValidateTensorStatus(context));
 
-  const TfLiteTensor* cell_state =
-      GetVariableInput(context, node, micro::lstm::full::kCellStateTensor);
-  TF_LITE_ENSURE(context, cell_state != nullptr);
-  TfLiteTensor* output_tensor;
+  op_data->cell_gate_nonlinear_type = builtin_data->activation;
+  op_data->size_info =
+      CreateLstmSizeInfo(builtin_data->time_major,
+                         lstm_tensors.GetInternalTensor(kLstmInputTensor)->dims,
+                         lstm_tensors.HiddenStateTensor()->dims);
   TF_LITE_ENSURE_OK(
-      context, GetOutputSafe(context, node, micro::lstm::full::kOutputTensor,
-                             &output_tensor));
+      context, ValidateTensorSize(context, lstm_tensors, op_data->size_info));
 
-  auto* cell_state_params =
-      static_cast<TfLiteAffineQuantization*>(cell_state->quantization.params);
-  auto* proj_params = static_cast<TfLiteAffineQuantization*>(
-      output_tensor->quantization.params);
-  if (cell_clip > static_cast<float>(0.0)) {
-    integer_lstm_param->quantized_cell_clip = static_cast<int16_t>(std::min(
-        std::max(cell_clip / cell_state_params->scale->data[0], -32768.0f),
-        32767.0f));
+  // Create cell state information and gate parameters (Fully Connected and Mul)
+  auto cell_state_type =
+      lstm_tensors.GetInternalTensor(kLstmCellStateTensor)->type;
+  if (cell_state_type == kTfLiteFloat32) {
+    op_data->cell_state_info =
+        CreateLstmCellStateInfoFloat(builtin_data->cell_clip);
+    TF_LITE_ENSURE_OK(
+        context, PrepareGateParametersFloat(context, lstm_tensors, op_data));
+  } else if (cell_state_type == kTfLiteInt16) {
+    op_data->cell_state_info = CreateLstmCellStateInfo(
+        lstm_tensors.CellStateTensor()->params.scale, builtin_data->cell_clip);
+    TF_LITE_ENSURE_OK(
+        context, PrepareGateParametersInteger(context, lstm_tensors, op_data));
   } else {
-    integer_lstm_param->quantized_cell_clip = 0;
+    MicroPrintf(
+        "Cell state type %s (%d) not supported. The quantized Unidirectional "
+        "Sequence LSTM Op only support int16 cell state",
+        TfLiteTypeGetName(cell_state_type), cell_state_type);
+    return kTfLiteError;
   }
-  if (proj_clip > static_cast<float>(0.0)) {
-    integer_lstm_param->quantized_proj_clip = static_cast<int8_t>(std::min(
-        std::max(proj_clip / proj_params->scale->data[0], -128.0f), 127.0f));
-  } else {
-    integer_lstm_param->quantized_proj_clip = 0;
+  // request buffers (four buffers)
+  for (size_t i = 0; i < 4; i++) {
+    TF_LITE_ENSURE_OK(context, context->RequestScratchBufferInArena(
+                                   context,
+                                   op_data->size_info.batch_size *
+                                       op_data->size_info.state_dimension *
+                                       TfLiteTypeGetSize(cell_state_type),
+                                   &(op_data->buffer_indices[i])));
   }
+  return kTfLiteOk;
+}
 
-  // Calculate effective scales.
-  OpData* op_data = static_cast<OpData*>(node->user_data);
-  const bool use_layer_norm = op_data->use_layer_norm;
+TfLiteStatus UnidirectionalSequenceLstmEval(TfLiteContext* context,
+                                            TfLiteNode* node) {
+  TFLITE_DCHECK(node->user_data != nullptr);
+  const OpDataLSTM& op_data = *reinterpret_cast<OpDataLSTM*>(node->user_data);
+  auto kernel_content = CreateLSTMKernelContent(context, node);
 
-  const TfLiteTensor* input;
-  TF_LITE_ENSURE_OK(
-      context,
-      GetInputSafe(context, node, micro::lstm::full::kInputTensor, &input));
+  const auto activation_type =
+      kernel_content.internal_tensors[kLstmInputTensor]->type;
+  const auto weight_type =
+      kernel_content.internal_tensors[kLstmInputToInputWeightsTensor]->type;
 
-  const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
-      context, node, micro::lstm::full::kInputToInputWeightsTensor);
-  const TfLiteTensor* input_to_forget_weights;
-  TF_LITE_ENSURE_OK(context,
-                    GetInputSafe(context, node,
-                                 micro::lstm::full::kInputToForgetWeightsTensor,
-                                 &input_to_forget_weights));
-  const TfLiteTensor* input_to_cell_weights;
-  TF_LITE_ENSURE_OK(
-      context,
-      GetInputSafe(context, node, micro::lstm::full::kInputToCellWeightsTensor,
-                   &input_to_cell_weights));
-  const TfLiteTensor* input_to_output_weights;
-  TF_LITE_ENSURE_OK(context,
-                    GetInputSafe(context, node,
-                                 micro::lstm::full::kInputToOutputWeightsTensor,
-                                 &input_to_output_weights));
-
-  const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
-      context, node, micro::lstm::full::kRecurrentToInputWeightsTensor);
-  const TfLiteTensor* recurrent_to_forget_weights;
-  TF_LITE_ENSURE_OK(
-      context, GetInputSafe(context, node,
-                            micro::lstm::full::kRecurrentToForgetWeightsTensor,
-                            &recurrent_to_forget_weights));
-  const TfLiteTensor* recurrent_to_cell_weights;
-  TF_LITE_ENSURE_OK(
-      context, GetInputSafe(context, node,
-                            micro::lstm::full::kRecurrentToCellWeightsTensor,
-                            &recurrent_to_cell_weights));
-  const TfLiteTensor* recurrent_to_output_weights;
-  TF_LITE_ENSURE_OK(
-      context, GetInputSafe(context, node,
-                            micro::lstm::full::kRecurrentToOutputWeightsTensor,
-                            &recurrent_to_output_weights));
-
-  const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
-      context, node, micro::lstm::full::kCellToInputWeightsTensor);
-  const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(
-      context, node, micro::lstm::full::kCellToForgetWeightsTensor);
-  const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(
-      context, node, micro::lstm::full::kCellToOutputWeightsTensor);
-
-  const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor(
-      context, node, micro::lstm::full::kInputLayerNormCoefficientsTensor);
-  const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor(
-      context, node, micro::lstm::full::kForgetLayerNormCoefficientsTensor);
-  const TfLiteTensor* cell_layer_norm_coefficients = GetOptionalInputTensor(
-      context, node, micro::lstm::full::kCellLayerNormCoefficientsTensor);
-  const TfLiteTensor* output_layer_norm_coefficients = GetOptionalInputTensor(
-      context, node, micro::lstm::full::kOutputLayerNormCoefficientsTensor);
-
-  const TfLiteTensor* projection_weights = GetOptionalInputTensor(
-      context, node, micro::lstm::full::kProjectionWeightsTensor);
-
-  TfLiteTensor* output_state =
-      GetVariableInput(context, node, micro::lstm::full::kOutputStateTensor);
-  TF_LITE_ENSURE(context, output_state != nullptr);
-
-  // Since we have already checked that weights are all there or none, we can
-  // check the existence of only one to get the condition.
-  const bool use_cifg = (input_to_input_weights == nullptr);
-  const bool use_peephole = (cell_to_output_weights != nullptr);
-  const bool use_projection = (projection_weights != nullptr);
-
-  // Get intermediate scales and zero points.
-  constexpr size_t kIntermediateCount = 5;
-  float intermediate_scale[kIntermediateCount];
-  int32_t intermediate_zp[kIntermediateCount];
-  for (int i = 0; i < 4; ++i) {
-    if (use_layer_norm) {
-      TfLiteTensor* intermediate =
-          context->GetTensor(context, node->intermediates->data[i]);
-      auto* tmp_params = static_cast<TfLiteAffineQuantization*>(
-          intermediate->quantization.params);
-      intermediate_scale[i] = tmp_params->scale->data[0];
-      intermediate_zp[i] = tmp_params->zero_point->data[0];
-    } else {
-      // Q3.12 for activation functions.
-      intermediate_scale[i] = std::pow(2, -12);
-      intermediate_zp[i] = 0;
+  switch (activation_type) {
+    case kTfLiteFloat32: {
+      LSTMBuffers<float> buffers =
+          CreateLSTMBuffers<float>(context, op_data.buffer_indices);
+      EvalLstm<float, float, float, float>(op_data, kernel_content, buffers);
+      break;
+    }
+    case kTfLiteInt8: {
+      switch (weight_type) {
+        case kTfLiteInt8: {
+          // 8(activation)x8(weight)->16(cell) LSTM with 32 bits bias
+          LSTMBuffers<int16_t> buffers =
+              CreateLSTMBuffers<int16_t>(context, op_data.buffer_indices);
+          EvalLstm<int8_t, int8_t, int16_t, int32_t>(op_data, kernel_content,
+                                                     buffers);
+          break;
+        }
+        default: {
+          MicroPrintf("Filter type %s (%d) not supported.",
+                      TfLiteTypeGetName(weight_type), activation_type);
+          return kTfLiteError;
+        }
+      }
+      break;
+    }
+    case kTfLiteInt16: {
+      switch (weight_type) {
+        case kTfLiteInt8: {
+          // 16(activation)x8(weight)->16(cell) LSTM with 64 bits bias
+          LSTMBuffers<int16_t> buffers =
+              CreateLSTMBuffers<int16_t>(context, op_data.buffer_indices);
+          EvalLstm<int16_t, int8_t, int16_t, int64_t>(op_data, kernel_content,
+                                                      buffers);
+          break;
+        }
+        default: {
+          MicroPrintf("Filter type %s (%d) not supported.",
+                      TfLiteTypeGetName(weight_type), weight_type);
+          return kTfLiteError;
+        }
+      }
+      break;
+    }
+    default: {
+      MicroPrintf("Input type %s (%d) not supported.",
+                  TfLiteTypeGetName(activation_type), activation_type);
+      return kTfLiteError;
     }
   }
-  // In the absence of projection, hidden becomes otuput and this intermediate
-  // is ignored.
-  TfLiteTensor* hidden =
-      context->GetTensor(context, node->intermediates->data[4]);
-  auto* hidden_params =
-      static_cast<TfLiteAffineQuantization*>(hidden->quantization.params);
-  intermediate_scale[4] = hidden_params->scale->data[0];
-  intermediate_zp[4] = hidden_params->zero_point->data[0];
-
-  // Scales.
-  const float default_scale = 1.0;
-  float input_scale = default_scale;
-  float input_to_input_weight_scale = default_scale;
-  float recurrent_to_input_weight_scale = default_scale;
-  float cell_to_input_weight_scale = default_scale;
-  float input_to_forget_weight_scale = default_scale;
-  float recurrent_to_forget_weight_scale = default_scale;
-  float cell_to_forget_weight_scale = default_scale;
-  float input_to_cell_weight_scale = default_scale;
-  float recurrent_to_cell_weight_scale = default_scale;
-  float input_to_output_weight_scale = default_scale;
-  float recurrent_to_output_weight_scale = default_scale;
-  float cell_to_output_weight_scale = default_scale;
-  float projection_weight_scale = default_scale;
-  float layer_norm_input_scale = default_scale;
-  float layer_norm_forget_scale = default_scale;
-  float layer_norm_cell_scale = default_scale;
-  float layer_norm_output_scale = default_scale;
-  float output_state_scale = default_scale;
-  int cell_scale = 1;
-
-  // Effective scales.
-  float effective_input_to_input_scale = default_scale;
-  float effective_recurrent_to_input_scale = default_scale;
-  float effective_cell_to_input_scale = default_scale;
-  float effective_input_to_forget_scale = default_scale;
-  float effective_recurrent_to_forget_scale = default_scale;
-  float effective_cell_to_forget_scale = default_scale;
-  float effective_input_to_cell_scale = default_scale;
-  float effective_recurrent_to_cell_scale = default_scale;
-  float effective_input_to_output_scale = default_scale;
-  float effective_recurrent_to_output_scale = default_scale;
-  float effective_cell_to_output_scale = default_scale;
-  float effective_proj_scale = default_scale;
-  float effective_hidden_scale = default_scale;
-
-  // Populate scales.
-  if (!use_cifg) {
-    input_to_input_weight_scale = input_to_input_weights->params.scale;
-    recurrent_to_input_weight_scale = recurrent_to_input_weights->params.scale;
-  }
-
-  if (use_peephole) {
-    if (!use_cifg) {
-      cell_to_input_weight_scale = cell_to_input_weights->params.scale;
-    }
-    cell_to_forget_weight_scale = cell_to_forget_weights->params.scale;
-    cell_to_output_weight_scale = cell_to_output_weights->params.scale;
-  }
-
-  if (use_layer_norm) {
-    if (!use_cifg) {
-      layer_norm_input_scale = input_layer_norm_coefficients->params.scale;
-    }
-    layer_norm_forget_scale = forget_layer_norm_coefficients->params.scale;
-    layer_norm_cell_scale = cell_layer_norm_coefficients->params.scale;
-    layer_norm_output_scale = output_layer_norm_coefficients->params.scale;
-  }
-
-  if (use_projection) {
-    projection_weight_scale = projection_weights->params.scale;
-  }
-  output_state_scale = output_state->params.scale;
-
-  input_to_forget_weight_scale = input_to_forget_weights->params.scale;
-  input_to_cell_weight_scale = input_to_cell_weights->params.scale;
-  input_to_output_weight_scale = input_to_output_weights->params.scale;
-  recurrent_to_forget_weight_scale = recurrent_to_forget_weights->params.scale;
-  recurrent_to_cell_weight_scale = recurrent_to_cell_weights->params.scale;
-  recurrent_to_output_weight_scale = recurrent_to_output_weights->params.scale;
-
-  // Check cell state (already used above)
-  TF_LITE_ENSURE(context, CheckedLog2(cell_state->params.scale, &cell_scale));
-  // TF_LITE_ENSURE(context, cell_scale <= -9);
-  integer_lstm_param->cell_scale = cell_scale;
-  input_scale = input->params.scale;
-
-  // Calculate effective scales.
-  if (!use_cifg) {
-    effective_input_to_input_scale =
-        input_to_input_weight_scale * input_scale / intermediate_scale[0];
-    effective_recurrent_to_input_scale = recurrent_to_input_weight_scale *
-                                         output_state_scale /
-                                         intermediate_scale[0];
-  }
-  effective_input_to_forget_scale =
-      input_to_forget_weight_scale * input_scale / intermediate_scale[1];
-  effective_recurrent_to_forget_scale = recurrent_to_forget_weight_scale *
-                                        output_state_scale /
-                                        intermediate_scale[1];
-
-  effective_input_to_cell_scale =
-      input_to_cell_weight_scale * input_scale / intermediate_scale[2];
-  effective_recurrent_to_cell_scale = recurrent_to_cell_weight_scale *
-                                      output_state_scale /
-                                      intermediate_scale[2];
-
-  effective_input_to_output_scale =
-      input_to_output_weight_scale * input_scale / intermediate_scale[3];
-  effective_recurrent_to_output_scale = recurrent_to_output_weight_scale *
-                                        output_state_scale /
-                                        intermediate_scale[3];
-
-  effective_hidden_scale = std::pow((float)2, (float)-15) /
-                           intermediate_scale[4] *
-                           std::pow((float)2, (float)-15);
-
-  effective_proj_scale =
-      projection_weight_scale * intermediate_scale[4] / output_state_scale;
-
-  if (use_peephole) {
-    if (!use_cifg) {
-      effective_cell_to_input_scale =
-          std::pow((float)(2), (float)cell_scale) *  // NOLINT
-          (float)(cell_to_input_weight_scale) / intermediate_scale[0];
-    }
-    effective_cell_to_forget_scale =
-        std::pow((float)2, (float)cell_scale) *  // NOLINT
-        (float)cell_to_forget_weight_scale / intermediate_scale[1];
-    effective_cell_to_output_scale =
-        std::pow((float)2, (float)cell_scale) *  // NOLINT
-        (float)cell_to_output_weight_scale / intermediate_scale[3];
-  }
-
-  // Decompose scales.
-  QuantizeMultiplier(static_cast<double>(effective_input_to_input_scale),
-                     &integer_lstm_param->effective_input_to_input_scale_a,
-                     &integer_lstm_param->effective_input_to_input_scale_b);
-  QuantizeMultiplier(static_cast<double>(effective_recurrent_to_input_scale),
-                     &integer_lstm_param->effective_recurrent_to_input_scale_a,
-                     &integer_lstm_param->effective_recurrent_to_input_scale_b);
-  QuantizeMultiplier(static_cast<double>(effective_cell_to_input_scale),
-                     &integer_lstm_param->effective_cell_to_input_scale_a,
-                     &integer_lstm_param->effective_cell_to_input_scale_b);
-  QuantizeMultiplier(static_cast<double>(effective_input_to_forget_scale),
-                     &integer_lstm_param->effective_input_to_forget_scale_a,
-                     &integer_lstm_param->effective_input_to_forget_scale_b);
-  QuantizeMultiplier(
-      static_cast<double>(effective_recurrent_to_forget_scale),
-      &integer_lstm_param->effective_recurrent_to_forget_scale_a,
-      &integer_lstm_param->effective_recurrent_to_forget_scale_b);
-  QuantizeMultiplier(static_cast<double>(effective_cell_to_forget_scale),
-                     &integer_lstm_param->effective_cell_to_forget_scale_a,
-                     &integer_lstm_param->effective_cell_to_forget_scale_b);
-  QuantizeMultiplier(static_cast<double>(effective_input_to_cell_scale),
-                     &integer_lstm_param->effective_input_to_cell_scale_a,
-                     &integer_lstm_param->effective_input_to_cell_scale_b);
-  QuantizeMultiplier(static_cast<double>(effective_recurrent_to_cell_scale),
-                     &integer_lstm_param->effective_recurrent_to_cell_scale_a,
-                     &integer_lstm_param->effective_recurrent_to_cell_scale_b);
-  QuantizeMultiplier(static_cast<double>(effective_input_to_output_scale),
-                     &integer_lstm_param->effective_input_to_output_scale_a,
-                     &integer_lstm_param->effective_input_to_output_scale_b);
-  QuantizeMultiplier(
-      static_cast<double>(effective_recurrent_to_output_scale),
-      &integer_lstm_param->effective_recurrent_to_output_scale_a,
-      &integer_lstm_param->effective_recurrent_to_output_scale_b);
-  QuantizeMultiplier(static_cast<double>(effective_cell_to_output_scale),
-                     &integer_lstm_param->effective_cell_to_output_scale_a,
-                     &integer_lstm_param->effective_cell_to_output_scale_b);
-  QuantizeMultiplier(static_cast<double>(effective_proj_scale),
-                     &integer_lstm_param->effective_proj_scale_a,
-                     &integer_lstm_param->effective_proj_scale_b);
-  QuantizeMultiplier(static_cast<double>(effective_hidden_scale),
-                     &integer_lstm_param->effective_hidden_scale_a,
-                     &integer_lstm_param->effective_hidden_scale_b);
-  QuantizeMultiplier(static_cast<double>(layer_norm_input_scale),
-                     &integer_lstm_param->layer_norm_input_scale_a,
-                     &integer_lstm_param->layer_norm_input_scale_b);
-  QuantizeMultiplier(static_cast<double>(layer_norm_forget_scale),
-                     &integer_lstm_param->layer_norm_forget_scale_a,
-                     &integer_lstm_param->layer_norm_forget_scale_b);
-  QuantizeMultiplier(static_cast<double>(layer_norm_cell_scale),
-                     &integer_lstm_param->layer_norm_cell_scale_a,
-                     &integer_lstm_param->layer_norm_cell_scale_b);
-  QuantizeMultiplier(static_cast<double>(layer_norm_output_scale),
-                     &integer_lstm_param->layer_norm_output_scale_a,
-                     &integer_lstm_param->layer_norm_output_scale_b);
-
-  integer_lstm_param->hidden_zp = intermediate_zp[4];
-
-  // 10000 is used to make sure the kernel logic does not overflow.
-  if (!use_cifg) {
-    integer_lstm_param->input_variance_guard =
-        std::max(static_cast<int32_t>(1),
-                 static_cast<int32_t>(10000 * layer_norm_input_scale));
-  }
-  integer_lstm_param->forget_variance_guard =
-      std::max(static_cast<int32_t>(1),
-               static_cast<int32_t>(10000 * layer_norm_forget_scale));
-  integer_lstm_param->cell_variance_guard =
-      std::max(static_cast<int32_t>(1),
-               static_cast<int32_t>(10000 * layer_norm_cell_scale));
-  integer_lstm_param->output_variance_guard =
-      std::max(static_cast<int32_t>(1),
-               static_cast<int32_t>(10000 * layer_norm_output_scale));
-
   return kTfLiteOk;
 }
 
 }  // namespace
 
-// Temporary tensors
-enum TemporaryTensor {
-  kScratchBuffer = 0,
-  kInputQuantized = 1,
-  kOutputStateQuantized = 2,
-  kCellStateQuantized = 3,
-  kInputScalingFactors = 4,
-  kOutputStateScalingFactors = 5,
-  kProductScalingFactors = 6,
-  kRecoveredCellWeights = 7,
-  kAccumScratch = 8,
-  kInputZeroPoints = 9,
-  kOutputStateZeroPoints = 10,
-  kRowSums = 11,
-  kNumTemporaryTensors = 12,
-};
-
-void* Init(TfLiteContext* context, const char* buffer, size_t length) {
-  OpData* op_data = reinterpret_cast<OpData*>(
-      context->AllocatePersistentBuffer(context, sizeof(OpData)));
-
-  return op_data;
-}
-
-// Check that input tensor dimensions matches with each other.
-TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
-                                        TfLiteNode* node, int n_input,
-                                        int n_output, int n_cell,
-                                        bool use_layer_norm, bool is_integer) {
-  const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
-
-  // Making sure clipping parameters have valid values.
-  // == 0 means no clipping
-  //  > 0 means clipping
-  TF_LITE_ENSURE(context, params->cell_clip >= 0);
-  TF_LITE_ENSURE(context, params->proj_clip >= 0);
-  const TfLiteEvalTensor* input_to_input_weights = tflite::micro::GetEvalInput(
-      context, node, micro::lstm::full::kInputToInputWeightsTensor);
-  if (input_to_input_weights != nullptr) {
-    TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
-    TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
-    TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
-  }
-  const TfLiteEvalTensor* input_to_forget_weights = tflite::micro::GetEvalInput(
-      context, node, micro::lstm::full::kInputToForgetWeightsTensor);
-
-  TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
-  TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
-  TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
-  const TfLiteEvalTensor* input_to_cell_weights = tflite::micro::GetEvalInput(
-      context, node, micro::lstm::full::kInputToCellWeightsTensor);
-
-  TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
-  TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
-  TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
-  const TfLiteEvalTensor* recurrent_to_input_weights =
-      tflite::micro::GetEvalInput(
-          context, node, micro::lstm::full::kRecurrentToInputWeightsTensor);
-  if (recurrent_to_input_weights != nullptr) {
-    TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
-    TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
-                      n_cell);
-    TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1],
-                      n_output);
-  }
-  const TfLiteEvalTensor* recurrent_to_forget_weights =
-      tflite::micro::GetEvalInput(
-          context, node, micro::lstm::full::kRecurrentToForgetWeightsTensor);
-
-  TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
-  TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
-                    n_cell);
-  TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
-                    n_output);
-  const TfLiteEvalTensor* recurrent_to_cell_weights =
-      tflite::micro::GetEvalInput(
-          context, node, micro::lstm::full::kRecurrentToCellWeightsTensor);
-
-  TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
-  TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
-  TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
-                    n_output);
-
-  // We make sure the input-gate's parameters are either both present (regular
-  // LSTM) or not at all (CIFG-LSTM).
-  const bool cifg_weights_all_or_none =
-      ((input_to_input_weights != nullptr) &&
-       (recurrent_to_input_weights != nullptr)) ||
-      ((input_to_input_weights == nullptr) &&
-       (recurrent_to_input_weights == nullptr));
-  TF_LITE_ENSURE(context, cifg_weights_all_or_none == true);
-
-  const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
-      context, node, micro::lstm::full::kCellToInputWeightsTensor);
-  if (cell_to_input_weights != nullptr) {
-    TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
-    TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
-    TF_LITE_ENSURE_TYPES_EQ(
-        context, cell_to_input_weights->type,
-        is_integer ? kTfLiteInt16 : input_to_forget_weights->type);
-  }
-
-  const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(
-      context, node, lstm::full::kCellToForgetWeightsTensor);
-  if (cell_to_forget_weights != nullptr) {
-    TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
-    TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
-    TF_LITE_ENSURE_TYPES_EQ(
-        context, cell_to_forget_weights->type,
-        is_integer ? kTfLiteInt16 : input_to_forget_weights->type);
-  }
-
-  const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(
-      context, node, micro::lstm::full::kCellToOutputWeightsTensor);
-  if (cell_to_output_weights != nullptr) {
-    TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
-    TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
-    TF_LITE_ENSURE_TYPES_EQ(
-        context, cell_to_output_weights->type,
-        is_integer ? kTfLiteInt16 : input_to_forget_weights->type);
-  }
-
-  // Making sure the peephole weights are there all or none.
-  const bool use_cifg = (input_to_input_weights == nullptr);
-  const bool peephole_weights_all_or_none =
-      ((cell_to_input_weights != nullptr || use_cifg) &&
-       (cell_to_forget_weights != nullptr) &&
-       (cell_to_output_weights != nullptr)) ||
-      ((cell_to_input_weights == nullptr) &&
-       (cell_to_forget_weights == nullptr) &&
-       (cell_to_output_weights == nullptr));
-  TF_LITE_ENSURE(context, peephole_weights_all_or_none == true);
-  const TfLiteEvalTensor* input_gate_bias = tflite::micro::GetEvalInput(
-      context, node, micro::lstm::full::kInputGateBiasTensor);
-
-  if (use_cifg) {
-    TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr);
-  } else {
-    TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1);
-    TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
-    if (is_integer) {
-      TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteInt32);
-    } else {
-      TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteFloat32);
-    }
-  }
-  const TfLiteEvalTensor* forget_gate_bias = tflite::micro::GetEvalInput(
-      context, node, micro::lstm::full::kForgetGateBiasTensor);
-
-  TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
-  TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
-  if (is_integer) {
-    TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteInt32);
-  } else {
-    TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteFloat32);
-  }
-  const TfLiteEvalTensor* cell_gate_bias = tflite::micro::GetEvalInput(
-      context, node, micro::lstm::full::kCellGateBiasTensor);
-
-  TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->size, 1);
-  TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->data[0], n_cell);
-  if (is_integer) {
-    TF_LITE_ENSURE_TYPES_EQ(context, cell_gate_bias->type, kTfLiteInt32);
-  } else {
-    TF_LITE_ENSURE_TYPES_EQ(context, cell_gate_bias->type, kTfLiteFloat32);
-  }
-  const TfLiteEvalTensor* output_gate_bias = tflite::micro::GetEvalInput(
-      context, node, micro::lstm::full::kOutputGateBiasTensor);
-  TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
-  TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
-  if (is_integer) {
-    TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteInt32);
-  } else {
-    TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteFloat32);
-  }
-
-  const TfLiteTensor* projection_weights = GetOptionalInputTensor(
-      context, node, micro::lstm::full::kProjectionWeightsTensor);
-  if (projection_weights != nullptr) {
-    TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
-    TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
-    TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell);
-  }
-
-  const TfLiteTensor* projection_bias = GetOptionalInputTensor(
-      context, node, micro::lstm::full::kProjectionBiasTensor);
-  if (projection_bias != nullptr) {
-    TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
-    TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
-    if (is_integer) {
-      TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteInt32);
-    } else {
-      TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteFloat32);
-    }
-  }
-
-  // Making sure the projection tensors are consistent:
-  // 1) If projection weight is not present, then projection bias should not be
-  // present.
-  // 2) If projection weight is present, then projection bias is optional.
-  const bool projecton_tensors_consistent =
-      ((projection_weights != nullptr) || (projection_bias == nullptr));
-  TF_LITE_ENSURE(context, projecton_tensors_consistent == true);
-
-  if (use_layer_norm) {
-    const TfLiteEvalTensor* input_layer_norm_coefficients =
-        tflite::micro::GetEvalInput(
-            context, node,
-            micro::lstm::full::kInputLayerNormCoefficientsTensor);
-    if (use_cifg) {
-      TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients, nullptr);
-    } else {
-      TF_LITE_ENSURE(context, input_layer_norm_coefficients != nullptr);
-      TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->size, 1);
-      TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->data[0],
-                        n_cell);
-      if (is_integer) {
-        TF_LITE_ENSURE_TYPES_EQ(context, input_layer_norm_coefficients->type,
-                                kTfLiteInt16);
-      } else {
-        TF_LITE_ENSURE_TYPES_EQ(context, input_layer_norm_coefficients->type,
-                                kTfLiteFloat32);
-      }
-    }
-    const TfLiteEvalTensor* forget_layer_norm_coefficients =
-        tflite::micro::GetEvalInput(
-            context, node,
-            micro::lstm::full::kForgetLayerNormCoefficientsTensor);
-    TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->size, 1);
-    TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->data[0],
-                      n_cell);
-    if (is_integer) {
-      TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type,
-                              kTfLiteInt16);
-    } else {
-      TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type,
-                              kTfLiteFloat32);
-    }
-    const TfLiteEvalTensor* cell_layer_norm_coefficients =
-        tflite::micro::GetEvalInput(
-            context, node, micro::lstm::full::kCellLayerNormCoefficientsTensor);
-    TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->size, 1);
-    TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->data[0],
-                      n_cell);
-    if (is_integer) {
-      TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type,
-                              kTfLiteInt16);
-    } else {
-      TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type,
-                              kTfLiteFloat32);
-    }
-    const TfLiteEvalTensor* output_layer_norm_coefficients =
-        tflite::micro::GetEvalInput(
-            context, node,
-            micro::lstm::full::kOutputLayerNormCoefficientsTensor);
-
-    TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->size, 1);
-    TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->data[0],
-                      n_cell);
-    if (is_integer) {
-      TF_LITE_ENSURE_TYPES_EQ(context, output_layer_norm_coefficients->type,
-                              kTfLiteInt16);
-    } else {
-      TF_LITE_ENSURE_TYPES_EQ(context, output_layer_norm_coefficients->type,
-                              kTfLiteFloat32);
-    }
-  }
-
-  return kTfLiteOk;
-}
-
-TfLiteStatus PrecomputeZeroPointTimesWeightWithBias(
-    TfLiteContext* context, int32_t zero_point,
-    const TfLiteTensor* weight_tensor, const TfLiteTensor* bias_tensor,
-    std::unique_ptr<int32_t[]>* output) {
-  if (weight_tensor == nullptr) {
-    return kTfLiteOk;
-  }
-
-  const RuntimeShape& weight_shape = GetTensorShape(weight_tensor);
-  TF_LITE_ENSURE_EQ(context, weight_shape.DimensionsCount(), 2);
-  const int row = weight_shape.Dims(0);
-  const int col = weight_shape.Dims(1);
-  output->reset(new int32_t[row]);
-  if (bias_tensor == nullptr) {
-    memset(output->get(), 0, row * sizeof(int32_t));
-  } else {
-    const int32_t* bias = GetTensorData<int32_t>(bias_tensor);
-    memcpy(output->get(), bias, row * sizeof(int32_t));
-  }
-  if (zero_point != 0) {
-    const int8_t* weight = GetTensorData<int8_t>(weight_tensor);
-    tensor_utils::PortableMatrixScalarMultiplyAccumulate(
-        weight, zero_point, row, col, output->get());
-  }
-  return kTfLiteOk;
-}
-
-TfLiteStatus PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext* context,
-                                                       OpData* op_data,
-                                                       TfLiteNode* node) {
-  const TfLiteTensor* input;
-  TF_LITE_ENSURE_OK(
-      context,
-      GetInputSafe(context, node, micro::lstm::full::kInputTensor, &input));
-  const TfLiteTensor* output_state =
-      GetVariableInput(context, node, micro::lstm::full::kOutputStateTensor);
-  TF_LITE_ENSURE(context, output_state != nullptr);
-
-  const int32_t input_zero_point = -input->params.zero_point;
-  const int32_t output_state_zero_point = -output_state->params.zero_point;
-
-  const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
-      context, node, micro::lstm::full::kInputToInputWeightsTensor);
-  const TfLiteTensor* input_to_forget_weights;
-  TF_LITE_ENSURE_OK(context,
-                    GetInputSafe(context, node,
-                                 micro::lstm::full::kInputToForgetWeightsTensor,
-                                 &input_to_forget_weights));
-  const TfLiteTensor* input_to_cell_weights;
-  TF_LITE_ENSURE_OK(
-      context,
-      GetInputSafe(context, node, micro::lstm::full::kInputToCellWeightsTensor,
-                   &input_to_cell_weights));
-  const TfLiteTensor* input_to_output_weights;
-  TF_LITE_ENSURE_OK(context,
-                    GetInputSafe(context, node,
-                                 micro::lstm::full::kInputToOutputWeightsTensor,
-                                 &input_to_output_weights));
-
-  const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
-      context, node, micro::lstm::full::kRecurrentToInputWeightsTensor);
-  const TfLiteTensor* recurrent_to_forget_weights;
-  TF_LITE_ENSURE_OK(
-      context, GetInputSafe(context, node,
-                            micro::lstm::full::kRecurrentToForgetWeightsTensor,
-                            &recurrent_to_forget_weights));
-  const TfLiteTensor* recurrent_to_cell_weights;
-  TF_LITE_ENSURE_OK(
-      context, GetInputSafe(context, node,
-                            micro::lstm::full::kRecurrentToCellWeightsTensor,
-                            &recurrent_to_cell_weights));
-  const TfLiteTensor* recurrent_to_output_weights;
-  TF_LITE_ENSURE_OK(
-      context, GetInputSafe(context, node,
-                            micro::lstm::full::kRecurrentToOutputWeightsTensor,
-                            &recurrent_to_output_weights));
-
-  const TfLiteTensor* projection_weights = GetOptionalInputTensor(
-      context, node, lstm::full::kProjectionWeightsTensor);
-  const TfLiteTensor* projection_bias = GetOptionalInputTensor(
-      context, node, micro::lstm::full::kProjectionBiasTensor);
-
-  lstm_eval::IntegerLstmParameter* integer_lstm_params =
-      &op_data->integer_lstm_param;
-
-  TfLiteTensor* intermediate =
-      context->GetTensor(context, node->intermediates->data[4]);
-  const auto* params =
-      static_cast<TfLiteAffineQuantization*>(intermediate->quantization.params);
-  const int32_t hidden_zp = params->zero_point->data[0];
-
-  // Get bias and perform zero point calculation.
-  // When there is layer normalization, the gate bias does not apply to matmul
-  // directly:
-  //      y = ln(w * x + w * r + w * c) + b.
-  const bool is_layer_norm = op_data->use_layer_norm;
-
-  // Forget gate.
-  const TfLiteTensor* forget_gate_bias =
-      is_layer_norm
-          ? nullptr
-          : GetInput(context, node, micro::lstm::full::kForgetGateBiasTensor);
-  TF_LITE_ENSURE_OK(
-      context,
-      PrecomputeZeroPointTimesWeightWithBias(
-          context, input_zero_point, input_to_forget_weights, forget_gate_bias,
-          &(integer_lstm_params->input_to_forget_effective_bias)));
-
-  TF_LITE_ENSURE_OK(
-      context,
-      PrecomputeZeroPointTimesWeightWithBias(
-          context, output_state_zero_point, recurrent_to_forget_weights,
-          nullptr, &(integer_lstm_params->recurrent_to_forget_effective_bias)));
-
-  // Modulation gate.
-  const TfLiteTensor* cell_gate_bias =
-      is_layer_norm
-          ? nullptr
-          : GetInput(context, node, micro::lstm::full::kCellGateBiasTensor);
-  TF_LITE_ENSURE_OK(
-      context,
-      PrecomputeZeroPointTimesWeightWithBias(
-          context, input_zero_point, input_to_cell_weights, cell_gate_bias,
-          &(integer_lstm_params->input_to_cell_effective_bias)));
-  TF_LITE_ENSURE_OK(
-      context,
-      PrecomputeZeroPointTimesWeightWithBias(
-          context, output_state_zero_point, recurrent_to_cell_weights, nullptr,
-          &(integer_lstm_params->recurrent_to_cell_effective_bias)));
-
-  // Output gate.
-  const TfLiteTensor* output_gate_bias =
-      is_layer_norm
-          ? nullptr
-          : GetInput(context, node, micro::lstm::full::kOutputGateBiasTensor);
-  TF_LITE_ENSURE_OK(
-      context,
-      PrecomputeZeroPointTimesWeightWithBias(
-          context, input_zero_point, input_to_output_weights, output_gate_bias,
-          &(integer_lstm_params->input_to_output_effective_bias)));
-
-  TF_LITE_ENSURE_OK(
-      context,
-      PrecomputeZeroPointTimesWeightWithBias(
-          context, output_state_zero_point, recurrent_to_output_weights,
-          nullptr, &(integer_lstm_params->recurrent_to_output_effective_bias)));
-
-  // Input gate. The calculation is only meaningful for non-cifg case.
-  const TfLiteTensor* input_gate_bias =
-      is_layer_norm
-          ? nullptr
-          : GetInput(context, node, micro::lstm::full::kInputGateBiasTensor);
-  TF_LITE_ENSURE_OK(
-      context,
-      PrecomputeZeroPointTimesWeightWithBias(
-          context, input_zero_point, input_to_input_weights, input_gate_bias,
-          &(integer_lstm_params->input_to_input_effective_bias)));
-  TF_LITE_ENSURE_OK(
-      context,
-      PrecomputeZeroPointTimesWeightWithBias(
-          context, output_state_zero_point, recurrent_to_input_weights, nullptr,
-          &(integer_lstm_params->recurrent_to_input_effective_bias)));
-
-  // Projection bias. The calculation is only meaningful for with projection.
-  TF_LITE_ENSURE_OK(context,
-                    PrecomputeZeroPointTimesWeightWithBias(
-                        context, hidden_zp, projection_weights, projection_bias,
-                        &(integer_lstm_params->projection_effective_bias)));
-  return kTfLiteOk;
-}
-
-// Resize the output and  state tensors based on the sizes of the input tensors.
-// Allocate a temporary scratch tensor. Also check that the sizes of the input
-// tensors match each other.
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
-  OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
-  // const int scratch_tensor_index = op_data->scratch_tensor_index;
-
-  // Check we have all the inputs and outputs we need.
-  bool use_layer_norm = false;
-  if (node->inputs->size == 24) {
-    const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor(
-        context, node, micro::lstm::full::kForgetLayerNormCoefficientsTensor);
-    if (forget_layer_norm_coefficients == nullptr) {
-      use_layer_norm = false;
-    } else {
-      use_layer_norm = true;
-    }
-  } else if (node->inputs->size == 20) {
-    // This is deprecated and is only kept here for backward compatibility.
-    use_layer_norm = false;
-  } else {
-    MicroPrintf("The LSTM Full kernel expects 20 or 24 inputs. Got %d inputs",
-                node->inputs->size);
-    return kTfLiteError;
-  }
-  TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
-  op_data->use_layer_norm = use_layer_norm;
-
-  // Inferring batch size, number of outputs and sequence length and
-  // number of cells from the input tensors.
-  const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(
-      context, node, micro::lstm::full::kInputTensor);
-  const bool is_integer = input->type == kTfLiteInt8;
-  TF_LITE_ENSURE(context, input->dims->size > 1);
-  const auto* params =
-      reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
-          node->builtin_data);
-  const bool time_major = params->time_major;
-  const int n_batch = time_major ? input->dims->data[1] : input->dims->data[0];
-  const int n_input = input->dims->data[2];
-  const TfLiteEvalTensor* input_to_output_weights = tflite::micro::GetEvalInput(
-      context, node, micro::lstm::full::kInputToOutputWeightsTensor);
-  const int n_cell = input_to_output_weights->dims->data[0];
-  TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
-  TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
-  const TfLiteEvalTensor* recurrent_to_output_weights =
-      tflite::micro::GetEvalInput(
-          context, node, micro::lstm::full::kRecurrentToOutputWeightsTensor);
-
-  TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2);
-  TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0],
-                    n_cell);
-  const int n_output = recurrent_to_output_weights->dims->data[1];
-
-  // Check that input tensor dimensions matches with each other.
-  TF_LITE_ENSURE_OK(
-      context, CheckInputTensorDimensions(context, node, n_input, n_output,
-                                          n_cell, use_layer_norm, is_integer));
-  // Get the pointer to output, output_state and cell_state buffer tensors.
-  //  TfLiteEvalTensor* output =
-  //      tflite::micro::GetEvalOutput(context, node,
-  //      micro::lstm::full::kOutputTensor);
-  TfLiteEvalTensor* output_state = tflite::micro::GetMutableEvalInput(
-      context, node, micro::lstm::full::kOutputStateTensor);
-  TFLITE_DCHECK(output_state != nullptr);
-  TfLiteEvalTensor* cell_state = tflite::micro::GetMutableEvalInput(
-      context, node, micro::lstm::full::kCellStateTensor);
-  TFLITE_DCHECK(cell_state != nullptr);
-  // Check the shape of input state tensors.
-  // These tensor may be 1D or 2D. It's fine as long as the total size is
-  // correct.
-  TF_LITE_ENSURE_EQ(context, NumElements(output_state->dims),
-                    n_batch * n_output);
-  TF_LITE_ENSURE_EQ(context, NumElements(cell_state->dims), n_batch * n_cell);
-
-  if (is_integer) {
-    const int num_intermediate_tensors = node->intermediates->size;
-    TF_LITE_ENSURE(context, num_intermediate_tensors == 5);
-  }
-
-  if (is_integer) {
-    // Integer UnidirectionalSequenceLSTM prepare function for 8x8->16.
-    // This code path needs 5 intermediate tensors per Op.
-    // Populate quantization parameters.
-    PopulateQuantizedLstmParams8x8_16(context, node,
-                                      &op_data->integer_lstm_param);
-    // Allocate scratch buffer. Need 6 16bit buffer with size n_batch * n_cell
-    // and 1 8bit buffer with size n_batch * n_cell. We also need 1 32 bit
-    // buffer with size n_batch * n_cell.
-    //
-    // Handle cifg case as well, which might save one buffer.
-
-    int scratch_idx = 0;
-
-    context->RequestScratchBufferInArena(
-        context, n_batch * n_cell * sizeof(int32_t), &(scratch_idx));
-    op_data->scratch_tensor_index = scratch_idx;
-
-    for (int scratch_index = 1; scratch_index < 6; ++scratch_index) {
-      // node->temporaries->data[scratch_index] = op_data->scratch_tensor_index
-      // + scratch_index;
-      context->RequestScratchBufferInArena(
-          context, n_batch * n_cell * sizeof(int32_t), &(scratch_idx));
-      TFLITE_DCHECK(scratch_idx ==
-                    (op_data->scratch_tensor_index + scratch_index));
-    }
-
-    // Populate precomputed zp * weight.
-    TF_LITE_ENSURE_OK(context, PopulatePrecomputedZPTimesWeightsWithBias(
-                                   context, op_data, node));
-  }
-
-  return kTfLiteOk;
-}
-
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
-  const auto* params =
-      reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
-          node->builtin_data);
-  const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
-  //  const bool use_layer_norm = op_data->use_layer_norm;
-  //  const bool time_major = params->time_major;
-
-  const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(
-      context, node, micro::lstm::full::kInputTensor);
-  const TfLiteEvalTensor* input_to_input_weights = tflite::micro::GetEvalInput(
-      context, node, micro::lstm::full::kInputToInputWeightsTensor);
-  const TfLiteEvalTensor* input_to_forget_weights = tflite::micro::GetEvalInput(
-      context, node, micro::lstm::full::kInputToForgetWeightsTensor);
-  const TfLiteEvalTensor* input_to_cell_weights = tflite::micro::GetEvalInput(
-      context, node, micro::lstm::full::kInputToCellWeightsTensor);
-  const TfLiteEvalTensor* input_to_output_weights = tflite::micro::GetEvalInput(
-      context, node, micro::lstm::full::kInputToOutputWeightsTensor);
-  const TfLiteEvalTensor* recurrent_to_input_weights =
-      tflite::micro::GetEvalInput(
-          context, node, micro::lstm::full::kRecurrentToInputWeightsTensor);
-  const TfLiteEvalTensor* recurrent_to_forget_weights =
-      tflite::micro::GetEvalInput(
-          context, node, micro::lstm::full::kRecurrentToForgetWeightsTensor);
-  const TfLiteEvalTensor* recurrent_to_cell_weights =
-      tflite::micro::GetEvalInput(
-          context, node, micro::lstm::full::kRecurrentToCellWeightsTensor);
-  const TfLiteEvalTensor* recurrent_to_output_weights =
-      tflite::micro::GetEvalInput(
-          context, node, micro::lstm::full::kRecurrentToOutputWeightsTensor);
-  const TfLiteEvalTensor* cell_to_input_weights = context->GetEvalTensor(
-      context,
-      node->inputs->data[micro::lstm::full::kCellToInputWeightsTensor]);
-  const TfLiteEvalTensor* cell_to_forget_weights = context->GetEvalTensor(
-      context,
-      node->inputs->data[micro::lstm::full::kCellToForgetWeightsTensor]);
-  const TfLiteEvalTensor* cell_to_output_weights = context->GetEvalTensor(
-      context,
-      node->inputs->data[micro::lstm::full::kCellToOutputWeightsTensor]);
-  const TfLiteEvalTensor* input_gate_bias = context->GetEvalTensor(
-      context, node->inputs->data[micro::lstm::full::kInputGateBiasTensor]);
-
-  const TfLiteEvalTensor* forget_gate_bias = context->GetEvalTensor(
-      context, node->inputs->data[micro::lstm::full::kForgetGateBiasTensor]);
-  const TfLiteEvalTensor* cell_gate_bias = context->GetEvalTensor(
-      context, node->inputs->data[micro::lstm::full::kCellGateBiasTensor]);
-  const TfLiteEvalTensor* output_gate_bias = context->GetEvalTensor(
-      context, node->inputs->data[micro::lstm::full::kOutputGateBiasTensor]);
-
-  const TfLiteEvalTensor* projection_weights = context->GetEvalTensor(
-      context, node->inputs->data[micro::lstm::full::kProjectionWeightsTensor]);
-  const TfLiteEvalTensor* projection_bias = context->GetEvalTensor(
-      context, node->inputs->data[micro::lstm::full::kProjectionBiasTensor]);
-
-  TfLiteEvalTensor* output_state = context->GetEvalTensor(
-      context, node->inputs->data[micro::lstm::full::kOutputStateTensor]);
-  TFLITE_DCHECK(output_state != nullptr);
-  TfLiteEvalTensor* cell_state = context->GetEvalTensor(
-      context, node->inputs->data[micro::lstm::full::kCellStateTensor]);
-  TFLITE_DCHECK(cell_state != nullptr);
-  const TfLiteEvalTensor* input_layer_norm_coefficients =
-      context->GetEvalTensor(
-          context,
-          node->inputs
-              ->data[micro::lstm::full::kInputLayerNormCoefficientsTensor]);
-
-  const TfLiteEvalTensor* forget_layer_norm_coefficients =
-      context->GetEvalTensor(
-          context,
-          node->inputs
-              ->data[micro::lstm::full::kForgetLayerNormCoefficientsTensor]);
-  const TfLiteEvalTensor* cell_layer_norm_coefficients = context->GetEvalTensor(
-      context,
-      node->inputs->data[micro::lstm::full::kCellLayerNormCoefficientsTensor]);
-
-  const TfLiteEvalTensor* output_layer_norm_coefficients =
-      context->GetEvalTensor(
-          context,
-          node->inputs
-              ->data[micro::lstm::full::kOutputLayerNormCoefficientsTensor]);
-
-  TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(
-      context, node, micro::lstm::full::kOutputTensor);
-
-  // Copy out the LSTM specific params so they can be passed in the function.
-  TfLiteLSTMParams lstm_params;
-  lstm_params.activation = params->activation;
-  lstm_params.cell_clip = params->cell_clip;
-  lstm_params.proj_clip = params->proj_clip;
-  lstm_params.asymmetric_quantize_inputs = params->asymmetric_quantize_inputs;
-  switch (input_to_output_weights->type) {
-    case kTfLiteInt8: {
-      const bool is_hybrid = input->type == kTfLiteFloat32;
-      if (is_hybrid) {
-        MicroPrintf(" hybrid type is not supported.");
-        return kTfLiteError;
-
-      } else {
-        TfLiteEvalTensor* scratch[6];
-        // Allocate scratch buffer. Need 6 16bit buffer with size n_batch *
-        // n_cell
-        // and 1 8bit buffer with size n_batch * n_cell. We also need 1 32 bit
-        // buffer with size n_batch * n_cell.
-        //
-        // Handle cifg case as well, which might save one buffer.
-
-        const auto* tmp_params =
-            reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
-                node->builtin_data);
-        const bool time_major = tmp_params->time_major;
-        for (int scratch_index = 0; scratch_index < 6; ++scratch_index) {
-          TFLITE_DCHECK(context != nullptr);
-          TFLITE_DCHECK(context->GetScratchBuffer != nullptr);
-          int32_t* scratch_tensor =
-              static_cast<int32_t*>(context->GetScratchBuffer(
-                  context, op_data->scratch_tensor_index + scratch_index));
-          scratch[scratch_index] = (TfLiteEvalTensor*)scratch_tensor;
-        }
-        /*
-                                TF_LITE_ENSURE_OK(context,
-                                                GetScratchSafe(context, node, 0,
-           &scratch0));
-
-                                TF_LITE_ENSURE_OK(context,
-                                                GetScratchSafe(context, node, 1,
-           &scratch1));
-
-                                TF_LITE_ENSURE_OK(context,
-                                                GetScratchSafe(context, node, 2,
-           &scratch2));
-
-                                TF_LITE_ENSURE_OK(context,
-                                                GetScratchSafe(context, node, 3,
-           &scratch3));
-
-                                TF_LITE_ENSURE_OK(context,
-                                                GetScratchSafe(context, node, 4,
-           &scratch4));
-
-                                TF_LITE_ENSURE_OK(context,
-                                                GetScratchSafe(context, node, 5,
-           &scratch5));
-        */
-        return lstm_eval::EvalInteger8x8_16(
-            context, node, input, input_to_input_weights,
-            input_to_forget_weights, input_to_cell_weights,
-            input_to_output_weights, recurrent_to_input_weights,
-            recurrent_to_forget_weights, recurrent_to_cell_weights,
-            recurrent_to_output_weights, cell_to_input_weights,
-            cell_to_forget_weights, cell_to_output_weights,
-            input_layer_norm_coefficients, forget_layer_norm_coefficients,
-            cell_layer_norm_coefficients, output_layer_norm_coefficients,
-            input_gate_bias, forget_gate_bias, cell_gate_bias, output_gate_bias,
-            projection_weights, projection_bias, &lstm_params,
-            /*forward_sequence=*/true, time_major, &op_data->integer_lstm_param,
-            output_state, cell_state, output, scratch[0], scratch[1],
-            scratch[2], scratch[3], scratch[4], scratch[5]);
-      }
-    }
-
-    default:
-      MicroPrintf("Type %s is not currently supported.",
-                  TfLiteTypeGetName(input_to_output_weights->type));
-      return kTfLiteError;
-  }
-  return kTfLiteOk;
-}
-//}  // namespace unidirectional_sequence_lstm
-
-}  // namespace micro
-}  // namespace ops
-
 TFLMRegistration Register_UNIDIRECTIONAL_SEQUENCE_LSTM() {
-  return tflite::micro::RegisterOp(ops::micro::Init, ops::micro::Prepare,
-                                   ops::micro::Eval);
+  return tflite::micro::RegisterOp(UnidirectionalSequenceLstmInit,
+                                   UnidirectionalSequenceLstmPrepare,
+                                   UnidirectionalSequenceLstmEval);
 }
 }  // namespace tflite
diff --git a/tensorflow/lite/micro/tools/make/ext_libs/ndsplib-hifi5.patch b/tensorflow/lite/micro/tools/make/ext_libs/ndsplib-hifi5.patch
index 125e3f7..118845c 100644
--- a/tensorflow/lite/micro/tools/make/ext_libs/ndsplib-hifi5.patch
+++ b/tensorflow/lite/micro/tools/make/ext_libs/ndsplib-hifi5.patch
@@ -12,3 +12,20 @@
    #define onchip
    #define NASSERT(x) {(void)__builtin_expect((x)!=0,1);}
  #else
+diff --git a/library/include_private/common.h b/library/include_private/common.h
+index 2eaf70f..9df811c 100644
+--- a/library/include_private/common.h
++++ b/library/include_private/common.h
+@@ -172,6 +172,12 @@ __pragma (warning(pop))
+ #if defined(COMPILER_XTENSA) || defined(COMPILER_GNU)
+ #define DISCARD_FUN(retval_type,name,arglist) \
+ __asm__(".type "#name", @object\n\t.global "#name"\n\t.align 4\n\t"#name":\n\t.long 0x49438B96,0x4D73F192\n\t");
++
++#define DISCARD_FUN_FOR_NONVOID_RETURN(retval_type,name,arglist) \
++__attribute__ ((section ("/DISCARD/"))) \
++retval_type name arglist \
++{ return (retval_type) 0; }
++
+ #endif
+ 
+ /*------ LIST OF DEFINES DEPENDING ON ISA OPTIONS ------*/
diff --git a/tensorflow/lite/micro/tools/make/ext_libs/xa_nnlib_hifi4.patch b/tensorflow/lite/micro/tools/make/ext_libs/xa_nnlib_hifi4.patch
index 227ee92..1bb15aa 100644
--- a/tensorflow/lite/micro/tools/make/ext_libs/xa_nnlib_hifi4.patch
+++ b/tensorflow/lite/micro/tools/make/ext_libs/xa_nnlib_hifi4.patch
@@ -1,207 +1,5 @@
-From 0a68f2ffa640d1b52314278cec838384722eb1d0 Mon Sep 17 00:00:00 2001
-From: William Huang <yushh@google.com>
-Date: Tue, 16 May 2023 09:18:55 +0000
-Subject: [PATCH] Optimize Xtensa transpose convolution for more kernel sizes
- and input channels.
-
-Previously, there were three code paths, in decreasing performance:
-
-1. Kernel size (H*W) multiple of 4, input channels multiple of 16
-2. Kernel size (H*W) multiple of 4, input channels multiple of 4
-3. Others (unoptimized case)
-
-This patch reduces them to the follow two cases:
-
-1. Input channels multiple of 4
-2. Others (unoptimized case)
-
-Original CL=cl/516144094
-
-BUG=227374718
-
-Signed-off-by: William Huang <yushh@google.com>
-
-Optimize Xtensa CONV2D circular buffer copy.
-
-In Xtensa's CONV2D kernel, data is shuffled around and padded so the 2D
-convolution turns into sequential vector products. Unfortunately, this
-process is somewhat slow, and the overhead is especially high for small
-vector lengths.
-
-This patch introduces the following:
-
-- Faster code path for no padding (since our models use VALID padding,
-  i.e., no padding at all)
-- Manual loop if array is small and memcpy if array is large
-- Skip memset on padded channels as the corresponding kernels are
-  already zero
-
-BUG=249796929
-
-Signed-off-by: William Huang <yushh@google.com>
-
-Add implementation for zero-copy CONV2D kernels.
-
-The previous `xa_nn_conv2d_std_sym8sxsym16s` implementation shuffles the
-input tensor into a circular buffer, flattening the dimensions, so that
-the 2D convolution turns into sequential vector products. However, this
-created significant overhead for layers where the resulting vector
-lengths are small.
-
-This patch implements an alternative zero-copy method that takes
-advantage of two facts:
-
-1. If `x_padding == 0`, the width dimension is automatically flattened
-   with the channel dimension, and we need only `kernel_height`
-   sequential vector products, even without the data shuffling
-2. Similar to the loop tiling done in
-   `xa_nn_matXvec_sym8sxsym16s_sym16s_circ`, we can tile the `out_width`
-   and `out_channels` dimensions, achieving the throughput of
-   `_xa_nn_dot_product_2row_4vec_mat_vecs_4bytes_aligned` (i.e., 1.6
-   MULAAAAQs/cycle), even when `out_height < 2`
-
-As a result, the patch significantly benefits layers where the kernel
-and output heights are small, leading to 25%+ cycle reductions in some
-use cases.
-
-Signed-off-by: William Huang <yushh@google.com>
----
- .../cnn/hifi4/xa_nn_conv2d_std_circ_buf.c     |  84 +++++++-
- .../cnn/hifi4/xa_nn_conv2d_std_state.h        |  15 ++
- .../cnn/hifi4/xa_nn_conv2d_std_sym8sxsym16s.c | 203 +++++++++++++++---
- .../hifi4/xa_nn_transpose_conv_sym8sxsym16s.c |  36 +---
- 4 files changed, 275 insertions(+), 63 deletions(-)
-
-diff --git a/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_circ_buf.c b/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_circ_buf.c
-index f8adba2..1a5f186 100644
---- a/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_circ_buf.c
-+++ b/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_circ_buf.c
-@@ -642,7 +642,8 @@ VOID conv2d_std_init_cir_buf(
- }
- 
- // Add x_stride (but not more than kernel_width) x (input_height x input_channels) new planes to circular buffer
--VOID conv2d_std_update_cir_buf(
-+// Slow version of conv2d_std_update_cir_buf with fewer requirements
-+VOID conv2d_std_update_cir_buf_slow(
-     WORD32 input_channels,
-     WORD32 input_channels_pad,
-     WORD32 input_bytewidth,
-@@ -742,6 +743,87 @@ VOID conv2d_std_update_cir_buf(
-   *pp_inp = (VOID *)p_inp;
- }
- 
-+// Add x_stride (but not more than kernel_width) x (input_height x input_channels) new planes to circular buffer
-+VOID conv2d_std_update_cir_buf(
-+    WORD32 input_channels,
-+    WORD32 input_channels_pad,
-+    WORD32 input_bytewidth,
-+    WORD32 input_width,
-+    WORD32 input_height,
-+    WORD32 y_padding,
-+    WORD32 y_b_pad,
-+    WORD32 x_padding,
-+    WORD32 kernel_width,
-+    WORD32 x_stride,
-+    VOID **pp_inp,
-+    WORD32 idx_beg_inp_width_pad,
-+    xa_nn_conv_state_t *p_state)
-+{
-+  if (y_padding != 0 || y_b_pad != 0 || x_padding != 0) {
-+    conv2d_std_update_cir_buf_slow(
-+      input_channels,
-+      input_channels_pad,
-+      input_bytewidth,
-+      input_width,
-+      input_height,
-+      y_padding,
-+      y_b_pad,
-+      x_padding,
-+      kernel_width,
-+      x_stride,
-+      pp_inp,
-+      idx_beg_inp_width_pad,
-+      p_state
-+    );
-+    return;
-+  }
-+
-+  WORD32 i,k;
-+  WORD8 *p_inp = (WORD8 *)*pp_inp;
-+  WORD32 planes_to_add = x_stride > kernel_width ? kernel_width : x_stride;
-+  WORD32 planes_to_keep = kernel_width - planes_to_add;
-+
-+  // Copy 'planes_to_add' planes of data to circular buffer
-+  AE_ADDCIRC16X4_XC((ae_int16x4 *)p_state->cir_buf.p_curr, planes_to_add * input_channels_pad * input_bytewidth);
-+  WORD8 *p_dst = (WORD8 *)p_state->cir_buf.p_curr;
-+  AE_ADDCIRC16X4_XC((ae_int16x4 *)p_dst, planes_to_keep * input_channels_pad * input_bytewidth);
-+
-+  WORD32 copy_inp_width = planes_to_add;
-+  WORD32 to_skip_inp_width = x_stride - planes_to_add;     // Non-zero for x_stride > kernel_width
-+
-+  int size = input_channels * input_bytewidth;
-+  if (size <= 32) {
-+    for(i=0;i<input_height;i++)
-+    {
-+      for(k=0;k<copy_inp_width;k++)
-+      {
-+        for (int j = 0; j < size; ++j) {
-+          p_dst[j] = p_inp[j];
-+        }
-+        AE_ADDCIRC16X4_XC((ae_int16x4 *)p_dst, input_channels_pad * input_bytewidth);
-+        p_inp += input_channels * input_bytewidth;
-+      }
-+      AE_ADDCIRC16X4_XC((ae_int16x4 *)p_dst, planes_to_keep * input_channels_pad * input_bytewidth);
-+      p_inp += (input_width - copy_inp_width) * input_channels * input_bytewidth;
-+    }
-+  } else {
-+    for(i=0;i<input_height;i++)
-+    {
-+      for(k=0;k<copy_inp_width;k++)
-+      {
-+        memcpy(p_dst, p_inp, input_channels * input_bytewidth);
-+        AE_ADDCIRC16X4_XC((ae_int16x4 *)p_dst, input_channels_pad * input_bytewidth);
-+        p_inp += input_channels * input_bytewidth;
-+      }
-+      AE_ADDCIRC16X4_XC((ae_int16x4 *)p_dst, planes_to_keep * input_channels_pad * input_bytewidth);
-+      p_inp += (input_width - copy_inp_width) * input_channels * input_bytewidth;
-+    }
-+  }
-+  p_inp += (-input_height * input_width + copy_inp_width + to_skip_inp_width) * input_channels * input_bytewidth;
-+
-+  *pp_inp = (VOID *)p_inp;
-+}
-+
- VOID xa_nn_dilated_conv2d_std_load_cir_buf_asym8(
-     WORD32 input_channels,
-     WORD32 input_channels_pad,
-diff --git a/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_state.h b/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_state.h
-index a2ba355..8d33bad 100644
---- a/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_state.h
-+++ b/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_state.h
-@@ -214,6 +214,21 @@ VOID conv2d_std_init_cir_buf(
-     VOID **pp_inp,
-     xa_nn_conv_state_t *p_state);
- 
-+VOID conv2d_std_update_cir_buf_slow(
-+    WORD32 input_channels,
-+    WORD32 input_channels_pad,
-+    WORD32 input_bytewidth,
-+    WORD32 input_width,
-+    WORD32 input_height,
-+    WORD32 y_padding,
-+    WORD32 y_b_pad,
-+    WORD32 x_padding,
-+    WORD32 kernel_width,
-+    WORD32 x_stride,
-+    VOID **pp_inp,
-+    WORD32 idx_beg_inp_width_pad,
-+    xa_nn_conv_state_t *p_state);
-+
- VOID conv2d_std_update_cir_buf(
-     WORD32 input_channels,
-     WORD32 input_channels_pad,
 diff --git a/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_sym8sxsym16s.c b/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_sym8sxsym16s.c
-index 92721bc..6f868be 100644
+index b9905e9..990b713 100644
 --- a/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_sym8sxsym16s.c
 +++ b/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_sym8sxsym16s.c
 @@ -49,6 +49,24 @@ static inline ae_int32x2 MultiplyByQuantizedMultiplier_ref(ae_int64 d_x,
@@ -229,78 +27,35 @@
  static WORD32 conv_x_left_pad(
      WORD32 x_padding,
      WORD32 kernel_width,
-@@ -238,41 +256,166 @@ WORD32 xa_nn_conv2d_std_per_chan_sym8sxsym16s(
-   WORD32 y_b_pad = kernel_height + (out_height - 1) * y_stride - (y_padding + input_height);
-   y_b_pad = y_b_pad < 0 ? 0 : y_b_pad;
+@@ -129,6 +147,160 @@ static WORD32 conv_x_right_pad(
+   return out_width_over_x_r_pad;
+ }
  
--  conv2d_std_init_cir_buf(input_channels, input_channels_pad, input_bytewidth, input_width, input_height, y_padding, y_b_pad, x_padding_var, kernel_width, x_stride, (VOID**)&pp_inp, p_state);
-+  if (x_padding || (input_channels & 0x3) || (out_channels & 0x3) || (out_width & 0x1)) {
-+    conv2d_std_init_cir_buf(input_channels, input_channels_pad, input_bytewidth, input_width, input_height, y_padding, y_b_pad, x_padding_var, kernel_width, x_stride, (VOID**)&pp_inp, p_state);
- 
--  // Index to padded input width
--  WORD32 idx_beg_inp_width_pad = kernel_width - x_stride;
--  idx_beg_inp_width_pad = idx_beg_inp_width_pad < 0 ? 0 : idx_beg_inp_width_pad;
-+    // Index to padded input width
-+    WORD32 idx_beg_inp_width_pad = kernel_width - x_stride;
-+    idx_beg_inp_width_pad = idx_beg_inp_width_pad < 0 ? 0 : idx_beg_inp_width_pad;
- 
- 
--  // Process Loop to compute one output plane [out_height x out_channels] per iteration
--  for(j=0;j<out_width-out_width_over_x_pad-out_width_over_x_r_pad;j++)
--  {
--    // Add x_stride x (input_height x input_channels) new planes to circular buffer
--    conv2d_std_update_cir_buf(input_channels, input_channels_pad, input_bytewidth, input_width, input_height, y_padding, y_b_pad, x_padding_var, kernel_width, x_stride, (VOID**)&pp_inp, idx_beg_inp_width_pad, p_state);
-+    // Process Loop to compute one output plane [out_height x out_channels] per iteration
-+    for(j=0;j<out_width-out_width_over_x_pad-out_width_over_x_r_pad;j++)
++static WORD32 xa_nn_conv2d_std_per_chan_sym8sxsym16s_no_circ_buf(
++    WORD16* __restrict__ p_out,
++    const WORD16* __restrict__ p_inp,
++    const WORD8* __restrict__ p_kernel,
++    const WORD64* __restrict__ p_bias,
++    WORD32 input_height,
++    WORD32 input_width,
++    WORD32 input_channels,
++    WORD32 kernel_height,
++    WORD32 kernel_width,
++    WORD32 out_channels,
++    WORD32 x_stride,
++    WORD32 y_stride,
++    WORD32 x_padding,
++    WORD32 y_padding,
++    WORD32 out_height,
++    WORD32 out_width,
++    WORD32 input_zero_bias,
++    WORD32 * p_out_multiplier,
++    WORD32 * p_out_shift,
++    WORD32 out_zero_bias,
++    WORD32 out_data_format
++    )
 +    {
-+      // Add x_stride x (input_height x input_channels) new planes to circular buffer
-+      conv2d_std_update_cir_buf(input_channels, input_channels_pad, input_bytewidth, input_width, input_height, y_padding, y_b_pad, x_padding_var, kernel_width, x_stride, (VOID**)&pp_inp, idx_beg_inp_width_pad, p_state);
- 
--    // Update index to input width padded
--    idx_beg_inp_width_pad += x_stride;
-+      // Update index to input width padded
-+      idx_beg_inp_width_pad += x_stride;
- 
--    // Convolution using matXvec with matrix as circular buffer
--    xa_nn_matXvec_sym8sxsym16s_sym16s_circ
--      (p_out /* output */
--       ,p_state->cir_buf.p_curr/* matrix: rows x cols */
--       ,p_state->p_kernel_padded /* vec: cols */
--       ,p_bias /* bias */
--       ,out_height /* rows */
--       ,input_channels_pad * kernel_width * kernel_height /* cols */
--       ,input_channels_pad * kernel_width * y_stride/* row_offset */
--       ,out_channels /* vec_count */
--       ,input_channels_pad * kernel_width * kernel_height /* vec_stride */
--       ,out_channels_offset /* out_col_offset */
--       ,out_height_offset /* out_row_offset */
--       ,input_zero_bias
--       ,p_out_multiplier
--       ,p_out_shift
--       ,out_zero_bias
--      );
--    p_out += out_width_offset;
-+      // Convolution using matXvec with matrix as circular buffer
-+      xa_nn_matXvec_sym8sxsym16s_sym16s_circ
-+        (p_out /* output */
-+        ,p_state->cir_buf.p_curr/* matrix: rows x cols */
-+        ,p_state->p_kernel_padded /* vec: cols */
-+        ,p_bias /* bias */
-+        ,out_height /* rows */
-+        ,input_channels_pad * kernel_width * kernel_height /* cols */
-+        ,input_channels_pad * kernel_width * y_stride/* row_offset */
-+        ,out_channels /* vec_count */
-+        ,input_channels_pad * kernel_width * kernel_height /* vec_stride */
-+        ,out_channels_offset /* out_col_offset */
-+        ,out_height_offset /* out_row_offset */
-+        ,input_zero_bias
-+        ,p_out_multiplier
-+        ,p_out_shift
-+        ,out_zero_bias
-+        );
-+      p_out += out_width_offset;
-+    }
-+  } else {
++
 +    const WORD16 *p_dst0_0 = p_out + 0;
 +    const WORD16 *p_dst0_1 = p_out + 1;
 +    const WORD16 *p_dst0_2 = p_out + 2;
@@ -310,8 +65,8 @@
 +    const WORD16 *p_dst1_2 = p_out + out_channels + 2;
 +    const WORD16 *p_dst1_3 = p_out + out_channels + 3;
 +    int kernel_out_ch_offset = kernel_height * kernel_width * input_channels;
-+    int input_x_offset = input_channels * x_stride / 4;
-+    int p_inp_vec_stride = input_width * input_channels / 4;
++    int input_x_offset = (input_channels * x_stride) / 4;
++    int p_inp_vec_stride = (input_width * input_channels) / 4;
 +    int p_kern_vec_stride = kernel_width * input_channels;
 +    int vec_len = kernel_width * input_channels;
 +    for (int out_y = 0; out_y < out_height; ++out_y) {
@@ -325,6 +80,7 @@
 +          ae_int64 out1_1 = p_bias[out_ch + 1];
 +          ae_int64 out1_2 = p_bias[out_ch + 2];
 +          ae_int64 out1_3 = p_bias[out_ch + 3];
++
 +          out0_0 = AE_SLAI64(out0_0, 8);
 +          out0_1 = AE_SLAI64(out0_1, 8);
 +          out0_2 = AE_SLAI64(out0_2, 8);
@@ -333,10 +89,11 @@
 +          out1_1 = AE_SLAI64(out1_1, 8);
 +          out1_2 = AE_SLAI64(out1_2, 8);
 +          out1_3 = AE_SLAI64(out1_3, 8);
++
 +          int in_x_o = out_x * x_stride;
 +          int in_y_o = out_y * y_stride - y_padding;
 +          int k_y_min = -in_y_o;
-+          int k_y_max = input_width - in_y_o;
++          int k_y_max = input_height - in_y_o;
 +          k_y_min = (k_y_min < 0) ? 0 : k_y_min;
 +          k_y_min = (k_y_min < kernel_height) ? k_y_min : kernel_height;
 +          k_y_max = (k_y_max < 0) ? 0 : k_y_max;
@@ -382,6 +139,7 @@
 +              AE_MULAAAAQ16(out1_3, d_inp1, d_kern3);
 +            }
 +          }
++
 +          out0_0 = AE_SRAI64(out0_0, 8);
 +          out0_1 = AE_SRAI64(out0_1, 8);
 +          out0_2 = AE_SRAI64(out0_2, 8);
@@ -390,6 +148,7 @@
 +          out1_1 = AE_SRAI64(out1_1, 8);
 +          out1_2 = AE_SRAI64(out1_2, 8);
 +          out1_3 = AE_SRAI64(out1_3, 8);
++
 +          ae_int32x2 acc_vec0 = MultiplyByQuantizedMultiplier_x2_opt(
 +              out0_0, out1_0, p_out_multiplier[out_ch + 0],
 +              p_out_shift[out_ch + 0]);
@@ -423,70 +182,45 @@
 +        p_dst1_3 += out_channels;
 +      }
 +    }
++  return 0;
++}
++
+ WORD32 xa_nn_conv2d_std_per_chan_sym8sxsym16s(
+     WORD16* __restrict__ p_out,
+     const WORD16* __restrict__ p_inp,
+@@ -180,6 +352,35 @@ WORD32 xa_nn_conv2d_std_per_chan_sym8sxsym16s(
+     XA_NNLIB_ARG_CHK_COND((p_out_shift[itr] < -31 || p_out_shift[itr] > 31), -1);
    }
  
-   return 0;
-diff --git a/algo/kernels/cnn/hifi4/xa_nn_transpose_conv_sym8sxsym16s.c b/algo/kernels/cnn/hifi4/xa_nn_transpose_conv_sym8sxsym16s.c
-index 7f31b75..a010d45 100644
---- a/algo/kernels/cnn/hifi4/xa_nn_transpose_conv_sym8sxsym16s.c
-+++ b/algo/kernels/cnn/hifi4/xa_nn_transpose_conv_sym8sxsym16s.c
-@@ -157,7 +157,7 @@ int xa_nn_transpose_conv_sym8sxsym16s(WORD16* output_data,
- 	 */
- 	if(input_data && filter_data && output_data && scratch_buffer &&
- 			(((unsigned int)input_data&0x7)==0) && (((unsigned int)filter_data&0x3)==0) && (((unsigned int)output_data&0x7) == 0) &&
--			(((unsigned int)scratch_buffer&0x7) == 0) && ((input_depth&0xF)==0) && ((filter_height*filter_width&0x3)==0))
-+			(((unsigned int)scratch_buffer&0x7) == 0) && ((input_depth&0x3)==0))
- 	{
- 		{
- 			//tbd : batch = 1, need to handle other values and in_x_min/max= 0 .. need toc heck for other values
-@@ -180,7 +180,8 @@ int xa_nn_transpose_conv_sym8sxsym16s(WORD16* output_data,
- 					filt_y_max = (filt_y_max < filter_height) ? filt_y_max : filter_height;
- 					filt_y_max = (filt_y_max < 0) ? 0 : filt_y_max;
- 					pinp =  (WORD16*)&input_data[in_y*input_width*input_depth+in_x*input_depth];
--					for (int in_channel = 0; in_channel < input_depth; in_channel+=16)
-+					int in_channel = 0;
-+					for (; in_channel + 15 < input_depth; in_channel+=16)
- 					{
- 						ae_int16x4 d_inp, d_inp1, d_inp2, d_inp3;
- 						AE_L16X4_IP(d_inp, (ae_int16x4*)pinp, sizeof(WORD64));
-@@ -235,36 +236,7 @@ int xa_nn_transpose_conv_sym8sxsym16s(WORD16* output_data,
- 							}
- 						}
- 					}
--				}
--			}
--		}
--	}
--	else if(input_data && filter_data && output_data && scratch_buffer &&
--			(((unsigned int)input_data&0x7)==0) && (((unsigned int)filter_data&0x3)==0) && (((unsigned int)output_data&0x7) == 0) &&
--			(((unsigned int)scratch_buffer&0x7) == 0) && ((input_depth&0x3)==0) && ((filter_height*filter_width&0x3)==0))
--	{
--		{
--			//tbd : batch = 1, need to handle other values and in_x_min/max= 0 .. need toc heck for other values
--			for (int in_y = 0; in_y < input_height; ++in_y)
--			{
--				for (int in_x = 0; in_x < input_width; ++in_x)
--				{
--					const int out_x_orig = in_x*stride_width - pad_width;
--					const int out_y_orig = in_y*stride_height - pad_height;
--					int filt_x_min = -out_x_orig; 
--					int filt_x_max = output_width - out_x_orig; 
--					int filt_y_min = -out_y_orig; 
--					int filt_y_max = output_height - out_y_orig; 
--					filt_x_min = (filt_x_min < filter_width) ? filt_x_min : filter_width;
--					filt_x_min = (filt_x_min < 0) ? 0 : filt_x_min;
--					filt_x_max = (filt_x_max < filter_width) ? filt_x_max : filter_width;
--					filt_x_max = (filt_x_max < 0) ? 0 : filt_x_max;
--					filt_y_min = (filt_y_min < filter_height) ? filt_y_min : filter_height;
--					filt_y_min = (filt_y_min < 0) ? 0 : filt_y_min;
--					filt_y_max = (filt_y_max < filter_height) ? filt_y_max : filter_height;
--					filt_y_max = (filt_y_max < 0) ? 0 : filt_y_max;
--					pinp =  (WORD16*)&input_data[in_y*input_width*input_depth+in_x*input_depth];
--					for (int in_channel = 0; in_channel < input_depth; in_channel+=4)
-+					for (; in_channel + 3 < input_depth; in_channel+=4)
- 					{
- 						ae_int16x4 d_inp;
- 						AE_L16X4_IP(d_inp, (ae_int16x4*)pinp, sizeof(WORD64));
--- 
-2.41.0.162.gfafddb0af9-goog
-
++  if ( !(x_padding) && !(input_channels & 0x3) && !(out_channels & 0x3) && !(out_width & 0x1) && (out_data_format == 0) && ((out_width-1)*x_stride <=(input_width-kernel_width) ) )
++  {
++    int ret_val=0;
++    ret_val=xa_nn_conv2d_std_per_chan_sym8sxsym16s_no_circ_buf(p_out,
++                                                              p_inp,
++                                                              p_kernel,
++                                                              p_bias,
++                                                              input_height,
++                                                              input_width,
++                                                              input_channels,
++                                                              kernel_height,
++                                                              kernel_width,
++                                                              out_channels,
++                                                              x_stride,
++                                                              y_stride,
++                                                              x_padding,
++                                                              y_padding,
++                                                              out_height,
++                                                              out_width,
++                                                              input_zero_bias,
++                                                              p_out_multiplier,
++                                                              p_out_shift,
++                                                              out_zero_bias,
++                                                              out_data_format
++                                                            );
++
++    return ret_val;
++  }
++
+   WORD32 j;
+   WORD32 input_bytewidth = 2;
+   VOID *pp_inp = (VOID *)p_inp;
diff --git a/tensorflow/lite/micro/tools/make/ext_libs/xa_nnlib_hifi5.patch b/tensorflow/lite/micro/tools/make/ext_libs/xa_nnlib_hifi5.patch
deleted file mode 100644
index 9d95c63..0000000
--- a/tensorflow/lite/micro/tools/make/ext_libs/xa_nnlib_hifi5.patch
+++ /dev/null
@@ -1,36 +0,0 @@
-diff --git a/algo/kernels/fc/hifi4/xa_nn_fully_connected.c b/algo/kernels/fc/hifi4/xa_nn_fully_connected.c
-index 26a2b73..61f0a64 100644
---- a/algo/kernels/fc/hifi4/xa_nn_fully_connected.c
-+++ b/algo/kernels/fc/hifi4/xa_nn_fully_connected.c
-@@ -298,7 +298,6 @@ WORD32 xa_nn_fully_connected_sym8sxasym8s_asym8s
-   XA_NNLIB_ARG_CHK_PTR(p_out, -1);
-   XA_NNLIB_ARG_CHK_PTR(p_weight, -1);
-   XA_NNLIB_ARG_CHK_PTR(p_inp, -1);
--  XA_NNLIB_ARG_CHK_PTR(p_bias, -1);
-   /* Pointer alignment checks */
- #if 0
-   XA_NNLIB_ARG_CHK_ALIGN(p_out, ALIGNMENT, -1);
-@@ -310,7 +309,8 @@ WORD32 xa_nn_fully_connected_sym8sxasym8s_asym8s
-   XA_NNLIB_ARG_CHK_ALIGN(p_out, sizeof(WORD8), -1);
-   XA_NNLIB_ARG_CHK_ALIGN(p_weight, sizeof(WORD8), -1);
-   XA_NNLIB_ARG_CHK_ALIGN(p_inp, sizeof(WORD8), -1);
--  XA_NNLIB_ARG_CHK_ALIGN(p_bias, sizeof(WORD32), -1);
-+  if (p_bias != NULL)
-+    XA_NNLIB_ARG_CHK_ALIGN(p_bias, sizeof(WORD32), -1);
- #endif
-   /* Basic Parameter checks */
-   XA_NNLIB_ARG_CHK_COND((out_depth <= 0), -1);
-diff --git a/algo/kernels/matXvec/hifi5/xa_nn_matXvec_sym8sxasym8s.c b/algo/kernels/matXvec/hifi5/xa_nn_matXvec_sym8sxasym8s.c
-index 5350cbe..a91e043 100644
---- a/algo/kernels/matXvec/hifi5/xa_nn_matXvec_sym8sxasym8s.c
-+++ b/algo/kernels/matXvec/hifi5/xa_nn_matXvec_sym8sxasym8s.c
-@@ -704,7 +704,8 @@ WORD32 xa_nn_matXvec_sym8sxasym8s_asym8s(
-   XA_NNLIB_ARG_CHK_PTR(p_mat1, -1);
-   XA_NNLIB_ARG_CHK_PTR(p_vec1, -1);
-   /* Pointer alignment checks */
--  XA_NNLIB_ARG_CHK_ALIGN(p_bias, sizeof(WORD32), -1);
-+  if (p_bias != NULL)
-+    XA_NNLIB_ARG_CHK_ALIGN(p_bias, sizeof(WORD32), -1);
-   /* Basic Parameter checks */
-   XA_NNLIB_ARG_CHK_COND((rows <= 0), -1);
-   XA_NNLIB_ARG_CHK_COND((cols1 <= 0), -1);
diff --git a/tensorflow/lite/micro/tools/make/ext_libs/xtensa_download.sh b/tensorflow/lite/micro/tools/make/ext_libs/xtensa_download.sh
index fb45123..d80855f 100755
--- a/tensorflow/lite/micro/tools/make/ext_libs/xtensa_download.sh
+++ b/tensorflow/lite/micro/tools/make/ext_libs/xtensa_download.sh
@@ -44,13 +44,13 @@
 fi
 
 if [[ ${2} == "hifi4" ]]; then
-  LIBRARY_URL="http://github.com/foss-xtensa/nnlib-hifi4/raw/master/archive/xa_nnlib_hifi4_10_14_2022.zip"
+  LIBRARY_URL="http://github.com/foss-xtensa/nnlib-hifi4/raw/master/archive/xa_nnlib_hifi4_09_05_2023.zip"
   LIBRARY_DIRNAME="xa_nnlib_hifi4"
-  LIBRARY_MD5="2bf3c1c7fd5a23f157babc8e24fd2c55"
+  LIBRARY_MD5="2a54e056aef73a4fcffde4643998501a"
 elif [[ ${2} == "hifi5" ]]; then
-  LIBRARY_URL="http://github.com/foss-xtensa/nnlib-hifi5/raw/master/archive/xa_nnlib_hifi5_12_19_2022.zip"
+  LIBRARY_URL="http://github.com/foss-xtensa/nnlib-hifi5/raw/master/archive/xa_nnlib_hifi5_09_05_2023.zip"
   LIBRARY_DIRNAME="xa_nnlib_hifi5"
-  LIBRARY_MD5="83306809191f42a064bde688b94e1eb1"
+  LIBRARY_MD5="1deb55ef200bf5dbedc70b99b02140c0"
 elif [[ ${2} == "vision_p6" ]]; then
   LIBRARY_URL="https://github.com/foss-xtensa/tflmlib_vision/raw/main/archive/xi_tflmlib_vision_p6_22_06_29.zip"
   LIBRARY_DIRNAME="xi_tflmlib_vision_p6"