Improve Conv2dD32

- Specialize Conv2D for exact depths of 32, improves performance on
  hardware by ~40%

Change-Id: Ic4f193d2c9104b91abd11ddcba6616b1d5bd0710
diff --git a/tflm/opt/conv_s8.cc b/tflm/opt/conv_s8.cc
index 0dc432b..7842b0d 100644
--- a/tflm/opt/conv_s8.cc
+++ b/tflm/opt/conv_s8.cc
@@ -184,6 +184,8 @@
   const auto dilation_height_factor = params.dilation_height_factor;
   const auto pad_width = params.padding_values.width;
   const auto pad_height = params.padding_values.height;
+  const auto input_batch = input_shape.Dims(0);
+  const auto input_height = input_shape.Dims(1);
   const auto input_width = input_shape.Dims(2);
   const auto input_depth = input_shape.Dims(3);
   const auto filter_height = filter_shape.Dims(1);
@@ -205,13 +207,18 @@
       stride_width == 1 && dilation_height_factor == 1 &&
       dilation_width_factor == 1 && pad_height == 0 && pad_width == 0 &&
       (input_depth == filter_depth)) {
-    if ((output_depth % 8) == 0 && (input_depth % 32) == 0) {
+
+    if ((input_depth == 32) && (input_batch * input_height * input_width) >= 4) {
       RUN_KERNEL(kelvin::opt::ConvS8K1x1D32);
     }
 
+    if ((output_depth % 8) == 0 && (input_depth % 32) == 0) {
+      RUN_KERNEL(kelvin::opt::ConvS8K1x1DMod32);
+    }
+
     // TODO: Relax this kernel for all output_depths
     if ((output_depth < 8) && (input_depth % 32) == 0) {
-      RUN_KERNEL(kelvin::opt::ConvS8K1x1D32);
+      RUN_KERNEL(kelvin::opt::ConvS8K1x1DMod32);
     }
 
     if ((output_depth % 16) == 0 && (input_depth == 16)) {
diff --git a/tflm/opt/conv_s8.h b/tflm/opt/conv_s8.h
index b79bd65..96cea2e 100644
--- a/tflm/opt/conv_s8.h
+++ b/tflm/opt/conv_s8.h
@@ -23,6 +23,17 @@
 namespace kelvin::opt {
 
 // filter 1x1 d%32==0
+void ConvS8K1x1DMod32(const tflite::ConvParams& params,
+                   const int32_t* output_multiplier, const int32_t* output_shift,
+                   const tflite::RuntimeShape& input_shape,
+                   const int8_t* input_data,
+                   const tflite::RuntimeShape& filter_shape,
+                   const int8_t* filter_data,
+                   const tflite::RuntimeShape& bias_shape,
+                   const int32_t* bias_data,
+                   const tflite::RuntimeShape& output_shape, int8_t* output_data);
+
+// filter 1x1 d==32
 void ConvS8K1x1D32(const tflite::ConvParams& params,
                    const int32_t* output_multiplier, const int32_t* output_shift,
                    const tflite::RuntimeShape& input_shape,
diff --git a/tflm/opt/conv_s8_1x1.cc b/tflm/opt/conv_s8_1x1.cc
index 18046e9..600f765 100644
--- a/tflm/opt/conv_s8_1x1.cc
+++ b/tflm/opt/conv_s8_1x1.cc
@@ -22,7 +22,7 @@
 
 namespace kelvin::opt {
 
-void ConvS8K1x1D32(
+void ConvS8K1x1DMod32(
     const tflite::ConvParams& params, const int32_t* output_multiplier,
     const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
     const int8_t* input_data, const tflite::RuntimeShape& filter_shape,
@@ -138,8 +138,8 @@
         }
 
         int8_t* p_local_filter = p_swizzled_filter_data + (in_channel * 8);
-        vld_b_p_x_m(v8, p_local_filter);
-        vld_b_x_m(v12, p_local_filter);
+        vld_b_x_m(v8, p_local_filter);
+        vld_b_x_m(v12, p_local_filter + (4 * 32));
         aconv_vxv(v48, v0, cmds, v8);
       }
 
@@ -152,13 +152,13 @@
 
       int i = 0;
       for (; i < std::min(4, out_this_iter); i++) {
-        vst_b_l_xx(v48, p_out, out_channels_this_iter);
-        p_out += output_depth;
+        vst_b_l_xx(v48, p_out + (i * output_depth), out_channels_this_iter);
+        // p_out += output_depth;
         vsliden_h_4_vv(v48, v48, v48);
       }
       for (; i < out_this_iter; i++) {
-        vst_b_l_xx(v52, p_out, out_channels_this_iter);
-        p_out += output_depth;
+        vst_b_l_xx(v52, p_out + (i * output_depth), out_channels_this_iter);
+        // p_out += output_depth;
         vsliden_h_4_vv(v52, v52, v52);
       }
     }
@@ -167,6 +167,261 @@
   } while (out_channel < output_depth);
 }
 
+void ConvS8K1x1D32(
+    const tflite::ConvParams& params, const int32_t* output_multiplier,
+    const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
+    const int8_t* input_data, const tflite::RuntimeShape& filter_shape,
+    const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
+    const int32_t* bias_data, const tflite::RuntimeShape& output_shape,
+    int8_t* output_data) {
+  // Get parameters.
+  const int32_t input_offset = params.input_offset;  // r = s(q - Z)
+  const int32_t output_offset = params.output_offset;
+
+  // Set min and max value of the output.
+  const int32_t output_activation_min = params.quantized_activation_min;
+  const int32_t output_activation_max = params.quantized_activation_max;
+
+  // Consistency check.
+  TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+  TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+  const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+  const int input_depth = input_shape.Dims(3);
+  const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
+  assert(input_depth == 32);
+  if (bias_data) {
+    TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
+  }
+
+  // Check dimensions of the tensors.
+  const int filter_input_depth = filter_shape.Dims(3);
+      assert(filter_input_depth == 32);
+  const int groups = input_depth / filter_input_depth;
+  TFLITE_DCHECK_NE(groups, 0);
+  TFLITE_DCHECK_EQ(input_depth % filter_input_depth, 0);
+  const int filters_per_group = output_depth / groups;
+  TFLITE_DCHECK_NE(filters_per_group, 0);
+  const int output_height = output_shape.Dims(1);
+  const int output_width = output_shape.Dims(2);
+
+  union {
+    vconv_u8_t conv;
+    uint32_t raw;
+  } cmds;
+  cmds.conv.mode = 0;
+  cmds.conv.start = 0;
+  cmds.conv.stop = 7;
+  cmds.conv.sbias1 = input_offset;
+  cmds.conv.sdata1 = true;
+  cmds.conv.sbias2 = 0;
+  cmds.conv.sdata2 = true;
+  const size_t swizzled_filter_data_size =
+      8 * filter_input_depth;
+  std::unique_ptr<int8_t> swizzled_filter_data(reinterpret_cast<int8_t*>(
+      ::aligned_alloc(32, swizzled_filter_data_size)));
+  int8_t* p_swizzled_filter_data = swizzled_filter_data.get();
+  int32_t swizzled_bias_data[32];
+  int32_t swizzled_mult_data[32];
+  int32_t swizzled_shift_data[32];
+
+// v0-v7 INPUT0
+// v8-v15 FLT
+// v16-v23 INPUT1
+// v24-v27 ACCB
+// v32-v43 unused
+// v44-47 BIAS
+// v48-v55 aconv ACC
+// v56-v59 MULT
+// v60-v63 BIAS
+#define INPUT0_0 v0
+#define INPUT0_1 v4
+#define FLT0_0 v8
+#define FLT0_1 v12
+#define INPUT1_0 v16
+#define INPUT1_1 v20
+#define ACCB0 v24
+#define ACCB1 v25
+#define ACCB2 v26
+#define ACCB3 v27
+#define ACCB4 v28
+#define ACCB5 v29
+#define ACCB6 v30
+#define ACCB7 v31
+#define BIAS0 v44
+#define ACC0 v48
+#define ACC1 v52
+#define MULT0 v56
+#define SHFT0 v60
+
+  const int n_elems = (output_width * batches * output_height);
+  int out_channel = 0;
+  do { // out_channel
+    int out_channels_this_iter = std::min(8, output_depth - out_channel);
+    assert(out_channels_this_iter == 8);
+    Filter_N_H_W_M(filter_data + (out_channel * filter_input_depth),
+                   p_swizzled_filter_data, out_channels_this_iter, 1, 1,
+                   filter_input_depth);
+    if (bias_data) {
+      Swizzle(bias_data + out_channel, swizzled_bias_data, out_channels_this_iter);
+      vld_w_x_m(BIAS0, swizzled_bias_data);
+    } else {
+      vdup_w_x_m(BIAS0, 0);
+    }
+    Swizzle(output_multiplier + out_channel, swizzled_mult_data, out_channels_this_iter);
+    Swizzle(output_shift + out_channel, swizzled_shift_data, out_channels_this_iter);
+
+    vld_w_x_m(MULT0, swizzled_mult_data);
+    vld_w_x_m(SHFT0, swizzled_shift_data);
+    vrsub_w_vx_m(SHFT0, SHFT0, 0);
+
+    int out = 0;
+    do {
+      const int8_t* p_in = input_data + (out * input_depth);
+      int8_t* p_out = output_data + (out * output_depth) + out_channel;
+
+      // 8x accumulators
+      vmv_v_m(ACC0, BIAS0);
+      vmv_v_m(ACC1, BIAS0);
+
+      acset_v(ACC0, ACC0);
+      const int8_t* p_input = p_in;
+       // Inputs
+      vld_b_s_xx_m(INPUT0_0, p_input, input_depth);
+      vld_b_s_xx_m(INPUT0_1, p_input + (4 * input_depth), input_depth);
+
+      int8_t* p_local_filter = p_swizzled_filter_data;
+      vld_b_x_m(FLT0_0, p_local_filter);
+      vld_b_x_m(FLT0_1, p_local_filter + (4 * 32));
+
+      aconv_vxv(ACC0, INPUT0_0, cmds, FLT0_0);
+      vld_b_s_xx_m(INPUT1_0, p_input + (8 * input_depth), input_depth);
+      vld_b_s_xx_m(INPUT1_1, p_input + ( 12 * input_depth), input_depth);
+
+      vcget(ACC0);
+      vmv_v_m(INPUT0_0, ACC0);
+      vmv_v_m(INPUT0_1, ACC1);
+      INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE2(
+          INPUT0_0, INPUT0_1, MULT0, SHFT0, output_activation_min, output_activation_max,
+          output_offset);
+      vsraqs_b_vx(ACCB0, INPUT0_0, 0);
+      vsraqs_b_vx(ACCB1, INPUT0_1, 0);
+
+      vmv_v_m(ACC0, BIAS0);
+      vstq_b_s_xx(ACCB0, p_out, output_depth);
+      vstq_b_s_xx(ACCB1, p_out + (4 * output_depth), output_depth);
+      vmv_v_m(ACC1, BIAS0);
+      acset_v(ACC0, ACC0);
+
+      aconv_vxv(ACC0, INPUT1_0, cmds, FLT0_0);
+      vld_b_s_xx_m(INPUT0_0, p_input + (16 * input_depth), input_depth);
+      vld_b_s_xx_m(INPUT0_1, p_input + (20 * input_depth), input_depth);
+      vcget(ACC0);
+      vmv_v_m(INPUT1_0, ACC0);
+      vmv_v_m(INPUT1_1, ACC1);
+      INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE2(
+          INPUT1_0, INPUT1_1, MULT0, SHFT0, output_activation_min, output_activation_max,
+          output_offset);
+      vsraqs_b_vx(ACCB2, INPUT1_0, 0);
+      vsraqs_b_vx(ACCB3, INPUT1_1, 0);
+
+      vld_b_s_xx_m(INPUT1_0, p_input + (24 * input_depth), input_depth);
+      vld_b_s_xx_m(INPUT1_1, p_input + (28 * input_depth), input_depth);
+
+      vmv_v_m(ACC0, BIAS0);
+      vstq_b_s_xx(ACCB2, p_out + (8 * output_depth), output_depth);
+      vstq_b_s_xx(ACCB3, p_out + (12 * output_depth), output_depth);
+      vmv_v_m(ACC1, BIAS0);
+      acset_v(ACC0, ACC0);
+      aconv_vxv(ACC0, INPUT0_0, cmds, FLT0_0);
+      vcget(ACC0);
+      vmv_v_m(INPUT0_0, ACC0);
+      vmv_v_m(INPUT0_1, ACC1);
+      INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE2(
+          INPUT0_0, INPUT0_1, MULT0, SHFT0, output_activation_min, output_activation_max,
+          output_offset);
+      vsraqs_b_vx(ACCB4, INPUT0_0, 0);
+      vsraqs_b_vx(ACCB5, INPUT0_1, 0);
+
+      vmv_v_m(ACC0, BIAS0);
+      vstq_b_s_xx(ACCB4, p_out + (16 * output_depth), output_depth);
+      vstq_b_s_xx(ACCB5, p_out + (20 * output_depth), output_depth);
+      vmv_v_m(ACC1, BIAS0);
+      acset_v(ACC0, ACC0);
+      aconv_vxv(ACC0, INPUT1_0, cmds, FLT0_0);
+      vcget(ACC0);
+      INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE2(
+          ACC0, ACC1, MULT0, SHFT0, output_activation_min, output_activation_max,
+          output_offset);
+      vsraqs_b_vx(ACCB6, ACC0, 0);
+      vsraqs_b_vx(ACCB7, ACC1, 0);
+
+      vstq_b_s_xx(ACCB6, p_out + (24 * output_depth), output_depth);
+      vstq_b_s_xx(ACCB7, p_out + (28 * output_depth), output_depth);
+
+      out += 32;
+    } while ((n_elems - out) >= 32);
+    do {// remainder loop
+      int out_this_iter = std::min(8, n_elems - out);
+      const int8_t* p_in = input_data + (out * input_depth);
+      int8_t* p_out = output_data + (out * output_depth) + out_channel;
+
+      // 8x accumulators
+      vmv_v_m(ACC0, BIAS0);
+      vmv_v_m(ACC1, BIAS0);
+      acset_v(ACC0, ACC0);
+      const int8_t* p_input = p_in;
+      if (out_this_iter < 8) {
+        switch (out_this_iter) {
+          case 7:
+            vld_b_x(v6, p_input + (6 * input_depth));
+          case 6:
+            vld_b_x(v5, p_input + (5 * input_depth));
+          case 5:
+            vld_b_x(v4, p_input + (4 * input_depth));
+          case 4:
+            vld_b_x(v3, p_input + (3 * input_depth));
+          case 3:
+            vld_b_x(v2, p_input + (2 * input_depth));
+          case 2:
+            vld_b_x(v1, p_input + input_depth);
+          case 1:
+            vld_b_x(INPUT0_0, p_input);
+        }
+      } else {
+        // Inputs
+        vld_b_s_xx_m(INPUT0_0, p_input, input_depth);
+        vld_b_s_xx_m(INPUT0_1, p_input + (4 * input_depth), input_depth);
+      }
+
+      int8_t* p_local_filter = p_swizzled_filter_data;
+      vld_b_x_m(FLT0_0, p_local_filter);
+      vld_b_x_m(FLT0_1, p_local_filter + (4 * 32));
+      aconv_vxv(ACC0, INPUT0_0, cmds, FLT0_0);
+
+      vcget(v48);
+      INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE2(
+          ACC0, ACC1, MULT0, SHFT0, output_activation_min, output_activation_max,
+          output_offset);
+      vsraqs_b_vx(ACC0, ACC0, 0);
+      vsraqs_b_vx(ACC1, ACC1, 0);
+
+      int i = 0;
+      for (; i < std::min(4, out_this_iter); i++) {
+        vst_b_l_xx(ACC0, p_out + (i * output_depth), out_channels_this_iter);
+        vsliden_h_4_vv(ACC0, ACC0, ACC0);
+      }
+      for (; i < out_this_iter; i++) {
+        vst_b_l_xx(ACC1, p_out + (i * output_depth), out_channels_this_iter);
+        vsliden_h_4_vv(ACC1, ACC1, ACC1);
+      }
+      out += out_this_iter;
+    } while (out < n_elems);
+    out_channel += out_channels_this_iter;
+  } while (out_channel < output_depth);
+}
+
 void ConvS8K1x1D16(
     const tflite::ConvParams& params, const int32_t* output_multiplier,
     const int32_t* output_shift, const tflite::RuntimeShape& input_shape,