[otbn,dv] Add a generator for instructions using x1 more heavily

This is intended to generate things like "add x1, x1, x1" in more
situations. In particular, it can do so when the call stack is full.
That isn't possible for StraightLineInsn, because that picks operands
one at a time (this instruction is valid for a full call stack, but
"add x1, x0, x0" is not).

Signed-off-by: Rupert Swarbrick <rswarbrick@lowrisc.org>
diff --git a/hw/ip/otbn/dv/rig/rig/configs/base.yml b/hw/ip/otbn/dv/rig/rig/configs/base.yml
index 3f00fd8..51c9fba 100644
--- a/hw/ip/otbn/dv/rig/rig/configs/base.yml
+++ b/hw/ip/otbn/dv/rig/rig/configs/base.yml
@@ -5,6 +5,7 @@
 gen-weights:
   # Generators that can continue the program
   Branch: 0.1
+  CallStackRW: 0.1
   Jump: 0.1
   Loop: 0.1
   LoopDupEnd: 0.01
diff --git a/hw/ip/otbn/dv/rig/rig/gens/call_stack_rw.py b/hw/ip/otbn/dv/rig/rig/gens/call_stack_rw.py
new file mode 100644
index 0000000..90119c7
--- /dev/null
+++ b/hw/ip/otbn/dv/rig/rig/gens/call_stack_rw.py
@@ -0,0 +1,132 @@
+# Copyright lowRISC contributors.
+# Licensed under the Apache License, Version 2.0, see LICENSE for details.
+# SPDX-License-Identifier: Apache-2.0
+
+import random
+from typing import Optional
+
+from shared.insn_yaml import InsnsFile
+from shared.operand import RegOperandType
+
+from ..config import Config
+from ..program import ProgInsn, Program
+from ..model import Model
+from ..snippet import ProgSnippet
+from ..snippet_gen import GenCont, GenRet, SnippetGen
+
+
+class CallStackRW(SnippetGen):
+    '''A snippet generator that tries to exercise the call stack
+
+    We already have code that can exercise the call stack (e.g.
+    StraightLineInsn), but there are certain things that it will never do (pop
+    & push when the stack is full, for example) and other multiple uses of x1
+    aren't particularly frequent. Generate more here!
+
+    '''
+
+    def __init__(self, cfg: Config, insns_file: InsnsFile) -> None:
+        super().__init__()
+
+        # Grab instructions like "add" or "sub", which take two GPRs as inputs
+        # and write one GPR as an output.
+        self.insns = []
+        self.indices = []
+        self.weights = []
+
+        for insn in insns_file.insns:
+            gpr_dsts = []
+            gpr_srcs = []
+
+            for idx, op in enumerate(insn.operands):
+                if not isinstance(op.op_type, RegOperandType):
+                    continue
+                is_gpr = op.op_type.reg_type == 'gpr'
+                is_dst = op.op_type.is_dest()
+
+                if is_gpr:
+                    if is_dst:
+                        gpr_dsts.append(idx)
+                    else:
+                        gpr_srcs.append(idx)
+
+            if len(gpr_dsts) == 1 and len(gpr_srcs) == 2 and insn.lsu is None:
+                weight = cfg.insn_weights.get(insn.mnemonic)
+                if weight > 0:
+                    self.insns.append(insn)
+                    self.indices.append((gpr_dsts[0], gpr_srcs[0], gpr_srcs[1]))
+                    self.weights.append(weight)
+
+        if not self.insns:
+            # All the weights for the instructions we can use are zero
+            self.disabled = True
+
+    def gen(self,
+            cont: GenCont,
+            model: Model,
+            program: Program) -> Optional[GenRet]:
+        # We can't read or write x1 when it's marked const.
+        if model.is_const('gpr', 1):
+            return None
+
+        # Make sure we don't get paint ourselves into a corner
+        if program.get_insn_space_at(model.pc) <= 1:
+            return None
+
+        # Pick an instruction
+        idx = random.choices(range(len(self.weights)), weights=self.weights)[0]
+        grd_idx, grs1_idx, grs2_idx = self.indices[idx]
+        insn = self.insns[idx]
+
+        # This instruction will have one GPR dest and two GPR sources. It might
+        # also have some immediate values. Decide how to fill it out.
+        # Interesting patterns are
+        #
+        #  (0)  add   x1, x1, ??
+        #  (1)  add   x1, ??, x1
+        #  (2)  add   x1, x1, x1
+        #  (3)  add   ??, x1, x1
+        #  (4)  add   x1, ??, ??
+        #
+        # We can generate any of 0-3 as long as the call stack is nonempty
+        # (which we've already guaranteed with pick_weight). We can generate 4
+        # if the stack is not full.
+        min_idx = 4 if model.call_stack.empty() else 0
+        max_idx = 3 if model.call_stack.full() else 4
+        if min_idx > max_idx:
+            # This is possible (call stack both empty and full) because we
+            # might not have full knowledge of the state of the call stack.
+            return None
+
+        flavour = random.randint(min_idx, max_idx)
+        x1_grd = flavour != 3
+        x1_grs1 = flavour not in [1, 4]
+        x1_grs2 = flavour not in [0, 4]
+
+        op_vals = []
+        for idx, operand in enumerate(insn.operands):
+            use_x1 = ((x1_grd and idx == grd_idx) or
+                      (x1_grs1 and idx == grs1_idx) or
+                      (x1_grs2 and idx == grs2_idx))
+            if use_x1:
+                op_vals.append(1)
+            else:
+                # Make sure we don't use x1 when we're not expecting to
+                if not isinstance(operand.op_type, RegOperandType):
+                    weights = None
+                else:
+                    weights = {1: 0.0}
+
+                enc_op_val = model.pick_operand_value(operand.op_type, weights)
+                if enc_op_val is None:
+                    return None
+                op_vals.append(enc_op_val)
+
+        prog_insn = ProgInsn(insn, op_vals, None)
+        snippet = ProgSnippet(model.pc, [prog_insn])
+        snippet.insert_into_program(program)
+
+        model.update_for_insn(prog_insn)
+        model.pc += 4
+
+        return (snippet, False, model)
diff --git a/hw/ip/otbn/dv/rig/rig/model.py b/hw/ip/otbn/dv/rig/rig/model.py
index ac9dab8..8a35a4b 100644
--- a/hw/ip/otbn/dv/rig/rig/model.py
+++ b/hw/ip/otbn/dv/rig/rig/model.py
@@ -209,7 +209,7 @@
         # entry of None means an entry with an architectural value, but where
         # we don't actually know what it is (usually a result of some
         # arithmetic operation that got written to x1).
