[otbn] Tidy up how we deal with fuel in RIG

We used to have quite a bit of complicated juggling, where every
generator made sure to leave enough fuel for a following ECALL
instruction. Stop doing that: it's way more complicated than it needs
to be. Generators still have to make sure they don't paint themselves
in a corner (by running out of actual space), but they can now exhaust
fuel without worrying.

This makes it much easier to do recursive generation. It's needed in
particular for sane loop generation, where we want to constrain
model.fuel dramatically when we're generating loop bodies.

While we're changing this, we also refactor the methods in SnippetGens
a little, to reflect what's going on a bit more obviously in the types.

Signed-off-by: Rupert Swarbrick <rswarbrick@lowrisc.org>
diff --git a/hw/ip/otbn/util/rig/gens/branch.py b/hw/ip/otbn/util/rig/gens/branch.py
index 3d10bf9..72b4742 100644
--- a/hw/ip/otbn/util/rig/gens/branch.py
+++ b/hw/ip/otbn/util/rig/gens/branch.py
@@ -11,7 +11,7 @@
 from .jump import Jump
 from ..program import ProgInsn, Program
 from ..model import Model
-from ..snippet import BranchSnippet, ProgSnippet
+from ..snippet import BranchSnippet, ProgSnippet, Snippet
 from ..snippet_gen import GenCont, GenRet, SnippetGen
 
 
@@ -58,8 +58,8 @@
         #
         # We'll need at least 4 instructions' space for a proper branch: the
         # branch instruction, the fall-through instruction, the branch target
-        # (which will jump back if necessary), and an eventual ECALL)
-        if program.get_insn_space_left() < 4:
+        # (which will jump back if necessary), and a continuation.
+        if program.space < 4:
             fall_thru = True
         else:
             fall_thru = random.random() < 0.01
@@ -72,12 +72,6 @@
             model: Model,
             program: Program) -> Optional[GenRet]:
 
-        if model.fuel <= 1:
-            # The shortest possible branch sequence (branch to PC + 4) takes an
-            # instruction and needs at least one instruction afterwards for the
-            # ECALL, so don't generate anything if fuel is less than 2.
-            return None
-
         # Return None if this is the last instruction in the current gap
         # because we need to either jump or do an ECALL to avoid getting stuck
         # (just like the StraightLineInsn generator)
@@ -139,6 +133,7 @@
                             (1, (0.5, 1.0))]
         fuel_frac = self.pick_from_weighted_ranges(fuel_frac_ranges)
         assert 0 <= fuel_frac <= 1
+        assert 0 < model.fuel
         branch_fuel = max(1, int(0.5 + fuel_frac * model.fuel))
 
         # Similarly, decide how much of our remaining space to give the code
@@ -147,14 +142,20 @@
         space_frac_ranges = fuel_frac_ranges
         space_frac = self.pick_from_weighted_ranges(space_frac_ranges)
         assert 0 <= space_frac <= 1
-        # Subtract 2: one for the branch instruction and one for an eventual
-        # ECALL. We checked earlier we had at least 4 instructions' space left,
-        # so there should always be at least 2 instructions' space left
-        # afterwards.
-        max_space_for_branches = program.get_insn_space_left() - 2
-        assert max_space_for_branches >= 2
-        branch_space = max(1, int(space_frac * (max_space_for_branches / 2)))
-        assert 2 * branch_space <= max_space_for_branches
+        # Subtract 2: one for the branch instruction and one for the start of
+        # whatever happens after we re-merge. We checked earlier we had at
+        # least 4 instructions' space left, so there should always be at least
+        # 2 instructions' space left afterwards.
+        #
+        # Divide that by two (to apportion it between the branches) and then
+        # subtract one more. This last subtraction is to allow the jump to
+        # recombine.
+        space_for_branches = program.space - 2
+        assert space_for_branches >= 2
+        total_branch_space = max(1, int(space_frac * (space_for_branches / 2)))
+        assert 1 <= 2 * total_branch_space <= space_for_branches
+        branch_space = total_branch_space - 1
+        assert 0 <= branch_space
 
         # Make an updated copy of program that includes the branch instruction.
         # Similarly, take a copy of the model and update it as if we've fallen
