[util] implement flash scrambling key derivation

The address and data flash scrambling keys are computed from seeds
stored in the OTP, and netlist constants embedded in the RTL. This
implements the same flash scrambling key derivation scheme implemented
in the RTL so flash VMEM images can be pre-scrabled for backdoor loading
in DV.

Signed-off-by: Timothy Trippel <ttrippel@google.com>
diff --git a/util/design/gen-flash-img.py b/util/design/gen-flash-img.py
index 90b580a..754dbac 100755
--- a/util/design/gen-flash-img.py
+++ b/util/design/gen-flash-img.py
@@ -15,6 +15,7 @@
 import re
 import sys
 from dataclasses import dataclass
+from enum import Enum
 from pathlib import Path
 from typing import List
 
@@ -74,6 +75,24 @@
 # ------------------------------------------------------------------------------
 
 
+class FlashScramblingKeyType(Enum):
+    ADDRESS = 1
+    DATA = 2
+
+
+# Flash scrambling key computation parameters.
+KEY_TYPE_2_IV = {
+    FlashScramblingKeyType.ADDRESS: 0x97883548F536F544,
+    FlashScramblingKeyType.DATA: 0xC5F5C1D8AEF35040,
+}
+KEY_TYPE_2_FINALIZATION_CONST = {
+    FlashScramblingKeyType.ADDRESS: 0x39AED01B4B2277312E9480868216A281,
+    FlashScramblingKeyType.DATA: 0x1D888AC88259C44AAB06CB4A4C65A7EA,
+}
+FLASH_KEY_COMPUTATION_KEY_SIZE = OTP_FLASH_ADDR_KEY_SEED_SIZE // 2
+FLASH_KEY_COMPUTATION_KEY_MASK = (2**FLASH_KEY_COMPUTATION_KEY_SIZE) - 1
+
+
 @dataclass
 class FlashScramblingConfigs:
     scrambling_enabled: bool = False
@@ -150,11 +169,36 @@
         OTP_SECRET1_BLOCK_SIZE)
 
 
+def _compute_flash_scrambling_key(scrambling_configs: FlashScramblingConfigs,
+                                  key_type: FlashScramblingKeyType) -> int:
+    if key_type == FlashScramblingKeyType.ADDRESS:
+        key_seed = scrambling_configs.addr_key_seed
+    else:
+        key_seed = scrambling_configs.data_key_seed
+    full_key = 0
+    for i in range(2):
+        round_1_present_key = (key_seed >>
+                               (FLASH_KEY_COMPUTATION_KEY_SIZE *
+                                i)) & FLASH_KEY_COMPUTATION_KEY_MASK
+        key_half = 0
+        for j in range(2):
+            if j == 0:
+                cipher = Present(round_1_present_key)
+                key_half = cipher.encrypt(
+                    KEY_TYPE_2_IV[key_type]) ^ KEY_TYPE_2_IV[key_type]
+            else:
+                cipher = Present(KEY_TYPE_2_FINALIZATION_CONST[key_type])
+                key_half = cipher.encrypt(key_half) ^ key_half
+        full_key |= key_half << (64 * i)
+    return full_key
+
+
 def _compute_flash_scrambling_keys(
         scrambling_configs: FlashScramblingConfigs) -> None:
-    # TODO: implement key computation
-    scrambling_configs.addr_key = 0
-    scrambling_configs.data_key = 0
+    scrambling_configs.addr_key = _compute_flash_scrambling_key(
+        scrambling_configs, FlashScramblingKeyType.ADDRESS)
+    scrambling_configs.data_key = _compute_flash_scrambling_key(
+        scrambling_configs, FlashScramblingKeyType.DATA)
 
 
 def _reformat_flash_vmem(