[otbn] Teach RIG to generate BEQ and BNE instructions
The branching program flow means that we have to teach the RIG some
new tricks. Firstly, we need to add "recursive" generation (so that
the generator for a branch can generate stuff in each side below it).
We do this by passing a continuation of type GenCont to the generator.
Secondly, we need to tighten up the snippet type a bit. Branches are
represented as branches in the expression tree (as you'd expect), and
we also needed to add a sequencing type, SeqSnippet (so that you can
have a list of instructions followed by a branch, for example).
We always generate branches that converge (so there will only ever be
one ECALL in the generated program). To glue the branches back
together, we generate a jump from the end of one to the end of the
other. This needs slight changes in the Jump generator so that we can
specify the target address that we want in this case.
The final big change is that we need to add support for "φ nodes" to
the model, to cope with when execution joins back together. This is
the "merge" method added to Model, CallStack and KnownMem.
Signed-off-by: Rupert Swarbrick <rswarbrick@lowrisc.org>
diff --git a/hw/ip/otbn/util/otbn-rig b/hw/ip/otbn/util/otbn-rig
index 0f0516b..94145d8 100755
--- a/hw/ip/otbn/util/otbn-rig
+++ b/hw/ip/otbn/util/otbn-rig
@@ -19,7 +19,7 @@
sys.path.append(os.path.dirname(__file__))
from rig.init_data import InitData # noqa: E402
-from rig.rig import gen_program, snippets_to_program # noqa: E402
+from rig.rig import gen_program # noqa: E402
from rig.snippet import Snippet # noqa: E402
@@ -41,12 +41,12 @@
return 1
# Run the generator
- init_data, snippets = gen_program(args.start_addr, args.size, insns_file)
+ init_data, snippet = gen_program(args.start_addr, args.size, insns_file)
# Write out the data and snippets to a JSON file
ser_data = init_data.as_json()
- ser_snippets = [snippet.to_json() for snippet in snippets]
- ser = [ser_data, ser_snippets]
+ ser_snippet = snippet.to_json()
+ ser = [ser_data, ser_snippet]
try:
if args.output == '-':
json.dump(ser, sys.stdout)
@@ -86,17 +86,16 @@
if not (isinstance(json_data, list) and len(json_data) == 2):
raise ValueError('Top-level structure should be a length 2 list.')
- json_init_data, json_snippets = json_data
+ json_init_data, json_snippet = json_data
init_data = InitData.read(json_init_data)
- snippets = [Snippet.from_json(insns_file, idx, x)
- for idx, x in enumerate(json_snippets)]
+ snippet = Snippet.from_json(insns_file, [], json_snippet)
except ValueError as err:
print('Failed to parse snippets from {!r}: {}'
.format(args.snippets.name, err),
file=sys.stderr)
return 1
- program = snippets_to_program(snippets)
+ program = snippet.to_program()
dsegs = init_data.as_segs()
# Dump the assembly output, and the linker script too if we're writing to
diff --git a/hw/ip/otbn/util/rig/gens/branch.py b/hw/ip/otbn/util/rig/gens/branch.py
new file mode 100644
index 0000000..f3705e2
--- /dev/null
+++ b/hw/ip/otbn/util/rig/gens/branch.py
@@ -0,0 +1,238 @@
+# Copyright lowRISC contributors.
+# Licensed under the Apache License, Version 2.0, see LICENSE for details.
+# SPDX-License-Identifier: Apache-2.0
+
+import random
+from typing import Optional, Sequence, Tuple
+
+from shared.insn_yaml import InsnsFile
+from shared.operand import ImmOperandType, RegOperandType
+
+from .jump import Jump
+from ..program import ProgInsn, Program
+from ..model import Model
+from ..snippet import BranchSnippet, ProgSnippet, SeqSnippet
+from ..snippet_gen import GenCont, GenRet, SnippetGen
+
+
+class Branch(SnippetGen):
+ '''A generator that makes a snippet with a BEQ or BNE branch'''
+ def __init__(self, insns_file: InsnsFile) -> None:
+ self.jump_gen = Jump(insns_file)
+ self.beq = self._get_named_insn(insns_file, 'beq')
+ self.bne = self._get_named_insn(insns_file, 'bne')
+
+ # beq and bne expect operands: grs1, grs2, offset
+ for insn in [self.beq, self.bne]:
+ if not (len(insn.operands) == 3 and
+ isinstance(insn.operands[0].op_type, RegOperandType) and
+ insn.operands[0].op_type.reg_type == 'gpr' and
+ not insn.operands[0].op_type.is_dest() and
+ isinstance(insn.operands[1].op_type, RegOperandType) and
+ insn.operands[1].op_type.reg_type == 'gpr' and
+ not insn.operands[1].op_type.is_dest() and
+ isinstance(insn.operands[2].op_type, ImmOperandType) and
+ insn.operands[2].op_type.signed):
+ raise RuntimeError('{} instruction from instructions file is not '
+ 'the shape expected by the Branch generator.'
+ .format(insn.mnemonic))
+
+ _FloatRng = Tuple[float, float]
+ _WeightedFloatRng = Tuple[float, _FloatRng]
+
+ @staticmethod
+ def pick_from_weighted_ranges(r: Sequence[_WeightedFloatRng]) -> float:
+ ff0, ff1 = random.choices([rng for _, rng in r],
+ weights=[w for w, _ in r])[0]
+ return random.uniform(ff0, ff1)
+
+ def _pick_tgt_addr(self,
+ pc: int,
+ off_min: int,
+ off_max: int,
+ program: Program) -> Optional[int]:
+ '''Pick the target address for a branch'''
+ # Make sure we can cover the case where we branch conditionally to
+ # PC+4, which is probably quite unlikely otherwise, by giving at 1%
+ # chance.
+ #
+ # We'll need at least 4 instructions' space for a proper branch: the
+ # branch instruction, the fall-through instruction, the branch target
+ # (which will jump back if necessary), and an eventual ECALL)
+ if program.get_insn_space_left() < 4:
+ fall_thru = True
+ else:
+ fall_thru = random.random() < 0.01
+
+ return (pc + 4 if fall_thru else
+ program.pick_branch_target(pc, 1, off_min, off_max))
+
+ def gen(self,
+ cont: GenCont,
+ model: Model,
+ program: Program) -> Optional[GenRet]:
+
+ if model.fuel <= 1:
+ # The shortest possible branch sequence (branch to PC + 4) takes an
+ # instruction and needs at least one instruction afterwards for the
+ # ECALL, so don't generate anything if fuel is less than 2.
+ return None
+
+ # Return None if this is the last instruction in the current gap
+ # because we need to either jump or do an ECALL to avoid getting stuck
+ # (just like the StraightLineInsn generator)
+ if program.get_insn_space_at(model.pc) <= 1:
+ return None
+
+ # Decide whether to generate BEQ or BNE. In the future, we'll load
+ # this weighting from somewhere else.
+ beq_weight = 1.0
+ bne_weight = 1.0
+ sum_weights = beq_weight + bne_weight
+ is_beq = random.random() < beq_weight / sum_weights
+
+ insn = self.beq if is_beq else self.bne
+ grs1_op, grs2_op, off_op = insn.operands
+
+ assert isinstance(off_op.op_type, ImmOperandType)
+
+ # Calculate the range of target addresses we can encode (this includes
+ # any PC-relative adjustment)
+ off_rng = off_op.op_type.get_op_val_range(model.pc)
+ assert off_rng is not None
+ off_min, off_max = off_rng
+
+ # Pick the source GPRs that we're comparing.
+ assert isinstance(grs1_op.op_type, RegOperandType)
+ assert isinstance(grs2_op.op_type, RegOperandType)
+ grs1 = model.pick_reg_operand_value(grs1_op.op_type)
+ grs2 = model.pick_reg_operand_value(grs2_op.op_type)
+ if grs1 is None or grs2 is None:
+ return None
+
+ tgt_addr = self._pick_tgt_addr(model.pc, off_min, off_max, program)
+ if tgt_addr is None:
+ return None
+
+ assert off_min <= tgt_addr <= off_max
+
+ off_enc = off_op.op_type.op_val_to_enc_val(tgt_addr, model.pc)
+ assert off_enc is not None
+
+ branch_insn = ProgInsn(insn, [grs1, grs2, off_enc], None)
+
+ if tgt_addr == model.pc + 4:
+ # If tgt_addr equals model.pc + 4, this actually behaves like a
+ # straight-line instruction! Add the branch instruction, update the
+ # model, and return.
+ psnip = ProgSnippet(model.pc, [branch_insn])
+ psnip.insert_into_program(program)
+ model.update_for_insn(branch_insn)
+ model.pc += 4
+ return (psnip, model)
+
+ # Decide how much of our remaining fuel to give the code below the
+ # branch. Each side gets the same amount because only one side appears
+ # in the instruction stream.
+ fuel_frac_ranges = [(1, (0, 0.1)),
+ (10, (0.1, 0.5)),
+ (1, (0.5, 1.0))]
+ fuel_frac = self.pick_from_weighted_ranges(fuel_frac_ranges)
+ assert 0 <= fuel_frac <= 1
+ branch_fuel = max(1, int(0.5 + fuel_frac * model.fuel))
+
+ # Similarly, decide how much of our remaining space to give the code
+ # below the branch. Unlike with the fuel, we halve the result for each
+ # side (since each side of the branch consumes instruction space)
+ space_frac_ranges = fuel_frac_ranges
+ space_frac = self.pick_from_weighted_ranges(space_frac_ranges)
+ assert 0 <= space_frac <= 1
+ # Subtract 2: one for the branch instruction and one for an eventual
+ # ECALL. We checked earlier we had at least 4 instructions' space left,
+ # so there should always be at least 2 instructions' space left
+ # afterwards.
+ max_space_for_branches = program.get_insn_space_left() - 2
+ assert max_space_for_branches >= 2
+ branch_space = max(1, int(space_frac * (max_space_for_branches / 2)))
+ assert 2 * branch_space <= max_space_for_branches
+
+ # Make an updated copy of program that includes the branch instruction.
+ # Similarly, take a copy of the model and update it as if we've fallen
+ # through the branch instruction. Note that we can't just modify
+ # program or model here because generation might fail.
+ #
+ # Insert a bogus instruction at tgt_addr into prog0. This represents
+ # the first instruction on the other side of the branch: we need to do
+ # this to avoid both sides of the branch trying to put an instruction
+ # there.
+ prog0 = program.copy()
+ prog0.add_insns(model.pc, [branch_insn])
+ model0 = model.copy()
+ model0.update_for_insn(branch_insn)
+
+ model0.pc += 4
+ prog0.add_insns(tgt_addr, [branch_insn])
+
+ model0.fuel = branch_fuel
+ prog0.constrain_space(branch_space)
+
+ ret0 = cont(model0, prog0)
+ if ret0 is None:
+ return None
+
+ snippet0, model0 = ret0
+
+ # We successfully generated the fall-through branch. Now we want to
+ # generate the other side. Make another copy of program and insert the
+ # instructions from snippet0 into it. Add the bogus instruction at
+ # model.pc, as above. Also add a bogus instruction at model0.pc: this
+ # represents "the next thing" that happens at the end of the first
+ # branch, and we mustn't allow snippet1 to use that space.
+ prog1 = program.copy()
+ snippet0.insert_into_program(prog1)
+ prog1.add_insns(model.pc, [branch_insn])
+ prog1.add_insns(model0.pc, [branch_insn])
+
+ model1 = model.copy()
+ model1.update_for_insn(branch_insn)
+ model1.pc = tgt_addr
+
+ prog1.constrain_space(branch_space)
+ model1.fuel = branch_fuel
+
+ ret1 = cont(model1, prog1)
+ if ret1 is None:
+ return None
+
+ snippet1, model1 = ret1
+
+ # We've managed to generate both sides of the branch. All that's left
+ # to do is fix up the execution paths to converge again. To do this, we
+ # need to add a jump to one side or the other. (Alternatively, we could
+ # jump from both to another address, but this shouldn't provide any
+ # extra coverage, so there's not much point)
+ if random.random() < 0.5:
+ # Add the jump to go from branch 0 to branch 1
+ jump_ret = self.jump_gen.gen_tgt(model0, prog0, model1.pc)
+ if jump_ret is None:
+ return None
+
+ jmp_snippet, model0 = jump_ret
+ if not snippet0.merge(jmp_snippet):
+ snippet0 = SeqSnippet([snippet0, jmp_snippet])
+ else:
+ # Add the jump to go from branch 1 to branch 0
+ jump_ret = self.jump_gen.gen_tgt(model1, prog1, model0.pc)
+ if jump_ret is None:
+ return None
+
+ jmp_snippet, model1 = jump_ret
+ if not snippet1.merge(jmp_snippet):
+ snippet1 = SeqSnippet([snippet1, jmp_snippet])
+
+ assert model0.pc == model1.pc
+ model0.merge(model1)
+
+ snippet = BranchSnippet(model.pc, branch_insn, snippet0, snippet1)
+ snippet.insert_into_program(program)
+ return (snippet, model0)
diff --git a/hw/ip/otbn/util/rig/gens/ecall.py b/hw/ip/otbn/util/rig/gens/ecall.py
index 42085a1..047baff 100644
--- a/hw/ip/otbn/util/rig/gens/ecall.py
+++ b/hw/ip/otbn/util/rig/gens/ecall.py
@@ -2,14 +2,14 @@
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
# SPDX-License-Identifier: Apache-2.0
-from typing import Optional, Tuple
+from typing import Optional
from shared.insn_yaml import InsnsFile
from ..program import ProgInsn, Program
from ..model import Model
-from ..snippet import Snippet
-from ..snippet_gen import SnippetGen
+from ..snippet import ProgSnippet
+from ..snippet_gen import GenCont, GenRet, SnippetGen
class ECall(SnippetGen):
@@ -28,20 +28,24 @@
self.insn = ProgInsn(ecall_insn, [], None)
def gen(self,
- size: int,
+ cont: GenCont,
model: Model,
- program: Program) -> Optional[Tuple[Snippet, bool, int]]:
- snippet = Snippet([(model.pc, [self.insn])])
+ program: Program) -> Optional[GenRet]:
+ snippet = ProgSnippet(model.pc, [self.insn])
snippet.insert_into_program(program)
- return (snippet, True, 0)
+ return (snippet, None)
def pick_weight(self,
- size: int,
model: Model,
program: Program) -> float:
- # Choose small weights when size is large and large ones when it's
- # small.
- assert size > 0
- return (1e-10 if size > 5
- else 0.1 if size > 1
+ # Choose small weights when we've got lots of room and large ones when
+ # we haven't.
+ fuel = model.fuel
+ space = program.get_insn_space_left()
+ assert fuel > 0
+ assert space > 0
+
+ room = min(fuel, space)
+ return (1e-10 if room > 5
+ else 0.1 if room > 1
else 1e10)
diff --git a/hw/ip/otbn/util/rig/gens/jump.py b/hw/ip/otbn/util/rig/gens/jump.py
index b8fbbda..68ac7fd 100644
--- a/hw/ip/otbn/util/rig/gens/jump.py
+++ b/hw/ip/otbn/util/rig/gens/jump.py
@@ -10,8 +10,8 @@
from ..program import ProgInsn, Program
from ..model import Model
-from ..snippet import Snippet
-from ..snippet_gen import SnippetGen
+from ..snippet import ProgSnippet, Snippet
+from ..snippet_gen import GenCont, GenRet, SnippetGen
class Jump(SnippetGen):
@@ -49,9 +49,15 @@
self.jalr = jalr
def gen(self,
- size: int,
+ cont: GenCont,
model: Model,
- program: Program) -> Optional[Tuple[Snippet, bool, int]]:
+ program: Program) -> Optional[GenRet]:
+ return self.gen_tgt(model, program, None)
+
+ def gen_tgt(self,
+ model: Model,
+ program: Program,
+ tgt_addr: Optional[int]) -> Optional[Tuple[Snippet, Model]]:
# Decide whether to generate JALR or JAL. In the future, we'll load
# this weighting from somewhere else.
@@ -64,17 +70,23 @@
# wrapper will disable us entirely this time around.
is_jalr = random.random() < jalr_weight / sum_weights
if is_jalr:
- ret = self.gen_jalr(size, model, program)
- if ret is not None:
- return ret
+ ret = self.gen_jalr(model, program, tgt_addr)
+ else:
+ ret = self.gen_jal(model, program, tgt_addr)
- return self.gen_jal(size, model, program)
+ if ret is None:
+ return None
+ else:
+ snippet, new_model = ret
+ assert new_model is not None
+ return (snippet, new_model)
def _pick_jump(self,
base_addr: int,
imm_optype: ImmOperandType,
model: Model,
- program: Program) -> Optional[Tuple[int, int, int]]:
+ program: Program,
+ tgt_addr: Optional[int]) -> Optional[Tuple[int, int, int]]:
'''Pick target and link register for a jump instruction
For a JALR instruction, base_addr is the address stored in the register
@@ -102,6 +114,15 @@
tgt_min = imm_min + base_addr
tgt_max = imm_max + base_addr
+ # If there is a desired target, check it's representable. If not,
+ # return None. Otherwise, narrow the range to just that.
+ if tgt_addr is not None:
+ if tgt_min <= tgt_addr <= tgt_max:
+ tgt_min = tgt_addr
+ tgt_max = tgt_addr
+ else:
+ return None
+
# Pick a branch target. "1" here is the minimum number of instructions
# that must fit. One is enough (we'll just end up generating another
# branch immediately)
@@ -128,12 +149,11 @@
prog_insn: ProgInsn,
link_reg_idx: int,
new_pc: int,
- size: int,
model: Model,
- program: Program) -> Tuple[Snippet, bool, int]:
+ program: Program) -> GenRet:
'''Generate a 1-instruction snippet for prog_insn; finish generation'''
# Generate our one-instruction snippet and add it to the program
- snippet = Snippet([(model.pc, [prog_insn])])
+ snippet = ProgSnippet(model.pc, [prog_insn])
snippet.insert_into_program(program)
# Update the model with the instruction
@@ -152,31 +172,30 @@
# And update the PC, which is now tgt
model.pc = new_pc
- return (snippet, False, size - 1)
+ return (snippet, model)
def gen_jal(self,
- size: int,
model: Model,
- program: Program) -> Optional[Tuple[Snippet, bool, int]]:
+ program: Program,
+ tgt_addr: Optional[int]) -> Optional[GenRet]:
'''Generate a random JAL instruction'''
assert len(self.jal.operands) == 2
offset_optype = self.jal.operands[1].op_type
assert isinstance(offset_optype, ImmOperandType)
- jmp_data = self._pick_jump(0, offset_optype, model, program)
+ jmp_data = self._pick_jump(0, offset_optype, model, program, tgt_addr)
if jmp_data is None:
return None
tgt, enc_offset, link_reg_idx = jmp_data
prog_insn = ProgInsn(self.jal, [link_reg_idx, enc_offset], None)
- return self._add_snippet(prog_insn, link_reg_idx, tgt,
- size, model, program)
+ return self._add_snippet(prog_insn, link_reg_idx, tgt, model, program)
def gen_jalr(self,
- size: int,
model: Model,
- program: Program) -> Optional[Tuple[Snippet, bool, int]]:
+ program: Program,
+ tgt_addr: Optional[int]) -> Optional[GenRet]:
'''Generate a random JALR instruction'''
assert len(self.jalr.operands) == 3
@@ -192,7 +211,8 @@
base_reg_idx, base_reg_val = random.choice(known_regs)
- jmp_data = self._pick_jump(base_reg_val, offset_optype, model, program)
+ jmp_data = self._pick_jump(base_reg_val, offset_optype,
+ model, program, tgt_addr)
if jmp_data is None:
return None
@@ -202,4 +222,4 @@
[link_reg_idx, base_reg_idx, enc_offset],
None)
return self._add_snippet(prog_insn, link_reg_idx, tgt,
- size, model, program)
+ model, program)
diff --git a/hw/ip/otbn/util/rig/gens/straight_line_insn.py b/hw/ip/otbn/util/rig/gens/straight_line_insn.py
index 1d8e838..b6dc721 100644
--- a/hw/ip/otbn/util/rig/gens/straight_line_insn.py
+++ b/hw/ip/otbn/util/rig/gens/straight_line_insn.py
@@ -11,8 +11,8 @@
from ..program import ProgInsn, Program
from ..model import Model
-from ..snippet import Snippet
-from ..snippet_gen import SnippetGen
+from ..snippet import ProgSnippet
+from ..snippet_gen import GenCont, GenRet, SnippetGen
class StraightLineInsn(SnippetGen):
@@ -32,9 +32,9 @@
self.insns.append(insn)
def gen(self,
- size: int,
+ cont: GenCont,
model: Model,
- program: Program) -> Optional[Tuple[Snippet, bool, int]]:
+ program: Program) -> Optional[GenRet]:
# Return None if this is the last instruction in the current gap
# because we need to either jump or do an ECALL to avoid getting stuck.
@@ -68,14 +68,14 @@
# Success! We have generated an instruction. Put it in a snippet and
# add that to the program
- snippet = Snippet([(model.pc, [prog_insn])])
+ snippet = ProgSnippet(model.pc, [prog_insn])
snippet.insert_into_program(program)
# Then update the model with the instruction and update the model PC
model.update_for_insn(prog_insn)
model.pc += 4
- return (snippet, False, size - 1)
+ return (snippet, model)
def fill_insn(self, insn: Insn, model: Model) -> Optional[ProgInsn]:
'''Try to fill out an instruction
diff --git a/hw/ip/otbn/util/rig/model.py b/hw/ip/otbn/util/rig/model.py
index 89f7dfe..130ea63 100644
--- a/hw/ip/otbn/util/rig/model.py
+++ b/hw/ip/otbn/util/rig/model.py
@@ -37,6 +37,28 @@
return (r, s, t)
+def intersect_ranges(a: List[Tuple[int, int]],
+ b: List[Tuple[int, int]]) -> List[Tuple[int, int]]:
+ ret = []
+ paired = ([(r, False) for r in a] + [(r, True) for r in b])
+ arng = None # type: Optional[Tuple[int, int]]
+ brng = None # type: Optional[Tuple[int, int]]
+ for (lo, hi), is_b in sorted(paired):
+ if is_b:
+ if arng is not None:
+ a0, a1 = arng
+ if a0 <= hi and lo <= a1:
+ ret.append((max(a0, lo), min(a1, hi)))
+ brng = (lo, hi)
+ else:
+ if brng is not None:
+ b0, b1 = brng
+ if b0 <= hi and lo <= b1:
+ ret.append((max(lo, b0), min(hi, b1)))
+ arng = (lo, hi)
+ return ret
+
+
class KnownMem:
'''A representation of what memory/CSRs have architectural values'''
def __init__(self, top_addr: int):
@@ -47,6 +69,18 @@
# then each byte in the address range {lo..hi - 1} has a known value.
self.known_ranges = [] # type: List[Tuple[int, int]]
+ def copy(self) -> 'KnownMem':
+ '''Return a shallow copy of the object'''
+ ret = KnownMem(self.top_addr)
+ ret.known_ranges = self.known_ranges.copy()
+ return ret
+
+ def merge(self, other: 'KnownMem') -> None:
+ '''Merge in values from another KnownMem object'''
+ assert self.top_addr == other.top_addr
+ self.known_ranges = intersect_ranges(self.known_ranges,
+ other.known_ranges)
+
def touch_range(self, base: int, width: int) -> None:
'''Mark {base .. base + width - 1} as known'''
assert 0 <= width
@@ -314,6 +348,82 @@
return addr
+class CallStack:
+ '''An abstract model of the x1 call stack'''
+ def __init__(self) -> None:
+ self._min_depth = 0
+ self._max_depth = 0
+ self._elts_at_top = [] # type: List[Optional[int]]
+
+ def copy(self) -> 'CallStack':
+ '''Return a deep copy of the call stack'''
+ ret = CallStack()
+ ret._min_depth = self._min_depth
+ ret._max_depth = self._max_depth
+ ret._elts_at_top = self._elts_at_top.copy()
+ return ret
+
+ def merge(self, other: 'CallStack') -> None:
+ self._min_depth = min(self._min_depth, other._min_depth)
+ self._max_depth = max(self._max_depth, other._max_depth)
+ new_top = []
+ for a, b in zip(reversed(self._elts_at_top),
+ reversed(other._elts_at_top)):
+ if a == b:
+ new_top.append(a)
+ else:
+ break
+ new_top.reverse()
+ self._elts_at_top = new_top
+ assert self._min_depth <= self._max_depth
+ assert len(self._elts_at_top) <= self._max_depth
+
+ def empty(self) -> bool:
+ assert 0 <= self._min_depth
+ return self._min_depth == 0
+
+ def full(self) -> bool:
+ assert self._max_depth <= 8
+ return self._max_depth == 8
+
+ def pop(self) -> None:
+ assert 0 < self._min_depth
+ self._min_depth -= 1
+ self._max_depth -= 1
+ if self._elts_at_top:
+ self._elts_at_top.pop()
+
+ def peek(self) -> Optional[int]:
+ assert 0 < self._min_depth
+ ret = self._elts_at_top[-1] if self._elts_at_top else None
+ return self._elts_at_top[-1] if self._elts_at_top else None
+
+ def write(self, value: Optional[int], update: bool) -> None:
+ '''Write a value to the call stack.
+
+ The update flag works as described for Model.write_reg
+
+ '''
+ if update:
+ # If we're updating a write to x1, check that the new value refines
+ # the top of the call stack.
+ assert self._min_depth > 0
+ if self._elts_at_top:
+ assert self._elts_at_top[-1] in [None, value]
+ self._elts_at_top[-1] = value
+ else:
+ self._elts_at_top.append(value)
+ else:
+ assert not self.full()
+ self._min_depth += 1
+ self._max_depth += 1
+ self._elts_at_top.append(value)
+
+ def XXXshow(self) -> str:
+ return ('CallStack({}, {}, {})'
+ .format(self._min_depth, self._max_depth, self._elts_at_top))
+
+
class Model:
'''An abstract model of the processor and memories
@@ -322,7 +432,11 @@
following the instruction stream to this point.
'''
- def __init__(self, dmem_size: int, reset_addr: int) -> None:
+ def __init__(self, dmem_size: int, reset_addr: int, fuel: int) -> None:
+ assert fuel >= 0
+ self.initial_fuel = fuel
+ self.fuel = fuel
+
self.dmem_size = dmem_size
# Known values for registers. This is a dictionary mapping register
@@ -345,7 +459,7 @@
# entry of None means an entry with an architectural value, but where
# we don't actually know what it is (usually a result of some
# arithmetic operation that got written to x1).
- self._call_stack = [] # type: List[Optional[int]]
+ self._call_stack = CallStack()
# Known values for memory, keyed by memory type ('dmem', 'csr', 'wsr').
csrs = KnownMem(4096)
@@ -371,6 +485,74 @@
# generating)
self.pc = reset_addr
+ def copy(self) -> 'Model':
+ '''Return a deep copy of the model'''
+ ret = Model(self.dmem_size, self.pc, self.initial_fuel)
+ ret.fuel = self.fuel
+ ret._known_regs = {n: regs.copy()
+ for n, regs in self._known_regs.items()}
+ ret._call_stack = self._call_stack.copy()
+ ret._known_mem = {n: mem.copy()
+ for n, mem in self._known_mem.items()}
+ return ret
+
+ def merge(self, other: 'Model') -> None:
+ '''Merge in values from another model'''
+ assert self.initial_fuel == other.initial_fuel
+ self.fuel = min(self.fuel, other.fuel)
+ assert self.dmem_size == other.dmem_size
+
+ reg_types = self._known_regs.keys() | other._known_regs.keys()
+ for reg_type in reg_types:
+ sregs = self._known_regs.get(reg_type)
+ oregs = other._known_regs.get(reg_type)
+ if sregs is None:
+ # If sregs is None, we have no registers that are known to have
+ # architectural values.
+ continue
+ if oregs is None:
+ # If oregs is None, other has no registers with architectural
+ # values. Thus the merged model shouldn't have any either.
+ del self._known_regs[reg_type]
+ continue
+
+ # Both register files have at least some architectural values.
+ # Build a new, merged version.
+ merged = {} # type: Dict[int, Optional[int]]
+ for reg_name, svalue in sregs.items():
+ ovalue = oregs.get(reg_name, 'missing')
+ if ovalue == 'missing':
+ # The register is missing from oregs. This means it might
+ # not have an architectural value, so we should skip it
+ # from sregs too.
+ pass
+ elif ovalue is None:
+ # The register has an architectural value in other, but not
+ # one we know. Make sure it's unknown here too.
+ merged[reg_name] = None
+ else:
+ assert isinstance(ovalue, int)
+ if svalue is None:
+ # The register has an unknown architectural value in
+ # self and a known value in other. So we don't know its
+ # value (but it is still architecturally specified): no
+ # change.
+ merged[reg_name] = None
+ else:
+ # self and other both have a known value for the
+ # register. Do they match? If so, take that value.
+ # Otherwise, make it unknown.
+ merged[reg_name] = None if svalue != ovalue else svalue
+
+ self._known_regs[reg_type] = merged
+
+ self._call_stack.merge(other._call_stack)
+
+ for mem_type, self_mem in self._known_mem.items():
+ self_mem.merge(other._known_mem[mem_type])
+
+ assert self.pc == other.pc
+
def read_reg(self, reg_type: str, idx: int) -> None:
'''Update the model for a read of the given register
@@ -379,7 +561,6 @@
'''
if reg_type == 'gpr' and idx == 1:
- assert self._call_stack
self._call_stack.pop()
def write_reg(self,
@@ -406,12 +587,7 @@
if idx == 1:
# Special-case writes to x1
- if update:
- assert self._call_stack
- assert self._call_stack[-1] in [None, value]
- self._call_stack[-1] = value
- else:
- self._call_stack.append(value)
+ self._call_stack.write(value, update)
return
self._known_regs.setdefault(reg_type, {})[idx] = value
@@ -419,7 +595,7 @@
def get_reg(self, reg_type: str, idx: int) -> Optional[int]:
'''Get a register value, if known.'''
if reg_type == 'gpr' and idx == 1:
- return self._call_stack[-1] if self._call_stack else None
+ return self._call_stack.peek()
return self._known_regs.setdefault(reg_type, {}).get(idx)
@@ -475,7 +651,7 @@
# architectural value). This won't appear in known_regs,
# because we don't track x1 there.
assert 1 not in known_regs
- if self._call_stack:
+ if not self._call_stack.empty():
known_list.append(1)
return random.choice(known_list)
@@ -487,7 +663,7 @@
idx = random.getrandbits(op_type.width)
if ((idx == 1 and
op_type.reg_type == 'gpr' and
- len(self._call_stack) >= 8)):
+ self._call_stack.full())):
continue
return idx
@@ -508,8 +684,10 @@
# not None
if reg_type == 'gpr':
assert 1 not in known_regs
- if self._call_stack and self._call_stack[-1] is not None:
- ret.append((1, self._call_stack[-1]))
+ if not self._call_stack.empty():
+ x1 = self._call_stack.peek()
+ if x1 is not None:
+ ret.append((1, x1))
return ret
@@ -522,7 +700,7 @@
# stack is not empty.
if reg_type == 'gpr':
assert 1 not in arch_regs
- if self._call_stack:
+ if not self._call_stack.empty():
arch_regs.append(1)
return arch_regs
@@ -843,6 +1021,10 @@
mem_type, addr = prog_insn.lsu_info
self.touch_mem(mem_type, addr, insn.lsu.idx_width)
+ def consume_fuel(self) -> None:
+ '''Consume one item of fuel, but bottom out at fuel == 1'''
+ self.fuel = max(1, self.fuel - 1)
+
def update_for_insn(self, prog_insn: ProgInsn) -> None:
# If this is a sufficiently simple operation that we understand the
# result, or a complicated instruction where we have to do something
@@ -859,3 +1041,5 @@
updater(prog_insn)
else:
self._generic_update_for_insn(prog_insn)
+
+ self.consume_fuel()
diff --git a/hw/ip/otbn/util/rig/program.py b/hw/ip/otbn/util/rig/program.py
index 45b2d0d..0988968 100644
--- a/hw/ip/otbn/util/rig/program.py
+++ b/hw/ip/otbn/util/rig/program.py
@@ -164,11 +164,33 @@
# size 4N bytes.
self._sections = {} # type: Dict[int, List[ProgInsn]]
+ # The number of instructions' space available. If we aren't below any
+ # branches, this is the space available in imem. When we're branching,
+ # this might be less.
+ self._space = self.imem_size // 4
+
+ def copy(self) -> 'Program':
+ '''Return a a shallow copy of the program
+
+ This is a shallow copy, so shares ProgInsn instances, but it can be
+ modified by adding instructions without affecting the original.
+
+ '''
+ ret = Program(self.imem_lma, self.imem_size,
+ self.dmem_lma, self.dmem_size)
+ ret._sections = {base: section.copy()
+ for base, section in self._sections.items()}
+ ret._space = self._space
+ return ret
+
def add_insns(self, addr: int, insns: List[ProgInsn]) -> None:
'''Add a sequence of instructions, starting at addr'''
assert addr & 3 == 0
assert addr <= self.imem_size
+ assert len(insns) <= self._space
+ self._space -= len(insns)
+
sec_top = addr + 4 * len(insns)
# This linear search is a bit naff, but I doubt it will have a
@@ -489,3 +511,12 @@
return 0
return max(0, space // 4)
+
+ def get_insn_space_left(self) -> int:
+ '''Return how many more instructions there is space for'''
+ return self._space
+
+ def constrain_space(self, space: int) -> None:
+ '''Constrain the amount of space available'''
+ assert space <= self._space
+ self._space = space
diff --git a/hw/ip/otbn/util/rig/rig.py b/hw/ip/otbn/util/rig/rig.py
index 913e626..e490f83 100644
--- a/hw/ip/otbn/util/rig/rig.py
+++ b/hw/ip/otbn/util/rig/rig.py
@@ -2,7 +2,7 @@
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
# SPDX-License-Identifier: Apache-2.0
-from typing import Dict, List, Tuple
+from typing import Tuple
from shared.insn_yaml import InsnsFile
from shared.mem_layout import get_memory_layout
@@ -15,12 +15,12 @@
def gen_program(start_addr: int,
- size: int,
- insns_file: InsnsFile) -> Tuple[InitData, List[Snippet]]:
+ fuel: int,
+ insns_file: InsnsFile) -> Tuple[InitData, Snippet]:
'''Generate a random program for OTBN
start_addr is the reset address (the value that should be programmed into
- the START_ADDR register). size gives a rough upper bound for the number of
+ the START_ADDR register). fuel gives a rough upper bound for the number of
instructions that will be executed by the generated program.
Returns (init_data, snippets, program). init_data is a dict mapping (4-byte
@@ -41,7 +41,7 @@
assert start_addr & 3 == 0
program = Program(imem_lma, imem_size, dmem_lma, dmem_size)
- model = Model(dmem_size, start_addr)
+ model = Model(dmem_size, start_addr, fuel)
# Generate some initialised data to start with. Otherwise, it takes a while
# before we start issuing loads (because we need stores to happen first).
@@ -50,34 +50,8 @@
for addr in init_data.keys():
model.touch_mem('dmem', addr, 4)
- generators = SnippetGens(insns_file)
- snippets = []
+ ret = SnippetGens(insns_file).gens(model, program, True)
+ assert ret is not None
+ snippet, _ = ret
- while size > 0:
- snippet, done, new_size = generators.gen(size, model, program)
- snippets.append(snippet)
- if done:
- break
-
- # Each new snippet should consume some of size to guarantee
- # termination.
- assert new_size < size
- size = new_size
-
- return init_data, snippets
-
-
-def snippets_to_program(snippets: List[Snippet]) -> Program:
- '''Write a series of disjoint snippets to make a program'''
- # Find the size of the memory that we can access. Both memories start
- # at address 0: a strict Harvard architecture. (mems[x][0] is the LMA
- # for memory x, not the VMA)
- mems = get_memory_layout()
- imem_lma, imem_size = mems['IMEM']
- dmem_lma, dmem_size = mems['DMEM']
- program = Program(imem_lma, imem_size, dmem_lma, dmem_size)
-
- for snippet in snippets:
- snippet.insert_into_program(program)
-
- return program
+ return init_data, snippet
diff --git a/hw/ip/otbn/util/rig/snippet.py b/hw/ip/otbn/util/rig/snippet.py
index 60f1c34..4a53ad3 100644
--- a/hw/ip/otbn/util/rig/snippet.py
+++ b/hw/ip/otbn/util/rig/snippet.py
@@ -2,30 +2,16 @@
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
# SPDX-License-Identifier: Apache-2.0
-from typing import List, Tuple
+from typing import List, Optional
from shared.insn_yaml import InsnsFile
+from shared.mem_layout import get_memory_layout
from .program import ProgInsn, Program
class Snippet:
- '''A collection of instructions, generated as part of a random program.
-
- parts is a list of pairs (addr, insns), where insns is a nonempty list of
- instructions and addr is the address of its first element. The entry point
- for the snippet is the address of the first part.
-
- '''
- def __init__(self,
- parts: List[Tuple[int, List[ProgInsn]]]):
- assert parts
- for idx, (addr, insns) in enumerate(parts):
- assert addr >= 0
- assert addr & 3 == 0
-
- self.parts = parts
-
+ '''A collection of instructions, generated as part of a random program.'''
def insert_into_program(self, program: Program) -> None:
'''Insert this snippet into the given program
@@ -33,66 +19,213 @@
instructions in the program.
'''
- for addr, insns in self.parts:
- program.add_insns(addr, insns)
+ raise NotImplementedError()
def to_json(self) -> object:
'''Serialize to an object that can be written as JSON'''
- lst = []
- for addr, insns in self.parts:
- lst.append((addr, [i.to_json() for i in insns]))
- return lst
+ raise NotImplementedError()
+
+ @staticmethod
+ def _addr_from_json(where: str, json: object) -> int:
+ '''Read an instruction address from a parsed json object'''
+
+ # The address should be an aligned non-negative integer and insns
+ # should itself be a list (of serialized Insn objects).
+ if not isinstance(json, int):
+ raise ValueError('First coordinate of {} is not an integer.'
+ .format(where))
+ if json < 0:
+ raise ValueError('Address of {} is {}, but should be non-negative.'
+ .format(where, json))
+ if json & 3:
+ raise ValueError('Address of {} is {}, '
+ 'but should be 4-byte aligned.'
+ .format(where, json))
+ return json
+
+ @staticmethod
+ def _from_json_lst(insns_file: InsnsFile,
+ idx: List[int],
+ json: List[object]) -> 'Snippet':
+ raise NotImplementedError()
@staticmethod
def from_json(insns_file: InsnsFile,
- idx: int,
+ idx: List[int],
json: object) -> 'Snippet':
- '''The inverse of to_json.
+ '''The inverse of to_json'''
+ if not (isinstance(json, list) and json):
+ raise ValueError('Snippet {} is not a nonempty list.'.format(idx))
- idx is the 0-based number of the snippet in the file, just used for
- error messages.
+ key = json[0]
+ if not isinstance(key, str):
+ raise ValueError('Key for snippet {} is not a string.'.format(idx))
+
+ if key == 'PS':
+ return ProgSnippet._from_json_lst(insns_file, idx, json[1:])
+ elif key == 'BS':
+ return BranchSnippet._from_json_lst(insns_file, idx, json[1:])
+ elif key == 'SS':
+ return SeqSnippet._from_json_lst(insns_file, idx, json[1:])
+ else:
+ raise ValueError('Snippet {} has unknown key {!r}.'
+ .format(idx, key))
+
+ def merge(self, snippet: 'Snippet') -> bool:
+ '''Merge snippet after this one and return True if possible.
+
+ If not possible, leaves self unchanged and returns False.
'''
- if not isinstance(json, list):
- raise ValueError('Object for snippet {} is not a list.'
- .format(idx))
+ return False
- parts = []
- for idx1, part in enumerate(json):
- # Each element should be a pair: (addr, insns). This will have come
- # out as a list (since tuples serialize as lists).
- if not (isinstance(part, list) and len(part) == 2):
- raise ValueError('Part {} for snippet {} is not a pair.'
- .format(idx1, idx))
+ def to_program(self) -> Program:
+ '''Write a series of disjoint snippets to make a program'''
+ # Find the size of the memory that we can access. Both memories start
+ # at address 0: a strict Harvard architecture. (mems[x][0] is the LMA
+ # for memory x, not the VMA)
+ mems = get_memory_layout()
+ imem_lma, imem_size = mems['IMEM']
+ dmem_lma, dmem_size = mems['DMEM']
+ program = Program(imem_lma, imem_size, dmem_lma, dmem_size)
+ self.insert_into_program(program)
+ return program
- addr, insns_json = part
- # The address should be an aligned non-negative integer and insns
- # should itself be a list (of serialized Insn objects).
- if not isinstance(addr, int):
- raise ValueError('First coordinate of part {} for snippet {} '
- 'is not an integer.'
- .format(idx1, idx))
- if addr < 0:
- raise ValueError('Address of part {} for snippet {} is {}, '
- 'but should be non-negative.'
- .format(idx1, idx, addr))
- if addr & 3:
- raise ValueError('Address of part {} for snippet {} is {}, '
- 'but should be 4-byte aligned.'
- .format(idx1, idx, addr))
+class ProgSnippet(Snippet):
+ '''A sequence of instructions that are executed in order'''
+ def __init__(self, addr: int, insns: List[ProgInsn]):
+ assert addr >= 0
+ assert addr & 3 == 0
- if not isinstance(insns_json, list):
- raise ValueError('Second coordinate of part {} for snippet {} '
- 'is not a list.'
- .format(idx1, idx))
+ self.addr = addr
+ self.insns = insns
- insns = []
- for insn_idx, insn_json in enumerate(insns_json):
- where = ('In snippet {}, part {}, instruction {}'
- .format(idx, idx1, insn_idx))
- insns.append(ProgInsn.from_json(insns_file, where, insn_json))
+ def insert_into_program(self, program: Program) -> None:
+ program.add_insns(self.addr, self.insns)
- parts.append((addr, insns))
+ def to_json(self) -> object:
+ '''Serialize to an object that can be written as JSON'''
+ return ['PS', self.addr, [i.to_json() for i in self.insns]]
- return Snippet(parts)
+ @staticmethod
+ def _from_json_lst(insns_file: InsnsFile,
+ idx: List[int],
+ json: List[object]) -> Snippet:
+ '''The inverse of to_json.'''
+ # Each element should be a pair: (addr, insns).
+ if len(json) != 2:
+ raise ValueError('Snippet {} has {} arguments; '
+ 'expected 2 for a ProgSnippet.'
+ .format(idx, len(json)))
+ j_addr, j_insns = json
+
+ where = 'snippet {}'.format(idx)
+ addr = Snippet._addr_from_json(where, j_addr)
+
+ if not isinstance(j_insns, list):
+ raise ValueError('Second coordinate of {} is not a list.'
+ .format(where))
+
+ insns = []
+ for insn_idx, insn_json in enumerate(j_insns):
+ pi_where = ('In snippet {}, instruction {}'
+ .format(idx, insn_idx))
+ pi = ProgInsn.from_json(insns_file, pi_where, insn_json)
+ insns.append(pi)
+
+ return ProgSnippet(addr, insns)
+
+ def merge(self, snippet: Snippet) -> bool:
+ if not isinstance(snippet, ProgSnippet):
+ return False
+
+ next_addr = self.addr + 4 * len(self.insns)
+ if snippet.addr != next_addr:
+ return False
+
+ self.insns += snippet.insns
+ return True
+
+
+class SeqSnippet(Snippet):
+ '''A nonempty sequence of snippets that run one after another'''
+ def __init__(self, children: List[Snippet]):
+ assert children
+ self.children = children
+
+ def insert_into_program(self, program: Program) -> None:
+ for child in self.children:
+ child.insert_into_program(program)
+
+ def to_json(self) -> object:
+ ret = ['SS'] # type: List[object]
+ ret += [c.to_json() for c in self.children]
+ return ret
+
+ @staticmethod
+ def _from_json_lst(insns_file: InsnsFile,
+ idx: List[int],
+ json: List[object]) -> Snippet:
+ if len(json) == 0:
+ raise ValueError('List at {} for SeqSnippet is empty.'.format(idx))
+
+ children = []
+ for i, item in enumerate(json):
+ children.append(Snippet.from_json(insns_file, idx + [i], item))
+ return SeqSnippet(children)
+
+
+class BranchSnippet(Snippet):
+ '''A snippet representing a branch
+
+ branch_insn is the first instruction that runs, at address addr, then
+ either snippet0 or snippet1 will run. The program will complete in either
+ case.
+
+ '''
+ def __init__(self,
+ addr: int,
+ branch_insn: ProgInsn,
+ snippet0: Snippet,
+ snippet1: Snippet):
+ self.addr = addr
+ self.branch_insn = branch_insn
+ self.snippet0 = snippet0
+ self.snippet1 = snippet1
+
+ def insert_into_program(self, program: Program) -> None:
+ program.add_insns(self.addr, [self.branch_insn])
+ self.snippet0.insert_into_program(program)
+ if self.snippet1 is not None:
+ self.snippet1.insert_into_program(program)
+
+ def to_json(self) -> object:
+ js1 = None if self.snippet1 is None else self.snippet1.to_json()
+ return ['BS',
+ self.addr,
+ self.branch_insn.to_json(),
+ self.snippet0.to_json(),
+ js1]
+
+ @staticmethod
+ def _from_json_lst(insns_file: InsnsFile,
+ idx: List[int],
+ json: List[object]) -> Snippet:
+ if len(json) != 4:
+ raise ValueError('List for snippet {} is of the wrong '
+ 'length for a BranchSnippet ({}, not 4)'
+ .format(idx, len(json)))
+
+ j_addr, j_branch_insn, j_snippet0, j_snippet1 = json
+
+ addr_where = 'address for snippet {}'.format(idx)
+ addr = Snippet._addr_from_json(addr_where, j_addr)
+
+ bi_where = 'branch instruction for snippet {}'.format(idx)
+ branch_insn = ProgInsn.from_json(insns_file, bi_where, j_branch_insn)
+
+ snippet0 = Snippet.from_json(insns_file, idx + [0], j_snippet0)
+ snippet1 = Snippet.from_json(insns_file, idx + [1], j_snippet1)
+
+ return BranchSnippet(addr, branch_insn, snippet0, snippet1)
diff --git a/hw/ip/otbn/util/rig/snippet_gen.py b/hw/ip/otbn/util/rig/snippet_gen.py
index 73281c2..50d736d 100644
--- a/hw/ip/otbn/util/rig/snippet_gen.py
+++ b/hw/ip/otbn/util/rig/snippet_gen.py
@@ -9,7 +9,7 @@
'''
-from typing import Optional, Tuple
+from typing import Callable, Optional, Tuple
from shared.insn_yaml import Insn, InsnsFile
@@ -17,6 +17,17 @@
from .model import Model
from .snippet import Snippet
+# A continuation type that allows a generator to recursively generate some more
+# stuff.
+GenCont = Callable[[Model, Program], Optional[Tuple[Snippet, Model]]]
+
+# The return type of a single generator. This is a tuple (snippet, model).
+# snippet is a generated snippet. If the program is done (i.e. every execution
+# ends with ecall) then model is None. Otherwise it is a Model object
+# representing the state of the processor after executing the code in the
+# snippet(s).
+GenRet = Tuple[Snippet, Optional[Model]]
+
class SnippetGen:
'''A parameterised sequence of instructions
@@ -26,21 +37,14 @@
'''
def gen(self,
- size: int,
+ cont: GenCont,
model: Model,
- program: Program) -> Optional[Tuple[Snippet, bool, int]]:
+ program: Program) -> Optional[GenRet]:
'''Try to generate instructions for this type of snippet.
- size is always positive and gives an upper bound on the number of
- instructions in the dynamic instruction stream that this should
- generate. For example, a loop of 10 instructions that goes around 10
- times would consume 100 from size.
-
On success, inserts the instructions into program, updates the model,
- and returns a tuple (snippet, done, new_size). snippet is the generated
- snippet. done is true if the program is finished (if snippet ends with
- ecall) and is false otherwise. new_size is the size left after the
- generated snippet.
+ and returns a GenRet tuple. See comment above the type definition for
+ more information.
On failure, leaves program and model unchanged and returns None. There
should always be at least one snippet generator with positive weight
@@ -49,19 +53,24 @@
with the current program state", but the generator may be retried
later.
+ The cont argument is a continuation, used to call out to more
+ generators in order to do recursive generation. It takes a (mutable)
+ model and program and picks a sequence of instructions. The paths
+ through the generated code don't terminate with an ECALL but instead
+ end up at the resulting model.pc.
+
'''
raise NotImplementedError('gen not implemented by subclass')
def pick_weight(self,
- size: int,
model: Model,
program: Program) -> float:
'''Pick a weight by which to multiply this generator's default weight
This is called for each generator before we start trying to generate a
snippet for a given program and model state. This can be used to
- disable a generator when we know it won't work (if size is too small, for
- example).
+ disable a generator when we know it won't work (if model.fuel is too
+ small, for example).
It can also be used to alter weights depending on where we are in the
program. For example, a generator that generates ecall to end the
diff --git a/hw/ip/otbn/util/rig/snippet_gens.py b/hw/ip/otbn/util/rig/snippet_gens.py
index 70e719f..cd0f625 100644
--- a/hw/ip/otbn/util/rig/snippet_gens.py
+++ b/hw/ip/otbn/util/rig/snippet_gens.py
@@ -3,26 +3,28 @@
# SPDX-License-Identifier: Apache-2.0
import random
-from typing import List, Tuple
+from typing import List, Optional, Tuple
from shared.insn_yaml import InsnsFile
from .program import Program
from .model import Model
-from .snippet import Snippet
-from .snippet_gen import SnippetGen
+from .snippet import SeqSnippet, Snippet
+from .snippet_gen import GenRet, SnippetGen
+from .gens.branch import Branch
from .gens.ecall import ECall
-from .gens.straight_line_insn import StraightLineInsn
from .gens.jump import Jump
+from .gens.straight_line_insn import StraightLineInsn
class SnippetGens:
'''A collection of snippet generators'''
_WEIGHTED_CLASSES = [
+ (Branch, 0.1),
(ECall, 1.0),
- (StraightLineInsn, 1.0),
- (Jump, 0.1)
+ (Jump, 0.1),
+ (StraightLineInsn, 1.0)
]
def __init__(self, insns_file: InsnsFile) -> None:
@@ -31,24 +33,34 @@
self.generators.append((cls(insns_file), weight))
def gen(self,
- size: int,
model: Model,
- program: Program) -> Tuple[Snippet, bool, int]:
+ program: Program,
+ ecall: bool) -> Optional[GenRet]:
'''Pick a snippet and update model, program with its contents.
- Returns a pair (snippet, done, new_size) with the same meanings as
- Snippet.gen, except that new_size is clamped to be at least 1 if done
- is false. This avoids snippets having to special-case to make sure they
- aren't chosen when size is near zero. The end result might be a
- slightly longer instruction stream than we intended, but it shouldn't
- be much bigger.
+ Normally returns a GenRet tuple with the same meanings as Snippet.gen.
+ If the chosen snippet would generate an ECALL and ecall is False, this
+ instead returns None (and leaves model and program unchanged).
'''
real_weights = []
for generator, weight in self.generators:
- weight_mult = generator.pick_weight(size, model, program)
+ weight_mult = generator.pick_weight(model, program)
real_weights.append(weight * weight_mult)
+ # Define a continuation (which basically just calls self.gens()) to
+ # pass to each generator. This allows recursive generation and avoids
+ # needing circular imports to get the types right.
+ def cont(md: Model, prg: Program) -> Optional[Tuple[Snippet, Model]]:
+ ret = self.gens(md, prg, False)
+ if ret is None:
+ return None
+ snippet, model = ret
+ # We should always have a Model returned (because the ecall
+ # argument was False)
+ assert model is not None
+ return (snippet, model)
+
while True:
# Pick a generator based on the weights in real_weights.
idx = random.choices(range(len(self.generators)),
@@ -62,15 +74,56 @@
# that the choice we made had positive weight.
assert real_weights[idx] > 0
- # Run the generator to generate a snippet
- gen_res = generator.gen(size, model, program)
- if gen_res is not None:
- snippet, done, new_size = gen_res
- if not done:
- new_size = max(new_size, 1)
+ if isinstance(generator, ECall) and not ecall:
+ return None
- return (snippet, done, new_size)
+ # Run the generator to generate a snippet
+ gen_res = generator.gen(cont, model, program)
+ if gen_res is not None:
+ return gen_res
# If gen_res is None, the generator failed. Set that weight to zero
# and try again.
real_weights[idx] = 0.0
+
+ def gens(self,
+ model: Model,
+ program: Program,
+ ecall: bool) -> Optional[GenRet]:
+ '''Generate some snippets to continue program.
+
+ This will try to run down model.fuel and program.size. If ecall is
+ True, it will eventually generate an ECALL instruction. If ecall is
+ False then instead of generating the ECALL instruction, it will instead
+ stop (leaving model.pc where the ECALL instruction would have been
+ inserted).
+
+ '''
+ children = [] # type: List[Snippet]
+ next_model = model # type: Optional[Model]
+ while True:
+ assert next_model is not None
+ old_fuel = next_model.fuel
+ gr = self.gen(next_model, program, ecall)
+ if gr is None:
+ assert ecall is False
+ break
+
+ snippet, next_model = gr
+
+ # Merge adjacent program snippets if possible. Otherwise, add a new
+ # one.
+ if not children or not children[-1].merge(snippet):
+ children.append(snippet)
+
+ if next_model is None:
+ break
+
+ assert next_model.fuel < old_fuel
+
+ if not children:
+ assert ecall is False
+ return None
+
+ snippet = children[0] if len(children) == 1 else SeqSnippet(children)
+ return (snippet, next_model)