@@ -173,60 +174,66 @@
         model0.pc += 4
         prog0.add_insns(tgt_addr, [branch_insn])
 
-        model0.fuel = branch_fuel
-        prog0.constrain_space(branch_space)
+        # Pass branch_fuel - 1 here to give ourselves space for a jump to
+        # branch 1 (if we decide to do it that way around)
+        model0.fuel = branch_fuel - 1
+        prog0.space = branch_space
 
-        ret0 = cont(model0, prog0)
-        if ret0 is None:
-            return None
+        snippet0, model0 = cont(model0, prog0)
+        # If snippet0 is None then we didn't manage to generate anything, but
+        # that's fine. model0 will be unchanged and we just have an empty
+        # sequence on the fall-through side of the branch.
 
-        snippet0, model0 = ret0
-
-        # We successfully generated the fall-through branch. Now we want to
-        # generate the other side. Make another copy of program and insert the
-        # instructions from snippet0 into it. Add the bogus instruction at
-        # model.pc, as above. Also add a bogus instruction at model0.pc: this
-        # represents "the next thing" that happens at the end of the first
-        # branch, and we mustn't allow snippet1 to use that space.
+        # Now generate the other side. Make another copy of program and insert
+        # any instructions from snippet0 into it. Add the bogus instruction at
+        # model.pc, as above. Also add a bogus instruction at model0.pc (if it
+        # doesn't equal model.pc): this represents "the next thing" that
+        # happens at the end of the first branch, and we mustn't allow snippet1
+        # to use that space.
         prog1 = program.copy()
-        snippet0.insert_into_program(prog1)
         prog1.add_insns(model.pc, [branch_insn])
-        prog1.add_insns(model0.pc, [branch_insn])
+        if snippet0 is None:
+            assert model0.pc == model.pc + 4
+        else:
+            snippet0.insert_into_program(prog1)
+            prog1.add_insns(model0.pc, [branch_insn])
 
         model1 = model.copy()
         model1.update_for_insn(branch_insn)
         model1.pc = tgt_addr
 
-        prog1.constrain_space(branch_space)
-        model1.fuel = branch_fuel
+        prog1.space = branch_space
+        model1.fuel = branch_fuel - 1
 
-        ret1 = cont(model1, prog1)
-        if ret1 is None:
-            return None
+        snippet1, model1 = cont(model1, prog1)
+        # If snippet1 is None, we didn't manage to generate anything here
+        # either. As before, that's fine.
 
-        snippet1, model1 = ret1
-
-        # We've managed to generate both sides of the branch. All that's left
-        # to do is fix up the execution paths to converge again. To do this, we
-        # need to add a jump to one side or the other. (Alternatively, we could
-        # jump from both to another address, but this shouldn't provide any
-        # extra coverage, so there's not much point)
+        # We've now generated both sides of the branch. All that's left to do
+        # is fix up the execution paths to converge again. To do this, we need
+        # to add a jump to one side or the other. (Alternatively, we could jump
+        # from both to another address, but this shouldn't provide any extra
+        # coverage, so there's not much point)
         if random.random() < 0.5:
-            # Add the jump to go from branch 0 to branch 1
+            # Add the jump to go from branch 0 to branch 1.
+            model0.fuel += 1
+            prog0.space += 1
             jump_ret = self.jump_gen.gen_tgt(model0, prog0, model1.pc)
             if jump_ret is None:
                 return None
 
             jmp_snippet, model0 = jump_ret
-            snippet0 = snippet0.merge(jmp_snippet)
+            snippet0 = Snippet.cons_option(snippet0, jmp_snippet)
         else:
             # Add the jump to go from branch 1 to branch 0
+            model1.fuel += 1
+            prog1.space += 1
             jump_ret = self.jump_gen.gen_tgt(model1, prog1, model0.pc)
             if jump_ret is None:
                 return None
 
             jmp_snippet, model1 = jump_ret
-            snippet1 = snippet1.merge(jmp_snippet)
+            snippet1 = Snippet.cons_option(snippet1, jmp_snippet)
 
         assert model0.pc == model1.pc
         model0.merge(model1)
