[otbn] Fix use of carry flag for borrow in otbnsim When carry is set, for subb and cmpb it acts as a borrow flag so A - B - c must be computed (where c is the carry flag). This introduces a seperate 'subtract_with_borrow' function to the model to deal with this. Documentation is also updated to match. Signed-off-by: Greg Chadwick <gac@lowrisc.org>
diff --git a/hw/ip/otbn/data/bignum-insns.yml b/hw/ip/otbn/data/bignum-insns.yml index 5d637b1..4f7ef65 100644 --- a/hw/ip/otbn/data/bignum-insns.yml +++ b/hw/ip/otbn/data/bignum-insns.yml
@@ -368,7 +368,7 @@ st = DecodeShiftType(shift_type) operation: | b_shifted = ShiftReg(b, st, sb) - (result, flags_out) = AddWithCarry(a, -b_shifted, "0") + (result, flags_out) = SubtractWithBorrow(a, b_shifted, 0) WDR[d] = result FLAGS[flag_group] = flags_out @@ -393,7 +393,7 @@ decode: *bn-sub-decode operation: | b_shifted = ShiftReg(b, st, sb) - (result, flags_out) = AddWithCarry(a, -b_shifted, ~FLAGS[flag_group].C) + (result, flags_out) = SubtractWithBorrow(a, b_shifted, FLAGS[flag_group].C) WDR[d] = result FLAGS[flag_group] = flags_out @@ -430,7 +430,7 @@ fg = DecodeFlagGroup(flag_group) i = ZeroExtend(imm, WLEN) operation: | - (result, flags_out) = AddWithCarry(a, -i, "0") + (result, flags_out) = SubtractWithBorrow(a, i, 0) WDR[d] = result FLAGS[flag_group] = flags_out @@ -457,7 +457,7 @@ a = UInt(wrs1) b = UInt(wrs2) operation: | - (result, ) = AddWithCarry(a, -b, "0") + (result, ) = SubtractWithBorrow(a, b, 0) if result < 0: result = MOD + result @@ -709,7 +709,7 @@ st = DecodeShiftType(shift_type) operation: | b_shifted = ShiftReg(b, st, sb) - (, flags_out) = AddWithCarry(a, -b_shifted, "0") + (, flags_out) = SubtractWithBorrow(a, b_shifted, 0) FLAGS[flag_group] = flags_out encoding: @@ -731,7 +731,7 @@ This instruction is identical to BN.SUBB, except that no result register is written. decode: *bn-cmp-decode operation: | - (, flags_out) = AddWithCarry(a, -b, ~FLAGS[flag_group].C) + (, flags_out) = SubtractWithBorrow(a, b, FLAGS[flag_group].C) FLAGS[flag_group] = flags_out encoding:
diff --git a/hw/ip/otbn/doc/_index.md b/hw/ip/otbn/doc/_index.md index c8cb298..ab0af87 100644 --- a/hw/ip/otbn/doc/_index.md +++ b/hw/ip/otbn/doc/_index.md
@@ -437,6 +437,17 @@ return (result[WLEN-1:0], flags_out) +def SubtractWithBorrow(a: Bits(WLEN), b: Bits(WLEN), borrow_in: Bits(1)) -> (Bits(WLEN), FlagGroup): + result: Bits[WLEN+1] = a - b - borrow_in + + flags_out = FlagGroup() + flags_out.C = result[WLEN] + flags_out.L = result[0] + flags_out.M = result[WLEN-1] + flags_out.Z = (result[WLEN-1:0] == 0) + + return (result[WLEN-1:0], flags_out) + def DecodeHalfWordSelect(hwsel: Bits(1)) -> HalfWord: if hwsel == 0: return HalfWord.LOWER
diff --git a/hw/ip/otbn/dv/otbnsim/sim/insn.py b/hw/ip/otbn/dv/otbnsim/sim/insn.py index 387de72..60ccb83 100644 --- a/hw/ip/otbn/dv/otbnsim/sim/insn.py +++ b/hw/ip/otbn/dv/otbnsim/sim/insn.py
@@ -507,7 +507,7 @@ a = int(state.wreg[self.wrs1]) b_shifted = ShiftReg(int(state.wreg[self.wrs2]), self.shift_type, self.shift_bytes) - (result, flags) = state.add_with_carry(a, -b_shifted, 0) + (result, flags) = state.subtract_with_borrow(a, b_shifted, 0) state.wreg[self.wrd] = result state.flags[self.flag_group] = flags @@ -525,12 +525,14 @@ self.flag_group = op_vals['flag_group'] def execute(self, state: OTBNState) -> None: + assert (state.flags[self.flag_group].C == 0 or + state.flags[self.flag_group].C == 1) + a = int(state.wreg[self.wrs1]) b_shifted = ShiftReg(int(state.wreg[self.wrs2]), self.shift_type, self.shift_bytes) - (result, - flags) = state.add_with_carry(a, -b_shifted, - 1 - state.flags[self.flag_group].C) + flag_c = state.flags[self.flag_group].C + (result, flags) = state.subtract_with_borrow(a, b_shifted, flag_c) state.wreg[self.wrd] = result state.flags[self.flag_group] = flags @@ -548,7 +550,7 @@ def execute(self, state: OTBNState) -> None: a = int(state.wreg[self.wrs]) b = int(self.imm) - (result, flags) = state.add_with_carry(a, -b, 0) + (result, flags) = state.subtract_with_borrow(a, b, 0) state.wreg[self.wrd] = result state.flags[self.flag_group] = flags @@ -565,7 +567,7 @@ def execute(self, state: OTBNState) -> None: a = int(state.wreg[self.wrs1]) b = int(state.wreg[self.wrs2]) - result, _ = state.add_with_carry(a, -b, 0) + result, _ = state.subtract_with_borrow(a, b, 0) if result < 0: result += state.mod state.wreg[self.wrd] = result @@ -698,7 +700,7 @@ a = int(state.wreg[self.wrs1]) b_shifted = ShiftReg(int(state.wreg[self.wrs2]), self.shift_type, self.shift_bytes) - (_, flags) = state.add_with_carry(a, -b_shifted, 0) + (_, flags) = state.subtract_with_borrow(a, b_shifted, 0) state.flags[self.flag_group] = flags @@ -714,11 +716,13 @@ self.flag_group = op_vals['flag_group'] def execute(self, state: OTBNState) -> None: + assert (state.flags[self.flag_group].C == 0 or + state.flags[self.flag_group].C == 1) a = int(state.wreg[self.wrs1]) b_shifted = ShiftReg(int(state.wreg[self.wrs2]), self.shift_type, self.shift_bytes) - carry_flag = 1 - state.flags[self.flag_group].C - (_, flags) = state.add_with_carry(a, -b_shifted, carry_flag) + flag_c = state.flags[self.flag_group].C + (_, flags) = state.subtract_with_borrow(a, b_shifted, flag_c) state.flags[self.flag_group] = flags
diff --git a/hw/ip/otbn/dv/otbnsim/sim/state.py b/hw/ip/otbn/dv/otbnsim/sim/state.py index 86bdd3a..e93e82d 100644 --- a/hw/ip/otbn/dv/otbnsim/sim/state.py +++ b/hw/ip/otbn/dv/otbnsim/sim/state.py
@@ -407,6 +407,14 @@ return (carryless_result, FlagReg.mlz_for_result(C, carryless_result)) + @staticmethod + def subtract_with_borrow(a: int, b: int, borrow_in: int) -> Tuple[int, FlagReg]: + result = a - b - borrow_in + carryless_result = result & ((1 << 256) - 1) + C = bool((result >> 256) & 1) + + return (carryless_result, FlagReg.mlz_for_result(C, carryless_result)) + def update_mlz_flags(self, fg: int, result: int) -> None: '''Update M, L, Z flags for the given result''' self.flags[fg] = FlagReg.mlz_for_result(self.flags[fg].C, result)