Compute output shapes for some kernels (#2356)

@tensorflow/micro

Update the output tensor shape during prepare phase when the computed shape does not match the shape in the flatbuffer.

Kernels:
- BATCH_TO_SPACE_ND
- SPACE_TO_BATCH_ND
- CONV
- RESHAPE
- EXPAND_DIMS
- DEPTHWISE_CONV

Update CMSIS_NN and ARC_MLI optimized kernels.
Add additional tests from TfLite for BATCH_TO_SPACE_ND and SPACE_TO_BATCH_ND.
Update existing tests.
Add tests for Keras model using convolution with dilation > 1.

Update memory_arena_threshold_test to increase total, tail, and persistent allocation sizes:
- Add 20 bytes for CONV output shape
- Add 15 bytes for arena allocation alignment
- x2 convolution layers

Update micro_speech_test arena size as per description in C++ code.

See #2319 for additional details.

Resolves [b/317362237](https://issuetracker.google.com/317362237)

bug=fixes #2368 #1646 #1629 #1231 #2338 #2319
diff --git a/tensorflow/lite/micro/examples/micro_speech/micro_speech_test.cc b/tensorflow/lite/micro/examples/micro_speech/micro_speech_test.cc
index f31728c..3ce32f3 100644
--- a/tensorflow/lite/micro/examples/micro_speech/micro_speech_test.cc
+++ b/tensorflow/lite/micro/examples/micro_speech/micro_speech_test.cc
@@ -32,19 +32,13 @@
 #include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
 #include "tensorflow/lite/micro/testing/micro_test.h"
 
-#define TF_LITE_MICRO_CHECK_FAIL()   \
-  do {                               \
-    if (micro_test::did_test_fail) { \
-      return kTfLiteError;           \
-    }                                \
-  } while (false)
-
 namespace {
 
 // Arena size is a guesstimate, followed by use of
 // MicroInterpreter::arena_used_bytes() on both the AudioPreprocessor and
-// MicroSpeech models and using the larger of the two results.
-constexpr size_t kArenaSize = 28584;  // xtensa p6
+// MicroSpeech models and using the larger of the two results plus the
+// arena alignment size (16).
+constexpr size_t kArenaSize = 28664;  // xtensa p6
 alignas(16) uint8_t g_arena[kArenaSize];
 
 using Features = int8_t[kFeatureCount][kFeatureSize];
diff --git a/tensorflow/lite/micro/kernels/arc_mli/conv.cc b/tensorflow/lite/micro/kernels/arc_mli/conv.cc
index 41d2c53..896e228 100644
--- a/tensorflow/lite/micro/kernels/arc_mli/conv.cc
+++ b/tensorflow/lite/micro/kernels/arc_mli/conv.cc
@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -29,6 +29,7 @@
 #include "tensorflow/lite/micro/kernels/arc_mli/mli_tf_utils.h"
 #include "tensorflow/lite/micro/kernels/arc_mli/scratch_buf_mgr.h"
 #include "tensorflow/lite/micro/kernels/arc_mli/scratch_buffers.h"
+#include "tensorflow/lite/micro/kernels/conv.h"
 #include "tensorflow/lite/micro/kernels/kernel_util.h"
 #include "tensorflow/lite/micro/micro_log.h"
 
@@ -122,7 +123,7 @@
 TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
                              const TfLiteConvParams* params, int width,
                              int height, int filter_width, int filter_height,
-                             int out_width, int out_height,
+                             int* out_width, int* out_height,
                              const TfLiteType data_type, OpData* data) {
   bool has_bias = node->inputs->size == 3;
   // Check number of inputs/outputs
@@ -134,7 +135,7 @@
   data->padding = ComputePaddingHeightWidth(
       params->stride_height, params->stride_width,
       params->dilation_height_factor, params->dilation_width_factor, height,
-      width, filter_height, filter_width, padding, &out_height, &out_width);
+      width, filter_height, filter_width, padding, out_height, out_width);
   // Note that quantized inference requires that all tensors have their
   // parameters set. This is usually done during quantized training.
 #if !defined(TF_LITE_STRIP_REFERENCE_IMPL)
@@ -167,6 +168,7 @@
 #endif
   return kTfLiteOk;
 }
+
 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
   TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
   return context->AllocatePersistentBuffer(context, sizeof(OpData));
@@ -190,6 +192,17 @@
   TfLiteTensor* bias =
       micro_context->AllocateTempInputTensor(context, node, kBiasTensor);
 
+  // Check dimensionality of input, filter, output
+  TF_LITE_ENSURE_EQ(context, input->dims->size, 4);
+  TF_LITE_ENSURE_EQ(context, filter->dims->size, 4);
+  TF_LITE_ENSURE_EQ(context, output->dims->size, 4);
+
+  // Check input channels matching filter
+  const int input_channels = input->dims->data[3];
+  const int filter_input_channels = filter->dims->data[3];
+  TF_LITE_ENSURE(context, filter_input_channels > 0);
+  TF_LITE_ENSURE_EQ(context, input_channels % filter_input_channels, 0);
+
   int input_width = input->dims->data[2];
   int input_height = input->dims->data[1];
 #if defined(MLI_2_0) && !defined(MLI_2_0_KRNL_TEST)
@@ -199,8 +212,8 @@
   int filter_width = filter->dims->data[2];
   int filter_height = filter->dims->data[1];
 #endif
-  int output_width = output->dims->data[2];
-  int output_height = output->dims->data[1];
+  int output_width = 0;
+  int output_height = 0;
 
   // Dynamically allocate per-channel quantization parameters.
   const int num_channels = filter->dims->data[kConvQuantizedDimension];
@@ -235,7 +248,11 @@
 
   TF_LITE_ENSURE_STATUS(CalculateOpData(
       context, node, params, input_width, input_height, filter_width,
-      filter_height, output_width, output_height, input->type, data));
+      filter_height, &output_width, &output_height, input->type, data));
+
+  // compute output tensor shape and relocate shape data
+  TF_LITE_ENSURE_STATUS(ConvReshapeOutputTensor(
+      context, node, input, filter, output, output_height, output_width));
 
   data->input_zero_point = input->params.zero_point;
   data->filter_zero_point = filter->params.zero_point;
diff --git a/tensorflow/lite/micro/kernels/arc_mli/depthwise_conv.cc b/tensorflow/lite/micro/kernels/arc_mli/depthwise_conv.cc
index c2c9cd5..4fa5e94 100644
--- a/tensorflow/lite/micro/kernels/arc_mli/depthwise_conv.cc
+++ b/tensorflow/lite/micro/kernels/arc_mli/depthwise_conv.cc
@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -30,6 +30,7 @@
 #include "tensorflow/lite/micro/kernels/arc_mli/mli_tf_utils.h"
 #include "tensorflow/lite/micro/kernels/arc_mli/scratch_buf_mgr.h"
 #include "tensorflow/lite/micro/kernels/arc_mli/scratch_buffers.h"
+#include "tensorflow/lite/micro/kernels/depthwise_conv.h"
 #include "tensorflow/lite/micro/kernels/kernel_util.h"
 #include "tensorflow/lite/micro/micro_log.h"
 
@@ -118,17 +119,16 @@
 TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
                              TfLiteDepthwiseConvParams* params, int width,
                              int height, int filter_width, int filter_height,
+                             int* out_width, int* out_height,
                              const TfLiteType data_type, OpData* data) {
   bool has_bias = node->inputs->size == 3;
   // Check number of inputs/outputs
   TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
   TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
 
-  int unused_output_height, unused_output_width;
   data->padding = ComputePaddingHeightWidth(
       params->stride_height, params->stride_width, 1, 1, height, width,
-      filter_height, filter_width, params->padding, &unused_output_height,
-      &unused_output_width);
+      filter_height, filter_width, params->padding, out_height, out_width);
 
   // Note that quantized inference requires that all tensors have their
   // parameters set. This is usually done during quantized training.
@@ -182,6 +182,25 @@
   const TfLiteTensor* bias =
       AllocateTempInputTensor(context, node, kBiasTensor);
 
+  // Check dimensionality of input, filter, output
+  TF_LITE_ENSURE_EQ(context, input->dims->size, 4);
+  TF_LITE_ENSURE_EQ(context, filter->dims->size, 4);
+  TF_LITE_ENSURE_EQ(context, output->dims->size, 4);
+  TF_LITE_ENSURE(context, params.dilation_height_factor > 0);
+  TF_LITE_ENSURE(context, params.dilation_width_factor > 0);
+
+  // Filter in DepthwiseConv is expected to be [1, height, width, channels].
+  TF_LITE_ENSURE_EQ(context, filter->dims->data[0], 1);
+
+  // Check input channels matching filter
+  const int num_filter_channels = filter->dims->data[3];
+  const int num_input_channels = input->dims->data[3];
+  TF_LITE_ENSURE(context, num_input_channels != 0);
+  TF_LITE_ENSURE_EQ(context, num_filter_channels % num_input_channels, 0);
+
+  int output_width = 0;
+  int output_height = 0;
+
   const TfLiteType data_type = input->type;
   int width = SizeOfDimension(input, 2);
   int height = SizeOfDimension(input, 1);
@@ -227,9 +246,13 @@
                       affine_quantization->zero_point->size);
   }
 
-  TF_LITE_ENSURE_STATUS(CalculateOpData(context, node, params, width, height,
-                                        filter_width, filter_height, data_type,
-                                        data));
+  TF_LITE_ENSURE_STATUS(CalculateOpData(
+      context, node, params, width, height, filter_width, filter_height,
+      &output_width, &output_height, data_type, data));
+
+  // compute output tensor shape and relocate shape data
+  TF_LITE_ENSURE_STATUS(DepthwiseConvReshapeOutputTensor(
+      context, node, input, filter, output, output_height, output_width));
 
   data->input_zero_point = input->params.zero_point;
   data->filter_zero_point = filter->params.zero_point;
diff --git a/tensorflow/lite/micro/kernels/batch_to_space_nd.cc b/tensorflow/lite/micro/kernels/batch_to_space_nd.cc
index 31a1c28..94d1228 100644
--- a/tensorflow/lite/micro/kernels/batch_to_space_nd.cc
+++ b/tensorflow/lite/micro/kernels/batch_to_space_nd.cc
@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -15,7 +15,10 @@
 
 #include "tensorflow/lite/kernels/internal/reference/batch_to_space_nd.h"
 
+#include <algorithm>
+
 #include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/internal/runtime_shape.h"
 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
 #include "tensorflow/lite/kernels/kernel_util.h"
 #include "tensorflow/lite/micro/kernels/kernel_util.h"
@@ -38,6 +41,68 @@
 const int kInputOutputMinDimensionNum = 3;
 const int kInputOutputMaxDimensionNum = 4;
 
