[aes] Prepare distribution of PRD for domain-oriented masking (DOM)

Unlike for the masked Canright S-Boxes, the output masks for SubBytes are
not known beforehand when using DOM. Instead, the output masks are
generated from the PRD input inside the individual S-Boxes.

In order to simplify the integration of different masking schemes, this
commit thus aligns how the masking PRD is distributed when using the masked
Canright S-Boxes. We no longer separate the masking PRNG output into output
masks and PRD for SubBytes at the cipher core level but feed the full
per-S-Box PRNG output into each S-Box. The "generation" of the output mask
is performed inside the S-Box and the output mask becomes an output port
also for the masked Canright S-Boxes.

Signed-off-by: Pirmin Vogel <vogelpi@lowrisc.org>
diff --git a/hw/ip/aes/pre_dv/aes_sbox_lec/aes_sbox_masked_wrapper.sv b/hw/ip/aes/pre_dv/aes_sbox_lec/aes_sbox_masked_wrapper.sv
index e544b7f..e50a944 100644
--- a/hw/ip/aes/pre_dv/aes_sbox_lec/aes_sbox_masked_wrapper.sv
+++ b/hw/ip/aes/pre_dv/aes_sbox_lec/aes_sbox_masked_wrapper.sv
@@ -10,25 +10,24 @@
   output logic [7:0]        data_o
 );
 
-  logic [7:0] in_data_m, out_data_m;
-  logic [7:0] in_mask, out_mask;
-  logic [9:0] prd_masking;
+  logic  [7:0] in_data_m, out_data_m;
+  logic  [7:0] in_mask, out_mask;
+  logic [17:0] prd_masking;
 
   // The mask inputs are tied to constant values.
   assign in_mask     = 8'hAA;
-  assign out_mask    = 8'h55;
-  assign prd_masking = 10'h2AA;
+  assign prd_masking = 18'h2AAAA;
 
   // Mask input data
   assign in_data_m = data_i ^ in_mask;
 
   aes_sbox_masked aes_sbox_masked (
-    .op_i          ( op_i        ),
-    .data_i        ( in_data_m   ),
-    .in_mask_i     ( in_mask     ),
-    .out_mask_i    ( out_mask    ),
-    .prd_masking_i ( prd_masking ),
-    .data_o        ( out_data_m  )
+    .op_i   ( op_i        ),
+    .data_i ( in_data_m   ),
+    .mask_i ( in_mask     ),
+    .prd_i  ( prd_masking ),
+    .data_o ( out_data_m  ),
+    .mask_o ( out_mask    )
   );
 
   // Unmask output data
diff --git a/hw/ip/aes/pre_dv/aes_sbox_tb/rtl/aes_sbox_tb.sv b/hw/ip/aes/pre_dv/aes_sbox_tb/rtl/aes_sbox_tb.sv
index 5a05745..b996cf6 100644
--- a/hw/ip/aes/pre_dv/aes_sbox_tb/rtl/aes_sbox_tb.sv
+++ b/hw/ip/aes/pre_dv/aes_sbox_tb/rtl/aes_sbox_tb.sv
@@ -51,13 +51,18 @@
   );
 
   // Mask Generation
+  parameter int unsigned WidthPRDSBoxCanrightMasked        = 8;
+  parameter int unsigned WidthPRDSBoxCanrightMaskedNoreuse = 18;
+
   logic              [7:0] masked_stimulus;
   logic              [7:0] in_mask;
-  logic              [7:0] out_mask;
-  logic [WidthPRDSBox-1:0] prd_masking;
+
   logic              [7:0] masked_response [NUM_SBOX_IMPLS_MASKED];
-  logic             [31:0] tmp;
-  logic              [5:0] unused_tmp;
+  logic              [7:0] out_mask [NUM_SBOX_IMPLS_MASKED];
+
+  logic                                       [31:0] tmp;
+  logic [31-(WidthPRDSBoxCanrightMaskedNoreuse+8):0] unused_tmp;
+  logic      [WidthPRDSBoxCanrightMaskedNoreuse-1:0] prd_masking;
 
   always_ff @(posedge clk_i or negedge rst_ni) begin : reg_tmp
     if (!rst_ni) begin
@@ -67,34 +72,34 @@
     end
   end
   assign in_mask     = tmp[7:0];
-  assign out_mask    = tmp[15:8];
-  assign prd_masking = tmp[WidthPRDSBox-1+16:16];
-  assign unused_tmp  = tmp[31:WidthPRDSBox+16];
+  assign prd_masking = tmp[8 +: WidthPRDSBoxCanrightMaskedNoreuse];
+  assign unused_tmp  = tmp[31:WidthPRDSBoxCanrightMaskedNoreuse+8];
 
   assign masked_stimulus = stimulus ^ in_mask;
 
   // Instantiate Masked SBox Implementations
   aes_sbox_canright_masked_noreuse aes_sbox_canright_masked_noreuse (
-    .op_i          ( op                 ),
-    .data_i        ( masked_stimulus    ),
-    .in_mask_i     ( in_mask            ),
-    .out_mask_i    ( out_mask           ),
-    .prd_masking_i ( prd_masking        ),
-    .data_o        ( masked_response[0] )
+    .op_i   ( op                                                 ),
+    .data_i ( masked_stimulus                                    ),
+    .mask_i ( in_mask                                            ),
+    .prd_i  ( prd_masking[WidthPRDSBoxCanrightMaskedNoreuse-1:0] ),
+    .data_o ( masked_response[0]                                 ),
+    .mask_o ( out_mask[0]                                        )
   );
 
   aes_sbox_canright_masked aes_sbox_canright_masked (
-    .op_i       ( op                 ),
-    .data_i     ( masked_stimulus    ),
-    .in_mask_i  ( in_mask            ),
-    .out_mask_i ( out_mask           ),
-    .data_o     ( masked_response[1] )
+    .op_i   ( op                                          ),
+    .data_i ( masked_stimulus                             ),
+    .mask_i ( in_mask                                     ),
+    .prd_i  ( prd_masking[WidthPRDSBoxCanrightMasked-1:0] ),
+    .data_o ( masked_response[1]                          ),
+    .mask_o ( out_mask[1]                                 )
   );
 
   // Unmask responses
   always_comb begin : unmask_resp
     for (int i=0; i<NUM_SBOX_IMPLS_MASKED; i++) begin
