[otbn] Fix use of carry flag for borrow in otbnsim

When carry is set, for subb and cmpb it acts as a borrow flag so A - B -
c must be computed (where c is the carry flag). This introduces a
seperate 'subtract_with_borrow' function to the model to deal with this.
Documentation is also updated to match.

Signed-off-by: Greg Chadwick <gac@lowrisc.org>
diff --git a/hw/ip/otbn/data/bignum-insns.yml b/hw/ip/otbn/data/bignum-insns.yml
index 5d637b1..4f7ef65 100644
--- a/hw/ip/otbn/data/bignum-insns.yml
+++ b/hw/ip/otbn/data/bignum-insns.yml
@@ -368,7 +368,7 @@
     st = DecodeShiftType(shift_type)
   operation: |
     b_shifted = ShiftReg(b, st, sb)
-    (result, flags_out) = AddWithCarry(a, -b_shifted, "0")
+    (result, flags_out) = SubtractWithBorrow(a, b_shifted, 0)
 
     WDR[d] = result
     FLAGS[flag_group] = flags_out
@@ -393,7 +393,7 @@
   decode: *bn-sub-decode
   operation: |
     b_shifted = ShiftReg(b, st, sb)
-    (result, flags_out) = AddWithCarry(a, -b_shifted, ~FLAGS[flag_group].C)
+    (result, flags_out) = SubtractWithBorrow(a, b_shifted, FLAGS[flag_group].C)
 
     WDR[d] = result
     FLAGS[flag_group] = flags_out
@@ -430,7 +430,7 @@
     fg = DecodeFlagGroup(flag_group)
     i = ZeroExtend(imm, WLEN)
   operation: |
-    (result, flags_out) = AddWithCarry(a, -i, "0")
+    (result, flags_out) = SubtractWithBorrow(a, i, 0)
 
     WDR[d] = result
     FLAGS[flag_group] = flags_out
@@ -457,7 +457,7 @@
     a = UInt(wrs1)
     b = UInt(wrs2)
   operation: |
-    (result, ) = AddWithCarry(a, -b, "0")
+    (result, ) = SubtractWithBorrow(a, b, 0)
 
     if result < 0:
       result = MOD + result
@@ -709,7 +709,7 @@
     st = DecodeShiftType(shift_type)
   operation: |
     b_shifted = ShiftReg(b, st, sb)
-    (, flags_out) = AddWithCarry(a, -b_shifted, "0")
+    (, flags_out) = SubtractWithBorrow(a, b_shifted, 0)
 
     FLAGS[flag_group] = flags_out
   encoding:
@@ -731,7 +731,7 @@
     This instruction is identical to BN.SUBB, except that no result register is written.
   decode: *bn-cmp-decode
   operation: |
-    (, flags_out) = AddWithCarry(a, -b, ~FLAGS[flag_group].C)
+    (, flags_out) = SubtractWithBorrow(a, b, FLAGS[flag_group].C)
 
     FLAGS[flag_group] = flags_out
   encoding:
diff --git a/hw/ip/otbn/doc/_index.md b/hw/ip/otbn/doc/_index.md
index c8cb298..ab0af87 100644
--- a/hw/ip/otbn/doc/_index.md
+++ b/hw/ip/otbn/doc/_index.md
@@ -437,6 +437,17 @@
 
   return (result[WLEN-1:0], flags_out)
 
+def SubtractWithBorrow(a: Bits(WLEN), b: Bits(WLEN), borrow_in: Bits(1)) -> (Bits(WLEN), FlagGroup):
+  result: Bits[WLEN+1] = a - b - borrow_in
+
+  flags_out = FlagGroup()
+  flags_out.C = result[WLEN]
+  flags_out.L = result[0]
+  flags_out.M = result[WLEN-1]
+  flags_out.Z = (result[WLEN-1:0] == 0)
+
+  return (result[WLEN-1:0], flags_out)
+
 def DecodeHalfWordSelect(hwsel: Bits(1)) -> HalfWord:
   if hwsel == 0:
     return HalfWord.LOWER
diff --git a/hw/ip/otbn/dv/otbnsim/sim/insn.py b/hw/ip/otbn/dv/otbnsim/sim/insn.py
index 387de72..60ccb83 100644
--- a/hw/ip/otbn/dv/otbnsim/sim/insn.py
+++ b/hw/ip/otbn/dv/otbnsim/sim/insn.py
@@ -507,7 +507,7 @@
         a = int(state.wreg[self.wrs1])
         b_shifted = ShiftReg(int(state.wreg[self.wrs2]), self.shift_type,
                              self.shift_bytes)
