[otbn,dv] Make load/store error handling explicit in ISS

This should be nicer in the generated documentation (and also make
multiple-error behaviour clearer). Also, it's a required staging
commit for sorting out our stalling behaviour on errors.

Signed-off-by: Rupert Swarbrick <rswarbrick@lowrisc.org>
diff --git a/hw/ip/otbn/dv/otbnsim/sim/dmem.py b/hw/ip/otbn/dv/otbnsim/sim/dmem.py
index 4e09186..44f47e7 100644
--- a/hw/ip/otbn/dv/otbnsim/sim/dmem.py
+++ b/hw/ip/otbn/dv/otbnsim/sim/dmem.py
@@ -7,7 +7,6 @@
 
 from shared.mem_layout import get_memory_layout
 
-from .err_bits import BAD_DATA_ADDR
 from .trace import Trace
 
 
@@ -61,8 +60,6 @@
         self.data = [uninit] * num_words
         self.trace = []  # type: List[TraceDmemStore]
 
-        self.err_flag = False
-
         self._load_begun = False
         self._load_ready = False
 
@@ -143,38 +140,44 @@
             u32s += self._get_u32s(idx)
         return struct.pack('<{}I'.format(len(u32s)), *u32s)
 
+    def is_valid_256b_addr(self, addr: int) -> bool:
+        '''Return true if this is a valid address for a BN.LID/BN.SID'''
+        assert addr >= 0
+        if addr & 31:
+            return False
+
+        word_addr = addr // 32
+        if word_addr >= len(self.data):
+            return False
+
+        return True
+
     def load_u256(self, addr: int) -> int:
         '''Read a u256 little-endian value from an aligned address'''
         assert addr >= 0
+        assert self.is_valid_256b_addr(addr)
 
