Point wise convolution input_depth = 16

Change-Id: I68882b62f6455dc981148bb24dab78e7eadc9686
diff --git a/tflm/opt/conv_s8.cc b/tflm/opt/conv_s8.cc
index a3237a6..079e1f8 100644
--- a/tflm/opt/conv_s8.cc
+++ b/tflm/opt/conv_s8.cc
@@ -204,9 +204,14 @@
   if (filter_height == 1 && filter_width == 1 && stride_height == 1 &&
       stride_width == 1 && dilation_height_factor == 1 &&
       dilation_width_factor == 1 && pad_height == 0 && pad_width == 0 &&
-      (input_depth == filter_depth) && (output_depth % 8) == 0 &&
-      (input_depth % 32) == 0) {
-    RUN_KERNEL(kelvin::opt::ConvS8K1x1);
+      (input_depth == filter_depth)) {
+    if ((output_depth % 8) == 0 && (input_depth % 32) == 0) {
+      RUN_KERNEL(kelvin::opt::ConvS8K1x1D32);
+    }
+
+    if ((output_depth % 8) == 0 && (input_depth == 16)) {
+      RUN_KERNEL(kelvin::opt::ConvS8K1x1D16);
+    }
   }
 
   if (input_depth == 1 && filter_width == 5 && filter_height == 5 &&
diff --git a/tflm/opt/conv_s8.h b/tflm/opt/conv_s8.h
index 718d9ee..6450537 100644
--- a/tflm/opt/conv_s8.h
+++ b/tflm/opt/conv_s8.h
@@ -22,16 +22,27 @@
 
 namespace kelvin::opt {
 
-// filter 1x1
-void ConvS8K1x1(const tflite::ConvParams& params,
-                const int32_t* output_multiplier, const int32_t* output_shift,
-                const tflite::RuntimeShape& input_shape,
-                const int8_t* input_data,
-                const tflite::RuntimeShape& filter_shape,
-                const int8_t* filter_data,
-                const tflite::RuntimeShape& bias_shape,
-                const int32_t* bias_data,
-                const tflite::RuntimeShape& output_shape, int8_t* output_data);
+// filter 1x1 d%32==0
+void ConvS8K1x1D32(const tflite::ConvParams& params,
+                   const int32_t* output_multiplier, const int32_t* output_shift,
+                   const tflite::RuntimeShape& input_shape,
+                   const int8_t* input_data,
+                   const tflite::RuntimeShape& filter_shape,
+                   const int8_t* filter_data,
+                   const tflite::RuntimeShape& bias_shape,
+                   const int32_t* bias_data,
+                   const tflite::RuntimeShape& output_shape, int8_t* output_data);
+
+// filter 1x1 d==16
+void ConvS8K1x1D16(const tflite::ConvParams& params,
+                   const int32_t* output_multiplier, const int32_t* output_shift,
+                   const tflite::RuntimeShape& input_shape,
+                   const int8_t* input_data,
+                   const tflite::RuntimeShape& filter_shape,
+                   const int8_t* filter_data,
+                   const tflite::RuntimeShape& bias_shape,
+                   const int32_t* bias_data,
+                   const tflite::RuntimeShape& output_shape, int8_t* output_data);
 
 // filter depth 4n
 void ConvS8D4(const tflite::ConvParams& params,
diff --git a/tflm/opt/conv_s8_1x1.cc b/tflm/opt/conv_s8_1x1.cc
index bc61ddf..66e449c 100644
--- a/tflm/opt/conv_s8_1x1.cc
+++ b/tflm/opt/conv_s8_1x1.cc
@@ -22,7 +22,7 @@
 
 namespace kelvin::opt {
 
-void ConvS8K1x1(
+void ConvS8K1x1D32(
     const tflite::ConvParams& params, const int32_t* output_multiplier,
     const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
     const int8_t* input_data, const tflite::RuntimeShape& filter_shape,
@@ -140,7 +140,6 @@
         int8_t* p_local_filter = p_swizzled_filter_data + (in_channel * 8);
         vld_b_p_x_m(v8, p_local_filter);
         vld_b_x_m(v12, p_local_filter);
-
         aconv_vxv(v48, v0, cmds, v8);
       }
 
@@ -171,4 +170,166 @@
   } while (out_channel < output_depth);
 }
 
+void ConvS8K1x1D16(
+    const tflite::ConvParams& params, const int32_t* output_multiplier,
+    const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
+    const int8_t* input_data, const tflite::RuntimeShape& filter_shape,
+    const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
+    const int32_t* bias_data, const tflite::RuntimeShape& output_shape,
+    int8_t* output_data) {
+  // Get parameters.
+  const int32_t input_offset = params.input_offset;  // r = s(q - Z)
+  const int32_t output_offset = params.output_offset;
+
+  // Set min and max value of the output.
+  const int32_t output_activation_min = params.quantized_activation_min;
+  const int32_t output_activation_max = params.quantized_activation_max;
+
+  // Consistency check.
+  TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+  TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+  const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+  const int input_depth = input_shape.Dims(3);
+  const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
+  if (bias_data) {
+    TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
+  }
+
+  // Check dimensions of the tensors.
+  const int filter_input_depth = filter_shape.Dims(3);
+  const int groups = input_depth / filter_input_depth;
+  TFLITE_DCHECK_NE(groups, 0);
+  TFLITE_DCHECK_EQ(input_depth % filter_input_depth, 0);
+  const int filters_per_group = output_depth / groups;
+  TFLITE_DCHECK_NE(filters_per_group, 0);
+  const int output_height = output_shape.Dims(1);
+  const int output_width = output_shape.Dims(2);
+
+  union {
+    vconv_u8_t conv;
+    uint32_t raw;
+  } cmds;
+  cmds.conv.mode = 0;
+  cmds.conv.start = 0;
+  cmds.conv.stop = 3;
+  cmds.conv.sbias1 = input_offset;
+  cmds.conv.sdata1 = true;
+  cmds.conv.sbias2 = 0;
+  cmds.conv.sdata2 = true;
+
+  int8_t swizzled_filter_data[8*16];
+  int32_t swizzled_bias_data[32];
+  int32_t swizzled_mult_data[32];
+  int32_t swizzled_shift_data[32];
+
+  const int effective_output_width = batches * output_width * output_height;
+
+  // TODO(derekjchow): Remove this when simulator supports vx vslide ops
+  vdup_b_x_m(v12, 0);
+
+  int out_channel = 0;
+  for (; out_channel < output_depth; out_channel += 8) {
+    Filter_N_H_W_M(filter_data + (out_channel * 16),
+                   swizzled_filter_data, 8, 1, 1, 16);
+    vld_b_x_m(v8, swizzled_filter_data);
+
+    if (bias_data) {
+      Swizzle(bias_data + out_channel, swizzled_bias_data, 8);
+      vld_w_x_m(v16, swizzled_bias_data);
+    } else {
+      vdup_w_x_m(v16, 0);
+    }
+    Swizzle(output_multiplier + out_channel, swizzled_mult_data, 8);
+    Swizzle(output_shift + out_channel, swizzled_shift_data, 8);
+
+    vld_w_x_m(v20, swizzled_mult_data);
+    vld_w_x_m(v24, swizzled_shift_data);
+    vrsub_w_vx_m(v24, v24, 0);
+
+    int8_t* p_output = output_data + out_channel;
+    int out = 0;
+    for (; out + 8 <= effective_output_width; out += 8) {
+      // 8x accumulators
+      vmv_v_m(v48, v16);
+      vmv_v_m(v52, v16);
+      acset_v(v48, v48);
+      const int8_t* p_in = input_data + (out * input_depth);
+
+      vld_b_x_m(v0, p_in);
+      vslidevp_w_4_vv_m(v4, v0, v12);
+      aconv_vxv(v48, v0, cmds, v8);
+
+      vcget(v48);
+      INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(
+          v48, v20, v24, output_activation_min, output_activation_max,
+          output_offset);
+      INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(
+          v52, v20, v24, output_activation_min, output_activation_max,
+          output_offset);
+      vsraqs_b_vx(v48, v48, 0);
+      vsraqs_b_vx(v52, v52, 0);
+
+      vstq_b_s_xx(v48, p_output, 2 * output_depth);
+      vstq_b_s_xx(v52, p_output + output_depth, 2 * output_depth);
+      p_output += (8 * output_depth);
+    }  // out_x
+
+    // Remainder
+    int remainder_x = (effective_output_width - out);
+    if (remainder_x != 0) {
+      vmv_v_m(v48, v16);
+      vmv_v_m(v52, v16);
+      acset_v(v48, v48);
+      const int8_t* p_in = input_data + (out * input_depth);
+
+      // Load inputs
+      switch (8 - remainder_x) { // rest (stripmines?)
+        case 0:
+          vld_b_l_xx(v7, p_in + (7 * input_depth), 16);
+        case 1:
+          vld_b_l_xx(v6, p_in + (6 * input_depth), 16);
+        case 2:
+          vld_b_l_xx(v5, p_in + (5 * input_depth), 16);
+        case 3:
+          vld_b_l_xx(v4, p_in + (4 * input_depth), 16);
+        case 4:
+          vld_b_l_xx(v3, p_in + (3 * input_depth), 16);
+        case 5:
+          vld_b_l_xx(v2, p_in + (2 * input_depth), 16);
+        case 6:
+          vld_b_l_xx(v1, p_in + (1 * input_depth), 16);
+        case 7:
+          vld_b_l_xx(v0, p_in, 16);
+      }
+
+      aconv_vxv(v48, v0, cmds, v8);
+
+      vcget(v48);
+      INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(
+          v48, v20, v24, output_activation_min, output_activation_max,
+          output_offset);
+      INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(
+          v52, v20, v24, output_activation_min, output_activation_max,
+          output_offset);
+      vsraqs_b_vx(v48, v48, 0);
+      vsraqs_b_vx(v52, v52, 0);
+
+      int i = 0;
+      for (; i < std::min(4, remainder_x); i++) {
+        vst_b_l_xx(v48, p_output, 8);
+        p_output += output_depth;
+        vsliden_w_2_vv(v48, v48, v12);
+      }
+
+      for (; i < remainder_x; i++) {
+        vst_b_l_xx(v52, p_output, 8);
+        p_output += output_depth;
+        vsliden_w_2_vv(v52, v52, v12);
+      }
+    }
+  }
+}
+
 }  // namespace kelvin::opt