CMSIS-NN: Add int32 bias support int16xint8 conv (#2557)
BUG=CMSIS-NN glue is updated to support int32bias for int16x8
diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/conv.cc b/tensorflow/lite/micro/kernels/cmsis_nn/conv.cc
index 6691b59..4c35970 100644
--- a/tensorflow/lite/micro/kernels/cmsis_nn/conv.cc
+++ b/tensorflow/lite/micro/kernels/cmsis_nn/conv.cc
@@ -21,7 +21,6 @@
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/conv.h"
-#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/padding.h"
@@ -214,9 +213,26 @@
const cmsis_nn_dims* filter_dims, const int8_t* filter,
const cmsis_nn_dims* bias_dims, const int64_t* bias,
const cmsis_nn_dims* output_dims, int16_t* output, TfLiteType weightsT) {
+ const cmsis_nn_bias_data bias_data = {bias, false};
+
return arm_convolve_wrapper_s16(ctx, conv_params, quant_params, input_dims,
- input, filter_dims, filter, bias_dims, bias,
- output_dims, output);
+ input, filter_dims, filter, bias_dims,
+ &bias_data, output_dims, output);
+}
+
+template <>
+arm_cmsis_nn_status convolve_wrapper(
+ const cmsis_nn_context* ctx, const cmsis_nn_conv_params* conv_params,
+ const cmsis_nn_per_channel_quant_params* quant_params,
+ const cmsis_nn_dims* input_dims, const int16_t* input,
+ const cmsis_nn_dims* filter_dims, const int8_t* filter,
+ const cmsis_nn_dims* bias_dims, const int32_t* bias,
+ const cmsis_nn_dims* output_dims, int16_t* output, TfLiteType weightsT) {
+ const cmsis_nn_bias_data bias_data = {bias, true};
+
+ return arm_convolve_wrapper_s16(ctx, conv_params, quant_params, input_dims,
+ input, filter_dims, filter, bias_dims,
+ &bias_data, output_dims, output);
}
template <typename ActType, typename BiasType, TfLiteType type>
@@ -362,25 +378,17 @@
TFLITE_DCHECK(node->user_data != nullptr);
const OpData& data = *(static_cast<const OpData*>(node->user_data));
- if (bias == nullptr || bias->type == kTfLiteInt64) {
+ if (bias == nullptr || bias->type == kTfLiteInt32) {
+ return EvalQuantizedPerChannel<int16_t, int32_t, kTfLiteInt16>(
+ context, node, params, data, input, filter, bias, output);
+ } else if (bias->type == kTfLiteInt64) {
return EvalQuantizedPerChannel<int16_t, int64_t, kTfLiteInt16>(
context, node, params, data, input, filter, bias, output);
} else {
- reference_integer_ops::ConvPerChannel(
- ConvParamsQuantized(params, data.reference_op_data),
- data.reference_op_data.per_channel_output_multiplier,
- data.reference_op_data.per_channel_output_shift,
- tflite::micro::GetTensorShape(input),
- tflite::micro::GetTensorData<int16_t>(input),
- tflite::micro::GetTensorShape(filter),
- tflite::micro::GetTensorData<int8_t>(filter),
- tflite::micro::GetTensorShape(bias),
- tflite::micro::GetOptionalTensorData<std::int32_t>(bias),
- tflite::micro::GetTensorShape(output),
- tflite::micro::GetTensorData<int16_t>(output));
+ MicroPrintf("Bias type %s (%d) not supported.",
+ TfLiteTypeGetName(bias->type), bias->type);
+ return kTfLiteError;
}
-
- return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
@@ -443,22 +451,12 @@
break;
}
case kTfLiteInt16: {
- if (bias == nullptr || bias->type == kTfLiteInt64) {
+ if (bias == nullptr || bias->type == kTfLiteInt32) {
+ return EvalQuantizedPerChannel<int16_t, int32_t, kTfLiteInt16>(
+ context, node, params, data, input, filter, bias, output);
+ } else if (bias->type == kTfLiteInt64) {
return EvalQuantizedPerChannel<int16_t, int64_t, kTfLiteInt16>(
context, node, params, data, input, filter, bias, output);
- } else if (bias->type == kTfLiteInt32) {
- reference_integer_ops::ConvPerChannel(
- ConvParamsQuantized(params, data.reference_op_data),
- data.reference_op_data.per_channel_output_multiplier,
- data.reference_op_data.per_channel_output_shift,
- tflite::micro::GetTensorShape(input),
- tflite::micro::GetTensorData<int16_t>(input),
- tflite::micro::GetTensorShape(filter),
- tflite::micro::GetTensorData<int8_t>(filter),
- tflite::micro::GetTensorShape(bias),
- tflite::micro::GetOptionalTensorData<std::int32_t>(bias),
- tflite::micro::GetTensorShape(output),
- tflite::micro::GetTensorData<int16_t>(output));
} else {
MicroPrintf("Bias type %s (%d) not supported.",
TfLiteTypeGetName(bias->type), bias->type);
diff --git a/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn_download.sh b/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn_download.sh
index 482766c..195c0ca 100755
--- a/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn_download.sh
+++ b/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn_download.sh
@@ -47,9 +47,9 @@
echo >&2 "${DOWNLOADED_CMSIS_NN_PATH} already exists, skipping the download."
else
- ZIP_PREFIX_NN="8492d82a1a81651977c5f5128492b4a0f0cf6715"
+ ZIP_PREFIX_NN="15dbe7cda7130766777f18a69388238bc6540cef"
CMSIS_NN_URL="http://github.com/ARM-software/CMSIS-NN/archive/${ZIP_PREFIX_NN}.zip"
- CMSIS_NN_MD5="2cb03e4f044b78af6751009cd53247a8"
+ CMSIS_NN_MD5="6121fa707d4dac17acb4e222453a38a6"
# wget is much faster than git clone of the entire repo. So we wget a specific
# version and can then apply a patch, as needed.