Replace `int64_t` with `WideType`
PiperOrigin-RevId: 559161023
diff --git a/sim/BUILD b/sim/BUILD
index 53942c9..8f3bb8e 100644
--- a/sim/BUILD
+++ b/sim/BUILD
@@ -48,6 +48,7 @@
"@com_google_mpact-riscv//riscv:riscv_state",
"@com_google_mpact-sim//mpact/sim/generic:core",
"@com_google_mpact-sim//mpact/sim/generic:instruction",
+ "@com_google_mpact-sim//mpact/sim/generic:type_helpers",
],
)
diff --git a/sim/kelvin_vector_instructions.cc b/sim/kelvin_vector_instructions.cc
index d914bed..fce79e7 100644
--- a/sim/kelvin_vector_instructions.cc
+++ b/sim/kelvin_vector_instructions.cc
@@ -1,7 +1,6 @@
#include "sim/kelvin_vector_instructions.h"
#include <algorithm>
-#include <cstdint>
#include <cstdlib>
#include <functional>
#include <limits>
@@ -15,6 +14,7 @@
#include "riscv/riscv_register.h"
#include "mpact/sim/generic/data_buffer.h"
#include "mpact/sim/generic/instruction.h"
+#include "mpact/sim/generic/type_helpers.h"
namespace kelvin::sim {
@@ -554,15 +554,9 @@
// Halving addition with optional rounding bit.
template <typename T>
T KelvinVHaddHelper(bool round, T vs1, T vs2) {
- if (std::is_signed<T>::value) {
- return static_cast<T>((static_cast<int64_t>(vs1) +
- static_cast<int64_t>(vs2) + (round ? 1 : 0)) >>
- 1);
- } else {
- return static_cast<T>((static_cast<uint64_t>(vs1) +
- static_cast<uint64_t>(vs2) + (round ? 1 : 0)) >>
- 1);
- }
+ using WT = typename mpact::sim::riscv::WideType<T>::type;
+ return static_cast<T>(
+ (static_cast<WT>(vs1) + static_cast<WT>(vs2) + (round ? 1 : 0)) >> 1);
}
template <typename T>
@@ -581,15 +575,9 @@
// Halving subtraction with optional rounding bit.
template <typename T>
T KelvinVHsubHelper(bool round, T vs1, T vs2) {
- if (std::is_signed<T>::value) {
- return static_cast<T>((static_cast<int64_t>(vs1) -
- static_cast<int64_t>(vs2) + (round ? 1 : 0)) >>
- 1);
- } else {
- return static_cast<T>((static_cast<uint64_t>(vs1) -
- static_cast<uint64_t>(vs2) + (round ? 1 : 0)) >>
- 1);
- }
+ using WT = typename mpact::sim::riscv::WideType<T>::type;
+ return static_cast<T>(
+ (static_cast<WT>(vs1) - static_cast<WT>(vs2) + (round ? 1 : 0)) >> 1);
}
template <typename T>
@@ -746,10 +734,11 @@
// result.
template <typename T>
T KelvinVShiftHelper(bool round, T vs1, T vs2) {
+ using WT = typename mpact::sim::riscv::WideType<T>::type;
if (std::is_signed<T>::value == true) {
constexpr int n = sizeof(T) * 8;
int shamt = vs2;
- int64_t s = vs1;
+ WT s = vs1;
if (!vs1) {
return 0;
} else if (vs1 < 0 && shamt >= n) {
@@ -757,14 +746,17 @@
} else if (vs1 > 0 && shamt >= n) {
s = 0;
} else if (shamt > 0) {
- s = (static_cast<int64_t>(vs1) + (round ? (1ll << (shamt - 1)) : 0)) >>
+ s = (static_cast<WT>(vs1) +
+ (round ? static_cast<WT>(1ll << (shamt - 1)) : 0)) >>
shamt;
} else { // shamt < 0
using UT = typename std::make_unsigned<T>::type;
UT ushamt = static_cast<UT>(-shamt <= n ? -shamt : n);
CHECK_LE(ushamt, n);
CHECK_GE(ushamt, 0);
- s = static_cast<int64_t>(static_cast<uint64_t>(vs1) << ushamt);
+ // Use unsigned WideType to prevent undefined negative shift.
+ using UWT = typename mpact::sim::riscv::WideType<UT>::type;
+ s = static_cast<WT>(static_cast<UWT>(vs1) << ushamt);
}
T neg_max = std::numeric_limits<T>::min();
T pos_max = std::numeric_limits<T>::max();
@@ -777,18 +769,19 @@
constexpr int n = sizeof(T) * 8;
// Shift can be positive/negative.
int shamt = static_cast<typename std::make_signed<T>::type>(vs2);
- uint64_t s = vs1;
+ WT s = vs1;
if (!vs1) {
return 0;
} else if (shamt > n) {
s = 0;
} else if (shamt > 0) {
- s = (static_cast<uint64_t>(vs1) + (round ? (1ull << (shamt - 1)) : 0)) >>
+ s = (static_cast<WT>(vs1) +
+ (round ? static_cast<WT>(1ull << (shamt - 1)) : 0)) >>
shamt;
} else {
using UT = typename std::make_unsigned<T>::type;
UT ushamt = static_cast<UT>(-shamt <= n ? -shamt : n);
- s = static_cast<uint64_t>(vs1) << (ushamt);
+ s = static_cast<WT>(vs1) << (ushamt);
}
T pos_max = std::numeric_limits<T>::max();
bool pos_sat = vs1 && (shamt < -n || s > pos_max);
@@ -891,10 +884,10 @@
static_assert(2 * sizeof(Td) == sizeof(Ts) || 4 * sizeof(Td) == sizeof(Ts));
constexpr int src_bits = sizeof(Ts) * 8;
vs2 &= (src_bits - 1);
-
- int64_t res =
- (static_cast<int64_t>(vs1) + (vs2 && round ? (1ll << (vs2 - 1)) : 0)) >>
- vs2;
+ using WTs = typename mpact::sim::riscv::WideType<Ts>::type;
+ WTs res = (static_cast<WTs>(vs1) +
+ (vs2 && round ? static_cast<WTs>(1ll << (vs2 - 1)) : 0)) >>
+ vs2;
bool neg_sat = res < std::numeric_limits<Td>::min();
bool pos_sat = res > std::numeric_limits<Td>::max();
@@ -923,16 +916,12 @@
// Multiplication of vector elements.
template <typename T>
void KelvinVMul(bool scalar, bool strip_mine, Instruction *inst) {
- KelvinBinaryVectorOp(inst, scalar, strip_mine,
- std::function<T(T, T)>([](T vs1, T vs2) -> T {
- if (std::is_signed<T>::value) {
- return static_cast<T>(static_cast<int64_t>(vs1) *
- static_cast<int64_t>(vs2));
- } else {
- return static_cast<T>(static_cast<uint64_t>(vs1) *
- static_cast<uint64_t>(vs2));
- }
- }));
+ KelvinBinaryVectorOp(
+ inst, scalar, strip_mine, std::function<T(T, T)>([](T vs1, T vs2) -> T {
+ using WT = typename mpact::sim::riscv::WideType<T>::type;
+
+ return static_cast<T>(static_cast<WT>(vs1) * static_cast<WT>(vs2));
+ }));
}
template void KelvinVMul<int8_t>(bool, bool, Instruction *);
template void KelvinVMul<int16_t>(bool, bool, Instruction *);
@@ -943,19 +932,16 @@
void KelvinVMuls(bool scalar, bool strip_mine, Instruction *inst) {
KelvinBinaryVectorOp(
inst, scalar, strip_mine, std::function<T(T, T)>([](T vs1, T vs2) -> T {
+ using WT = typename mpact::sim::riscv::WideType<T>::type;
+ WT result = static_cast<WT>(vs1) * static_cast<WT>(vs2);
if (std::is_signed<T>::value) {
- int64_t result =
- static_cast<int64_t>(vs1) * static_cast<int64_t>(vs2);
result = std::max(
- static_cast<int64_t>(std::numeric_limits<T>::min()),
- std::min(static_cast<int64_t>(std::numeric_limits<T>::max()),
- result));
+ static_cast<WT>(std::numeric_limits<T>::min()),
+ std::min(static_cast<WT>(std::numeric_limits<T>::max()), result));
return result;
} else {
- uint64_t result =
- static_cast<uint64_t>(vs1) * static_cast<uint64_t>(vs2);
- result = std::min(
- static_cast<uint64_t>(std::numeric_limits<T>::max()), result);
+ result =
+ std::min(static_cast<WT>(std::numeric_limits<T>::max()), result);
return result;
}
}));
@@ -984,16 +970,12 @@
// Returns high half.
template <typename T>
T KelvinVMulhHelper(bool round, T vs1, T vs2) {
+ using WT = typename mpact::sim::riscv::WideType<T>::type;
constexpr int n = sizeof(T) * 8;
- if (std::is_signed<T>::value) {
- int64_t result = static_cast<int64_t>(vs1) * static_cast<int64_t>(vs2);
- result += round ? 1ll << (n - 1) : 0;
- return static_cast<uint64_t>(result) >> n;
- } else {
- uint64_t result = static_cast<uint64_t>(vs1) * static_cast<uint64_t>(vs2);
- result += round ? 1ull << (n - 1) : 0;
- return result >> n;
- }
+
+ WT result = static_cast<WT>(vs1) * static_cast<WT>(vs2);
+ result += round ? static_cast<WT>(1ll << (n - 1)) : 0;
+ return result >> n;
}
template <typename T>
@@ -1014,11 +996,12 @@
template <typename T>
T KelvinVDmulhHelper(bool round, bool round_neg, T vs1, T vs2) {
constexpr int n = sizeof(T) * 8;
- int64_t result = static_cast<int64_t>(vs1) * static_cast<int64_t>(vs2);
+ using WT = typename mpact::sim::riscv::WideType<T>::type;
+ WT result = static_cast<WT>(vs1) * static_cast<WT>(vs2);
if (round) {
- int64_t rnd = 0x40000000ll >> (32 - n);
+ WT rnd = static_cast<WT>(0x40000000ll >> (32 - n));
if (result < 0 && round_neg) {
- rnd = (-0x40000000ll) >> (32 - n);
+ rnd = static_cast<WT>((-0x40000000ll) >> (32 - n));
}
result += rnd;
}
@@ -1046,8 +1029,9 @@
KelvinBinaryVectorOp<false /* halftype */, false /* widen_dst */, T, T, T, T>(
inst, scalar, strip_mine,
std::function<T(T, T, T)>([](T vd, T vs1, T vs2) -> T {
- return static_cast<int64_t>(vd) +
- static_cast<int64_t>(vs1) * static_cast<int64_t>(vs2);
+ using WT = typename mpact::sim::riscv::WideType<T>::type;
+ return static_cast<WT>(vd) +
+ static_cast<WT>(vs1) * static_cast<WT>(vs2);
}));
}
template void KelvinVMacc<int8_t>(bool, bool, Instruction *);
@@ -1060,8 +1044,9 @@
KelvinBinaryVectorOp<false /* halftype */, false /* widen_dst */, T, T, T, T>(
inst, scalar, strip_mine,
std::function<T(T, T, T)>([](T vd, T vs1, T vs2) -> T {
- return static_cast<int64_t>(vs1) +
- static_cast<int64_t>(vd) * static_cast<int64_t>(vs2);
+ using WT = typename mpact::sim::riscv::WideType<T>::type;
+ return static_cast<WT>(vs1) +
+ static_cast<WT>(vd) * static_cast<WT>(vs2);
}));
}
template void KelvinVMadd<int8_t>(bool, bool, Instruction *);