[prim_prince] Add option to instantiate a registers half-way

Signed-off-by: Michael Schaffner <msf@opentitan.org>
diff --git a/hw/ip/prim/rtl/prim_prince.sv b/hw/ip/prim/rtl/prim_prince.sv
index bd5e50d..655835f 100644
--- a/hw/ip/prim/rtl/prim_prince.sv
+++ b/hw/ip/prim/rtl/prim_prince.sv
@@ -2,15 +2,14 @@
 // Licensed under the Apache License, Version 2.0, see LICENSE for details.
 // SPDX-License-Identifier: Apache-2.0
 //
-// This module is an implementation of the 64bit PRINCE block cipher. It is a
-// fully unrolled combinational implementation with configurable number of
-// rounds. Due to the reflective construction of this cipher, the same circuit
-// can be used for encryption and decryption, as described below. Further, the
-// primitive supports a 32bit block cipher flavor which is not specified in the
-// original paper. It should be noted, however, that the 32bit version is
-// **not** secure and must not be used in a setting where cryptographic cipher
-// strength is required. The 32bit variant is only intended to be used as a
-// lightweight data scrambling device.
+// This module is an implementation of the 64bit PRINCE block cipher. It is a fully unrolled
+// combinational implementation with configurable number of rounds. Optionally, registers for the
+// data and key states can be enabled, if this is required. Due to the reflective construction of
+// this cipher, the same circuit can be used for encryption and decryption, as described below.
+// Further, the primitive supports a 32bit block cipher flavor which is not specified in the
+// original paper. It should be noted, however, that the 32bit version is **not** secure and must
+// not be used in a setting where cryptographic cipher strength is required. The 32bit variant is
+// only intended to be used as a lightweight data scrambling device.
 //
 // See also: prim_present, prim_cipher_pkg
 //
@@ -33,11 +32,20 @@
   parameter int NumRoundsHalf = 5,
   // This primitive uses the new key schedule proposed in https://eprint.iacr.org/2014/656.pdf
   // Setting this parameter to 1 falls back to the original key schedule.
-  parameter bit UseOldKeySched = 1'b0
+  parameter bit UseOldKeySched = 1'b0,
+  // This instantiates a data register halfway in the primitive.
+  parameter bit HalfwayDataReg = 1'b0,
+  // This instantiates a key register halfway in the primitive.
+  parameter bit HalfwayKeyReg = 1'b0
 ) (
+  input                        clk_i,
+  input                        rst_ni,
+
+  input                        valid_i,
   input        [DataWidth-1:0] data_i,
   input        [KeyWidth-1:0]  key_i,
-  input                        dec_i, // set to 1 for decryption
+  input                        dec_i,   // set to 1 for decryption
+  output logic                 valid_o,
   output logic [DataWidth-1:0] data_o
 );
 
@@ -45,26 +53,46 @@
   // key expansion //
   ///////////////////
 
-  logic [DataWidth-1:0] k0, k0_prime, k1, k0_new;
-
+  logic [DataWidth-1:0] k0, k0_prime_d, k1_d, k0_new_d, k0_prime_q, k1_q, k0_new_q;
   always_comb begin : p_key_expansion
-    k0       = key_i[DataWidth-1:0];
-    k0_prime = {k0[0], k0[DataWidth-1:2], k0[DataWidth-1] ^ k0[1]};
-    k1       = key_i[2*DataWidth-1 : DataWidth];
+    k0         = key_i[DataWidth-1:0];
+    k0_prime_d = {k0[0], k0[DataWidth-1:2], k0[DataWidth-1] ^ k0[1]};
+    k1_d       = key_i[2*DataWidth-1 : DataWidth];
 
     // modify key for decryption
     if (dec_i) begin
-      k0       = k0_prime;
-      k0_prime = key_i[DataWidth-1:0];
-      k1       ^= prim_cipher_pkg::PRINCE_ALPHA_CONST[DataWidth-1:0];
+      k0          = k0_prime_d;
+      k0_prime_d  = key_i[DataWidth-1:0];
+      k1_d       ^= prim_cipher_pkg::PRINCE_ALPHA_CONST[DataWidth-1:0];
     end
   end
 
   if (UseOldKeySched) begin : gen_legacy_keyschedule
-    assign k0_new = k1;
+    assign k0_new_d = k1_d;
   end else begin : gen_new_keyschedule
     // improved keyschedule proposed by https://eprint.iacr.org/2014/656.pdf
-    assign k0_new = k0;
+    assign k0_new_d = k0;
+  end
+
+  if (HalfwayKeyReg) begin : gen_key_reg
+    always_ff @(posedge clk_i or negedge rst_ni) begin : p_key_reg
+      if (!rst_ni) begin
+        k1_q       <= '0;
+        k0_prime_q <= '0;
+        k0_new_q   <= '0;
+      end else begin
+        if (valid_i) begin
+          k1_q       <= k1_d;
+          k0_prime_q <= k0_prime_d;
+          k0_new_q   <= k0_new_d;
+        end
+      end
+    end
+  end else begin : gen_no_key_reg
+    // just pass the key through in this case
+    assign k1_q       = k1_d;
+    assign k0_prime_q = k0_prime_d;
+    assign k0_new_q   = k0_new_d;
   end
 
   //////////////