-        (result, flags) = state.add_with_carry(a, -b_shifted, 0)
+        (result, flags) = state.subtract_with_borrow(a, b_shifted, 0)
         state.wreg[self.wrd] = result
         state.flags[self.flag_group] = flags
 
@@ -525,12 +525,14 @@
         self.flag_group = op_vals['flag_group']
 
     def execute(self, state: OTBNState) -> None:
+        assert (state.flags[self.flag_group].C == 0 or
+                state.flags[self.flag_group].C == 1)
+
         a = int(state.wreg[self.wrs1])
         b_shifted = ShiftReg(int(state.wreg[self.wrs2]), self.shift_type,
                              self.shift_bytes)
-        (result,
-         flags) = state.add_with_carry(a, -b_shifted,
-                                       1 - state.flags[self.flag_group].C)
+        flag_c = state.flags[self.flag_group].C
+        (result, flags) = state.subtract_with_borrow(a, b_shifted, flag_c)
         state.wreg[self.wrd] = result
         state.flags[self.flag_group] = flags
 
@@ -548,7 +550,7 @@
     def execute(self, state: OTBNState) -> None:
         a = int(state.wreg[self.wrs])
         b = int(self.imm)
-        (result, flags) = state.add_with_carry(a, -b, 0)
+        (result, flags) = state.subtract_with_borrow(a, b, 0)
         state.wreg[self.wrd] = result
         state.flags[self.flag_group] = flags
 
@@ -565,7 +567,7 @@
     def execute(self, state: OTBNState) -> None:
         a = int(state.wreg[self.wrs1])
         b = int(state.wreg[self.wrs2])
-        result, _ = state.add_with_carry(a, -b, 0)
+        result, _ = state.subtract_with_borrow(a, b, 0)
         if result < 0:
             result += state.mod
         state.wreg[self.wrd] = result
@@ -698,7 +700,7 @@
         a = int(state.wreg[self.wrs1])
         b_shifted = ShiftReg(int(state.wreg[self.wrs2]), self.shift_type,
                              self.shift_bytes)
-        (_, flags) = state.add_with_carry(a, -b_shifted, 0)
+        (_, flags) = state.subtract_with_borrow(a, b_shifted, 0)
         state.flags[self.flag_group] = flags
 
 
@@ -714,11 +716,13 @@
         self.flag_group = op_vals['flag_group']
 
     def execute(self, state: OTBNState) -> None:
+        assert (state.flags[self.flag_group].C == 0 or
+                state.flags[self.flag_group].C == 1)
         a = int(state.wreg[self.wrs1])
         b_shifted = ShiftReg(int(state.wreg[self.wrs2]), self.shift_type,
                              self.shift_bytes)
-        carry_flag = 1 - state.flags[self.flag_group].C
-        (_, flags) = state.add_with_carry(a, -b_shifted, carry_flag)
+        flag_c = state.flags[self.flag_group].C
+        (_, flags) = state.subtract_with_borrow(a, b_shifted, flag_c)
         state.flags[self.flag_group] = flags
 
 
diff --git a/hw/ip/otbn/dv/otbnsim/sim/state.py b/hw/ip/otbn/dv/otbnsim/sim/state.py
index 86bdd3a..e93e82d 100644
--- a/hw/ip/otbn/dv/otbnsim/sim/state.py
+++ b/hw/ip/otbn/dv/otbnsim/sim/state.py
@@ -407,6 +407,14 @@
 
         return (carryless_result, FlagReg.mlz_for_result(C, carryless_result))
 
+    @staticmethod
+    def subtract_with_borrow(a: int, b: int, borrow_in: int) -> Tuple[int, FlagReg]:
+        result = a - b - borrow_in
+        carryless_result = result & ((1 << 256) - 1)
+        C = bool((result >> 256) & 1)
+
+        return (carryless_result, FlagReg.mlz_for_result(C, carryless_result))
+
     def update_mlz_flags(self, fg: int, result: int) -> None:
         '''Update M, L, Z flags for the given result'''
         self.flags[fg] = FlagReg.mlz_for_result(self.flags[fg].C, result)