cmsis-nn: update fully connected int8 (#2469)
- Adds non zero filter offset support.
- Adds support for batch matmul like behavior where weights are like input, i.e. not initialized before eval.
BUG=non zero filter offset not supported for CMSIS-NN
diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc b/tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc
index 0c4f8aa..dc7b78c 100644
--- a/tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc
+++ b/tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc
@@ -104,9 +104,7 @@
TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
buf_size = arm_fully_connected_s16_get_buffer_size(&filter_dims);
- } else if (input->type == kTfLiteInt8 &&
- data->reference_op_data.filter_zero_point == 0 &&
- filter->type != kTfLiteInt4) {
+ } else if (input->type == kTfLiteInt8 && filter->type != kTfLiteInt4) {
const RuntimeShape input_shape = GetTensorShape(input);
TFLITE_DCHECK_GE(output_dim_count, 2);
@@ -130,11 +128,13 @@
} else {
buf_size = arm_fully_connected_s8_get_buffer_size(&filter_dims);
- if (buf_size > 0) {
+ int8_t* filter_data = GetTensorData<int8_t>(filter);
+ data->kernel_sums = nullptr;
+
+ if (buf_size > 0 && filter_data != nullptr) {
data->kernel_sums = static_cast<int32_t*>(
context->AllocatePersistentBuffer(context, buf_size));
- int8_t* filter_data = GetTensorData<int8_t>(filter);
arm_vector_sum_s8(data->kernel_sums, filter_dims.n, data->output_depth,
filter_data, 1, nullptr);
@@ -298,12 +298,20 @@
} else {
cmsis_nn_fc_params fc_params;
fc_params.input_offset = -data.reference_op_data.input_zero_point;
+ fc_params.filter_offset = -data.reference_op_data.filter_zero_point;
fc_params.output_offset = data.reference_op_data.output_zero_point;
- fc_params.filter_offset = 0;
fc_params.activation.min = data.reference_op_data.output_activation_min;
fc_params.activation.max = data.reference_op_data.output_activation_max;
- ctx.buf = data.kernel_sums;
+ if (data.kernel_sums != nullptr) {
+ ctx.buf = data.kernel_sums;
+ } else if (ctx.buf != nullptr) {
+ // If behaving like batch matmul we calculate kernel sums in eval.
+ arm_vector_sum_s8(
+ static_cast<int32_t*>(ctx.buf), filter_dims.n, data.output_depth,
+ tflite::micro::GetTensorData<int8_t>(filter), 1, nullptr);
+ }
+
TF_LITE_ENSURE_EQ(
context,
arm_fully_connected_s8(
@@ -393,22 +401,8 @@
return EvalQuantizedInt4(context, node, data, input, filter, bias,
output);
case kTfLiteInt8:
- if (data.reference_op_data.filter_zero_point == 0) {
- return EvalQuantizedInt8(context, node, data, input, filter, bias,
- output);
- } else {
- tflite::reference_integer_ops::FullyConnected(
- FullyConnectedParamsQuantized(data.reference_op_data),
- tflite::micro::GetTensorShape(input),
- tflite::micro::GetTensorData<int8_t>(input),
- tflite::micro::GetTensorShape(filter),
- tflite::micro::GetTensorData<int8_t>(filter),
- tflite::micro::GetTensorShape(bias),
- tflite::micro::GetOptionalTensorData<int32_t>(bias),
- tflite::micro::GetTensorShape(output),
- tflite::micro::GetTensorData<int8_t>(output));
- return kTfLiteOk;
- }
+ return EvalQuantizedInt8(context, node, data, input, filter, bias,
+ output);
default:
MicroPrintf("Filter Type %s (%d) not supported.",
TfLiteTypeGetName(filter->type), filter->type);
diff --git a/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn_download.sh b/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn_download.sh
index cc79116..601c4d4 100755
--- a/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn_download.sh
+++ b/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn_download.sh
@@ -47,9 +47,9 @@
echo >&2 "${DOWNLOADED_CMSIS_NN_PATH} already exists, skipping the download."
else
- ZIP_PREFIX_NN="2a999a2fd887c98042353accac77479f00b5f99d"
+ ZIP_PREFIX_NN="72e1ebf623ab1660a3e14e4e36fdcddce46f1991"
CMSIS_NN_URL="http://github.com/ARM-software/CMSIS-NN/archive/${ZIP_PREFIX_NN}.zip"
- CMSIS_NN_MD5="c6cfe1f8e0f6518c92f7e42ed7b7afd4"
+ CMSIS_NN_MD5="23a623f4eca6c8f11ee5366c2cf61a44"
# wget is much faster than git clone of the entire repo. So we wget a specific
# version and can then apply a patch, as needed.