Merge "Process 16 channels per iteration in 1x1 d16"
diff --git a/tflm/opt/conv_s8.cc b/tflm/opt/conv_s8.cc
index 64e22e1..933e9f8 100644
--- a/tflm/opt/conv_s8.cc
+++ b/tflm/opt/conv_s8.cc
@@ -214,7 +214,7 @@
       RUN_KERNEL(kelvin::opt::ConvS8K1x1D32);
     }
 
-    if ((output_depth % 8) == 0 && (input_depth == 16)) {
+    if ((output_depth % 16) == 0 && (input_depth == 16)) {
       RUN_KERNEL(kelvin::opt::ConvS8K1x1D16);
     }
   }
diff --git a/tflm/opt/conv_s8_1x1.cc b/tflm/opt/conv_s8_1x1.cc
index 26a681b..18046e9 100644
--- a/tflm/opt/conv_s8_1x1.cc
+++ b/tflm/opt/conv_s8_1x1.cc
@@ -223,59 +223,95 @@
 
   const int effective_output_width = batches * output_width * output_height;
 
-  // TODO(derekjchow): Remove this when simulator supports vx vslide ops
-  vdup_b_x_m(v12, 0);
+#define INPUT v0 // v0, v1, v2, v3, v4, v5, v6, v7
+#define INPUT_SLIDE v4
+#define FLT_0 v8 // v8, v9, v10, v11
+#define FLT_1 v16 // v16, v17, v18, v19
+#define BIAS_0 v20
+#define BIAS_1 v24
+#define MULT_0 v28
+#define MULT_1 v32
+#define SHFT_0 v36
+#define SHFT_1 v40
+#define ACC_0 v48
+#define ACC_1 v52
+#define RES_0 v60
+#define RES_1 v61
+#define RES_2 v62
+#define RES_3 v63
 
   int out_channel = 0;
