Use dwconv on input_depth == 1, 5x5

Change-Id: Icd523652329dc68ef3b73d6cf166cca952950d5d
diff --git a/tflm/opt/conv_s8.cc b/tflm/opt/conv_s8.cc
index d3529ef..3f77764 100644
--- a/tflm/opt/conv_s8.cc
+++ b/tflm/opt/conv_s8.cc
@@ -200,6 +200,11 @@
   return; \
 }
 
+  if (input_depth == 1 && filter_width == 5 && filter_height == 5 &&
+      output_depth == 24) {
+    RUN_KERNEL(kelvin::opt::ConvPerChannelD1OD24_5x5);
+  }
+
   // special case of filter_depth = 4n
   if (dilation_width_factor == 1 && dilation_height_factor == 1 &&
       stride_width <= 2 && stride_height <= 2 && filter_depth % 4 == 0 &&
diff --git a/tflm/opt/conv_s8.h b/tflm/opt/conv_s8.h
index 91c535a..718d9ee 100644
--- a/tflm/opt/conv_s8.h
+++ b/tflm/opt/conv_s8.h
@@ -71,6 +71,15 @@
     const int32_t* bias_data, const tflite::RuntimeShape& output_shape,
     int8_t* output_data);
 
+// Input depth = 1, filter_width = 5, filter_height = 5, output_depth = 24
+void ConvPerChannelD1OD24_5x5(
+    const tflite::ConvParams& params, const int32_t* output_multiplier,
+    const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
+    const int8_t* input_data, const tflite::RuntimeShape& filter_shape,
+    const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
+    const int32_t* bias_data, const tflite::RuntimeShape& output_shape,
+    int8_t* output_data);
+
 }  // namespace kelvin::opt
 
 #endif  // TFLM_OPT_CONV_S8_H_
diff --git a/tflm/opt/conv_s8_d1.cc b/tflm/opt/conv_s8_d1.cc
index b886f32..09f4110 100644
--- a/tflm/opt/conv_s8_d1.cc
+++ b/tflm/opt/conv_s8_d1.cc
@@ -64,8 +64,679 @@
       output[24] = input[3];
   }
 }
+
 }  // namespace
 
