[aes] Improve DOM S-Box

First experiments on FPGA revealed that the previous design suffered from
substantial leakage resulting from mixing new input data and randomness
with previous intermediate results. Basically, the output of the DOM S-Box
is only valid in the 5th cycle. Before that, outputs shouldn't toggle and
internal nodes should only toggle when the corresponding stage is actually
evaluated, i.e. if the data produced during that cycle contributes to the
output.

To counter those issues, this commit adds pipeline registers to the DOM
multiplier primitives (optional according to paper) and inserts additional
register stages inside the GF(2^8) and GF(2^4) inverters to only present
new input data to the multipliers when they actually need it. Similarly,
fresh randomness is only presented to the multipliers when it is actually
needed. This requires local PRD buffering but at the same time allows to
reduce the width of the masking PRNG by a factor of 4. Also, the PRNG is
operated during 4 instead of just 1 cycle per round which is beneficial
in terms of SCA resistance.

Signed-off-by: Pirmin Vogel <vogelpi@lowrisc.org>
diff --git a/hw/ip/aes/pre_dv/aes_sbox_tb/rtl/aes_sbox_tb.sv b/hw/ip/aes/pre_dv/aes_sbox_tb/rtl/aes_sbox_tb.sv
index 5ab986a..88ff912 100644
--- a/hw/ip/aes/pre_dv/aes_sbox_tb/rtl/aes_sbox_tb.sv
+++ b/hw/ip/aes/pre_dv/aes_sbox_tb/rtl/aes_sbox_tb.sv
@@ -51,66 +51,77 @@
   );
 
   // Mask Generation
-  parameter int unsigned WidthPRDSBoxCanrightMasked        = 8;
-  parameter int unsigned WidthPRDSBoxCanrightMaskedNoreuse = 18;
-  parameter int unsigned WidthPRDSBoxDOM                   = 28;
+  logic  [7:0] masked_stimulus;
+  logic  [7:0] in_mask;
 
-  logic                      [7:0] masked_stimulus;
-  logic                      [7:0] in_mask;
+  logic  [7:0] masked_response [NUM_SBOX_IMPLS_MASKED];
+  logic  [7:0] out_mask [NUM_SBOX_IMPLS_MASKED];
 
-  logic                      [7:0] masked_response [NUM_SBOX_IMPLS_MASKED];
-  logic                      [7:0] out_mask [NUM_SBOX_IMPLS_MASKED];
+  logic [31:0] mask;
+  logic [23:0] unused_mask;
 
-  logic                     [63:0] tmp;
-  logic [63-(WidthPRDSBoxDOM+8):0] unused_tmp;
-  logic      [WidthPRDSBoxDOM-1:0] prd_masking;
-
-  always_ff @(posedge clk_i or negedge rst_ni) begin : reg_tmp
+  always_ff @(posedge clk_i or negedge rst_ni) begin : reg_mask
     if (!rst_ni) begin
-      tmp <= 64'hAAAFF;
+      mask <= 32'hAAFF;
     end else if (dom_done) begin
-      tmp <= {$random, $random};
+      mask <= $random;
     end
   end
-  assign in_mask     = tmp[7:0];
-  assign prd_masking = tmp[8 +: WidthPRDSBoxDOM];
-  assign unused_tmp  = tmp[63:WidthPRDSBoxDOM+8];
+  assign in_mask     = mask[7:0];
+  assign unused_mask = mask[31:8];
 
   assign masked_stimulus = stimulus ^ in_mask;
 
