blob: 3774dddb58167c8c7afb235d5181e6fc61c7adab [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_log.h"
#include "tensorflow/lite/micro/micro_utils.h"
namespace tflite {
namespace {
constexpr int kParams = 0;
constexpr int kIndices = 1;
constexpr int kOutputTensor = 0;
constexpr int MAX_INDICES_ND = 5;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
MicroContext* micro_context = GetMicroContext(context);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TfLiteTensor* params = micro_context->AllocateTempInputTensor(node, kParams);
TF_LITE_ENSURE(context, params != nullptr);
TfLiteTensor* indices =
micro_context->AllocateTempInputTensor(node, kIndices);
TF_LITE_ENSURE(context, indices != nullptr);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
switch (params->type) {
case kTfLiteFloat32:
case kTfLiteInt8:
break;
default:
MicroPrintf("Params of type '%s' are not supported by gather_nd.",
TfLiteTypeGetName(params->type));
return kTfLiteError;
break;
}
switch (indices->type) {
case kTfLiteInt32:
break;
default:
MicroPrintf("Indices of type '%s' are not supported by gather_nd.",
TfLiteTypeGetName(indices->type));
return kTfLiteError;
}
const int params_rank = NumDimensions(params);
const int indices_rank = NumDimensions(indices);
const int indices_nd = SizeOfDimension(indices, indices_rank - 1);
if (params_rank < 1) {
MicroPrintf("Params must be at least a vector.");
return kTfLiteError;
}
if (indices_rank < 1) {
MicroPrintf("Indices must be at least a vector.");
return kTfLiteError;
}
if (indices_nd > params_rank) {
MicroPrintf("Index innermost dimension length must be <= params rank.");
return kTfLiteError;
}
if (indices_nd > MAX_INDICES_ND) {
MicroPrintf("Index innermost dimension length must not exceed %d.",
MAX_INDICES_ND);
return kTfLiteError;
}
// Assign to output the input type.
output->type = params->type;
// The tensor output dims must be relocated
// from the FlatBuffer to the persistent storage arena.
TfLiteEvalTensor* output_eval =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_OK(context, tflite::micro::CreateWritableTensorDimsWithCopy(
context, output, output_eval));
// TFLM gather_nd does not create the output tensor, but it needs to ensure
// that the output shape is correct. The result shape is
// indices.shape[:-1] + params.shape[indices.shape[-1]:]
TfLiteIntArray* output_shape = output->dims;
int output_index = 0;
for (int i = 0; i < indices_rank - 1; ++i) {
output_shape->data[output_index++] = indices->dims->data[i];
}
for (int i = indices_nd; i < params_rank; ++i) {
output_shape->data[output_index++] = params->dims->data[i];
}
output_shape->size = output_index;
micro_context->DeallocateTempTfLiteTensor(params);
micro_context->DeallocateTempTfLiteTensor(indices);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
template <typename ParamsT, typename IndicesT>
TfLiteStatus GatherNd(const TfLiteEvalTensor* params,
const TfLiteEvalTensor* indices,
TfLiteEvalTensor* output) {
const int indices_dims = indices->dims->size;
const int indices_nd = indices->dims->data[indices_dims - 1];
const int params_dims = params->dims->size;
const IndicesT* index_data = tflite::micro::GetTensorData<IndicesT>(indices);
const ParamsT* param_data = tflite::micro::GetTensorData<ParamsT>(params);
ParamsT* output_data = tflite::micro::GetTensorData<ParamsT>(output);
int n_slices = 1;
for (int i = 0; i < indices_dims - 1; ++i) {
n_slices *= indices->dims->data[i];
}
// If indices[-1] == params.rank, fetch single elements.
// If indices[-1] < params.rank, fetch slices.
int slice_size = 1;
for (int i = indices_nd; i < params_dims; ++i) {
slice_size *= params->dims->data[i];
}
int params_flat_size = ElementCount(*params->dims);
int remain_flat_size = params_flat_size;
// Number of elements per dimension
int dims_to_count[MAX_INDICES_ND];
for (int i = 0; i < indices_nd; ++i) {
dims_to_count[i] = remain_flat_size / params->dims->data[i];
remain_flat_size = dims_to_count[i];
}
for (int i = 0; i < n_slices; ++i) {
int from_pos = 0;
for (int j = 0; j < indices_nd; ++j) {
int offset = i * indices_nd + j;
IndicesT index = index_data[offset];
from_pos += index * dims_to_count[j];
}
if (from_pos < 0 || from_pos + slice_size > params_flat_size) {
return kTfLiteError;
}
std::memcpy(output_data + i * slice_size, param_data + from_pos,
sizeof(ParamsT) * slice_size);
}
return kTfLiteOk;
}
template <typename IndicesT>
TfLiteStatus EvalGatherNd(TfLiteContext* context,
const TfLiteEvalTensor* params,
const TfLiteEvalTensor* indices,
TfLiteEvalTensor* output) {
TfLiteStatus status = kTfLiteError;
switch (params->type) {
case kTfLiteFloat32:
status = GatherNd<float, IndicesT>(params, indices, output);
break;
case kTfLiteInt8:
status = GatherNd<int8_t, IndicesT>(params, indices, output);
break;
default:
MicroPrintf("Params type '%s' are not supported by gather_nd.",
TfLiteTypeGetName(params->type));
return kTfLiteError;
}
if (status != kTfLiteOk) {
MicroPrintf("gather_nd index out of bounds");
}
return status;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteEvalTensor* params =
tflite::micro::GetEvalInput(context, node, kParams);
const TfLiteEvalTensor* indices =
tflite::micro::GetEvalInput(context, node, kIndices);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
switch (indices->type) {
case kTfLiteInt32:
return EvalGatherNd<int32_t>(context, params, indices, output);
break;
default:
MicroPrintf("Indices of type '%s' are not supported by gather_nd.",
TfLiteTypeGetName(indices->type));
return kTfLiteError;
}
}
} // namespace
TFLMRegistration Register_GATHER_ND() {
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite