[util] refactor flash img processing script to enable scrambling

This refactors the flash image pre-processing script to enable
pre-scrambling a flash VMEM file so that it can be backdoor loaded in DV
sims when scrambling is enabled.

Extra command line arguments can now be passed to the script to turn on
the data scrambling feature. When scrambling is not enabled, the script
will simply reformat the input VMEM file to add integrity and
reliability ECC bits, same as before.

Signed-off-by: Timothy Trippel <ttrippel@google.com>
diff --git a/rules/opentitan.bzl b/rules/opentitan.bzl
index 073da62..2714c04 100644
--- a/rules/opentitan.bzl
+++ b/rules/opentitan.bzl
@@ -453,7 +453,9 @@
             ctx.executable._tool,
         ],
         arguments = [
+            "--infile",
             ctx.file.vmem.path,
+            "--outfile",
             scrambled_vmem.path,
         ],
         executable = ctx.executable._tool,
diff --git a/util/design/BUILD b/util/design/BUILD
index ba7cd00..4ecae87 100644
--- a/util/design/BUILD
+++ b/util/design/BUILD
@@ -22,7 +22,10 @@
 py_binary(
     name = "gen-flash-img",
     srcs = ["gen-flash-img.py"],
-    deps = [":secded_gen"],
+    deps = [
+        ":secded_gen",
+        requirement("pyfinite"),
+    ],
 )
 
 py_binary(
diff --git a/util/design/gen-flash-img.py b/util/design/gen-flash-img.py
index 64bc765..8e83039 100755
--- a/util/design/gen-flash-img.py
+++ b/util/design/gen-flash-img.py
@@ -2,79 +2,133 @@
 # Copyright lowRISC contributors.
 # Licensed under the Apache License, Version 2.0, see LICENSE for details.
 # SPDX-License-Identifier: Apache-2.0
-r"""Takes a compiled vmem image and processes it for flash.
-    Long term this should include both layers of ECC and scrambling,
-    The first version has only the truncated plaintext ECC.
+r"""Takes a compiled VMEM image and processes it for loading into flash.
+
+    Specifically, this takes a raw flash image and adds both layers of ECC
+    (integrity and reliablity), and optionally, scrambles the data using the
+    same XEX scrambling scheme used in the flash controller. This enables
+    backdoor loading the flash on simulation platforms (e.g., DV and Verilator).
 """
 
 import argparse
-import math
+import functools
+import logging
 import re
+import sys
 from pathlib import Path
-from typing import Dict, Any
+from typing import List
 
+from pyfinite import ffield
+
+import prince
 import secded_gen
 
+FLASH_WORD_SIZE = 64  # bits
+FLASH_ADDR_SIZE = 16  # bits
+FLASH_INTEGRITY_ECC_SIZE = 4  # bits
+FLASH_RELIABILITY_ECC_SIZE = 8  # bits
+GF_OPERAND_B_MASK = (2**FLASH_WORD_SIZE) - 1
+GF_OPERAND_A_MASK = (GF_OPERAND_B_MASK &
+                     ~(0xffff <<
+                       (FLASH_WORD_SIZE - FLASH_ADDR_SIZE))) << FLASH_WORD_SIZE
+PRINCE_NUM_HALF_ROUNDS = 5
 
-def _add_intg_ecc(config: Dict[str, Any], in_val: int) -> str:
-    result, m = secded_gen.ecc_encode(config, "hamming", 64, in_val)
-
-    m_nibbles = math.ceil(m / 4)
-    result = format(result, '0' + str(16 + m_nibbles) + 'x')
-
-    # due to lack of storage space, the first nibble of the ECC is truncated
-    return result[1:]
+# Create GF(2^64) with irreducible_polynomial = x^64 + x^4 + x^3 + x + 1
+GF_2_64 = ffield.FField(64,
+                        gen=((0x1 << 64) | (0x1 << 4) | (0x1 << 3) |
+                             (0x1 << 1) | 0x1))
 
 
-def _add_reliability_ecc(config: Dict[str, Any], in_val: int) -> str:
-    result, m = secded_gen.ecc_encode(config, "hamming", 68, in_val)
-
-    m_nibbles = math.ceil((68 + m) / 4)
-    result = format(result, '0' + str(m_nibbles) + 'x')
-
-    # return full result
-    return result
+@functools.lru_cache(maxsize=None)
+def _xex_scramble(data: int, word_addr: int, flash_addr_key: int,
+                  flash_data_key: int) -> int:
+    operand_a = ((flash_addr_key & GF_OPERAND_A_MASK) >>
+                 (FLASH_WORD_SIZE - FLASH_ADDR_SIZE)) | word_addr
+    operand_b = flash_addr_key & GF_OPERAND_B_MASK
+    mask = GF_2_64.Multiply(operand_a, operand_b)
+    masked_data = data ^ mask
+    return prince.prince(masked_data, flash_data_key,
+                         PRINCE_NUM_HALF_ROUNDS) ^ mask
 
 
-def main():
+def main(argv: List[str]):
+    # Parse command line args.
     parser = argparse.ArgumentParser()
-    parser.add_argument('infile')
-    parser.add_argument('outfile')
-    args = parser.parse_args()
+    parser.add_argument("--infile",
+                        "-i",
+                        type=str,
+                        help="Input VMEM file to reformat.")
+    parser.add_argument("--outfile", "-o", type=str, help="Output VMEM file.")
+    parser.add_argument("--scramble",
+                        "-s",
+                        action="store_true",
+                        help="Whether to scramble data or not.")
+    parser.add_argument("--address-key",
+                        type=str,
+                        help="Flash address scrambling key.")
+    parser.add_argument("--data-key",
+                        type=str,
+                        help="Flash address scrambling key.")
+    args = parser.parse_args(argv)
 
-    # open original vmem and extract relevant content
+    # Validate command line args.
+    if args.scramble:
+        if args.address_key is None or args.data_key is None:
+            logging.error(
+                "Address and data keys must be provided when scrambling is"
+                "requested.")
+
+    # Open original VMEM for processing.
     try:
-        vmem_orig = Path(args.infile).read_text()
+        orig_vmem = Path(args.infile).read_text()
     except IOError:
         raise Exception(f"Unable to open {args.infile}")
 
-    # search only for lines that contain data, skip all other comments
-    result = re.findall(r"^@.*$", vmem_orig, flags=re.MULTILINE)
+    # Search for lines that contain data, skipping other comment lines.
+    orig_vmem_lines = re.findall(r"^@.*$", orig_vmem, flags=re.MULTILINE)
 
+    # Load project SECDED configuration.
     config = secded_gen.load_secded_config()
 
-    output = []
-    for line in result:
-        items = line.split()
-        result = ""
-        for item in items:
+    reformatted_vmem_lines = []
+    for line in orig_vmem_lines:
+        line_items = line.split()
+        reformatted_line = ""
+        address = None
+        data = None
+        for item in line_items:
+            # Process the address first.
             if re.match(r"^@", item):
-                result += item
+                reformatted_line += item
+                address = int(item.lstrip("@"), 16)
+            # Process the data words.
             else:
-                data_w_intg_ecc = _add_intg_ecc(config, int(item, 16))
-                full_ecc = _add_reliability_ecc(config, int(data_w_intg_ecc, 16))
-                result += f' {full_ecc}'
+                data = int(item, 16)
+                # `data_w_intg_ecc` will be in format {ECC bits, data bits}.
+                data_w_intg_ecc, _ = secded_gen.ecc_encode(
+                    config, "hamming", FLASH_WORD_SIZE, data)
+                # Due to storage constraints the first nibble of ECC is dropped.
+                data_w_intg_ecc &= 0xFFFFFFFFFFFFFFFFF
+                if args.scramble:
+                    intg_ecc = data_w_intg_ecc & (0xF << FLASH_WORD_SIZE)
+                    data = _xex_scramble(data, address, args.flash_addr_key,
+                                         args.flash_data_key)
+                    data_w_intg_ecc = intg_ecc | data
+                # `data_w_full_ecc` will be in format {reliablity ECC bits,
+                # integrity ECC bits, data bits}.
+                data_w_full_ecc, _ = secded_gen.ecc_encode(
+                    config, "hamming",
+                    FLASH_WORD_SIZE + FLASH_INTEGRITY_ECC_SIZE,
+                    data_w_intg_ecc)
+                reformatted_line += f" {data_w_full_ecc:x}"
 
-        # add processed element to output
-        output.append(result)
+        # Append reformatted line to what will be the new output VMEM file.
+        reformatted_vmem_lines.append(reformatted_line)
 
-    # open output file
-    outfile = open(args.outfile, "w")
-
-    # write processed content
-    for entry in output:
-        outfile.write(entry + "\n")
+    # Write re-formatted output file.
+    with open(args.outfile, "w") as of:
+        of.writelines(reformatted_vmem_lines)
 
 
 if __name__ == "__main__":
-    main()
+    main(sys.argv[1:])