Revert "SPACE_TO_BATCH_ND: update output tensor shape" (#2339)

Temporarily reverting so that we can reland with SpaceToBatch, BatchToSpace, Conv2D and ExpandDims output tensor resizing atomically.

Reverts tensorflow/tflite-micro#2335

BUG=#2338
diff --git a/tensorflow/lite/micro/examples/micro_speech/micro_speech_test.cc b/tensorflow/lite/micro/examples/micro_speech/micro_speech_test.cc
index 6fe75c1..f31728c 100644
--- a/tensorflow/lite/micro/examples/micro_speech/micro_speech_test.cc
+++ b/tensorflow/lite/micro/examples/micro_speech/micro_speech_test.cc
@@ -32,6 +32,13 @@
 #include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
 #include "tensorflow/lite/micro/testing/micro_test.h"
 
+#define TF_LITE_MICRO_CHECK_FAIL()   \
+  do {                               \
+    if (micro_test::did_test_fail) { \
+      return kTfLiteError;           \
+    }                                \
+  } while (false)
+
 namespace {
 
 // Arena size is a guesstimate, followed by use of
diff --git a/tensorflow/lite/micro/kernels/space_to_batch_nd.cc b/tensorflow/lite/micro/kernels/space_to_batch_nd.cc
index fd7ef92..6b536ee 100644
--- a/tensorflow/lite/micro/kernels/space_to_batch_nd.cc
+++ b/tensorflow/lite/micro/kernels/space_to_batch_nd.cc
@@ -1,4 +1,4 @@
-/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -15,10 +15,7 @@
 
 #include "tensorflow/lite/kernels/internal/reference/space_to_batch_nd.h"
 
-#include <algorithm>
-
 #include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/kernels/internal/runtime_shape.h"
 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
 #include "tensorflow/lite/kernels/internal/types.h"
 #include "tensorflow/lite/kernels/kernel_util.h"
@@ -27,11 +24,12 @@
 #include "tensorflow/lite/micro/micro_utils.h"
 
 namespace tflite {
+
 namespace {
 
 constexpr int kInputTensor = 0;
 constexpr int kBlockShapeTensor = 1;
-constexpr int kPaddingTensor = 2;
+constexpr int kCropsTensor = 2;
 constexpr int kOutputTensor = 0;
 
 // Currently, only 3D NHC and 4D NHWC input/output op_context are supported.
@@ -46,68 +44,6 @@
   return context->AllocatePersistentBuffer(context, sizeof(SpaceToBatchParams));
 }
 
-TfLiteStatus ReshapeOutputTensor(TfLiteContext* context, const TfLiteNode* node,
-                                 const TfLiteTensor* input,
-                                 const TfLiteTensor* block_shape,
-                                 const TfLiteTensor* padding,
-                                 TfLiteTensor* output) {
-  TF_LITE_ENSURE(context, IsConstantOrPersistentTensor(block_shape));
-  TF_LITE_ENSURE(context, IsConstantOrPersistentTensor(padding));
-  const int32_t* block_shape_data = GetTensorData<int32_t>(block_shape);
-  const int32_t* padding_data = GetTensorData<int32_t>(padding);
-
-  TfLiteIntArray* input_dims = input->dims;
-  int spatial_dims_num = input_dims->size - 2;
-  // Block_shape should be a 1D tensor with dimension [spatial_dims_num].
-  TF_LITE_ENSURE_EQ(context, NumDimensions(block_shape), 1);
-  TF_LITE_ENSURE_EQ(context, block_shape->dims->data[0], spatial_dims_num);
-  // Padding should be a 2D tensor with dimension [spatial_dims_num, 2].
-  TF_LITE_ENSURE_EQ(context, NumDimensions(padding), 2);
-  TF_LITE_ENSURE_EQ(context, padding->dims->data[0], spatial_dims_num);
-  TF_LITE_ENSURE_EQ(context, padding->dims->data[1], 2);
-
-  // copy from input tensor as per TfLite code
-  RuntimeShape output_shape = GetTensorShape(input);
-  // keep a copy of the output tensor shape for later comparison
-  RuntimeShape old_output_shape = GetTensorShape(output);
-
-  // Ensures the input height and width (with padding) is a multiple of block
-  // shape height and width.
-  int output_batch_size = input_dims->data[0];
-  for (int dim = 0; dim < spatial_dims_num; ++dim) {
-    int final_dim_size = (input_dims->data[dim + 1] + padding_data[dim * 2] +
-                          padding_data[dim * 2 + 1]);
-    TF_LITE_ENSURE(context, block_shape_data[dim] != 0);
-    TF_LITE_ENSURE_EQ(context, final_dim_size % block_shape_data[dim], 0);
-    output_shape.SetDim(dim + 1, final_dim_size / block_shape_data[dim]);
-    output_batch_size *= block_shape_data[dim];
-  }
-  output_shape.SetDim(0, output_batch_size);
-  output_shape.SetDim(input_dims->size - 1,
-                      input_dims->data[input_dims->size - 1]);
-
-  // check if need to relocate output tensor dims
-  if (output_shape == old_output_shape) {
-    return kTfLiteOk;
-  } else if (output_shape.FlatSize() > old_output_shape.FlatSize() &&
-             output->data.data != nullptr) {
-    MicroPrintf(
-        "SPACE_TO_BATCH_ND: resizing flatbuffer tensor data is not supported");
-    return kTfLiteError;
-  }
-
-  // set the output tensor dims from output_shape
-  TF_LITE_ENSURE_EQ(context, input_dims->size, output->dims->size);
-  TfLiteEvalTensor* output_eval =
-      tflite::micro::GetEvalOutput(context, node, kOutputTensor);
-  TF_LITE_ENSURE_STATUS(tflite::micro::CreateWritableTensorDimsWithCopy(
-      context, output, output_eval));
-  std::copy_n(output_shape.DimsData(), output_shape.DimensionsCount(),
-              output->dims->data);
-
-  return kTfLiteOk;
-}
-
 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
   MicroContext* micro_context = GetMicroContext(context);
 
@@ -116,47 +52,19 @@
 
   TfLiteTensor* input =
       micro_context->AllocateTempInputTensor(node, kInputTensor);
-  TF_LITE_ENSURE(context, input != nullptr);
-  TfLiteTensor* block_shape =
-      micro_context->AllocateTempInputTensor(node, kBlockShapeTensor);
-  TF_LITE_ENSURE(context, block_shape != nullptr);
-  TfLiteTensor* padding =
-      micro_context->AllocateTempInputTensor(node, kPaddingTensor);
-  TF_LITE_ENSURE(context, padding != nullptr);
   TfLiteTensor* output =
       micro_context->AllocateTempOutputTensor(node, kOutputTensor);
-  TF_LITE_ENSURE(context, output != nullptr);
+  TF_LITE_ENSURE(context, input != nullptr && output != nullptr);
 
   TF_LITE_ENSURE(context, NumDimensions(input) >= kInputOutputMinDimensionNum);
   TF_LITE_ENSURE(context, NumDimensions(output) >= kInputOutputMinDimensionNum);
   TF_LITE_ENSURE(context, NumDimensions(input) <= kInputOutputMaxDimensionNum);
   TF_LITE_ENSURE(context, NumDimensions(output) <= kInputOutputMaxDimensionNum);
   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
-  TF_LITE_ENSURE(context,
-                 input->type == kTfLiteFloat32 || input->type == kTfLiteInt8);
-
-  TF_LITE_ENSURE(context, node->user_data != nullptr);
-  SpaceToBatchParams& params =
-      *(static_cast<SpaceToBatchParams*>(node->user_data));
-
-  if (input->type == kTfLiteInt8) {
-    TF_LITE_ENSURE(context, input->params.scale == output->params.scale);
-    TF_LITE_ENSURE(context,
-                   input->params.zero_point == output->params.zero_point);
-    params.output_offset = output->params.zero_point;
-  } else {
-    params.output_offset = 0;
-  }
-
-  TfLiteStatus status =
-      ReshapeOutputTensor(context, node, input, block_shape, padding, output);
 
   micro_context->DeallocateTempTfLiteTensor(input);
-  micro_context->DeallocateTempTfLiteTensor(block_shape);
-  micro_context->DeallocateTempTfLiteTensor(padding);
   micro_context->DeallocateTempTfLiteTensor(output);
-
-  return status;
+  return kTfLiteOk;
 }
 
 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
@@ -168,8 +76,8 @@
       tflite::micro::GetEvalInput(context, node, kInputTensor);
   const TfLiteEvalTensor* block_shape =
       tflite::micro::GetEvalInput(context, node, kBlockShapeTensor);
-  const TfLiteEvalTensor* padding =
-      tflite::micro::GetEvalInput(context, node, kPaddingTensor);
+  const TfLiteEvalTensor* crops =
+      tflite::micro::GetEvalInput(context, node, kCropsTensor);
   TfLiteEvalTensor* output =
       tflite::micro::GetEvalOutput(context, node, kOutputTensor);
 
@@ -180,8 +88,8 @@
           tflite::micro::GetTensorData<float>(input),
           tflite::micro::GetTensorShape(block_shape),
           tflite::micro::GetTensorData<int32_t>(block_shape),
-          tflite::micro::GetTensorShape(padding),
-          tflite::micro::GetTensorData<int32_t>(padding),
+          tflite::micro::GetTensorShape(crops),
+          tflite::micro::GetTensorData<int32_t>(crops),
           tflite::micro::GetTensorShape(output),
           tflite::micro::GetTensorData<float>(output));
       break;
@@ -191,8 +99,8 @@
           tflite::micro::GetTensorData<int8_t>(input),
           tflite::micro::GetTensorShape(block_shape),
           tflite::micro::GetTensorData<int32_t>(block_shape),
-          tflite::micro::GetTensorShape(padding),
-          tflite::micro::GetTensorData<int32_t>(padding),
+          tflite::micro::GetTensorShape(crops),
+          tflite::micro::GetTensorData<int32_t>(crops),
           tflite::micro::GetTensorShape(output),
           tflite::micro::GetTensorData<int8_t>(output));
       break;
