[sw/otbn] Create a static check for OTBN loops.
Checks for common mistakes that can cause LOOP errors or loop stack
issues.
Signed-off-by: Jade Philipoom <jadep@google.com>
diff --git a/hw/ip/otbn/util/check_call_stack.py b/hw/ip/otbn/util/check_call_stack.py
index cf4e263..c063162 100755
--- a/hw/ip/otbn/util/check_call_stack.py
+++ b/hw/ip/otbn/util/check_call_stack.py
@@ -7,6 +7,7 @@
import sys
from typing import Dict, Tuple
+from shared.check import CheckResult
from shared.decode import OTBNProgram, decode_elf
from shared.insn_yaml import Insn
from shared.operand import RegOperandType
@@ -28,32 +29,33 @@
return True
-def check_call_stack(program: OTBNProgram) -> Tuple[bool, str]:
+def check_call_stack(program: OTBNProgram) -> CheckResult:
'''Check that the special register x1 is used safely.
If x1 is used for purposes unrelated to the call stack, it can trigger a
CALL_STACK error. This check errors if x1 is used for any other instruction
than `jal` or `jalr`.
'''
- for pc in program.insns:
- insn = program.get_insn(pc)
- operands = program.get_operands(pc)
+ out = CheckResult()
+ for pc, (insn, operands) in program.insns.items():
if not _check_call_stack_insn(insn, operands):
- return (False, 'check_call_stack: FAIL at PC {:#x}: {} {}'.format(
- pc, insn.mnemonic, operands))
- return (True, 'check_call_stack: PASS')
+ out.err(
+ 'Potentially dangerous use of the call stack register x1 at '
+ 'PC {:#x}: {}'.format(pc, insn.disassemble(pc, operands)))
+ out.set_prefix('check_call_stack: ')
+ return out
def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument('elf', help=('The .elf file to check.'))
+ parser.add_argument('-v', '--verbose', action='store_true')
args = parser.parse_args()
program = decode_elf(args.elf)
- ok, msg = check_call_stack(program)
- print(msg)
- if not ok:
- return 1
- return 0
+ result = check_call_stack(program)
+ if args.verbose or result.has_errors() or result.has_warnings():
+ print(result.report())
+ return 1 if result.has_errors() else 0
if __name__ == "__main__":
diff --git a/hw/ip/otbn/util/check_loop.py b/hw/ip/otbn/util/check_loop.py
new file mode 100755
index 0000000..e5b44f1
--- /dev/null
+++ b/hw/ip/otbn/util/check_loop.py
@@ -0,0 +1,248 @@
+#!/usr/bin/env python3
+# Copyright lowRISC contributors.
+# Licensed under the Apache License, Version 2.0, see LICENSE for details.
+# SPDX-License-Identifier: Apache-2.0
+
+import argparse
+import sys
+from typing import Dict, List, Tuple
+
+from shared.check import CheckResult
+from shared.decode import OTBNProgram, decode_elf
+from shared.insn_yaml import Insn
+
+
+class CodeSection:
+ '''A continuous part of a program's code (represented as a PC range).
+
+ The code section is considered to include both the start and end PC.
+ '''
+ def __init__(self, start: int, end: int):
+ self.start = start
+ self.end = end
+
+ def __contains__(self, pc):
+ return self.start <= pc and pc <= self.end
+
+ def __repr__(self):
+ return '{:#x}-0x{:#x}'.format(self.start, self.end)
+
+
+def _get_pcs_for_mnemonics(program: OTBNProgram,
+ mnems: List[str]) -> List[int]:
+ '''Gets all PCs in the program at which the given instruction is present.'''
+ return [
+ pc for (pc, (insn, _)) in program.insns.items()
+ if insn.mnemonic in mnems
+ ]
+
+
+def _get_branches(program: OTBNProgram) -> List[int]:
+ '''Gets the PCs of all branch instructions (BEQ and BNE) in the program.'''
+ return _get_pcs_for_mnemonics(program, ['bne', 'beq'])
+
+
+def _get_loop_starts(program: OTBNProgram) -> List[int]:
+ '''Gets the start PCs of all loops (LOOP and LOOPI) in the program.'''
+ return _get_pcs_for_mnemonics(program, ['loop', 'loopi'])
+
+
+def _get_loops(program: OTBNProgram) -> List[CodeSection]:
+ '''Gets the PC ranges of all loops (LOOP and LOOPI) in the program.'''
+ loop_starts = _get_loop_starts(program)
+ loops = []
+ for pc in loop_starts:
+ insn = program.get_insn(pc)
+ operands = program.get_operands(pc)
+ end_pc = pc + operands['bodysize'] * 4
+ loops.append(CodeSection(pc + 4, end_pc))
+ return loops
+
+
+def _check_loop_iterations(program: OTBNProgram,
+ loops: List[CodeSection]) -> CheckResult:
+ '''Checks number of iterations for loopi.
+
+ If the number of iterations is 0, this check fails; `loopi` requires at least
+ one iteration and will raise a LOOP error otherwise. The `loop` instruction
+ also has this requirement, but since the number of loop iterations comes
+ from a register it's harder to check statically and is not considered here.
+ '''
+ out = CheckResult()
+ for loop in loops:
+ insn = program.get_insn(loop.start)
+ operands = program.get_operands(loop.start)
+ if insn.mnemonic == 'loopi' and operands['iterations'] <= 0:
+ out.err(
+ 'Bad number of loop iterations ({}) at PC {:#x}: {}'.format(
+ operands['iterations'], loop.start,
+ insn.disassemble(loop.start, operands)))
+ return out
+
+
+def _check_loop_end_insns(program: OTBNProgram,
+ loops: List[CodeSection]) -> CheckResult:
+ '''Checks that loops do not end in control flow instructions.
+
+ Such instructions can cause LOOP software errors during execution.
+ '''
+ out = CheckResult()
+ for loop in loops:
+ loop_end_insn = program.get_insn(loop.end)
+ if not loop_end_insn.straight_line:
+ out.err('Control flow instruction ({}) at end of loop at PC {:#x} '
+ '(loop starting at PC {:#x})'.format(
+ loop_end_insn.mnemonic, loop.end, loop.start))
+ return out
+
+
+def _check_loop_inclusion(program: OTBNProgram,
+ loops: List[CodeSection]) -> CheckResult:
+ '''Checks that inner loops are fully contained within outer loops.
+
+ When a loop starts within the body of another loop, it must be the case
+ that the inner loop's final instruction occurs before the outer loop's.
+ '''
+ out = CheckResult()
+ for loop in loops:
+ for other in loops:
+ if other.start in loop and not other.end in loop:
+ out.err('Inner loop ends after outer loop (inner loop {}, '
+ 'outer loop {})'.format(other, loop))
+
+ return out
+
+
+def _check_loop_branching(program: OTBNProgram,
+ loops: List[CodeSection]) -> CheckResult:
+ '''Checks that there are no branches into or out of loop bodies.
+
+ Branches within the same loop body are permitted (but not branches from an
+ inner loop to an outer loop, as this counts as branching out of the inner
+ loop). Because this isn't necessarily a fatal issue (for instance, it's
+ possible the branched-to code will always return to the loop), this check
+ returns warnings rather than errors.
+
+ A `jal` instruction with a register other than `x1` as the first operand is
+ treated the same as a branch and not permitted to cross the loop-body
+ boundary.
+ '''
+ out = CheckResult()
+
+ # Check all bne and beq instructions, as well as `jal` instructions with
+ # first operands other than x1 (unconditional branch)
+ to_check = _get_branches(program)
+ for pc in _get_pcs_for_mnemonics(program, ['jal']):
+ operands = program.get_operands(pc)
+ if operands['grd'] != 1:
+ to_check.append(pc)
+
+ for pc in to_check:
+ operands = program.get_operands(pc)
+ branch_addr = operands['offset'] & ((1 << 32) - 1)
+
+ # Get the loop bodies the branch is inside, if any
+ current_loops = []
+ for loop in loops:
+ if pc in loop:
+ current_loops.append(loop)
+
+ # Check that we're not branching out of any loop bodies
+ for loop in current_loops:
+ if branch_addr not in loop:
+ insn = program.get_insn(pc)
+ out.warn(
+ 'Branch out of loop at PC {:#x} (loop from PC {:#x} to PC '
+ '{:#x}, branch {} to PC {:#x}). This might cause problems '
+ 'with the loop stack and surprising behavior.'.format(
+ pc, loop.start, loop.end, insn.mnemonic, branch_addr))
+
+ # Check that we're not branching *into* a loop body that the branch
+ # instruction is not already in
+ for loop in loops:
+ if (branch_addr in loop) and (loop not in current_loops):
+ out.warn(
+ 'Branch into loop at PC {:#x} (loop from PC {:#x} to PC '
+ '{:#x}, branch {} to PC {:#x}). This might cause problems '
+ 'with the loop stack and surprising behavior.'.format(
+ pc, loop.start, loop.end, insn.mnemonic, branch_addr))
+
+ return out
+
+
+def _check_loop_stack(program: OTBNProgram,
+ loops: List[CodeSection]) -> CheckResult:
+ '''Checks that loops will likely be properly cleared from loop stack.
+
+ The checks here are based on the OTBN hardware IP documentation on loop
+ nesting. From the docs:
+
+ To avoid polluting the loop stack and avoid surprising behaviour, the
+ programmer must ensure that:
+
+ * Even if there are branches and jumps within a loop body, the final
+ instruction of the loop body gets executed exactly once per iteration.
+ * Nested loops have distinct end addresses.
+ * The end instruction of an outer loop is not executed before an inner loop
+ finishes.
+
+ In order to avoid simulating the control flow of the entire program to
+ check the first and third conditions, this check takes a conservative,
+ simplistic approach and simply warns about all branching into or out of
+ loop bodies, including jumps that don't use the call stack (e.g. `jal x0,
+ <addr>`). Branching to locations within the same loop body is permitted.
+
+ The second condition in the list, distinct end addresses, is checked
+ separately.
+ '''
+ out = CheckResult()
+ out += _check_loop_branching(program, loops)
+
+ # Check that loops have unique end addresses
+ end_addrs = []
+ for loop in loops:
+ if loop.end in end_addrs:
+ out.err(
+ 'Loop starting at PC {:#x} shares a final instruction with '
+ 'another loop; consider adding a NOP instruction.'.format(
+ loop.start))
+
+ return out
+
+
+def check_loop(program: OTBNProgram) -> CheckResult:
+ '''Check that loops are properly formed.
+
+ Performs three checks to rule out certain classes of loop errors and
+ undefined behavior:
+ 1. For loopi instructions, check that the number of iterations is > 0.
+ 2. Ensure that loops do not end in control-flow instructions such as jal or
+ bne, which will raise LOOP errors.
+ 3. Checks that there is no branching into or out of loop bodies.
+ 4. For nested loops, the inner loop is completely contained within the
+ outer loop.
+ '''
+ loops = _get_loops(program)
+ out = CheckResult()
+ out += _check_loop_iterations(program, loops)
+ out += _check_loop_end_insns(program, loops)
+ out += _check_loop_stack(program, loops)
+ out += _check_loop_inclusion(program, loops)
+ out.set_prefix('check_loop: ')
+ return out
+
+
+def main() -> int:
+ parser = argparse.ArgumentParser()
+ parser.add_argument('elf', help=('The .elf file to check.'))
+ parser.add_argument('-v', '--verbose', action='store_true')
+ args = parser.parse_args()
+ program = decode_elf(args.elf)
+ result = check_loop(program)
+ if args.verbose or result.has_errors() or result.has_warnings():
+ print(result.report())
+ return 1 if result.has_errors() else 0
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/hw/ip/otbn/util/shared/check.py b/hw/ip/otbn/util/shared/check.py
new file mode 100644
index 0000000..315b8be
--- /dev/null
+++ b/hw/ip/otbn/util/shared/check.py
@@ -0,0 +1,61 @@
+#!/usr/bin/env python3
+# Copyright lowRISC contributors.
+# Licensed under the Apache License, Version 2.0, see LICENSE for details.
+# SPDX-License-Identifier: Apache-2.0
+
+
+class CheckResult:
+ '''A class to record the results of static checks.
+
+ Can record any number of errors and warnings. Combine two check results
+ with +, e.g.:
+
+ out = CheckResult()
+ out += first_check()
+ out += second_check()
+ out.warn('A warning')
+ print(out.report()) # prints warnings/errors from both checks and "A warning"
+ '''
+ def __init__(self):
+ self.errors = []
+ self.warnings = []
+ self.prefix = ''
+
+ def warn(self, msg):
+ '''Add a warning.'''
+ self.warnings.append(msg)
+
+ def err(self, msg):
+ '''Add an error.'''
+ self.errors.append(msg)
+
+ def __add__(self, other):
+ '''Combines both operands' errors/warnings in a new CheckResult.'''
+ if not isinstance(other, CheckResult):
+ raise ValueError(
+ 'Cannot add {} (of type {}) to {} (of type CheckResult)'.
+ format(other, type(other), self))
+ out = CheckResult()
+ out.warnings = self.warnings + other.warnings
+ out.errors = self.errors + other.errors
+ return out
+
+ def set_prefix(self, prefix):
+ '''Add a prefix to the printouts for this check.'''
+ self.prefix = prefix
+
+ def has_errors(self):
+ return len(self.errors) != 0
+
+ def has_warnings(self):
+ return len(self.warnings) != 0
+
+ def report(self):
+ '''Show a message to represent the results of the check.'''
+ if not self.has_warnings() and not self.has_errors():
+ return '{}PASS'.format(self.prefix)
+ warn_strs = [
+ '{}WARN: {}'.format(self.prefix, w) for w in self.warnings
+ ]
+ err_strs = ['{}ERROR: {}'.format(self.prefix, e) for e in self.errors]
+ return '\n'.join(warn_strs + err_strs)
diff --git a/hw/ip/otbn/util/shared/decode.py b/hw/ip/otbn/util/shared/decode.py
index 3826551..fc2b82a 100644
--- a/hw/ip/otbn/util/shared/decode.py
+++ b/hw/ip/otbn/util/shared/decode.py
@@ -3,8 +3,8 @@
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
# SPDX-License-Identifier: Apache-2.0
-import sys
import struct
+import sys
from typing import Dict, Tuple
from shared.elf import read_elf
@@ -25,11 +25,11 @@
self.data = data # addr -> data (32b word)
self.insns = {}
- for pc in insns.keys():
- opcode = insns[pc]
+ for pc, opcode in insns.items():
mnem = INSNS_FILE.mnem_for_word(opcode)
if mnem is None:
- raise ValueError('No legal decoding for mnemonic: {}'.format(mnem))
+ raise ValueError(
+ 'No legal decoding for mnemonic: {}'.format(mnem))
insn = INSNS_FILE.mnemonic_to_insn.get(mnem)
enc_vals = insn.encoding.extract_operands(opcode)
op_vals = insn.enc_vals_to_op_vals(pc, enc_vals)