[otbn] Allow multiple errors in a single cycle in ISS

This means we can't just raise an exception any more, which makes
things a little more fiddly. Instead, any component of the state that
might raise an error has an "errs" list, to which it appends an error
if it sees one. These lists get concatenated in a tree, in much the
same way as we gather up changes.

After executing an instruction, the code checks whether there were any
errors. If not, it commits pending changes. If there is an error, it
figures out the correct value of ERR_BITS and stops.

Signed-off-by: Rupert Swarbrick <rswarbrick@lowrisc.org>
diff --git a/hw/ip/otbn/dv/otbnsim/sim/alert.py b/hw/ip/otbn/dv/otbnsim/sim/alert.py
index e30a834..a32ff48 100644
--- a/hw/ip/otbn/dv/otbnsim/sim/alert.py
+++ b/hw/ip/otbn/dv/otbnsim/sim/alert.py
@@ -17,11 +17,11 @@
 ERR_CODE_FATAL_REG = 1 << 7
 
 
-class Alert(Exception):
-    '''An exception raised to signal that the program did something wrong
+class Alert:
+    '''An object describing something the program did wrong
 
     This maps onto alerts in the implementation. The err_code value is the
-    value that should be written to the ERR_CODE external register.
+    value that should be OR'd into the ERR_BITS external register.
 
     '''
     # Subclasses should override this class field or the error_code method
@@ -33,7 +33,7 @@
 
 
 class BadAddrError(Alert):
-    '''Raised when loading or storing or setting PC with a bad address'''
+    '''Generated when loading or storing or setting PC with a bad address'''
 
     def __init__(self, operation: str, addr: int, what: str):
         assert operation in ['pc',
@@ -54,7 +54,7 @@
 
 
 class LoopError(Alert):
-    '''Raised when doing something wrong with a LOOP/LOOPI'''
+    '''Generated when doing something wrong with a LOOP/LOOPI'''
 
     err_code = ERR_CODE_LOOP
 
@@ -63,3 +63,15 @@
 
     def __str__(self) -> str:
         return 'Loop error: {}'.format(self.what)
+
+
+class IllegalInsnError(Alert):
+    '''Generated on a bad instruction'''
+    err_code = ERR_CODE_ILLEGAL_INSN
+
+    def __init__(self, word: int, msg: str):
+        self.word = word
+        self.msg = msg
+
+    def __str__(self) -> str:
+        return ('Illegal instruction {:#010x}: {}'.format(self.word, self.msg))
diff --git a/hw/ip/otbn/dv/otbnsim/sim/decode.py b/hw/ip/otbn/dv/otbnsim/sim/decode.py
index 60eabef..146ec46 100644
--- a/hw/ip/otbn/dv/otbnsim/sim/decode.py
+++ b/hw/ip/otbn/dv/otbnsim/sim/decode.py
@@ -7,7 +7,7 @@
 import struct
 from typing import List, Optional, Tuple, Type
 
-from .alert import Alert, ERR_CODE_ILLEGAL_INSN
+from .alert import IllegalInsnError
 from .isa import DecodeError, OTBNInsn
 from .insn import INSN_CLASSES
 from .state import OTBNState
@@ -18,18 +18,6 @@
 _MaskTuple = Tuple[int, int, Type[OTBNInsn]]
 
 
-class IllegalInsnError(Alert):
-    '''Raised on a bad instruction'''
-    err_code = ERR_CODE_ILLEGAL_INSN
-
-    def __init__(self, word: int, msg: str):
-        self.word = word
-        self.msg = msg
-
-    def __str__(self) -> str:
-        return ('Illegal instruction {:#010x}: {}'.format(self.word, self.msg))
-
-
 class IllegalInsn(OTBNInsn):
     '''A catch-all subclass of Instruction for bad data
 
@@ -51,7 +39,7 @@
         self._disasm = (pc, '?? 0x{:08x}'.format(raw))
 
     def execute(self, state: OTBNState) -> None:
-        raise IllegalInsnError(self.raw, self.msg)
+        state.on_error(IllegalInsnError(self.raw, self.msg))
 
 
 MASK_TUPLES = None  # type: Optional[List[_MaskTuple]]