diff --git a/hw/ip/otbn/util/rig/gens/ecall.py b/hw/ip/otbn/util/rig/gens/ecall.py
index 047baff..5ead6af 100644
--- a/hw/ip/otbn/util/rig/gens/ecall.py
+++ b/hw/ip/otbn/util/rig/gens/ecall.py
@@ -8,7 +8,7 @@
 
 from ..program import ProgInsn, Program
 from ..model import Model
-from ..snippet import ProgSnippet
+from ..snippet import ProgSnippet, Snippet
 from ..snippet_gen import GenCont, GenRet, SnippetGen
 
 
@@ -27,25 +27,25 @@
 
         self.insn = ProgInsn(ecall_insn, [], None)
 
+    def gen_at(self, pc: int, program: Program) -> Snippet:
+        '''Generate an ECALL instruction at pc and insert into program'''
+        snippet = ProgSnippet(pc, [self.insn])
+        snippet.insert_into_program(program)
+        return snippet
+
     def gen(self,
             cont: GenCont,
             model: Model,
             program: Program) -> Optional[GenRet]:
-        snippet = ProgSnippet(model.pc, [self.insn])
-        snippet.insert_into_program(program)
-        return (snippet, None)
+        return (self.gen_at(model.pc, program), None)
 
     def pick_weight(self,
                     model: Model,
                     program: Program) -> float:
         # Choose small weights when we've got lots of room and large ones when
         # we haven't.
-        fuel = model.fuel
-        space = program.get_insn_space_left()
-        assert fuel > 0
-        assert space > 0
-
-        room = min(fuel, space)
+        room = min(model.fuel, program.space)
+        assert 0 < room
         return (1e-10 if room > 5
                 else 0.1 if room > 1
                 else 1e10)
diff --git a/hw/ip/otbn/util/rig/model.py b/hw/ip/otbn/util/rig/model.py
index f476a06..6f4ee30 100644
--- a/hw/ip/otbn/util/rig/model.py
+++ b/hw/ip/otbn/util/rig/model.py
@@ -788,8 +788,9 @@
             self.touch_mem(mem_type, addr, insn.lsu.idx_width)
 
     def consume_fuel(self) -> None:
-        '''Consume one item of fuel, but bottom out at fuel == 1'''
-        self.fuel = max(1, self.fuel - 1)
+        '''Consume one item of fuel'''
+        assert self.fuel > 0
+        self.fuel -= 1
 
     def update_for_insn(self, prog_insn: ProgInsn) -> None:
         # If this is a sufficiently simple operation that we understand the
diff --git a/hw/ip/otbn/util/rig/program.py b/hw/ip/otbn/util/rig/program.py
index 0988968..4337919 100644
--- a/hw/ip/otbn/util/rig/program.py
+++ b/hw/ip/otbn/util/rig/program.py
@@ -167,7 +167,7 @@
         # The number of instructions' space available. If we aren't below any
         # branches, this is the space available in imem. When we're branching,
         # this might be less.
