Merge "sw/tflite-micro: clean up function definitions"
diff --git a/tensorflow/lite/micro/kernels/kelvin/add.cc b/tensorflow/lite/micro/kernels/kelvin/add.cc
index 9a647a0..8c33716 100644
--- a/tensorflow/lite/micro/kernels/kelvin/add.cc
+++ b/tensorflow/lite/micro/kernels/kelvin/add.cc
@@ -81,16 +81,13 @@
tflite::micro::GetTensorData<int32_t>(output));
} else {
kelvin::opt::ElementwiseAddS32(
+ op_params, tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<int32_t>(input1),
+ tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<int32_t>(input2),
- tflite::micro::GetTensorData<int32_t>(output),
- op_params.quantized_activation_min,
- op_params.quantized_activation_max,
- MatchingElementsSize(tflite::micro::GetTensorShape(input1),
- tflite::micro::GetTensorShape(input2),
- tflite::micro::GetTensorShape(output)));
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData<int32_t>(output));
}
- return kTfLiteOk;
} else if (output->type == kTfLiteInt16) {
tflite::ArithmeticParams op_params;
op_params.left_shift = data->left_shift;
@@ -120,20 +117,13 @@
tflite::micro::GetTensorData<int16_t>(output));
} else {
kelvin::opt::ElementwiseAddS16(
+ op_params, tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<int16_t>(input1),
+ tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<int16_t>(input2),
- op_params.input1_offset, op_params.input1_multiplier,
- op_params.input1_shift, op_params.input2_offset,
- op_params.input2_multiplier, op_params.input2_shift,
- op_params.left_shift, tflite::micro::GetTensorData<int16_t>(output),
- op_params.output_offset, op_params.output_multiplier,
- op_params.output_shift, op_params.quantized_activation_min,
- op_params.quantized_activation_max,
- MatchingElementsSize(tflite::micro::GetTensorShape(input1),
- tflite::micro::GetTensorShape(input2),
- tflite::micro::GetTensorShape(output)));
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData<int16_t>(output));
}
- return kTfLiteOk;
} else if (output->type == kTfLiteInt8) {
tflite::ArithmeticParams op_params;
op_params.left_shift = data->left_shift;
@@ -163,18 +153,12 @@
tflite::micro::GetTensorData<int8_t>(output));
} else {
kelvin::opt::ElementwiseAddS8(
+ op_params, tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<int8_t>(input1),
- tflite::micro::GetTensorData<int8_t>(input2), op_params.input1_offset,
- op_params.input1_multiplier, op_params.input1_shift,
- op_params.input2_offset, op_params.input2_multiplier,
- op_params.input2_shift, op_params.left_shift,
- tflite::micro::GetTensorData<int8_t>(output), op_params.output_offset,
- op_params.output_multiplier, op_params.output_shift,
- op_params.quantized_activation_min,
- op_params.quantized_activation_max,
- MatchingElementsSize(tflite::micro::GetTensorShape(input1),
- tflite::micro::GetTensorShape(input2),
- tflite::micro::GetTensorShape(output)));
+ tflite::micro::GetTensorShape(input2),
+ tflite::micro::GetTensorData<int8_t>(input2),
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData<int8_t>(output));
}
} else {
MicroPrintf("Unsupported output type: %s", TfLiteTypeGetName(output->type));
diff --git a/tensorflow/lite/micro/kernels/kelvin/leaky_relu.cc b/tensorflow/lite/micro/kernels/kelvin/leaky_relu.cc
index 2ca8a31..fafcfed 100644
--- a/tensorflow/lite/micro/kernels/kelvin/leaky_relu.cc
+++ b/tensorflow/lite/micro/kernels/kelvin/leaky_relu.cc
@@ -53,25 +53,31 @@
return kTfLiteOk;
} break;
case kTfLiteInt8: {
- kelvin::opt::LeakyReluS8(
- tflite::micro::GetTensorData<int8_t>(input),
- tflite::micro::GetTensorData<int8_t>(output),
- MatchingFlatSize(tflite::micro::GetTensorShape(input),
- tflite::micro::GetTensorShape(output)),
- data.input_zero_point, data.output_zero_point,
- data.output_multiplier_alpha, data.output_shift_alpha,
- data.output_multiplier_identity, data.output_shift_identity);
+ LeakyReluParams op_params = {};
+ op_params.input_offset = data.input_zero_point;
+ op_params.output_offset = data.output_zero_point;
+ op_params.output_multiplier_alpha = data.output_multiplier_alpha;
+ op_params.output_shift_alpha = data.output_shift_alpha;
+ op_params.output_multiplier_identity = data.output_multiplier_identity;
+ op_params.output_shift_identity = data.output_shift_identity;
+ kelvin::opt::LeakyReluS8(op_params, tflite::micro::GetTensorShape(input),
+ tflite::micro::GetTensorData<int8_t>(input),
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData<int8_t>(output));
return kTfLiteOk;
} break;
case kTfLiteInt16: {
- kelvin::opt::LeakyReluS16(
- tflite::micro::GetTensorData<int16_t>(input),
- tflite::micro::GetTensorData<int16_t>(output),
- MatchingFlatSize(tflite::micro::GetTensorShape(input),
- tflite::micro::GetTensorShape(output)),
- data.input_zero_point, data.output_zero_point,
- data.output_multiplier_alpha, data.output_shift_alpha,
- data.output_multiplier_identity, data.output_shift_identity);
+ LeakyReluParams op_params = {};
+ op_params.input_offset = data.input_zero_point;
+ op_params.output_offset = data.output_zero_point;
+ op_params.output_multiplier_alpha = data.output_multiplier_alpha;
+ op_params.output_shift_alpha = data.output_shift_alpha;
+ op_params.output_multiplier_identity = data.output_multiplier_identity;
+ op_params.output_shift_identity = data.output_shift_identity;
+ kelvin::opt::LeakyReluS16(op_params, tflite::micro::GetTensorShape(input),
+ tflite::micro::GetTensorData<int16_t>(input),
+ tflite::micro::GetTensorShape(output),
+ tflite::micro::GetTensorData<int16_t>(output));
return kTfLiteOk;
} break;
default:
diff --git a/tensorflow/lite/micro/kernels/kelvin/resize_nearest_neighbor.cc b/tensorflow/lite/micro/kernels/kelvin/resize_nearest_neighbor.cc
index 039bd0d..5b700ae 100644
--- a/tensorflow/lite/micro/kernels/kelvin/resize_nearest_neighbor.cc
+++ b/tensorflow/lite/micro/kernels/kelvin/resize_nearest_neighbor.cc
@@ -90,8 +90,7 @@
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int32_t>(output));
} else if (output->type == kTfLiteInt8) {
-
- kelvin::opt::KelvinResizeNearestNeighbor(
+ kelvin::opt::ResizeNearestNeighborS8(
op_params, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(size),