diff --git a/hw/ip/otbn/dv/otbnsim/sim/dmem.py b/hw/ip/otbn/dv/otbnsim/sim/dmem.py
index b0a7d9d..f2becd2 100644
--- a/hw/ip/otbn/dv/otbnsim/sim/dmem.py
+++ b/hw/ip/otbn/dv/otbnsim/sim/dmem.py
@@ -61,6 +61,8 @@
         self.data = [uninit] * num_words
         self.trace = []  # type: List[TraceDmemStore]
 
+        self._errs = []  # type: List[BadAddrError]
+
     def _get_u32s(self, idx: int) -> List[int]:
         '''Return the value at idx as 8 uint32's
 
@@ -143,14 +145,16 @@
         assert addr >= 0
 
         if addr & 31:
-            raise BadAddrError('wide load', addr,
-                               'address is not 32-byte aligned')
+            self._errs.append(BadAddrError('wide load', addr,
+                                           'address is not 32-byte aligned'))
+            return 0
 
         word_addr = addr // 32
 
         if word_addr >= len(self.data):
-            raise BadAddrError('wide load', addr,
-                               'address is above the top of dmem')
+            self._errs.append(BadAddrError('wide load', addr,
+                                           'address is above the top of dmem'))
+            return 0
 
         return self.data[word_addr]
 
@@ -160,13 +164,15 @@
         assert 0 <= value < (1 << 256)
 
         if addr & 31:
-            raise BadAddrError('wide store', addr,
-                               'address is not 32-byte aligned')
+            self._errs.append(BadAddrError('wide store', addr,
+                                           'address is not 32-byte aligned'))
+            return
 
         word_addr = addr // 32
         if word_addr >= len(self.data):
-            raise BadAddrError('wide store', addr,
-                               'address is above the top of dmem')
+            self._errs.append(BadAddrError('wide store', addr,
+                                           'address is above the top of dmem'))
+            return
 
         self.trace.append(TraceDmemStore(addr, value, True))
 
@@ -179,11 +185,13 @@
         '''
         assert addr >= 0
         if addr & 3:
-            raise BadAddrError('narrow load', addr,
-                               'address is not 4-byte aligned')
+            self._errs.append(BadAddrError('narrow load', addr,
+                                           'address is not 4-byte aligned'))
+            return 0
         if (addr + 3) // 32 >= len(self.data):
-            raise BadAddrError('narrow load', addr,
-                               'address is above the top of dmem')
+            self._errs.append(BadAddrError('narrow load', addr,
+                                           'address is above the top of dmem'))
+            return 0
 
         idx32 = addr // 4
         idxW = idx32 // 8
@@ -201,14 +209,19 @@
         assert 0 <= value <= (1 << 32) - 1
 
         if addr & 3:
-            raise BadAddrError('narrow load', addr,
-                               'address is not 4-byte aligned')
+            self._errs.append(BadAddrError('narrow load', addr,
+                                           'address is not 4-byte aligned'))
+            return
         if (addr + 3) // 32 >= len(self.data):
-            raise BadAddrError('narrow load', addr,
-                               'address is above the top of dmem')
+            self._errs.append(BadAddrError('narrow load', addr,
+                                           'address is above the top of dmem'))
+            return
 
         self.trace.append(TraceDmemStore(addr, value, False))
 
+    def errors(self) -> List[BadAddrError]:
+        return self._errs
+
     def changes(self) -> Sequence[Trace]:
         return self.trace
 
@@ -234,9 +247,11 @@
         self._set_u32s(idxW, u32s)
 
     def commit(self) -> None:
+        assert not self._errs
         for item in self.trace:
             self._commit_store(item)
         self.trace = []
 
     def abort(self) -> None:
         self.trace = []
+        self._errs = []
diff --git a/hw/ip/otbn/dv/otbnsim/sim/gpr.py b/hw/ip/otbn/dv/otbnsim/sim/gpr.py
index b08243c..e70cebb 100644
--- a/hw/ip/otbn/dv/otbnsim/sim/gpr.py
+++ b/hw/ip/otbn/dv/otbnsim/sim/gpr.py
@@ -27,10 +27,11 @@
     # The depth of the x1 call stack
     stack_depth = 8
 
-    def __init__(self, parent: RegFile):
+    def __init__(self, parent: 'GPRs'):
         super().__init__(parent, 1, 32, 0)
         self.stack = []  # type: List[int]
         self.saw_read = False
+        self.gpr_parent = parent
 
     # We overload read_unsigned here, to handle the read-sensitive behaviour
     # without needing the base class to deal with it.
