[otbn,dv] Refactor how the RIG ends programs

We want to add some generators that end programs with something other
than an ECALL (by generating errors).

Generalise things so that a generator class can say it does so (with
the new ends_program class variable) and change some variable names to
reflect this.

Also, change how we generate the top-level program. We now decide how
long the "head" and "tail" should be, respectively, and then generate
them separately. This makes some of the logic rather simpler,
especially when we start adding snippet generators that cause errors.

Signed-off-by: Rupert Swarbrick <rswarbrick@lowrisc.org>
diff --git a/hw/ip/otbn/dv/rig/rig/gens/ecall.py b/hw/ip/otbn/dv/rig/rig/gens/ecall.py
index a5ab5f4..cd81c34 100644
--- a/hw/ip/otbn/dv/rig/rig/gens/ecall.py
+++ b/hw/ip/otbn/dv/rig/rig/gens/ecall.py
@@ -15,6 +15,9 @@
 
 class ECall(SnippetGen):
     '''A generator that makes a snippet with a single ECALL instruction'''
+
+    ends_program = True
+
     def __init__(self, cfg: Config, insns_file: InsnsFile) -> None:
         super().__init__()
 
diff --git a/hw/ip/otbn/dv/rig/rig/rig.py b/hw/ip/otbn/dv/rig/rig/rig.py
index a6fc573..ac4f2f3 100644
--- a/hw/ip/otbn/dv/rig/rig/rig.py
+++ b/hw/ip/otbn/dv/rig/rig/rig.py
@@ -58,5 +58,5 @@
         raise RuntimeError('Failed to initialise snippet generators: {}'
                            .format(err)) from None
 
-    snippet, end_addr = gens.gen_rest(model, program)
+    snippet, end_addr = gens.gen_program(model, program)
     return init_data, snippet, end_addr
diff --git a/hw/ip/otbn/dv/rig/rig/snippet_gen.py b/hw/ip/otbn/dv/rig/rig/snippet_gen.py
index 5ba4cd8..34b2f26 100644
--- a/hw/ip/otbn/dv/rig/rig/snippet_gen.py
+++ b/hw/ip/otbn/dv/rig/rig/snippet_gen.py
@@ -45,6 +45,10 @@
     binary.
 
     '''
+    # A class-level variable that is set for generators that will end the
+    # program (with an ECALL or an error)
+    ends_program: bool = False
+
     def __init__(self) -> None:
         self.disabled = False
 
diff --git a/hw/ip/otbn/dv/rig/rig/snippet_gens.py b/hw/ip/otbn/dv/rig/rig/snippet_gens.py
index b215c8b..e50229d 100644
--- a/hw/ip/otbn/dv/rig/rig/snippet_gens.py
+++ b/hw/ip/otbn/dv/rig/rig/snippet_gens.py
@@ -31,13 +31,15 @@
     ]
 
     def __init__(self, cfg: Config, insns_file: InsnsFile) -> None:
-        self.generators = []  # type: List[Tuple[SnippetGen, float]]
+        self._cont_generators = []  # type: List[Tuple[SnippetGen, float]]
+        self._end_generators = []   # type: List[Tuple[SnippetGen, float]]
 
         # Grab an ECall generator. We'll use it in self.gens to append an ECALL
         # instruction if necessary.
         ecall = None
 
         used_names = set()
+
         for cls in SnippetGens._CLASSES:
             cls_name = cls.__name__
             weight = cfg.gen_weights.values.get(cls_name)
@@ -48,21 +50,27 @@
             gen = cls(cfg, insns_file)
             if isinstance(gen, ECall):
                 ecall = gen
+
                 # The ECall generator mustn't disable itself
                 assert not gen.disabled
 
+                # It also shouldn't be disabled by a config
+                if weight == 0:
+                    raise ValueError(f'Config at {cfg.path} gives zero weight '
+                                     f'to the ECall generator.')
+
             assert cls_name not in used_names
             used_names.add(cls_name)
 
             if weight > 0 and not gen.disabled:
-                self.generators.append((gen, weight))
+                pr = (gen, weight)
+                if cls.ends_program:
+                    self._end_generators.append(pr)
+                else:
+                    self._cont_generators.append(pr)
 
-        # Check that at least one generator has positive weight and wasn't
-        # disabled
-        if not self.generators:
-            raise ValueError('Config at {} disables or gives zero '
-                             'weight to all generators.'
-                             .format(cfg.path))
+        # self._end_generators should include ECall, at least.
+        assert self._end_generators
 
         # Check that we used all the names in cfg.gen_weights
         unused_names = set(cfg.gen_weights.values.keys()) - used_names
@@ -79,40 +87,39 @@
     def gen(self,
             model: Model,
             program: Program,
-            ecall: bool) -> Optional[GenRet]:
+            end: bool) -> Optional[GenRet]:
         '''Pick a snippet and update model, program with its contents.
 
         This assumes that program.get_insn_space_at(model.pc) > 0.
 
         Normally returns a GenRet tuple with the same meanings as Snippet.gen.