+TfLiteStatus ReshapeOutputTensor(TfLiteContext* context, const TfLiteNode* node,
+                                 const TfLiteTensor* input,
+                                 const TfLiteTensor* block_shape,
+                                 const TfLiteTensor* crops,
+                                 TfLiteTensor* output) {
+  TF_LITE_ENSURE(context, IsConstantOrPersistentTensor(block_shape));
+  TF_LITE_ENSURE(context, IsConstantOrPersistentTensor(crops));
+  const int32_t* block_shape_data = GetTensorData<int32_t>(block_shape);
+  const int32_t* crops_data = GetTensorData<int32_t>(crops);
+
+  TfLiteIntArray* input_dims = input->dims;
+  int spatial_dims_num = input_dims->size - 2;
+  // Block_shape should be a 1D tensor with dimension [spatial_dims_num].
+  TF_LITE_ENSURE_EQ(context, NumDimensions(block_shape), 1);
+  TF_LITE_ENSURE_EQ(context, block_shape->dims->data[0], spatial_dims_num);
+  // Crops should be a 2D tensor with dimension [spatial_dims_num, 2].
+  TF_LITE_ENSURE_EQ(context, NumDimensions(crops), 2);
+  TF_LITE_ENSURE_EQ(context, crops->dims->data[0], spatial_dims_num);
+  TF_LITE_ENSURE_EQ(context, crops->dims->data[1], 2);
+
+  for (int i = 0; i < spatial_dims_num * 2; ++i) {
+    TF_LITE_ENSURE(context, crops_data[i] >= 0);
+  }
+
+  // copy from input tensor as per TfLite code
+  TF_LITE_ENSURE_EQ(context, input_dims->size, output->dims->size);
+  RuntimeShape output_shape = GetTensorShape(input);
+  // keep a copy of the output tensor shape for later comparison
+  RuntimeShape old_output_shape = GetTensorShape(output);
+
+  int output_batch_size = input_dims->data[0];
+  for (int dim = 0; dim < spatial_dims_num; ++dim) {
+    // Number of batch must be multiple of (block_shape[dim]).
+    TF_LITE_ENSURE(context, block_shape_data[dim] != 0);
+    TF_LITE_ENSURE_EQ(context, output_batch_size % block_shape_data[dim], 0);
+    output_batch_size = output_batch_size / block_shape_data[dim];
+    output_shape.SetDim(dim + 1,
+                        input_dims->data[dim + 1] * block_shape_data[dim] -
+                            crops_data[dim * 2] - crops_data[dim * 2 + 1]);
+  }
+  output_shape.SetDim(0, output_batch_size);
+  output_shape.SetDim(input_dims->size - 1,
+                      input_dims->data[input_dims->size - 1]);
+
+  // check if need to relocate output tensor dims
+  if (output_shape == old_output_shape) {
+    return kTfLiteOk;
+  }
+  TF_LITE_ENSURE(context,
+                 output_shape.FlatSize() <= old_output_shape.FlatSize());
+
+  // set the output tensor dims from output_shape
+  TfLiteEvalTensor* output_eval =
+      tflite::micro::GetEvalOutput(context, node, kOutputTensor);
+  TF_LITE_ENSURE_STATUS(tflite::micro::CreateWritableTensorDimsWithCopy(
+      context, output, output_eval));
+  std::copy_n(output_shape.DimsData(), output_shape.DimensionsCount(),
+              output->dims->data);
+
+  return kTfLiteOk;
+}
+
 TfLiteStatus BatchToSpaceNDPrepare(TfLiteContext* context, TfLiteNode* node) {
   TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
@@ -46,20 +111,40 @@
 
   TfLiteTensor* input =
       micro_context->AllocateTempInputTensor(node, kInputTensor);
+  TF_LITE_ENSURE(context, input != nullptr);
+  TfLiteTensor* block_shape =
+      micro_context->AllocateTempInputTensor(node, kBlockShapeTensor);
+  TF_LITE_ENSURE(context, block_shape != nullptr);
+  TfLiteTensor* crops =
+      micro_context->AllocateTempInputTensor(node, kCropsTensor);
+  TF_LITE_ENSURE(context, crops != nullptr);
   TfLiteTensor* output =
       micro_context->AllocateTempOutputTensor(node, kOutputTensor);
-  TF_LITE_ENSURE(context, input != nullptr && output != nullptr);
+  TF_LITE_ENSURE(context, output != nullptr);
 
   TF_LITE_ENSURE(context, NumDimensions(input) >= kInputOutputMinDimensionNum);
   TF_LITE_ENSURE(context, NumDimensions(output) >= kInputOutputMinDimensionNum);
   TF_LITE_ENSURE(context, NumDimensions(input) <= kInputOutputMaxDimensionNum);
   TF_LITE_ENSURE(context, NumDimensions(output) <= kInputOutputMaxDimensionNum);
   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
+  TF_LITE_ENSURE(context,
+                 input->type == kTfLiteFloat32 || input->type == kTfLiteInt8);
+
+  if (input->type == kTfLiteInt8) {
+    TF_LITE_ENSURE(context, input->params.scale == output->params.scale);
+    TF_LITE_ENSURE(context,
+                   input->params.zero_point == output->params.zero_point);
+  }
+
+  TfLiteStatus status =
+      ReshapeOutputTensor(context, node, input, block_shape, crops, output);
 
   micro_context->DeallocateTempTfLiteTensor(input);
+  micro_context->DeallocateTempTfLiteTensor(block_shape);
+  micro_context->DeallocateTempTfLiteTensor(crops);
   micro_context->DeallocateTempTfLiteTensor(output);
 
-  return kTfLiteOk;
+  return status;
 }
 
 TfLiteStatus BatchToSpaceNDEval(TfLiteContext* context, TfLiteNode* node) {
diff --git a/tensorflow/lite/micro/kernels/batch_to_space_nd_test.cc b/tensorflow/lite/micro/kernels/batch_to_space_nd_test.cc
index 455c325..1b42a29 100644
--- a/tensorflow/lite/micro/kernels/batch_to_space_nd_test.cc
+++ b/tensorflow/lite/micro/kernels/batch_to_space_nd_test.cc
@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -14,6 +14,8 @@
 ==============================================================================*/
 
 #include <cstdint>
+#include <limits>
+#include <type_traits>
 
 #include "tensorflow/lite/c/builtin_op_data.h"
 #include "tensorflow/lite/c/common.h"
@@ -25,98 +27,165 @@
 namespace testing {
 namespace {
 
-constexpr int kBasicInputOutputSize = 16;
-int basic_input_dims[] = {4, 4, 2, 2, 1};
-const float basic_input[kBasicInputOutputSize] = {
-    1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
-int basic_block_shape_dims[] = {1, 2};
-const int32_t basic_block_shape[] = {2, 2};
-int basic_crops_dims[] = {1, 4};
-const int32_t basic_crops[] = {0, 0, 0, 0};
-int basic_output_dims[] = {4, 1, 4, 4, 1};
-const float basic_golden[kBasicInputOutputSize] = {1, 5, 2, 6, 9,  13, 10, 14,
-                                                   3, 7, 4, 8, 11, 15, 12, 16};
+constexpr float kTestTolerance = 1e-05;
+constexpr int kNumInputs = 3;
+constexpr int kNumOutputs = 1;
+constexpr int kInputTensorIndex = 0;
+constexpr int kBlockShapeTensorIndex = 1;
+constexpr int kCropTensorIndex = 2;
+constexpr int kOutputTensorIndex = 3;
 
-template <typename T>
-TfLiteStatus ValidateBatchToSpaceNdGoldens(TfLiteTensor* tensors,
-                                           int tensors_size, const T* golden,
-                                           T* output, int output_size) {
-  int inputs_array_data[] = {3, 0, 1, 2};
-  TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
-  int outputs_array_data[] = {1, 3};
-  TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
+// min/max are used to compute scale, zero-point (asymmetric)
+template <typename T, size_t kInputSize, size_t kOutputSize>
+struct TestQuantParams {
+  // quantization parameters
+  float data_min;              // input data minimum value
+  float data_max;              // input data maximum value
+  T output_data[kOutputSize];  // quantized output storage
+  T input_data[kInputSize];    // quantized input storage
+};
 
-  const TFLMRegistration registration = Register_BATCH_TO_SPACE_ND();
-  micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
+TfLiteStatus ExecuteBatchToSpaceNdTest(TfLiteTensor* tensors,
+                                       int tensors_count) {
+  int kInputArrayData[] = {kNumInputs, kInputTensorIndex,
+                           kBlockShapeTensorIndex, kCropTensorIndex};
+  TfLiteIntArray* inputs_array = IntArrayFromInts(kInputArrayData);
+  int kOutputArrayData[] = {kNumOutputs, kOutputTensorIndex};
+  TfLiteIntArray* outputs_array = IntArrayFromInts(kOutputArrayData);
+
+  const TFLMRegistration registration = tflite::Register_BATCH_TO_SPACE_ND();
+  micro::KernelRunner runner(registration, tensors, tensors_count, inputs_array,
                              outputs_array, nullptr);
 
-  TF_LITE_ENSURE_STATUS(runner.InitAndPrepare());
-  TF_LITE_ENSURE_STATUS(runner.Invoke());
-
-  for (int i = 0; i < output_size; ++i) {
-    // TODO(b/158102673): workaround for not having fatal test assertions.
-    TF_LITE_MICRO_EXPECT_EQ(golden[i], output[i]);
-    if (golden[i] != output[i]) {
-      return kTfLiteError;
-    }
+  TfLiteStatus status = runner.InitAndPrepare();
+  if (status != kTfLiteOk) {
+    return status;
   }
+  status = runner.Invoke();
+
+  return status;
+}
+
+template <typename T>
+TfLiteStatus TestBatchToSpaceNd(int* input_dims_data[kNumInputs],
+                                const T* input_data,
+                                const int32_t* block_shape_data,
+                                const int32_t* crop_data, int* output_dims_data,
+                                const T* golden_data, T* output_data) {
+  TfLiteIntArray* input_dims =
+      IntArrayFromInts(input_dims_data[kInputTensorIndex]);
+  TfLiteIntArray* block_shape_dims =
+      IntArrayFromInts(input_dims_data[kBlockShapeTensorIndex]);
+  TfLiteIntArray* crop_dims =
+      IntArrayFromInts(input_dims_data[kCropTensorIndex]);
+  TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
+
+  constexpr int kTensorsCount = kNumInputs + kNumOutputs;
+  TfLiteTensor tensors[kTensorsCount];
+  tensors[kInputTensorIndex] =
+      tflite::testing::CreateTensor(input_data, input_dims);
+  tensors[kBlockShapeTensorIndex] =
+      tflite::testing::CreateTensor(block_shape_data, block_shape_dims);
+  tensors[kBlockShapeTensorIndex].allocation_type = kTfLiteMmapRo;
+  tensors[kCropTensorIndex] =
+      tflite::testing::CreateTensor(crop_data, crop_dims);
+  tensors[kCropTensorIndex].allocation_type = kTfLiteMmapRo;
+  tensors[kOutputTensorIndex] =
+      tflite::testing::CreateTensor(output_data, output_dims);
+
+  TfLiteStatus status = ExecuteBatchToSpaceNdTest(tensors, kTensorsCount);
+  if (status != kTfLiteOk) {
+    return status;
+  }
+
+  // check output dimensions (relocated) against original dimensions
+  TF_LITE_MICRO_EXPECT_EQ(output_dims->size,
+                          tensors[kOutputTensorIndex].dims->size);
+  TF_LITE_MICRO_CHECK_FAIL();
+  for (int i = 0; i < output_dims->size; i++) {
+    TF_LITE_MICRO_EXPECT_EQ(output_dims->data[i],
+                            tensors[kOutputTensorIndex].dims->data[i]);
+    // TODO(b/158102673): workaround for not having fatal test assertions.
+    TF_LITE_MICRO_CHECK_FAIL();
+  }
+
+  // check output data against expected
+  const int output_count = ElementCount(*output_dims);
+  for (int i = 0; i < output_count; i++) {
+    TF_LITE_MICRO_EXPECT_NEAR(golden_data[i], output_data[i], kTestTolerance);
+    // TODO(b/158102673): workaround for not having fatal test assertions.
+    TF_LITE_MICRO_CHECK_FAIL();
+  }
+
   return kTfLiteOk;
 }
 
-TfLiteStatus TestBatchToSpaceNdFloat(
-    int* input_dims_data, const float* input_data, int* block_shape_dims_data,
-    const int32_t* block_shape_data, int* crops_dims_data,
-    const int32_t* crops_data, int* output_dims_data, const float* golden,
-    float* output_data) {
-  TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
-  TfLiteIntArray* block_shape_dims = IntArrayFromInts(block_shape_dims_data);
-  TfLiteIntArray* crops_dims = IntArrayFromInts(crops_dims_data);
-  TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
-
-  constexpr int inputs_size = 3;
-  constexpr int outputs_size = 1;
-  constexpr int tensors_size = inputs_size + outputs_size;
-  TfLiteTensor tensors[tensors_size] = {
-      CreateTensor(input_data, input_dims),
-      CreateTensor(block_shape_data, block_shape_dims),
-      CreateTensor(crops_data, crops_dims),
-      CreateTensor(output_data, output_dims),
-  };
-
-  return ValidateBatchToSpaceNdGoldens(tensors, tensors_size, golden,
-                                       output_data, ElementCount(*output_dims));
-}
-
-template <typename T>
+template <typename T, size_t kInCount, size_t kOutCount>
 TfLiteStatus TestBatchToSpaceNdQuantized(
-    int* input_dims_data, const float* input_data, T* input_quantized,
-    float input_scale, int input_zero_point, int* block_shape_dims_data,
-    const int32_t* block_shape_data, int* crops_dims_data,
-    const int32_t* crops_data, int* output_dims_data, const float* golden,
-    T* golden_quantized, float output_scale, int output_zero_point,
-    T* output_data) {
-  TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
-  TfLiteIntArray* block_shape_dims = IntArrayFromInts(block_shape_dims_data);
-  TfLiteIntArray* crops_dims = IntArrayFromInts(crops_dims_data);
+    TestQuantParams<T, kInCount, kOutCount>& params,
+    int* input_dims_data[kNumInputs], const float* input_data,
+    const int32_t* block_shape_data, const int32_t* crop_data,
+    int* output_dims_data, const float* golden_data) {
+  TfLiteIntArray* input_dims =
+      IntArrayFromInts(input_dims_data[kInputTensorIndex]);
+  TfLiteIntArray* block_shape_dims =
+      IntArrayFromInts(input_dims_data[kBlockShapeTensorIndex]);
+  TfLiteIntArray* crop_dims =
+      IntArrayFromInts(input_dims_data[kCropTensorIndex]);
   TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
 
-  constexpr int inputs_size = 3;
-  constexpr int outputs_size = 1;
-  constexpr int tensors_size = inputs_size + outputs_size;
-  TfLiteTensor tensors[tensors_size] = {
-      tflite::testing::CreateQuantizedTensor(input_data, input_quantized,
-                                             input_dims, input_scale,
-                                             input_zero_point),
-      tflite::testing::CreateTensor(block_shape_data, block_shape_dims),
-      tflite::testing::CreateTensor(crops_data, crops_dims),
-      tflite::testing::CreateQuantizedTensor(output_data, output_dims,
-                                             output_scale, output_zero_point),
-  };
-  tflite::Quantize(golden, golden_quantized, ElementCount(*output_dims),
-                   output_scale, output_zero_point);
+  constexpr float kMaxMultiplier =
+      std::numeric_limits<T>::max() /
+      static_cast<float>(std::numeric_limits<T>::max() + 1);
+  int zero_point = tflite::testing::ZeroPointFromMinMax<T>(
+      params.data_min, params.data_max * kMaxMultiplier);
+  float scale = tflite::testing::ScaleFromMinMax<T>(
+      params.data_min, params.data_max * kMaxMultiplier);
 
-  return ValidateBatchToSpaceNdGoldens(tensors, tensors_size, golden_quantized,
-                                       output_data, ElementCount(*output_dims));
+  constexpr int kTensorsCount = kNumInputs + kNumOutputs;
+  TfLiteTensor tensors[kTensorsCount];
+  tensors[kInputTensorIndex] = tflite::testing::CreateQuantizedTensor(
+      input_data, params.input_data, input_dims, scale, zero_point);
+  tensors[kBlockShapeTensorIndex] =
+      tflite::testing::CreateTensor(block_shape_data, block_shape_dims);
+  tensors[kBlockShapeTensorIndex].allocation_type = kTfLiteMmapRo;
+  tensors[kCropTensorIndex] =
+      tflite::testing::CreateTensor(crop_data, crop_dims);
+  tensors[kCropTensorIndex].allocation_type = kTfLiteMmapRo;
+  tensors[kOutputTensorIndex] = tflite::testing::CreateQuantizedTensor(
+      params.output_data, output_dims, scale, zero_point);
+
+  TfLiteStatus status = ExecuteBatchToSpaceNdTest(tensors, kTensorsCount);
+  if (status != kTfLiteOk) {
+    return status;
+  }
+
+  // check output dimensions (relocated) against original dimensions
+  TF_LITE_MICRO_EXPECT_EQ(output_dims->size,
+                          tensors[kOutputTensorIndex].dims->size);
+  TF_LITE_MICRO_CHECK_FAIL();
+  for (int i = 0; i < output_dims->size; i++) {
+    TF_LITE_MICRO_EXPECT_EQ(output_dims->data[i],
+                            tensors[kOutputTensorIndex].dims->data[i]);
+    // TODO(b/158102673): workaround for not having fatal test assertions.
+    TF_LITE_MICRO_CHECK_FAIL();
+  }
+
+  // check output data against expected
+  const int output_count = ElementCount(*output_dims);
+  const float quantization_tolerance =
+      (params.data_max - params.data_min) /
+      (std::numeric_limits<T>::max() - std::numeric_limits<T>::min());
+  for (int i = 0; i < output_count; i++) {
+    float output_dequantized_data =
+        (params.output_data[i] - zero_point) * scale;
+    TF_LITE_MICRO_EXPECT_NEAR(golden_data[i], output_dequantized_data,
+                              quantization_tolerance);
+    // TODO(b/158102673): workaround for not having fatal test assertions.
+    TF_LITE_MICRO_CHECK_FAIL();
+  }
+
+  return kTfLiteOk;
 }
 
 }  // namespace
@@ -125,30 +194,291 @@
 
 TF_LITE_MICRO_TESTS_BEGIN
 
-TF_LITE_MICRO_TEST(BatchToSpaceBasicFloat) {
-  float output[tflite::testing::kBasicInputOutputSize];
-  TF_LITE_MICRO_EXPECT_EQ(
-      kTfLiteOk,
-      tflite::testing::TestBatchToSpaceNdFloat(
-          tflite::testing::basic_input_dims, tflite::testing::basic_input,
-          tflite::testing::basic_block_shape_dims,
-          tflite::testing::basic_block_shape, tflite::testing::basic_crops_dims,
-          tflite::testing::basic_crops, tflite::testing::basic_output_dims,
-          tflite::testing::basic_golden, output));
+TF_LITE_MICRO_TEST(BatchToSpaceNDOpTestInvalidOutputShapeTest) {
+  int kInputDims[] = {3, 2, 4, 1};
+  int kBlockShapeDims[] = {1, 1};
+  int kCropDims[] = {2, 1, 2};
+  int kOutputDims[] = {3, 1, 1, 1};  // invalid shape
+
+  int* kInputDimsArray[tflite::testing::kNumInputs];
+  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
+  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
+  kInputDimsArray[tflite::testing::kCropTensorIndex] = kCropDims;
+
+  constexpr float kInput[] = {1, 2, 3, 4, 5, 6, 7, 8};
+  constexpr int32_t kBlockShape[] = {2};
+  constexpr int32_t kCrop[] = {0, 0};
+  constexpr float kGolden[] = {0};  // placeholder data, not used
+  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
+  float output_data[kOutputCount];
+
+  TF_LITE_MICRO_EXPECT_EQ(kTfLiteError,
+                          tflite::testing::TestBatchToSpaceNd(
+                              kInputDimsArray, kInput, kBlockShape, kCrop,
+                              kOutputDims, kGolden, output_data));
 }
 
-TF_LITE_MICRO_TEST(BatchToSpaceBasicInt8) {
-  int8_t output[tflite::testing::kBasicInputOutputSize];
-  int8_t input_quantized[tflite::testing::kBasicInputOutputSize];
-  int8_t golden_quantized[tflite::testing::kBasicInputOutputSize];
+TF_LITE_MICRO_TEST(BatchToSpaceNDOpTestValidOutputShapeTest) {
+  int kInputDims[] = {3, 2, 4, 1};
+  int kBlockShapeDims[] = {1, 1};
+  int kCropDims[] = {2, 1, 2};
+  int kOutputDims[] = {3, 1, 8, 1};
+
+  int* kInputDimsArray[tflite::testing::kNumInputs];
+  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
+  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
+  kInputDimsArray[tflite::testing::kCropTensorIndex] = kCropDims;
+
+  constexpr float kInput[] = {1, 2, 3, 4, 5, 6, 7, 8};
+  constexpr int32_t kBlockShape[] = {2};
+  constexpr int32_t kCrop[] = {0, 0};
+  constexpr float kGolden[] = {1, 5, 2, 6, 3, 7, 4, 8};
+  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
+  float output_data[kOutputCount];
+
   TF_LITE_MICRO_EXPECT_EQ(
-      kTfLiteOk,
-      tflite::testing::TestBatchToSpaceNdQuantized(
-          tflite::testing::basic_input_dims, tflite::testing::basic_input,
-          input_quantized, 1.0f, 0, tflite::testing::basic_block_shape_dims,
-          tflite::testing::basic_block_shape, tflite::testing::basic_crops_dims,
-          tflite::testing::basic_crops, tflite::testing::basic_output_dims,
-          tflite::testing::basic_golden, golden_quantized, 1.0f, 0, output));
+      kTfLiteOk, tflite::testing::TestBatchToSpaceNd(
+                     kInputDimsArray, kInput, kBlockShape, kCrop, kOutputDims,
+                     kGolden, output_data));
+}
+
+TF_LITE_MICRO_TEST(BatchToSpaceNDOpTestSimpleConstTest) {
+  int kInputDims[] = {4, 4, 2, 2, 1};
+  int kBlockShapeDims[] = {1, 2};
+  int kCropDims[] = {2, 2, 2};
+  int kOutputDims[] = {4, 1, 4, 4, 1};
+
+  int* kInputDimsArray[tflite::testing::kNumInputs];
+  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
+  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
+  kInputDimsArray[tflite::testing::kCropTensorIndex] = kCropDims;
+
+  constexpr float kInput[] = {
+      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
+  };
+  constexpr int32_t kBlockShape[] = {2, 2};
+  constexpr int32_t kCrop[] = {0, 0, 0, 0};
+  constexpr float kGolden[] = {
+      1, 5, 2, 6, 9, 13, 10, 14, 3, 7, 4, 8, 11, 15, 12, 16,
+  };
+  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
+  float output_data[kOutputCount];
+
+  TF_LITE_MICRO_EXPECT_EQ(
+      kTfLiteOk, tflite::testing::TestBatchToSpaceNd(
+                     kInputDimsArray, kInput, kBlockShape, kCrop, kOutputDims,
+                     kGolden, output_data));
+}
+
+// non-quantized test
+TF_LITE_MICRO_TEST(BatchToSpaceNDOpTestSimpleConstTestInt8) {
+  int kInputDims[] = {4, 4, 2, 2, 1};
+  int kBlockShapeDims[] = {1, 2};
+  int kCropDims[] = {2, 2, 2};
+  int kOutputDims[] = {4, 1, 4, 4, 1};
+
+  int* kInputDimsArray[tflite::testing::kNumInputs];
+  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
+  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
+  kInputDimsArray[tflite::testing::kCropTensorIndex] = kCropDims;
+
+  constexpr int8_t kInput[] = {
+      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
+  };
+  constexpr int32_t kBlockShape[] = {2, 2};
+  constexpr int32_t kCrop[] = {0, 0, 0, 0};
+  constexpr int8_t kGolden[] = {
+      1, 5, 2, 6, 9, 13, 10, 14, 3, 7, 4, 8, 11, 15, 12, 16,
+  };
+  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
+  int8_t output_data[kOutputCount];
+
+  TF_LITE_MICRO_EXPECT_EQ(
+      kTfLiteOk, tflite::testing::TestBatchToSpaceNd(
+                     kInputDimsArray, kInput, kBlockShape, kCrop, kOutputDims,
+                     kGolden, output_data));
+}
+
+TF_LITE_MICRO_TEST(BatchToSpaceNDOpTestBatchOneConstTest) {
+  int kInputDims[] = {4, 1, 2, 2, 1};
+  int kBlockShapeDims[] = {1, 2};
+  int kCropDims[] = {2, 2, 2};
+  int kOutputDims[] = {4, 1, 2, 2, 1};
+
+  int* kInputDimsArray[tflite::testing::kNumInputs];
+  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
+  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
+  kInputDimsArray[tflite::testing::kCropTensorIndex] = kCropDims;
+
+  constexpr float kInput[] = {1, 2, 3, 4};
+  constexpr int32_t kBlockShape[] = {1, 1};
+  constexpr int32_t kCrop[] = {0, 0, 0, 0};
+  constexpr float kGolden[] = {1, 2, 3, 4};
+  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
+  float output_data[kOutputCount];
+
+  TF_LITE_MICRO_EXPECT_EQ(
+      kTfLiteOk, tflite::testing::TestBatchToSpaceNd(
+                     kInputDimsArray, kInput, kBlockShape, kCrop, kOutputDims,
+                     kGolden, output_data));
+}
+
+// non-quantized test
+TF_LITE_MICRO_TEST(BatchToSpaceNDOpTestSimpleConstTestInt8EmptyOutput) {
+  int kInputDims[] = {4, 4, 2, 2, 1};
+  int kBlockShapeDims[] = {1, 2};
+  int kCropDims[] = {2, 2, 2};
+  int kOutputDims[] = {4, 1, 4, 0, 1};
+
+  int* kInputDimsArray[tflite::testing::kNumInputs];
+  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
+  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
+  kInputDimsArray[tflite::testing::kCropTensorIndex] = kCropDims;
+
+  constexpr int8_t kInput[] = {
+      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
+  };
+  constexpr int32_t kBlockShape[] = {2, 2};
+  constexpr int32_t kCrop[] = {0, 0, 2, 2};
+  constexpr int8_t kGolden[] = {0};  // placeholder data, not used
+  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
+  int8_t output_data[kOutputCount];
+
+  TF_LITE_MICRO_EXPECT_EQ(
+      kTfLiteOk, tflite::testing::TestBatchToSpaceNd(
+                     kInputDimsArray, kInput, kBlockShape, kCrop, kOutputDims,
+                     kGolden, output_data));
+}
+
+TF_LITE_MICRO_TEST(BatchToSpaceNDOpTestInvalidShapeTest) {
+  int kInputDims[] = {4, 3, 2, 2, 1};
+  int kBlockShapeDims[] = {1, 2};
+  int kCropDims[] = {2, 2, 2};
+  int kOutputDims[] = {4, 0, 0, 0, 0};  // placeholder dims, not used
+
+  int* kInputDimsArray[tflite::testing::kNumInputs];
+  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
+  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
+  kInputDimsArray[tflite::testing::kCropTensorIndex] = kCropDims;
+
+  constexpr float kInput[] = {0};  // placeholder data, not used
+  constexpr int32_t kBlockShape[] = {2, 2};
+  constexpr int32_t kCrop[] = {0, 0, 0, 0};
+  constexpr float kGolden[] = {0};  // placeholder data, not used
+  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
+  float output_data[kOutputCount];
+
+  TF_LITE_MICRO_EXPECT_EQ(kTfLiteError,
+                          tflite::testing::TestBatchToSpaceNd(
+                              kInputDimsArray, kInput, kBlockShape, kCrop,
+                              kOutputDims, kGolden, output_data));
+}
+
+TF_LITE_MICRO_TEST(BatchToSpaceNDOpTestInvalidCropsConstTest) {
+  int kInputDims[] = {4, 3, 2, 2, 1};
+  int kBlockShapeDims[] = {1, 2};
+  int kCropDims[] = {2, 2, 2};
+  int kOutputDims[] = {4, 0, 0, 0, 0};  // placeholder dims, not used
+
+  int* kInputDimsArray[tflite::testing::kNumInputs];
+  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
+  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
+  kInputDimsArray[tflite::testing::kCropTensorIndex] = kCropDims;
+
+  constexpr float kInput[] = {0};  // placeholder data, not used
+  constexpr int32_t kBlockShape[] = {2, 2};
+  constexpr int32_t kCrop[] = {0, 0, 0, -1};
+  constexpr float kGolden[] = {0};  // placeholder data, not used
+  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
+  float output_data[kOutputCount];
+
+  TF_LITE_MICRO_EXPECT_EQ(kTfLiteError,
+                          tflite::testing::TestBatchToSpaceNd(
+                              kInputDimsArray, kInput, kBlockShape, kCrop,
+                              kOutputDims, kGolden, output_data));
+}
+
+TF_LITE_MICRO_TEST(BatchToSpaceNDOpTestSimple3DConstTest) {
+  int kInputDims[] = {3, 4, 4, 1};
+  int kBlockShapeDims[] = {1, 1};
+  int kCropDims[] = {2, 1, 2};
+  int kOutputDims[] = {3, 2, 8, 1};
+
+  int* kInputDimsArray[tflite::testing::kNumInputs];
+  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
+  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
+  kInputDimsArray[tflite::testing::kCropTensorIndex] = kCropDims;
+
+  constexpr float kInput[] = {
+      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
+  };
+  constexpr int32_t kBlockShape[] = {2};
+  constexpr int32_t kCrop[] = {0, 0};
+  constexpr float kGolden[] = {
+      1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15, 8, 16,
+  };
+  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
+  float output_data[kOutputCount];
+
+  TF_LITE_MICRO_EXPECT_EQ(
+      kTfLiteOk, tflite::testing::TestBatchToSpaceNd(
+                     kInputDimsArray, kInput, kBlockShape, kCrop, kOutputDims,
+                     kGolden, output_data));
+}
+
+TF_LITE_MICRO_TEST(BatchToSpaceNDOpTestSimple3DConstTestWithCrops) {
+  int kInputDims[] = {3, 4, 4, 1};
+  int kBlockShapeDims[] = {1, 1};
+  int kCropDims[] = {2, 1, 2};
+  int kOutputDims[] = {3, 2, 6, 1};
+
+  int* kInputDimsArray[tflite::testing::kNumInputs];
+  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
+  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
+  kInputDimsArray[tflite::testing::kCropTensorIndex] = kCropDims;
+
+  constexpr float kInput[] = {
+      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
+  };
+  constexpr int32_t kBlockShape[] = {2};
+  constexpr int32_t kCrop[] = {1, 1};
+  constexpr float kGolden[] = {9, 2, 10, 3, 11, 4, 13, 6, 14, 7, 15, 8};
+  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
+  float output_data[kOutputCount];
+
+  TF_LITE_MICRO_EXPECT_EQ(
+      kTfLiteOk, tflite::testing::TestBatchToSpaceNd(
+                     kInputDimsArray, kInput, kBlockShape, kCrop, kOutputDims,
+                     kGolden, output_data));
+}
+
+TF_LITE_MICRO_TEST(BatchToSpaceNDOpTestSimple3DConstTestWithCropsINT8) {
+  int kInputDims[] = {3, 4, 4, 1};
+  int kBlockShapeDims[] = {1, 1};
+  int kCropDims[] = {2, 1, 2};
+  int kOutputDims[] = {3, 2, 6, 1};
+
+  int* kInputDimsArray[tflite::testing::kNumInputs];
+  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
+  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
+  kInputDimsArray[tflite::testing::kCropTensorIndex] = kCropDims;
+
+  constexpr float kInput[] = {
+      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
+  };
+  constexpr int kInputCount = std::extent<decltype(kInput)>::value;
+  constexpr int32_t kBlockShape[] = {2};
+  constexpr int32_t kCrop[] = {1, 1};
+  constexpr float kGolden[] = {9, 2, 10, 3, 11, 4, 13, 6, 14, 7, 15, 8};
+  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
+
+  tflite::testing::TestQuantParams<int8_t, kInputCount, kOutputCount> params = {
+      -16, 16, {}, {}};
+
+  TF_LITE_MICRO_EXPECT_EQ(
+      kTfLiteOk, tflite::testing::TestBatchToSpaceNdQuantized(
+                     params, kInputDimsArray, kInput, kBlockShape, kCrop,
+                     kOutputDims, kGolden));
 }
 
 TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/conv.cc b/tensorflow/lite/micro/kernels/cmsis_nn/conv.cc
index d3d1552..ef15da7 100644
--- a/tensorflow/lite/micro/kernels/cmsis_nn/conv.cc
+++ b/tensorflow/lite/micro/kernels/cmsis_nn/conv.cc
@@ -75,29 +75,26 @@
           (input->type == kTfLiteInt8 && filter->type == kTfLiteInt4),
       "Hybrid models are not supported on TFLite Micro.");
 
-  RuntimeShape input_shape = GetTensorShape(input);
-  RuntimeShape output_shape = GetTensorShape(output);
+  // Check dimensionality of input, filter, output
+  TF_LITE_ENSURE_EQ(context, input->dims->size, 4);
+  TF_LITE_ENSURE_EQ(context, filter->dims->size, 4);
+  TF_LITE_ENSURE_EQ(context, output->dims->size, 4);
+
+  // Check input channels matching filter
+  const int input_channels = input->dims->data[3];
+  const int filter_input_channels = filter->dims->data[3];
+  TF_LITE_ENSURE(context, filter_input_channels > 0);
+  TF_LITE_ENSURE_EQ(context, input_channels % filter_input_channels, 0);
 
   // Initialize cmsis_nn input dimensions
   cmsis_nn_dims input_dims;
-  input_dims.n = MatchingDim(input_shape, 0, output_shape, 0);
   input_dims.h = input->dims->data[1];
   input_dims.w = input->dims->data[2];
-  input_dims.c = input_shape.Dims(3);
 
   // Initialize cmsis_nn filter dimensions
   cmsis_nn_dims filter_dims;
-  filter_dims.n = output_shape.Dims(3);
   filter_dims.h = filter->dims->data[1];
   filter_dims.w = filter->dims->data[2];
-  filter_dims.c = input_dims.c;
-
-  // Initialize cmsis_nn output dimensions
-  cmsis_nn_dims output_dims;
-  output_dims.n = input_dims.n;
-  output_dims.h = output->dims->data[1];
-  output_dims.w = output->dims->data[2];
-  output_dims.c = output_shape.Dims(3);
 
   if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) {
     const int num_channels = filter->dims->data[kConvQuantizedDimension];
@@ -109,11 +106,32 @@
             context, num_channels * sizeof(int32_t)));
   }
 
+  int output_height = 0;
+  int output_width = 0;
   TF_LITE_ENSURE_STATUS(CalculateOpDataConv(
       context, node, params, input_dims.w, input_dims.h, filter_dims.w,
-      filter_dims.h, output_dims.w, output_dims.h, input->type,
+      filter_dims.h, &output_width, &output_height, input->type,
       &data->reference_op_data));
 
+  // compute output tensor shape and relocate shape data
+  TF_LITE_ENSURE_STATUS(ConvReshapeOutputTensor(
+      context, node, input, filter, output, output_height, output_width));
+
+  // Finish initializing cmsis_nn input, filter dimensions
+  RuntimeShape input_shape = GetTensorShape(input);
+  RuntimeShape output_shape = GetTensorShape(output);
+  input_dims.n = MatchingDim(input_shape, 0, output_shape, 0);
+  input_dims.c = input_shape.Dims(3);
+  filter_dims.n = output_shape.Dims(3);
+  filter_dims.c = input_dims.c;
+
+  // Initialize cmsis_nn output dimensions
+  cmsis_nn_dims output_dims;
+  output_dims.n = input_dims.n;
+  output_dims.h = output_shape.Dims(1);
+  output_dims.w = output_shape.Dims(2);
+  output_dims.c = output_shape.Dims(3);
+
   // CMSIS_NN allows INT64 or nullptr bias data pointer
   if (input->type == kTfLiteInt8 ||
       (input->type == kTfLiteInt16 &&
diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/depthwise_conv.cc b/tensorflow/lite/micro/kernels/cmsis_nn/depthwise_conv.cc
index f30a952..77d6712 100644
--- a/tensorflow/lite/micro/kernels/cmsis_nn/depthwise_conv.cc
+++ b/tensorflow/lite/micro/kernels/cmsis_nn/depthwise_conv.cc
@@ -75,13 +75,29 @@
       micro_context->AllocateTempOutputTensor(node, kDepthwiseConvOutputTensor);
   TF_LITE_ENSURE(context, output != nullptr);
 
+  // Check dimensionality of input, filter, output
+  TF_LITE_ENSURE_EQ(context, input->dims->size, 4);
+  TF_LITE_ENSURE_EQ(context, filter->dims->size, 4);
+  TF_LITE_ENSURE_EQ(context, output->dims->size, 4);
+  TF_LITE_ENSURE(context, params.dilation_height_factor > 0);
+  TF_LITE_ENSURE(context, params.dilation_width_factor > 0);
+
+  // Filter in DepthwiseConv is expected to be [1, height, width, channels].
+  TF_LITE_ENSURE_EQ(context, filter->dims->data[0], 1);
+
+  // Check input channels matching filter
+  const int num_filter_channels = filter->dims->data[3];
+  const int num_input_channels = input->dims->data[3];
+  TF_LITE_ENSURE(context, num_input_channels != 0);
+  TF_LITE_ENSURE_EQ(context, num_filter_channels % num_input_channels, 0);
+
   const TfLiteType data_type = input->type;
   int input_width = SizeOfDimension(input, 2);
   int input_height = SizeOfDimension(input, 1);
   int filter_width = SizeOfDimension(filter, 2);
   int filter_height = SizeOfDimension(filter, 1);
-  int output_width = SizeOfDimension(output, 2);
-  int output_height = SizeOfDimension(output, 1);
+  int output_width = 0;
+  int output_height = 0;
 
   if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) {
     TF_LITE_ENSURE_EQ(context, filter->quantization.type,
@@ -120,9 +136,13 @@
 
   TF_LITE_ENSURE_STATUS(CalculateOpDataDepthwiseConv(
       context, node, params, input_width, input_height, filter_width,
-      filter_height, output_width, output_height, data_type,
+      filter_height, &output_width, &output_height, data_type,
       &data->reference_op_data));
 
+  // compute output tensor shape and relocate shape data
+  TF_LITE_ENSURE_STATUS(DepthwiseConvReshapeOutputTensor(
+      context, node, input, filter, output, output_height, output_width));
+
   if (input->type == kTfLiteInt8) {
     RuntimeShape input_shape = GetTensorShape(input);
     RuntimeShape output_shape = GetTensorShape(output);
diff --git a/tensorflow/lite/micro/kernels/conv.h b/tensorflow/lite/micro/kernels/conv.h
index 0c8073f..a927980 100644
--- a/tensorflow/lite/micro/kernels/conv.h
+++ b/tensorflow/lite/micro/kernels/conv.h
@@ -70,10 +70,20 @@
 TfLiteStatus CalculateOpDataConv(TfLiteContext* context, TfLiteNode* node,
                                  const TfLiteConvParams& params, int width,
                                  int height, int filter_width,
-                                 int filter_height, int out_width,
-                                 int out_height, const TfLiteType data_type,
+                                 int filter_height, int* out_width,
+                                 int* out_height, const TfLiteType data_type,
                                  OpDataConv* data);
 
+// When this method is called, the output tensor shape is computed and
+// relocated to persistent arena memory.
+// The height and width parameters should be the computed results from
+// CalculateOpDataConv.
+TfLiteStatus ConvReshapeOutputTensor(TfLiteContext* context, TfLiteNode* node,
+                                     const TfLiteTensor* input,
+                                     const TfLiteTensor* filter,
+                                     TfLiteTensor* output, int height,
+                                     int width);
+
 void* ConvInit(TfLiteContext* context, const char* buffer, size_t length);
 
 TfLiteStatus ConvPrepare(TfLiteContext* context, TfLiteNode* node);
diff --git a/tensorflow/lite/micro/kernels/conv_common.cc b/tensorflow/lite/micro/kernels/conv_common.cc
index 51c7a6f..ddcd5e5 100644
--- a/tensorflow/lite/micro/kernels/conv_common.cc
+++ b/tensorflow/lite/micro/kernels/conv_common.cc
@@ -79,8 +79,8 @@
 TfLiteStatus CalculateOpDataConv(TfLiteContext* context, TfLiteNode* node,
                                  const TfLiteConvParams& params, int width,
                                  int height, int filter_width,
-                                 int filter_height, int out_width,
-                                 int out_height, const TfLiteType data_type,
+                                 int filter_height, int* out_width,
+                                 int* out_height, const TfLiteType data_type,
                                  OpDataConv* data) {
   bool has_bias = node->inputs->size == 3;
   // Check number of inputs/outputs
@@ -92,7 +92,7 @@
   data->padding = ComputePaddingHeightWidth(
       params.stride_height, params.stride_width, params.dilation_height_factor,
       params.dilation_width_factor, height, width, filter_height, filter_width,
-      padding, &out_height, &out_width);
+      padding, out_height, out_width);
 
   MicroContext* micro_context = GetMicroContext(context);
 
@@ -135,6 +135,28 @@
   return kTfLiteOk;
 }
 
+TfLiteStatus ConvReshapeOutputTensor(TfLiteContext* context, TfLiteNode* node,
+                                     const TfLiteTensor* input,
+                                     const TfLiteTensor* filter,
+                                     TfLiteTensor* output, int height,
+                                     int width) {
+  const int filter_output_channels = filter->dims->data[0];
+  const int batches = input->dims->data[0];
+
+  // relocate output tensor dims so they can be updated
+  TfLiteEvalTensor* output_eval =
+      tflite::micro::GetEvalOutput(context, node, kConvOutputTensor);
+  TF_LITE_ENSURE_STATUS(tflite::micro::CreateWritableTensorDimsWithCopy(
+      context, output, output_eval));
+
+  output->dims->data[0] = batches;
+  output->dims->data[1] = height;
+  output->dims->data[2] = width;
+  output->dims->data[3] = filter_output_channels;
+
+  return kTfLiteOk;
+}
+
 TfLiteStatus ConvPrepare(TfLiteContext* context, TfLiteNode* node) {
   TFLITE_DCHECK(node->user_data != nullptr);
   TFLITE_DCHECK(node->builtin_data != nullptr);
@@ -163,12 +185,23 @@
            (filter->type == kTfLiteInt4 || filter->type == kTfLiteInt8)),
       "Hybrid models are not supported on TFLite Micro.");
 
+  // Check dimensionality of input, filter, output
+  TF_LITE_ENSURE_EQ(context, input->dims->size, 4);
+  TF_LITE_ENSURE_EQ(context, filter->dims->size, 4);
+  TF_LITE_ENSURE_EQ(context, output->dims->size, 4);
+
+  // Check input channels matching filter
+  const int input_channels = input->dims->data[3];
+  const int filter_input_channels = filter->dims->data[3];
+  TF_LITE_ENSURE(context, filter_input_channels > 0);
+  TF_LITE_ENSURE_EQ(context, input_channels % filter_input_channels, 0);
+
   const int input_width = input->dims->data[2];
   const int input_height = input->dims->data[1];
   const int filter_width = filter->dims->data[2];
   const int filter_height = filter->dims->data[1];
-  const int output_width = output->dims->data[2];
-  const int output_height = output->dims->data[1];
+  int output_width = 0;
+  int output_height = 0;
 
   // Dynamically allocate per-channel quantization parameters.
   const int num_channels = filter->dims->data[kConvQuantizedDimension];
@@ -198,7 +231,11 @@
 
   TF_LITE_ENSURE_STATUS(CalculateOpDataConv(
       context, node, params, input_width, input_height, filter_width,
-      filter_height, output_width, output_height, input->type, data));
+      filter_height, &output_width, &output_height, input->type, data));
+
+  // compute output tensor shape and relocate shape data
+  TF_LITE_ENSURE_STATUS(ConvReshapeOutputTensor(
+      context, node, input, filter, output, output_height, output_width));
 
   if (filter->type == kTfLiteInt4) {
     int filter_size =
diff --git a/tensorflow/lite/micro/kernels/conv_test.cc b/tensorflow/lite/micro/kernels/conv_test.cc
index 0fb9411..3cfc594 100644
--- a/tensorflow/lite/micro/kernels/conv_test.cc
+++ b/tensorflow/lite/micro/kernels/conv_test.cc
@@ -277,9 +277,6 @@
 }
 
 TF_LITE_MICRO_TEST(SimpleTestDilatedQuantizedPerChannel) {
-  const int output_dims_count = 24;
-  int8_t output_data[output_dims_count];
-
   const float input_scale = 0.5f;
   const float output_scale = 1.0f;
   const int input_zero_point = 0;
@@ -292,10 +289,10 @@
       1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4,
       // b = 1
       1, 2, 3, 4, 5, 6, 2, 6, 2, 4, 4, 2, 3, 2, 6, 5, 1, 4, 1, 2, 1, 4, 6, 3};
-  const int output_elements = 24;
-  int output_shape[] = {4, 2, 2, 2, 3};
-  const float golden_data[] = {25, 2, 7, 25, 2, 7, 10, 2, -3, 10, 2, -3,
-                               39, 7, 6, 50, 3, 4, 14, 4, -5, 15, 0, -7};
+  constexpr int output_elements = 12;
+  int output_shape[] = {4, 2, 1, 2, 3};
+  int8_t output_data[output_elements];
+  const float golden_data[] = {25, 2, 7, 25, 2, 7, 39, 7, 6, 50, 3, 4};
 
   int8_t input_quantized[input_elements];
   int8_t filter_quantized[tflite::testing::kFilterElements];
@@ -1087,7 +1084,7 @@
   using tflite::ElementCount;
   using tflite::kConvBiasQuantized8;
   using tflite::kConvFilter8x3x3x3;
-  using tflite::kConvGoldenOutput1x16x16x8;
+  using tflite::kConvGoldenOutput1x15x15x8;
   using tflite::kConvInput1x32x32x3;
   using tflite::testing::CreateTensor;
   using tflite::testing::FloatArrayFromFloats;
@@ -1159,8 +1156,8 @@
                                            0};
 
   // Create output tensor of 16x16x8
-  int8_t output_data[1 * 16 * 16 * kOutDepth];
-  int output_shape[] = {4, 1, 16, 16, kOutDepth};
+  int8_t output_data[1 * 15 * 15 * kOutDepth];
+  int output_shape[] = {4, 1, 15, 15, kOutDepth};
   TfLiteIntArray* output_dims = IntArrayFromInts(output_shape);
   const int output_dims_count = ElementCount(*output_dims);
   TfLiteTensor output_tensor = CreateTensor(output_data, output_dims);
@@ -1183,7 +1180,7 @@
 
   TF_LITE_MICRO_EXPECT_EQ(
       kTfLiteOk,
-      ValidateConvGoldens(tensors, tensors_size, kConvGoldenOutput1x16x16x8,
+      ValidateConvGoldens(tensors, tensors_size, kConvGoldenOutput1x15x15x8,
                           output_dims_count, &conv_params,
                           tflite::Register_CONV_2D(), output_data,
                           1.0 /* tolerance */));
diff --git a/tensorflow/lite/micro/kernels/conv_test_common.cc b/tensorflow/lite/micro/kernels/conv_test_common.cc
index a0f733b..eda801b 100644
--- a/tensorflow/lite/micro/kernels/conv_test_common.cc
+++ b/tensorflow/lite/micro/kernels/conv_test_common.cc
@@ -18,13 +18,18 @@
 namespace tflite {
 namespace testing {
 
+constexpr int kInputIndex = 0;
+constexpr int kFilterIndex = 1;
+constexpr int kBiasIndex = 2;
+constexpr int kOutputIndex = 3;
+
 template <typename T>
 TfLiteStatus InvokeConv(TfLiteTensor* tensors, int tensors_size,
                         int output_length, TfLiteConvParams* conv_params,
                         TFLMRegistration registration, T* output_data) {
-  int inputs_array_data[] = {3, 0, 1, 2};
+  int inputs_array_data[] = {3, kInputIndex, kFilterIndex, kBiasIndex};
   TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
-  int outputs_array_data[] = {1, 3};
+  int outputs_array_data[] = {1, kOutputIndex};
   TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
 
   micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
@@ -45,15 +50,37 @@
                                  TfLiteConvParams* conv_params,
                                  TFLMRegistration registration, T* output_data,
                                  float tolerance) {
+  // grab pointer to expected tensor shape
+  TfLiteIntArray* expected_out_dims = tensors[kOutputIndex].dims;
+
   TfLiteStatus status = InvokeConv(tensors, tensors_size, output_length,
                                    conv_params, registration, output_data);
   if (status != kTfLiteOk) {
     return status;
   }
+
+  // check output dimensions (relocated) against original dimensions
+  TF_LITE_MICRO_EXPECT_EQ(expected_out_dims->size,
+                          tensors[kOutputIndex].dims->size);
+  TF_LITE_MICRO_CHECK_FAIL();
+  for (int i = 0; i < expected_out_dims->size; i++) {
+    TF_LITE_MICRO_EXPECT_EQ(expected_out_dims->data[i],
+                            tensors[kOutputIndex].dims->data[i]);
+    // TODO(b/158102673): workaround for not having fatal test assertions.
+    TF_LITE_MICRO_CHECK_FAIL();
+  }
+
+  // compare expected against output data
+  const int actual_output_length = ElementCount(*expected_out_dims);
+  TF_LITE_MICRO_EXPECT_EQ(output_length, actual_output_length);
+  TF_LITE_MICRO_CHECK_FAIL();
   for (int i = 0; i < output_length; ++i) {
     TF_LITE_MICRO_EXPECT_NEAR(expected_output_data[i], output_data[i],
                               tolerance);
+    // TODO(b/158102673): workaround for not having fatal test assertions.
+    TF_LITE_MICRO_CHECK_FAIL();
   }
+
   return kTfLiteOk;
 }
 
diff --git a/tensorflow/lite/micro/kernels/depthwise_conv.h b/tensorflow/lite/micro/kernels/depthwise_conv.h
index 5f2d87e..b6712cd 100644
--- a/tensorflow/lite/micro/kernels/depthwise_conv.h
+++ b/tensorflow/lite/micro/kernels/depthwise_conv.h
@@ -44,9 +44,17 @@
 TfLiteStatus CalculateOpDataDepthwiseConv(
     TfLiteContext* context, TfLiteNode* node,
     const TfLiteDepthwiseConvParams& params, int width, int height,
-    int filter_width, int filter_height, int out_width, int out_height,
+    int filter_width, int filter_height, int* out_width, int* out_height,
     const TfLiteType data_type, OpDataConv* data);
 
+// When this method is called, the output tensor shape is computed and
+// relocated to persistent arena memory.
+// The height and width parameters should be the computed results from
+// CalculateOpDataConv.
+TfLiteStatus DepthwiseConvReshapeOutputTensor(
+    TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input,
+    const TfLiteTensor* filter, TfLiteTensor* output, int height, int width);
+
 TfLiteStatus DepthwiseConvPrepare(TfLiteContext* context, TfLiteNode* node);
 
 // This is the most generic TFLMRegistration. The actual supported types
diff --git a/tensorflow/lite/micro/kernels/depthwise_conv_common.cc b/tensorflow/lite/micro/kernels/depthwise_conv_common.cc
index 52804de..431bec0 100644
--- a/tensorflow/lite/micro/kernels/depthwise_conv_common.cc
+++ b/tensorflow/lite/micro/kernels/depthwise_conv_common.cc
@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -80,7 +80,7 @@
 TfLiteStatus CalculateOpDataDepthwiseConv(
     TfLiteContext* context, TfLiteNode* node,
     const TfLiteDepthwiseConvParams& params, int width, int height,
-    int filter_width, int filter_height, int out_width, int out_height,
+    int filter_width, int filter_height, int* out_width, int* out_height,
     const TfLiteType data_type, OpDataConv* data) {
   bool has_bias = node->inputs->size == 3;
   // Check number of inputs/outputs
@@ -92,7 +92,7 @@
   data->padding = ComputePaddingHeightWidth(
       params.stride_height, params.stride_width, params.dilation_height_factor,
       params.dilation_width_factor, height, width, filter_height, filter_width,
-      padding, &out_height, &out_width);
+      padding, out_height, out_width);
 
   MicroContext* micro_context = GetMicroContext(context);
 
@@ -133,6 +133,26 @@
   return kTfLiteOk;
 }
 
+TfLiteStatus DepthwiseConvReshapeOutputTensor(
+    TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input,
+    const TfLiteTensor* filter, TfLiteTensor* output, int height, int width) {
+  const int filter_output_channels = filter->dims->data[3];
+  const int batches = input->dims->data[0];
+
+  // relocate output tensor dims so they can be updated
+  TfLiteEvalTensor* output_eval =
+      tflite::micro::GetEvalOutput(context, node, kConvOutputTensor);
+  TF_LITE_ENSURE_STATUS(tflite::micro::CreateWritableTensorDimsWithCopy(
+      context, output, output_eval));
+
+  output->dims->data[0] = batches;
+  output->dims->data[1] = height;
+  output->dims->data[2] = width;
+  output->dims->data[3] = filter_output_channels;
+
+  return kTfLiteOk;
+}
+
 TfLiteStatus DepthwiseConvPrepare(TfLiteContext* context, TfLiteNode* node) {
   TFLITE_DCHECK(node->user_data != nullptr);
   TFLITE_DCHECK(node->builtin_data != nullptr);
@@ -152,12 +172,28 @@
       micro_context->AllocateTempInputTensor(node, kDepthwiseConvWeightsTensor);
   TF_LITE_ENSURE(context, filter != nullptr);
 
+  // Check dimensionality of input, filter, output
+  TF_LITE_ENSURE_EQ(context, input->dims->size, 4);
+  TF_LITE_ENSURE_EQ(context, filter->dims->size, 4);
+  TF_LITE_ENSURE_EQ(context, output->dims->size, 4);
+  TF_LITE_ENSURE(context, params.dilation_height_factor > 0);
+  TF_LITE_ENSURE(context, params.dilation_width_factor > 0);
+
+  // Filter in DepthwiseConv is expected to be [1, height, width, channels].
+  TF_LITE_ENSURE_EQ(context, filter->dims->data[0], 1);
+
+  // Check input channels matching filter
+  const int num_filter_channels = filter->dims->data[3];
+  const int num_input_channels = input->dims->data[3];
+  TF_LITE_ENSURE(context, num_input_channels != 0);
+  TF_LITE_ENSURE_EQ(context, num_filter_channels % num_input_channels, 0);
+
   const int input_width = input->dims->data[2];
   const int input_height = input->dims->data[1];
   const int filter_width = filter->dims->data[2];
   const int filter_height = filter->dims->data[1];
-  const int output_width = output->dims->data[2];
-  const int output_height = output->dims->data[1];
+  int output_width = 0;
+  int output_height = 0;
 
   // Dynamically allocate per-channel quantization parameters.
   const int num_channels = filter->dims->data[kDepthwiseConvQuantizedDimension];
@@ -207,7 +243,11 @@
 
   TF_LITE_ENSURE_STATUS(CalculateOpDataDepthwiseConv(
       context, node, params, input_width, input_height, filter_width,
-      filter_height, output_width, output_height, input->type, data));
+      filter_height, &output_width, &output_height, input->type, data));
+
+  // compute output tensor shape and relocate shape data
+  TF_LITE_ENSURE_STATUS(DepthwiseConvReshapeOutputTensor(
+      context, node, input, filter, output, output_height, output_width));
 
   micro_context->DeallocateTempTfLiteTensor(output);
   micro_context->DeallocateTempTfLiteTensor(input);
diff --git a/tensorflow/lite/micro/kernels/depthwise_conv_test.cc b/tensorflow/lite/micro/kernels/depthwise_conv_test.cc
index b50b40a..4cf090c 100644
--- a/tensorflow/lite/micro/kernels/depthwise_conv_test.cc
+++ b/tensorflow/lite/micro/kernels/depthwise_conv_test.cc
@@ -1,5 +1,5 @@
 
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -25,8 +25,11 @@
 namespace testing {
 namespace {
 
-// Index of the output tensor in context->tensors, specific to
+// Indices of the tensors in context->tensors, specific to
 // DepthwiseConv.
+constexpr int kInputTensorIndex = 0;
+constexpr int kFilterTensorIndex = 1;
+constexpr int kBiasTensorIndex = 2;
 constexpr int kOutputTensorIndex = 3;
 
 constexpr int kMaxFilterChannels = 64;
@@ -43,9 +46,10 @@
     const T* expected_output_data, int output_length,
     TfLiteDepthwiseConvParams* conv_params, float tolerance, int tensors_size,
     TfLiteTensor* tensors) {
-  int inputs_array_data[] = {3, 0, 1, 2};
+  int inputs_array_data[] = {3, kInputTensorIndex, kFilterTensorIndex,
+                             kBiasTensorIndex};
   TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
-  int outputs_array_data[] = {1, 3};
+  int outputs_array_data[] = {1, kOutputTensorIndex};
   TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
 
   const TFLMRegistration registration = Register_DEPTHWISE_CONV_2D();
@@ -61,6 +65,8 @@
   conv_params->depth_multiplier = depth_mul;
 
   const char* init_data = reinterpret_cast<const char*>(conv_params);
+  // grab pointer to expected tensor shape
+  TfLiteIntArray* expected_out_dims = tensors[kOutputTensorIndex].dims;
 
   // TODO(b/154240825): Use a test macro here which fails and returns.
   TfLiteStatus status = runner.InitAndPrepare(init_data);
@@ -69,12 +75,28 @@
   }
   TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
 
-  const T* output_data = tflite::GetTensorData<T>(&tensors[kOutputTensorIndex]);
+  // check output dimensions (relocated) against original dimensions
+  TF_LITE_MICRO_EXPECT_EQ(expected_out_dims->size,
+                          tensors[kOutputTensorIndex].dims->size);
+  TF_LITE_MICRO_CHECK_FAIL();
+  for (int i = 0; i < expected_out_dims->size; i++) {
+    TF_LITE_MICRO_EXPECT_EQ(expected_out_dims->data[i],
+                            tensors[kOutputTensorIndex].dims->data[i]);
+    // TODO(b/158102673): workaround for not having fatal test assertions.
+    TF_LITE_MICRO_CHECK_FAIL();
+  }
 
+  const int actual_output_length = ElementCount(*expected_out_dims);
+  TF_LITE_MICRO_EXPECT_EQ(output_length, actual_output_length);
+  TF_LITE_MICRO_CHECK_FAIL();
+  const T* output_data = tflite::GetTensorData<T>(&tensors[kOutputTensorIndex]);
   for (int i = 0; i < output_length; ++i) {
     TF_LITE_MICRO_EXPECT_NEAR(expected_output_data[i], output_data[i],
                               tolerance);
+    // TODO(b/158102673): workaround for not having fatal test assertions.
+    TF_LITE_MICRO_CHECK_FAIL();
   }
+
   return kTfLiteOk;
 }
 
diff --git a/tensorflow/lite/micro/kernels/expand_dims.cc b/tensorflow/lite/micro/kernels/expand_dims.cc
index 6bae37b..19c54d7 100644
--- a/tensorflow/lite/micro/kernels/expand_dims.cc
+++ b/tensorflow/lite/micro/kernels/expand_dims.cc
@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -47,38 +47,43 @@
   }
 }
 
