[util] check flash scrambling enablement before processing VMEM

The flash scrambling enablement flag is set in OTP and read by ROM to
enable flash scrambling. If this flag is set, then the flash VMEM
processing script should scramble the flash data.

Signed-off-by: Timothy Trippel <ttrippel@google.com>
diff --git a/util/design/gen-flash-img.py b/util/design/gen-flash-img.py
index 754dbac..c8e1eba 100755
--- a/util/design/gen-flash-img.py
+++ b/util/design/gen-flash-img.py
@@ -25,8 +25,13 @@
 import prince
 import secded_gen
 
+MUBI4_TRUE = 0x6
+
 # Fixed OTP data / scrambling parameters.
 OTP_WORD_SIZE = 16  # bits
+OTP_FLASH_DATA_DEFAULT_CFG_RE = re.compile(
+    r"CREATOR_SW_CFG: CREATOR_SW_CFG_FLASH_DATA_DEFAULT_CFG")
+OTP_FLASH_DATA_DEFAULT_CFG_BLOCK_SIZE = 32  # bits
 OTP_SECRET1_RE = re.compile(r"SECRET1")
 OTP_SECRET1_BLOCK_SIZE = 64  # bits
 OTP_SECRET1_PRESENT_KEY = 0x5703C3EB2BB563689E00A67814EFBDE8
@@ -34,22 +39,7 @@
 OTP_SECRET1_PRESENT_NUM_ROUNDS = 32
 OTP_FLASH_ADDR_KEY_SEED_SIZE = 256  # bits
 OTP_FLASH_DATA_KEY_SEED_SIZE = 256  # bits
-
-# Computed OTP data / scrambling parameters.
-# ------------------------------------------------------------------------------
-# DO NOT EDIT: edit fixed parameters above instead.
-# ------------------------------------------------------------------------------
-OTP_SECRET1_PRESENT_CIPHER = Present(OTP_SECRET1_PRESENT_KEY,
-                                     rounds=OTP_SECRET1_PRESENT_NUM_ROUNDS,
-                                     keylen=OTP_SECRET1_PRESENT_KEY_LENGTH)
 OTP_SECRET1_FLASH_ADDR_KEY_SEED_START = 0