-      responses[NUM_SBOX_IMPLS+i] = masked_response[i] ^ out_mask;
+      responses[NUM_SBOX_IMPLS+i] = masked_response[i] ^ out_mask[i];
     end
   end
 
diff --git a/hw/ip/aes/rtl/aes_cipher_core.sv b/hw/ip/aes/rtl/aes_cipher_core.sv
index 71b0798..00c4050 100644
--- a/hw/ip/aes/rtl/aes_cipher_core.sv
+++ b/hw/ip/aes/rtl/aes_cipher_core.sv
@@ -268,14 +268,16 @@
     assign unused_prd_masking_rsd_req = prd_masking_rsd_req;
     assign prd_masking_rsd_ack        = 1'b0;
 
+    logic [3:0][3:0][7:0] unused_sb_out_mask;
+    assign unused_sb_out_mask = sb_out_mask;
+
   end else begin : gen_masks
     // The input mask is the mask share of the state.
     assign sb_in_mask  = state_q[1];
 
     // The masking PRNG generates:
-    // - the SubBytes output masks,
-    // - additional randomness required by SubBytes, as well as
-    // - the randomness required by the key expand module.
+    // - the pseudo-random data (PRD) required by SubBytes,
+    // - the PRD required by the key expand module (has 4 S-Boxes internally).
     aes_prng_masking #(
       .Width                ( WidthPRDMasking          ),
       .ChunkSize            ( ChunkSizePRDMasking      ),
@@ -296,32 +298,55 @@
     );
   end
 
-  // Extract SubBytes output masks and additional randomness on a row basis. We have:
-  // prd_masking = { prd_key_expand, ... , sb_prd[4], sb_out_mask[4], sb_prd[0], sb_out_mask[0] }
-  localparam int unsigned WidthPRDRow = 4*(8+WidthPRDSBox);
-  for (genvar i = 0; i < 4; i++) begin : gen_sb_prd
-    assign sb_out_mask[i]    = aes_sb_out_mask_get(prd_masking[i * WidthPRDRow +: WidthPRDRow]);
-    assign data_in_mask_o[i] = aes_sb_out_mask_get(prd_masking[i * WidthPRDRow +: WidthPRDRow]);
-    assign prd_sub_bytes[i]  =      aes_sb_prd_get(prd_masking[i * WidthPRDRow +: WidthPRDRow]);
-  end
-  // Extract randomness for key expand module.
+  // Extract randomness for key expand module and SubBytes.
+  //
+  // The masking PRNG output has the following shape:
+  // prd_masking = { prd_key_expand, prd_sub_bytes }
   assign prd_key_expand = prd_masking[WidthPRDMasking-1 -: WidthPRDKey];
+  assign prd_sub_bytes  = prd_masking[WidthPRDData-1 -: WidthPRDData];
+
+  // Extract randomness for masking the input data.
+  //
+  // The masking PRNG is used for generating both the PRD for the S-Boxes/SubBytes operation as
+  // well as for the input data masks. When using any of the masked Canright S-Box implementations,
+  // it is important that the SubBytes input masks (generated by the PRNG in Round X-1) and the
+  // SubBytes output masks (generated by the PRNG in Round X) are independent. Inside the PRNG,
+  // this is achieved by using multiple, separately re-seeded LFSR chunks and by selecting the
+  // separate LFSR chunks in alternating fashion. Since the input data masks become the SubBytes
+  // input masks in the first round, we select the same 8 bit lanes for the input data masks which
+  // are also used to form the SubBytes output mask for the masked Canright S-Box implementations,
+  // i.e., the 8 LSBs of the per S-Box PRD. In particular, we have:
+  //
+  // prd_masking = { prd_key_expand, ... , sb_prd[4], sb_out_mask[4], sb_prd[0], sb_out_mask[0] }
+  //
+  // Where sb_out_mask[x] contains the SubBytes output mask for byte x (when using a masked
+  // Canright S-Box implementation) and sb_prd[x] contains additional PRD consumed by SubBytes for
+  // byte x.
+  //
+  // When using a masked S-Box implementation other than Canright, we still select the 8 LSBs of
+  // the per-S-Box PRD to form the input data mask of the corresponding byte. We do this to
+  // distribute the input data masks over all LFSR chunks of the masking PRNG. We do the extraction
+  // on a row basis.
+  localparam int unsigned WidthPRDRow = 4*WidthPRDSBox;
+  for (genvar i = 0; i < 4; i++) begin : gen_in_mask
+    assign data_in_mask_o[i] = aes_prd_get_lsbs(prd_masking[i * WidthPRDRow +: WidthPRDRow]);
+  end
 
   // Cipher data path
   aes_sub_bytes #(
     .SBoxImpl ( SBoxImpl )
   ) u_aes_sub_bytes (
-    .clk_i         ( clk_i             ),
-    .rst_ni        ( rst_ni            ),
-    .en_i          ( sub_bytes_en      ),
-    .out_req_o     ( sub_bytes_out_req ),
-    .out_ack_i     ( sub_bytes_out_ack ),
-    .op_i          ( op_i              ),
-    .data_i        ( state_q[0]        ),
-    .in_mask_i     ( sb_in_mask        ),
-    .out_mask_i    ( sb_out_mask       ),
-    .prd_masking_i ( prd_sub_bytes     ),
-    .data_o        ( sub_bytes_out     )
+    .clk_i     ( clk_i             ),
+    .rst_ni    ( rst_ni            ),
+    .en_i      ( sub_bytes_en      ),
+    .out_req_o ( sub_bytes_out_req ),
+    .out_ack_i ( sub_bytes_out_ack ),
+    .op_i      ( op_i              ),
+    .data_i    ( state_q[0]        ),
+    .mask_i    ( sb_in_mask        ),
+    .prd_i     ( prd_sub_bytes     ),
+    .data_o    ( sub_bytes_out     ),
+    .mask_o    ( sb_out_mask       )
   );
 
   for (genvar s = 0; s < NumShares; s++) begin : gen_shares_shift_mix
