[sw/otbn] Create a static check for OTBN loops.

Checks for common mistakes that can cause LOOP errors or loop stack
issues.

Signed-off-by: Jade Philipoom <jadep@google.com>
diff --git a/hw/ip/otbn/util/check_call_stack.py b/hw/ip/otbn/util/check_call_stack.py
index cf4e263..c063162 100755
--- a/hw/ip/otbn/util/check_call_stack.py
+++ b/hw/ip/otbn/util/check_call_stack.py
@@ -7,6 +7,7 @@
 import sys
 from typing import Dict, Tuple
 
+from shared.check import CheckResult
 from shared.decode import OTBNProgram, decode_elf
 from shared.insn_yaml import Insn
 from shared.operand import RegOperandType
@@ -28,32 +29,33 @@
     return True
 
 
-def check_call_stack(program: OTBNProgram) -> Tuple[bool, str]:
+def check_call_stack(program: OTBNProgram) -> CheckResult:
     '''Check that the special register x1 is used safely.
 
     If x1 is used for purposes unrelated to the call stack, it can trigger a
     CALL_STACK error. This check errors if x1 is used for any other instruction
     than `jal` or `jalr`.
     '''
-    for pc in program.insns:
-        insn = program.get_insn(pc)
-        operands = program.get_operands(pc)
+    out = CheckResult()
+    for pc, (insn, operands) in program.insns.items():
         if not _check_call_stack_insn(insn, operands):
-            return (False, 'check_call_stack: FAIL at PC {:#x}: {} {}'.format(
-                pc, insn.mnemonic, operands))
-    return (True, 'check_call_stack: PASS')
+            out.err(
+                'Potentially dangerous use of the call stack register x1 at '
+                'PC {:#x}: {}'.format(pc, insn.disassemble(pc, operands)))
+    out.set_prefix('check_call_stack: ')
+    return out
 
 
 def main() -> int:
     parser = argparse.ArgumentParser()
     parser.add_argument('elf', help=('The .elf file to check.'))
+    parser.add_argument('-v', '--verbose', action='store_true')
     args = parser.parse_args()
     program = decode_elf(args.elf)
-    ok, msg = check_call_stack(program)
-    print(msg)
-    if not ok:
-        return 1
-    return 0
+    result = check_call_stack(program)
+    if args.verbose or result.has_errors() or result.has_warnings():
+        print(result.report())
+    return 1 if result.has_errors() else 0
 
 
 if __name__ == "__main__":