diff --git a/tensorflow/lite/micro/kernels/space_to_batch_nd_test.cc b/tensorflow/lite/micro/kernels/space_to_batch_nd_test.cc
index d45a606..eae185b 100644
--- a/tensorflow/lite/micro/kernels/space_to_batch_nd_test.cc
+++ b/tensorflow/lite/micro/kernels/space_to_batch_nd_test.cc
@@ -1,4 +1,4 @@
-/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -14,8 +14,6 @@
 ==============================================================================*/
 
 #include <cstdint>
-#include <limits>
-#include <type_traits>
 
 #include "tensorflow/lite/c/builtin_op_data.h"
 #include "tensorflow/lite/c/common.h"
@@ -27,160 +25,98 @@
 namespace testing {
 namespace {
 
-constexpr float kTestTolerance = 1e-05;
-constexpr int kNumInputs = 3;
-constexpr int kNumOutputs = 1;
-constexpr int kInputTensorIndex = 0;
-constexpr int kBlockShapeTensorIndex = 1;
-constexpr int kPaddingTensorIndex = 2;
-constexpr int kOutputTensorIndex = 3;
+constexpr int kBasicInputOutputSize = 16;
+int basic_input_dims[] = {4, 1, 4, 4, 1};
+const float basic_input[kBasicInputOutputSize] = {
+    1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
+int basic_block_shape_dims[] = {1, 2};
+const int32_t basic_block_shape[] = {2, 2};
+int basic_crops_dims[] = {1, 4};
+const int32_t basic_crops[] = {0, 0, 0, 0};
+int basic_output_dims[] = {4, 4, 2, 2, 1};
+const float basic_golden[kBasicInputOutputSize] = {1, 3, 9,  11, 2, 4, 10, 12,
+                                                   5, 7, 13, 15, 6, 8, 14, 16};
 
-// min/max are used to compute scale, zero-point (asymmetric)
-template <typename T, size_t kInputSize, size_t kOutputSize>
-struct TestQuantParams {
-  // quantization parameters
-  float data_min;              // input data minimum value
-  float data_max;              // input data maximum value
-  T output_data[kOutputSize];  // quantized output storage
-  T input_data[kInputSize];    // quantized input storage
-};
+template <typename T>
+TfLiteStatus ValidateSpaceToBatchNdGoldens(TfLiteTensor* tensors,
+                                           int tensors_size, const T* golden,
+                                           T* output, int output_size) {
+  int inputs_array_data[] = {3, 0, 1, 2};
+  TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
+  int outputs_array_data[] = {1, 3};
+  TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
 
-TfLiteStatus ExecuteSpaceToBatchNdTest(TfLiteTensor* tensors,
-                                       int tensors_count) {
-  int kInputArrayData[] = {kNumInputs, kInputTensorIndex,
-                           kBlockShapeTensorIndex, kPaddingTensorIndex};
-  TfLiteIntArray* inputs_array = IntArrayFromInts(kInputArrayData);
-  int kOutputArrayData[] = {kNumOutputs, kOutputTensorIndex};
-  TfLiteIntArray* outputs_array = IntArrayFromInts(kOutputArrayData);
-
-  const TFLMRegistration registration = tflite::Register_SPACE_TO_BATCH_ND();
-  micro::KernelRunner runner(registration, tensors, tensors_count, inputs_array,
+  const TFLMRegistration registration = Register_SPACE_TO_BATCH_ND();
+  micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
                              outputs_array, nullptr);
 
-  TfLiteStatus status = runner.InitAndPrepare();
-  if (status != kTfLiteOk) {
-    return status;
-  }
-  status = runner.Invoke();
+  TF_LITE_ENSURE_STATUS(runner.InitAndPrepare());
+  TF_LITE_ENSURE_STATUS(runner.Invoke());
 
-  return status;
+  for (int i = 0; i < output_size; ++i) {
+    // TODO(b/158102673): workaround for not having fatal test assertions.
+    TF_LITE_MICRO_EXPECT_EQ(golden[i], output[i]);
+    if (golden[i] != output[i]) {
+      return kTfLiteError;
+    }
+  }
+  return kTfLiteOk;
 }
 
 TfLiteStatus TestSpaceToBatchNdFloat(
-    int* input_dims_data[kNumInputs], const float* input_data,
-    const int32_t* block_shape_data, const int32_t* padding_data,
-    int* output_dims_data, const float* golden_data, float* output_data) {
-  TfLiteIntArray* input_dims =
-      IntArrayFromInts(input_dims_data[kInputTensorIndex]);
-  TfLiteIntArray* block_shape_dims =
-      IntArrayFromInts(input_dims_data[kBlockShapeTensorIndex]);
-  TfLiteIntArray* padding_dims =
-      IntArrayFromInts(input_dims_data[kPaddingTensorIndex]);
+    int* input_dims_data, const float* input_data, int* block_shape_dims_data,
+    const int32_t* block_shape_data, int* crops_dims_data,
+    const int32_t* crops_data, int* output_dims_data, const float* golden,
+    float* output_data) {
+  TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
+  TfLiteIntArray* block_shape_dims = IntArrayFromInts(block_shape_dims_data);
+  TfLiteIntArray* crops_dims = IntArrayFromInts(crops_dims_data);
   TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
 
-  constexpr int kTensorsCount = kNumInputs + kNumOutputs;
-  TfLiteTensor tensors[kTensorsCount];
-  tensors[kInputTensorIndex] =
-      tflite::testing::CreateTensor(input_data, input_dims);
-  tensors[kBlockShapeTensorIndex] =
-      tflite::testing::CreateTensor(block_shape_data, block_shape_dims);
-  tensors[kBlockShapeTensorIndex].allocation_type = kTfLiteMmapRo;
-  tensors[kPaddingTensorIndex] =
-      tflite::testing::CreateTensor(padding_data, padding_dims);
-  tensors[kPaddingTensorIndex].allocation_type = kTfLiteMmapRo;
-  tensors[kOutputTensorIndex] =
-      tflite::testing::CreateTensor(output_data, output_dims);
+  constexpr int inputs_size = 3;
+  constexpr int outputs_size = 1;
+  constexpr int tensors_size = inputs_size + outputs_size;
+  TfLiteTensor tensors[tensors_size] = {
+      CreateTensor(input_data, input_dims),
+      CreateTensor(block_shape_data, block_shape_dims),
+      CreateTensor(crops_data, crops_dims),
+      CreateTensor(output_data, output_dims),
+  };
 
-  TfLiteStatus status = ExecuteSpaceToBatchNdTest(tensors, kTensorsCount);
-  if (status != kTfLiteOk) {
-    return status;
-  }
-
-  // check output dimensions (relocated) against original dimensions
-  TF_LITE_MICRO_EXPECT_EQ(output_dims->size,
-                          tensors[kOutputTensorIndex].dims->size);
-  TF_LITE_MICRO_CHECK_FAIL();
-  for (int i = 0; i < output_dims->size; i++) {
-    TF_LITE_MICRO_EXPECT_EQ(output_dims->data[i],
-                            tensors[kOutputTensorIndex].dims->data[i]);
-    // TODO(b/158102673): workaround for not having fatal test assertions.
-    TF_LITE_MICRO_CHECK_FAIL();
-  }
-
-  // check output data against golden
-  const int output_count = ElementCount(*output_dims);
-  for (int i = 0; i < output_count; i++) {
-    TF_LITE_MICRO_EXPECT_NEAR(golden_data[i], output_data[i], kTestTolerance);
-    // TODO(b/158102673): workaround for not having fatal test assertions.
-    TF_LITE_MICRO_CHECK_FAIL();
-  }
-
-  return kTfLiteOk;
+  return ValidateSpaceToBatchNdGoldens(tensors, tensors_size, golden,
+                                       output_data, ElementCount(*output_dims));
 }
 
-template <typename T, size_t kInCount, size_t kOutCount>
+template <typename T>
 TfLiteStatus TestSpaceToBatchNdQuantized(
-    TestQuantParams<T, kInCount, kOutCount>& params,
-    int* input_dims_data[kNumInputs], const float* input_data,
-    const int32_t* block_shape_data, const int32_t* padding_data,
-    int* output_dims_data, const float* golden_data) {
-  TfLiteIntArray* input_dims =
-      IntArrayFromInts(input_dims_data[kInputTensorIndex]);
-  TfLiteIntArray* block_shape_dims =
-      IntArrayFromInts(input_dims_data[kBlockShapeTensorIndex]);
-  TfLiteIntArray* padding_dims =
-      IntArrayFromInts(input_dims_data[kPaddingTensorIndex]);
+    int* input_dims_data, const float* input_data, T* input_quantized,
+    float input_scale, int input_zero_point, int* block_shape_dims_data,
+    const int32_t* block_shape_data, int* crops_dims_data,
+    const int32_t* crops_data, int* output_dims_data, const float* golden,
+    T* golden_quantized, float output_scale, int output_zero_point,
+    T* output_data) {
+  TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
+  TfLiteIntArray* block_shape_dims = IntArrayFromInts(block_shape_dims_data);
+  TfLiteIntArray* crops_dims = IntArrayFromInts(crops_dims_data);
   TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
 
-  int zero_point =
-      tflite::testing::ZeroPointFromMinMax<T>(params.data_min, params.data_max);
-  float scale =
-      tflite::testing::ScaleFromMinMax<T>(params.data_min, params.data_max);
+  constexpr int inputs_size = 3;
+  constexpr int outputs_size = 1;
+  constexpr int tensors_size = inputs_size + outputs_size;
+  TfLiteTensor tensors[tensors_size] = {
+      tflite::testing::CreateQuantizedTensor(input_data, input_quantized,
+                                             input_dims, input_scale,
+                                             input_zero_point),
+      tflite::testing::CreateTensor(block_shape_data, block_shape_dims),
+      tflite::testing::CreateTensor(crops_data, crops_dims),
+      tflite::testing::CreateQuantizedTensor(output_data, output_dims,
+                                             output_scale, output_zero_point),
+  };
+  tflite::Quantize(golden, golden_quantized, ElementCount(*output_dims),
+                   output_scale, output_zero_point);
 
-  constexpr int kTensorsCount = kNumInputs + kNumOutputs;
-  TfLiteTensor tensors[kTensorsCount];
-  tensors[kInputTensorIndex] = tflite::testing::CreateQuantizedTensor(
-      input_data, params.input_data, input_dims, scale, zero_point);
-  tensors[kBlockShapeTensorIndex] =
-      tflite::testing::CreateTensor(block_shape_data, block_shape_dims);
-  tensors[kBlockShapeTensorIndex].allocation_type = kTfLiteMmapRo;
-  tensors[kPaddingTensorIndex] =
-      tflite::testing::CreateTensor(padding_data, padding_dims);
-  tensors[kPaddingTensorIndex].allocation_type = kTfLiteMmapRo;
-  tensors[kOutputTensorIndex] = tflite::testing::CreateQuantizedTensor(
-      params.output_data, output_dims, scale, zero_point);
-
-  TfLiteStatus status = ExecuteSpaceToBatchNdTest(tensors, kTensorsCount);
-  if (status != kTfLiteOk) {
-    return status;
-  }
-
-  // check output dimensions (relocated) against original dimensions
-  TF_LITE_MICRO_EXPECT_EQ(output_dims->size,
-                          tensors[kOutputTensorIndex].dims->size);
-  TF_LITE_MICRO_CHECK_FAIL();
-  for (int i = 0; i < output_dims->size; i++) {
-    TF_LITE_MICRO_EXPECT_EQ(output_dims->data[i],
-                            tensors[kOutputTensorIndex].dims->data[i]);
-    // TODO(b/158102673): workaround for not having fatal test assertions.
-    TF_LITE_MICRO_CHECK_FAIL();
-  }
-
-  // check output data against golden
-  const int output_count = ElementCount(*output_dims);
-  const float quantization_tolerance =
-      (params.data_max - params.data_min) /
-      (std::numeric_limits<T>::max() - std::numeric_limits<T>::min());
-  for (int i = 0; i < output_count; i++) {
-    float output_dequantized_data =
-        (params.output_data[i] - zero_point) * scale;
-    TF_LITE_MICRO_EXPECT_NEAR(golden_data[i], output_dequantized_data,
-                              quantization_tolerance);
-    // TODO(b/158102673): workaround for not having fatal test assertions.
-    TF_LITE_MICRO_CHECK_FAIL();
-  }
-
-  return kTfLiteOk;
+  return ValidateSpaceToBatchNdGoldens(tensors, tensors_size, golden_quantized,
+                                       output_data, ElementCount(*output_dims));
 }
 
 }  // namespace
@@ -189,313 +125,30 @@
 
 TF_LITE_MICRO_TESTS_BEGIN
 
-TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestInvalidShapeTest) {
-  int kInputDims[] = {4, 1, 3, 3, 1};  // invalid shape
-  int kBlockShapeDims[] = {1, 2};
-  int kPaddingDims[] = {2, 2, 2};
-  int kOutputDims[] = {4, 0, 0, 0, 0};  // placeholder dims, not used
-
-  int* kInputDimsArray[tflite::testing::kNumInputs];
-  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
-  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
-  kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
-
-  constexpr float kInput[] = {1, 2, 3, 4, 5, 6, 7, 8, 9};
-  constexpr int32_t kBlockShape[] = {2, 2};
-  constexpr int32_t kPadding[] = {0, 0, 0, 0};
-  constexpr float kGolden[] = {0};  // placeholder data, not used
-  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
-  float output_data[kOutputCount];
-
-  TF_LITE_MICRO_EXPECT_EQ(kTfLiteError,
-                          tflite::testing::TestSpaceToBatchNdFloat(
-                              kInputDimsArray, kInput, kBlockShape, kPadding,
-                              kOutputDims, kGolden, output_data));
-}
-
-TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestInvalidOutputShapeTest) {
-  int kInputDims[] = {3, 1, 8, 1};
-  int kBlockShapeDims[] = {1, 1};
-  int kPaddingDims[] = {2, 1, 2};
-  int kOutputDims[] = {3, 1, 6, 1};  // invalid shape
-
-  int* kInputDimsArray[tflite::testing::kNumInputs];
-  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
-  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
-  kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
-
-  constexpr float kInput[] = {1, 1, 1, 1, 1, 1, 1, 1};
-  constexpr int32_t kBlockShape[] = {2};
-  constexpr int32_t kPadding[] = {2, 2};
-  constexpr float kGolden[] = {
-      0, 0, 0, 0, 0, 0,  // placeholder data, not used
-  };
-  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
-  float output_data[kOutputCount];
-
-  TF_LITE_MICRO_EXPECT_EQ(kTfLiteError,
-                          tflite::testing::TestSpaceToBatchNdFloat(
-                              kInputDimsArray, kInput, kBlockShape, kPadding,
-                              kOutputDims, kGolden, output_data));
-}
-
-TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestValidOutputShapeTest) {
-  int kInputDims[] = {3, 1, 8, 1};
-  int kBlockShapeDims[] = {1, 1};
-  int kPaddingDims[] = {2, 1, 2};
-  int kOutputDims[] = {3, 2, 6, 1};
-
-  int* kInputDimsArray[tflite::testing::kNumInputs];
-  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
-  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
-  kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
-
-  constexpr float kInput[] = {1, 1, 1, 1, 1, 1, 1, 1};
-  constexpr int32_t kBlockShape[] = {2};
-  constexpr int32_t kPadding[] = {2, 2};
-  constexpr float kGolden[] = {0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0};
-  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
-  float output_data[kOutputCount];
-
-  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
-                          tflite::testing::TestSpaceToBatchNdFloat(
-                              kInputDimsArray, kInput, kBlockShape, kPadding,
-                              kOutputDims, kGolden, output_data));
-}
-
-TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestSimpleConstTest) {
-  int kInputDims[] = {4, 1, 4, 4, 1};
-  int kBlockShapeDims[] = {1, 2};
-  int kPaddingDims[] = {2, 2, 2};
-  int kOutputDims[] = {4, 4, 2, 2, 1};
-
-  int* kInputDimsArray[tflite::testing::kNumInputs];
-  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
-  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
-  kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
-
-  constexpr float kInput[] = {
-      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
-  };
-  constexpr int32_t kBlockShape[] = {2, 2};
-  constexpr int32_t kPadding[] = {0, 0, 0, 0};
-  constexpr float kGolden[] = {
-      1, 3, 9, 11, 2, 4, 10, 12, 5, 7, 13, 15, 6, 8, 14, 16,
-  };
-  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
-  float output_data[kOutputCount];
-
-  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
-                          tflite::testing::TestSpaceToBatchNdFloat(
-                              kInputDimsArray, kInput, kBlockShape, kPadding,
-                              kOutputDims, kGolden, output_data));
-}
-
-TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestMultipleInputBatchesConstTest) {
-  int kInputDims[] = {4, 2, 2, 4, 1};
-  int kBlockShapeDims[] = {1, 2};
-  int kPaddingDims[] = {2, 2, 2};
-  int kOutputDims[] = {4, 8, 1, 2, 1};
-
-  int* kInputDimsArray[tflite::testing::kNumInputs];
-  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
-  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
-  kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
-
-  constexpr float kInput[] = {
-      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
-  };
-  constexpr int32_t kBlockShape[] = {2, 2};
-  constexpr int32_t kPadding[] = {0, 0, 0, 0};
-  constexpr float kGolden[] = {
-      1, 3, 9, 11, 2, 4, 10, 12, 5, 7, 13, 15, 6, 8, 14, 16,
-  };
-  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
-  float output_data[kOutputCount];
-
-  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
-                          tflite::testing::TestSpaceToBatchNdFloat(
-                              kInputDimsArray, kInput, kBlockShape, kPadding,
-                              kOutputDims, kGolden, output_data));
-}
-
-TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestSimplePaddingConstTest) {
-  int kInputDims[] = {4, 1, 5, 2, 1};
-  int kBlockShapeDims[] = {1, 2};
-  int kPaddingDims[] = {2, 2, 2};
-  int kOutputDims[] = {4, 6, 2, 2, 1};
-
-  int* kInputDimsArray[tflite::testing::kNumInputs];
-  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
-  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
-  kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
-
-  constexpr float kInput[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
-  constexpr int32_t kBlockShape[] = {3, 2};
-  constexpr int32_t kPadding[] = {1, 0, 2, 0};
-  constexpr float kGolden[] = {
-      0, 0, 0, 5, 0, 0, 0, 6, 0, 1, 0, 7, 0, 2, 0, 8, 0, 3, 0, 9, 0, 4, 0, 10,
-  };
-  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
-  float output_data[kOutputCount];
-
-  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
-                          tflite::testing::TestSpaceToBatchNdFloat(
-                              kInputDimsArray, kInput, kBlockShape, kPadding,
-                              kOutputDims, kGolden, output_data));
-}
-
-TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestComplexPaddingConstTest) {
-  int kInputDims[] = {4, 1, 4, 2, 1};
-  int kBlockShapeDims[] = {1, 2};
-  int kPaddingDims[] = {2, 2, 2};
-  int kOutputDims[] = {4, 6, 2, 4, 1};
-
-  int* kInputDimsArray[tflite::testing::kNumInputs];
-  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
-  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
-  kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
-
-  constexpr float kInput[] = {1, 2, 3, 4, 5, 6, 7, 8};
-  constexpr int32_t kBlockShape[] = {3, 2};
-  constexpr int32_t kPadding[] = {1, 1, 2, 4};
-  constexpr float kGolden[] = {
-      0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 1, 0, 0, 0, 7, 0, 0,
-      0, 2, 0, 0, 0, 8, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0,
-  };
-  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
-  float output_data[kOutputCount];
-
-  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
-                          tflite::testing::TestSpaceToBatchNdFloat(
-                              kInputDimsArray, kInput, kBlockShape, kPadding,
-                              kOutputDims, kGolden, output_data));
-}
-
-TF_LITE_MICRO_TEST(QuantizedSpaceToBatchNDOpTestSimplePaddingConstTestInt8) {
-  int kInputDims[] = {4, 1, 5, 2, 1};
-  int kBlockShapeDims[] = {1, 2};
-  int kPaddingDims[] = {2, 2, 2};
-  int kOutputDims[] = {4, 6, 2, 2, 1};
-
-  int* kInputDimsArray[tflite::testing::kNumInputs];
-  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
-  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
-  kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
-
-  constexpr float kInput[] = {
-      -0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, -0.9, 0.1,
-  };
-  constexpr int kInputCount = std::extent<decltype(kInput)>::value;
-  constexpr int32_t kBlockShape[] = {3, 2};
-  constexpr int32_t kPadding[] = {1, 0, 2, 0};
-  constexpr float kGolden[] = {
-      0, 0,   0, -0.5, 0, 0,    0, 0.6,  0, -0.1, 0, -0.7,
-      0, 0.2, 0, 0.8,  0, -0.3, 0, -0.9, 0, 0.4,  0, 0.1,
-  };
-  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
-
-  tflite::testing::TestQuantParams<int8_t, kInputCount, kOutputCount> params = {
-      -1,
-      std::numeric_limits<int8_t>::max() /
-          static_cast<float>(std::numeric_limits<int8_t>::max() + 1),
-      {},
-      {}};
-
+TF_LITE_MICRO_TEST(SpaceToBatchBasicFloat) {
+  float output[tflite::testing::kBasicInputOutputSize];
   TF_LITE_MICRO_EXPECT_EQ(
-      kTfLiteOk, tflite::testing::TestSpaceToBatchNdQuantized(
-                     params, kInputDimsArray, kInput, kBlockShape, kPadding,
-                     kOutputDims, kGolden));
+      kTfLiteOk,
+      tflite::testing::TestSpaceToBatchNdFloat(
+          tflite::testing::basic_input_dims, tflite::testing::basic_input,
+          tflite::testing::basic_block_shape_dims,
+          tflite::testing::basic_block_shape, tflite::testing::basic_crops_dims,
+          tflite::testing::basic_crops, tflite::testing::basic_output_dims,
+          tflite::testing::basic_golden, output));
 }
 