-        self._call_stack = CallStack()
+        self.call_stack = CallStack()
 
         # The loop stack.
         self.loop_stack = LoopStack()
@@ -249,7 +249,7 @@
         for entry in self._const_stack:
             ret._const_stack.append({n: regs.copy()
                                      for n, regs in entry.items()})
-        ret._call_stack = self._call_stack.copy()
+        ret.call_stack = self.call_stack.copy()
         ret.loop_stack = self.loop_stack.copy()
         ret._known_mem = {n: mem.copy()
                           for n, mem in self._known_mem.items()}
@@ -318,7 +318,7 @@
 
         assert self._const_stack == other._const_stack
 
-        self._call_stack.merge(other._call_stack)
+        self.call_stack.merge(other.call_stack)
         self.loop_stack.merge(other.loop_stack)
 
         for mem_type, self_mem in self._known_mem.items():
@@ -336,7 +336,7 @@
         if reg_type == 'gpr' and idx == 1:
             # We shouldn't ever read from x1 if it is marked constant
             assert not self.is_const('gpr', 1)
-            self._call_stack.pop()
+            self.call_stack.pop()
 
     def write_reg(self,
                   reg_type: str,
@@ -364,7 +364,7 @@
 
             if idx == 1:
                 # Special-case writes to x1
-                self._call_stack.write(value, update)
+                self.call_stack.write(value, update)
                 return
 
         self._known_regs.setdefault(reg_type, {})[idx] = value
@@ -372,7 +372,7 @@
     def get_reg(self, reg_type: str, idx: int) -> Optional[int]:
         '''Get a register value, if known.'''
         if reg_type == 'gpr' and idx == 1:
-            return self._call_stack.peek()
+            return self.call_stack.peek()
 
         return self._known_regs.setdefault(reg_type, {}).get(idx)
 
@@ -446,9 +446,9 @@
         # when the stack is full, because we only do one operand at a time.
         if op_type.reg_type == 'gpr':
             can_use_x1 = not self.is_const('gpr', 1)
-            if is_src and self._call_stack.empty():
+            if is_src and self.call_stack.empty():
                 can_use_x1 = False
-            if is_dst and self._call_stack.full():
+            if is_dst and self.call_stack.full():
                 can_use_x1 = False
 
             # Since x1 isn't tracked in known_regs, we add it here if wanted
@@ -533,8 +533,8 @@
         # not None and can be read iff it isn't marked constant.
         if reg_type == 'gpr':
             assert 1 not in known_regs
-            if not self._call_stack.empty():
-                x1 = self._call_stack.peek()
+            if not self.call_stack.empty():
+                x1 = self.call_stack.peek()
                 if x1 is not None:
                     if not self.is_const('gpr', 1):
                         ret.append((1, x1))
@@ -550,7 +550,7 @@
         # stack is not empty.
         if reg_type == 'gpr':
             assert 1 not in arch_regs
-            if not self._call_stack.empty():
+            if not self.call_stack.empty():
                 if not self.is_const('gpr', 1):
                     arch_regs.append(1)
 
@@ -603,7 +603,7 @@
 
             # x1 (the call stack) has different handling
             if reg_idx == 1:
-                self._call_stack.forget_value()
+                self.call_stack.forget_value()
                 return
 
         # Set the value in known_regs to None, but only if the register already
diff --git a/hw/ip/otbn/dv/rig/rig/snippet_gens.py b/hw/ip/otbn/dv/rig/rig/snippet_gens.py
index 868989e..67709c4 100644
--- a/hw/ip/otbn/dv/rig/rig/snippet_gens.py
+++ b/hw/ip/otbn/dv/rig/rig/snippet_gens.py
@@ -14,6 +14,7 @@
 from .snippet_gen import GenRet, SnippetGen
 
 from .gens.branch import Branch
+from .gens.call_stack_rw import CallStackRW
 from .gens.ecall import ECall
 from .gens.jump import Jump
 from .gens.loop import Loop
@@ -33,6 +34,7 @@
     '''A collection of snippet generators'''
     _CLASSES = [
         Branch,
+        CallStackRW,
         Jump,
         Loop,
         LoopDupEnd,