tflite-micro: add max_pooling int16 kernel for Kelvin Change-Id: I71cde103ead25bc6ebcf158217008a78a351e666
diff --git a/tensorflow/lite/micro/kernels/kelvin/pooling.cc b/tensorflow/lite/micro/kernels/kelvin/pooling.cc index d343a55..94fc6f2 100644 --- a/tensorflow/lite/micro/kernels/kelvin/pooling.cc +++ b/tensorflow/lite/micro/kernels/kelvin/pooling.cc
@@ -74,27 +74,34 @@ TfLiteEvalTensor* output = micro::GetEvalOutput(context, node, kPoolingOutputTensor); + tflite::PoolParams op_params; + op_params.stride_height = params->stride_height; + op_params.stride_width = params->stride_width; + op_params.filter_height = params->filter_height; + op_params.filter_width = params->filter_width; + op_params.padding_values.height = data->padding.height; + op_params.padding_values.width = data->padding.width; + op_params.quantized_activation_min = data->activation_min; + op_params.quantized_activation_max = data->activation_max; + op_params.float_activation_min = data->activation_min_f32; + op_params.float_activation_max = data->activation_max_f32; + switch (input->type) { case kTfLiteFloat32: - MaxPoolingEvalFloat(context, node, params, data, input, output); + reference_ops::MaxPool(op_params, tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData<float>(input), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData<float>(output)); break; case kTfLiteInt8: - tflite::PoolParams op_params; - op_params.stride_height = params->stride_height; - op_params.stride_width = params->stride_width; - op_params.filter_height = params->filter_height; - op_params.filter_width = params->filter_width; - op_params.padding_values.height = data->padding.height; - op_params.padding_values.width = data->padding.width; - op_params.quantized_activation_min = data->activation_min; - op_params.quantized_activation_max = data->activation_max; kelvin::opt::MaxPoolS8( op_params, tflite::micro::GetTensorShape(input), input->data.int8, tflite::micro::GetTensorShape(output), output->data.int8); break; case kTfLiteInt16: - MaxPoolingEvalQuantized<int16_t>(context, node, params, data, input, - output); + kelvin::opt::MaxPoolS16( + op_params, tflite::micro::GetTensorShape(input), input->data.i16, + tflite::micro::GetTensorShape(output), output->data.i16); break; default: MicroPrintf("Type %s not currently supported.",