blob: 2ba058cce00076f22cd193a1678d69d455442f62 [file] [log] [blame]
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/reference/arg_min_max.h"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/reference/comparisons.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_log.h"
namespace tflite {
namespace {
constexpr int kInputTensor = 0;
constexpr int kAxis = 1;
constexpr int kOutputTensor = 0;
template <typename T1, typename T2, typename T3>
inline void ArgMinMaxHelper(const RuntimeShape& input1_shape,
const T1* input1_data, const T3* input2_data,
const RuntimeShape& output_shape, T2* output_data,
bool is_arg_max) {
// Use Greater/Less from comparisons.h (formerly from kernels/micro_utils.h
// which was deprecated). Same as gtl::Greater but used here to reduce
// dependencies and binary size for micro environment.
if (is_arg_max) {
reference_ops::ArgMinMax(input1_shape, input1_data, input2_data,
output_shape, output_data,
reference_ops::GreaterFn<T1>);
} else {
reference_ops::ArgMinMax(input1_shape, input1_data, input2_data,
output_shape, output_data,
reference_ops::LessFn<T1>);
}
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
const TfLiteEvalTensor* axis =
tflite::micro::GetEvalInput(context, node, kAxis);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
#define TF_LITE_ARG_MIN_MAX(data_type, axis_type, output_type) \
ArgMinMaxHelper(tflite::micro::GetTensorShape(input), \
tflite::micro::GetTensorData<data_type>(input), \
tflite::micro::GetTensorData<axis_type>(axis), \
tflite::micro::GetTensorShape(output), \
tflite::micro::GetTensorData<output_type>(output), \
is_arg_max)
if (axis->type == kTfLiteInt32) {
if (output->type == kTfLiteInt32) {
switch (input->type) {
case kTfLiteFloat32:
TF_LITE_ARG_MIN_MAX(float, int32_t, int32_t);
break;
case kTfLiteInt8:
TF_LITE_ARG_MIN_MAX(int8_t, int32_t, int32_t);
break;
default:
MicroPrintf(
"Only float32, uint8_t and int8_t are "
"supported currently, got %s.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
} else {
MicroPrintf("Only int32_t are supported currently, got %s.",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
} else {
MicroPrintf("Only int32_t are supported currently, got %s.",
TfLiteTypeGetName(axis->type));
return kTfLiteError;
}
#undef TF_LITE_ARG_MIN_MAX
return kTfLiteOk;
}
TfLiteStatus ArgMinEval(TfLiteContext* context, TfLiteNode* node) {
return Eval(context, node, false);
}
TfLiteStatus ArgMaxEval(TfLiteContext* context, TfLiteNode* node) {
return Eval(context, node, true);
}
} // namespace
TFLMRegistration Register_ARG_MAX() {
return tflite::micro::RegisterOp(nullptr, nullptr, ArgMaxEval);
}
TFLMRegistration Register_ARG_MIN() {
return tflite::micro::RegisterOp(nullptr, nullptr, ArgMinEval);
}
} // namespace tflite