Relax output_depth for ConvS8D4

- ConvS8D4 can now handle output depths that are not a multiple of 8.

Change-Id: Ibfc445e09f392453ddb8c038a5b224b8bcb33c4e
diff --git a/tflm/opt/conv_s8.cc b/tflm/opt/conv_s8.cc
index 2d49dbc..ce56fda 100644
--- a/tflm/opt/conv_s8.cc
+++ b/tflm/opt/conv_s8.cc
@@ -198,7 +198,7 @@
   // special case of filter_depth = 4n
   if (dilation_width_factor == 1 && dilation_height_factor == 1 &&
       stride_width <= 2 && stride_height <= 2 && filter_depth % 4 == 0 &&
-      output_depth % 8 == 0 && output_width >= 8 && pad_width <= 1) {
+      output_depth >= 8 && output_width >= 8 && pad_width <= 1) {
     fn = kelvin::opt::ConvS8D4;
   }
 
diff --git a/tflm/opt/conv_s8_d4.cc b/tflm/opt/conv_s8_d4.cc
index 0dd3e50..18df3e7 100644
--- a/tflm/opt/conv_s8_d4.cc
+++ b/tflm/opt/conv_s8_d4.cc
@@ -29,11 +29,12 @@
 
 namespace kelvin::opt {
 namespace {
-void Filter_8_H_W_M(const int8_t* input, int8_t* output, int H, int W, int M) {
+
+void Filter_N_H_W_M(const int8_t* input, int8_t* output, int N, int H, int W, int M) {
   const int8_t(&in)[8][H][W][M] = *(int8_t(*)[8][H][W][M])input;
   int8_t(&out)[H][W][M / 4][8][4] = *(int8_t(*)[H][W][M / 4][8][4]) output;
   assert(M >= 4);
-  for (int zo = 0; zo < 8; ++zo) {
+  for (int zo = 0; zo < N; ++zo) {
     for (int ky = 0; ky < H; ++ky) {
       for (int kx = 0; kx < W; ++kx) {
         for (int zi = 0; zi < M; ++zi) {
@@ -44,28 +45,33 @@
       }
     }
   }
-}
-
-void Swizzle(const int32_t* input, int32_t* output, int N) {
-  const int32_t(&in)[N] = *(int32_t(*)[N])input;
-  int32_t(&out)[N * 4] = *(int32_t(*)[N * 4]) output;
-  // Convert to accumulator swizzle pattern.
-  for (int i = 0; i < N / 8; ++i) {
-    int32_t* out0 = out + i * 32 + 0;
-    int32_t* out1 = out + i * 32 + 16;
-    int32_t* out2 = out + i * 32 + 8;
-    int32_t* out3 = out + i * 32 + 24;
-    for (int j = 0; j < 4; ++j) {
-      const int32_t* p_in = in + i * 8;
-      for (int k = 0; k < 2; ++k) {
-        *out0++ = *p_in++;
-        *out1++ = *p_in++;
-        *out2++ = *p_in++;
-        *out3++ = *p_in++;
+  // Zero out the rest of the output.
+  for (int zo = N; zo < 8; ++zo) {
+    for (int ky = 0; ky < H; ++ky) {
+      for (int kx = 0; kx < W; ++kx) {
+        for (int zi = 0; zi < M; ++zi) {
+          const int zi_hi = zi >> 2;  // div4
+          const int zi_lo = zi & 3;   // rem4
+          out[ky][kx][zi_hi][zo][zi_lo] = 0;
+        }
       }
     }
   }
 }
+
+void Swizzle(const int32_t* input, int32_t* output, int N) {
+  assert(N <= 8);
+  const int32_t(&in)[8] = *(int32_t(*)[8])input;
+  int32_t(&out)[32] = *(int32_t(*)[32]) output;
+  // Convert to accumulator swizzle pattern.
+  memset(out, 0, 32 * sizeof(int32_t));
+  int offsets[] = {0, 16, 8, 24, 1, 17, 9, 25};
+  for (int i = 0; i < N; ++i) {
+    int offset = offsets[i];
+    out[0 + offset] = out[2 + offset] = out[4 + offset] = out[6 + offset] = in[i];
+  }
+}
+
 }  // namespace
 
 void ConvS8D4(
@@ -137,14 +143,16 @@
   int32_t swizzled_mult_data[32];
   int32_t swizzled_shift_data[32];
 
-  for (int out_channel = 0; out_channel + 8 <= output_depth; out_channel += 8) {
-    Filter_8_H_W_M(filter_data + (out_channel * filter_height * filter_width *
+  int out_channel = 0;
+  do {
+    int out_channels_this_iter = std::min(8, output_depth - out_channel);
+    Filter_N_H_W_M(filter_data + (out_channel * filter_height * filter_width *
                                   filter_input_depth),
-                   p_swizzled_filter_data, filter_height, filter_width,
+                   p_swizzled_filter_data, out_channels_this_iter, filter_height, filter_width,
                    filter_input_depth);
-    Swizzle(bias_data + out_channel, swizzled_bias_data, 8);
-    Swizzle(output_multiplier + out_channel, swizzled_mult_data, 8);
-    Swizzle(output_shift + out_channel, swizzled_shift_data, 8);
+    Swizzle(bias_data + out_channel, swizzled_bias_data, out_channels_this_iter);
+    Swizzle(output_multiplier + out_channel, swizzled_mult_data, out_channels_this_iter);
+    Swizzle(output_shift + out_channel, swizzled_shift_data, out_channels_this_iter);
     vld_w_x_m(v16, swizzled_bias_data);
     vld_w_x_m(v20, swizzled_mult_data);
     vld_w_x_m(v24, swizzled_shift_data);
@@ -162,7 +170,7 @@
           acset_v(v48, v48);
           int in_channel = 0;
           do {
-            int channels_this_iter = std::min(filter_input_depth, 32);
+            int in_channels_this_iter = std::min(filter_input_depth, 32);
             for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
               const int in_y = in_y_origin + dilation_height_factor * filter_y;
               const bool is_row_inside_input =
@@ -320,11 +328,11 @@
                 vld_b_p_x_m(v8, p_local_filter_start);
                 vld_b_x_m(v12, p_local_filter_start);
 
-                cmds.conv.stop = (channels_this_iter / 4) - 1;
+                cmds.conv.stop = (in_channels_this_iter / 4) - 1;
                 aconv_vxv(v48, v0, cmds, v8);
               }
             }
-            in_channel += channels_this_iter;
+            in_channel += in_channels_this_iter;
           } while (in_channel < filter_input_depth);
           vcget(v48);
           vadd_w_vv_m(v48, v48, v16);
@@ -341,44 +349,70 @@
           vmax_w_vx_m(v52, v52, output_activation_min);
           vsraqs_b_vx(v56, v48, 0);
           vsraqs_b_vx(v57, v52, 0);
-          if (out_xs_this_iter >= 4) {
-            vstq_b_s_xx(v56,
-                        &output_data[tflite::Offset(output_shape, batch, out_y,
-                                                    out_x, out_channel)],
-                        output_depth);
-          } else {
-            for (int i = 0; i < out_xs_this_iter; ++i) {
-              if (i > 0) {
-                vsliden_b_4_vv(v58, v56, v0);
-                vsliden_b_4_vv(v56, v58, v0);
+          if (out_channels_this_iter == 8) {
+            if (out_xs_this_iter >= 4) {
+              vstq_b_s_xx(v56,
+                          &output_data[tflite::Offset(output_shape, batch, out_y,
+                                                      out_x, out_channel)],
+                          output_depth);
+            } else {
+              for (int i = 0; i < std::min(4, out_xs_this_iter); ++i) {
+                if (i > 0) {
+                  vsliden_b_4_vv(v58, v56, v0);
+                  vsliden_b_4_vv(v56, v58, v0);
+                }
+                vst_b_l_xx(v56,
+                          &output_data[tflite::Offset(output_shape, batch, out_y,
+                                                      out_x + i, out_channel)],
+                          out_channels_this_iter);
               }
-              vst_b_l_xx(v56,
-                        &output_data[tflite::Offset(output_shape, batch, out_y,
-                                                    out_x + i, out_channel)],
-                        8);
             }
-          }
-          if (out_xs_this_iter == 8) {
-            vstq_b_s_xx(v57,
-                        &output_data[tflite::Offset(output_shape, batch, out_y,
-                                                    out_x + 4, out_channel)],
-                        output_depth);
-          } else if (out_xs_this_iter > 4) {
-            for (int i = 4; i < out_xs_this_iter; ++i) {
-              if (i > 4) {
-                vsliden_b_4_vv(v58, v57, v0);
-                vsliden_b_4_vv(v57, v58, v0);
+            if (out_xs_this_iter == 8) {
+              vstq_b_s_xx(v57,
+                          &output_data[tflite::Offset(output_shape, batch, out_y,
+                                                      out_x + 4, out_channel)],
+                          output_depth);
+            } else if (out_xs_this_iter > 4) {
+              for (int i = 4; i < std::min(8, out_xs_this_iter); ++i) {
+                if (i > 4) {
+                  vsliden_b_4_vv(v58, v57, v0);
+                  vsliden_b_4_vv(v57, v58, v0);
+                }
+                vst_b_l_xx(v57,
+                          &output_data[tflite::Offset(output_shape, batch, out_y,
+                                                      out_x + i, out_channel)],
+                          out_channels_this_iter);
               }
-              vst_b_l_xx(v57,
-                        &output_data[tflite::Offset(output_shape, batch, out_y,
-                                                    out_x + i, out_channel)],
-                        8);
+            }
+          } else {
+              for (int i = 0; i < std::min(4, out_xs_this_iter); ++i) {
+                if (i > 0) {
+                  vsliden_b_4_vv(v58, v56, v0);
+                  vsliden_b_4_vv(v56, v58, v0);
+                }
+                vst_b_l_xx(v56,
+                          &output_data[tflite::Offset(output_shape, batch, out_y,
+                                                      out_x + i, out_channel)],
+                          out_channels_this_iter);
+              }
+            if (out_xs_this_iter > 4) {
+              for (int i = 4; i < std::min(8, out_xs_this_iter); ++i) {
+                if (i > 4) {
+                  vsliden_b_4_vv(v58, v57, v0);
+                  vsliden_b_4_vv(v57, v58, v0);
+                }
+                vst_b_l_xx(v57,
+                          &output_data[tflite::Offset(output_shape, batch, out_y,
+                                                      out_x + i, out_channel)],
+                          out_channels_this_iter);
+              }
             }
           }
           out_x += out_xs_this_iter;
         } while (out_x < output_width);
       }
     }
-  }
+    out_channel += out_channels_this_iter;
+  } while (out_channel < output_depth);
 }
 }  // namespace kelvin::opt