@@ -405,19 +430,19 @@
     .Masking      ( Masking      ),
     .SBoxImpl     ( SBoxImpl     )
   ) u_aes_key_expand (
-    .clk_i         ( clk_i              ),
-    .rst_ni        ( rst_ni             ),
-    .cfg_valid_i   ( cfg_valid_i        ),
-    .op_i          ( key_expand_op      ),
-    .en_i          ( key_expand_en      ),
-    .out_req_o     ( key_expand_out_req ),
-    .out_ack_i     ( key_expand_out_ack ),
-    .clear_i       ( key_expand_clear   ),
-    .round_i       ( key_expand_round   ),
-    .key_len_i     ( key_len_i          ),
-    .key_i         ( key_full_q         ),
-    .key_o         ( key_expand_out     ),
-    .prd_masking_i ( prd_key_expand     )
+    .clk_i       ( clk_i              ),
+    .rst_ni      ( rst_ni             ),
+    .cfg_valid_i ( cfg_valid_i        ),
+    .op_i        ( key_expand_op      ),
+    .en_i        ( key_expand_en      ),
+    .out_req_o   ( key_expand_out_req ),
+    .out_ack_i   ( key_expand_out_ack ),
+    .clear_i     ( key_expand_clear   ),
+    .round_i     ( key_expand_round   ),
+    .key_len_i   ( key_len_i          ),
+    .key_i       ( key_full_q         ),
+    .key_o       ( key_expand_out     ),
+    .prd_i       ( prd_key_expand     )
   );
 
   for (genvar s = 0; s < NumShares; s++) begin : gen_shares_round_key
@@ -617,13 +642,46 @@
        SBoxImpl == SBoxImplCanright)))
 
   // Make sure the output of the masking PRNG is properly extracted without creating overlaps
-  // of masks and PRD distributed to the individual S-Boxes.
-  logic [WidthPRDMasking-1:0] unused_prd_masking;
-  for (genvar i = 0; i < 4; i++) begin : gen_unused_prd_masking
-    assign unused_prd_masking[i * WidthPRDRow +: WidthPRDRow] =
-        aes_sb_out_mask_prd_concat(sb_out_mask[i], prd_sub_bytes[i]);
+  // in the data input masks, or between the PRD fed to the key expand module and SubBytes.
+  if (WidthPRDSBox > 8) begin : gen_prd_extract_assert
+    // For one row of the state matrix, extract the WidthPRDSBox-8 MSBs of the per-S-Box PRD from the
+    // PRNG output.
+    function automatic logic [3:0][(WidthPRDSBox-8)-1:0] aes_prd_get_msbs(
+      logic [(4*WidthPRDSBox)-1:0] in
+    );
+      logic [3:0][(WidthPRDSBox-8)-1:0] prd_msbs;
+      for (int i=0; i<4; i++) begin
+        prd_msbs[i] = in[(i*WidthPRDSBox) + 8 +: (WidthPRDSBox-8)];
+      end
+      return prd_msbs;
+    endfunction
+
+    // For one row of the state matrix, undo the extraction of LSBs and MSBs of the per-S-Box PRD from
+    // the PRNG output. This can be used to verify proper extraction (no overlap of output masks and PRD
+    // for masked Canright S-Box implementations, no unused PRNG output).
+    function automatic logic [4*WidthPRDSBox-1:0] aes_prd_concat_bits(
+      logic [3:0]                 [7:0] prd_lsbs,
+      logic [3:0][(WidthPRDSBox-8)-1:0] prd_msbs
+    );
+      logic [(4*WidthPRDSBox)-1:0] prd;
+      for (int i=0; i<4; i++) begin
+        prd[(i*WidthPRDSBox) +: WidthPRDSBox] = {prd_msbs[i], prd_lsbs[i]};
+      end
+      return prd;
+    endfunction
+
+    // Check for correct extraction of masking PRNG output without overlaps.
+    logic            [WidthPRDMasking-1:0] unused_prd_masking;
+    logic [3:0][3:0][(WidthPRDSBox-8)-1:0] unused_prd_msbs;
+    for (genvar i = 0; i < 4; i++) begin : gen_unused_prd_msbs
+      assign unused_prd_msbs[i] = aes_prd_get_msbs(prd_masking[i * WidthPRDRow +: WidthPRDRow]);
+    end
+    for (genvar i = 0; i < 4; i++) begin : gen_unused_prd_masking
+      assign unused_prd_masking[i * WidthPRDRow +: WidthPRDRow] =
+          aes_prd_concat_bits(data_in_mask_o[i], unused_prd_msbs[i]);
+    end
+    assign unused_prd_masking[WidthPRDMasking-1 -: WidthPRDKey] = prd_key_expand;
+    `ASSERT(AesMskgPrdExtraction, prd_masking == unused_prd_masking)
   end
-  assign unused_prd_masking[WidthPRDMasking-1 -: WidthPRDKey] = prd_key_expand;
-  `ASSERT(AesMskgPrdExtraction, prd_masking == unused_prd_masking)
 
 endmodule