-TF_LITE_MICRO_TEST(QuantizedSpaceToBatchNDOpTestComplexPaddingConstTest) {
-  int kInputDims[] = {4, 1, 4, 2, 1};
-  int kBlockShapeDims[] = {1, 2};
-  int kPaddingDims[] = {2, 2, 2};
-  int kOutputDims[] = {4, 6, 2, 4, 1};
-
-  int* kInputDimsArray[tflite::testing::kNumInputs];
-  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
-  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
-  kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
-
-  constexpr float kInput[] = {
-      -0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8,
-  };
-  constexpr int kInputCount = std::extent<decltype(kInput)>::value;
-  constexpr int32_t kBlockShape[] = {3, 2};
-  constexpr int32_t kPadding[] = {1, 1, 2, 4};
-  constexpr float kGolden[] = {
-      0, 0,    0, 0, 0, -0.5, 0, 0, 0, 0,   0, 0, 0, 0.6, 0, 0,
-      0, -0.1, 0, 0, 0, -0.7, 0, 0, 0, 0.2, 0, 0, 0, 0.8, 0, 0,
-      0, -0.3, 0, 0, 0, 0,    0, 0, 0, 0.4, 0, 0, 0, 0,   0, 0,
-  };
-  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
-
-  tflite::testing::TestQuantParams<int8_t, kInputCount, kOutputCount> params = {
-      -1, 1, {}, {}};
-
+TF_LITE_MICRO_TEST(SpaceToBatchBasicInt8) {
+  int8_t output[tflite::testing::kBasicInputOutputSize];
+  int8_t input_quantized[tflite::testing::kBasicInputOutputSize];
+  int8_t golden_quantized[tflite::testing::kBasicInputOutputSize];
   TF_LITE_MICRO_EXPECT_EQ(
-      kTfLiteOk, tflite::testing::TestSpaceToBatchNdQuantized(
-                     params, kInputDimsArray, kInput, kBlockShape, kPadding,
-                     kOutputDims, kGolden));
-}
-
-TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestSimple3DConstTest) {
-  int kInputDims[] = {3, 1, 4, 4};
-  int kBlockShapeDims[] = {1, 1};
-  int kPaddingDims[] = {2, 1, 2};
-  int kOutputDims[] = {3, 2, 2, 4};
-
-  int* kInputDimsArray[tflite::testing::kNumInputs];
-  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
-  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
-  kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
-
-  constexpr float kInput[] = {
-      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
-  };
-  constexpr int32_t kBlockShape[] = {2};
-  constexpr int32_t kPadding[] = {0, 0};
-  constexpr float kGolden[] = {
-      1, 2, 3, 4, 9, 10, 11, 12, 5, 6, 7, 8, 13, 14, 15, 16,
-  };
-  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
-  float output_data[kOutputCount];
-
-  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
-                          tflite::testing::TestSpaceToBatchNdFloat(
-                              kInputDimsArray, kInput, kBlockShape, kPadding,
-                              kOutputDims, kGolden, output_data));
-}
-
-TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestSimple3DPaddingConstTest) {
-  int kInputDims[] = {3, 1, 4, 4};
-  int kBlockShapeDims[] = {1, 1};
-  int kPaddingDims[] = {2, 1, 2};
-  int kOutputDims[] = {3, 2, 4, 4};
-
-  int* kInputDimsArray[tflite::testing::kNumInputs];
-  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
-  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
-  kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
-
-  constexpr float kInput[] = {
-      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
-  };
-  constexpr int32_t kBlockShape[] = {2};
-  constexpr int32_t kPadding[] = {2, 2};
-  constexpr float kGolden[] = {
-      0, 0, 0, 0, 1, 2, 3, 4, 9,  10, 11, 12, 0, 0, 0, 0,
-      0, 0, 0, 0, 5, 6, 7, 8, 13, 14, 15, 16, 0, 0, 0, 0,
-  };
-  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
-  float output_data[kOutputCount];
-
-  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
-                          tflite::testing::TestSpaceToBatchNdFloat(
-                              kInputDimsArray, kInput, kBlockShape, kPadding,
-                              kOutputDims, kGolden, output_data));
+      kTfLiteOk,
+      tflite::testing::TestSpaceToBatchNdQuantized(
+          tflite::testing::basic_input_dims, tflite::testing::basic_input,
+          input_quantized, 1.0f, 0, tflite::testing::basic_block_shape_dims,
+          tflite::testing::basic_block_shape, tflite::testing::basic_crops_dims,
+          tflite::testing::basic_crops, tflite::testing::basic_output_dims,
+          tflite::testing::basic_golden, golden_quantized, 1.0f, 0, output));
 }
 
 TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/lite/micro/testing/micro_test.h b/tensorflow/lite/micro/testing/micro_test.h
index bd2d93c..a28f4b6 100644
--- a/tensorflow/lite/micro/testing/micro_test.h
+++ b/tensorflow/lite/micro/testing/micro_test.h
@@ -264,11 +264,4 @@
     }                                                                        \
   } while (false)
 
-#define TF_LITE_MICRO_CHECK_FAIL()   \
-  do {                               \
-    if (micro_test::did_test_fail) { \
-      return kTfLiteError;           \
-    }                                \
-  } while (false)
-
 #endif  // TENSORFLOW_LITE_MICRO_TESTING_MICRO_TEST_H_