-// Verifies that the output tensor's dimension shape is equivalent to inserting
+// Rewrite the output tensor's dimension shape so it is equivalent to inserting
 // a dimension of length 1 at the dimension index axis of input's shape as
 // defined in https://www.tensorflow.org/api_docs/python/tf/expand_dims.
-TfLiteStatus VerifyTensorDim(TfLiteContext* context, const TfLiteTensor* input,
-                             const TfLiteTensor* axis_tensor,
-                             const TfLiteTensor* output) {
+TfLiteStatus ReshapeOutputTensor(TfLiteContext* context, TfLiteNode* node,
+                                 const TfLiteTensor* input,
+                                 const TfLiteTensor* axis_tensor,
+                                 TfLiteTensor* output) {
   int32_t axis_value = 0;
   TF_LITE_ENSURE_OK(context,
                     GetAxisValueFromTensor(context, axis_tensor, &axis_value));
 
-  tflite::RuntimeShape input_shape = tflite::GetTensorShape(input);
+  TfLiteIntArray* input_shape = input->dims;
   if (axis_value < 0) {
-    axis_value = input_shape.DimensionsCount() + 1 + axis_value;
+    axis_value = input_shape->size + 1 + axis_value;
   }
-  TF_LITE_ENSURE(context, axis_value <= input_shape.DimensionsCount());
 
-  // TFLM only supports fixed dimension tensor and assumes that the output shape
-  // is fully specified in the model. As such, TFLM directly use the pointer to
-  // the dimension array in the model buffer.
-  tflite::RuntimeShape output_shape = tflite::GetTensorShape(output);
+  // relocate output tensor dims so they can be updated
+  TfLiteEvalTensor* output_eval =
+      tflite::micro::GetEvalOutput(context, node, kOutputTensor);
+  TF_LITE_ENSURE_STATUS(tflite::micro::CreateWritableTensorDimsWithCopy(
+      context, output, output_eval));
 
-  TF_LITE_ENSURE(context, output_shape.DimensionsCount() ==
-                              input_shape.DimensionsCount() + 1);
-  for (int i = 0; i < output_shape.DimensionsCount(); ++i) {
+  TfLiteIntArray* output_shape = output->dims;
+  TF_LITE_ENSURE(context, output_shape->size == input_shape->size + 1);
+  TF_LITE_ENSURE(context, axis_value < output_shape->size);
+  TF_LITE_ENSURE(context, axis_value >= 0);
+
+  for (int i = 0; i < output_shape->size; ++i) {
     if (i < axis_value) {
-      TF_LITE_ENSURE(context, output_shape.Dims(i) == input_shape.Dims(i));
+      output_shape->data[i] = input_shape->data[i];
     } else if (i == axis_value) {
-      TF_LITE_ENSURE(context, output_shape.Dims(i) == 1);
+      output_shape->data[i] = 1;
     } else {
-      TF_LITE_ENSURE(context, output_shape.Dims(i) == input_shape.Dims(i - 1));
+      output_shape->data[i] = input_shape->data[i - 1];
     }
   }
+
   return kTfLiteOk;
 }
 
@@ -101,7 +106,8 @@
     MicroPrintf("DynamicTensor is not yet supported by Expand_Dims.");
     return kTfLiteError;
   }
-  TF_LITE_ENSURE_OK(context, VerifyTensorDim(context, input, axis, output));
+  TF_LITE_ENSURE_OK(context,
+                    ReshapeOutputTensor(context, node, input, axis, output));
 
   micro_context->DeallocateTempTfLiteTensor(input);
   micro_context->DeallocateTempTfLiteTensor(axis);
diff --git a/tensorflow/lite/micro/kernels/expand_dims_test.cc b/tensorflow/lite/micro/kernels/expand_dims_test.cc
index d8e217e..4a49a8e 100644
--- a/tensorflow/lite/micro/kernels/expand_dims_test.cc
+++ b/tensorflow/lite/micro/kernels/expand_dims_test.cc
@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -34,28 +34,24 @@
 constexpr int kInputTensors[] = {2, kDimsTensorIndex, kAxisTensorIndex};
 constexpr int kOutputTensors[] = {1, kOutputTensorIndex};
 
+// Some targets do not support dynamic memory (i.e., no malloc or new), thus,
+// the test need to place non-transitent memories in static variables. This is
+// safe because tests are guaranteed to run serially.
+// Both below structures are trivially destructible.
+static TFLMRegistration registration;
+static TfLiteTensor tensors[kTensorsSize];
+
 template <typename T>
 micro::KernelRunner CreateExpandDimsKernelRunner(
     int* input_dims, const T* input_data, int* axis_dims,
     const int32_t* axis_data, int* output_dims, T* output_data) {
-  // Some targets do not support dynamic memory (i.e., no malloc or new), thus,
-  // the test need to place non-transitent memories in static variables. This is
-  // safe because tests are guaranteed to run serially.
-  // Both below structures are trivially destructible.
-  static TFLMRegistration registration;
-  static TfLiteTensor tensors[kTensorsSize];
-
   TfLiteIntArray* in_dims = IntArrayFromInts(input_dims);
   TfLiteIntArray* ax_dims = IntArrayFromInts(axis_dims);
   TfLiteIntArray* out_dims = IntArrayFromInts(output_dims);
 
-  const int out_dims_size = out_dims->size;
-  const int in_dims_size = in_dims->size;
-  TF_LITE_MICRO_EXPECT_EQ(out_dims_size, (in_dims_size + 1));
-
   tensors[kDimsTensorIndex] = CreateTensor(input_data, in_dims);
   tensors[kAxisTensorIndex] = CreateTensor(axis_data, ax_dims);
-  tensors[kOutputTensorIndex] = CreateTensor(output_data, out_dims, true);
+  tensors[kOutputTensorIndex] = CreateTensor(output_data, out_dims);
 
   TfLiteIntArray* inputs_array =
       IntArrayFromInts(const_cast<int*>(kInputTensors));
@@ -81,9 +77,16 @@
   TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
   TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
 
-  // The output tensor's data have been updated by the kernel.
-  TfLiteIntArray* actual_out_dims = IntArrayFromInts(output_dims);
+  // The output tensor shape has been updated by the kernel.
+  TfLiteIntArray* actual_out_dims = tensors[kOutputTensorIndex].dims;
   const int output_size = ElementCount(*actual_out_dims);
+  TfLiteIntArray* golden_out_dims = IntArrayFromInts(expected_output_dims);
+
+  // check output dimensions (relocated) against original dimensions
+  TF_LITE_MICRO_EXPECT_EQ(golden_out_dims->size, actual_out_dims->size);
+  for (int i = 0; i < golden_out_dims->size; i++) {
+    TF_LITE_MICRO_EXPECT_EQ(golden_out_dims->data[i], actual_out_dims->data[i]);
+  }
 
   for (int i = 0; i < output_size; ++i) {
     TF_LITE_MICRO_EXPECT_EQ(expected_output_data[i], output_data[i]);
@@ -103,10 +106,9 @@
   const int8_t golden_data[] = {-1, 1, -2, 2};
   int axis_dims[] = {1, 1};
   const int32_t axis_data[] = {0};
-  int golden_dims[] = {1, 2, 2};
   int output_dims[] = {3, 1, 2, 2};
   tflite::testing::TestExpandDims<int8_t>(input_dims, input_data, axis_dims,
-                                          axis_data, golden_dims, output_dims,
+                                          axis_data, output_dims, output_dims,
                                           golden_data, output_data);
 }
 
@@ -117,10 +119,9 @@
   const float golden_data[] = {-1.1, 1.2, -2.1, 2.2};
   int axis_dims[] = {1, 1};
   const int32_t axis_data[] = {1};
-  int golden_dims[] = {2, 1, 2};
   int output_dims[] = {3, 2, 1, 2};
   tflite::testing::TestExpandDims<float>(input_dims, input_data, axis_dims,
-                                         axis_data, golden_dims, output_dims,
+                                         axis_data, output_dims, output_dims,
                                          golden_data, output_data);
 }
 
@@ -131,10 +132,9 @@
   const int8_t golden_data[] = {-1, 1, -2, 2};
   int axis_dims[] = {1, 1};
   const int32_t axis_data[] = {2};
-  int golden_dims[] = {2, 2, 1};
   int output_dims[] = {3, 2, 2, 1};
   tflite::testing::TestExpandDims<int8_t>(input_dims, input_data, axis_dims,
-                                          axis_data, golden_dims, output_dims,
+                                          axis_data, output_dims, output_dims,
                                           golden_data, output_data);
 }
 
