Vectorized Logistic for int8

- Vectorize the logistic kernel for TFLM, handling up to 32 input values
  a cycle.

Change-Id: I26ae6a04946caf7fc00323cd4199df8f2229f68f
diff --git a/tests/tflm/BUILD b/tests/tflm/BUILD
index 747b0e8..33a91a9 100644
--- a/tests/tflm/BUILD
+++ b/tests/tflm/BUILD
@@ -86,6 +86,22 @@
 )
 
 kelvin_test(
+    name = "logistic_test",
+    srcs = [
+        "@tflite-micro//tensorflow/lite/micro/kernels:logistic_test.cc",
+    ],
+    deps = [
+        "//crt",
+        "@tflite-micro//tensorflow/lite/c:common",
+        "@tflite-micro//tensorflow/lite/kernels/internal:tensor",
+        "@tflite-micro//tensorflow/lite/micro:micro_utils",
+        "@tflite-micro//tensorflow/lite/micro:test_helpers",
+        "@tflite-micro//tensorflow/lite/micro/kernels:kernel_runner",
+        "@tflite-micro//tensorflow/lite/micro/testing:micro_test",
+    ],
+)
+
+kelvin_test(
     name = "pooling_test",
     srcs = [
         "@tflite-micro//tensorflow/lite/micro/kernels:pooling_test.cc",
diff --git a/tflm/opt/BUILD b/tflm/opt/BUILD
index 3464e0d..a23b7df 100644
--- a/tflm/opt/BUILD
+++ b/tflm/opt/BUILD
@@ -32,6 +32,7 @@
         "elementwise_add_s8.cc",
         "leaky_relu_s16.cc",
         "leaky_relu_s8.cc",
+        "logistic_s8.cc",
         "max_pool_s8.cc",
         "memcpy.cc",
     ],
diff --git a/tflm/opt/logistic_s8.cc b/tflm/opt/logistic_s8.cc
new file mode 100644
index 0000000..9c349c7
--- /dev/null
+++ b/tflm/opt/logistic_s8.cc
@@ -0,0 +1,212 @@
+/*
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "crt/kelvin.h"
+#include "tensorflow/lite/kernels/internal/common.h"
+#include "tensorflow/lite/micro/micro_log.h"
+#include "tflm/opt/util.h"
+
+namespace kelvin::opt {
+
+void LogisticS8(int32_t input_zero_point, int32_t input_range_radius,
+                int32_t input_multiplier, int32_t input_left_shift,
+                int32_t input_size, const int8_t* input_data,
+                int8_t* output_data) {
+  static constexpr int8_t kMinInt8 = std::numeric_limits<int8_t>::min();
+  static constexpr int8_t kMaxInt8 = std::numeric_limits<int8_t>::max();
+  static constexpr int32_t kOutputZeroPoint = -128;
+
+#define INPUTS v0
+#define MASK_IF_POSITIVE v4
+#define MASK_IF_ZERO v8
+#define NEG_ABS_INPUT v12
+  int i = 0;
+  do {
+    int count_this_iter = std::min(32L, input_size - i);
+
+    // Load inputs, widen to 32-bits and apply zero-point.
+    vld_b_l_xx(INPUTS, input_data + i, count_this_iter);
+    vaddw_h_vx(INPUTS, INPUTS, 0);
+    vaddw_w_vx(v2, v1, -input_zero_point);
+    vaddw_w_vx(INPUTS, v0, -input_zero_point);
+
+    // MultiplyByQuantizedMultiplier
+    vmul_w_vx_m(v16, INPUTS, (1 << LEFT_SHIFT(input_left_shift)));
+    vdmulh_w_r_vx_m(v16, v16, input_multiplier);
+    vsha_w_r_vx_m(v16, v16, RIGHT_SHIFT(input_left_shift));
+
+    // Start gemmlowp::logistic
+    // Compute a mask of positive inputs
+    constexpr int32_t kInt32_AllOnes = ~0L;
+    vgt_w_vx_m(v20, v16, 0);
+    vdup_w_x_m(v4, kInt32_AllOnes);
+    vsel_w_vx_m(MASK_IF_POSITIVE, v20, 0);
+
+    // Compute a mask of zero inputs
+    veq_w_vx_m(v20, v16, 0);
+    vdup_w_x_m(v8, -1);
+    vsel_w_vx_m(MASK_IF_ZERO, v20, 0);
+
+    // Calculate absolute values of inputs, and negate
+    vabsd_w_vx_m(NEG_ABS_INPUT, v16, 0);
+    vrsub_w_vx_m(NEG_ABS_INPUT, NEG_ABS_INPUT, 0);
+
+    // Start gemmlowp::exp_on_negative_values
+    constexpr int32_t kQ4_OneQuarter = 0x02000000;
+    constexpr int32_t kQ4_OneQuarterMinusOne = 0x01FFFFFF;
+    vand_w_vx_m(v56, NEG_ABS_INPUT, kQ4_OneQuarterMinusOne);
+    vsub_w_vx_m(v56, v56, kQ4_OneQuarter);
+
+    // remainders -- live until after barrel shifters
+    vsub_w_vv_m(v60, v56, NEG_ABS_INPUT);
+
+    // Start gemmlowp::exp_on_interval_between_negative_one_quarter_and_0_excl
+    vsha_w_r_vx_m(v56, v56, -4);
+    constexpr int32_t kQ4_OneEighth = 0x10000000;
+    vdup_w_x_m(v20, kQ4_OneEighth);
+    vadd_w_vv_m(v56, v56, v20);
+
+    vdmulh_w_r_vv_m(v16, v56, v56);  // x2
+    vdmulh_w_r_vv_m(v24, v56, v16);
+    vdmulh_w_r_vv_m(v20, v16, v16);
+    vsha_w_r_vx_m(v20, v20, 2);
+
+    constexpr int32_t kQ4_ConstantTerm = 0x70f5a894;
+    constexpr int32_t kQ4_ConstantOneOverThree = 0x2aaaaaab;
+    vadd_w_vv_m(v20, v20, v24);  // x4_over_4 + x3
+    vdmulh_w_r_vx_m(v20, v20,
+                    kQ4_ConstantOneOverThree);  // _ * constant_1_over_3
+    vadd_w_vv_m(v20, v20, v16);                 // _ + x2
+    vsha_w_r_vx_m(v20, v20, 1);  // SaturatingRoundingMultiplyByPOT<-1>(_)
+
+    vadd_w_vv_m(v20, v56, v20);                   // x + x4_over_24...
+    vdmulh_w_r_vx_m(v20, v20, kQ4_ConstantTerm);  // constant_term * _
+    vadd_w_vx_m(v20, v20, kQ4_ConstantTerm);
+    // End gemmlowp::exp_on_interval_between_negative_one_quarter_and_0_excl
+
+#define BARREL_SHIFTER(shift, multiplier)  \
+  {                                        \
+    vand_w_vx_m(v28, v60, 1 << shift);     \
+    vne_w_vx_m(v28, v28, 0);               \
+    vdmulh_w_r_vx_m(v24, v20, multiplier); \
+    vsel_w_vv_m(v24, v28, v20);            \
+    vmv_v_m(v20, v24);                     \
+  }
+
+    BARREL_SHIFTER(25, 0x63afbe7b);
+    BARREL_SHIFTER(26, 0x4da2cbf2);
+    BARREL_SHIFTER(27, 0x2f16ac6c);
+    BARREL_SHIFTER(28, 0x1152aaa4);
+    BARREL_SHIFTER(29, 0x02582ab7);
+    BARREL_SHIFTER(30, 0x000afe11);
+    BARREL_SHIFTER(0, 0x000000f2);
+#undef BARREL_SHIFTER
+
+    constexpr int32_t kResultF_One = 0x7fffffff;
+    vne_w_vx_m(v56, NEG_ABS_INPUT, 0);
+    vsel_w_vx_m(v24, v56, kResultF_One);
+    // End gemmlowp::exp_on_negative_values
+
+    // Begin gemmlowp::one_over_one_plus_x_for_x_in_0_1
+    constexpr int32_t kF2_Constant48Over17 = 0x5a5a5a5a;
+    constexpr int32_t kF2_ConstantNeg32Over17 = 0xc3c3c3c4;
+    constexpr int32_t kF0_OneHalf = 0x40000000;
+    vshl_w_vx_m(v24, v24, 1);            // x0 >> 1
+    vadd_w_vx_m(v24, v24, kF0_OneHalf);  // _ + ((x1 + 1) >> 1)
+    vmv_v_m(v20, v24);                   // half_denominators
+
+    vdmulh_w_r_vx_m(v24, v24,
+                    kF2_ConstantNeg32Over17);     // half_denominator * -32/17
+    vadd_w_vx_m(v24, v24, kF2_Constant48Over17);  // _ + 48/17
+
+    constexpr int32_t kF2_One = 0x20000000;
+
+#define DIVISION()                  \
+  {                                 \
+    vdmulh_w_r_vv_m(v28, v24, v20); \
+    vmv_v_m(v36, v28);              \
+    vgt_w_vx_m(v32, v28, kF2_One);  \
+    vsel_w_vx_m(v28, v32, kF2_One); \
+    vdup_w_x_m(v32, kF2_One);       \
+    vsub_w_vv_m(v40, v28, v32);     \
+    vdup_w_x_m(v32, 0xffffffff);    \
+    vsub_w_vv_m(v40, v32, v40);     \
+    vadd_w_vx_m(v40, v40, 1);       \
+    vle_w_vx_m(v32, v36, kF2_One);  \
+    vsel_w_vx_m(v36, v32, kF2_One); \
+    vdup_w_x_m(v32, kF2_One);       \
+    vsub_w_vv_m(v36, v32, v36);     \
+    vor_vv_m(v40, v36, v40);        \
+    vdmulh_w_r_vv_m(v40, v40, v24); \
+    vsha_w_r_vx_m(v40, v40, -2);    \
+    vadd_w_vv_m(v40, v40, v24);     \
+    vmv_v_m(v24, v40);              \
+  }
+
+    DIVISION();
+    DIVISION();
+    DIVISION();
+#undef DIVISION
+
+    vsll_w_vx_m(v40, v40, 1);  // result_if_positive
+    // End gemmlowp::one_over_one_plus_x_for_x_in_0_1
+
+    vgt_w_vx_m(v32, v40, kResultF_One);
+    vsel_w_vx_m(v44, v32, kResultF_One);  // values >1
+    vdup_w_x_m(v32, kResultF_One);
+    vsub_w_vv_m(v44, v32, v44);
+    vdup_w_x_m(v32, 0xffffffff);
+    vsub_w_vv_m(v44, v32, v44);
+    vadd_w_vx_m(v44, v44, 1);
+    vle_w_vx_m(v36, v40, kResultF_One);
+    vsel_w_vx_m(v32, v36, kResultF_One);
+    vdup_w_x_m(v36, kResultF_One);
+    vsub_w_vv_m(v32, v36, v40);
+    vor_vv_m(v44, v44, v32);  // result_if_negative
+
+    vsel_w_vv_m(v40, MASK_IF_POSITIVE, v44);
+    vmv_v_m(v56, v40);
+
+    constexpr int32_t kResultF_OneHalf = 0x40000000;
+    vdup_w_x_m(v48, kResultF_OneHalf);  // 1/2
+    vsel_w_vv_m(v48, MASK_IF_ZERO, v56);
+    vmv_v_m(v16, v48);
+    // End gemmlowp::logistic
+
+    vle_w_vx_m(v48, INPUTS, -input_range_radius);
+    vge_w_vx_m(v56, INPUTS, input_range_radius);
+    vor_vv_m(v12, v48, v56);
+    vne_w_vx_m(v12, v12, 1);
+
+    vmul_w_vx_m(v48, v48, static_cast<int32_t>(kMinInt8));
+    vmul_w_vx_m(v8, v56, static_cast<int32_t>(kMaxInt8));
+
+    vmul_w_vv_m(v16, v16, v12);
+    vsha_w_r_vx_m(v16, v16, 23);
+    vadd_w_vx_m(v16, v16, kOutputZeroPoint);
+    vmax_w_vx_m(v16, v16, kMinInt8);
+    vmin_w_vx_m(v16, v16, kMaxInt8);
+    vmul_w_vv_m(v12, v12, v16);
+    vor_vv_m(v48, v48, v56);
+    vor_vv_m(v48, v48, v8);
+    vor_vv_m(v48, v48, v12);
+    vsraqs_b_vx(v48, v48, 0);
+    vst_b_l_xx(v48, output_data + i, count_this_iter);
+
+    i += count_this_iter;
+  } while (i < input_size);
+}
+}  // namespace kelvin::opt
diff --git a/tflm/opt/opt.h b/tflm/opt/opt.h
index 76d5218..4fb9c51 100644
--- a/tflm/opt/opt.h
+++ b/tflm/opt/opt.h
@@ -105,6 +105,10 @@
                const tflite::RuntimeShape& input_shape,
                const int8_t* input_data,
                const tflite::RuntimeShape& output_shape, int8_t* output_data);
+void LogisticS8(int32_t input_zero_point, int32_t input_range_radius,
+                int32_t input_multiplier, int32_t input_left_shift,
+                int32_t input_size, const int8_t* input_data,
+                int8_t* output_data);
 
 }  // namespace kelvin::opt