tflite-micro: add max_pooling int16 kernel for Kelvin
Change-Id: I71cde103ead25bc6ebcf158217008a78a351e666
diff --git a/tensorflow/lite/micro/kernels/kelvin/pooling.cc b/tensorflow/lite/micro/kernels/kelvin/pooling.cc
index d343a55..94fc6f2 100644
--- a/tensorflow/lite/micro/kernels/kelvin/pooling.cc
+++ b/tensorflow/lite/micro/kernels/kelvin/pooling.cc
@@ -74,27 +74,34 @@
TfLiteEvalTensor* output =
micro::GetEvalOutput(context, node, kPoolingOutputTensor);
+ tflite::PoolParams op_params;
+ op_params.stride_height = params->stride_height;
+ op_params.stride_width = params->stride_width;
+ op_params.filter_height = params->filter_height;
+ op_params.filter_width = params->filter_width;
+ op_params.padding_values.height = data->padding.height;
+ op_params.padding_values.width = data->padding.width;
+ op_params.quantized_activation_min = data->activation_min;
+ op_params.quantized_activation_max = data->activation_max;
+ op_params.float_activation_min = data->activation_min_f32;
+ op_params.float_activation_max = data->activation_max_f32;
+
switch (input->type) {
case kTfLiteFloat32:
- MaxPoolingEvalFloat(context, node, params, data, input, output);
+ reference_ops::MaxPool(op_params, tflite::micro::GetTensorShape(input),
+ tflite::micro::GetTensorData<float>(input),
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData<float>(output));
break;
case kTfLiteInt8:
- tflite::PoolParams op_params;
- op_params.stride_height = params->stride_height;
- op_params.stride_width = params->stride_width;
- op_params.filter_height = params->filter_height;
- op_params.filter_width = params->filter_width;
- op_params.padding_values.height = data->padding.height;
- op_params.padding_values.width = data->padding.width;
- op_params.quantized_activation_min = data->activation_min;
- op_params.quantized_activation_max = data->activation_max;
kelvin::opt::MaxPoolS8(
op_params, tflite::micro::GetTensorShape(input), input->data.int8,
tflite::micro::GetTensorShape(output), output->data.int8);
break;
case kTfLiteInt16:
- MaxPoolingEvalQuantized<int16_t>(context, node, params, data, input,
- output);
+ kelvin::opt::MaxPoolS16(
+ op_params, tflite::micro::GetTensorShape(input), input->data.i16,
+ tflite::micro::GetTensorShape(output), output->data.i16);
break;
default:
MicroPrintf("Type %s not currently supported.",