Merge "Process 16 channels per iteration in 1x1 d16"
diff --git a/tflm/opt/elementwise_add_s8.cc b/tflm/opt/elementwise_add_s8.cc
index b751ceb..e4119d4 100644
--- a/tflm/opt/elementwise_add_s8.cc
+++ b/tflm/opt/elementwise_add_s8.cc
@@ -31,82 +31,145 @@
                       const int32_t output_activation_max,
                       const int32_t block_size) {
   int blocks = block_size;
-  int vl;
-  getmaxvl_b(vl);
 
   const int32_t input1_shift_mul = 1 << LEFT_SHIFT(input1_shift);
   const int32_t input2_shift_mul = 1 << LEFT_SHIFT(input2_shift);
 
+  while (blocks >= 96) {
+    vld_b_lp_xx(v0, input1, 32);
+    vld_b_lp_xx(v8, input2, 32);
+
+    vaddw_h_vx(v2, v0, 0);
+    vaddw_h_vx(v10, v8, 0);
+    vaddw_w_vx(v4, v2, input1_offset);
+    vaddw_w_vx(v6, v3, input1_offset);
+    vaddw_w_vx(v12, v10, input2_offset);
+    vaddw_w_vx(v14, v11, input2_offset);
+
+    vld_b_lp_xx(v16, input1, 32);
+    vld_b_lp_xx(v24, input2, 32);
+
+    vaddw_h_vx(v18, v16, 0);
+    vaddw_h_vx(v26, v24, 0);
+    vaddw_w_vx(v20, v18, input1_offset);
+    vaddw_w_vx(v22, v19, input1_offset);
+    vaddw_w_vx(v28, v26, input2_offset);
+    vaddw_w_vx(v30, v27, input2_offset);
+
+    vld_b_lp_xx(v32, input1, 32);
+    vld_b_lp_xx(v40, input2, 32);
+
+    vaddw_h_vx(v34, v32, 0);
+    vaddw_h_vx(v42, v40, 0);
+    vaddw_w_vx(v36, v34, input1_offset);
+    vaddw_w_vx(v38, v35, input1_offset);
+    vaddw_w_vx(v44, v42, input2_offset);
+    vaddw_w_vx(v46, v43, input2_offset);
+
+    vsll_w_vx_m(v4, v4, left_shift);
+    vsll_w_vx_m(v12, v12, left_shift);
+
+    vsll_w_vx_m(v20, v20, left_shift);
+    vsll_w_vx_m(v28, v28, left_shift);
+
+    vsll_w_vx_m(v36, v36, left_shift);
+    vsll_w_vx_m(v44, v44, left_shift);
+
+    vmul_w_vx_m(v4, v4, input1_shift_mul);
+    vmul_w_vx_m(v12, v12, input2_shift_mul);
+
+    vmul_w_vx_m(v20, v20, input1_shift_mul);
+    vmul_w_vx_m(v28, v28, input2_shift_mul);
+
+    vmul_w_vx_m(v36, v36, input1_shift_mul);
+    vmul_w_vx_m(v44, v44, input2_shift_mul);
+
+    vdmulh_w_r_vx_m(v4, v4, input1_mult);
+    vdmulh_w_r_vx_m(v12, v12, input2_mult);
+    vsha_w_r_vx_m(v4, v4, -input1_shift);
+    vsha_w_r_vx_m(v12, v12, -input2_shift);
+    vadd_w_vx_m(v4, v4, input1_offset);
+    vadd_w_vx_m(v12, v12, input2_offset);
+
+    vdmulh_w_r_vx_m(v20, v20, input1_mult);
+    vsha_w_r_vx_m(v20, v20, -input1_shift);
+    vadd_w_vx_m(v20, v20, input1_offset);
+    vdmulh_w_r_vx_m(v28, v28, input2_mult);
+    vsha_w_r_vx_m(v28, v28, -input2_shift);
+    vadd_w_vx_m(v28, v28, input2_offset);
+
+    vdmulh_w_r_vx_m(v36, v36, input1_mult);
+    vsha_w_r_vx_m(v36, v36, -input1_shift);
+    vadd_w_vx_m(v36, v36, input1_offset);
+    vdmulh_w_r_vx_m(v44, v44, input2_mult);
+    vsha_w_r_vx_m(v44, v44, -input2_shift);
+    vadd_w_vx_m(v44, v44, input2_offset);
+
+    vadd_w_vv_m(v12, v4, v12);
+    vadd_w_vv_m(v28, v20, v28);
+    vadd_w_vv_m(v44, v36, v44);
+
+    vdmulh_w_r_vx_m(v12, v12, output_mult);
+    vdmulh_w_r_vx_m(v28, v28, output_mult);
+    vdmulh_w_r_vx_m(v44, v44, output_mult);
+    vsha_w_r_vx_m(v12, v12, -output_shift);
+    vsha_w_r_vx_m(v28, v28, -output_shift);
+    vsha_w_r_vx_m(v44, v44, -output_shift);
+    vadd_w_vx_m(v12, v12, output_offset);
+    vadd_w_vx_m(v28, v28, output_offset);
+    vadd_w_vx_m(v44, v44, output_offset);
+
+    vmin_w_vx_m(v12, v12, output_activation_max);
+    vmin_w_vx_m(v28, v28, output_activation_max);
+    vmin_w_vx_m(v44, v44, output_activation_max);
+    vmax_w_vx_m(v12, v12, output_activation_min);
+    vmax_w_vx_m(v28, v28, output_activation_min);
+    vmax_w_vx_m(v44, v44, output_activation_min);
+
+    vsraqs_b_vx(v12, v12, 0);
+    vst_b_lp_xx(v12, output, 32);
+    vsraqs_b_vx(v28, v28, 0);
+    vst_b_lp_xx(v28, output, 32);
+    vsraqs_b_vx(v44, v44, 0);
+    vst_b_lp_xx(v44, output, 32);
+
+    blocks -= 96;
+  }
+
   while (blocks) {
-    int count = std::min(blocks, vl);
+    int count = std::min(blocks, 32);
+    vld_b_lp_xx(v0, input1, count);
+    vld_b_lp_xx(v8, input2, count);
 
-    // Widen input1 to 32-bit wide values (in vm0, vm1, vm2, vm3).
-    vld_b_lp_xx_m(vm0, input1, count);
-    vaddw_h_vx_m(vm0, vm0, 0);
-    vaddw_w_vx_m(vm2, vm1, input1_offset);
-    vaddw_w_vx_m(vm0, vm0, input1_offset);
+    vaddw_h_vx(v2, v0, 0);
+    vaddw_h_vx(v10, v8, 0);
+    vaddw_w_vx(v4, v2, input1_offset);
+    vaddw_w_vx(v6, v3, input1_offset);
+    vaddw_w_vx(v12, v10, input2_offset);
+    vaddw_w_vx(v14, v11, input2_offset);
 
-    // Widen input2 to 32-bit wide values (in vm4, vm5, vm6, vm7).
-    vld_b_lp_xx_m(vm4, input2, count);
-    vaddw_h_vx_m(vm4, vm4, 0);
-    vaddw_w_vx_m(vm6, vm5, input2_offset);
-    vaddw_w_vx_m(vm4, vm4, input2_offset);
+    vsll_w_vx_m(v4, v4, left_shift);
+    vsll_w_vx_m(v12, v12, left_shift);
 
-    // Apply left_shift to all inputs.
-    vsll_w_vx_m(vm0, vm0, left_shift);
-    vsll_w_vx_m(vm1, vm1, left_shift);
-    vsll_w_vx_m(vm2, vm2, left_shift);
-    vsll_w_vx_m(vm3, vm3, left_shift);
-    vsll_w_vx_m(vm4, vm4, left_shift);
-    vsll_w_vx_m(vm5, vm5, left_shift);
-    vsll_w_vx_m(vm6, vm6, left_shift);
-    vsll_w_vx_m(vm7, vm7, left_shift);
+    vmul_w_vx_m(v4, v4, input1_shift_mul);
+    vmul_w_vx_m(v12, v12, input2_shift_mul);
 
-    vmul_w_vx_m(vm0, vm0, input1_shift_mul);
-    vmul_w_vx_m(vm1, vm1, input1_shift_mul);
-    vmul_w_vx_m(vm2, vm2, input1_shift_mul);
-    vmul_w_vx_m(vm3, vm3, input1_shift_mul);
-    vmul_w_vx_m(vm4, vm4, input2_shift_mul);
-    vmul_w_vx_m(vm5, vm5, input2_shift_mul);
-    vmul_w_vx_m(vm6, vm6, input2_shift_mul);
-    vmul_w_vx_m(vm7, vm7, input2_shift_mul);
+    vdmulh_w_r_vx_m(v4, v4, input1_mult);
+    vdmulh_w_r_vx_m(v12, v12, input2_mult);
+    vsha_w_r_vx_m(v4, v4, -input1_shift);
+    vsha_w_r_vx_m(v12, v12, -input2_shift);
+    vadd_w_vx_m(v4, v4, input1_offset);
+    vadd_w_vx_m(v12, v12, input2_offset);
 
-    rescale_m(vm0, vm0, input1_mult, input1_shift, input1_offset);
-    rescale_m(vm1, vm1, input1_mult, input1_shift, input1_offset);
-    rescale_m(vm2, vm2, input1_mult, input1_shift, input1_offset);
-    rescale_m(vm3, vm3, input1_mult, input1_shift, input1_offset);
-    rescale_m(vm4, vm4, input2_mult, input2_shift, input2_offset);
-    rescale_m(vm5, vm5, input2_mult, input2_shift, input2_offset);
-    rescale_m(vm6, vm6, input2_mult, input2_shift, input2_offset);
-    rescale_m(vm7, vm7, input2_mult, input2_shift, input2_offset);
+    vadd_w_vv_m(v16, v4, v12);
 
-    // Sum the rescaled inputs.
-    vadd_w_vv_m(vm0, vm0, vm4);
-    vadd_w_vv_m(vm1, vm1, vm5);
-    vadd_w_vv_m(vm2, vm2, vm6);
-    vadd_w_vv_m(vm3, vm3, vm7);
+    rescale_m(v16, v16, output_mult, output_shift, output_offset);
 
-    // Rescale the summed output.
-    rescale_m(vm0, vm0, output_mult, output_shift, output_offset);
-    rescale_m(vm1, vm1, output_mult, output_shift, output_offset);
-    rescale_m(vm2, vm2, output_mult, output_shift, output_offset);
-    rescale_m(vm3, vm3, output_mult, output_shift, output_offset);
+    vmin_w_vx_m(v16, v16, output_activation_max);
+    vmax_w_vx_m(v16, v16, output_activation_min);
 
-    // Clamp to the provided range.
-    vmin_w_vx_m(vm0, vm0, output_activation_max);
-    vmin_w_vx_m(vm1, vm1, output_activation_max);
-    vmin_w_vx_m(vm2, vm2, output_activation_max);
-    vmin_w_vx_m(vm3, vm3, output_activation_max);
-    vmax_w_vx_m(vm0, vm0, output_activation_min);
-    vmax_w_vx_m(vm1, vm1, output_activation_min);
-    vmax_w_vx_m(vm2, vm2, output_activation_min);
-    vmax_w_vx_m(vm3, vm3, output_activation_min);
-
-    // Swizzle and narrow back to bytes.
-    vsraqs_b_vx_m(vm0, vm0, 0);
-
-    // Store to memory.
-    vst_b_lp_xx_m(vm0, output, count);
+    vsraqs_b_vx(v16, v16, 0);
+    vst_b_lp_xx(v16, output, count);
 
     blocks -= count;
   }