-        self._space = self.imem_size // 4
+        self.space = self.imem_size // 4
 
     def copy(self) -> 'Program':
         '''Return a a shallow copy of the program
@@ -180,7 +180,7 @@
                       self.dmem_lma, self.dmem_size)
         ret._sections = {base: section.copy()
                          for base, section in self._sections.items()}
-        ret._space = self._space
+        ret.space = self.space
         return ret
 
     def add_insns(self, addr: int, insns: List[ProgInsn]) -> None:
@@ -188,8 +188,8 @@
         assert addr & 3 == 0
         assert addr <= self.imem_size
 
-        assert len(insns) <= self._space
-        self._space -= len(insns)
+        assert len(insns) <= self.space
+        self.space -= len(insns)
 
         sec_top = addr + 4 * len(insns)
 
@@ -498,7 +498,12 @@
         return tgts[0]
 
     def get_insn_space_at(self, addr: int) -> int:
-        '''Return how many instructions there is space for, starting at addr'''
+        '''Return how many instructions there is space for, starting at addr
+
+        Note that this doesn't take the global space constraint into
+        account.
+
+        '''
         space = self.imem_size - addr
         if space <= 0:
             return 0
@@ -511,12 +516,3 @@
                     return 0
 
         return max(0, space // 4)
-
-    def get_insn_space_left(self) -> int:
-        '''Return how many more instructions there is space for'''
-        return self._space
-
-    def constrain_space(self, space: int) -> None:
-        '''Constrain the amount of space available'''
-        assert space <= self._space
-        self._space = space
diff --git a/hw/ip/otbn/util/rig/rig.py b/hw/ip/otbn/util/rig/rig.py
index e490f83..5f6ecb6 100644
--- a/hw/ip/otbn/util/rig/rig.py
+++ b/hw/ip/otbn/util/rig/rig.py
@@ -50,8 +50,5 @@
     for addr in init_data.keys():
         model.touch_mem('dmem', addr, 4)
 
-    ret = SnippetGens(insns_file).gens(model, program, True)
-    assert ret is not None
-    snippet, _ = ret
-
+    snippet = SnippetGens(insns_file).gen_rest(model, program)
     return init_data, snippet
diff --git a/hw/ip/otbn/util/rig/snippet.py b/hw/ip/otbn/util/rig/snippet.py
index f283794..744320f 100644
--- a/hw/ip/otbn/util/rig/snippet.py
+++ b/hw/ip/otbn/util/rig/snippet.py
@@ -221,8 +221,9 @@
     def __init__(self,
                  addr: int,
                  branch_insn: ProgInsn,
-                 snippet0: Snippet,
-                 snippet1: Snippet):
+                 snippet0: Optional[Snippet],
+                 snippet1: Optional[Snippet]):
+        assert snippet0 is not None or snippet1 is not None
         self.addr = addr
         self.branch_insn = branch_insn
         self.snippet0 = snippet0
@@ -230,16 +231,18 @@
 
     def insert_into_program(self, program: Program) -> None:
         program.add_insns(self.addr, [self.branch_insn])
-        self.snippet0.insert_into_program(program)
+        if self.snippet0 is not None:
+            self.snippet0.insert_into_program(program)
         if self.snippet1 is not None:
             self.snippet1.insert_into_program(program)
 
     def to_json(self) -> object:
+        js0 = None if self.snippet0 is None else self.snippet0.to_json()
         js1 = None if self.snippet1 is None else self.snippet1.to_json()
         return ['BS',
                 self.addr,
                 self.branch_insn.to_json(),
-                self.snippet0.to_json(),
+                js0,
                 js1]
 
     @staticmethod
@@ -259,7 +262,13 @@
         bi_where = 'branch instruction for snippet {}'.format(idx)
         branch_insn = ProgInsn.from_json(insns_file, bi_where, j_branch_insn)
 
-        snippet0 = Snippet.from_json(insns_file, idx + [0], j_snippet0)
-        snippet1 = Snippet.from_json(insns_file, idx + [1], j_snippet1)
+        snippet0 = (None if j_snippet0 is None
+                    else Snippet.from_json(insns_file, idx + [0], j_snippet0))
+        snippet1 = (None if j_snippet1 is None
+                    else Snippet.from_json(insns_file, idx + [1], j_snippet1))
+
+        if snippet0 is None and snippet1 is None:
+            raise ValueError('Both sides of branch snippet {} are None.'
+                             .format(idx))
 
         return BranchSnippet(addr, branch_insn, snippet0, snippet1)
diff --git a/hw/ip/otbn/util/rig/snippet_gen.py b/hw/ip/otbn/util/rig/snippet_gen.py
index 50d736d..5d175ba 100644
--- a/hw/ip/otbn/util/rig/snippet_gen.py
+++ b/hw/ip/otbn/util/rig/snippet_gen.py
@@ -17,10 +17,6 @@
 from .model import Model
 from .snippet import Snippet
 
-# A continuation type that allows a generator to recursively generate some more
-# stuff.
-GenCont = Callable[[Model, Program], Optional[Tuple[Snippet, Model]]]
-
 # The return type of a single generator. This is a tuple (snippet, model).
 # snippet is a generated snippet. If the program is done (i.e. every execution
 # ends with ecall) then model is None. Otherwise it is a Model object
@@ -28,6 +24,14 @@
 # snippet(s).
 GenRet = Tuple[Snippet, Optional[Model]]
 
+# The return type of repeated generator calls. If the snippet is None, no
+# generators managed to generate anything.
+GensRet = Tuple[Optional[Snippet], Model]
+
+# A continuation type that allows a generator to recursively generate some more
+# stuff.
+GenCont = Callable[[Model, Program], GensRet]
+
 
 class SnippetGen:
     '''A parameterised sequence of instructions
