Specialize 5x5 DepthwiseConv

- Use adwconv for stride == 1
- Improve reuse of weights by only loading once for each
  channel.

Change-Id: Idb865668f039450c314dbc5f3046203d9e621240
Bypass-Presubmit-Reason: Flaky test_nexus_boot_robot
diff --git a/crt/kelvin.h b/crt/kelvin.h
index a631593..10cc34f 100644
--- a/crt/kelvin.h
+++ b/crt/kelvin.h
@@ -62,4 +62,16 @@
 };
 static_assert(sizeof(struct vconv_u8_t) == 4);
 
+struct vdwconv_u8_t {
+  uint32_t mode:2;      // 1:0
+  uint32_t sparsity:2;  // 3:2
+  uint32_t regbase:4;   // 7:4
+  uint32_t rsvd:4;      // 11:8
+  int32_t sbias1:9;    // 20:12
+  uint32_t sdata1:1;    // 21
+  int32_t sbias2:9;    // 30:22
+  uint32_t sdata2:1;    // 31
+};
+static_assert(sizeof(struct vdwconv_u8_t) == 4);
+
 #endif  // CRT_KELVIN_H_
diff --git a/tests/kelvin_isa/vdwconv.cc b/tests/kelvin_isa/vdwconv.cc
index ec155dc..31e391f 100644
--- a/tests/kelvin_isa/vdwconv.cc
+++ b/tests/kelvin_isa/vdwconv.cc
@@ -20,18 +20,6 @@
 #include "tests/kelvin_isa/kelvin_test.h"
 #include "tests/kelvin_isa/vdwconv_data.h"
 
