diff --git a/hw/ip/aes/data/aes.hjson b/hw/ip/aes/data/aes.hjson
index 9171216..589789d 100644
--- a/hw/ip/aes/data/aes.hjson
+++ b/hw/ip/aes/data/aes.hjson
@@ -30,13 +30,15 @@
       desc: '''
         Initial Key Registers. Loaded into the internal Full Key register upon
         starting encryption/decryption of the next block. Can only be updated
-        when the AES unit is idle. All keys registers must be updated when the
+        when the AES unit is idle. If the AES unit is non-idle, writes to these
+        registers are ignored. All key registers must be updated when the
         key is changed, regardless of key length (write 0 for unusued bits).
       '''
       count: "NumRegsKey",
       cname: "KEY",
       swaccess: "wo",
       hwaccess: "hrw",
+      hwext:    "true",
       hwqe:     "true",
       fields: [
         { bits: "31:0", name: "key", desc: "Initial Key" }
diff --git a/hw/ip/aes/rtl/aes_control.sv b/hw/ip/aes/rtl/aes_control.sv
index 701193e..0b68b5e 100644
--- a/hw/ip/aes/rtl/aes_control.sv
+++ b/hw/ip/aes/rtl/aes_control.sv
@@ -32,6 +32,8 @@
 
   // Control outputs key expand data path
   output aes_pkg::mode_e          key_expand_mode_o,
+  output aes_pkg::key_init_sel_e  key_init_sel_o,
+  output logic [7:0]              key_init_we_o,
   output aes_pkg::key_full_sel_e  key_full_sel_o,
   output logic                    key_full_we_o,
   output aes_pkg::key_dec_sel_e   key_dec_sel_o,
@@ -43,7 +45,6 @@
   output aes_pkg::round_key_sel_e round_key_sel_o,
 
   // Key/data registers
-  output logic                    key_we_o,
   output logic                    data_in_we_o,
   output logic                    data_out_we_o,
 
@@ -82,6 +83,7 @@
   logic       data_in_new;
   logic       data_in_load;
 
+  logic       key_init_clear;
   logic [7:0] key_init_new_d, key_init_new_q;
   logic       key_init_new;
   logic       dec_key_gen;
@@ -112,6 +114,8 @@
     add_rk_sel_o = ADD_RK_ROUND;
 
     // Key expand data path
+    key_init_sel_o     = KEY_INIT_INPUT;
+    key_init_we_o      = 8'h00;
     key_full_sel_o     = KEY_FULL_ROUND;
     key_full_we_o      = 1'b0;
     key_dec_sel_o      = KEY_DEC_EXPAND;
@@ -136,7 +140,6 @@
     // Key, data I/O register control
     dec_key_gen   = 1'b0;
     data_in_load  = 1'b0;
-    key_we_o      = 1'b0;
     data_in_we_o  = 1'b0;
     data_out_we_o = 1'b0;
 
@@ -191,6 +194,8 @@
 
           aes_ctrl_ns = CLEAR;
         end
+
+        key_init_we_o = idle_o ? key_init_qe_i : 8'h00;
       end
 
       INIT: begin
@@ -353,7 +358,8 @@
 
       CLEAR: begin
         if (key_clear_i) begin
-          key_we_o       = 1'b1;
+          key_init_sel_o = KEY_INIT_CLEAR;
+          key_init_we_o  = 8'hFF;
           key_full_sel_o = KEY_FULL_CLEAR;
           key_full_we_o  = 1'b1;
           key_dec_sel_o  = KEY_DEC_CLEAR;
@@ -398,7 +404,8 @@
 
   // Detect new key, new input, output read
   // Edge detectors are cleared by the FSM
-  assign key_init_new_d = (dec_key_gen | key_we_o) ? '0 : (key_init_new_q | key_init_qe_i);
+  assign key_init_clear = (key_init_sel_o == KEY_INIT_CLEAR) & (&key_init_we_o);
+  assign key_init_new_d = (dec_key_gen | key_init_clear) ? '0 : (key_init_new_q | key_init_qe_i);
   assign key_init_new   = &key_init_new_d;
 
   assign data_in_new_d = (data_in_load | data_in_we_o) ? '0 : (data_in_new_q | data_in_qe_i);
diff --git a/hw/ip/aes/rtl/aes_core.sv b/hw/ip/aes/rtl/aes_core.sv
index 2f1a565..ec5af2f 100644
--- a/hw/ip/aes/rtl/aes_core.sv
+++ b/hw/ip/aes/rtl/aes_core.sv
@@ -22,7 +22,6 @@
   logic     [3:0][31:0] data_in;
   logic     [3:0]       data_in_qe;
   logic                 data_in_we;
-  logic                 key_we;
   logic     [7:0][31:0] key_init;
   logic     [7:0]       key_init_qe;
 
@@ -42,6 +41,10 @@
   logic [3:0][3:0][7:0] add_round_key_out;
   add_rk_sel_e          add_round_key_in_sel;
 
+  logic     [7:0][31:0] key_init_d;
+  logic     [7:0][31:0] key_init_q;
+  logic     [7:0]       key_init_we;
+  key_init_sel_e        key_init_sel;
   logic     [7:0][31:0] key_full_d;
   logic     [7:0][31:0] key_full_q;
   logic                 key_full_we;
@@ -176,10 +179,31 @@
   // Key //
   /////////
 
+  // Initial Key registers
+  always_comb begin : key_init_mux
+    unique case (key_init_sel)
+      KEY_INIT_INPUT: key_init_d = key_init;
+      KEY_INIT_CLEAR: key_init_d = '0;
+      default:        key_init_d = 'X;
+    endcase
+  end
+
+  always_ff @(posedge clk_i or negedge rst_ni) begin : key_init_reg
+    if (!rst_ni) begin
+      key_init_q <= '0;
+    end else begin
+      for (int i=0; i<8; i++) begin
+        if (key_init_we[i]) begin
+          key_init_q[i] <= key_init_d[i];
+        end
+      end
+    end
+  end
+
   // Full Key registers
   always_comb begin : key_full_mux
     unique case (key_full_sel)
-      KEY_FULL_ENC_INIT: key_full_d = key_init;
+      KEY_FULL_ENC_INIT: key_full_d = key_init_q;
       KEY_FULL_DEC_INIT: key_full_d = key_dec_q;
       KEY_FULL_ROUND:    key_full_d = key_expand_out;
       KEY_FULL_CLEAR:    key_full_d = '0;
@@ -281,7 +305,10 @@
     .state_sel_o            ( state_sel                          ),
     .state_we_o             ( state_we                           ),
     .add_rk_sel_o           ( add_round_key_in_sel               ),
+
     .key_expand_mode_o      ( key_expand_mode                    ),
+    .key_init_sel_o         ( key_init_sel                       ),
+    .key_init_we_o          ( key_init_we                        ),
     .key_full_sel_o         ( key_full_sel                       ),
     .key_full_we_o          ( key_full_we                        ),
     .key_dec_sel_o          ( key_dec_sel                        ),
@@ -292,7 +319,6 @@
     .key_words_sel_o        ( key_words_sel                      ),
     .round_key_sel_o        ( round_key_sel                      ),
 
-    .key_we_o               ( key_we                             ),
     .data_in_we_o           ( data_in_we                         ),
     .data_out_we_o          ( data_out_we                        ),
 
@@ -315,14 +341,7 @@
     .stall_we_o             ( hw2reg.status.stall.de             )
   );
 
