[prim] Add Write Mask port

This commit adds write mask ports in the prim_sram_arbiter. With the
update, the arbiter now can support Bit Enable SRAM macro too.

Signed-off-by: Eunchan Kim <eunchan@opentitan.org>
diff --git a/hw/ip/prim/rtl/prim_sram_arbiter.sv b/hw/ip/prim/rtl/prim_sram_arbiter.sv
index 50d2762..acb9b8a 100644
--- a/hw/ip/prim/rtl/prim_sram_arbiter.sv
+++ b/hw/ip/prim/rtl/prim_sram_arbiter.sv
@@ -12,18 +12,20 @@
 `include "prim_assert.sv"
 
 module prim_sram_arbiter #(
-  parameter int N  = 4,
-  parameter int SramDw = 32,
-  parameter int SramAw = 12,
-  parameter ArbiterImpl = "PPC"
+  parameter int unsigned N  = 4,
+  parameter int unsigned SramDw = 32,
+  parameter int unsigned SramAw = 12,
+  parameter ArbiterImpl = "PPC",
+  parameter bit EnMask = 1'b 0 // Disable wmask if 0
 ) (
   input clk_i,
   input rst_ni,
 
   input        [     N-1:0] req_i,
   input        [SramAw-1:0] req_addr_i [N],
-  input                     req_write_i[N],
+  input        [     N-1:0] req_write_i,
   input        [SramDw-1:0] req_wdata_i[N],
+  input        [SramDw-1:0] req_wmask_i[N],
   output logic [     N-1:0] gnt_o,
 
   output logic [     N-1:0] rsp_rvalid_o,      // Pulse
@@ -35,30 +37,50 @@
   output logic [SramAw-1:0] sram_addr_o,
   output logic              sram_write_o,
   output logic [SramDw-1:0] sram_wdata_o,
+  output logic [SramDw-1:0] sram_wmask_o,
   input                     sram_rvalid_i,
   input        [SramDw-1:0] sram_rdata_i,
   input        [1:0]        sram_rerror_i
 );
 
-
   typedef struct packed {
     logic write;
     logic [SramAw-1:0] addr;
     logic [SramDw-1:0] wdata;
+    logic [SramDw-1:0] wmask;
   } req_t;
 
-  localparam int ARB_DW = $bits(req_t);
-
   req_t req_packed [N];
 
   for (genvar i = 0 ; i < N ; i++) begin : gen_reqs
-    assign req_packed[i] = {req_write_i[i], req_addr_i[i], req_wdata_i[i]};
+    assign req_packed[i] = {
+      req_write_i[i],
+      req_addr_i [i],
+      req_wdata_i[i],
+      (EnMask) ? req_wmask_i[i] : {SramDw{1'b1}}
+    };
   end
 
+  localparam int ARB_DW = $bits(req_t);
+
   req_t sram_packed;
   assign sram_write_o = sram_packed.write;
   assign sram_addr_o  = sram_packed.addr;
   assign sram_wdata_o = sram_packed.wdata;
+  assign sram_wmask_o = (EnMask) ? sram_packed.wmask : {SramDw{1'b1}};
+
+  if (EnMask == 1'b 0) begin : g_unused
+    logic unused_wmask;
+
+    always_comb begin
+      unused_wmask = 1'b 1;
+      for (int unsigned i = 0 ; i < N ; i++) begin
+        unused_wmask ^= ^req_wmask_i[i];
+      end
+      unused_wmask ^= ^sram_packed.wmask;
+    end
+  end
+
 
   if (ArbiterImpl == "PPC") begin : gen_arb_ppc
     prim_arbiter_ppc #(
diff --git a/hw/ip/spi_device/rtl/spi_device_pkg.sv b/hw/ip/spi_device/rtl/spi_device_pkg.sv
index bb8c4f3..96089e3 100644
--- a/hw/ip/spi_device/rtl/spi_device_pkg.sv
+++ b/hw/ip/spi_device/rtl/spi_device_pkg.sv
@@ -367,6 +367,16 @@
     return result;
   endfunction : sram_strb2mask
 
+  function automatic logic [SramStrbW-1:0] sram_mask2strb(
+    logic [SramDw-1:0] mask
+  );
+    logic [SramStrbW-1:0] result;
+    for (int unsigned i = 0 ; i < SramStrbW ; i++) begin
+      result[i] = &mask[8*i+:8];
+    end
+    return result;
+  endfunction : sram_mask2strb
+
   // Calculate each space's base and size
   parameter sram_addr_t SramReadBufferIdx  = sram_addr_t'(0);
   parameter sram_addr_t SramReadBufferSize = sram_addr_t'(SramMsgDepth);
diff --git a/hw/ip/spi_device/rtl/spi_fwmode.sv b/hw/ip/spi_device/rtl/spi_fwmode.sv
index 52822ef..253aca7 100644
--- a/hw/ip/spi_device/rtl/spi_fwmode.sv
+++ b/hw/ip/spi_device/rtl/spi_fwmode.sv
@@ -108,8 +108,9 @@
 
   logic        [1:0] fwm_sram_req;
   logic [SramAw-1:0] fwm_sram_addr  [2];
-  logic              fwm_sram_write [2];
+  logic        [1:0] fwm_sram_write;
   logic [SramDw-1:0] fwm_sram_wdata [2];
+  logic [SramDw-1:0] fwm_sram_wmask [2];
   logic        [1:0] fwm_sram_gnt;
   logic        [1:0] fwm_sram_rvalid;    // RXF doesn't use
   logic [SramDw-1:0] fwm_sram_rdata [2]; // RXF doesn't use
@@ -221,6 +222,7 @@
     .sram_rdata  (fwm_sram_rdata [FwModeRxFifo]),
     .sram_error  (fwm_sram_error [FwModeRxFifo])
   );
+  assign fwm_sram_wmask [FwModeRxFifo] = '1;
 
   // TX Fifo control (SRAM read request --> FIFO write)
   spi_fwm_txf_ctrl #(
@@ -252,9 +254,15 @@
     .sram_rdata  (fwm_sram_rdata [FwModeTxFifo]),
     .sram_error  (fwm_sram_error [FwModeTxFifo])
   );
+  assign fwm_sram_wmask [FwModeTxFifo] = '1;
 
   // Arbiter for FIFOs : Connecting between SRAM Ctrls and SRAM interface
-  assign fwm_wstrb_o = '1;
+  logic [SramDw-1:0] fwm_wmask;
+
+  assign fwm_wstrb_o = sram_mask2strb(fwm_wmask);
+
+  // TODO: Assume other 7bits in a byte are same to the first bit
+
   prim_sram_arbiter #(
     .N            (2),  // RXF, TXF
     .SramDw       (SramDw),
@@ -267,6 +275,7 @@
     .req_addr_i   (fwm_sram_addr),
     .req_write_i  (fwm_sram_write),
     .req_wdata_i  (fwm_sram_wdata),
+    .req_wmask_i  (fwm_sram_wmask),
     .gnt_o        (fwm_sram_gnt),
 
     .rsp_rvalid_o (fwm_sram_rvalid),
@@ -277,6 +286,7 @@
     .sram_addr_o  (fwm_addr_o),
     .sram_write_o (fwm_write_o),
     .sram_wdata_o (fwm_wdata_o),
+    .sram_wmask_o (fwm_wmask),
 
     .sram_rvalid_i(fwm_rvalid_i),
     .sram_rdata_i (fwm_rdata_i),