[rom_ctrl] Add cSHAKE hash to generated ROM images

Signed-off-by: Rupert Swarbrick <rswarbrick@lowrisc.org>
diff --git a/hw/ip/rom_ctrl/util/scramble_image.py b/hw/ip/rom_ctrl/util/scramble_image.py
index b5bbf45..adc4e6f 100755
--- a/hw/ip/rom_ctrl/util/scramble_image.py
+++ b/hw/ip/rom_ctrl/util/scramble_image.py
@@ -9,6 +9,7 @@
 from typing import Dict, List
 
 import hjson  # type: ignore
+from Crypto.Hash import cSHAKE256
 
 from mem import MemChunk, MemFile
 
@@ -420,6 +421,50 @@
 
         return MemFile(mem.width, [MemChunk(0, scrambled)])
 
+    def add_hash(self, scr_mem: MemFile) -> None:
+        '''Calculate and insert a cSHAKE256 hash for scr_mem
+
+        This reads all the scrambled data in logical order, except for the last
+        8 words. It then calculates the resulting cSHAKE hash and finally
+        inserts that hash (unscrambled) in as the top 8 words.
+
+        '''
+        # We only support flat memories of the correct length
+        assert len(scr_mem.chunks) == 1
+        assert scr_mem.chunks[0].base_addr == 0
+        assert len(scr_mem.chunks[0].words) == self.rom_size_words
+
+        scr_chunk = scr_mem.chunks[0]
+
+        data_nonce_width = 64 - self._addr_width
+        subst_perm_rounds = 2
+        addr_scr_nonce = self.nonce >> data_nonce_width
+
+        bytes_per_word = 32 // 8
+        num_digest_words = 256 // 32
+
+        # Read out the scrambled data
+        to_hash = b''
+        for log_addr in range(self.rom_size_words - num_digest_words):
+            phy_addr = subst_perm_enc(log_addr, addr_scr_nonce,
+                                      self._addr_width, subst_perm_rounds)
+            scr_word = scr_chunk.words[phy_addr]
+            to_hash += scr_word.to_bytes(64 // 8, byteorder='little')
+
+        # Hash it
+        hash_obj = cSHAKE256.new(data=to_hash,
+                                 custom='ROM_CTRL'.encode('UTF-8'))
+        digest_bytes = hash_obj.read(bytes_per_word * num_digest_words)
+        digest256 = int.from_bytes(digest_bytes, byteorder='little')
+
+        # Insert the hash back into scr_mem
+        for digest_idx in range(num_digest_words):
+            log_addr = self.rom_size_words - num_digest_words + digest_idx
+            phy_addr = subst_perm_enc(log_addr, addr_scr_nonce,
+                                      self._addr_width, subst_perm_rounds)
+            digest_word = (digest256 >> (32 * digest_idx)) & ((1 << 32) - 1)
+            scr_chunk.words[phy_addr] = digest_word
+
 
 def main() -> None:
     parser = argparse.ArgumentParser()
@@ -448,7 +493,8 @@
     # Scramble the memory
     scr_mem = scrambler.scramble40(clr_flat)
 
-    # TODO: Calculate and insert the expected hash here.
+    # Insert the expected hash here to the top 8 words
+    scrambler.add_hash(scr_mem)
 
     scr_mem.write_vmem(args.outfile)