Improvements to elementwise_add_{s8,s16}

- Implement the quantized multiplication with less instructions.

Change-Id: I31b7b7b742fad91be7aebe6a2970c3519501e115
diff --git a/tflm/opt/elementwise_add_s16.cc b/tflm/opt/elementwise_add_s16.cc
index 001113e..853ad9d 100644
--- a/tflm/opt/elementwise_add_s16.cc
+++ b/tflm/opt/elementwise_add_s16.cc
@@ -57,18 +57,18 @@
     vmul_w_vx_m(vm2, vm2, input2_shift_mul);
     vmul_w_vx_m(vm3, vm3, input2_shift_mul);
 
-    rescale_m(vm0, vm0, vm15, input1_mult, input1_shift, input1_offset);
-    rescale_m(vm1, vm1, vm15, input1_mult, input1_shift, input1_offset);
-    rescale_m(vm2, vm2, vm15, input2_mult, input2_shift, input2_offset);
-    rescale_m(vm3, vm3, vm15, input2_mult, input2_shift, input2_offset);
+    rescale_m(vm0, vm0, input1_mult, input1_shift, input1_offset);
+    rescale_m(vm1, vm1, input1_mult, input1_shift, input1_offset);
+    rescale_m(vm2, vm2, input2_mult, input2_shift, input2_offset);
+    rescale_m(vm3, vm3, input2_mult, input2_shift, input2_offset);
 
     // Sum the rescaled inputs.
     vadd_w_vv_m(vm0, vm0, vm2);
     vadd_w_vv_m(vm1, vm1, vm3);
 
     // Rescale the summed output.
-    rescale_m(vm0, vm0, vm15, output_mult, output_shift, output_offset);
-    rescale_m(vm1, vm1, vm15, output_mult, output_shift, output_offset);
+    rescale_m(vm0, vm0, output_mult, output_shift, output_offset);
+    rescale_m(vm1, vm1, output_mult, output_shift, output_offset);
 
     // Clamp to the provided range.
     vmin_w_vx_m(vm0, vm0, output_activation_max);
diff --git a/tflm/opt/elementwise_add_s8.cc b/tflm/opt/elementwise_add_s8.cc
index 762d7af..b751ceb 100644
--- a/tflm/opt/elementwise_add_s8.cc
+++ b/tflm/opt/elementwise_add_s8.cc
@@ -71,14 +71,14 @@
     vmul_w_vx_m(vm6, vm6, input2_shift_mul);
     vmul_w_vx_m(vm7, vm7, input2_shift_mul);
 
-    rescale_m(vm0, vm0, vm15, input1_mult, input1_shift, input1_offset);
-    rescale_m(vm1, vm1, vm15, input1_mult, input1_shift, input1_offset);
-    rescale_m(vm2, vm2, vm15, input1_mult, input1_shift, input1_offset);
-    rescale_m(vm3, vm3, vm15, input1_mult, input1_shift, input1_offset);
-    rescale_m(vm4, vm4, vm15, input2_mult, input2_shift, input2_offset);
-    rescale_m(vm5, vm5, vm15, input2_mult, input2_shift, input2_offset);
-    rescale_m(vm6, vm6, vm15, input2_mult, input2_shift, input2_offset);
-    rescale_m(vm7, vm7, vm15, input2_mult, input2_shift, input2_offset);
+    rescale_m(vm0, vm0, input1_mult, input1_shift, input1_offset);
+    rescale_m(vm1, vm1, input1_mult, input1_shift, input1_offset);
+    rescale_m(vm2, vm2, input1_mult, input1_shift, input1_offset);
+    rescale_m(vm3, vm3, input1_mult, input1_shift, input1_offset);
+    rescale_m(vm4, vm4, input2_mult, input2_shift, input2_offset);
+    rescale_m(vm5, vm5, input2_mult, input2_shift, input2_offset);
+    rescale_m(vm6, vm6, input2_mult, input2_shift, input2_offset);
+    rescale_m(vm7, vm7, input2_mult, input2_shift, input2_offset);
 
     // Sum the rescaled inputs.
     vadd_w_vv_m(vm0, vm0, vm4);
@@ -87,10 +87,10 @@
     vadd_w_vv_m(vm3, vm3, vm7);
 
     // Rescale the summed output.
-    rescale_m(vm0, vm0, vm15, output_mult, output_shift, output_offset);
-    rescale_m(vm1, vm1, vm15, output_mult, output_shift, output_offset);
-    rescale_m(vm2, vm2, vm15, output_mult, output_shift, output_offset);
-    rescale_m(vm3, vm3, vm15, output_mult, output_shift, output_offset);
+    rescale_m(vm0, vm0, output_mult, output_shift, output_offset);
+    rescale_m(vm1, vm1, output_mult, output_shift, output_offset);
+    rescale_m(vm2, vm2, output_mult, output_shift, output_offset);
+    rescale_m(vm3, vm3, output_mult, output_shift, output_offset);
 
     // Clamp to the provided range.
     vmin_w_vx_m(vm0, vm0, output_activation_max);
diff --git a/tflm/opt/util.h b/tflm/opt/util.h
index d0c16db..7fd36be 100644
--- a/tflm/opt/util.h
+++ b/tflm/opt/util.h
@@ -25,22 +25,17 @@
 
 // Use this in place of Tensorflow's
 // MultiplyByQuantizedMultiplierSmallerThanOneExp
-#define rescale_internal(Vd, Vs, Vscratch, mult, shift, offset, m) \
+#define rescale_internal(Vd, Vs, mult, shift, offset, m) \
   do {                                                             \
-    int32_t _shift = RIGHT_SHIFT(shift);                           \
     vdmulh_w_r_vx##m(Vd, Vs, mult);                                \
-    vdup_w_x##m(Vscratch, -_shift);                                \
-    vand_vv##m(Vscratch, Vscratch, Vd);                            \
-    vsra_w_vx##m(Vscratch, Vscratch, 31);                          \
-    vadd_w_vv##m(Vd, Vd, Vscratch);                                \
-    vsha_w_r_vx##m(Vd, Vd, _shift);                                \
+    vsha_w_r_vx##m(Vd, Vd, -shift);                                \
     vadd_w_vx##m(Vd, Vd, offset);                                  \
   } while (0);
 
-#define rescale(Vd, Vs, Vscratch, mult, shift, offset) \
-  rescale_internal(Vd, Vs, Vscratch, mult, shift,      \
+#define rescale(Vd, Vs, mult, shift, offset) \
+  rescale_internal(Vd, Vs, mult, shift,      \
                    offset, );  // NOLINT(whitespace/parens)
-#define rescale_m(Vd, Vs, Vscratch, mult, shift, offset) \
-  rescale_internal(Vd, Vs, Vscratch, mult, shift, offset, _m);
+#define rescale_m(Vd, Vs, mult, shift, offset) \
+  rescale_internal(Vd, Vs, mult, shift, offset, _m);
 
 #endif  // TFLM_OPT_UTIL_H_