Update cmsis prepare functions to exit earlier (#2630)
* Adds additional checks to various cmsis-nn prepare functions
BUG=1121
Authored-by: Ryan O'Shea <ryan.oshea3@arm.com>
Change-Id: Ic6481873f064a94a4dd0b4a49790842180d73dd9
diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/add.cc b/tensorflow/lite/micro/kernels/cmsis_nn/add.cc
index 898410a..fb166a1 100644
--- a/tensorflow/lite/micro/kernels/cmsis_nn/add.cc
+++ b/tensorflow/lite/micro/kernels/cmsis_nn/add.cc
@@ -1,4 +1,4 @@
-/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -301,6 +301,15 @@
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
+ TF_LITE_ENSURE_EQ(context, input1->type, output->type);
+ TF_LITE_ENSURE_MSG(
+ context,
+ input1->type == kTfLiteFloat32 || input1->type == kTfLiteInt32 ||
+ input1->type == kTfLiteInt16 || input1->type == kTfLiteInt8,
+ "Input data type not supported");
+ TF_LITE_ENSURE_MSG(context, input1->type == input2->type,
+ "Hybrid models are not supported on TFLite Micro.");
+
if (input1->type == kTfLiteInt16) {
TF_LITE_ENSURE_EQ(context, input1->params.zero_point, 0);
TF_LITE_ENSURE_EQ(context, input2->params.zero_point, 0);
diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/conv.cc b/tensorflow/lite/micro/kernels/cmsis_nn/conv.cc
index 4c35970..cae68c7 100644
--- a/tensorflow/lite/micro/kernels/cmsis_nn/conv.cc
+++ b/tensorflow/lite/micro/kernels/cmsis_nn/conv.cc
@@ -67,11 +67,17 @@
TfLiteType bias_type = bias != nullptr ? bias->type : kTfLiteNoType;
TF_LITE_ENSURE_EQ(context, input->type, output->type);
+ TF_LITE_ENSURE_MSG(context,
+ input->type == kTfLiteFloat32 ||
+ input->type == kTfLiteInt16 ||
+ input->type == kTfLiteInt8,
+ "Input data type not supported");
TF_LITE_ENSURE_MSG(
context,
- input->type == filter->type ||
+ (input->type == kTfLiteFloat32 && filter->type == kTfLiteFloat32) ||
(input->type == kTfLiteInt16 && filter->type == kTfLiteInt8) ||
- (input->type == kTfLiteInt8 && filter->type == kTfLiteInt4),
+ (input->type == kTfLiteInt8 &&
+ (filter->type == kTfLiteInt4 || filter->type == kTfLiteInt8)),
"Hybrid models are not supported on TFLite Micro.");
// Consistency check tensor dims
diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/depthwise_conv.cc b/tensorflow/lite/micro/kernels/cmsis_nn/depthwise_conv.cc
index f30a952..7183a28 100644
--- a/tensorflow/lite/micro/kernels/cmsis_nn/depthwise_conv.cc
+++ b/tensorflow/lite/micro/kernels/cmsis_nn/depthwise_conv.cc
@@ -1,4 +1,4 @@
-/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -75,6 +75,20 @@
micro_context->AllocateTempOutputTensor(node, kDepthwiseConvOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
+ TF_LITE_ENSURE_EQ(context, input->type, output->type);
+ TF_LITE_ENSURE_MSG(context,
+ input->type == kTfLiteFloat32 ||
+ input->type == kTfLiteInt16 ||
+ input->type == kTfLiteInt8,
+ "Input data type not supported");
+ TF_LITE_ENSURE_MSG(
+ context,
+ (input->type == kTfLiteFloat32 && filter->type == kTfLiteFloat32) ||
+ (input->type == kTfLiteInt16 && filter->type == kTfLiteInt8) ||
+ (input->type == kTfLiteInt8 &&
+ (filter->type == kTfLiteInt4 || filter->type == kTfLiteInt8)),
+ "Hybrid models are not supported on TFLite Micro.");
+
const TfLiteType data_type = input->type;
int input_width = SizeOfDimension(input, 2);
int input_height = SizeOfDimension(input, 1);
diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc b/tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc
index dc7b78c..7c373b5 100644
--- a/tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc
+++ b/tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc
@@ -76,7 +76,19 @@
node, kFullyConnectedOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
- TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
+ TF_LITE_ENSURE_EQ(context, input->type, output->type);
+ TF_LITE_ENSURE_MSG(context,
+ input->type == kTfLiteFloat32 ||
+ input->type == kTfLiteInt16 ||
+ input->type == kTfLiteInt8,
+ "Input data type not supported");
+ TF_LITE_ENSURE_MSG(
+ context,
+ (input->type == kTfLiteFloat32 && filter->type == kTfLiteFloat32) ||
+ (input->type == kTfLiteInt16 && filter->type == kTfLiteInt8) ||
+ (input->type == kTfLiteInt8 &&
+ (filter->type == kTfLiteInt4 || filter->type == kTfLiteInt8)),
+ "Hybrid models are not supported on TFLite Micro.");
const RuntimeShape filter_shape = GetTensorShape(filter);
const RuntimeShape output_shape = GetTensorShape(output);
@@ -125,7 +137,7 @@
input_dims.c = data->accum_depth;
buf_size = arm_convolve_1x1_s8_fast_get_buffer_size(&input_dims);
- } else {
+ } else if (input->type == kTfLiteInt8) {
buf_size = arm_fully_connected_s8_get_buffer_size(&filter_dims);
int8_t* filter_data = GetTensorData<int8_t>(filter);
diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/softmax.cc b/tensorflow/lite/micro/kernels/cmsis_nn/softmax.cc
index f83a090..6515691 100644
--- a/tensorflow/lite/micro/kernels/cmsis_nn/softmax.cc
+++ b/tensorflow/lite/micro/kernels/cmsis_nn/softmax.cc
@@ -1,4 +1,4 @@
-/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -52,6 +52,17 @@
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
TF_LITE_ENSURE(context, output != nullptr);
+ TF_LITE_ENSURE_MSG(
+ context,
+ input->type == output->type ||
+ (input->type == kTfLiteInt8 && output->type == kTfLiteInt16),
+ "Input and output data types are not supported together.");
+ TF_LITE_ENSURE_MSG(context,
+ input->type == kTfLiteFloat32 ||
+ input->type == kTfLiteInt16 ||
+ input->type == kTfLiteInt8,
+ "Input data type not supported");
+
TF_LITE_ENSURE(context, node->user_data != nullptr);
CMSISNNSoftmaxParams* op_data =
static_cast<CMSISNNSoftmaxParams*>(node->user_data);
diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/transpose_conv.cc b/tensorflow/lite/micro/kernels/cmsis_nn/transpose_conv.cc
index 06305bc..20cf0e1 100644
--- a/tensorflow/lite/micro/kernels/cmsis_nn/transpose_conv.cc
+++ b/tensorflow/lite/micro/kernels/cmsis_nn/transpose_conv.cc
@@ -174,10 +174,17 @@
micro_context->AllocateTempInputTensor(node, kFilterTensor);
TF_LITE_ENSURE(context, filter != nullptr);
+ TF_LITE_ENSURE_EQ(context, input->type, output->type);
+ TF_LITE_ENSURE_MSG(context,
+ input->type == kTfLiteFloat32 ||
+ input->type == kTfLiteInt16 ||
+ input->type == kTfLiteInt8,
+ "Input data type not supported");
TF_LITE_ENSURE_MSG(
context,
- input->type == filter->type ||
- (input->type == kTfLiteInt16 && filter->type == kTfLiteInt8),
+ (input->type == kTfLiteFloat32 && filter->type == kTfLiteFloat32) ||
+ (input->type == kTfLiteInt16 && filter->type == kTfLiteInt8) ||
+ (input->type == kTfLiteInt8 && filter->type == kTfLiteInt8),
"Hybrid models are not supported on TFLite Micro.");
// Get height and width of the output.