[otbn] A disassembly wrapper around riscv32-unknown-elf-objdump

This doesn't support some of the clever formatting flags that objdump
has (mainly because I've never yet used them), but it does support the
standard syntax. Output looks like this:

  00000000 <d0inv>:
     0:   8000702b                bn.xor  w0, w0, w0
     4:   0010402b                bn.addi w0, w0, 1
     8:   00006e8b                bn.mov  w29, w0
     c:   00d4107b                loopi   256, 13
    10:   01de103b                bn.mulqacc.Z w28.0, w29.0, 0
    14:   03de203b                bn.mulqacc w28.1, w29.0, 1
    18:   89de20bb                bn.mulqacc.so w1.L, w28.0, w29.1, 1
    1c:   05de003b                bn.mulqacc w28.2, w29.0, 0
    20:   0bde003b                bn.mulqacc w28.1, w29.1, 0
    24:   11de003b                bn.mulqacc w28.0, w29.2, 0
    28:   07de203b                bn.mulqacc w28.3, w29.0, 1
    2c:   0dde203b                bn.mulqacc w28.2, w29.1, 1
    30:   13de203b                bn.mulqacc w28.1, w29.2, 1
    34:   b9de20bb                bn.mulqacc.so w1.U, w28.0, w29.3, 1
    38:   0000e0ab                bn.and  w1, w1, w0
    3c:   801eeeab                bn.or   w29, w29, w1
    40:   0000002b                bn.add  w0, w0, w0
    44:   01df9eab                bn.sub  w29, w31, w29
    48:   00008067                ret

Signed-off-by: Rupert Swarbrick <rswarbrick@lowrisc.org>
diff --git a/hw/ip/otbn/util/Makefile b/hw/ip/otbn/util/Makefile
index f5078c2..d691180 100644
--- a/hw/ip/otbn/util/Makefile
+++ b/hw/ip/otbn/util/Makefile
@@ -3,7 +3,7 @@
 # SPDX-License-Identifier: Apache-2.0
 
 pylibs := insn_yaml.py
-pyscripts := yaml_to_doc.py otbn-as otbn-ld
+pyscripts := yaml_to_doc.py otbn-as otbn-ld otbn-objdump
 
 .PHONY: all
 all: lint
diff --git a/hw/ip/otbn/util/insn_yaml.py b/hw/ip/otbn/util/insn_yaml.py
index a0d1b06..a8b02c4 100644
--- a/hw/ip/otbn/util/insn_yaml.py
+++ b/hw/ip/otbn/util/insn_yaml.py
@@ -238,6 +238,17 @@
         assert bits_taken == self.width
         return ret
 
+    def decode(self, raw: int) -> int:
+        '''Extract the bit fields from the given value'''
+        ret = 0
+        for msb, lsb in self.ranges:
+            width = msb - lsb + 1
+            mask = (1 << width) - 1
+
+            ret <<= width
+            ret |= (raw >> lsb) & mask
+        return ret
+
 
 class BoolLiteral:
     '''Represents a boolean literal, with possible 'x characters
@@ -681,6 +692,15 @@
         '''
         return None
 
+    def render_val(self, value: int) -> str:
+        '''Render the given value as a string.
+
+        The default implementation prints it as a decimal number. Register
+        operands, for example, will want to print 3 as "x3" and so on.
+
+        '''
+        return str(value)
+
 
 class RegOperandType(OperandType):
     '''A class representing a register operand type'''
@@ -720,6 +740,16 @@
 
         return idx
 
+    def render_val(self, value: int) -> str:
+        fmt = RegOperandType.TYPE_FMTS.get(self.reg_type)
+        assert fmt is not None
+        _, pfx = fmt
+
+        if pfx is None:
+            return super().render_val(value)
+
+        return '{}{}'.format(pfx, value)
+
 
 class ImmOperandType(OperandType):
     '''A class representing an immediate operand type'''
@@ -768,6 +798,18 @@
                          'Supported values: {}.'
                          .format(as_str, known_vals))
 
+    def render_val(self, value: int) -> str:
+        # On a bad value, we have to return *something*. Since this is just
+        # going into disassembly, let's be vaguely helpful and return something
+        # that looks clearly bogus.
+        #
+        # Note that if the number of items in the enum is not a power of 2,
+        # this could happen with a bad binary, despite good tools.
+        if value < 0 or value >= len(self.items):
+            return '???'
+
+        return self.items[value]
+
 
 class OptionOperandType(ImmOperandType):
     '''A class representing an option operand type'''
