Improvements to elementwise_add_{s8,s16}
- Implement the quantized multiplication with less instructions.
Change-Id: I31b7b7b742fad91be7aebe6a2970c3519501e115
diff --git a/tflm/opt/elementwise_add_s16.cc b/tflm/opt/elementwise_add_s16.cc
index 001113e..853ad9d 100644
--- a/tflm/opt/elementwise_add_s16.cc
+++ b/tflm/opt/elementwise_add_s16.cc
@@ -57,18 +57,18 @@
vmul_w_vx_m(vm2, vm2, input2_shift_mul);
vmul_w_vx_m(vm3, vm3, input2_shift_mul);
- rescale_m(vm0, vm0, vm15, input1_mult, input1_shift, input1_offset);
- rescale_m(vm1, vm1, vm15, input1_mult, input1_shift, input1_offset);
- rescale_m(vm2, vm2, vm15, input2_mult, input2_shift, input2_offset);
- rescale_m(vm3, vm3, vm15, input2_mult, input2_shift, input2_offset);
+ rescale_m(vm0, vm0, input1_mult, input1_shift, input1_offset);
+ rescale_m(vm1, vm1, input1_mult, input1_shift, input1_offset);
+ rescale_m(vm2, vm2, input2_mult, input2_shift, input2_offset);
+ rescale_m(vm3, vm3, input2_mult, input2_shift, input2_offset);
// Sum the rescaled inputs.
vadd_w_vv_m(vm0, vm0, vm2);
vadd_w_vv_m(vm1, vm1, vm3);
// Rescale the summed output.
- rescale_m(vm0, vm0, vm15, output_mult, output_shift, output_offset);
- rescale_m(vm1, vm1, vm15, output_mult, output_shift, output_offset);
+ rescale_m(vm0, vm0, output_mult, output_shift, output_offset);
+ rescale_m(vm1, vm1, output_mult, output_shift, output_offset);
// Clamp to the provided range.
vmin_w_vx_m(vm0, vm0, output_activation_max);
diff --git a/tflm/opt/elementwise_add_s8.cc b/tflm/opt/elementwise_add_s8.cc
index 762d7af..b751ceb 100644
--- a/tflm/opt/elementwise_add_s8.cc
+++ b/tflm/opt/elementwise_add_s8.cc
@@ -71,14 +71,14 @@
vmul_w_vx_m(vm6, vm6, input2_shift_mul);
vmul_w_vx_m(vm7, vm7, input2_shift_mul);
- rescale_m(vm0, vm0, vm15, input1_mult, input1_shift, input1_offset);
- rescale_m(vm1, vm1, vm15, input1_mult, input1_shift, input1_offset);
- rescale_m(vm2, vm2, vm15, input1_mult, input1_shift, input1_offset);
- rescale_m(vm3, vm3, vm15, input1_mult, input1_shift, input1_offset);
- rescale_m(vm4, vm4, vm15, input2_mult, input2_shift, input2_offset);
- rescale_m(vm5, vm5, vm15, input2_mult, input2_shift, input2_offset);
- rescale_m(vm6, vm6, vm15, input2_mult, input2_shift, input2_offset);
- rescale_m(vm7, vm7, vm15, input2_mult, input2_shift, input2_offset);
+ rescale_m(vm0, vm0, input1_mult, input1_shift, input1_offset);
+ rescale_m(vm1, vm1, input1_mult, input1_shift, input1_offset);
+ rescale_m(vm2, vm2, input1_mult, input1_shift, input1_offset);
+ rescale_m(vm3, vm3, input1_mult, input1_shift, input1_offset);
+ rescale_m(vm4, vm4, input2_mult, input2_shift, input2_offset);
+ rescale_m(vm5, vm5, input2_mult, input2_shift, input2_offset);
+ rescale_m(vm6, vm6, input2_mult, input2_shift, input2_offset);
+ rescale_m(vm7, vm7, input2_mult, input2_shift, input2_offset);
// Sum the rescaled inputs.
vadd_w_vv_m(vm0, vm0, vm4);
@@ -87,10 +87,10 @@
vadd_w_vv_m(vm3, vm3, vm7);
// Rescale the summed output.
- rescale_m(vm0, vm0, vm15, output_mult, output_shift, output_offset);
- rescale_m(vm1, vm1, vm15, output_mult, output_shift, output_offset);
- rescale_m(vm2, vm2, vm15, output_mult, output_shift, output_offset);
- rescale_m(vm3, vm3, vm15, output_mult, output_shift, output_offset);
+ rescale_m(vm0, vm0, output_mult, output_shift, output_offset);
+ rescale_m(vm1, vm1, output_mult, output_shift, output_offset);
+ rescale_m(vm2, vm2, output_mult, output_shift, output_offset);
+ rescale_m(vm3, vm3, output_mult, output_shift, output_offset);
// Clamp to the provided range.
vmin_w_vx_m(vm0, vm0, output_activation_max);
diff --git a/tflm/opt/util.h b/tflm/opt/util.h
index d0c16db..7fd36be 100644
--- a/tflm/opt/util.h
+++ b/tflm/opt/util.h
@@ -25,22 +25,17 @@
// Use this in place of Tensorflow's
// MultiplyByQuantizedMultiplierSmallerThanOneExp
-#define rescale_internal(Vd, Vs, Vscratch, mult, shift, offset, m) \
+#define rescale_internal(Vd, Vs, mult, shift, offset, m) \
do { \
- int32_t _shift = RIGHT_SHIFT(shift); \
vdmulh_w_r_vx##m(Vd, Vs, mult); \
- vdup_w_x##m(Vscratch, -_shift); \
- vand_vv##m(Vscratch, Vscratch, Vd); \
- vsra_w_vx##m(Vscratch, Vscratch, 31); \
- vadd_w_vv##m(Vd, Vd, Vscratch); \
- vsha_w_r_vx##m(Vd, Vd, _shift); \
+ vsha_w_r_vx##m(Vd, Vd, -shift); \
vadd_w_vx##m(Vd, Vd, offset); \
} while (0);
-#define rescale(Vd, Vs, Vscratch, mult, shift, offset) \
- rescale_internal(Vd, Vs, Vscratch, mult, shift, \
+#define rescale(Vd, Vs, mult, shift, offset) \
+ rescale_internal(Vd, Vs, mult, shift, \
offset, ); // NOLINT(whitespace/parens)
-#define rescale_m(Vd, Vs, Vscratch, mult, shift, offset) \
- rescale_internal(Vd, Vs, Vscratch, mult, shift, offset, _m);
+#define rescale_m(Vd, Vs, mult, shift, offset) \
+ rescale_internal(Vd, Vs, mult, shift, offset, _m);
#endif // TFLM_OPT_UTIL_H_