Vectorized Logistic for int8
- Vectorize the logistic kernel for TFLM, handling up to 32 input values
a cycle.
Change-Id: I26ae6a04946caf7fc00323cd4199df8f2229f68f
diff --git a/tests/tflm/BUILD b/tests/tflm/BUILD
index 747b0e8..33a91a9 100644
--- a/tests/tflm/BUILD
+++ b/tests/tflm/BUILD
@@ -86,6 +86,22 @@
)
kelvin_test(
+ name = "logistic_test",
+ srcs = [
+ "@tflite-micro//tensorflow/lite/micro/kernels:logistic_test.cc",
+ ],
+ deps = [
+ "//crt",
+ "@tflite-micro//tensorflow/lite/c:common",
+ "@tflite-micro//tensorflow/lite/kernels/internal:tensor",
+ "@tflite-micro//tensorflow/lite/micro:micro_utils",
+ "@tflite-micro//tensorflow/lite/micro:test_helpers",
+ "@tflite-micro//tensorflow/lite/micro/kernels:kernel_runner",
+ "@tflite-micro//tensorflow/lite/micro/testing:micro_test",
+ ],
+)
+
+kelvin_test(
name = "pooling_test",
srcs = [
"@tflite-micro//tensorflow/lite/micro/kernels:pooling_test.cc",
diff --git a/tflm/opt/BUILD b/tflm/opt/BUILD
index 3464e0d..a23b7df 100644
--- a/tflm/opt/BUILD
+++ b/tflm/opt/BUILD
@@ -32,6 +32,7 @@
"elementwise_add_s8.cc",
"leaky_relu_s16.cc",
"leaky_relu_s8.cc",
+ "logistic_s8.cc",
"max_pool_s8.cc",
"memcpy.cc",
],
diff --git a/tflm/opt/logistic_s8.cc b/tflm/opt/logistic_s8.cc
new file mode 100644
index 0000000..9c349c7
--- /dev/null
+++ b/tflm/opt/logistic_s8.cc
@@ -0,0 +1,212 @@
+/*
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "crt/kelvin.h"
+#include "tensorflow/lite/kernels/internal/common.h"
+#include "tensorflow/lite/micro/micro_log.h"
+#include "tflm/opt/util.h"
+
+namespace kelvin::opt {
+
+void LogisticS8(int32_t input_zero_point, int32_t input_range_radius,
+ int32_t input_multiplier, int32_t input_left_shift,
+ int32_t input_size, const int8_t* input_data,
+ int8_t* output_data) {
+ static constexpr int8_t kMinInt8 = std::numeric_limits<int8_t>::min();
+ static constexpr int8_t kMaxInt8 = std::numeric_limits<int8_t>::max();
+ static constexpr int32_t kOutputZeroPoint = -128;
+
+#define INPUTS v0
+#define MASK_IF_POSITIVE v4
+#define MASK_IF_ZERO v8
+#define NEG_ABS_INPUT v12
+ int i = 0;
+ do {
+ int count_this_iter = std::min(32L, input_size - i);
+
+ // Load inputs, widen to 32-bits and apply zero-point.
+ vld_b_l_xx(INPUTS, input_data + i, count_this_iter);
+ vaddw_h_vx(INPUTS, INPUTS, 0);
+ vaddw_w_vx(v2, v1, -input_zero_point);
+ vaddw_w_vx(INPUTS, v0, -input_zero_point);
+
+ // MultiplyByQuantizedMultiplier
+ vmul_w_vx_m(v16, INPUTS, (1 << LEFT_SHIFT(input_left_shift)));
+ vdmulh_w_r_vx_m(v16, v16, input_multiplier);
+ vsha_w_r_vx_m(v16, v16, RIGHT_SHIFT(input_left_shift));
+
+ // Start gemmlowp::logistic
+ // Compute a mask of positive inputs
+ constexpr int32_t kInt32_AllOnes = ~0L;
+ vgt_w_vx_m(v20, v16, 0);
+ vdup_w_x_m(v4, kInt32_AllOnes);
+ vsel_w_vx_m(MASK_IF_POSITIVE, v20, 0);
+
+ // Compute a mask of zero inputs
+ veq_w_vx_m(v20, v16, 0);
+ vdup_w_x_m(v8, -1);
+ vsel_w_vx_m(MASK_IF_ZERO, v20, 0);
+
+ // Calculate absolute values of inputs, and negate
+ vabsd_w_vx_m(NEG_ABS_INPUT, v16, 0);
+ vrsub_w_vx_m(NEG_ABS_INPUT, NEG_ABS_INPUT, 0);
+
+ // Start gemmlowp::exp_on_negative_values
+ constexpr int32_t kQ4_OneQuarter = 0x02000000;
+ constexpr int32_t kQ4_OneQuarterMinusOne = 0x01FFFFFF;
+ vand_w_vx_m(v56, NEG_ABS_INPUT, kQ4_OneQuarterMinusOne);
+ vsub_w_vx_m(v56, v56, kQ4_OneQuarter);
+
+ // remainders -- live until after barrel shifters
+ vsub_w_vv_m(v60, v56, NEG_ABS_INPUT);
+
+ // Start gemmlowp::exp_on_interval_between_negative_one_quarter_and_0_excl
+ vsha_w_r_vx_m(v56, v56, -4);
+ constexpr int32_t kQ4_OneEighth = 0x10000000;
+ vdup_w_x_m(v20, kQ4_OneEighth);
+ vadd_w_vv_m(v56, v56, v20);
+
+ vdmulh_w_r_vv_m(v16, v56, v56); // x2
+ vdmulh_w_r_vv_m(v24, v56, v16);
+ vdmulh_w_r_vv_m(v20, v16, v16);
+ vsha_w_r_vx_m(v20, v20, 2);
+
+ constexpr int32_t kQ4_ConstantTerm = 0x70f5a894;
+ constexpr int32_t kQ4_ConstantOneOverThree = 0x2aaaaaab;
+ vadd_w_vv_m(v20, v20, v24); // x4_over_4 + x3
+ vdmulh_w_r_vx_m(v20, v20,
+ kQ4_ConstantOneOverThree); // _ * constant_1_over_3
+ vadd_w_vv_m(v20, v20, v16); // _ + x2
+ vsha_w_r_vx_m(v20, v20, 1); // SaturatingRoundingMultiplyByPOT<-1>(_)
+
+ vadd_w_vv_m(v20, v56, v20); // x + x4_over_24...
+ vdmulh_w_r_vx_m(v20, v20, kQ4_ConstantTerm); // constant_term * _
+ vadd_w_vx_m(v20, v20, kQ4_ConstantTerm);
+ // End gemmlowp::exp_on_interval_between_negative_one_quarter_and_0_excl
+
+#define BARREL_SHIFTER(shift, multiplier) \
+ { \
+ vand_w_vx_m(v28, v60, 1 << shift); \
+ vne_w_vx_m(v28, v28, 0); \
+ vdmulh_w_r_vx_m(v24, v20, multiplier); \
+ vsel_w_vv_m(v24, v28, v20); \
+ vmv_v_m(v20, v24); \
+ }
+
+ BARREL_SHIFTER(25, 0x63afbe7b);
+ BARREL_SHIFTER(26, 0x4da2cbf2);
+ BARREL_SHIFTER(27, 0x2f16ac6c);
+ BARREL_SHIFTER(28, 0x1152aaa4);
+ BARREL_SHIFTER(29, 0x02582ab7);
+ BARREL_SHIFTER(30, 0x000afe11);
+ BARREL_SHIFTER(0, 0x000000f2);
+#undef BARREL_SHIFTER
+
+ constexpr int32_t kResultF_One = 0x7fffffff;
+ vne_w_vx_m(v56, NEG_ABS_INPUT, 0);
+ vsel_w_vx_m(v24, v56, kResultF_One);
+ // End gemmlowp::exp_on_negative_values
+
+ // Begin gemmlowp::one_over_one_plus_x_for_x_in_0_1
+ constexpr int32_t kF2_Constant48Over17 = 0x5a5a5a5a;
+ constexpr int32_t kF2_ConstantNeg32Over17 = 0xc3c3c3c4;
+ constexpr int32_t kF0_OneHalf = 0x40000000;
+ vshl_w_vx_m(v24, v24, 1); // x0 >> 1
+ vadd_w_vx_m(v24, v24, kF0_OneHalf); // _ + ((x1 + 1) >> 1)
+ vmv_v_m(v20, v24); // half_denominators
+
+ vdmulh_w_r_vx_m(v24, v24,
+ kF2_ConstantNeg32Over17); // half_denominator * -32/17
+ vadd_w_vx_m(v24, v24, kF2_Constant48Over17); // _ + 48/17
+
+ constexpr int32_t kF2_One = 0x20000000;
+
+#define DIVISION() \
+ { \
+ vdmulh_w_r_vv_m(v28, v24, v20); \
+ vmv_v_m(v36, v28); \
+ vgt_w_vx_m(v32, v28, kF2_One); \
+ vsel_w_vx_m(v28, v32, kF2_One); \
+ vdup_w_x_m(v32, kF2_One); \
+ vsub_w_vv_m(v40, v28, v32); \
+ vdup_w_x_m(v32, 0xffffffff); \
+ vsub_w_vv_m(v40, v32, v40); \
+ vadd_w_vx_m(v40, v40, 1); \
+ vle_w_vx_m(v32, v36, kF2_One); \
+ vsel_w_vx_m(v36, v32, kF2_One); \
+ vdup_w_x_m(v32, kF2_One); \
+ vsub_w_vv_m(v36, v32, v36); \
+ vor_vv_m(v40, v36, v40); \
+ vdmulh_w_r_vv_m(v40, v40, v24); \
+ vsha_w_r_vx_m(v40, v40, -2); \
+ vadd_w_vv_m(v40, v40, v24); \
+ vmv_v_m(v24, v40); \
+ }
+
+ DIVISION();
+ DIVISION();
+ DIVISION();
+#undef DIVISION
+
+ vsll_w_vx_m(v40, v40, 1); // result_if_positive
+ // End gemmlowp::one_over_one_plus_x_for_x_in_0_1
+
+ vgt_w_vx_m(v32, v40, kResultF_One);
+ vsel_w_vx_m(v44, v32, kResultF_One); // values >1
+ vdup_w_x_m(v32, kResultF_One);
+ vsub_w_vv_m(v44, v32, v44);
+ vdup_w_x_m(v32, 0xffffffff);
+ vsub_w_vv_m(v44, v32, v44);
+ vadd_w_vx_m(v44, v44, 1);
+ vle_w_vx_m(v36, v40, kResultF_One);
+ vsel_w_vx_m(v32, v36, kResultF_One);
+ vdup_w_x_m(v36, kResultF_One);
+ vsub_w_vv_m(v32, v36, v40);
+ vor_vv_m(v44, v44, v32); // result_if_negative
+
+ vsel_w_vv_m(v40, MASK_IF_POSITIVE, v44);
+ vmv_v_m(v56, v40);
+
+ constexpr int32_t kResultF_OneHalf = 0x40000000;
+ vdup_w_x_m(v48, kResultF_OneHalf); // 1/2
+ vsel_w_vv_m(v48, MASK_IF_ZERO, v56);
+ vmv_v_m(v16, v48);
+ // End gemmlowp::logistic
+
+ vle_w_vx_m(v48, INPUTS, -input_range_radius);
+ vge_w_vx_m(v56, INPUTS, input_range_radius);
+ vor_vv_m(v12, v48, v56);
+ vne_w_vx_m(v12, v12, 1);
+
+ vmul_w_vx_m(v48, v48, static_cast<int32_t>(kMinInt8));
+ vmul_w_vx_m(v8, v56, static_cast<int32_t>(kMaxInt8));
+
+ vmul_w_vv_m(v16, v16, v12);
+ vsha_w_r_vx_m(v16, v16, 23);
+ vadd_w_vx_m(v16, v16, kOutputZeroPoint);
+ vmax_w_vx_m(v16, v16, kMinInt8);
+ vmin_w_vx_m(v16, v16, kMaxInt8);
+ vmul_w_vv_m(v12, v12, v16);
+ vor_vv_m(v48, v48, v56);
+ vor_vv_m(v48, v48, v8);
+ vor_vv_m(v48, v48, v12);
+ vsraqs_b_vx(v48, v48, 0);
+ vst_b_l_xx(v48, output_data + i, count_this_iter);
+
+ i += count_this_iter;
+ } while (i < input_size);
+}
+} // namespace kelvin::opt
diff --git a/tflm/opt/opt.h b/tflm/opt/opt.h
index 76d5218..4fb9c51 100644
--- a/tflm/opt/opt.h
+++ b/tflm/opt/opt.h
@@ -105,6 +105,10 @@
const tflite::RuntimeShape& input_shape,
const int8_t* input_data,
const tflite::RuntimeShape& output_shape, int8_t* output_data);
+void LogisticS8(int32_t input_zero_point, int32_t input_range_radius,
+ int32_t input_multiplier, int32_t input_left_shift,
+ int32_t input_size, const int8_t* input_data,
+ int8_t* output_data);
} // namespace kelvin::opt