blob: a33c3406ae59962d5277832188e06e568992e8cf [file] [log] [blame]
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cmath>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_log.h"
#include "tensorflow/lite/micro/micro_utils.h"
namespace tflite {
namespace {
constexpr int kAbsNameId = 0;
constexpr int kRsrqtNameId = 1;
const int kElementwiseInputTensor = 0;
const int kElementwiseOutputTensor = 0;
struct OpDataAbsRsqrt {
int32_t multiplier;
int shift;
int input_offset;
int output_offset;
bool needs_rescale;
TfLiteQuantizationType input_quantization_type;
TfLiteType input_type;
};
bool IsNumericSupportedType(const TfLiteType type) {
return type == kTfLiteFloat32;
}
bool IsLogicalSupportedType(const TfLiteType type) {
return type == kTfLiteBool;
}
bool IsAbsSupportedType(const TfLiteType type) {
return type == kTfLiteFloat32 || type == kTfLiteInt8 || type == kTfLiteInt16;
}
bool IsRsqrtSupportedType(const TfLiteType type) {
return type == kTfLiteFloat32 || type == kTfLiteInt8;
}
inline void SetAbsOutputMultiplier(const float input_scale,
const float output_scale,
int32_t* multiplier, int* shift) {
QuantizeMultiplier(static_cast<double>(input_scale / output_scale),
multiplier, shift);
}
inline void SetRsqrtOutputMultiplier(const float input_scale,
const float output_scale,
int32_t* multiplier, int* shift) {
const double scale =
1. / static_cast<double>((std::sqrt(input_scale) * output_scale));
QuantizeMultiplier(scale, multiplier, shift);
}
typedef bool (*IsSupportedType)(TfLiteType);
template <IsSupportedType>
TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
MicroContext* micro_context = GetMicroContext(context);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kElementwiseInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kElementwiseOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (!IsSupportedType(input->type)) {
MicroPrintf("Input data type %s (%d) is not supported.",
TfLiteTypeGetName(input->type), input->type);
return kTfLiteError;
}
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
typedef bool (*IsSupportedType)(TfLiteType);
template <IsSupportedType, const int op_nameid>
TfLiteStatus PrepareAbsRsqrt(TfLiteContext* context, TfLiteNode* node) {
MicroContext* micro_context = GetMicroContext(context);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (!IsSupportedType(input->type)) {
MicroPrintf("Input data type %s (%d) is not supported.",
TfLiteTypeGetName(input->type), input->type);
return kTfLiteError;
}
auto* op_data = static_cast<OpDataAbsRsqrt*>(node->user_data);
op_data->input_type = input->type;
// For int16 type input, we support both quantized and non-quantized
// evaluation.
if (op_nameid == kAbsNameId) {
op_data->input_quantization_type = input->quantization.type;
}
if (input->type == kTfLiteInt8 ||
(input->type == kTfLiteInt16 &&
input->quantization.type != kTfLiteNoQuantization)) {
TF_LITE_ENSURE_EQ(context, input->quantization.type,
kTfLiteAffineQuantization);
TF_LITE_ENSURE_EQ(context, output->quantization.type,
kTfLiteAffineQuantization);
const auto* input_params =
reinterpret_cast<TfLiteAffineQuantization*>(input->quantization.params);
const auto* output_params = reinterpret_cast<TfLiteAffineQuantization*>(
output->quantization.params);
TF_LITE_ENSURE(context, input_params != nullptr);
TF_LITE_ENSURE(context, input_params->scale != nullptr);
TF_LITE_ENSURE(context, input_params->scale->size > 0);
TF_LITE_ENSURE(context, input_params->zero_point->size > 0);
TF_LITE_ENSURE(context, output_params != nullptr);
TF_LITE_ENSURE(context, output_params->scale != nullptr);
TF_LITE_ENSURE(context, output_params->scale->size > 0);
TF_LITE_ENSURE(context, output_params->zero_point->size > 0);
op_data->input_offset = input_params->zero_point->data[0];
op_data->output_offset = output_params->zero_point->data[0];
if (input->type == kTfLiteInt16) {
TF_LITE_ENSURE_EQ(context, op_data->input_offset, 0);
TF_LITE_ENSURE_EQ(context, op_data->output_offset, 0);
}
const float input_scale = input_params->scale->data[0];
const float output_scale = output_params->scale->data[0];
op_data->needs_rescale = input_scale != output_scale;
if (op_nameid == kAbsNameId && op_data->needs_rescale) {
SetAbsOutputMultiplier(input_scale, output_scale, &op_data->multiplier,
&op_data->shift);
} else if (op_nameid == kRsrqtNameId) {
SetRsqrtOutputMultiplier(input_scale, output_scale, &op_data->multiplier,
&op_data->shift);
}
}
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
template <typename T>
inline TfLiteStatus EvalImplQuantized(
TfLiteContext* context, TfLiteNode* node,
T func(TfLiteContext*, TfLiteNode*, T),
TfLiteStatus validate_input_func(TfLiteContext*, TfLiteNode*, T),
TfLiteType expected_type) {
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
const size_t num_elements = ElementCount(*input->dims);
const T* in_data = tflite::micro::GetTensorData<T>(input);
T* out_data = tflite::micro::GetTensorData<T>(output);
for (size_t i = 0; i < num_elements; ++i) {
if (validate_input_func) {
TF_LITE_ENSURE_OK(context,
validate_input_func(context, node, in_data[i]));
}
out_data[i] = func(context, node, in_data[i]);
}
return kTfLiteOk;
}
template <typename T>
inline T AbsHelper(T i) {
return std::abs(i);
}
template <typename T>
inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
T func(T), TfLiteStatus validate_input_func(T),
TfLiteType expected_type) {
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
const size_t num_elements = ElementCount(*input->dims);
const T* in_data = tflite::micro::GetTensorData<T>(input);
T* out_data = tflite::micro::GetTensorData<T>(output);
for (size_t i = 0; i < num_elements; ++i) {
if (validate_input_func) {
TF_LITE_ENSURE_OK(context, validate_input_func(in_data[i]));
}
out_data[i] = func(in_data[i]);
}
return kTfLiteOk;
}
inline TfLiteStatus EvalNumeric(TfLiteContext* context, TfLiteNode* node,
float float_func(float)) {
return EvalImpl<float>(context, node, float_func,
/*validate_input_func=*/nullptr, kTfLiteFloat32);
}
inline TfLiteStatus EvalLogical(TfLiteContext* context, TfLiteNode* node,
bool bool_func(bool)) {
return EvalImpl<bool>(context, node, bool_func,
/*validate_input_func=*/nullptr, kTfLiteBool);
}
void* ElementWiseAbsRsqrtInit(TfLiteContext* context, const char* buffer,
size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(OpDataAbsRsqrt));
}
template <typename T>
inline T AbsEvalQuantized(TfLiteContext* context, TfLiteNode* node, T i) {
const auto* op_data = static_cast<const OpDataAbsRsqrt*>(node->user_data);
const int kMin = std::numeric_limits<T>::min();
const int kMax = std::numeric_limits<T>::max();
const int32_t value = std::abs(i - op_data->input_offset);
if (!op_data->needs_rescale) {
return static_cast<T>(
std::min(std::max(static_cast<long int>(value + op_data->output_offset),
static_cast<long int>(kMin)),
static_cast<long int>(kMax)));
}
const int32_t output = tflite::MultiplyByQuantizedMultiplier(
value, op_data->multiplier, op_data->shift) +
op_data->output_offset;
return static_cast<T>(std::min(
std::max(static_cast<long int>(output), static_cast<long int>(kMin)),
static_cast<long int>(kMax)));
}
template <typename T>
inline T RsqrtEvalQuantized(TfLiteContext* context, TfLiteNode* node, T i) {
const auto* op_data = static_cast<const OpDataAbsRsqrt*>(node->user_data);
const int kMin = std::numeric_limits<T>::min();
const int kMax = std::numeric_limits<T>::max();
const int32_t value = (i - op_data->input_offset);
const int32_t kShift = 20; // Shift to keep value integer.
if (value == 0) {
// Assume that any value close to 0 represents the max output value.
return static_cast<T>(kMax);
}
int32_t inv_sqrt_multiplier;
int inv_sqrt_shift;
GetInvSqrtQuantizedMultiplierExp(value, kReverseShift, &inv_sqrt_multiplier,
&inv_sqrt_shift);
const int32_t data = tflite::MultiplyByQuantizedMultiplier(
static_cast<int32_t>(1), inv_sqrt_multiplier, inv_sqrt_shift + kShift);
const int32_t output =
tflite::MultiplyByQuantizedMultiplier(data, op_data->multiplier,
op_data->shift - kShift) +
op_data->output_offset;
return static_cast<T>(std::min(
std::max(static_cast<long int>(output), static_cast<long int>(kMin)),
static_cast<long int>(kMax)));
}
template <typename T>
TfLiteStatus validate_input_func(TfLiteContext* context, TfLiteNode* node,
T i) {
const auto* op_data = static_cast<const OpDataAbsRsqrt*>(node->user_data);
TF_LITE_ENSURE_MSG(context, i >= op_data->input_offset,
"Rsqrt is only defined for positive values");
return static_cast<TfLiteStatus>(kTfLiteOk);
}
TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) {
OpDataAbsRsqrt* op_data = reinterpret_cast<OpDataAbsRsqrt*>(node->user_data);
TfLiteType type = op_data->input_type;
TfLiteQuantizationType input_quantization_type =
op_data->input_quantization_type;
TfLiteStatus eval_result;
switch (type) {
case kTfLiteFloat32:
eval_result = EvalNumeric(context, node, std::abs);
break;
case kTfLiteInt8:
eval_result =
EvalImplQuantized<int8_t>(context, node, AbsEvalQuantized,
/*validate_input_func=*/nullptr, type);
break;
case kTfLiteInt16:
eval_result =
input_quantization_type == kTfLiteNoQuantization
? EvalImpl<int16_t>(context, node, AbsHelper,
/*validate_input_func=*/nullptr, type)
: EvalImplQuantized<int16_t>(context, node, AbsEvalQuantized,
/*validate_input_func=*/nullptr,
type);
break;
default:
MicroPrintf("Current data type %s is not supported.",
TfLiteTypeGetName(type));
return kTfLiteError;
break;
}
return eval_result;
}
TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
return EvalNumeric(context, node, std::sin);
}
TfLiteStatus CosEval(TfLiteContext* context, TfLiteNode* node) {
return EvalNumeric(context, node, std::cos);
}
TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) {
return EvalNumeric(context, node, std::log);
}
TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) {
return EvalNumeric(context, node, std::sqrt);
}
TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
const auto* op_data = static_cast<const OpDataAbsRsqrt*>(node->user_data);
TfLiteType type = op_data->input_type;
switch (type) {
case kTfLiteFloat32:
return EvalImpl<float>(
context, node, [](float f) { return 1.f / std::sqrt(f); },
/*validate_input_func=*/nullptr, type);
case kTfLiteInt8:
return EvalImplQuantized<int8_t>(context, node, RsqrtEvalQuantized,
validate_input_func, type);
case kTfLiteInt16:
return EvalImplQuantized<int16_t>(context, node, RsqrtEvalQuantized,
validate_input_func, type);
default:
MicroPrintf("Current data type %s is not supported.",
TfLiteTypeGetName(type));
return kTfLiteError;
}
}
TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) {
return EvalNumeric(context, node, [](float f) { return f * f; });
}
TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
return EvalLogical(context, node, [](bool v) { return !v; });
}
} // namespace
TFLMRegistration Register_ABS() {
return tflite::micro::RegisterOp(
ElementWiseAbsRsqrtInit, PrepareAbsRsqrt<IsAbsSupportedType, kAbsNameId>,
AbsEval);
}
TFLMRegistration Register_SIN() {
return tflite::micro::RegisterOp(
nullptr, GenericPrepare<IsNumericSupportedType>, SinEval);
}
TFLMRegistration Register_COS() {
return tflite::micro::RegisterOp(
nullptr, GenericPrepare<IsNumericSupportedType>, CosEval);
}
TFLMRegistration Register_LOG() {
return tflite::micro::RegisterOp(
nullptr, GenericPrepare<IsNumericSupportedType>, LogEval);
}
TFLMRegistration Register_SQRT() {
return tflite::micro::RegisterOp(
nullptr, GenericPrepare<IsNumericSupportedType>, SqrtEval);
}
TFLMRegistration Register_RSQRT() {
return tflite::micro::RegisterOp(
ElementWiseAbsRsqrtInit,
PrepareAbsRsqrt<IsRsqrtSupportedType, kRsrqtNameId>, RsqrtEval);
}
TFLMRegistration Register_SQUARE() {
return tflite::micro::RegisterOp(
nullptr, GenericPrepare<IsNumericSupportedType>, SquareEval);
}
TFLMRegistration Register_LOGICAL_NOT() {
return tflite::micro::RegisterOp(
nullptr, GenericPrepare<IsLogicalSupportedType>, LogicalNotEval);
}
} // namespace tflite