[otbn] Make stalling more explicit in instruction definitions

This is prompted by the fact that we couldn't properly model things
like faulting calls to BN.LID. We need to do some of the instruction
body on the first cycle (to spot that an address is bogus). However,
we can't run everything on the first cycle. For example accesses of
the RND WSR/CSR can't do their work until the data is available, and
this data isn't supplied by the Python simulator.

Solve the problem by splitting the execution over multiple calls to
the execute() method, using Python's "yield" statement to cause the
stall.

Signed-off-by: Rupert Swarbrick <rswarbrick@lowrisc.org>
diff --git a/hw/ip/otbn/doc/isa.md b/hw/ip/otbn/doc/isa.md
index 0c267e2..3d6dc8a 100644
--- a/hw/ip/otbn/doc/isa.md
+++ b/hw/ip/otbn/doc/isa.md
@@ -61,6 +61,9 @@
 Memory stores are represented as `DMEM.store_u32(addr, value)` and `DMEM.store_u256(addr, value)`.
 In all cases, memory values are interpreted as unsigned integers and, as for register accesses, the instruction descriptions are written to ensure that any value stored to memory is representable.
 
+Some instructions can stall for one or more cycles (those instructions that access memory, CSRs or WSRs).
+To represent this precisely in the pseudo-code, and the simulator reference model, such instructions execute a `yield` statement to stall the processor for a cycle.
+
 There are a few other helper functions, defined here to avoid having to inline their bodies into each instruction.
 ```python3
 def from_2s_complement(n: int) -> int:
diff --git a/hw/ip/otbn/dv/otbnsim/sim/decode.py b/hw/ip/otbn/dv/otbnsim/sim/decode.py
index e956d66..3bcafb2 100644
--- a/hw/ip/otbn/dv/otbnsim/sim/decode.py
+++ b/hw/ip/otbn/dv/otbnsim/sim/decode.py
@@ -5,7 +5,7 @@
 '''Code to load instruction words into a simulator'''
 
 import struct
-from typing import List
+from typing import List, Optional, Iterator
 
 from .err_bits import ILLEGAL_INSN
 from .isa import INSNS_FILE, DecodeError, OTBNInsn
@@ -35,8 +35,9 @@
         # disassembling the underlying DummyInsn.
         self._disasm = (pc, '?? 0x{:08x}'.format(raw))
 
-    def execute(self, state: OTBNState) -> None:
+    def execute(self, state: OTBNState) -> Optional[Iterator[None]]:
         state.stop_at_end_of_cycle(ILLEGAL_INSN)
+        return None
 
 
 def _decode_word(pc: int, word: int) -> OTBNInsn:
diff --git a/hw/ip/otbn/dv/otbnsim/sim/dmem.py b/hw/ip/otbn/dv/otbnsim/sim/dmem.py
index 44f47e7..f4fd374 100644
--- a/hw/ip/otbn/dv/otbnsim/sim/dmem.py
+++ b/hw/ip/otbn/dv/otbnsim/sim/dmem.py
@@ -60,9 +60,6 @@
         self.data = [uninit] * num_words
         self.trace = []  # type: List[TraceDmemStore]
 
-        self._load_begun = False
-        self._load_ready = False
-
     def _get_u32s(self, idx: int) -> List[int]:
         '''Return the value at idx as 8 uint32's
 
@@ -229,23 +226,10 @@
         # And write back
         self._set_u32s(idxW, u32s)
 
-    def commit(self, stalled: bool) -> None:
-        if self._load_begun:
-            self._load_begun = False
-            self._load_ready = True
-        else:
-            self._load_ready = False
-
+    def commit(self) -> None:
         for item in self.trace:
             self._commit_store(item)
         self.trace = []
 
     def abort(self) -> None:
         self.trace = []
-
-    def in_progress_load_complete(self) -> bool:
-        '''Returns true if a previously started load has completed'''
-        return self._load_ready
-
-    def begin_load(self) -> None:
-        self._load_begun = True
diff --git a/hw/ip/otbn/dv/otbnsim/sim/insn.py b/hw/ip/otbn/dv/otbnsim/sim/insn.py
index 0afc83f..6e6b253 100644
--- a/hw/ip/otbn/dv/otbnsim/sim/insn.py
+++ b/hw/ip/otbn/dv/otbnsim/sim/insn.py
@@ -2,11 +2,11 @@
 # Licensed under the Apache License, Version 2.0, see LICENSE for details.
 # SPDX-License-Identifier: Apache-2.0
 
-from typing import Dict
+from typing import Dict, Iterator, Optional
 
 from sim import err_bits
 from .flags import FlagReg
-from .isa import (DecodeError, OTBNInsn, OTBNLDInsn, RV32RegReg, RV32RegImm,
+from .isa import (DecodeError, OTBNInsn, RV32RegReg, RV32RegImm,
                   RV32ImmShift, insn_for_mnemonic, logical_byte_shift,
                   extract_quarter_word)
 from .state import OTBNState
@@ -172,7 +172,7 @@
         state.gprs.get_reg(self.grd).write_unsigned(result)
 
 
-class LW(OTBNLDInsn):
+class LW(OTBNInsn):
     insn = insn_for_mnemonic('lw', 3)
 
     def __init__(self, raw: int, op_vals: Dict[str, int]):
@@ -181,15 +181,26 @@
         self.offset = op_vals['offset']
         self.grs1 = op_vals['grs1']
 
-    def execute(self, state: OTBNState) -> None:
+    def execute(self, state: OTBNState) -> Optional[Iterator[None]]:
+        # LW executes over two cycles. On the first cycle, we read the base
+        # address, compute the load address and check it for correctness, then
+        # perform the load itself, returning the result.
+        #
+        # On the second cycle, we write the result to the destination register.
+
         base = state.gprs.get_reg(self.grs1).read_unsigned()
         addr = (base + self.offset) & ((1 << 32) - 1)
 
         if not state.dmem.is_valid_32b_addr(addr):
             state.stop_at_end_of_cycle(BAD_DATA_ADDR)
-        else:
-            result = state.dmem.load_u32(addr)
-            state.gprs.get_reg(self.grd).write_unsigned(result)
+            return
+
+        result = state.dmem.load_u32(addr)
+
+        # Stall for a single cycle for memory to respond
+        yield
+
+        state.gprs.get_reg(self.grd).write_unsigned(result)
 
 
 class SW(OTBNInsn):
@@ -208,8 +219,9 @@
 
         if not state.dmem.is_valid_32b_addr(addr):
             state.stop_at_end_of_cycle(BAD_DATA_ADDR)
-        else:
-            state.dmem.store_u32(addr, value)
+            return
+
+        state.dmem.store_u32(addr, value)
 
 
 class BEQ(OTBNInsn):
@@ -289,24 +301,25 @@
         self.csr = op_vals['csr']
         self.grs1 = op_vals['grs1']
 
-    def pre_execute(self, state: OTBNState) -> bool:
-        if self.csr == 0xfc0:
-            # Will return False if RND value not available, causing instruction
-            # to stall
-            return state.wsrs.RND.request_value()
-
-        return True
-
-    def execute(self, state: OTBNState) -> None:
+    def execute(self, state: OTBNState) -> Optional[Iterator[None]]:
         if not state.csrs.check_idx(self.csr):
             # Invalid CSR index. Stop with an illegal instruction error.
             state.stop_at_end_of_cycle(ILLEGAL_INSN)
             return
 
-        old_val = state.read_csr(self.csr)
         bits_to_set = state.gprs.get_reg(self.grs1).read_unsigned()
-        new_val = old_val | bits_to_set
 
+        if self.csr == 0xfc0:
+            # A read from RND. If a RND value is not available, request_value()
+            # initiates or continues an EDN request and returns False. If a RND
+            # value is available, it returns True.
+            while not state.wsrs.RND.request_value():
+                # There's a pending EDN request. Stall for a cycle.
+                yield
+
+        # At this point, the CSR is ready. Read, update and write back to grs1.
+        old_val = state.read_csr(self.csr)
+        new_val = old_val | bits_to_set
         state.gprs.get_reg(self.grd).write_unsigned(old_val)
         if self.grs1 != 0:
             state.write_csr(self.csr, new_val)
@@ -321,15 +334,7 @@
         self.csr = op_vals['csr']
         self.grs1 = op_vals['grs1']
 
-    def pre_execute(self, state: OTBNState) -> bool:
-        if self.csr == 0xfc0 and self.grd != 0:
-            # Will return False if RND value not available, causing instruction
-            # to stall
-            return state.wsrs.RND.request_value()
-
-        return True
-
-    def execute(self, state: OTBNState) -> None:
+    def execute(self, state: OTBNState) -> Optional[Iterator[None]]:
         if not state.csrs.check_idx(self.csr):
             # Invalid CSR index. Stop with an illegal instruction error.
             state.stop_at_end_of_cycle(ILLEGAL_INSN)
@@ -337,6 +342,17 @@
 
         new_val = state.gprs.get_reg(self.grs1).read_unsigned()
 
+        if self.csr == 0xfc0 and self.grd != 0:
+            # A read from RND. If a RND value is not available, request_value()
+            # initiates or continues an EDN request and returns False. If a RND
+            # value is available, it returns True.
+            while not state.wsrs.RND.request_value():
+                # There's a pending EDN request. Stall for a cycle.
+                yield
+
+        # At this point, the CSR is either ready or unneeded. Read it if
+        # necessary and write to grd, then overwrite with new_val.
+
         if self.grd != 0:
             old_val = state.read_csr(self.csr)
             state.gprs.get_reg(self.grd).write_unsigned(old_val)
@@ -886,7 +902,7 @@
         state.set_flags(self.flag_group, flags)
 
 
-class BNLID(OTBNLDInsn):
+class BNLID(OTBNInsn):
     insn = insn_for_mnemonic('bn.lid', 5)
 
     def __init__(self, raw: int, op_vals: Dict[str, int]):
@@ -900,27 +916,39 @@
         if self.grd_inc and self.grs1_inc:
             raise DecodeError('grd_inc and grs1_inc both set')
 
-    def execute(self, state: OTBNState) -> None:
+    def execute(self, state: OTBNState) -> Optional[Iterator[None]]:
+        # BN.LID executes over two cycles. On the first cycle, we read the base
+        # address, compute the load address and check it for correctness,
+        # increment any GPRs, then perform the load itself. On the second
+        # cycle, update the WDR with the result.
+
         grs1_val = state.gprs.get_reg(self.grs1).read_unsigned()
         addr = (grs1_val + self.offset) & ((1 << 32) - 1)
         grd_val = state.gprs.get_reg(self.grd).read_unsigned()
 
         if grd_val > 31:
             state.stop_at_end_of_cycle(ILLEGAL_INSN)
-        elif not state.dmem.is_valid_256b_addr(addr):
+            return
+
+        if not state.dmem.is_valid_256b_addr(addr):
             state.stop_at_end_of_cycle(BAD_DATA_ADDR)
-        else:
-            wrd = grd_val & 0x1f
-            value = state.dmem.load_u256(addr)
-            state.wdrs.get_reg(wrd).write_unsigned(value)
+            return
 
-            if self.grd_inc:
-                new_grd_val = grd_val + 1
-                state.gprs.get_reg(self.grd).write_unsigned(new_grd_val)
+        wrd = grd_val & 0x1f
+        value = state.dmem.load_u256(addr)
 
-            if self.grs1_inc:
-                new_grs1_val = (grs1_val + 32) & ((1 << 32) - 1)
-                state.gprs.get_reg(self.grs1).write_unsigned(new_grs1_val)
+        if self.grd_inc:
+            new_grd_val = grd_val + 1
+            state.gprs.get_reg(self.grd).write_unsigned(new_grd_val)
+
+        if self.grs1_inc:
+            new_grs1_val = (grs1_val + 32) & ((1 << 32) - 1)
+            state.gprs.get_reg(self.grs1).write_unsigned(new_grs1_val)
+
+        # Stall for a single cycle for memory to respond
+        yield
+
+        state.wdrs.get_reg(wrd).write_unsigned(value)
 
 
 class BNSID(OTBNInsn):
@@ -945,21 +973,24 @@
 
         if grs2_val > 31:
             state.stop_at_end_of_cycle(ILLEGAL_INSN)
-        elif not state.dmem.is_valid_256b_addr(addr):
+            return
+
+        if not state.dmem.is_valid_256b_addr(addr):
             state.stop_at_end_of_cycle(BAD_DATA_ADDR)
-        else:
-            wrs = grs2_val & 0x1f
-            wrs_val = state.wdrs.get_reg(wrs).read_unsigned()
+            return
 
-            state.dmem.store_u256(addr, wrs_val)
+        wrs = grs2_val & 0x1f
+        wrs_val = state.wdrs.get_reg(wrs).read_unsigned()
 
-            if self.grs1_inc:
-                new_grs1_val = (grs1_val + 32) & ((1 << 32) - 1)
-                state.gprs.get_reg(self.grs1).write_unsigned(new_grs1_val)
+        state.dmem.store_u256(addr, wrs_val)
 
-            if self.grs2_inc:
-                new_grs2_val = grs2_val + 1
-                state.gprs.get_reg(self.grs2).write_unsigned(new_grs2_val)
+        if self.grs1_inc:
+            new_grs1_val = (grs1_val + 32) & ((1 << 32) - 1)
+            state.gprs.get_reg(self.grs1).write_unsigned(new_grs1_val)
+
+        if self.grs2_inc:
+            new_grs2_val = grs2_val + 1
+            state.gprs.get_reg(self.grs2).write_unsigned(new_grs2_val)
 
 
 class BNMOV(OTBNInsn):
@@ -1020,20 +1051,23 @@
         self.wrd = op_vals['wrd']
         self.wsr = op_vals['wsr']
 
-    def pre_execute(self, state: OTBNState) -> bool:
-        if self.wsr == 0x1:
-            # Will return False if RND value not available, causing instruction
-            # to stall
-            return state.wsrs.RND.request_value()
-
-        return True
-
-    def execute(self, state: OTBNState) -> None:
+    def execute(self, state: OTBNState) -> Optional[Iterator[None]]:
+        # The first, and possibly only, cycle of execution.
         if not state.wsrs.check_idx(self.wsr):
             # Invalid WSR index. Stop with an illegal instruction error.
             state.stop_at_end_of_cycle(ILLEGAL_INSN)
             return
 
+        if self.wsr == 0x1:
+            # A read from RND. If a RND value is not available, request_value()
+            # initiates or continues an EDN request and returns False. If a RND
+            # value is available, it returns True.
+            while not state.wsrs.RND.request_value():
+                # There's a pending EDN request. Stall for a cycle.
+                yield
+
+        # At this point, the WSR is ready. Read it, and update wrd with the
+        # result.
         val = state.wsrs.read_at_idx(self.wsr)
         state.wdrs.get_reg(self.wrd).write_unsigned(val)
 
@@ -1047,11 +1081,6 @@
         self.wrs = op_vals['wrs']
 
     def execute(self, state: OTBNState) -> None:
-        if not state.wsrs.check_idx(self.wsr):
-            # Invalid WSR index. Stop with an illegal instruction error.
-            state.stop_at_end_of_cycle(ILLEGAL_INSN)
-            return
-
         val = state.wdrs.get_reg(self.wrs).read_unsigned()
         state.wsrs.write_at_idx(self.wsr, val)
 
diff --git a/hw/ip/otbn/dv/otbnsim/sim/isa.py b/hw/ip/otbn/dv/otbnsim/sim/isa.py
index bbf4557..89eb3fd 100644
--- a/hw/ip/otbn/dv/otbnsim/sim/isa.py
+++ b/hw/ip/otbn/dv/otbnsim/sim/isa.py
@@ -3,7 +3,7 @@
 # SPDX-License-Identifier: Apache-2.0
 
 import sys
-from typing import Dict, Optional, Tuple
+from typing import Dict, Iterator, Optional, Tuple
 
 from shared.insn_yaml import Insn, DummyInsn, load_insns_yaml
 
@@ -79,15 +79,13 @@
         # it can't hurt to check).
         self._disasm = None  # type: Optional[Tuple[int, str]]
 
-    def pre_execute(self, state: OTBNState) -> bool:
-        '''Performs any actions required before instruction can execute.
+    def execute(self, state: OTBNState) -> Optional[Iterator[None]]:
+        '''Execute the instruction
 
-        Return True if instruction is clear to execute. Returning False will
-        stall the simulator for a step.
+        This may yield (returning an iterator object) if the instruction has
+        stalled the processor and will take multiple cycles.
+
         '''
-        return True
-
-    def execute(self, state: OTBNState) -> None:
         raise NotImplementedError('OTBNInsn.execute')
 
     def disassemble(self, pc: int) -> str:
@@ -108,20 +106,6 @@
         return (1 << 32) + value if value < 0 else value
 
 
-class OTBNLDInsn(OTBNInsn):
-    '''A general class for any load instruction providing appropriate stalls'''
-
-    def pre_execute(self, state: OTBNState) -> bool:
-        if state.dmem.in_progress_load_complete():
-            # Load has been started and now complete, execution can proceed
-            return True
-
-        # Load not complete so begin a new load
-        state.dmem.begin_load()
-        # Load stalls when it begins
-        return False
-
-
 class RV32RegReg(OTBNInsn):
     '''A general class for register-register insns from the RV32I ISA'''
     def __init__(self, raw: int, op_vals: Dict[str, int]):
diff --git a/hw/ip/otbn/dv/otbnsim/sim/sim.py b/hw/ip/otbn/dv/otbnsim/sim/sim.py
index be61a83..cfc4d7e 100644
--- a/hw/ip/otbn/dv/otbnsim/sim/sim.py
+++ b/hw/ip/otbn/dv/otbnsim/sim/sim.py
@@ -2,7 +2,7 @@
 # Licensed under the Apache License, Version 2.0, see LICENSE for details.
 # SPDX-License-Identifier: Apache-2.0
 
-from typing import List, Optional, Tuple
+from typing import Iterator, List, Optional, Tuple
 
 from .isa import OTBNInsn
 from .state import OTBNState
@@ -18,6 +18,7 @@
         self.state = OTBNState()
         self.program = []  # type: List[OTBNInsn]
         self.stats = None  # type: Optional[ExecutionStats]
+        self._execute_generator = None  # type: Optional[Iterator[None]]
 
     def load_program(self, program: List[OTBNInsn]) -> None:
         self.program = program.copy()
@@ -32,6 +33,7 @@
 
         '''
         self.stats = ExecutionStats(self.program)
+        self._execute_generator = None
         self.state.start()
 
     def run(self, verbose: bool, collect_stats: bool) -> int:
@@ -58,7 +60,7 @@
     def step(self,
              verbose: bool,
              collect_stats: bool) -> Tuple[Optional[OTBNInsn], List[Trace]]:
-        '''Run a single instruction.
+        '''Run a single cycle.
 
         Returns the instruction, together with a list of the architectural
         changes that have happened. If the model isn't currently running,
@@ -87,8 +89,27 @@
 
         sim_stalled = self.state.non_insn_stall
         if not sim_stalled:
-            # Instruction can stall sim by returning False from `pre_execute`
-            sim_stalled = not insn.pre_execute(self.state)
+            if self._execute_generator is None:
+                # This is the first cycle for an instruction. Run any setup for
+                # the state object and then start running the instruction
+                # itself.
+                self.state.pre_insn(insn.affects_control)
+
+                # Either execute the instruction directly (if it is a
+                # single-cycle instruction without a `yield` in execute()), or
+                # return a generator for multi-cycle instructions. Note that
+                # this doesn't consume the first yielded value.
+                self._execute_generator = insn.execute(self.state)
+
+            if self._execute_generator is not None:
+                # This is a cycle for a multi-cycle instruction (which possibly
+                # started just above)
+                try:
+                    next(self._execute_generator)
+                except StopIteration:
+                    self._execute_generator = None
+
+            sim_stalled = (self._execute_generator is not None)
 
         if sim_stalled:
             self.state.commit(sim_stalled=True)
@@ -98,8 +119,7 @@
             if collect_stats:
                 self.stats.record_stall()
         else:
-            self.state.pre_insn(insn.affects_control)
-            insn.execute(self.state)
+            assert self._execute_generator is None
             self.state.post_insn()
 
             if collect_stats:
diff --git a/hw/ip/otbn/dv/otbnsim/sim/state.py b/hw/ip/otbn/dv/otbnsim/sim/state.py
index a1f8fdb..43586ec 100644
--- a/hw/ip/otbn/dv/otbnsim/sim/state.py
+++ b/hw/ip/otbn/dv/otbnsim/sim/state.py
@@ -35,8 +35,8 @@
         self.dmem = Dmem()
 
         # Stalling support: Instructions can indicate they should stall by
-        # returning false from OTBNInsn.pre_execute. For non instruction related
-        # stalls setting self.non_insn_stall will produce a stall.
+        # yielding in OTBNInsn.execute. For non instruction related stalls,
+        # setting self.non_insn_stall will produce a stall.
         #
         # As a special case, we stall until the URND reseed is completed then
         # stall for one more cycle before fetching the first instruction (to
@@ -127,13 +127,12 @@
             self.non_insn_stall = False
             self.ext_regs.commit()
 
-        self.dmem.commit(sim_stalled)
-
-        # If we're stalled, there's nothing more to do: we only commit when we
-        # finish our stall cycles.
+        # If we're stalled, there's nothing more to do: we only commit the rest
+        # of the architectural state when we finish our stall cycles.
         if sim_stalled:
             return
 
+        self.dmem.commit()
         self.gprs.commit()
         self.pc = self.get_next_pc()
         self._pc_next_override = None