@@ -46,12 +50,9 @@
         and returns a GenRet tuple. See comment above the type definition for
         more information.
 
-        On failure, leaves program and model unchanged and returns None. There
-        should always be at least one snippet generator with positive weight
-        (see pick_weight below) that succeeds unconditionally. This will be the
-        ecall generator. Failure is interpreted as "this snippet won't work
-        with the current program state", but the generator may be retried
-        later.
+        On failure, leaves program and model unchanged and returns None.
+        Failure is interpreted as "this snippet won't work with the current
+        program state", but the generator may be retried later.
 
         The cont argument is a continuation, used to call out to more
         generators in order to do recursive generation. It takes a (mutable)
@@ -59,6 +60,9 @@
         through the generated code don't terminate with an ECALL but instead
         end up at the resulting model.pc.
 
+        This will only be called when model.fuel > 0 and
+        program.get_insn_space_at(model.pc) > 0.
+
         '''
         raise NotImplementedError('gen not implemented by subclass')
 
@@ -73,10 +77,13 @@
         small, for example).
 
         It can also be used to alter weights depending on where we are in the
-        program. For example, a generator that generates ecall to end the
+        program. For example, a generator that generates ECALL to end the
         program could decrease its weight when size is large, to avoid
         generating tiny programs by accident.
 
+        This will only be called when model.fuel > 0 and
+        program.get_insn_space_at(model.pc) > 0.
+
         The default implementation always returns 1.0.
 
         '''
diff --git a/hw/ip/otbn/util/rig/snippet_gens.py b/hw/ip/otbn/util/rig/snippet_gens.py
index 3293eff..52d2d84 100644
--- a/hw/ip/otbn/util/rig/snippet_gens.py
+++ b/hw/ip/otbn/util/rig/snippet_gens.py
@@ -32,64 +32,84 @@
         for cls, weight in SnippetGens._WEIGHTED_CLASSES:
             self.generators.append((cls(insns_file), weight))
 
+        # Grab an ECall generator. We'll use it in self.gens to append an ECALL
+        # instruction if necessary.
+        ecall = None
+        for gen, _ in self.generators:
+            if isinstance(gen, ECall):
+                ecall = gen
+                break
+        assert ecall is not None
+        assert isinstance(ecall, ECall)
+        self.ecall = ecall
+
     def gen(self,
             model: Model,
             program: Program,
             ecall: bool) -> Optional[GenRet]:
         '''Pick a snippet and update model, program with its contents.
 
+        This assumes that program.get_insn_space_at(model.pc) > 0.
+
         Normally returns a GenRet tuple with the same meanings as Snippet.gen.
         If the chosen snippet would generate an ECALL and ecall is False, this
         instead returns None (and leaves model and program unchanged).
 
         '''
+        # If we've run out of fuel, stop immediately.
+        if not model.fuel:
+            return None
+
         real_weights = []
+        pos_weights = 0
         for generator, weight in self.generators:
-            weight_mult = generator.pick_weight(model, program)
-            real_weights.append(weight * weight_mult)
+            if isinstance(generator, ECall) and not ecall:
+                real_weight = 0.0
+            else:
+                weight_mult = generator.pick_weight(model, program)
+                real_weight = weight * weight_mult
 
-        # Define a continuation (which basically just calls self.gens()) to
-        # pass to each generator. This allows recursive generation and avoids
-        # needing circular imports to get the types right.
-        def cont(md: Model, prg: Program) -> Optional[Tuple[Snippet, Model]]:
-            ret = self.gens(md, prg, False)
-            if ret is None:
-                return None
-            snippet, model = ret
-            # We should always have a Model returned (because the ecall
-            # argument was False)
-            assert model is not None
-            return (snippet, model)
+            assert real_weight >= 0
+            if real_weight > 0:
+                pos_weights += 1
 
