[rom_ctrl] A script to scramble a ROM image

This takes a 32-bit ELF file as input (something like the
boot_rom_sim_verilator.elf that we already create as part of the SW
build).

It pads the file up to the expected ROM size with pseudo-random data,
adds ECC checksum bits, extending to 40 bits in width and then
scrambles the result, writing out a VMEM file at the end.

This is not yet a complete solution because we don't calculate the
expected digest for the top 8 bytes (at the moment, they are all
zero). We'll come back to that once we have a proper model of the
interaction between rom_ctrl and the KMAC block.

As well as code to load a 32-bit ELF and write a 40-bit VMEM file,
this commit also has code to load VMEM files because the initial
version of this patch consumed e.g. boot_rom_sim_verilator.32.vmem.
The problem is that the build process has already zero-padded this
file internally, which we don't want, so we have to load
the (segmented) ELF instead. However, the VMEM loading code works and
might be useful soon, so I've left it in for now.

Signed-off-by: Rupert Swarbrick <rswarbrick@lowrisc.org>
diff --git a/ci/scripts/mypy.sh b/ci/scripts/mypy.sh
index 8a31952..231fb62 100755
--- a/ci/scripts/mypy.sh
+++ b/ci/scripts/mypy.sh
@@ -11,6 +11,7 @@
     hw/ip/otbn/dv/rig
     hw/ip/otbn/dv/otbnsim
     hw/ip/otbn/util
+    hw/ip/rom_ctrl/util
 )
 
 retcode=0
