Remove extra offset addition from ElementwiseAdd
BUG=345095360
Change-Id: I08ed0b49ff43cb146adec7e233efcf8e9b68b61b
diff --git a/tflm/opt/elementwise_add_s16.cc b/tflm/opt/elementwise_add_s16.cc
index 853ad9d..106742b 100644
--- a/tflm/opt/elementwise_add_s16.cc
+++ b/tflm/opt/elementwise_add_s16.cc
@@ -57,10 +57,10 @@
vmul_w_vx_m(vm2, vm2, input2_shift_mul);
vmul_w_vx_m(vm3, vm3, input2_shift_mul);
- rescale_m(vm0, vm0, input1_mult, input1_shift, input1_offset);
- rescale_m(vm1, vm1, input1_mult, input1_shift, input1_offset);
- rescale_m(vm2, vm2, input2_mult, input2_shift, input2_offset);
- rescale_m(vm3, vm3, input2_mult, input2_shift, input2_offset);
+ rescale_m(vm0, vm0, input1_mult, input1_shift, 0);
+ rescale_m(vm1, vm1, input1_mult, input1_shift, 0);
+ rescale_m(vm2, vm2, input2_mult, input2_shift, 0);
+ rescale_m(vm3, vm3, input2_mult, input2_shift, 0);
// Sum the rescaled inputs.
vadd_w_vv_m(vm0, vm0, vm2);
diff --git a/tflm/opt/elementwise_add_s8.cc b/tflm/opt/elementwise_add_s8.cc
index e4119d4..e664769 100644
--- a/tflm/opt/elementwise_add_s8.cc
+++ b/tflm/opt/elementwise_add_s8.cc
@@ -88,22 +88,16 @@
vdmulh_w_r_vx_m(v12, v12, input2_mult);
vsha_w_r_vx_m(v4, v4, -input1_shift);
vsha_w_r_vx_m(v12, v12, -input2_shift);
- vadd_w_vx_m(v4, v4, input1_offset);
- vadd_w_vx_m(v12, v12, input2_offset);
vdmulh_w_r_vx_m(v20, v20, input1_mult);
vsha_w_r_vx_m(v20, v20, -input1_shift);
- vadd_w_vx_m(v20, v20, input1_offset);
vdmulh_w_r_vx_m(v28, v28, input2_mult);
vsha_w_r_vx_m(v28, v28, -input2_shift);
- vadd_w_vx_m(v28, v28, input2_offset);
vdmulh_w_r_vx_m(v36, v36, input1_mult);
vsha_w_r_vx_m(v36, v36, -input1_shift);
- vadd_w_vx_m(v36, v36, input1_offset);
vdmulh_w_r_vx_m(v44, v44, input2_mult);
vsha_w_r_vx_m(v44, v44, -input2_shift);
- vadd_w_vx_m(v44, v44, input2_offset);
vadd_w_vv_m(v12, v4, v12);
vadd_w_vv_m(v28, v20, v28);
@@ -158,8 +152,6 @@
vdmulh_w_r_vx_m(v12, v12, input2_mult);
vsha_w_r_vx_m(v4, v4, -input1_shift);
vsha_w_r_vx_m(v12, v12, -input2_shift);
- vadd_w_vx_m(v4, v4, input1_offset);
- vadd_w_vx_m(v12, v12, input2_offset);
vadd_w_vv_m(v16, v4, v12);