Replace `int64_t` with `WideType`

PiperOrigin-RevId: 559161023
diff --git a/sim/BUILD b/sim/BUILD
index 53942c9..8f3bb8e 100644
--- a/sim/BUILD
+++ b/sim/BUILD
@@ -48,6 +48,7 @@
         "@com_google_mpact-riscv//riscv:riscv_state",
         "@com_google_mpact-sim//mpact/sim/generic:core",
         "@com_google_mpact-sim//mpact/sim/generic:instruction",
+        "@com_google_mpact-sim//mpact/sim/generic:type_helpers",
     ],
 )
 
diff --git a/sim/kelvin_vector_instructions.cc b/sim/kelvin_vector_instructions.cc
index d914bed..fce79e7 100644
--- a/sim/kelvin_vector_instructions.cc
+++ b/sim/kelvin_vector_instructions.cc
@@ -1,7 +1,6 @@
 #include "sim/kelvin_vector_instructions.h"
 
 #include <algorithm>
-#include <cstdint>
 #include <cstdlib>
 #include <functional>
 #include <limits>
@@ -15,6 +14,7 @@
 #include "riscv/riscv_register.h"
 #include "mpact/sim/generic/data_buffer.h"
 #include "mpact/sim/generic/instruction.h"
+#include "mpact/sim/generic/type_helpers.h"
 
 namespace kelvin::sim {
 
@@ -554,15 +554,9 @@
 // Halving addition with optional rounding bit.
 template <typename T>
 T KelvinVHaddHelper(bool round, T vs1, T vs2) {
-  if (std::is_signed<T>::value) {
-    return static_cast<T>((static_cast<int64_t>(vs1) +
-                           static_cast<int64_t>(vs2) + (round ? 1 : 0)) >>
-                          1);
-  } else {
-    return static_cast<T>((static_cast<uint64_t>(vs1) +
-                           static_cast<uint64_t>(vs2) + (round ? 1 : 0)) >>
-                          1);
-  }
+  using WT = typename mpact::sim::riscv::WideType<T>::type;
+  return static_cast<T>(
+      (static_cast<WT>(vs1) + static_cast<WT>(vs2) + (round ? 1 : 0)) >> 1);
 }
 
 template <typename T>
@@ -581,15 +575,9 @@
 // Halving subtraction with optional rounding bit.
 template <typename T>
 T KelvinVHsubHelper(bool round, T vs1, T vs2) {
-  if (std::is_signed<T>::value) {
-    return static_cast<T>((static_cast<int64_t>(vs1) -
-                           static_cast<int64_t>(vs2) + (round ? 1 : 0)) >>
-                          1);
-  } else {
-    return static_cast<T>((static_cast<uint64_t>(vs1) -
-                           static_cast<uint64_t>(vs2) + (round ? 1 : 0)) >>
-                          1);
-  }
+  using WT = typename mpact::sim::riscv::WideType<T>::type;
+  return static_cast<T>(
+      (static_cast<WT>(vs1) - static_cast<WT>(vs2) + (round ? 1 : 0)) >> 1);
 }
 
 template <typename T>
@@ -746,10 +734,11 @@
 // result.
 template <typename T>
 T KelvinVShiftHelper(bool round, T vs1, T vs2) {
+  using WT = typename mpact::sim::riscv::WideType<T>::type;
   if (std::is_signed<T>::value == true) {
     constexpr int n = sizeof(T) * 8;
     int shamt = vs2;
-    int64_t s = vs1;
+    WT s = vs1;
     if (!vs1) {
       return 0;
     } else if (vs1 < 0 && shamt >= n) {
@@ -757,14 +746,17 @@
     } else if (vs1 > 0 && shamt >= n) {
       s = 0;
     } else if (shamt > 0) {
-      s = (static_cast<int64_t>(vs1) + (round ? (1ll << (shamt - 1)) : 0)) >>
+      s = (static_cast<WT>(vs1) +
+           (round ? static_cast<WT>(1ll << (shamt - 1)) : 0)) >>
           shamt;
     } else {  // shamt < 0
       using UT = typename std::make_unsigned<T>::type;
       UT ushamt = static_cast<UT>(-shamt <= n ? -shamt : n);
       CHECK_LE(ushamt, n);
       CHECK_GE(ushamt, 0);
-      s = static_cast<int64_t>(static_cast<uint64_t>(vs1) << ushamt);
+      // Use unsigned WideType to prevent undefined negative shift.
+      using UWT = typename mpact::sim::riscv::WideType<UT>::type;
+      s = static_cast<WT>(static_cast<UWT>(vs1) << ushamt);
     }
     T neg_max = std::numeric_limits<T>::min();
     T pos_max = std::numeric_limits<T>::max();
@@ -777,18 +769,19 @@
     constexpr int n = sizeof(T) * 8;
     // Shift can be positive/negative.
     int shamt = static_cast<typename std::make_signed<T>::type>(vs2);
-    uint64_t s = vs1;
+    WT s = vs1;
     if (!vs1) {
       return 0;
     } else if (shamt > n) {
       s = 0;
     } else if (shamt > 0) {
-      s = (static_cast<uint64_t>(vs1) + (round ? (1ull << (shamt - 1)) : 0)) >>
+      s = (static_cast<WT>(vs1) +
+           (round ? static_cast<WT>(1ull << (shamt - 1)) : 0)) >>
           shamt;
     } else {
       using UT = typename std::make_unsigned<T>::type;
       UT ushamt = static_cast<UT>(-shamt <= n ? -shamt : n);
-      s = static_cast<uint64_t>(vs1) << (ushamt);
+      s = static_cast<WT>(vs1) << (ushamt);
     }
     T pos_max = std::numeric_limits<T>::max();
     bool pos_sat = vs1 && (shamt < -n || s > pos_max);
@@ -891,10 +884,10 @@
   static_assert(2 * sizeof(Td) == sizeof(Ts) || 4 * sizeof(Td) == sizeof(Ts));
   constexpr int src_bits = sizeof(Ts) * 8;
   vs2 &= (src_bits - 1);
-
-  int64_t res =
-      (static_cast<int64_t>(vs1) + (vs2 && round ? (1ll << (vs2 - 1)) : 0)) >>
-      vs2;
+  using WTs = typename mpact::sim::riscv::WideType<Ts>::type;
+  WTs res = (static_cast<WTs>(vs1) +
+             (vs2 && round ? static_cast<WTs>(1ll << (vs2 - 1)) : 0)) >>
+            vs2;
 
   bool neg_sat = res < std::numeric_limits<Td>::min();
   bool pos_sat = res > std::numeric_limits<Td>::max();
@@ -923,16 +916,12 @@
 // Multiplication of vector elements.
 template <typename T>
 void KelvinVMul(bool scalar, bool strip_mine, Instruction *inst) {
-  KelvinBinaryVectorOp(inst, scalar, strip_mine,
-                       std::function<T(T, T)>([](T vs1, T vs2) -> T {
-                         if (std::is_signed<T>::value) {
-                           return static_cast<T>(static_cast<int64_t>(vs1) *
-                                                 static_cast<int64_t>(vs2));
-                         } else {
-                           return static_cast<T>(static_cast<uint64_t>(vs1) *
-                                                 static_cast<uint64_t>(vs2));
-                         }
-                       }));
+  KelvinBinaryVectorOp(
+      inst, scalar, strip_mine, std::function<T(T, T)>([](T vs1, T vs2) -> T {
+        using WT = typename mpact::sim::riscv::WideType<T>::type;
+
+        return static_cast<T>(static_cast<WT>(vs1) * static_cast<WT>(vs2));
+      }));
 }
 template void KelvinVMul<int8_t>(bool, bool, Instruction *);
 template void KelvinVMul<int16_t>(bool, bool, Instruction *);
@@ -943,19 +932,16 @@
 void KelvinVMuls(bool scalar, bool strip_mine, Instruction *inst) {
   KelvinBinaryVectorOp(
       inst, scalar, strip_mine, std::function<T(T, T)>([](T vs1, T vs2) -> T {
+        using WT = typename mpact::sim::riscv::WideType<T>::type;
+        WT result = static_cast<WT>(vs1) * static_cast<WT>(vs2);
         if (std::is_signed<T>::value) {
-          int64_t result =
-              static_cast<int64_t>(vs1) * static_cast<int64_t>(vs2);
           result = std::max(
-              static_cast<int64_t>(std::numeric_limits<T>::min()),
-              std::min(static_cast<int64_t>(std::numeric_limits<T>::max()),
-                       result));
+              static_cast<WT>(std::numeric_limits<T>::min()),
+              std::min(static_cast<WT>(std::numeric_limits<T>::max()), result));
           return result;
         } else {
-          uint64_t result =
-              static_cast<uint64_t>(vs1) * static_cast<uint64_t>(vs2);
-          result = std::min(
-              static_cast<uint64_t>(std::numeric_limits<T>::max()), result);
+          result =
+              std::min(static_cast<WT>(std::numeric_limits<T>::max()), result);
           return result;
         }
       }));
@@ -984,16 +970,12 @@
 // Returns high half.
 template <typename T>
 T KelvinVMulhHelper(bool round, T vs1, T vs2) {
+  using WT = typename mpact::sim::riscv::WideType<T>::type;
   constexpr int n = sizeof(T) * 8;
-  if (std::is_signed<T>::value) {
-    int64_t result = static_cast<int64_t>(vs1) * static_cast<int64_t>(vs2);
-    result += round ? 1ll << (n - 1) : 0;
-    return static_cast<uint64_t>(result) >> n;
-  } else {
-    uint64_t result = static_cast<uint64_t>(vs1) * static_cast<uint64_t>(vs2);
-    result += round ? 1ull << (n - 1) : 0;
-    return result >> n;
-  }
+
+  WT result = static_cast<WT>(vs1) * static_cast<WT>(vs2);
+  result += round ? static_cast<WT>(1ll << (n - 1)) : 0;
+  return result >> n;
 }
 
 template <typename T>
@@ -1014,11 +996,12 @@
 template <typename T>
 T KelvinVDmulhHelper(bool round, bool round_neg, T vs1, T vs2) {
   constexpr int n = sizeof(T) * 8;
-  int64_t result = static_cast<int64_t>(vs1) * static_cast<int64_t>(vs2);
+  using WT = typename mpact::sim::riscv::WideType<T>::type;
+  WT result = static_cast<WT>(vs1) * static_cast<WT>(vs2);
   if (round) {
-    int64_t rnd = 0x40000000ll >> (32 - n);
+    WT rnd = static_cast<WT>(0x40000000ll >> (32 - n));
     if (result < 0 && round_neg) {
-      rnd = (-0x40000000ll) >> (32 - n);
+      rnd = static_cast<WT>((-0x40000000ll) >> (32 - n));
     }
     result += rnd;
   }
@@ -1046,8 +1029,9 @@
   KelvinBinaryVectorOp<false /* halftype */, false /* widen_dst */, T, T, T, T>(
       inst, scalar, strip_mine,
       std::function<T(T, T, T)>([](T vd, T vs1, T vs2) -> T {
-        return static_cast<int64_t>(vd) +
-               static_cast<int64_t>(vs1) * static_cast<int64_t>(vs2);
+        using WT = typename mpact::sim::riscv::WideType<T>::type;
+        return static_cast<WT>(vd) +
+               static_cast<WT>(vs1) * static_cast<WT>(vs2);
       }));
 }
 template void KelvinVMacc<int8_t>(bool, bool, Instruction *);
@@ -1060,8 +1044,9 @@
   KelvinBinaryVectorOp<false /* halftype */, false /* widen_dst */, T, T, T, T>(
       inst, scalar, strip_mine,
       std::function<T(T, T, T)>([](T vd, T vs1, T vs2) -> T {
-        return static_cast<int64_t>(vs1) +
-               static_cast<int64_t>(vd) * static_cast<int64_t>(vs2);
+        using WT = typename mpact::sim::riscv::WideType<T>::type;
+        return static_cast<WT>(vs1) +
+               static_cast<WT>(vd) * static_cast<WT>(vs2);
       }));
 }
 template void KelvinVMadd<int8_t>(bool, bool, Instruction *);