Add support functions for an optimized Add kernel
- Add elementwise_add_s8,s16,s32. s8 and s16 versions operate on
quantized inputs, s32 operates on full-range values. These can be used
to implement portions of an optimized Add kernel for TFLM.
Change-Id: I6508e12761f0ec1ede3c40289bf5fa2dd42e236a
diff --git a/tests/tflm/BUILD b/tests/tflm/BUILD
index de9127f..71f70ab 100644
--- a/tests/tflm/BUILD
+++ b/tests/tflm/BUILD
@@ -2,6 +2,22 @@
package(default_visibility = ["//visibility:public"])
kelvin_test(
+ name = "add_test",
+ srcs = [
+ "@tflite-micro//tensorflow/lite/micro/kernels:add_test.cc",
+ ],
+ deps = [
+ "//crt:crt_header",
+ "@tflite-micro//tensorflow/lite/c:common",
+ "@tflite-micro//tensorflow/lite/kernels/internal:tensor",
+ "@tflite-micro//tensorflow/lite/micro/kernels:kernel_runner",
+ "@tflite-micro//tensorflow/lite/micro/testing:micro_test",
+ "@tflite-micro//tensorflow/lite/micro:micro_utils",
+ "@tflite-micro//tensorflow/lite/micro:test_helpers",
+ ],
+)
+
+kelvin_test(
name = "reshape_test",
srcs = [
"@tflite-micro//tensorflow/lite/micro/kernels:reshape_test.cc",
diff --git a/tflm/opt/BUILD b/tflm/opt/BUILD
index abf99e6..2ba3900 100644
--- a/tflm/opt/BUILD
+++ b/tflm/opt/BUILD
@@ -3,10 +3,14 @@
cc_library(
name = "opt",
srcs = [
+ "elementwise_add_s8.cc",
+ "elementwise_add_s16.cc",
+ "elementwise_add_s32.cc",
"memcpy.cc",
],
hdrs = [
"opt.h",
+ "util.h",
],
deps = [
"//crt:crt_header",
diff --git a/tflm/opt/elementwise_add_s16.cc b/tflm/opt/elementwise_add_s16.cc
new file mode 100644
index 0000000..2800eb6
--- /dev/null
+++ b/tflm/opt/elementwise_add_s16.cc
@@ -0,0 +1,80 @@
+// Copyright 2023 Google LLC
+// Licensed under the Apache License, Version 2.0, see LICENSE for details.
+// SPDX-License-Identifier: Apache-2.0
+
+#include "crt/kelvin.h"
+#include "tflm/opt/opt.h"
+#include "tflm/opt/util.h"
+
+namespace kelvin::opt {
+
+void elementwise_add_s16(const int16_t* input1, const int16_t* input2,
+ const int32_t input1_offset, const int32_t input1_mult,
+ const int32_t input1_shift,
+ const int32_t input2_offset, const int32_t input2_mult,
+ const int32_t input2_shift, const int32_t left_shift,
+ int16_t* output, const int32_t output_offset,
+ const int32_t output_mult, const int32_t output_shift,
+ const int32_t output_activation_min,
+ const int32_t output_activation_max,
+ const int32_t block_size) {
+ int blocks = block_size;
+ int vl;
+ getmaxvl_h(vl);
+ while (blocks) {
+ int count = std::min(blocks, vl);
+
+ // Widen input1 to 32-bit wide values (in vm0, vm1).
+ vld_h_lp_xx_m(vm0, input1, count);
+ vaddw_w_vx_m(vm0, vm0, input1_offset);
+
+ // Widen input2 to 32-bit wide values (in vm2, vm3).
+ vld_h_lp_xx_m(vm2, input2, count);
+ vaddw_w_vx_m(vm2, vm2, input2_offset);
+
+ // Apply left_shift to all inputs.
+ vsll_w_vx_m(vm0, vm0, left_shift);
+ vsll_w_vx_m(vm1, vm1, left_shift);
+ vsll_w_vx_m(vm2, vm2, left_shift);
+ vsll_w_vx_m(vm3, vm3, left_shift);
+
+ int32_t input1_shift_mul = 1 << LEFT_SHIFT(input1_shift);
+ int32_t input2_shift_mul = 1 << LEFT_SHIFT(input2_shift);
+ vmul_w_vx_m(vm0, vm0, input1_shift_mul);
+ vmul_w_vx_m(vm1, vm1, input1_shift_mul);
+ vmul_w_vx_m(vm2, vm2, input2_shift_mul);
+ vmul_w_vx_m(vm3, vm3, input2_shift_mul);
+
+ rescale_m(vm0, vm0, input1_mult, input1_shift, input1_offset);
+ rescale_m(vm1, vm1, input1_mult, input1_shift, input1_offset);
+ rescale_m(vm2, vm2, input2_mult, input2_shift, input2_offset);
+ rescale_m(vm3, vm3, input2_mult, input2_shift, input2_offset);
+
+ // Sum the rescaled inputs.
+ vadd_w_vv_m(vm0, vm0, vm2);
+ vadd_w_vv_m(vm1, vm1, vm3);
+
+ // Rescale the summed output.
+ rescale_m(vm0, vm0, output_mult, output_shift, output_offset);
+ rescale_m(vm1, vm1, output_mult, output_shift, output_offset);
+
+ // Clamp to the provided range.
+ vmin_w_vx_m(vm0, vm0, output_activation_max);
+ vmin_w_vx_m(vm1, vm1, output_activation_max);
+ vmax_w_vx_m(vm0, vm0, output_activation_min);
+ vmax_w_vx_m(vm1, vm1, output_activation_min);
+
+ // Swizzle and narrow back to bytes.
+ vand_w_vx_m(vm0, vm0, 0xFFFF);
+ vand_w_vx_m(vm1, vm1, 0xFFFF);
+ vsll_w_vx_m(vm1, vm1, 16);
+ vor_vv_m(vm0, vm0, vm1);
+
+ // Store to memory.
+ vst_h_lp_xx_m(vm0, output, count);
+
+ blocks -= count;
+ }
+}
+
+} // namespace kelvin::opt
diff --git a/tflm/opt/elementwise_add_s32.cc b/tflm/opt/elementwise_add_s32.cc
new file mode 100644
index 0000000..ff17fe1
--- /dev/null
+++ b/tflm/opt/elementwise_add_s32.cc
@@ -0,0 +1,31 @@
+// Copyright 2023 Google LLC
+// Licensed under the Apache License, Version 2.0, see LICENSE for details.
+// SPDX-License-Identifier: Apache-2.0
+
+#include "crt/kelvin.h"
+#include "tflm/opt/opt.h"
+
+namespace kelvin::opt {
+void elementwise_add_s32(const int32_t* input1, const int32_t* input2,
+ int32_t* output, const int32_t output_activation_min,
+ const int32_t output_activation_max,
+ const int32_t block_size) {
+ int blocks = block_size;
+ int vl;
+ getmaxvl_w_m(vl);
+ while (blocks) {
+ int count = std::min(blocks, vl);
+
+ vld_w_p_xx_m(vm0, input1, count);
+ vld_w_p_xx_m(vm1, input2, count);
+
+ vadd_w_vv_m(vm0, vm0, vm1);
+ vmin_w_vx_m(vm0, vm0, output_activation_max);
+ vmax_w_vx_m(vm0, vm0, output_activation_min);
+
+ vst_w_p_xx_m(vm0, output, count);
+
+ blocks -= count;
+ }
+}
+} // namespace kelvin::opt
diff --git a/tflm/opt/elementwise_add_s8.cc b/tflm/opt/elementwise_add_s8.cc
new file mode 100644
index 0000000..8380fa1
--- /dev/null
+++ b/tflm/opt/elementwise_add_s8.cc
@@ -0,0 +1,103 @@
+// Copyright 2023 Google LLC
+// Licensed under the Apache License, Version 2.0, see LICENSE for details.
+// SPDX-License-Identifier: Apache-2.0
+
+#include "crt/kelvin.h"
+#include "tflm/opt/opt.h"
+#include "tflm/opt/util.h"
+
+namespace kelvin::opt {
+
+void elementwise_add_s8(const int8_t* input1, const int8_t* input2,
+ const int32_t input1_offset, const int32_t input1_mult,
+ const int32_t input1_shift, const int32_t input2_offset,
+ const int32_t input2_mult, const int32_t input2_shift,
+ const int32_t left_shift, int8_t* output,
+ const int32_t output_offset, const int32_t output_mult,
+ const int32_t output_shift,
+ const int32_t output_activation_min,
+ const int32_t output_activation_max,
+ const int32_t block_size) {
+ int blocks = block_size;
+ int vl;
+ getmaxvl_b(vl);
+
+ const int32_t input1_shift_mul = 1 << LEFT_SHIFT(input1_shift);
+ const int32_t input2_shift_mul = 1 << LEFT_SHIFT(input2_shift);
+
+ while (blocks) {
+ int count = std::min(blocks, vl);
+
+ // Widen input1 to 32-bit wide values (in vm0, vm1, vm2, vm3).
+ vld_b_lp_xx_m(vm0, input1, count);
+ vaddw_h_vx_m(vm0, vm0, 0);
+ vaddw_w_vx_m(vm2, vm1, input1_offset);
+ vaddw_w_vx_m(vm0, vm0, input1_offset);
+
+ // Widen input2 to 32-bit wide values (in vm4, vm5, vm6, vm7).
+ vld_b_lp_xx_m(vm4, input2, count);
+ vaddw_h_vx_m(vm4, vm4, 0);
+ vaddw_w_vx_m(vm6, vm5, input2_offset);
+ vaddw_w_vx_m(vm4, vm4, input2_offset);
+
+ // Apply left_shift to all inputs.
+ vsll_w_vx_m(vm0, vm0, left_shift);
+ vsll_w_vx_m(vm1, vm1, left_shift);
+ vsll_w_vx_m(vm2, vm2, left_shift);
+ vsll_w_vx_m(vm3, vm3, left_shift);
+ vsll_w_vx_m(vm4, vm4, left_shift);
+ vsll_w_vx_m(vm5, vm5, left_shift);
+ vsll_w_vx_m(vm6, vm6, left_shift);
+ vsll_w_vx_m(vm7, vm7, left_shift);
+
+ vmul_w_vx_m(vm0, vm0, input1_shift_mul);
+ vmul_w_vx_m(vm1, vm1, input1_shift_mul);
+ vmul_w_vx_m(vm2, vm2, input1_shift_mul);
+ vmul_w_vx_m(vm3, vm3, input1_shift_mul);
+ vmul_w_vx_m(vm4, vm4, input2_shift_mul);
+ vmul_w_vx_m(vm5, vm5, input2_shift_mul);
+ vmul_w_vx_m(vm6, vm6, input2_shift_mul);
+ vmul_w_vx_m(vm7, vm7, input2_shift_mul);
+
+ rescale_m(vm0, vm0, input1_mult, input1_shift, input1_offset);
+ rescale_m(vm1, vm1, input1_mult, input1_shift, input1_offset);
+ rescale_m(vm2, vm2, input1_mult, input1_shift, input1_offset);
+ rescale_m(vm3, vm3, input1_mult, input1_shift, input1_offset);
+ rescale_m(vm4, vm4, input2_mult, input2_shift, input2_offset);
+ rescale_m(vm5, vm5, input2_mult, input2_shift, input2_offset);
+ rescale_m(vm6, vm6, input2_mult, input2_shift, input2_offset);
+ rescale_m(vm7, vm7, input2_mult, input2_shift, input2_offset);
+
+ // Sum the rescaled inputs.
+ vadd_w_vv_m(vm0, vm0, vm4);
+ vadd_w_vv_m(vm1, vm1, vm5);
+ vadd_w_vv_m(vm2, vm2, vm6);
+ vadd_w_vv_m(vm3, vm3, vm7);
+
+ // Rescale the summed output.
+ rescale_m(vm0, vm0, output_mult, output_shift, output_offset);
+ rescale_m(vm1, vm1, output_mult, output_shift, output_offset);
+ rescale_m(vm2, vm2, output_mult, output_shift, output_offset);
+ rescale_m(vm3, vm3, output_mult, output_shift, output_offset);
+
+ // Clamp to the provided range.
+ vmin_w_vx_m(vm0, vm0, output_activation_max);
+ vmin_w_vx_m(vm1, vm1, output_activation_max);
+ vmin_w_vx_m(vm2, vm2, output_activation_max);
+ vmin_w_vx_m(vm3, vm3, output_activation_max);
+ vmax_w_vx_m(vm0, vm0, output_activation_min);
+ vmax_w_vx_m(vm1, vm1, output_activation_min);
+ vmax_w_vx_m(vm2, vm2, output_activation_min);
+ vmax_w_vx_m(vm3, vm3, output_activation_min);
+
+ // Swizzle and narrow back to bytes.
+ vsraqs_b_vx_m(vm0, vm0, 0);
+
+ // Store to memory.
+ vst_b_lp_xx_m(vm0, output, count);
+
+ blocks -= count;
+ }
+}
+
+} // namespace kelvin::opt
diff --git a/tflm/opt/opt.h b/tflm/opt/opt.h
index 5574daf..12075ab 100644
--- a/tflm/opt/opt.h
+++ b/tflm/opt/opt.h
@@ -2,11 +2,35 @@
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0
-#ifndef OPT_OPT_H_
-#define OPT_OPT_H_
+#ifndef TFLM_OPT_OPT_H_
+#define TFLM_OPT_OPT_H_
namespace kelvin::opt {
void *memcpy(void *dst, const void *src, size_t n);
+void elementwise_add_s8(const int8_t* input1, const int8_t* input2,
+ const int32_t input1_offset, const int32_t input1_mult,
+ const int32_t input1_shift, const int32_t input2_offset,
+ const int32_t input2_mult, const int32_t input2_shift,
+ const int32_t left_shift, int8_t* output,
+ const int32_t output_offset, const int32_t output_mult,
+ const int32_t output_shift,
+ const int32_t output_activation_min,
+ const int32_t output_activation_max,
+ const int32_t block_size);
+void elementwise_add_s16(const int16_t* input1, const int16_t* input2,
+ const int32_t input1_offset, const int32_t input1_mult,
+ const int32_t input1_shift,
+ const int32_t input2_offset, const int32_t input2_mult,
+ const int32_t input2_shift, const int32_t left_shift,
+ int16_t* output, const int32_t output_offset,
+ const int32_t output_mult, const int32_t output_shift,
+ const int32_t output_activation_min,
+ const int32_t output_activation_max,
+ const int32_t block_size);
+void elementwise_add_s32(const int32_t* input1, const int32_t* input2,
+ int32_t* output, const int32_t output_activation_min,
+ const int32_t output_activation_max,
+ const int32_t block_size);
} // namespace kelvin::opt
-#endif // OPT_OPT_H_
+#endif // TFLM_OPT_OPT_H_
diff --git a/tflm/opt/util.h b/tflm/opt/util.h
new file mode 100644
index 0000000..8f9d079
--- /dev/null
+++ b/tflm/opt/util.h
@@ -0,0 +1,26 @@
+// Copyright 2023 Google LLC
+// Licensed under the Apache License, Version 2.0, see LICENSE for details.
+// SPDX-License-Identifier: Apache-2.0
+
+#ifndef TFLM_OPT_UTIL_H_
+#define TFLM_OPT_UTIL_H_
+
+#include <algorithm>
+#include <cstdint>
+
+#define LEFT_SHIFT(_shift) std::max(_shift, 0L)
+#define RIGHT_SHIFT(_shift) -std::min(_shift, 0L)
+
+#define rescale_internal(Vd, Vs, mult, shift, offset, m) \
+ do { \
+ int32_t _shift = RIGHT_SHIFT(shift); \
+ vmulh_w_r_vx##m(Vd, Vs, mult); \
+ vmul_w_vx##m(Vd, Vd, 2); \
+ vsha_w_vx##m(Vd, Vd, _shift); \
+ vadd_w_vx##m(Vd, Vd, offset); \
+ } while (0);
+
+#define rescale(Vd, Vs, mult, shift, offset) rescale_internal(Vd, Vs, mult, shift, offset, );
+#define rescale_m(Vd, Vs, mult, shift, offset) rescale_internal(Vd, Vs, mult, shift, offset, _m);
+
+#endif // TFLM_OPT_UTIL_H_