[otbn] Fix use of carry flag for borrow in otbnsim
When carry is set, for subb and cmpb it acts as a borrow flag so A - B -
c must be computed (where c is the carry flag). This introduces a
seperate 'subtract_with_borrow' function to the model to deal with this.
Documentation is also updated to match.
Signed-off-by: Greg Chadwick <gac@lowrisc.org>
diff --git a/hw/ip/otbn/data/bignum-insns.yml b/hw/ip/otbn/data/bignum-insns.yml
index 5d637b1..4f7ef65 100644
--- a/hw/ip/otbn/data/bignum-insns.yml
+++ b/hw/ip/otbn/data/bignum-insns.yml
@@ -368,7 +368,7 @@
st = DecodeShiftType(shift_type)
operation: |
b_shifted = ShiftReg(b, st, sb)
- (result, flags_out) = AddWithCarry(a, -b_shifted, "0")
+ (result, flags_out) = SubtractWithBorrow(a, b_shifted, 0)
WDR[d] = result
FLAGS[flag_group] = flags_out
@@ -393,7 +393,7 @@
decode: *bn-sub-decode
operation: |
b_shifted = ShiftReg(b, st, sb)
- (result, flags_out) = AddWithCarry(a, -b_shifted, ~FLAGS[flag_group].C)
+ (result, flags_out) = SubtractWithBorrow(a, b_shifted, FLAGS[flag_group].C)
WDR[d] = result
FLAGS[flag_group] = flags_out
@@ -430,7 +430,7 @@
fg = DecodeFlagGroup(flag_group)
i = ZeroExtend(imm, WLEN)
operation: |
- (result, flags_out) = AddWithCarry(a, -i, "0")
+ (result, flags_out) = SubtractWithBorrow(a, i, 0)
WDR[d] = result
FLAGS[flag_group] = flags_out
@@ -457,7 +457,7 @@
a = UInt(wrs1)
b = UInt(wrs2)
operation: |
- (result, ) = AddWithCarry(a, -b, "0")
+ (result, ) = SubtractWithBorrow(a, b, 0)
if result < 0:
result = MOD + result
@@ -709,7 +709,7 @@
st = DecodeShiftType(shift_type)
operation: |
b_shifted = ShiftReg(b, st, sb)
- (, flags_out) = AddWithCarry(a, -b_shifted, "0")
+ (, flags_out) = SubtractWithBorrow(a, b_shifted, 0)
FLAGS[flag_group] = flags_out
encoding:
@@ -731,7 +731,7 @@
This instruction is identical to BN.SUBB, except that no result register is written.
decode: *bn-cmp-decode
operation: |
- (, flags_out) = AddWithCarry(a, -b, ~FLAGS[flag_group].C)
+ (, flags_out) = SubtractWithBorrow(a, b, FLAGS[flag_group].C)
FLAGS[flag_group] = flags_out
encoding:
diff --git a/hw/ip/otbn/doc/_index.md b/hw/ip/otbn/doc/_index.md
index c8cb298..ab0af87 100644
--- a/hw/ip/otbn/doc/_index.md
+++ b/hw/ip/otbn/doc/_index.md
@@ -437,6 +437,17 @@
return (result[WLEN-1:0], flags_out)
+def SubtractWithBorrow(a: Bits(WLEN), b: Bits(WLEN), borrow_in: Bits(1)) -> (Bits(WLEN), FlagGroup):
+ result: Bits[WLEN+1] = a - b - borrow_in
+
+ flags_out = FlagGroup()
+ flags_out.C = result[WLEN]
+ flags_out.L = result[0]
+ flags_out.M = result[WLEN-1]
+ flags_out.Z = (result[WLEN-1:0] == 0)
+
+ return (result[WLEN-1:0], flags_out)
+
def DecodeHalfWordSelect(hwsel: Bits(1)) -> HalfWord:
if hwsel == 0:
return HalfWord.LOWER
diff --git a/hw/ip/otbn/dv/otbnsim/sim/insn.py b/hw/ip/otbn/dv/otbnsim/sim/insn.py
index 387de72..60ccb83 100644
--- a/hw/ip/otbn/dv/otbnsim/sim/insn.py
+++ b/hw/ip/otbn/dv/otbnsim/sim/insn.py
@@ -507,7 +507,7 @@
a = int(state.wreg[self.wrs1])
b_shifted = ShiftReg(int(state.wreg[self.wrs2]), self.shift_type,
self.shift_bytes)
- (result, flags) = state.add_with_carry(a, -b_shifted, 0)
+ (result, flags) = state.subtract_with_borrow(a, b_shifted, 0)
state.wreg[self.wrd] = result
state.flags[self.flag_group] = flags
@@ -525,12 +525,14 @@
self.flag_group = op_vals['flag_group']
def execute(self, state: OTBNState) -> None:
+ assert (state.flags[self.flag_group].C == 0 or
+ state.flags[self.flag_group].C == 1)
+
a = int(state.wreg[self.wrs1])
b_shifted = ShiftReg(int(state.wreg[self.wrs2]), self.shift_type,
self.shift_bytes)
- (result,
- flags) = state.add_with_carry(a, -b_shifted,
- 1 - state.flags[self.flag_group].C)
+ flag_c = state.flags[self.flag_group].C
+ (result, flags) = state.subtract_with_borrow(a, b_shifted, flag_c)
state.wreg[self.wrd] = result
state.flags[self.flag_group] = flags
@@ -548,7 +550,7 @@
def execute(self, state: OTBNState) -> None:
a = int(state.wreg[self.wrs])
b = int(self.imm)
- (result, flags) = state.add_with_carry(a, -b, 0)
+ (result, flags) = state.subtract_with_borrow(a, b, 0)
state.wreg[self.wrd] = result
state.flags[self.flag_group] = flags
@@ -565,7 +567,7 @@
def execute(self, state: OTBNState) -> None:
a = int(state.wreg[self.wrs1])
b = int(state.wreg[self.wrs2])
- result, _ = state.add_with_carry(a, -b, 0)
+ result, _ = state.subtract_with_borrow(a, b, 0)
if result < 0:
result += state.mod
state.wreg[self.wrd] = result
@@ -698,7 +700,7 @@
a = int(state.wreg[self.wrs1])
b_shifted = ShiftReg(int(state.wreg[self.wrs2]), self.shift_type,
self.shift_bytes)
- (_, flags) = state.add_with_carry(a, -b_shifted, 0)
+ (_, flags) = state.subtract_with_borrow(a, b_shifted, 0)
state.flags[self.flag_group] = flags
@@ -714,11 +716,13 @@
self.flag_group = op_vals['flag_group']
def execute(self, state: OTBNState) -> None:
+ assert (state.flags[self.flag_group].C == 0 or
+ state.flags[self.flag_group].C == 1)
a = int(state.wreg[self.wrs1])
b_shifted = ShiftReg(int(state.wreg[self.wrs2]), self.shift_type,
self.shift_bytes)
- carry_flag = 1 - state.flags[self.flag_group].C
- (_, flags) = state.add_with_carry(a, -b_shifted, carry_flag)
+ flag_c = state.flags[self.flag_group].C
+ (_, flags) = state.subtract_with_borrow(a, b_shifted, flag_c)
state.flags[self.flag_group] = flags
diff --git a/hw/ip/otbn/dv/otbnsim/sim/state.py b/hw/ip/otbn/dv/otbnsim/sim/state.py
index 86bdd3a..e93e82d 100644
--- a/hw/ip/otbn/dv/otbnsim/sim/state.py
+++ b/hw/ip/otbn/dv/otbnsim/sim/state.py
@@ -407,6 +407,14 @@
return (carryless_result, FlagReg.mlz_for_result(C, carryless_result))
+ @staticmethod
+ def subtract_with_borrow(a: int, b: int, borrow_in: int) -> Tuple[int, FlagReg]:
+ result = a - b - borrow_in
+ carryless_result = result & ((1 << 256) - 1)
+ C = bool((result >> 256) & 1)
+
+ return (carryless_result, FlagReg.mlz_for_result(C, carryless_result))
+
def update_mlz_flags(self, fg: int, result: int) -> None:
'''Update M, L, Z flags for the given result'''
self.flags[fg] = FlagReg.mlz_for_result(self.flags[fg].C, result)