diff --git a/hw/ip/aes/rtl/aes_key_expand.sv b/hw/ip/aes/rtl/aes_key_expand.sv
index 9a8e26b..b326e03 100644
--- a/hw/ip/aes/rtl/aes_key_expand.sv
+++ b/hw/ip/aes/rtl/aes_key_expand.sv
@@ -26,33 +26,32 @@
   input  key_len_e               key_len_i,
   input  logic       [7:0][31:0] key_i [NumShares],
   output logic       [7:0][31:0] key_o [NumShares],
-  input  logic [WidthPRDKey-1:0] prd_masking_i
+  input  logic [WidthPRDKey-1:0] prd_i
 );
 
-  logic                [7:0] rcon_d, rcon_q;
-  logic                      rcon_we;
-  logic                      use_rcon;
+  logic       [7:0] rcon_d, rcon_q;
+  logic             rcon_we;
+  logic             use_rcon;
 
-  logic                [3:0] rnd;
-  logic                [3:0] rnd_type;
+  logic       [3:0] rnd;
+  logic       [3:0] rnd_type;
 
-  logic               [31:0] spec_in_128 [NumShares];
-  logic               [31:0] spec_in_192 [NumShares];
-  logic               [31:0] rot_word_in [NumShares];
-  logic               [31:0] rot_word_out [NumShares];
-  logic                      use_rot_word;
-  logic               [31:0] sub_word_in, sub_word_out;
-  logic                [3:0] sub_word_out_req;
-  logic               [31:0] sw_in_mask, sw_out_mask;
-  logic [4*WidthPRDSBox-1:0] sw_prd;
-  logic                [7:0] rcon_add_in, rcon_add_out;
-  logic               [31:0] rcon_added;
+  logic      [31:0] spec_in_128 [NumShares];
+  logic      [31:0] spec_in_192 [NumShares];
+  logic      [31:0] rot_word_in [NumShares];
+  logic      [31:0] rot_word_out [NumShares];
+  logic             use_rot_word;
+  logic      [31:0] sub_word_in, sub_word_out;
+  logic       [3:0] sub_word_out_req;
+  logic      [31:0] sw_in_mask, sw_out_mask;
+  logic       [7:0] rcon_add_in, rcon_add_out;
+  logic      [31:0] rcon_added;
 
-  logic               [31:0] irregular [NumShares];
-  logic          [7:0][31:0] regular [NumShares];
+  logic      [31:0] irregular [NumShares];
+  logic [7:0][31:0] regular [NumShares];
 
   // cfg_valid_i is used for gating assertions only.
-  logic                      unused_cfg_valid;
+  logic                     unused_cfg_valid;
   assign unused_cfg_valid = cfg_valid_i;
 
   // Get a shorter reference.
@@ -189,30 +188,32 @@
   if (!Masking) begin : gen_no_sw_in_mask
     // The mask share is ignored anyway, it can be 0.
     assign sw_in_mask  = '0;
+
+    // Tie-off unused signals.
+    logic [31:0] unused_sw_out_mask;
+    assign unused_sw_out_mask = sw_out_mask;
+
   end else begin : gen_sw_in_mask
     // The input mask is the mask share of rot_word_in/out.
     assign sw_in_mask = use_rot_word ? rot_word_out[1] : rot_word_in[1];
   end
 
