blob: 9c349c7a29ff5eba5419dc8d35031e1a8928388a [file] [log] [blame]
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "crt/kelvin.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/micro/micro_log.h"
#include "tflm/opt/util.h"
namespace kelvin::opt {
void LogisticS8(int32_t input_zero_point, int32_t input_range_radius,
int32_t input_multiplier, int32_t input_left_shift,
int32_t input_size, const int8_t* input_data,
int8_t* output_data) {
static constexpr int8_t kMinInt8 = std::numeric_limits<int8_t>::min();
static constexpr int8_t kMaxInt8 = std::numeric_limits<int8_t>::max();
static constexpr int32_t kOutputZeroPoint = -128;
#define INPUTS v0
#define MASK_IF_POSITIVE v4
#define MASK_IF_ZERO v8
#define NEG_ABS_INPUT v12
int i = 0;
do {
int count_this_iter = std::min(32L, input_size - i);
// Load inputs, widen to 32-bits and apply zero-point.
vld_b_l_xx(INPUTS, input_data + i, count_this_iter);
vaddw_h_vx(INPUTS, INPUTS, 0);
vaddw_w_vx(v2, v1, -input_zero_point);
vaddw_w_vx(INPUTS, v0, -input_zero_point);
// MultiplyByQuantizedMultiplier
vmul_w_vx_m(v16, INPUTS, (1 << LEFT_SHIFT(input_left_shift)));
vdmulh_w_r_vx_m(v16, v16, input_multiplier);
vsha_w_r_vx_m(v16, v16, RIGHT_SHIFT(input_left_shift));
// Start gemmlowp::logistic
// Compute a mask of positive inputs
constexpr int32_t kInt32_AllOnes = ~0L;
vgt_w_vx_m(v20, v16, 0);
vdup_w_x_m(v4, kInt32_AllOnes);
vsel_w_vx_m(MASK_IF_POSITIVE, v20, 0);
// Compute a mask of zero inputs
veq_w_vx_m(v20, v16, 0);
vdup_w_x_m(v8, -1);
vsel_w_vx_m(MASK_IF_ZERO, v20, 0);
// Calculate absolute values of inputs, and negate
vabsd_w_vx_m(NEG_ABS_INPUT, v16, 0);
vrsub_w_vx_m(NEG_ABS_INPUT, NEG_ABS_INPUT, 0);
// Start gemmlowp::exp_on_negative_values
constexpr int32_t kQ4_OneQuarter = 0x02000000;
constexpr int32_t kQ4_OneQuarterMinusOne = 0x01FFFFFF;
vand_w_vx_m(v56, NEG_ABS_INPUT, kQ4_OneQuarterMinusOne);
vsub_w_vx_m(v56, v56, kQ4_OneQuarter);
// remainders -- live until after barrel shifters
vsub_w_vv_m(v60, v56, NEG_ABS_INPUT);
// Start gemmlowp::exp_on_interval_between_negative_one_quarter_and_0_excl
vsha_w_r_vx_m(v56, v56, -4);
constexpr int32_t kQ4_OneEighth = 0x10000000;
vdup_w_x_m(v20, kQ4_OneEighth);
vadd_w_vv_m(v56, v56, v20);
vdmulh_w_r_vv_m(v16, v56, v56); // x2
vdmulh_w_r_vv_m(v24, v56, v16);
vdmulh_w_r_vv_m(v20, v16, v16);
vsha_w_r_vx_m(v20, v20, 2);
constexpr int32_t kQ4_ConstantTerm = 0x70f5a894;
constexpr int32_t kQ4_ConstantOneOverThree = 0x2aaaaaab;
vadd_w_vv_m(v20, v20, v24); // x4_over_4 + x3
vdmulh_w_r_vx_m(v20, v20,
kQ4_ConstantOneOverThree); // _ * constant_1_over_3
vadd_w_vv_m(v20, v20, v16); // _ + x2
vsha_w_r_vx_m(v20, v20, 1); // SaturatingRoundingMultiplyByPOT<-1>(_)
vadd_w_vv_m(v20, v56, v20); // x + x4_over_24...
vdmulh_w_r_vx_m(v20, v20, kQ4_ConstantTerm); // constant_term * _
vadd_w_vx_m(v20, v20, kQ4_ConstantTerm);
// End gemmlowp::exp_on_interval_between_negative_one_quarter_and_0_excl
#define BARREL_SHIFTER(shift, multiplier) \
{ \
vand_w_vx_m(v28, v60, 1 << shift); \
vne_w_vx_m(v28, v28, 0); \
vdmulh_w_r_vx_m(v24, v20, multiplier); \
vsel_w_vv_m(v24, v28, v20); \
vmv_v_m(v20, v24); \
}
BARREL_SHIFTER(25, 0x63afbe7b);
BARREL_SHIFTER(26, 0x4da2cbf2);
BARREL_SHIFTER(27, 0x2f16ac6c);
BARREL_SHIFTER(28, 0x1152aaa4);
BARREL_SHIFTER(29, 0x02582ab7);
BARREL_SHIFTER(30, 0x000afe11);
BARREL_SHIFTER(0, 0x000000f2);
#undef BARREL_SHIFTER
constexpr int32_t kResultF_One = 0x7fffffff;
vne_w_vx_m(v56, NEG_ABS_INPUT, 0);
vsel_w_vx_m(v24, v56, kResultF_One);
// End gemmlowp::exp_on_negative_values
// Begin gemmlowp::one_over_one_plus_x_for_x_in_0_1
constexpr int32_t kF2_Constant48Over17 = 0x5a5a5a5a;
constexpr int32_t kF2_ConstantNeg32Over17 = 0xc3c3c3c4;
constexpr int32_t kF0_OneHalf = 0x40000000;
vshl_w_vx_m(v24, v24, 1); // x0 >> 1
vadd_w_vx_m(v24, v24, kF0_OneHalf); // _ + ((x1 + 1) >> 1)
vmv_v_m(v20, v24); // half_denominators
vdmulh_w_r_vx_m(v24, v24,
kF2_ConstantNeg32Over17); // half_denominator * -32/17
vadd_w_vx_m(v24, v24, kF2_Constant48Over17); // _ + 48/17
constexpr int32_t kF2_One = 0x20000000;
#define DIVISION() \
{ \
vdmulh_w_r_vv_m(v28, v24, v20); \
vmv_v_m(v36, v28); \
vgt_w_vx_m(v32, v28, kF2_One); \
vsel_w_vx_m(v28, v32, kF2_One); \
vdup_w_x_m(v32, kF2_One); \
vsub_w_vv_m(v40, v28, v32); \
vdup_w_x_m(v32, 0xffffffff); \
vsub_w_vv_m(v40, v32, v40); \
vadd_w_vx_m(v40, v40, 1); \
vle_w_vx_m(v32, v36, kF2_One); \
vsel_w_vx_m(v36, v32, kF2_One); \
vdup_w_x_m(v32, kF2_One); \
vsub_w_vv_m(v36, v32, v36); \
vor_vv_m(v40, v36, v40); \
vdmulh_w_r_vv_m(v40, v40, v24); \
vsha_w_r_vx_m(v40, v40, -2); \
vadd_w_vv_m(v40, v40, v24); \
vmv_v_m(v24, v40); \
}
DIVISION();
DIVISION();
DIVISION();
#undef DIVISION
vsll_w_vx_m(v40, v40, 1); // result_if_positive
// End gemmlowp::one_over_one_plus_x_for_x_in_0_1
vgt_w_vx_m(v32, v40, kResultF_One);
vsel_w_vx_m(v44, v32, kResultF_One); // values >1
vdup_w_x_m(v32, kResultF_One);
vsub_w_vv_m(v44, v32, v44);
vdup_w_x_m(v32, 0xffffffff);
vsub_w_vv_m(v44, v32, v44);
vadd_w_vx_m(v44, v44, 1);
vle_w_vx_m(v36, v40, kResultF_One);
vsel_w_vx_m(v32, v36, kResultF_One);
vdup_w_x_m(v36, kResultF_One);
vsub_w_vv_m(v32, v36, v40);
vor_vv_m(v44, v44, v32); // result_if_negative
vsel_w_vv_m(v40, MASK_IF_POSITIVE, v44);
vmv_v_m(v56, v40);
constexpr int32_t kResultF_OneHalf = 0x40000000;
vdup_w_x_m(v48, kResultF_OneHalf); // 1/2
vsel_w_vv_m(v48, MASK_IF_ZERO, v56);
vmv_v_m(v16, v48);
// End gemmlowp::logistic
vle_w_vx_m(v48, INPUTS, -input_range_radius);
vge_w_vx_m(v56, INPUTS, input_range_radius);
vor_vv_m(v12, v48, v56);
vne_w_vx_m(v12, v12, 1);
vmul_w_vx_m(v48, v48, static_cast<int32_t>(kMinInt8));
vmul_w_vx_m(v8, v56, static_cast<int32_t>(kMaxInt8));
vmul_w_vv_m(v16, v16, v12);
vsha_w_r_vx_m(v16, v16, 23);
vadd_w_vx_m(v16, v16, kOutputZeroPoint);
vmax_w_vx_m(v16, v16, kMinInt8);
vmin_w_vx_m(v16, v16, kMaxInt8);
vmul_w_vv_m(v12, v12, v16);
vor_vv_m(v48, v48, v56);
vor_vv_m(v48, v48, v8);
vor_vv_m(v48, v48, v12);
vsraqs_b_vx(v48, v48, 0);
vst_b_l_xx(v48, output_data + i, count_this_iter);
i += count_this_iter;
} while (i < input_size);
}
} // namespace kelvin::opt