Alex Van Damme | 59112ad | 2023-08-11 11:10:33 -0700 | [diff] [blame] | 1 | // Copyright 2023 Google LLC |
| 2 | // Licensed under the Apache License, Version 2.0, see LICENSE for details. |
| 3 | // SPDX-License-Identifier: Apache-2.0 |
| 4 | |
| 5 | #include "crt/kelvin.h" |
| 6 | #include "tflm/opt/opt.h" |
| 7 | #include "tflm/opt/util.h" |
| 8 | |
| 9 | namespace kelvin::opt { |
| 10 | |
| 11 | void elementwise_add_s16(const int16_t* input1, const int16_t* input2, |
| 12 | const int32_t input1_offset, const int32_t input1_mult, |
| 13 | const int32_t input1_shift, |
| 14 | const int32_t input2_offset, const int32_t input2_mult, |
| 15 | const int32_t input2_shift, const int32_t left_shift, |
| 16 | int16_t* output, const int32_t output_offset, |
| 17 | const int32_t output_mult, const int32_t output_shift, |
| 18 | const int32_t output_activation_min, |
| 19 | const int32_t output_activation_max, |
| 20 | const int32_t block_size) { |
| 21 | int blocks = block_size; |
| 22 | int vl; |
| 23 | getmaxvl_h(vl); |
| 24 | while (blocks) { |
| 25 | int count = std::min(blocks, vl); |
| 26 | |
| 27 | // Widen input1 to 32-bit wide values (in vm0, vm1). |
| 28 | vld_h_lp_xx_m(vm0, input1, count); |
| 29 | vaddw_w_vx_m(vm0, vm0, input1_offset); |
| 30 | |
| 31 | // Widen input2 to 32-bit wide values (in vm2, vm3). |
| 32 | vld_h_lp_xx_m(vm2, input2, count); |
| 33 | vaddw_w_vx_m(vm2, vm2, input2_offset); |
| 34 | |
| 35 | // Apply left_shift to all inputs. |
| 36 | vsll_w_vx_m(vm0, vm0, left_shift); |
| 37 | vsll_w_vx_m(vm1, vm1, left_shift); |
| 38 | vsll_w_vx_m(vm2, vm2, left_shift); |
| 39 | vsll_w_vx_m(vm3, vm3, left_shift); |
| 40 | |
| 41 | int32_t input1_shift_mul = 1 << LEFT_SHIFT(input1_shift); |
| 42 | int32_t input2_shift_mul = 1 << LEFT_SHIFT(input2_shift); |
| 43 | vmul_w_vx_m(vm0, vm0, input1_shift_mul); |
| 44 | vmul_w_vx_m(vm1, vm1, input1_shift_mul); |
| 45 | vmul_w_vx_m(vm2, vm2, input2_shift_mul); |
| 46 | vmul_w_vx_m(vm3, vm3, input2_shift_mul); |
| 47 | |
Alex Van Damme | 2c72b28 | 2023-08-25 10:04:59 -0700 | [diff] [blame] | 48 | rescale_m(vm0, vm0, vm15, input1_mult, input1_shift, input1_offset); |
| 49 | rescale_m(vm1, vm1, vm15, input1_mult, input1_shift, input1_offset); |
| 50 | rescale_m(vm2, vm2, vm15, input2_mult, input2_shift, input2_offset); |
| 51 | rescale_m(vm3, vm3, vm15, input2_mult, input2_shift, input2_offset); |
Alex Van Damme | 59112ad | 2023-08-11 11:10:33 -0700 | [diff] [blame] | 52 | |
| 53 | // Sum the rescaled inputs. |
| 54 | vadd_w_vv_m(vm0, vm0, vm2); |
| 55 | vadd_w_vv_m(vm1, vm1, vm3); |
| 56 | |
| 57 | // Rescale the summed output. |
Alex Van Damme | 2c72b28 | 2023-08-25 10:04:59 -0700 | [diff] [blame] | 58 | rescale_m(vm0, vm0, vm15, output_mult, output_shift, output_offset); |
| 59 | rescale_m(vm1, vm1, vm15, output_mult, output_shift, output_offset); |
Alex Van Damme | 59112ad | 2023-08-11 11:10:33 -0700 | [diff] [blame] | 60 | |
| 61 | // Clamp to the provided range. |
| 62 | vmin_w_vx_m(vm0, vm0, output_activation_max); |
| 63 | vmin_w_vx_m(vm1, vm1, output_activation_max); |
| 64 | vmax_w_vx_m(vm0, vm0, output_activation_min); |
| 65 | vmax_w_vx_m(vm1, vm1, output_activation_min); |
| 66 | |
| 67 | // Swizzle and narrow back to bytes. |
| 68 | vand_w_vx_m(vm0, vm0, 0xFFFF); |
| 69 | vand_w_vx_m(vm1, vm1, 0xFFFF); |
| 70 | vsll_w_vx_m(vm1, vm1, 16); |
| 71 | vor_vv_m(vm0, vm0, vm1); |
| 72 | |
| 73 | // Store to memory. |
| 74 | vst_h_lp_xx_m(vm0, output, count); |
| 75 | |
| 76 | blocks -= count; |
| 77 | } |
| 78 | } |
| 79 | |
| 80 | } // namespace kelvin::opt |