-  // Key and input data register clear
-  always_comb begin : key_reg_clear
-    for (int i=0; i<8; i++) begin
-      hw2reg.key[i].d  = '0;
-      hw2reg.key[i].de = key_we;
-    end
-  end
-
+  // Input data register clear
   always_comb begin : data_in_reg_clear
     for (int i=0; i<4; i++) begin
       hw2reg.data_in[i].d  = '0;
@@ -346,6 +365,12 @@
   end
 
   // Outputs
+  always_comb begin : key_reg_put
+    for (int i=0; i<8; i++) begin
+      hw2reg.key[i].d  = key_init_q[i];
+    end
+  end
+
   always_comb begin : data_out_put
     for (int i=0; i<4; i++) begin
       hw2reg.data_out[i].d = data_out_q[i];
diff --git a/hw/ip/aes/rtl/aes_pkg.sv b/hw/ip/aes/rtl/aes_pkg.sv
index 004bb5d..ebaf949 100644
--- a/hw/ip/aes/rtl/aes_pkg.sv
+++ b/hw/ip/aes/rtl/aes_pkg.sv
@@ -29,6 +29,11 @@
   ADD_RK_FINAL
 } add_rk_sel_e;
 
+typedef enum logic {
+  KEY_INIT_INPUT,
+  KEY_INIT_CLEAR
+} key_init_sel_e;
+
 typedef enum logic [1:0] {
   KEY_FULL_ENC_INIT,
   KEY_FULL_DEC_INIT,
diff --git a/hw/ip/aes/rtl/aes_reg_pkg.sv b/hw/ip/aes/rtl/aes_reg_pkg.sv
index 9d6587d..9bccb69 100644
--- a/hw/ip/aes/rtl/aes_reg_pkg.sv
+++ b/hw/ip/aes/rtl/aes_reg_pkg.sv
@@ -65,7 +65,6 @@
 
   typedef struct packed {
     logic [31:0] d;
-    logic        de;
   } aes_hw2reg_key_mreg_t;
 
   typedef struct packed {
@@ -138,7 +137,7 @@
   // Internal design logic to register //
   ///////////////////////////////////////
   typedef struct packed {
-    aes_hw2reg_key_mreg_t [7:0] key; // [543:280]
+    aes_hw2reg_key_mreg_t [7:0] key; // [535:280]
     aes_hw2reg_data_in_mreg_t [3:0] data_in; // [279:148]
     aes_hw2reg_data_out_mreg_t [3:0] data_out; // [147:20]
     aes_hw2reg_ctrl_reg_t ctrl; // [19:10]
diff --git a/hw/ip/aes/rtl/aes_reg_top.sv b/hw/ip/aes/rtl/aes_reg_top.sv
index 646a3eb..d02d28e 100644
--- a/hw/ip/aes/rtl/aes_reg_top.sv
+++ b/hw/ip/aes/rtl/aes_reg_top.sv
@@ -133,210 +133,130 @@
   // Register instances
 
   // Subregister 0 of Multireg key
-  // R[key0]: V(False)
+  // R[key0]: V(True)
 
-  prim_subreg #(
-    .DW      (32),
-    .SWACCESS("WO"),
-    .RESVAL  (32'h0)
+  prim_subreg_ext #(
+    .DW    (32)
   ) u_key0 (
-    .clk_i   (clk_i    ),
-    .rst_ni  (rst_ni  ),
-
-    // from register interface
+    .re     (1'b0),
     .we     (key0_we),
     .wd     (key0_wd),
-
-    // from internal hardware
-    .de     (hw2reg.key[0].de),
-    .d      (hw2reg.key[0].d ),
-
-    // to internal hardware
+    .d      (hw2reg.key[0].d),
+    .qre    (),
     .qe     (reg2hw.key[0].qe),
     .q      (reg2hw.key[0].q ),
-
     .qs     ()
   );
 
   // Subregister 1 of Multireg key
-  // R[key1]: V(False)
+  // R[key1]: V(True)
 
-  prim_subreg #(
-    .DW      (32),
-    .SWACCESS("WO"),
-    .RESVAL  (32'h0)
+  prim_subreg_ext #(
+    .DW    (32)
   ) u_key1 (
-    .clk_i   (clk_i    ),
-    .rst_ni  (rst_ni  ),
-
-    // from register interface
+    .re     (1'b0),
     .we     (key1_we),
     .wd     (key1_wd),
-
-    // from internal hardware
-    .de     (hw2reg.key[1].de),
-    .d      (hw2reg.key[1].d ),
-
-    // to internal hardware
+    .d      (hw2reg.key[1].d),
+    .qre    (),
     .qe     (reg2hw.key[1].qe),
     .q      (reg2hw.key[1].q ),
-
     .qs     ()
   );
 
   // Subregister 2 of Multireg key
-  // R[key2]: V(False)
+  // R[key2]: V(True)
 
-  prim_subreg #(
-    .DW      (32),
-    .SWACCESS("WO"),
-    .RESVAL  (32'h0)
+  prim_subreg_ext #(
+    .DW    (32)
   ) u_key2 (
-    .clk_i   (clk_i    ),
-    .rst_ni  (rst_ni  ),
-
-    // from register interface
+    .re     (1'b0),
     .we     (key2_we),
     .wd     (key2_wd),
-
-    // from internal hardware
-    .de     (hw2reg.key[2].de),
-    .d      (hw2reg.key[2].d ),
-
-    // to internal hardware
+    .d      (hw2reg.key[2].d),
+    .qre    (),
     .qe     (reg2hw.key[2].qe),
     .q      (reg2hw.key[2].q ),
-
     .qs     ()
   );
 
   // Subregister 3 of Multireg key
-  // R[key3]: V(False)
+  // R[key3]: V(True)
 
-  prim_subreg #(
-    .DW      (32),
-    .SWACCESS("WO"),
-    .RESVAL  (32'h0)
+  prim_subreg_ext #(
+    .DW    (32)
   ) u_key3 (
-    .clk_i   (clk_i    ),
-    .rst_ni  (rst_ni  ),
-
-    // from register interface
+    .re     (1'b0),
     .we     (key3_we),
     .wd     (key3_wd),
-
-    // from internal hardware
-    .de     (hw2reg.key[3].de),
-    .d      (hw2reg.key[3].d ),
-
-    // to internal hardware
+    .d      (hw2reg.key[3].d),
+    .qre    (),
     .qe     (reg2hw.key[3].qe),
     .q      (reg2hw.key[3].q ),
-
     .qs     ()
   );
 
   // Subregister 4 of Multireg key
-  // R[key4]: V(False)
+  // R[key4]: V(True)
 
-  prim_subreg #(
-    .DW      (32),
-    .SWACCESS("WO"),
-    .RESVAL  (32'h0)
+  prim_subreg_ext #(
+    .DW    (32)
   ) u_key4 (
-    .clk_i   (clk_i    ),
-    .rst_ni  (rst_ni  ),
-
-    // from register interface
+    .re     (1'b0),
     .we     (key4_we),
     .wd     (key4_wd),
-
-    // from internal hardware
-    .de     (hw2reg.key[4].de),
-    .d      (hw2reg.key[4].d ),
-
-    // to internal hardware
+    .d      (hw2reg.key[4].d),
+    .qre    (),
     .qe     (reg2hw.key[4].qe),
     .q      (reg2hw.key[4].q ),
-
     .qs     ()
   );
 
   // Subregister 5 of Multireg key
-  // R[key5]: V(False)
+  // R[key5]: V(True)
 
-  prim_subreg #(
-    .DW      (32),
-    .SWACCESS("WO"),
-    .RESVAL  (32'h0)
+  prim_subreg_ext #(
+    .DW    (32)
   ) u_key5 (
-    .clk_i   (clk_i    ),
-    .rst_ni  (rst_ni  ),
-
-    // from register interface
+    .re     (1'b0),
     .we     (key5_we),
     .wd     (key5_wd),
-
-    // from internal hardware
-    .de     (hw2reg.key[5].de),
-    .d      (hw2reg.key[5].d ),
-
-    // to internal hardware
+    .d      (hw2reg.key[5].d),
+    .qre    (),
     .qe     (reg2hw.key[5].qe),
     .q      (reg2hw.key[5].q ),
-
     .qs     ()
   );
 
   // Subregister 6 of Multireg key
-  // R[key6]: V(False)
+  // R[key6]: V(True)
 
-  prim_subreg #(
-    .DW      (32),
-    .SWACCESS("WO"),
-    .RESVAL  (32'h0)
+  prim_subreg_ext #(
+    .DW    (32)
   ) u_key6 (
-    .clk_i   (clk_i    ),
-    .rst_ni  (rst_ni  ),
-
-    // from register interface
+    .re     (1'b0),
     .we     (key6_we),
     .wd     (key6_wd),
-
-    // from internal hardware
-    .de     (hw2reg.key[6].de),
-    .d      (hw2reg.key[6].d ),
-
-    // to internal hardware
+    .d      (hw2reg.key[6].d),
+    .qre    (),
     .qe     (reg2hw.key[6].qe),
     .q      (reg2hw.key[6].q ),
-
     .qs     ()
   );
 
   // Subregister 7 of Multireg key
-  // R[key7]: V(False)
+  // R[key7]: V(True)
 
-  prim_subreg #(
-    .DW      (32),
-    .SWACCESS("WO"),
-    .RESVAL  (32'h0)
+  prim_subreg_ext #(
+    .DW    (32)
   ) u_key7 (
-    .clk_i   (clk_i    ),
-    .rst_ni  (rst_ni  ),
-
-    // from register interface
+    .re     (1'b0),
     .we     (key7_we),
     .wd     (key7_wd),
-
-    // from internal hardware
-    .de     (hw2reg.key[7].de),
-    .d      (hw2reg.key[7].d ),
-
-    // to internal hardware
+    .d      (hw2reg.key[7].d),
+    .qre    (),
     .qe     (reg2hw.key[7].qe),
     .q      (reg2hw.key[7].q ),
-
     .qs     ()
   );
 