@@ -41,13 +42,19 @@
             return self.stack[-1] if self.stack else 0xcafef00d
 
         if not self.stack:
-            raise CallStackError(False)
+            self.gpr_parent.errs.append(CallStackError(False))
+            return 0
 
         # Mark that we've read something (so that we pop from the stack as part
         # of commit) and return the top of the stack.
         self.saw_read = True
         return self.stack[-1]
 
+    def post_insn(self) -> None:
+        if self._next_uval is not None:
+            if not self.saw_read and len(self.stack) == 8:
+                self.gpr_parent.errs.append(CallStackError(True))
+
     def commit(self) -> None:
         if self.saw_read:
             assert self.stack
@@ -55,8 +62,9 @@
             self.saw_read = False
 
         if self._next_uval is not None:
-            if len(self.stack) == 8:
-                raise CallStackError(True)
+            # We should already have checked that we won't overflow the call
+            # stack in post_insn().
+            assert len(self.stack) <= 8
             self.stack.append(self._next_uval)
 
         super().commit()
@@ -72,6 +80,7 @@
     def __init__(self) -> None:
         super().__init__('x', 32, 32)
         self._x1 = CallStackReg(self)
+        self.errs = []  # type: List[CallStackError]
 
     def get_reg(self, idx: int) -> Reg:
         if idx == 0:
@@ -89,10 +98,18 @@
         '''Get the call stack, bottom-first.'''
         return self._x1.stack
 
+    def post_insn(self) -> None:
+        return self._x1.post_insn()
+
+    def errors(self) -> List[CallStackError]:
+        return self.errs
+
     def commit(self) -> None:
-        self._x1.commit()
         super().commit()
+        assert not self.errs
+        self._x1.commit()
 
     def abort(self) -> None:
-        self._x1.abort()
         super().abort()
+        self._x1.abort()
+        self.errs = []
diff --git a/hw/ip/otbn/dv/otbnsim/sim/insn.py b/hw/ip/otbn/dv/otbnsim/sim/insn.py
index 594abee..2f2cc61 100644
--- a/hw/ip/otbn/dv/otbnsim/sim/insn.py
+++ b/hw/ip/otbn/dv/otbnsim/sim/insn.py
@@ -4,7 +4,7 @@
 
 from typing import Dict
 
-from .alert import LoopError
+from .alert import ERR_CODE_NO_ERROR, LoopError
 from .flags import FlagReg
 from .isa import (DecodeError, OTBNInsn, RV32RegReg, RV32RegImm, RV32ImmShift,
                   insn_for_mnemonic, logical_byte_shift)
@@ -315,7 +315,7 @@
 
     def execute(self, state: OTBNState) -> None:
         # Set INTR_STATE.done and STATUS, reflecting the fact we've stopped.
-        state.stop(None)
+        state._stop(ERR_CODE_NO_ERROR)
 
 
 class LOOP(OTBNInsn):
@@ -330,8 +330,9 @@
     def execute(self, state: OTBNState) -> None:
         num_iters = state.gprs.get_reg(self.grs).read_unsigned()
         if num_iters == 0:
-            raise LoopError('loop count in x{} was zero'
-                            .format(self.grs))
+            state.on_error(LoopError('loop count in x{} was zero'
+                                     .format(self.grs)))
+            return
 
         state.loop_start(num_iters, self.bodysize)
 
diff --git a/hw/ip/otbn/dv/otbnsim/sim/loop.py b/hw/ip/otbn/dv/otbnsim/sim/loop.py
index f14102f..4e8e564 100644
--- a/hw/ip/otbn/dv/otbnsim/sim/loop.py
+++ b/hw/ip/otbn/dv/otbnsim/sim/loop.py
@@ -59,6 +59,7 @@
     def __init__(self) -> None:
         self.stack = []  # type: List[LoopLevel]
         self.trace = []  # type: List[Trace]
