blob: ac83e1f3cddba914a5fa03d287d8db90a5ff627e [file] [log] [blame]
Cindy Liu43879e42023-10-18 11:18:03 -07001/*
2 * Copyright 2023 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
Alex Van Damme59112ad2023-08-11 11:10:33 -070016
17#include "crt/kelvin.h"
18#include "tflm/opt/opt.h"
19#include "tflm/opt/util.h"
20
21namespace kelvin::opt {
22
23void elementwise_add_s8(const int8_t* input1, const int8_t* input2,
24 const int32_t input1_offset, const int32_t input1_mult,
25 const int32_t input1_shift, const int32_t input2_offset,
26 const int32_t input2_mult, const int32_t input2_shift,
27 const int32_t left_shift, int8_t* output,
28 const int32_t output_offset, const int32_t output_mult,
29 const int32_t output_shift,
30 const int32_t output_activation_min,
31 const int32_t output_activation_max,
32 const int32_t block_size) {
33 int blocks = block_size;
34 int vl;
35 getmaxvl_b(vl);
36
37 const int32_t input1_shift_mul = 1 << LEFT_SHIFT(input1_shift);
38 const int32_t input2_shift_mul = 1 << LEFT_SHIFT(input2_shift);
39
40 while (blocks) {
41 int count = std::min(blocks, vl);
42
43 // Widen input1 to 32-bit wide values (in vm0, vm1, vm2, vm3).
44 vld_b_lp_xx_m(vm0, input1, count);
45 vaddw_h_vx_m(vm0, vm0, 0);
46 vaddw_w_vx_m(vm2, vm1, input1_offset);
47 vaddw_w_vx_m(vm0, vm0, input1_offset);
48
49 // Widen input2 to 32-bit wide values (in vm4, vm5, vm6, vm7).
50 vld_b_lp_xx_m(vm4, input2, count);
51 vaddw_h_vx_m(vm4, vm4, 0);
52 vaddw_w_vx_m(vm6, vm5, input2_offset);
53 vaddw_w_vx_m(vm4, vm4, input2_offset);
54
55 // Apply left_shift to all inputs.
56 vsll_w_vx_m(vm0, vm0, left_shift);
57 vsll_w_vx_m(vm1, vm1, left_shift);
58 vsll_w_vx_m(vm2, vm2, left_shift);
59 vsll_w_vx_m(vm3, vm3, left_shift);
60 vsll_w_vx_m(vm4, vm4, left_shift);
61 vsll_w_vx_m(vm5, vm5, left_shift);
62 vsll_w_vx_m(vm6, vm6, left_shift);
63 vsll_w_vx_m(vm7, vm7, left_shift);
64
65 vmul_w_vx_m(vm0, vm0, input1_shift_mul);
66 vmul_w_vx_m(vm1, vm1, input1_shift_mul);
67 vmul_w_vx_m(vm2, vm2, input1_shift_mul);
68 vmul_w_vx_m(vm3, vm3, input1_shift_mul);
69 vmul_w_vx_m(vm4, vm4, input2_shift_mul);
70 vmul_w_vx_m(vm5, vm5, input2_shift_mul);
71 vmul_w_vx_m(vm6, vm6, input2_shift_mul);
72 vmul_w_vx_m(vm7, vm7, input2_shift_mul);
73
Alex Van Damme2c72b282023-08-25 10:04:59 -070074 rescale_m(vm0, vm0, vm15, input1_mult, input1_shift, input1_offset);
75 rescale_m(vm1, vm1, vm15, input1_mult, input1_shift, input1_offset);
76 rescale_m(vm2, vm2, vm15, input1_mult, input1_shift, input1_offset);
77 rescale_m(vm3, vm3, vm15, input1_mult, input1_shift, input1_offset);
78 rescale_m(vm4, vm4, vm15, input2_mult, input2_shift, input2_offset);
79 rescale_m(vm5, vm5, vm15, input2_mult, input2_shift, input2_offset);
80 rescale_m(vm6, vm6, vm15, input2_mult, input2_shift, input2_offset);
81 rescale_m(vm7, vm7, vm15, input2_mult, input2_shift, input2_offset);
Alex Van Damme59112ad2023-08-11 11:10:33 -070082
83 // Sum the rescaled inputs.
84 vadd_w_vv_m(vm0, vm0, vm4);
85 vadd_w_vv_m(vm1, vm1, vm5);
86 vadd_w_vv_m(vm2, vm2, vm6);
87 vadd_w_vv_m(vm3, vm3, vm7);
88
89 // Rescale the summed output.
Alex Van Damme2c72b282023-08-25 10:04:59 -070090 rescale_m(vm0, vm0, vm15, output_mult, output_shift, output_offset);
91 rescale_m(vm1, vm1, vm15, output_mult, output_shift, output_offset);
92 rescale_m(vm2, vm2, vm15, output_mult, output_shift, output_offset);
93 rescale_m(vm3, vm3, vm15, output_mult, output_shift, output_offset);
Alex Van Damme59112ad2023-08-11 11:10:33 -070094
95 // Clamp to the provided range.
96 vmin_w_vx_m(vm0, vm0, output_activation_max);
97 vmin_w_vx_m(vm1, vm1, output_activation_max);
98 vmin_w_vx_m(vm2, vm2, output_activation_max);
99 vmin_w_vx_m(vm3, vm3, output_activation_max);
100 vmax_w_vx_m(vm0, vm0, output_activation_min);
101 vmax_w_vx_m(vm1, vm1, output_activation_min);
102 vmax_w_vx_m(vm2, vm2, output_activation_min);
103 vmax_w_vx_m(vm3, vm3, output_activation_min);
104
105 // Swizzle and narrow back to bytes.
106 vsraqs_b_vx_m(vm0, vm0, 0);
107
108 // Store to memory.
109 vst_b_lp_xx_m(vm0, output, count);
110
111 blocks -= count;
112 }
113}
114
115} // namespace kelvin::opt