diff --git a/hw/ip/otbn/util/check_loop.py b/hw/ip/otbn/util/check_loop.py
new file mode 100755
index 0000000..e5b44f1
--- /dev/null
+++ b/hw/ip/otbn/util/check_loop.py
@@ -0,0 +1,248 @@
+#!/usr/bin/env python3
+# Copyright lowRISC contributors.
+# Licensed under the Apache License, Version 2.0, see LICENSE for details.
+# SPDX-License-Identifier: Apache-2.0
+
+import argparse
+import sys
+from typing import Dict, List, Tuple
+
+from shared.check import CheckResult
+from shared.decode import OTBNProgram, decode_elf
+from shared.insn_yaml import Insn
+
+
+class CodeSection:
+    '''A continuous part of a program's code (represented as a PC range).
+
+    The code section is considered to include both the start and end PC.
+    '''
+    def __init__(self, start: int, end: int):
+        self.start = start
+        self.end = end
+
+    def __contains__(self, pc):
+        return self.start <= pc and pc <= self.end
+
+    def __repr__(self):
+        return '{:#x}-0x{:#x}'.format(self.start, self.end)
+
+
+def _get_pcs_for_mnemonics(program: OTBNProgram,
+                           mnems: List[str]) -> List[int]:
+    '''Gets all PCs in the program at which the given instruction is present.'''
+    return [
+        pc for (pc, (insn, _)) in program.insns.items()
+        if insn.mnemonic in mnems
+    ]
+
+
+def _get_branches(program: OTBNProgram) -> List[int]:
+    '''Gets the PCs of all branch instructions (BEQ and BNE) in the program.'''
+    return _get_pcs_for_mnemonics(program, ['bne', 'beq'])
+
+
+def _get_loop_starts(program: OTBNProgram) -> List[int]:
+    '''Gets the start PCs of all loops (LOOP and LOOPI) in the program.'''
+    return _get_pcs_for_mnemonics(program, ['loop', 'loopi'])
+
+
+def _get_loops(program: OTBNProgram) -> List[CodeSection]:
+    '''Gets the PC ranges of all loops (LOOP and LOOPI) in the program.'''
+    loop_starts = _get_loop_starts(program)
+    loops = []
+    for pc in loop_starts:
+        insn = program.get_insn(pc)
+        operands = program.get_operands(pc)
+        end_pc = pc + operands['bodysize'] * 4
+        loops.append(CodeSection(pc + 4, end_pc))
+    return loops
+
+
+def _check_loop_iterations(program: OTBNProgram,
+                           loops: List[CodeSection]) -> CheckResult:
+    '''Checks number of iterations for loopi.
+
+    If the number of iterations is 0, this check fails; `loopi` requires at least
+    one iteration and will raise a LOOP error otherwise. The `loop` instruction
+    also has this requirement, but since the number of loop iterations comes
+    from a register it's harder to check statically and is not considered here.
+    '''
+    out = CheckResult()
+    for loop in loops:
+        insn = program.get_insn(loop.start)
+        operands = program.get_operands(loop.start)
+        if insn.mnemonic == 'loopi' and operands['iterations'] <= 0:
+            out.err(
+                'Bad number of loop iterations ({}) at PC {:#x}: {}'.format(
+                    operands['iterations'], loop.start,
+                    insn.disassemble(loop.start, operands)))
+    return out
+
+
+def _check_loop_end_insns(program: OTBNProgram,
+                          loops: List[CodeSection]) -> CheckResult:
+    '''Checks that loops do not end in control flow instructions.
+
+    Such instructions can cause LOOP software errors during execution.
+    '''
+    out = CheckResult()
+    for loop in loops:
+        loop_end_insn = program.get_insn(loop.end)
+        if not loop_end_insn.straight_line:
+            out.err('Control flow instruction ({}) at end of loop at PC {:#x} '
+                    '(loop starting at PC {:#x})'.format(
+                        loop_end_insn.mnemonic, loop.end, loop.start))
+    return out
+
+
+def _check_loop_inclusion(program: OTBNProgram,
+                          loops: List[CodeSection]) -> CheckResult:
+    '''Checks that inner loops are fully contained within outer loops.
+
+    When a loop starts within the body of another loop, it must be the case
+    that the inner loop's final instruction occurs before the outer loop's.
+    '''
+    out = CheckResult()
+    for loop in loops:
+        for other in loops:
+            if other.start in loop and not other.end in loop:
+                out.err('Inner loop ends after outer loop (inner loop {}, '
+                        'outer loop {})'.format(other, loop))
+
+    return out
+
+
+def _check_loop_branching(program: OTBNProgram,
+                          loops: List[CodeSection]) -> CheckResult:
+    '''Checks that there are no branches into or out of loop bodies.
+
+    Branches within the same loop body are permitted (but not branches from an
+    inner loop to an outer loop, as this counts as branching out of the inner
+    loop). Because this isn't necessarily a fatal issue (for instance, it's
+    possible the branched-to code will always return to the loop), this check
+    returns warnings rather than errors.
+
+    A `jal` instruction with a register other than `x1` as the first operand is
+    treated the same as a branch and not permitted to cross the loop-body
+    boundary.
+    '''
+    out = CheckResult()
+
+    # Check all bne and beq instructions, as well as `jal` instructions with
+    # first operands other than x1 (unconditional branch)
+    to_check = _get_branches(program)
+    for pc in _get_pcs_for_mnemonics(program, ['jal']):
+        operands = program.get_operands(pc)
+        if operands['grd'] != 1:
+            to_check.append(pc)
+
+    for pc in to_check:
+        operands = program.get_operands(pc)
+        branch_addr = operands['offset'] & ((1 << 32) - 1)
+
+        # Get the loop bodies the branch is inside, if any
+        current_loops = []
+        for loop in loops:
+            if pc in loop:
+                current_loops.append(loop)
+
+        # Check that we're not branching out of any loop bodies
+        for loop in current_loops:
+            if branch_addr not in loop:
+                insn = program.get_insn(pc)
+                out.warn(
+                    'Branch out of loop at PC {:#x} (loop from PC {:#x} to PC '
+                    '{:#x}, branch {} to PC {:#x}). This might cause problems '
+                    'with the loop stack and surprising behavior.'.format(
+                        pc, loop.start, loop.end, insn.mnemonic, branch_addr))
+
+        # Check that we're not branching *into* a loop body that the branch
+        # instruction is not already in
+        for loop in loops:
+            if (branch_addr in loop) and (loop not in current_loops):
+                out.warn(
+                    'Branch into loop at PC {:#x} (loop from PC {:#x} to PC '
+                    '{:#x}, branch {} to PC {:#x}). This might cause problems '
+                    'with the loop stack and surprising behavior.'.format(
+                        pc, loop.start, loop.end, insn.mnemonic, branch_addr))
+
+    return out
+
+
+def _check_loop_stack(program: OTBNProgram,
+                      loops: List[CodeSection]) -> CheckResult:
+    '''Checks that loops will likely be properly cleared from loop stack.
+
+    The checks here are based on the OTBN hardware IP documentation on loop
+    nesting. From the docs:
+
+        To avoid polluting the loop stack and avoid surprising behaviour, the
+        programmer must ensure that:
+
+        * Even if there are branches and jumps within a loop body, the final
+          instruction of the loop body gets executed exactly once per iteration.
+        * Nested loops have distinct end addresses.
+        * The end instruction of an outer loop is not executed before an inner loop
+          finishes.
+
+    In order to avoid simulating the control flow of the entire program to
+    check the first and third conditions, this check takes a conservative,
+    simplistic approach and simply warns about all branching into or out of
+    loop bodies, including jumps that don't use the call stack (e.g. `jal x0,
+    <addr>`). Branching to locations within the same loop body is permitted.
+
+    The second condition in the list, distinct end addresses, is checked
+    separately.
+    '''
+    out = CheckResult()
+    out += _check_loop_branching(program, loops)
+
+    # Check that loops have unique end addresses
+    end_addrs = []
+    for loop in loops:
+        if loop.end in end_addrs:
+            out.err(
+                'Loop starting at PC {:#x} shares a final instruction with '
+                'another loop; consider adding a NOP instruction.'.format(
+                    loop.start))
+
+    return out
+
+
+def check_loop(program: OTBNProgram) -> CheckResult:
+    '''Check that loops are properly formed.
+
+    Performs three checks to rule out certain classes of loop errors and
+    undefined behavior:
+    1. For loopi instructions, check that the number of iterations is > 0.
+    2. Ensure that loops do not end in control-flow instructions such as jal or
+       bne, which will raise LOOP errors.
+    3. Checks that there is no branching into or out of loop bodies.
+    4. For nested loops, the inner loop is completely contained within the
+       outer loop.
+    '''
+    loops = _get_loops(program)
+    out = CheckResult()
+    out += _check_loop_iterations(program, loops)
+    out += _check_loop_end_insns(program, loops)
+    out += _check_loop_stack(program, loops)
+    out += _check_loop_inclusion(program, loops)
+    out.set_prefix('check_loop: ')
+    return out
+
+
+def main() -> int:
+    parser = argparse.ArgumentParser()
+    parser.add_argument('elf', help=('The .elf file to check.'))
+    parser.add_argument('-v', '--verbose', action='store_true')
+    args = parser.parse_args()
+    program = decode_elf(args.elf)
+    result = check_loop(program)
+    if args.verbose or result.has_errors() or result.has_warnings():
+        print(result.report())
+    return 1 if result.has_errors() else 0
+
+
+if __name__ == "__main__":
+    sys.exit(main())
diff --git a/hw/ip/otbn/util/shared/check.py b/hw/ip/otbn/util/shared/check.py
new file mode 100644
index 0000000..315b8be
--- /dev/null
+++ b/hw/ip/otbn/util/shared/check.py
@@ -0,0 +1,61 @@
+#!/usr/bin/env python3
+# Copyright lowRISC contributors.
+# Licensed under the Apache License, Version 2.0, see LICENSE for details.
+# SPDX-License-Identifier: Apache-2.0
+
+
+class CheckResult:
+    '''A class to record the results of static checks.
+
+    Can record any number of errors and warnings. Combine two check results
+    with +, e.g.:
+
+        out = CheckResult()
+        out += first_check()
+        out += second_check()
+        out.warn('A warning')
+        print(out.report()) # prints warnings/errors from both checks and "A warning"
+    '''
+    def __init__(self):
+        self.errors = []
+        self.warnings = []
+        self.prefix = ''
+
+    def warn(self, msg):
+        '''Add a warning.'''
+        self.warnings.append(msg)
+
+    def err(self, msg):
+        '''Add an error.'''
+        self.errors.append(msg)
+
+    def __add__(self, other):
+        '''Combines both operands' errors/warnings in a new CheckResult.'''
+        if not isinstance(other, CheckResult):
+            raise ValueError(
+                'Cannot add {} (of type {}) to {} (of type CheckResult)'.
+                format(other, type(other), self))
+        out = CheckResult()
+        out.warnings = self.warnings + other.warnings
+        out.errors = self.errors + other.errors
+        return out
+
+    def set_prefix(self, prefix):
+        '''Add a prefix to the printouts for this check.'''
+        self.prefix = prefix
+
+    def has_errors(self):
+        return len(self.errors) != 0
+
+    def has_warnings(self):
+        return len(self.warnings) != 0
+
+    def report(self):
+        '''Show a message to represent the results of the check.'''
+        if not self.has_warnings() and not self.has_errors():
+            return '{}PASS'.format(self.prefix)
+        warn_strs = [
+            '{}WARN: {}'.format(self.prefix, w) for w in self.warnings
+        ]
+        err_strs = ['{}ERROR: {}'.format(self.prefix, e) for e in self.errors]
+        return '\n'.join(warn_strs + err_strs)
diff --git a/hw/ip/otbn/util/shared/decode.py b/hw/ip/otbn/util/shared/decode.py
index 3826551..fc2b82a 100644
--- a/hw/ip/otbn/util/shared/decode.py
+++ b/hw/ip/otbn/util/shared/decode.py
@@ -3,8 +3,8 @@
 # Licensed under the Apache License, Version 2.0, see LICENSE for details.
 # SPDX-License-Identifier: Apache-2.0
 
-import sys
 import struct
+import sys
 from typing import Dict, Tuple
 
 from shared.elf import read_elf
@@ -25,11 +25,11 @@
         self.data = data  # addr -> data (32b word)
 
         self.insns = {}
-        for pc in insns.keys():
-            opcode = insns[pc]
+        for pc, opcode in insns.items():
             mnem = INSNS_FILE.mnem_for_word(opcode)
             if mnem is None:
-                raise ValueError('No legal decoding for mnemonic: {}'.format(mnem))
+                raise ValueError(
+                    'No legal decoding for mnemonic: {}'.format(mnem))
             insn = INSNS_FILE.mnemonic_to_insn.get(mnem)
             enc_vals = insn.encoding.extract_operands(opcode)
             op_vals = insn.enc_vals_to_op_vals(pc, enc_vals)