@@ -145,10 +145,9 @@
   const int8_t golden_data[] = {-1, 1, 2, -2, 0, 3};
   int axis_dims[] = {1, 1};
   const int32_t axis_data[] = {-4};
-  int golden_dims[] = {1, 3, 1, 2};
   int output_dims[] = {4, 1, 3, 1, 2};
   tflite::testing::TestExpandDims<int8_t>(input_dims, input_data, axis_dims,
-                                          axis_data, golden_dims, output_dims,
+                                          axis_data, output_dims, output_dims,
                                           golden_data, output_data);
 }
 
@@ -159,10 +158,9 @@
   const float golden_data[] = {0.1, -0.8, -1.2, -0.5, 0.9, 1.3};
   int axis_dims[] = {1, 1};
   const int32_t axis_data[] = {-3};
-  int golden_dims[] = {3, 1, 1, 2};
   int output_dims[] = {4, 3, 1, 1, 2};
   tflite::testing::TestExpandDims<float>(input_dims, input_data, axis_dims,
-                                         axis_data, golden_dims, output_dims,
+                                         axis_data, output_dims, output_dims,
                                          golden_data, output_data);
 }
 
@@ -173,10 +171,9 @@
   const int8_t golden_data[] = {-1, 1, 2, -2, 0, 3};
   int axis_dims[] = {1, 1};
   const int32_t axis_data[] = {-2};
-  int golden_dims[] = {1, 2, 1, 3};
   int output_dims[] = {4, 1, 2, 1, 3};
   tflite::testing::TestExpandDims<int8_t>(input_dims, input_data, axis_dims,
-                                          axis_data, golden_dims, output_dims,
+                                          axis_data, output_dims, output_dims,
                                           golden_data, output_data);
 }
 
@@ -187,24 +184,40 @@
   const float golden_data[] = {0.1, -0.8, -1.2, -0.5, 0.9, 1.3};
   int axis_dims[] = {1, 1};
   const int32_t axis_data[] = {-1};
-  int golden_dims[] = {1, 3, 2, 1};
   int output_dims[] = {4, 1, 3, 2, 1};
   tflite::testing::TestExpandDims<float>(input_dims, input_data, axis_dims,
+                                         axis_data, output_dims, output_dims,
+                                         golden_data, output_data);
+}
+
+TF_LITE_MICRO_TEST(ExpandDimsInputOutputDimsMismatch) {
+  float output_data[6];
+  int input_dims[] = {3, 1, 3, 2};
+  const float input_data[] = {0.1, -0.8, -1.2, -0.5, 0.9, 1.3};
+  const float golden_data[] = {0.1, -0.8, -1.2, -0.5, 0.9, 1.3};
+  int axis_dims[] = {1, 1};
+  const int32_t axis_data[] = {-1};
+  // When input dimension is [1, 3, 2] and the axis is -1, the output dimension
+  // should be [1, 3, 2, 1] as in the test case ExpandDimsNegativeAxisTest1.
+  // Shuffle the output dimension to make it incorrect to test that output
+  // tensor dimensions are computed correctly.
+  int golden_dims[] = {4, 1, 3, 2, 1};
+  int output_dims[] = {4, 1, 3, 1, 2};
+
+  tflite::testing::TestExpandDims<float>(input_dims, input_data, axis_dims,
                                          axis_data, golden_dims, output_dims,
                                          golden_data, output_data);
 }
 