@@ -790,6 +832,11 @@
                          'If specified, it should have been {!r}.'
                          .format(as_str, self.option))
 
+    def render_val(self, value: int) -> str:
+        # Option types are always 1 bit wide, so the value should be 0 or 1.
+        assert value in [0, 1]
+        return self.option if value else ''
+
 
 def parse_operand_type(fmt: str) -> OperandType:
     '''Make sense of the operand type syntax'''
@@ -953,6 +1000,20 @@
         # parsing.
         return r'([^ ,+\-]+|[+\-]+)\s*'
 
+    def render_vals(self,
+                    op_vals: Dict[str, int],
+                    operands: Dict[str, Operand]) -> str:
+        '''Return an assembly listing for the given operand fields
+
+        '''
+        if self.is_literal:
+            return self.text
+
+        assert self.text in op_vals
+        assert self.text in operands
+
+        return operands[self.text].op_type.render_val(op_vals[self.text])
+
 
 class SyntaxHunk:
     '''An object representing a hunk of syntax that might be optional'''
@@ -1047,6 +1108,28 @@
         # (one-or-more) to it.
         return '(?:{})?'.format(body) if self.is_optional else body
 
+    def render_vals(self,
+                    op_vals: Dict[str, int],
+                    operands: Dict[str, Operand]) -> str:
+        '''Return an assembly listing for the hunk given operand values
+
+        If this hunk is optional and all its operands are zero, the hunk is
+        omitted (so this function returns the empty string).
+
+        '''
+        if self.is_optional:
+            required = False
+            for op_name in self.op_list:
+                if op_vals[op_name] != 0:
+                    required = True
+                    break
+
+            if not required:
+                return ''
+
+        return ''.join(token.render_vals(op_vals, operands)
+                       for token in self.tokens)
+
 
 class InsnSyntax:
     '''A class representing the syntax of an instruction
@@ -1185,6 +1268,15 @@
 
         return (pattern, op_to_grp)
 
+    def render_vals(self,
+                    op_vals: Dict[str, int],
+                    operands: Dict[str, Operand]) -> str:
+        '''Return an assembly listing for the given operand fields'''
+        parts = []
+        for hunk in self.hunks:
+            parts.append(hunk.render_vals(op_vals, operands))
+        return ''.join(parts)
+
 
 class EncodingField:
     '''A single element of an encoding's mapping'''
