[rom_ctrl] Add cSHAKE hash to generated ROM images
Signed-off-by: Rupert Swarbrick <rswarbrick@lowrisc.org>
diff --git a/hw/ip/rom_ctrl/util/scramble_image.py b/hw/ip/rom_ctrl/util/scramble_image.py
index b5bbf45..adc4e6f 100755
--- a/hw/ip/rom_ctrl/util/scramble_image.py
+++ b/hw/ip/rom_ctrl/util/scramble_image.py
@@ -9,6 +9,7 @@
from typing import Dict, List
import hjson # type: ignore
+from Crypto.Hash import cSHAKE256
from mem import MemChunk, MemFile
@@ -420,6 +421,50 @@
return MemFile(mem.width, [MemChunk(0, scrambled)])
+ def add_hash(self, scr_mem: MemFile) -> None:
+ '''Calculate and insert a cSHAKE256 hash for scr_mem
+
+ This reads all the scrambled data in logical order, except for the last
+ 8 words. It then calculates the resulting cSHAKE hash and finally
+ inserts that hash (unscrambled) in as the top 8 words.
+
+ '''
+ # We only support flat memories of the correct length
+ assert len(scr_mem.chunks) == 1
+ assert scr_mem.chunks[0].base_addr == 0
+ assert len(scr_mem.chunks[0].words) == self.rom_size_words
+
+ scr_chunk = scr_mem.chunks[0]
+
+ data_nonce_width = 64 - self._addr_width
+ subst_perm_rounds = 2
+ addr_scr_nonce = self.nonce >> data_nonce_width
+
+ bytes_per_word = 32 // 8
+ num_digest_words = 256 // 32
+
+ # Read out the scrambled data
+ to_hash = b''
+ for log_addr in range(self.rom_size_words - num_digest_words):
+ phy_addr = subst_perm_enc(log_addr, addr_scr_nonce,
+ self._addr_width, subst_perm_rounds)
+ scr_word = scr_chunk.words[phy_addr]
+ to_hash += scr_word.to_bytes(64 // 8, byteorder='little')
+
+ # Hash it
+ hash_obj = cSHAKE256.new(data=to_hash,
+ custom='ROM_CTRL'.encode('UTF-8'))
+ digest_bytes = hash_obj.read(bytes_per_word * num_digest_words)
+ digest256 = int.from_bytes(digest_bytes, byteorder='little')
+
+ # Insert the hash back into scr_mem
+ for digest_idx in range(num_digest_words):
+ log_addr = self.rom_size_words - num_digest_words + digest_idx
+ phy_addr = subst_perm_enc(log_addr, addr_scr_nonce,
+ self._addr_width, subst_perm_rounds)
+ digest_word = (digest256 >> (32 * digest_idx)) & ((1 << 32) - 1)
+ scr_chunk.words[phy_addr] = digest_word
+
def main() -> None:
parser = argparse.ArgumentParser()
@@ -448,7 +493,8 @@
# Scramble the memory
scr_mem = scrambler.scramble40(clr_flat)
- # TODO: Calculate and insert the expected hash here.
+ # Insert the expected hash here to the top 8 words
+ scrambler.add_hash(scr_mem)
scr_mem.write_vmem(args.outfile)