blob: 8b5370b2080f9f7fe33faeddfa788a7847659ab5 [file] [log] [blame]
Alex Van Damme59112ad2023-08-11 11:10:33 -07001// Copyright 2023 Google LLC
2// Licensed under the Apache License, Version 2.0, see LICENSE for details.
3// SPDX-License-Identifier: Apache-2.0
4
5#include "crt/kelvin.h"
6#include "tflm/opt/opt.h"
7#include "tflm/opt/util.h"
8
9namespace kelvin::opt {
10
11void elementwise_add_s16(const int16_t* input1, const int16_t* input2,
12 const int32_t input1_offset, const int32_t input1_mult,
13 const int32_t input1_shift,
14 const int32_t input2_offset, const int32_t input2_mult,
15 const int32_t input2_shift, const int32_t left_shift,
16 int16_t* output, const int32_t output_offset,
17 const int32_t output_mult, const int32_t output_shift,
18 const int32_t output_activation_min,
19 const int32_t output_activation_max,
20 const int32_t block_size) {
21 int blocks = block_size;
22 int vl;
23 getmaxvl_h(vl);
24 while (blocks) {
25 int count = std::min(blocks, vl);
26
27 // Widen input1 to 32-bit wide values (in vm0, vm1).
28 vld_h_lp_xx_m(vm0, input1, count);
29 vaddw_w_vx_m(vm0, vm0, input1_offset);
30
31 // Widen input2 to 32-bit wide values (in vm2, vm3).
32 vld_h_lp_xx_m(vm2, input2, count);
33 vaddw_w_vx_m(vm2, vm2, input2_offset);
34
35 // Apply left_shift to all inputs.
36 vsll_w_vx_m(vm0, vm0, left_shift);
37 vsll_w_vx_m(vm1, vm1, left_shift);
38 vsll_w_vx_m(vm2, vm2, left_shift);
39 vsll_w_vx_m(vm3, vm3, left_shift);
40
41 int32_t input1_shift_mul = 1 << LEFT_SHIFT(input1_shift);
42 int32_t input2_shift_mul = 1 << LEFT_SHIFT(input2_shift);
43 vmul_w_vx_m(vm0, vm0, input1_shift_mul);
44 vmul_w_vx_m(vm1, vm1, input1_shift_mul);
45 vmul_w_vx_m(vm2, vm2, input2_shift_mul);
46 vmul_w_vx_m(vm3, vm3, input2_shift_mul);
47
Alex Van Damme2c72b282023-08-25 10:04:59 -070048 rescale_m(vm0, vm0, vm15, input1_mult, input1_shift, input1_offset);
49 rescale_m(vm1, vm1, vm15, input1_mult, input1_shift, input1_offset);
50 rescale_m(vm2, vm2, vm15, input2_mult, input2_shift, input2_offset);
51 rescale_m(vm3, vm3, vm15, input2_mult, input2_shift, input2_offset);
Alex Van Damme59112ad2023-08-11 11:10:33 -070052
53 // Sum the rescaled inputs.
54 vadd_w_vv_m(vm0, vm0, vm2);
55 vadd_w_vv_m(vm1, vm1, vm3);
56
57 // Rescale the summed output.
Alex Van Damme2c72b282023-08-25 10:04:59 -070058 rescale_m(vm0, vm0, vm15, output_mult, output_shift, output_offset);
59 rescale_m(vm1, vm1, vm15, output_mult, output_shift, output_offset);
Alex Van Damme59112ad2023-08-11 11:10:33 -070060
61 // Clamp to the provided range.
62 vmin_w_vx_m(vm0, vm0, output_activation_max);
63 vmin_w_vx_m(vm1, vm1, output_activation_max);
64 vmax_w_vx_m(vm0, vm0, output_activation_min);
65 vmax_w_vx_m(vm1, vm1, output_activation_min);
66
67 // Swizzle and narrow back to bytes.
68 vand_w_vx_m(vm0, vm0, 0xFFFF);
69 vand_w_vx_m(vm1, vm1, 0xFFFF);
70 vsll_w_vx_m(vm1, vm1, 16);
71 vor_vv_m(vm0, vm0, vm1);
72
73 // Store to memory.
74 vst_h_lp_xx_m(vm0, output, count);
75
76 blocks -= count;
77 }
78}
79
80} // namespace kelvin::opt