diff --git a/hw/ip/rom_ctrl/data/rom_ctrl.hjson b/hw/ip/rom_ctrl/data/rom_ctrl.hjson
index 42ed849..6adcdb3 100644
--- a/hw/ip/rom_ctrl/data/rom_ctrl.hjson
+++ b/hw/ip/rom_ctrl/data/rom_ctrl.hjson
@@ -122,6 +122,9 @@
 
     rom: [
       // ROM size (given as `items` below) must be a power of two.
+      //
+      // NOTE: This number is replicated in ../util/scramble_image.py: keep the
+      // two in sync.
       { window: {
           name: "ROM"
           items: "4096" // 16 KiB
diff --git a/hw/ip/rom_ctrl/util/Makefile b/hw/ip/rom_ctrl/util/Makefile
new file mode 100644
index 0000000..99116fd
--- /dev/null
+++ b/hw/ip/rom_ctrl/util/Makefile
@@ -0,0 +1,30 @@
+# Copyright lowRISC contributors.
+# Licensed under the Apache License, Version 2.0, see LICENSE for details.
+# SPDX-License-Identifier: Apache-2.0
+
+.PHONY: all
+all: lint asm-check
+
+# We need a directory to build stuff and use the "rom_ctrl/util" namespace
+# in the top-level build-out directory.
+repo-top := ../../../..
+build-dir := $(repo-top)/build-out/otbn/util
+lint-build-dir := $(build-dir)/lint
+
+$(build-dir) $(lint-build-dir):
+	mkdir -p $@
+
+pyscripts := scramble_image.py
+pylibs := $(filter-out $(pyscripts),$(wildcard *.py))
+
+lint-stamps := $(foreach s,$(pyscripts),$(lint-build-dir)/$(s).stamp)
+$(lint-build-dir)/%.stamp: % $(pylibs) | $(lint-build-dir)
+	mypy --strict $< $(pylibs)
+	touch $@
+
+.PHONY: lint
+lint: $(lint-stamps)
+
+.PHONY: clean
+clean:
+	rm -rf $(build-dir)
diff --git a/hw/ip/rom_ctrl/util/mem.py b/hw/ip/rom_ctrl/util/mem.py
new file mode 100644
index 0000000..2f55169
--- /dev/null
+++ b/hw/ip/rom_ctrl/util/mem.py
@@ -0,0 +1,344 @@
+#!/usr/bin/env python3
+# Copyright lowRISC contributors.
+# Licensed under the Apache License, Version 2.0, see LICENSE for details.
+# SPDX-License-Identifier: Apache-2.0
+
+import random
+import re
+import subprocess
+import tempfile
+from typing import BinaryIO, IO, List, Optional, TextIO, Tuple
+
+from elftools.elf.elffile import ELFFile  # type: ignore
+
+
+def red_xor32(word: int) -> int:
+    '''Reduction XOR for a uint32'''
+    word = (word & 0xffff) ^ (word >> 16)
+    word = (word & 0xff) ^ (word >> 8)
+    word = (word & 0xf) ^ (word >> 4)
+    word = (word & 0x3) ^ (word >> 2)
+    return (word & 0x1) ^ (word >> 1)
+
+
+def add_ecc32(word: int) -> int:
+    '''Add Hsiao (39,32) ECC bits to a 32-bit unsigned word'''
+    assert 0 <= word < (1 << 32)
+    b0 = red_xor32(word ^ 0x00850e56a2) << 32
+    b1 = red_xor32(word ^ 0x002e534c61) << 33
+    b2 = red_xor32(word ^ 0x000901a9fe) << 34
+    b3 = red_xor32(word ^ 0x007079a702) << 35
+    b4 = red_xor32(word ^ 0x00caba900d) << 36
+    b5 = red_xor32(word ^ 0x00d3c44b18) << 37
+    b6 = red_xor32(word ^ 0x0034a430d5) << 38
+    return word | b0 | b1 | b2 | b3 | b4 | b5 | b6
+
+
+class MemChunk:
+    def __init__(self, base_addr: int, words: List[int]):
+        '''A contiguous list of words starting at base_addr'''
+        self.base_addr = base_addr
+        self.words = words
+
+    def __str__(self) -> str:
+        return ('MemChunk(@{:#x}, words_len={})'
+                .format(self.base_addr, len(self.words)))
+
+    def next_addr(self) -> int:
+        '''Get the address directly above the chunk'''
+        return self.base_addr + len(self.words)
+
+    def write_vmem(self, width: int, outfile: TextIO) -> None:
+        '''Write this chunk as one or more lines to outfile
+
+        width is the maximum width of a word in bits.
+
+        '''
+        addr_chars = max(8, (self.next_addr().bit_length() + 3) // 4)
+        word_chars = (width + 3) // 4
+
+        # Try to wrap at 79 characters. To do this, pick a number of words so
+        # that addr_chars + num_words * (word_chars + 1) fits (note that we
+        # gain a character by adding a @ on the front of the address, but lose
+        # it again by omitting the trailing space after the last word).
+        nwords_on_line = max(1, (79 - addr_chars) // (1 + word_chars))
+        for start_idx in range(0, len(self.words), nwords_on_line):
+            line_addr = self.base_addr + start_idx
+            toks = [f'@{line_addr:0{addr_chars}X}']
+            for word in self.words[start_idx:start_idx + nwords_on_line]:
+                toks.append(f'{word:0{word_chars}X}')
+            outfile.write(' '.join(toks) + '\n')
+
+    def add_ecc32(self) -> None:
+        '''Add ECC32 integrity bits
+
+        This extends the input words (which are assumed to be 32-bit) by 7
+        bits, to make 39-bit words.
+
+        '''
+        self.words = [add_ecc32(w) for w in self.words]
+
+
+class MemFile:
+    def __init__(self, width: int, chunks: List[MemChunk]):
+        self.width = width
+        self.chunks = chunks
+
+    def __str__(self) -> str:
+        return ('MemFile(width={}, chunks_len={})'
+                .format(self.width, len(self.chunks)))
+
+    @staticmethod
+    def _parse_line(width: int, line: str) -> Tuple[int, List[int]]:
+        '''Parse a line from a preprocessed vmem file
+
+        Returns a pair (addr, words) where addr is the address at the start of
+        the line and words is a list of the words that have been found, parsed
+        to unsigned numbers. Assumes that line has at least one non-whitespace
+        character. Each word is checked to make sure that it fits in width
+        bits.
+
+        '''
+        tokens = line.split()
+        assert tokens
+
+        addr_match = re.match(r'@([0-9a-fA-F]+)$', tokens[0])
+        if addr_match is None:
+            raise ValueError('Bad line format: first token is {!r}, '
+                             'which is not in the right format for an address.'
+                             .format(tokens[0]))
+        addr = int(addr_match.group(1), 16)
+
+        words = []
+        for idx, word_tok in enumerate(tokens[1:]):
+            try:
+                word = int(word_tok, 16)
+            except ValueError:
+                raise ValueError('Word {} of the line is invalid: '
+                                 '{!r} is not a hex number.'
+                                 .format(idx + 1, word_tok)) from None
+
+            if word < 0 or word >> width:
+                raise ValueError('Word {} of the line is {!r}, which '
+                                 'does not fit in an unsigned {}-bit number.'
+                                 .format(idx + 1, word_tok, width))
+            words.append(word)
+
+        return (addr, words)
+
+    @staticmethod
+    def _load_preproc(width: int, infile: IO[str]) -> 'MemFile':
+        '''Load a pre-processed file'''
+        chunks = []
+        next_chunk = None  # type: Optional[MemChunk]
+        for line in infile:
+            # If the line is empty or whitespace, skip it.
+            if not line or line.isspace():
+                continue
+
+            line_addr, line_words = MemFile._parse_line(width, line)
+
+            # If there aren't actually any words on the line, skip it.
+            if not line_words:
+                continue
+
+            if next_chunk is None:
+                next_chunk = MemChunk(line_addr, line_words)
+                continue
+
+            # Glue the line onto the current chunk if there's no gap
+            chunk_end = next_chunk.next_addr()
+            if line_addr < chunk_end:
+                raise ValueError("Cannot read data starting at {:#x}: "
+                                 "we're already at {:#x}, so this would "
+                                 "go backwards."
+                                 .format(line_addr, chunk_end))
+            if line_addr == chunk_end:
+                next_chunk.words += line_words
+                continue
+
+            # If we're here, there's a gap between the current chunk and
+            # line_addr.
+            chunks.append(next_chunk)
+            next_chunk = MemChunk(line_addr, line_words)
+
+        if next_chunk is not None:
+            chunks.append(next_chunk)
+
+        return MemFile(width, chunks)
+
+    @staticmethod
+    def load_vmem(width: int, infile: TextIO) -> 'MemFile':
+        '''Read a VMEM file
+
+        This assumes that all words fit in the given width.
+
+        '''
+        with tempfile.TemporaryFile('w+') as tmp:
+            # First, run cpp as a subprocess to strip out any comments. These
+            # are allowed by the vmem format as described in srec_vmem(5) and
+            # tokenising them is hard: get the C preprocessor to do the work
+            # for us! The -P argument tells cpp not to generate linemarkers
+            subprocess.run(['cpp', '-P'], stdin=infile, stdout=tmp, check=True)
+            tmp.seek(0)
+            return MemFile._load_preproc(width, tmp)
+
+    @staticmethod
+    def load_elf32(infile: BinaryIO, base_addr: int) -> 'MemFile':
+        '''Read a little-endian 32-bit ELF file'''
+        elf_file = ELFFile(infile)
+        segments = []  # type: List[Tuple[int, int, bytes]]
+        for segment in elf_file.iter_segments():
+            seg_type = segment['p_type']
+
+            # We're only interested in nonempty PT_LOAD segments
+            if seg_type != 'PT_LOAD' or segment['p_memsz'] == 0:
+                continue
+
+            seg_lma = segment['p_paddr']
+            seg_end = seg_lma + segment['p_memsz']
+
+            # We re-map the addresses relative to base_addr: check that no
+            # segment starts before it.
+            if seg_lma < base_addr:
+                raise ValueError('ELF file contains a segment starting at '
+                                 '{:#x}, so cannot be loaded relative to base '
+                                 'address {:#x}.'
+                                 .format(seg_lma, base_addr))
+
+            segments.append((seg_lma - base_addr,
+                             seg_end - base_addr, segment.data()))
+
+        # Sort the segments by base address
+        segments.sort(key=lambda t: t[0])
+
+        # Make sure that they don't overlap
+        prev_lma = 0
+        next_addr = 0
+        for lma, end, data in segments:
+            if lma < next_addr:
+                raise ValueError('ELF file contains overlapping segments with '
+                                 'address ranges {:#x}..{:#x} and '
+                                 '{:#x}..{:#x}.'
+                                 .format(prev_lma, next_addr - 1, lma, end))
+            prev_lma = lma
+            next_addr = end + 1
+
+        # Merge any adjacent segments, bridging any sub-word gaps. This doesn't
+        # do any other right padding: we'll do that on the final pass that
+        # converts to 32-bit words.
+        merged_segments = []  # type: List[Tuple[int, int, bytes]]
+        next_word = 0
+        for lma, end, data in segments:
+            # Round the LMA down to the previous word boundary. The non-overlap
+            # check above should ensure that this is never actually less than
+            # next_word.
+            lma_word = lma // 4
+            assert next_word <= lma_word
+
+            # If there isn't an aligned whole word between the two segments,
+            # bridge the gap
+            if merged_segments and next_word == lma_word:
+                last_lma_word, last_end, last_data = merged_segments[-1]
+                if last_end < lma:
+                    # The largest gap here is be something like last_end = 1;
+                    # lma = 7, which has size 2*4 - 1 - 1 = 6.
+                    assert lma - last_end <= 6
+                    last_data += bytes(lma - last_end)
+                merged_segments[-1] = (last_lma_word, end, last_data + data)
+            else:
+                # Pad on the left if necessary to ensure that lma is 32-bit
+                # aligned.
+                if lma % 4:
+                    merged_segments.append((lma_word, end, bytes(lma % 4) + data))
+                else:
+                    merged_segments.append((lma_word, end, data))
+
+            next_word = 1 + (end // 4)
+
+        # Assemble the bytes in each segment into little-endian 32-bit words.
+        # Zero-extend any partial word at the end of a segment. Because of the
+        # merging in the previous pass, we know this won't cause any overlaps.
+        chunks = []  # type: List[MemChunk]
+        for lma_word, _, data in merged_segments:
+            words = []
+            word = 0
+            for idx, byte in enumerate(data):
+                shift = 8 * (idx % 4)
+                word |= byte << shift
+                if idx % 4 == 3:
+                    words.append(word)
+                    word = 0
+            # idx here will be the index of the last byte. If data ended with a
+            # partial word, idx will be something other than 3 mod 4.
+            if idx % 4 != 3:
+                words.append(word)
+
+            chunks.append(MemChunk(lma_word, words))
+
+        return MemFile(32, chunks)
+
+    def next_addr(self) -> int:
+        '''Get the address directly above the top of the MemFile'''
+        return 0 if not self.chunks else self.chunks[-1].next_addr()
+
+    def write_vmem(self, outfile: TextIO) -> None:
+        '''Write data to a VMEM file'''
+        for chunk in self.chunks:
+            chunk.write_vmem(self.width, outfile)
+
+    def flatten(self, size: int, rnd_seed: int) -> 'MemFile':
+        '''Flatten into a single chunk, padding with pseudo-random data
+
+        As well as padding between the chunks, this expands the result up to
+        size words by adding padding after the last chunk if necessary.
+
+        '''
+        assert self.next_addr() <= size
+
+        old_rnd_state = random.getstate()
+        random.seed(rnd_seed)
+
+        try:
+            acc = MemChunk(0, [])
+            # Add each chunk
+            for chunk in self.chunks:
+                acc_end = acc.next_addr()
+                assert acc_end <= chunk.base_addr
+
+                # If there's a gap before the chunk, insert some random bits
+                padding_len = chunk.base_addr - acc_end
+                if padding_len:
+                    acc.words += [random.getrandbits(32)
+                                  for _ in range(padding_len)]
+
+                assert acc.next_addr() == chunk.base_addr
+                acc.words += chunk.words
+
+            acc_end = acc.next_addr()
+            assert acc_end == self.next_addr()
+
+            # If there's a gap after the last chunk, insert some more random
+            # bits
+            padding_len = size - acc_end
+            if padding_len:
+                acc.words += [random.getrandbits(32)
+                              for _ in range(padding_len)]
+
+            assert acc.next_addr() == size
+
+            return MemFile(self.width, [acc])
+        finally:
+            random.setstate(old_rnd_state)
+
+    def add_ecc32(self) -> None:
+        '''Add ECC32 integrity bits
+
+        This extends the input words (which are assumed to be 32-bit) by 7
+        bits, to make 39-bit words.
+
+        '''
+        assert self.width <= 32
+        for chunk in self.chunks:
+            chunk.add_ecc32()
+        self.width = 39
diff --git a/hw/ip/rom_ctrl/util/scramble_image.py b/hw/ip/rom_ctrl/util/scramble_image.py
new file mode 100755
index 0000000..49b14c9
--- /dev/null
+++ b/hw/ip/rom_ctrl/util/scramble_image.py
@@ -0,0 +1,457 @@
+#!/usr/bin/env python3
+# Copyright lowRISC contributors.
+# Licensed under the Apache License, Version 2.0, see LICENSE for details.
+# SPDX-License-Identifier: Apache-2.0
+
+'''Script for scrambling a ROM image'''
+
+import argparse
+from typing import Dict, List
+
+import hjson  # type: ignore
+
+from mem import MemChunk, MemFile
+
+ROM_BASE_WORD = 0x8000 // 4
+ROM_SIZE_WORDS = 4096
+
+PRINCE_SBOX4 = [
+    0xb, 0xf, 0x3, 0x2,
+    0xa, 0xc, 0x9, 0x1,
+    0x6, 0x7, 0x8, 0x0,
+    0xe, 0x5, 0xd, 0x4
+]
+
+PRINCE_SBOX4_INV = [
+    0xb, 0x7, 0x3, 0x2,
+    0xf, 0xd, 0x8, 0x9,
+    0xa, 0x6, 0x4, 0x0,
+    0x5, 0xe, 0xc, 0x1
+]
+
+PRESENT_SBOX4 = [
+    0xc, 0x5, 0x6, 0xb,
+    0x9, 0x0, 0xa, 0xd,
+    0x3, 0xe, 0xf, 0x8,
+    0x4, 0x7, 0x1, 0x2
+]
+
+PRESENT_SBOX4_INV = [
+    0x5, 0xe, 0xf, 0x8,
+    0xc, 0x1, 0x2, 0xd,
+    0xb, 0x4, 0x6, 0x3,
+    0x0, 0x7, 0x9, 0xa
+]
+
+PRINCE_SHIFT_ROWS64 = [
+    0x4, 0x9, 0xe, 0x3,
+    0x8, 0xd, 0x2, 0x7,
+    0xc, 0x1, 0x6, 0xb,
+    0x0, 0x5, 0xa, 0xf
+]
+
+PRINCE_SHIFT_ROWS64_INV = [
+    0xc, 0x9, 0x6, 0x3,
+    0x0, 0xd, 0xa, 0x7,
+    0x4, 0x1, 0xe, 0xb,
+    0x8, 0x5, 0x2, 0xf
+]
+
+PRINCE_ROUND_CONSTS = [
+    0x0000000000000000,
+    0x13198a2e03707344,
+    0xa4093822299f31d0,
+    0x082efa98ec4e6c89,
+    0x452821e638d01377,
+    0xbe5466cf34e90c6c,
+    0x7ef84f78fd955cb1,
+    0x85840851f1ac43aa,
+    0xc882d32f25323c54,
+    0x64a51195e0e3610d,
+    0xd3b5a399ca0c2399,
+    0xc0ac29b7c97c50dd
+]
+
+PRINCE_SHIFT_ROWS_CONSTS = [0x7bde, 0xbde7, 0xde7b, 0xe7bd]
+
+_UDict = Dict[object, object]
+
+
+def sbox(data: int, width: int, coeffs: List[int]) -> int:
+    assert 0 <= width
+    assert 0 <= data < (1 << width)
+
+    full_mask = (1 << width) - 1
+    sbox_mask = (1 << (4 * (width // 4))) - 1
+
+    ret = data & (full_mask & ~sbox_mask)
+    for i in range(width // 4):
+        nibble = (data >> (4 * i)) & 0xf
+        sb_nibble = coeffs[nibble]
+        ret |= sb_nibble << (4 * i)
+
+    return ret
+
+
+def subst_perm_enc(data: int, key: int, width: int, num_rounds: int) -> int:
+    '''A model of prim_subst_perm in encrypt mode'''
+    assert 0 <= width
+    assert 0 <= data < (1 << width)
+    assert 0 <= key < (1 << width)
+
+    full_mask = (1 << width) - 1
+    bfly_mask = (1 << (2 * (width // 2))) - 1
+
+    for rnd in range(num_rounds):
+        data_xor = data ^ key
+
+        # SBox layer
+        data_sbox = sbox(data_xor, width, PRESENT_SBOX4)
+
+        # Reverse the vector
+        data_rev = 0
+        for i in range(width):
+            bit = (data_sbox >> i) & 1
+            data_rev |= bit << (width - 1 - i)
+
+        # Butterfly
+        data_bfly = data_rev & (full_mask & ~bfly_mask)
+        for i in range(width // 2):
+            # data_bfly[i] = data_rev[2i]
+            bit = (data_rev >> (2 * i)) & 1
+            data_bfly |= bit << i
+            # data_bfly[width/2 + i] = data_rev[2i+1]
+            bit = (data_rev >> (2 * i + 1)) & 1
+            data_bfly |= bit << (width // 2 + i)
+
+        data = data_bfly
+
+    return data ^ key
+
+
+def subst_perm_dec(data: int, key: int, width: int, num_rounds: int) -> int:
+    '''A model of prim_subst_perm in decrypt mode'''
+    assert 0 <= width
+    assert 0 <= data < (1 << width)
+    assert 0 <= key < (1 << width)
+
+    full_mask = (1 << width) - 1
+    bfly_mask = (1 << (2 * (width // 2))) - 1
+
+    for rnd in range(num_rounds):
+        data_xor = data ^ key
+
+        # Butterfly
+        data_bfly = data_xor & (full_mask & ~bfly_mask)
+        for i in range(width // 2):
+            # data_bfly[2i] = data_xor[i]
+            bit = (data_xor >> i) & 1
+            data_bfly |= bit << (2 * i)
+            # data_bfly[2i+1] = data_xor[i + width // 2]
+            bit = (data_xor >> (i + width // 2)) & 1
+            data_bfly |= bit << (2 * i + 1)
+
+        # Reverse the vector
+        data_rev = 0
+        for i in range(width):
+            bit = (data_bfly >> i) & 1
+            data_rev |= bit << (width - 1 - i)
+
+        # Inverse SBox layer
+        data = sbox(data_rev, width, PRESENT_SBOX4_INV)
+
+    return data ^ key
+
+
+def prince_nibble_red16(data: int) -> int:
+    assert 0 <= data < (1 << 16)
+    nib0 = (data >> 0) & 0xf
+    nib1 = (data >> 4) & 0xf
+    nib2 = (data >> 8) & 0xf
+    nib3 = (data >> 12) & 0xf
+    return nib0 ^ nib1 ^ nib2 ^ nib3
+
+
+def prince_mult_prime(data: int) -> int:
+    assert 0 <= data < (1 << 64)
+    ret = 0
+    for blk_idx in range(4):
+        data_hw = (data >> (16 * blk_idx)) & 0xffff
+        start_sr_idx = 0 if blk_idx in [0, 3] else 1
+        for nibble_idx in range(4):
+            sr_idx = (start_sr_idx + 3 - nibble_idx) % 4
+            sr_const = PRINCE_SHIFT_ROWS_CONSTS[sr_idx]
+            nibble = prince_nibble_red16(data_hw & sr_const)
+            ret |= nibble << (16 * blk_idx + 4 * nibble_idx)
+    return ret
+
+
+def prince_shiftrows(data: int, inv: bool) -> int:
+    assert 0 <= data < (1 << 64)
+    shifts = PRINCE_SHIFT_ROWS64_INV if inv else PRINCE_SHIFT_ROWS64
+
+    ret = 0
+    for nibble_idx in range(64 // 4):
+        src_nibble_idx = shifts[nibble_idx]
+        src_nibble = (data >> (4 * src_nibble_idx)) & 0xf
+        ret |= src_nibble << (4 * nibble_idx)
+    return ret
+
+
+def prince_fwd_round(rc: int, key: int, data: int) -> int:
+    assert 0 <= rc < (1 << 64)
+    assert 0 <= key < (1 << 64)
+    assert 0 <= data < (1 << 64)
+
+    data = sbox(data, 64, PRINCE_SBOX4)
+    data = prince_mult_prime(data)
+    data = prince_shiftrows(data, False)
+    data ^= rc
+    data ^= key
+    return data
+
+
+def prince_inv_round(rc: int, key: int, data: int) -> int:
+    assert 0 <= rc < (1 << 64)
+    assert 0 <= key < (1 << 64)
+    assert 0 <= data < (1 << 64)
+
+    data ^= key
+    data ^= rc
+    data = prince_shiftrows(data, True)
+    data = prince_mult_prime(data)
+    data = sbox(data, 64, PRINCE_SBOX4_INV)
+    return data
+
+
+def prince(data: int, key: int, num_rounds_half: int) -> int:
+    '''Run the PRINCE cipher
+
+    This uses the new keyschedule proposed by Dinur in "Cryptanalytic
+    Time-Memory-Data Tradeoffs for FX-Constructions with Applications to PRINCE
+    and PRIDE".
+
+    '''
+    assert 0 <= data < (1 << 64)
+    assert 0 <= key < (1 << 128)
+    assert 0 <= num_rounds_half <= 5
+
+    # TODO: This matches the RTL in prim_prince.sv, but seems to be the other
+    #       way around in the original paper.
+    k1 = key & ((1 << 64) - 1)
+    k0 = key >> 64
+
+    k0_rot1 = ((k0 & 1) << 63) | (k0 >> 1)
+    k0_prime = k0_rot1 ^ (k0 >> 63)
+
+    data ^= k0
+    data ^= k1
+    data ^= PRINCE_ROUND_CONSTS[0]
+
+    for hri in range(num_rounds_half):
+        round_idx = 1 + hri
+        rc = PRINCE_ROUND_CONSTS[round_idx]
+        rk = k0 if round_idx & 1 else k1
+        data = prince_fwd_round(rc, rk, data)
+
+    data = sbox(data, 64, PRINCE_SBOX4)
+    data = prince_mult_prime(data)
+    data = sbox(data, 64, PRINCE_SBOX4_INV)
+
+    for hri in range(num_rounds_half):
+        round_idx = 11 - num_rounds_half + hri
+        rc = PRINCE_ROUND_CONSTS[round_idx]
+        rk = k1 if round_idx & 1 else k0
+        data = prince_inv_round(rc, rk, data)
+
+    data ^= PRINCE_ROUND_CONSTS[11]
+    data ^= k1
+
+    data ^= k0_prime
+
+    return data
+
+
+class Scrambler:
+    def __init__(self, nonce: int, key: int, rom_size_words: int):
+        assert 0 <= nonce < (1 << 64)
+        assert 0 <= key < (1 << 128)
+        assert 0 < rom_size_words < (1 << 64)
+
+        self.nonce = nonce
+        self.key = key
+        self.rom_size_words = rom_size_words
+
+        self._addr_width = (rom_size_words - 1).bit_length()
+
+    @staticmethod
+    def _get_rom_ctrl(modules: List[object]) -> _UDict:
+        rom_ctrls = []  # type: List[_UDict]
+        for entry in modules:
+            assert isinstance(entry, dict)
+            entry_type = entry.get('type')
+            assert isinstance(entry_type, str)
+
+            if entry_type == 'rom_ctrl':
+                rom_ctrls.append(entry)
+
+        assert len(rom_ctrls) == 1
+        return rom_ctrls[0]
+
+    @staticmethod
+    def _get_params(module: _UDict) -> Dict[str, _UDict]:
+        params = module.get('param_list')
+        assert isinstance(params, list)
+
+        named_params = {}  # type: Dict[str, _UDict]
+        for param in params:
+            name = param.get('name')
+            assert isinstance(name, str)
+            assert name not in named_params
+            named_params[name] = param
+
+        return named_params
+
+    @staticmethod
+    def _get_param_value(params: Dict[str, _UDict],
+                         name: str,
+                         width: int) -> int:
+        param = params.get(name)
+        assert isinstance(param, dict)
+
+        default = param.get('default')
+        assert isinstance(default, str)
+        int_val = int(default, 0)
+        assert 0 <= int_val < (1 << width)
+        return int_val
+
+    @staticmethod
+    def from_hjson_path(path: str, rom_size_words: int) -> 'Scrambler':
+        assert 0 < rom_size_words
+
+        with open(path) as handle:
+            top = hjson.load(handle, use_decimal=True)
+
+        assert isinstance(top, dict)
+        modules = top.get('module')
+        assert isinstance(modules, list)
+
+        rom_ctrl = Scrambler._get_rom_ctrl(modules)
+
+        params = Scrambler._get_params(rom_ctrl)
+        nonce = Scrambler._get_param_value(params, 'RndCnstScrNonce', 64)
+        key = Scrambler._get_param_value(params, 'RndCnstScrKey', 128)
+
+        return Scrambler(nonce, key, rom_size_words)
+
+    def flatten(self, mem: MemFile) -> MemFile:
+        '''Flatten and pad mem up to the correct size
+
+        This adds 8 trailing zero words as space to store the expected hash.
+        These are (obviously!) not the right hash: we inject them properly
+        later.
+
+        '''
+        digest_size_words = 8
+        initial_len = self.rom_size_words - digest_size_words
+        seed = self.key + self.nonce
+
+        flattened = mem.flatten(initial_len, seed)
+        assert len(flattened.chunks) == 1
+        assert len(flattened.chunks[0].words) == initial_len
+
+        # Add the 8 trailing zero words. We do it here, rather than passing
+        # rom_size_words to mem.flatten, to make sure that we see the error if
+        # mem is too big.
+        flattened.chunks[0].words += [0] * digest_size_words
+
+        return flattened
+
+    def scramble40(self, mem: MemFile) -> MemFile:
+        assert len(mem.chunks) == 1
+        assert len(mem.chunks[0].words) == self.rom_size_words
+
+        word_width = 40
+
+        # Write addr_sp, data_sp for the S&P networks for address and data,
+        # respectively. Write clr[i] for unscrambled data word i and scr[i] for
+        # scrambled data word i. We need to construct scr[0], scr[1], ...,
+        # scr[self.rom_size_words].
+        #
+        # Then, for all i, we have:
+        #
+        #   clr[i] = PRINCE(i) ^ data_sp(scr[addr_sp(i)])
+        #
+        # Change coordinates by evaluating at addr_sp_inv(i):
+        #
+        #   clr[addr_sp_inv(i)] = PRINCE(addr_sp_inv(i)) ^ data_sp(scr[i])
+        #
+        # so
+        #
+        #   scr[i] = data_sp_inv(clr[addr_sp_inv(i)] ^ PRINCE(addr_sp_inv(i)))
+        subst_perm_rounds = 2
+        num_rounds_half = 2
+
+        assert word_width <= 64
+        word_mask = (1 << word_width) - 1
+
+        data_scr_nonce = self.nonce >> self._addr_width
+        addr_scr_nonce = self.nonce & ((1 << self._addr_width) - 1)
+
+        scrambled = []
+        for phy_addr in range(self.rom_size_words):
+            log_addr = subst_perm_dec(phy_addr, addr_scr_nonce,
+                                      self._addr_width, subst_perm_rounds)
+            assert 0 <= log_addr < self.rom_size_words
+
+            to_scramble = (data_scr_nonce << self._addr_width) | log_addr
+            keystream = prince(to_scramble, self.key, num_rounds_half)
+
+            keystream_trunc = keystream & word_mask
+            clr_data = mem.chunks[0].words[log_addr]
+            assert 0 <= clr_data < word_mask
+
+            sp_scr_data = keystream_trunc ^ clr_data
+            scr_data = subst_perm_enc(sp_scr_data, 0, word_width, subst_perm_rounds)
+
+            assert 0 <= scr_data < word_mask
+
+            scrambled.append(scr_data)
+
+        return MemFile(mem.width, [MemChunk(0, scrambled)])
+
+
+def main() -> None:
+    parser = argparse.ArgumentParser()
+    parser.add_argument('hjson')
+    parser.add_argument('infile', type=argparse.FileType('rb'))
+    parser.add_argument('outfile', type=argparse.FileType('w'))
+
+    args = parser.parse_args()
+    scrambler = Scrambler.from_hjson_path(args.hjson, ROM_SIZE_WORDS)
+
+    # Load the input ELF file
+    clr_mem = MemFile.load_elf32(args.infile, 4 * ROM_BASE_WORD)
+
+    # Flatten the file, padding with pseudo-random data and ensuring it's
+    # exactly scrambler.rom_size_words words long.
+    clr_flat = scrambler.flatten(clr_mem)
+
+    # Extend from 32 bits to 39 by adding Hsiao (39,32) ECC bits.
+    clr_flat.add_ecc32()
+
+    # Zero-extend the cleartext memory by one more bit (this is the size we
+    # actually use in the physical ROM)
+    assert clr_flat.width == 39
+    clr_flat.width = 40
+
+    # Scramble the memory
+    scr_mem = scrambler.scramble40(clr_flat)
+
+    # TODO: Calculate and insert the expected hash here.
+
+    scr_mem.write_vmem(args.outfile)
+
+
+if __name__ == '__main__':
+    main()
diff --git a/sw/device/boot_rom/meson.build b/sw/device/boot_rom/meson.build
index 9c78f11..a3ff68a 100644
--- a/sw/device/boot_rom/meson.build
+++ b/sw/device/boot_rom/meson.build
@@ -75,6 +75,15 @@
     build_by_default: true,
   )
 
+  boot_rom_scrambled = custom_target(
+    'boot_rom_scrambled_' + device_name,
+    command: scramble_image_command,
+    depend_files: scramble_image_depend_files,
+    input: boot_rom_elf,
+    output: scramble_image_outputs,
+    build_by_default: true,
+  )
+
   boot_rom_sim_dv_logs = []
   if device_name == 'sim_dv'
     boot_rom_sim_dv_logs = custom_target(
diff --git a/sw/device/meson.build b/sw/device/meson.build
index 2211605..7535682 100644
--- a/sw/device/meson.build
+++ b/sw/device/meson.build
@@ -62,6 +62,29 @@
   meson.source_root() / 'util/design/gen-otp-img.py',
 ]
 
+# Generates a scrambled version of a ROM image from an ELF
+#
+# TODO: This is currently top_earlgrey-specific. That's fine for now, because
+#       top_earlgrey is the only top-level with a rom_ctrl block, but we'll
+#       need to make this more generic if we support more top-levels.
+scramble_image_hjson = [
+  meson.source_root() / 'hw/top_earlgrey/data/autogen/top_earlgrey.gen.hjson'
+]
+scramble_image_outputs = [
+  '@BASENAME@.scr.40.vmem',
+]
+scramble_image_command = [
+    prog_python,
+    meson.source_root() / 'hw/ip/rom_ctrl/util/scramble_image.py',
+    scramble_image_hjson,
+    '@INPUT@',
+    '@OUTPUT@',
+]
+scramble_image_depend_files = [
+    meson.source_root() / 'hw/ip/rom_ctrl/util/scramble_image.py',
+    scramble_image_hjson
+]
+
 subdir('boot_rom')
 subdir('otp_img')
 subdir('silicon_creator')
diff --git a/sw/device/silicon_creator/mask_rom/meson.build b/sw/device/silicon_creator/mask_rom/meson.build
index 1c662f8..9483add 100644
--- a/sw/device/silicon_creator/mask_rom/meson.build
+++ b/sw/device/silicon_creator/mask_rom/meson.build
@@ -55,11 +55,24 @@
     build_by_default: true,
   )
 
+  mask_rom_scrambled = custom_target(
+    'mask_rom_scrambled_' + device_name,
+    command: scramble_image_command,
+    depend_files: scramble_image_depend_files,
+    input: mask_rom_elf,
+    output: scramble_image_outputs,
+    build_by_default: true,
+  )
+
   custom_target(
     'mask_rom_export_' + device_name,
     command: export_target_command,
     depend_files: [export_target_depend_files,],
-    input: [mask_rom_elf, mask_rom_embedded],
+    input: [
+      mask_rom_elf,
+      mask_rom_embedded,
+      mask_rom_scrambled,
+    ],
     output: 'mask_rom_export_' + device_name,
     build_always_stale: true,
     build_by_default: true,