-  assign sw_out_mask = aes_sb_out_mask_get(prd_masking_i);
-  assign sw_prd      = aes_sb_prd_get(prd_masking_i);
-
   // SubWord - individually substitute bytes.
   for (genvar i = 0; i < 4; i++) begin : gen_sbox
     aes_sbox #(
       .SBoxImpl ( SBoxImpl )
     ) u_aes_sbox_i (
-      .clk_i         ( clk_i                                  ),
-      .rst_ni        ( rst_ni                                 ),
-      .en_i          ( en_i                                   ),
-      .out_req_o     ( sub_word_out_req[i]                    ),
-      .out_ack_i     ( out_ack_i                              ),
-      .op_i          ( CIPH_FWD                               ),
-      .data_i        ( sub_word_in[8*i +: 8]                  ),
-      .in_mask_i     ( sw_in_mask[8*i +: 8]                   ),
-      .out_mask_i    ( sw_out_mask[8*i +: 8]                  ),
-      .prd_masking_i ( sw_prd[WidthPRDSBox*i +: WidthPRDSBox] ),
-      .data_o        ( sub_word_out[8*i +: 8]                 )
+      .clk_i     ( clk_i                                 ),
+      .rst_ni    ( rst_ni                                ),
+      .en_i      ( en_i                                  ),
+      .out_req_o ( sub_word_out_req[i]                   ),
+      .out_ack_i ( out_ack_i                             ),
+      .op_i      ( CIPH_FWD                              ),
+      .data_i    ( sub_word_in[8*i +: 8]                 ),
+      .mask_i    ( sw_in_mask[8*i +: 8]                  ),
+      .prd_i     ( prd_i[WidthPRDSBox*i +: WidthPRDSBox] ),
+      .data_o    ( sub_word_out[8*i +: 8]                ),
+      .mask_o    ( sw_out_mask[8*i +: 8]                 )
     );
   end
 
@@ -400,10 +401,4 @@
       AES_256
       })
 
