[otbn] Fix rounding in bn.addm in ISS
The add_with_carry function truncates its result to 256 bits, which is
what you normally want, but isn't what you want for BN.ADDM, where you
might overflow but then fix it by subtracting MOD.
I've updated the documentation to match. This now matches the RTL.
Signed-off-by: Rupert Swarbrick <rswarbrick@lowrisc.org>
diff --git a/hw/ip/otbn/data/bignum-insns.yml b/hw/ip/otbn/data/bignum-insns.yml
index 0924e7e..a61ffe1 100644
--- a/hw/ip/otbn/data/bignum-insns.yml
+++ b/hw/ip/otbn/data/bignum-insns.yml
@@ -134,22 +134,27 @@
synopsis: Pseudo-Modulo Add
operands: [wrd, wrs1, wrs2]
doc: |
- Adds two WDR values, subtracts the value of the MOD WSR once if
- the result is equal or larger than MOD, and writes the result to
- the destination WDR. This operation is a modulo addition if the
- sum of the two input registers is smaller than twice the value
- of the MOD WSR. Flags are not used or saved.
+ Add two WDR values, modulo the MOD WSR.
+
+ The values in `<wrs1>` and `<wrs2>` are summed to get an intermediate result (of width `WLEN + 1`).
+ If this result is greater than MOD then MOD is subtracted from it.
+ The result is then truncated to 256 bits and stored in `<wrd>`.
+
+ This operation correctly implements addition modulo MOD, providing that the intermediate result is less than `2 * MOD`.
+ The intermediate result is small enough if both inputs are less than `MOD`.
+
+ Flags are not used or saved.
decode: |
d = UInt(wrd)
a = UInt(wrs1)
b = UInt(wrs2)
operation: |
- (result, ) = AddWithCarry(a, b, "0")
+ result = a + b
if result >= MOD:
result = result - MOD
- WDR[d] = result
+ WDR[d] = result & ((1 << 256) - 1)
encoding:
scheme: bnam
mapping:
diff --git a/hw/ip/otbn/dv/otbnsim/sim/insn.py b/hw/ip/otbn/dv/otbnsim/sim/insn.py
index 335880a..4f048f1 100644
--- a/hw/ip/otbn/dv/otbnsim/sim/insn.py
+++ b/hw/ip/otbn/dv/otbnsim/sim/insn.py
@@ -421,12 +421,13 @@
def execute(self, state: OTBNState) -> None:
a = state.wdrs.get_reg(self.wrs1).read_unsigned()
b = state.wdrs.get_reg(self.wrs2).read_unsigned()
+ result = a + b
- (result, _) = state.add_with_carry(a, b, 0)
mod_val = state.wsrs.MOD.read_unsigned()
if result >= mod_val:
result -= mod_val
+ result = result & ((1 << 256) - 1)
state.wdrs.get_reg(self.wrd).write_unsigned(result)