[util,bazel] integrate flash scrambling with bazel

This integrates the flash scrambling script with Bazel to enable
pre-scrambling flash images before they are backdoor loaded in DV and
Verilator simulations.

A key feature to enable such is to pass the OTP VMEM image to the flash
scrambling script so the flash scrambling key seeds, and enablement
flag, can be read out and decoded for use.

While this commit reads out the scrambling key seeds, they still must be
processed to produce the actual scrambling keys (as is done in HW).
Additionally, the scrambling enablement flag must be read to determine
whether or not to enabling scrambling within the pre-preprocessing
script. These will both happen in a follow up.

Signed-off-by: Timothy Trippel <ttrippel@google.com>
diff --git a/rules/opentitan.bzl b/rules/opentitan.bzl
index 2714c04..9a68c64 100644
--- a/rules/opentitan.bzl
+++ b/rules/opentitan.bzl
@@ -440,24 +440,34 @@
 )
 
 def _scramble_flash_vmem_impl(ctx):
+    # Declare outputs.
     outputs = []
     scrambled_vmem = ctx.actions.declare_file("{}.scr.vmem".format(
         # Remove ".vmem" from file basename.
         ctx.file.vmem.basename.replace("." + ctx.file.vmem.extension, ""),
     ))
     outputs.append(scrambled_vmem)
