| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/lite/c/builtin_op_data.h" |
| #include "tensorflow/lite/c/common.h" |
| #include "tensorflow/lite/kernels/internal/common.h" |
| #include "tensorflow/lite/kernels/internal/quantization_util.h" |
| #include "tensorflow/lite/kernels/internal/reference/fully_connected.h" |
| #include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h" |
| #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
| #include "tensorflow/lite/kernels/kernel_util.h" |
| #include "tensorflow/lite/micro/kernels/fully_connected.h" |
| #include "tensorflow/lite/micro/kernels/kernel_util.h" |
| |
| namespace tflite { |
| |
| const int kFullyConnectedInputTensor = 0; |
| const int kFullyConnectedWeightsTensor = 1; |
| const int kFullyConnectedBiasTensor = 2; |
| const int kFullyConnectedOutputTensor = 0; |
| |
| FullyConnectedParams FullyConnectedParamsQuantized( |
| const OpDataFullyConnected& op_data) { |
| FullyConnectedParams op_params; |
| op_params.input_offset = -op_data.input_zero_point; |
| op_params.weights_offset = -op_data.filter_zero_point; |
| op_params.output_offset = op_data.output_zero_point; |
| op_params.output_multiplier = op_data.output_multiplier; |
| op_params.output_shift = op_data.output_shift; |
| op_params.quantized_activation_min = op_data.output_activation_min; |
| op_params.quantized_activation_max = op_data.output_activation_max; |
| return op_params; |
| } |
| |
| FullyConnectedParams FullyConnectedParamsFloat( |
| TfLiteFusedActivation activation) { |
| FullyConnectedParams op_params; |
| CalculateActivationRange(activation, &op_params.float_activation_min, |
| &op_params.float_activation_max); |
| return op_params; |
| } |
| |
| TfLiteStatus CalculateOpDataFullyConnected( |
| TfLiteContext* context, TfLiteFusedActivation activation, |
| TfLiteType data_type, const TfLiteTensor* input, const TfLiteTensor* filter, |
| const TfLiteTensor* bias, TfLiteTensor* output, |
| OpDataFullyConnected* data) { |
| // TODO(b/324385802): Support per-channel quantization for FullyConnected. |
| // If you have hit this failure message, you will need to disable this |
| // behavior. This can be done by setting the following flag to true: |
| // TfLiteConverter._experimental_disable_per_channel_quantization_for_dense_layers |
| // https://github.com/tensorflow/tensorflow/blob/377f47694fa790e98db6665b9adecde00b5e0d68/tensorflow/lite/python/lite.py#L674 |
| if (filter->quantization.type == kTfLiteAffineQuantization && |
| filter->quantization.params != nullptr) { |
| TfLiteAffineQuantization* affine_quantization = |
| reinterpret_cast<TfLiteAffineQuantization*>( |
| filter->quantization.params); |
| TF_LITE_ENSURE(context, affine_quantization->scale); |
| TF_LITE_ENSURE_MSG( |
| context, affine_quantization->scale->size == 1, |
| "FullyConnected per-channel quantization not yet supported. Please set " |
| "converter._experimental_disable_per_channel_quantization_for_dense_" |
| "layers = True."); |
| } |
| |
| if (data_type != kTfLiteFloat32) { |
| double real_multiplier = 0.0; |
| TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( |
| context, input, filter, bias, output, &real_multiplier)); |
| QuantizeMultiplier(real_multiplier, &data->output_multiplier, |
| &data->output_shift); |
| |
| // Filter weights will always be symmetric quantized since we only support |
| // int8 quantization. See |
| // https://github.com/tensorflow/tensorflow/issues/44912 for additional |
| // context. |
| TFLITE_DCHECK(filter->params.zero_point == 0); |
| |
| data->input_zero_point = input->params.zero_point; |
| data->filter_zero_point = filter->params.zero_point; |
| data->output_zero_point = output->params.zero_point; |
| |
| return CalculateActivationRangeQuantized(context, activation, output, |
| &data->output_activation_min, |
| &data->output_activation_max); |
| } |
| return kTfLiteOk; |
| } |
| |
| } // namespace tflite |