+        self.errs = []  # type: List[LoopError]
 
     def start_loop(self,
                    next_addr: int,
@@ -77,7 +78,7 @@
         assert 0 < loop_count
 
         if len(self.stack) == LoopStack.stack_depth:
-            raise LoopError('loop stack overflow')
+            self.errs.append(LoopError('loop stack overflow'))
 
         self.trace.append(TraceLoopStart(loop_count, insn_count))
         self.stack.append(LoopLevel(next_addr, insn_count, loop_count - 1))
@@ -91,7 +92,8 @@
                 # Make sure that it isn't a jump, branch or another loop
                 # instruction.
                 if insn_affects_control:
-                    raise LoopError('control instruction at end of loop')
+                    self.errs.append(LoopError('control instruction '
+                                               'at end of loop'))
 
     def step(self, next_pc: int) -> Optional[int]:
         '''Update loop stack. If we should loop, return new PC'''
@@ -116,11 +118,16 @@
 
         return None
 
+    def errors(self) -> List[LoopError]:
+        return self.errs
+
     def changes(self) -> List[Trace]:
         return self.trace
 
     def commit(self) -> None:
+        assert not self.errs
         self.trace = []
 
     def abort(self) -> None:
         self.trace = []
+        self.errs = []
diff --git a/hw/ip/otbn/dv/otbnsim/sim/sim.py b/hw/ip/otbn/dv/otbnsim/sim/sim.py
index cc58681..05df7f1 100644
--- a/hw/ip/otbn/dv/otbnsim/sim/sim.py
+++ b/hw/ip/otbn/dv/otbnsim/sim/sim.py
@@ -48,46 +48,41 @@
         was_stalled = self.state.stalled
         pc_before = self.state.pc
 
-        try:
-            if was_stalled:
-                insn = None
-                changes = []
-            else:
-                word_pc = int(self.state.pc) >> 2
-                if word_pc >= len(self.program):
-                    raise RuntimeError('Trying to execute instruction at address '
-                                       '{:#x}, but the program is only {:#x} '
-                                       'bytes ({} instructions) long. Since there '
-                                       'are no architectural contents of the '
-                                       'memory here, we have to stop.'
-                                       .format(int(self.state.pc),
-                                               4 * len(self.program),
-                                               len(self.program)))
-                insn = self.program[word_pc]
-
-                if insn.insn.cycles > 1:
-                    self.state.add_stall_cycles(insn.insn.cycles - 1)
-
-                self.state.pre_insn(insn.affects_control)
-                insn.execute(self.state)
-                self.state.post_insn()
-
-                changes = self.state.changes()
-
+        if was_stalled:
+            insn = None
+            changes = []
             self.state.commit()
+        else:
+            word_pc = int(self.state.pc) >> 2
+            if word_pc >= len(self.program):
+                raise RuntimeError('Trying to execute instruction at address '
+                                   '{:#x}, but the program is only {:#x} '
+                                   'bytes ({} instructions) long. Since there '
+                                   'are no architectural contents of the '
+                                   'memory here, we have to stop.'
+                                   .format(int(self.state.pc),
+                                           4 * len(self.program),
+                                           len(self.program)))
+            insn = self.program[word_pc]
 
-        except Alert as alert:
-            # Roll back any pending state changes: we ensure that a faulting
-            # instruction doesn't actually do anything.
-            self.state.abort()
+            if insn.insn.cycles > 1:
+                self.state.add_stall_cycles(insn.insn.cycles - 1)
 
-            # We've rolled back any changes, but need to actually generate an
-            # "alert". To do that, we tell the state to set an appropriate
-            # error code in the external ERR_CODE register and clear the busy
-            # flag. These register changes get reflected in the returned list
-            # of trace items, that we propagate up.
-            self.state.stop(alert.error_code())
-            changes = self.state.changes()
+            self.state.pre_insn(insn.affects_control)
+            insn.execute(self.state)
+            self.state.post_insn()
+
+            errors = self.state.errors()
+            if errors:
+                # Roll back any pending state changes, ensuring that a faulting
+                # instruction doesn't actually do anything. Also generate a
+                # change that sets an appropriate error bits in the external
+                # ERR_CODE register and clears the busy flag.
+                self.state.die(errors)
+                changes = self.state.changes()
+            else:
+                changes = self.state.changes()
+                self.state.commit()
 
         if verbose:
             disasm = ('(stall)' if insn is None
diff --git a/hw/ip/otbn/dv/otbnsim/sim/state.py b/hw/ip/otbn/dv/otbnsim/sim/state.py
index 690b7a8..077d4ce 100644
--- a/hw/ip/otbn/dv/otbnsim/sim/state.py
+++ b/hw/ip/otbn/dv/otbnsim/sim/state.py
@@ -6,7 +6,7 @@
 
 from shared.mem_layout import get_memory_layout
 
-from .alert import BadAddrError
+from .alert import Alert, BadAddrError
 from .csr import CSRFile
 from .dmem import Dmem
 from .ext_regs import OTBNExtRegs
@@ -51,6 +51,8 @@
         self.ext_regs = OTBNExtRegs()
         self.running = False
 
+        self.errs = []  # type: List[Alert]
+
     def add_stall_cycles(self, num_cycles: int) -> None:
         '''Add stall cycles before the next insn completes'''
         assert num_cycles >= 0
@@ -65,6 +67,14 @@
         if back_pc is not None:
             self.pc_next = back_pc
 
+    def errors(self) -> List[Alert]:
+        c = []  # type: List[Alert]
+        c += self.errs
+        c += self.gprs.errors()
+        c += self.dmem.errors()
+        c += self.loop_stack.errors()
+        return c
+
     def changes(self) -> List[Trace]:
         c = []  # type: List[Trace]
         c += self.gprs.changes()
@@ -113,7 +123,7 @@
         self.csrs.flags.commit()
         self.wdrs.commit()
 
-    def abort(self) -> None:
+    def _abort(self) -> None:
         '''Abort any pending state changes'''
         # This should only be called when an instruction's execution goes
         # wrong. If self._stalls is positive, the bad execution caused those
@@ -136,6 +146,29 @@
         self._start_stall = True
         self.stalled = True
 
+    def _stop(self, err_bits: int) -> None:
+        '''Set flags to stop the processor.
+
+        err_bits is the value written to the ERR_BITS register.
+
+        '''
+        # INTR_STATE is the interrupt state register. Bit 0 (which is being
+        # set) is the 'done' flag.
+        self.ext_regs.set_bits('INTR_STATE', 1 << 0)
+        # STATUS is a status register. Bit 0 (being cleared) is the 'busy' flag
+        self.ext_regs.clear_bits('STATUS', 1 << 0)
+
+        self.ext_regs.write('ERR_BITS', err_bits, True)
+        self.running = False
+
+    def die(self, alerts: List[Alert]) -> None:
+        err_bits = 0
+        for alert in alerts:
+            err_bits |= alert.error_code()
+
+        self._abort()
+        self._stop(err_bits)
+
     def get_quarter_word_unsigned(self, idx: int, qwsel: int) -> int:
         '''Select a 64-bit quarter of a wide register.
 
@@ -221,7 +254,7 @@
     def check_jump_dest(self) -> None:
         '''Check whether self.pc_next is a valid jump/branch target
 
-        If not, raises a BadAddrError.
+        If not, generates a BadAddrError.
 
         '''
         if self.pc_next is None:
@@ -233,18 +266,19 @@
 
         # Check the new PC is word-aligned
         if self.pc_next & 3:
-            raise BadAddrError('pc', self.pc_next,
-                               'address is not 4-byte aligned')
+            self.errs.append(BadAddrError('pc', self.pc_next,
+                                          'address is not 4-byte aligned'))
 
         # Check the new PC lies in instruction memory
         if self.pc_next >= self.imem_size:
-            raise BadAddrError('pc', self.pc_next,
-                               'address lies above the top of imem')
+            self.errs.append(BadAddrError('pc', self.pc_next,
+                                          'address lies above the top of imem'))
 
     def post_insn(self) -> None:
         '''Update state after running an instruction but before commit'''
         self.check_jump_dest()
         self.loop_step()
+        self.gprs.post_insn()
 
     def read_csr(self, idx: int) -> int:
         '''Read the CSR with index idx as an unsigned 32-bit number'''
@@ -258,20 +292,6 @@
         '''Return the current call stack, bottom-first'''
         return self.gprs.peek_call_stack()
 
-    def stop(self, err_code: Optional[int]) -> None:
-        '''Set flags to stop the processor.
-
-        If err_code is not None, it is the value to write to the ERR_BITS
-        register.
-
-        '''
-        # INTR_STATE is the interrupt state register. Bit 0 (which is being
-        # set) is the 'done' flag.
-        self.ext_regs.set_bits('INTR_STATE', 1 << 0)
-        # STATUS is a status register. Bit 0 (being cleared) is the 'busy' flag
-        self.ext_regs.clear_bits('STATUS', 1 << 0)
-
-        if err_code is not None:
-            self.ext_regs.write('ERR_BITS', err_code, True)
-
-        self.running = False
+    def on_error(self, error: Alert) -> None:
+        '''Add a pending error that will be reported at the end of the cycle'''
+        self.errs.append(error)