Add support for 16 bit activations in ExpandDims operation. (#2589)
The change introduces the support for ML models with 16 bit activations and 8 bit weights in the ExpandDims operation.
BUG=fixes #68293
diff --git a/tensorflow/lite/micro/kernels/expand_dims.cc b/tensorflow/lite/micro/kernels/expand_dims.cc
index 6bae37b..d47b42c 100644
--- a/tensorflow/lite/micro/kernels/expand_dims.cc
+++ b/tensorflow/lite/micro/kernels/expand_dims.cc
@@ -13,6 +13,8 @@
limitations under the License.
==============================================================================*/
+#include <cstdint>
+
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
@@ -128,13 +130,18 @@
memCopyN(tflite::micro::GetTensorData<float>(output),
tflite::micro::GetTensorData<float>(input), flat_size);
} break;
+ case kTfLiteInt16: {
+ memCopyN(tflite::micro::GetTensorData<int16_t>(output),
+ tflite::micro::GetTensorData<int16_t>(input), flat_size);
+ } break;
case kTfLiteInt8: {
memCopyN(tflite::micro::GetTensorData<int8_t>(output),
tflite::micro::GetTensorData<int8_t>(input), flat_size);
} break;
default:
MicroPrintf(
- "Expand_Dims only currently supports int8 and float32, got %d.",
+ "Expand_Dims only currently supports int8, int16 and float32, got "
+ "%d.",
input->type);
return kTfLiteError;
}
diff --git a/tensorflow/lite/micro/kernels/expand_dims_test.cc b/tensorflow/lite/micro/kernels/expand_dims_test.cc
index d8e217e..39a83b5 100644
--- a/tensorflow/lite/micro/kernels/expand_dims_test.cc
+++ b/tensorflow/lite/micro/kernels/expand_dims_test.cc
@@ -13,6 +13,8 @@
limitations under the License.
==============================================================================*/
+#include <cstdint>
+
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
@@ -138,6 +140,20 @@
golden_data, output_data);
}
+TF_LITE_MICRO_TEST(ExpandDimsPositiveAxisTest3) {
+ int16_t output_data[6];
+ int input_dims[] = {3, 3, 1, 2};
+ const int16_t input_data[] = {-1, 1, 2, -2, 0, 3};
+ const int16_t golden_data[] = {-1, 1, 2, -2, 0, 3};
+ int axis_dims[] = {1, 1};
+ const int32_t axis_data[] = {3};
+ int golden_dims[] = {1, 3, 1, 2};
+ int output_dims[] = {4, 3, 1, 2, 1};
+ tflite::testing::TestExpandDims<int16_t>(input_dims, input_data, axis_dims,
+ axis_data, golden_dims, output_dims,
+ golden_data, output_data);
+}
+
TF_LITE_MICRO_TEST(ExpandDimsNegativeAxisTest4) {
int8_t output_data[6];
int input_dims[] = {3, 3, 1, 2};