@@ -77,7 +105,7 @@
   // pre-round XOR
   always_comb begin : p_pre_round_xor
     data_state[0] = data_i ^ k0;
-    data_state[0] ^= k1;
+    data_state[0] ^= k1_d;
     data_state[0] ^= prim_cipher_pkg::PRINCE_ROUND_CONST[0][DataWidth-1:0];
   end
 
@@ -105,38 +133,58 @@
     assign data_state_xor = data_state_round ^
                             prim_cipher_pkg::PRINCE_ROUND_CONST[k][DataWidth-1:0];
     // improved keyschedule proposed by https://eprint.iacr.org/2014/656.pdf
-    if (k % 2 == 1) assign data_state[k]  = data_state_xor ^ k0_new;
-    else            assign data_state[k]  = data_state_xor ^ k1;
+    if (k % 2 == 1) assign data_state[k]  = data_state_xor ^ k0_new_d;
+    else            assign data_state[k]  = data_state_xor ^ k1_d;
   end
 
   // middle part
-  logic [DataWidth-1:0] data_state_middle;
+  logic [DataWidth-1:0] data_state_middle_d, data_state_middle_q, data_state_middle;
   if (DataWidth == 64) begin : gen_middle_d64
     always_comb begin : p_middle_d64
-      data_state_middle = prim_cipher_pkg::sbox4_64bit(data_state[NumRoundsHalf],
+      data_state_middle_d = prim_cipher_pkg::sbox4_64bit(data_state[NumRoundsHalf],
           prim_cipher_pkg::PRINCE_SBOX4);
-      data_state_middle = prim_cipher_pkg::prince_mult_prime_64bit(data_state_middle);
+      data_state_middle = prim_cipher_pkg::prince_mult_prime_64bit(data_state_middle_q);
       data_state_middle = prim_cipher_pkg::sbox4_64bit(data_state_middle,
           prim_cipher_pkg::PRINCE_SBOX4_INV);
     end
   end else begin : gen_middle_d32
     always_comb begin : p_middle_d32
-      data_state_middle = prim_cipher_pkg::sbox4_32bit(data_state_middle[NumRoundsHalf],
+      data_state_middle_d = prim_cipher_pkg::sbox4_32bit(data_state_middle[NumRoundsHalf],
           prim_cipher_pkg::PRINCE_SBOX4);
-      data_state_middle = prim_cipher_pkg::prince_mult_prime_32bit(data_state_middle);
+      data_state_middle = prim_cipher_pkg::prince_mult_prime_32bit(data_state_middle_q);
       data_state_middle = prim_cipher_pkg::sbox4_32bit(data_state_middle,
           prim_cipher_pkg::PRINCE_SBOX4_INV);
     end
   end
 
+  if (HalfwayDataReg) begin : gen_data_reg
+    logic valid_q;
+    always_ff @(posedge clk_i or negedge rst_ni) begin : p_data_reg
+      if (!rst_ni) begin
+        valid_q <= 1'b0;
+        data_state_middle_q <= '0;
+      end else begin
+        valid_q <= valid_i;
+        if (valid_i) begin
+          data_state_middle_q <= data_state_middle_d;
+        end
+      end
+    end
+    assign valid_o = valid_q;
+  end else begin : gen_no_data_reg
+    // just pass data through in this case
+    assign data_state_middle_q = data_state_middle_d;
+    assign valid_o = valid_i;
+  end
+
   assign data_state[NumRoundsHalf+1] = data_state_middle;
 
   // backward pass
   for (genvar k = 1; k <= NumRoundsHalf; k++) begin : gen_bwd_pass
     logic [DataWidth-1:0] data_state_xor0, data_state_xor1;
     // improved keyschedule proposed by https://eprint.iacr.org/2014/656.pdf
-    if (k % 2 == 1) assign data_state_xor0 = data_state[NumRoundsHalf+k] ^ k0_new;
-    else            assign data_state_xor0 = data_state[NumRoundsHalf+k] ^ k1;
+    if (k % 2 == 1) assign data_state_xor0 = data_state[NumRoundsHalf+k] ^ k0_new_q;
+    else            assign data_state_xor0 = data_state[NumRoundsHalf+k] ^ k1_q;
     // the construction is reflective, hence the subtraction with NumRoundsHalf
     assign data_state_xor1 = data_state_xor0 ^
                              prim_cipher_pkg::PRINCE_ROUND_CONST[10-NumRoundsHalf+k][DataWidth-1:0];
@@ -165,8 +213,8 @@
   always_comb begin : p_post_round_xor
     data_o  = data_state[2*NumRoundsHalf+1] ^
               prim_cipher_pkg::PRINCE_ROUND_CONST[11][DataWidth-1:0];
-    data_o ^= k1;
-    data_o ^= k0_prime;
+    data_o ^= k1_q;
+    data_o ^= k0_prime_q;
   end
 
   ////////////////