Add int16 pack and transpose
Change-Id: I8d63eadd27cb20868045666efb3b0369efc2c1cb
diff --git a/tensorflow/lite/micro/kernels/pack.cc b/tensorflow/lite/micro/kernels/pack.cc
index 7b4aeef..5ee2759 100644
--- a/tensorflow/lite/micro/kernels/pack.cc
+++ b/tensorflow/lite/micro/kernels/pack.cc
@@ -85,6 +85,9 @@
return PackImpl<int8_t>(context, node, output, data->values_count,
data->axis);
}
+ case kTfLiteInt16: {
+ return PackImpl<int16_t>(context, node, output, data->values_count, data->axis);
+ }
case kTfLiteInt32: {
return PackImpl<int32_t>(context, node, output, data->values_count,
data->axis);
diff --git a/tensorflow/lite/micro/kernels/transpose.cc b/tensorflow/lite/micro/kernels/transpose.cc
index 710bfca..c57812b 100644
--- a/tensorflow/lite/micro/kernels/transpose.cc
+++ b/tensorflow/lite/micro/kernels/transpose.cc
@@ -97,6 +97,12 @@
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
break;
+ case kTfLiteInt16:
+ reference_ops::Transpose(params, tflite::micro::GetTensorShape(input),
+ tflite::micro::GetTensorData<int16_t>(input),
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData<int16_t>(output));
+ break;
case kTfLiteInt8:
reference_ops::Transpose(params, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),