Revert "Compute output shapes for some kernels (#2356)" (#2390)
This change broke some internal models, so reverting this until we can better understand why.
This also reverts PR #2383. When relanded, both should go back in together.
BUG=b/318738218
diff --git a/tensorflow/lite/micro/examples/micro_speech/micro_speech_test.cc b/tensorflow/lite/micro/examples/micro_speech/micro_speech_test.cc
index 3ce32f3..f31728c 100644
--- a/tensorflow/lite/micro/examples/micro_speech/micro_speech_test.cc
+++ b/tensorflow/lite/micro/examples/micro_speech/micro_speech_test.cc
@@ -32,13 +32,19 @@
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/micro/testing/micro_test.h"
+#define TF_LITE_MICRO_CHECK_FAIL() \
+ do { \
+ if (micro_test::did_test_fail) { \
+ return kTfLiteError; \
+ } \
+ } while (false)
+
namespace {
// Arena size is a guesstimate, followed by use of
// MicroInterpreter::arena_used_bytes() on both the AudioPreprocessor and
-// MicroSpeech models and using the larger of the two results plus the
-// arena alignment size (16).
-constexpr size_t kArenaSize = 28664; // xtensa p6
+// MicroSpeech models and using the larger of the two results.
+constexpr size_t kArenaSize = 28584; // xtensa p6
alignas(16) uint8_t g_arena[kArenaSize];
using Features = int8_t[kFeatureCount][kFeatureSize];
diff --git a/tensorflow/lite/micro/kernels/arc_mli/conv.cc b/tensorflow/lite/micro/kernels/arc_mli/conv.cc
index 896e228..41d2c53 100644
--- a/tensorflow/lite/micro/kernels/arc_mli/conv.cc
+++ b/tensorflow/lite/micro/kernels/arc_mli/conv.cc
@@ -1,4 +1,4 @@
-/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -29,7 +29,6 @@
#include "tensorflow/lite/micro/kernels/arc_mli/mli_tf_utils.h"
#include "tensorflow/lite/micro/kernels/arc_mli/scratch_buf_mgr.h"
#include "tensorflow/lite/micro/kernels/arc_mli/scratch_buffers.h"
-#include "tensorflow/lite/micro/kernels/conv.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_log.h"
@@ -123,7 +122,7 @@
TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
const TfLiteConvParams* params, int width,
int height, int filter_width, int filter_height,
- int* out_width, int* out_height,
+ int out_width, int out_height,
const TfLiteType data_type, OpData* data) {
bool has_bias = node->inputs->size == 3;
// Check number of inputs/outputs
@@ -135,7 +134,7 @@
data->padding = ComputePaddingHeightWidth(
params->stride_height, params->stride_width,
params->dilation_height_factor, params->dilation_width_factor, height,
- width, filter_height, filter_width, padding, out_height, out_width);
+ width, filter_height, filter_width, padding, &out_height, &out_width);
// Note that quantized inference requires that all tensors have their
// parameters set. This is usually done during quantized training.
#if !defined(TF_LITE_STRIP_REFERENCE_IMPL)
@@ -168,7 +167,6 @@
#endif
return kTfLiteOk;
}
-
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(OpData));
@@ -192,17 +190,6 @@
TfLiteTensor* bias =
micro_context->AllocateTempInputTensor(context, node, kBiasTensor);
- // Check dimensionality of input, filter, output
- TF_LITE_ENSURE_EQ(context, input->dims->size, 4);
- TF_LITE_ENSURE_EQ(context, filter->dims->size, 4);
- TF_LITE_ENSURE_EQ(context, output->dims->size, 4);
-
- // Check input channels matching filter
- const int input_channels = input->dims->data[3];
- const int filter_input_channels = filter->dims->data[3];
- TF_LITE_ENSURE(context, filter_input_channels > 0);
- TF_LITE_ENSURE_EQ(context, input_channels % filter_input_channels, 0);
-
int input_width = input->dims->data[2];
int input_height = input->dims->data[1];
#if defined(MLI_2_0) && !defined(MLI_2_0_KRNL_TEST)
@@ -212,8 +199,8 @@
int filter_width = filter->dims->data[2];
int filter_height = filter->dims->data[1];
#endif
- int output_width = 0;
- int output_height = 0;
+ int output_width = output->dims->data[2];
+ int output_height = output->dims->data[1];
// Dynamically allocate per-channel quantization parameters.
const int num_channels = filter->dims->data[kConvQuantizedDimension];
@@ -248,11 +235,7 @@
TF_LITE_ENSURE_STATUS(CalculateOpData(
context, node, params, input_width, input_height, filter_width,
- filter_height, &output_width, &output_height, input->type, data));
-
- // compute output tensor shape and relocate shape data
- TF_LITE_ENSURE_STATUS(ConvReshapeOutputTensor(
- context, node, input, filter, output, output_height, output_width));
+ filter_height, output_width, output_height, input->type, data));
data->input_zero_point = input->params.zero_point;
data->filter_zero_point = filter->params.zero_point;
diff --git a/tensorflow/lite/micro/kernels/arc_mli/depthwise_conv.cc b/tensorflow/lite/micro/kernels/arc_mli/depthwise_conv.cc
index 4fa5e94..c2c9cd5 100644
--- a/tensorflow/lite/micro/kernels/arc_mli/depthwise_conv.cc
+++ b/tensorflow/lite/micro/kernels/arc_mli/depthwise_conv.cc
@@ -1,4 +1,4 @@
-/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -30,7 +30,6 @@
#include "tensorflow/lite/micro/kernels/arc_mli/mli_tf_utils.h"
#include "tensorflow/lite/micro/kernels/arc_mli/scratch_buf_mgr.h"
#include "tensorflow/lite/micro/kernels/arc_mli/scratch_buffers.h"
-#include "tensorflow/lite/micro/kernels/depthwise_conv.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_log.h"
@@ -119,16 +118,17 @@
TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
TfLiteDepthwiseConvParams* params, int width,
int height, int filter_width, int filter_height,
- int* out_width, int* out_height,
const TfLiteType data_type, OpData* data) {
bool has_bias = node->inputs->size == 3;
// Check number of inputs/outputs
TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
+ int unused_output_height, unused_output_width;
data->padding = ComputePaddingHeightWidth(
params->stride_height, params->stride_width, 1, 1, height, width,
- filter_height, filter_width, params->padding, out_height, out_width);
+ filter_height, filter_width, params->padding, &unused_output_height,
+ &unused_output_width);
// Note that quantized inference requires that all tensors have their
// parameters set. This is usually done during quantized training.
@@ -182,25 +182,6 @@
const TfLiteTensor* bias =
AllocateTempInputTensor(context, node, kBiasTensor);
- // Check dimensionality of input, filter, output
- TF_LITE_ENSURE_EQ(context, input->dims->size, 4);
- TF_LITE_ENSURE_EQ(context, filter->dims->size, 4);
- TF_LITE_ENSURE_EQ(context, output->dims->size, 4);
- TF_LITE_ENSURE(context, params.dilation_height_factor > 0);
- TF_LITE_ENSURE(context, params.dilation_width_factor > 0);
-
- // Filter in DepthwiseConv is expected to be [1, height, width, channels].
- TF_LITE_ENSURE_EQ(context, filter->dims->data[0], 1);
-
- // Check input channels matching filter
- const int num_filter_channels = filter->dims->data[3];
- const int num_input_channels = input->dims->data[3];
- TF_LITE_ENSURE(context, num_input_channels != 0);
- TF_LITE_ENSURE_EQ(context, num_filter_channels % num_input_channels, 0);
-
- int output_width = 0;
- int output_height = 0;
-
const TfLiteType data_type = input->type;
int width = SizeOfDimension(input, 2);
int height = SizeOfDimension(input, 1);
@@ -246,13 +227,9 @@
affine_quantization->zero_point->size);
}
- TF_LITE_ENSURE_STATUS(CalculateOpData(
- context, node, params, width, height, filter_width, filter_height,
- &output_width, &output_height, data_type, data));
-
- // compute output tensor shape and relocate shape data
- TF_LITE_ENSURE_STATUS(DepthwiseConvReshapeOutputTensor(
- context, node, input, filter, output, output_height, output_width));
+ TF_LITE_ENSURE_STATUS(CalculateOpData(context, node, params, width, height,
+ filter_width, filter_height, data_type,
+ data));
data->input_zero_point = input->params.zero_point;
data->filter_zero_point = filter->params.zero_point;
diff --git a/tensorflow/lite/micro/kernels/batch_to_space_nd.cc b/tensorflow/lite/micro/kernels/batch_to_space_nd.cc
index 94d1228..31a1c28 100644
--- a/tensorflow/lite/micro/kernels/batch_to_space_nd.cc
+++ b/tensorflow/lite/micro/kernels/batch_to_space_nd.cc
@@ -1,4 +1,4 @@
-/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -15,10 +15,7 @@
#include "tensorflow/lite/kernels/internal/reference/batch_to_space_nd.h"
-#include <algorithm>
-
#include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/kernels/internal/runtime_shape.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
@@ -41,68 +38,6 @@
const int kInputOutputMinDimensionNum = 3;
const int kInputOutputMaxDimensionNum = 4;
-TfLiteStatus ReshapeOutputTensor(TfLiteContext* context, const TfLiteNode* node,
- const TfLiteTensor* input,
- const TfLiteTensor* block_shape,
- const TfLiteTensor* crops,
- TfLiteTensor* output) {
- TF_LITE_ENSURE(context, IsConstantOrPersistentTensor(block_shape));
- TF_LITE_ENSURE(context, IsConstantOrPersistentTensor(crops));
- const int32_t* block_shape_data = GetTensorData<int32_t>(block_shape);
- const int32_t* crops_data = GetTensorData<int32_t>(crops);
-
- TfLiteIntArray* input_dims = input->dims;
- int spatial_dims_num = input_dims->size - 2;
- // Block_shape should be a 1D tensor with dimension [spatial_dims_num].
- TF_LITE_ENSURE_EQ(context, NumDimensions(block_shape), 1);
- TF_LITE_ENSURE_EQ(context, block_shape->dims->data[0], spatial_dims_num);
- // Crops should be a 2D tensor with dimension [spatial_dims_num, 2].
- TF_LITE_ENSURE_EQ(context, NumDimensions(crops), 2);
- TF_LITE_ENSURE_EQ(context, crops->dims->data[0], spatial_dims_num);
- TF_LITE_ENSURE_EQ(context, crops->dims->data[1], 2);
-
- for (int i = 0; i < spatial_dims_num * 2; ++i) {
- TF_LITE_ENSURE(context, crops_data[i] >= 0);
- }
-
- // copy from input tensor as per TfLite code
- TF_LITE_ENSURE_EQ(context, input_dims->size, output->dims->size);
- RuntimeShape output_shape = GetTensorShape(input);
- // keep a copy of the output tensor shape for later comparison
- RuntimeShape old_output_shape = GetTensorShape(output);
-
- int output_batch_size = input_dims->data[0];
- for (int dim = 0; dim < spatial_dims_num; ++dim) {
- // Number of batch must be multiple of (block_shape[dim]).
- TF_LITE_ENSURE(context, block_shape_data[dim] != 0);
- TF_LITE_ENSURE_EQ(context, output_batch_size % block_shape_data[dim], 0);
- output_batch_size = output_batch_size / block_shape_data[dim];
- output_shape.SetDim(dim + 1,
- input_dims->data[dim + 1] * block_shape_data[dim] -
- crops_data[dim * 2] - crops_data[dim * 2 + 1]);
- }
- output_shape.SetDim(0, output_batch_size);
- output_shape.SetDim(input_dims->size - 1,
- input_dims->data[input_dims->size - 1]);
-
- // check if need to relocate output tensor dims
- if (output_shape == old_output_shape) {
- return kTfLiteOk;
- }
- TF_LITE_ENSURE(context,
- output_shape.FlatSize() <= old_output_shape.FlatSize());
-
- // set the output tensor dims from output_shape
- TfLiteEvalTensor* output_eval =
- tflite::micro::GetEvalOutput(context, node, kOutputTensor);
- TF_LITE_ENSURE_STATUS(tflite::micro::CreateWritableTensorDimsWithCopy(
- context, output, output_eval));
- std::copy_n(output_shape.DimsData(), output_shape.DimensionsCount(),
- output->dims->data);
-
- return kTfLiteOk;
-}
-
TfLiteStatus BatchToSpaceNDPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
@@ -111,40 +46,20 @@
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kInputTensor);
- TF_LITE_ENSURE(context, input != nullptr);
- TfLiteTensor* block_shape =
- micro_context->AllocateTempInputTensor(node, kBlockShapeTensor);
- TF_LITE_ENSURE(context, block_shape != nullptr);
- TfLiteTensor* crops =
- micro_context->AllocateTempInputTensor(node, kCropsTensor);
- TF_LITE_ENSURE(context, crops != nullptr);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
- TF_LITE_ENSURE(context, output != nullptr);
+ TF_LITE_ENSURE(context, input != nullptr && output != nullptr);
TF_LITE_ENSURE(context, NumDimensions(input) >= kInputOutputMinDimensionNum);
TF_LITE_ENSURE(context, NumDimensions(output) >= kInputOutputMinDimensionNum);
TF_LITE_ENSURE(context, NumDimensions(input) <= kInputOutputMaxDimensionNum);
TF_LITE_ENSURE(context, NumDimensions(output) <= kInputOutputMaxDimensionNum);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
- TF_LITE_ENSURE(context,
- input->type == kTfLiteFloat32 || input->type == kTfLiteInt8);
-
- if (input->type == kTfLiteInt8) {
- TF_LITE_ENSURE(context, input->params.scale == output->params.scale);
- TF_LITE_ENSURE(context,
- input->params.zero_point == output->params.zero_point);
- }
-
- TfLiteStatus status =
- ReshapeOutputTensor(context, node, input, block_shape, crops, output);
micro_context->DeallocateTempTfLiteTensor(input);
- micro_context->DeallocateTempTfLiteTensor(block_shape);
- micro_context->DeallocateTempTfLiteTensor(crops);
micro_context->DeallocateTempTfLiteTensor(output);
- return status;
+ return kTfLiteOk;
}
TfLiteStatus BatchToSpaceNDEval(TfLiteContext* context, TfLiteNode* node) {
diff --git a/tensorflow/lite/micro/kernels/batch_to_space_nd_test.cc b/tensorflow/lite/micro/kernels/batch_to_space_nd_test.cc
index 1b42a29..455c325 100644
--- a/tensorflow/lite/micro/kernels/batch_to_space_nd_test.cc
+++ b/tensorflow/lite/micro/kernels/batch_to_space_nd_test.cc
@@ -1,4 +1,4 @@
-/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -14,8 +14,6 @@
==============================================================================*/
#include <cstdint>
-#include <limits>
-#include <type_traits>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
@@ -27,165 +25,98 @@
namespace testing {
namespace {
-constexpr float kTestTolerance = 1e-05;
-constexpr int kNumInputs = 3;
-constexpr int kNumOutputs = 1;
-constexpr int kInputTensorIndex = 0;
-constexpr int kBlockShapeTensorIndex = 1;
-constexpr int kCropTensorIndex = 2;
-constexpr int kOutputTensorIndex = 3;
+constexpr int kBasicInputOutputSize = 16;
+int basic_input_dims[] = {4, 4, 2, 2, 1};
+const float basic_input[kBasicInputOutputSize] = {
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
+int basic_block_shape_dims[] = {1, 2};
+const int32_t basic_block_shape[] = {2, 2};
+int basic_crops_dims[] = {1, 4};
+const int32_t basic_crops[] = {0, 0, 0, 0};
+int basic_output_dims[] = {4, 1, 4, 4, 1};
+const float basic_golden[kBasicInputOutputSize] = {1, 5, 2, 6, 9, 13, 10, 14,
+ 3, 7, 4, 8, 11, 15, 12, 16};
-// min/max are used to compute scale, zero-point (asymmetric)
-template <typename T, size_t kInputSize, size_t kOutputSize>
-struct TestQuantParams {
- // quantization parameters
- float data_min; // input data minimum value
- float data_max; // input data maximum value
- T output_data[kOutputSize]; // quantized output storage
- T input_data[kInputSize]; // quantized input storage
-};
+template <typename T>
+TfLiteStatus ValidateBatchToSpaceNdGoldens(TfLiteTensor* tensors,
+ int tensors_size, const T* golden,
+ T* output, int output_size) {
+ int inputs_array_data[] = {3, 0, 1, 2};
+ TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
+ int outputs_array_data[] = {1, 3};
+ TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
-TfLiteStatus ExecuteBatchToSpaceNdTest(TfLiteTensor* tensors,
- int tensors_count) {
- int kInputArrayData[] = {kNumInputs, kInputTensorIndex,
- kBlockShapeTensorIndex, kCropTensorIndex};
- TfLiteIntArray* inputs_array = IntArrayFromInts(kInputArrayData);
- int kOutputArrayData[] = {kNumOutputs, kOutputTensorIndex};
- TfLiteIntArray* outputs_array = IntArrayFromInts(kOutputArrayData);
-
- const TFLMRegistration registration = tflite::Register_BATCH_TO_SPACE_ND();
- micro::KernelRunner runner(registration, tensors, tensors_count, inputs_array,
+ const TFLMRegistration registration = Register_BATCH_TO_SPACE_ND();
+ micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
outputs_array, nullptr);
- TfLiteStatus status = runner.InitAndPrepare();
- if (status != kTfLiteOk) {
- return status;
- }
- status = runner.Invoke();
+ TF_LITE_ENSURE_STATUS(runner.InitAndPrepare());
+ TF_LITE_ENSURE_STATUS(runner.Invoke());
- return status;
+ for (int i = 0; i < output_size; ++i) {
+ // TODO(b/158102673): workaround for not having fatal test assertions.
+ TF_LITE_MICRO_EXPECT_EQ(golden[i], output[i]);
+ if (golden[i] != output[i]) {
+ return kTfLiteError;
+ }
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus TestBatchToSpaceNdFloat(
+ int* input_dims_data, const float* input_data, int* block_shape_dims_data,
+ const int32_t* block_shape_data, int* crops_dims_data,
+ const int32_t* crops_data, int* output_dims_data, const float* golden,
+ float* output_data) {
+ TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
+ TfLiteIntArray* block_shape_dims = IntArrayFromInts(block_shape_dims_data);
+ TfLiteIntArray* crops_dims = IntArrayFromInts(crops_dims_data);
+ TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
+
+ constexpr int inputs_size = 3;
+ constexpr int outputs_size = 1;
+ constexpr int tensors_size = inputs_size + outputs_size;
+ TfLiteTensor tensors[tensors_size] = {
+ CreateTensor(input_data, input_dims),
+ CreateTensor(block_shape_data, block_shape_dims),
+ CreateTensor(crops_data, crops_dims),
+ CreateTensor(output_data, output_dims),
+ };
+
+ return ValidateBatchToSpaceNdGoldens(tensors, tensors_size, golden,
+ output_data, ElementCount(*output_dims));
}
template <typename T>
-TfLiteStatus TestBatchToSpaceNd(int* input_dims_data[kNumInputs],
- const T* input_data,
- const int32_t* block_shape_data,
- const int32_t* crop_data, int* output_dims_data,
- const T* golden_data, T* output_data) {
- TfLiteIntArray* input_dims =
- IntArrayFromInts(input_dims_data[kInputTensorIndex]);
- TfLiteIntArray* block_shape_dims =
- IntArrayFromInts(input_dims_data[kBlockShapeTensorIndex]);
- TfLiteIntArray* crop_dims =
- IntArrayFromInts(input_dims_data[kCropTensorIndex]);
- TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
-
- constexpr int kTensorsCount = kNumInputs + kNumOutputs;
- TfLiteTensor tensors[kTensorsCount];
- tensors[kInputTensorIndex] =
- tflite::testing::CreateTensor(input_data, input_dims);
- tensors[kBlockShapeTensorIndex] =
- tflite::testing::CreateTensor(block_shape_data, block_shape_dims);
- tensors[kBlockShapeTensorIndex].allocation_type = kTfLiteMmapRo;
- tensors[kCropTensorIndex] =
- tflite::testing::CreateTensor(crop_data, crop_dims);
- tensors[kCropTensorIndex].allocation_type = kTfLiteMmapRo;
- tensors[kOutputTensorIndex] =
- tflite::testing::CreateTensor(output_data, output_dims);
-
- TfLiteStatus status = ExecuteBatchToSpaceNdTest(tensors, kTensorsCount);
- if (status != kTfLiteOk) {
- return status;
- }
-
- // check output dimensions (relocated) against original dimensions
- TF_LITE_MICRO_EXPECT_EQ(output_dims->size,
- tensors[kOutputTensorIndex].dims->size);
- TF_LITE_MICRO_CHECK_FAIL();
- for (int i = 0; i < output_dims->size; i++) {
- TF_LITE_MICRO_EXPECT_EQ(output_dims->data[i],
- tensors[kOutputTensorIndex].dims->data[i]);
- // TODO(b/158102673): workaround for not having fatal test assertions.
- TF_LITE_MICRO_CHECK_FAIL();
- }
-
- // check output data against expected
- const int output_count = ElementCount(*output_dims);
- for (int i = 0; i < output_count; i++) {
- TF_LITE_MICRO_EXPECT_NEAR(golden_data[i], output_data[i], kTestTolerance);
- // TODO(b/158102673): workaround for not having fatal test assertions.
- TF_LITE_MICRO_CHECK_FAIL();
- }
-
- return kTfLiteOk;
-}
-
-template <typename T, size_t kInCount, size_t kOutCount>
TfLiteStatus TestBatchToSpaceNdQuantized(
- TestQuantParams<T, kInCount, kOutCount>& params,
- int* input_dims_data[kNumInputs], const float* input_data,
- const int32_t* block_shape_data, const int32_t* crop_data,
- int* output_dims_data, const float* golden_data) {
- TfLiteIntArray* input_dims =
- IntArrayFromInts(input_dims_data[kInputTensorIndex]);
- TfLiteIntArray* block_shape_dims =
- IntArrayFromInts(input_dims_data[kBlockShapeTensorIndex]);
- TfLiteIntArray* crop_dims =
- IntArrayFromInts(input_dims_data[kCropTensorIndex]);
+ int* input_dims_data, const float* input_data, T* input_quantized,
+ float input_scale, int input_zero_point, int* block_shape_dims_data,
+ const int32_t* block_shape_data, int* crops_dims_data,
+ const int32_t* crops_data, int* output_dims_data, const float* golden,
+ T* golden_quantized, float output_scale, int output_zero_point,
+ T* output_data) {
+ TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
+ TfLiteIntArray* block_shape_dims = IntArrayFromInts(block_shape_dims_data);
+ TfLiteIntArray* crops_dims = IntArrayFromInts(crops_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
- constexpr float kMaxMultiplier =
- std::numeric_limits<T>::max() /
- static_cast<float>(std::numeric_limits<T>::max() + 1);
- int zero_point = tflite::testing::ZeroPointFromMinMax<T>(
- params.data_min, params.data_max * kMaxMultiplier);
- float scale = tflite::testing::ScaleFromMinMax<T>(
- params.data_min, params.data_max * kMaxMultiplier);
+ constexpr int inputs_size = 3;
+ constexpr int outputs_size = 1;
+ constexpr int tensors_size = inputs_size + outputs_size;
+ TfLiteTensor tensors[tensors_size] = {
+ tflite::testing::CreateQuantizedTensor(input_data, input_quantized,
+ input_dims, input_scale,
+ input_zero_point),
+ tflite::testing::CreateTensor(block_shape_data, block_shape_dims),
+ tflite::testing::CreateTensor(crops_data, crops_dims),
+ tflite::testing::CreateQuantizedTensor(output_data, output_dims,
+ output_scale, output_zero_point),
+ };
+ tflite::Quantize(golden, golden_quantized, ElementCount(*output_dims),
+ output_scale, output_zero_point);
- constexpr int kTensorsCount = kNumInputs + kNumOutputs;
- TfLiteTensor tensors[kTensorsCount];
- tensors[kInputTensorIndex] = tflite::testing::CreateQuantizedTensor(
- input_data, params.input_data, input_dims, scale, zero_point);
- tensors[kBlockShapeTensorIndex] =
- tflite::testing::CreateTensor(block_shape_data, block_shape_dims);
- tensors[kBlockShapeTensorIndex].allocation_type = kTfLiteMmapRo;
- tensors[kCropTensorIndex] =
- tflite::testing::CreateTensor(crop_data, crop_dims);
- tensors[kCropTensorIndex].allocation_type = kTfLiteMmapRo;
- tensors[kOutputTensorIndex] = tflite::testing::CreateQuantizedTensor(
- params.output_data, output_dims, scale, zero_point);
-
- TfLiteStatus status = ExecuteBatchToSpaceNdTest(tensors, kTensorsCount);
- if (status != kTfLiteOk) {
- return status;
- }
-
- // check output dimensions (relocated) against original dimensions
- TF_LITE_MICRO_EXPECT_EQ(output_dims->size,
- tensors[kOutputTensorIndex].dims->size);
- TF_LITE_MICRO_CHECK_FAIL();
- for (int i = 0; i < output_dims->size; i++) {
- TF_LITE_MICRO_EXPECT_EQ(output_dims->data[i],
- tensors[kOutputTensorIndex].dims->data[i]);
- // TODO(b/158102673): workaround for not having fatal test assertions.
- TF_LITE_MICRO_CHECK_FAIL();
- }
-
- // check output data against expected
- const int output_count = ElementCount(*output_dims);
- const float quantization_tolerance =
- (params.data_max - params.data_min) /
- (std::numeric_limits<T>::max() - std::numeric_limits<T>::min());
- for (int i = 0; i < output_count; i++) {
- float output_dequantized_data =
- (params.output_data[i] - zero_point) * scale;
- TF_LITE_MICRO_EXPECT_NEAR(golden_data[i], output_dequantized_data,
- quantization_tolerance);
- // TODO(b/158102673): workaround for not having fatal test assertions.
- TF_LITE_MICRO_CHECK_FAIL();
- }
-
- return kTfLiteOk;
+ return ValidateBatchToSpaceNdGoldens(tensors, tensors_size, golden_quantized,
+ output_data, ElementCount(*output_dims));
}
} // namespace
@@ -194,291 +125,30 @@
TF_LITE_MICRO_TESTS_BEGIN
-TF_LITE_MICRO_TEST(BatchToSpaceNDOpTestInvalidOutputShapeTest) {
- int kInputDims[] = {3, 2, 4, 1};
- int kBlockShapeDims[] = {1, 1};
- int kCropDims[] = {2, 1, 2};
- int kOutputDims[] = {3, 1, 1, 1}; // invalid shape
-
- int* kInputDimsArray[tflite::testing::kNumInputs];
- kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
- kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
- kInputDimsArray[tflite::testing::kCropTensorIndex] = kCropDims;
-
- constexpr float kInput[] = {1, 2, 3, 4, 5, 6, 7, 8};
- constexpr int32_t kBlockShape[] = {2};
- constexpr int32_t kCrop[] = {0, 0};
- constexpr float kGolden[] = {0}; // placeholder data, not used
- constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
- float output_data[kOutputCount];
-
- TF_LITE_MICRO_EXPECT_EQ(kTfLiteError,
- tflite::testing::TestBatchToSpaceNd(
- kInputDimsArray, kInput, kBlockShape, kCrop,
- kOutputDims, kGolden, output_data));
-}
-
-TF_LITE_MICRO_TEST(BatchToSpaceNDOpTestValidOutputShapeTest) {
- int kInputDims[] = {3, 2, 4, 1};
- int kBlockShapeDims[] = {1, 1};
- int kCropDims[] = {2, 1, 2};
- int kOutputDims[] = {3, 1, 8, 1};
-
- int* kInputDimsArray[tflite::testing::kNumInputs];
- kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
- kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
- kInputDimsArray[tflite::testing::kCropTensorIndex] = kCropDims;
-
- constexpr float kInput[] = {1, 2, 3, 4, 5, 6, 7, 8};
- constexpr int32_t kBlockShape[] = {2};
- constexpr int32_t kCrop[] = {0, 0};
- constexpr float kGolden[] = {1, 5, 2, 6, 3, 7, 4, 8};
- constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
- float output_data[kOutputCount];
-
+TF_LITE_MICRO_TEST(BatchToSpaceBasicFloat) {
+ float output[tflite::testing::kBasicInputOutputSize];
TF_LITE_MICRO_EXPECT_EQ(
- kTfLiteOk, tflite::testing::TestBatchToSpaceNd(
- kInputDimsArray, kInput, kBlockShape, kCrop, kOutputDims,
- kGolden, output_data));
+ kTfLiteOk,
+ tflite::testing::TestBatchToSpaceNdFloat(
+ tflite::testing::basic_input_dims, tflite::testing::basic_input,
+ tflite::testing::basic_block_shape_dims,
+ tflite::testing::basic_block_shape, tflite::testing::basic_crops_dims,
+ tflite::testing::basic_crops, tflite::testing::basic_output_dims,
+ tflite::testing::basic_golden, output));
}
-TF_LITE_MICRO_TEST(BatchToSpaceNDOpTestSimpleConstTest) {
- int kInputDims[] = {4, 4, 2, 2, 1};
- int kBlockShapeDims[] = {1, 2};
- int kCropDims[] = {2, 2, 2};
- int kOutputDims[] = {4, 1, 4, 4, 1};
-
- int* kInputDimsArray[tflite::testing::kNumInputs];
- kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
- kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
- kInputDimsArray[tflite::testing::kCropTensorIndex] = kCropDims;
-
- constexpr float kInput[] = {
- 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
- };
- constexpr int32_t kBlockShape[] = {2, 2};
- constexpr int32_t kCrop[] = {0, 0, 0, 0};
- constexpr float kGolden[] = {
- 1, 5, 2, 6, 9, 13, 10, 14, 3, 7, 4, 8, 11, 15, 12, 16,
- };
- constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
- float output_data[kOutputCount];
-
+TF_LITE_MICRO_TEST(BatchToSpaceBasicInt8) {
+ int8_t output[tflite::testing::kBasicInputOutputSize];
+ int8_t input_quantized[tflite::testing::kBasicInputOutputSize];
+ int8_t golden_quantized[tflite::testing::kBasicInputOutputSize];
TF_LITE_MICRO_EXPECT_EQ(
- kTfLiteOk, tflite::testing::TestBatchToSpaceNd(
- kInputDimsArray, kInput, kBlockShape, kCrop, kOutputDims,
- kGolden, output_data));
-}
-
-// non-quantized test
-TF_LITE_MICRO_TEST(BatchToSpaceNDOpTestSimpleConstTestInt8) {
- int kInputDims[] = {4, 4, 2, 2, 1};
- int kBlockShapeDims[] = {1, 2};
- int kCropDims[] = {2, 2, 2};
- int kOutputDims[] = {4, 1, 4, 4, 1};
-
- int* kInputDimsArray[tflite::testing::kNumInputs];
- kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
- kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
- kInputDimsArray[tflite::testing::kCropTensorIndex] = kCropDims;
-
- constexpr int8_t kInput[] = {
- 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
- };
- constexpr int32_t kBlockShape[] = {2, 2};
- constexpr int32_t kCrop[] = {0, 0, 0, 0};
- constexpr int8_t kGolden[] = {
- 1, 5, 2, 6, 9, 13, 10, 14, 3, 7, 4, 8, 11, 15, 12, 16,
- };
- constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
- int8_t output_data[kOutputCount];
-
- TF_LITE_MICRO_EXPECT_EQ(
- kTfLiteOk, tflite::testing::TestBatchToSpaceNd(
- kInputDimsArray, kInput, kBlockShape, kCrop, kOutputDims,
- kGolden, output_data));
-}
-
-TF_LITE_MICRO_TEST(BatchToSpaceNDOpTestBatchOneConstTest) {
- int kInputDims[] = {4, 1, 2, 2, 1};
- int kBlockShapeDims[] = {1, 2};
- int kCropDims[] = {2, 2, 2};
- int kOutputDims[] = {4, 1, 2, 2, 1};
-
- int* kInputDimsArray[tflite::testing::kNumInputs];
- kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
- kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
- kInputDimsArray[tflite::testing::kCropTensorIndex] = kCropDims;
-
- constexpr float kInput[] = {1, 2, 3, 4};
- constexpr int32_t kBlockShape[] = {1, 1};
- constexpr int32_t kCrop[] = {0, 0, 0, 0};
- constexpr float kGolden[] = {1, 2, 3, 4};
- constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
- float output_data[kOutputCount];
-
- TF_LITE_MICRO_EXPECT_EQ(
- kTfLiteOk, tflite::testing::TestBatchToSpaceNd(
- kInputDimsArray, kInput, kBlockShape, kCrop, kOutputDims,
- kGolden, output_data));
-}
-
-// non-quantized test
-TF_LITE_MICRO_TEST(BatchToSpaceNDOpTestSimpleConstTestInt8EmptyOutput) {
- int kInputDims[] = {4, 4, 2, 2, 1};
- int kBlockShapeDims[] = {1, 2};
- int kCropDims[] = {2, 2, 2};
- int kOutputDims[] = {4, 1, 4, 0, 1};
-
- int* kInputDimsArray[tflite::testing::kNumInputs];
- kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
- kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
- kInputDimsArray[tflite::testing::kCropTensorIndex] = kCropDims;
-
- constexpr int8_t kInput[] = {
- 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
- };
- constexpr int32_t kBlockShape[] = {2, 2};
- constexpr int32_t kCrop[] = {0, 0, 2, 2};
- constexpr int8_t kGolden[] = {0}; // placeholder data, not used
- constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
- int8_t output_data[kOutputCount];
-
- TF_LITE_MICRO_EXPECT_EQ(
- kTfLiteOk, tflite::testing::TestBatchToSpaceNd(
- kInputDimsArray, kInput, kBlockShape, kCrop, kOutputDims,
- kGolden, output_data));
-}
-
-TF_LITE_MICRO_TEST(BatchToSpaceNDOpTestInvalidShapeTest) {
- int kInputDims[] = {4, 3, 2, 2, 1};
- int kBlockShapeDims[] = {1, 2};
- int kCropDims[] = {2, 2, 2};
- int kOutputDims[] = {4, 0, 0, 0, 0}; // placeholder dims, not used
-
- int* kInputDimsArray[tflite::testing::kNumInputs];
- kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
- kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
- kInputDimsArray[tflite::testing::kCropTensorIndex] = kCropDims;
-
- constexpr float kInput[] = {0}; // placeholder data, not used
- constexpr int32_t kBlockShape[] = {2, 2};
- constexpr int32_t kCrop[] = {0, 0, 0, 0};
- constexpr float kGolden[] = {0}; // placeholder data, not used
- constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
- float output_data[kOutputCount];
-
- TF_LITE_MICRO_EXPECT_EQ(kTfLiteError,
- tflite::testing::TestBatchToSpaceNd(
- kInputDimsArray, kInput, kBlockShape, kCrop,
- kOutputDims, kGolden, output_data));
-}
-
-TF_LITE_MICRO_TEST(BatchToSpaceNDOpTestInvalidCropsConstTest) {
- int kInputDims[] = {4, 3, 2, 2, 1};
- int kBlockShapeDims[] = {1, 2};
- int kCropDims[] = {2, 2, 2};
- int kOutputDims[] = {4, 0, 0, 0, 0}; // placeholder dims, not used
-
- int* kInputDimsArray[tflite::testing::kNumInputs];
- kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
- kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
- kInputDimsArray[tflite::testing::kCropTensorIndex] = kCropDims;
-
- constexpr float kInput[] = {0}; // placeholder data, not used
- constexpr int32_t kBlockShape[] = {2, 2};
- constexpr int32_t kCrop[] = {0, 0, 0, -1};
- constexpr float kGolden[] = {0}; // placeholder data, not used
- constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
- float output_data[kOutputCount];
-
- TF_LITE_MICRO_EXPECT_EQ(kTfLiteError,
- tflite::testing::TestBatchToSpaceNd(
- kInputDimsArray, kInput, kBlockShape, kCrop,
- kOutputDims, kGolden, output_data));
-}
-
-TF_LITE_MICRO_TEST(BatchToSpaceNDOpTestSimple3DConstTest) {
- int kInputDims[] = {3, 4, 4, 1};
- int kBlockShapeDims[] = {1, 1};
- int kCropDims[] = {2, 1, 2};
- int kOutputDims[] = {3, 2, 8, 1};
-
- int* kInputDimsArray[tflite::testing::kNumInputs];
- kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
- kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
- kInputDimsArray[tflite::testing::kCropTensorIndex] = kCropDims;
-
- constexpr float kInput[] = {
- 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
- };
- constexpr int32_t kBlockShape[] = {2};
- constexpr int32_t kCrop[] = {0, 0};
- constexpr float kGolden[] = {
- 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15, 8, 16,
- };
- constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
- float output_data[kOutputCount];
-
- TF_LITE_MICRO_EXPECT_EQ(
- kTfLiteOk, tflite::testing::TestBatchToSpaceNd(
- kInputDimsArray, kInput, kBlockShape, kCrop, kOutputDims,
- kGolden, output_data));
-}
-
-TF_LITE_MICRO_TEST(BatchToSpaceNDOpTestSimple3DConstTestWithCrops) {
- int kInputDims[] = {3, 4, 4, 1};
- int kBlockShapeDims[] = {1, 1};
- int kCropDims[] = {2, 1, 2};
- int kOutputDims[] = {3, 2, 6, 1};
-
- int* kInputDimsArray[tflite::testing::kNumInputs];
- kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
- kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
- kInputDimsArray[tflite::testing::kCropTensorIndex] = kCropDims;
-
- constexpr float kInput[] = {
- 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
- };
- constexpr int32_t kBlockShape[] = {2};
- constexpr int32_t kCrop[] = {1, 1};
- constexpr float kGolden[] = {9, 2, 10, 3, 11, 4, 13, 6, 14, 7, 15, 8};
- constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
- float output_data[kOutputCount];
-
- TF_LITE_MICRO_EXPECT_EQ(
- kTfLiteOk, tflite::testing::TestBatchToSpaceNd(
- kInputDimsArray, kInput, kBlockShape, kCrop, kOutputDims,
- kGolden, output_data));
-}
-
-TF_LITE_MICRO_TEST(BatchToSpaceNDOpTestSimple3DConstTestWithCropsINT8) {
- int kInputDims[] = {3, 4, 4, 1};
- int kBlockShapeDims[] = {1, 1};
- int kCropDims[] = {2, 1, 2};
- int kOutputDims[] = {3, 2, 6, 1};
-
- int* kInputDimsArray[tflite::testing::kNumInputs];
- kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
- kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
- kInputDimsArray[tflite::testing::kCropTensorIndex] = kCropDims;
-
- constexpr float kInput[] = {
- 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
- };
- constexpr int kInputCount = std::extent<decltype(kInput)>::value;
- constexpr int32_t kBlockShape[] = {2};
- constexpr int32_t kCrop[] = {1, 1};
- constexpr float kGolden[] = {9, 2, 10, 3, 11, 4, 13, 6, 14, 7, 15, 8};
- constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
-
- tflite::testing::TestQuantParams<int8_t, kInputCount, kOutputCount> params = {
- -16, 16, {}, {}};
-
- TF_LITE_MICRO_EXPECT_EQ(
- kTfLiteOk, tflite::testing::TestBatchToSpaceNdQuantized(
- params, kInputDimsArray, kInput, kBlockShape, kCrop,
- kOutputDims, kGolden));
+ kTfLiteOk,
+ tflite::testing::TestBatchToSpaceNdQuantized(
+ tflite::testing::basic_input_dims, tflite::testing::basic_input,
+ input_quantized, 1.0f, 0, tflite::testing::basic_block_shape_dims,
+ tflite::testing::basic_block_shape, tflite::testing::basic_crops_dims,
+ tflite::testing::basic_crops, tflite::testing::basic_output_dims,
+ tflite::testing::basic_golden, golden_quantized, 1.0f, 0, output));
}
TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/conv.cc b/tensorflow/lite/micro/kernels/cmsis_nn/conv.cc
index ef15da7..d3d1552 100644
--- a/tensorflow/lite/micro/kernels/cmsis_nn/conv.cc
+++ b/tensorflow/lite/micro/kernels/cmsis_nn/conv.cc
@@ -75,26 +75,29 @@
(input->type == kTfLiteInt8 && filter->type == kTfLiteInt4),
"Hybrid models are not supported on TFLite Micro.");
- // Check dimensionality of input, filter, output
- TF_LITE_ENSURE_EQ(context, input->dims->size, 4);
- TF_LITE_ENSURE_EQ(context, filter->dims->size, 4);
- TF_LITE_ENSURE_EQ(context, output->dims->size, 4);
-
- // Check input channels matching filter
- const int input_channels = input->dims->data[3];
- const int filter_input_channels = filter->dims->data[3];
- TF_LITE_ENSURE(context, filter_input_channels > 0);
- TF_LITE_ENSURE_EQ(context, input_channels % filter_input_channels, 0);
+ RuntimeShape input_shape = GetTensorShape(input);
+ RuntimeShape output_shape = GetTensorShape(output);
// Initialize cmsis_nn input dimensions
cmsis_nn_dims input_dims;
+ input_dims.n = MatchingDim(input_shape, 0, output_shape, 0);
input_dims.h = input->dims->data[1];
input_dims.w = input->dims->data[2];
+ input_dims.c = input_shape.Dims(3);
// Initialize cmsis_nn filter dimensions
cmsis_nn_dims filter_dims;
+ filter_dims.n = output_shape.Dims(3);
filter_dims.h = filter->dims->data[1];
filter_dims.w = filter->dims->data[2];
+ filter_dims.c = input_dims.c;
+
+ // Initialize cmsis_nn output dimensions
+ cmsis_nn_dims output_dims;
+ output_dims.n = input_dims.n;
+ output_dims.h = output->dims->data[1];
+ output_dims.w = output->dims->data[2];
+ output_dims.c = output_shape.Dims(3);
if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) {
const int num_channels = filter->dims->data[kConvQuantizedDimension];
@@ -106,32 +109,11 @@
context, num_channels * sizeof(int32_t)));
}
- int output_height = 0;
- int output_width = 0;
TF_LITE_ENSURE_STATUS(CalculateOpDataConv(
context, node, params, input_dims.w, input_dims.h, filter_dims.w,
- filter_dims.h, &output_width, &output_height, input->type,
+ filter_dims.h, output_dims.w, output_dims.h, input->type,
&data->reference_op_data));
- // compute output tensor shape and relocate shape data
- TF_LITE_ENSURE_STATUS(ConvReshapeOutputTensor(
- context, node, input, filter, output, output_height, output_width));
-
- // Finish initializing cmsis_nn input, filter dimensions
- RuntimeShape input_shape = GetTensorShape(input);
- RuntimeShape output_shape = GetTensorShape(output);
- input_dims.n = MatchingDim(input_shape, 0, output_shape, 0);
- input_dims.c = input_shape.Dims(3);
- filter_dims.n = output_shape.Dims(3);
- filter_dims.c = input_dims.c;
-
- // Initialize cmsis_nn output dimensions
- cmsis_nn_dims output_dims;
- output_dims.n = input_dims.n;
- output_dims.h = output_shape.Dims(1);
- output_dims.w = output_shape.Dims(2);
- output_dims.c = output_shape.Dims(3);
-
// CMSIS_NN allows INT64 or nullptr bias data pointer
if (input->type == kTfLiteInt8 ||
(input->type == kTfLiteInt16 &&
diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/depthwise_conv.cc b/tensorflow/lite/micro/kernels/cmsis_nn/depthwise_conv.cc
index 77d6712..f30a952 100644
--- a/tensorflow/lite/micro/kernels/cmsis_nn/depthwise_conv.cc
+++ b/tensorflow/lite/micro/kernels/cmsis_nn/depthwise_conv.cc
@@ -75,29 +75,13 @@
micro_context->AllocateTempOutputTensor(node, kDepthwiseConvOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
- // Check dimensionality of input, filter, output
- TF_LITE_ENSURE_EQ(context, input->dims->size, 4);
- TF_LITE_ENSURE_EQ(context, filter->dims->size, 4);
- TF_LITE_ENSURE_EQ(context, output->dims->size, 4);
- TF_LITE_ENSURE(context, params.dilation_height_factor > 0);
- TF_LITE_ENSURE(context, params.dilation_width_factor > 0);
-
- // Filter in DepthwiseConv is expected to be [1, height, width, channels].
- TF_LITE_ENSURE_EQ(context, filter->dims->data[0], 1);
-
- // Check input channels matching filter
- const int num_filter_channels = filter->dims->data[3];
- const int num_input_channels = input->dims->data[3];
- TF_LITE_ENSURE(context, num_input_channels != 0);
- TF_LITE_ENSURE_EQ(context, num_filter_channels % num_input_channels, 0);
-
const TfLiteType data_type = input->type;
int input_width = SizeOfDimension(input, 2);
int input_height = SizeOfDimension(input, 1);
int filter_width = SizeOfDimension(filter, 2);
int filter_height = SizeOfDimension(filter, 1);
- int output_width = 0;
- int output_height = 0;
+ int output_width = SizeOfDimension(output, 2);
+ int output_height = SizeOfDimension(output, 1);
if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) {
TF_LITE_ENSURE_EQ(context, filter->quantization.type,
@@ -136,13 +120,9 @@
TF_LITE_ENSURE_STATUS(CalculateOpDataDepthwiseConv(
context, node, params, input_width, input_height, filter_width,
- filter_height, &output_width, &output_height, data_type,
+ filter_height, output_width, output_height, data_type,
&data->reference_op_data));
- // compute output tensor shape and relocate shape data
- TF_LITE_ENSURE_STATUS(DepthwiseConvReshapeOutputTensor(
- context, node, input, filter, output, output_height, output_width));
-
if (input->type == kTfLiteInt8) {
RuntimeShape input_shape = GetTensorShape(input);
RuntimeShape output_shape = GetTensorShape(output);
diff --git a/tensorflow/lite/micro/kernels/conv.h b/tensorflow/lite/micro/kernels/conv.h
index a927980..0c8073f 100644
--- a/tensorflow/lite/micro/kernels/conv.h
+++ b/tensorflow/lite/micro/kernels/conv.h
@@ -70,20 +70,10 @@
TfLiteStatus CalculateOpDataConv(TfLiteContext* context, TfLiteNode* node,
const TfLiteConvParams& params, int width,
int height, int filter_width,
- int filter_height, int* out_width,
- int* out_height, const TfLiteType data_type,
+ int filter_height, int out_width,
+ int out_height, const TfLiteType data_type,
OpDataConv* data);
-// When this method is called, the output tensor shape is computed and
-// relocated to persistent arena memory.
-// The height and width parameters should be the computed results from
-// CalculateOpDataConv.
-TfLiteStatus ConvReshapeOutputTensor(TfLiteContext* context, TfLiteNode* node,
- const TfLiteTensor* input,
- const TfLiteTensor* filter,
- TfLiteTensor* output, int height,
- int width);
-
void* ConvInit(TfLiteContext* context, const char* buffer, size_t length);
TfLiteStatus ConvPrepare(TfLiteContext* context, TfLiteNode* node);
diff --git a/tensorflow/lite/micro/kernels/conv_common.cc b/tensorflow/lite/micro/kernels/conv_common.cc
index ddcd5e5..51c7a6f 100644
--- a/tensorflow/lite/micro/kernels/conv_common.cc
+++ b/tensorflow/lite/micro/kernels/conv_common.cc
@@ -79,8 +79,8 @@
TfLiteStatus CalculateOpDataConv(TfLiteContext* context, TfLiteNode* node,
const TfLiteConvParams& params, int width,
int height, int filter_width,
- int filter_height, int* out_width,
- int* out_height, const TfLiteType data_type,
+ int filter_height, int out_width,
+ int out_height, const TfLiteType data_type,
OpDataConv* data) {
bool has_bias = node->inputs->size == 3;
// Check number of inputs/outputs
@@ -92,7 +92,7 @@
data->padding = ComputePaddingHeightWidth(
params.stride_height, params.stride_width, params.dilation_height_factor,
params.dilation_width_factor, height, width, filter_height, filter_width,
- padding, out_height, out_width);
+ padding, &out_height, &out_width);
MicroContext* micro_context = GetMicroContext(context);
@@ -135,28 +135,6 @@
return kTfLiteOk;
}
-TfLiteStatus ConvReshapeOutputTensor(TfLiteContext* context, TfLiteNode* node,
- const TfLiteTensor* input,
- const TfLiteTensor* filter,
- TfLiteTensor* output, int height,
- int width) {
- const int filter_output_channels = filter->dims->data[0];
- const int batches = input->dims->data[0];
-
- // relocate output tensor dims so they can be updated
- TfLiteEvalTensor* output_eval =
- tflite::micro::GetEvalOutput(context, node, kConvOutputTensor);
- TF_LITE_ENSURE_STATUS(tflite::micro::CreateWritableTensorDimsWithCopy(
- context, output, output_eval));
-
- output->dims->data[0] = batches;
- output->dims->data[1] = height;
- output->dims->data[2] = width;
- output->dims->data[3] = filter_output_channels;
-
- return kTfLiteOk;
-}
-
TfLiteStatus ConvPrepare(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
TFLITE_DCHECK(node->builtin_data != nullptr);
@@ -185,23 +163,12 @@
(filter->type == kTfLiteInt4 || filter->type == kTfLiteInt8)),
"Hybrid models are not supported on TFLite Micro.");
- // Check dimensionality of input, filter, output
- TF_LITE_ENSURE_EQ(context, input->dims->size, 4);
- TF_LITE_ENSURE_EQ(context, filter->dims->size, 4);
- TF_LITE_ENSURE_EQ(context, output->dims->size, 4);
-
- // Check input channels matching filter
- const int input_channels = input->dims->data[3];
- const int filter_input_channels = filter->dims->data[3];
- TF_LITE_ENSURE(context, filter_input_channels > 0);
- TF_LITE_ENSURE_EQ(context, input_channels % filter_input_channels, 0);
-
const int input_width = input->dims->data[2];
const int input_height = input->dims->data[1];
const int filter_width = filter->dims->data[2];
const int filter_height = filter->dims->data[1];
- int output_width = 0;
- int output_height = 0;
+ const int output_width = output->dims->data[2];
+ const int output_height = output->dims->data[1];
// Dynamically allocate per-channel quantization parameters.
const int num_channels = filter->dims->data[kConvQuantizedDimension];
@@ -231,11 +198,7 @@
TF_LITE_ENSURE_STATUS(CalculateOpDataConv(
context, node, params, input_width, input_height, filter_width,
- filter_height, &output_width, &output_height, input->type, data));
-
- // compute output tensor shape and relocate shape data
- TF_LITE_ENSURE_STATUS(ConvReshapeOutputTensor(
- context, node, input, filter, output, output_height, output_width));
+ filter_height, output_width, output_height, input->type, data));
if (filter->type == kTfLiteInt4) {
int filter_size =
diff --git a/tensorflow/lite/micro/kernels/conv_test.cc b/tensorflow/lite/micro/kernels/conv_test.cc
index 3cfc594..0fb9411 100644
--- a/tensorflow/lite/micro/kernels/conv_test.cc
+++ b/tensorflow/lite/micro/kernels/conv_test.cc
@@ -277,6 +277,9 @@
}
TF_LITE_MICRO_TEST(SimpleTestDilatedQuantizedPerChannel) {
+ const int output_dims_count = 24;
+ int8_t output_data[output_dims_count];
+
const float input_scale = 0.5f;
const float output_scale = 1.0f;
const int input_zero_point = 0;
@@ -289,10 +292,10 @@
1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4,
// b = 1
1, 2, 3, 4, 5, 6, 2, 6, 2, 4, 4, 2, 3, 2, 6, 5, 1, 4, 1, 2, 1, 4, 6, 3};
- constexpr int output_elements = 12;
- int output_shape[] = {4, 2, 1, 2, 3};
- int8_t output_data[output_elements];
- const float golden_data[] = {25, 2, 7, 25, 2, 7, 39, 7, 6, 50, 3, 4};
+ const int output_elements = 24;
+ int output_shape[] = {4, 2, 2, 2, 3};
+ const float golden_data[] = {25, 2, 7, 25, 2, 7, 10, 2, -3, 10, 2, -3,
+ 39, 7, 6, 50, 3, 4, 14, 4, -5, 15, 0, -7};
int8_t input_quantized[input_elements];
int8_t filter_quantized[tflite::testing::kFilterElements];
@@ -1084,7 +1087,7 @@
using tflite::ElementCount;
using tflite::kConvBiasQuantized8;
using tflite::kConvFilter8x3x3x3;
- using tflite::kConvGoldenOutput1x15x15x8;
+ using tflite::kConvGoldenOutput1x16x16x8;
using tflite::kConvInput1x32x32x3;
using tflite::testing::CreateTensor;
using tflite::testing::FloatArrayFromFloats;
@@ -1156,8 +1159,8 @@
0};
// Create output tensor of 16x16x8
- int8_t output_data[1 * 15 * 15 * kOutDepth];
- int output_shape[] = {4, 1, 15, 15, kOutDepth};
+ int8_t output_data[1 * 16 * 16 * kOutDepth];
+ int output_shape[] = {4, 1, 16, 16, kOutDepth};
TfLiteIntArray* output_dims = IntArrayFromInts(output_shape);
const int output_dims_count = ElementCount(*output_dims);
TfLiteTensor output_tensor = CreateTensor(output_data, output_dims);
@@ -1180,7 +1183,7 @@
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk,
- ValidateConvGoldens(tensors, tensors_size, kConvGoldenOutput1x15x15x8,
+ ValidateConvGoldens(tensors, tensors_size, kConvGoldenOutput1x16x16x8,
output_dims_count, &conv_params,
tflite::Register_CONV_2D(), output_data,
1.0 /* tolerance */));
diff --git a/tensorflow/lite/micro/kernels/conv_test_common.cc b/tensorflow/lite/micro/kernels/conv_test_common.cc
index eda801b..a0f733b 100644
--- a/tensorflow/lite/micro/kernels/conv_test_common.cc
+++ b/tensorflow/lite/micro/kernels/conv_test_common.cc
@@ -18,18 +18,13 @@
namespace tflite {
namespace testing {
-constexpr int kInputIndex = 0;
-constexpr int kFilterIndex = 1;
-constexpr int kBiasIndex = 2;
-constexpr int kOutputIndex = 3;
-
template <typename T>
TfLiteStatus InvokeConv(TfLiteTensor* tensors, int tensors_size,
int output_length, TfLiteConvParams* conv_params,
TFLMRegistration registration, T* output_data) {
- int inputs_array_data[] = {3, kInputIndex, kFilterIndex, kBiasIndex};
+ int inputs_array_data[] = {3, 0, 1, 2};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
- int outputs_array_data[] = {1, kOutputIndex};
+ int outputs_array_data[] = {1, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
@@ -50,37 +45,15 @@
TfLiteConvParams* conv_params,
TFLMRegistration registration, T* output_data,
float tolerance) {
- // grab pointer to expected tensor shape
- TfLiteIntArray* expected_out_dims = tensors[kOutputIndex].dims;
-
TfLiteStatus status = InvokeConv(tensors, tensors_size, output_length,
conv_params, registration, output_data);
if (status != kTfLiteOk) {
return status;
}
-
- // check output dimensions (relocated) against original dimensions
- TF_LITE_MICRO_EXPECT_EQ(expected_out_dims->size,
- tensors[kOutputIndex].dims->size);
- TF_LITE_MICRO_CHECK_FAIL();
- for (int i = 0; i < expected_out_dims->size; i++) {
- TF_LITE_MICRO_EXPECT_EQ(expected_out_dims->data[i],
- tensors[kOutputIndex].dims->data[i]);
- // TODO(b/158102673): workaround for not having fatal test assertions.
- TF_LITE_MICRO_CHECK_FAIL();
- }
-
- // compare expected against output data
- const int actual_output_length = ElementCount(*expected_out_dims);
- TF_LITE_MICRO_EXPECT_EQ(output_length, actual_output_length);
- TF_LITE_MICRO_CHECK_FAIL();
for (int i = 0; i < output_length; ++i) {
TF_LITE_MICRO_EXPECT_NEAR(expected_output_data[i], output_data[i],
tolerance);
- // TODO(b/158102673): workaround for not having fatal test assertions.
- TF_LITE_MICRO_CHECK_FAIL();
}
-
return kTfLiteOk;
}
diff --git a/tensorflow/lite/micro/kernels/depthwise_conv.h b/tensorflow/lite/micro/kernels/depthwise_conv.h
index b6712cd..5f2d87e 100644
--- a/tensorflow/lite/micro/kernels/depthwise_conv.h
+++ b/tensorflow/lite/micro/kernels/depthwise_conv.h
@@ -44,17 +44,9 @@
TfLiteStatus CalculateOpDataDepthwiseConv(
TfLiteContext* context, TfLiteNode* node,
const TfLiteDepthwiseConvParams& params, int width, int height,
- int filter_width, int filter_height, int* out_width, int* out_height,
+ int filter_width, int filter_height, int out_width, int out_height,
const TfLiteType data_type, OpDataConv* data);
-// When this method is called, the output tensor shape is computed and
-// relocated to persistent arena memory.
-// The height and width parameters should be the computed results from
-// CalculateOpDataConv.
-TfLiteStatus DepthwiseConvReshapeOutputTensor(
- TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input,
- const TfLiteTensor* filter, TfLiteTensor* output, int height, int width);
-
TfLiteStatus DepthwiseConvPrepare(TfLiteContext* context, TfLiteNode* node);
// This is the most generic TFLMRegistration. The actual supported types
diff --git a/tensorflow/lite/micro/kernels/depthwise_conv_common.cc b/tensorflow/lite/micro/kernels/depthwise_conv_common.cc
index 431bec0..52804de 100644
--- a/tensorflow/lite/micro/kernels/depthwise_conv_common.cc
+++ b/tensorflow/lite/micro/kernels/depthwise_conv_common.cc
@@ -1,4 +1,4 @@
-/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -80,7 +80,7 @@
TfLiteStatus CalculateOpDataDepthwiseConv(
TfLiteContext* context, TfLiteNode* node,
const TfLiteDepthwiseConvParams& params, int width, int height,
- int filter_width, int filter_height, int* out_width, int* out_height,
+ int filter_width, int filter_height, int out_width, int out_height,
const TfLiteType data_type, OpDataConv* data) {
bool has_bias = node->inputs->size == 3;
// Check number of inputs/outputs
@@ -92,7 +92,7 @@
data->padding = ComputePaddingHeightWidth(
params.stride_height, params.stride_width, params.dilation_height_factor,
params.dilation_width_factor, height, width, filter_height, filter_width,
- padding, out_height, out_width);
+ padding, &out_height, &out_width);
MicroContext* micro_context = GetMicroContext(context);
@@ -133,26 +133,6 @@
return kTfLiteOk;
}
-TfLiteStatus DepthwiseConvReshapeOutputTensor(
- TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input,
- const TfLiteTensor* filter, TfLiteTensor* output, int height, int width) {
- const int filter_output_channels = filter->dims->data[3];
- const int batches = input->dims->data[0];
-
- // relocate output tensor dims so they can be updated
- TfLiteEvalTensor* output_eval =
- tflite::micro::GetEvalOutput(context, node, kConvOutputTensor);
- TF_LITE_ENSURE_STATUS(tflite::micro::CreateWritableTensorDimsWithCopy(
- context, output, output_eval));
-
- output->dims->data[0] = batches;
- output->dims->data[1] = height;
- output->dims->data[2] = width;
- output->dims->data[3] = filter_output_channels;
-
- return kTfLiteOk;
-}
-
TfLiteStatus DepthwiseConvPrepare(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
TFLITE_DCHECK(node->builtin_data != nullptr);
@@ -172,28 +152,12 @@
micro_context->AllocateTempInputTensor(node, kDepthwiseConvWeightsTensor);
TF_LITE_ENSURE(context, filter != nullptr);
- // Check dimensionality of input, filter, output
- TF_LITE_ENSURE_EQ(context, input->dims->size, 4);
- TF_LITE_ENSURE_EQ(context, filter->dims->size, 4);
- TF_LITE_ENSURE_EQ(context, output->dims->size, 4);
- TF_LITE_ENSURE(context, params.dilation_height_factor > 0);
- TF_LITE_ENSURE(context, params.dilation_width_factor > 0);
-
- // Filter in DepthwiseConv is expected to be [1, height, width, channels].
- TF_LITE_ENSURE_EQ(context, filter->dims->data[0], 1);
-
- // Check input channels matching filter
- const int num_filter_channels = filter->dims->data[3];
- const int num_input_channels = input->dims->data[3];
- TF_LITE_ENSURE(context, num_input_channels != 0);
- TF_LITE_ENSURE_EQ(context, num_filter_channels % num_input_channels, 0);
-
const int input_width = input->dims->data[2];
const int input_height = input->dims->data[1];
const int filter_width = filter->dims->data[2];
const int filter_height = filter->dims->data[1];
- int output_width = 0;
- int output_height = 0;
+ const int output_width = output->dims->data[2];
+ const int output_height = output->dims->data[1];
// Dynamically allocate per-channel quantization parameters.
const int num_channels = filter->dims->data[kDepthwiseConvQuantizedDimension];
@@ -243,11 +207,7 @@
TF_LITE_ENSURE_STATUS(CalculateOpDataDepthwiseConv(
context, node, params, input_width, input_height, filter_width,
- filter_height, &output_width, &output_height, input->type, data));
-
- // compute output tensor shape and relocate shape data
- TF_LITE_ENSURE_STATUS(DepthwiseConvReshapeOutputTensor(
- context, node, input, filter, output, output_height, output_width));
+ filter_height, output_width, output_height, input->type, data));
micro_context->DeallocateTempTfLiteTensor(output);
micro_context->DeallocateTempTfLiteTensor(input);
diff --git a/tensorflow/lite/micro/kernels/depthwise_conv_test.cc b/tensorflow/lite/micro/kernels/depthwise_conv_test.cc
index 4cf090c..b50b40a 100644
--- a/tensorflow/lite/micro/kernels/depthwise_conv_test.cc
+++ b/tensorflow/lite/micro/kernels/depthwise_conv_test.cc
@@ -1,5 +1,5 @@
-/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -25,11 +25,8 @@
namespace testing {
namespace {
-// Indices of the tensors in context->tensors, specific to
+// Index of the output tensor in context->tensors, specific to
// DepthwiseConv.
-constexpr int kInputTensorIndex = 0;
-constexpr int kFilterTensorIndex = 1;
-constexpr int kBiasTensorIndex = 2;
constexpr int kOutputTensorIndex = 3;
constexpr int kMaxFilterChannels = 64;
@@ -46,10 +43,9 @@
const T* expected_output_data, int output_length,
TfLiteDepthwiseConvParams* conv_params, float tolerance, int tensors_size,
TfLiteTensor* tensors) {
- int inputs_array_data[] = {3, kInputTensorIndex, kFilterTensorIndex,
- kBiasTensorIndex};
+ int inputs_array_data[] = {3, 0, 1, 2};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
- int outputs_array_data[] = {1, kOutputTensorIndex};
+ int outputs_array_data[] = {1, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
const TFLMRegistration registration = Register_DEPTHWISE_CONV_2D();
@@ -65,8 +61,6 @@
conv_params->depth_multiplier = depth_mul;
const char* init_data = reinterpret_cast<const char*>(conv_params);
- // grab pointer to expected tensor shape
- TfLiteIntArray* expected_out_dims = tensors[kOutputTensorIndex].dims;
// TODO(b/154240825): Use a test macro here which fails and returns.
TfLiteStatus status = runner.InitAndPrepare(init_data);
@@ -75,28 +69,12 @@
}
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
- // check output dimensions (relocated) against original dimensions
- TF_LITE_MICRO_EXPECT_EQ(expected_out_dims->size,
- tensors[kOutputTensorIndex].dims->size);
- TF_LITE_MICRO_CHECK_FAIL();
- for (int i = 0; i < expected_out_dims->size; i++) {
- TF_LITE_MICRO_EXPECT_EQ(expected_out_dims->data[i],
- tensors[kOutputTensorIndex].dims->data[i]);
- // TODO(b/158102673): workaround for not having fatal test assertions.
- TF_LITE_MICRO_CHECK_FAIL();
- }
-
- const int actual_output_length = ElementCount(*expected_out_dims);
- TF_LITE_MICRO_EXPECT_EQ(output_length, actual_output_length);
- TF_LITE_MICRO_CHECK_FAIL();
const T* output_data = tflite::GetTensorData<T>(&tensors[kOutputTensorIndex]);
+
for (int i = 0; i < output_length; ++i) {
TF_LITE_MICRO_EXPECT_NEAR(expected_output_data[i], output_data[i],
tolerance);
- // TODO(b/158102673): workaround for not having fatal test assertions.
- TF_LITE_MICRO_CHECK_FAIL();
}
-
return kTfLiteOk;
}
diff --git a/tensorflow/lite/micro/kernels/expand_dims.cc b/tensorflow/lite/micro/kernels/expand_dims.cc
index 19c54d7..6bae37b 100644
--- a/tensorflow/lite/micro/kernels/expand_dims.cc
+++ b/tensorflow/lite/micro/kernels/expand_dims.cc
@@ -1,4 +1,4 @@
-/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -47,43 +47,38 @@
}
}
-// Rewrite the output tensor's dimension shape so it is equivalent to inserting
+// Verifies that the output tensor's dimension shape is equivalent to inserting
// a dimension of length 1 at the dimension index axis of input's shape as
// defined in https://www.tensorflow.org/api_docs/python/tf/expand_dims.
-TfLiteStatus ReshapeOutputTensor(TfLiteContext* context, TfLiteNode* node,
- const TfLiteTensor* input,
- const TfLiteTensor* axis_tensor,
- TfLiteTensor* output) {
+TfLiteStatus VerifyTensorDim(TfLiteContext* context, const TfLiteTensor* input,
+ const TfLiteTensor* axis_tensor,
+ const TfLiteTensor* output) {
int32_t axis_value = 0;
TF_LITE_ENSURE_OK(context,
GetAxisValueFromTensor(context, axis_tensor, &axis_value));
- TfLiteIntArray* input_shape = input->dims;
+ tflite::RuntimeShape input_shape = tflite::GetTensorShape(input);
if (axis_value < 0) {
- axis_value = input_shape->size + 1 + axis_value;
+ axis_value = input_shape.DimensionsCount() + 1 + axis_value;
}
+ TF_LITE_ENSURE(context, axis_value <= input_shape.DimensionsCount());
- // relocate output tensor dims so they can be updated
- TfLiteEvalTensor* output_eval =
- tflite::micro::GetEvalOutput(context, node, kOutputTensor);
- TF_LITE_ENSURE_STATUS(tflite::micro::CreateWritableTensorDimsWithCopy(
- context, output, output_eval));
+ // TFLM only supports fixed dimension tensor and assumes that the output shape
+ // is fully specified in the model. As such, TFLM directly use the pointer to
+ // the dimension array in the model buffer.
+ tflite::RuntimeShape output_shape = tflite::GetTensorShape(output);
- TfLiteIntArray* output_shape = output->dims;
- TF_LITE_ENSURE(context, output_shape->size == input_shape->size + 1);
- TF_LITE_ENSURE(context, axis_value < output_shape->size);
- TF_LITE_ENSURE(context, axis_value >= 0);
-
- for (int i = 0; i < output_shape->size; ++i) {
+ TF_LITE_ENSURE(context, output_shape.DimensionsCount() ==
+ input_shape.DimensionsCount() + 1);
+ for (int i = 0; i < output_shape.DimensionsCount(); ++i) {
if (i < axis_value) {
- output_shape->data[i] = input_shape->data[i];
+ TF_LITE_ENSURE(context, output_shape.Dims(i) == input_shape.Dims(i));
} else if (i == axis_value) {
- output_shape->data[i] = 1;
+ TF_LITE_ENSURE(context, output_shape.Dims(i) == 1);
} else {
- output_shape->data[i] = input_shape->data[i - 1];
+ TF_LITE_ENSURE(context, output_shape.Dims(i) == input_shape.Dims(i - 1));
}
}
-
return kTfLiteOk;
}
@@ -106,8 +101,7 @@
MicroPrintf("DynamicTensor is not yet supported by Expand_Dims.");
return kTfLiteError;
}
- TF_LITE_ENSURE_OK(context,
- ReshapeOutputTensor(context, node, input, axis, output));
+ TF_LITE_ENSURE_OK(context, VerifyTensorDim(context, input, axis, output));
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(axis);
diff --git a/tensorflow/lite/micro/kernels/expand_dims_test.cc b/tensorflow/lite/micro/kernels/expand_dims_test.cc
index 4a49a8e..d8e217e 100644
--- a/tensorflow/lite/micro/kernels/expand_dims_test.cc
+++ b/tensorflow/lite/micro/kernels/expand_dims_test.cc
@@ -1,4 +1,4 @@
-/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -34,24 +34,28 @@
constexpr int kInputTensors[] = {2, kDimsTensorIndex, kAxisTensorIndex};
constexpr int kOutputTensors[] = {1, kOutputTensorIndex};
-// Some targets do not support dynamic memory (i.e., no malloc or new), thus,
-// the test need to place non-transitent memories in static variables. This is
-// safe because tests are guaranteed to run serially.
-// Both below structures are trivially destructible.
-static TFLMRegistration registration;
-static TfLiteTensor tensors[kTensorsSize];
-
template <typename T>
micro::KernelRunner CreateExpandDimsKernelRunner(
int* input_dims, const T* input_data, int* axis_dims,
const int32_t* axis_data, int* output_dims, T* output_data) {
+ // Some targets do not support dynamic memory (i.e., no malloc or new), thus,
+ // the test need to place non-transitent memories in static variables. This is
+ // safe because tests are guaranteed to run serially.
+ // Both below structures are trivially destructible.
+ static TFLMRegistration registration;
+ static TfLiteTensor tensors[kTensorsSize];
+
TfLiteIntArray* in_dims = IntArrayFromInts(input_dims);
TfLiteIntArray* ax_dims = IntArrayFromInts(axis_dims);
TfLiteIntArray* out_dims = IntArrayFromInts(output_dims);
+ const int out_dims_size = out_dims->size;
+ const int in_dims_size = in_dims->size;
+ TF_LITE_MICRO_EXPECT_EQ(out_dims_size, (in_dims_size + 1));
+
tensors[kDimsTensorIndex] = CreateTensor(input_data, in_dims);
tensors[kAxisTensorIndex] = CreateTensor(axis_data, ax_dims);
- tensors[kOutputTensorIndex] = CreateTensor(output_data, out_dims);
+ tensors[kOutputTensorIndex] = CreateTensor(output_data, out_dims, true);
TfLiteIntArray* inputs_array =
IntArrayFromInts(const_cast<int*>(kInputTensors));
@@ -77,16 +81,9 @@
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
- // The output tensor shape has been updated by the kernel.
- TfLiteIntArray* actual_out_dims = tensors[kOutputTensorIndex].dims;
+ // The output tensor's data have been updated by the kernel.
+ TfLiteIntArray* actual_out_dims = IntArrayFromInts(output_dims);
const int output_size = ElementCount(*actual_out_dims);
- TfLiteIntArray* golden_out_dims = IntArrayFromInts(expected_output_dims);
-
- // check output dimensions (relocated) against original dimensions
- TF_LITE_MICRO_EXPECT_EQ(golden_out_dims->size, actual_out_dims->size);
- for (int i = 0; i < golden_out_dims->size; i++) {
- TF_LITE_MICRO_EXPECT_EQ(golden_out_dims->data[i], actual_out_dims->data[i]);
- }
for (int i = 0; i < output_size; ++i) {
TF_LITE_MICRO_EXPECT_EQ(expected_output_data[i], output_data[i]);
@@ -106,9 +103,10 @@
const int8_t golden_data[] = {-1, 1, -2, 2};
int axis_dims[] = {1, 1};
const int32_t axis_data[] = {0};
+ int golden_dims[] = {1, 2, 2};
int output_dims[] = {3, 1, 2, 2};
tflite::testing::TestExpandDims<int8_t>(input_dims, input_data, axis_dims,
- axis_data, output_dims, output_dims,
+ axis_data, golden_dims, output_dims,
golden_data, output_data);
}
@@ -119,9 +117,10 @@
const float golden_data[] = {-1.1, 1.2, -2.1, 2.2};
int axis_dims[] = {1, 1};
const int32_t axis_data[] = {1};
+ int golden_dims[] = {2, 1, 2};
int output_dims[] = {3, 2, 1, 2};
tflite::testing::TestExpandDims<float>(input_dims, input_data, axis_dims,
- axis_data, output_dims, output_dims,
+ axis_data, golden_dims, output_dims,
golden_data, output_data);
}
@@ -132,9 +131,10 @@
const int8_t golden_data[] = {-1, 1, -2, 2};
int axis_dims[] = {1, 1};
const int32_t axis_data[] = {2};
+ int golden_dims[] = {2, 2, 1};
int output_dims[] = {3, 2, 2, 1};
tflite::testing::TestExpandDims<int8_t>(input_dims, input_data, axis_dims,
- axis_data, output_dims, output_dims,
+ axis_data, golden_dims, output_dims,
golden_data, output_data);
}
@@ -145,9 +145,10 @@
const int8_t golden_data[] = {-1, 1, 2, -2, 0, 3};
int axis_dims[] = {1, 1};
const int32_t axis_data[] = {-4};
+ int golden_dims[] = {1, 3, 1, 2};
int output_dims[] = {4, 1, 3, 1, 2};
tflite::testing::TestExpandDims<int8_t>(input_dims, input_data, axis_dims,
- axis_data, output_dims, output_dims,
+ axis_data, golden_dims, output_dims,
golden_data, output_data);
}
@@ -158,9 +159,10 @@
const float golden_data[] = {0.1, -0.8, -1.2, -0.5, 0.9, 1.3};
int axis_dims[] = {1, 1};
const int32_t axis_data[] = {-3};
+ int golden_dims[] = {3, 1, 1, 2};
int output_dims[] = {4, 3, 1, 1, 2};
tflite::testing::TestExpandDims<float>(input_dims, input_data, axis_dims,
- axis_data, output_dims, output_dims,
+ axis_data, golden_dims, output_dims,
golden_data, output_data);
}
@@ -171,9 +173,10 @@
const int8_t golden_data[] = {-1, 1, 2, -2, 0, 3};
int axis_dims[] = {1, 1};
const int32_t axis_data[] = {-2};
+ int golden_dims[] = {1, 2, 1, 3};
int output_dims[] = {4, 1, 2, 1, 3};
tflite::testing::TestExpandDims<int8_t>(input_dims, input_data, axis_dims,
- axis_data, output_dims, output_dims,
+ axis_data, golden_dims, output_dims,
golden_data, output_data);
}
@@ -184,40 +187,24 @@
const float golden_data[] = {0.1, -0.8, -1.2, -0.5, 0.9, 1.3};
int axis_dims[] = {1, 1};
const int32_t axis_data[] = {-1};
+ int golden_dims[] = {1, 3, 2, 1};
int output_dims[] = {4, 1, 3, 2, 1};
tflite::testing::TestExpandDims<float>(input_dims, input_data, axis_dims,
- axis_data, output_dims, output_dims,
- golden_data, output_data);
-}
-
-TF_LITE_MICRO_TEST(ExpandDimsInputOutputDimsMismatch) {
- float output_data[6];
- int input_dims[] = {3, 1, 3, 2};
- const float input_data[] = {0.1, -0.8, -1.2, -0.5, 0.9, 1.3};
- const float golden_data[] = {0.1, -0.8, -1.2, -0.5, 0.9, 1.3};
- int axis_dims[] = {1, 1};
- const int32_t axis_data[] = {-1};
- // When input dimension is [1, 3, 2] and the axis is -1, the output dimension
- // should be [1, 3, 2, 1] as in the test case ExpandDimsNegativeAxisTest1.
- // Shuffle the output dimension to make it incorrect to test that output
- // tensor dimensions are computed correctly.
- int golden_dims[] = {4, 1, 3, 2, 1};
- int output_dims[] = {4, 1, 3, 1, 2};
-
- tflite::testing::TestExpandDims<float>(input_dims, input_data, axis_dims,
axis_data, golden_dims, output_dims,
golden_data, output_data);
}
-TF_LITE_MICRO_TEST(ExpandDimsAxisPositiveOutOfRangeShallFailTest) {
- int8_t output_data[6];
+TF_LITE_MICRO_TEST(ExpandDimsInputOutputDimsMismatchShallFail) {
+ float output_data[6];
int input_dims[] = {3, 1, 3, 2};
- const int8_t input_data[] = {1, 8, 2, 5, 9, 3};
+ const float input_data[] = {0.1, -0.8, -1.2, -0.5, 0.9, 1.3};
int axis_dims[] = {1, 1};
- // The input dimension is 3-D, so that axis value should not exceed 3.
- // The below axis value 4 shall lead to failure at prepare.
- const int32_t axis_data[] = {4};
- int output_dims[] = {4, 1, 3, 2, 1};
+ const int32_t axis_data[] = {-1};
+ // When input dimension is [1, 3, 2] and the axis is -1, the output dimension
+ // should be [1, 3, 2, 1] as in the test case ExpandDimsNegativeAxisTest1.
+ // Shuffle the output dimension to make it incorrect so that the EXPAND_DIMS
+ // op would fail at prepare.
+ int output_dims[] = {4, 1, 3, 1, 2};
tflite::micro::KernelRunner runner =
tflite::testing::CreateExpandDimsKernelRunner(input_dims, input_data,
@@ -227,14 +214,14 @@
TF_LITE_MICRO_EXPECT_EQ(kTfLiteError, runner.InitAndPrepare());
}
-TF_LITE_MICRO_TEST(ExpandDimsAxisNegativeOutOfRangeShallFailTest) {
+TF_LITE_MICRO_TEST(ExpandDimsAxisOutOfRangeShallFail) {
int8_t output_data[6];
int input_dims[] = {3, 1, 3, 2};
const int8_t input_data[] = {1, 8, 2, 5, 9, 3};
int axis_dims[] = {1, 1};
- // The input dimension is 3-D, so that axis value should be less than zero.
- // The below axis value -5 shall lead to failure at prepare.
- const int32_t axis_data[] = {-5};
+ // The input dimension is 3-D, so that axis value should not exceed 3.
+ // The below axis value 4 shall lead to failure at prepare.
+ const int32_t axis_data[] = {4};
int output_dims[] = {4, 1, 3, 2, 1};
tflite::micro::KernelRunner runner =
diff --git a/tensorflow/lite/micro/kernels/reshape.h b/tensorflow/lite/micro/kernels/reshape.h
index 15e5f61..02bda32 100644
--- a/tensorflow/lite/micro/kernels/reshape.h
+++ b/tensorflow/lite/micro/kernels/reshape.h
@@ -19,7 +19,6 @@
namespace tflite {
constexpr int kReshapeInputTensor = 0;
-constexpr int kReshapeShapeTensor = 1;
constexpr int kReshapeOutputTensor = 0;
TfLiteStatus PrepareReshapeReference(TfLiteContext* context, TfLiteNode* node);
diff --git a/tensorflow/lite/micro/kernels/reshape_common.cc b/tensorflow/lite/micro/kernels/reshape_common.cc
index 75bc2eb..b86e2be 100644
--- a/tensorflow/lite/micro/kernels/reshape_common.cc
+++ b/tensorflow/lite/micro/kernels/reshape_common.cc
@@ -1,4 +1,4 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -35,9 +35,6 @@
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kReshapeInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
- // The shape tensor is optional
- TfLiteTensor* new_shape =
- micro_context->AllocateTempInputTensor(node, kReshapeShapeTensor);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kReshapeOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
@@ -46,35 +43,20 @@
// input. Here we calculate what that dimension should be so that the number
// of output elements in the same as the number of input elements.
int num_input_elements = NumElements(input);
-
- int output_shape_size = 0;
- int* output_shape_data = nullptr;
- if (new_shape != nullptr && new_shape->dims->size > 0) {
- // use new shape tensor data
- TF_LITE_ENSURE_EQ(context, new_shape->type, kTfLiteInt32);
- output_shape_data = GetTensorData<int>(new_shape);
- output_shape_size = new_shape->dims->data[new_shape->dims->size - 1];
-
- TF_LITE_ENSURE_EQ(context, output_shape_size,
- static_cast<int32_t>(output->dims->size));
- } else {
- // use output shape
- output_shape_size = output->dims->size;
- output_shape_data = output->dims->data;
- }
+ TfLiteIntArray* output_shape = output->dims;
if (NumInputs(node) == 1 && // Legacy scalar supported with params.
- output_shape_size == 1 && output_shape_data[0] == 0) {
+ output_shape->size == 1 && output_shape->data[0] == 0) {
// Legacy tflite models use a shape parameter of [0] to indicate scalars,
// so adjust accordingly. TODO(b/111614235): Allow zero-sized buffers during
// toco conversion.
- output_shape_size = 0;
+ output_shape->size = 0;
}
int num_output_elements = 1;
int stretch_dim = -1;
- for (int i = 0; i < output_shape_size; ++i) {
- int value = output_shape_data[i];
+ for (int i = 0; i < output_shape->size; ++i) {
+ int value = output_shape->data[i];
if (value == -1) {
TF_LITE_ENSURE_EQ(context, stretch_dim, -1);
stretch_dim = i;
@@ -82,26 +64,20 @@
num_output_elements *= value;
}
}
- if (stretch_dim != -1 || output_shape_size == 0) {
+ if (stretch_dim != -1) {
TfLiteEvalTensor* output_eval =
tflite::micro::GetEvalOutput(context, node, kReshapeOutputTensor);
TF_LITE_ENSURE_STATUS(tflite::micro::CreateWritableTensorDimsWithCopy(
context, output, output_eval));
- if (stretch_dim != -1) {
- output->dims->data[stretch_dim] =
- num_input_elements / num_output_elements;
- num_output_elements *= output->dims->data[stretch_dim];
- }
- output->dims->size = output_shape_size;
+ output_shape = output->dims; // output tensor dims were moved
+ output_shape->data[stretch_dim] = num_input_elements / num_output_elements;
+ num_output_elements *= output_shape->data[stretch_dim];
}
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
TF_LITE_ENSURE_EQ(context, num_input_elements, num_output_elements);
micro_context->DeallocateTempTfLiteTensor(input);
- if (new_shape != nullptr) {
- micro_context->DeallocateTempTfLiteTensor(new_shape);
- }
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
diff --git a/tensorflow/lite/micro/kernels/reshape_test.cc b/tensorflow/lite/micro/kernels/reshape_test.cc
index d97007a..d78d9fa 100644
--- a/tensorflow/lite/micro/kernels/reshape_test.cc
+++ b/tensorflow/lite/micro/kernels/reshape_test.cc
@@ -265,7 +265,7 @@
golden_dims_len, false);
}
-// Stretch is supported with TF Micro
+// Stretch is not supported with TF Micro
TF_LITE_MICRO_TEST(ReshapeWithStretchDimensionShouldSucceed) {
float output_data_float[32];
int8_t output_data_int8[32];
@@ -347,7 +347,7 @@
const float input_data[] = {3.0f};
auto input_tensor = CreateTensor(input_data, input_dims);
- float output_data[] = {0.0f};
+ float output_data[1];
int output_dims_data[2] = {1, 0};
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
auto output_tensor = CreateTensor(output_data, output_dims);
diff --git a/tensorflow/lite/micro/kernels/space_to_batch_nd.cc b/tensorflow/lite/micro/kernels/space_to_batch_nd.cc
index e5e86d1..f8df149 100644
--- a/tensorflow/lite/micro/kernels/space_to_batch_nd.cc
+++ b/tensorflow/lite/micro/kernels/space_to_batch_nd.cc
@@ -1,4 +1,4 @@
-/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -15,10 +15,7 @@
#include "tensorflow/lite/kernels/internal/reference/space_to_batch_nd.h"
-#include <algorithm>
-
#include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/kernels/internal/runtime_shape.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
@@ -27,11 +24,12 @@
#include "tensorflow/lite/micro/micro_utils.h"
namespace tflite {
+
namespace {
constexpr int kInputTensor = 0;
constexpr int kBlockShapeTensor = 1;
-constexpr int kPaddingTensor = 2;
+constexpr int kCropsTensor = 2;
constexpr int kOutputTensor = 0;
// Currently, only 3D NHC and 4D NHWC input/output op_context are supported.
@@ -47,68 +45,6 @@
return context->AllocatePersistentBuffer(context, sizeof(SpaceToBatchParams));
}
-TfLiteStatus ReshapeOutputTensor(TfLiteContext* context, const TfLiteNode* node,
- const TfLiteTensor* input,
- const TfLiteTensor* block_shape,
- const TfLiteTensor* padding,
- TfLiteTensor* output) {
- TF_LITE_ENSURE(context, IsConstantOrPersistentTensor(block_shape));
- TF_LITE_ENSURE(context, IsConstantOrPersistentTensor(padding));
- const int32_t* block_shape_data = GetTensorData<int32_t>(block_shape);
- const int32_t* padding_data = GetTensorData<int32_t>(padding);
-
- TfLiteIntArray* input_dims = input->dims;
- int spatial_dims_num = input_dims->size - 2;
- // Block_shape should be a 1D tensor with dimension [spatial_dims_num].
- TF_LITE_ENSURE_EQ(context, NumDimensions(block_shape), 1);
- TF_LITE_ENSURE_EQ(context, block_shape->dims->data[0], spatial_dims_num);
- // Padding should be a 2D tensor with dimension [spatial_dims_num, 2].
- TF_LITE_ENSURE_EQ(context, NumDimensions(padding), 2);
- TF_LITE_ENSURE_EQ(context, padding->dims->data[0], spatial_dims_num);
- TF_LITE_ENSURE_EQ(context, padding->dims->data[1], 2);
-
- // copy from input tensor as per TfLite code
- TF_LITE_ENSURE_EQ(context, input_dims->size, output->dims->size);
- RuntimeShape output_shape = GetTensorShape(input);
- // keep a copy of the output tensor shape for later comparison
- RuntimeShape old_output_shape = GetTensorShape(output);
-
- // Ensures the input height and width (with padding) is a multiple of block
- // shape height and width.
- int output_batch_size = input_dims->data[0];
- for (int dim = 0; dim < spatial_dims_num; ++dim) {
- int final_dim_size = (input_dims->data[dim + 1] + padding_data[dim * 2] +
- padding_data[dim * 2 + 1]);
- TF_LITE_ENSURE(context, block_shape_data[dim] != 0);
- TF_LITE_ENSURE_EQ(context, final_dim_size % block_shape_data[dim], 0);
- output_shape.SetDim(dim + 1, final_dim_size / block_shape_data[dim]);
- output_batch_size *= block_shape_data[dim];
- }
- output_shape.SetDim(0, output_batch_size);
- output_shape.SetDim(input_dims->size - 1,
- input_dims->data[input_dims->size - 1]);
-
- // check if need to relocate output tensor dims
- if (output_shape == old_output_shape) {
- return kTfLiteOk;
- } else if (output_shape.FlatSize() > old_output_shape.FlatSize() &&
- output->data.data != nullptr) {
- MicroPrintf(
- "SPACE_TO_BATCH_ND: resizing flatbuffer tensor data is not supported");
- return kTfLiteError;
- }
-
- // set the output tensor dims from output_shape
- TfLiteEvalTensor* output_eval =
- tflite::micro::GetEvalOutput(context, node, kOutputTensor);
- TF_LITE_ENSURE_STATUS(tflite::micro::CreateWritableTensorDimsWithCopy(
- context, output, output_eval));
- std::copy_n(output_shape.DimsData(), output_shape.DimensionsCount(),
- output->dims->data);
-
- return kTfLiteOk;
-}
-
TfLiteStatus SpaceToBatchNDPrepare(TfLiteContext* context, TfLiteNode* node) {
MicroContext* micro_context = GetMicroContext(context);
@@ -117,47 +53,19 @@
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kInputTensor);
- TF_LITE_ENSURE(context, input != nullptr);
- TfLiteTensor* block_shape =
- micro_context->AllocateTempInputTensor(node, kBlockShapeTensor);
- TF_LITE_ENSURE(context, block_shape != nullptr);
- TfLiteTensor* padding =
- micro_context->AllocateTempInputTensor(node, kPaddingTensor);
- TF_LITE_ENSURE(context, padding != nullptr);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
- TF_LITE_ENSURE(context, output != nullptr);
+ TF_LITE_ENSURE(context, input != nullptr && output != nullptr);
TF_LITE_ENSURE(context, NumDimensions(input) >= kInputOutputMinDimensionNum);
TF_LITE_ENSURE(context, NumDimensions(output) >= kInputOutputMinDimensionNum);
TF_LITE_ENSURE(context, NumDimensions(input) <= kInputOutputMaxDimensionNum);
TF_LITE_ENSURE(context, NumDimensions(output) <= kInputOutputMaxDimensionNum);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
- TF_LITE_ENSURE(context,
- input->type == kTfLiteFloat32 || input->type == kTfLiteInt8);
-
- TF_LITE_ENSURE(context, node->user_data != nullptr);
- SpaceToBatchParams& params =
- *(static_cast<SpaceToBatchParams*>(node->user_data));
-
- if (input->type == kTfLiteInt8) {
- TF_LITE_ENSURE(context, input->params.scale == output->params.scale);
- TF_LITE_ENSURE(context,
- input->params.zero_point == output->params.zero_point);
- params.output_offset = output->params.zero_point;
- } else {
- params.output_offset = 0;
- }
-
- TfLiteStatus status =
- ReshapeOutputTensor(context, node, input, block_shape, padding, output);
micro_context->DeallocateTempTfLiteTensor(input);
- micro_context->DeallocateTempTfLiteTensor(block_shape);
- micro_context->DeallocateTempTfLiteTensor(padding);
micro_context->DeallocateTempTfLiteTensor(output);
-
- return status;
+ return kTfLiteOk;
}
TfLiteStatus SpaceToBatchNDEval(TfLiteContext* context, TfLiteNode* node) {
@@ -169,8 +77,8 @@
tflite::micro::GetEvalInput(context, node, kInputTensor);
const TfLiteEvalTensor* block_shape =
tflite::micro::GetEvalInput(context, node, kBlockShapeTensor);
- const TfLiteEvalTensor* padding =
- tflite::micro::GetEvalInput(context, node, kPaddingTensor);
+ const TfLiteEvalTensor* crops =
+ tflite::micro::GetEvalInput(context, node, kCropsTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
@@ -181,8 +89,8 @@
tflite::micro::GetTensorData<float>(input),
tflite::micro::GetTensorShape(block_shape),
tflite::micro::GetTensorData<int32_t>(block_shape),
- tflite::micro::GetTensorShape(padding),
- tflite::micro::GetTensorData<int32_t>(padding),
+ tflite::micro::GetTensorShape(crops),
+ tflite::micro::GetTensorData<int32_t>(crops),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
break;
@@ -192,8 +100,8 @@
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(block_shape),
tflite::micro::GetTensorData<int32_t>(block_shape),
- tflite::micro::GetTensorShape(padding),
- tflite::micro::GetTensorData<int32_t>(padding),
+ tflite::micro::GetTensorShape(crops),
+ tflite::micro::GetTensorData<int32_t>(crops),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
break;
diff --git a/tensorflow/lite/micro/kernels/space_to_batch_nd_test.cc b/tensorflow/lite/micro/kernels/space_to_batch_nd_test.cc
index d45a606..eae185b 100644
--- a/tensorflow/lite/micro/kernels/space_to_batch_nd_test.cc
+++ b/tensorflow/lite/micro/kernels/space_to_batch_nd_test.cc
@@ -1,4 +1,4 @@
-/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -14,8 +14,6 @@
==============================================================================*/
#include <cstdint>
-#include <limits>
-#include <type_traits>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
@@ -27,160 +25,98 @@
namespace testing {
namespace {
-constexpr float kTestTolerance = 1e-05;
-constexpr int kNumInputs = 3;
-constexpr int kNumOutputs = 1;
-constexpr int kInputTensorIndex = 0;
-constexpr int kBlockShapeTensorIndex = 1;
-constexpr int kPaddingTensorIndex = 2;
-constexpr int kOutputTensorIndex = 3;
+constexpr int kBasicInputOutputSize = 16;
+int basic_input_dims[] = {4, 1, 4, 4, 1};
+const float basic_input[kBasicInputOutputSize] = {
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
+int basic_block_shape_dims[] = {1, 2};
+const int32_t basic_block_shape[] = {2, 2};
+int basic_crops_dims[] = {1, 4};
+const int32_t basic_crops[] = {0, 0, 0, 0};
+int basic_output_dims[] = {4, 4, 2, 2, 1};
+const float basic_golden[kBasicInputOutputSize] = {1, 3, 9, 11, 2, 4, 10, 12,
+ 5, 7, 13, 15, 6, 8, 14, 16};
-// min/max are used to compute scale, zero-point (asymmetric)
-template <typename T, size_t kInputSize, size_t kOutputSize>
-struct TestQuantParams {
- // quantization parameters
- float data_min; // input data minimum value
- float data_max; // input data maximum value
- T output_data[kOutputSize]; // quantized output storage
- T input_data[kInputSize]; // quantized input storage
-};
+template <typename T>
+TfLiteStatus ValidateSpaceToBatchNdGoldens(TfLiteTensor* tensors,
+ int tensors_size, const T* golden,
+ T* output, int output_size) {
+ int inputs_array_data[] = {3, 0, 1, 2};
+ TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
+ int outputs_array_data[] = {1, 3};
+ TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
-TfLiteStatus ExecuteSpaceToBatchNdTest(TfLiteTensor* tensors,
- int tensors_count) {
- int kInputArrayData[] = {kNumInputs, kInputTensorIndex,
- kBlockShapeTensorIndex, kPaddingTensorIndex};
- TfLiteIntArray* inputs_array = IntArrayFromInts(kInputArrayData);
- int kOutputArrayData[] = {kNumOutputs, kOutputTensorIndex};
- TfLiteIntArray* outputs_array = IntArrayFromInts(kOutputArrayData);
-
- const TFLMRegistration registration = tflite::Register_SPACE_TO_BATCH_ND();
- micro::KernelRunner runner(registration, tensors, tensors_count, inputs_array,
+ const TFLMRegistration registration = Register_SPACE_TO_BATCH_ND();
+ micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
outputs_array, nullptr);
- TfLiteStatus status = runner.InitAndPrepare();
- if (status != kTfLiteOk) {
- return status;
- }
- status = runner.Invoke();
+ TF_LITE_ENSURE_STATUS(runner.InitAndPrepare());
+ TF_LITE_ENSURE_STATUS(runner.Invoke());
- return status;
+ for (int i = 0; i < output_size; ++i) {
+ // TODO(b/158102673): workaround for not having fatal test assertions.
+ TF_LITE_MICRO_EXPECT_EQ(golden[i], output[i]);
+ if (golden[i] != output[i]) {
+ return kTfLiteError;
+ }
+ }
+ return kTfLiteOk;
}
TfLiteStatus TestSpaceToBatchNdFloat(
- int* input_dims_data[kNumInputs], const float* input_data,
- const int32_t* block_shape_data, const int32_t* padding_data,
- int* output_dims_data, const float* golden_data, float* output_data) {
- TfLiteIntArray* input_dims =
- IntArrayFromInts(input_dims_data[kInputTensorIndex]);
- TfLiteIntArray* block_shape_dims =
- IntArrayFromInts(input_dims_data[kBlockShapeTensorIndex]);
- TfLiteIntArray* padding_dims =
- IntArrayFromInts(input_dims_data[kPaddingTensorIndex]);
+ int* input_dims_data, const float* input_data, int* block_shape_dims_data,
+ const int32_t* block_shape_data, int* crops_dims_data,
+ const int32_t* crops_data, int* output_dims_data, const float* golden,
+ float* output_data) {
+ TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
+ TfLiteIntArray* block_shape_dims = IntArrayFromInts(block_shape_dims_data);
+ TfLiteIntArray* crops_dims = IntArrayFromInts(crops_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
- constexpr int kTensorsCount = kNumInputs + kNumOutputs;
- TfLiteTensor tensors[kTensorsCount];
- tensors[kInputTensorIndex] =
- tflite::testing::CreateTensor(input_data, input_dims);
- tensors[kBlockShapeTensorIndex] =
- tflite::testing::CreateTensor(block_shape_data, block_shape_dims);
- tensors[kBlockShapeTensorIndex].allocation_type = kTfLiteMmapRo;
- tensors[kPaddingTensorIndex] =
- tflite::testing::CreateTensor(padding_data, padding_dims);
- tensors[kPaddingTensorIndex].allocation_type = kTfLiteMmapRo;
- tensors[kOutputTensorIndex] =
- tflite::testing::CreateTensor(output_data, output_dims);
+ constexpr int inputs_size = 3;
+ constexpr int outputs_size = 1;
+ constexpr int tensors_size = inputs_size + outputs_size;
+ TfLiteTensor tensors[tensors_size] = {
+ CreateTensor(input_data, input_dims),
+ CreateTensor(block_shape_data, block_shape_dims),
+ CreateTensor(crops_data, crops_dims),
+ CreateTensor(output_data, output_dims),
+ };
- TfLiteStatus status = ExecuteSpaceToBatchNdTest(tensors, kTensorsCount);
- if (status != kTfLiteOk) {
- return status;
- }
-
- // check output dimensions (relocated) against original dimensions
- TF_LITE_MICRO_EXPECT_EQ(output_dims->size,
- tensors[kOutputTensorIndex].dims->size);
- TF_LITE_MICRO_CHECK_FAIL();
- for (int i = 0; i < output_dims->size; i++) {
- TF_LITE_MICRO_EXPECT_EQ(output_dims->data[i],
- tensors[kOutputTensorIndex].dims->data[i]);
- // TODO(b/158102673): workaround for not having fatal test assertions.
- TF_LITE_MICRO_CHECK_FAIL();
- }
-
- // check output data against golden
- const int output_count = ElementCount(*output_dims);
- for (int i = 0; i < output_count; i++) {
- TF_LITE_MICRO_EXPECT_NEAR(golden_data[i], output_data[i], kTestTolerance);
- // TODO(b/158102673): workaround for not having fatal test assertions.
- TF_LITE_MICRO_CHECK_FAIL();
- }
-
- return kTfLiteOk;
+ return ValidateSpaceToBatchNdGoldens(tensors, tensors_size, golden,
+ output_data, ElementCount(*output_dims));
}
-template <typename T, size_t kInCount, size_t kOutCount>
+template <typename T>
TfLiteStatus TestSpaceToBatchNdQuantized(
- TestQuantParams<T, kInCount, kOutCount>& params,
- int* input_dims_data[kNumInputs], const float* input_data,
- const int32_t* block_shape_data, const int32_t* padding_data,
- int* output_dims_data, const float* golden_data) {
- TfLiteIntArray* input_dims =
- IntArrayFromInts(input_dims_data[kInputTensorIndex]);
- TfLiteIntArray* block_shape_dims =
- IntArrayFromInts(input_dims_data[kBlockShapeTensorIndex]);
- TfLiteIntArray* padding_dims =
- IntArrayFromInts(input_dims_data[kPaddingTensorIndex]);
+ int* input_dims_data, const float* input_data, T* input_quantized,
+ float input_scale, int input_zero_point, int* block_shape_dims_data,
+ const int32_t* block_shape_data, int* crops_dims_data,
+ const int32_t* crops_data, int* output_dims_data, const float* golden,
+ T* golden_quantized, float output_scale, int output_zero_point,
+ T* output_data) {
+ TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
+ TfLiteIntArray* block_shape_dims = IntArrayFromInts(block_shape_dims_data);
+ TfLiteIntArray* crops_dims = IntArrayFromInts(crops_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
- int zero_point =
- tflite::testing::ZeroPointFromMinMax<T>(params.data_min, params.data_max);
- float scale =
- tflite::testing::ScaleFromMinMax<T>(params.data_min, params.data_max);
+ constexpr int inputs_size = 3;
+ constexpr int outputs_size = 1;
+ constexpr int tensors_size = inputs_size + outputs_size;
+ TfLiteTensor tensors[tensors_size] = {
+ tflite::testing::CreateQuantizedTensor(input_data, input_quantized,
+ input_dims, input_scale,
+ input_zero_point),
+ tflite::testing::CreateTensor(block_shape_data, block_shape_dims),
+ tflite::testing::CreateTensor(crops_data, crops_dims),
+ tflite::testing::CreateQuantizedTensor(output_data, output_dims,
+ output_scale, output_zero_point),
+ };
+ tflite::Quantize(golden, golden_quantized, ElementCount(*output_dims),
+ output_scale, output_zero_point);
- constexpr int kTensorsCount = kNumInputs + kNumOutputs;
- TfLiteTensor tensors[kTensorsCount];
- tensors[kInputTensorIndex] = tflite::testing::CreateQuantizedTensor(
- input_data, params.input_data, input_dims, scale, zero_point);
- tensors[kBlockShapeTensorIndex] =
- tflite::testing::CreateTensor(block_shape_data, block_shape_dims);
- tensors[kBlockShapeTensorIndex].allocation_type = kTfLiteMmapRo;
- tensors[kPaddingTensorIndex] =
- tflite::testing::CreateTensor(padding_data, padding_dims);
- tensors[kPaddingTensorIndex].allocation_type = kTfLiteMmapRo;
- tensors[kOutputTensorIndex] = tflite::testing::CreateQuantizedTensor(
- params.output_data, output_dims, scale, zero_point);
-
- TfLiteStatus status = ExecuteSpaceToBatchNdTest(tensors, kTensorsCount);
- if (status != kTfLiteOk) {
- return status;
- }
-
- // check output dimensions (relocated) against original dimensions
- TF_LITE_MICRO_EXPECT_EQ(output_dims->size,
- tensors[kOutputTensorIndex].dims->size);
- TF_LITE_MICRO_CHECK_FAIL();
- for (int i = 0; i < output_dims->size; i++) {
- TF_LITE_MICRO_EXPECT_EQ(output_dims->data[i],
- tensors[kOutputTensorIndex].dims->data[i]);
- // TODO(b/158102673): workaround for not having fatal test assertions.
- TF_LITE_MICRO_CHECK_FAIL();
- }
-
- // check output data against golden
- const int output_count = ElementCount(*output_dims);
- const float quantization_tolerance =
- (params.data_max - params.data_min) /
- (std::numeric_limits<T>::max() - std::numeric_limits<T>::min());
- for (int i = 0; i < output_count; i++) {
- float output_dequantized_data =
- (params.output_data[i] - zero_point) * scale;
- TF_LITE_MICRO_EXPECT_NEAR(golden_data[i], output_dequantized_data,
- quantization_tolerance);
- // TODO(b/158102673): workaround for not having fatal test assertions.
- TF_LITE_MICRO_CHECK_FAIL();
- }
-
- return kTfLiteOk;
+ return ValidateSpaceToBatchNdGoldens(tensors, tensors_size, golden_quantized,
+ output_data, ElementCount(*output_dims));
}
} // namespace
@@ -189,313 +125,30 @@
TF_LITE_MICRO_TESTS_BEGIN
-TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestInvalidShapeTest) {
- int kInputDims[] = {4, 1, 3, 3, 1}; // invalid shape
- int kBlockShapeDims[] = {1, 2};
- int kPaddingDims[] = {2, 2, 2};
- int kOutputDims[] = {4, 0, 0, 0, 0}; // placeholder dims, not used
-
- int* kInputDimsArray[tflite::testing::kNumInputs];
- kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
- kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
- kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
-
- constexpr float kInput[] = {1, 2, 3, 4, 5, 6, 7, 8, 9};
- constexpr int32_t kBlockShape[] = {2, 2};
- constexpr int32_t kPadding[] = {0, 0, 0, 0};
- constexpr float kGolden[] = {0}; // placeholder data, not used
- constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
- float output_data[kOutputCount];
-
- TF_LITE_MICRO_EXPECT_EQ(kTfLiteError,
- tflite::testing::TestSpaceToBatchNdFloat(
- kInputDimsArray, kInput, kBlockShape, kPadding,
- kOutputDims, kGolden, output_data));
-}
-
-TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestInvalidOutputShapeTest) {
- int kInputDims[] = {3, 1, 8, 1};
- int kBlockShapeDims[] = {1, 1};
- int kPaddingDims[] = {2, 1, 2};
- int kOutputDims[] = {3, 1, 6, 1}; // invalid shape
-
- int* kInputDimsArray[tflite::testing::kNumInputs];
- kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
- kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
- kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
-
- constexpr float kInput[] = {1, 1, 1, 1, 1, 1, 1, 1};
- constexpr int32_t kBlockShape[] = {2};
- constexpr int32_t kPadding[] = {2, 2};
- constexpr float kGolden[] = {
- 0, 0, 0, 0, 0, 0, // placeholder data, not used
- };
- constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
- float output_data[kOutputCount];
-
- TF_LITE_MICRO_EXPECT_EQ(kTfLiteError,
- tflite::testing::TestSpaceToBatchNdFloat(
- kInputDimsArray, kInput, kBlockShape, kPadding,
- kOutputDims, kGolden, output_data));
-}
-
-TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestValidOutputShapeTest) {
- int kInputDims[] = {3, 1, 8, 1};
- int kBlockShapeDims[] = {1, 1};
- int kPaddingDims[] = {2, 1, 2};
- int kOutputDims[] = {3, 2, 6, 1};
-
- int* kInputDimsArray[tflite::testing::kNumInputs];
- kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
- kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
- kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
-
- constexpr float kInput[] = {1, 1, 1, 1, 1, 1, 1, 1};
- constexpr int32_t kBlockShape[] = {2};
- constexpr int32_t kPadding[] = {2, 2};
- constexpr float kGolden[] = {0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0};
- constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
- float output_data[kOutputCount];
-
- TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
- tflite::testing::TestSpaceToBatchNdFloat(
- kInputDimsArray, kInput, kBlockShape, kPadding,
- kOutputDims, kGolden, output_data));
-}
-
-TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestSimpleConstTest) {
- int kInputDims[] = {4, 1, 4, 4, 1};
- int kBlockShapeDims[] = {1, 2};
- int kPaddingDims[] = {2, 2, 2};
- int kOutputDims[] = {4, 4, 2, 2, 1};
-
- int* kInputDimsArray[tflite::testing::kNumInputs];
- kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
- kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
- kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
-
- constexpr float kInput[] = {
- 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
- };
- constexpr int32_t kBlockShape[] = {2, 2};
- constexpr int32_t kPadding[] = {0, 0, 0, 0};
- constexpr float kGolden[] = {
- 1, 3, 9, 11, 2, 4, 10, 12, 5, 7, 13, 15, 6, 8, 14, 16,
- };
- constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
- float output_data[kOutputCount];
-
- TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
- tflite::testing::TestSpaceToBatchNdFloat(
- kInputDimsArray, kInput, kBlockShape, kPadding,
- kOutputDims, kGolden, output_data));
-}
-
-TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestMultipleInputBatchesConstTest) {
- int kInputDims[] = {4, 2, 2, 4, 1};
- int kBlockShapeDims[] = {1, 2};
- int kPaddingDims[] = {2, 2, 2};
- int kOutputDims[] = {4, 8, 1, 2, 1};
-
- int* kInputDimsArray[tflite::testing::kNumInputs];
- kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
- kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
- kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
-
- constexpr float kInput[] = {
- 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
- };
- constexpr int32_t kBlockShape[] = {2, 2};
- constexpr int32_t kPadding[] = {0, 0, 0, 0};
- constexpr float kGolden[] = {
- 1, 3, 9, 11, 2, 4, 10, 12, 5, 7, 13, 15, 6, 8, 14, 16,
- };
- constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
- float output_data[kOutputCount];
-
- TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
- tflite::testing::TestSpaceToBatchNdFloat(
- kInputDimsArray, kInput, kBlockShape, kPadding,
- kOutputDims, kGolden, output_data));
-}
-
-TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestSimplePaddingConstTest) {
- int kInputDims[] = {4, 1, 5, 2, 1};
- int kBlockShapeDims[] = {1, 2};
- int kPaddingDims[] = {2, 2, 2};
- int kOutputDims[] = {4, 6, 2, 2, 1};
-
- int* kInputDimsArray[tflite::testing::kNumInputs];
- kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
- kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
- kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
-
- constexpr float kInput[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
- constexpr int32_t kBlockShape[] = {3, 2};
- constexpr int32_t kPadding[] = {1, 0, 2, 0};
- constexpr float kGolden[] = {
- 0, 0, 0, 5, 0, 0, 0, 6, 0, 1, 0, 7, 0, 2, 0, 8, 0, 3, 0, 9, 0, 4, 0, 10,
- };
- constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
- float output_data[kOutputCount];
-
- TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
- tflite::testing::TestSpaceToBatchNdFloat(
- kInputDimsArray, kInput, kBlockShape, kPadding,
- kOutputDims, kGolden, output_data));
-}
-
-TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestComplexPaddingConstTest) {
- int kInputDims[] = {4, 1, 4, 2, 1};
- int kBlockShapeDims[] = {1, 2};
- int kPaddingDims[] = {2, 2, 2};
- int kOutputDims[] = {4, 6, 2, 4, 1};
-
- int* kInputDimsArray[tflite::testing::kNumInputs];
- kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
- kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
- kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
-
- constexpr float kInput[] = {1, 2, 3, 4, 5, 6, 7, 8};
- constexpr int32_t kBlockShape[] = {3, 2};
- constexpr int32_t kPadding[] = {1, 1, 2, 4};
- constexpr float kGolden[] = {
- 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 1, 0, 0, 0, 7, 0, 0,
- 0, 2, 0, 0, 0, 8, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0,
- };
- constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
- float output_data[kOutputCount];
-
- TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
- tflite::testing::TestSpaceToBatchNdFloat(
- kInputDimsArray, kInput, kBlockShape, kPadding,
- kOutputDims, kGolden, output_data));
-}
-
-TF_LITE_MICRO_TEST(QuantizedSpaceToBatchNDOpTestSimplePaddingConstTestInt8) {
- int kInputDims[] = {4, 1, 5, 2, 1};
- int kBlockShapeDims[] = {1, 2};
- int kPaddingDims[] = {2, 2, 2};
- int kOutputDims[] = {4, 6, 2, 2, 1};
-
- int* kInputDimsArray[tflite::testing::kNumInputs];
- kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
- kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
- kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
-
- constexpr float kInput[] = {
- -0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, -0.9, 0.1,
- };
- constexpr int kInputCount = std::extent<decltype(kInput)>::value;
- constexpr int32_t kBlockShape[] = {3, 2};
- constexpr int32_t kPadding[] = {1, 0, 2, 0};
- constexpr float kGolden[] = {
- 0, 0, 0, -0.5, 0, 0, 0, 0.6, 0, -0.1, 0, -0.7,
- 0, 0.2, 0, 0.8, 0, -0.3, 0, -0.9, 0, 0.4, 0, 0.1,
- };
- constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
-
- tflite::testing::TestQuantParams<int8_t, kInputCount, kOutputCount> params = {
- -1,
- std::numeric_limits<int8_t>::max() /
- static_cast<float>(std::numeric_limits<int8_t>::max() + 1),
- {},
- {}};
-
+TF_LITE_MICRO_TEST(SpaceToBatchBasicFloat) {
+ float output[tflite::testing::kBasicInputOutputSize];
TF_LITE_MICRO_EXPECT_EQ(
- kTfLiteOk, tflite::testing::TestSpaceToBatchNdQuantized(
- params, kInputDimsArray, kInput, kBlockShape, kPadding,
- kOutputDims, kGolden));
+ kTfLiteOk,
+ tflite::testing::TestSpaceToBatchNdFloat(
+ tflite::testing::basic_input_dims, tflite::testing::basic_input,
+ tflite::testing::basic_block_shape_dims,
+ tflite::testing::basic_block_shape, tflite::testing::basic_crops_dims,
+ tflite::testing::basic_crops, tflite::testing::basic_output_dims,
+ tflite::testing::basic_golden, output));
}
-TF_LITE_MICRO_TEST(QuantizedSpaceToBatchNDOpTestComplexPaddingConstTest) {
- int kInputDims[] = {4, 1, 4, 2, 1};
- int kBlockShapeDims[] = {1, 2};
- int kPaddingDims[] = {2, 2, 2};
- int kOutputDims[] = {4, 6, 2, 4, 1};
-
- int* kInputDimsArray[tflite::testing::kNumInputs];
- kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
- kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
- kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
-
- constexpr float kInput[] = {
- -0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8,
- };
- constexpr int kInputCount = std::extent<decltype(kInput)>::value;
- constexpr int32_t kBlockShape[] = {3, 2};
- constexpr int32_t kPadding[] = {1, 1, 2, 4};
- constexpr float kGolden[] = {
- 0, 0, 0, 0, 0, -0.5, 0, 0, 0, 0, 0, 0, 0, 0.6, 0, 0,
- 0, -0.1, 0, 0, 0, -0.7, 0, 0, 0, 0.2, 0, 0, 0, 0.8, 0, 0,
- 0, -0.3, 0, 0, 0, 0, 0, 0, 0, 0.4, 0, 0, 0, 0, 0, 0,
- };
- constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
-
- tflite::testing::TestQuantParams<int8_t, kInputCount, kOutputCount> params = {
- -1, 1, {}, {}};
-
+TF_LITE_MICRO_TEST(SpaceToBatchBasicInt8) {
+ int8_t output[tflite::testing::kBasicInputOutputSize];
+ int8_t input_quantized[tflite::testing::kBasicInputOutputSize];
+ int8_t golden_quantized[tflite::testing::kBasicInputOutputSize];
TF_LITE_MICRO_EXPECT_EQ(
- kTfLiteOk, tflite::testing::TestSpaceToBatchNdQuantized(
- params, kInputDimsArray, kInput, kBlockShape, kPadding,
- kOutputDims, kGolden));
-}
-
-TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestSimple3DConstTest) {
- int kInputDims[] = {3, 1, 4, 4};
- int kBlockShapeDims[] = {1, 1};
- int kPaddingDims[] = {2, 1, 2};
- int kOutputDims[] = {3, 2, 2, 4};
-
- int* kInputDimsArray[tflite::testing::kNumInputs];
- kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
- kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
- kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
-
- constexpr float kInput[] = {
- 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
- };
- constexpr int32_t kBlockShape[] = {2};
- constexpr int32_t kPadding[] = {0, 0};
- constexpr float kGolden[] = {
- 1, 2, 3, 4, 9, 10, 11, 12, 5, 6, 7, 8, 13, 14, 15, 16,
- };
- constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
- float output_data[kOutputCount];
-
- TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
- tflite::testing::TestSpaceToBatchNdFloat(
- kInputDimsArray, kInput, kBlockShape, kPadding,
- kOutputDims, kGolden, output_data));
-}
-
-TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestSimple3DPaddingConstTest) {
- int kInputDims[] = {3, 1, 4, 4};
- int kBlockShapeDims[] = {1, 1};
- int kPaddingDims[] = {2, 1, 2};
- int kOutputDims[] = {3, 2, 4, 4};
-
- int* kInputDimsArray[tflite::testing::kNumInputs];
- kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
- kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
- kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
-
- constexpr float kInput[] = {
- 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
- };
- constexpr int32_t kBlockShape[] = {2};
- constexpr int32_t kPadding[] = {2, 2};
- constexpr float kGolden[] = {
- 0, 0, 0, 0, 1, 2, 3, 4, 9, 10, 11, 12, 0, 0, 0, 0,
- 0, 0, 0, 0, 5, 6, 7, 8, 13, 14, 15, 16, 0, 0, 0, 0,
- };
- constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
- float output_data[kOutputCount];
-
- TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
- tflite::testing::TestSpaceToBatchNdFloat(
- kInputDimsArray, kInput, kBlockShape, kPadding,
- kOutputDims, kGolden, output_data));
+ kTfLiteOk,
+ tflite::testing::TestSpaceToBatchNdQuantized(
+ tflite::testing::basic_input_dims, tflite::testing::basic_input,
+ input_quantized, 1.0f, 0, tflite::testing::basic_block_shape_dims,
+ tflite::testing::basic_block_shape, tflite::testing::basic_crops_dims,
+ tflite::testing::basic_crops, tflite::testing::basic_output_dims,
+ tflite::testing::basic_golden, golden_quantized, 1.0f, 0, output));
}
TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/lite/micro/kernels/testdata/conv_test_data.cc b/tensorflow/lite/micro/kernels/testdata/conv_test_data.cc
index 6ab2e9c..094aab6 100644
--- a/tensorflow/lite/micro/kernels/testdata/conv_test_data.cc
+++ b/tensorflow/lite/micro/kernels/testdata/conv_test_data.cc
@@ -1,4 +1,4 @@
-/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -303,7 +303,7 @@
55295, 184082, 75855, 233991};
// Kernel Conv Test Case: Int8Filter8x3x3x3PerChannelScaleRelu6ShouldMatchGolden
-const int8_t kConvGoldenOutput1x15x15x8[1 * 15 * 15 * 8] = {
+const int8_t kConvGoldenOutput1x16x16x8[1 * 16 * 16 * 8] = {
-128, -21, -81, 67, -20, -109, -29, 4, -128, -19, -81, 68,
-19, -109, -31, 3, -128, -19, -80, 68, -20, -109, -32, 2,
-128, -19, -80, 68, -20, -109, -32, 1, -128, -19, -80, 68,
@@ -314,26 +314,28 @@
-36, -108, -41, 0, -128, -17, -78, 69, -20, -108, -37, -3,
-128, -18, -77, 68, -21, -107, -37, -3, -128, -18, -77, 69,
-20, -107, -38, -4, -128, -18, -77, 69, -22, -107, -38, -4,
- -128, -20, -81, 67, -19, -109, -30, 3, -128, -19, -81, 68,
- -19, -109, -31, 2, -128, -19, -80, 68, -20, -109, -31, 2,
- -128, -19, -80, 68, -20, -109, -33, 1, -128, -20, -79, 68,
- -19, -108, -33, 1, -128, -19, -88, 67, -19, -113, -34, 0,
- -128, -19, -99, 66, -13, -118, -1, 26, -128, -19, -120, 66,
- 32, -128, 2, 64, -128, -20, -124, 67, 8, -128, 13, 76,
- -128, -19, -98, 68, 1, -118, -17, 31, -128, -18, -89, 67,
- -12, -113, -33, 25, -128, -17, -76, 69, -22, -107, -37, -5,
- -128, -17, -77, 68, -22, -107, -38, -4, -128, -17, -77, 69,
- -22, -107, -38, -5, -128, -18, -76, 69, -23, -107, -38, -5,
- -128, -19, -81, 68, -20, -109, -31, 2, -128, -19, -80, 68,
- -21, -109, -32, 1, -128, -20, -79, 67, -20, -109, -32, 1,
- -128, -21, -79, 67, -20, -108, -32, 0, -128, -20, -79, 67,
- -21, -108, -33, -1, -128, -20, -86, 67, -20, -113, -12, -1,
- -128, -21, -93, 66, -15, -115, -10, 21, -128, -23, -113, 65,
- -16, -126, 77, 46, -128, -23, -120, 65, -9, -128, 74, 71,
- -128, -26, -97, 60, -26, -118, -12, 29, -128, -27, -89, 61,
- -11, -114, 1, 28, -128, -32, -77, 58, -15, -108, -42, -2,
- -128, -30, -78, 58, -16, -108, -40, 0, -128, -26, -77, 60,
- -18, -108, -40, -2, -128, -20, -77, 66, -24, -107, -41, -4,
+ -128, -19, -82, 69, -18, -110, -31, -4, -128, -20, -81, 67,
+ -19, -109, -30, 3, -128, -19, -81, 68, -19, -109, -31, 2,
+ -128, -19, -80, 68, -20, -109, -31, 2, -128, -19, -80, 68,
+ -20, -109, -33, 1, -128, -20, -79, 68, -19, -108, -33, 1,
+ -128, -19, -88, 67, -19, -113, -34, 0, -128, -19, -99, 66,
+ -13, -118, -1, 26, -128, -19, -120, 66, 32, -128, 2, 64,
+ -128, -20, -124, 67, 8, -128, 13, 76, -128, -19, -98, 68,
+ 1, -118, -17, 31, -128, -18, -89, 67, -12, -113, -33, 25,
+ -128, -17, -76, 69, -22, -107, -37, -5, -128, -17, -77, 68,
+ -22, -107, -38, -4, -128, -17, -77, 69, -22, -107, -38, -5,
+ -128, -18, -76, 69, -23, -107, -38, -5, -128, -19, -82, 68,
+ -18, -110, -31, -5, -128, -19, -81, 68, -20, -109, -31, 2,
+ -128, -19, -80, 68, -21, -109, -32, 1, -128, -20, -79, 67,
+ -20, -109, -32, 1, -128, -21, -79, 67, -20, -108, -32, 0,
+ -128, -20, -79, 67, -21, -108, -33, -1, -128, -20, -86, 67,
+ -20, -113, -12, -1, -128, -21, -93, 66, -15, -115, -10, 21,
+ -128, -23, -113, 65, -16, -126, 77, 46, -128, -23, -120, 65,
+ -9, -128, 74, 71, -128, -26, -97, 60, -26, -118, -12, 29,
+ -128, -27, -89, 61, -11, -114, 1, 28, -128, -32, -77, 58,
+ -15, -108, -42, -2, -128, -30, -78, 58, -16, -108, -40, 0,
+ -128, -26, -77, 60, -18, -108, -40, -2, -128, -20, -77, 66,
+ -24, -107, -41, -4, -128, -19, -82, 68, -17, -110, -32, -5,
-128, -20, -81, 67, -20, -109, -32, 3, -128, -20, -82, 67,
-14, -110, -34, 0, -128, -20, -90, 67, -62, -113, -67, 16,
-128, -24, -80, 64, -25, -109, -38, 1, -128, -25, -80, 61,
@@ -344,26 +346,28 @@
-8, -112, -19, 16, -128, -42, -83, 52, -26, -111, -42, 9,
-128, -40, -82, 52, -20, -111, -38, 8, -128, -38, -82, 52,
-18, -111, -38, 7, -128, -29, -81, 54, -22, -110, -42, 9,
- -128, -31, -80, 58, -17, -109, -33, 1, -128, -24, -107, 60,
- 0, -123, -46, 32, -128, -23, -121, 66, -27, -128, 24, 76,
- -128, -34, -92, 57, -9, -115, -39, 32, -128, -35, -86, 53,
- -16, -113, -29, 10, -128, -30, -97, 58, -25, -117, -17, 24,
- -128, -27, -97, 62, -46, -117, -15, 28, -128, -26, -91, 66,
- -4, -114, -44, 23, -128, -24, -87, 68, 15, -112, -71, 17,
- -128, -25, -85, 66, 10, -111, -71, 13, -128, -29, -83, 63,
- -38, -110, -39, 5, -128, -36, -82, 59, -88, -110, -14, -2,
- -128, -41, -90, 52, 25, -115, -9, 19, -128, -43, -94, 47,
- 1, -117, -21, 34, -128, -32, -84, 52, -13, -111, -16, 15,
- -128, -38, -85, 54, -23, -112, -36, 6, -128, -32, -116, 54,
- -12, -128, 61, 50, -128, -24, -127, 63, -12, -128, 68, 87,
- -128, -24, -106, 63, -47, -122, -6, 44, -128, -22, -102, 63,
- 8, -120, -17, 38, -128, -20, -99, 66, 6, -118, 30, 38,
- -128, -20, -90, 66, -8, -114, 35, 21, -128, -20, -89, 66,
- -16, -114, 18, 17, -128, -20, -89, 65, -15, -113, 14, 25,
- -128, -18, -84, 66, -14, -111, 22, 9, -128, -18, -89, 66,
- -4, -113, 8, 15, -128, -20, -89, 66, -16, -113, -27, 20,
- -128, -27, -90, 65, 9, -114, -30, 14, -128, -37, -90, 56,
- -63, -114, -5, 20, -128, -43, -84, 50, -21, -112, -32, 8,
+ -128, -19, -81, 68, -18, -110, -32, -7, -128, -31, -80, 58,
+ -17, -109, -33, 1, -128, -24, -107, 60, 0, -123, -46, 32,
+ -128, -23, -121, 66, -27, -128, 24, 76, -128, -34, -92, 57,
+ -9, -115, -39, 32, -128, -35, -86, 53, -16, -113, -29, 10,
+ -128, -30, -97, 58, -25, -117, -17, 24, -128, -27, -97, 62,
+ -46, -117, -15, 28, -128, -26, -91, 66, -4, -114, -44, 23,
+ -128, -24, -87, 68, 15, -112, -71, 17, -128, -25, -85, 66,
+ 10, -111, -71, 13, -128, -29, -83, 63, -38, -110, -39, 5,
+ -128, -36, -82, 59, -88, -110, -14, -2, -128, -41, -90, 52,
+ 25, -115, -9, 19, -128, -43, -94, 47, 1, -117, -21, 34,
+ -128, -32, -84, 52, -13, -111, -16, 15, -128, -24, -83, 63,
+ -20, -111, -39, -1, -128, -38, -85, 54, -23, -112, -36, 6,
+ -128, -32, -116, 54, -12, -128, 61, 50, -128, -24, -127, 63,
+ -12, -128, 68, 87, -128, -24, -106, 63, -47, -122, -6, 44,
+ -128, -22, -102, 63, 8, -120, -17, 38, -128, -20, -99, 66,
+ 6, -118, 30, 38, -128, -20, -90, 66, -8, -114, 35, 21,
+ -128, -20, -89, 66, -16, -114, 18, 17, -128, -20, -89, 65,
+ -15, -113, 14, 25, -128, -18, -84, 66, -14, -111, 22, 9,
+ -128, -18, -89, 66, -4, -113, 8, 15, -128, -20, -89, 66,
+ -16, -113, -27, 20, -128, -27, -90, 65, 9, -114, -30, 14,
+ -128, -37, -90, 56, -63, -114, -5, 20, -128, -43, -84, 50,
+ -21, -112, -32, 8, -128, -38, -89, 49, -17, -114, -24, 16,
-128, -41, -96, 49, -13, -118, -21, 31, -128, -37, -101, 53,
-33, -120, 23, 28, -128, -33, -116, 54, -24, -127, 56, 61,
-128, -22, -108, 63, -15, -123, 45, 53, -128, -20, -89, 66,
@@ -374,26 +378,28 @@
-22, -105, -30, -14, -128, -20, -78, 66, -30, -108, 1, -4,
-128, -19, -88, 66, -9, -113, 8, 10, -128, -25, -96, 65,
-11, -117, -5, 27, -128, -43, -96, 53, -6, -118, -26, 31,
- -128, -46, -106, 44, -8, -123, 9, 46, -128, -38, -109, 48,
- -17, -124, -10, 45, -128, -26, -110, 59, -12, -124, 28, 57,
- -128, -20, -89, 68, -22, -113, 15, 23, -128, -21, -74, 69,
- -14, -106, -41, -4, -128, -24, -68, 63, -2, -103, -53, -12,
- -128, -23, -73, 67, -9, -105, -54, -12, -128, -22, -76, 67,
- -12, -107, -45, -13, -128, -21, -84, 67, -44, -110, -51, 21,
- -128, -23, -69, 67, -21, -104, -40, -21, -128, -21, -68, 67,
- -16, -103, -64, -13, -128, -25, -64, 65, -14, -101, -54, -21,
- -128, -18, -70, 69, -18, -104, -28, -19, -128, -20, -86, 66,
- -19, -112, 5, 5, -128, -31, -102, 64, -16, -120, 16, 34,
- -128, -42, -99, 48, -43, -120, 33, 24, -128, -36, -116, 51,
- 23, -128, 52, 69, -128, -24, -98, 63, -31, -118, 31, 43,
- -128, -20, -76, 69, -28, -106, -40, -5, -128, -24, -71, 64,
- -13, -105, -45, -7, -128, -17, -67, 70, -22, -102, -59, -20,
- -128, -23, -72, 67, -21, -105, -38, -13, -128, -21, -77, 67,
- -23, -108, -54, -9, -128, -21, -86, 67, -51, -112, 15, 22,
- -128, -23, -71, 67, -32, -105, -67, -17, -128, -22, -67, 66,
- -15, -102, -46, -13, -128, -20, -60, 69, -21, -99, -60, -30,
- -128, -23, -61, 65, 0, -100, -61, -25, -128, -18, -70, 69,
- -35, -104, -31, -22, -128, -25, -94, 67, -21, -116, 15, 5,
+ -128, -41, -97, 48, -13, -118, -13, 34, -128, -46, -106, 44,
+ -8, -123, 9, 46, -128, -38, -109, 48, -17, -124, -10, 45,
+ -128, -26, -110, 59, -12, -124, 28, 57, -128, -20, -89, 68,
+ -22, -113, 15, 23, -128, -21, -74, 69, -14, -106, -41, -4,
+ -128, -24, -68, 63, -2, -103, -53, -12, -128, -23, -73, 67,
+ -9, -105, -54, -12, -128, -22, -76, 67, -12, -107, -45, -13,
+ -128, -21, -84, 67, -44, -110, -51, 21, -128, -23, -69, 67,
+ -21, -104, -40, -21, -128, -21, -68, 67, -16, -103, -64, -13,
+ -128, -25, -64, 65, -14, -101, -54, -21, -128, -18, -70, 69,
+ -18, -104, -28, -19, -128, -20, -86, 66, -19, -112, 5, 5,
+ -128, -31, -102, 64, -16, -120, 16, 34, -128, -41, -104, 49,
+ -25, -121, 5, 49, -128, -42, -99, 48, -43, -120, 33, 24,
+ -128, -36, -116, 51, 23, -128, 52, 69, -128, -24, -98, 63,
+ -31, -118, 31, 43, -128, -20, -76, 69, -28, -106, -40, -5,
+ -128, -24, -71, 64, -13, -105, -45, -7, -128, -17, -67, 70,
+ -22, -102, -59, -20, -128, -23, -72, 67, -21, -105, -38, -13,
+ -128, -21, -77, 67, -23, -108, -54, -9, -128, -21, -86, 67,
+ -51, -112, 15, 22, -128, -23, -71, 67, -32, -105, -67, -17,
+ -128, -22, -67, 66, -15, -102, -46, -13, -128, -20, -60, 69,
+ -21, -99, -60, -30, -128, -23, -61, 65, 0, -100, -61, -25,
+ -128, -18, -70, 69, -35, -104, -31, -22, -128, -25, -94, 67,
+ -21, -116, 15, 5, -128, -31, -103, 59, -47, -120, 24, 42,
-128, -41, -92, 50, -21, -116, -12, 21, -128, -30, -99, 53,
-25, -119, 30, 28, -128, -22, -88, 66, -29, -113, -13, 23,
-128, -20, -76, 69, -11, -107, -37, -6, -128, -20, -72, 67,
@@ -404,26 +410,28 @@
-27, -98, -65, -25, -128, -20, -50, 69, -36, -94, -73, -47,
-128, -20, -52, 68, -39, -95, -74, -46, -128, -19, -60, 69,
-44, -99, -62, -36, -128, -22, -86, 69, -22, -113, -24, -15,
- -128, -45, -98, 48, -13, -119, -5, 31, -128, -32, -101, 52,
- 1, -120, -17, 34, -128, -20, -90, 66, -23, -113, -28, 27,
- -128, -22, -75, 69, -21, -107, -52, -13, -128, -23, -74, 64,
- -45, -105, -42, -7, -128, -18, -68, 69, -21, -103, -54, -18,
- -128, -19, -67, 69, -23, -103, -53, -19, -128, -19, -75, 69,
- -23, -107, -54, -13, -128, -20, -85, 67, -29, -111, -1, 20,
- -128, -19, -68, 69, -36, -103, -68, -18, -128, -20, -63, 69,
- -23, -100, -70, -23, -128, -19, -59, 69, -6, -98, -99, -31,
- -128, -20, -61, 68, -15, -100, -88, -31, -128, -19, -67, 69,
- -20, -102, -73, -23, -128, -22, -92, 69, -19, -115, -38, 1,
- -128, -42, -91, 50, -29, -116, 10, 20, -128, -37, -104, 50,
- -2, -122, 18, 37, -128, -23, -101, 61, -5, -119, -13, 48,
- -128, -21, -78, 70, -33, -107, -46, -3, -128, -26, -73, 65,
- -31, -106, -37, -11, -128, -22, -69, 66, -13, -103, -55, -12,
- -128, -19, -66, 69, -21, -102, -55, -21, -128, -20, -72, 69,
- -28, -105, -25, -15, -128, -22, -75, 69, -29, -106, -16, -6,
- -128, -21, -70, 69, -19, -104, -35, -9, -128, -21, -65, 69,
- -23, -101, -60, -23, -128, -23, -68, 68, -25, -103, -61, -23,
- -128, -24, -68, 63, -12, -103, -55, -15, -128, -20, -78, 69,
- -33, -108, -55, -18, -128, -25, -107, 66, -2, -123, -1, 43,
+ -128, -34, -100, 60, -27, -119, 3, 39, -128, -45, -98, 48,
+ -13, -119, -5, 31, -128, -32, -101, 52, 1, -120, -17, 34,
+ -128, -20, -90, 66, -23, -113, -28, 27, -128, -22, -75, 69,
+ -21, -107, -52, -13, -128, -23, -74, 64, -45, -105, -42, -7,
+ -128, -18, -68, 69, -21, -103, -54, -18, -128, -19, -67, 69,
+ -23, -103, -53, -19, -128, -19, -75, 69, -23, -107, -54, -13,
+ -128, -20, -85, 67, -29, -111, -1, 20, -128, -19, -68, 69,
+ -36, -103, -68, -18, -128, -20, -63, 69, -23, -100, -70, -23,
+ -128, -19, -59, 69, -6, -98, -99, -31, -128, -20, -61, 68,
+ -15, -100, -88, -31, -128, -19, -67, 69, -20, -102, -73, -23,
+ -128, -22, -92, 69, -19, -115, -38, 1, -128, -37, -102, 57,
+ -17, -120, -2, 45, -128, -42, -91, 50, -29, -116, 10, 20,
+ -128, -37, -104, 50, -2, -122, 18, 37, -128, -23, -101, 61,
+ -5, -119, -13, 48, -128, -21, -78, 70, -33, -107, -46, -3,
+ -128, -26, -73, 65, -31, -106, -37, -11, -128, -22, -69, 66,
+ -13, -103, -55, -12, -128, -19, -66, 69, -21, -102, -55, -21,
+ -128, -20, -72, 69, -28, -105, -25, -15, -128, -22, -75, 69,
+ -29, -106, -16, -6, -128, -21, -70, 69, -19, -104, -35, -9,
+ -128, -21, -65, 69, -23, -101, -60, -23, -128, -23, -68, 68,
+ -25, -103, -61, -23, -128, -24, -68, 63, -12, -103, -55, -15,
+ -128, -20, -78, 69, -33, -108, -55, -18, -128, -25, -107, 66,
+ -2, -123, -1, 43, -128, -37, -102, 51, -10, -120, 17, 46,
-128, -40, -81, 52, -33, -110, -21, 3, -128, -43, -96, 50,
-27, -118, 7, 17, -128, -32, -112, 53, -21, -125, 15, 57,
-128, -23, -94, 66, -11, -115, -42, 34, -128, -20, -76, 70,
@@ -434,26 +442,39 @@
-34, -104, -64, -17, -128, -23, -69, 64, -3, -103, -48, -15,
-128, -18, -74, 69, -12, -106, -65, -16, -128, -22, -93, 67,
-18, -116, -31, 20, -128, -26, -90, 62, -16, -114, 49, 23,
- -128, -44, -87, 49, -18, -113, -34, 14, -128, -45, -93, 49,
- -11, -117, -3, 22, -128, -40, -110, 48, -18, -125, 49, 48,
- -128, -30, -113, 57, -7, -125, 5, 61, -128, -23, -93, 66,
- -13, -115, -47, 29, -128, -20, -77, 69, -14, -107, -67, 0,
- -128, -20, -73, 69, -27, -105, -48, -8, -128, -26, -70, 64,
- -47, -104, -43, -14, -128, -22, -72, 68, 7, -105, -52, -6,
- -128, -24, -70, 64, -50, -104, -46, -15, -128, -19, -73, 69,
- -30, -105, -41, -12, -128, -19, -81, 70, -40, -109, -85, -9,
- -128, -22, -97, 67, 21, -118, -45, 27, -128, -29, -93, 59,
- -37, -115, 9, 31, -128, -31, -81, 59, 3, -110, -54, 5,
- -128, -44, -88, 47, -11, -114, 1, 21, -128, -38, -85, 52,
- -29, -112, -4, 10, -128, -39, -92, 50, -27, -116, 18, 16,
- -128, -39, -112, 46, -14, -126, 57, 50, -128, -31, -117, 57,
- 9, -128, 4, 67, -128, -25, -105, 63, -17, -121, -46, 47,
- -128, -22, -91, 67, -2, -114, -50, 20, -128, -21, -88, 68,
- -35, -112, -63, 12, -128, -21, -89, 69, -25, -113, -70, 12,
- -128, -21, -94, 68, -30, -116, -71, 18, -128, -22, -99, 66,
- 6, -119, -57, 32, -128, -26, -101, 60, 24, -120, 19, 44,
- -128, -35, -94, 54, -27, -117, 28, 27, -128, -41, -89, 50,
- -24, -114, -16, 19, -128, -41, -89, 52, -12, -114, -24, 19};
+ -128, -24, -86, 62, -74, -112, 2, 0, -128, -44, -87, 49,
+ -18, -113, -34, 14, -128, -45, -93, 49, -11, -117, -3, 22,
+ -128, -40, -110, 48, -18, -125, 49, 48, -128, -30, -113, 57,
+ -7, -125, 5, 61, -128, -23, -93, 66, -13, -115, -47, 29,
+ -128, -20, -77, 69, -14, -107, -67, 0, -128, -20, -73, 69,
+ -27, -105, -48, -8, -128, -26, -70, 64, -47, -104, -43, -14,
+ -128, -22, -72, 68, 7, -105, -52, -6, -128, -24, -70, 64,
+ -50, -104, -46, -15, -128, -19, -73, 69, -30, -105, -41, -12,
+ -128, -19, -81, 70, -40, -109, -85, -9, -128, -22, -97, 67,
+ 21, -118, -45, 27, -128, -29, -93, 59, -37, -115, 9, 31,
+ -128, -31, -81, 59, 3, -110, -54, 5, -128, -35, -87, 55,
+ -9, -113, -29, 10, -128, -44, -88, 47, -11, -114, 1, 21,
+ -128, -38, -85, 52, -29, -112, -4, 10, -128, -39, -92, 50,
+ -27, -116, 18, 16, -128, -39, -112, 46, -14, -126, 57, 50,
+ -128, -31, -117, 57, 9, -128, 4, 67, -128, -25, -105, 63,
+ -17, -121, -46, 47, -128, -22, -91, 67, -2, -114, -50, 20,
+ -128, -21, -88, 68, -35, -112, -63, 12, -128, -21, -89, 69,
+ -25, -113, -70, 12, -128, -21, -94, 68, -30, -116, -71, 18,
+ -128, -22, -99, 66, 6, -119, -57, 32, -128, -26, -101, 60,
+ 24, -120, 19, 44, -128, -35, -94, 54, -27, -117, 28, 27,
+ -128, -41, -89, 50, -24, -114, -16, 19, -128, -41, -89, 52,
+ -12, -114, -24, 19, -128, -37, -91, 52, -17, -115, -19, 21,
+ -128, -37, -85, 55, -39, -112, -37, 8, -128, -38, -84, 55,
+ -32, -111, -42, 8, -128, -37, -86, 54, -31, -113, -32, 10,
+ -128, -37, -95, 52, -36, -117, 10, 19, -128, -35, -110, 52,
+ -10, -125, 71, 51, -128, -28, -110, 56, -21, -124, 71, 55,
+ -128, -29, -105, 57, 37, -122, 21, 48, -128, -32, -106, 57,
+ 30, -123, 32, 52, -128, -30, -107, 58, 25, -123, 38, 51,
+ -128, -30, -105, 57, -9, -122, 49, 45, -128, -30, -98, 57,
+ -59, -118, 33, 30, -128, -33, -90, 57, -46, -114, -13, 18,
+ -128, -37, -90, 56, -15, -115, -27, 16, -128, -42, -96, 50,
+ 12, -118, -6, 33, -128, -43, -95, 49, 7, -117, -4, 32,
+ -128, -37, -95, 53, -7, -117, -1, 30};
// Conv Test Case: Int8Filter1x3x3x1ShouldMatchGolden
const int8_t kConvFilter1x3x3x1[1 * 3 * 3 * 1]{
diff --git a/tensorflow/lite/micro/kernels/testdata/conv_test_data.h b/tensorflow/lite/micro/kernels/testdata/conv_test_data.h
index 72ac8a2..bdac510 100644
--- a/tensorflow/lite/micro/kernels/testdata/conv_test_data.h
+++ b/tensorflow/lite/micro/kernels/testdata/conv_test_data.h
@@ -1,4 +1,4 @@
-/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -22,7 +22,7 @@
extern const int8_t kConvInput1x32x32x3[];
extern const int8_t kConvFilter8x3x3x3[];
extern const int32_t kConvBiasQuantized8[];
-extern const int8_t kConvGoldenOutput1x15x15x8[];
+extern const int8_t kConvGoldenOutput1x16x16x8[];
// Kernel Conv Test Cases: Int8Filter1x3x3x1ShouldMatchGolden
extern const int8_t kConvInput1x4x4x1[];
diff --git a/tensorflow/lite/micro/memory_arena_threshold_test.cc b/tensorflow/lite/micro/memory_arena_threshold_test.cc
index 90d7f6c..3017f56 100644
--- a/tensorflow/lite/micro/memory_arena_threshold_test.cc
+++ b/tensorflow/lite/micro/memory_arena_threshold_test.cc
@@ -1,4 +1,4 @@
-/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -93,24 +93,24 @@
// Total size contributed by the conv model excluding the
// RecordingMicroAllocator's overhead
// TODO(b/207157610): replace magic number that depends on OPs
-constexpr int kTestConvModelOnlyTotalSize = 9558;
+constexpr int kTestConvModelOnlyTotalSize = 9488;
// Tail size contributed by the conv model excluding the
// RecordingMicroAllocator's overhead
// TODO(b/207157610): replace magic number that depends on OPs
-constexpr int kTestConvModelOnlyTailSize = 1886;
+constexpr int kTestConvModelOnlyTailSize = 1816;
constexpr int kTestConvModelPersistentTfLiteTensorDataSize = 128;
-constexpr int kTestConvModelPersistentBufferDataSize = 798;
+constexpr int kTestConvModelPersistentBufferDataSize = 728;
#else
// Total size contributed by the conv model excluding the
// RecordingMicroAllocator's overhead
// TODO(b/207157610): replace magic number that depends on OPs
-constexpr int kTestConvModelOnlyTotalSize = 9830;
+constexpr int kTestConvModelOnlyTotalSize = 9760;
// Tail size contributed by the conv model excluding the
// RecordingMicroAllocator's overhead
// TODO(b/207157610): replace magic number that depends on OPs
-constexpr int kTestConvModelOnlyTailSize = 2158;
+constexpr int kTestConvModelOnlyTailSize = 2088;
constexpr int kTestConvModelPersistentTfLiteTensorDataSize = 224;
-constexpr int kTestConvModelPersistentBufferDataSize = 790;
+constexpr int kTestConvModelPersistentBufferDataSize = 720;
#endif
constexpr int kTestConvModelHeadSize = 7744;
constexpr int kTestConvModelOpRuntimeDataSize = 136;
diff --git a/tensorflow/lite/micro/python/keras_tests/BUILD b/tensorflow/lite/micro/python/keras_tests/BUILD
deleted file mode 100644
index 12fd6a9..0000000
--- a/tensorflow/lite/micro/python/keras_tests/BUILD
+++ /dev/null
@@ -1,28 +0,0 @@
-# Description:
-# TensorFlow Lite microcontroller example.
-load("@rules_python//python:defs.bzl", "py_binary", "py_test")
-load("@tflm_pip_deps//:requirements.bzl", "requirement")
-
-package(
- default_visibility = ["//visibility:public"],
- # Disabling layering_check because of http://b/177257332
- features = ["-layering_check"],
- licenses = ["notice"],
-)
-
-py_test(
- name = "conv_tests",
- srcs = ["conv_tests.py"],
- main = "conv_tests.py",
- python_version = "PY3",
- tags = [
- "noasan",
- "nomsan", # Python doesn't like these symbols
- "noubsan",
- ],
- deps = [
- requirement("numpy"),
- requirement("tensorflow-cpu"),
- "//python/tflite_micro:runtime",
- ],
-)
diff --git a/tensorflow/lite/micro/python/keras_tests/conv_tests.py b/tensorflow/lite/micro/python/keras_tests/conv_tests.py
deleted file mode 100644
index 80c43bf..0000000
--- a/tensorflow/lite/micro/python/keras_tests/conv_tests.py
+++ /dev/null
@@ -1,158 +0,0 @@
-# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# =============================================================================
-"""
-Convolution kernel testing with dilation > 1, using the TfLiteConverter to
-convert models directly from Keras.
-
-Run:
-bazel build tensorflow/lite/micro/python/keras_tests:conv_tests
-bazel-bin/tensorflow/lite/micro/python/keras_tests/conv_tests
-"""
-
-from __future__ import annotations
-
-from tensorflow.python.framework import test_util
-from tensorflow.python.platform import test
-
-import tensorflow as tf
-import keras.api._v2.keras as keras # for Visual Studio Code to work correctly
-from tflite_micro.python.tflite_micro import runtime
-
-
-class KerasConvTest(test_util.TensorFlowTestCase):
-
- def MakeConv1dModel(self, *, shape, dilation):
- input_layer = keras.layers.Input(shape=shape)
- conv_layer = keras.layers.Conv1D(1,
- 3,
- dilation_rate=dilation,
- padding='same')(input_layer)
- model = keras.Model(inputs=input_layer, outputs=conv_layer)
- return model
-
- def MakeConv2dModel(self, *, shape, dilation):
- input_layer = keras.layers.Input(shape=shape)
- conv_layer = keras.layers.Conv2D(1,
- 3,
- dilation_rate=dilation,
- padding='same')(input_layer)
- model = keras.Model(inputs=input_layer, outputs=conv_layer)
- return model
-
- def MakeDepthwiseConv1dModel(self, *, shape, dilation):
- input_layer = keras.layers.Input(shape=shape)
- conv_layer = keras.layers.DepthwiseConv1D(3,
- dilation_rate=dilation,
- padding='same')(input_layer)
- model = keras.Model(inputs=input_layer, outputs=conv_layer)
- return model
-
- def MakeDepthwiseConv2dModel(self, *, shape, dilation):
- input_layer = keras.layers.Input(shape=shape)
- conv_layer = keras.layers.DepthwiseConv2D(3,
- dilation_rate=dilation,
- padding='same')(input_layer)
- model = keras.Model(inputs=input_layer, outputs=conv_layer)
- return model
-
- def MakeTransposeConv1dModel(self, *, shape, dilation):
- input_layer = keras.layers.Input(shape=shape)
- conv_layer = keras.layers.Conv1DTranspose(1,
- 3,
- dilation_rate=dilation,
- padding='same')(input_layer)
- model = keras.Model(inputs=input_layer, outputs=conv_layer)
- return model
-
- def MakeTransposeConv2dModel(self, *, shape, dilation):
- input_layer = keras.layers.Input(shape=shape)
- conv_layer = keras.layers.Conv2DTranspose(1,
- 3,
- dilation_rate=dilation,
- padding='same')(input_layer)
- model = keras.Model(inputs=input_layer, outputs=conv_layer)
- return model
-
- def ExecuteModelTest(self, model: keras.Model):
- model_shape = list(model.layers[0].input_shape[0])
- model_shape[0] = 1
- input_data = tf.ones(shape=model_shape, dtype=tf.float32)
- tf_result: tf.Tensor = model(input_data) # type: ignore
-
- converter = tf.lite.TFLiteConverter.from_keras_model(model=model)
- tflite_model = converter.convert()
- tf.lite.experimental.Analyzer.analyze(model_content=tflite_model)
-
- tflm_interpreter = runtime.Interpreter.from_bytes(
- tflite_model,
- intrepreter_config=runtime.InterpreterConfig.kPreserveAllTensors)
- tflm_interpreter.set_input(input_data, 0)
- tflm_interpreter.invoke()
- tflm_result = tflm_interpreter.get_output(0)
- tflm_output_details = tflm_interpreter.get_output_details(0)
- tflm_shape = tflm_output_details['shape']
-
- print(f'{tf_result=}')
- print(f'{tflm_result=} {tflm_shape=}')
-
- self.assertAllClose(tf_result, tflm_result)
- self.assertAllEqual(tf_result.shape, tflm_shape)
-
- def setUp(self):
- pass
-
- def testConv1dWithDilation1(self):
- model = self.MakeConv1dModel(shape=(8, 1), dilation=1)
- self.ExecuteModelTest(model)
-
- def testConv1dWithDilation2(self):
- model = self.MakeConv1dModel(shape=(8, 1), dilation=2)
- self.ExecuteModelTest(model)
-
- def testConv2dWithDilation1(self):
- model = self.MakeConv2dModel(shape=(1, 8, 1), dilation=1)
- self.ExecuteModelTest(model)
-
- def testConv2dWithDilation2(self):
- model = self.MakeConv2dModel(shape=(1, 8, 1), dilation=2)
- self.ExecuteModelTest(model)
-
- def testDepthwiseConv1dWithDilation1(self):
- model = self.MakeDepthwiseConv1dModel(shape=(8, 1), dilation=1)
- self.ExecuteModelTest(model)
-
- def testDepthwiseConv1dWithDilation2(self):
- model = self.MakeDepthwiseConv1dModel(shape=(8, 1), dilation=2)
- self.ExecuteModelTest(model)
-
- def testDepthwiseConv2dWithDilation1(self):
- model = self.MakeDepthwiseConv2dModel(shape=(1, 8, 1), dilation=1)
- self.ExecuteModelTest(model)
-
- def testDepthwiseConv2dWithDilation2(self):
- model = self.MakeDepthwiseConv2dModel(shape=(1, 8, 1), dilation=2)
- self.ExecuteModelTest(model)
-
- def testTransposeConv1dWithDilation1(self):
- model = self.MakeTransposeConv1dModel(shape=(8, 1), dilation=1)
- self.ExecuteModelTest(model)
-
- def testTransposeConv2dWithDilation1(self):
- model = self.MakeTransposeConv2dModel(shape=(1, 8, 1), dilation=1)
- self.ExecuteModelTest(model)
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/lite/micro/testing/micro_test.h b/tensorflow/lite/micro/testing/micro_test.h
index bd2d93c..a28f4b6 100644
--- a/tensorflow/lite/micro/testing/micro_test.h
+++ b/tensorflow/lite/micro/testing/micro_test.h
@@ -264,11 +264,4 @@
} \
} while (false)
-#define TF_LITE_MICRO_CHECK_FAIL() \
- do { \
- if (micro_test::did_test_fail) { \
- return kTfLiteError; \
- } \
- } while (false)
-
#endif // TENSORFLOW_LITE_MICRO_TESTING_MICRO_TEST_H_