| /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| #include "tensorflow/lite/c/common.h" |
| #include "tensorflow/lite/kernels/internal/quantization_util.h" |
| #include "tensorflow/lite/kernels/internal/reference/binary_function.h" |
| #include "tensorflow/lite/kernels/internal/reference/integer_ops/add.h" |
| #include "tensorflow/lite/kernels/kernel_util.h" |
| #include "tensorflow/lite/micro/kernels/kernel_util.h" |
| #include "tensorflow/lite/micro/micro_context.h" |
| #include "tensorflow/lite/micro/micro_log.h" |
| |
| namespace tflite { |
| namespace { |
| constexpr int kInputTensor1 = 0; |
| constexpr int kInputTensor2 = 1; |
| constexpr int kOutputTensor = 0; |
| |
| struct OpData { |
| bool requires_broadcast; |
| ArithmeticParams arithmetic_params; |
| }; |
| |
| template <typename T> |
| T SquaredDifference(T input1, T input2) { |
| const T difference = input1 - input2; |
| return difference * difference; |
| } |
| |
| void* SquaredDifferenceInit(TfLiteContext* context, const char* buffer, |
| size_t length) { |
| TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); |
| return context->AllocatePersistentBuffer(context, sizeof(OpData)); |
| } |
| |
| void PrepareQuantized( |
| const TfLiteQuantizationParams& input1_quantization_params, |
| const TfLiteQuantizationParams& input2_quantization_params, |
| const TfLiteQuantizationParams& output_quantization_params, |
| const int left_shift, const int32_t quantized_activation_min, |
| const int32_t quantized_activation_max, OpData* data) { |
| data->arithmetic_params.input1_offset = |
| -input1_quantization_params.zero_point; |
| data->arithmetic_params.input2_offset = |
| -input2_quantization_params.zero_point; |
| data->arithmetic_params.output_offset = output_quantization_params.zero_point; |
| data->arithmetic_params.left_shift = left_shift; |
| const double twice_max_input_scale = |
| 2.0 * static_cast<double>(std::max(input1_quantization_params.scale, |
| input2_quantization_params.scale)); |
| const double real_input1_multiplier = |
| static_cast<double>(input1_quantization_params.scale) / |
| twice_max_input_scale; |
| double real_input2_multiplier = |
| static_cast<double>(input2_quantization_params.scale) / |
| twice_max_input_scale; |
| const double real_output_multiplier = |
| (twice_max_input_scale * twice_max_input_scale) / |
| static_cast<double>((1 << data->arithmetic_params.left_shift * 2) * |
| output_quantization_params.scale); |
| QuantizeMultiplierSmallerThanOneExp( |
| real_input1_multiplier, &data->arithmetic_params.input1_multiplier, |
| &data->arithmetic_params.input1_shift); |
| QuantizeMultiplierSmallerThanOneExp( |
| real_input2_multiplier, &data->arithmetic_params.input2_multiplier, |
| &data->arithmetic_params.input2_shift); |
| QuantizeMultiplier(real_output_multiplier, |
| &data->arithmetic_params.output_multiplier, |
| &data->arithmetic_params.output_shift); |
| data->arithmetic_params.quantized_activation_min = quantized_activation_min; |
| data->arithmetic_params.quantized_activation_max = quantized_activation_max; |
| } |
| |
| TfLiteStatus SquaredDifferencePrepare(TfLiteContext* context, |
| TfLiteNode* node) { |
| TFLITE_DCHECK(node->user_data != nullptr); |
| OpData* data = reinterpret_cast<OpData*>(node->user_data); |
| data->requires_broadcast = false; |
| |
| TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); |
| TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
| |
| MicroContext* micro_context = GetMicroContext(context); |
| |
| TfLiteTensor* input1 = |
| micro_context->AllocateTempInputTensor(node, kInputTensor1); |
| TF_LITE_ENSURE(context, input1 != nullptr); |
| TfLiteTensor* input2 = |
| micro_context->AllocateTempInputTensor(node, kInputTensor2); |
| TF_LITE_ENSURE(context, input2 != nullptr); |
| TfLiteTensor* output = |
| micro_context->AllocateTempOutputTensor(node, kOutputTensor); |
| TF_LITE_ENSURE(context, output != nullptr); |
| |
| TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type); |
| output->type = input2->type; |
| |
| const TfLiteQuantizationParams& input1_quantization_params = input1->params; |
| const TfLiteQuantizationParams& input2_quantization_params = input2->params; |
| const TfLiteQuantizationParams& output_quantization_params = output->params; |
| if (input1->type == kTfLiteInt8) { |
| const int32_t integer_type_min = std::numeric_limits<int8_t>::min(); |
| const int32_t integer_type_max = std::numeric_limits<int8_t>::max(); |
| TF_LITE_ENSURE(context, |
| input1_quantization_params.zero_point >= integer_type_min); |
| TF_LITE_ENSURE(context, |
| input1_quantization_params.zero_point <= integer_type_max); |
| TF_LITE_ENSURE(context, |
| input2_quantization_params.zero_point >= integer_type_min); |
| TF_LITE_ENSURE(context, |
| input2_quantization_params.zero_point <= integer_type_max); |
| TF_LITE_ENSURE(context, |
| output_quantization_params.zero_point >= integer_type_min); |
| TF_LITE_ENSURE(context, |
| output_quantization_params.zero_point <= integer_type_max); |
| // leftshift = 7 is selected so that maximum shifted result 255^2 * (1 << (7 |
| // * 2 )) does not overflow signed 32-bit integer |
| PrepareQuantized(input1_quantization_params, input2_quantization_params, |
| output_quantization_params, /*left_shift=*/7, |
| /*quantized_activation_min*/ integer_type_min, |
| /*quantized_activation_max*/ integer_type_max, data); |
| } else if (input1->type == kTfLiteInt16) { |
| const int32_t integer_type_min = std::numeric_limits<int16_t>::min(); |
| const int32_t integer_type_max = std::numeric_limits<int16_t>::max(); |
| TF_LITE_ENSURE(context, input1_quantization_params.zero_point == 0); |
| TF_LITE_ENSURE(context, input2_quantization_params.zero_point == 0); |
| TF_LITE_ENSURE(context, output_quantization_params.zero_point == 0); |
| |
| // leftshift = 0 as number is already 16-bit. so that maximum shifted result |
| // 32767^2 * (1 << (0 * 2 )) |
| PrepareQuantized(input1_quantization_params, input2_quantization_params, |
| output_quantization_params, /*left_shift=*/0, |
| /*quantized_activation_min*/ integer_type_min, |
| /*quantized_activation_max*/ integer_type_max, data); |
| } |
| |
| data->requires_broadcast = !HaveSameShapes(input1, input2); |
| |
| micro_context->DeallocateTempTfLiteTensor(input1); |
| micro_context->DeallocateTempTfLiteTensor(input2); |
| micro_context->DeallocateTempTfLiteTensor(output); |
| return kTfLiteOk; |
| } |
| |
| template <typename T> |
| T SquaredDifference(T x, T y, const ArithmeticParams& params) { |
| const int32_t input1_val = params.input1_offset + x; |
| const int32_t input2_val = params.input2_offset + y; |
| const int32_t shifted_input1_val = input1_val * (1 << params.left_shift); |
| const int32_t shifted_input2_val = input2_val * (1 << params.left_shift); |
| const int32_t scaled_input1_val = |
| MultiplyByQuantizedMultiplierSmallerThanOneExp( |
| shifted_input1_val, params.input1_multiplier, params.input1_shift); |
| const int32_t scaled_input2_val = |
| MultiplyByQuantizedMultiplierSmallerThanOneExp( |
| shifted_input2_val, params.input2_multiplier, params.input2_shift); |
| const int32_t raw_diff = scaled_input1_val - scaled_input2_val; |
| |
| // Max of this is 32767^2 * (1 << 0), so won't overflow 32 bits. |
| const int32_t squared_raw_diff = raw_diff * raw_diff; |
| const int32_t raw_output = |
| MultiplyByQuantizedMultiplier(squared_raw_diff, params.output_multiplier, |
| params.output_shift) + |
| params.output_offset; |
| const int32_t clamped_output = |
| std::min(params.quantized_activation_max, |
| std::max(params.quantized_activation_min, raw_output)); |
| return static_cast<T>(clamped_output); |
| } |
| |
| template <typename T> |
| void EvalQuantizedSquaredDifference(TfLiteContext* context, TfLiteNode* node, |
| const OpData* data, |
| const TfLiteEvalTensor* input1, |
| const TfLiteEvalTensor* input2, |
| TfLiteEvalTensor* output) { |
| const auto* op_data = static_cast<const OpData*>(node->user_data); |
| if (data->requires_broadcast) { |
| reference_integer_ops::BroadcastBinaryFunction4DSlow( |
| op_data->arithmetic_params, tflite::micro::GetTensorShape(input1), |
| tflite::micro::GetTensorData<T>(input1), |
| tflite::micro::GetTensorShape(input2), |
| tflite::micro::GetTensorData<T>(input2), |
| tflite::micro::GetTensorShape(output), |
| tflite::micro::GetTensorData<T>(output), |
| reference_integer_ops::CheckArithmeticParams, SquaredDifference); |
| } else { |
| const int flat_size = tflite::micro::GetTensorShape(input1).FlatSize(); |
| reference_integer_ops::ElementWise( |
| flat_size, op_data->arithmetic_params, |
| tflite::micro::GetTensorData<T>(input1), |
| tflite::micro::GetTensorData<T>(input2), |
| tflite::micro::GetTensorData<T>(output), |
| reference_integer_ops::CheckArithmeticParams, SquaredDifference); |
| } |
| } |
| |
| template <typename T> |
| void EvalSquaredDifference(TfLiteContext* context, TfLiteNode* node, |
| const OpData* data, const TfLiteEvalTensor* input1, |
| const TfLiteEvalTensor* input2, |
| TfLiteEvalTensor* output) { |
| if (data->requires_broadcast) { |
| reference_ops::BroadcastBinaryFunction4DSlow<T, T, T>( |
| tflite::micro::GetTensorShape(input1), |
| tflite::micro::GetTensorData<T>(input1), |
| tflite::micro::GetTensorShape(input2), |
| tflite::micro::GetTensorData<T>(input2), |
| tflite::micro::GetTensorShape(output), |
| tflite::micro::GetTensorData<T>(output), SquaredDifference<T>); |
| } else { |
| reference_ops::BinaryFunction<T, T, T>( |
| tflite::micro::GetTensorShape(input1), |
| tflite::micro::GetTensorData<T>(input1), |
| tflite::micro::GetTensorShape(input2), |
| tflite::micro::GetTensorData<T>(input2), |
| tflite::micro::GetTensorShape(output), |
| tflite::micro::GetTensorData<T>(output), SquaredDifference<T>); |
| } |
| } |
| |
| TfLiteStatus SquaredDifferenceEval(TfLiteContext* context, TfLiteNode* node) { |
| OpData* data = reinterpret_cast<OpData*>(node->user_data); |
| |
| const TfLiteEvalTensor* input1 = |
| tflite::micro::GetEvalInput(context, node, kInputTensor1); |
| const TfLiteEvalTensor* input2 = |
| tflite::micro::GetEvalInput(context, node, kInputTensor2); |
| TfLiteEvalTensor* output = |
| tflite::micro::GetEvalOutput(context, node, kOutputTensor); |
| |
| if (output->type == kTfLiteFloat32) { |
| EvalSquaredDifference<float>(context, node, data, input1, input2, output); |
| } else if (output->type == kTfLiteInt32) { |
| EvalSquaredDifference<int32_t>(context, node, data, input1, input2, output); |
| } else if (output->type == kTfLiteInt8) { |
| EvalQuantizedSquaredDifference<int8_t>(context, node, data, input1, input2, |
| output); |
| } else if (output->type == kTfLiteInt16) { |
| EvalQuantizedSquaredDifference<int16_t>(context, node, data, input1, input2, |
| output); |
| } else { |
| MicroPrintf( |
| "SquaredDifference only supports FLOAT32, INT32 , INT16 and INT8 now, " |
| "got %d.", |
| output->type); |
| return kTfLiteError; |
| } |
| |
| return kTfLiteOk; |
| } |
| } // namespace |
| |
| TFLMRegistration Register_SQUARED_DIFFERENCE() { |
| return tflite::micro::RegisterOp( |
| SquaredDifferenceInit, SquaredDifferencePrepare, SquaredDifferenceEval); |
| } |
| |
| } // namespace tflite |