Improved strategy for filter padding of conv I3xD8

*This is to ensure padding strategy is more readable
*Make loop efficient by reducing if statements and using memcpy

Change-Id: I337efe180d7a5ea57d0b4b02ae9fa820edf74cef
diff --git a/tflm/opt/conv_s8_I3xD8.cc b/tflm/opt/conv_s8_I3xD8.cc
index 370bd97..dd17056 100644
--- a/tflm/opt/conv_s8_I3xD8.cc
+++ b/tflm/opt/conv_s8_I3xD8.cc
@@ -26,6 +26,12 @@
 #include "tflm/opt/conv_util.h"
 #include "tflm/opt/opt.h"
 
+namespace {
+
+constexpr int32_t kBytesPerRegister = 32;
+constexpr int8_t kFilterZeroPoint = 0;
+}  // namespace
+
 namespace kelvin::opt {
 namespace {
 
@@ -57,65 +63,39 @@
   out2[7] = in[7];
 }
 
-void PaddedFilter_N_H_W_M(const int8_t* input, int8_t* output, int N, int H,
-                          int W, int M) {
-  if (M != 3) {
-    MicroPrintf("Filter shuffling can only handle M(input_depth) == 3");
-    exit(-1);
-  }
-
-  const int8_t(&in)[N][H][W][M] = *(int8_t(*)[N][H][W][M])input;
-  int8_t(&out)[N / 8][3][8 * 4 * 3] = *(int8_t(*)[N / 8][3][8 * 4 * 3]) output;
-  int group = 0;
+void PaddedFilter(const int8_t* input, int8_t* output, int output_channels) {
   // Filter data is being reorganized into groups of 8 channels and falttening
   // row. 9th element of 3x3 filter is padded (9000 9000 9000 9000) 8 channels
   // are aligned this way     (  c0  c1   c2    c3)
-  for (int ky = 0; ky < H; ++ky) {
-    int filter_element[N / 8]{0};
-    for (int kx = 0; kx < W; ++kx) {
-      for (int output_channel = 0; output_channel < N; ++output_channel) {
-        for (int input_channel = 0; input_channel < M; ++input_channel) {
-          group = output_channel >> 3;
-          if (kx == 1 && input_channel == 0) {
-            continue;
-          }
-          if (kx == 2 && (input_channel < 2)) {
-            continue;
-          }
 
-          if (kx == 0 && input_channel == 2) {
-            out[group][ky][filter_element[group]] =
-                in[output_channel][ky][kx][input_channel];
-            filter_element[group] += 1;
-            out[group][ky][filter_element[group]] =
-                in[output_channel][ky][kx + 1][0];
-            filter_element[group] += 1;
-          } else if (kx == 1 && input_channel == 2) {
-            out[group][ky][filter_element[group]] =
-                in[output_channel][ky][kx][input_channel];
-            filter_element[group] += 1;
-            out[group][ky][filter_element[group]] =
-                in[output_channel][ky][kx + 1][0];
-            filter_element[group] += 1;
-            out[group][ky][filter_element[group]] =
-                in[output_channel][ky][kx + 1][1];
-            filter_element[group] += 1;
-          } else if (kx == 2 && input_channel == 2) {
-            out[group][ky][filter_element[group]] =
-                in[output_channel][ky][kx][input_channel];
-            filter_element[group] += 1;
-            out[group][ky][filter_element[group]] = 0;
-            filter_element[group] += 1;
-            out[group][ky][filter_element[group]] = 0;
-            filter_element[group] += 1;
-            out[group][ky][filter_element[group]] = 0;
-            filter_element[group] += 1;
-          } else {
-            out[group][ky][filter_element[group]] =
-                in[output_channel][ky][kx][input_channel];
-            filter_element[group] += 1;
-          }
-        }
+  for (int group = 0; group < output_channels / 8; group++) {
+    const int8_t* input_group_pointer = input + (group * 8 * 3 * 3 * 3);
+    int8_t* output_group_pointer = output + (group * 8 * 3 * 3 * 4);
+
+    for (int channel = 0; channel < 8; channel++) {
+      for (int row = 0; row < 3; row++) {
+        int out_row_offset = (channel * 4) + (row * 3 * kBytesPerRegister);
+
+        const int8_t* input_c1_offset =
+            input_group_pointer + (channel * 27) + (9 * row);
+        int8_t* output_c1_offset = output_group_pointer + out_row_offset;
+        memcpy(output_c1_offset, input_c1_offset, 4);
+
+        const int8_t* input_c2_offset =
+            input_group_pointer + (channel * 27) + (9 * row + 4);
+        int8_t* output_c2_offset =
+            output_group_pointer + out_row_offset + kBytesPerRegister;
+        memcpy(output_c2_offset, input_c2_offset, 4);
+
+        const int8_t* input_c3_offset =
+            input_group_pointer + (channel * 27) + (9 * row + 8);
+        int8_t* output_c3_offset =
+            output_group_pointer + out_row_offset + (2 * kBytesPerRegister);
+        memcpy(output_c3_offset, input_c3_offset, 1);
+
+        *(output_c3_offset + 1) = kFilterZeroPoint;
+        *(output_c3_offset + 2) = kFilterZeroPoint;
+        *(output_c3_offset + 3) = kFilterZeroPoint;
       }
     }
   }
@@ -216,6 +196,7 @@
   const int groups = input_depth / filter_input_depth;
   TFLITE_DCHECK_NE(groups, 0);
   TFLITE_DCHECK_EQ(input_depth % filter_input_depth, 0);
+  TFLITE_DCHECK_EQ(filter_input_depth, 3);
   const int filters_per_group = output_depth / groups;
   TFLITE_DCHECK_NE(filters_per_group, 0);
   const int output_height = output_shape.Dims(1);
@@ -253,9 +234,7 @@
       ::aligned_alloc(32, swizzled_filter_data_size)));
   int8_t* p_swizzled_filter_data = swizzled_filter_data.get();
 
-  PaddedFilter_N_H_W_M(filter_data, p_swizzled_filter_data, output_depth,
-                       filter_height, filter_width, filter_input_depth);
-
+  PaddedFilter(filter_data, p_swizzled_filter_data, output_depth);
   // structure of padded filter data : 1st row 0-8 channels 0-95 , 2nd row 0-8
   // channels 96-191, 3rd row 0-8 channels 192-287