-struct vdwconv_u8_t {
-  uint32_t mode:2;      // 1:0
-  uint32_t sparsity:2;  // 3:2
-  uint32_t regbase:4;   // 7:4
-  uint32_t rsvd:4;      // 11:8
-  int32_t sbias1:9;    // 20:12
-  uint32_t sdata1:1;    // 21
-  int32_t sbias2:9;    // 30:22
-  uint32_t sdata2:1;    // 31
-};
-static_assert(sizeof(vdwconv_u8_t) == 4);
-
 #ifdef TEST_GEN
 static int32_t dwconv(const vdwconv_u8_t& cmd, uint8_t ina[3], uint8_t inb[3]) {
   int32_t sbias1 = cmd.sbias1;
diff --git a/tflm/opt/depthwise_conv_s8.cc b/tflm/opt/depthwise_conv_s8.cc
index 111fdc1..9e15b3c 100644
--- a/tflm/opt/depthwise_conv_s8.cc
+++ b/tflm/opt/depthwise_conv_s8.cc
@@ -38,8 +38,475 @@
     *out1 = *p_in++;
     *out2 = *p_in++;
     *out3 = *p_in++;
+  }
+}
+
+// special case of input depth = 32n, filter shape of 5x5, stride == 1
+void DepthwiseConvS85x5D32_Stride1(
+    const tflite::DepthwiseParams& params, const int32_t* output_multiplier,
+    const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
+    const int8_t* input_data, const tflite::RuntimeShape& filter_shape,
+    const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
+    const int32_t* bias_data, const tflite::RuntimeShape& output_shape,
+    int8_t* output_data
+) {
+  const int stride_width = params.stride_width;
+  const int stride_height = params.stride_height;
+  const int pad_width = params.padding_values.width;
+  const int pad_height = params.padding_values.height;
+  const int32_t input_offset = params.input_offset;
+  const int32_t output_offset = params.output_offset;
+  const int32_t output_activation_min = params.quantized_activation_min;
+  const int32_t output_activation_max = params.quantized_activation_max;
+  const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+  const int input_height = input_shape.Dims(1);
+  const int input_width = input_shape.Dims(2);
+  const int input_depth = input_shape.Dims(3);
+  const int output_height = output_shape.Dims(1);
+  const int output_width = output_shape.Dims(2);
+  const int output_depth = output_shape.Dims(3);
+  int32_t swizzled_bias_data[32];
+  int32_t swizzled_shift_multi[32];
+  int32_t swizzled_output_multi[32];
+
+  for (int in_channel = 0; in_channel + 32 <= input_depth; in_channel += 32) {
+    const int output_channel = in_channel;
+    VectorSwizzle(bias_data + output_channel, swizzled_bias_data, 32);
+    VectorSwizzle(output_multiplier + output_channel, swizzled_output_multi, 32);
+    VectorSwizzle(output_shift + output_channel, swizzled_shift_multi, 32);
+
+    union {
+      vdwconv_u8_t dwconv;
+      uint32_t raw;
+    } cmds;
+    cmds.raw = 0;
+    cmds.dwconv.sdata1 = true;
+    cmds.dwconv.sbias1 = input_offset;
+    cmds.dwconv.sdata2 = true;
+    cmds.dwconv.sbias2 = 0;
+    cmds.dwconv.mode = 0;
+    cmds.dwconv.sparsity = 0;
+    cmds.dwconv.regbase = 0;
+
+    // Don't reorder me!
+    const int8_t* p_flt0 = filter_data + in_channel;
+    const int32_t stride = input_depth;
+    vld_b_sp_xx_m(v0, p_flt0, stride);
+    vld_b_sp_xx_m(v4, p_flt0, stride);
+    vld_b_sp_xx_m(v8, p_flt0, stride);
+    vld_b_sp_xx_m(v12, p_flt0, stride);
+    vld_b_sp_xx_m(v16, p_flt0, stride);
+    vld_b_sp_xx_m(v20, p_flt0, stride);
+    vld_b_sp_xx(v24, p_flt0, stride);
+
+    // Extra two registers to get our
+    // total usage to a multiple of 3 for dwconv.
+    vdup_b_x(v25, 0);
+    vdup_b_x(v26, 0);
+
+    for (int batch = 0; batch < batches; ++batch) {
+      const int8_t* p_output = output_data + (batch * output_height * output_width * output_depth) + output_channel;
+      for (int out_y = 0; out_y < output_height; ++out_y) {
+        const int y_offset = out_y * output_width * output_depth;
+        for (int out_x = 0; out_x < output_width; ++out_x) {
+          const int in_x_origin = (out_x * stride_width) - pad_width;
+          const int in_y_origin = (out_y * stride_height) - pad_height;
+
+          bool top_pad = in_y_origin < 0;
+          bool left_pad = in_x_origin < 0;
+          int top_pad_count = top_pad ? 0 - in_y_origin : 0;
+          int left_pad_count = left_pad ? 0 - in_x_origin : 0;
+          bool bottom_pad = (in_y_origin + 4) >= input_height;
+          bool right_pad = (in_x_origin + 4) >= input_width;
+          int bottom_pad_count = std::abs(bottom_pad ? (in_y_origin + 4) - input_height + 1: 0);
+          int right_pad_count = std::abs(right_pad ? (in_x_origin + 4) - input_width + 1 : 0);
+          bool padding_required = top_pad || left_pad || bottom_pad || right_pad;
+          assert(top_pad_count <= pad_height);
+          assert(bottom_pad_count <= pad_height);
+          assert(left_pad_count <= pad_width);
+          assert(right_pad_count <= pad_width);
+          assert(!(left_pad && right_pad));
+          const int8_t* p_in_0 = input_data +
+            (batch * input_height * input_width * input_depth) +
+            (in_y_origin * input_width * input_depth) +
+            (in_x_origin * input_depth) +
+            in_channel;
+          const int8_t* p_in_1 = p_in_0 + (input_width * input_depth);
+          const int8_t* p_in_2 = p_in_1 + (input_width * input_depth);
+          const int8_t* p_in_3 = p_in_2 + (input_width * input_depth);
+          const int8_t* p_in_4 = p_in_3 + (input_width * input_depth);
+          // Extra two registers to get our
+          // total usage to a multiple of 3 for dwconv.
+          vdup_b_x(v52, -input_offset);
+          vdup_b_x(v53, -input_offset);
+          if (!padding_required) {
+            vld_b_sp_xx(v27, p_in_0, input_depth);
+            vld_b_sp_xx_m(v28, p_in_0, input_depth);
+            vld_b_sp_xx_m(v32, p_in_1, input_depth);
+            vld_b_sp_xx(v36, p_in_1, input_depth);
+            vld_b_sp_xx(v37, p_in_2, input_depth);
+            vld_b_sp_xx(v38, p_in_2, input_depth);
+            vld_b_sp_xx(v39, p_in_2, input_depth);
+            vld_b_sp_xx(v40, p_in_2, input_depth);
+            vld_b_sp_xx(v41, p_in_2, input_depth);
+            vld_b_sp_xx(v42, p_in_3, input_depth);
+            vld_b_sp_xx(v43, p_in_3, input_depth);
+            vld_b_sp_xx(v44, p_in_3, input_depth);
+            vld_b_sp_xx(v45, p_in_3, input_depth);
+            vld_b_sp_xx(v46, p_in_3, input_depth);
+            vld_b_sp_xx(v47, p_in_4, input_depth);
+            vld_b_sp_xx_m(v48, p_in_4, input_depth);
+          } else {
+            // Top row
+            if (top_pad_count >= 1) {
+              vdup_b_x(v27, -input_offset);
+              vdup_b_x_m(v28, -input_offset);
+            } else {
+              switch (left_pad_count) {
+                case 2:
+                  vdup_b_x(v28, -input_offset);
+                case 1:
+                  vdup_b_x(v27, -input_offset);
+              }
+              switch (left_pad_count) {
+                case 0:
+                  vld_b_x(v27, p_in_0);
+                case 1:
+                  vld_b_x(v28, p_in_0 + input_depth);
+              }
+              vld_b_x(v29, p_in_0 + (2 * input_depth));
+              switch (right_pad_count) {
+                case 2:
+                  vdup_b_x(v30, -input_offset);
+                case 1:
+                  vdup_b_x(v31, -input_offset);
+              }
+              switch (right_pad_count) {
+                case 0:
+                  vld_b_x(v31, p_in_0 + (4 * input_depth));
+                case 1:
+                  vld_b_x(v30, p_in_0 + (3 * input_depth));
+              }
+            }
+
+            // 2nd row
+            if (top_pad_count == 2) {
+              vdup_b_x_m(v32, -input_offset);
+              vdup_b_x(v36, -input_offset);
+            } else {
+              switch (left_pad_count) {
+                case 2:
+                  vdup_b_x(v33, -input_offset);
+                case 1:
+                  vdup_b_x(v32, -input_offset);
+              }
+              switch (left_pad_count) {
+                case 0:
+                  vld_b_x(v32, p_in_1);
+                case 1:
+                  vld_b_x(v33, p_in_1 + input_depth);
+              }
+              vld_b_x(v34, p_in_1 + (2 * input_depth));
+              switch (right_pad_count) {
+                case 2:
+                  vdup_b_x(v35, -input_offset);
+                case 1:
+                  vdup_b_x(v36, -input_offset);
+              }
+              switch (right_pad_count) {
+                case 0:
+                  vld_b_x(v36, p_in_1 + (4 * input_depth));
+                case 1:
+                  vld_b_x(v35, p_in_1 + (3 * input_depth));
+              }
+            }
+
+            // 3rd row
+            switch (left_pad_count) {
+              case 2:
+                vdup_b_x(v38, -input_offset);
+              case 1:
+                vdup_b_x(v37, -input_offset);
+            }
+            switch (left_pad_count) {
+              case 0:
+                vld_b_x(v37, p_in_2);
+              case 1:
+                vld_b_x(v38, p_in_2 + input_depth);
+            }
+            vld_b_x(v39, p_in_2 + (2 * input_depth));
+            switch (right_pad_count) {
+              case 2:
+                vdup_b_x(v40, -input_offset);
+              case 1:
+                vdup_b_x(v41, -input_offset);
+            }
+            switch (right_pad_count) {
+              case 0:
+                vld_b_x(v41, p_in_2 + (4 * input_depth));
+              case 1:
+                vld_b_x(v40, p_in_2 + (3 * input_depth));
+            }
+
+            // 4th row
+            if (bottom_pad_count == 2) {
+              vdup_b_x(v42, -input_offset);
+              vdup_b_x(v43, -input_offset);
+              vdup_b_x(v44, -input_offset);
+              vdup_b_x(v45, -input_offset);
+              vdup_b_x(v46, -input_offset);
+            } else {
+              switch (left_pad_count) {
+                case 2:
+                  vdup_b_x(v43, -input_offset);
+                case 1:
+                  vdup_b_x(v42, -input_offset);
+              }
+              switch (left_pad_count) {
+                case 0:
+                  vld_b_x(v42, p_in_3);
+                case 1:
+                  vld_b_x(v43, p_in_3 + input_depth);
+              }
+              switch (right_pad_count) {
+                case 2:
+                  vdup_b_x(v45, -input_offset);
+                case 1:
+                  vdup_b_x(v46, -input_offset);
+              }
+              vld_b_x(v44, p_in_3 + (2 * input_depth));
+              switch (right_pad_count) {
+                case 0:
+                  vld_b_x(v46, p_in_3 + (4 * input_depth));
+                case 1:
+                  vld_b_x(v45, p_in_3 + (3 * input_depth));
+              }
+            }
+
+            // 5th row
+            if (bottom_pad_count >= 1) {
+              vdup_b_x(v47, -input_offset);
+              vdup_b_x(v48, -input_offset);
+              vdup_b_x(v49, -input_offset);
+              vdup_b_x(v50, -input_offset);
+              vdup_b_x(v51, -input_offset);
+            } else {
+              switch (left_pad_count) {
+                case 2:
+                  vdup_b_x(v48, -input_offset);
+                case 1:
+                  vdup_b_x(v47, -input_offset);
+              }
+              switch (left_pad_count) {
+                case 0:
+                  vld_b_x(v47, p_in_4);
+                case 1:
+                  vld_b_x(v48, p_in_4 + input_depth);
+              }
+              vld_b_x(v49, p_in_4 + (2 * input_depth));
+              switch (right_pad_count) {
+                case 2:
+                  vdup_b_x(v50, -input_offset);
+                case 1:
+                  vdup_b_x(v51, -input_offset);
+              }
+              switch (right_pad_count) {
+                case 0:
+                  vld_b_x(v51, p_in_4 + (4 * input_depth));
+                case 1:
+                  vld_b_x(v50, p_in_4 + (3 * input_depth));
+              }
+            }
+          }
+
+          vld_w_x_m(v60, swizzled_bias_data);
+          adwinit_v(v60, v60);
+          adwconv_vxv(v60, v27, cmds, v0);
+          adwconv_vxv(v60, v30, cmds, v3);
+          adwconv_vxv(v60, v33, cmds, v6);
+          adwconv_vxv(v60, v36, cmds, v9);
+          adwconv_vxv(v60, v39, cmds, v12);
+          adwconv_vxv(v60, v42, cmds, v15);
+          adwconv_vxv(v60, v45, cmds, v18);
+          adwconv_vxv(v60, v48, cmds, v21);
+          vdwconv_vxv(v60, v51, cmds, v24);
+
+          vld_w_x_m(v56, swizzled_output_multi);
+          vdmulh_w_rn_vv_m(v60, v60, v56);
+          vld_w_x_m(v56, swizzled_shift_multi);
+          vrsub_w_vx_m(v56, v56, 0);
+          vsha_w_r_vv_m(v60, v60, v56);
+          vadd_w_vx_m(v60, v60, output_offset);
+          vmax_w_vx_m(v60, v60, output_activation_min);
+          vmin_w_vx_m(v60, v60, output_activation_max);
+          vsraqs_b_vx(v60, v60, 0);
+          vst_b_x(v60, p_output + y_offset + (out_x * output_depth));
+        }
+      }
     }
   }
+}
+
+// special case of input depth = 32n, filter shape of 5x5
+void DepthwiseConvS85x5D32(
+    const tflite::DepthwiseParams& params, const int32_t* output_multiplier,
+    const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
+    const int8_t* input_data, const tflite::RuntimeShape& filter_shape,
+    const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
+    const int32_t* bias_data, const tflite::RuntimeShape& output_shape,
+    int8_t* output_data
+) {
+  const int stride_width = params.stride_width;
+  const int stride_height = params.stride_height;
+  const int pad_width = params.padding_values.width;
+  const int pad_height = params.padding_values.height;
+  const int32_t input_offset = params.input_offset;
+  const int32_t output_offset = params.output_offset;
+  const int32_t output_activation_min = params.quantized_activation_min;
+  const int32_t output_activation_max = params.quantized_activation_max;
+  const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+  const int input_height = input_shape.Dims(1);
+  const int input_width = input_shape.Dims(2);
+  const int input_depth = input_shape.Dims(3);
+  const int filter_height = filter_shape.Dims(1);
+  const int filter_width = filter_shape.Dims(2);
+  const int output_height = output_shape.Dims(1);
+  const int output_width = output_shape.Dims(2);
+  const int output_depth = output_shape.Dims(3);
+  int32_t swizzled_bias_data[32];
+  int32_t swizzled_shift_multi[32];
+  int32_t swizzled_output_multi[32];
+
+  for (int in_channel = 0; in_channel + 32 <= input_depth; in_channel += 32) {
+    const int output_channel = in_channel;
+    VectorSwizzle(bias_data + output_channel, swizzled_bias_data, 32);
+    VectorSwizzle(output_multiplier + output_channel, swizzled_output_multi, 32);
+    VectorSwizzle(output_shift + output_channel, swizzled_shift_multi, 32);
+
+    vld_w_x_m(v52, swizzled_bias_data);
+    vld_w_x_m(v56, swizzled_output_multi);
+    vld_w_x_m(v60, swizzled_shift_multi);
+    vrsub_w_vx_m(v60, v60, 0);
+
+    // Don't reorder me!
+    const int8_t* p_flt = filter_data + in_channel;
+    vld_b_sp_xx(v6, p_flt, input_depth);
+    vld_b_sp_xx(v7, p_flt, input_depth);
+    vld_b_sp_xx_m(v8, p_flt, input_depth);
+    vld_b_sp_xx_m(v12, p_flt, input_depth);
+    vld_b_sp_xx_m(v16, p_flt, input_depth);
+    vld_b_sp_xx_m(v20, p_flt, input_depth);
+    vld_b_sp_xx_m(v24, p_flt, input_depth);
+    vld_b_sp_xx(v28, p_flt, input_depth);
+    vld_b_sp_xx(v29, p_flt, input_depth);
+    vld_b_sp_xx(v30, p_flt, input_depth);
+
+
+    for (int batch = 0; batch < batches; ++batch) {
+      const int8_t* p_input = input_data + (batch * input_width * input_height * input_depth) + in_channel;
+      const int8_t* p_output = output_data + (batch * output_width * output_height * output_depth) + output_channel;
+      for (int out_y = 0; out_y < output_height; ++out_y) {
+        const int out_y_offset = (out_y * output_width * output_depth);
+        for (int out_x = 0; out_x < output_width; ++out_x) {
+          const int in_x_origin = (out_x * stride_width) - pad_width;
+          const int in_y_origin = (out_y * stride_height) - pad_height;
+
+          // Initialize accumulators w/ bias_data
+          vmv_v_m(v48, v52);
+
+          for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
+            const int in_y = in_y_origin + filter_y;
+            if ((in_y < 0) || (in_y >= input_height)) {
+              continue;
+            }
+            switch (filter_y) {
+              case 0:
+                vaddw_h_vx(v31, v6, 0);
+                vaddw_h_vx(v33, v7, 0);
+                vaddw_h_vx(v35, v8, 0);
+                vaddw_h_vx(v37, v9, 0);
+                vaddw_h_vx(v39, v10, 0);
+                break;
+              case 1:
+                vaddw_h_vx(v31, v11, 0);
+                vaddw_h_vx(v33, v12, 0);
+                vaddw_h_vx(v35, v13, 0);
+                vaddw_h_vx(v37, v14, 0);
+                vaddw_h_vx(v39, v15, 0);
+                break;
+              case 2:
+                vaddw_h_vx(v31, v16, 0);
+                vaddw_h_vx(v33, v17, 0);
+                vaddw_h_vx(v35, v18, 0);
+                vaddw_h_vx(v37, v19, 0);
+                vaddw_h_vx(v39, v20, 0);
+                break;
+              case 3:
+                vaddw_h_vx(v31, v21, 0);
+                vaddw_h_vx(v33, v22, 0);
+                vaddw_h_vx(v35, v23, 0);
+                vaddw_h_vx(v37, v24, 0);
+                vaddw_h_vx(v39, v25, 0);
+                break;
+              case 4:
+                vaddw_h_vx(v31, v26, 0);
+                vaddw_h_vx(v33, v27, 0);
+                vaddw_h_vx(v35, v28, 0);
+                vaddw_h_vx(v37, v29, 0);
+                vaddw_h_vx(v39, v30, 0);
+                break;
+            }
+            const int in_y_offset = in_y  * input_width * input_depth;
+            for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
+              const int in_x = in_x_origin + filter_x;
+              if ((in_x < 0) || (in_x >= input_width)) {
+                continue;
+              }
+
+              vld_b_x(v0, p_input + (in_x * input_depth) + in_y_offset);
+
+              vaddw_h_vx(v0, v0, 0);
+              vadd_h_vx(v0, v0, static_cast<int16_t>(input_offset));
+              vadd_h_vx(v1, v1,
+                        static_cast<int16_t>(input_offset));  // v0 v1 input
+              switch (filter_x) {
+                case 0:
+                  vmulw_w_vv(v2, v1, v32);
+                  vmulw_w_vv(v0, v0, v31);
+                  break;
+                case 1:
+                  vmulw_w_vv(v2, v1, v34);
+                  vmulw_w_vv(v0, v0, v33);
+                  break;
+                case 2:
+                  vmulw_w_vv(v2, v1, v36);
+                  vmulw_w_vv(v0, v0, v35);
+                  break;
+                case 3:
+                  vmulw_w_vv(v2, v1, v38);
+                  vmulw_w_vv(v0, v0, v37);
+                  break;
+                case 4:
+                  vmulw_w_vv(v2, v1, v40);
+                  vmulw_w_vv(v0, v0, v39);
+                  break;
+              }
+              vadd_w_vv_m(v48, v48, v0);
+            }
+          }
+
+          vdmulh_w_rn_vv_m(v48, v48, v56);
+          vsha_w_r_vv_m(v48, v48, v60);
+          vadd_w_vx_m(v48, v48, output_offset);
+          vmax_w_vx_m(v48, v48, output_activation_min);
+          vmin_w_vx_m(v48, v48, output_activation_max);
+          vsraqs_b_vx(v48, v48, 0);
+          vst_b_x(v48, p_output + out_y_offset + (out_x * output_depth));
+        }
+      }
+    }
+  }
+}
 
 // special case of input depth = 32n
 void DepthwiseConvS8D32(
@@ -161,6 +628,8 @@
   // TODO(b/141565753): Re-introduce ScopedProfilingLabel on Micro.
   const int stride_width = params.stride_width;
   const int stride_height = params.stride_height;
+  const int filter_height = filter_shape.Dims(1);
+  const int filter_width = filter_shape.Dims(2);
   const int dilation_width_factor = params.dilation_width_factor;
   const int dilation_height_factor = params.dilation_height_factor;
   const int depth_multiplier = params.depth_multiplier;
@@ -186,7 +655,15 @@
 
     // special case of output depth = 32n
     if (output_depth % 32 == 0) {
-      fn = DepthwiseConvS8D32;
+      if (filter_width == 5 && filter_height == 5) {
+        if (stride_width <= 1 && stride_height <= 1) {
+          fn = DepthwiseConvS85x5D32_Stride1;
+        } else {
+          fn = DepthwiseConvS85x5D32;
+        }
+      } else {
+        fn = DepthwiseConvS8D32;
+      }
     }
 
     fn(params, output_multiplier, output_shift, input_shape, input_data,