+#define CALCULATE_IN_X(in_x_origin)                        \
+  {                                                        \
+    _Pragma("GCC unroll 5") for (int i = 0; i < 5; ++i) {  \
+      in_x[i] = in_x_origin + (dilation_width_factor * i); \
+    }                                                      \
+  }
+
+#define CALCULATE_IN_Y(in_y_origin)                         \
+  {                                                         \
+    _Pragma("GCC unroll 5") for (int i = 0; i < 5; ++i) {   \
+      in_y[i] = in_y_origin + (dilation_height_factor * i); \
+    }                                                       \
+  }
+
+#define PAD_ROW_0(input_offset)   \
+  {                               \
+    vdup_b_x(v27, -input_offset); \
+    vdup_b_x(v28, -input_offset); \
+    vdup_b_x(v29, -input_offset); \
+    vdup_b_x(v30, -input_offset); \
+    vdup_b_x(v31, -input_offset); \
+  }
+#define PAD_ROW_1(input_offset)   \
+  {                               \
+    vdup_b_x(v32, -input_offset); \
+    vdup_b_x(v33, -input_offset); \
+    vdup_b_x(v34, -input_offset); \
+    vdup_b_x(v35, -input_offset); \
+    vdup_b_x(v36, -input_offset); \
+  }
+#define PAD_ROW_2(input_offset)   \
+  {                               \
+    vdup_b_x(v37, -input_offset); \
+    vdup_b_x(v38, -input_offset); \
+    vdup_b_x(v39, -input_offset); \
+    vdup_b_x(v40, -input_offset); \
+    vdup_b_x(v41, -input_offset); \
+  }
+#define PAD_ROW_3(input_offset)   \
+  {                               \
+    vdup_b_x(v42, -input_offset); \
+    vdup_b_x(v43, -input_offset); \
+    vdup_b_x(v44, -input_offset); \
+    vdup_b_x(v45, -input_offset); \
+    vdup_b_x(v46, -input_offset); \
+  }
+#define PAD_ROW_4(input_offset)   \
+  {                               \
+    vdup_b_x(v47, -input_offset); \
+    vdup_b_x(v48, -input_offset); \
+    vdup_b_x(v49, -input_offset); \
+    vdup_b_x(v50, -input_offset); \
+    vdup_b_x(v51, -input_offset); \
+  }
+
+#define LOAD_ROW_0(p_input, input_width, in_y, in_x)         \
+  {                                                          \
+    const int8_t* p_row = p_input + (in_y[0] * input_width); \
+    vdup_b_x(v27, *(p_row + in_x[0]));                       \
+    vdup_b_x(v28, *(p_row + in_x[1]));                       \
+    vdup_b_x(v29, *(p_row + in_x[2]));                       \
+    vdup_b_x(v30, *(p_row + in_x[3]));                       \
+    vdup_b_x(v31, *(p_row + in_x[4]));                       \
+  }
+
+#define LOAD_ROW_1(p_input, input_width, in_y, in_x)         \
+  {                                                          \
+    const int8_t* p_row = p_input + (in_y[1] * input_width); \
+    vdup_b_x(v32, *(p_row + in_x[0]));                       \
+    vdup_b_x(v33, *(p_row + in_x[1]));                       \
+    vdup_b_x(v34, *(p_row + in_x[2]));                       \
+    vdup_b_x(v35, *(p_row + in_x[3]));                       \
+    vdup_b_x(v36, *(p_row + in_x[4]));                       \
+  }
+
+#define LOAD_ROW_2(p_input, input_width, in_y, in_x)         \
+  {                                                          \
+    const int8_t* p_row = p_input + (in_y[2] * input_width); \
+    vdup_b_x(v37, *(p_row + in_x[0]));                       \
+    vdup_b_x(v38, *(p_row + in_x[1]));                       \
+    vdup_b_x(v39, *(p_row + in_x[2]));                       \
+    vdup_b_x(v40, *(p_row + in_x[3]));                       \
+    vdup_b_x(v41, *(p_row + in_x[4]));                       \
+  }
+
+#define LOAD_ROW_3(p_input, input_width, in_y, in_x)         \
+  {                                                          \
+    const int8_t* p_row = p_input + (in_y[3] * input_width); \
+    vdup_b_x(v42, *(p_row + in_x[0]));                       \
+    vdup_b_x(v43, *(p_row + in_x[1]));                       \
+    vdup_b_x(v44, *(p_row + in_x[2]));                       \
+    vdup_b_x(v45, *(p_row + in_x[3]));                       \
+    vdup_b_x(v46, *(p_row + in_x[4]));                       \
+  }
+
+#define LOAD_ROW_4(p_input, input_width, in_y, in_x)         \
+  {                                                          \
+    const int8_t* p_row = p_input + (in_y[4] * input_width); \
+    vdup_b_x(v47, *(p_row + in_x[0]));                       \
+    vdup_b_x(v48, *(p_row + in_x[1]));                       \
+    vdup_b_x(v49, *(p_row + in_x[2]));                       \
+    vdup_b_x(v50, *(p_row + in_x[3]));                       \
+    vdup_b_x(v51, *(p_row + in_x[4]));                       \
+  }
+
+#define H_PAD_OR_LOAD_ROW_0(p_input, input_width, input_offset, in_y, in_x) \
+  if (in_x[0] >= 0 && in_x[4] < input_width) {                              \
+    LOAD_ROW_0(p_input, input_width, in_y, in_x);                           \
+  } else {                                                                  \
+    const int8_t* p_row = p_input + (in_y[0] * input_width);                \
+    if (in_x[0] < 0 || in_x[0] >= input_width) {                            \
+      vdup_b_x(v27, -input_offset);                                         \
+    } else {                                                                \
+      vdup_b_x(v27, *(p_row + in_x[0]));                                    \
+    }                                                                       \
+    if (in_x[1] < 0 || in_x[1] >= input_width) {                            \
+      vdup_b_x(v28, -input_offset);                                         \
+    } else {                                                                \
+      vdup_b_x(v28, *(p_row + in_x[1]));                                    \
+    }                                                                       \
+    if (in_x[2] < 0 || in_x[2] >= input_width) {                            \
+      vdup_b_x(v29, -input_offset);                                         \
+    } else {                                                                \
+      vdup_b_x(v29, *(p_row + in_x[2]));                                    \
+    }                                                                       \
+    if (in_x[3] < 0 || in_x[3] >= input_width) {                            \
+      vdup_b_x(v30, -input_offset);                                         \
+    } else {                                                                \
+      vdup_b_x(v30, *(p_row + in_x[3]));                                    \
+    }                                                                       \
+    if (in_x[4] < 0 || in_x[4] >= input_width) {                            \
+      vdup_b_x(v31, -input_offset);                                         \
+    } else {                                                                \
+      vdup_b_x(v31, *(p_row + in_x[4]));                                    \
+    }                                                                       \
+  }
+
+#define H_PAD_OR_LOAD_ROW_1(p_input, input_width, input_offset, in_y, in_x) \
+  if (in_x[0] >= 0 && in_x[4] < input_width) {                              \
+    LOAD_ROW_1(p_input, input_width, in_y, in_x);                           \
+  } else {                                                                  \
+    const int8_t* p_row = p_input + (in_y[1] * input_width);                \
+    if (in_x[0] < 0 || in_x[0] >= input_width) {                            \
+      vdup_b_x(v32, -input_offset);                                         \
+    } else {                                                                \
+      vdup_b_x(v32, *(p_row + in_x[0]));                                    \
+    }                                                                       \
+    if (in_x[1] < 0 || in_x[1] >= input_width) {                            \
+      vdup_b_x(v33, -input_offset);                                         \
+    } else {                                                                \
+      vdup_b_x(v33, *(p_row + in_x[1]));                                    \
+    }                                                                       \
+    if (in_x[2] < 0 || in_x[2] >= input_width) {                            \
+      vdup_b_x(v34, -input_offset);                                         \
+    } else {                                                                \
+      vdup_b_x(v34, *(p_row + in_x[2]));                                    \
+    }                                                                       \
+    if (in_x[3] < 0 || in_x[3] >= input_width) {                            \
+      vdup_b_x(v35, -input_offset);                                         \
+    } else {                                                                \
+      vdup_b_x(v35, *(p_row + in_x[3]));                                    \
+    }                                                                       \
+    if (in_x[4] < 0 || in_x[4] >= input_width) {                            \
+      vdup_b_x(v36, -input_offset);                                         \
+    } else {                                                                \
+      vdup_b_x(v36, *(p_row + in_x[4]));                                    \
+    }                                                                       \
+  }
+
+#define H_PAD_OR_LOAD_ROW_2(p_input, input_width, input_offset, in_y, in_x) \
+  if (in_x[0] >= 0 && in_x[4] < input_width) {                              \
+    LOAD_ROW_2(p_input, input_width, in_y, in_x);                           \
+  } else {                                                                  \
+    const int8_t* p_row = p_input + (in_y[2] * input_width);                \
+    if (in_x[0] < 0 || in_x[0] >= input_width) {                            \
+      vdup_b_x(v37, -input_offset);                                         \
+    } else {                                                                \
+      vdup_b_x(v37, *(p_row + in_x[0]));                                    \
+    }                                                                       \
+    if (in_x[1] < 0 || in_x[1] >= input_width) {                            \
+      vdup_b_x(v38, -input_offset);                                         \
+    } else {                                                                \
+      vdup_b_x(v38, *(p_row + in_x[1]));                                    \
+    }                                                                       \
+    if (in_x[2] < 0 || in_x[2] >= input_width) {                            \
+      vdup_b_x(v39, -input_offset);                                         \
+    } else {                                                                \
+      vdup_b_x(v39, *(p_row + in_x[2]));                                    \
+    }                                                                       \
+    if (in_x[3] < 0 || in_x[3] >= input_width) {                            \
+      vdup_b_x(v40, -input_offset);                                         \
+    } else {                                                                \
+      vdup_b_x(v40, *(p_row + in_x[3]));                                    \
+    }                                                                       \
+    if (in_x[4] < 0 || in_x[4] >= input_width) {                            \
+      vdup_b_x(v41, -input_offset);                                         \
+    } else {                                                                \
+      vdup_b_x(v41, *(p_row + in_x[4]));                                    \
+    }                                                                       \
+  }
+
+#define H_PAD_OR_LOAD_ROW_3(p_input, input_width, input_offset, in_y, in_x) \
+  if (in_x[0] >= 0 && in_x[4] < input_width) {                              \
+    LOAD_ROW_3(p_input, input_width, in_y, in_x);                           \
+  } else {                                                                  \
+    const int8_t* p_row = p_input + (in_y[3] * input_width);                \
+    if (in_x[0] < 0 || in_x[0] >= input_width) {                            \
+      vdup_b_x(v42, -input_offset);                                         \
+    } else {                                                                \
+      vdup_b_x(v42, *(p_row + in_x[0]));                                    \
+    }                                                                       \
+    if (in_x[1] < 0 || in_x[1] >= input_width) {                            \
+      vdup_b_x(v43, -input_offset);                                         \
+    } else {                                                                \
+      vdup_b_x(v43, *(p_row + in_x[1]));                                    \
+    }                                                                       \
+    if (in_x[2] < 0 || in_x[2] >= input_width) {                            \
+      vdup_b_x(v44, -input_offset);                                         \
+    } else {                                                                \
+      vdup_b_x(v44, *(p_row + in_x[2]));                                    \
+    }                                                                       \
+    if (in_x[3] < 0 || in_x[3] >= input_width) {                            \
+      vdup_b_x(v45, -input_offset);                                         \
+    } else {                                                                \
+      vdup_b_x(v45, *(p_row + in_x[3]));                                    \
+    }                                                                       \
+    if (in_x[4] < 0 || in_x[4] >= input_width) {                            \
+      vdup_b_x(v46, -input_offset);                                         \
+    } else {                                                                \
+      vdup_b_x(v46, *(p_row + in_x[4]));                                    \
+    }                                                                       \
+  }
+
+#define H_PAD_OR_LOAD_ROW_4(p_input, input_width, input_offset, in_y, in_x) \
+  if (in_x[0] >= 0 && in_x[4] < input_width) {                              \
+    LOAD_ROW_4(p_input, input_width, in_y, in_x);                           \
+  } else {                                                                  \
+    const int8_t* p_row = p_input + (in_y[4] * input_width);                \
+    if (in_x[0] < 0 || in_x[0] >= input_width) {                            \
+      vdup_b_x(v47, -input_offset);                                         \
+    } else {                                                                \
+      vdup_b_x(v47, *(p_row + in_x[0]));                                    \
+    }                                                                       \
+    if (in_x[1] < 0 || in_x[1] >= input_width) {                            \
+      vdup_b_x(v48, -input_offset);                                         \
+    } else {                                                                \
+      vdup_b_x(v48, *(p_row + in_x[1]));                                    \
+    }                                                                       \
+    if (in_x[2] < 0 || in_x[2] >= input_width) {                            \
+      vdup_b_x(v49, -input_offset);                                         \
+    } else {                                                                \
+      vdup_b_x(v49, *(p_row + in_x[2]));                                    \
+    }                                                                       \
+    if (in_x[3] < 0 || in_x[3] >= input_width) {                            \
+      vdup_b_x(v50, -input_offset);                                         \
+    } else {                                                                \
+      vdup_b_x(v50, *(p_row + in_x[3]));                                    \
+    }                                                                       \
+    if (in_x[4] < 0 || in_x[4] >= input_width) {                            \
+      vdup_b_x(v51, -input_offset);                                         \
+    } else {                                                                \
+      vdup_b_x(v51, *(p_row + in_x[4]));                                    \
+    }                                                                       \
+  }
+
+#define _H_PAD_OR_LOAD_ROW(row, p_input, input_width, input_offset, in_y, \
+                           in_x)                                          \
+  H_PAD_OR_LOAD_ROW_##row(p_input, input_width, input_offset, in_y, in_x);
+
+#define _PAD_OR_LOAD_ROW(row, p_input, input_height, input_width, in_y, in_x,  \
+                         input_offset)                                         \
+  {                                                                            \
+    if (in_y[row] < 0 || in_y[row] >= input_height) {                          \
+      PAD_ROW_##row(input_offset);                                             \
+    } else {                                                                   \
+      _H_PAD_OR_LOAD_ROW(row, p_input, input_width, input_offset, in_y, in_x); \
+    }                                                                          \
+  }
+
+#define PAD_OR_LOAD_ROW_0(p_input, input_height, input_width, in_y, in_x, \
+                          input_offset)                                   \
+  _PAD_OR_LOAD_ROW(0, p_input, input_height, input_width, in_y, in_x,     \
+                   input_offset);
+#define PAD_OR_LOAD_ROW_1(p_input, input_height, input_width, in_y, in_x, \
+                          input_offset)                                   \
+  _PAD_OR_LOAD_ROW(1, p_input, input_height, input_width, in_y, in_x,     \
+                   input_offset);
+#define PAD_OR_LOAD_ROW_2(p_input, input_height, input_width, in_y, in_x, \
+                          input_offset)                                   \
+  _PAD_OR_LOAD_ROW(2, p_input, input_height, input_width, in_y, in_x,     \
+                   input_offset);
+#define PAD_OR_LOAD_ROW_3(p_input, input_height, input_width, in_y, in_x, \
+                          input_offset)                                   \
+  _PAD_OR_LOAD_ROW(3, p_input, input_height, input_width, in_y, in_x,     \
+                   input_offset);
+#define PAD_OR_LOAD_ROW_4(p_input, input_height, input_width, in_y, in_x, \
+                          input_offset)                                   \
+  _PAD_OR_LOAD_ROW(4, p_input, input_height, input_width, in_y, in_x,     \
+                   input_offset);
+
+#define COMPUTE(cmds, swizzled_bias_data) \
+  {                                       \
+    vld_w_x_m(v60, swizzled_bias_data);   \
+    adwinit_v(v60, v60);                  \
+    adwconv_vxv(v60, v27, cmds, v0);      \
+    adwconv_vxv(v60, v30, cmds, v3);      \
+    adwconv_vxv(v60, v33, cmds, v6);      \
+    adwconv_vxv(v60, v36, cmds, v9);      \
+    adwconv_vxv(v60, v39, cmds, v12);     \
+    adwconv_vxv(v60, v42, cmds, v15);     \
+    adwconv_vxv(v60, v45, cmds, v18);     \
+    adwconv_vxv(v60, v48, cmds, v21);     \
+    vdwconv_vxv(v60, v51, cmds, v24);     \
+  }
+
+#define OUTPUT(output_activation_min, output_activation_max, output_offset, \
+               local_output_data, n_channels)                               \
+  {                                                                         \
+    INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(                                  \
+        v60, v52, v56, output_activation_min, output_activation_max,        \
+        output_offset);                                                     \
+    vsraqs_b_vx(v60, v60, 0);                                               \
+    vst_b_l_xx(v60, local_output_data, n_channels);                         \
+  }
+
+// Estimated count of arithmetic ops: 58.297 M  ops, equivalently 29.148 M  MACs
+void ConvPerChannelD1OD24_5x5(
+    const tflite::ConvParams& params, const int32_t* output_multiplier,
+    const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
+    const int8_t* input_data, const tflite::RuntimeShape& filter_shape,
+    const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
+    const int32_t* bias_data, const tflite::RuntimeShape& output_shape,
+    int8_t* output_data) {
+  // Get parameters.
+  const int32_t input_offset = params.input_offset;  // r = s(q - Z)
+  const int stride_width = params.stride_width;
+  const int stride_height = params.stride_height;
+  const int dilation_width_factor = params.dilation_width_factor;
+  const int dilation_height_factor = params.dilation_height_factor;
+  const int pad_width = params.padding_values.width;
+  const int pad_height = params.padding_values.height;
+  const int32_t output_offset = params.output_offset;
+
+  // Set min and max value of the output.
+  const int32_t output_activation_min = params.quantized_activation_min;
+  const int32_t output_activation_max = params.quantized_activation_max;
+
+  // Consistency check.
+  TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+  TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+  // const int batches = tflite::MatchingDim(input_shape, 0, output_shape, 0);
+  const int input_depth = input_shape.Dims(3);
+  const int output_depth =
+      tflite::MatchingDim(filter_shape, 0, output_shape, 3);
+  if (bias_data) {
+    TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
+  }
+
+  // Check dimensions of the tensors.
+  const int input_height = input_shape.Dims(1);
+  const int input_width = input_shape.Dims(2);
+  const int filter_height = filter_shape.Dims(1);
+  const int filter_width = filter_shape.Dims(2);
+  const int filter_input_depth = filter_shape.Dims(3);
+  const int groups = input_depth / filter_input_depth;
+  TFLITE_DCHECK_NE(groups, 0);
+  TFLITE_DCHECK_EQ(input_depth % filter_input_depth, 0);
+  const int filters_per_group = output_depth / groups;
+  TFLITE_DCHECK_NE(filters_per_group, 0);
+  const int output_height = output_shape.Dims(1);
+  const int output_width = output_shape.Dims(2);
+
+  // Scratch pads to juggle data
+  const size_t swizzled_filter_data_size = 24 * filter_height * filter_width;
+  std::unique_ptr<int8_t> swizzled_filter_data(reinterpret_cast<int8_t*>(
+      ::aligned_alloc(32, swizzled_filter_data_size)));
+  int32_t swizzled_bias_data[32];
+  int32_t swizzled_output_multiplier[32];
+  int32_t swizzled_output_shift[32];
+  // Transpose filter for easy loading
+  for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
+    for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
+      for (int i = 0; i < 24; i++) {
+        int filter_location =
+            (filter_y * filter_width * 24) + (filter_x * 24) + i;
+        swizzled_filter_data.get()[filter_location] =
+            filter_data[tflite::Offset(filter_shape, i, filter_y, filter_x, 0)];
+      }
+    }
+  }
+  const int8_t* p_flt_0 = swizzled_filter_data.get() + (0 * filter_width * 24);
+  const int8_t* p_flt_1 = swizzled_filter_data.get() + (1 * filter_width * 24);
+  const int8_t* p_flt_2 = swizzled_filter_data.get() + (2 * filter_width * 24);
+  const int8_t* p_flt_3 = swizzled_filter_data.get() + (3 * filter_width * 24);
+  const int8_t* p_flt_4 = swizzled_filter_data.get() + (4 * filter_width * 24);
+  vld_b_l_xx(v0, p_flt_0 + (0 * 24), 24);
+  vld_b_l_xx(v1, p_flt_0 + (1 * 24), 24);
+  vld_b_l_xx(v2, p_flt_0 + (2 * 24), 24);
+  vld_b_l_xx(v3, p_flt_0 + (3 * 24), 24);
+  vld_b_l_xx(v4, p_flt_0 + (4 * 24), 24);
+
+  vld_b_l_xx(v5, p_flt_1 + (0 * 24), 24);
+  vld_b_l_xx(v6, p_flt_1 + (1 * 24), 24);
+  vld_b_l_xx(v7, p_flt_1 + (2 * 24), 24);
+  vld_b_l_xx(v8, p_flt_1 + (3 * 24), 24);
+  vld_b_l_xx(v9, p_flt_1 + (4 * 24), 24);
+
+  vld_b_l_xx(v10, p_flt_2 + (0 * 24), 24);
+  vld_b_l_xx(v11, p_flt_2 + (1 * 24), 24);
+  vld_b_l_xx(v12, p_flt_2 + (2 * 24), 24);
+  vld_b_l_xx(v13, p_flt_2 + (3 * 24), 24);
+  vld_b_l_xx(v14, p_flt_2 + (4 * 24), 24);
+
+  vld_b_l_xx(v15, p_flt_3 + (0 * 24), 24);
+  vld_b_l_xx(v16, p_flt_3 + (1 * 24), 24);
+  vld_b_l_xx(v17, p_flt_3 + (2 * 24), 24);
+  vld_b_l_xx(v18, p_flt_3 + (3 * 24), 24);
+  vld_b_l_xx(v19, p_flt_3 + (4 * 24), 24);
+
+  vld_b_l_xx(v20, p_flt_4 + (0 * 24), 24);
+  vld_b_l_xx(v21, p_flt_4 + (1 * 24), 24);
+  vld_b_l_xx(v22, p_flt_4 + (2 * 24), 24);
+  vld_b_l_xx(v23, p_flt_4 + (3 * 24), 24);
+  vld_b_l_xx(v24, p_flt_4 + (4 * 24), 24);
+  vdup_b_x(v25, 0);
+  vdup_b_x(v26, 0);
+
+  union {
+    vdwconv_u8_t dwconv;
+    uint32_t raw;
+  } cmds;
+  cmds.raw = 0;
+  cmds.dwconv.sdata1 = true;
+  cmds.dwconv.sbias1 = input_offset;
+  cmds.dwconv.sdata2 = true;
+  cmds.dwconv.sbias2 = 0;
+  cmds.dwconv.mode = 0;
+  cmds.dwconv.sparsity = 0;
+  cmds.dwconv.regbase = 0;
+  int out_channel = 0;
+  int n_channels = 24;
+
+  memset(swizzled_bias_data, 0, 32 * sizeof(uint32_t));
+  JumptableSwizzle(bias_data + out_channel, swizzled_bias_data, n_channels);
+  memset(swizzled_output_multiplier, 0, 32 * sizeof(uint32_t));
+  JumptableSwizzle(output_multiplier + out_channel, swizzled_output_multiplier,
+                   n_channels);
+  JumptableSwizzle(output_shift + out_channel, swizzled_output_shift,
+                   n_channels);
+  vld_w_x_m(v52, swizzled_output_multiplier);
+  vld_w_x_m(v56, swizzled_output_shift);
+  vrsub_w_vx_m(v56, v56, 0);
+
+  int8_t* local_output_data = output_data + out_channel;
+  int in_y[5];
+  int in_x[5];
+  int out_y = 0;
+  const int8_t* p_input = input_data;
+  // Handle top row padding
+  for (; out_y < pad_height; ++out_y) {
+    int out_x = 0;
+    const int in_y_origin = (out_y * stride_height) - pad_height;
+    CALCULATE_IN_Y(in_y_origin);
+    // Left padding required
+    for (; out_x < pad_width; ++out_x) {
+      const int in_x_origin = (out_x * stride_width) - pad_width;
+      CALCULATE_IN_X(in_x_origin);
+      PAD_OR_LOAD_ROW_0(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      PAD_OR_LOAD_ROW_1(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      PAD_OR_LOAD_ROW_2(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      PAD_OR_LOAD_ROW_3(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      PAD_OR_LOAD_ROW_4(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      COMPUTE(cmds, swizzled_bias_data);
+      OUTPUT(output_activation_min, output_activation_max, output_offset,
+             local_output_data, n_channels);
+      local_output_data += output_depth;
+    }
+    // No side padding
+    for (; out_x < (output_width - pad_width); ++out_x) {
+      const int in_x_origin = (out_x * stride_width) - pad_width;
+      CALCULATE_IN_X(in_x_origin);
+      PAD_OR_LOAD_ROW_0(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      PAD_OR_LOAD_ROW_1(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      PAD_OR_LOAD_ROW_2(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      PAD_OR_LOAD_ROW_3(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      PAD_OR_LOAD_ROW_4(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      COMPUTE(cmds, swizzled_bias_data);
+      OUTPUT(output_activation_min, output_activation_max, output_offset,
+             local_output_data, n_channels);
+      local_output_data += output_depth;
+    }
+    // Right padding required
+    for (; out_x < output_width; ++out_x) {
+      const int in_x_origin = (out_x * stride_width) - pad_width;
+      CALCULATE_IN_X(in_x_origin);
+      PAD_OR_LOAD_ROW_0(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      PAD_OR_LOAD_ROW_1(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      PAD_OR_LOAD_ROW_2(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      PAD_OR_LOAD_ROW_3(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      PAD_OR_LOAD_ROW_4(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      COMPUTE(cmds, swizzled_bias_data);
+      OUTPUT(output_activation_min, output_activation_max, output_offset,
+             local_output_data, n_channels);
+      local_output_data += output_depth;
+    }
+  }
+
+  // No height padding
+  for (; out_y < (output_height - pad_height); ++out_y) {
+    const int in_y_origin = (out_y * stride_height) - pad_height;
+    CALCULATE_IN_Y(in_y_origin);
+    // Left padding
+    int out_x = 0;
+    for (; out_x < pad_width; ++out_x) {
+      const int in_x_origin = (out_x * stride_width) - pad_width;
+      CALCULATE_IN_X(in_x_origin);
+      PAD_OR_LOAD_ROW_0(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      PAD_OR_LOAD_ROW_1(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      PAD_OR_LOAD_ROW_2(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      PAD_OR_LOAD_ROW_3(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      PAD_OR_LOAD_ROW_4(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      COMPUTE(cmds, swizzled_bias_data);
+      OUTPUT(output_activation_min, output_activation_max, output_offset,
+             local_output_data, n_channels);
+      local_output_data += output_depth;
+    }
+    for (; out_x < (output_width - pad_width); ++out_x) {
+      const int in_x_origin = (out_x * stride_width) - pad_width;
+
+      CALCULATE_IN_X(in_x_origin);
+      LOAD_ROW_0(p_input, input_width, in_y, in_x);
+      LOAD_ROW_1(p_input, input_width, in_y, in_x);
+      LOAD_ROW_2(p_input, input_width, in_y, in_x);
+      LOAD_ROW_3(p_input, input_width, in_y, in_x);
+      LOAD_ROW_4(p_input, input_width, in_y, in_x);
+      COMPUTE(cmds, swizzled_bias_data);
+      OUTPUT(output_activation_min, output_activation_max, output_offset,
+             local_output_data, n_channels);
+      local_output_data += output_depth;
+    }
+    // Right padding
+    for (; out_x < output_width; ++out_x) {
+      const int in_x_origin = (out_x * stride_width) - pad_width;
+      CALCULATE_IN_X(in_x_origin);
+      PAD_OR_LOAD_ROW_0(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      PAD_OR_LOAD_ROW_1(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      PAD_OR_LOAD_ROW_2(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      PAD_OR_LOAD_ROW_3(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      PAD_OR_LOAD_ROW_4(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      COMPUTE(cmds, swizzled_bias_data);
+      OUTPUT(output_activation_min, output_activation_max, output_offset,
+             local_output_data, n_channels);
+      local_output_data += output_depth;
+    }
+  }
+
+  // Handle bottom row padding
+  for (; out_y < output_height; ++out_y) {
+    const int in_y_origin = (out_y * stride_height) - pad_height;
+    CALCULATE_IN_Y(in_y_origin);
+    int out_x = 0;
+    for (; out_x < pad_width; ++out_x) {
+      const int in_x_origin = (out_x * stride_width) - pad_width;
+      CALCULATE_IN_X(in_x_origin);
+      PAD_OR_LOAD_ROW_0(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      PAD_OR_LOAD_ROW_1(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      PAD_OR_LOAD_ROW_2(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      PAD_OR_LOAD_ROW_3(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      PAD_OR_LOAD_ROW_4(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      COMPUTE(cmds, swizzled_bias_data);
+      OUTPUT(output_activation_min, output_activation_max, output_offset,
+             local_output_data, n_channels);
+      local_output_data += output_depth;
+    }
+    for (; out_x < (output_width - pad_width); ++out_x) {
+      const int in_x_origin = (out_x * stride_width) - pad_width;
+      CALCULATE_IN_X(in_x_origin);
+      PAD_OR_LOAD_ROW_0(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      PAD_OR_LOAD_ROW_1(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      PAD_OR_LOAD_ROW_2(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      PAD_OR_LOAD_ROW_3(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      PAD_OR_LOAD_ROW_4(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      COMPUTE(cmds, swizzled_bias_data);
+      OUTPUT(output_activation_min, output_activation_max, output_offset,
+             local_output_data, n_channels);
+      local_output_data += output_depth;
+    }
+    for (; out_x < output_width; ++out_x) {
+      const int in_x_origin = (out_x * stride_width) - pad_width;
+      CALCULATE_IN_X(in_x_origin);
+      PAD_OR_LOAD_ROW_0(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      PAD_OR_LOAD_ROW_1(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      PAD_OR_LOAD_ROW_2(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      PAD_OR_LOAD_ROW_3(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      PAD_OR_LOAD_ROW_4(p_input, input_height, input_width, in_y, in_x,
+                        input_offset);
+      COMPUTE(cmds, swizzled_bias_data);
+      OUTPUT(output_activation_min, output_activation_max, output_offset,
+             local_output_data, n_channels);
+      local_output_data += output_depth;
+    }
+  }
+}
+
+#undef PAD_OR_LOAD_ROW_0
+#undef PAD_OR_LOAD_ROW_1
+#undef PAD_OR_LOAD_ROW_2
+#undef PAD_OR_LOAD_ROW_3
+#undef PAD_OR_LOAD_ROW_4
+#undef _PAD_OR_LOAD_ROW
+#undef _H_PAD_OR_LOAD_ROW
+#undef H_PAD_OR_LOAD_ROW_0
+#undef H_PAD_OR_LOAD_ROW_1
+#undef H_PAD_OR_LOAD_ROW_2
+#undef H_PAD_OR_LOAD_ROW_3
+#undef H_PAD_OR_LOAD_ROW_4
+#undef LOAD_ROW_0
+#undef LOAD_ROW_1
+#undef LOAD_ROW_2
+#undef LOAD_ROW_3
+#undef LOAD_ROW_4
+#undef PAD_ROW_0
+#undef PAD_ROW_1
+#undef PAD_ROW_2
+#undef PAD_ROW_3
+#undef PAD_ROW_4
+#undef CALCULATE_IN_X
+#undef CALCULATE_IN_Y
+
 void ConvPerChannelD1(
     const tflite::ConvParams& params, const int32_t* output_multiplier,
     const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
@@ -207,8 +878,8 @@
               v48, v56, v60, output_activation_min, output_activation_max,
               output_offset);
           vsraqs_b_vx(v48, v48, 0);
-          vst_b_l_xx(v48, output_data, n_channels);
-          output_data += output_depth;
+          vst_b_l_xx(v48, local_output_data, n_channels);
+          local_output_data += output_depth;
         }
       }
     }
diff --git a/tflm/opt/conv_s8_d4.cc b/tflm/opt/conv_s8_d4.cc
index 54c20bd..26f3885 100644
--- a/tflm/opt/conv_s8_d4.cc
+++ b/tflm/opt/conv_s8_d4.cc
@@ -26,6 +26,7 @@
 #include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h"
 #include "tensorflow/lite/kernels/internal/runtime_shape.h"
 #include "tensorflow/lite/kernels/internal/types.h"
+#include "tflm/opt/conv_s8.h"
 
 #define unlikely(x) (__builtin_expect(false || (x), false))
 #define likely(x) (__builtin_expect(false || (x), true))
@@ -36,7 +37,7 @@
 void Filter_N_H_W_M(const int8_t* input, int8_t* output, int N, int H, int W, int M) {
   const int8_t(&in)[8][H][W][M] = *(int8_t(*)[8][H][W][M])input;
   int8_t(&out)[H][W][M / 4][8][4] = *(int8_t(*)[H][W][M / 4][8][4]) output;
-  assert(M >= 4);
+  // assert(M >= 4);
   for (int zo = 0; zo < N; ++zo) {
     for (int ky = 0; ky < H; ++ky) {
       for (int kx = 0; kx < W; ++kx) {
@@ -52,7 +53,7 @@
   for (int zo = N; zo < 8; ++zo) {
     for (int ky = 0; ky < H; ++ky) {
       for (int kx = 0; kx < W; ++kx) {
-        for (int zi = 0; zi < M; ++zi) {
+        for (int zi = M; zi < 4; ++zi) {
           const int zi_hi = zi >> 2;  // div4
           const int zi_lo = zi & 3;   // rem4
           out[ky][kx][zi_hi][zo][zi_lo] = 0;