+  // PRD Generation
+  parameter int unsigned WidthPRDSBoxCanrightMasked        = 8;
+  parameter int unsigned WidthPRDSBoxCanrightMaskedNoreuse = 18;
+  parameter int unsigned WidthPRDSBoxDOM                   = 8;
+
+  logic                                   [31:0] prd;
+  logic [31-WidthPRDSBoxCanrightMaskedNoreuse:0] unused_prd;
+
+  always_ff @(posedge clk_i or negedge rst_ni) begin : reg_prd
+    if (!rst_ni) begin
+      prd <= 32'h4321;
+    end else begin
+      prd <= {$random};
+    end
+  end
+  assign unused_prd = prd[31:WidthPRDSBoxCanrightMaskedNoreuse];
+
   // Instantiate Masked SBox Implementations
   aes_sbox_canright_masked_noreuse aes_sbox_canright_masked_noreuse (
-    .op_i   ( op                                                 ),
-    .data_i ( masked_stimulus                                    ),
-    .mask_i ( in_mask                                            ),
-    .prd_i  ( prd_masking[WidthPRDSBoxCanrightMaskedNoreuse-1:0] ),
-    .data_o ( masked_response[0]                                 ),
-    .mask_o ( out_mask[0]                                        )
+    .op_i   ( op                                         ),
+    .data_i ( masked_stimulus                            ),
+    .mask_i ( in_mask                                    ),
+    .prd_i  ( prd[WidthPRDSBoxCanrightMaskedNoreuse-1:0] ),
+    .data_o ( masked_response[0]                         ),
+    .mask_o ( out_mask[0]                                )
   );
 
   aes_sbox_canright_masked aes_sbox_canright_masked (
-    .op_i   ( op                                          ),
-    .data_i ( masked_stimulus                             ),
-    .mask_i ( in_mask                                     ),
-    .prd_i  ( prd_masking[WidthPRDSBoxCanrightMasked-1:0] ),
-    .data_o ( masked_response[1]                          ),
-    .mask_o ( out_mask[1]                                 )
+    .op_i   ( op                                  ),
+    .data_i ( masked_stimulus                     ),
+    .mask_i ( in_mask                             ),
+    .prd_i  ( prd[WidthPRDSBoxCanrightMasked-1:0] ),
+    .data_o ( masked_response[1]                  ),
+    .mask_o ( out_mask[1]                         )
   );
 
   // Instantiate DOM SBox Implementation
   logic dom_done;
   aes_sbox_dom aes_sbox_dom (
-    .clk_i     ( clk_i                            ),
-    .rst_ni    ( rst_ni                           ),
-    .en_i      ( 1'b1                             ),
-    .out_req_o ( dom_done                         ),
-    .out_ack_i ( 1'b1                             ),
-    .op_i      ( op                               ),
-    .data_i    ( masked_stimulus                  ),
-    .mask_i    ( in_mask                          ),
-    .prd_i     ( prd_masking[WidthPRDSBoxDOM-1:0] ),
-    .data_o    ( masked_response[2]               ),
-    .mask_o    ( out_mask[2]                      )
+    .clk_i     ( clk_i                    ),
+    .rst_ni    ( rst_ni                   ),
+    .en_i      ( 1'b1                     ),
+    .out_req_o ( dom_done                 ),
+    .out_ack_i ( 1'b1                     ),
+    .op_i      ( op                       ),
+    .data_i    ( masked_stimulus          ),
+    .mask_i    ( in_mask                  ),
+    .prd_i     ( prd[WidthPRDSBoxDOM-1:0] ),
+    .data_o    ( masked_response[2]       ),
+    .mask_o    ( out_mask[2]              )
   );
 
   // Unmask responses
diff --git a/hw/ip/aes/rtl/aes_cipher_control.sv b/hw/ip/aes/rtl/aes_cipher_control.sv
index 5e6e37d..95b4532 100644
--- a/hw/ip/aes/rtl/aes_cipher_control.sv
+++ b/hw/ip/aes/rtl/aes_cipher_control.sv
@@ -8,9 +8,10 @@
 
 `include "prim_assert.sv"
 
-module aes_cipher_control
+module aes_cipher_control import aes_pkg::*;
 #(
-  parameter bit Masking = 0
+  parameter bit         Masking  = 0,
+  parameter sbox_impl_e SBoxImpl = SBoxImplLut
 ) (
   input  logic                    clk_i,
   input  logic                    rst_ni,
@@ -66,8 +67,6 @@
   output aes_pkg::round_key_sel_e round_key_sel_o
 );
 
-  import aes_pkg::*;
-
   // Types
   // $ ./sparse-fsm-encode.py -d 3 -m 7 -n 6 \
   //      -s 31468618 --language=sv
@@ -111,6 +110,7 @@
   logic       key_clear_d, key_clear_q;
   logic       data_out_clear_d, data_out_clear_q;
   logic       prng_reseed_done_d, prng_reseed_done_q;
+  logic       advance;
 
   // cfg_valid_i is used for gating assertions only.
   logic       unused_cfg_valid;
@@ -155,6 +155,7 @@
     key_clear_d          = key_clear_q;
     data_out_clear_d     = data_out_clear_q;
     prng_reseed_done_d   = prng_reseed_done_q | prng_reseed_ack_i;
+    advance              = 1'b0;
 
     // Alert
     alert_o              = 1'b0;
@@ -225,18 +226,20 @@
         prng_reseed_done_d = 1'b0;
 
         // AES-256 has two round keys available right from beginning. Pseudo-random data is
-        // required by KeyExpand only, and only if it is actually advancing.
+        // required by KeyExpand only.
         if (key_len_i != AES_256) begin
           // Advance in sync with KeyExpand. Based on the S-Box implementation, it can take
-          // multiple cycles to finish. Wait for handshake.
+          // multiple cycles to finish. Wait for handshake. The DOM S-Boxes take fresh PRD
+          // in every cycle except the last.
+          advance         = key_expand_out_req_i;
+          prng_update_o   = (SBoxImpl == SBoxImplDom) ? ~advance : Masking;
           key_expand_en_o = 1'b1;
-          if (key_expand_out_req_i) begin
+          if (advance) begin
             key_expand_out_ack_o = 1'b1;
             state_we_o           = ~dec_key_gen_q;
             key_full_we_o        = 1'b1;
             rnd_ctr_d            = rnd_ctr_q     + 4'b0001;
             rnd_ctr_rem_d        = rnd_ctr_rem_q - 4'b0001;
-            prng_update_o        = Masking;
             aes_cipher_ctrl_ns   = ROUND;
           end
         end else begin
@@ -265,10 +268,13 @@
         round_key_sel_o = (op_i == CIPH_FWD) ? ROUND_KEY_DIRECT : ROUND_KEY_MIXED;
 
         // Advance in sync with SubBytes and KeyExpand. Based on the S-Box implementation, both can
-        // take multiple cycles to finish. Wait for handshake.
+        // take multiple cycles to finish. Wait for handshake. Make the masking PRNG advance every
+        // cycle. The DOM S-Boxes take fresh PRD in every cycle except the last.
+        advance         = (dec_key_gen_q | sub_bytes_out_req_i) & key_expand_out_req_i;
+        prng_update_o   = (SBoxImpl == SBoxImplDom) ? ~advance : Masking;
         sub_bytes_en_o  = ~dec_key_gen_q;
         key_expand_en_o = 1'b1;
-        if ((dec_key_gen_q || sub_bytes_out_req_i) && key_expand_out_req_i) begin
+        if (advance) begin
           sub_bytes_out_ack_o  = ~dec_key_gen_q;
           key_expand_out_ack_o = 1'b1;
 
@@ -279,11 +285,6 @@
           rnd_ctr_d     = rnd_ctr_q     + 4'b0001;
           rnd_ctr_rem_d = rnd_ctr_rem_q - 4'b0001;
 
-          // Make the masking PRNG advance once per round only. Updating it while waiting for key
-          // expand would cause the SBoxes to be re-evaluated, thereby creating additional SCA
-          // leakage.
-          prng_update_o = Masking;
-
           // Are we doing the last regular round?
           if (rnd_ctr_q == num_rounds_regular) begin
             aes_cipher_ctrl_ns = FINISH;
@@ -330,18 +331,20 @@
         // - the masking PRNG has been reseeded (if masking is used), and
         // - all mux selector signals are valid (don't release data in that case of errors).
         // Perform both handshakes simultaneously.
+        advance        = dec_key_gen_q | sub_bytes_out_req_i;
         sub_bytes_en_o = ~dec_key_gen_q;
-        out_valid_o    = (dec_key_gen_q | sub_bytes_out_req_i) & (Masking == prng_reseed_done_q) &
-            ~mux_sel_err_i;
+        out_valid_o    = advance & (Masking == prng_reseed_done_q) & ~mux_sel_err_i;
+        // When using DOM S-Boxes, make the masking PRNG advance every cycle until the output is
+        // ready. For other S-Boxes, make it advance once only. Updating it while being stalled
+        // would cause non-DOM S-Boxes to be re-evaluated, thereby creating additional SCA leakage.
+        prng_update_o  = (SBoxImpl == SBoxImplDom) ? ~advance                  :
+                          Masking                  ? out_valid_o & out_ready_i : 1'b0;
         if (out_valid_o && out_ready_i) begin
           sub_bytes_out_ack_o = ~dec_key_gen_q;
 
           // Clear the state.
           state_we_o          = 1'b1;
           crypt_d             = 1'b0;
-          // Make the masking PRNG advance once only. Updating it while being stalled would
-          // cause the SBoxes to be re-evaluated, thereby creating additional SCA leakage.
-          prng_update_o       = Masking;
           // If we were generating the decryption key and didn't get the handshake in the last
           // regular round, we should clear dec_key_gen now.
           dec_key_gen_d       = 1'b0;
diff --git a/hw/ip/aes/rtl/aes_cipher_core.sv b/hw/ip/aes/rtl/aes_cipher_core.sv
index 3711140..8a213c8 100644
--- a/hw/ip/aes/rtl/aes_cipher_core.sv
+++ b/hw/ip/aes/rtl/aes_cipher_core.sv
@@ -480,7 +480,8 @@
 
   // Control
   aes_cipher_control #(
-    .Masking ( Masking )
+    .Masking  ( Masking  ),
+    .SBoxImpl ( SBoxImpl )
   ) u_aes_cipher_control (
     .clk_i                ( clk_i               ),
     .rst_ni               ( rst_ni              ),
diff --git a/hw/ip/aes/rtl/aes_pkg.sv b/hw/ip/aes/rtl/aes_pkg.sv
index 86158d1..f6640f1 100644
--- a/hw/ip/aes/rtl/aes_pkg.sv
+++ b/hw/ip/aes/rtl/aes_pkg.sv
@@ -8,14 +8,14 @@
 
 // Widths of signals carrying pseudo-random data for clearing and masking and purposes
 parameter int unsigned WidthPRDClearing = 64;
-parameter int unsigned WidthPRDSBox     = 28; // Number PRD bits per S-Box. This includes the
+parameter int unsigned WidthPRDSBox     = 8;  // Number PRD bits per S-Box. This includes the
                                               // 8 bits for the output mask when using any of the
                                               // masked Canright S-Box implementations.
 parameter int unsigned WidthPRDData     = 16*WidthPRDSBox; // 16 S-Boxes for the data path
 parameter int unsigned WidthPRDKey      = 4*WidthPRDSBox;  // 4 S-Boxes for the key expand
 parameter int unsigned WidthPRDMasking  = WidthPRDData + WidthPRDKey;
 
-parameter int unsigned ChunkSizePRDMasking = WidthPRDMasking/10;
+parameter int unsigned ChunkSizePRDMasking = WidthPRDMasking/5;
 
 // Clearing PRNG default LFSR seed and permutation
 // These LFSR parameters have been generated with
@@ -33,22 +33,18 @@
 // We use a single seed that is split down into chunks internally. All LFSR chunks use the same
 // permutation.
 // These LFSR parameters have been generated with
-// $ util/design/gen-lfsr-seed.py --width 560 --seed 31468618 --prefix "Masking"
-parameter int MaskingLfsrWidth = 560; // = WidthPRDMasking = WidthPRDSBox * (16 + 4)
+// $ util/design/gen-lfsr-seed.py --width 160 --seed 31468618 --prefix "Masking"
+parameter int MaskingLfsrWidth = 160; // = WidthPRDMasking = WidthPRDSBox * (16 + 4)
 typedef logic [MaskingLfsrWidth-1:0] masking_lfsr_seed_t;
-parameter masking_lfsr_seed_t RndCnstMaskingLfsrSeedDefault = {
-  280'h53813d65392c83c01ea5d8be84f1e258891711849a075a71f35fe9b31605f9077a6b75,
-  280'h8a442031e1c4616ea343ec153282a30c132b5723c5a4cf4743b3c7c32d580f74f1713a
-};
+parameter masking_lfsr_seed_t RndCnstMaskingLfsrSeedDefault =
+  160'hc132b5723c5a4cf4743b3c7c32d580f74f1713a;
 
 // These LFSR parameters have been generated with
-// $ util/design/gen-lfsr-seed.py --width 56 --seed 31468618 --prefix "MskgChunk"
-parameter int MskgChunkLfsrWidth = 56; // = ChunkSizePRDMasking = WidthPRDMasking/10
+// $ util/design/gen-lfsr-seed.py --width 32 --seed 31468618 --prefix "MskgChunk"
+parameter int MskgChunkLfsrWidth = 32; // = ChunkSizePRDMasking = WidthPRDMasking/5
 typedef logic [MskgChunkLfsrWidth-1:0][$clog2(MskgChunkLfsrWidth)-1:0] mskg_chunk_lfsr_perm_t;
-parameter mskg_chunk_lfsr_perm_t RndCnstMskgChunkLfsrPermDefault = {
-  80'h61e8c17eab6c959af0bc,
-  256'h09e6cf18694b61b24c75f40902f5395b9a35c8a82b726450f80459d31b143211
-};
+parameter mskg_chunk_lfsr_perm_t RndCnstMskgChunkLfsrPermDefault =
+  160'heb3749dc187e7434d7f62a3d251e1c5b8cd10491;
 
 typedef enum integer {
   SBoxImplLut,                   // Unmasked LUT-based S-Box
diff --git a/hw/ip/aes/rtl/aes_sbox.sv b/hw/ip/aes/rtl/aes_sbox.sv
index ac0e2cd..92e4dd2 100644
--- a/hw/ip/aes/rtl/aes_sbox.sv
+++ b/hw/ip/aes/rtl/aes_sbox.sv
@@ -4,6 +4,8 @@
 //
 // AES SBox
 
+`include "prim_assert.sv"
+
 module aes_sbox import aes_pkg::*;
 #(
   parameter sbox_impl_e SBoxImpl = SBoxImplLut
@@ -59,20 +61,28 @@
   end else begin : gen_sbox_masked
 
     if (SBoxImpl == SBoxImplDom) begin : gen_sbox_dom
+      // Tie off unused inputs.
+      if (WidthPRDSBox > 8) begin : gen_unused_prd
+        logic [WidthPRDSBox-1-8:0] unused_prd;
+        assign unused_prd = prd_i[WidthPRDSBox-1:8];
+      end
+
       aes_sbox_dom u_aes_sbox (
-        .clk_i      ( clk_i       ),
-        .rst_ni     ( rst_ni      ),
-        .en_i       ( en_i        ),
-        .out_req_o  ( out_req_o   ),
-        .out_ack_i  ( out_ack_i   ),
-        .op_i       ( op_i        ),
-        .data_i     ( data_i      ),
-        .mask_i     ( mask_i      ),
-        .prd_i      ( prd_i[27:0] ),
-        .data_o     ( data_o      ),
-        .mask_o     ( mask_o      )
+        .clk_i      ( clk_i      ),
+        .rst_ni     ( rst_ni     ),
+        .en_i       ( en_i       ),
+        .out_req_o  ( out_req_o  ),
+        .out_ack_i  ( out_ack_i  ),
+        .op_i       ( op_i       ),
+        .data_i     ( data_i     ),
+        .mask_i     ( mask_i     ),
+        .prd_i      ( prd_i[7:0] ),
+        .data_o     ( data_o     ),
+        .mask_o     ( mask_o     )
       );
 
+      `ASSERT_INIT(AesWidthPRDSBox, WidthPRDSBox == 8)
+
     end else if (SBoxImpl == SBoxImplCanrightMaskedNoreuse) begin : gen_sbox_canright_masked_noreuse
       // Tie off unused inputs.
       logic unused_clk;
@@ -93,6 +103,8 @@
         .mask_o ( mask_o      )
       );
 
+      `ASSERT_INIT(AesWidthPRDSBox, WidthPRDSBox == 18)
+
     end else begin : gen_sbox_canright_masked // SBoxImpl == SBoxImplCanrightMasked
       // Tie off unused inputs.
       logic  unused_clk;
@@ -113,6 +125,7 @@
         .mask_o ( mask_o     )
       );
 
+      `ASSERT_INIT(AesWidthPRDSBox, WidthPRDSBox == 8)
     end
   end
 
diff --git a/hw/ip/aes/rtl/aes_sbox_dom.sv b/hw/ip/aes/rtl/aes_sbox_dom.sv
index 2f4fdec..510aef6 100644
--- a/hw/ip/aes/rtl/aes_sbox_dom.sv
+++ b/hw/ip/aes/rtl/aes_sbox_dom.sv
@@ -32,16 +32,27 @@
 
 `include "prim_assert.sv"
 
-// DOM-indep GF(2^N) multiplier, unpipelined, first-order masked.
+// Packed struct for pseudo-random data (PRD) distribution. Stages 1, 3 and 4 require 8 bits each.
+// Stage 2 requires just 4 bits.
+typedef struct packed {
+  logic [7:0] prd_1;
+  logic [3:0] prd_2;
+  logic [7:0] prd_3;
+  logic [7:0] prd_4;
+} prd_t;
+
+// DOM-indep GF(2^N) multiplier, first-order masked.
 // Computes (a_q ^ b_q) = (a_x ^ b_x) * (a_y ^ b_y), i.e. q = x * y using first-order
 // domain-oriented masking. The sharings of x and y are required to be uniformly random and
 // independent from each other.
 // See Fig. 2 in [1].
 module aes_dom_indep_mul_gf2pn #(
-  parameter int unsigned NPower = 4
+  parameter int unsigned NPower   = 4,
+  parameter bit          Pipeline = 1'b0
 ) (
   input  logic              clk_i,
   input  logic              rst_ni,
+  input  logic              we_i,
   input  logic [NPower-1:0] a_x,    // Share a of x
   input  logic [NPower-1:0] a_y,    // Share a of y
   input  logic [NPower-1:0] b_x,    // Share b of x
@@ -57,14 +68,14 @@
   // Calculation //
   /////////////////
   // Inner-domain terms
-  (* keep = "true" *) logic [NPower-1:0] mul_ax_ay, mul_bx_by;
+  (* keep = "true" *) logic [NPower-1:0] mul_ax_ay_d, mul_bx_by_d;
   if (NPower == 4) begin : gen_inner_mul_gf2p4
-    assign mul_ax_ay = aes_mul_gf2p4(a_x, a_y);
-    assign mul_bx_by = aes_mul_gf2p4(b_x, b_y);
+    assign mul_ax_ay_d = aes_mul_gf2p4(a_x, a_y);
+    assign mul_bx_by_d = aes_mul_gf2p4(b_x, b_y);
 
   end else begin : gen_inner_mul_gf2p2
-    assign mul_ax_ay = aes_mul_gf2p2(a_x, a_y);
-    assign mul_bx_by = aes_mul_gf2p2(b_x, b_y);
+    assign mul_ax_ay_d = aes_mul_gf2p2(a_x, a_y);
+    assign mul_bx_by_d = aes_mul_gf2p2(b_x, b_y);
   end
 
   // Cross-domain terms
@@ -82,8 +93,8 @@
   // Resharing //
   ///////////////
   // Resharing of cross-domain terms
-  (* keep = "true" *) logic [NPower-1:0] aq_z0_d, aq_z0_q;
-  (* keep = "true" *) logic [NPower-1:0] bq_z0_d, bq_z0_q;
+  logic [NPower-1:0] aq_z0_d, bq_z0_d;
+  (* keep = "true" *) logic [NPower-1:0] aq_z0_q, bq_z0_q;
   assign aq_z0_d = z_0 ^ mul_ax_by;
   assign bq_z0_d = z_0 ^ mul_ay_bx;
 
@@ -92,12 +103,46 @@
     if (!rst_ni) begin
       aq_z0_q <= '0;
       bq_z0_q <= '0;
-    end else begin
+    end else if (we_i) begin
       aq_z0_q <= aq_z0_d;
       bq_z0_q <= bq_z0_d;
     end
   end
 
+  /////////////////////////
+  // Optional Pipelining //
+  /////////////////////////
+  logic [NPower-1:0] mul_ax_ay, mul_bx_by;
+
+  if (Pipeline == 1'b1) begin : gen_pipeline
+    // Add pipeline registers on inner-domain terms prior to integration. This allows accepting new
+    // input data every clock cycle and prevents SCA leakage occurring due to the integration of
+    // reshared cross-domain terms with inner-domain terms derived from different input data.
+
+    (* keep = "true" *) logic [NPower-1:0] mul_ax_ay_q, mul_bx_by_q;
+    always_ff @(posedge clk_i or negedge rst_ni) begin
+      if (!rst_ni) begin
+        mul_ax_ay_q <= '0;
+        mul_bx_by_q <= '0;
+      end else if (we_i) begin
+        mul_ax_ay_q <= mul_ax_ay_d;
+        mul_bx_by_q <= mul_bx_by_d;
+      end
+    end
+
+    assign mul_ax_ay = mul_ax_ay_q;
+    assign mul_bx_by = mul_bx_by_q;
+
+  end else begin : gen_no_pipeline
+    // Do not add the optional pipeline registers on the inner-domain terms. This allows to save
+    // some area in case the multiplier does not need to accept new data in every cycle. However,
+    // this can cause SCA leakage as during the clock cycle in which new data arrives, the new
+    // inner-domain terms are integrated with the previous, reshared cross-domain terms.
+
+    assign mul_ax_ay = mul_ax_ay_d;
+    assign mul_bx_by = mul_bx_by_d;
+  end
+
   /////////////////
   // Integration //
   /////////////////
@@ -109,17 +154,19 @@
 
 endmodule
 
-// DOM-dep GF(2^N) multiplier, unpipelined, first-order masked.
+// DOM-dep GF(2^N) multiplier, first-order masked.
 // Computes (a_q ^ b_q) = (a_x ^ b_x) * (a_y ^ b_y), i.e. q = x * y using first-order
 // domain-oriented masking. The sharings of x and y are NOT required to be independent from each
 // other. This is the un-optimized version consuming 3 times N bits of randomness for blinding and
 // resharing. It is not used in the design but we keep it for reference.
 // See Fig. 4 and Formulas 8 - 11 in [1].
 module aes_dom_dep_mul_gf2pn_unopt #(
-  parameter int unsigned NPower = 4
+  parameter int unsigned NPower   = 4,
+  parameter bit          Pipeline = 1'b0
 ) (
   input  logic              clk_i,
   input  logic              rst_ni,
+  input  logic              we_i,
   input  logic [NPower-1:0] a_x,    // Share a of x
   input  logic [NPower-1:0] a_y,    // Share a of y
   input  logic [NPower-1:0] b_x,    // Share b of x
@@ -137,8 +184,8 @@
   // Blinding //
   //////////////
   // Blinding of y by z.
-  (* keep = "true" *) logic [NPower-1:0] a_yz_d, a_yz_q;
-  (* keep = "true" *) logic [NPower-1:0] b_yz_d, b_yz_q;
+  logic [NPower-1:0] a_yz_d, b_yz_d;
+  (* keep = "true" *) logic [NPower-1:0] a_yz_q, b_yz_q;
   assign a_yz_d = a_y ^ a_z;
   assign b_yz_d = b_y ^ b_z;
 
@@ -147,7 +194,7 @@
     if (!rst_ni) begin
       a_yz_q <= '0;
       b_yz_q <= '0;
-    end else begin
+    end else if (we_i) begin
       a_yz_q <= a_yz_d;
       b_yz_q <= b_yz_d;
     end
@@ -158,10 +205,12 @@
   ////////////////
   logic [NPower-1:0] a_mul_x_z, b_mul_x_z;
   aes_dom_indep_mul_gf2pn #(
-    .NPower ( NPower )
+    .NPower   ( NPower   ),
+    .Pipeline ( Pipeline )
   ) aes_dom_indep_mul_gf2pn (
     .clk_i  ( clk_i     ),
     .rst_ni ( rst_ni    ),
+    .we_i   ( we_i      ),
     .a_x    ( a_x       ), // Share a of x
     .a_y    ( a_z       ), // Share a of z
     .b_x    ( b_x       ), // Share b of x
@@ -171,6 +220,40 @@
     .b_q    ( b_mul_x_z )  // Share b of x * z
   );
 
+  /////////////////////////
+  // Optional Pipelining //
+  /////////////////////////
+  logic [NPower-1:0] a_x_calc, b_x_calc;
+
+  if (Pipeline == 1'b1) begin : gen_pipeline
+    // Add pipeline registers for input x. This allows accepting new input data every clock cycle
+    // and prevents SCA leakage occurring due to the multiplication of input x with b belonging to
+    // different clock cycles.
+
+    (* keep = "true" *) logic [NPower-1:0] a_x_q, b_x_q;
+    always_ff @(posedge clk_i or negedge rst_ni) begin
+      if (!rst_ni) begin
+        a_x_q <= '0;
+        b_x_q <= '0;
+      end else if (we_i) begin
+        a_x_q <= a_x;
+        b_x_q <= b_x;
+      end
+    end
+
+    assign a_x_calc = a_x_q;
+    assign b_x_calc = b_x_q;
+
+  end else begin : gen_no_pipeline
+    // Do not add the optional pipeline registers for input x. This allows to save some area in
+    // case the multiplier does not need to accept new data in every cycle. However, this can cause
+    // SCA leakage as during the clock cycle in which new data arrives, the new x input is
+    // multiplied with the previous b.
+
+    assign a_x_calc = a_x;
+    assign b_x_calc = b_x;
+  end
+
   /////////////////
   // Calculation //
   /////////////////
@@ -180,12 +263,12 @@
 
   logic [NPower-1:0] a_mul_ax_b, b_mul_bx_b;
   if (NPower == 4) begin : gen_mul_gf2p4
-    assign a_mul_ax_b = aes_mul_gf2p4(a_x, b);
-    assign b_mul_bx_b = aes_mul_gf2p4(b_x, b);
+    assign a_mul_ax_b = aes_mul_gf2p4(a_x_calc, b);
+    assign b_mul_bx_b = aes_mul_gf2p4(b_x_calc, b);
 
   end else begin : gen_mul_gf2p2
-    assign a_mul_ax_b = aes_mul_gf2p2(a_x, b);
-    assign b_mul_bx_b = aes_mul_gf2p2(b_x, b);
+    assign a_mul_ax_b = aes_mul_gf2p2(a_x_calc, b);
+    assign b_mul_bx_b = aes_mul_gf2p2(b_x_calc, b);
   end
 
   /////////////////
@@ -199,7 +282,7 @@
 
 endmodule
 
-// DOM-dep GF(2^N) multiplier, unpipelined, first-order masked.
+// DOM-dep GF(2^N) multiplier, first-order masked.
 // Computes (a_q ^ b_q) = (a_x ^ b_x) * (a_y ^ b_y), i.e. q = x * y using first-order
 // domain-oriented masking. The sharings of x and y are NOT required to be independent from each
 // other. This is the optimized version consuming 2 instead of 3 times N bits of randomness for
@@ -207,13 +290,17 @@
 // See Formula 12 in [1].
 module aes_dom_dep_mul_gf2pn #(
   parameter int unsigned NPower      = 4,
-  parameter bit          PreDOMIndep = 1'b0 // 1'b0: Not followed by a DOM-indep multiplier, this
-                                            //       enables additional area optimizations
-                                            // 1'b1: Directly followed by a DOM-indep multiplier,
-                                            //       this is the version discussed in [1].
+  parameter bit          Pipeline    = 1'b0,
+  parameter bit          PreDomIndep = 1'b0 // 1'b0: Not followed by an un-pipelined DOM-indep
+                                            //       multiplier, this enables additional area
+                                            //       optimizations
+                                            // 1'b1: Directly followed by an un-pipelined
+                                            //       DOM-indep multiplier, this is the version
+                                            //       discussed in [1].
 ) (
   input  logic              clk_i,
   input  logic              rst_ni,
+  input  logic              we_i,
   input  logic [NPower-1:0] a_x,    // Share a of x
   input  logic [NPower-1:0] a_y,    // Share a of y
   input  logic [NPower-1:0] b_x,    // Share b of x
@@ -230,8 +317,8 @@
   // Blinding //
   //////////////
   // Blinding of y by z_0.
-  (* keep = "true" *) logic [NPower-1:0] a_yz0_d, a_yz0_q;
-  (* keep = "true" *) logic [NPower-1:0] b_yz0_d, b_yz0_q;
+  logic [NPower-1:0] a_yz0_d, b_yz0_d;
+  (* keep = "true" *) logic [NPower-1:0] a_yz0_q, b_yz0_q;
   assign a_yz0_d = a_y ^ z_0;
   assign b_yz0_d = b_y ^ z_0;
 
@@ -240,7 +327,7 @@
     if (!rst_ni) begin
       a_yz0_q <= '0;
       b_yz0_q <= '0;
-    end else begin
+    end else if (we_i) begin
       a_yz0_q <= a_yz0_d;
       b_yz0_q <= b_yz0_d;
     end
@@ -266,8 +353,8 @@
   end
 
   // Resharing
-  (* keep = "true" *) logic [NPower-1:0] axz0_z1_d, axz0_z1_q;
-  (* keep = "true" *) logic [NPower-1:0] bxz0_z1_d, bxz0_z1_q;
+  logic [NPower-1:0] axz0_z1_d, bxz0_z1_d;
+  (* keep = "true" *) logic [NPower-1:0] axz0_z1_q, bxz0_z1_q;
   assign axz0_z1_d = mul_ax_z0 ^ z_1;
   assign bxz0_z1_d = mul_bx_z0 ^ z_1;
 
@@ -276,12 +363,56 @@
     if (!rst_ni) begin
       axz0_z1_q <= '0;
       bxz0_z1_q <= '0;
-    end else begin
+    end else if (we_i) begin
       axz0_z1_q <= axz0_z1_d;
       bxz0_z1_q <= bxz0_z1_d;
     end
   end
 
+  /////////////////////////
+  // Optional Pipelining //
+  /////////////////////////
+  logic [NPower-1:0] a_x_calc, b_x_calc, a_y_calc, b_y_calc;
+
+  if (Pipeline == 1'b1 && PreDomIndep != 1'b1) begin : gen_pipeline
+    // Add pipeline registers for inputs x and y. This allows accepting new input data every clock
+    // cycle and prevents SCA leakage occurring due to the multiplication of inputs x and y with
+    // d_b belonging to different clock cycles.
+    //
+    // The PreDomIndep variant has the required pipeline registers built in already.
+
+    (* keep = "true" *) logic [NPower-1:0] a_x_q, b_x_q, a_y_q, b_y_q;
+    always_ff @(posedge clk_i or negedge rst_ni) begin
+      if (!rst_ni) begin
+        a_x_q <= '0;
+        b_x_q <= '0;
+        a_y_q <= '0;
+        b_y_q <= '0;
+      end else if (we_i) begin
+        a_x_q <= a_x;
+        b_x_q <= b_x;
+        a_y_q <= a_y;
+        b_y_q <= b_y;
+      end
+    end
+
+    assign a_x_calc = a_x_q;
+    assign b_x_calc = b_x_q;
+    assign a_y_calc = a_y_q;
+    assign b_y_calc = b_y_q;
+
+  end else begin : gen_no_pipeline
+    // Do not add the optional pipeline registers for inputs x and y. This allows to save some area
+    // in case the multiplier does not need to accept new data in every cycle. However, this can
+    // cause SCA leakage as during the clock cycle in which new data arrives, the new x and y
+    // inputs are multiplied with the previous d_b.
+
+    assign a_x_calc = a_x;
+    assign b_x_calc = b_x;
+    assign a_y_calc = a_y;
+    assign b_y_calc = b_y;
+  end
+
   ///////////////////////////////
   // Calculation & Integration //
   ///////////////////////////////
@@ -295,22 +426,22 @@
   // is only suitable for first-order masking.
   // See Formula 12 in [1].
 
-  if (PreDOMIndep == 1'b1) begin : gen_pre_dom_indep
-    // This DOM-dep multiplier is directly followed by a DOM-indep multiplier without an additional
-    // pipeline stage in between. To prevent SCA leakage in the DOM-indep multiplier, the d_y and
-    // _D_y_z0 parts of d_b need to be individually multiplied with input x and then the results
-    // need to be integrated (summed up) on a per-domain basis.
+  if (PreDomIndep == 1'b1) begin : gen_pre_dom_indep
+    // This DOM-dep multiplier is directly followed by an un-pipelined DOM-indep multiplier. To
+    // prevent SCA leakage in the un-pipelined DOM-indep multiplier, the d_y and _D_y_z0 parts of
+    // d_b need to be individually multiplied with input x and then the results need to be
+    // integrated (summed up) on a per-domain basis.
 
     // d_y part: Inner-domain terms of x * y
-    (* keep = "true" *) logic [NPower-1:0] mul_ax_ay_d, mul_ax_ay_q;
-    (* keep = "true" *) logic [NPower-1:0] mul_bx_by_d, mul_bx_by_q;
+    logic [NPower-1:0] mul_ax_ay_d, mul_bx_by_d;
+    (* keep = "true" *) logic [NPower-1:0] mul_ax_ay_q, mul_bx_by_q;
     if (NPower == 4) begin : gen_inner_mul_gf2p4
-      assign mul_ax_ay_d = aes_mul_gf2p4(a_x, a_y);
-      assign mul_bx_by_d = aes_mul_gf2p4(b_x, b_y);
+      assign mul_ax_ay_d = aes_mul_gf2p4(a_x_calc, a_y_calc);
+      assign mul_bx_by_d = aes_mul_gf2p4(b_x_calc, b_y_calc);
 
     end else begin : gen_inner_mul_gf2p2
-      assign mul_ax_ay_d = aes_mul_gf2p2(a_x, a_y);
-      assign mul_bx_by_d = aes_mul_gf2p2(b_x, b_y);
+      assign mul_ax_ay_d = aes_mul_gf2p2(a_x_calc, a_y_calc);
+      assign mul_bx_by_d = aes_mul_gf2p2(b_x_calc, b_y_calc);
     end
 
     // Registers
@@ -318,22 +449,21 @@
       if (!rst_ni) begin
         mul_ax_ay_q <= '0;
         mul_bx_by_q <= '0;
-      end else begin
+      end else if (we_i) begin
         mul_ax_ay_q <= mul_ax_ay_d;
         mul_bx_by_q <= mul_bx_by_d;
       end
     end
 
     // Input Registers
-    (* keep = "true" *) logic [NPower-1:0] a_x_q;
-    (* keep = "true" *) logic [NPower-1:0] b_x_q;
+    (* keep = "true" *) logic [NPower-1:0] a_x_q, b_x_q;
     always_ff @(posedge clk_i or negedge rst_ni) begin
       if (!rst_ni) begin
         a_x_q <= '0;
         b_x_q <= '0;
-      end else begin
-        a_x_q <= a_x;
-        b_x_q <= b_x;
+      end else if (we_i) begin
+        a_x_q <= a_x_calc;
+        b_x_q <= b_x_calc;
       end
     end
 
@@ -354,23 +484,23 @@
     assign b_q = bxz0_z1_q ^ mul_bx_by_q ^ mul_bx_ayz0;
 
   end else begin : gen_not_pre_dom_indep
-    // This DOM-dep multiplier is not directly followed by a DOM-indep multiplier. As a result,
-    // the the d_y and _D_y_z0 parts of d_b can be summed up prior to the multiplication with input
-    // x which helps saving 2 GF multipliers and 4 registers (NPower flops each).
+    // This DOM-dep multiplier is not directly followed by an un-pipelined DOM-indep multiplier. As
+    // a result, the the d_y and _D_y_z0 parts of d_b can be summed up prior to the multiplication
+    // with input x which allows saving 2 GF multipliers.
 
     // Sum up d_y and _D_y_z0.
     (* keep = "true" *) logic [NPower-1:0] a_b, b_b;
-    assign a_b = a_y ^ b_yz0_q;
-    assign b_b = b_y ^ a_yz0_q;
+    assign a_b = a_y_calc ^ b_yz0_q;
+    assign b_b = b_y_calc ^ a_yz0_q;
 
     // GF multiplications
     (* keep = "true" *) logic [NPower-1:0] a_mul_ax_b, b_mul_bx_b;
     if (NPower == 4) begin : gen_mul_gf2p4
-      assign a_mul_ax_b = aes_mul_gf2p4(a_x, a_b);
-      assign b_mul_bx_b = aes_mul_gf2p4(b_x, b_b);
+      assign a_mul_ax_b = aes_mul_gf2p4(a_x_calc, a_b);
+      assign b_mul_bx_b = aes_mul_gf2p4(b_x_calc, b_b);
     end else begin : gen_mul_gf2p2
-      assign a_mul_ax_b = aes_mul_gf2p2(a_x, a_b);
-      assign b_mul_bx_b = aes_mul_gf2p2(b_x, b_b);
+      assign a_mul_ax_b = aes_mul_gf2p2(a_x_calc, a_b);
+      assign b_mul_bx_b = aes_mul_gf2p2(b_x_calc, b_b);
     end
 
     // Integration
@@ -388,50 +518,56 @@
 module aes_dom_inverse_gf2p4 (
   input  logic        clk_i,
   input  logic        rst_ni,
+  input  logic  [1:0] we_i,
   input  logic  [3:0] a_gamma,
   input  logic  [3:0] b_gamma,
-  input  logic [11:0] prd,
+  input  logic  [3:0] prd_2,
+  input  logic  [7:0] prd_3,
   output logic  [3:0] a_gamma_inv,
   output logic  [3:0] b_gamma_inv
 );
 
   import aes_sbox_canright_pkg::*;
 
-  // Distribute the randomness for the various multiplers.
-  logic [3:0] z_2;
-  logic [3:0] z_3_1;
-  logic [3:0] z_3_0;
-  assign z_2   = prd[3:0];
-  assign z_3_0 = prd[7:4];
-  assign z_3_1 = prd[11:8];
-
   /////////////
   // Stage 2 //
   /////////////
   // Formula 13 in [2].
 
   logic [1:0] a_gamma1, a_gamma0, b_gamma1, b_gamma0, a_gamma1_gamma0, b_gamma1_gamma0;
-  (* keep = "true" *) logic [1:0] a_gamma_ss, b_gamma_ss;
   assign a_gamma1 = a_gamma[3:2];
   assign a_gamma0 = a_gamma[1:0];
   assign b_gamma1 = b_gamma[3:2];
   assign b_gamma0 = b_gamma[1:0];
 
-  assign a_gamma_ss = aes_scale_omega2_gf2p2(aes_square_gf2p2(a_gamma1 ^ a_gamma0));
-  assign b_gamma_ss = aes_scale_omega2_gf2p2(aes_square_gf2p2(b_gamma1 ^ b_gamma0));
+  logic [1:0] a_gamma_ss_d, b_gamma_ss_d;
+  (* keep = "true" *) logic [1:0] a_gamma_ss_q, b_gamma_ss_q;
+  assign a_gamma_ss_d = aes_scale_omega2_gf2p2(aes_square_gf2p2(a_gamma1 ^ a_gamma0));
+  assign b_gamma_ss_d = aes_scale_omega2_gf2p2(aes_square_gf2p2(b_gamma1 ^ b_gamma0));
+  always_ff @(posedge clk_i or negedge rst_ni) begin
+    if (!rst_ni) begin
+      a_gamma_ss_q <= '0;
+      b_gamma_ss_q <= '0;
+    end else if (we_i[0]) begin
+      a_gamma_ss_q <= a_gamma_ss_d;
+      b_gamma_ss_q <= b_gamma_ss_d;
+    end
+  end
 
   aes_dom_dep_mul_gf2pn #(
     .NPower      ( 2    ),
-    .PreDOMIndep ( 1'b0 )
+    .Pipeline    ( 1'b1 ),
+    .PreDomIndep ( 1'b0 )
   ) aes_dom_mul_gamma1_gamma0 (
     .clk_i  ( clk_i           ),
     .rst_ni ( rst_ni          ),
+    .we_i   ( we_i[0]         ),
     .a_x    ( a_gamma1        ), // Share a of x
     .a_y    ( a_gamma0        ), // Share a of y
     .b_x    ( b_gamma1        ), // Share b of x
     .b_y    ( b_gamma0        ), // Share b of y
-    .z_0    ( z_2[1:0]        ), // Randomness for blinding
-    .z_1    ( z_2[3:2]        ), // Randomness for resharing
+    .z_0    ( prd_2[1:0]      ), // Randomness for blinding
+    .z_1    ( prd_2[3:2]      ), // Randomness for resharing
     .a_q    ( a_gamma1_gamma0 ), // Share a of q
     .b_q    ( b_gamma1_gamma0 )  // Share b of q
   );
@@ -442,39 +578,58 @@
 
   // Formulas 14 and 15 in [2].
   (* keep = "true" *) logic [1:0] a_omega, b_omega;
-  assign a_omega = aes_square_gf2p2(a_gamma1_gamma0 ^ a_gamma_ss);
-  assign b_omega = aes_square_gf2p2(b_gamma1_gamma0 ^ b_gamma_ss);
+  assign a_omega = aes_square_gf2p2(a_gamma1_gamma0 ^ a_gamma_ss_q);
+  assign b_omega = aes_square_gf2p2(b_gamma1_gamma0 ^ b_gamma_ss_q);
 
   // Formulas 16 and 17 in [2].
 
+  (* keep = "true" *) logic [1:0] a_gamma1_q, a_gamma0_q, b_gamma1_q, b_gamma0_q;
+  always_ff @(posedge clk_i or negedge rst_ni) begin
+    if (!rst_ni) begin
+      a_gamma1_q <= '0;
+      a_gamma0_q <= '0;
+      b_gamma1_q <= '0;
+      b_gamma0_q <= '0;
+    end else if (we_i[0]) begin
+      a_gamma1_q <= a_gamma1;
+      a_gamma0_q <= a_gamma0;
+      b_gamma1_q <= b_gamma1;
+      b_gamma0_q <= b_gamma0;
+    end
+  end
+
   aes_dom_dep_mul_gf2pn #(
     .NPower      ( 2    ),
-    .PreDOMIndep ( 1'b1 )
+    .Pipeline    ( 1'b1 ),
+    .PreDomIndep ( 1'b0 )
   ) aes_dom_mul_omega_gamma1 (
     .clk_i  ( clk_i            ),
     .rst_ni ( rst_ni           ),
-    .a_x    ( a_gamma1         ), // Share a of x
+    .we_i   ( we_i[1]          ),
+    .a_x    ( a_gamma1_q       ), // Share a of x
     .a_y    ( a_omega          ), // Share a of y
-    .b_x    ( b_gamma1         ), // Share b of x
+    .b_x    ( b_gamma1_q       ), // Share b of x
     .b_y    ( b_omega          ), // Share b of y
-    .z_0    ( z_3_1[1:0]       ), // Randomness for blinding
-    .z_1    ( z_3_1[3:2]       ), // Randomness for resharing
+    .z_0    ( prd_3[5:4]       ), // Randomness for blinding
+    .z_1    ( prd_3[7:6]       ), // Randomness for resharing
     .a_q    ( a_gamma_inv[1:0] ), // Share a of q
     .b_q    ( b_gamma_inv[1:0] )  // Share b of q
   );
 
   aes_dom_dep_mul_gf2pn #(
     .NPower      ( 2    ),
-    .PreDOMIndep ( 1'b1 )
+    .Pipeline    ( 1'b1 ),
+    .PreDomIndep ( 1'b0 )
   ) aes_dom_mul_omega_gamma0 (
     .clk_i  ( clk_i            ),
     .rst_ni ( rst_ni           ),
+    .we_i   ( we_i[1]          ),
     .a_x    ( a_omega          ), // Share a of x
-    .a_y    ( a_gamma0         ), // Share a of y
+    .a_y    ( a_gamma0_q       ), // Share a of y
     .b_x    ( b_omega          ), // Share b of x
-    .b_y    ( b_gamma0         ), // Share b of y
-    .z_0    ( z_3_0[1:0]       ), // Randomness for blinding
-    .z_1    ( z_3_0[3:2]       ), // Randomness for resharing
+    .b_y    ( b_gamma0_q       ), // Share b of y
+    .z_0    ( prd_3[1:0]       ), // Randomness for blinding
+    .z_1    ( prd_3[3:2]       ), // Randomness for resharing
     .a_q    ( a_gamma_inv[3:2] ), // Share a of q
     .b_q    ( b_gamma_inv[3:2] )  // Share b of q
   );
@@ -486,58 +641,62 @@
 module aes_dom_inverse_gf2p8 (
   input  logic        clk_i,
   input  logic        rst_ni,
+  input  logic  [3:0] we_i,
   input  logic  [7:0] a_y,     // input data masked by b_y
   input  logic  [7:0] b_y,     // input mask
-  input  logic [27:0] prd,     // pseudo-random data, e.g. for intermediate masks
+  input  prd_t        prd,     // pseudo-random data, e.g. for intermediate masks
   output logic  [7:0] a_y_inv, // output data masked by b_y_inv
   output logic  [7:0] b_y_inv  // output mask
 );
 
   import aes_sbox_canright_pkg::*;
 
-  // Distribute the randomness for the various stages.
-  logic  [7:0] z_1;
-  logic [11:0] z_23;
-  logic  [3:0] z_4_0;
-  logic  [3:0] z_4_1;
-  assign z_1   = prd[7:0];
-  assign z_23  = prd[19:8];
-  assign z_4_0 = prd[23:20];
-  assign z_4_1 = prd[27:24];
-
   /////////////
   // Stage 1 //
   /////////////
   // Formula 12 in [2].
 
   logic [3:0] a_y1, a_y0, b_y1, b_y0, a_y1_y0, b_y1_y0;
-  (* keep = "true" *) logic [3:0] a_y_ss, b_y_ss, a_gamma, b_gamma;
+  (* keep = "true" *) logic [3:0] a_gamma, b_gamma;
   assign a_y1 = a_y[7:4];
   assign a_y0 = a_y[3:0];
   assign b_y1 = b_y[7:4];
   assign b_y0 = b_y[3:0];
 
-  assign a_y_ss = aes_square_scale_gf2p4_gf2p2(a_y1 ^ a_y0);
-  assign b_y_ss = aes_square_scale_gf2p4_gf2p2(b_y1 ^ b_y0);
+  logic [3:0] a_y_ss_d, b_y_ss_d;
+  (* keep = "true" *) logic [3:0] a_y_ss_q, b_y_ss_q;
+  assign a_y_ss_d = aes_square_scale_gf2p4_gf2p2(a_y1 ^ a_y0);
+  assign b_y_ss_d = aes_square_scale_gf2p4_gf2p2(b_y1 ^ b_y0);
+  always_ff @(posedge clk_i or negedge rst_ni) begin
+    if (!rst_ni) begin
+      a_y_ss_q <= '0;
+      b_y_ss_q <= '0;
+    end else if (we_i[0]) begin
+      a_y_ss_q <= a_y_ss_d;
+      b_y_ss_q <= b_y_ss_d;
+    end
+  end
 
   aes_dom_dep_mul_gf2pn #(
     .NPower      ( 4    ),
-    .PreDOMIndep ( 1'b0 )
+    .Pipeline    ( 1'b1 ),
+    .PreDomIndep ( 1'b0 )
   ) aes_dom_mul_y1_y0 (
-    .clk_i  ( clk_i    ),
-    .rst_ni ( rst_ni   ),
-    .a_x    ( a_y1     ), // Share a of x
-    .a_y    ( a_y0     ), // Share a of y
-    .b_x    ( b_y1     ), // Share b of x
-    .b_y    ( b_y0     ), // Share b of y
-    .z_0    ( z_1[3:0] ), // Randomness for blinding
-    .z_1    ( z_1[7:4] ), // Randomness for resharing
-    .a_q    ( a_y1_y0  ), // Share a of q
-    .b_q    ( b_y1_y0  )  // Share b of q
+    .clk_i  ( clk_i          ),
+    .rst_ni ( rst_ni         ),
+    .we_i   ( we_i[0]        ),
+    .a_x    ( a_y1           ), // Share a of x
+    .a_y    ( a_y0           ), // Share a of y
+    .b_x    ( b_y1           ), // Share b of x
+    .b_y    ( b_y0           ), // Share b of y
+    .z_0    ( prd.prd_1[3:0] ), // Randomness for blinding
+    .z_1    ( prd.prd_1[7:4] ), // Randomness for resharing
+    .a_q    ( a_y1_y0        ), // Share a of q
+    .b_q    ( b_y1_y0        )  // Share b of q
   );
 
-  assign a_gamma = a_y_ss ^ a_y1_y0;
-  assign b_gamma = b_y_ss ^ b_y1_y0;
+  assign a_gamma = a_y_ss_q ^ a_y1_y0;
+  assign b_gamma = b_y_ss_q ^ b_y1_y0;
 
   ////////////////////
   // Stages 2 and 3 //
@@ -547,13 +706,15 @@
 
   // a_gamma is masked by b_gamma, a_gamma_inv is masked by b_gamma_inv.
   aes_dom_inverse_gf2p4 aes_dom_inverse_gf2p4 (
-    .clk_i       ( clk_i   ),
-    .rst_ni      ( rst_ni  ),
-    .a_gamma     ( a_gamma ),
-    .b_gamma     ( b_gamma ),
-    .prd         ( z_23    ),
-    .a_gamma_inv ( a_theta ),
-    .b_gamma_inv ( b_theta )
+    .clk_i       ( clk_i     ),
+    .rst_ni      ( rst_ni    ),
+    .we_i        ( we_i[2:1] ),
+    .a_gamma     ( a_gamma   ),
+    .b_gamma     ( b_gamma   ),
+    .prd_2       ( prd.prd_2 ),
+    .prd_3       ( prd.prd_3 ),
+    .a_gamma_inv ( a_theta   ),
+    .b_gamma_inv ( b_theta   )
   );
 
   /////////////
@@ -561,32 +722,51 @@
   /////////////
   // Formulas 18 and 19 in [2].
 
+  (* keep = "true" *) logic [3:0] a_y1_q, a_y0_q, b_y1_q, b_y0_q;
+  always_ff @(posedge clk_i or negedge rst_ni) begin
+    if (!rst_ni) begin
+      a_y1_q <= '0;
+      a_y0_q <= '0;
+      b_y1_q <= '0;
+      b_y0_q <= '0;
+    end else if (we_i[2]) begin
+      a_y1_q <= a_y1;
+      a_y0_q <= a_y0;
+      b_y1_q <= b_y1;
+      b_y0_q <= b_y0;
+    end
+  end
+
   aes_dom_indep_mul_gf2pn #(
-    .NPower ( 4 )
+    .NPower   ( 4    ),
+    .Pipeline ( 1'b1 )
   ) aes_dom_mul_theta_y1 (
-    .clk_i  ( clk_i        ),
-    .rst_ni ( rst_ni       ),
-    .a_x    ( a_y1         ), // Share a of x
-    .a_y    ( a_theta      ), // Share a of y
-    .b_x    ( b_y1         ), // Share b of x
-    .b_y    ( b_theta      ), // Share b of y
-    .z_0    ( z_4_1        ), // Randomness for resharing
-    .a_q    ( a_y_inv[3:0] ), // Share a of q
-    .b_q    ( b_y_inv[3:0] )  // Share b of q
+    .clk_i  ( clk_i          ),
+    .rst_ni ( rst_ni         ),
+    .we_i   ( we_i[3]        ),
+    .a_x    ( a_y1_q         ), // Share a of x
+    .a_y    ( a_theta        ), // Share a of y
+    .b_x    ( b_y1_q         ), // Share b of x
+    .b_y    ( b_theta        ), // Share b of y
+    .z_0    ( prd.prd_4[7:4] ), // Randomness for resharing
+    .a_q    ( a_y_inv[3:0]   ), // Share a of q
+    .b_q    ( b_y_inv[3:0]   )  // Share b of q
   );
 
   aes_dom_indep_mul_gf2pn #(
-    .NPower ( 4 )
+    .NPower   ( 4    ),
+    .Pipeline ( 1'b1 )
   ) aes_dom_mul_theta_y0 (
-    .clk_i  ( clk_i        ),
-    .rst_ni ( rst_ni       ),
-    .a_x    ( a_theta      ), // Share a of x
-    .a_y    ( a_y0         ), // Share a of y
-    .b_x    ( b_theta      ), // Share b of x
-    .b_y    ( b_y0         ), // Share b of y
-    .z_0    ( z_4_0        ), // Randomness for resharing
-    .a_q    ( a_y_inv[7:4] ), // Share a of q
-    .b_q    ( b_y_inv[7:4] )  // Share b of q
+    .clk_i  ( clk_i          ),
+    .rst_ni ( rst_ni         ),
+    .we_i   ( we_i[3]        ),
+    .a_x    ( a_theta        ), // Share a of x
+    .a_y    ( a_y0_q         ), // Share a of y
+    .b_x    ( b_theta        ), // Share b of x
+    .b_y    ( b_y0_q         ), // Share b of y
+    .z_0    ( prd.prd_4[3:0] ), // Randomness for resharing
+    .a_q    ( a_y_inv[7:4]   ), // Share a of q
+    .b_q    ( b_y_inv[7:4]   )  // Share b of q
   );
 
 endmodule
@@ -600,7 +780,8 @@
   input  aes_pkg::ciph_op_e op_i,
   input  logic        [7:0] data_i, // masked, the actual input data is data_i ^ mask_i
   input  logic        [7:0] mask_i, // input mask
-  input  logic       [27:0] prd_i,  // pseudo-random data for remasking
+  input  logic        [7:0] prd_i,  // pseudo-random data for remasking, in total we need 28 bits
+                                    // of PRD per evaluation, but at most 8 bits per cycle
   output logic        [7:0] data_o, // masked, the actual output data is data_o ^ mask_o
   output logic        [7:0] mask_o  // output mask
 );
@@ -610,6 +791,8 @@
 
   logic [7:0] in_data_basis_x, out_data_basis_x;
   logic [7:0] in_mask_basis_x, out_mask_basis_x;
+  logic [3:0] we;
+  prd_t       prd_d, prd_q;
 
   // Convert data to normal basis X.
   assign in_data_basis_x = (op_i == CIPH_FWD) ? aes_mvm(data_i, A2X) :
@@ -624,9 +807,10 @@
   aes_dom_inverse_gf2p8 aes_dom_inverse_gf2p8 (
     .clk_i   ( clk_i            ),
     .rst_ni  ( rst_ni           ),
+    .we_i    ( we               ),
     .a_y     ( in_data_basis_x  ), // input
     .b_y     ( in_mask_basis_x  ), // input
-    .prd     ( prd_i            ), // input
+    .prd     ( prd_d            ), // input
     .a_y_inv ( out_data_basis_x ), // output
     .b_y_inv ( out_mask_basis_x )  // output
   );
@@ -642,9 +826,9 @@
 
   // Counter register
   logic [2:0] count_d, count_q;
-  assign count_d = (out_req_o && out_ack_i) ? '0               :
-                   out_req_o                ? count_q          :
-                   en_i                     ? count_q + 3'b001 : count_q;
+  assign count_d = (out_req_o && out_ack_i) ? '0             :
+                   out_req_o                ? count_q        :
+                   en_i                     ? count_q + 3'd1 : count_q;
   always_ff @(posedge clk_i or negedge rst_ni) begin : reg_count
     if (!rst_ni) begin
       count_q <= '0;
@@ -652,6 +836,44 @@
       count_q <= count_d;
     end
   end
-  assign out_req_o = en_i & count_q == 3'b100;
+  assign out_req_o = en_i & count_q == 3'd4;
+
+  // Write enable signals for internal registers
+  assign we[0] = en_i & count_q == 3'd0;
+  assign we[1] = en_i & count_q == 3'd1;
+  assign we[2] = en_i & count_q == 3'd2;
+  assign we[3] = en_i & count_q == 3'd3;
+
+  // Buffer and forward PRD for the individual stages. We get 8 bits per cycle from the PRNG.
+  // Stage 1, 3 and 4 require 8 bits each. Stage 2 requires just 4 bits.
+  always_comb begin : iv_mux
+    unique case (we)
+      4'b0000: prd_d = prd_q;
+      4'b0001: prd_d = '{prd_1: prd_i,
+                         prd_2: prd_q.prd_2,
+                         prd_3: prd_q.prd_3,
+                         prd_4: prd_q.prd_4};
+      4'b0010: prd_d = '{prd_1: prd_q.prd_1,
+                         prd_2: prd_i[3:0],
+                         prd_3: prd_q.prd_3,
+                         prd_4: prd_q.prd_4};
+      4'b0100: prd_d = '{prd_1: prd_q.prd_1,
+                         prd_2: prd_q.prd_2,
+                         prd_3: prd_i,
+                         prd_4: prd_q.prd_4};
+      4'b1000: prd_d = '{prd_1: prd_q.prd_1,
+                         prd_2: prd_q.prd_2,
+                         prd_3: prd_q.prd_3,
+                         prd_4: prd_i};
+      default: prd_d = prd_q;
+    endcase
+  end
+  always_ff @(posedge clk_i or negedge rst_ni) begin : reg_prd
+    if (!rst_ni) begin
+      prd_q <= '0;
+    end else if (|we) begin
+      prd_q <= prd_d;
+    end
+  end
 
 endmodule