-TF_LITE_MICRO_TEST(ExpandDimsInputOutputDimsMismatchShallFail) {
-  float output_data[6];
+TF_LITE_MICRO_TEST(ExpandDimsAxisPositiveOutOfRangeShallFailTest) {
+  int8_t output_data[6];
   int input_dims[] = {3, 1, 3, 2};
-  const float input_data[] = {0.1, -0.8, -1.2, -0.5, 0.9, 1.3};
+  const int8_t input_data[] = {1, 8, 2, 5, 9, 3};
   int axis_dims[] = {1, 1};
-  const int32_t axis_data[] = {-1};
-  // When input dimension is [1, 3, 2] and the axis is -1, the output dimension
-  // should be [1, 3, 2, 1] as in the test case ExpandDimsNegativeAxisTest1.
-  // Shuffle the output dimension to make it incorrect so that the EXPAND_DIMS
-  // op would fail at prepare.
-  int output_dims[] = {4, 1, 3, 1, 2};
+  // The input dimension is 3-D, so that axis value should not exceed 3.
+  // The below axis value 4 shall lead to failure at prepare.
+  const int32_t axis_data[] = {4};
+  int output_dims[] = {4, 1, 3, 2, 1};
 
   tflite::micro::KernelRunner runner =
       tflite::testing::CreateExpandDimsKernelRunner(input_dims, input_data,
@@ -214,14 +227,14 @@
   TF_LITE_MICRO_EXPECT_EQ(kTfLiteError, runner.InitAndPrepare());
 }
 
-TF_LITE_MICRO_TEST(ExpandDimsAxisOutOfRangeShallFail) {
+TF_LITE_MICRO_TEST(ExpandDimsAxisNegativeOutOfRangeShallFailTest) {
   int8_t output_data[6];
   int input_dims[] = {3, 1, 3, 2};
   const int8_t input_data[] = {1, 8, 2, 5, 9, 3};
   int axis_dims[] = {1, 1};
-  // The input dimension is 3-D, so that axis value should not exceed 3.
-  // The below axis value 4 shall lead to failure at prepare.
-  const int32_t axis_data[] = {4};
+  // The input dimension is 3-D, so that axis value should be less than zero.
+  // The below axis value -5 shall lead to failure at prepare.
+  const int32_t axis_data[] = {-5};
   int output_dims[] = {4, 1, 3, 2, 1};
 
   tflite::micro::KernelRunner runner =
diff --git a/tensorflow/lite/micro/kernels/reshape.h b/tensorflow/lite/micro/kernels/reshape.h
index 02bda32..15e5f61 100644
--- a/tensorflow/lite/micro/kernels/reshape.h
+++ b/tensorflow/lite/micro/kernels/reshape.h
@@ -19,6 +19,7 @@
 namespace tflite {
 
 constexpr int kReshapeInputTensor = 0;
+constexpr int kReshapeShapeTensor = 1;
 constexpr int kReshapeOutputTensor = 0;
 
 TfLiteStatus PrepareReshapeReference(TfLiteContext* context, TfLiteNode* node);
diff --git a/tensorflow/lite/micro/kernels/reshape_common.cc b/tensorflow/lite/micro/kernels/reshape_common.cc
index b86e2be..91fd1bc 100644
--- a/tensorflow/lite/micro/kernels/reshape_common.cc
+++ b/tensorflow/lite/micro/kernels/reshape_common.cc
@@ -35,6 +35,9 @@
   TfLiteTensor* input =
       micro_context->AllocateTempInputTensor(node, kReshapeInputTensor);
   TF_LITE_ENSURE(context, input != nullptr);
+  // The shape tensor is optional
+  TfLiteTensor* new_shape =
+      micro_context->AllocateTempInputTensor(node, kReshapeShapeTensor);
   TfLiteTensor* output =
       micro_context->AllocateTempOutputTensor(node, kReshapeOutputTensor);
   TF_LITE_ENSURE(context, output != nullptr);
@@ -43,20 +46,35 @@
   // input. Here we calculate what that dimension should be so that the number
   // of output elements in the same as the number of input elements.
   int num_input_elements = NumElements(input);
-  TfLiteIntArray* output_shape = output->dims;
+
+  int output_shape_size = 0;
+  int* output_shape_data = nullptr;
+  if (new_shape != nullptr && new_shape->dims->size > 0) {
+    // use new shape tensor data
+    TF_LITE_ENSURE_EQ(context, new_shape->type, kTfLiteInt32);
+    TF_LITE_ENSURE_EQ(context, new_shape->dims->size, 1);
+    output_shape_data = GetTensorData<int>(new_shape);
+    output_shape_size = new_shape->dims->data[0];
+    TF_LITE_ENSURE_EQ(context, output_shape_size,
+                      static_cast<int32_t>(output->dims->size));
+  } else {
+    // use output shape
+    output_shape_size = output->dims->size;
+    output_shape_data = output->dims->data;
+  }
 
   if (NumInputs(node) == 1 &&  // Legacy scalar supported with params.
-      output_shape->size == 1 && output_shape->data[0] == 0) {
+      output_shape_size == 1 && output_shape_data[0] == 0) {
     // Legacy tflite models use a shape parameter of [0] to indicate scalars,
     // so adjust accordingly. TODO(b/111614235): Allow zero-sized buffers during
     // toco conversion.
-    output_shape->size = 0;
+    output_shape_size = 0;
   }
 
   int num_output_elements = 1;
   int stretch_dim = -1;
-  for (int i = 0; i < output_shape->size; ++i) {
-    int value = output_shape->data[i];
+  for (int i = 0; i < output_shape_size; ++i) {
+    int value = output_shape_data[i];
     if (value == -1) {
       TF_LITE_ENSURE_EQ(context, stretch_dim, -1);
       stretch_dim = i;
@@ -64,20 +82,26 @@
       num_output_elements *= value;
     }
   }
-  if (stretch_dim != -1) {
+  if (stretch_dim != -1 || output_shape_size == 0) {
     TfLiteEvalTensor* output_eval =
         tflite::micro::GetEvalOutput(context, node, kReshapeOutputTensor);
     TF_LITE_ENSURE_STATUS(tflite::micro::CreateWritableTensorDimsWithCopy(
         context, output, output_eval));
-    output_shape = output->dims;  // output tensor dims were moved
-    output_shape->data[stretch_dim] = num_input_elements / num_output_elements;
-    num_output_elements *= output_shape->data[stretch_dim];
+    if (stretch_dim != -1) {
+      output->dims->data[stretch_dim] =
+          num_input_elements / num_output_elements;
+      num_output_elements *= output->dims->data[stretch_dim];
+    }
+    output->dims->size = output_shape_size;
   }
 
   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
   TF_LITE_ENSURE_EQ(context, num_input_elements, num_output_elements);
 
   micro_context->DeallocateTempTfLiteTensor(input);
+  if (new_shape != nullptr) {
+    micro_context->DeallocateTempTfLiteTensor(new_shape);
+  }
   micro_context->DeallocateTempTfLiteTensor(output);
   return kTfLiteOk;
 }
diff --git a/tensorflow/lite/micro/kernels/reshape_test.cc b/tensorflow/lite/micro/kernels/reshape_test.cc
index d78d9fa..d97007a 100644
--- a/tensorflow/lite/micro/kernels/reshape_test.cc
+++ b/tensorflow/lite/micro/kernels/reshape_test.cc
@@ -265,7 +265,7 @@
       golden_dims_len, false);
 }
 
-// Stretch is not supported with TF Micro
+// Stretch is supported with TF Micro
 TF_LITE_MICRO_TEST(ReshapeWithStretchDimensionShouldSucceed) {
   float output_data_float[32];
   int8_t output_data_int8[32];
@@ -347,7 +347,7 @@
   const float input_data[] = {3.0f};
   auto input_tensor = CreateTensor(input_data, input_dims);
 
-  float output_data[1];
+  float output_data[] = {0.0f};
   int output_dims_data[2] = {1, 0};
   TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
   auto output_tensor = CreateTensor(output_data, output_dims);
diff --git a/tensorflow/lite/micro/kernels/space_to_batch_nd.cc b/tensorflow/lite/micro/kernels/space_to_batch_nd.cc
index f8df149..e5e86d1 100644
--- a/tensorflow/lite/micro/kernels/space_to_batch_nd.cc
+++ b/tensorflow/lite/micro/kernels/space_to_batch_nd.cc
@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -15,7 +15,10 @@
 
 #include "tensorflow/lite/kernels/internal/reference/space_to_batch_nd.h"
 
+#include <algorithm>
+
 #include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/internal/runtime_shape.h"
 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
 #include "tensorflow/lite/kernels/internal/types.h"
 #include "tensorflow/lite/kernels/kernel_util.h"
@@ -24,12 +27,11 @@
 #include "tensorflow/lite/micro/micro_utils.h"
 
 namespace tflite {
-
 namespace {
 
 constexpr int kInputTensor = 0;
 constexpr int kBlockShapeTensor = 1;
-constexpr int kCropsTensor = 2;
+constexpr int kPaddingTensor = 2;
 constexpr int kOutputTensor = 0;
 
 // Currently, only 3D NHC and 4D NHWC input/output op_context are supported.
@@ -45,6 +47,68 @@
   return context->AllocatePersistentBuffer(context, sizeof(SpaceToBatchParams));
 }
 
+TfLiteStatus ReshapeOutputTensor(TfLiteContext* context, const TfLiteNode* node,
+                                 const TfLiteTensor* input,
+                                 const TfLiteTensor* block_shape,
+                                 const TfLiteTensor* padding,
+                                 TfLiteTensor* output) {
+  TF_LITE_ENSURE(context, IsConstantOrPersistentTensor(block_shape));
+  TF_LITE_ENSURE(context, IsConstantOrPersistentTensor(padding));
+  const int32_t* block_shape_data = GetTensorData<int32_t>(block_shape);
+  const int32_t* padding_data = GetTensorData<int32_t>(padding);
+
+  TfLiteIntArray* input_dims = input->dims;
+  int spatial_dims_num = input_dims->size - 2;
+  // Block_shape should be a 1D tensor with dimension [spatial_dims_num].
+  TF_LITE_ENSURE_EQ(context, NumDimensions(block_shape), 1);
+  TF_LITE_ENSURE_EQ(context, block_shape->dims->data[0], spatial_dims_num);
+  // Padding should be a 2D tensor with dimension [spatial_dims_num, 2].
+  TF_LITE_ENSURE_EQ(context, NumDimensions(padding), 2);
+  TF_LITE_ENSURE_EQ(context, padding->dims->data[0], spatial_dims_num);
+  TF_LITE_ENSURE_EQ(context, padding->dims->data[1], 2);
+
+  // copy from input tensor as per TfLite code
+  TF_LITE_ENSURE_EQ(context, input_dims->size, output->dims->size);
+  RuntimeShape output_shape = GetTensorShape(input);
+  // keep a copy of the output tensor shape for later comparison
+  RuntimeShape old_output_shape = GetTensorShape(output);
+
+  // Ensures the input height and width (with padding) is a multiple of block
+  // shape height and width.
+  int output_batch_size = input_dims->data[0];
+  for (int dim = 0; dim < spatial_dims_num; ++dim) {
+    int final_dim_size = (input_dims->data[dim + 1] + padding_data[dim * 2] +
+                          padding_data[dim * 2 + 1]);
+    TF_LITE_ENSURE(context, block_shape_data[dim] != 0);
+    TF_LITE_ENSURE_EQ(context, final_dim_size % block_shape_data[dim], 0);
+    output_shape.SetDim(dim + 1, final_dim_size / block_shape_data[dim]);
+    output_batch_size *= block_shape_data[dim];
+  }
+  output_shape.SetDim(0, output_batch_size);
+  output_shape.SetDim(input_dims->size - 1,
+                      input_dims->data[input_dims->size - 1]);
+
+  // check if need to relocate output tensor dims
+  if (output_shape == old_output_shape) {
+    return kTfLiteOk;
+  } else if (output_shape.FlatSize() > old_output_shape.FlatSize() &&
+             output->data.data != nullptr) {
+    MicroPrintf(
+        "SPACE_TO_BATCH_ND: resizing flatbuffer tensor data is not supported");
+    return kTfLiteError;
+  }
+
+  // set the output tensor dims from output_shape
+  TfLiteEvalTensor* output_eval =
+      tflite::micro::GetEvalOutput(context, node, kOutputTensor);
+  TF_LITE_ENSURE_STATUS(tflite::micro::CreateWritableTensorDimsWithCopy(
+      context, output, output_eval));
+  std::copy_n(output_shape.DimsData(), output_shape.DimensionsCount(),
+              output->dims->data);
+
+  return kTfLiteOk;
+}
+
 TfLiteStatus SpaceToBatchNDPrepare(TfLiteContext* context, TfLiteNode* node) {
   MicroContext* micro_context = GetMicroContext(context);
 
@@ -53,19 +117,47 @@
 
   TfLiteTensor* input =
       micro_context->AllocateTempInputTensor(node, kInputTensor);
+  TF_LITE_ENSURE(context, input != nullptr);
+  TfLiteTensor* block_shape =
+      micro_context->AllocateTempInputTensor(node, kBlockShapeTensor);
+  TF_LITE_ENSURE(context, block_shape != nullptr);
+  TfLiteTensor* padding =
+      micro_context->AllocateTempInputTensor(node, kPaddingTensor);
+  TF_LITE_ENSURE(context, padding != nullptr);
   TfLiteTensor* output =
       micro_context->AllocateTempOutputTensor(node, kOutputTensor);
-  TF_LITE_ENSURE(context, input != nullptr && output != nullptr);
+  TF_LITE_ENSURE(context, output != nullptr);
 
   TF_LITE_ENSURE(context, NumDimensions(input) >= kInputOutputMinDimensionNum);
   TF_LITE_ENSURE(context, NumDimensions(output) >= kInputOutputMinDimensionNum);
   TF_LITE_ENSURE(context, NumDimensions(input) <= kInputOutputMaxDimensionNum);
   TF_LITE_ENSURE(context, NumDimensions(output) <= kInputOutputMaxDimensionNum);
   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
+  TF_LITE_ENSURE(context,
+                 input->type == kTfLiteFloat32 || input->type == kTfLiteInt8);
+
+  TF_LITE_ENSURE(context, node->user_data != nullptr);
+  SpaceToBatchParams& params =
+      *(static_cast<SpaceToBatchParams*>(node->user_data));
+
+  if (input->type == kTfLiteInt8) {
+    TF_LITE_ENSURE(context, input->params.scale == output->params.scale);
+    TF_LITE_ENSURE(context,
+                   input->params.zero_point == output->params.zero_point);
+    params.output_offset = output->params.zero_point;
+  } else {
+    params.output_offset = 0;
+  }
+
+  TfLiteStatus status =
+      ReshapeOutputTensor(context, node, input, block_shape, padding, output);
 
   micro_context->DeallocateTempTfLiteTensor(input);
+  micro_context->DeallocateTempTfLiteTensor(block_shape);
+  micro_context->DeallocateTempTfLiteTensor(padding);
   micro_context->DeallocateTempTfLiteTensor(output);
-  return kTfLiteOk;
+
+  return status;
 }
 
 TfLiteStatus SpaceToBatchNDEval(TfLiteContext* context, TfLiteNode* node) {
@@ -77,8 +169,8 @@
       tflite::micro::GetEvalInput(context, node, kInputTensor);
   const TfLiteEvalTensor* block_shape =
       tflite::micro::GetEvalInput(context, node, kBlockShapeTensor);
-  const TfLiteEvalTensor* crops =
-      tflite::micro::GetEvalInput(context, node, kCropsTensor);
+  const TfLiteEvalTensor* padding =
+      tflite::micro::GetEvalInput(context, node, kPaddingTensor);
   TfLiteEvalTensor* output =
       tflite::micro::GetEvalOutput(context, node, kOutputTensor);
 
@@ -89,8 +181,8 @@
           tflite::micro::GetTensorData<float>(input),
           tflite::micro::GetTensorShape(block_shape),
           tflite::micro::GetTensorData<int32_t>(block_shape),
-          tflite::micro::GetTensorShape(crops),
-          tflite::micro::GetTensorData<int32_t>(crops),
+          tflite::micro::GetTensorShape(padding),
+          tflite::micro::GetTensorData<int32_t>(padding),
           tflite::micro::GetTensorShape(output),
           tflite::micro::GetTensorData<float>(output));
       break;
@@ -100,8 +192,8 @@
           tflite::micro::GetTensorData<int8_t>(input),
           tflite::micro::GetTensorShape(block_shape),
           tflite::micro::GetTensorData<int32_t>(block_shape),
-          tflite::micro::GetTensorShape(crops),
-          tflite::micro::GetTensorData<int32_t>(crops),
+          tflite::micro::GetTensorShape(padding),
+          tflite::micro::GetTensorData<int32_t>(padding),
           tflite::micro::GetTensorShape(output),
           tflite::micro::GetTensorData<int8_t>(output));
       break;
diff --git a/tensorflow/lite/micro/kernels/space_to_batch_nd_test.cc b/tensorflow/lite/micro/kernels/space_to_batch_nd_test.cc
index eae185b..d45a606 100644
--- a/tensorflow/lite/micro/kernels/space_to_batch_nd_test.cc
+++ b/tensorflow/lite/micro/kernels/space_to_batch_nd_test.cc
@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -14,6 +14,8 @@
 ==============================================================================*/
 
 #include <cstdint>
+#include <limits>
+#include <type_traits>
 
 #include "tensorflow/lite/c/builtin_op_data.h"
 #include "tensorflow/lite/c/common.h"
@@ -25,98 +27,160 @@
 namespace testing {
 namespace {
 
-constexpr int kBasicInputOutputSize = 16;
-int basic_input_dims[] = {4, 1, 4, 4, 1};
-const float basic_input[kBasicInputOutputSize] = {
-    1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
-int basic_block_shape_dims[] = {1, 2};
-const int32_t basic_block_shape[] = {2, 2};
-int basic_crops_dims[] = {1, 4};
-const int32_t basic_crops[] = {0, 0, 0, 0};
-int basic_output_dims[] = {4, 4, 2, 2, 1};
-const float basic_golden[kBasicInputOutputSize] = {1, 3, 9,  11, 2, 4, 10, 12,
-                                                   5, 7, 13, 15, 6, 8, 14, 16};
+constexpr float kTestTolerance = 1e-05;
+constexpr int kNumInputs = 3;
+constexpr int kNumOutputs = 1;
+constexpr int kInputTensorIndex = 0;
+constexpr int kBlockShapeTensorIndex = 1;
+constexpr int kPaddingTensorIndex = 2;
+constexpr int kOutputTensorIndex = 3;
 
-template <typename T>
-TfLiteStatus ValidateSpaceToBatchNdGoldens(TfLiteTensor* tensors,
-                                           int tensors_size, const T* golden,
-                                           T* output, int output_size) {
-  int inputs_array_data[] = {3, 0, 1, 2};
-  TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
-  int outputs_array_data[] = {1, 3};
-  TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
+// min/max are used to compute scale, zero-point (asymmetric)
+template <typename T, size_t kInputSize, size_t kOutputSize>
+struct TestQuantParams {
+  // quantization parameters
+  float data_min;              // input data minimum value
+  float data_max;              // input data maximum value
+  T output_data[kOutputSize];  // quantized output storage
+  T input_data[kInputSize];    // quantized input storage
+};
 
-  const TFLMRegistration registration = Register_SPACE_TO_BATCH_ND();
-  micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
+TfLiteStatus ExecuteSpaceToBatchNdTest(TfLiteTensor* tensors,
+                                       int tensors_count) {
+  int kInputArrayData[] = {kNumInputs, kInputTensorIndex,
+                           kBlockShapeTensorIndex, kPaddingTensorIndex};
+  TfLiteIntArray* inputs_array = IntArrayFromInts(kInputArrayData);
+  int kOutputArrayData[] = {kNumOutputs, kOutputTensorIndex};
+  TfLiteIntArray* outputs_array = IntArrayFromInts(kOutputArrayData);
+
+  const TFLMRegistration registration = tflite::Register_SPACE_TO_BATCH_ND();
+  micro::KernelRunner runner(registration, tensors, tensors_count, inputs_array,
                              outputs_array, nullptr);
 
-  TF_LITE_ENSURE_STATUS(runner.InitAndPrepare());
-  TF_LITE_ENSURE_STATUS(runner.Invoke());
-
-  for (int i = 0; i < output_size; ++i) {
-    // TODO(b/158102673): workaround for not having fatal test assertions.
-    TF_LITE_MICRO_EXPECT_EQ(golden[i], output[i]);
-    if (golden[i] != output[i]) {
-      return kTfLiteError;
-    }
+  TfLiteStatus status = runner.InitAndPrepare();
+  if (status != kTfLiteOk) {
+    return status;
   }
-  return kTfLiteOk;
+  status = runner.Invoke();
+
+  return status;
 }
 
 TfLiteStatus TestSpaceToBatchNdFloat(
-    int* input_dims_data, const float* input_data, int* block_shape_dims_data,
-    const int32_t* block_shape_data, int* crops_dims_data,
-    const int32_t* crops_data, int* output_dims_data, const float* golden,
-    float* output_data) {
-  TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
-  TfLiteIntArray* block_shape_dims = IntArrayFromInts(block_shape_dims_data);
-  TfLiteIntArray* crops_dims = IntArrayFromInts(crops_dims_data);
+    int* input_dims_data[kNumInputs], const float* input_data,
+    const int32_t* block_shape_data, const int32_t* padding_data,
+    int* output_dims_data, const float* golden_data, float* output_data) {
+  TfLiteIntArray* input_dims =
+      IntArrayFromInts(input_dims_data[kInputTensorIndex]);
+  TfLiteIntArray* block_shape_dims =
+      IntArrayFromInts(input_dims_data[kBlockShapeTensorIndex]);
+  TfLiteIntArray* padding_dims =
+      IntArrayFromInts(input_dims_data[kPaddingTensorIndex]);
   TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
 
-  constexpr int inputs_size = 3;
-  constexpr int outputs_size = 1;
-  constexpr int tensors_size = inputs_size + outputs_size;
-  TfLiteTensor tensors[tensors_size] = {
-      CreateTensor(input_data, input_dims),
-      CreateTensor(block_shape_data, block_shape_dims),
-      CreateTensor(crops_data, crops_dims),
-      CreateTensor(output_data, output_dims),
-  };
+  constexpr int kTensorsCount = kNumInputs + kNumOutputs;
+  TfLiteTensor tensors[kTensorsCount];
+  tensors[kInputTensorIndex] =
+      tflite::testing::CreateTensor(input_data, input_dims);
+  tensors[kBlockShapeTensorIndex] =
+      tflite::testing::CreateTensor(block_shape_data, block_shape_dims);
+  tensors[kBlockShapeTensorIndex].allocation_type = kTfLiteMmapRo;
+  tensors[kPaddingTensorIndex] =
+      tflite::testing::CreateTensor(padding_data, padding_dims);
+  tensors[kPaddingTensorIndex].allocation_type = kTfLiteMmapRo;
+  tensors[kOutputTensorIndex] =
+      tflite::testing::CreateTensor(output_data, output_dims);
 
-  return ValidateSpaceToBatchNdGoldens(tensors, tensors_size, golden,
-                                       output_data, ElementCount(*output_dims));
+  TfLiteStatus status = ExecuteSpaceToBatchNdTest(tensors, kTensorsCount);
+  if (status != kTfLiteOk) {
+    return status;
+  }
+
+  // check output dimensions (relocated) against original dimensions
+  TF_LITE_MICRO_EXPECT_EQ(output_dims->size,
+                          tensors[kOutputTensorIndex].dims->size);
+  TF_LITE_MICRO_CHECK_FAIL();
+  for (int i = 0; i < output_dims->size; i++) {
+    TF_LITE_MICRO_EXPECT_EQ(output_dims->data[i],
+                            tensors[kOutputTensorIndex].dims->data[i]);
+    // TODO(b/158102673): workaround for not having fatal test assertions.
+    TF_LITE_MICRO_CHECK_FAIL();
+  }
+
+  // check output data against golden
+  const int output_count = ElementCount(*output_dims);
+  for (int i = 0; i < output_count; i++) {
+    TF_LITE_MICRO_EXPECT_NEAR(golden_data[i], output_data[i], kTestTolerance);
+    // TODO(b/158102673): workaround for not having fatal test assertions.
+    TF_LITE_MICRO_CHECK_FAIL();
+  }
+
+  return kTfLiteOk;
 }
 
-template <typename T>
+template <typename T, size_t kInCount, size_t kOutCount>
 TfLiteStatus TestSpaceToBatchNdQuantized(
-    int* input_dims_data, const float* input_data, T* input_quantized,
-    float input_scale, int input_zero_point, int* block_shape_dims_data,
-    const int32_t* block_shape_data, int* crops_dims_data,
-    const int32_t* crops_data, int* output_dims_data, const float* golden,
-    T* golden_quantized, float output_scale, int output_zero_point,
-    T* output_data) {
-  TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
-  TfLiteIntArray* block_shape_dims = IntArrayFromInts(block_shape_dims_data);
-  TfLiteIntArray* crops_dims = IntArrayFromInts(crops_dims_data);
+    TestQuantParams<T, kInCount, kOutCount>& params,
+    int* input_dims_data[kNumInputs], const float* input_data,
+    const int32_t* block_shape_data, const int32_t* padding_data,
+    int* output_dims_data, const float* golden_data) {
+  TfLiteIntArray* input_dims =
+      IntArrayFromInts(input_dims_data[kInputTensorIndex]);
+  TfLiteIntArray* block_shape_dims =
+      IntArrayFromInts(input_dims_data[kBlockShapeTensorIndex]);
+  TfLiteIntArray* padding_dims =
+      IntArrayFromInts(input_dims_data[kPaddingTensorIndex]);
   TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
 
-  constexpr int inputs_size = 3;
-  constexpr int outputs_size = 1;
-  constexpr int tensors_size = inputs_size + outputs_size;
-  TfLiteTensor tensors[tensors_size] = {
-      tflite::testing::CreateQuantizedTensor(input_data, input_quantized,
-                                             input_dims, input_scale,
-                                             input_zero_point),
-      tflite::testing::CreateTensor(block_shape_data, block_shape_dims),
-      tflite::testing::CreateTensor(crops_data, crops_dims),
-      tflite::testing::CreateQuantizedTensor(output_data, output_dims,
-                                             output_scale, output_zero_point),
-  };
-  tflite::Quantize(golden, golden_quantized, ElementCount(*output_dims),
-                   output_scale, output_zero_point);
+  int zero_point =
+      tflite::testing::ZeroPointFromMinMax<T>(params.data_min, params.data_max);
+  float scale =
+      tflite::testing::ScaleFromMinMax<T>(params.data_min, params.data_max);
 
-  return ValidateSpaceToBatchNdGoldens(tensors, tensors_size, golden_quantized,
-                                       output_data, ElementCount(*output_dims));
+  constexpr int kTensorsCount = kNumInputs + kNumOutputs;
+  TfLiteTensor tensors[kTensorsCount];
+  tensors[kInputTensorIndex] = tflite::testing::CreateQuantizedTensor(
+      input_data, params.input_data, input_dims, scale, zero_point);
+  tensors[kBlockShapeTensorIndex] =
+      tflite::testing::CreateTensor(block_shape_data, block_shape_dims);
+  tensors[kBlockShapeTensorIndex].allocation_type = kTfLiteMmapRo;
+  tensors[kPaddingTensorIndex] =
+      tflite::testing::CreateTensor(padding_data, padding_dims);
+  tensors[kPaddingTensorIndex].allocation_type = kTfLiteMmapRo;
+  tensors[kOutputTensorIndex] = tflite::testing::CreateQuantizedTensor(
+      params.output_data, output_dims, scale, zero_point);
+
+  TfLiteStatus status = ExecuteSpaceToBatchNdTest(tensors, kTensorsCount);
+  if (status != kTfLiteOk) {
+    return status;
+  }
+
+  // check output dimensions (relocated) against original dimensions
+  TF_LITE_MICRO_EXPECT_EQ(output_dims->size,
+                          tensors[kOutputTensorIndex].dims->size);
+  TF_LITE_MICRO_CHECK_FAIL();
+  for (int i = 0; i < output_dims->size; i++) {
+    TF_LITE_MICRO_EXPECT_EQ(output_dims->data[i],
+                            tensors[kOutputTensorIndex].dims->data[i]);
+    // TODO(b/158102673): workaround for not having fatal test assertions.
+    TF_LITE_MICRO_CHECK_FAIL();
+  }
+
+  // check output data against golden
+  const int output_count = ElementCount(*output_dims);
+  const float quantization_tolerance =
+      (params.data_max - params.data_min) /
+      (std::numeric_limits<T>::max() - std::numeric_limits<T>::min());
+  for (int i = 0; i < output_count; i++) {
+    float output_dequantized_data =
+        (params.output_data[i] - zero_point) * scale;
+    TF_LITE_MICRO_EXPECT_NEAR(golden_data[i], output_dequantized_data,
+                              quantization_tolerance);
+    // TODO(b/158102673): workaround for not having fatal test assertions.
+    TF_LITE_MICRO_CHECK_FAIL();
+  }
+
+  return kTfLiteOk;
 }
 
 }  // namespace
@@ -125,30 +189,313 @@
 
 TF_LITE_MICRO_TESTS_BEGIN
 
-TF_LITE_MICRO_TEST(SpaceToBatchBasicFloat) {
-  float output[tflite::testing::kBasicInputOutputSize];
-  TF_LITE_MICRO_EXPECT_EQ(
-      kTfLiteOk,
-      tflite::testing::TestSpaceToBatchNdFloat(
-          tflite::testing::basic_input_dims, tflite::testing::basic_input,
-          tflite::testing::basic_block_shape_dims,
-          tflite::testing::basic_block_shape, tflite::testing::basic_crops_dims,
-          tflite::testing::basic_crops, tflite::testing::basic_output_dims,
-          tflite::testing::basic_golden, output));
+TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestInvalidShapeTest) {
+  int kInputDims[] = {4, 1, 3, 3, 1};  // invalid shape
+  int kBlockShapeDims[] = {1, 2};
+  int kPaddingDims[] = {2, 2, 2};
+  int kOutputDims[] = {4, 0, 0, 0, 0};  // placeholder dims, not used
+
+  int* kInputDimsArray[tflite::testing::kNumInputs];
+  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
+  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
+  kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
+
+  constexpr float kInput[] = {1, 2, 3, 4, 5, 6, 7, 8, 9};
+  constexpr int32_t kBlockShape[] = {2, 2};
+  constexpr int32_t kPadding[] = {0, 0, 0, 0};
+  constexpr float kGolden[] = {0};  // placeholder data, not used
+  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
+  float output_data[kOutputCount];
+
+  TF_LITE_MICRO_EXPECT_EQ(kTfLiteError,
+                          tflite::testing::TestSpaceToBatchNdFloat(
+                              kInputDimsArray, kInput, kBlockShape, kPadding,
+                              kOutputDims, kGolden, output_data));
 }
 
-TF_LITE_MICRO_TEST(SpaceToBatchBasicInt8) {
-  int8_t output[tflite::testing::kBasicInputOutputSize];
-  int8_t input_quantized[tflite::testing::kBasicInputOutputSize];
-  int8_t golden_quantized[tflite::testing::kBasicInputOutputSize];
+TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestInvalidOutputShapeTest) {
+  int kInputDims[] = {3, 1, 8, 1};
+  int kBlockShapeDims[] = {1, 1};
+  int kPaddingDims[] = {2, 1, 2};
+  int kOutputDims[] = {3, 1, 6, 1};  // invalid shape
+
+  int* kInputDimsArray[tflite::testing::kNumInputs];
+  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
+  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
+  kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
+
+  constexpr float kInput[] = {1, 1, 1, 1, 1, 1, 1, 1};
+  constexpr int32_t kBlockShape[] = {2};
+  constexpr int32_t kPadding[] = {2, 2};
+  constexpr float kGolden[] = {
+      0, 0, 0, 0, 0, 0,  // placeholder data, not used
+  };
+  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
+  float output_data[kOutputCount];
+
+  TF_LITE_MICRO_EXPECT_EQ(kTfLiteError,
+                          tflite::testing::TestSpaceToBatchNdFloat(
+                              kInputDimsArray, kInput, kBlockShape, kPadding,
+                              kOutputDims, kGolden, output_data));
+}
+
+TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestValidOutputShapeTest) {
+  int kInputDims[] = {3, 1, 8, 1};
+  int kBlockShapeDims[] = {1, 1};
+  int kPaddingDims[] = {2, 1, 2};
+  int kOutputDims[] = {3, 2, 6, 1};
+
+  int* kInputDimsArray[tflite::testing::kNumInputs];
+  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
+  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
+  kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
+
+  constexpr float kInput[] = {1, 1, 1, 1, 1, 1, 1, 1};
+  constexpr int32_t kBlockShape[] = {2};
+  constexpr int32_t kPadding[] = {2, 2};
+  constexpr float kGolden[] = {0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0};
+  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
+  float output_data[kOutputCount];
+
+  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
+                          tflite::testing::TestSpaceToBatchNdFloat(
+                              kInputDimsArray, kInput, kBlockShape, kPadding,
+                              kOutputDims, kGolden, output_data));
+}
+
+TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestSimpleConstTest) {
+  int kInputDims[] = {4, 1, 4, 4, 1};
+  int kBlockShapeDims[] = {1, 2};
+  int kPaddingDims[] = {2, 2, 2};
+  int kOutputDims[] = {4, 4, 2, 2, 1};
+
+  int* kInputDimsArray[tflite::testing::kNumInputs];
+  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
+  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
+  kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
+
+  constexpr float kInput[] = {
+      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
+  };
+  constexpr int32_t kBlockShape[] = {2, 2};
+  constexpr int32_t kPadding[] = {0, 0, 0, 0};
+  constexpr float kGolden[] = {
+      1, 3, 9, 11, 2, 4, 10, 12, 5, 7, 13, 15, 6, 8, 14, 16,
+  };
+  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
+  float output_data[kOutputCount];
+
+  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
+                          tflite::testing::TestSpaceToBatchNdFloat(
+                              kInputDimsArray, kInput, kBlockShape, kPadding,
+                              kOutputDims, kGolden, output_data));
+}
+
+TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestMultipleInputBatchesConstTest) {
+  int kInputDims[] = {4, 2, 2, 4, 1};
+  int kBlockShapeDims[] = {1, 2};
+  int kPaddingDims[] = {2, 2, 2};
+  int kOutputDims[] = {4, 8, 1, 2, 1};
+
+  int* kInputDimsArray[tflite::testing::kNumInputs];
+  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
+  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
+  kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
+
+  constexpr float kInput[] = {
+      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
+  };
+  constexpr int32_t kBlockShape[] = {2, 2};
+  constexpr int32_t kPadding[] = {0, 0, 0, 0};
+  constexpr float kGolden[] = {
+      1, 3, 9, 11, 2, 4, 10, 12, 5, 7, 13, 15, 6, 8, 14, 16,
+  };
+  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
+  float output_data[kOutputCount];
+
+  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
+                          tflite::testing::TestSpaceToBatchNdFloat(
+                              kInputDimsArray, kInput, kBlockShape, kPadding,
+                              kOutputDims, kGolden, output_data));
+}
+
+TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestSimplePaddingConstTest) {
+  int kInputDims[] = {4, 1, 5, 2, 1};
+  int kBlockShapeDims[] = {1, 2};
+  int kPaddingDims[] = {2, 2, 2};
+  int kOutputDims[] = {4, 6, 2, 2, 1};
+
+  int* kInputDimsArray[tflite::testing::kNumInputs];
+  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
+  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
+  kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
+
+  constexpr float kInput[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
+  constexpr int32_t kBlockShape[] = {3, 2};
+  constexpr int32_t kPadding[] = {1, 0, 2, 0};
+  constexpr float kGolden[] = {
+      0, 0, 0, 5, 0, 0, 0, 6, 0, 1, 0, 7, 0, 2, 0, 8, 0, 3, 0, 9, 0, 4, 0, 10,
+  };
+  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
+  float output_data[kOutputCount];
+
+  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
+                          tflite::testing::TestSpaceToBatchNdFloat(
+                              kInputDimsArray, kInput, kBlockShape, kPadding,
+                              kOutputDims, kGolden, output_data));
+}
+
+TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestComplexPaddingConstTest) {
+  int kInputDims[] = {4, 1, 4, 2, 1};
+  int kBlockShapeDims[] = {1, 2};
+  int kPaddingDims[] = {2, 2, 2};
+  int kOutputDims[] = {4, 6, 2, 4, 1};
+
+  int* kInputDimsArray[tflite::testing::kNumInputs];
+  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
+  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
+  kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
+
+  constexpr float kInput[] = {1, 2, 3, 4, 5, 6, 7, 8};
+  constexpr int32_t kBlockShape[] = {3, 2};
+  constexpr int32_t kPadding[] = {1, 1, 2, 4};
+  constexpr float kGolden[] = {
+      0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 1, 0, 0, 0, 7, 0, 0,
+      0, 2, 0, 0, 0, 8, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0,
+  };
+  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
+  float output_data[kOutputCount];
+
+  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
+                          tflite::testing::TestSpaceToBatchNdFloat(
+                              kInputDimsArray, kInput, kBlockShape, kPadding,
+                              kOutputDims, kGolden, output_data));
+}
+
+TF_LITE_MICRO_TEST(QuantizedSpaceToBatchNDOpTestSimplePaddingConstTestInt8) {
+  int kInputDims[] = {4, 1, 5, 2, 1};
+  int kBlockShapeDims[] = {1, 2};
+  int kPaddingDims[] = {2, 2, 2};
+  int kOutputDims[] = {4, 6, 2, 2, 1};
+
+  int* kInputDimsArray[tflite::testing::kNumInputs];
+  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
+  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
+  kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
+
+  constexpr float kInput[] = {
+      -0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, -0.9, 0.1,
+  };
+  constexpr int kInputCount = std::extent<decltype(kInput)>::value;
+  constexpr int32_t kBlockShape[] = {3, 2};
+  constexpr int32_t kPadding[] = {1, 0, 2, 0};
+  constexpr float kGolden[] = {
+      0, 0,   0, -0.5, 0, 0,    0, 0.6,  0, -0.1, 0, -0.7,
+      0, 0.2, 0, 0.8,  0, -0.3, 0, -0.9, 0, 0.4,  0, 0.1,
+  };
+  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
+
+  tflite::testing::TestQuantParams<int8_t, kInputCount, kOutputCount> params = {
+      -1,
+      std::numeric_limits<int8_t>::max() /
+          static_cast<float>(std::numeric_limits<int8_t>::max() + 1),
+      {},
+      {}};
+
   TF_LITE_MICRO_EXPECT_EQ(
-      kTfLiteOk,
-      tflite::testing::TestSpaceToBatchNdQuantized(
-          tflite::testing::basic_input_dims, tflite::testing::basic_input,
-          input_quantized, 1.0f, 0, tflite::testing::basic_block_shape_dims,
-          tflite::testing::basic_block_shape, tflite::testing::basic_crops_dims,
-          tflite::testing::basic_crops, tflite::testing::basic_output_dims,
-          tflite::testing::basic_golden, golden_quantized, 1.0f, 0, output));
+      kTfLiteOk, tflite::testing::TestSpaceToBatchNdQuantized(
+                     params, kInputDimsArray, kInput, kBlockShape, kPadding,
+                     kOutputDims, kGolden));
+}
+
+TF_LITE_MICRO_TEST(QuantizedSpaceToBatchNDOpTestComplexPaddingConstTest) {
+  int kInputDims[] = {4, 1, 4, 2, 1};
+  int kBlockShapeDims[] = {1, 2};
+  int kPaddingDims[] = {2, 2, 2};
+  int kOutputDims[] = {4, 6, 2, 4, 1};
+
+  int* kInputDimsArray[tflite::testing::kNumInputs];
+  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
+  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
+  kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
+
+  constexpr float kInput[] = {
+      -0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8,
+  };
+  constexpr int kInputCount = std::extent<decltype(kInput)>::value;
+  constexpr int32_t kBlockShape[] = {3, 2};
+  constexpr int32_t kPadding[] = {1, 1, 2, 4};
+  constexpr float kGolden[] = {
+      0, 0,    0, 0, 0, -0.5, 0, 0, 0, 0,   0, 0, 0, 0.6, 0, 0,
+      0, -0.1, 0, 0, 0, -0.7, 0, 0, 0, 0.2, 0, 0, 0, 0.8, 0, 0,
+      0, -0.3, 0, 0, 0, 0,    0, 0, 0, 0.4, 0, 0, 0, 0,   0, 0,
+  };
+  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
+
+  tflite::testing::TestQuantParams<int8_t, kInputCount, kOutputCount> params = {
+      -1, 1, {}, {}};
+
+  TF_LITE_MICRO_EXPECT_EQ(
+      kTfLiteOk, tflite::testing::TestSpaceToBatchNdQuantized(
+                     params, kInputDimsArray, kInput, kBlockShape, kPadding,
+                     kOutputDims, kGolden));
+}
+
+TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestSimple3DConstTest) {
+  int kInputDims[] = {3, 1, 4, 4};
+  int kBlockShapeDims[] = {1, 1};
+  int kPaddingDims[] = {2, 1, 2};
+  int kOutputDims[] = {3, 2, 2, 4};
+
+  int* kInputDimsArray[tflite::testing::kNumInputs];
+  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
+  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
+  kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
+
+  constexpr float kInput[] = {
+      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
+  };
+  constexpr int32_t kBlockShape[] = {2};
+  constexpr int32_t kPadding[] = {0, 0};
+  constexpr float kGolden[] = {
+      1, 2, 3, 4, 9, 10, 11, 12, 5, 6, 7, 8, 13, 14, 15, 16,
+  };
+  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
+  float output_data[kOutputCount];
+
+  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
+                          tflite::testing::TestSpaceToBatchNdFloat(
+                              kInputDimsArray, kInput, kBlockShape, kPadding,
+                              kOutputDims, kGolden, output_data));
+}
+
+TF_LITE_MICRO_TEST(SpaceToBatchNDOpTestSimple3DPaddingConstTest) {
+  int kInputDims[] = {3, 1, 4, 4};
+  int kBlockShapeDims[] = {1, 1};
+  int kPaddingDims[] = {2, 1, 2};
+  int kOutputDims[] = {3, 2, 4, 4};
+
+  int* kInputDimsArray[tflite::testing::kNumInputs];
+  kInputDimsArray[tflite::testing::kInputTensorIndex] = kInputDims;
+  kInputDimsArray[tflite::testing::kBlockShapeTensorIndex] = kBlockShapeDims;
+  kInputDimsArray[tflite::testing::kPaddingTensorIndex] = kPaddingDims;
+
+  constexpr float kInput[] = {
+      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
+  };
+  constexpr int32_t kBlockShape[] = {2};
+  constexpr int32_t kPadding[] = {2, 2};
+  constexpr float kGolden[] = {
+      0, 0, 0, 0, 1, 2, 3, 4, 9,  10, 11, 12, 0, 0, 0, 0,
+      0, 0, 0, 0, 5, 6, 7, 8, 13, 14, 15, 16, 0, 0, 0, 0,
+  };
+  constexpr int kOutputCount = std::extent<decltype(kGolden)>::value;
+  float output_data[kOutputCount];
+
+  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
+                          tflite::testing::TestSpaceToBatchNdFloat(
+                              kInputDimsArray, kInput, kBlockShape, kPadding,
+                              kOutputDims, kGolden, output_data));
 }
 
 TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/lite/micro/kernels/testdata/conv_test_data.cc b/tensorflow/lite/micro/kernels/testdata/conv_test_data.cc
