sw/tflite-micro: clean up function definitions Try to use a same function definition as in reference implementations. This makes codes cleaner and more expandable. Change-Id: I8cecbb405eef4ae28d60b94a42abb9c2587d25c4
diff --git a/tensorflow/lite/micro/kernels/kelvin/add.cc b/tensorflow/lite/micro/kernels/kelvin/add.cc index 9a647a0..8c33716 100644 --- a/tensorflow/lite/micro/kernels/kelvin/add.cc +++ b/tensorflow/lite/micro/kernels/kelvin/add.cc
@@ -81,16 +81,13 @@ tflite::micro::GetTensorData<int32_t>(output)); } else { kelvin::opt::ElementwiseAddS32( + op_params, tflite::micro::GetTensorShape(input1), tflite::micro::GetTensorData<int32_t>(input1), + tflite::micro::GetTensorShape(input2), tflite::micro::GetTensorData<int32_t>(input2), - tflite::micro::GetTensorData<int32_t>(output), - op_params.quantized_activation_min, - op_params.quantized_activation_max, - MatchingElementsSize(tflite::micro::GetTensorShape(input1), - tflite::micro::GetTensorShape(input2), - tflite::micro::GetTensorShape(output))); + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData<int32_t>(output)); } - return kTfLiteOk; } else if (output->type == kTfLiteInt16) { tflite::ArithmeticParams op_params; op_params.left_shift = data->left_shift; @@ -120,20 +117,13 @@ tflite::micro::GetTensorData<int16_t>(output)); } else { kelvin::opt::ElementwiseAddS16( + op_params, tflite::micro::GetTensorShape(input1), tflite::micro::GetTensorData<int16_t>(input1), + tflite::micro::GetTensorShape(input2), tflite::micro::GetTensorData<int16_t>(input2), - op_params.input1_offset, op_params.input1_multiplier, - op_params.input1_shift, op_params.input2_offset, - op_params.input2_multiplier, op_params.input2_shift, - op_params.left_shift, tflite::micro::GetTensorData<int16_t>(output), - op_params.output_offset, op_params.output_multiplier, - op_params.output_shift, op_params.quantized_activation_min, - op_params.quantized_activation_max, - MatchingElementsSize(tflite::micro::GetTensorShape(input1), - tflite::micro::GetTensorShape(input2), - tflite::micro::GetTensorShape(output))); + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData<int16_t>(output)); } - return kTfLiteOk; } else if (output->type == kTfLiteInt8) { tflite::ArithmeticParams op_params; op_params.left_shift = data->left_shift; @@ -163,18 +153,12 @@ tflite::micro::GetTensorData<int8_t>(output)); } else { kelvin::opt::ElementwiseAddS8( + op_params, tflite::micro::GetTensorShape(input1), tflite::micro::GetTensorData<int8_t>(input1), - tflite::micro::GetTensorData<int8_t>(input2), op_params.input1_offset, - op_params.input1_multiplier, op_params.input1_shift, - op_params.input2_offset, op_params.input2_multiplier, - op_params.input2_shift, op_params.left_shift, - tflite::micro::GetTensorData<int8_t>(output), op_params.output_offset, - op_params.output_multiplier, op_params.output_shift, - op_params.quantized_activation_min, - op_params.quantized_activation_max, - MatchingElementsSize(tflite::micro::GetTensorShape(input1), - tflite::micro::GetTensorShape(input2), - tflite::micro::GetTensorShape(output))); + tflite::micro::GetTensorShape(input2), + tflite::micro::GetTensorData<int8_t>(input2), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData<int8_t>(output)); } } else { MicroPrintf("Unsupported output type: %s", TfLiteTypeGetName(output->type));
diff --git a/tensorflow/lite/micro/kernels/kelvin/leaky_relu.cc b/tensorflow/lite/micro/kernels/kelvin/leaky_relu.cc index 2ca8a31..fafcfed 100644 --- a/tensorflow/lite/micro/kernels/kelvin/leaky_relu.cc +++ b/tensorflow/lite/micro/kernels/kelvin/leaky_relu.cc
@@ -53,25 +53,31 @@ return kTfLiteOk; } break; case kTfLiteInt8: { - kelvin::opt::LeakyReluS8( - tflite::micro::GetTensorData<int8_t>(input), - tflite::micro::GetTensorData<int8_t>(output), - MatchingFlatSize(tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorShape(output)), - data.input_zero_point, data.output_zero_point, - data.output_multiplier_alpha, data.output_shift_alpha, - data.output_multiplier_identity, data.output_shift_identity); + LeakyReluParams op_params = {}; + op_params.input_offset = data.input_zero_point; + op_params.output_offset = data.output_zero_point; + op_params.output_multiplier_alpha = data.output_multiplier_alpha; + op_params.output_shift_alpha = data.output_shift_alpha; + op_params.output_multiplier_identity = data.output_multiplier_identity; + op_params.output_shift_identity = data.output_shift_identity; + kelvin::opt::LeakyReluS8(op_params, tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData<int8_t>(input), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData<int8_t>(output)); return kTfLiteOk; } break; case kTfLiteInt16: { - kelvin::opt::LeakyReluS16( - tflite::micro::GetTensorData<int16_t>(input), - tflite::micro::GetTensorData<int16_t>(output), - MatchingFlatSize(tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorShape(output)), - data.input_zero_point, data.output_zero_point, - data.output_multiplier_alpha, data.output_shift_alpha, - data.output_multiplier_identity, data.output_shift_identity); + LeakyReluParams op_params = {}; + op_params.input_offset = data.input_zero_point; + op_params.output_offset = data.output_zero_point; + op_params.output_multiplier_alpha = data.output_multiplier_alpha; + op_params.output_shift_alpha = data.output_shift_alpha; + op_params.output_multiplier_identity = data.output_multiplier_identity; + op_params.output_shift_identity = data.output_shift_identity; + kelvin::opt::LeakyReluS16(op_params, tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData<int16_t>(input), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData<int16_t>(output)); return kTfLiteOk; } break; default:
diff --git a/tensorflow/lite/micro/kernels/kelvin/resize_nearest_neighbor.cc b/tensorflow/lite/micro/kernels/kelvin/resize_nearest_neighbor.cc index 039bd0d..5b700ae 100644 --- a/tensorflow/lite/micro/kernels/kelvin/resize_nearest_neighbor.cc +++ b/tensorflow/lite/micro/kernels/kelvin/resize_nearest_neighbor.cc
@@ -90,8 +90,7 @@ tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData<int32_t>(output)); } else if (output->type == kTfLiteInt8) { - - kelvin::opt::KelvinResizeNearestNeighbor( + kelvin::opt::ResizeNearestNeighborS8( op_params, tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData<int8_t>(input), tflite::micro::GetTensorShape(size),