-OTP_SECRET1_FLASH_ADDR_KEY_SEED_STOP = (OTP_FLASH_ADDR_KEY_SEED_SIZE //
-                                        OTP_SECRET1_BLOCK_SIZE)
-OTP_SECRET1_FLASH_DATA_KEY_SEED_START = (OTP_FLASH_ADDR_KEY_SEED_SIZE //
-                                         OTP_SECRET1_BLOCK_SIZE)
-OTP_SECRET1_FLASH_DATA_KEY_SEED_STOP = OTP_SECRET1_FLASH_DATA_KEY_SEED_START + (
-    OTP_FLASH_DATA_KEY_SEED_SIZE // OTP_SECRET1_BLOCK_SIZE)
-# ------------------------------------------------------------------------------
 
 # Flash data / scrambling parameters.
 FLASH_ADDR_KEY_SIZE = 128  # bits
@@ -60,20 +50,6 @@
 FLASH_RELIABILITY_ECC_SIZE = 8  # bits
 FLASH_PRINCE_NUM_HALF_ROUNDS = 5
 
-# Computed flash data / scrambling parameters.
-# ------------------------------------------------------------------------------
-# DO NOT EDIT: edit fixed parameters above instead.
-# ------------------------------------------------------------------------------
-FLASH_GF_OPERAND_B_MASK = (2**FLASH_WORD_SIZE) - 1
-FLASH_GF_OPERAND_A_MASK = (
-    FLASH_GF_OPERAND_B_MASK &
-    ~(0xffff << (FLASH_WORD_SIZE - FLASH_ADDR_SIZE))) << FLASH_WORD_SIZE
-# Create GF(2^64) with irreducible_polynomial = x^64 + x^4 + x^3 + x + 1
-FLASH_GF_2_64 = ffield.FField(64,
-                              gen=((0x1 << 64) | (0x1 << 4) | (0x1 << 3) |
-                                   (0x1 << 1) | 0x1))
-# ------------------------------------------------------------------------------
-
 
 class FlashScramblingKeyType(Enum):
     ADDRESS = 1
@@ -81,6 +57,21 @@
 
 
 # Flash scrambling key computation parameters.
+# ------------------------------------------------------------------------------
+# DO NOT EDIT: edit fixed parameters above instead.
+# ------------------------------------------------------------------------------
+# Computed OTP data / scrambling parameters.
+OTP_SECRET1_PRESENT_CIPHER = Present(OTP_SECRET1_PRESENT_KEY,
+                                     rounds=OTP_SECRET1_PRESENT_NUM_ROUNDS,
+                                     keylen=OTP_SECRET1_PRESENT_KEY_LENGTH)
+OTP_SECRET1_FLASH_ADDR_KEY_SEED_STOP = (OTP_FLASH_ADDR_KEY_SEED_SIZE //
+                                        OTP_SECRET1_BLOCK_SIZE)
+OTP_SECRET1_FLASH_DATA_KEY_SEED_START = (OTP_FLASH_ADDR_KEY_SEED_SIZE //
+                                         OTP_SECRET1_BLOCK_SIZE)
+OTP_SECRET1_FLASH_DATA_KEY_SEED_STOP = OTP_SECRET1_FLASH_DATA_KEY_SEED_START + (
+    OTP_FLASH_DATA_KEY_SEED_SIZE // OTP_SECRET1_BLOCK_SIZE)
+
+# Computed flash data / scrambling parameters.
 KEY_TYPE_2_IV = {
     FlashScramblingKeyType.ADDRESS: 0x97883548F536F544,
     FlashScramblingKeyType.DATA: 0xC5F5C1D8AEF35040,
@@ -91,6 +82,20 @@
 }
 FLASH_KEY_COMPUTATION_KEY_SIZE = OTP_FLASH_ADDR_KEY_SEED_SIZE // 2
 FLASH_KEY_COMPUTATION_KEY_MASK = (2**FLASH_KEY_COMPUTATION_KEY_SIZE) - 1
+FLASH_GF_OPERAND_B_MASK = (2**FLASH_WORD_SIZE) - 1
+FLASH_GF_OPERAND_A_MASK = (
+    FLASH_GF_OPERAND_B_MASK &
+    ~(0xffff << (FLASH_WORD_SIZE - FLASH_ADDR_SIZE))) << FLASH_WORD_SIZE
+# Create GF(2^64) with irreducible_polynomial = x^64 + x^4 + x^3 + x + 1
+FLASH_GF_2_64 = ffield.FField(64,
+                              gen=((0x1 << 64) | (0x1 << 4) | (0x1 << 3) |
+                                   (0x1 << 1) | 0x1))
+
+# Format string for generating new VMEM file.
+FLASH_VMEM_WORD_SIZE = (FLASH_WORD_SIZE + FLASH_INTEGRITY_ECC_SIZE +
+                        FLASH_RELIABILITY_ECC_SIZE)
+VMEM_FORMAT_STR = " {:0" + f"{FLASH_VMEM_WORD_SIZE // 4}" + "X}"
+# ------------------------------------------------------------------------------
 
 
 @dataclass
@@ -132,40 +137,61 @@
     try:
         otp_vmem = Path(otp_vmem_file).read_text()
     except IOError:
-        raise Exception(f"Unable to open {otp_vmem}")
+        raise Exception(f"Unable to open {otp_vmem_file}")
     otp_vmem_lines = re.findall(r"^@.*$", otp_vmem, flags=re.MULTILINE)
 
-    # TODO: Retrieve partition with the flash scrambling enablement flags.
-    configs.scrambling_enabled = False
-
-    # Retrieve SECRET1 partition which contains the flash scrambling key seeds,
-    # stripping ECC bits from each data word when processing.
-    data_blocks_64bit = []
-    data_block_64bit = 0
+    # Retrieve OTP data from the following partitions:
+    # - CREATOR_SW_CFG: which contains the flash scramble enablement flag, and
+    # - SECRET1: which contains the flash scrambling key seeds.
+    # Note, we strip ECC bits from each data word when processing.
+    flash_data_default_cfg = None
+    secret1_data_blocks = []
+    otp_data_block = 0
     idx = 0
     for line in otp_vmem_lines:
-        if OTP_SECRET1_RE.search(line):
+        if (OTP_FLASH_DATA_DEFAULT_CFG_RE.search(line) or
+                OTP_SECRET1_RE.search(line)):
             otp_data_word_w_ecc = int(line.split()[1], 16)
             otp_data_word = otp_data_word_w_ecc & (2**OTP_WORD_SIZE - 1)
-            data_block_64bit |= otp_data_word << (idx * OTP_WORD_SIZE)
+            otp_data_block |= otp_data_word << (idx * OTP_WORD_SIZE)
             idx += 1
-            if idx == (64 // OTP_WORD_SIZE):
-                data_blocks_64bit.append(data_block_64bit)
-                data_block_64bit = 0
-                idx = 0
+            if OTP_FLASH_DATA_DEFAULT_CFG_RE.search(line):
+                if idx == (OTP_FLASH_DATA_DEFAULT_CFG_BLOCK_SIZE //
+                           OTP_WORD_SIZE):
+                    flash_data_default_cfg = otp_data_block & 0xff
+                    # If flash data scrambling is disabled, then we can return
+                    # early to save execution time.
+                    if flash_data_default_cfg != MUBI4_TRUE:
+                        configs.scrambling_enabled = False
+                        return
+                    configs.scrambling_enabled = True
+                    otp_data_block = 0
+                    idx = 0
+            if OTP_SECRET1_RE.search(line):
+                if idx == (OTP_SECRET1_BLOCK_SIZE // OTP_WORD_SIZE):
+                    secret1_data_blocks.append(otp_data_block)
+                    otp_data_block = 0
+                    idx = 0
+
+    # Check we found the data we were looking for in the OTP image.
+    if flash_data_default_cfg is None:
+        raise RuntimeError(
+            "Cannot read flash scrambling enablement state from OTP.")
+    if not secret1_data_blocks:
+        raise RuntimeError("Cannot read flash scrambling key seeds from OTP.")
 
     # Descramble SECRET1 partition data blocks and extract flash scrambling key
     # seeds. The SECRET1 partition layout looks like:
     # {FLASH_ADDR_KEY_SEED, FLASH_DATA_KEY_SEED, SRAM_DATA_KEY_SEED, DIGEST}
-    descrambled_data_blocks = list(
-        map(OTP_SECRET1_PRESENT_CIPHER.decrypt, data_blocks_64bit))
+    descrambled_secret1_blocks = list(
+        map(OTP_SECRET1_PRESENT_CIPHER.decrypt, secret1_data_blocks))
     configs.addr_key_seed = _convert_array_2_int(
-        descrambled_data_blocks[OTP_SECRET1_FLASH_ADDR_KEY_SEED_START:
-                                OTP_SECRET1_FLASH_ADDR_KEY_SEED_STOP],
+        descrambled_secret1_blocks[OTP_SECRET1_FLASH_ADDR_KEY_SEED_START:
+                                   OTP_SECRET1_FLASH_ADDR_KEY_SEED_STOP],
         OTP_SECRET1_BLOCK_SIZE)
     configs.data_key_seed = _convert_array_2_int(
-        descrambled_data_blocks[OTP_SECRET1_FLASH_DATA_KEY_SEED_START:
-                                OTP_SECRET1_FLASH_DATA_KEY_SEED_STOP],
+        descrambled_secret1_blocks[OTP_SECRET1_FLASH_DATA_KEY_SEED_START:
+                                   OTP_SECRET1_FLASH_DATA_KEY_SEED_STOP],
         OTP_SECRET1_BLOCK_SIZE)
 
 
@@ -246,7 +272,8 @@
                     ecc_configs, "hamming",
                     FLASH_WORD_SIZE + FLASH_INTEGRITY_ECC_SIZE,
                     data_w_intg_ecc)
-                reformatted_line += f" {data_w_full_ecc:x}"
+                reformatted_line += str.format(VMEM_FORMAT_STR,
+                                               data_w_full_ecc)
 
         # Append reformatted line to what will be the new output VMEM file.
         reformatted_vmem_lines.append(reformatted_line)
@@ -282,7 +309,7 @@
 
     # Write re-formatted output file.
     with open(args.out_flash_vmem, "w") as of:
-        of.writelines(reformatted_vmem_lines)
+        of.write("\n".join(reformatted_vmem_lines))
 
 
 if __name__ == "__main__":