index 094aab6..6ab2e9c 100644
--- a/tensorflow/lite/micro/kernels/testdata/conv_test_data.cc
+++ b/tensorflow/lite/micro/kernels/testdata/conv_test_data.cc
@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -303,7 +303,7 @@
                                         55295,    184082, 75855,  233991};
 
 // Kernel Conv Test Case: Int8Filter8x3x3x3PerChannelScaleRelu6ShouldMatchGolden
-const int8_t kConvGoldenOutput1x16x16x8[1 * 16 * 16 * 8] = {
+const int8_t kConvGoldenOutput1x15x15x8[1 * 15 * 15 * 8] = {
     -128, -21,  -81,  67,  -20,  -109, -29,  4,   -128, -19,  -81,  68,
     -19,  -109, -31,  3,   -128, -19,  -80,  68,  -20,  -109, -32,  2,
     -128, -19,  -80,  68,  -20,  -109, -32,  1,   -128, -19,  -80,  68,
@@ -314,28 +314,26 @@
     -36,  -108, -41,  0,   -128, -17,  -78,  69,  -20,  -108, -37,  -3,
     -128, -18,  -77,  68,  -21,  -107, -37,  -3,  -128, -18,  -77,  69,
     -20,  -107, -38,  -4,  -128, -18,  -77,  69,  -22,  -107, -38,  -4,
-    -128, -19,  -82,  69,  -18,  -110, -31,  -4,  -128, -20,  -81,  67,
-    -19,  -109, -30,  3,   -128, -19,  -81,  68,  -19,  -109, -31,  2,
-    -128, -19,  -80,  68,  -20,  -109, -31,  2,   -128, -19,  -80,  68,
-    -20,  -109, -33,  1,   -128, -20,  -79,  68,  -19,  -108, -33,  1,
-    -128, -19,  -88,  67,  -19,  -113, -34,  0,   -128, -19,  -99,  66,
-    -13,  -118, -1,   26,  -128, -19,  -120, 66,  32,   -128, 2,    64,
-    -128, -20,  -124, 67,  8,    -128, 13,   76,  -128, -19,  -98,  68,
-    1,    -118, -17,  31,  -128, -18,  -89,  67,  -12,  -113, -33,  25,
-    -128, -17,  -76,  69,  -22,  -107, -37,  -5,  -128, -17,  -77,  68,
-    -22,  -107, -38,  -4,  -128, -17,  -77,  69,  -22,  -107, -38,  -5,
-    -128, -18,  -76,  69,  -23,  -107, -38,  -5,  -128, -19,  -82,  68,
-    -18,  -110, -31,  -5,  -128, -19,  -81,  68,  -20,  -109, -31,  2,
-    -128, -19,  -80,  68,  -21,  -109, -32,  1,   -128, -20,  -79,  67,
-    -20,  -109, -32,  1,   -128, -21,  -79,  67,  -20,  -108, -32,  0,
-    -128, -20,  -79,  67,  -21,  -108, -33,  -1,  -128, -20,  -86,  67,
-    -20,  -113, -12,  -1,  -128, -21,  -93,  66,  -15,  -115, -10,  21,
-    -128, -23,  -113, 65,  -16,  -126, 77,   46,  -128, -23,  -120, 65,
-    -9,   -128, 74,   71,  -128, -26,  -97,  60,  -26,  -118, -12,  29,
-    -128, -27,  -89,  61,  -11,  -114, 1,    28,  -128, -32,  -77,  58,
-    -15,  -108, -42,  -2,  -128, -30,  -78,  58,  -16,  -108, -40,  0,
-    -128, -26,  -77,  60,  -18,  -108, -40,  -2,  -128, -20,  -77,  66,
-    -24,  -107, -41,  -4,  -128, -19,  -82,  68,  -17,  -110, -32,  -5,
+    -128, -20,  -81,  67,  -19,  -109, -30,  3,   -128, -19,  -81,  68,
+    -19,  -109, -31,  2,   -128, -19,  -80,  68,  -20,  -109, -31,  2,
+    -128, -19,  -80,  68,  -20,  -109, -33,  1,   -128, -20,  -79,  68,
+    -19,  -108, -33,  1,   -128, -19,  -88,  67,  -19,  -113, -34,  0,
+    -128, -19,  -99,  66,  -13,  -118, -1,   26,  -128, -19,  -120, 66,
+    32,   -128, 2,    64,  -128, -20,  -124, 67,  8,    -128, 13,   76,
+    -128, -19,  -98,  68,  1,    -118, -17,  31,  -128, -18,  -89,  67,
+    -12,  -113, -33,  25,  -128, -17,  -76,  69,  -22,  -107, -37,  -5,
+    -128, -17,  -77,  68,  -22,  -107, -38,  -4,  -128, -17,  -77,  69,
+    -22,  -107, -38,  -5,  -128, -18,  -76,  69,  -23,  -107, -38,  -5,
+    -128, -19,  -81,  68,  -20,  -109, -31,  2,   -128, -19,  -80,  68,
+    -21,  -109, -32,  1,   -128, -20,  -79,  67,  -20,  -109, -32,  1,
+    -128, -21,  -79,  67,  -20,  -108, -32,  0,   -128, -20,  -79,  67,
+    -21,  -108, -33,  -1,  -128, -20,  -86,  67,  -20,  -113, -12,  -1,
+    -128, -21,  -93,  66,  -15,  -115, -10,  21,  -128, -23,  -113, 65,
+    -16,  -126, 77,   46,  -128, -23,  -120, 65,  -9,   -128, 74,   71,
+    -128, -26,  -97,  60,  -26,  -118, -12,  29,  -128, -27,  -89,  61,
+    -11,  -114, 1,    28,  -128, -32,  -77,  58,  -15,  -108, -42,  -2,
+    -128, -30,  -78,  58,  -16,  -108, -40,  0,   -128, -26,  -77,  60,
+    -18,  -108, -40,  -2,  -128, -20,  -77,  66,  -24,  -107, -41,  -4,
     -128, -20,  -81,  67,  -20,  -109, -32,  3,   -128, -20,  -82,  67,
     -14,  -110, -34,  0,   -128, -20,  -90,  67,  -62,  -113, -67,  16,
     -128, -24,  -80,  64,  -25,  -109, -38,  1,   -128, -25,  -80,  61,
@@ -346,28 +344,26 @@
     -8,   -112, -19,  16,  -128, -42,  -83,  52,  -26,  -111, -42,  9,
     -128, -40,  -82,  52,  -20,  -111, -38,  8,   -128, -38,  -82,  52,
     -18,  -111, -38,  7,   -128, -29,  -81,  54,  -22,  -110, -42,  9,
-    -128, -19,  -81,  68,  -18,  -110, -32,  -7,  -128, -31,  -80,  58,
-    -17,  -109, -33,  1,   -128, -24,  -107, 60,  0,    -123, -46,  32,
-    -128, -23,  -121, 66,  -27,  -128, 24,   76,  -128, -34,  -92,  57,
-    -9,   -115, -39,  32,  -128, -35,  -86,  53,  -16,  -113, -29,  10,
-    -128, -30,  -97,  58,  -25,  -117, -17,  24,  -128, -27,  -97,  62,
-    -46,  -117, -15,  28,  -128, -26,  -91,  66,  -4,   -114, -44,  23,
-    -128, -24,  -87,  68,  15,   -112, -71,  17,  -128, -25,  -85,  66,
-    10,   -111, -71,  13,  -128, -29,  -83,  63,  -38,  -110, -39,  5,
-    -128, -36,  -82,  59,  -88,  -110, -14,  -2,  -128, -41,  -90,  52,
-    25,   -115, -9,   19,  -128, -43,  -94,  47,  1,    -117, -21,  34,
-    -128, -32,  -84,  52,  -13,  -111, -16,  15,  -128, -24,  -83,  63,
-    -20,  -111, -39,  -1,  -128, -38,  -85,  54,  -23,  -112, -36,  6,
-    -128, -32,  -116, 54,  -12,  -128, 61,   50,  -128, -24,  -127, 63,
-    -12,  -128, 68,   87,  -128, -24,  -106, 63,  -47,  -122, -6,   44,
-    -128, -22,  -102, 63,  8,    -120, -17,  38,  -128, -20,  -99,  66,
-    6,    -118, 30,   38,  -128, -20,  -90,  66,  -8,   -114, 35,   21,
-    -128, -20,  -89,  66,  -16,  -114, 18,   17,  -128, -20,  -89,  65,
-    -15,  -113, 14,   25,  -128, -18,  -84,  66,  -14,  -111, 22,   9,
-    -128, -18,  -89,  66,  -4,   -113, 8,    15,  -128, -20,  -89,  66,
-    -16,  -113, -27,  20,  -128, -27,  -90,  65,  9,    -114, -30,  14,
-    -128, -37,  -90,  56,  -63,  -114, -5,   20,  -128, -43,  -84,  50,
-    -21,  -112, -32,  8,   -128, -38,  -89,  49,  -17,  -114, -24,  16,
+    -128, -31,  -80,  58,  -17,  -109, -33,  1,   -128, -24,  -107, 60,
+    0,    -123, -46,  32,  -128, -23,  -121, 66,  -27,  -128, 24,   76,
+    -128, -34,  -92,  57,  -9,   -115, -39,  32,  -128, -35,  -86,  53,
+    -16,  -113, -29,  10,  -128, -30,  -97,  58,  -25,  -117, -17,  24,
+    -128, -27,  -97,  62,  -46,  -117, -15,  28,  -128, -26,  -91,  66,
+    -4,   -114, -44,  23,  -128, -24,  -87,  68,  15,   -112, -71,  17,
+    -128, -25,  -85,  66,  10,   -111, -71,  13,  -128, -29,  -83,  63,
+    -38,  -110, -39,  5,   -128, -36,  -82,  59,  -88,  -110, -14,  -2,
+    -128, -41,  -90,  52,  25,   -115, -9,   19,  -128, -43,  -94,  47,
+    1,    -117, -21,  34,  -128, -32,  -84,  52,  -13,  -111, -16,  15,
+    -128, -38,  -85,  54,  -23,  -112, -36,  6,   -128, -32,  -116, 54,
+    -12,  -128, 61,   50,  -128, -24,  -127, 63,  -12,  -128, 68,   87,
+    -128, -24,  -106, 63,  -47,  -122, -6,   44,  -128, -22,  -102, 63,
+    8,    -120, -17,  38,  -128, -20,  -99,  66,  6,    -118, 30,   38,
+    -128, -20,  -90,  66,  -8,   -114, 35,   21,  -128, -20,  -89,  66,
+    -16,  -114, 18,   17,  -128, -20,  -89,  65,  -15,  -113, 14,   25,
+    -128, -18,  -84,  66,  -14,  -111, 22,   9,   -128, -18,  -89,  66,
+    -4,   -113, 8,    15,  -128, -20,  -89,  66,  -16,  -113, -27,  20,
+    -128, -27,  -90,  65,  9,    -114, -30,  14,  -128, -37,  -90,  56,
+    -63,  -114, -5,   20,  -128, -43,  -84,  50,  -21,  -112, -32,  8,
     -128, -41,  -96,  49,  -13,  -118, -21,  31,  -128, -37,  -101, 53,
     -33,  -120, 23,   28,  -128, -33,  -116, 54,  -24,  -127, 56,   61,
     -128, -22,  -108, 63,  -15,  -123, 45,   53,  -128, -20,  -89,  66,
@@ -378,28 +374,26 @@
     -22,  -105, -30,  -14, -128, -20,  -78,  66,  -30,  -108, 1,    -4,
     -128, -19,  -88,  66,  -9,   -113, 8,    10,  -128, -25,  -96,  65,
     -11,  -117, -5,   27,  -128, -43,  -96,  53,  -6,   -118, -26,  31,
-    -128, -41,  -97,  48,  -13,  -118, -13,  34,  -128, -46,  -106, 44,
-    -8,   -123, 9,    46,  -128, -38,  -109, 48,  -17,  -124, -10,  45,
-    -128, -26,  -110, 59,  -12,  -124, 28,   57,  -128, -20,  -89,  68,
-    -22,  -113, 15,   23,  -128, -21,  -74,  69,  -14,  -106, -41,  -4,
-    -128, -24,  -68,  63,  -2,   -103, -53,  -12, -128, -23,  -73,  67,
-    -9,   -105, -54,  -12, -128, -22,  -76,  67,  -12,  -107, -45,  -13,
-    -128, -21,  -84,  67,  -44,  -110, -51,  21,  -128, -23,  -69,  67,
-    -21,  -104, -40,  -21, -128, -21,  -68,  67,  -16,  -103, -64,  -13,
-    -128, -25,  -64,  65,  -14,  -101, -54,  -21, -128, -18,  -70,  69,
-    -18,  -104, -28,  -19, -128, -20,  -86,  66,  -19,  -112, 5,    5,
-    -128, -31,  -102, 64,  -16,  -120, 16,   34,  -128, -41,  -104, 49,
-    -25,  -121, 5,    49,  -128, -42,  -99,  48,  -43,  -120, 33,   24,
-    -128, -36,  -116, 51,  23,   -128, 52,   69,  -128, -24,  -98,  63,
-    -31,  -118, 31,   43,  -128, -20,  -76,  69,  -28,  -106, -40,  -5,
-    -128, -24,  -71,  64,  -13,  -105, -45,  -7,  -128, -17,  -67,  70,
-    -22,  -102, -59,  -20, -128, -23,  -72,  67,  -21,  -105, -38,  -13,
-    -128, -21,  -77,  67,  -23,  -108, -54,  -9,  -128, -21,  -86,  67,
-    -51,  -112, 15,   22,  -128, -23,  -71,  67,  -32,  -105, -67,  -17,
-    -128, -22,  -67,  66,  -15,  -102, -46,  -13, -128, -20,  -60,  69,
-    -21,  -99,  -60,  -30, -128, -23,  -61,  65,  0,    -100, -61,  -25,
-    -128, -18,  -70,  69,  -35,  -104, -31,  -22, -128, -25,  -94,  67,
-    -21,  -116, 15,   5,   -128, -31,  -103, 59,  -47,  -120, 24,   42,
+    -128, -46,  -106, 44,  -8,   -123, 9,    46,  -128, -38,  -109, 48,
+    -17,  -124, -10,  45,  -128, -26,  -110, 59,  -12,  -124, 28,   57,
+    -128, -20,  -89,  68,  -22,  -113, 15,   23,  -128, -21,  -74,  69,
+    -14,  -106, -41,  -4,  -128, -24,  -68,  63,  -2,   -103, -53,  -12,
+    -128, -23,  -73,  67,  -9,   -105, -54,  -12, -128, -22,  -76,  67,
+    -12,  -107, -45,  -13, -128, -21,  -84,  67,  -44,  -110, -51,  21,
+    -128, -23,  -69,  67,  -21,  -104, -40,  -21, -128, -21,  -68,  67,
+    -16,  -103, -64,  -13, -128, -25,  -64,  65,  -14,  -101, -54,  -21,
+    -128, -18,  -70,  69,  -18,  -104, -28,  -19, -128, -20,  -86,  66,
+    -19,  -112, 5,    5,   -128, -31,  -102, 64,  -16,  -120, 16,   34,
+    -128, -42,  -99,  48,  -43,  -120, 33,   24,  -128, -36,  -116, 51,
+    23,   -128, 52,   69,  -128, -24,  -98,  63,  -31,  -118, 31,   43,
+    -128, -20,  -76,  69,  -28,  -106, -40,  -5,  -128, -24,  -71,  64,
+    -13,  -105, -45,  -7,  -128, -17,  -67,  70,  -22,  -102, -59,  -20,
+    -128, -23,  -72,  67,  -21,  -105, -38,  -13, -128, -21,  -77,  67,
+    -23,  -108, -54,  -9,  -128, -21,  -86,  67,  -51,  -112, 15,   22,
+    -128, -23,  -71,  67,  -32,  -105, -67,  -17, -128, -22,  -67,  66,
+    -15,  -102, -46,  -13, -128, -20,  -60,  69,  -21,  -99,  -60,  -30,
+    -128, -23,  -61,  65,  0,    -100, -61,  -25, -128, -18,  -70,  69,
+    -35,  -104, -31,  -22, -128, -25,  -94,  67,  -21,  -116, 15,   5,
     -128, -41,  -92,  50,  -21,  -116, -12,  21,  -128, -30,  -99,  53,
     -25,  -119, 30,   28,  -128, -22,  -88,  66,  -29,  -113, -13,  23,
     -128, -20,  -76,  69,  -11,  -107, -37,  -6,  -128, -20,  -72,  67,
@@ -410,28 +404,26 @@
     -27,  -98,  -65,  -25, -128, -20,  -50,  69,  -36,  -94,  -73,  -47,
     -128, -20,  -52,  68,  -39,  -95,  -74,  -46, -128, -19,  -60,  69,
     -44,  -99,  -62,  -36, -128, -22,  -86,  69,  -22,  -113, -24,  -15,
-    -128, -34,  -100, 60,  -27,  -119, 3,    39,  -128, -45,  -98,  48,
-    -13,  -119, -5,   31,  -128, -32,  -101, 52,  1,    -120, -17,  34,
-    -128, -20,  -90,  66,  -23,  -113, -28,  27,  -128, -22,  -75,  69,
-    -21,  -107, -52,  -13, -128, -23,  -74,  64,  -45,  -105, -42,  -7,
-    -128, -18,  -68,  69,  -21,  -103, -54,  -18, -128, -19,  -67,  69,
-    -23,  -103, -53,  -19, -128, -19,  -75,  69,  -23,  -107, -54,  -13,
-    -128, -20,  -85,  67,  -29,  -111, -1,   20,  -128, -19,  -68,  69,
-    -36,  -103, -68,  -18, -128, -20,  -63,  69,  -23,  -100, -70,  -23,
-    -128, -19,  -59,  69,  -6,   -98,  -99,  -31, -128, -20,  -61,  68,
-    -15,  -100, -88,  -31, -128, -19,  -67,  69,  -20,  -102, -73,  -23,
-    -128, -22,  -92,  69,  -19,  -115, -38,  1,   -128, -37,  -102, 57,
-    -17,  -120, -2,   45,  -128, -42,  -91,  50,  -29,  -116, 10,   20,
-    -128, -37,  -104, 50,  -2,   -122, 18,   37,  -128, -23,  -101, 61,
-    -5,   -119, -13,  48,  -128, -21,  -78,  70,  -33,  -107, -46,  -3,
-    -128, -26,  -73,  65,  -31,  -106, -37,  -11, -128, -22,  -69,  66,
-    -13,  -103, -55,  -12, -128, -19,  -66,  69,  -21,  -102, -55,  -21,
-    -128, -20,  -72,  69,  -28,  -105, -25,  -15, -128, -22,  -75,  69,
-    -29,  -106, -16,  -6,  -128, -21,  -70,  69,  -19,  -104, -35,  -9,
-    -128, -21,  -65,  69,  -23,  -101, -60,  -23, -128, -23,  -68,  68,
-    -25,  -103, -61,  -23, -128, -24,  -68,  63,  -12,  -103, -55,  -15,
-    -128, -20,  -78,  69,  -33,  -108, -55,  -18, -128, -25,  -107, 66,
-    -2,   -123, -1,   43,  -128, -37,  -102, 51,  -10,  -120, 17,   46,
+    -128, -45,  -98,  48,  -13,  -119, -5,   31,  -128, -32,  -101, 52,
+    1,    -120, -17,  34,  -128, -20,  -90,  66,  -23,  -113, -28,  27,
+    -128, -22,  -75,  69,  -21,  -107, -52,  -13, -128, -23,  -74,  64,
+    -45,  -105, -42,  -7,  -128, -18,  -68,  69,  -21,  -103, -54,  -18,
+    -128, -19,  -67,  69,  -23,  -103, -53,  -19, -128, -19,  -75,  69,
+    -23,  -107, -54,  -13, -128, -20,  -85,  67,  -29,  -111, -1,   20,
+    -128, -19,  -68,  69,  -36,  -103, -68,  -18, -128, -20,  -63,  69,
+    -23,  -100, -70,  -23, -128, -19,  -59,  69,  -6,   -98,  -99,  -31,
+    -128, -20,  -61,  68,  -15,  -100, -88,  -31, -128, -19,  -67,  69,
+    -20,  -102, -73,  -23, -128, -22,  -92,  69,  -19,  -115, -38,  1,
+    -128, -42,  -91,  50,  -29,  -116, 10,   20,  -128, -37,  -104, 50,
+    -2,   -122, 18,   37,  -128, -23,  -101, 61,  -5,   -119, -13,  48,
+    -128, -21,  -78,  70,  -33,  -107, -46,  -3,  -128, -26,  -73,  65,
+    -31,  -106, -37,  -11, -128, -22,  -69,  66,  -13,  -103, -55,  -12,
+    -128, -19,  -66,  69,  -21,  -102, -55,  -21, -128, -20,  -72,  69,
+    -28,  -105, -25,  -15, -128, -22,  -75,  69,  -29,  -106, -16,  -6,
+    -128, -21,  -70,  69,  -19,  -104, -35,  -9,  -128, -21,  -65,  69,
+    -23,  -101, -60,  -23, -128, -23,  -68,  68,  -25,  -103, -61,  -23,
+    -128, -24,  -68,  63,  -12,  -103, -55,  -15, -128, -20,  -78,  69,
+    -33,  -108, -55,  -18, -128, -25,  -107, 66,  -2,   -123, -1,   43,
     -128, -40,  -81,  52,  -33,  -110, -21,  3,   -128, -43,  -96,  50,
     -27,  -118, 7,    17,  -128, -32,  -112, 53,  -21,  -125, 15,   57,
     -128, -23,  -94,  66,  -11,  -115, -42,  34,  -128, -20,  -76,  70,
@@ -442,39 +434,26 @@
     -34,  -104, -64,  -17, -128, -23,  -69,  64,  -3,   -103, -48,  -15,
     -128, -18,  -74,  69,  -12,  -106, -65,  -16, -128, -22,  -93,  67,
     -18,  -116, -31,  20,  -128, -26,  -90,  62,  -16,  -114, 49,   23,
-    -128, -24,  -86,  62,  -74,  -112, 2,    0,   -128, -44,  -87,  49,
-    -18,  -113, -34,  14,  -128, -45,  -93,  49,  -11,  -117, -3,   22,
-    -128, -40,  -110, 48,  -18,  -125, 49,   48,  -128, -30,  -113, 57,
-    -7,   -125, 5,    61,  -128, -23,  -93,  66,  -13,  -115, -47,  29,
-    -128, -20,  -77,  69,  -14,  -107, -67,  0,   -128, -20,  -73,  69,
-    -27,  -105, -48,  -8,  -128, -26,  -70,  64,  -47,  -104, -43,  -14,
-    -128, -22,  -72,  68,  7,    -105, -52,  -6,  -128, -24,  -70,  64,
-    -50,  -104, -46,  -15, -128, -19,  -73,  69,  -30,  -105, -41,  -12,
-    -128, -19,  -81,  70,  -40,  -109, -85,  -9,  -128, -22,  -97,  67,
-    21,   -118, -45,  27,  -128, -29,  -93,  59,  -37,  -115, 9,    31,
-    -128, -31,  -81,  59,  3,    -110, -54,  5,   -128, -35,  -87,  55,
-    -9,   -113, -29,  10,  -128, -44,  -88,  47,  -11,  -114, 1,    21,
-    -128, -38,  -85,  52,  -29,  -112, -4,   10,  -128, -39,  -92,  50,
-    -27,  -116, 18,   16,  -128, -39,  -112, 46,  -14,  -126, 57,   50,
-    -128, -31,  -117, 57,  9,    -128, 4,    67,  -128, -25,  -105, 63,
-    -17,  -121, -46,  47,  -128, -22,  -91,  67,  -2,   -114, -50,  20,
-    -128, -21,  -88,  68,  -35,  -112, -63,  12,  -128, -21,  -89,  69,
-    -25,  -113, -70,  12,  -128, -21,  -94,  68,  -30,  -116, -71,  18,
-    -128, -22,  -99,  66,  6,    -119, -57,  32,  -128, -26,  -101, 60,
-    24,   -120, 19,   44,  -128, -35,  -94,  54,  -27,  -117, 28,   27,
-    -128, -41,  -89,  50,  -24,  -114, -16,  19,  -128, -41,  -89,  52,
-    -12,  -114, -24,  19,  -128, -37,  -91,  52,  -17,  -115, -19,  21,
-    -128, -37,  -85,  55,  -39,  -112, -37,  8,   -128, -38,  -84,  55,
-    -32,  -111, -42,  8,   -128, -37,  -86,  54,  -31,  -113, -32,  10,
-    -128, -37,  -95,  52,  -36,  -117, 10,   19,  -128, -35,  -110, 52,
-    -10,  -125, 71,   51,  -128, -28,  -110, 56,  -21,  -124, 71,   55,
-    -128, -29,  -105, 57,  37,   -122, 21,   48,  -128, -32,  -106, 57,
-    30,   -123, 32,   52,  -128, -30,  -107, 58,  25,   -123, 38,   51,
-    -128, -30,  -105, 57,  -9,   -122, 49,   45,  -128, -30,  -98,  57,
-    -59,  -118, 33,   30,  -128, -33,  -90,  57,  -46,  -114, -13,  18,
-    -128, -37,  -90,  56,  -15,  -115, -27,  16,  -128, -42,  -96,  50,
-    12,   -118, -6,   33,  -128, -43,  -95,  49,  7,    -117, -4,   32,
-    -128, -37,  -95,  53,  -7,   -117, -1,   30};
+    -128, -44,  -87,  49,  -18,  -113, -34,  14,  -128, -45,  -93,  49,
+    -11,  -117, -3,   22,  -128, -40,  -110, 48,  -18,  -125, 49,   48,
+    -128, -30,  -113, 57,  -7,   -125, 5,    61,  -128, -23,  -93,  66,
+    -13,  -115, -47,  29,  -128, -20,  -77,  69,  -14,  -107, -67,  0,
+    -128, -20,  -73,  69,  -27,  -105, -48,  -8,  -128, -26,  -70,  64,
+    -47,  -104, -43,  -14, -128, -22,  -72,  68,  7,    -105, -52,  -6,
+    -128, -24,  -70,  64,  -50,  -104, -46,  -15, -128, -19,  -73,  69,
+    -30,  -105, -41,  -12, -128, -19,  -81,  70,  -40,  -109, -85,  -9,
+    -128, -22,  -97,  67,  21,   -118, -45,  27,  -128, -29,  -93,  59,
+    -37,  -115, 9,    31,  -128, -31,  -81,  59,  3,    -110, -54,  5,
+    -128, -44,  -88,  47,  -11,  -114, 1,    21,  -128, -38,  -85,  52,
+    -29,  -112, -4,   10,  -128, -39,  -92,  50,  -27,  -116, 18,   16,
+    -128, -39,  -112, 46,  -14,  -126, 57,   50,  -128, -31,  -117, 57,
+    9,    -128, 4,    67,  -128, -25,  -105, 63,  -17,  -121, -46,  47,
+    -128, -22,  -91,  67,  -2,   -114, -50,  20,  -128, -21,  -88,  68,
+    -35,  -112, -63,  12,  -128, -21,  -89,  69,  -25,  -113, -70,  12,
+    -128, -21,  -94,  68,  -30,  -116, -71,  18,  -128, -22,  -99,  66,
+    6,    -119, -57,  32,  -128, -26,  -101, 60,  24,   -120, 19,   44,
+    -128, -35,  -94,  54,  -27,  -117, 28,   27,  -128, -41,  -89,  50,
+    -24,  -114, -16,  19,  -128, -41,  -89,  52,  -12,  -114, -24,  19};
 
 // Conv Test Case: Int8Filter1x3x3x1ShouldMatchGolden
 const int8_t kConvFilter1x3x3x1[1 * 3 * 3 * 1]{
diff --git a/tensorflow/lite/micro/kernels/testdata/conv_test_data.h b/tensorflow/lite/micro/kernels/testdata/conv_test_data.h
index bdac510..72ac8a2 100644
--- a/tensorflow/lite/micro/kernels/testdata/conv_test_data.h
+++ b/tensorflow/lite/micro/kernels/testdata/conv_test_data.h
@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -22,7 +22,7 @@
 extern const int8_t kConvInput1x32x32x3[];
 extern const int8_t kConvFilter8x3x3x3[];
 extern const int32_t kConvBiasQuantized8[];
-extern const int8_t kConvGoldenOutput1x16x16x8[];
+extern const int8_t kConvGoldenOutput1x15x15x8[];
 
 // Kernel Conv Test Cases: Int8Filter1x3x3x1ShouldMatchGolden
 extern const int8_t kConvInput1x4x4x1[];
diff --git a/tensorflow/lite/micro/memory_arena_threshold_test.cc b/tensorflow/lite/micro/memory_arena_threshold_test.cc
index 3017f56..90d7f6c 100644
--- a/tensorflow/lite/micro/memory_arena_threshold_test.cc
+++ b/tensorflow/lite/micro/memory_arena_threshold_test.cc
@@ -1,4 +1,4 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -93,24 +93,24 @@
 // Total size contributed by the conv model excluding the
 // RecordingMicroAllocator's overhead
 // TODO(b/207157610): replace magic number that depends on OPs
-constexpr int kTestConvModelOnlyTotalSize = 9488;
+constexpr int kTestConvModelOnlyTotalSize = 9558;
 // Tail size contributed by the conv model excluding the
 // RecordingMicroAllocator's overhead
 // TODO(b/207157610): replace magic number that depends on OPs
-constexpr int kTestConvModelOnlyTailSize = 1816;
+constexpr int kTestConvModelOnlyTailSize = 1886;
 constexpr int kTestConvModelPersistentTfLiteTensorDataSize = 128;
-constexpr int kTestConvModelPersistentBufferDataSize = 728;
+constexpr int kTestConvModelPersistentBufferDataSize = 798;
 #else
 // Total size contributed by the conv model excluding the
 // RecordingMicroAllocator's overhead
 // TODO(b/207157610): replace magic number that depends on OPs
-constexpr int kTestConvModelOnlyTotalSize = 9760;
+constexpr int kTestConvModelOnlyTotalSize = 9830;
 // Tail size contributed by the conv model excluding the
 // RecordingMicroAllocator's overhead
 // TODO(b/207157610): replace magic number that depends on OPs
-constexpr int kTestConvModelOnlyTailSize = 2088;
+constexpr int kTestConvModelOnlyTailSize = 2158;
 constexpr int kTestConvModelPersistentTfLiteTensorDataSize = 224;
-constexpr int kTestConvModelPersistentBufferDataSize = 720;
+constexpr int kTestConvModelPersistentBufferDataSize = 790;
 #endif
 constexpr int kTestConvModelHeadSize = 7744;
 constexpr int kTestConvModelOpRuntimeDataSize = 136;
diff --git a/tensorflow/lite/micro/python/keras_tests/BUILD b/tensorflow/lite/micro/python/keras_tests/BUILD
new file mode 100644
index 0000000..12fd6a9
--- /dev/null
+++ b/tensorflow/lite/micro/python/keras_tests/BUILD
@@ -0,0 +1,28 @@
+# Description:
+#   TensorFlow Lite microcontroller example.
+load("@rules_python//python:defs.bzl", "py_binary", "py_test")
+load("@tflm_pip_deps//:requirements.bzl", "requirement")
+
+package(
+    default_visibility = ["//visibility:public"],
+    # Disabling layering_check because of http://b/177257332
+    features = ["-layering_check"],
+    licenses = ["notice"],
+)
+
+py_test(
+    name = "conv_tests",
+    srcs = ["conv_tests.py"],
+    main = "conv_tests.py",
+    python_version = "PY3",
+    tags = [
+        "noasan",
+        "nomsan",  # Python doesn't like these symbols
+        "noubsan",
+    ],
+    deps = [
+        requirement("numpy"),
+        requirement("tensorflow-cpu"),
+        "//python/tflite_micro:runtime",
+    ],
+)
diff --git a/tensorflow/lite/micro/python/keras_tests/conv_tests.py b/tensorflow/lite/micro/python/keras_tests/conv_tests.py
new file mode 100644
index 0000000..80c43bf
--- /dev/null
+++ b/tensorflow/lite/micro/python/keras_tests/conv_tests.py
@@ -0,0 +1,158 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""
+Convolution kernel testing with dilation > 1, using the TfLiteConverter to
+convert models directly from Keras.
+
+Run:
+bazel build tensorflow/lite/micro/python/keras_tests:conv_tests
+bazel-bin/tensorflow/lite/micro/python/keras_tests/conv_tests
+"""
+
+from __future__ import annotations
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+
+import tensorflow as tf
+import keras.api._v2.keras as keras  # for Visual Studio Code to work correctly
+from tflite_micro.python.tflite_micro import runtime
+
+
+class KerasConvTest(test_util.TensorFlowTestCase):
+
+  def MakeConv1dModel(self, *, shape, dilation):
+    input_layer = keras.layers.Input(shape=shape)
+    conv_layer = keras.layers.Conv1D(1,
+                                     3,
+                                     dilation_rate=dilation,
+                                     padding='same')(input_layer)
+    model = keras.Model(inputs=input_layer, outputs=conv_layer)
+    return model
+
+  def MakeConv2dModel(self, *, shape, dilation):
+    input_layer = keras.layers.Input(shape=shape)
+    conv_layer = keras.layers.Conv2D(1,
+                                     3,
+                                     dilation_rate=dilation,
+                                     padding='same')(input_layer)
+    model = keras.Model(inputs=input_layer, outputs=conv_layer)
+    return model
+
+  def MakeDepthwiseConv1dModel(self, *, shape, dilation):
+    input_layer = keras.layers.Input(shape=shape)
+    conv_layer = keras.layers.DepthwiseConv1D(3,
+                                              dilation_rate=dilation,
+                                              padding='same')(input_layer)
+    model = keras.Model(inputs=input_layer, outputs=conv_layer)
+    return model
+
+  def MakeDepthwiseConv2dModel(self, *, shape, dilation):
+    input_layer = keras.layers.Input(shape=shape)
+    conv_layer = keras.layers.DepthwiseConv2D(3,
+                                              dilation_rate=dilation,
+                                              padding='same')(input_layer)
+    model = keras.Model(inputs=input_layer, outputs=conv_layer)
+    return model
+
+  def MakeTransposeConv1dModel(self, *, shape, dilation):
+    input_layer = keras.layers.Input(shape=shape)
+    conv_layer = keras.layers.Conv1DTranspose(1,
+                                              3,
+                                              dilation_rate=dilation,
+                                              padding='same')(input_layer)
+    model = keras.Model(inputs=input_layer, outputs=conv_layer)
+    return model
+
+  def MakeTransposeConv2dModel(self, *, shape, dilation):
+    input_layer = keras.layers.Input(shape=shape)
+    conv_layer = keras.layers.Conv2DTranspose(1,
+                                              3,
+                                              dilation_rate=dilation,
+                                              padding='same')(input_layer)
+    model = keras.Model(inputs=input_layer, outputs=conv_layer)
+    return model
+
+  def ExecuteModelTest(self, model: keras.Model):
+    model_shape = list(model.layers[0].input_shape[0])
+    model_shape[0] = 1
+    input_data = tf.ones(shape=model_shape, dtype=tf.float32)
+    tf_result: tf.Tensor = model(input_data)  # type: ignore
+
+    converter = tf.lite.TFLiteConverter.from_keras_model(model=model)
+    tflite_model = converter.convert()
+    tf.lite.experimental.Analyzer.analyze(model_content=tflite_model)
+
+    tflm_interpreter = runtime.Interpreter.from_bytes(
+        tflite_model,
+        intrepreter_config=runtime.InterpreterConfig.kPreserveAllTensors)
+    tflm_interpreter.set_input(input_data, 0)
+    tflm_interpreter.invoke()
+    tflm_result = tflm_interpreter.get_output(0)
+    tflm_output_details = tflm_interpreter.get_output_details(0)
+    tflm_shape = tflm_output_details['shape']
+
+    print(f'{tf_result=}')
+    print(f'{tflm_result=} {tflm_shape=}')
+
+    self.assertAllClose(tf_result, tflm_result)
+    self.assertAllEqual(tf_result.shape, tflm_shape)
+
+  def setUp(self):
+    pass
+
+  def testConv1dWithDilation1(self):
+    model = self.MakeConv1dModel(shape=(8, 1), dilation=1)
+    self.ExecuteModelTest(model)
+
+  def testConv1dWithDilation2(self):
+    model = self.MakeConv1dModel(shape=(8, 1), dilation=2)
+    self.ExecuteModelTest(model)
+
+  def testConv2dWithDilation1(self):
+    model = self.MakeConv2dModel(shape=(1, 8, 1), dilation=1)
+    self.ExecuteModelTest(model)
+
+  def testConv2dWithDilation2(self):
+    model = self.MakeConv2dModel(shape=(1, 8, 1), dilation=2)
+    self.ExecuteModelTest(model)
+
+  def testDepthwiseConv1dWithDilation1(self):
+    model = self.MakeDepthwiseConv1dModel(shape=(8, 1), dilation=1)
+    self.ExecuteModelTest(model)
+
+  def testDepthwiseConv1dWithDilation2(self):
+    model = self.MakeDepthwiseConv1dModel(shape=(8, 1), dilation=2)
+    self.ExecuteModelTest(model)
+
+  def testDepthwiseConv2dWithDilation1(self):
+    model = self.MakeDepthwiseConv2dModel(shape=(1, 8, 1), dilation=1)
+    self.ExecuteModelTest(model)
+
+  def testDepthwiseConv2dWithDilation2(self):
+    model = self.MakeDepthwiseConv2dModel(shape=(1, 8, 1), dilation=2)
+    self.ExecuteModelTest(model)
+
+  def testTransposeConv1dWithDilation1(self):
+    model = self.MakeTransposeConv1dModel(shape=(8, 1), dilation=1)
+    self.ExecuteModelTest(model)
+
+  def testTransposeConv2dWithDilation1(self):
+    model = self.MakeTransposeConv2dModel(shape=(1, 8, 1), dilation=1)
+    self.ExecuteModelTest(model)
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/lite/micro/testing/micro_test.h b/tensorflow/lite/micro/testing/micro_test.h
index a28f4b6..bd2d93c 100644
--- a/tensorflow/lite/micro/testing/micro_test.h
+++ b/tensorflow/lite/micro/testing/micro_test.h
@@ -264,4 +264,11 @@
     }                                                                        \
   } while (false)
 
+#define TF_LITE_MICRO_CHECK_FAIL()   \
+  do {                               \
+    if (micro_test::did_test_fail) { \
+      return kTfLiteError;           \
+    }                                \
+  } while (false)
+
 #endif  // TENSORFLOW_LITE_MICRO_TESTING_MICRO_TEST_H_