Specialize 3x3, stride 1 DepthwiseConv

- Specialized variant of 3x3 DepthwiseConv for inputs with stride of 1,
  in the hottest loop this computes two outputs per iteration.

Change-Id: Iad93a69069e09c83b321ab36245e6dafe6034871
diff --git a/tflm/opt/depthwise_conv_s8.cc b/tflm/opt/depthwise_conv_s8.cc
index a130324..1b450b5 100644
--- a/tflm/opt/depthwise_conv_s8.cc
+++ b/tflm/opt/depthwise_conv_s8.cc
@@ -41,6 +41,427 @@
     *out3 = *p_in++;
   }
 }
+
+// special case of input depth = 32n, filter shape of 3x3, strides of 1
+void DepthwiseConvS83x3D32_Stride1(
+    const tflite::DepthwiseParams& params, const int32_t* output_multiplier,
+    const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
+    const int8_t* input_data, const tflite::RuntimeShape& filter_shape,
+    const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
+    const int32_t* bias_data, const tflite::RuntimeShape& output_shape,
+    int8_t* output_data
+) {
+  const int stride_width = params.stride_width;
+  const int stride_height = params.stride_height;
+  const int pad_width = params.padding_values.width;
+  const int pad_height = params.padding_values.height;
+  const int32_t input_offset = params.input_offset;
+  const int32_t output_offset = params.output_offset;
+  const int32_t output_activation_min = params.quantized_activation_min;
+  const int32_t output_activation_max = params.quantized_activation_max;
+  const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+  const int input_height = input_shape.Dims(1);
+  const int input_width = input_shape.Dims(2);
+  const int input_depth = input_shape.Dims(3);
+  const int output_height = output_shape.Dims(1);
+  const int output_width = output_shape.Dims(2);
+  const int output_depth = output_shape.Dims(3);
+  int32_t swizzled_bias_data[32];
+  int32_t swizzled_shift_multi[32];
+  int32_t swizzled_output_multi[32];
+
+  for (int in_channel = 0; in_channel + 32 <= input_depth; in_channel += 32) {
+    const int output_channel = in_channel;
+    int8_t* p_output = output_data + output_channel;
+    VectorSwizzle(bias_data + output_channel, swizzled_bias_data, 32);
+    VectorSwizzle(output_multiplier + output_channel, swizzled_output_multi, 32);
+    VectorSwizzle(output_shift + output_channel, swizzled_shift_multi, 32);
+
+    vld_w_x_m(v52, swizzled_bias_data);
+    vld_w_x_m(v56, swizzled_output_multi);
+    vld_w_x_m(v60, swizzled_shift_multi);
+    vrsub_w_vx_m(v60, v60, 0);
+
+    union {
+      vdwconv_u8_t dwconv;
+      uint32_t raw;
+    } cmds;
+    cmds.raw = 0;
+    cmds.dwconv.sdata1 = true;
+    cmds.dwconv.sbias1 = input_offset;
+    cmds.dwconv.sdata2 = true;
+    cmds.dwconv.sbias2 = 0;
+    cmds.dwconv.mode = 0;
+    cmds.dwconv.sparsity = 0;
+    cmds.dwconv.regbase = 0;
+
+    // Don't reorder me, otherwise data will not be
+    // loaded in the correct order
+    // (we can reuse the p_flt* due to the `p` vld variant).
+    const int8_t* p_flt0 = filter_data + in_channel;
+    const int8_t* p_flt1 = p_flt0 + input_depth;
+    const int32_t stride = 2 * input_depth;
+    vld_b_sp_xx(v6, p_flt0, stride);
+    vld_b_sp_xx(v7, p_flt1, stride);
+    vld_b_sp_xx(v8, p_flt0, stride);
+    vld_b_sp_xx(v9, p_flt1, stride);
+    vld_b_sp_xx(v10, p_flt0, stride);
+    vld_b_sp_xx(v11, p_flt1, stride);
+    vld_b_sp_xx(v12, p_flt0, stride);
+    vld_b_sp_xx(v13, p_flt1, stride);
+    vld_b_sp_xx(v14, p_flt0, stride);
+
+    for (int batch = 0; batch < batches; ++batch) {
+      int out_y = 0;
+      for (; out_y < pad_height; ++out_y) {
+        int out_x = 0;
+        const int in_y_origin = (out_y * stride_height) - pad_height;
+        assert(in_y_origin < 0);
+        vdup_b_x(v15, -input_offset);
+        vdup_b_x(v16, -input_offset);
+        vdup_b_x(v17, -input_offset);
+        const int8_t* p_in_0 = input_data +
+            (batch * input_height * input_width * input_depth) +
+            (in_y_origin * input_width * input_depth) +
+            (((out_x * stride_width) - pad_width) * input_depth) +
+            in_channel;
+        const int8_t* p_in_1 = p_in_0 + (input_width * input_depth);
+        const int8_t* p_in_2 = p_in_1 + (input_width * input_depth);
+        for (; out_x < pad_width; ++out_x) {
+          vmv_v_m(v48, v52);
+
+          vdup_b_x(v18, -input_offset);
+          p_in_1 += input_depth;
+          vld_b_sp_xx(v19, p_in_1, input_depth);
+          vld_b_sp_xx(v20, p_in_1, input_depth);
+          vdup_b_x(v21, -input_offset);
+          p_in_2 += input_depth;
+          vld_b_sp_xx(v22, p_in_2, input_depth);
+          vld_b_sp_xx(v23, p_in_2, input_depth);
+
+          adwinit_v(v48, v48);
+          adwconv_vxv(v48, v15, cmds, v6);
+          adwconv_vxv(v48, v18, cmds, v9);
+          vdwconv_vxv(v48, v21, cmds, v12);
+
+          INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(
+              v48, v56, v60,
+              output_activation_min,
+              output_activation_max,
+              output_offset);
+          vsraqs_b_vx(v48, v48, 0);
+          vst_b_x(v48, p_output);
+
+          p_output += output_depth;
+          p_in_0 -= (2 * stride_width * input_depth);
+          p_in_1 -= (2 * stride_width * input_depth);
+          p_in_2 -= (2 * stride_width * input_depth);
+        }
+        for (; out_x < output_width - pad_width; ++out_x) {
+          vmv_v_m(v48, v52);
+          vld_b_sp_xx(v18, p_in_1, input_depth);
+          vld_b_sp_xx(v19, p_in_1, input_depth);
+          vld_b_sp_xx(v20, p_in_1, input_depth);
+          vld_b_sp_xx(v21, p_in_2, input_depth);
+          vld_b_sp_xx(v22, p_in_2, input_depth);
+          vld_b_sp_xx(v23, p_in_2, input_depth);
+
+          adwinit_v(v48, v48);
+          adwconv_vxv(v48, v15, cmds, v6);
+          adwconv_vxv(v48, v18, cmds, v9);
+          vdwconv_vxv(v48, v21, cmds, v12);
+
+          INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(
+              v48, v56, v60,
+              output_activation_min,
+              output_activation_max,
+              output_offset);
+          vsraqs_b_vx(v48, v48, 0);
+          vst_b_x(v48, p_output);
+          p_output += output_depth;
+          p_in_0 -= (2 * stride_width * input_depth);
+          p_in_1 -= (2 * stride_width * input_depth);
+          p_in_2 -= (2 * stride_width * input_depth);
+        }
+        for (; out_x < output_width; ++out_x) {
+          vmv_v_m(v48, v52);
+
+          vld_b_sp_xx(v18, p_in_1, input_depth);
+          vld_b_sp_xx(v19, p_in_1, input_depth);
+          vdup_b_x(v20, -input_offset);
+          vld_b_sp_xx(v21, p_in_2, input_depth);
+          vld_b_sp_xx(v22, p_in_2, input_depth);
+          vdup_b_x(v23, -input_offset);
+
+          adwinit_v(v48, v48);
+          adwconv_vxv(v48, v15, cmds, v6);
+          adwconv_vxv(v48, v18, cmds, v9);
+          vdwconv_vxv(v48, v21, cmds, v12);
+
+          INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(
+              v48, v56, v60,
+              output_activation_min,
+              output_activation_max,
+              output_offset);
+          vsraqs_b_vx(v48, v48, 0);
+          vst_b_x(v48, p_output);
+
+          p_output += output_depth;
+          p_in_0 -= (2 * stride_width * input_depth);
+          p_in_1 -= (2 * stride_width * input_depth);
+          p_in_2 -= (2 * stride_width * input_depth);
+        }
+      }
+      for (; out_y < output_height - pad_height; ++out_y) {
+        const int in_y_origin = (out_y * stride_height) - pad_height;
+        int out_x = 0;
+        const int8_t* p_in_0 = input_data +
+            (batch * input_height * input_width * input_depth) +
+            (in_y_origin * input_width * input_depth) +
+            (((out_x * stride_width) - pad_width) * input_depth) +
+            in_channel;
+        const int8_t* p_in_1 = p_in_0 + (input_width * input_depth);
+        const int8_t* p_in_2 = p_in_1 + (input_width * input_depth);
+        for (; out_x < pad_width; ++out_x) {
+          vmv_v_m(v48, v52);
+
+          vdup_b_x(v15, -input_offset);
+          vdup_b_x(v18, -input_offset);
+          vdup_b_x(v21, -input_offset);
+          p_in_0 += input_depth;
+          p_in_1 += input_depth;
+          p_in_2 += input_depth;
+          vld_b_sp_xx(v16, p_in_0, input_depth);
+          vld_b_sp_xx(v17, p_in_0, input_depth);
+          vld_b_sp_xx(v19, p_in_1, input_depth);
+          vld_b_sp_xx(v20, p_in_1, input_depth);
+          vld_b_sp_xx(v22, p_in_2, input_depth);
+          vld_b_sp_xx(v23, p_in_2, input_depth);
+
+          adwinit_v(v48, v48);
+          adwconv_vxv(v48, v15, cmds, v6);
+          adwconv_vxv(v48, v18, cmds, v9);
+          vdwconv_vxv(v48, v21, cmds, v12);
+          INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(
+              v48, v56, v60,
+              output_activation_min,
+              output_activation_max,
+              output_offset);
+          vsraqs_b_vx(v48, v48, 0);
+          vst_b_x(v48, p_output);
+
+          p_output += output_depth;
+          p_in_0 -= (2 * stride_width * input_depth);
+          p_in_1 -= (2 * stride_width * input_depth);
+          p_in_2 -= (2 * stride_width * input_depth);
+        }
+        for (; out_x + 2 <= output_width - pad_width; out_x += 2) {
+          // Initialize accumulators w/ bias data.
+          vmv_v_m(v44, v52);
+          vmv_v_m(v48, v52);
+
+          vld_b_sp_xx(v15, p_in_0, stride_width * input_depth);
+          vld_b_sp_xx(v16, p_in_0, stride_width * input_depth);
+          vld_b_sp_xx(v17, p_in_0, stride_width * input_depth);
+          vld_b_sp_xx(v18, p_in_0, stride_width * input_depth);
+          vld_b_sp_xx(v19, p_in_1, stride_width * input_depth);
+          vld_b_sp_xx(v20, p_in_1, stride_width * input_depth);
+          vld_b_sp_xx(v21, p_in_1, stride_width * input_depth);
+          vld_b_sp_xx(v22, p_in_1, stride_width * input_depth);
+          vld_b_sp_xx(v23, p_in_2, stride_width * input_depth);
+          vld_b_sp_xx(v24, p_in_2, stride_width * input_depth);
+          vld_b_sp_xx(v25, p_in_2, stride_width * input_depth);
+          vld_b_sp_xx(v26, p_in_2, stride_width * input_depth);
+
+          adwinit_v(v48, v48);
+          adwconv_vxv(v48, v15, cmds, v6);
+          adwconv_vxv(v48, v19, cmds, v9);
+          vdwconv_vxv(v48, v23, cmds, v12);
+
+          adwinit_v(v44, v44);
+          adwconv_vxv(v44, v16, cmds, v6);
+          adwconv_vxv(v44, v20, cmds, v9);
+          vdwconv_vxv(v44, v24, cmds, v12);
+
+          INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(
+              v44, v56, v60,
+              output_activation_min,
+              output_activation_max,
+              output_offset);
+          INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(
+              v48, v56, v60,
+              output_activation_min,
+              output_activation_max,
+              output_offset);
+          vsraqs_b_vx(v48, v48, 0);
+          vsraqs_b_vx(v44, v44, 0);
+          vst_b_x(v48, p_output);
+          p_output += output_depth;
+          vst_b_x(v44, p_output);
+          p_output += output_depth;
+
+          p_in_0 -= (2 * stride_width * input_depth);
+          p_in_1 -= (2 * stride_width * input_depth);
+          p_in_2 -= (2 * stride_width * input_depth);
+        }
+        for (; out_x < output_width - pad_width; ++out_x) {
+          vmv_v_m(v48, v52);
+
+          vld_b_sp_xx(v15, p_in_0, input_depth);
+          vld_b_sp_xx(v16, p_in_0, input_depth);
+          vld_b_sp_xx(v17, p_in_0, input_depth);
+          vld_b_sp_xx(v18, p_in_1, input_depth);
+          vld_b_sp_xx(v19, p_in_1, input_depth);
+          vld_b_sp_xx(v20, p_in_1, input_depth);
+          vld_b_sp_xx(v21, p_in_2, input_depth);
+          vld_b_sp_xx(v22, p_in_2, input_depth);
+          vld_b_sp_xx(v23, p_in_2, input_depth);
+
+          adwinit_v(v48, v48);
+          adwconv_vxv(v48, v15, cmds, v6);
+          adwconv_vxv(v48, v18, cmds, v9);
+          vdwconv_vxv(v48, v21, cmds, v12);
+
+          INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(
+              v48, v56, v60,
+              output_activation_min,
+              output_activation_max,
+              output_offset);
+          vsraqs_b_vx(v48, v48, 0);
+          vst_b_x(v48, p_output);
+          p_output += output_depth;
+          p_in_0 -= (2 * stride_width * input_depth);
+          p_in_1 -= (2 * stride_width * input_depth);
+          p_in_2 -= (2 * stride_width * input_depth);
+        }
+        for (; out_x < output_width; ++out_x) {
+          vmv_v_m(v48, v52);
+
+          vdup_b_x(v17, -input_offset);
+          vdup_b_x(v20, -input_offset);
+          vdup_b_x(v23, -input_offset);
+          vld_b_sp_xx(v15, p_in_0, input_depth);
+          vld_b_sp_xx(v16, p_in_0, input_depth);
+          vld_b_sp_xx(v18, p_in_1, input_depth);
+          vld_b_sp_xx(v19, p_in_1, input_depth);
+          vld_b_sp_xx(v21, p_in_2, input_depth);
+          vld_b_sp_xx(v22, p_in_2, input_depth);
+
+          adwinit_v(v48, v48);
+          adwconv_vxv(v48, v15, cmds, v6);
+          adwconv_vxv(v48, v18, cmds, v9);
+          vdwconv_vxv(v48, v21, cmds, v12);
+          INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(
+              v48, v56, v60,
+              output_activation_min,
+              output_activation_max,
+              output_offset);
+          vsraqs_b_vx(v48, v48, 0);
+          vst_b_x(v48, p_output);
+
+          p_output += output_depth;
+          p_in_0 -= (2 * stride_width * input_depth);
+          p_in_1 -= (2 * stride_width * input_depth);
+          p_in_2 -= (2 * stride_width * input_depth);
+        }
+      }
+      for (; out_y < output_height; ++out_y) {
+        const int in_y_origin = (out_y * stride_height) - pad_height;
+        assert(in_y_origin + 2 >= input_height);
+        vdup_b_x(v21, -input_offset);
+        vdup_b_x(v22, -input_offset);
+        vdup_b_x(v23, -input_offset);
+        int out_x = 0;
+        const int8_t* p_in_0 = input_data +
+            (batch * input_height * input_width * input_depth) +
+            (in_y_origin * input_width * input_depth) +
+            (((out_x * stride_width) - pad_width) * input_depth) +
+            in_channel;
+        const int8_t* p_in_1 = p_in_0 + (input_width * input_depth);
+        for (; out_x < pad_width; ++out_x) {
+          vmv_v_m(v48, v52);
+
+          vdup_b_x(v15, -input_offset);
+          p_in_0 += input_depth;
+          vld_b_sp_xx(v16, p_in_0, input_depth);
+          vld_b_sp_xx(v17, p_in_0, input_depth);
+          vdup_b_x(v18, -input_offset);
+          p_in_1 += input_depth;
+          vld_b_sp_xx(v19, p_in_1, input_depth);
+          vld_b_sp_xx(v20, p_in_1, input_depth);
+
+          adwinit_v(v48, v48);
+          adwconv_vxv(v48, v15, cmds, v6);
+          adwconv_vxv(v48, v18, cmds, v9);
+          vdwconv_vxv(v48, v21, cmds, v12);
+
+          INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(
+              v48, v56, v60,
+              output_activation_min,
+              output_activation_max,
+              output_offset);
+          vsraqs_b_vx(v48, v48, 0);
+          vst_b_x(v48, p_output);
+          p_output += output_depth;
+          p_in_0 -= (2 * stride_width * input_depth);
+          p_in_1 -= (2 * stride_width * input_depth);
+        }
+        for (; out_x < output_width - pad_width; ++out_x) {
+          vmv_v_m(v48, v52);
+          vld_b_sp_xx(v15, p_in_0, input_depth);
+          vld_b_sp_xx(v16, p_in_0, input_depth);
+          vld_b_sp_xx(v17, p_in_0, input_depth);
+          vld_b_sp_xx(v18, p_in_1, input_depth);
+          vld_b_sp_xx(v19, p_in_1, input_depth);
+          vld_b_sp_xx(v20, p_in_1, input_depth);
+
+          adwinit_v(v48, v48);
+          adwconv_vxv(v48, v15, cmds, v6);
+          adwconv_vxv(v48, v18, cmds, v9);
+          vdwconv_vxv(v48, v21, cmds, v12);
+
+          INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(
+              v48, v56, v60,
+              output_activation_min,
+              output_activation_max,
+              output_offset);
+          vsraqs_b_vx(v48, v48, 0);
+          vst_b_x(v48, p_output);
+          p_output += output_depth;
+          p_in_0 -= (2 * stride_width * input_depth);
+          p_in_1 -= (2 * stride_width * input_depth);
+        }
+        for (; out_x < output_width; ++out_x) {
+          vmv_v_m(v48, v52);
+
+          vld_b_sp_xx(v15, p_in_0, input_depth);
+          vld_b_sp_xx(v16, p_in_0, input_depth);
+          vdup_b_x(v17, -input_offset);
+          vld_b_sp_xx(v18, p_in_1, input_depth);
+          vld_b_sp_xx(v19, p_in_1, input_depth);
+          vdup_b_x(v20, -input_offset);
+
+          adwinit_v(v48, v48);
+          adwconv_vxv(v48, v15, cmds, v6);
+          adwconv_vxv(v48, v18, cmds, v9);
+          vdwconv_vxv(v48, v21, cmds, v12);
+
+          INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(
+              v48, v56, v60,
+              output_activation_min,
+              output_activation_max,
+              output_offset);
+          vsraqs_b_vx(v48, v48, 0);
+          vst_b_x(v48, p_output);
+          p_output += output_depth;
+          p_in_0 -= (2 * stride_width * input_depth);
+          p_in_1 -= (2 * stride_width * input_depth);
+        }
+      }
+    }
+  }
+}
+
 // special case of input depth = 32n, filter shape of 3x3
 void DepthwiseConvS83x3D32(
     const tflite::DepthwiseParams& params, const int32_t* output_multiplier,
@@ -829,6 +1250,8 @@
         } else {
           fn = DepthwiseConvS85x5D32;
         }
+      } else if (filter_width == 3 && filter_height == 3 && pad_width <= 1 && pad_height <= 1 && stride_width == 1 && stride_height == 1) {
+        fn = DepthwiseConvS83x3D32_Stride1;
       } else if (filter_width == 3 && filter_height == 3 && pad_width <= 1 && pad_height <= 1) {
         fn = DepthwiseConvS83x3D32;
       } else {