-        if addr & 31:
-            self.err_flag = True
-            return 0
-
-        word_addr = addr // 32
-
-        if word_addr >= len(self.data):
-            self.err_flag = True
-            return 0
-
-        return self.data[word_addr]
+        return self.data[addr // 32]
 
     def store_u256(self, addr: int, value: int) -> None:
         '''Write a u256 little-endian value to an aligned address'''
         assert addr >= 0
         assert 0 <= value < (1 << 256)
-
-        if addr & 31:
-            self.err_flag = True
-            return
-
-        word_addr = addr // 32
-        if word_addr >= len(self.data):
-            self.err_flag = True
-            return
+        assert self.is_valid_256b_addr(addr)
 
         self.trace.append(TraceDmemStore(addr, value, True))
 
+    def is_valid_32b_addr(self, addr: int) -> bool:
+        '''Return true if this is a valid address for a LW/SW instruction'''
+        assert addr >= 0
+        if addr & 3:
+            return False
+
+        if (addr + 3) // 32 >= len(self.data):
+            return False
+
+        return True
+
     def load_u32(self, addr: int) -> int:
         '''Read a 32-bit value from memory.
 
@@ -182,14 +185,7 @@
         32-bit integer.
 
         '''
-        assert addr >= 0
-        if addr & 3:
-            self.err_flag = True
-            return 0
-
-        if (addr + 3) // 32 >= len(self.data):
-            self.err_flag = True
-            return 0
+        assert self.is_valid_32b_addr(addr)
 
         idx32 = addr // 4
         idxW = idx32 // 8
@@ -205,20 +201,10 @@
         '''
         assert addr >= 0
         assert 0 <= value <= (1 << 32) - 1
-
-        if addr & 3:
-            self.err_flag = True
-            return
-
-        if (addr + 3) // 32 >= len(self.data):
-            self.err_flag = True
-            return
+        assert self.is_valid_32b_addr(addr)
 
         self.trace.append(TraceDmemStore(addr, value, False))
 
-    def err_bits(self) -> int:
-        return BAD_DATA_ADDR if self.err_flag else 0
-
     def changes(self) -> Sequence[Trace]:
         return self.trace
 
@@ -244,8 +230,6 @@
         self._set_u32s(idxW, u32s)
 
     def commit(self, stalled: bool) -> None:
-        assert not self.err_flag
-
         if self._load_begun:
             self._load_begun = False
             self._load_ready = True
@@ -258,7 +242,6 @@
 
     def abort(self) -> None:
         self.trace = []
-        self.err_flag = False
 
     def in_progress_load_complete(self) -> bool:
         '''Returns true if a previously started load has completed'''
diff --git a/hw/ip/otbn/dv/otbnsim/sim/insn.py b/hw/ip/otbn/dv/otbnsim/sim/insn.py
index 02590d8..0afc83f 100644
--- a/hw/ip/otbn/dv/otbnsim/sim/insn.py
+++ b/hw/ip/otbn/dv/otbnsim/sim/insn.py
@@ -10,7 +10,7 @@
                   RV32ImmShift, insn_for_mnemonic, logical_byte_shift,
                   extract_quarter_word)
 from .state import OTBNState
-from .err_bits import ILLEGAL_INSN
+from .err_bits import BAD_DATA_ADDR, ILLEGAL_INSN
 
 
 class ADD(RV32RegReg):
@@ -184,8 +184,12 @@
     def execute(self, state: OTBNState) -> None:
         base = state.gprs.get_reg(self.grs1).read_unsigned()
         addr = (base + self.offset) & ((1 << 32) - 1)
-        result = state.dmem.load_u32(addr)
-        state.gprs.get_reg(self.grd).write_unsigned(result)
+
+        if not state.dmem.is_valid_32b_addr(addr):
+            state.stop_at_end_of_cycle(BAD_DATA_ADDR)
+        else:
+            result = state.dmem.load_u32(addr)
+            state.gprs.get_reg(self.grd).write_unsigned(result)
 
 
 class SW(OTBNInsn):
@@ -201,7 +205,11 @@
         base = state.gprs.get_reg(self.grs1).read_unsigned()
         addr = (base + self.offset) & ((1 << 32) - 1)
         value = state.gprs.get_reg(self.grs2).read_unsigned()
-        state.dmem.store_u32(addr, value)
+
+        if not state.dmem.is_valid_32b_addr(addr):
+            state.stop_at_end_of_cycle(BAD_DATA_ADDR)
+        else:
+            state.dmem.store_u32(addr, value)
 
 
 class BEQ(OTBNInsn):
@@ -899,6 +907,8 @@
 
         if grd_val > 31:
             state.stop_at_end_of_cycle(ILLEGAL_INSN)
+        elif not state.dmem.is_valid_256b_addr(addr):
+            state.stop_at_end_of_cycle(BAD_DATA_ADDR)
         else:
             wrd = grd_val & 0x1f
             value = state.dmem.load_u256(addr)
@@ -935,6 +945,8 @@
 
         if grs2_val > 31:
             state.stop_at_end_of_cycle(ILLEGAL_INSN)
+        elif not state.dmem.is_valid_256b_addr(addr):
+            state.stop_at_end_of_cycle(BAD_DATA_ADDR)
         else:
             wrs = grs2_val & 0x1f
             wrs_val = state.wdrs.get_reg(wrs).read_unsigned()
diff --git a/hw/ip/otbn/dv/otbnsim/sim/state.py b/hw/ip/otbn/dv/otbnsim/sim/state.py
index 80af4b4..a1f8fdb 100644
--- a/hw/ip/otbn/dv/otbnsim/sim/state.py
+++ b/hw/ip/otbn/dv/otbnsim/sim/state.py
@@ -254,9 +254,7 @@
         self.loop_step()
         self.gprs.post_insn()
 
-        self._err_bits |= (self.gprs.err_bits() |
-                           self.dmem.err_bits() |
-                           self.loop_stack.err_bits())
+        self._err_bits |= self.gprs.err_bits() | self.loop_stack.err_bits()
         if self._err_bits:
             self.pending_halt = True