[otbn] Add some OTBN instructions to otbnsim

Add the variant, model and basic instruction support.

We also dump an execution trace to a file (/tmp/otbn_XXXXX/trace) and
print out the name of the trace directory, which should make debugging
rather easier.

As a sanity check, run it with the smoke test example as follows:

    mkdir -p build-bin/otbn/smoke_test
    hw/ip/otbn/util/build.sh \
      hw/ip/otbn/dv/smoke/smoke_test.S \
      build-bin/otbn/smoke_test/smoke
    hw/ip/otbn/dv/otbnsim/standalone.py -v \
      build-bin/otbn/smoke_test/smoke.elf

The trace looks like this:

    lui x1, 855019                      | [x1 = d0beb000]
    addi x1, x1, 1299                   | [x1 = d0beb513]
    lui x2, 658409                      | [x2 = a0be9000]
    addi x2, x2, 282                    | [x2 = a0be911a]
    add x3, x1, x2                      | [x3 = 717d462d, x1 = 00000000]
    sub x4, x2, x1                      | [x4 = a0be911a]
    or x5, x1, x2                       | [x5 = a0be911a]
    and x6, x1, x2                      | [x6 = 00000000]
    xor x7, x1, x2                      | [x7 = a0be911a]
    ori x8, x1, 291                     | [x8 = 00000123]
    andi x9, x1, 1980                   | [x9 = 00000000]
    xori x10, x1, 1146                  | [x10 = 0000047a]
    slli x11, x1, 0x0a                  | [x11 = 00000000]
    srli x12, x1, 0x0d                  | [x12 = 00000000]
    srai x13, x1, 0x07                  | [x13 = 00000000]
    sll x14, x1, x2                     | [x14 = 00000000]
    srl x15, x1, x2                     | [x15 = 00000000]
    sra x16, x1, x2                     | [x16 = 00000000]

Notice the updates to x1: the code is using it just like a normal
register, but we know better(!) and are treating it as a hardware call
stack.

This also behaves "as expected" when running the loop.S code snippet.
Generate that with

   make -C hw/ip/otbn/util asm-sanity
   hw/ip/otbn/dv/otbnsim/standalone.py -v \
     build-bin/otbn/util/code-snippets/loop

Finally, note that we bump the required version of riscv-model. This
is because there's an API break between 0.6.2 and 0.6.4 (the module
exporting TerminateException changes from riscvmodel.isa to
riscvmodel.model).

Signed-off-by: Stefan Wallentowitz <stefan.wallentowitz@gi-de.com>
Co-authored-by: Rupert Swarbrick <rswarbrick@lowrisc.org>
diff --git a/hw/ip/otbn/dv/model/otbn_model.cc b/hw/ip/otbn/dv/model/otbn_model.cc
index 244c6df..d15a7fc 100644
--- a/hw/ip/otbn/dv/model/otbn_model.cc
+++ b/hw/ip/otbn/dv/model/otbn_model.cc
@@ -232,6 +232,7 @@
   char ifname[] = "/tmp/otbn_XXXXXX/imem";
   char dfname[] = "/tmp/otbn_XXXXXX/dmem";
   char cfname[] = "/tmp/otbn_XXXXXX/cycles";
