Update cmsis prepare functions to exit earlier (#2630)

* Adds additional checks to various cmsis-nn prepare functions

BUG=1121
Authored-by: Ryan O'Shea <ryan.oshea3@arm.com>
Change-Id: Ic6481873f064a94a4dd0b4a49790842180d73dd9
diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/add.cc b/tensorflow/lite/micro/kernels/cmsis_nn/add.cc
index 898410a..fb166a1 100644
--- a/tensorflow/lite/micro/kernels/cmsis_nn/add.cc
+++ b/tensorflow/lite/micro/kernels/cmsis_nn/add.cc
@@ -1,4 +1,4 @@
-/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -301,6 +301,15 @@
       micro_context->AllocateTempOutputTensor(node, kOutputTensor);
   TF_LITE_ENSURE(context, output != nullptr);
 
+  TF_LITE_ENSURE_EQ(context, input1->type, output->type);
+  TF_LITE_ENSURE_MSG(
+      context,
+      input1->type == kTfLiteFloat32 || input1->type == kTfLiteInt32 ||
+          input1->type == kTfLiteInt16 || input1->type == kTfLiteInt8,
+      "Input data type not supported");
+  TF_LITE_ENSURE_MSG(context, input1->type == input2->type,
+                     "Hybrid models are not supported on TFLite Micro.");
+
   if (input1->type == kTfLiteInt16) {
     TF_LITE_ENSURE_EQ(context, input1->params.zero_point, 0);
     TF_LITE_ENSURE_EQ(context, input2->params.zero_point, 0);
diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/conv.cc b/tensorflow/lite/micro/kernels/cmsis_nn/conv.cc
index 4c35970..cae68c7 100644
--- a/tensorflow/lite/micro/kernels/cmsis_nn/conv.cc
+++ b/tensorflow/lite/micro/kernels/cmsis_nn/conv.cc
@@ -67,11 +67,17 @@
   TfLiteType bias_type = bias != nullptr ? bias->type : kTfLiteNoType;
 
   TF_LITE_ENSURE_EQ(context, input->type, output->type);
+  TF_LITE_ENSURE_MSG(context,
+                     input->type == kTfLiteFloat32 ||
+                         input->type == kTfLiteInt16 ||
+                         input->type == kTfLiteInt8,
+                     "Input data type not supported");
   TF_LITE_ENSURE_MSG(
       context,
-      input->type == filter->type ||
+      (input->type == kTfLiteFloat32 && filter->type == kTfLiteFloat32) ||
           (input->type == kTfLiteInt16 && filter->type == kTfLiteInt8) ||
-          (input->type == kTfLiteInt8 && filter->type == kTfLiteInt4),
+          (input->type == kTfLiteInt8 &&
+           (filter->type == kTfLiteInt4 || filter->type == kTfLiteInt8)),
       "Hybrid models are not supported on TFLite Micro.");
 
   // Consistency check tensor dims
diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/depthwise_conv.cc b/tensorflow/lite/micro/kernels/cmsis_nn/depthwise_conv.cc
index f30a952..7183a28 100644
--- a/tensorflow/lite/micro/kernels/cmsis_nn/depthwise_conv.cc
+++ b/tensorflow/lite/micro/kernels/cmsis_nn/depthwise_conv.cc
@@ -1,4 +1,4 @@
-/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -75,6 +75,20 @@
       micro_context->AllocateTempOutputTensor(node, kDepthwiseConvOutputTensor);
   TF_LITE_ENSURE(context, output != nullptr);
 
+  TF_LITE_ENSURE_EQ(context, input->type, output->type);
+  TF_LITE_ENSURE_MSG(context,
+                     input->type == kTfLiteFloat32 ||
+                         input->type == kTfLiteInt16 ||
+                         input->type == kTfLiteInt8,
+                     "Input data type not supported");
+  TF_LITE_ENSURE_MSG(
+      context,
+      (input->type == kTfLiteFloat32 && filter->type == kTfLiteFloat32) ||
+          (input->type == kTfLiteInt16 && filter->type == kTfLiteInt8) ||
+          (input->type == kTfLiteInt8 &&
+           (filter->type == kTfLiteInt4 || filter->type == kTfLiteInt8)),
+      "Hybrid models are not supported on TFLite Micro.");
+
   const TfLiteType data_type = input->type;
   int input_width = SizeOfDimension(input, 2);
   int input_height = SizeOfDimension(input, 1);
diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc b/tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc
index dc7b78c..7c373b5 100644
--- a/tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc
+++ b/tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc
@@ -76,7 +76,19 @@
       node, kFullyConnectedOutputTensor);
   TF_LITE_ENSURE(context, output != nullptr);
 
-  TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
+  TF_LITE_ENSURE_EQ(context, input->type, output->type);
+  TF_LITE_ENSURE_MSG(context,
+                     input->type == kTfLiteFloat32 ||
+                         input->type == kTfLiteInt16 ||
+                         input->type == kTfLiteInt8,
+                     "Input data type not supported");
+  TF_LITE_ENSURE_MSG(
+      context,
+      (input->type == kTfLiteFloat32 && filter->type == kTfLiteFloat32) ||
+          (input->type == kTfLiteInt16 && filter->type == kTfLiteInt8) ||
+          (input->type == kTfLiteInt8 &&
+           (filter->type == kTfLiteInt4 || filter->type == kTfLiteInt8)),
+      "Hybrid models are not supported on TFLite Micro.");
 
   const RuntimeShape filter_shape = GetTensorShape(filter);
   const RuntimeShape output_shape = GetTensorShape(output);
@@ -125,7 +137,7 @@
       input_dims.c = data->accum_depth;
 
       buf_size = arm_convolve_1x1_s8_fast_get_buffer_size(&input_dims);
-    } else {
+    } else if (input->type == kTfLiteInt8) {
       buf_size = arm_fully_connected_s8_get_buffer_size(&filter_dims);
 
       int8_t* filter_data = GetTensorData<int8_t>(filter);
diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/softmax.cc b/tensorflow/lite/micro/kernels/cmsis_nn/softmax.cc
index f83a090..6515691 100644
--- a/tensorflow/lite/micro/kernels/cmsis_nn/softmax.cc
+++ b/tensorflow/lite/micro/kernels/cmsis_nn/softmax.cc
@@ -1,4 +1,4 @@
-/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -52,6 +52,17 @@
   TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
   TF_LITE_ENSURE(context, output != nullptr);
 
+  TF_LITE_ENSURE_MSG(
+      context,
+      input->type == output->type ||
+          (input->type == kTfLiteInt8 && output->type == kTfLiteInt16),
+      "Input and output data types are not supported together.");
+  TF_LITE_ENSURE_MSG(context,
+                     input->type == kTfLiteFloat32 ||
+                         input->type == kTfLiteInt16 ||
+                         input->type == kTfLiteInt8,
+                     "Input data type not supported");
+
   TF_LITE_ENSURE(context, node->user_data != nullptr);
   CMSISNNSoftmaxParams* op_data =
       static_cast<CMSISNNSoftmaxParams*>(node->user_data);
diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/transpose_conv.cc b/tensorflow/lite/micro/kernels/cmsis_nn/transpose_conv.cc
index 06305bc..20cf0e1 100644
--- a/tensorflow/lite/micro/kernels/cmsis_nn/transpose_conv.cc
+++ b/tensorflow/lite/micro/kernels/cmsis_nn/transpose_conv.cc
@@ -174,10 +174,17 @@
       micro_context->AllocateTempInputTensor(node, kFilterTensor);
   TF_LITE_ENSURE(context, filter != nullptr);
 
+  TF_LITE_ENSURE_EQ(context, input->type, output->type);
+  TF_LITE_ENSURE_MSG(context,
+                     input->type == kTfLiteFloat32 ||
+                         input->type == kTfLiteInt16 ||
+                         input->type == kTfLiteInt8,
+                     "Input data type not supported");
   TF_LITE_ENSURE_MSG(
       context,
-      input->type == filter->type ||
-          (input->type == kTfLiteInt16 && filter->type == kTfLiteInt8),
+      (input->type == kTfLiteFloat32 && filter->type == kTfLiteFloat32) ||
+          (input->type == kTfLiteInt16 && filter->type == kTfLiteInt8) ||
+          (input->type == kTfLiteInt8 && filter->type == kTfLiteInt8),
       "Hybrid models are not supported on TFLite Micro.");
 
   // Get height and width of the output.