-  for (; out_channel < output_depth; out_channel += 8) {
+  for (; out_channel < output_depth; out_channel += 16) {
     Filter_N_H_W_M(filter_data + (out_channel * 16),
                    swizzled_filter_data, 8, 1, 1, 16);
-    vld_b_x_m(v8, swizzled_filter_data);
+    vld_b_x_m(FLT_0, swizzled_filter_data);
+    Filter_N_H_W_M(filter_data + ((out_channel + 8) * 16), swizzled_filter_data, 8, 1, 1, 16);
+    vld_b_x_m(FLT_1, swizzled_filter_data);
 
     if (bias_data) {
       Swizzle(bias_data + out_channel, swizzled_bias_data, 8);
-      vld_w_x_m(v16, swizzled_bias_data);
+      vld_w_x_m(BIAS_0, swizzled_bias_data);
+      Swizzle(bias_data + out_channel + 8, swizzled_bias_data, 8);
+      vld_w_x_m(BIAS_1, swizzled_bias_data);
     } else {
-      vdup_w_x_m(v16, 0);
+      vdup_w_x_m(BIAS_0, 0);
+      vdup_w_x_m(BIAS_1, 0);
     }
+
     Swizzle(output_multiplier + out_channel, swizzled_mult_data, 8);
     Swizzle(output_shift + out_channel, swizzled_shift_data, 8);
+    vld_w_x_m(MULT_0, swizzled_mult_data);
+    vld_w_x_m(SHFT_0, swizzled_shift_data);
+    vrsub_w_vx_m(SHFT_0, SHFT_0, 0);
 
-    vld_w_x_m(v20, swizzled_mult_data);
-    vld_w_x_m(v24, swizzled_shift_data);
-    vrsub_w_vx_m(v24, v24, 0);
+    Swizzle(output_multiplier + out_channel + 8, swizzled_mult_data, 8);
+    Swizzle(output_shift + out_channel + 8, swizzled_shift_data, 8);
+    vld_w_x_m(MULT_1, swizzled_mult_data);
+    vld_w_x_m(SHFT_1, swizzled_shift_data);
+    vrsub_w_vx_m(SHFT_1, SHFT_1, 0);
 
     int8_t* p_output = output_data + out_channel;
     int out = 0;
     for (; out + 8 <= effective_output_width; out += 8) {
       // 8x accumulators
-      vmv_v_m(v48, v16);
-      vmv_v_m(v52, v16);
-      acset_v(v48, v48);
       const int8_t* p_in = input_data + (out * input_depth);
 
-      vld_b_x_m(v0, p_in);
-      vslidevp_w_4_vv_m(v4, v0, v12);
-      aconv_vxv(v48, v0, cmds, v8);
+      vld_b_x_m(INPUT, p_in);
+      vslidevp_w_4_vv_m(INPUT_SLIDE, INPUT, INPUT);
 
-      vcget(v48);
+      vmv_v_m(ACC_0, BIAS_0);
+      vmv_v_m(ACC_1, BIAS_0);
+      acset_v(ACC_0, ACC_0);
+      aconv_vxv(ACC_0, INPUT, cmds, FLT_0);
+
+      vcget(ACC_0);
       INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE2(
-          v48, v52, v20, v24, output_activation_min, output_activation_max,
+          ACC_0, ACC_1, MULT_0, SHFT_0, output_activation_min, output_activation_max,
           output_offset);
-      vsraqs_b_vx(v48, v48, 0);
-      vsraqs_b_vx(v52, v52, 0);
+      vsraqs_b_vx(RES_0, ACC_0, 0);
+      vsraqs_b_vx(RES_1, ACC_1, 0);
+      vstq_b_s_xx(RES_0, p_output, 2 * output_depth);
+      vstq_b_s_xx(RES_1, p_output + output_depth, 2 * output_depth);
 
-      vstq_b_s_xx(v48, p_output, 2 * output_depth);
-      vstq_b_s_xx(v52, p_output + output_depth, 2 * output_depth);
+      vmv_v_m(ACC_0, BIAS_1);
+      vmv_v_m(ACC_1, BIAS_1);
+      acset_v(ACC_0, ACC_0);
+      aconv_vxv(ACC_0, INPUT, cmds, FLT_1);
+      vcget(ACC_0);
+      INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE2(
+          ACC_0, ACC_1, MULT_1, SHFT_1, output_activation_min, output_activation_max,
+          output_offset);
+      vsraqs_b_vx(RES_2, ACC_0, 0);
+      vsraqs_b_vx(RES_3, ACC_1, 0);
+      vstq_b_s_xx(RES_2, p_output + 8, 2 * output_depth);
+      vstq_b_s_xx(RES_3, p_output + output_depth + 8, 2 * output_depth);
       p_output += (8 * output_depth);
+
     }  // out_x
 
     // Remainder
     int remainder_x = (effective_output_width - out);
     if (remainder_x != 0) {
-      vmv_v_m(v48, v16);
-      vmv_v_m(v52, v16);
-      acset_v(v48, v48);
       const int8_t* p_in = input_data + (out * input_depth);
 
       // Load inputs
@@ -298,29 +334,62 @@
           vld_b_l_xx(v0, p_in, 16);
       }
 
-      aconv_vxv(v48, v0, cmds, v8);
-
-      vcget(v48);
+      vmv_v_m(ACC_0, BIAS_0);
+      vmv_v_m(ACC_1, BIAS_0);
+      acset_v(ACC_0, ACC_0);
+      aconv_vxv(ACC_0, INPUT, cmds, FLT_0);
+      vcget(ACC_0);
       INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE2(
-          v48, v52, v20, v24, output_activation_min, output_activation_max,
+          ACC_0, ACC_1, MULT_0, SHFT_0, output_activation_min, output_activation_max,
           output_offset);
-      vsraqs_b_vx(v48, v48, 0);
-      vsraqs_b_vx(v52, v52, 0);
+      vsraqs_b_vx(RES_0, ACC_0, 0);
+      vsraqs_b_vx(RES_1, ACC_1, 0);
+
+      vmv_v_m(ACC_0, BIAS_1);
+      vmv_v_m(ACC_1, BIAS_1);
+      acset_v(ACC_0, ACC_0);
+      aconv_vxv(ACC_0, INPUT, cmds, FLT_1);
+      vcget(ACC_0);
+      INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE2(
+          ACC_0, ACC_1, MULT_1, SHFT_1, output_activation_min, output_activation_max,
+          output_offset);
+      vsraqs_b_vx(RES_2, ACC_0, 0);
+      vsraqs_b_vx(RES_3, ACC_1, 0);
 
       int i = 0;
       for (; i < std::min(4, remainder_x); i++) {
-        vst_b_l_xx(v48, p_output, 8);
+        vst_b_l_xx(RES_0, p_output, 8);
+        vsliden_w_2_vv(RES_0, RES_0, RES_0);
+        vst_b_l_xx(RES_2, p_output + 8, 8);
+        vsliden_w_2_vv(RES_2, RES_2, RES_2);
         p_output += output_depth;
-        vsliden_w_2_vv(v48, v48, v12);
       }
 
       for (; i < remainder_x; i++) {
-        vst_b_l_xx(v52, p_output, 8);
+        vst_b_l_xx(RES_1, p_output, 8);
+        vsliden_w_2_vv(RES_1, RES_1, RES_1);
+        vst_b_l_xx(RES_3, p_output + 8, 8);
+        vsliden_w_2_vv(RES_3, RES_3, RES_3);
         p_output += output_depth;
-        vsliden_w_2_vv(v52, v52, v12);
       }
     }
   }
+#undef INPUT
+#undef INPUT_SLIDE
+#undef FLT_0
+#undef FLT_1
+#undef BIAS_0
+#undef BIAS_1
+#undef MULT_0
+#undef MULT_1
+#undef SHFT_0
+#undef SHFT_1
+#undef ACC_0
+#undef ACC_1
+#undef RES_0
+#undef RES_1
+#undef RES_2
+#undef RES_3
 }
 
 }  // namespace kelvin::opt