-        If the chosen snippet would generate an ECALL and ecall is False, this
-        instead returns None (and leaves model and program unchanged).
+        If the chosen snippet would cause execution to halt but end is False,
+        this instead returns None (and leaves model and program unchanged).
 
         '''
         # If we've run out of fuel, stop immediately.
         if not model.fuel:
             return None
 
+        generators = self._end_generators if end else self._cont_generators
+
         real_weights = []
-        pos_weights = 0
-        for generator, weight in self.generators:
-            if isinstance(generator, ECall) and not ecall:
-                real_weight = 0.0
-            else:
-                weight_mult = generator.pick_weight(model, program)
-                real_weight = weight * weight_mult
+        num_pos_weights = 0
+        for generator, weight in generators:
+            weight_mult = generator.pick_weight(model, program)
+            real_weight = weight * weight_mult
 
             assert real_weight >= 0
             if real_weight > 0:
-                pos_weights += 1
+                num_pos_weights += 1
 
             real_weights.append(real_weight)
 
-        while pos_weights > 0:
+        while num_pos_weights > 0:
             # Pick a generator based on the weights in real_weights.
-            idx = random.choices(range(len(self.generators)),
+            idx = random.choices(range(len(generators)),
                                  weights=real_weights)[0]
-            generator, _ = self.generators[idx]
+            generator, _ = generators[idx]
 
             # Note that there should always be at least one positive weight in
             # real_weights. random.choices doesn't check that: if you pass all
@@ -129,7 +136,7 @@
             # If gen_res is None, the generator failed. Set that weight to zero
             # and try again.
             real_weights[idx] = 0.0
-            pos_weights -= 1
+            num_pos_weights -= 1
 
         # We ran out of generators with positive weight. Give up.
         return None
@@ -142,39 +149,38 @@
     def _gens(self,
               model: Model,
               program: Program,
-              ecall: bool) -> Tuple[List[Snippet], Model]:
+              end: bool) -> Tuple[List[Snippet], Model]:
         '''Generate some snippets to continue program.
 
-        This will try to run down model.fuel and program.size. If ecall is
-        True, it will eventually generate an ECALL instruction. If ecall is
-        False then instead of generating the ECALL instruction, it will instead
-        stop (leaving model.pc where the ECALL instruction would have been
-        inserted).
+        This will try to run down model.fuel and program.size. If end is True,
+        it will eventually generate an ECALL instruction or cause an error. If
+        end is False then, instead of doing that, it will instead stop (leaving
+        model.pc at the place where the next instruction should be inserted).
 
         '''
         children = []  # type: List[Snippet]
         while True:
-            must_stop = False
-            # If we've run out of space and ecall is False, we stop
-            # immediately. If ecall is True, we need to generate one last ECALL
+            gen_ecall = False
+            # If we've run out of space and end is False, we stop immediately.
+            # If end is True, we need to bail out now! Generate one last ECALL
             # instruction.
             if not program.space:
-                if ecall:
-                    must_stop = True
+                if end:
+                    gen_ecall = True
                 else:
                     break
 
             old_fuel = model.fuel
             gen_res = None
-            if not must_stop:
-                gen_res = self.gen(model, program, ecall)
+            if not gen_ecall:
+                gen_res = self.gen(model, program, end)
 
             if gen_res is None:
-                # We failed to generate another snippet. If ecall is False,
+                # We failed to generate another snippet. If end is False,
                 # that's fine: we've probably run out of fuel and should stop.
-                # If ecall is True, that's bad news: we can't just leave the
+                # If end is True, that's bad news: we can't just leave the
                 # program unfinished. In that case, force an ECALL instruction.
-                if ecall:
+                if end:
                     children.append(self._gen_ecall(model.pc, program))
                 break
 
@@ -182,7 +188,7 @@
             children.append(snippet)
 
             if done:
-                assert ecall
+                assert end
                 break
 
             assert model.fuel < old_fuel
@@ -205,16 +211,38 @@
         snippet = Snippet.merge_list(snippets) if snippets else None
         return (snippet, next_model)
 
-    def gen_rest(self, model: Model, program: Program) -> Tuple[Snippet, int]:
-        '''Generate the rest of the program, ending with an ECALL
+    def gen_program(self,
+                    model: Model,
+                    program: Program) -> Tuple[Snippet, int]:
+        '''Generate an entire program
 
-        Returns a pair (snippet, end_addr) where snippet is the Snippet object
-        representing the program and end_addr is the address of the final
-        instruction that should be executed.
+        This is the top-level entry point used by gen_program in rig.py.
 
         '''
-        snippets, next_model = self._gens(model, program, True)
+        # The basic strategy is that we split the length of program execution
+        # (given by model.fuel) into a head and a tail. We generate the "head
+        # part", trying not to generate an ECALL or error. Following that, we
+        # generate the "tail part", explicitly aiming to end in an ECALL or
+        # error.
+
+        head_fuel_min = max(1, int(model.fuel * 0.9))
+        head_fuel_max = model.fuel - 1
+        head_fuel = random.randint(head_fuel_min, head_fuel_max)
+        assert 1 <= head_fuel < model.fuel
+
+        tail_fuel = model.fuel - head_fuel
+
+        model.fuel = head_fuel
+        head, model = self.gens(model, program)
+        # Add the rest of the fuel to the tank
+        model.fuel += tail_fuel
+
+        tail_snippets, end_model = self._gens(model, program, True)
         # _gens() always generates at least one snippet (containing the final
-        # ECALL). The model it returns points at the ECALL instruction's PC.
-        assert snippets
-        return (Snippet.merge_list(snippets), next_model.pc)
+        # ECALL or an instruction that causes an error). The model it returns
+        # points at that final instruction's PC.
+        assert tail_snippets
+
+        snippets = tail_snippets if head is None else [head] + tail_snippets
+
+        return (Snippet.merge_list(snippets), end_model.pc)