Enable int16_t support for CMSIS-NN LSTM kernel (#2551)
BUG=New feature in CMSIS-NN
diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/unidirectional_sequence_lstm.cc b/tensorflow/lite/micro/kernels/cmsis_nn/unidirectional_sequence_lstm.cc
index 75ba5ea..49da4d9 100644
--- a/tensorflow/lite/micro/kernels/cmsis_nn/unidirectional_sequence_lstm.cc
+++ b/tensorflow/lite/micro/kernels/cmsis_nn/unidirectional_sequence_lstm.cc
@@ -56,6 +56,12 @@
arm_vector_sum_s8(kernel_sum, size1, size2, weights, offset, biases);
}
+void CMSIS_NN_VectorSum(int64_t* kernel_sum, const int32_t size1,
+ const int32_t size2, const int8_t* weights,
+ const int32_t offset, const int64_t* biases) {
+ arm_vector_sum_s8_s64(kernel_sum, size1, size2, weights, offset, biases);
+}
+
template <typename BiasType>
TfLiteStatus CMSIS_NN_PortOpData(TfLiteContext* context, OpDataLSTM* params_ref,
const LSTMKernelContents& kernel_content,
@@ -289,6 +295,32 @@
return kTfLiteOk;
}
+TfLiteStatus CMSIS_NN_EvalInteger16x8_16Lstm(
+ const OpData& op_data, const LSTMKernelContents& kernel_content,
+ const LSTMBuffers<int16_t>& buffers) {
+ TFLITE_DCHECK(
+ kernel_content.GetInternalTensor(tflite::kLstmInputTensor)->dims->size >=
+ 2 &&
+ kernel_content.GetInternalTensor(tflite::kLstmInputTensor)->dims->size <=
+ 3);
+
+ const int16_t* input = tflite::micro::GetOptionalTensorData<int16_t>(
+ kernel_content.GetInternalTensor(tflite::kLstmInputTensor));
+ int16_t* output =
+ tflite::micro::GetTensorData<int16_t>(kernel_content.output_tensor);
+
+ // Create lstm buffer struct
+ cmsis_nn_lstm_context cmsis_buffers;
+ cmsis_buffers.temp1 = reinterpret_cast<int16_t*>(buffers.buffer0);
+ cmsis_buffers.temp2 = reinterpret_cast<int16_t*>(buffers.buffer1);
+ cmsis_buffers.cell_state = reinterpret_cast<int16_t*>(buffers.buffer2);
+
+ arm_lstm_unidirectional_s16(input, output, &op_data.params_cmsis_nn,
+ &cmsis_buffers);
+
+ return kTfLiteOk;
+}
+
/*Kernel functions*/
void* UnidirectionalSequenceLstmInit(TfLiteContext* context, const char* buffer,
size_t length) {
@@ -351,6 +383,12 @@
number_of_buffers = 3;
CMSIS_NN_PortOpData<int32_t>(context, op_data_lstm, kernel_content,
&op_data->params_cmsis_nn);
+ } else if (activation_type == kTfLiteInt16 &&
+ cell_state_type == kTfLiteInt16) {
+ auto kernel_content = CreateLSTMKernelContent(context, node);
+ number_of_buffers = 3;
+ CMSIS_NN_PortOpData<int64_t>(context, op_data_lstm, kernel_content,
+ &op_data->params_cmsis_nn);
} else {
number_of_buffers = 4;
}
@@ -394,8 +432,7 @@
// 8(activation)x8(weight)->16(cell) LSTM with 32 bits bias
LSTMBuffers<int16_t> buffers =
CMSIS_NN_CreateLSTMBuffers(context, op_data_lstm.buffer_indices);
- return CMSIS_NN_EvalInteger8x8_16Lstm(op_data, kernel_content,
- buffers);
+ CMSIS_NN_EvalInteger8x8_16Lstm(op_data, kernel_content, buffers);
break;
}
default: {
@@ -411,9 +448,8 @@
case kTfLiteInt8: {
// 16(activation)x8(weight)->16(cell) LSTM with 64 bits bias
LSTMBuffers<int16_t> buffers =
- CreateLSTMBuffers<int16_t>(context, op_data_lstm.buffer_indices);
- EvalLstm<int16_t, int8_t, int16_t, int64_t>(op_data_lstm,
- kernel_content, buffers);
+ CMSIS_NN_CreateLSTMBuffers(context, op_data_lstm.buffer_indices);
+ CMSIS_NN_EvalInteger16x8_16Lstm(op_data, kernel_content, buffers);
break;
}
default: {
@@ -460,6 +496,33 @@
return kTfLiteOk;
}
+TfLiteStatus UnidirectionalSequenceLstmEvalInt16(TfLiteContext* context,
+ TfLiteNode* node) {
+ TFLITE_DCHECK(node->user_data != nullptr);
+ const OpData& op_data = *reinterpret_cast<const OpData*>(node->user_data);
+ const OpDataLSTM& op_data_lstm = op_data.params_ref;
+ auto kernel_content = CreateLSTMKernelContent(context, node);
+ const auto activation_type =
+ kernel_content.internal_tensors[kLstmInputTensor]->type;
+ const auto weight_type =
+ kernel_content.internal_tensors[kLstmInputToInputWeightsTensor]->type;
+
+ TFLITE_DCHECK(weight_type == kTfLiteInt16 &&
+ "Only int16 filter type supported.");
+
+ if (activation_type == kTfLiteInt16) {
+ LSTMBuffers<int16_t> buffers =
+ CMSIS_NN_CreateLSTMBuffers(context, op_data_lstm.buffer_indices);
+
+ return CMSIS_NN_EvalInteger16x8_16Lstm(op_data, kernel_content, buffers);
+ } else {
+ MicroPrintf("Input type %s (%d) not supported.",
+ TfLiteTypeGetName(activation_type), activation_type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
} // namespace
TFLMRegistration Register_UNIDIRECTIONAL_SEQUENCE_LSTM() {
@@ -474,4 +537,10 @@
UnidirectionalSequenceLstmEvalInt8);
}
+TFLMRegistration Register_UNIDIRECTIONAL_SEQUENCE_LSTM_INT16() {
+ return tflite::micro::RegisterOp(UnidirectionalSequenceLstmInit,
+ UnidirectionalSequenceLstmPrepare,
+ UnidirectionalSequenceLstmEvalInt16);
+}
+
} // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/unidirectional_sequence_lstm.h b/tensorflow/lite/micro/kernels/unidirectional_sequence_lstm.h
index 16aa23b..46f6b2d 100644
--- a/tensorflow/lite/micro/kernels/unidirectional_sequence_lstm.h
+++ b/tensorflow/lite/micro/kernels/unidirectional_sequence_lstm.h
@@ -1,4 +1,4 @@
-/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -36,10 +36,19 @@
// implementations.
TFLMRegistration Register_UNIDIRECTIONAL_SEQUENCE_LSTM_INT8();
+// Returns a TFLMRegistration struct for kernel variant that only supports
+// int16 activations and int8 weights and uses the latency optimized
+// implementations.
+TFLMRegistration Register_UNIDIRECTIONAL_SEQUENCE_LSTM_INT16();
+
#else
inline TFLMRegistration Register_UNIDIRECTIONAL_SEQUENCE_LSTM_INT8() {
return Register_UNIDIRECTIONAL_SEQUENCE_LSTM();
}
+
+inline TFLMRegistration Register_UNIDIRECTIONAL_SEQUENCE_LSTM_INT16() {
+ return Register_UNIDIRECTIONAL_SEQUENCE_LSTM();
+}
#endif
} // namespace tflite
diff --git a/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn_download.sh b/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn_download.sh
index fae77aa..482766c 100755
--- a/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn_download.sh
+++ b/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn_download.sh
@@ -47,9 +47,9 @@
echo >&2 "${DOWNLOADED_CMSIS_NN_PATH} already exists, skipping the download."
else
- ZIP_PREFIX_NN="6cc31fb36fa330325b2bb0ffde3a7288384e58ab"
- CMSIS_NN_URL="http://github.com/ARM-software/CMSIS-NN/archive/6cc31fb36fa330325b2bb0ffde3a7288384e58ab.zip"
- CMSIS_NN_MD5="42000f264b93b7b6cd60c1b507792daf"
+ ZIP_PREFIX_NN="8492d82a1a81651977c5f5128492b4a0f0cf6715"
+ CMSIS_NN_URL="http://github.com/ARM-software/CMSIS-NN/archive/${ZIP_PREFIX_NN}.zip"
+ CMSIS_NN_MD5="2cb03e4f044b78af6751009cd53247a8"
# wget is much faster than git clone of the entire repo. So we wget a specific
# version and can then apply a patch, as needed.