SPACE_TO_BATCH_ND: update output tensor shape (#2335)

@tensorflow/micro

Update the output tensor shape during prepare phase when the computed shape does not match the shape in the flatbuffer.

Added additional tests from TfLite.

See #2319 for additional details.

Resolves [b/313360569](https://issuetracker.google.com/313360569)

bug=fixes #2332 fixes #2319 #1646 #1629 #1231
diff --git a/tensorflow/lite/micro/examples/micro_speech/micro_speech_test.cc b/tensorflow/lite/micro/examples/micro_speech/micro_speech_test.cc
index f31728c..6fe75c1 100644
--- a/tensorflow/lite/micro/examples/micro_speech/micro_speech_test.cc
+++ b/tensorflow/lite/micro/examples/micro_speech/micro_speech_test.cc
@@ -32,13 +32,6 @@
 #include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
 #include "tensorflow/lite/micro/testing/micro_test.h"
 
-#define TF_LITE_MICRO_CHECK_FAIL()   \
-  do {                               \
-    if (micro_test::did_test_fail) { \
-      return kTfLiteError;           \
-    }                                \
-  } while (false)
-
 namespace {
 
 // Arena size is a guesstimate, followed by use of
diff --git a/tensorflow/lite/micro/kernels/space_to_batch_nd.cc b/tensorflow/lite/micro/kernels/space_to_batch_nd.cc
index 6b536ee..fd7ef92 100644
--- a/tensorflow/lite/micro/kernels/space_to_batch_nd.cc
+++ b/tensorflow/lite/micro/kernels/space_to_batch_nd.cc
@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -15,7 +15,10 @@
 
 #include "tensorflow/lite/kernels/internal/reference/space_to_batch_nd.h"
 
+#include <algorithm>
+
 #include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/internal/runtime_shape.h"
 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
 #include "tensorflow/lite/kernels/internal/types.h"
 #include "tensorflow/lite/kernels/kernel_util.h"
@@ -24,12 +27,11 @@
 #include "tensorflow/lite/micro/micro_utils.h"
 
 namespace tflite {
-
 namespace {
 
 constexpr int kInputTensor = 0;
 constexpr int kBlockShapeTensor = 1;
-constexpr int kCropsTensor = 2;
+constexpr int kPaddingTensor = 2;
 constexpr int kOutputTensor = 0;
 
 // Currently, only 3D NHC and 4D NHWC input/output op_context are supported.
@@ -44,6 +46,68 @@
   return context->AllocatePersistentBuffer(context, sizeof(SpaceToBatchParams));
 }
 
+TfLiteStatus ReshapeOutputTensor(TfLiteContext* context, const TfLiteNode* node,
+                                 const TfLiteTensor* input,
+                                 const TfLiteTensor* block_shape,
+                                 const TfLiteTensor* padding,
+                                 TfLiteTensor* output) {
+  TF_LITE_ENSURE(context, IsConstantOrPersistentTensor(block_shape));
+  TF_LITE_ENSURE(context, IsConstantOrPersistentTensor(padding));
+  const int32_t* block_shape_data = GetTensorData<int32_t>(block_shape);
+  const int32_t* padding_data = GetTensorData<int32_t>(padding);
+
+  TfLiteIntArray* input_dims = input->dims;
+  int spatial_dims_num = input_dims->size - 2;
+  // Block_shape should be a 1D tensor with dimension [spatial_dims_num].
+  TF_LITE_ENSURE_EQ(context, NumDimensions(block_shape), 1);
+  TF_LITE_ENSURE_EQ(context, block_shape->dims->data[0], spatial_dims_num);
+  // Padding should be a 2D tensor with dimension [spatial_dims_num, 2].
+  TF_LITE_ENSURE_EQ(context, NumDimensions(padding), 2);
+  TF_LITE_ENSURE_EQ(context, padding->dims->data[0], spatial_dims_num);
+  TF_LITE_ENSURE_EQ(context, padding->dims->data[1], 2);
+
+  // copy from input tensor as per TfLite code
+  RuntimeShape output_shape = GetTensorShape(input);
+  // keep a copy of the output tensor shape for later comparison
+  RuntimeShape old_output_shape = GetTensorShape(output);
+
+  // Ensures the input height and width (with padding) is a multiple of block
+  // shape height and width.
+  int output_batch_size = input_dims->data[0];
+  for (int dim = 0; dim < spatial_dims_num; ++dim) {
+    int final_dim_size = (input_dims->data[dim + 1] + padding_data[dim * 2] +
+                          padding_data[dim * 2 + 1]);
+    TF_LITE_ENSURE(context, block_shape_data[dim] != 0);
+    TF_LITE_ENSURE_EQ(context, final_dim_size % block_shape_data[dim], 0);
+    output_shape.SetDim(dim + 1, final_dim_size / block_shape_data[dim]);
+    output_batch_size *= block_shape_data[dim];
+  }
+  output_shape.SetDim(0, output_batch_size);
+  output_shape.SetDim(input_dims->size - 1,
+                      input_dims->data[input_dims->size - 1]);
+
+  // check if need to relocate output tensor dims
+  if (output_shape == old_output_shape) {
+    return kTfLiteOk;
+  } else if (output_shape.FlatSize() > old_output_shape.FlatSize() &&
+             output->data.data != nullptr) {
+    MicroPrintf(
+        "SPACE_TO_BATCH_ND: resizing flatbuffer tensor data is not supported");
+    return kTfLiteError;
+  }
+
+  // set the output tensor dims from output_shape
+  TF_LITE_ENSURE_EQ(context, input_dims->size, output->dims->size);
+  TfLiteEvalTensor* output_eval =
+      tflite::micro::GetEvalOutput(context, node, kOutputTensor);
+  TF_LITE_ENSURE_STATUS(tflite::micro::CreateWritableTensorDimsWithCopy(
+      context, output, output_eval));
+  std::copy_n(output_shape.DimsData(), output_shape.DimensionsCount(),
+              output->dims->data);
+
+  return kTfLiteOk;
+}
+
 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
   MicroContext* micro_context = GetMicroContext(context);
 
@@ -52,19 +116,47 @@
 
   TfLiteTensor* input =
       micro_context->AllocateTempInputTensor(node, kInputTensor);
+  TF_LITE_ENSURE(context, input != nullptr);
+  TfLiteTensor* block_shape =
+      micro_context->AllocateTempInputTensor(node, kBlockShapeTensor);
+  TF_LITE_ENSURE(context, block_shape != nullptr);
+  TfLiteTensor* padding =
+      micro_context->AllocateTempInputTensor(node, kPaddingTensor);
+  TF_LITE_ENSURE(context, padding != nullptr);
   TfLiteTensor* output =
       micro_context->AllocateTempOutputTensor(node, kOutputTensor);
-  TF_LITE_ENSURE(context, input != nullptr && output != nullptr);
+  TF_LITE_ENSURE(context, output != nullptr);
 
   TF_LITE_ENSURE(context, NumDimensions(input) >= kInputOutputMinDimensionNum);
   TF_LITE_ENSURE(context, NumDimensions(output) >= kInputOutputMinDimensionNum);
   TF_LITE_ENSURE(context, NumDimensions(input) <= kInputOutputMaxDimensionNum);
   TF_LITE_ENSURE(context, NumDimensions(output) <= kInputOutputMaxDimensionNum);
   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
+  TF_LITE_ENSURE(context,
+                 input->type == kTfLiteFloat32 || input->type == kTfLiteInt8);
+
+  TF_LITE_ENSURE(context, node->user_data != nullptr);
+  SpaceToBatchParams& params =
+      *(static_cast<SpaceToBatchParams*>(node->user_data));
+
+  if (input->type == kTfLiteInt8) {
+    TF_LITE_ENSURE(context, input->params.scale == output->params.scale);
+    TF_LITE_ENSURE(context,
+                   input->params.zero_point == output->params.zero_point);
+    params.output_offset = output->params.zero_point;
+  } else {
+    params.output_offset = 0;
+  }
+
+  TfLiteStatus status =
+      ReshapeOutputTensor(context, node, input, block_shape, padding, output);
 
   micro_context->DeallocateTempTfLiteTensor(input);
+  micro_context->DeallocateTempTfLiteTensor(block_shape);
+  micro_context->DeallocateTempTfLiteTensor(padding);
   micro_context->DeallocateTempTfLiteTensor(output);
-  return kTfLiteOk;
+
+  return status;
 }
 
 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
@@ -76,8 +168,8 @@
       tflite::micro::GetEvalInput(context, node, kInputTensor);
   const TfLiteEvalTensor* block_shape =
       tflite::micro::GetEvalInput(context, node, kBlockShapeTensor);
-  const TfLiteEvalTensor* crops =
-      tflite::micro::GetEvalInput(context, node, kCropsTensor);
+  const TfLiteEvalTensor* padding =
+      tflite::micro::GetEvalInput(context, node, kPaddingTensor);
   TfLiteEvalTensor* output =
       tflite::micro::GetEvalOutput(context, node, kOutputTensor);
 
@@ -88,8 +180,8 @@
           tflite::micro::GetTensorData<float>(input),
           tflite::micro::GetTensorShape(block_shape),
           tflite::micro::GetTensorData<int32_t>(block_shape),
-          tflite::micro::GetTensorShape(crops),
-          tflite::micro::GetTensorData<int32_t>(crops),
+          tflite::micro::GetTensorShape(padding),
+          tflite::micro::GetTensorData<int32_t>(padding),
           tflite::micro::GetTensorShape(output),
           tflite::micro::GetTensorData<float>(output));
       break;
@@ -99,8 +191,8 @@
           tflite::micro::GetTensorData<int8_t>(input),
           tflite::micro::GetTensorShape(block_shape),
           tflite::micro::GetTensorData<int32_t>(block_shape),
-          tflite::micro::GetTensorShape(crops),
-          tflite::micro::GetTensorData<int32_t>(crops),
+          tflite::micro::GetTensorShape(padding),
+          tflite::micro::GetTensorData<int32_t>(padding),
           tflite::micro::GetTensorShape(output),
           tflite::micro::GetTensorData<int8_t>(output));
       break;
diff --git a/tensorflow/lite/micro/kernels/space_to_batch_nd_test.cc b/tensorflow/lite/micro/kernels/space_to_batch_nd_test.cc
index eae185b..d45a606 100644
--- a/tensorflow/lite/micro/kernels/space_to_batch_nd_test.cc
+++ b/tensorflow/lite/micro/kernels/space_to_batch_nd_test.cc
@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -14,6 +14,8 @@
 ==============================================================================*/
 
 #include <cstdint>
+#include <limits>
+#include <type_traits>
 
 #include "tensorflow/lite/c/builtin_op_data.h"
 #include "tensorflow/lite/c/common.h"
@@ -25,98 +27,160 @@
 namespace testing {
 namespace {
 
-constexpr int kBasicInputOutputSize = 16;
-int basic_input_dims[] = {4, 1, 4, 4, 1};
-const float basic_input[kBasicInputOutputSize] = {
-    1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
-int basic_block_shape_dims[] = {1, 2};
-const int32_t basic_block_shape[] = {2, 2};
-int basic_crops_dims[] = {1, 4};
-const int32_t basic_crops[] = {0, 0, 0, 0};
-int basic_output_dims[] = {4, 4, 2, 2, 1};
-const float basic_golden[kBasicInputOutputSize] = {1, 3, 9,  11, 2, 4, 10, 12,
-                                                   5, 7, 13, 15, 6, 8, 14, 16};
+constexpr float kTestTolerance = 1e-05;
+constexpr int kNumInputs = 3;
+constexpr int kNumOutputs = 1;
+constexpr int kInputTensorIndex = 0;
+constexpr int kBlockShapeTensorIndex = 1;
+constexpr int kPaddingTensorIndex = 2;
+constexpr int kOutputTensorIndex = 3;
 
-template <typename T>
-TfLiteStatus ValidateSpaceToBatchNdGoldens(TfLiteTensor* tensors,
-                                           int tensors_size, const T* golden,
-                                           T* output, int output_size) {
-  int inputs_array_data[] = {3, 0, 1, 2};
-  TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
-  int outputs_array_data[] = {1, 3};
-  TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
+// min/max are used to compute scale, zero-point (asymmetric)
+template <typename T, size_t kInputSize, size_t kOutputSize>
+struct TestQuantParams {
+  // quantization parameters
+  float data_min;              // input data minimum value
+  float data_max;              // input data maximum value
+  T output_data[kOutputSize];  // quantized output storage
+  T input_data[kInputSize];    // quantized input storage
+};
 
-  const TFLMRegistration registration = Register_SPACE_TO_BATCH_ND();
-  micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
+TfLiteStatus ExecuteSpaceToBatchNdTest(TfLiteTensor* tensors,
+                                       int tensors_count) {
+  int kInputArrayData[] = {kNumInputs, kInputTensorIndex,
+                           kBlockShapeTensorIndex, kPaddingTensorIndex};
+  TfLiteIntArray* inputs_array = IntArrayFromInts(kInputArrayData);
+  int kOutputArrayData[] = {kNumOutputs, kOutputTensorIndex};
+  TfLiteIntArray* outputs_array = IntArrayFromInts(kOutputArrayData);
+
+  const TFLMRegistration registration = tflite::Register_SPACE_TO_BATCH_ND();
+  micro::KernelRunner runner(registration, tensors, tensors_count, inputs_array,
                              outputs_array, nullptr);
 
-  TF_LITE_ENSURE_STATUS(runner.InitAndPrepare());
-  TF_LITE_ENSURE_STATUS(runner.Invoke());
-
-  for (int i = 0; i < output_size; ++i) {
-    // TODO(b/158102673): workaround for not having fatal test assertions.
-    TF_LITE_MICRO_EXPECT_EQ(golden[i], output[i]);
-    if (golden[i] != output[i]) {
-      return kTfLiteError;
-    }
+  TfLiteStatus status = runner.InitAndPrepare();
+  if (status != kTfLiteOk) {
+    return status;
   }
-  return kTfLiteOk;
+  status = runner.Invoke();
+
+  return status;
 }
 
 TfLiteStatus TestSpaceToBatchNdFloat(
-    int* input_dims_data, const float* input_data, int* block_shape_dims_data,
-    const int32_t* block_shape_data, int* crops_dims_data,
-    const int32_t* crops_data, int* output_dims_data, const float* golden,
-    float* output_data) {
-  TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
-  TfLiteIntArray* block_shape_dims = IntArrayFromInts(block_shape_dims_data);
-  TfLiteIntArray* crops_dims = IntArrayFromInts(crops_dims_data);
+    int* input_dims_data[kNumInputs], const float* input_data,
+    const int32_t* block_shape_data, const int32_t* padding_data,
+    int* output_dims_data, const float* golden_data, float* output_data) {
+  TfLiteIntArray* input_dims =
+      IntArrayFromInts(input_dims_data[kInputTensorIndex]);
+  TfLiteIntArray* block_shape_dims =
+      IntArrayFromInts(input_dims_data[kBlockShapeTensorIndex]);
+  TfLiteIntArray* padding_dims =
+      IntArrayFromInts(input_dims_data[kPaddingTensorIndex]);
   TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
 
-  constexpr int inputs_size = 3;
-  constexpr int outputs_size = 1;
-  constexpr int tensors_size = inputs_size + outputs_size;
-  TfLiteTensor tensors[tensors_size] = {
-      CreateTensor(input_data, input_dims),
-      CreateTensor(block_shape_data, block_shape_dims),
-      CreateTensor(crops_data, crops_dims),
-      CreateTensor(output_data, output_dims),
-  };
+  constexpr int kTensorsCount = kNumInputs + kNumOutputs;
+  TfLiteTensor tensors[kTensorsCount];
+  tensors[kInputTensorIndex] =
+      tflite::testing::CreateTensor(input_data, input_dims);
+  tensors[kBlockShapeTensorIndex] =
+      tflite::testing::CreateTensor(block_shape_data, block_shape_dims);
+  tensors[kBlockShapeTensorIndex].allocation_type = kTfLiteMmapRo;
+  tensors[kPaddingTensorIndex] =
+      tflite::testing::CreateTensor(padding_data, padding_dims);
+  tensors[kPaddingTensorIndex].allocation_type = kTfLiteMmapRo;
+  tensors[kOutputTensorIndex] =
+      tflite::testing::CreateTensor(output_data, output_dims);
 
-  return ValidateSpaceToBatchNdGoldens(tensors, tensors_size, golden,
-                                       output_data, ElementCount(*output_dims));
+  TfLiteStatus status = ExecuteSpaceToBatchNdTest(tensors, kTensorsCount);
+  if (status != kTfLiteOk) {
+    return status;
+  }
+
+  // check output dimensions (relocated) against original dimensions
+  TF_LITE_MICRO_EXPECT_EQ(output_dims->size,
+                          tensors[kOutputTensorIndex].dims->size);
+  TF_LITE_MICRO_CHECK_FAIL();
+  for (int i = 0; i < output_dims->size; i++) {
+    TF_LITE_MICRO_EXPECT_EQ(output_dims->data[i],
+                            tensors[kOutputTensorIndex].dims->data[i]);
+    // TODO(b/158102673): workaround for not having fatal test assertions.
+    TF_LITE_MICRO_CHECK_FAIL();
+  }
+
+  // check output data against golden
+  const int output_count = ElementCount(*output_dims);
+  for (int i = 0; i < output_count; i++) {
+    TF_LITE_MICRO_EXPECT_NEAR(golden_data[i], output_data[i], kTestTolerance);
+    // TODO(b/158102673): workaround for not having fatal test assertions.
+    TF_LITE_MICRO_CHECK_FAIL();
+  }
+
+  return kTfLiteOk;
 }
 
-template <typename T>
+template <typename T, size_t kInCount, size_t kOutCount>
 TfLiteStatus TestSpaceToBatchNdQuantized(
-    int* input_dims_data, const float* input_data, T* input_quantized,
-    float input_scale, int input_zero_point, int* block_shape_dims_data,
-    const int32_t* block_shape_data, int* crops_dims_data,
-    const int32_t* crops_data, int* output_dims_data, const float* golden,
-    T* golden_quantized, float output_scale, int output_zero_point,
-    T* output_data) {
-  TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
-  TfLiteIntArray* block_shape_dims = IntArrayFromInts(block_shape_dims_data);
-  TfLiteIntArray* crops_dims = IntArrayFromInts(crops_dims_data);
+    TestQuantParams<T, kInCount, kOutCount>& params,
+    int* input_dims_data[kNumInputs], const float* input_data,
+    const int32_t* block_shape_data, const int32_t* padding_data,
+    int* output_dims_data, const float* golden_data) {
+  TfLiteIntArray* input_dims =
+      IntArrayFromInts(input_dims_data[kInputTensorIndex]);
+  TfLiteIntArray* block_shape_dims =
+      IntArrayFromInts(input_dims_data[kBlockShapeTensorIndex]);
+  TfLiteIntArray* padding_dims =
+      IntArrayFromInts(input_dims_data[kPaddingTensorIndex]);
   TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
 
-  constexpr int inputs_size = 3;
-  constexpr int outputs_size = 1;
-  constexpr int tensors_size = inputs_size + outputs_size;
-  TfLiteTensor tensors[tensors_size] = {
-      tflite::testing::CreateQuantizedTensor(input_data, input_quantized,
-                                             input_dims, input_scale,
-                                             input_zero_point),
-      tflite::testing::CreateTensor(block_shape_data, block_shape_dims),
-      tflite::testing::CreateTensor(crops_data, crops_dims),
-      tflite::testing::CreateQuantizedTensor(output_data, output_dims,
-                                             output_scale, output_zero_point),
-  };
-  tflite::Quantize(golden, golden_quantized, ElementCount(*output_dims),
-                   output_scale, output_zero_point);
+  int zero_point =
+      tflite::testing::ZeroPointFromMinMax<T>(params.data_min, params.data_max);
+  float scale =
+      tflite::testing::ScaleFromMinMax<T>(params.data_min, params.data_max);
 
-  return ValidateSpaceToBatchNdGoldens(tensors, tensors_size, golden_quantized,
-                                       output_data, ElementCount(*output_dims));
+  constexpr int kTensorsCount = kNumInputs + kNumOutputs;
+  TfLiteTensor tensors[kTensorsCount];
+  tensors[kInputTensorIndex] = tflite::testing::CreateQuantizedTensor(
+      input_data, params.input_data, input_dims, scale, zero_point);
+  tensors[kBlockShapeTensorIndex] =
+      tflite::testing::CreateTensor(block_shape_data, block_shape_dims);
+  tensors[kBlockShapeTensorIndex].allocation_type = kTfLiteMmapRo;
+  tensors[kPaddingTensorIndex] =
+      tflite::testing::CreateTensor(padding_data, padding_dims);
+  tensors[kPaddingTensorIndex].allocation_type = kTfLiteMmapRo;
+  tensors[kOutputTensorIndex] = tflite::testing::CreateQuantizedTensor(
+      params.output_data, output_dims, scale, zero_point);
+
+  TfLiteStatus status = ExecuteSpaceToBatchNdTest(tensors, kTensorsCount);
+  if (status != kTfLiteOk) {
+    return status;
+  }
+
+  // check output dimensions (relocated) against original dimensions
+  TF_LITE_MICRO_EXPECT_EQ(output_dims->size,
+                          tensors[kOutputTensorIndex].dims->size);
+  TF_LITE_MICRO_CHECK_FAIL();
+  for (int i = 0; i < output_dims->size; i++) {
+    TF_LITE_MICRO_EXPECT_EQ(output_dims->data[i],
+                            tensors[kOutputTensorIndex].dims->data[i]);
+    // TODO(b/158102673): workaround for not having fatal test assertions.
+    TF_LITE_MICRO_CHECK_FAIL();
+  }
+
+  // check output data against golden
+  const int output_count = ElementCount(*output_dims);
+  const float quantization_tolerance =
+      (params.data_max - params.data_min) /
+      (std::numeric_limits<T>::max() - std::numeric_limits<T>::min());
+  for (int i = 0; i < output_count; i++) {
+    float output_dequantized_data =
+        (params.output_data[i] - zero_point) * scale;
+    TF_LITE_MICRO_EXPECT_NEAR(golden_data[i], output_dequantized_data,
+                              quantization_tolerance);
+    // TODO(b/158102673): workaround for not having fatal test assertions.
+    TF_LITE_MICRO_CHECK_FAIL();
+  }
+
+  return kTfLiteOk;
 }
 
 }  // namespace
@@ -125,30 +189,313 @@
 
 TF_LITE_MICRO_TESTS_BEGIN
 
-TF_LITE_MICRO_TEST(SpaceToBatchBasicFloat) {
-  float output[tflite::testing::kBasicInputOutputSize];
-  TF_LITE_MICRO_EXPECT_EQ(
-      kTfLiteOk,
-      tflite::testing::TestSpaceToBatchNdFloat(
-          tflite::testing::basic_input_dims, tflite::testing::basic_input,
-          tflite::testing::basic_block_shape_dims,
-          tflite::testing::basic_block_shape, tflite::testing::basic_crops_dims,
-          tflite::testing::basic_crops, tflite::testing::basic_output_dims,
-          tflite::testing::basic_golden, output));
+TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestInvalidShapeTest) {
+  int kInputDims[] = {4, 1, 3, 3, 1};  // invalid shape
+  int kBlockShapeDims[] = {1, 2};
+  int kPaddingDims[] = {2, 2, 2};
+  int kOutputDims[] = {4, 0, 0, 0, 0};  // placeholder dims, not used
+
+  int* kInputDimsArray[tflite::testing::kNumInputs];
+  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
+  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
+  kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
+
+  constexpr float kInput[] = {1, 2, 3, 4, 5, 6, 7, 8, 9};
+  constexpr int32_t kBlockShape[] = {2, 2};
+  constexpr int32_t kPadding[] = {0, 0, 0, 0};
+  constexpr float kGolden[] = {0};  // placeholder data, not used
+  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
+  float output_data[kOutputCount];
+
+  TF_LITE_MICRO_EXPECT_EQ(kTfLiteError,
+                          tflite::testing::TestSpaceToBatchNdFloat(
+                              kInputDimsArray, kInput, kBlockShape, kPadding,
+                              kOutputDims, kGolden, output_data));
 }
 
-TF_LITE_MICRO_TEST(SpaceToBatchBasicInt8) {
-  int8_t output[tflite::testing::kBasicInputOutputSize];
-  int8_t input_quantized[tflite::testing::kBasicInputOutputSize];
-  int8_t golden_quantized[tflite::testing::kBasicInputOutputSize];
+TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestInvalidOutputShapeTest) {
+  int kInputDims[] = {3, 1, 8, 1};
+  int kBlockShapeDims[] = {1, 1};
+  int kPaddingDims[] = {2, 1, 2};
+  int kOutputDims[] = {3, 1, 6, 1};  // invalid shape
+
+  int* kInputDimsArray[tflite::testing::kNumInputs];
+  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
+  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
+  kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
+
+  constexpr float kInput[] = {1, 1, 1, 1, 1, 1, 1, 1};
+  constexpr int32_t kBlockShape[] = {2};
+  constexpr int32_t kPadding[] = {2, 2};
+  constexpr float kGolden[] = {
+      0, 0, 0, 0, 0, 0,  // placeholder data, not used
+  };
+  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
+  float output_data[kOutputCount];
+
+  TF_LITE_MICRO_EXPECT_EQ(kTfLiteError,
+                          tflite::testing::TestSpaceToBatchNdFloat(
+                              kInputDimsArray, kInput, kBlockShape, kPadding,
+                              kOutputDims, kGolden, output_data));
+}
+
+TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestValidOutputShapeTest) {
+  int kInputDims[] = {3, 1, 8, 1};
+  int kBlockShapeDims[] = {1, 1};
+  int kPaddingDims[] = {2, 1, 2};
+  int kOutputDims[] = {3, 2, 6, 1};
+
+  int* kInputDimsArray[tflite::testing::kNumInputs];
+  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
+  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
+  kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
+
+  constexpr float kInput[] = {1, 1, 1, 1, 1, 1, 1, 1};
+  constexpr int32_t kBlockShape[] = {2};
+  constexpr int32_t kPadding[] = {2, 2};
+  constexpr float kGolden[] = {0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0};
+  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
+  float output_data[kOutputCount];
+
+  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
+                          tflite::testing::TestSpaceToBatchNdFloat(
+                              kInputDimsArray, kInput, kBlockShape, kPadding,
+                              kOutputDims, kGolden, output_data));
+}
+
+TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestSimpleConstTest) {
+  int kInputDims[] = {4, 1, 4, 4, 1};
+  int kBlockShapeDims[] = {1, 2};
+  int kPaddingDims[] = {2, 2, 2};
+  int kOutputDims[] = {4, 4, 2, 2, 1};
+
+  int* kInputDimsArray[tflite::testing::kNumInputs];
+  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
+  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
+  kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
+
+  constexpr float kInput[] = {
+      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
+  };
+  constexpr int32_t kBlockShape[] = {2, 2};
+  constexpr int32_t kPadding[] = {0, 0, 0, 0};
+  constexpr float kGolden[] = {
+      1, 3, 9, 11, 2, 4, 10, 12, 5, 7, 13, 15, 6, 8, 14, 16,
+  };
+  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
+  float output_data[kOutputCount];
+
+  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
+                          tflite::testing::TestSpaceToBatchNdFloat(
+                              kInputDimsArray, kInput, kBlockShape, kPadding,
+                              kOutputDims, kGolden, output_data));
+}
+
+TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestMultipleInputBatchesConstTest) {
+  int kInputDims[] = {4, 2, 2, 4, 1};
+  int kBlockShapeDims[] = {1, 2};
+  int kPaddingDims[] = {2, 2, 2};
+  int kOutputDims[] = {4, 8, 1, 2, 1};
+
+  int* kInputDimsArray[tflite::testing::kNumInputs];
+  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
+  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
+  kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
+
+  constexpr float kInput[] = {
+      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
+  };
+  constexpr int32_t kBlockShape[] = {2, 2};
+  constexpr int32_t kPadding[] = {0, 0, 0, 0};
+  constexpr float kGolden[] = {
+      1, 3, 9, 11, 2, 4, 10, 12, 5, 7, 13, 15, 6, 8, 14, 16,
+  };
+  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
+  float output_data[kOutputCount];
+
+  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
+                          tflite::testing::TestSpaceToBatchNdFloat(
+                              kInputDimsArray, kInput, kBlockShape, kPadding,
+                              kOutputDims, kGolden, output_data));
+}
+
+TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestSimplePaddingConstTest) {
+  int kInputDims[] = {4, 1, 5, 2, 1};
+  int kBlockShapeDims[] = {1, 2};
+  int kPaddingDims[] = {2, 2, 2};
+  int kOutputDims[] = {4, 6, 2, 2, 1};
+
+  int* kInputDimsArray[tflite::testing::kNumInputs];
+  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
+  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
+  kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
+
+  constexpr float kInput[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
+  constexpr int32_t kBlockShape[] = {3, 2};
+  constexpr int32_t kPadding[] = {1, 0, 2, 0};
+  constexpr float kGolden[] = {
+      0, 0, 0, 5, 0, 0, 0, 6, 0, 1, 0, 7, 0, 2, 0, 8, 0, 3, 0, 9, 0, 4, 0, 10,
+  };
+  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
+  float output_data[kOutputCount];
+
+  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
+                          tflite::testing::TestSpaceToBatchNdFloat(
+                              kInputDimsArray, kInput, kBlockShape, kPadding,
+                              kOutputDims, kGolden, output_data));
+}
+
+TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestComplexPaddingConstTest) {
+  int kInputDims[] = {4, 1, 4, 2, 1};
+  int kBlockShapeDims[] = {1, 2};
+  int kPaddingDims[] = {2, 2, 2};
+  int kOutputDims[] = {4, 6, 2, 4, 1};
+
+  int* kInputDimsArray[tflite::testing::kNumInputs];
+  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
+  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
+  kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
+
+  constexpr float kInput[] = {1, 2, 3, 4, 5, 6, 7, 8};
+  constexpr int32_t kBlockShape[] = {3, 2};
+  constexpr int32_t kPadding[] = {1, 1, 2, 4};
+  constexpr float kGolden[] = {
+      0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 1, 0, 0, 0, 7, 0, 0,
+      0, 2, 0, 0, 0, 8, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0,
+  };
+  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
+  float output_data[kOutputCount];
+
+  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
+                          tflite::testing::TestSpaceToBatchNdFloat(
+                              kInputDimsArray, kInput, kBlockShape, kPadding,
+                              kOutputDims, kGolden, output_data));
+}
+
+TF_LITE_MICRO_TEST(QuantizedSpaceToBatchNDOpTestSimplePaddingConstTestInt8) {
+  int kInputDims[] = {4, 1, 5, 2, 1};
+  int kBlockShapeDims[] = {1, 2};
+  int kPaddingDims[] = {2, 2, 2};
+  int kOutputDims[] = {4, 6, 2, 2, 1};
+
+  int* kInputDimsArray[tflite::testing::kNumInputs];
+  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
+  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
+  kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
+
+  constexpr float kInput[] = {
+      -0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, -0.9, 0.1,
+  };
+  constexpr int kInputCount = std::extent<decltype(kInput)>::value;
+  constexpr int32_t kBlockShape[] = {3, 2};
+  constexpr int32_t kPadding[] = {1, 0, 2, 0};
+  constexpr float kGolden[] = {
+      0, 0,   0, -0.5, 0, 0,    0, 0.6,  0, -0.1, 0, -0.7,
+      0, 0.2, 0, 0.8,  0, -0.3, 0, -0.9, 0, 0.4,  0, 0.1,
+  };
+  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
+
+  tflite::testing::TestQuantParams<int8_t, kInputCount, kOutputCount> params = {
+      -1,
+      std::numeric_limits<int8_t>::max() /
+          static_cast<float>(std::numeric_limits<int8_t>::max() + 1),
+      {},
+      {}};
+
   TF_LITE_MICRO_EXPECT_EQ(
-      kTfLiteOk,
-      tflite::testing::TestSpaceToBatchNdQuantized(
-          tflite::testing::basic_input_dims, tflite::testing::basic_input,
-          input_quantized, 1.0f, 0, tflite::testing::basic_block_shape_dims,
-          tflite::testing::basic_block_shape, tflite::testing::basic_crops_dims,
-          tflite::testing::basic_crops, tflite::testing::basic_output_dims,
-          tflite::testing::basic_golden, golden_quantized, 1.0f, 0, output));
+      kTfLiteOk, tflite::testing::TestSpaceToBatchNdQuantized(
+                     params, kInputDimsArray, kInput, kBlockShape, kPadding,
+                     kOutputDims, kGolden));
+}
+
+TF_LITE_MICRO_TEST(QuantizedSpaceToBatchNDOpTestComplexPaddingConstTest) {
+  int kInputDims[] = {4, 1, 4, 2, 1};
+  int kBlockShapeDims[] = {1, 2};
+  int kPaddingDims[] = {2, 2, 2};
+  int kOutputDims[] = {4, 6, 2, 4, 1};
+
+  int* kInputDimsArray[tflite::testing::kNumInputs];
+  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
+  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
+  kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
+
+  constexpr float kInput[] = {
+      -0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8,
+  };
+  constexpr int kInputCount = std::extent<decltype(kInput)>::value;
+  constexpr int32_t kBlockShape[] = {3, 2};
+  constexpr int32_t kPadding[] = {1, 1, 2, 4};
+  constexpr float kGolden[] = {
+      0, 0,    0, 0, 0, -0.5, 0, 0, 0, 0,   0, 0, 0, 0.6, 0, 0,
+      0, -0.1, 0, 0, 0, -0.7, 0, 0, 0, 0.2, 0, 0, 0, 0.8, 0, 0,
+      0, -0.3, 0, 0, 0, 0,    0, 0, 0, 0.4, 0, 0, 0, 0,   0, 0,
+  };
+  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
+
+  tflite::testing::TestQuantParams<int8_t, kInputCount, kOutputCount> params = {
+      -1, 1, {}, {}};
+
+  TF_LITE_MICRO_EXPECT_EQ(
+      kTfLiteOk, tflite::testing::TestSpaceToBatchNdQuantized(
+                     params, kInputDimsArray, kInput, kBlockShape, kPadding,
+                     kOutputDims, kGolden));
+}
+
+TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestSimple3DConstTest) {
+  int kInputDims[] = {3, 1, 4, 4};
+  int kBlockShapeDims[] = {1, 1};
+  int kPaddingDims[] = {2, 1, 2};
+  int kOutputDims[] = {3, 2, 2, 4};
+
+  int* kInputDimsArray[tflite::testing::kNumInputs];
+  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
+  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
+  kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
+
+  constexpr float kInput[] = {
+      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
+  };
+  constexpr int32_t kBlockShape[] = {2};
+  constexpr int32_t kPadding[] = {0, 0};
+  constexpr float kGolden[] = {
+      1, 2, 3, 4, 9, 10, 11, 12, 5, 6, 7, 8, 13, 14, 15, 16,
+  };
+  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
+  float output_data[kOutputCount];
+
+  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
+                          tflite::testing::TestSpaceToBatchNdFloat(
+                              kInputDimsArray, kInput, kBlockShape, kPadding,
+                              kOutputDims, kGolden, output_data));
+}
+
+TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestSimple3DPaddingConstTest) {
+  int kInputDims[] = {3, 1, 4, 4};
+  int kBlockShapeDims[] = {1, 1};
+  int kPaddingDims[] = {2, 1, 2};
+  int kOutputDims[] = {3, 2, 4, 4};
+
+  int* kInputDimsArray[tflite::testing::kNumInputs];
+  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
+  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
+  kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
+
+  constexpr float kInput[] = {
+      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
+  };
+  constexpr int32_t kBlockShape[] = {2};
+  constexpr int32_t kPadding[] = {2, 2};
+  constexpr float kGolden[] = {
+      0, 0, 0, 0, 1, 2, 3, 4, 9,  10, 11, 12, 0, 0, 0, 0,
+      0, 0, 0, 0, 5, 6, 7, 8, 13, 14, 15, 16, 0, 0, 0, 0,
+  };
+  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
+  float output_data[kOutputCount];
+
+  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
+                          tflite::testing::TestSpaceToBatchNdFloat(
+                              kInputDimsArray, kInput, kBlockShape, kPadding,
+                              kOutputDims, kGolden, output_data));
 }
 
 TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/lite/micro/testing/micro_test.h b/tensorflow/lite/micro/testing/micro_test.h
index a28f4b6..bd2d93c 100644
--- a/tensorflow/lite/micro/testing/micro_test.h
+++ b/tensorflow/lite/micro/testing/micro_test.h
@@ -264,4 +264,11 @@
     }                                                                        \
   } while (false)
 
+#define TF_LITE_MICRO_CHECK_FAIL()   \
+  do {                               \
+    if (micro_test::did_test_fail) { \
+      return kTfLiteError;           \
+    }                                \
+  } while (false)
+
 #endif  // TENSORFLOW_LITE_MICRO_TESTING_MICRO_TEST_H_