Improve ConvS8D4

- Remove tflite::Offset
- Lift some invariants out of loops
- Loop unroll pragmas
- Remove conditional branching from output stage

Change-Id: Iabf35537ef3cdba3da4375d8b7473fecd1fa38f2
diff --git a/tflm/opt/conv_s8_d4.cc b/tflm/opt/conv_s8_d4.cc
index 18df3e7..54c20bd 100644
--- a/tflm/opt/conv_s8_d4.cc
+++ b/tflm/opt/conv_s8_d4.cc
@@ -27,6 +27,9 @@
 #include "tensorflow/lite/kernels/internal/runtime_shape.h"
 #include "tensorflow/lite/kernels/internal/types.h"
 
+#define unlikely(x) (__builtin_expect(false || (x), false))
+#define likely(x) (__builtin_expect(false || (x), true))
+
 namespace kelvin::opt {
 namespace {
 
@@ -159,8 +162,12 @@
     vrsub_w_vx_m(v24, v24, 0);
 
     for (int batch = 0; batch < batches; ++batch) {
+      const int8_t* p_output =
+          output_data + (batch * output_height * output_width * output_depth) +
+          out_channel;
       for (int out_y = 0; out_y < output_height; ++out_y) {
         const int in_y_origin = (out_y * stride_height) - pad_height;
+        const int out_y_offset = (out_y * output_width * output_depth);
         int out_x = 0;
         do {
           int out_xs_this_iter = std::min(8, output_width - out_x);
@@ -171,55 +178,59 @@
           int in_channel = 0;
           do {
             int in_channels_this_iter = std::min(filter_input_depth, 32);
-            for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
-              const int in_y = in_y_origin + dilation_height_factor * filter_y;
-              const bool is_row_inside_input =
-                  (in_y >= 0) && (in_y < input_height);
-              if (!is_row_inside_input) {
-                continue;
+            // Calculate first valid filter_y
+            int filter_y = 0;
+            {
+              int in_y = in_y_origin;
+              while (in_y < 0) {
+                ++filter_y;
+                in_y += (dilation_height_factor);
               }
+            }
+            for (; filter_y < filter_height; ++filter_y) {
+              const int y_filter_offset =
+                  (filter_y * filter_width * 8 * input_depth);
+              const int in_y = in_y_origin + dilation_height_factor * filter_y;
+              if (in_y >= input_height) {
+                break;
+              }
+              const int8_t* p_in =
+                  input_data + in_channel + (in_y * input_width * input_depth) +
+                  (batch * input_height * input_width * input_depth);
 
+              int in_x[8];
+#pragma GCC unroll 8
+              for (int i = 0; i < 8; ++i) {
+                in_x[i] = ((out_x + i) * stride_width) - pad_width;
+              }
               for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
-                int in_x[8];
-                bool right_pad = false;
+                const int8_t* p_in_x[8];
                 int first_right_pad = -1;
+
+#pragma GCC unroll 8
                 for (int i = 0; i < 8; ++i) {
-                  const int in_x_origin =
-                      ((out_x + i) * stride_width) - pad_width;
-                  in_x[i] = in_x_origin + dilation_width_factor * filter_x;
+                  p_in_x[i] = p_in + (in_x[i] * input_depth);
                 }
-                bool left_pad = (in_x[0] < 0);
+
+#pragma GCC unroll 8
                 for (int i = 7; i >= 0; --i) {
                   if (in_x[i] < input_width) {
                     break;
                   }
-                  right_pad = true;
                   first_right_pad = i;
                 }
+                bool left_pad = (in_x[0] < 0);
+                bool right_pad = (first_right_pad != -1);
 
-                if (left_pad) {
+                int stride = input_depth * stride_width;
+
+                if (unlikely(left_pad)) {
                   vdup_b_x(v0, -input_offset);
-                  vld_b_s_xx(
-                      v1,
-                      &input_data[tflite::Offset(input_shape, batch, in_y,
-                                                 in_x[1], in_channel)],
-                      input_depth * stride_width);
-                  vld_b_s_xx(
-                      v2,
-                      &input_data[tflite::Offset(input_shape, batch, in_y,
-                                                 in_x[2], in_channel)],
-                      input_depth * stride_width);
-                  vld_b_s_xx(
-                      v3,
-                      &input_data[tflite::Offset(input_shape, batch, in_y,
-                                                 in_x[3], in_channel)],
-                      input_depth * stride_width);
-                  vld_b_s_xx_m(
-                      v4,
-                      &input_data[tflite::Offset(input_shape, batch, in_y,
-                                                 in_x[4], in_channel)],
-                      input_depth * stride_width);
-                } else if (right_pad) {
+                  vld_b_s_xx(v1, p_in_x[1], stride);
+                  vld_b_s_xx(v2, p_in_x[2], stride);
+                  vld_b_s_xx(v3, p_in_x[3], stride);
+                  vld_b_s_xx_m(v4, p_in_x[4], stride);
+                } else if (unlikely(right_pad)) {
                   int first_pad = std::min(first_right_pad, out_xs_this_iter);
                   switch (first_pad) {
                     case 0:
@@ -241,88 +252,36 @@
                   }
                   switch (8 - first_pad) { // rest (stripmines?)
                     case 0:
-                      vld_b_s_xx(
-                          v7,
-                          &input_data[tflite::Offset(input_shape, batch, in_y,
-                                                     in_x[7], in_channel)],
-                          input_depth * stride_width);
+                      vld_b_s_xx(v7, p_in_x[7], stride);
                     case 1:
-                      vld_b_s_xx(
-                          v6,
-                          &input_data[tflite::Offset(input_shape, batch, in_y,
-                                                     in_x[6], in_channel)],
-                          input_depth * stride_width);
+                      vld_b_s_xx(v6, p_in_x[6], stride);
                     case 2:
-                      vld_b_s_xx(
-                          v5,
-                          &input_data[tflite::Offset(input_shape, batch, in_y,
-                                                     in_x[5], in_channel)],
-                          input_depth * stride_width);
+                      vld_b_s_xx(v5, p_in_x[5], stride);
                     case 3:
-                      vld_b_s_xx(
-                          v4,
-                          &input_data[tflite::Offset(input_shape, batch, in_y,
-                                                     in_x[4], in_channel)],
-                          input_depth * stride_width);
+                      vld_b_s_xx(v4, p_in_x[4], stride);
                     case 4:
-                      vld_b_s_xx(
-                          v3,
-                          &input_data[tflite::Offset(input_shape, batch, in_y,
-                                                     in_x[3], in_channel)],
-                          input_depth * stride_width);
+                      vld_b_s_xx(v3, p_in_x[3], stride);
                     case 5:
-                      vld_b_s_xx(
-                          v2,
-                          &input_data[tflite::Offset(input_shape, batch, in_y,
-                                                     in_x[2], in_channel)],
-                          input_depth * stride_width);
+                      vld_b_s_xx(v2, p_in_x[2], stride);
                     case 6:
-                      vld_b_s_xx(
-                          v1,
-                          &input_data[tflite::Offset(input_shape, batch, in_y,
-                                                     in_x[1], in_channel)],
-                          input_depth * stride_width);
+                      vld_b_s_xx(v1, p_in_x[1], stride);
                     case 7:
-                      vld_b_s_xx(
-                          v0,
-                          &input_data[tflite::Offset(input_shape, batch, in_y,
-                                                     in_x[0], in_channel)],
-                          input_depth * stride_width);
+                      vld_b_s_xx(v0, p_in_x[0], stride);
                   }
-                } else if (!left_pad && !right_pad) {
+                } else if (likely(!left_pad && !right_pad)) {
                   // Inputs
-                  vld_b_s_xx_m(
-                      v0,
-                      &input_data[tflite::Offset(input_shape, batch, in_y,
-                                                 in_x[0], in_channel)],
-                      input_depth * stride_width);
-                  vld_b_s_xx_m(
-                      v4,
-                      &input_data[tflite::Offset(input_shape, batch, in_y,
-                                                 in_x[4], in_channel)],
-                      input_depth * stride_width);
+                  vld_b_s_xx_m(v0, p_in_x[0], stride);
+                  vld_b_s_xx_m(v4, p_in_x[4], stride);
                 } else {
-                  vdup_b_x(v0, -input_offset);
-                  vdup_b_x(v7, -input_offset);
-                  vld_b_s_xx_m(
-                      v1,
-                      &input_data[tflite::Offset(input_shape, batch, in_y,
-                                                 in_x[1], in_channel)],
-                      input_depth * stride_width);
-                  vld_b_s_xx(
-                      v5,
-                      &input_data[tflite::Offset(input_shape, batch, in_y,
-                                                 in_x[5], in_channel)],
-                      input_depth * stride_width);
-                  vld_b_s_xx(
-                      v6,
-                      &input_data[tflite::Offset(input_shape, batch, in_y,
-                                                 in_x[6], in_channel)],
-                      input_depth * stride_width);
+                  vdup_b_x(v0, neg_input_offset);
+                  vdup_b_x(v7, neg_input_offset);
+                  vld_b_s_xx_m(v1, p_in_x[1], stride);
+                  vld_b_s_xx(v5, p_in_x[5], stride);
+                  vld_b_s_xx(v6, p_in_x[6], stride);
                 }
-                size_t local_filter_offset =
-                    (filter_y * filter_width * 8 * input_depth) +
-                    (filter_x * 8 * input_depth) + (in_channel * 8);
+                size_t local_filter_offset = y_filter_offset +
+                                             (filter_x * 8 * input_depth) +
+                                             (in_channel * 8);
                 int8_t* p_local_filter_start =
                     p_swizzled_filter_data + local_filter_offset;
                 vld_b_p_x_m(v8, p_local_filter_start);
@@ -330,6 +289,11 @@
 
                 cmds.conv.stop = (in_channels_this_iter / 4) - 1;
                 aconv_vxv(v48, v0, cmds, v8);
+
+#pragma GCC unroll 8
+                for (int i = 0; i < 8; ++i) {
+                  in_x[i] += dilation_width_factor;
+                }
               }
             }
             in_channel += in_channels_this_iter;
@@ -349,64 +313,38 @@
           vmax_w_vx_m(v52, v52, output_activation_min);
           vsraqs_b_vx(v56, v48, 0);
           vsraqs_b_vx(v57, v52, 0);
-          if (out_channels_this_iter == 8) {
-            if (out_xs_this_iter >= 4) {
-              vstq_b_s_xx(v56,
-                          &output_data[tflite::Offset(output_shape, batch, out_y,
-                                                      out_x, out_channel)],
-                          output_depth);
-            } else {
-              for (int i = 0; i < std::min(4, out_xs_this_iter); ++i) {
-                if (i > 0) {
-                  vsliden_b_4_vv(v58, v56, v0);
-                  vsliden_b_4_vv(v56, v58, v0);
-                }
-                vst_b_l_xx(v56,
-                          &output_data[tflite::Offset(output_shape, batch, out_y,
-                                                      out_x + i, out_channel)],
-                          out_channels_this_iter);
-              }
-            }
-            if (out_xs_this_iter == 8) {
-              vstq_b_s_xx(v57,
-                          &output_data[tflite::Offset(output_shape, batch, out_y,
-                                                      out_x + 4, out_channel)],
-                          output_depth);
-            } else if (out_xs_this_iter > 4) {
-              for (int i = 4; i < std::min(8, out_xs_this_iter); ++i) {
-                if (i > 4) {
-                  vsliden_b_4_vv(v58, v57, v0);
-                  vsliden_b_4_vv(v57, v58, v0);
-                }
-                vst_b_l_xx(v57,
-                          &output_data[tflite::Offset(output_shape, batch, out_y,
-                                                      out_x + i, out_channel)],
-                          out_channels_this_iter);
-              }
-            }
-          } else {
-              for (int i = 0; i < std::min(4, out_xs_this_iter); ++i) {
-                if (i > 0) {
-                  vsliden_b_4_vv(v58, v56, v0);
-                  vsliden_b_4_vv(v56, v58, v0);
-                }
-                vst_b_l_xx(v56,
-                          &output_data[tflite::Offset(output_shape, batch, out_y,
-                                                      out_x + i, out_channel)],
-                          out_channels_this_iter);
-              }
-            if (out_xs_this_iter > 4) {
-              for (int i = 4; i < std::min(8, out_xs_this_iter); ++i) {
-                if (i > 4) {
-                  vsliden_b_4_vv(v58, v57, v0);
-                  vsliden_b_4_vv(v57, v58, v0);
-                }
-                vst_b_l_xx(v57,
-                          &output_data[tflite::Offset(output_shape, batch, out_y,
-                                                      out_x + i, out_channel)],
-                          out_channels_this_iter);
-              }
-            }
+
+          const int8_t* p_out_x[8];
+#pragma GCC unroll 8
+          for (int i = 0; i < 8; ++i) {
+            p_out_x[i] = p_output + out_y_offset + ((out_x + i) * output_depth);
+          }
+
+          vslidep_h_4_vv(v58, v57, v57);  // x7
+          vslidep_h_4_vv(v59, v58, v58);  // x6
+          vslidep_h_4_vv(v60, v59, v59);  // x5
+          vslidep_h_4_vv(v61, v60, v60);  // x4
+          vslidep_h_4_vv(v62, v56, v56);  // x3
+          vslidep_h_4_vv(v63, v62, v62);  // x2
+          vslidep_h_4_vv(v57, v63, v63);  // x1
+          vslidep_h_4_vv(v56, v57, v57);  // x0
+          switch (out_xs_this_iter) {
+            case 8:
+              vst_b_l_xx(v58, p_out_x[7], out_channels_this_iter);
+            case 7:
+              vst_b_l_xx(v59, p_out_x[6], out_channels_this_iter);
+            case 6:
+              vst_b_l_xx(v60, p_out_x[5], out_channels_this_iter);
+            case 5:
+              vst_b_l_xx(v61, p_out_x[4], out_channels_this_iter);
+            case 4:
+              vst_b_l_xx(v62, p_out_x[3], out_channels_this_iter);
+            case 3:
+              vst_b_l_xx(v63, p_out_x[2], out_channels_this_iter);
+            case 2:
+              vst_b_l_xx(v57, p_out_x[1], out_channels_this_iter);
+            case 1:
+              vst_b_l_xx(v56, p_out_x[0], out_channels_this_iter);
           }
           out_x += out_xs_this_iter;
         } while (out_x < output_width);