diff --git a/hw/ip/otbn/util/otbn-objdump b/hw/ip/otbn/util/otbn-objdump
new file mode 100755
index 0000000..49d326d
--- /dev/null
+++ b/hw/ip/otbn/util/otbn-objdump
@@ -0,0 +1,174 @@
+#!/usr/bin/env python3
+# Copyright lowRISC contributors.
+# Licensed under the Apache License, Version 2.0, see LICENSE for details.
+# SPDX-License-Identifier: Apache-2.0
+
+'''A wrapper around riscv32-unknown-elf-objdump for OTBN'''
+
+import os
+import re
+import subprocess
+import sys
+from typing import Dict, List, Optional, Tuple
+
+from insn_yaml import Encoding, Insn, InsnsFile, load_file
+
+
+def snoop_disasm_flags(argv: List[str]) -> bool:
+    '''Look through objdump's flags for -d, -D etc.'''
+    for arg in argv:
+        if arg in ['-d', '-D', '--disassemble', '--disassemble-all']:
+            return True
+
+        # --disassemble=symbol
+        if arg.startswith('--disassemble='):
+            return True
+
+    return False
+
+
+def get_insn(raw: int, masks: List[Tuple[int, int, Insn]]) -> Optional[Insn]:
+    '''Try to find a mnemonic for this raw instruction
+
+    masks is a list of tuples (m0, m1, mnemonic) as returned by
+    get_insn_masks. If no tuple matches, returns None.
+
+    '''
+    found = None
+    for m0, m1, insn in masks:
+        # If any bit is set that should be zero or if any bit is clear that
+        # should be one, ignore this instruction.
+        if raw & m0 or (~ raw) & m1:
+            continue
+
+        # We have a match! The code in insn_yaml should already have checked
+        # this is the only one, but it can't hurt to be careful.
+        assert found is None
+
+        found = insn
+
+    return found
+
+
+def extract_operands(raw: int, encoding: Encoding) -> Dict[str, int]:
+    '''Extract the operand fields from the encoded instruction'''
+    ret = {}
+    for field in encoding.fields.values():
+        # The operand fields (rather than fixed ones) have the operand name as
+        # their value.
+        if not isinstance(field.value, str):
+            continue
+
+        ret[field.value] = field.scheme_field.bits.decode(raw)
+
+    return ret
+
+
+# OTBN instructions are 32 bit wide, so there's just one "word" in the second
+# column. The stuff that gets passed through looks like this:
+#
+#    84:   8006640b                0x8006640b
+#
+# We don't use a back-ref for the second copy of the data, because if the raw
+# part has leading zeros, they don't appear there. For example:
+#
+#   6d0:   0000418b                0x418b
+#
+_RAW_INSN_RE = re.compile(r'([\s]*[0-9a-f]+:[\s]+([0-9a-f]{8})[\s]+)'
+                          r'0x[0-9a-f]+\s*$')
+
+
+def transform_disasm_line(line: str,
+                          masks: List[Tuple[int, int, Insn]]) -> str:
+    '''Transform filter to insert OTBN disasm as needed'''
+    match = _RAW_INSN_RE.match(line)
+    if match is None:
+        return line
+
+    # Parse match.group(2) as an integer. It was exactly 8 hex characters, so
+    # will fit in a u32.
+    raw = int(match.group(2), 16)
+    assert 0 <= raw < (1 << 32)
+
+    insn = get_insn(raw, masks)
+    if insn is None:
+        # No match for this instruction pattern. Leave as-is.
+        return line
+
+    # Extract operand values. We know we have an encoding (otherwise
+    # get_insn_masks wouldn't have added the instruction to the masks list).
+    assert insn.encoding is not None
+    op_vals = extract_operands(raw, insn.encoding)
+
+    # Similarly, we know we have a syntax (again, get_insn_masks requires it).
+    # The rendering of the fields is done by the syntax object.
+    assert insn.syntax is not None
+    return('{}{:7}{}{}'.format(match.group(1), insn.mnemonic,
+                               '' if insn.glued_ops else ' ',
+                               insn.syntax.render_vals(op_vals,
+                                                       insn.name_to_operand)))
+
+
+def get_insn_masks(insns_file: InsnsFile) -> List[Tuple[int, int, Insn]]:
+    '''Generate a list of zeros/ones masks for known instructions
+
+    The returned list has elements (m0, m1, mnemonic). We don't check here that
+    the results are unambiguous: that check is supposed to happen in insn_yaml
+    already, and we'll do a belt-and-braces check for each instruction as we
+    go.
+
+    '''
+    ret = []
+    for insn in insns_file.insns:
+        if insn.encoding is None or insn.syntax is None:
+            continue
+
+        m0, m1 = insn.encoding.get_masks()
+        # Encoding.get_masks sets bits that are 'x', so we have to do a
+        # difference operation too.
+        ret.append((m0 & ~m1, m1 & ~m0, insn))
+    return ret
+
+
+def main() -> int:
+    args = sys.argv[1:]
+    has_disasm = snoop_disasm_flags(args)
+
+    objdump_name = 'riscv32-unknown-elf-objdump'
+    cmd = [objdump_name] + args
+    try:
+        if not has_disasm:
+            return subprocess.run(cmd).returncode
+        else:
+            proc = subprocess.run(cmd, capture_output=True, text=True)
+            if proc.returncode:
+                # Dump any lines that objdump wrote before it died
+                sys.stdout.write(proc.stdout)
+                return proc.returncode
+    except FileNotFoundError:
+        sys.stderr.write('Unknown command: {!r}. '
+                         '(is it installed and on your PATH?)\n'
+                         .format(objdump_name))
+        return 127
+
+    insns_yml = os.path.normpath(os.path.join(os.path.dirname(__file__),
+                                              '..', 'data', 'insns.yml'))
+    try:
+        insns_file = load_file(insns_yml)
+    except RuntimeError as err:
+        sys.stderr.write('{}\n'.format(err))
+        return 1
+
+    insn_masks = get_insn_masks(insns_file)
+
+    # If we get here, we think we're disassembling something, objdump ran
+    # successfully and we have its results in proc.stdout
+    for line in proc.stdout.split('\n'):
+        transformed = transform_disasm_line(line, insn_masks)
+        sys.stdout.write(transformed + '\n')
+
+    return 0
+
+
+if __name__ == '__main__':
+    sys.exit(main())