-  // Make sure the output of the masking PRNG is properly extracted without creating overlaps
-  // of masks and PRD distributed to the individual S-Boxes.
-  logic [WidthPRDKey-1:0] unused_prd_masking;
-  assign unused_prd_masking = aes_sb_out_mask_prd_concat(sw_out_mask, sw_prd);
-  `ASSERT(AesMskgPrdExtraction, prd_masking_i == unused_prd_masking)
-
 endmodule
diff --git a/hw/ip/aes/rtl/aes_pkg.sv b/hw/ip/aes/rtl/aes_pkg.sv
index e374e1a..d1c743a 100644
--- a/hw/ip/aes/rtl/aes_pkg.sv
+++ b/hw/ip/aes/rtl/aes_pkg.sv
@@ -8,10 +8,11 @@
 
 // Widths of signals carrying pseudo-random data for clearing and masking and purposes
 parameter int unsigned WidthPRDClearing = 64;
-parameter int unsigned WidthPRDSBox     = 10; // Number PRD bits per S-Box, not incl. the 8 bits
-                                              // for the output mask
-parameter int unsigned WidthPRDData     = 16*(8+WidthPRDSBox); // 16 S-Boxes for the data path
-parameter int unsigned WidthPRDKey      = 4*(8+WidthPRDSBox);  // 4 S-Boxes for the key expand
+parameter int unsigned WidthPRDSBox     = 18; // Number PRD bits per S-Box. This includes the
+                                              // 8 bits for the output mask when using any of the
+                                              // masked Canright S-Box implementations.
+parameter int unsigned WidthPRDData     = 16*WidthPRDSBox; // 16 S-Boxes for the data path
+parameter int unsigned WidthPRDKey      = 4*WidthPRDSBox;  // 4 S-Boxes for the key expand
 parameter int unsigned WidthPRDMasking  = WidthPRDData + WidthPRDKey;
 
 parameter int unsigned ChunkSizePRDMasking = WidthPRDMasking/10;
@@ -33,7 +34,7 @@
 // permutation.
 // These LFSR parameters have been generated with
 // $ util/design/gen-lfsr-seed.py --width 360 --seed 31468618 --prefix "Masking"
-parameter int MaskingLfsrWidth = 360;
+parameter int MaskingLfsrWidth = 360; // = WidthPRDMasking = WidthPRDSBox * (16 + 4)
 typedef logic [MaskingLfsrWidth-1:0] masking_lfsr_seed_t;
 parameter masking_lfsr_seed_t RndCnstMaskingLfsrSeedDefault = {
   180'h5ae9b31605f9077a6b758a442031e1c4616ea343ec153,
@@ -42,7 +43,7 @@
 
 // These LFSR parameters have been generated with
 // $ util/design/gen-lfsr-seed.py --width 36 --seed 31468618 --prefix "MskgChunk"
-parameter int MskgChunkLfsrWidth = 36;
+parameter int MskgChunkLfsrWidth = 36; // = ChunkSizePRDMasking = WidthPRDMasking/10
 typedef logic [MskgChunkLfsrWidth-1:0][$clog2(MskgChunkLfsrWidth)-1:0] mskg_chunk_lfsr_perm_t;
 parameter mskg_chunk_lfsr_perm_t RndCnstMskgChunkLfsrPermDefault =
     216'h6587da04c59c02125750f35e7634e08951122874022ce19b143211;
@@ -374,42 +375,41 @@
   return vec_c;
 endfunction
 
-// Functions for extracting  SubBytes output masks and additional pseudo-random data (PRD) from the
-// output of the masking PRNG on a row basis. We have:
+// Function for extracting LSBs of the per-S-Box pseudo-random data (PRD) from the output of the
+// masking PRNG.
+//
+// The masking PRNG is used for generating both the PRD for the S-Boxes/SubBytes operation as
+// well as for the input data masks. When using any of the masked Canright S-Box implementations,
+// it is important that the SubBytes input masks (generated by the PRNG in Round X-1) and the
+// SubBytes output masks (generated by the PRNG in Round X) are independent. Inside the PRNG,
+// this is achieved by using multiple, separately re-seeded LFSR chunks and by selecting the
+// separate LFSR chunks in alternating fashion. Since the input data masks become the SubBytes
+// input masks in the first round, we select the same 8 bit lanes for the input data masks which
+// are also used to form the SubBytes output mask for the masked Canright S-Box implementations,
+// i.e., the 8 LSBs of the per S-Box PRD. In particular, we have:
+//
 // prng_output = { prd_key_expand, ... , sb_prd[4], sb_out_mask[4], sb_prd[0], sb_out_mask[0] }
+//
+// Where sb_out_mask[x] contains the SubBytes output mask for byte x (when using a masked
+// Canright S-Box implementation) and sb_prd[x] contains additional PRD consumed by SubBytes for
+// byte x.
+//
+// When using a masked S-Box implementation other than Canright, we still select the 8 LSBs of
+// the per-S-Box PRD to form the input data mask of the corresponding byte. We do this to
+// distribute the input data masks over all LFSR chunks of the masking PRNG.
 
-// Extract one row of output masks for SubBytes from PRNG output. The output mask is in the LSBs of
-// each segment.
-function automatic logic [3:0][7:0] aes_sb_out_mask_get(logic [4*(8+WidthPRDSBox)-1:0] in);
-  logic [3:0][7:0] sb_out_mask;
-  for (int i=0; i<4; i++) begin
-    sb_out_mask[i] = in[i*(8+WidthPRDSBox) +: 8];
-  end
-  return sb_out_mask;
-endfunction
-
-// Extract one row of PRD for SubBytes from PRNG output. The PRD part is in the MSBs of each
-// segment.
-function automatic logic [3:0][WidthPRDSBox-1:0] aes_sb_prd_get(logic [4*(8+WidthPRDSBox)-1:0] in);
-  logic [3:0][WidthPRDSBox-1:0] sb_prd;
-  for (int i=0; i<4; i++) begin
-    sb_prd[i] = in[i*(8+WidthPRDSBox)+8 +: WidthPRDSBox];
-  end
-  return sb_prd;
-endfunction
-
-// Undo extraction of output masks and PRD for SubBytes for one row. This can be used to verify
-// proper extraction (no overlap of masks and PRD, no unused PRNG output).
-function automatic logic [4*(8+WidthPRDSBox)-1:0] aes_sb_out_mask_prd_concat(
-  logic              [3:0][7:0] sb_out_mask,
-  logic [3:0][WidthPRDSBox-1:0] sb_prd
+// For one row of the state matrix, extract the 8 LSBs of the per-S-Box PRD from the PRNG output.
+// These bits are used as:
+// - input data masks, and
+// - SubBytes output mask when using a masked Canright S-Box implementation.
+function automatic logic [3:0][7:0] aes_prd_get_lsbs(
+  logic [(4*WidthPRDSBox)-1:0] in
 );
-  logic [4*(8+WidthPRDSBox)-1:0] sb_out_mask_prd;
+  logic [3:0][7:0] prd_lsbs;
   for (int i=0; i<4; i++) begin
-    sb_out_mask_prd[i*(8+WidthPRDSBox)   +: 8]            = sb_out_mask[i];
-    sb_out_mask_prd[i*(8+WidthPRDSBox)+8 +: WidthPRDSBox] = sb_prd[i];
+    prd_lsbs[i] = in[i*WidthPRDSBox +: 8];
   end
-  return sb_out_mask_prd;
+  return prd_lsbs;
 endfunction
 
 endpackage
diff --git a/hw/ip/aes/rtl/aes_sbox.sv b/hw/ip/aes/rtl/aes_sbox.sv
index c0e7cec..b3dc17b 100644
--- a/hw/ip/aes/rtl/aes_sbox.sv
+++ b/hw/ip/aes/rtl/aes_sbox.sv
@@ -15,10 +15,10 @@
   input  logic                    out_ack_i,
   input  ciph_op_e                op_i,
   input  logic              [7:0] data_i,
-  input  logic              [7:0] in_mask_i,
-  input  logic              [7:0] out_mask_i,
-  input  logic [WidthPRDSBox-1:0] prd_masking_i,
-  output logic              [7:0] data_o
+  input  logic              [7:0] mask_i,
+  input  logic [WidthPRDSBox-1:0] prd_i,
+  output logic              [7:0] data_o,
+  output logic              [7:0] mask_o
 );
 
   import aes_pkg::*;
@@ -31,26 +31,30 @@
     // Tie off unused inputs.
     logic                    unused_clk;
     logic                    unused_rst;
-    logic             [15:0] unused_masks;
+    logic              [7:0] unused_mask;
     logic [WidthPRDSBox-1:0] unused_prd;
-    assign unused_clk   = clk_i;
-    assign unused_rst   = rst_ni;
-    assign unused_masks = {in_mask_i, out_mask_i};
-    assign unused_prd   = prd_masking_i;
+    assign unused_clk  = clk_i;
+    assign unused_rst  = rst_ni;
+    assign unused_mask = mask_i;
+    assign unused_prd  = prd_i;
 
     if (SBoxImpl == SBoxImplCanright) begin : gen_sbox_canright
       aes_sbox_canright u_aes_sbox (
-        .op_i,
-        .data_i,
-        .data_o
+        .op_i   ( op_i   ),
+        .data_i ( data_i ),
+        .data_o ( data_o )
       );
+
     end else begin : gen_sbox_lut // SBoxImpl == SBoxImplLut
       aes_sbox_lut u_aes_sbox (
-        .op_i,
-        .data_i,
-        .data_o
+        .op_i   ( op_i   ),
+        .data_i ( data_i ),
+        .data_o ( data_o )
       );
     end
+
+    assign mask_o = '0;
+
   end else begin : gen_sbox_masked
 
     if (SBoxImpl == SBoxImplCanrightMaskedNoreuse) begin : gen_sbox_canright_masked_noreuse
@@ -59,31 +63,40 @@
       logic unused_rst;
       assign unused_clk = clk_i;
       assign unused_rst = rst_ni;
+      if (WidthPRDSBox > 18) begin : gen_unused_prd
+        logic [WidthPRDSBox-1-18:0] unused_prd;
+        assign unused_prd = prd_i[WidthPRDSBox-1:18];
+      end
 
       aes_sbox_canright_masked_noreuse u_aes_sbox (
-        .op_i,
-        .data_i,
-        .in_mask_i,
-        .out_mask_i,
-        .prd_masking_i,
-        .data_o
+        .op_i   ( op_i        ),
+        .data_i ( data_i      ),
+        .mask_i ( mask_i      ),
+        .prd_i  ( prd_i[17:0] ),
+        .data_o ( data_o      ),
+        .mask_o ( mask_o      )
       );
+
     end else begin : gen_sbox_canright_masked // SBoxImpl == SBoxImplCanrightMasked
       // Tie off unused inputs.
-      logic                    unused_clk;
-      logic                    unused_rst;
-      logic [WidthPRDSBox-1:0] unused_prd;
+      logic  unused_clk;
+      logic  unused_rst;
       assign unused_clk = clk_i;
       assign unused_rst = rst_ni;
-      assign unused_prd = prd_masking_i;
+      if (WidthPRDSBox > 8) begin : gen_unused_prd
+        logic [WidthPRDSBox-1-8:0] unused_prd;
+        assign unused_prd = prd_i[WidthPRDSBox-1:8];
+      end
 
       aes_sbox_canright_masked u_aes_sbox (
-        .op_i,
-        .data_i,
-        .in_mask_i,
-        .out_mask_i,
-        .data_o
+        .op_i   ( op_i       ),
+        .data_i ( data_i     ),
+        .mask_i ( mask_i     ),
+        .prd_i  ( prd_i[7:0] ),
+        .data_o ( data_o     ),
+        .mask_o ( mask_o     )
       );
+
     end
   end
 
diff --git a/hw/ip/aes/rtl/aes_sbox_canright_masked.sv b/hw/ip/aes/rtl/aes_sbox_canright_masked.sv
index 13e0fe2..4fd4e5b 100644
--- a/hw/ip/aes/rtl/aes_sbox_canright_masked.sv
+++ b/hw/ip/aes/rtl/aes_sbox_canright_masked.sv
@@ -245,10 +245,11 @@
 
 module aes_sbox_canright_masked (
   input  aes_pkg::ciph_op_e op_i,
-  input  logic [7:0]        data_i,     // masked, the actual input data is data_i ^ in_mask_i
-  input  logic [7:0]        in_mask_i,  // input mask, independent from actual input data
-  input  logic [7:0]        out_mask_i, // output mask, independent from input mask
-  output logic [7:0]        data_o      // masked, the actual output data is data_o ^ out_mask_i
+  input  logic [7:0]        data_i, // masked, the actual input data is data_i ^ mask_i
+  input  logic [7:0]        mask_i, // input mask, independent from actual input data
+  input  logic [7:0]        prd_i,  // pseudo-random data for remasking, independent of input mask
+  output logic [7:0]        data_o, // masked, the actual output data is data_o ^ mask_o
+  output logic [7:0]        mask_o  // output mask
 );
 
   import aes_pkg::*;
@@ -265,14 +266,18 @@
   assign in_data_basis_x = (op_i == CIPH_FWD) ? aes_mvm(data_i, A2X) :
                                                 aes_mvm(data_i ^ 8'h63, S2X);
 
+  // For the masked Canright SBox, the output mask directly corresponds to the pseduo-random data
+  // provided as input.
+  assign mask_o = prd_i;
+
   // Convert masks to normal basis X.
   // The addition of constant 8'h63 following the affine transformation is skipped.
-  assign in_mask_basis_x  = (op_i == CIPH_FWD) ? aes_mvm(in_mask_i, A2X) :
-                                                 aes_mvm(in_mask_i, S2X);
+  assign in_mask_basis_x  = (op_i == CIPH_FWD) ? aes_mvm(mask_i, A2X) :
+                                                 aes_mvm(mask_i, S2X);
 
   // The output mask is converted in the opposite direction.
-  assign out_mask_basis_x = (op_i == CIPH_INV) ? aes_mvm(out_mask_i, A2X) :
-                                                 aes_mvm(out_mask_i, S2X);
+  assign out_mask_basis_x = (op_i == CIPH_INV) ? aes_mvm(mask_o, A2X) :
+                                                 aes_mvm(mask_o, S2X);
 
   // Do the inversion in normal basis X.
   aes_masked_inverse_gf2p8 aes_masked_inverse_gf2p8 (
diff --git a/hw/ip/aes/rtl/aes_sbox_canright_masked_noreuse.sv b/hw/ip/aes/rtl/aes_sbox_canright_masked_noreuse.sv
index 26fd7ba..72578b2 100644
--- a/hw/ip/aes/rtl/aes_sbox_canright_masked_noreuse.sv
+++ b/hw/ip/aes/rtl/aes_sbox_canright_masked_noreuse.sv
@@ -246,11 +246,12 @@
 
 module aes_sbox_canright_masked_noreuse (
   input  aes_pkg::ciph_op_e op_i,
-  input  logic [7:0]        data_i,        // masked, the actual input data is data_i ^ in_mask_i
-  input  logic [7:0]        in_mask_i,     // input mask, independent from actual input data
-  input  logic [7:0]        out_mask_i,    // output mask, independent from input mask
-  input  logic [9:0]        prd_masking_i, // pseudo-random data, e.g. for intermediate masks
-  output logic [7:0]        data_o         // masked, the actual output data is data_o ^ out_mask_i
+  input  logic        [7:0] data_i, // masked, the actual input data is data_i ^ mask_i
+  input  logic        [7:0] mask_i, // input mask, independent from actual input data
+  input  logic       [17:0] prd_i,  // pseudo-random data, for remasking and for intermediate
+                                    // masks, must be independent of input mask
+  output logic        [7:0] data_o, // masked, the actual output data is data_o ^ mask_o
+  output logic        [7:0] mask_o  // output mask
 );
 
   import aes_pkg::*;
@@ -267,21 +268,29 @@
   assign in_data_basis_x = (op_i == CIPH_FWD) ? aes_mvm(data_i, A2X) :
                                                 aes_mvm(data_i ^ 8'h63, S2X);
 
+  // For the masked Canright SBox with no re-use, the output mask directly corresponds to the
+  // LSBs of the pseduo-random data provided as input.
+  assign mask_o = prd_i[7:0];
+
+  // The remaining bits are used for intermediate masks.
+  logic [9:0] prd_masking;
+  assign prd_masking = prd_i[17:8];
+
   // Convert masks to normal basis X.
   // The addition of constant 8'h63 following the affine transformation is skipped.
-  assign in_mask_basis_x  = (op_i == CIPH_FWD) ? aes_mvm(in_mask_i, A2X) :
-                                                 aes_mvm(in_mask_i, S2X);
+  assign in_mask_basis_x  = (op_i == CIPH_FWD) ? aes_mvm(mask_i, A2X) :
+                                                 aes_mvm(mask_i, S2X);
 
   // The output mask is converted in the opposite direction.
-  assign out_mask_basis_x = (op_i == CIPH_INV) ? aes_mvm(out_mask_i, A2X) :
-                                                 aes_mvm(out_mask_i, S2X);
+  assign out_mask_basis_x = (op_i == CIPH_INV) ? aes_mvm(mask_o, A2X) :
+                                                 aes_mvm(mask_o, S2X);
 
   // Do the inversion in normal basis X.
   aes_masked_inverse_gf2p8_noreuse aes_masked_inverse_gf2p8 (
     .a     ( in_data_basis_x  ), // input
     .m     ( in_mask_basis_x  ), // input
     .n     ( out_mask_basis_x ), // input
-    .prd   ( prd_masking_i    ), // input
+    .prd   ( prd_masking      ), // input
     .a_inv ( out_data_basis_x )  // output
   );
 
diff --git a/hw/ip/aes/rtl/aes_sub_bytes.sv b/hw/ip/aes/rtl/aes_sub_bytes.sv
index 24c7ff5..d554197 100644
--- a/hw/ip/aes/rtl/aes_sub_bytes.sv
+++ b/hw/ip/aes/rtl/aes_sub_bytes.sv
@@ -15,10 +15,10 @@
   input  logic                              out_ack_i,
   input  ciph_op_e                          op_i,
   input  logic              [3:0][3:0][7:0] data_i,
-  input  logic              [3:0][3:0][7:0] in_mask_i,
-  input  logic              [3:0][3:0][7:0] out_mask_i,
-  input  logic [3:0][3:0][WidthPRDSBox-1:0] prd_masking_i,
-  output logic              [3:0][3:0][7:0] data_o
+  input  logic              [3:0][3:0][7:0] mask_i,
+  input  logic [3:0][3:0][WidthPRDSBox-1:0] prd_i,
+  output logic              [3:0][3:0][7:0] data_o,
+  output logic              [3:0][3:0][7:0] mask_o
 );
 
   logic [3:0][3:0] out_req;
@@ -32,17 +32,17 @@
       aes_sbox #(
         .SBoxImpl ( SBoxImpl )
       ) u_aes_sbox_ij (
-        .clk_i         ( clk_i               ),
-        .rst_ni        ( rst_ni              ),
-        .en_i          ( en_i                ),
-        .out_req_o     ( out_req[i][j]       ),
-        .out_ack_i     ( out_ack_i           ),
-        .op_i          ( op_i                ),
-        .data_i        ( data_i[i][j]        ),
-        .in_mask_i     ( in_mask_i[i][j]     ),
-        .out_mask_i    ( out_mask_i[i][j]    ),
-        .prd_masking_i ( prd_masking_i[i][j] ),
-        .data_o        ( data_o[i][j]        )
+        .clk_i     ( clk_i         ),
+        .rst_ni    ( rst_ni        ),
+        .en_i      ( en_i          ),
+        .out_req_o ( out_req[i][j] ),
+        .out_ack_i ( out_ack_i     ),
+        .op_i      ( op_i          ),
+        .data_i    ( data_i[i][j]  ),
+        .mask_i    ( mask_i[i][j]  ),
+        .prd_i     ( prd_i [i][j]  ),
+        .data_o    ( data_o[i][j]  ),
+        .mask_o    ( mask_o[i][j]  )
       );
     end
   end