Replace `int64_t` with `WideType` PiperOrigin-RevId: 559161023
diff --git a/sim/BUILD b/sim/BUILD index 53942c9..8f3bb8e 100644 --- a/sim/BUILD +++ b/sim/BUILD
@@ -48,6 +48,7 @@ "@com_google_mpact-riscv//riscv:riscv_state", "@com_google_mpact-sim//mpact/sim/generic:core", "@com_google_mpact-sim//mpact/sim/generic:instruction", + "@com_google_mpact-sim//mpact/sim/generic:type_helpers", ], )
diff --git a/sim/kelvin_vector_instructions.cc b/sim/kelvin_vector_instructions.cc index d914bed..fce79e7 100644 --- a/sim/kelvin_vector_instructions.cc +++ b/sim/kelvin_vector_instructions.cc
@@ -1,7 +1,6 @@ #include "sim/kelvin_vector_instructions.h" #include <algorithm> -#include <cstdint> #include <cstdlib> #include <functional> #include <limits> @@ -15,6 +14,7 @@ #include "riscv/riscv_register.h" #include "mpact/sim/generic/data_buffer.h" #include "mpact/sim/generic/instruction.h" +#include "mpact/sim/generic/type_helpers.h" namespace kelvin::sim { @@ -554,15 +554,9 @@ // Halving addition with optional rounding bit. template <typename T> T KelvinVHaddHelper(bool round, T vs1, T vs2) { - if (std::is_signed<T>::value) { - return static_cast<T>((static_cast<int64_t>(vs1) + - static_cast<int64_t>(vs2) + (round ? 1 : 0)) >> - 1); - } else { - return static_cast<T>((static_cast<uint64_t>(vs1) + - static_cast<uint64_t>(vs2) + (round ? 1 : 0)) >> - 1); - } + using WT = typename mpact::sim::riscv::WideType<T>::type; + return static_cast<T>( + (static_cast<WT>(vs1) + static_cast<WT>(vs2) + (round ? 1 : 0)) >> 1); } template <typename T> @@ -581,15 +575,9 @@ // Halving subtraction with optional rounding bit. template <typename T> T KelvinVHsubHelper(bool round, T vs1, T vs2) { - if (std::is_signed<T>::value) { - return static_cast<T>((static_cast<int64_t>(vs1) - - static_cast<int64_t>(vs2) + (round ? 1 : 0)) >> - 1); - } else { - return static_cast<T>((static_cast<uint64_t>(vs1) - - static_cast<uint64_t>(vs2) + (round ? 1 : 0)) >> - 1); - } + using WT = typename mpact::sim::riscv::WideType<T>::type; + return static_cast<T>( + (static_cast<WT>(vs1) - static_cast<WT>(vs2) + (round ? 1 : 0)) >> 1); } template <typename T> @@ -746,10 +734,11 @@ // result. template <typename T> T KelvinVShiftHelper(bool round, T vs1, T vs2) { + using WT = typename mpact::sim::riscv::WideType<T>::type; if (std::is_signed<T>::value == true) { constexpr int n = sizeof(T) * 8; int shamt = vs2; - int64_t s = vs1; + WT s = vs1; if (!vs1) { return 0; } else if (vs1 < 0 && shamt >= n) { @@ -757,14 +746,17 @@ } else if (vs1 > 0 && shamt >= n) { s = 0; } else if (shamt > 0) { - s = (static_cast<int64_t>(vs1) + (round ? (1ll << (shamt - 1)) : 0)) >> + s = (static_cast<WT>(vs1) + + (round ? static_cast<WT>(1ll << (shamt - 1)) : 0)) >> shamt; } else { // shamt < 0 using UT = typename std::make_unsigned<T>::type; UT ushamt = static_cast<UT>(-shamt <= n ? -shamt : n); CHECK_LE(ushamt, n); CHECK_GE(ushamt, 0); - s = static_cast<int64_t>(static_cast<uint64_t>(vs1) << ushamt); + // Use unsigned WideType to prevent undefined negative shift. + using UWT = typename mpact::sim::riscv::WideType<UT>::type; + s = static_cast<WT>(static_cast<UWT>(vs1) << ushamt); } T neg_max = std::numeric_limits<T>::min(); T pos_max = std::numeric_limits<T>::max(); @@ -777,18 +769,19 @@ constexpr int n = sizeof(T) * 8; // Shift can be positive/negative. int shamt = static_cast<typename std::make_signed<T>::type>(vs2); - uint64_t s = vs1; + WT s = vs1; if (!vs1) { return 0; } else if (shamt > n) { s = 0; } else if (shamt > 0) { - s = (static_cast<uint64_t>(vs1) + (round ? (1ull << (shamt - 1)) : 0)) >> + s = (static_cast<WT>(vs1) + + (round ? static_cast<WT>(1ull << (shamt - 1)) : 0)) >> shamt; } else { using UT = typename std::make_unsigned<T>::type; UT ushamt = static_cast<UT>(-shamt <= n ? -shamt : n); - s = static_cast<uint64_t>(vs1) << (ushamt); + s = static_cast<WT>(vs1) << (ushamt); } T pos_max = std::numeric_limits<T>::max(); bool pos_sat = vs1 && (shamt < -n || s > pos_max); @@ -891,10 +884,10 @@ static_assert(2 * sizeof(Td) == sizeof(Ts) || 4 * sizeof(Td) == sizeof(Ts)); constexpr int src_bits = sizeof(Ts) * 8; vs2 &= (src_bits - 1); - - int64_t res = - (static_cast<int64_t>(vs1) + (vs2 && round ? (1ll << (vs2 - 1)) : 0)) >> - vs2; + using WTs = typename mpact::sim::riscv::WideType<Ts>::type; + WTs res = (static_cast<WTs>(vs1) + + (vs2 && round ? static_cast<WTs>(1ll << (vs2 - 1)) : 0)) >> + vs2; bool neg_sat = res < std::numeric_limits<Td>::min(); bool pos_sat = res > std::numeric_limits<Td>::max(); @@ -923,16 +916,12 @@ // Multiplication of vector elements. template <typename T> void KelvinVMul(bool scalar, bool strip_mine, Instruction *inst) { - KelvinBinaryVectorOp(inst, scalar, strip_mine, - std::function<T(T, T)>([](T vs1, T vs2) -> T { - if (std::is_signed<T>::value) { - return static_cast<T>(static_cast<int64_t>(vs1) * - static_cast<int64_t>(vs2)); - } else { - return static_cast<T>(static_cast<uint64_t>(vs1) * - static_cast<uint64_t>(vs2)); - } - })); + KelvinBinaryVectorOp( + inst, scalar, strip_mine, std::function<T(T, T)>([](T vs1, T vs2) -> T { + using WT = typename mpact::sim::riscv::WideType<T>::type; + + return static_cast<T>(static_cast<WT>(vs1) * static_cast<WT>(vs2)); + })); } template void KelvinVMul<int8_t>(bool, bool, Instruction *); template void KelvinVMul<int16_t>(bool, bool, Instruction *); @@ -943,19 +932,16 @@ void KelvinVMuls(bool scalar, bool strip_mine, Instruction *inst) { KelvinBinaryVectorOp( inst, scalar, strip_mine, std::function<T(T, T)>([](T vs1, T vs2) -> T { + using WT = typename mpact::sim::riscv::WideType<T>::type; + WT result = static_cast<WT>(vs1) * static_cast<WT>(vs2); if (std::is_signed<T>::value) { - int64_t result = - static_cast<int64_t>(vs1) * static_cast<int64_t>(vs2); result = std::max( - static_cast<int64_t>(std::numeric_limits<T>::min()), - std::min(static_cast<int64_t>(std::numeric_limits<T>::max()), - result)); + static_cast<WT>(std::numeric_limits<T>::min()), + std::min(static_cast<WT>(std::numeric_limits<T>::max()), result)); return result; } else { - uint64_t result = - static_cast<uint64_t>(vs1) * static_cast<uint64_t>(vs2); - result = std::min( - static_cast<uint64_t>(std::numeric_limits<T>::max()), result); + result = + std::min(static_cast<WT>(std::numeric_limits<T>::max()), result); return result; } })); @@ -984,16 +970,12 @@ // Returns high half. template <typename T> T KelvinVMulhHelper(bool round, T vs1, T vs2) { + using WT = typename mpact::sim::riscv::WideType<T>::type; constexpr int n = sizeof(T) * 8; - if (std::is_signed<T>::value) { - int64_t result = static_cast<int64_t>(vs1) * static_cast<int64_t>(vs2); - result += round ? 1ll << (n - 1) : 0; - return static_cast<uint64_t>(result) >> n; - } else { - uint64_t result = static_cast<uint64_t>(vs1) * static_cast<uint64_t>(vs2); - result += round ? 1ull << (n - 1) : 0; - return result >> n; - } + + WT result = static_cast<WT>(vs1) * static_cast<WT>(vs2); + result += round ? static_cast<WT>(1ll << (n - 1)) : 0; + return result >> n; } template <typename T> @@ -1014,11 +996,12 @@ template <typename T> T KelvinVDmulhHelper(bool round, bool round_neg, T vs1, T vs2) { constexpr int n = sizeof(T) * 8; - int64_t result = static_cast<int64_t>(vs1) * static_cast<int64_t>(vs2); + using WT = typename mpact::sim::riscv::WideType<T>::type; + WT result = static_cast<WT>(vs1) * static_cast<WT>(vs2); if (round) { - int64_t rnd = 0x40000000ll >> (32 - n); + WT rnd = static_cast<WT>(0x40000000ll >> (32 - n)); if (result < 0 && round_neg) { - rnd = (-0x40000000ll) >> (32 - n); + rnd = static_cast<WT>((-0x40000000ll) >> (32 - n)); } result += rnd; } @@ -1046,8 +1029,9 @@ KelvinBinaryVectorOp<false /* halftype */, false /* widen_dst */, T, T, T, T>( inst, scalar, strip_mine, std::function<T(T, T, T)>([](T vd, T vs1, T vs2) -> T { - return static_cast<int64_t>(vd) + - static_cast<int64_t>(vs1) * static_cast<int64_t>(vs2); + using WT = typename mpact::sim::riscv::WideType<T>::type; + return static_cast<WT>(vd) + + static_cast<WT>(vs1) * static_cast<WT>(vs2); })); } template void KelvinVMacc<int8_t>(bool, bool, Instruction *); @@ -1060,8 +1044,9 @@ KelvinBinaryVectorOp<false /* halftype */, false /* widen_dst */, T, T, T, T>( inst, scalar, strip_mine, std::function<T(T, T, T)>([](T vd, T vs1, T vs2) -> T { - return static_cast<int64_t>(vs1) + - static_cast<int64_t>(vd) * static_cast<int64_t>(vs2); + using WT = typename mpact::sim::riscv::WideType<T>::type; + return static_cast<WT>(vs1) + + static_cast<WT>(vd) * static_cast<WT>(vs2); })); } template void KelvinVMadd<int8_t>(bool, bool, Instruction *);