-        while True:
+            real_weights.append(real_weight)
+
+        while pos_weights > 0:
             # Pick a generator based on the weights in real_weights.
             idx = random.choices(range(len(self.generators)),
                                  weights=real_weights)[0]
             generator, _ = self.generators[idx]
 
-            # Note that there should always be at least one non-zero weight in
+            # Note that there should always be at least one positive weight in
             # real_weights. random.choices doesn't check that: if you pass all
             # weights equal to zero, it always picks the last element. Since
             # that would cause an infinite loop, add a check here to make sure
             # that the choice we made had positive weight.
             assert real_weights[idx] > 0
 
-            if isinstance(generator, ECall) and not ecall:
-                return None
-
             # Run the generator to generate a snippet
-            gen_res = generator.gen(cont, model, program)
+            gen_res = generator.gen(self.gens, model, program)
             if gen_res is not None:
                 return gen_res
 
             # If gen_res is None, the generator failed. Set that weight to zero
             # and try again.
             real_weights[idx] = 0.0
+            pos_weights -= 1
 
-    def gens(self,
-             model: Model,
-             program: Program,
-             ecall: bool) -> Optional[GenRet]:
+        # We ran out of generators with positive weight. Give up.
+        return None
+
+    def _gen_ecall(self, pc: int, program: Program) -> Snippet:
+        '''Generate an ECALL instruction at pc, ignoring notions of fuel'''
+        assert program.get_insn_space_at(pc) > 0
+        return self.ecall.gen_at(pc, program)
+
+    def _gens(self,
+              model: Model,
+              program: Program,
+              ecall: bool) -> Tuple[List[Snippet], Optional[Model]]:
         '''Generate some snippets to continue program.
 
         This will try to run down model.fuel and program.size. If ecall is
@@ -103,23 +123,65 @@
         next_model = model  # type: Optional[Model]
         while True:
             assert next_model is not None
+
+            must_stop = False
+            # If we've run out of space and ecall is False, we stop
+            # immediately. If ecall is True, we need to generate one last ECALL
+            # instruction.
+            if not program.space:
+                if ecall:
+                    must_stop = True
+                else:
+                    break
+
             old_fuel = next_model.fuel
-            gr = self.gen(next_model, program, ecall)
-            if gr is None:
-                assert ecall is False
+            gen_res = None
+            if not must_stop:
+                gen_res = self.gen(next_model, program, ecall)
+
+            if gen_res is None:
+                # We failed to generate another snippet. If ecall is False,
+                # that's fine: we've probably run out of fuel and should stop.
+                # If ecall is True, that's bad news: we can't just leave the
+                # program unfinished. In that case, force an ECALL instruction.
+                if ecall:
+                    children.append(self._gen_ecall(next_model.pc, program))
+                    next_model = None
                 break
 
-            snippet, next_model = gr
+            snippet, next_model = gen_res
             children.append(snippet)
 
             if next_model is None:
+                assert ecall
                 break
 
             assert next_model.fuel < old_fuel
 
-        if not children:
-            assert ecall is False
-            return None
+        return (children, next_model)
 
-        snippet = Snippet.merge_list(children)
+    def gens(self,
+             model: Model,
+             program: Program) -> Tuple[Optional[Snippet], Model]:
+        '''Try to generate snippets to continue program
+
+        This will try to run down model.fuel and program.size. When it runs out
+        of one or the other, it stops and returns any snippet it generated plus
+        the updated model.
+
+        '''
+        snippets, next_model = self._gens(model, program, False)
+        # _gens() only sets next_model to None if ecall is True.
+        assert next_model is not None
+        snippet = Snippet.merge_list(snippets) if snippets else None
         return (snippet, next_model)
+
+    def gen_rest(self, model: Model, program: Program) -> Snippet:
+        '''Generate the rest of the program, ending with an ECALL'''
+        snippets, next_model = self._gens(model, program, True)
+        # Since _gens() has finished with an ECALL, it always returns None for
+        # next_model. It also always generates at least one snippet (containing
+        # the ECALL).
+        assert next_model is None
+        assert snippets
+        return Snippet.merge_list(snippets)