+
+    # Build arguments / inputs to `gen-flash-img.py` script.
+    arguments = [
+        "--in-flash-vmem",
+        ctx.file.vmem.path,
+        "--out-flash-vmem",
+        scrambled_vmem.path,
+    ]
+    inputs = [
+        ctx.file.vmem,
+        ctx.executable._tool,
+    ]
+    if ctx.file.otp:
+        arguments.extend(["--in-otp-vmem", ctx.file.otp.path])
+        inputs.append(ctx.file.otp)
+
+    # Run the action script.
     ctx.actions.run(
-        outputs = [scrambled_vmem],
-        inputs = [
-            ctx.file.vmem,
-            ctx.executable._tool,
-        ],
-        arguments = [
-            "--infile",
-            ctx.file.vmem.path,
-            "--outfile",
-            scrambled_vmem.path,
-        ],
+        outputs = outputs,
+        inputs = inputs,
+        arguments = arguments,
         executable = ctx.executable._tool,
     )
     return [DefaultInfo(
@@ -468,6 +478,7 @@
 scramble_flash_vmem = rv_rule(
     implementation = _scramble_flash_vmem_impl,
     attrs = {
+        "otp": attr.label(allow_single_file = True),
         "vmem": attr.label(allow_single_file = True),
         "_tool": attr.label(
             default = "@//util/design:gen-flash-img",
@@ -912,6 +923,7 @@
         platform = OPENTITAN_PLATFORM,
         signing_keys = DEFAULT_SIGNING_KEYS,
         signed = True,
+        sim_otp = None,
         testonly = True,
         manifest = "//sw/device/silicon_creator/rom_ext:manifest_standard",
         **kwargs):
@@ -928,6 +940,8 @@
       @param platform: The target platform for the artifacts.
       @param signing_keys: The signing keys for to sign each BIN file with.
       @param signed: Whether or not to emit signed binary/VMEM files.
+      @param sim_otp: OTP image that contains flash scrambling keys / enablement flag
+                      (only relevant for VMEM files built for sim targets).
       @param manifest: Partially populated manifest to set boot stage/slot configs.
       @param **kwargs: Arguments to forward to `opentitan_binary`.
     Emits rules:
@@ -1004,7 +1018,7 @@
                         word_size = 64,  # Backdoor-load VMEM image uses 64-bit words
                     )
 
-                    # Scramble signed VMEM64.
+                    # Scramble / compute ECC for signed VMEM64.
                     scr_signed_vmem_name = "{}_scr_vmem64_signed_{}".format(
                         devname,
                         key_name,
@@ -1012,6 +1026,7 @@
                     dev_targets.append(":" + scr_signed_vmem_name)
                     scramble_flash_vmem(
                         name = scr_signed_vmem_name,
+                        otp = sim_otp,
                         vmem = signed_vmem_name,
                         platform = platform,
                         testonly = testonly,
@@ -1030,11 +1045,12 @@
                 word_size = 64,  # Backdoor-load VMEM image uses 64-bit words
             )
 
-            # Scramble VMEM64.
+            # Scramble / compute ECC for VMEM64.
             scr_vmem_name = "{}_scr_vmem64".format(devname)
             dev_targets.append(":" + scr_vmem_name)
             scramble_flash_vmem(
                 name = scr_vmem_name,
+                otp = sim_otp,
                 vmem = vmem_name,
                 platform = platform,
                 testonly = testonly,
diff --git a/rules/opentitan_test.bzl b/rules/opentitan_test.bzl
index 619a2c5..03870ed 100644
--- a/rules/opentitan_test.bzl
+++ b/rules/opentitan_test.bzl
@@ -379,10 +379,22 @@
         if slot not in _FLASH_SLOTS:
             fail("Invalid slot: {}. Valid slots are: silicon_creator_{a,b,virtual}".format(slot))
         deps += _FLASH_SLOTS[slot]
+
+        # Get OTP image for sim targets. We need to pass the OTP image to the
+        # flash scrambling script since it contains the seeds to derive the
+        # scrambling keys. No need to worry about flash image scrambling for
+        # FPGA targets as the flash is loaded through bootstrap (i.e., the front
+        # door), unlike the sim targets which load via backdoor.
+        sim_otp_ = None
+        if "sim_dv" in target_params:
+            sim_otp_ = target_params["sim_dv"]["otp"]
+        elif "sim_verilator" in target_params:
+            sim_otp_ = target_params["sim_verilator"]["otp"]
         ot_flash_binary = name + "_prog"
         opentitan_flash_binary(
             name = ot_flash_binary,
             signed = signed,
+            sim_otp = sim_otp_,
             deps = deps,
             devices = devices_to_build_for,
             manifest = manifest,
diff --git a/sw/device/tests/BUILD b/sw/device/tests/BUILD
index 581cc1b..6153b52 100644
--- a/sw/device/tests/BUILD
+++ b/sw/device/tests/BUILD
@@ -2138,6 +2138,9 @@
             "--bootstrap=\"$(location {flash})\"",
         ],
     ),
+    dv = dv_params(
+        otp = ":power_virus_systemtest_otp_img_rma",
+    ),
     targets = [
         # TODO(#14814): add more targets
         "cw310_test_rom",
diff --git a/util/design/BUILD b/util/design/BUILD
index 4ecae87..dd28211 100644
--- a/util/design/BUILD
+++ b/util/design/BUILD
@@ -24,6 +24,7 @@
     srcs = ["gen-flash-img.py"],
     deps = [
         ":secded_gen",
+        "//util/design/lib:present",
         requirement("pyfinite"),
     ],
 )
diff --git a/util/design/gen-flash-img.py b/util/design/gen-flash-img.py
index 8e83039..90b580a 100755
--- a/util/design/gen-flash-img.py
+++ b/util/design/gen-flash-img.py
@@ -12,86 +12,167 @@
 
 import argparse
 import functools
-import logging
 import re
 import sys
+from dataclasses import dataclass
 from pathlib import Path
 from typing import List
 
 from pyfinite import ffield
+from util.design.lib.Present import Present
 
 import prince
 import secded_gen
 
+# Fixed OTP data / scrambling parameters.
+OTP_WORD_SIZE = 16  # bits
+OTP_SECRET1_RE = re.compile(r"SECRET1")
+OTP_SECRET1_BLOCK_SIZE = 64  # bits
+OTP_SECRET1_PRESENT_KEY = 0x5703C3EB2BB563689E00A67814EFBDE8
+OTP_SECRET1_PRESENT_KEY_LENGTH = 128  # bits
+OTP_SECRET1_PRESENT_NUM_ROUNDS = 32
+OTP_FLASH_ADDR_KEY_SEED_SIZE = 256  # bits
+OTP_FLASH_DATA_KEY_SEED_SIZE = 256  # bits
+
+# Computed OTP data / scrambling parameters.
+# ------------------------------------------------------------------------------
+# DO NOT EDIT: edit fixed parameters above instead.
+# ------------------------------------------------------------------------------
+OTP_SECRET1_PRESENT_CIPHER = Present(OTP_SECRET1_PRESENT_KEY,
+                                     rounds=OTP_SECRET1_PRESENT_NUM_ROUNDS,
+                                     keylen=OTP_SECRET1_PRESENT_KEY_LENGTH)
+OTP_SECRET1_FLASH_ADDR_KEY_SEED_START = 0
+OTP_SECRET1_FLASH_ADDR_KEY_SEED_STOP = (OTP_FLASH_ADDR_KEY_SEED_SIZE //
+                                        OTP_SECRET1_BLOCK_SIZE)
+OTP_SECRET1_FLASH_DATA_KEY_SEED_START = (OTP_FLASH_ADDR_KEY_SEED_SIZE //
+                                         OTP_SECRET1_BLOCK_SIZE)
+OTP_SECRET1_FLASH_DATA_KEY_SEED_STOP = OTP_SECRET1_FLASH_DATA_KEY_SEED_START + (
+    OTP_FLASH_DATA_KEY_SEED_SIZE // OTP_SECRET1_BLOCK_SIZE)
+# ------------------------------------------------------------------------------
+
+# Flash data / scrambling parameters.
+FLASH_ADDR_KEY_SIZE = 128  # bits
+FLASH_DATA_KEY_SIZE = 128  # bits
 FLASH_WORD_SIZE = 64  # bits
 FLASH_ADDR_SIZE = 16  # bits
 FLASH_INTEGRITY_ECC_SIZE = 4  # bits
 FLASH_RELIABILITY_ECC_SIZE = 8  # bits
-GF_OPERAND_B_MASK = (2**FLASH_WORD_SIZE) - 1
-GF_OPERAND_A_MASK = (GF_OPERAND_B_MASK &
-                     ~(0xffff <<
-                       (FLASH_WORD_SIZE - FLASH_ADDR_SIZE))) << FLASH_WORD_SIZE
-PRINCE_NUM_HALF_ROUNDS = 5
+FLASH_PRINCE_NUM_HALF_ROUNDS = 5
 
+# Computed flash data / scrambling parameters.
+# ------------------------------------------------------------------------------
+# DO NOT EDIT: edit fixed parameters above instead.
+# ------------------------------------------------------------------------------
+FLASH_GF_OPERAND_B_MASK = (2**FLASH_WORD_SIZE) - 1
+FLASH_GF_OPERAND_A_MASK = (
+    FLASH_GF_OPERAND_B_MASK &
+    ~(0xffff << (FLASH_WORD_SIZE - FLASH_ADDR_SIZE))) << FLASH_WORD_SIZE
 # Create GF(2^64) with irreducible_polynomial = x^64 + x^4 + x^3 + x + 1
-GF_2_64 = ffield.FField(64,
-                        gen=((0x1 << 64) | (0x1 << 4) | (0x1 << 3) |
-                             (0x1 << 1) | 0x1))
+FLASH_GF_2_64 = ffield.FField(64,
+                              gen=((0x1 << 64) | (0x1 << 4) | (0x1 << 3) |
+                                   (0x1 << 1) | 0x1))
+# ------------------------------------------------------------------------------
+
+
+@dataclass
+class FlashScramblingConfigs:
+    scrambling_enabled: bool = False
+    addr_key_seed: int = None
+    data_key_seed: int = None
+    addr_key: int = None
+    data_key: int = None
 
 
 @functools.lru_cache(maxsize=None)
 def _xex_scramble(data: int, word_addr: int, flash_addr_key: int,
                   flash_data_key: int) -> int:
-    operand_a = ((flash_addr_key & GF_OPERAND_A_MASK) >>
+    operand_a = ((flash_addr_key & FLASH_GF_OPERAND_A_MASK) >>
                  (FLASH_WORD_SIZE - FLASH_ADDR_SIZE)) | word_addr
-    operand_b = flash_addr_key & GF_OPERAND_B_MASK
-    mask = GF_2_64.Multiply(operand_a, operand_b)
+    operand_b = flash_addr_key & FLASH_GF_OPERAND_B_MASK
+    mask = FLASH_GF_2_64.Multiply(operand_a, operand_b)
     masked_data = data ^ mask
     return prince.prince(masked_data, flash_data_key,
-                         PRINCE_NUM_HALF_ROUNDS) ^ mask
+                         FLASH_PRINCE_NUM_HALF_ROUNDS) ^ mask
 
 
-def main(argv: List[str]):
-    # Parse command line args.
-    parser = argparse.ArgumentParser()
-    parser.add_argument("--infile",
-                        "-i",
-                        type=str,
-                        help="Input VMEM file to reformat.")
-    parser.add_argument("--outfile", "-o", type=str, help="Output VMEM file.")
-    parser.add_argument("--scramble",
-                        "-s",
-                        action="store_true",
-                        help="Whether to scramble data or not.")
-    parser.add_argument("--address-key",
-                        type=str,
-                        help="Flash address scrambling key.")
-    parser.add_argument("--data-key",
-                        type=str,
-                        help="Flash address scrambling key.")
-    args = parser.parse_args(argv)
+def _convert_array_2_int(data_array: List[int],
+                         data_size: int,
+                         little_endian=True) -> int:
+    """Converts array of data blocks to an int."""
+    reformatted_data = 0
+    if not little_endian:
+        data_array.reverse()
+    for i, data in enumerate(data_array):
+        reformatted_data |= (data << (i * data_size))
+    return reformatted_data
 
-    # Validate command line args.
-    if args.scramble:
-        if args.address_key is None or args.data_key is None:
-            logging.error(
-                "Address and data keys must be provided when scrambling is"
-                "requested.")
 
-    # Open original VMEM for processing.
+def _get_flash_scrambling_configs(otp_vmem_file: str,
+                                  configs: FlashScramblingConfigs) -> None:
+    # Open OTP VMEM file and read into memory, skipping comment lines.
     try:
-        orig_vmem = Path(args.infile).read_text()
+        otp_vmem = Path(otp_vmem_file).read_text()
     except IOError:
-        raise Exception(f"Unable to open {args.infile}")
+        raise Exception(f"Unable to open {otp_vmem}")
+    otp_vmem_lines = re.findall(r"^@.*$", otp_vmem, flags=re.MULTILINE)
 
-    # Search for lines that contain data, skipping other comment lines.
-    orig_vmem_lines = re.findall(r"^@.*$", orig_vmem, flags=re.MULTILINE)
+    # TODO: Retrieve partition with the flash scrambling enablement flags.
+    configs.scrambling_enabled = False
+
+    # Retrieve SECRET1 partition which contains the flash scrambling key seeds,
+    # stripping ECC bits from each data word when processing.
+    data_blocks_64bit = []
+    data_block_64bit = 0
+    idx = 0
+    for line in otp_vmem_lines:
+        if OTP_SECRET1_RE.search(line):
+            otp_data_word_w_ecc = int(line.split()[1], 16)
+            otp_data_word = otp_data_word_w_ecc & (2**OTP_WORD_SIZE - 1)
+            data_block_64bit |= otp_data_word << (idx * OTP_WORD_SIZE)
+            idx += 1
+            if idx == (64 // OTP_WORD_SIZE):
+                data_blocks_64bit.append(data_block_64bit)
+                data_block_64bit = 0
+                idx = 0
+
+    # Descramble SECRET1 partition data blocks and extract flash scrambling key
+    # seeds. The SECRET1 partition layout looks like:
+    # {FLASH_ADDR_KEY_SEED, FLASH_DATA_KEY_SEED, SRAM_DATA_KEY_SEED, DIGEST}
+    descrambled_data_blocks = list(
+        map(OTP_SECRET1_PRESENT_CIPHER.decrypt, data_blocks_64bit))
+    configs.addr_key_seed = _convert_array_2_int(
+        descrambled_data_blocks[OTP_SECRET1_FLASH_ADDR_KEY_SEED_START:
+                                OTP_SECRET1_FLASH_ADDR_KEY_SEED_STOP],
+        OTP_SECRET1_BLOCK_SIZE)
+    configs.data_key_seed = _convert_array_2_int(
+        descrambled_data_blocks[OTP_SECRET1_FLASH_DATA_KEY_SEED_START:
+                                OTP_SECRET1_FLASH_DATA_KEY_SEED_STOP],
+        OTP_SECRET1_BLOCK_SIZE)
+
+
+def _compute_flash_scrambling_keys(
+        scrambling_configs: FlashScramblingConfigs) -> None:
+    # TODO: implement key computation
+    scrambling_configs.addr_key = 0
+    scrambling_configs.data_key = 0
+
+
+def _reformat_flash_vmem(
+        flash_vmem_file: str,
+        scrambling_configs: FlashScramblingConfigs) -> List[str]:
+    # Open (raw) flash VMEM file and read into memory, skipping comment lines.
+    try:
+        flash_vmem = Path(flash_vmem_file).read_text()
+    except IOError:
+        raise Exception(f"Unable to open {flash_vmem_file}")
+    flash_vmem_lines = re.findall(r"^@.*$", flash_vmem, flags=re.MULTILINE)
 
     # Load project SECDED configuration.
-    config = secded_gen.load_secded_config()
+    ecc_configs = secded_gen.load_secded_config()
 
+    # Add integrity/reliability ECC, and potentially scramble, each flash word.
     reformatted_vmem_lines = []
-    for line in orig_vmem_lines:
+    for line in flash_vmem_lines:
         line_items = line.split()
         reformatted_line = ""
         address = None
@@ -106,18 +187,19 @@
                 data = int(item, 16)
                 # `data_w_intg_ecc` will be in format {ECC bits, data bits}.
                 data_w_intg_ecc, _ = secded_gen.ecc_encode(
-                    config, "hamming", FLASH_WORD_SIZE, data)
+                    ecc_configs, "hamming", FLASH_WORD_SIZE, data)
                 # Due to storage constraints the first nibble of ECC is dropped.
-                data_w_intg_ecc &= 0xFFFFFFFFFFFFFFFFF
-                if args.scramble:
+                data_w_intg_ecc &= 0xF_FFFF_FFFF_FFFF_FFFF
+                if scrambling_configs.scrambling_enabled:
                     intg_ecc = data_w_intg_ecc & (0xF << FLASH_WORD_SIZE)
-                    data = _xex_scramble(data, address, args.flash_addr_key,
-                                         args.flash_data_key)
+                    data = _xex_scramble(data, address,
+                                         scrambling_configs.addr_key,
+                                         scrambling_configs.data_key)
                     data_w_intg_ecc = intg_ecc | data
                 # `data_w_full_ecc` will be in format {reliablity ECC bits,
                 # integrity ECC bits, data bits}.
                 data_w_full_ecc, _ = secded_gen.ecc_encode(
-                    config, "hamming",
+                    ecc_configs, "hamming",
                     FLASH_WORD_SIZE + FLASH_INTEGRITY_ECC_SIZE,
                     data_w_intg_ecc)
                 reformatted_line += f" {data_w_full_ecc:x}"
@@ -125,8 +207,37 @@
         # Append reformatted line to what will be the new output VMEM file.
         reformatted_vmem_lines.append(reformatted_line)
 
+    return reformatted_vmem_lines
+
+
+def main(argv: List[str]):
+    # Parse command line args.
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--in-flash-vmem",
+                        type=str,
+                        help="Input VMEM file to reformat.")
+    parser.add_argument("--in-otp-vmem",
+                        type=str,
+                        help="Input OTP (VMEM) file to retrieve data from.")
+    parser.add_argument("--out-flash-vmem", type=str, help="Output VMEM file.")
+    args = parser.parse_args(argv)
+
+    # Read flash scrambling configurations (including: enablement, address and
+    # data key seeds) directly from OTP VMEM file.
+    scrambling_configs = FlashScramblingConfigs()
+    if args.in_otp_vmem:
+        _get_flash_scrambling_configs(args.in_otp_vmem, scrambling_configs)
+
+    # Compute flash scrambling keys from seeds.
+    if scrambling_configs.scrambling_enabled:
+        _compute_flash_scrambling_keys(scrambling_configs)
+
+    # Reformat flash VMEM file to add integrity/reliablity ECC and scrambling.
+    reformatted_vmem_lines = _reformat_flash_vmem(args.in_flash_vmem,
+                                                  scrambling_configs)
+
     # Write re-formatted output file.
-    with open(args.outfile, "w") as of:
+    with open(args.out_flash_vmem, "w") as of:
         of.writelines(reformatted_vmem_lines)