+  char tfname[] = "/tmp/otbn_XXXXXX/trace";
 
   if (mkdtemp(dir) == nullptr) {
     std::cerr << "Cannot create temporary directory" << std::endl;
@@ -241,6 +242,7 @@
   std::memcpy(ifname, dir, strlen(dir));
   std::memcpy(dfname, dir, strlen(dir));
   std::memcpy(cfname, dir, strlen(dir));
+  std::memcpy(tfname, dir, strlen(dir));
 
   try {
     dump_memory(dfname, dmem_scope, dmem_words, 32);
@@ -260,7 +262,7 @@
 
   std::ostringstream cmd;
   cmd << model_path << " " << imem_words << " " << ifname << " " << dmem_words
-      << " " << dfname << " " << cfname;
+      << " " << dfname << " " << cfname << " " << tfname;
 
   if (std::system(cmd.str().c_str()) != 0) {
     std::cerr << "Failed to execute model (cmd was: '" << cmd.str() << "').\n";
diff --git a/hw/ip/otbn/dv/otbnsim/Makefile b/hw/ip/otbn/dv/otbnsim/Makefile
index da59b61..80d6600 100644
--- a/hw/ip/otbn/dv/otbnsim/Makefile
+++ b/hw/ip/otbn/dv/otbnsim/Makefile
@@ -14,12 +14,12 @@
 	mkdir -p $@
 
 py-scripts := otbnsim.py standalone.py
-py-files   := $(wildcard *.py otbnsim/*.py)
+py-files   := $(wildcard *.py sim/*.py)
 py-libs    := $(filter-out $(py-scripts),$(py-files))
 
 lint-stamps := $(foreach scr,$(py-scripts),$(build-dir)/$(scr).stamp)
 
-$(lint-stamps): $(build-dir)/%.stamp: % $(py-libs)
+$(lint-stamps): $(build-dir)/%.stamp: % $(py-libs) | $(build-dir)
 	env MYPYPATH="$$MYPYPATH:../../util" mypy --strict $< $(py-libs)
 	touch $@
 
diff --git a/hw/ip/otbn/dv/otbnsim/otbnsim.py b/hw/ip/otbn/dv/otbnsim/otbnsim.py
index ea20fd0..6960ba0 100755
--- a/hw/ip/otbn/dv/otbnsim/otbnsim.py
+++ b/hw/ip/otbn/dv/otbnsim/otbnsim.py
@@ -10,10 +10,10 @@
 import sys
 
 from riscvmodel.sim import Simulator  # type: ignore
-from riscvmodel.model import Model  # type: ignore
-from riscvmodel.variant import RV32I  # type: ignore
 
 from sim.decode import decode_file
+from sim.model import OTBNModel
+from sim.variant import RV32IXotbn
 
 
 def main() -> int:
@@ -23,11 +23,12 @@
     parser.add_argument("dmem_words", type=int)
     parser.add_argument("dmem_file")
     parser.add_argument("cycles_file")
+    parser.add_argument("trace_file")
 
     args = parser.parse_args()
-    sim = Simulator(Model(RV32I))
+    sim = Simulator(OTBNModel(verbose=args.trace_file))
 
-    sim.load_program(decode_file(args.imem_file, RV32I))
+    sim.load_program(decode_file(args.imem_file, RV32IXotbn))
     with open(args.dmem_file, "rb") as f:
         sim.load_data(f.read())
 
diff --git a/hw/ip/otbn/dv/otbnsim/sim/decode.py b/hw/ip/otbn/dv/otbnsim/sim/decode.py
index bb4ce3c..0049ff3 100644
--- a/hw/ip/otbn/dv/otbnsim/sim/decode.py
+++ b/hw/ip/otbn/dv/otbnsim/sim/decode.py
@@ -16,6 +16,12 @@
 from riscvmodel.model import Model  # type: ignore
 from riscvmodel.variant import Variant, RV32I  # type: ignore
 
+# The riscvmodel decoder works by introspection, checking all the instruction
+# classes that have been defined so far. This implicit approach only works if
+# we make absolutely sure that we *have* loaded the instruction definitions we
+# use. So we do this useless import to ensure it.
+from .insn import InstructionLOOP  # noqa: F401
+
 # A subclass of Instruction
 _InsnSubclass = TypeVar('_InsnSubclass', bound=Instruction)
 
diff --git a/hw/ip/otbn/dv/otbnsim/sim/elf.py b/hw/ip/otbn/dv/otbnsim/sim/elf.py
index 69b6c72..ff45542 100644
--- a/hw/ip/otbn/dv/otbnsim/sim/elf.py
+++ b/hw/ip/otbn/dv/otbnsim/sim/elf.py
@@ -14,6 +14,7 @@
 from shared.mem_layout import get_memory_layout
 
 from .decode import decode_bytes
+from .variant import RV32IXotbn
 
 _SegList = List[Tuple[int, bytes]]
 
@@ -135,7 +136,7 @@
                            'not a multiple of 4.'
                            .format(path, len(imem_bytes)))
 
-    imem_insns = decode_bytes(imem_bytes, RV32I)
+    imem_insns = decode_bytes(imem_bytes, RV32IXotbn)
 
     sim.load_program(imem_insns)
     sim.load_data(dmem_bytes)
diff --git a/hw/ip/otbn/dv/otbnsim/sim/insn.py b/hw/ip/otbn/dv/otbnsim/sim/insn.py
new file mode 100644
index 0000000..634c781
--- /dev/null
+++ b/hw/ip/otbn/dv/otbnsim/sim/insn.py
@@ -0,0 +1,406 @@
+# Copyright lowRISC contributors.
+# Licensed under the Apache License, Version 2.0, see LICENSE for details.
+# SPDX-License-Identifier: Apache-2.0
+
+from riscvmodel.insn import isa  # type: ignore
+
+from .isa import (InstructionBNAFType,
+                  InstructionBNAIType,
+                  InstructionBNAMType,
+                  InstructionBNANType,
+                  InstructionBNAQType,
+                  InstructionBNAType,
+                  InstructionBNCSType,
+                  InstructionBNCType,
+                  InstructionBNISType,
+                  InstructionBNIType,
+                  InstructionBNMVRType,
+                  InstructionBNMVType,
+                  InstructionBNRType,
+                  InstructionBNSType,
+                  InstructionLIType,
+                  InstructionLType,
+                  ShiftReg)
+from .model import OTBNModel
+from .variant import RV32IXotbn
+
+
+@isa("loop", RV32IXotbn, opcode=0b1111011, funct2=0b00)
+class InstructionLOOP(InstructionLType):
+    """
+    Loop (indirect)
+
+    Repeat a sequence of code multiple times. The number of iterations is a GPR
+    value. The length of the loop is given as immediate.
+
+    Alternative assembly notation: The size of the loop body is given by the
+    number of instructions in the parentheses.
+
+    LOOP <grs> (
+      # loop body
+    )
+    """
+    def execute(self, model: OTBNModel) -> None:
+        assert self.rs1 is not None
+        model.state.loop_start(int(model.state.intreg[self.rs1]),
+                               int(self.bodysize))
+
+
+@isa("loopi", RV32IXotbn, opcode=0b1111011, funct2=0b01)
+class InstructionLOOPI(InstructionLIType):
+    """
+    Loop Immediate
+
+    Repeat a sequence of code multiple times. The number of iterations is given
+    as an immediate, as is the length of the loop. The number of iterations must
+    be larger than zero.
+
+    Alternative assembly notation. The size of the loop body is given by the
+    number of instructions in the parentheses.
+
+    LOOPI <iterations> (
+      # loop body
+    )
+
+    """
+    def execute(self, model: OTBNModel) -> None:
+        model.state.loop_start(int(self.iter), int(self.bodysize))
+
+
+@isa("bn.add", RV32IXotbn, opcode=0b0101011, funct3=0b000)
+class InstructionBNADD(InstructionBNAFType):
+    """
+    Add
+
+    Adds two WDR values, writes the result to the destination WDR and updates
+    flags. The content of the second source WDR can be shifted by an immediate
+    before it is consumed by the operation.
+    """
+    def execute(self, model: OTBNModel) -> None:
+        a = int(model.state.wreg[self.wrs1].unsigned())
+        b_shifted = ShiftReg(int(model.state.wreg[self.wrs2].unsigned()),
+                             self.shift_type, self.shift_bytes)
+        (result, flags) = model.add_with_carry(a, b_shifted, 0)
+        model.state.wreg[self.wrd] = result
+        model.state.flags[self.fg] = flags
+
+
+@isa("bn.addc", RV32IXotbn, opcode=0b0101011, funct3=0b010)
+class InstructionBNADDC(InstructionBNAFType):
+    def execute(self, model: OTBNModel) -> None:
+        a = int(model.state.wreg[self.wrs1].unsigned())
+        b_shifted = ShiftReg(int(model.state.wreg[self.wrs2].unsigned()),
+                             self.shift_type, self.shift_bytes)
+        (result, flags) = model.add_with_carry(a, b_shifted,
+                                               model.state.flags[self.fg].C)
+        model.state.wreg[self.wrd] = result
+        model.state.flags[self.fg] = flags
+
+
+@isa("bn.addi", RV32IXotbn, opcode=0b0101011, funct3=0b100, funct30=0)
+class InstructionBNADDI(InstructionBNAIType):
+    def execute(self, model: OTBNModel) -> None:
+        a = int(model.state.wreg[self.wrs1].unsigned())
+        b = int(self.imm)
+        (result, flags) = model.add_with_carry(a, b, 0)
+        model.state.wreg[self.wrd] = result
+        model.state.flags[self.fg] = flags
+
+
+@isa("bn.addm", RV32IXotbn, opcode=0b0101011, funct3=0b101, funct30=0)
+class InstructionBNADDM(InstructionBNAMType):
+    def execute(self, model: OTBNModel) -> None:
+        a = int(model.state.wreg[self.wrs1].unsigned())
+        b = int(model.state.wreg[self.wrs2].unsigned())
+        (result, _) = model.add_with_carry(a, b, 0)
+        if result >= int(model.state.mod):
+            result -= int(model.state.mod)
+        model.state.wreg[self.wrd] = result
+
+
+@isa("bn.mulqacc", RV32IXotbn, opcode=0b0111011)
+class InstructionBNMULQACC(InstructionBNAQType):
+    """
+    Quarter-word Multiply and Accumulate
+
+    Multiplies two WLEN/4 WDR values and adds the result to an accumulator after
+    shifting it. Optionally shifts some/all of the resulting accumulator value
+    out to a destination WDR.
+    """
+    def execute(self, model: OTBNModel) -> None:
+        assert self.wrs1 is not None
+        assert self.wrs2 is not None
+        assert self.wrs1_qwsel is not None
+        assert self.wrs2_qwsel is not None
+        assert self.acc_shift_imm is not None
+
+        a_qw = model.get_wr_quarterword(self.wrs1, self.wrs1_qwsel)
+        b_qw = model.get_wr_quarterword(self.wrs2, self.wrs2_qwsel)
+
+        mul_res = a_qw * b_qw
+
+        acc = int(model.state.single_regs['acc'])
+
+        if (self.zero_acc):
+            acc = 0
+
+        acc += (mul_res << (self.acc_shift_imm * 64))
+
+        if self.wb_variant > 0:
+            if self.wb_variant == 1:
+                model.set_wr_halfword(self.wrd, acc, self.wrd_hwsel)
+                acc = acc >> 128
+            elif self.wb_variant == 2:
+                model.state.wreg[self.wrd].set(acc)
+
+        model.state.single_regs['acc'].update(acc)
+
+
+@isa("bn.sub", RV32IXotbn, opcode=0b0101011, funct3=0b001)
+class InstructionBNSUB(InstructionBNAFType):
+    def execute(self, model: OTBNModel) -> None:
+        a = int(model.state.wreg[self.wrs1])
+        b_shifted = ShiftReg(int(model.state.wreg[self.wrs2]), self.shift_type,
+                             self.shift_bytes)
+        (result, flags) = model.add_with_carry(a, -b_shifted, 0)
+        model.state.wreg[self.wrd] = result
+        model.state.flags[self.fg] = flags
+
+
+@isa("bn.subb", RV32IXotbn, opcode=0b0101011, funct3=0b011)
+class InstructionBNSUBB(InstructionBNAFType):
+    def execute(self, model: OTBNModel) -> None:
+        a = int(model.state.wreg[self.wrs1])
+        b_shifted = ShiftReg(int(model.state.wreg[self.wrs2]), self.shift_type,
+                             self.shift_bytes)
+        (result,
+         flags) = model.add_with_carry(a, -b_shifted,
+                                       1 - model.state.flags[self.fg].C)
+        model.state.wreg[self.wrd] = result
+        model.state.flags[self.fg] = flags
+
+
+@isa("bn.subi", RV32IXotbn, opcode=0b0101011, funct3=0b100, funct30=1)
+class InstructionBNSUBI(InstructionBNAIType):
+    def execute(self, model: OTBNModel) -> None:
+        a = int(model.state.wreg[self.wrs1])
+        b = int(self.imm)
+        (result, flags) = model.add_with_carry(a, -b, 0)
+        model.state.wreg[self.wrd] = result
+        model.state.flags[self.fg] = flags
+
+
+@isa("bn.subm", RV32IXotbn, opcode=0b0101011, funct3=0b101, funct30=1)
+class InstructionBNSUBM(InstructionBNAMType):
+    def execute(self, model: OTBNModel) -> None:
+        a = int(model.state.wreg[self.wrs1])
+        b = int(model.state.wreg[self.wrs2])
+        result, _ = model.add_with_carry(a, -b, 0)
+        if result >= model.state.mod:
+            result -= model.state.mod
+        model.state.wreg[self.wrd] = result
+
+
+@isa("bn.and", RV32IXotbn, opcode=0b0101011, funct3=0b110, funct31=0)
+class InstructionBNAND(InstructionBNAType):
+    """
+    Bitwise AND
+
+    Performs a bitwise and operation. Takes the values stored in registers
+    referenced by wrs1 and wrs2 and stores the result in the register referenced
+    by wrd. The content of the second source register can be shifted by an
+    immediate before it is consumed by the operation.
+    """
+    def execute(self, model: OTBNModel) -> None:
+        assert self.shift_type is not None
+
+        b_shifted = ShiftReg(model.state.wreg[self.wrs2],
+                             self.shift_type, self.shift_bytes)
+        a = model.state.wreg[self.wrs1]
+        model.state.wreg[self.wrd] = a & b_shifted
+
+
+@isa("bn.or", RV32IXotbn, opcode=0b0101011, funct3=0b110, funct31=1)
+class InstructionBNOR(InstructionBNAType):
+    """
+    Bitwise OR
+
+    Performs a bitwise or operation. Takes the values stored in WDRs referenced
+    by wrs1 and wrs2 and stores the result in the WDR referenced by wrd. The
+    content of the second source WDR can be shifted by an immediate before it is
+    consumed by the operation.
+    """
+    def execute(self, model: OTBNModel) -> None:
+        assert self.shift_type is not None
+
+        b_shifted = ShiftReg(model.state.wreg[self.wrs2],
+                             self.shift_type, self.shift_bytes)
+        a = model.state.wreg[self.wrs1]
+        model.state.wreg[self.wrd] = a | b_shifted
+
+
+@isa("bn.not", RV32IXotbn, opcode=0b0101011, funct3=0b111, funct31=0)
+class InstructionBNNOT(InstructionBNANType):
+    """
+    Bitwise NOT
+
+    Negates the value in <wrs>, storing the result into <wrd>. The source value
+    can be shifted by an immediate before it is consumed by the operation.
+    """
+    def execute(self, model: OTBNModel) -> None:
+        b_shifted = model.state.wreg[self.wrs1]
+        model.state.wreg[self.wrd] = ~b_shifted
+
+
+@isa("bn.xor", RV32IXotbn, opcode=0b0101011, funct3=0b111, funct31=1)
+class InstructionBNXOR(InstructionBNAType):
+    """
+    Bitwise XOR.
+
+    Performs a bitwise xor operation. Takes the values stored in WDRs referenced
+    by wrs1 and wrs2 and stores the result in the WDR referenced by wrd. The
+    content of the second source WDR can be shifted by an immediate before it is
+    consumed by the operation.
+    """
+    def execute(self, model: OTBNModel) -> None:
+        b_shifted = model.state.wreg[self.wrs2]
+        a = model.state.wreg[self.wrs1]
+        model.state.wreg[self.wrd] = a ^ b_shifted
+
+
+@isa("bn.rshi", RV32IXotbn, opcode=0b1111011, funct2=0b11)
+class InstructionBNRSHI(InstructionBNRType):
+    def execute(self, model: OTBNModel) -> None:
+        a = int(model.state.wreg[self.wrs1])
+        b = int(model.state.wreg[self.wrs2])
+        shift_bit = int(self.imm)
+        model.state.wreg[self.wrd] = (((a << 256) | b) >> shift_bit) & (
+            (1 << 256) - 1)
+
+
+@isa("bn.sel", RV32IXotbn, opcode=0b0001011, funct3=0b000)
+class InstructionBNSEL(InstructionBNSType):
+    def execute(self, model: OTBNModel) -> None:
+        # self.flag gives a number (0-3), which we need to convert to a flag
+        # name for use with BitflagRegister.
+        assert self.flag is not None
+        assert 0 <= self.flag <= 3
+        flag_name = ['C', 'L', 'M', 'Z'][self.flag]
+
+        flag_is_set = model.state.flags[self.fg].get(flag_name)
+        val = model.state.wreg[self.wrs1 if flag_is_set else self.wrs2]
+        model.state.wreg[self.wrd] = val
+
+
+@isa("bn.cmp", RV32IXotbn, opcode=0b0001011, funct3=0b001)
+class InstructionBNCMP(InstructionBNCType):
+    def execute(self, model: OTBNModel) -> None:
+        a = int(model.state.wreg[self.wrs1])
+        b_shifted = ShiftReg(int(model.state.wreg[self.wrs2]), self.shift_type,
+                             self.shift_bytes)
+        (_, flags) = model.add_with_carry(a, -b_shifted, 0)
+        model.state.flags[self.fg] = flags
+
+
+@isa("bn.cmpb", RV32IXotbn, opcode=0b0001011, funct3=0b011)
+class InstructionBNCMPB(InstructionBNCType):
+    def execute(self, model: OTBNModel) -> None:
+        a = int(model.state.wreg[self.wrs1])
+        b_shifted = ShiftReg(int(model.state.wreg[self.wrs2]), self.shift_type,
+                             self.shift_bytes)
+        (_, flags) = model.add_with_carry(a, -b_shifted,
+                                          1 - model.state.flags[self.fg].C)
+        model.state.flags[self.fg] = flags
+
+
+@isa("bn.lid", RV32IXotbn, opcode=0b0001011, funct3=0b100)
+class InstructionBNLID(InstructionBNIType):
+    """
+    Load Word (indirect source, indirect destination)
+
+    Calculates a byte memory address by adding the offset to the value in the
+    GPR grs1. The value from this memory address is then copied into the WDR
+    pointed to by the value in GPR grd.
+
+    After the operation, either the value in the GPR grs1, or the value in grd
+    can be optionally incremented.
+
+    If grs1_inc is set, the value in grs1 is incremented by the value WLEN/8
+    (one word). If grd_inc is set, the value in grd is incremented by the value
+    1.
+    """
+    def execute(self, model: OTBNModel) -> None:
+        assert self.rs is not None
+        assert self.rd is not None
+
+        addr = int(model.state.intreg[self.rs] + int(self.imm) * 32)
+        wrd = int(model.state.intreg[self.rd])
+        word = model.load_wlen_word_from_memory(addr)
+        model.state.wreg[wrd] = word
+
+
+@isa("bn.sid", RV32IXotbn, opcode=0b0001011, funct3=0b101)
+class InstructionBNSID(InstructionBNISType):
+    """
+    Store Word (indirect source, indirect destination)
+
+    Calculates a byte memory address by adding the offset to the value in the
+    GPR grs1. The value from the WDR pointed to by grs2 is then copied into the
+    memory.
+
+    After the operation, either the value in the GPR grs1, or the value in grs2
+    can be optionally incremented.
+
+    If grs1_inc is set, the value in grs1 is incremented by the value WLEN/8
+    (one word). If grs2_inc is set, the value in grs2 is incremented by the
+    value 1.
+    """
+    def execute(self, model: OTBNModel) -> None:
+        assert self.rs2 is not None
+        assert self.rs1 is not None
+
+        addr = int(model.state.intreg[self.rs2] + int(self.imm) * 32)
+        wrs = int(model.state.intreg[self.rs1])
+        word = int(model.state.wreg[wrs])
+        model.store_wlen_word_to_memory(addr, word)
+
+
+@isa("bn.mov", RV32IXotbn, opcode=0b0001011, funct3=0b110, funct31=0)
+class InstructionBNMOV(InstructionBNMVType):
+    def execute(self, model: OTBNModel) -> None:
+        model.state.wreg[self.wrd] = model.state.wreg[self.wrs]
+
+
+@isa("bn.movr", RV32IXotbn, opcode=0b0001011, funct3=0b110, funct31=1)
+class InstructionBNMOVR(InstructionBNMVRType):
+    def execute(self, model: OTBNModel) -> None:
+        assert self.rd is not None
+        assert self.rs is not None
+        wrd = int(model.state.intreg[self.rd])
+        wrs = int(model.state.intreg[self.rs])
+        model.state.wreg[wrd] = model.state.wreg[wrs]
+        if self.rd_inc:
+            model.state.intreg[self.rd] += 1
+        if self.rs_inc:
+            model.state.intreg[self.rs] += 1
+
+
+@isa("bn.wsrrs", RV32IXotbn, opcode=0b0001011, funct3=0b111, funct31=0)
+class InstructionBNWSRRS(InstructionBNCSType):
+    """
+    Atomic Read and Set Bits in WSR
+    """
+    def execute(self, model: OTBNModel) -> None:
+        csr = model.state.wcsr_read(self.wsr)
+        model.state.wreg[self.wrd] = model.state.wreg[self.wrs] & csr
+
+
+@isa("bn.wsrrw", RV32IXotbn, opcode=0b0001011, funct3=0b111, funct31=1)
+class InstructionBNWSRRW(InstructionBNCSType):
+    def execute(self, model: OTBNModel) -> None:
+        index = int(self.wsr)
+        old_val = model.state.wcsr_read(index)
+        new_val = model.state.wreg[self.wrs]
+
+        model.state.wcsr_write(index, new_val)
+        model.state.wreg[self.wrd] = old_val
diff --git a/hw/ip/otbn/dv/otbnsim/sim/isa.py b/hw/ip/otbn/dv/otbnsim/sim/isa.py
new file mode 100644
index 0000000..34f19a4
--- /dev/null
+++ b/hw/ip/otbn/dv/otbnsim/sim/isa.py
@@ -0,0 +1,509 @@
+# Copyright lowRISC contributors.
+# Licensed under the Apache License, Version 2.0, see LICENSE for details.
+# SPDX-License-Identifier: Apache-2.0
+
+from enum import IntEnum
+from abc import ABCMeta
+from typing import Optional, cast
+
+from riscvmodel.isa import (Instruction,  # type: ignore
+                            InstructionFunct3Type, Field)
+from riscvmodel.types import Immediate  # type: ignore
+
+
+class InstructionFunct2Type(Instruction):  # type: ignore
+    field_funct2 = Field(name="funct2",
+                         base=12,
+                         size=2,
+                         description="",
+                         static=True)
+
+
+class InstructionFunct31Type(Instruction):  # type: ignore
+    field_funct31 = Field(name="funct31",
+                          base=31,
+                          size=1,
+                          description="",
+                          static=True)
+
+
+class InstructionFunct30Type(Instruction):  # type: ignore
+    field_funct30 = Field(name="funct30",
+                          base=30,
+                          size=1,
+                          description="",
+                          static=True)
+
+
+class InstructionW3Type(Instruction):  # type: ignore
+    field_wrd = Field(name="wrd", base=7, size=5)
+    field_wrs1 = Field(name="wrs1", base=15, size=5)
+    field_wrs2 = Field(name="wrs2", base=20, size=5)
+
+
+class InstructionFGType(Instruction):  # type: ignore
+    field_fg = Field(
+        name="fg",
+        base=31,
+        size=1,
+        description="Flag group to use. Defaults to 0.\n\nValid range: 0..1")
+
+
+class InstructionShiftType(Instruction):  # type: ignore
+    field_shift_type = Field(
+        name="shift_type",
+        base=30,
+        size=1,
+        description="The direction of an optional shift applied to <wrs2>.")
+    field_shift_bytes = Field(
+        name="shift_bytes",
+        base=25,
+        size=5,
+        description=
+        "Number of bytes by which to shift <wrs2>. Defaults to 0.\n\nValid range: 0..31"
+    )
+
+
+class InstructionLType(InstructionFunct2Type):
+    isa_format_id = "L"
+
+    field_rs1 = Field(name="rs1", base=15, size=5, description="")
+    field_bodysize = Field(name="bodysize", base=20, size=12, description="")
+
+    def __init__(self,
+                 rs1: Optional[int] = None,
+                 bodysize: Optional[int] = None):
+        super().__init__()
+        self.rs1 = rs1
+        self.bodysize = Immediate(bits=12, signed=False, init=bodysize)
+
+    def __str__(self) -> str:
+        return "{} x{}, {}".format(self.mnemonic, self.rs1, self.bodysize)
+
+
+class InstructionLIType(InstructionFunct2Type):
+    isa_format_id = "LI"
+
+    field_iter = Field(name="iter", base=7, size=5, description="")
+    field_bodysize = Field(name="bodysize", base=20, size=12, description="")
+
+    def __init__(self,
+                 iter: Optional[int] = None,
+                 bodysize: Optional[int] = None):
+        super().__init__()
+        self.iter = Immediate(bits=10, signed=False, init=iter)
+        self.bodysize = Immediate(bits=12, signed=False, init=bodysize)
+
+    def __str__(self) -> str:
+        return "{} {}, {}".format(self.mnemonic, self.iter, self.bodysize)
+
+
+class ShiftType(IntEnum):
+    LSL = 0  # logical shift left
+    LSR = 1  # logical shift right
+
+
+def ShiftReg(reg: int, shift_type: int, shift_bytes: Immediate) -> int:
+    assert 0 <= int(shift_bytes)
+    shift_bits = int(shift_bytes << 3)
+
+    return (reg << shift_bits
+            if shift_type == ShiftType.LSL
+            else reg >> shift_bits)
+
+
+class InstructionBNAType(InstructionFunct3Type,  # type: ignore
+                         InstructionFunct31Type,
+                         InstructionShiftType,
+                         InstructionW3Type,
+                         metaclass=ABCMeta):
+    """
+    :param wrd: Name of the destination WDR
+    :param wrs1: Name of the first source WDR
+    :param wrs2: Name of the second source WDR
+    :param shift_type: The direction of an optional shift applied to <wrs2>.
+    :param shift_bytes: Number of bytes by which to shift <wrs2>. Defaults to 0. Valid range: 0..31.
+    """
+    isa_format_id = "BNA"
+
+    def __init__(self,
+                 wrd: Optional[int] = None,
+                 wrs1: Optional[int] = None,
+                 wrs2: Optional[int] = None,
+                 shift_bytes: Optional[int] = None,
+                 shift_type: ShiftType = ShiftType.LSL):
+        super().__init__()
+        self.wrd = wrd
+        self.wrs1 = wrs1
+        self.wrs2 = wrs2
+        self.shift_bytes = Immediate(bits=5,
+                                     signed=False,
+                                     init=shift_bytes)
+        self.shift_type = shift_type
+
+    def __str__(self) -> str:
+        asm = ("{} w{}, w{}, w{}"
+               .format(self.mnemonic, self.wrd, self.wrs1, self.wrs2))
+        if int(self.shift_bytes) > 0:
+            asm += ", {} {}B".format(
+                "<<" if self.shift_type == ShiftType.LSL else ">>",
+                self.shift_bytes)
+        return asm
+
+
+class InstructionBNANType(InstructionFunct3Type,  # type: ignore
+                          InstructionFunct31Type,
+                          InstructionShiftType):
+    isa_format_id = "BNAN"
+
+    field_wrd = Field(name="wrd", base=7, size=5, description="")
+    field_wrs1 = Field(name="wrs1", base=20, size=5, description="")
+
+    def __init__(self,
+                 wrd: Optional[int] = None,
+                 wrs1: Optional[int] = None,
+                 shift_bytes: int = 0,
+                 shift_type: int = ShiftType.LSL):
+        super().__init__()
+        self.wrd = wrd
+        self.wrs1 = wrs1
+        self.shift_bytes = Immediate(bits=5, signed=False, init=shift_bytes)
+        self.shift_type = shift_type
+
+    def __str__(self) -> str:
+        return "{} w{}, w{}".format(self.mnemonic, self.wrd, self.wrs1)
+
+
+class InstructionBNCSType(InstructionFunct3Type,  # type: ignore
+                          InstructionFunct31Type):
+    isa_format_id = "BNCS"
+
+    field_wrd = Field(name="wrd", base=7, size=5)
+    field_wrs = Field(name="wrs", base=15, size=5)
+    field_wsr = Field(name="wsr", base=20, size=8)
+
+    def __init__(self,
+                 wrd: Optional[int] = None,
+                 wsr: Optional[int] = None,
+                 wrs: Optional[int] = None):
+        super().__init__()
+        self.wrd = wrd
+        self.wsr = Immediate(bits=8, signed=False, init=wsr)
+        self.wsr = wrs
+
+    def __str__(self) -> str:
+        return ("{} w{}, w{}, {}"
+                .format(self.mnemonic, self.wrd, self.wsr, self.wrs))
+
+
+class InstructionBNAFType(InstructionFunct3Type,  # type: ignore
+                          InstructionW3Type,
+                          InstructionFGType,
+                          InstructionShiftType):
+    isa_format_id = "BNAF"
+
+    def __init__(self,
+                 wrd: Optional[int] = None,
+                 wrs1: Optional[int] = None,
+                 wrs2: Optional[int] = None,
+                 shift_bytes: int = 0,
+                 shift_type: int = ShiftType.LSL,
+                 fg: int = 0):
+        self.wrd = wrd
+        self.wrs1 = wrs1
+        self.wrs2 = wrs2
+        self.shift_bytes = Immediate(bits=5, signed=False, init=shift_bytes)
+        self.shift_type = shift_type
+        self.fg = fg
+
+    def __str__(self) -> str:
+        shift = "{} {}B".format(
+            "<<" if self.shift_type == ShiftType.LSL else ">>",
+            self.shift_bytes)
+        return "{} w{}, w{}, w{}{}, FG{}".format(self.mnemonic, self.wrd,
+                                                 self.wrs1, self.wrs2, shift,
+                                                 self.fg)
+
+
+class InstructionBNAIType(InstructionFunct3Type,  # type: ignore
+                          InstructionFGType,
+                          InstructionFunct30Type):
+    isa_format_id = "BNAI"
+
+    field_wrd = Field(name="wrd", base=7, size=5)
+    field_wrs1 = Field(name="wrs1", base=15, size=5)
+    field_imm = Field(name="imm", base=20, size=10)
+    field_fg = Field(name="fg", base=30, size=1)
+
+    def __init__(self,
+                 wrd: Optional[int] = None,
+                 wrs1: Optional[int] = None,
+                 imm: Optional[int] = None,
+                 fg: int = 0):
+        self.wrd = wrd
+        self.wrs1 = wrs1
+        self.imm = Immediate(bits=10, signed=False, init=imm)
+        self.fg = fg
+
+
+class InstructionBNAMType(InstructionFunct3Type,  # type: ignore
+                          InstructionW3Type,
+                          InstructionFunct30Type):
+    isa_format_id = "BNAM"
+
+    def __init__(self,
+                 wrd: Optional[int] = None,
+                 wrs1: Optional[int] = None,
+                 wrs2: Optional[int] = None):
+        self.wrd = wrd
+        self.wrs1 = wrs1
+        self.wrs1 = wrs1
+
+    def __str__(self) -> str:
+        return "{} w{}, w{}, w{}".format(self.mnemonic, self.wrd, self.wrs1,
+                                         self.wrs2)
+
+
+class InstructionBNAQType(InstructionW3Type):
+    isa_format_id = "BNAQ"
+
+    field_wb_variant = Field(name="wb_variant",
+                             base=30,
+                             size=2,
+                             description="""
+Result writeback instruction variant. If no writeback variant is chosen, no
+destination register is written, and the multiplication result is only stored in
+the accumulator.
+
+Valid values:
+
+* .S0 (value 0): Shift out the lower half-word of the value stored in the
+  accumulator to a WLEN/2-sized half-word of the destination WDR. The
+  destination half-word is selected by the wrd_hwsel field.
+
+* .W0 (value 1): Write the value stored in the accumulator to the destination
+  WDR.""")
+    field_zero_acc = Field(name="zero_acc",
+                           base=12,
+                           size=1,
+                           description="""
+Zero the accumulator before accumulating the multiply result.
+
+To specify, use the literal syntax .Z""")
+    field_wrd_hwsel = Field(name="wrd_hwsel", base=29, size=1, description="")
+    field_wrs1_qwsel = Field(name="wrs1_qwsel",
+                             base=27,
+                             size=2,
+                             description="")
+    field_wrs2_qwsel = Field(name="wrs2_qwsel",
+                             base=25,
+                             size=2,
+                             description="")
+    field_acc_shift_imm = Field(name="acc_shift_imm",
+                                base=13,
+                                size=2,
+                                description="")
+
+    def __init__(self,
+                 wrd: int = 0,
+                 wrs1: Optional[int] = None,
+                 wrs2: Optional[int] = None,
+                 wb_variant: int = 0,
+                 zero_acc: bool = False,
+                 wrd_hwsel: int = 0,
+                 wrs1_qwsel: Optional[int] = None,
+                 wrs2_qwsel: Optional[int] = None,
+                 acc_shift_imm: Optional[int] = None):
+        self.wrd = wrd
+        self.wrs1 = wrs1
+        self.wrs2 = wrs2
+        self.wb_variant = wb_variant
+        self.zero_acc = zero_acc
+        self.wrd_hwsel = wrd_hwsel
+        self.wrs1_qwsel = wrs1_qwsel
+        self.wrs2_qwsel = wrs2_qwsel
+        self.acc_shift_imm = acc_shift_imm
+
+    def __str__(self) -> str:
+        istr = cast(str, self.mnemonic)
+        if self.wb_variant > 0:
+            istr += ".so" if self.wb_variant == 1 else ".wo"
+        if self.zero_acc:
+            istr += ".z"
+        istr += " "
+        if self.wb_variant > 0:
+            istr += "w{}".format(self.wrd)
+            if self.wb_variant == 1:
+                istr += ".u" if self.wrd_hwsel == 1 else ".l"
+            istr += ", "
+        istr += "w{}.{}, ".format(self.wrs1, self.wrs1_qwsel)
+        istr += "w{}.{}, ".format(self.wrs2, self.wrs2_qwsel)
+        istr += ('??' if self.acc_shift_imm is None
+                 else str(self.acc_shift_imm * 64))
+        return istr
+
+
+class InstructionBNRType(InstructionW3Type, InstructionFunct2Type):
+    isa_format_id = "BNR"
+
+    field_imm = Field(name="imm", base=[14, 25], size=[1, 7], description="")
+
+    def __init__(self,
+                 wrd: Optional[int] = None,
+                 wrs1: Optional[int] = None,
+                 wrs2: Optional[int] = None,
+                 imm: Optional[int] = None):
+        self.wrd = wrd
+        self.wrs1 = wrs1
+        self.wrs2 = wrs2
+        self.imm = Immediate(bits=8, signed=False, init=imm)
+
+    def __str__(self) -> str:
+        return "{} w{}, w{}, w{} >> {}".format(self.mnemonic, self.wrd,
+                                               self.wrs1, self.wrs2, self.imm)
+
+
+class InstructionBNSType(InstructionW3Type,
+                         InstructionFunct3Type,  # type: ignore
+                         InstructionFGType):
+    isa_format_id = "BNS"
+
+    field_flag = Field(name="flag", base=25, size=2, description="")
+
+    def __init__(self,
+                 wrd: Optional[int] = None,
+                 wrs1: Optional[int] = None,
+                 wrs2: Optional[int] = None,
+                 fg: int = 0,
+                 flag: Optional[int] = None):
+        self.wrd = wrd
+        self.wrs1 = wrs1
+        self.wrs2 = wrs2
+        self.fg = fg
+        self.flag = flag
+
+    def __str__(self) -> str:
+        return "{} w{}, w{}, w{}{}, FG{}".format(self.mnemonic, self.wrd,
+                                                 self.wrs1, self.wrs2, self.fg,
+                                                 self.flag)
+
+
+class InstructionBNCType(InstructionFunct3Type,  # type: ignore
+                         InstructionFGType,
+                         InstructionShiftType):
+    isa_format_id = "BNC"
+
+    field_wrs1 = Field(name="wrs1", base=15, size=5)
+    field_wrs2 = Field(name="wrs2", base=20, size=5)
+
+    def __init__(self,
+                 wrs1: Optional[int] = None,
+                 wrs2: Optional[int] = None,
+                 shift_bytes: int = 0,
+                 shift_type: int = ShiftType.LSL,
+                 fg: int = 0):
+        self.wrs1 = wrs1
+        self.wrs2 = wrs2
+        self.shift_bytes = Immediate(bits=5, signed=False, init=shift_bytes)
+        self.shift_type = shift_type
+        self.fg = fg
+
+    def __str__(self) -> str:
+        shift = "{} {}B".format(
+            "<<" if self.shift_type == ShiftType.LSL else ">>",
+            self.shift_bytes)
+        return "{} w{}, w{}{}, FG{}".format(self.mnemonic, self.wrs1,
+                                            self.wrs2, shift, self.fg)
+
+
+class InstructionBNIType(InstructionFunct3Type):  # type: ignore
+    isa_format_id = "BNI"
+
+    field_rd = Field(name="rd", base=7, size=5)
+    field_rs = Field(name="rs", base=15, size=5)
+    field_imm = Field(name="imm", base=22, size=10)
+    field_rd_inc = Field(name="rd_inc", base=20, size=1)
+    field_rs_inc = Field(name="rs_inc", base=21, size=1)
+
+    def __init__(self,
+                 rd: Optional[int] = None,
+                 rs: Optional[int] = None,
+                 imm: Optional[int] = None,
+                 rd_inc: bool = False,
+                 rs_inc: bool = False):
+        self.rd = rd
+        self.rs = rs
+        self.imm = Immediate(bits=10, signed=True)
+        self.rd_inc = rd_inc
+        self.rs_inc = rs_inc
+
+    def __str__(self) -> str:
+        return ("{} x{}, {}(x{})"
+                .format(self.mnemonic, self.rd, self.imm, self.rs))
+
+
+class InstructionBNISType(InstructionFunct3Type):  # type: ignore
+    isa_format_id = "BNIS"
+
+    field_rs1 = Field(name="rs1", base=7, size=5)
+    field_rs2 = Field(name="rs2", base=15, size=5)
+    field_imm = Field(name="imm", base=22, size=10)
+    field_rs1_inc = Field(name="rs1_inc", base=20, size=1)
+    field_rs2_inc = Field(name="rs2_inc", base=21, size=1)
+
+    def __init__(self,
+                 rs1: Optional[int] = None,
+                 rs2: Optional[int] = None,
+                 imm: Optional[int] = None,
+                 rs1_inc: bool = False,
+                 rs2_inc: bool = False):
+        self.rs1 = rs1
+        self.rs2 = rs2
+        self.imm = Immediate(bits=10, signed=True)
+        self.rs1_inc = rs1_inc
+        self.rs2_inc = rs2_inc
+
+    def __str__(self) -> str:
+        return ("{} x{}, {}(x{})"
+                .format(self.mnemonic, self.rs1, self.imm, self.rs2))
+
+
+class InstructionBNMVType(InstructionFunct3Type,  # type: ignore
+                          InstructionFunct31Type):
+    isa_format_id = "BNMV"
+
+    field_wrd = Field(name="wrd", base=7, size=5)
+    field_wrs = Field(name="wrs", base=15, size=5)
+
+    def __init__(self, wrd: Optional[int] = None, wrs: Optional[int] = None):
+        self.wrd = wrd
+        self.wrs = wrs
+
+    def __str__(self) -> str:
+        return "{} w{}, w{}".format(self.mnemonic, self.wrd, self.wrs)
+
+
+class InstructionBNMVRType(InstructionFunct3Type,  # type: ignore
+                           InstructionFunct31Type):
+    isa_format_id = "BNMVR"
+
+    field_rd = Field(name="rd", base=7, size=5)
+    field_rs = Field(name="rs", base=15, size=5)
+    field_rd_inc = Field(name="rd_inc", base=20, size=1)
+    field_rs_inc = Field(name="rs_inc", base=21, size=1)
+
+    def __init__(self,
+                 rd: Optional[int] = None,
+                 rs: Optional[int] = None,
+                 rd_inc: bool = False,
+                 rs_inc: bool = False):
+        self.rd = rd
+        self.rs = rs
+        self.rd_inc = rd_inc
+        self.rs_inc = rs_inc
+
+    def __str__(self) -> str:
+        dpp = "++" if self.rd_inc else ""
+        spp = "++" if self.rs_inc else ""
+        return "{} x{}{}, x{}{}".format(self.mnemonic, self.rd, dpp, self.rs,
+                                        spp)
diff --git a/hw/ip/otbn/dv/otbnsim/sim/model.py b/hw/ip/otbn/dv/otbnsim/sim/model.py
new file mode 100644
index 0000000..ddef158
--- /dev/null
+++ b/hw/ip/otbn/dv/otbnsim/sim/model.py
@@ -0,0 +1,368 @@
+# Copyright lowRISC contributors.
+# Licensed under the Apache License, Version 2.0, see LICENSE for details.
+# SPDX-License-Identifier: Apache-2.0
+
+from random import getrandbits
+from typing import List, Optional, Tuple, cast
+
+from attrdict import AttrDict  # type: ignore
+
+from riscvmodel.model import (Model, State,  # type: ignore
+                              Environment, TerminateException)
+from riscvmodel.isa import Instruction  # type: ignore
+from riscvmodel.types import (RegisterFile, Register,  # type: ignore
+                              SingleRegister, Trace, BitflagRegister)
+
+from .variant import RV32IXotbn
+
+
+class TraceCallStackPush(Trace):  # type: ignore
+    def __init__(self, value: int):
+        self.value = value
+
+    def __str__(self) -> str:
+        return "RAS push {:08x}".format(self.value)
+
+
+class TraceCallStackPop(Trace):  # type: ignore
+    def __init__(self, value: int):
+        self.value = value
+
+    def __str__(self) -> str:
+        return "RAS pop {:08x}".format(self.value)
+
+
+class TraceLoopStart(Trace):  # type: ignore
+    def __init__(self, iterations: int, bodysize: int):
+        self.iterations = iterations
+        self.bodysize = bodysize
+
+    def __str__(self) -> str:
+        return "Start LOOP, {} iterations, bodysize: {}".format(
+            self.iterations, self.bodysize)
+
+
+class TraceLoopIteration(Trace):  # type: ignore
+    def __init__(self, iteration: int, total: int):
+        self.iteration = iteration
+        self.total = total
+
+    def __str__(self) -> str:
+        return "LOOP iteration {}/{}".format(self.iteration, self.total)
+
+
+class OTBNIntRegisterFile(RegisterFile):  # type: ignore
+    def __init__(self) -> None:
+        super().__init__(num=32, bits=32, immutable={0: 0})
+
+        # The call stack for x1 and its pending updates
+        self.callstack = []  # type: List[int]
+        self.have_read_callstack = False
+        self.callstack_push_val = None  # type: Optional[int]
+
+    def __setitem__(self, key: int, value: int) -> None:
+        # Special handling for the callstack in x1
+        if key == 1:
+            self.callstack_push_val = value
+            return
+
+        # Otherwise, use the base class implementation
+        super().__setitem__(key, value)
+
+    def __getitem__(self, key: int) -> int:
+        # Special handling for the callstack in x1
+        if key == 1:
+            self.have_read_callstack = True
+
+        return cast(int, super().__getitem__(key))
+
+    def post_insn(self) -> None:
+        '''Update the x1 call stack after an instruction executes
+
+        This needs to run after execution (which sets up callstack_push_val and
+        have_read_callstack) but before we print the instruction in
+        State.issue, because any changes to x1 need to be reflected there.
+
+        '''
+        cs_changed = False
+        if self.have_read_callstack:
+            if self.callstack:
+                self.callstack.pop()
+                cs_changed = True
+
+        if self.callstack_push_val is not None:
+            self.callstack.append(self.callstack_push_val)
+            cs_changed = True
+
+        # Update self.regs[1] so that it always points at the top of the stack.
+        # If the stack is empty, set it to zero (we need to decide what happens
+        # in this case: see issue #3239)
+        if cs_changed:
+            cs_val = 0
+            if self.callstack:
+                cs_val = self.callstack[0]
+
+            super().__setitem__(1, cs_val)
+
+        self.have_read_callstack = False
+        self.callstack_push_val = None
+
+
+class LoopLevel:
+    '''An object representing a level in the current loop stack
+
+    start_addr is the first instruction inside the loop (the instruction
+    following the loop instruction). insn_count is the number of instructions
+    in the loop (and must be positive). restarts is one less than the number of
+    iterations, and must be positive.
+
+    '''
+    def __init__(self, start_addr: int, insn_count: int, restarts: int):
+        assert 0 <= start_addr
+        assert 0 < insn_count
+        assert 0 < restarts
+
+        self.loop_count = 1 + restarts
+        self.restarts_left = restarts
+        self.start_addr = start_addr
+        self.match_addr = start_addr + 4 * insn_count
+
+
+class LoopStack:
+    '''An object representing the loop stack
+
+    An entry on the loop stack represents a possible back edge: the
+    restarts_left counter tracks the number of these back edges. The entry is
+    removed when the counter gets to zero.
+
+    '''
+    def __init__(self) -> None:
+        self.stack = []  # type: List[LoopLevel]
+        self.trace = []  # type: List[Trace]
+
+    def start_loop(self,
+                   next_addr: int,
+                   insn_count: int,
+                   loop_count: int) -> Optional[int]:
+        '''Start a loop.
+
+        Adds the loop to the stack and returns the next PC if it's not
+        straight-line. If the loop count is one, this acts as a NOP (and
+        doesn't change the stack). If the loop count is zero, this doesn't
+        change the stack but the next PC will be the match address.
+
+        '''
+        assert 0 <= next_addr
+        assert 0 < insn_count
+        assert 0 <= loop_count
+
+        self.trace.append(TraceLoopStart(loop_count, insn_count))
+
+        if loop_count == 0:
+            return next_addr + 4 * insn_count
+
+        if loop_count > 1:
+            self.stack.append(LoopLevel(next_addr, insn_count, loop_count - 1))
+
+        return None
+
+    def step(self, cur_pc: int) -> int:
+        '''Calculate the next PC and update loop stack'''
+        next_pc = cur_pc + 4
+        if self.stack:
+            top = self.stack[-1]
+            if next_pc == top.match_addr:
+                assert top.restarts_left > 0
+                top.restarts_left -= 1
+
+                if not top.restarts_left:
+                    self.stack.pop()
+
+                # 1-based iteration number
+                idx = top.loop_count - top.restarts_left
+                self.trace.append(TraceLoopIteration(idx, top.loop_count))
+
+                return top.start_addr
+
+        return next_pc
+
+    def changes(self) -> List[Trace]:
+        return self.trace
+
+    def commit(self) -> None:
+        self.trace = []
+
+
+class FlagGroups:
+    def __init__(self) -> None:
+        self.groups = {
+            0: BitflagRegister(["C", "L", "M", "Z"], prefix = "FG0."),
+            1: BitflagRegister(["C", "L", "M", "Z"], prefix = "FG1.")
+        }
+
+    def __getitem__(self, key: int) -> BitflagRegister:
+        assert 0 <= key <= 1
+        return self.groups[key]
+
+    def __setitem__(self, key: int, value: int) -> None:
+        assert 0 <= key <= 1
+        self.groups[key].set(value)
+
+    def changes(self) -> List[Trace]:
+        return cast(List[Trace],
+                    self.groups[0].changes() + self.groups[1].changes())
+
+    def commit(self) -> None:
+        self.groups[0].commit()
+        self.groups[1].commit()
+
+
+class OTBNState(State):  # type: ignore
+    def __init__(self) -> None:
+        super().__init__(RV32IXotbn)
+
+        # Hack: this matches the superclass constructor, but you need it to
+        # explain to mypy what self.pc is (because mypy can't peek into
+        # riscvmodel without throwing up lots of errors)
+        self.pc = Register(32)
+
+        self.intreg = OTBNIntRegisterFile()
+        self.wreg = RegisterFile(num=32, bits=256, immutable={}, prefix="w")
+        self.single_regs = {
+            'acc': SingleRegister(256, "ACC"),
+            'mod': SingleRegister(256, "MOD")
+        }
+        self.flags = FlagGroups()
+        self.loop_stack = LoopStack()
+
+    def csr_read(self, index: int) -> int:
+        if index == 0x7C0:
+            return int(self.wreg)
+        elif 0x7D0 <= index <= 0x7D7:
+            bit_shift = 32 * (index - 0x7D0)
+            mask32 = (1 << 32) - 1
+            return (int(self.mod) >> bit_shift) & mask32
+        elif index == 0xFC0:
+            return getrandbits(32)
+        return cast(int, super().csr_read(self, index))
+
+    def wcsr_read(self, index: int) -> int:
+        assert 0 <= index <= 2
+        if index == 0:
+            return int(self.mod)
+        elif index == 1:
+            return getrandbits(256)
+        else:
+            assert index == 2
+            return int(self.single_regs['acc'])
+
+    def wcsr_write(self, index: int, value: int) -> None:
+        if index == 0:
+            self.mod = value
+
+    def loop_start(self, iterations: int, bodysize: int) -> None:
+        next_pc = int(self.pc) + 4
+        skip_pc = self.loop_stack.start_loop(next_pc, bodysize, iterations)
+        if skip_pc is not None:
+            self.pc_update.set(skip_pc)
+
+    def loop_step(self) -> None:
+        self.pc_update.set(self.loop_stack.step(int(self.pc)))
+
+    def changes(self) -> List[Trace]:
+        c = cast(List[Trace], super().changes())
+        c += self.loop_stack.changes()
+        c += self.wreg.changes()
+        c += self.flags.changes()
+        for name, reg in sorted(self.single_regs.items()):
+            c += reg.changes()
+        return c
+
+    def commit(self) -> None:
+        super().commit()
+        self.loop_stack.commit()
+        self.wreg.commit()
+        self.flags.commit()
+        for reg in self.single_regs.values():
+            reg.commit()
+
+
+class OTBNEnvironment(Environment):  # type: ignore
+    def call(self, state: OTBNState) -> None:
+        raise TerminateException(0)
+
+
+class OTBNModel(Model):  # type: ignore
+    def __init__(self, verbose: bool):
+        super().__init__(RV32IXotbn,
+                         environment=OTBNEnvironment(),
+                         verbose=verbose,
+                         asm_width=35)
+        self.state = OTBNState()
+
+    def get_wr_quarterword(self, wridx: int, qwsel: int) -> int:
+        assert 0 <= wridx <= 31
+        assert 0 <= qwsel <= 3
+        mask = (1 << 64) - 1
+        return (int(self.state.wreg[wridx]) >> (qwsel * 64)) & mask
+
+    def set_wr_halfword(self, wridx: int, value: int, hwsel: int) -> None:
+        assert 0 <= wridx <= 31
+        assert (value >> 128) == 0
+        assert 0 <= hwsel <= 1
+
+        mask = ((1 << 128) - 1) << (0 if hwsel else 128)
+        curr = int(self.state.wreg[wridx]) & mask
+        valpos = value << 128 if hwsel else value
+        self.state.wreg[wridx].set(curr | valpos)
+
+    def load_wlen_word_from_memory(self, addr: int) -> int:
+        assert 0 <= addr
+
+        word = 0
+        for byte_off in range(0, 32, 4):
+            bit_off = byte_off * 8
+            word += cast(int, self.state.memory.lw(addr + byte_off)) << bit_off
+        return word
+
+    def store_wlen_word_to_memory(self, addr: int, word: int) -> None:
+        assert 0 <= addr
+        assert 0 <= word
+        assert (word >> 256) == 0
+
+        mask32 = (1 << 32) - 1
+        for byte_off in range(0, 32, 4):
+            bit_off = byte_off * 8
+            self.state.memory.sw(addr + byte_off, (word >> bit_off) & mask32)
+
+    @staticmethod
+    def add_with_carry(a: int, b: int, carry_in: int) -> Tuple[int, int]:
+        result = a + b + carry_in
+
+        flags_out = AttrDict({"C": (result >> 256) & 1,
+                              "L": result & 1,
+                              "M": (result >> 255) & 1,
+                              "Z": 1 if result == 0 else 0})
+
+        return (result & ((1 << 256) - 1), flags_out)
+
+    def issue(self, insn: Instruction) -> List[Trace]:
+        '''An overridden version of riscvmodel's Model.issue
+
+        We have to override this to allow the loop stack to jump in between
+        instruction execution and calculating the trace of changes.
+
+        '''
+        self.state.pc += 4
+        insn.execute(self)
+
+        self.state.loop_step()
+        self.state.intreg.post_insn()
+
+        trace = self.state.changes()
+        if self.verbose is not False:
+            self.verbose_file.write(self.asm_tpl.
+                                    format(str(insn),
+                                           ", ".join([str(t) for t in trace])))
+        self.state.commit()
+        return trace
diff --git a/hw/ip/otbn/dv/otbnsim/sim/variant.py b/hw/ip/otbn/dv/otbnsim/sim/variant.py
new file mode 100644
index 0000000..6fb4436
--- /dev/null
+++ b/hw/ip/otbn/dv/otbnsim/sim/variant.py
@@ -0,0 +1,8 @@
+# Copyright lowRISC contributors.
+# Licensed under the Apache License, Version 2.0, see LICENSE for details.
+# SPDX-License-Identifier: Apache-2.0
+
+from riscvmodel.variant import Extension, Variant  # type: ignore
+
+RV32IXotbn = Variant("RV32IXotbn", custext=[Extension(
+    name="Xotbn", description="OpenTitan BigNum Extension", implies=["Zicsr"])])
diff --git a/hw/ip/otbn/dv/otbnsim/standalone.py b/hw/ip/otbn/dv/otbnsim/standalone.py
index dfbf641..c1b71d0 100755
--- a/hw/ip/otbn/dv/otbnsim/standalone.py
+++ b/hw/ip/otbn/dv/otbnsim/standalone.py
@@ -7,10 +7,9 @@
 import sys
 
 from riscvmodel.sim import Simulator  # type: ignore
-from riscvmodel.model import Model  # type: ignore
-from riscvmodel.variant import RV32I  # type: ignore
 
 from sim.elf import load_elf
+from sim.model import OTBNModel
 
 
 def main() -> int:
@@ -21,7 +20,8 @@
 
     args = parser.parse_args()
 
-    sim = Simulator(Model(RV32I, verbose=args.verbose))
+    model = OTBNModel(verbose=args.verbose)
+    sim = Simulator(model)
     load_elf(sim, args.elf)
 
     sim.run()
diff --git a/python-requirements.txt b/python-requirements.txt
index 0551115..406cb97 100644
--- a/python-requirements.txt
+++ b/python-requirements.txt
@@ -27,7 +27,7 @@
 yapf
 
 # Used by OTBN simulator
-riscv-model >= 0.4.1
+riscv-model >= 0.6.4
 
 # Development version with OT-specific changes
 git+https://github.com/lowRISC/fusesoc.git@ot#egg=fusesoc >= 1.11.0
diff --git a/util/otbnsim/README.md b/util/otbnsim/README.md
new file mode 100644
index 0000000..0ce8a1d
--- /dev/null
+++ b/util/otbnsim/README.md
@@ -0,0 +1,129 @@
+# OpenTitan Big Number Python Model
+
+## Generate documentation
+
+```console
+$ python -m otbnsim.doc
+++++ Instruction Formats
+{'id': 'L',
+ 'fields': [{'name': 'bodysize',
+             'base': 20,
+             'size': 12,
+             'description': '',
+             'static': False,
+             'value': None},
+            {'name': 'funct3',
+             'base': 12,
+             'size': 3,
+             'description': '',
+             'static': True,
+             'value': None},
+
+[...]
+
+++++ Instructions
+{'format': 'L',
+ 'description': 'Loop (indirect)\n'
+                '\n'
+                'Repeat a sequence of code multiple times. The number of '
+                'iterations is a GPR\n'
+                'value. The length of the loop is given as immediate.\n'
+                '\n'
+                'Alternative assembly notation: The size of the loop body is '
+                'given by the\n'
+                'number of instructions in the parentheses.\n'
+                '\n'
+                'LOOP <grs> (\n'
+                '  # loop body\n'
+                ')',
+ 'asm_signature': 'loop <iter>, <bodysize>',
+ 'code': '        model.state.loop_start(int(reg[rs1]), int(bodysize))'}
+
+[...]
+```
+
+## Assembler
+
+Assemble to verify everything works:
+
+```console
+$ otbn-asm << EOF
+> LOOPI 8 (
+>   addi x2, x2, 1
+> )
+> EOF
+loopi 8, 1
+addi x2, x2, 1
+```
+
+Produce different output format, C structs:
+
+```console
+$ otbn-asm -O carray test.S
+static const uint32_t code [] = {
+    0x0010140b, // loopi 8, 1
+    0x00110113, // addi x2, x2, 1
+};
+```
+
+Finally, generate a binary, write to output file:
+
+```console
+$ otbn-asm -O binary test.S test.bin
+$ hexdump test.bin
+0000000 140b 0010 0113 0011
+0000008
+```
+
+## Run standalone test
+
+```console
+$ python -m otbnsim.standalone test.S
+loopi 8, 1                          | [Start LOOP, 8 iterations, bodysize: 1]
+addi x2, x2, 1                      | [x2 = 00000001, pc = 00000004, LOOP iteration 1/8]
+addi x2, x2, 1                      | [x2 = 00000002, pc = 00000004, LOOP iteration 2/8]
+addi x2, x2, 1                      | [x2 = 00000003, pc = 00000004, LOOP iteration 3/8]
+addi x2, x2, 1                      | [x2 = 00000004, pc = 00000004, LOOP iteration 4/8]
+addi x2, x2, 1                      | [x2 = 00000005, pc = 00000004, LOOP iteration 5/8]
+addi x2, x2, 1                      | [x2 = 00000006, pc = 00000004, LOOP iteration 6/8]
+addi x2, x2, 1                      | [x2 = 00000007, pc = 00000004, LOOP iteration 7/8]
+addi x2, x2, 1                      | [x2 = 00000008, LOOP iteration 8/8]
+```
+
+## Run pytest
+
+```console
+$ pytest
+```
+
+Grab model trace for debugging
+
+```console
+$ pytest --model-verbose
+```
+
+## Get program from database of test programs
+
+Test programs are available in a python module. Assembly code are stored in this
+file and they can be generated as with the assembler.
+
+Examples:
+
+- Produce assembly code
+
+```console
+$ python test/programs.py mul_256x256
+```
+
+Produce C array
+
+
+```console
+$ python test/programs.py -O carray mul_256x256
+```
+
+Produce binary file
+
+```console
+$ python test/programs.py -O binary mul_256x256 program.bin
+```
diff --git a/util/otbnsim/otbnsim/model.py b/util/otbnsim/otbnsim/model.py
new file mode 100644
index 0000000..4b71410
--- /dev/null
+++ b/util/otbnsim/otbnsim/model.py
@@ -0,0 +1,251 @@
+# Copyright lowRISC contributors.
+# Licensed under the Apache License, Version 2.0, see LICENSE for details.
+# SPDX-License-Identifier: Apache-2.0
+
+from random import getrandbits
+from collections import deque
+from attrdict import AttrDict
+
+from riscvmodel.model import Model, State, Environment
+from riscvmodel.isa import TerminateException
+from riscvmodel.types import RegisterFile, Register, SingleRegister, TraceRegister, TraceIntegerRegister, Trace, TracePC, BitflagRegister
+
+from .variant import RV32IXotbn
+
+
+class TraceCallStackPush(Trace):
+    def __init__(self, value):
+        self.value = value
+
+    def __str__(self):
+        return "RAS push {:08x}".format(self.value)
+
+
+class TraceCallStackPop(Trace):
+    def __init__(self, value):
+        self.value = value
+
+    def __str__(self):
+        return "RAS pop {:08x}".format(self.value)
+
+
+class TraceLoopStart(Trace):
+    def __init__(self, iterations, bodysize):
+        self.iterations = iterations
+        self.bodysize = bodysize
+
+    def __str__(self):
+        return "Start LOOP, {} iterations, bodysize: {}".format(
+            self.iterations, self.bodysize)
+
+
+class TraceLoopIteration(Trace):
+    def __init__(self, iter, total):
+        self.iter = iter
+        self.total = total
+
+    def __str__(self):
+        return "LOOP iteration {}/{}".format(self.iter, self.total)
+
+
+class OTBNIntRegisterFile(RegisterFile):
+    def __init__(self, num: int, bits: int, immutable: list = {}):
+        super().__init__(num, bits, immutable)
+        self.callstack = deque()
+        self.cs_update = []
+
+    def __setitem__(self, key, value):
+        if key == 1:
+            self.cs_update.append(TraceCallStackPush(value))
+        elif not self.regs[key].immutable:
+            reg = Register(self.bits)
+            reg.set(value)
+            self.regs_updates.append(TraceIntegerRegister(key, reg))
+
+    def __getitem__(self, key):
+        if key == 1:
+            return self.callstack.popleft()
+        return self.regs[key]
+
+    def commit(self):
+        for cs in self.cs_update:
+            self.callstack.appendleft(cs.value)
+        self.cs_update.clear()
+        super().commit()
+
+class FlagGroups:
+    def __init__(self):
+        super().__init__()
+        self.groups = { 0: BitflagRegister(["C", "L", "M", "Z"], prefix = "FG0."), 1: BitflagRegister(["C", "L", "M", "Z"], prefix = "FG1.") }
+
+    def __getitem__(self, key):
+        return self.groups[key]
+
+    def __setitem__(self, key, value):
+        self.groups[key].set(value)
+
+    def changes(self):
+        return self.groups[0].changes() + self.groups[1].changes()
+
+    def commit(self):
+        self.groups[0].commit()
+        self.groups[1].commit()
+
+class OTBNState(State):
+    def __init__(self):
+        super().__init__(RV32IXotbn)
+        self.intreg = OTBNIntRegisterFile(32, 32, {0: 0})
+        self.wreg = RegisterFile(32, 256, {}, prefix="w")
+        self.single_regs = {}
+        self.single_regs["acc"] = SingleRegister(256, "ACC")
+        self.single_regs["mod"] = SingleRegister(256, "MOD")
+        self.flags = FlagGroups()
+
+        self.loop_trace = []
+        self.loop = deque()
+
+    def __setattr__(self, name, value):
+        if name in self.single_regs:
+            self.single_regs[name].update(value)
+        super().__setattr__(name, value)
+
+    def __getattr__(self, name):
+        if name in self.single_regs:
+            return self.single_regs[name]
+        return super().__getattribute__(name)
+
+    def csr_read(self, index):
+        if index == 0x7C0:
+            return int(self.wreg)
+        elif index == 0x7D0:
+            return (int(self.mod) >> 0) & 0xffffffff
+        elif index == 0x7D1:
+            return (int(self.mod) >> 32) & 0xffffffff
+        elif index == 0x7D2:
+            return (int(self.mod) >> 64) & 0xffffffff
+        elif index == 0x7D3:
+            return (int(self.mod) >> 96) & 0xffffffff
+        elif index == 0x7D4:
+            return (int(self.mod) >> 128) & 0xffffffff
+        elif index == 0x7D5:
+            return (int(self.mod) >> 160) & 0xffffffff
+        elif index == 0x7D6:
+            return (int(self.mod) >> 192) & 0xffffffff
+        elif index == 0x7D7:
+            return (int(self.mod) >> 224) & 0xffffffff
+        elif index == 0xFC0:
+            return getrandbits(32)
+        return super().csr_read(self, index)
+
+    def wcsr_read(self, index):
+        if index == 0:
+            return int(self.mod)
+        elif index == 1:
+            return getrandbits(256)
+        elif index == 2:
+            return int(self.acc)
+
+    def wcsr_write(self, index, value):
+        old = None
+        if index == 0:
+            old = int(self.mod)
+            self.mod = value
+        return old
+
+    def loop_start(self, iterations, bodysize):
+        self.loop.appendleft({
+            "iterations": iterations,
+            "bodysize": bodysize,
+            "count_iterations": 0,
+            "count_instructions": -1
+        })
+        self.loop_trace.append(TraceLoopStart(iterations, bodysize))
+
+    def changes(self):
+        c = super().changes()
+        if len(self.loop) > 0:
+            if self.loop[0][
+                    "count_instructions"] == self.loop[0]["bodysize"] - 1:
+                self.loop_trace.append(
+                    TraceLoopIteration(self.loop[0]["count_iterations"] + 1,
+                                       self.loop[0]["iterations"]))
+                if self.loop[0][
+                        "count_iterations"] == self.loop[0]["iterations"] - 1:
+                    self.loop.popleft()
+                else:
+                    pc = self.pc - (self.loop[0]["bodysize"] - 1) * 4
+                    self.pc_update = pc
+                    c.append(TracePC(pc))
+                    self.loop[0]["count_iterations"] += 1
+                    self.loop[0]["count_instructions"] = 0
+            else:
+                self.loop[0]["count_instructions"] += 1
+        c += self.loop_trace
+        c += self.wreg.changes()
+        c += self.flags.changes()
+        for reg in self.single_regs:
+            c += self.single_regs[reg].changes()
+        return c
+
+    def commit(self):
+        super().commit()
+        self.wreg.commit()
+        self.flags.commit()
+        self.loop_trace.clear()
+        for reg in self.single_regs:
+            self.single_regs[reg].commit()
+
+class OTBNEnvironment(Environment):
+    def call(self, state: OTBNState):
+        raise TerminateException(0)
+
+
+class OTBNModel(Model):
+    def __init__(self, *, verbose=False):
+        super().__init__(RV32IXotbn,
+                         environment=OTBNEnvironment(),
+                         verbose=verbose,
+                         asm_width=35)
+        self.state = OTBNState()
+
+    def get_wr_quarterword(self, wridx, qwsel):
+        return (int(self.state.wreg[wridx]) >>
+                (qwsel * 64)) & 0xffffffffffffffff
+
+    def set_wr_halfword(self, wridx, value, hwsel):
+        mask = ((1 << 128) - 1) << (128 if hwsel == 0 else 0)
+        curr = int(self.state.wreg[wridx]) & mask
+        valpos = (value & ((1 << 128) - 1)) << (128 if hwsel == 1 else 0)
+        self.state.wreg[wridx].set(curr | valpos)
+
+    def load_wlen_word_from_memory(self, addr):
+        word = self.state.memory.lw(addr)
+        word += self.state.memory.lw(addr + 4) << 32
+        word += self.state.memory.lw(addr + 8) << 64
+        word += self.state.memory.lw(addr + 12) << 96
+        word += self.state.memory.lw(addr + 16) << 128
+        word += self.state.memory.lw(addr + 20) << 160
+        word += self.state.memory.lw(addr + 24) << 192
+        word += self.state.memory.lw(addr + 28) << 224
+        return word
+
+    def store_wlen_word_to_memory(self, addr, word):
+        self.state.memory.sw(addr, word & 0xffffffff)
+        self.state.memory.sw(addr + 4, (word >> 32) & 0xffffffff)
+        self.state.memory.sw(addr + 8, (word >> 64) & 0xffffffff)
+        self.state.memory.sw(addr + 12, (word >> 96) & 0xffffffff)
+        self.state.memory.sw(addr + 16, (word >> 128) & 0xffffffff)
+        self.state.memory.sw(addr + 20, (word >> 160) & 0xffffffff)
+        self.state.memory.sw(addr + 24, (word >> 192) & 0xffffffff)
+        self.state.memory.sw(addr + 28, (word >> 224) & 0xffffffff)
+
+    @staticmethod
+    def add_with_carry(a, b, carry_in):
+        result = a + b + carry_in
+
+        flags_out = AttrDict({"C": (result >> 256) & 1,
+                              "L": result & 1,
+                              "M": (result >> 255) & 1,
+                              "Z": 1 if result == 0 else 0})
+
+        return (result & ((1 << 256) - 1), flags_out)
diff --git a/util/otbnsim/otbnsim/standalone.py b/util/otbnsim/otbnsim/standalone.py
new file mode 100644
index 0000000..b3de6d0
--- /dev/null
+++ b/util/otbnsim/otbnsim/standalone.py
@@ -0,0 +1,35 @@
+# Copyright lowRISC contributors.
+# Licensed under the Apache License, Version 2.0, see LICENSE for details.
+# SPDX-License-Identifier: Apache-2.0
+
+from riscvmodel.sim import Simulator
+from .model import OTBNModel
+from .variant import RV32IXotbn
+from .asm import parse
+
+import argparse
+import sys
+
+
+def run(program, data=[], *, verbose=True):
+    sim = Simulator(OTBNModel(verbose=verbose))
+    sim.load_program(program)
+    sim.load_data(data)
+    sim.run()
+    return sim.dump_data()
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("program",
+                        nargs='?',
+                        type=argparse.FileType('r'),
+                        default=sys.stdin)
+    parser.add_argument("data", nargs='?', type=argparse.FileType('rb'))
+    args = parser.parse_args()
+
+    run(parse(args.program.read()), args.data.read() if args.data else [])
+
+
+if __name__ == "__main__":
+    main()
diff --git a/util/otbnsim/setup.py b/util/otbnsim/setup.py
new file mode 100644
index 0000000..b68eb9c
--- /dev/null
+++ b/util/otbnsim/setup.py
@@ -0,0 +1,17 @@
+# Copyright lowRISC contributors.
+# Licensed under the Apache License, Version 2.0, see LICENSE for details.
+# SPDX-License-Identifier: Apache-2.0
+
+# Install this one as package
+
+from setuptools import setup, find_packages
+
+setup(name="otbnsim",
+      packages=find_packages(),
+      install_requires=["riscv-model>=0.6.2", "lark-parser", "attrdict"],
+      entry_points={
+          "console_scripts": [
+              "otbn-python-model = otbnsim.main:main",
+              "otbn-asm = otbnsim.asm:main",
+          ],
+      })
diff --git a/util/otbnsim/test/programs.py b/util/otbnsim/test/programs.py
new file mode 100644
index 0000000..feb20f7
--- /dev/null
+++ b/util/otbnsim/test/programs.py
@@ -0,0 +1,93 @@
+# Copyright lowRISC contributors.
+# Licensed under the Apache License, Version 2.0, see LICENSE for details.
+# SPDX-License-Identifier: Apache-2.0
+
+import argparse
+import sys
+
+from otbnsim.asm import parse, output
+
+# Prolog to load the memory content into w registers
+# w0 <= mem[255:0]
+# w1 <= mem[511:256]
+# w2 <= mem[767:512]
+# w3 <= mem[1023:768]
+w04_prolog = """
+addi x4, x0, 0
+bn.lid x4, 0(x0)
+addi x4, x0, 1
+bn.lid x4, 1(x0)
+addi x4, x0, 2
+bn.lid x4, 2(x0)
+addi x4, x0, 3
+bn.lid x4, 3(x0)
+"""
+
+# Epilog to write w0-w4 registers into memory
+# mem[255:0] <= w0
+# mem[511:256] <= w1
+# mem[767:512] <= w2
+# mem[1023:768] <= w3
+w04_epilog = """
+addi x4, x0, 0
+bn.sid x4, 0(x0)
+addi x4, x0, 1
+bn.sid x4, 1(x0)
+addi x4, x0, 2
+bn.sid x4, 2(x0)
+addi x4, x0, 3
+bn.sid x4, 3(x0)
+"""
+
+code_mul_256x256 = """
+BN.MULQACC.Z w0.0, w1.0, 0
+BN.MULQACC w0.1, w1.0, 64
+BN.MULQACC.SO w2.l, w0.0, w1.1, 64
+BN.MULQACC w0.2, w1.0, 0
+BN.MULQACC w0.1, w1.1, 0
+BN.MULQACC w0.0, w1.2, 0
+BN.MULQACC w0.3, w1.0, 64
+BN.MULQACC w0.2, w1.1, 64
+BN.MULQACC w0.1, w1.2, 64
+BN.MULQACC.SO w2.u, w0.0, w1.3, 64
+BN.MULQACC w0.3, w1.1, 0
+BN.MULQACC w0.2, w1.2, 0
+BN.MULQACC w0.1, w1.3, 0
+BN.MULQACC w0.3, w1.2, 64
+BN.MULQACC.SO w3.l, w0.2, w1.3, 64
+BN.MULQACC.SO w3.u, w0.3, w1.3, 0
+"""
+
+code_random = """
+ADDI x5, x0, 0
+ADDI x6, x0, 6
+BN.XOR w5, w5, w5
+BN.NOT w5, w5
+LOOPI 4(
+    BN.WSRRS w6, w5, 2
+    BN.MOVR x5+, x6
+)
+"""
+
+if __name__ == "__main__":
+    codes = [code[5:] for code in dir() if code.startswith("code_")]
+    parser = argparse.ArgumentParser()
+    parser.add_argument("test", type=str, choices=codes)
+    parser.add_argument("outfile",
+                        nargs="?",
+                        type=argparse.FileType('wb'),
+                        default=sys.stdout)
+    parser.add_argument("-s",
+                        "--standalone",
+                        action="store_true",
+                        help="Generate standalone (w0-w4 calling) code")
+    parser.add_argument("-O",
+                        "--output-format",
+                        choices=["asm", "binary", "carray"],
+                        default="asm")
+
+    args = parser.parse_args()
+    code = globals()["code_" + args.test]
+    if args.standalone:
+        code = w04_prolog + code + w04_epilog
+    output(parse(code), args.outfile, args.output_format)