[otbn] Add support for "const values" to RIG
This is going to be important for generating loop bodies. The trick is
that we want to say something like "don't touch x3, x4 or x10 in the
loop body" to make sure that they still have the same values each
iteration. This way, they can be used as base addresses for memory
accesses or jumps, without having to worry that they'll get trashed on
the following iteration.
This also adds support for "forgetting values" in the model. We'll
generate a loop body by iterating over all known registers and either
marking them const or forgetting their exact values. This guarantees
that we don't use a trashable register as a base address.
Signed-off-by: Rupert Swarbrick <rswarbrick@lowrisc.org>
diff --git a/hw/ip/otbn/util/rig/gens/straight_line_insn.py b/hw/ip/otbn/util/rig/gens/straight_line_insn.py
index b6dc721..4b20d1f 100644
--- a/hw/ip/otbn/util/rig/gens/straight_line_insn.py
+++ b/hw/ip/otbn/util/rig/gens/straight_line_insn.py
@@ -147,20 +147,21 @@
return random.choices(valid_gprs)[0]
- def _pick_inc_vals(self) -> Tuple[int, int]:
+ def _pick_inc_vals(self,
+ idx0: int,
+ idx1: int,
+ model: Model) -> Tuple[int, int]:
'''Pick two values in 0, 1 that aren't both 1
These are appropriate to use as the increment flags for
- BN.LID/BN.SID/BN.MOVR.
+ BN.LID/BN.SID/BN.MOVR. idx0 and idx1 are the indices of the GPRs in
+ question.
'''
- idx = random.randint(0, 2)
- if idx == 0:
- return (0, 0)
- elif idx == 1:
- return (1, 0)
- else:
- return (0, 1)
+ options = [(0, 0), (1, 0), (0, 1)]
+ wt10 = 0.0 if model.is_const('gpr', idx0) else 1.0
+ wt01 = 0.0 if model.is_const('gpr', idx1) else 1.0
+ return random.choices(options, weights=[1.0, wt10, wt01])[0]
def _fill_bn_xid(self, insn: Insn, model: Model) -> Optional[ProgInsn]:
'''Fill out a BN.LID or BN.SID instruction'''
@@ -258,23 +259,24 @@
assert offset_rng[0] <= imm_val <= offset_rng[1]
assert list(reg_indices.keys()) == ['grs1']
- grs1_val = reg_indices['grs1']
+ grs1_idx = reg_indices['grs1']
offset_val = offset.op_type.op_val_to_enc_val(imm_val, model.pc)
assert offset_val is not None
# Do we increment the GPRs? We can increment up to one of them.
- grs1_inc_val, wdr_gpr_inc_val = self._pick_inc_vals()
+ grs1_inc_val, wdr_gpr_inc_val = \
+ self._pick_inc_vals(grs1_idx, wdr_gpr_idx, model)
# Finally, package up the operands properly for the instruction we're
# building.
if is_load:
# bn.lid: grd, grs1, offset, grs1_inc, grd_inc
- enc_vals = [wdr_gpr_idx, grs1_val, offset_val,
+ enc_vals = [wdr_gpr_idx, grs1_idx, offset_val,
grs1_inc_val, wdr_gpr_inc_val]
else:
# bn.sid: grs1, grs2, offset, grs1_inc, grs2_inc
- enc_vals = [grs1_val, wdr_gpr_idx, offset_val,
+ enc_vals = [grs1_idx, wdr_gpr_idx, offset_val,
grs1_inc_val, wdr_gpr_inc_val]
return ProgInsn(insn, enc_vals, ('dmem', addr))
@@ -314,7 +316,7 @@
# defines the destination WDR)
grd_idx = self._pick_gpr_with_arch_val(model)
- grd_inc_val, grs_inc_val = self._pick_inc_vals()
+ grd_inc_val, grs_inc_val = self._pick_inc_vals(grd_idx, grs_idx, model)
return ProgInsn(insn,
[grd_idx, grs_idx, grd_inc_val, grs_inc_val],
diff --git a/hw/ip/otbn/util/rig/model.py b/hw/ip/otbn/util/rig/model.py
index ffbc634..f476a06 100644
--- a/hw/ip/otbn/util/rig/model.py
+++ b/hw/ip/otbn/util/rig/model.py
@@ -83,6 +83,10 @@
self._max_depth += 1
self._elts_at_top.append(value)
+ def forget_value(self) -> None:
+ '''Replace any known values with None'''
+ self._elts_at_top = [None] * len(self._elts_at_top)
+
class Model:
'''An abstract model of the processor and memories
@@ -114,6 +118,16 @@
# Set x0 (the zeros register)
self._known_regs['gpr'] = {0: 0}
+ # Registers that must be kept constant. This is used for things like
+ # loop bodies, where we want to allow some registers to have known
+ # values (so we can use them as e.g. base addresses) and need to make
+ # sure not to clobber them.
+ self._const_regs = {} # type: Dict[str, Set[int]]
+
+ # To allow a caller to set _const_regs and unset again afterwards, we
+ # have a "const stack". See push_const and pop_const for usage.
+ self._const_stack = [] # type: List[Dict[str, Set[int]]]
+
# A call stack, representing the contents of x1. The top of the stack
# is at the end (position -1), to match Python's list.pop function. A
# entry of None means an entry with an architectural value, but where
@@ -151,21 +165,22 @@
ret.fuel = self.fuel
ret._known_regs = {n: regs.copy()
for n, regs in self._known_regs.items()}
+ ret._const_regs = {n: regs.copy()
+ for n, regs in self._const_regs.items()}
+ for entry in self._const_stack:
+ ret._const_stack.append({n: regs.copy()
+ for n, regs in entry.items()})
ret._call_stack = self._call_stack.copy()
ret._known_mem = {n: mem.copy()
for n, mem in self._known_mem.items()}
return ret
- def merge(self, other: 'Model') -> None:
- '''Merge in values from another model'''
- assert self.initial_fuel == other.initial_fuel
- self.fuel = min(self.fuel, other.fuel)
- assert self.dmem_size == other.dmem_size
-
- reg_types = self._known_regs.keys() | other._known_regs.keys()
- for reg_type in reg_types:
+ def _merge_known_regs(self,
+ other: Dict[str, Dict[int, Optional[int]]]) -> None:
+ '''Merge known registers from another model'''
+ for reg_type in self._known_regs.keys() | other.keys():
sregs = self._known_regs.get(reg_type)
- oregs = other._known_regs.get(reg_type)
+ oregs = other.get(reg_type)
if sregs is None:
# If sregs is None, we have no registers that are known to have
# architectural values.
@@ -206,6 +221,23 @@
self._known_regs[reg_type] = merged
+ def _merge_const_regs(self, other: Dict[str, Set[int]]) -> None:
+ '''Merge constant registers from another model'''
+ for reg_type in self._const_regs.keys() | other.keys():
+ cr = self._const_regs.setdefault(reg_type, set())
+ cr |= other.get(reg_type, set())
+
+ def merge(self, other: 'Model') -> None:
+ '''Merge in values from another model'''
+ assert self.initial_fuel == other.initial_fuel
+ self.fuel = min(self.fuel, other.fuel)
+ assert self.dmem_size == other.dmem_size
+
+ self._merge_known_regs(other._known_regs)
+ self._merge_const_regs(other._const_regs)
+
+ assert self._const_stack == other._const_stack
+
self._call_stack.merge(other._call_stack)
for mem_type, self_mem in self._known_mem.items():
@@ -221,6 +253,8 @@
'''
if reg_type == 'gpr' and idx == 1:
+ # We shouldn't ever read from x1 if it is marked constant
+ assert not self.is_const('gpr', 1)
self._call_stack.pop()
def write_reg(self,
@@ -240,6 +274,8 @@
registers, but matters for x1.
'''
+ assert not self.is_const(reg_type, idx)
+
if reg_type == 'gpr':
if idx == 0:
# Ignore writes to x0
@@ -307,52 +343,65 @@
known_list = list(known_regs)
if op_type.reg_type == 'gpr':
- # Add x1 if to the list of known registers (if it has an
- # architectural value). This won't appear in known_regs,
- # because we don't track x1 there.
+ # Add x1 if to the list of known registers if it has an
+ # architectural value and isn't marked constant. This won't
+ # appear in known_regs, because we don't track x1 there.
assert 1 not in known_regs
- if not self._call_stack.empty():
+ if not (self._call_stack.empty() or self.is_const('gpr', 1)):
known_list.append(1)
return random.choice(known_list)
- # This operand isn't treated as a source. Pick any register, but "roll
- # again" if we pick x1 and the call stack is full.
+ # This operand isn't treated as a source. Generate a list of allowed
+ # registers (everything but constant registers, plus x1 if the call
+ # stack is full) and then pick from it.
assert op_type.width is not None
- while True:
- idx = random.getrandbits(op_type.width)
- if ((idx == 1 and
- op_type.reg_type == 'gpr' and
- self._call_stack.full())):
- continue
- return idx
+ const_regs = self._const_regs.get(op_type.reg_type, set())
+ all_regs = set(range(1 << op_type.width))
+ good_regs = all_regs - const_regs
+ if op_type.reg_type == 'gpr' and self._call_stack.full():
+ good_regs.discard(1)
+ return random.choice(list(good_regs))
+
+ def all_regs_with_known_vals(self) -> Dict[str, List[Tuple[int, int]]]:
+ '''Like regs_with_known_vals, but returns all reg types'''
+ ret = {} # type: Dict[str, List[Tuple[int, int]]]
+ for rt in self._known_regs.keys():
+ kv = self.regs_with_known_vals(rt)
+ if kv:
+ ret[rt] = kv
+ return ret
def regs_with_known_vals(self, reg_type: str) -> List[Tuple[int, int]]:
- '''Find registers whose values are known
+ '''Find registers whose values are known and can be read
Returns a list of pairs (idx, value) where idx is the register index
and value is its value.
'''
+ known_regs = self._known_regs.get(reg_type)
+ if known_regs is None:
+ return []
+
ret = []
- known_regs = self._known_regs.setdefault(reg_type, {})
for reg_idx, reg_val in known_regs.items():
if reg_val is not None:
ret.append((reg_idx, reg_val))
# Handle x1, which has a known value iff the top of the call stack is
- # not None
+ # not None and can be read iff it isn't marked constant.
if reg_type == 'gpr':
assert 1 not in known_regs
if not self._call_stack.empty():
x1 = self._call_stack.peek()
if x1 is not None:
- ret.append((1, x1))
+ if not self.is_const('gpr', 1):
+ ret.append((1, x1))
return ret
def regs_with_architectural_vals(self, reg_type: str) -> List[int]:
- '''List registers that have an architectural value'''
+ '''List registers that have an architectural value and can be read'''
known_regs = self._known_regs.setdefault(reg_type, {})
arch_regs = list(known_regs.keys())
@@ -361,10 +410,67 @@
if reg_type == 'gpr':
assert 1 not in arch_regs
if not self._call_stack.empty():
- arch_regs.append(1)
+ if not self.is_const('gpr', 1):
+ arch_regs.append(1)
return arch_regs
+ def push_const(self) -> int:
+ '''Snapshot the current _const_regs state and return a token
+
+ This token should be passed to pop_const (to catch errors from
+ unbalanced push/pop pairs)
+
+ '''
+ snapshot = {n: regs.copy() for n, regs in self._const_regs.items()}
+ self._const_stack.append(snapshot)
+ return len(self._const_stack)
+
+ def pop_const(self, token: int) -> None:
+ '''Pop an entry from the _const_regs snapshot stack'''
+ assert token >= 1
+ assert len(self._const_stack) == token
+ self._const_regs = self._const_stack.pop()
+
+ def mark_const(self, reg_type: str, reg_idx: int) -> None:
+ '''Mark a register as constant
+
+ The model will no longer pick it as a destination operand or allow it
+ to be changed.
+
+ '''
+ # Marking x0 as constant has no effect (since it is a real constant
+ # register)
+ if reg_idx == 0 and reg_type == 'gpr':
+ return
+
+ self._const_regs.setdefault(reg_type, set()).add(reg_idx)
+
+ def is_const(self, reg_type: str, reg_idx: int) -> bool:
+ '''Return true if this register is marked as constant'''
+ cr = self._const_regs.get(reg_type)
+ if cr is None:
+ return False
+ return reg_idx in cr
+
+ def forget_value(self, reg_type: str, reg_idx: int) -> None:
+ '''If the given register has a known value, forget it.'''
+ if reg_type == 'gpr':
+ # We always know the value of x0
+ if reg_idx == 0:
+ return
+
+ # x1 (the call stack) has different handling
+ if reg_idx == 1:
+ self._call_stack.forget_value()
+ return
+
+ # Set the value in known_regs to None, but only if the register already
+ # has an architectural value.
+ kr = self._known_regs.setdefault(reg_type, {})
+ if reg_idx in kr:
+ kr[reg_idx] = None
+
def pick_lsu_target(self,
mem_type: str,
loads_value: bool,