refactor(hdl): Make CoreAxiCSR bus-width agnostic

This change refactors the CoreAxiCSR module to be bus-width agnostic.
This is a precursor to adding TileLink support, which has a different
bus width than AXI.

Change-Id: I83424b4b587a5bd6833d9a5e9db862a613af2210
diff --git a/hdl/chisel/src/kelvin/CoreAxiCSR.scala b/hdl/chisel/src/kelvin/CoreAxiCSR.scala
index 3aac4a8..9639090 100644
--- a/hdl/chisel/src/kelvin/CoreAxiCSR.scala
+++ b/hdl/chisel/src/kelvin/CoreAxiCSR.scala
@@ -49,6 +49,8 @@
   val resetReg = RegInit(3.U(p.fetchAddrBits.W))
   val pcStartReg = RegInit(0.U(p.fetchAddrBits.W))
   val statusReg = RegInit(0.U(p.fetchAddrBits.W))
+
+  // Debug module registers, conditionally present.
   val debugReqAddrReg = Option.when(p.useDebugModule)(RegInit(0.U(32.W)))
   val debugReqDataReg = Option.when(p.useDebugModule)(RegInit(0.U(32.W)))
   val debugReqOpReg = Option.when(p.useDebugModule)(RegInit(DmReqOp.NOP.asUInt))
@@ -57,20 +59,25 @@
   val writeAddr = io.fabric.writeDataAddr.bits
   val writeData = io.fabric.writeDataBits
 
+  // Debug module handling logic.
   val rsp_queue = if (p.useDebugModule) {
+    // Queue for debug responses.
     val queue = Module(new Queue(new DebugModuleRspIO(p), 1))
     queue.io.enq <> io.debug.get.rsp
 
+    // Pulse valid signal for a single cycle on a write to the op register.
     val req_valid_pulse = RegInit(false.B)
     val write_to_op_reg = writeEn && writeAddr === CoreCsrAddrs.DbgReqOp
     req_valid_pulse := Mux(write_to_op_reg && io.debug.get.req.ready, true.B, false.B)
     io.debug.get.req.valid := req_valid_pulse
 
+    // Wire up debug request signals.
     io.debug.get.req.bits.address := debugReqAddrReg.get
     io.debug.get.req.bits.data := debugReqDataReg.get
     val (req_op, req_op_valid) = DmReqOp.safe(debugReqOpReg.get)
     io.debug.get.req.bits.op := Mux(req_op_valid, req_op, DmReqOp.NOP)
 
+    // Dequeue from the response queue when the status register is written to.
     val write_to_status_reg = writeEn && writeAddr === CoreCsrAddrs.DbgStatus
     queue.io.deq.ready := write_to_status_reg
     Some(queue)
@@ -78,53 +85,74 @@
     None
   }
 
+  val readAddr = io.fabric.readDataAddr.bits
+  // Align the read address to the AXI data bus width.
+  val alignedAddr = readAddr & ~((p.axi2DataBytes - 1).U(readAddr.getWidth.W))
+
+  val kRegWidthBits = 32
+  val kRegWidthBytes = kRegWidthBits / 8
+  val kCsrBaseAddr = 0x100
+
+  val regsPerBus = p.axi2DataBits / kRegWidthBits
+  val readData = Wire(Vec(regsPerBus, UInt(kRegWidthBits.W)))
+  for (i <- 0 until regsPerBus) {
+    readData(i) := 0.U
+  }
+
+  // Map of core control registers.
+  val coreRegMap = Map(
+    0x0 -> resetReg,
+    0x4 -> pcStartReg,
+    0x8 -> statusReg,
+  )
+
+  // Map of Kelvin's internal CSRs.
+  val csrRegs = io.kelvin_csr.value
+  val csrRegMap = (0 until p.csrOutCount).map { i =>
+    (kCsrBaseAddr + i * kRegWidthBytes) -> csrRegs(i)
+  }.toMap
+
+  // Map of debug registers, conditionally present.
   val debugReadMap = if (p.useDebugModule) {
     val debugStatusReg = Cat(rsp_queue.get.io.deq.valid, io.debug.get.req.ready)
-    Seq(
-      CoreCsrAddrs.DbgReqAddr -> Cat(0.U(96.W), debugReqAddrReg.get),
-      CoreCsrAddrs.DbgReqData -> Cat(0.U(64.W), debugReqDataReg.get, 0.U(32.W)),
-      CoreCsrAddrs.DbgReqOp   -> Cat(0.U(32.W), debugReqOpReg.get, 0.U(64.W)),
-      CoreCsrAddrs.DbgRspData -> Cat(rsp_queue.get.io.deq.bits.data, 0.U(96.W)),
-      CoreCsrAddrs.DbgRspOp   -> Cat(0.U(96.W), rsp_queue.get.io.deq.bits.op.asUInt),
-      CoreCsrAddrs.DbgStatus  -> Cat(0.U(64.W), debugStatusReg, 0.U(32.W)),
+    val regs = Seq(
+      CoreCsrAddrs.DbgReqAddr -> debugReqAddrReg.get,
+      CoreCsrAddrs.DbgReqData -> debugReqDataReg.get,
+      CoreCsrAddrs.DbgReqOp   -> debugReqOpReg.get,
+      CoreCsrAddrs.DbgRspData -> rsp_queue.get.io.deq.bits.data,
+      CoreCsrAddrs.DbgRspOp   -> rsp_queue.get.io.deq.bits.op.asUInt,
+      CoreCsrAddrs.DbgStatus  -> debugStatusReg,
     )
+    regs.map { case (k, v) => k.litValue.toInt -> v }.toMap
   } else {
-    Seq()
+    Map[Int, Data]()
   }
 
-  val readData =
-    MuxLookup(io.fabric.readDataAddr.bits, 0.U)(Seq(
-      0x0.U -> Cat(0.U(96.W), resetReg),
-      0x4.U -> Cat(0.U(64.W), pcStartReg, 0.U(32.W)),
-      0x8.U -> Cat(0.U(32.W), statusReg, 0.U(64.W)),
-    ) ++ debugReadMap
-      ++ ((0 until p.csrOutCount).map(
-      x => ((0x100 + 4*x).U -> (io.kelvin_csr.value(x) << (32 * (x % 4)).U))
-    )))
+  // Combine all register maps.
+  val allReadRegs = coreRegMap ++ csrRegMap ++ debugReadMap
 
-  val debugReadValidMap = if (p.useDebugModule) {
-    Seq(
-      CoreCsrAddrs.DbgReqAddr -> true.B,
-      CoreCsrAddrs.DbgReqData -> true.B,
-      CoreCsrAddrs.DbgReqOp   -> true.B,
-      CoreCsrAddrs.DbgRspData -> true.B,
-      CoreCsrAddrs.DbgRspOp   -> true.B,
-      CoreCsrAddrs.DbgStatus  -> true.B,
-    )
-  } else {
-    Seq()
+  // Group registers by their aligned base address to prevent multiple writers.
+  val groupedRegs = allReadRegs.groupBy { case (offset, _) =>
+    offset & ~(p.axi2DataBytes - 1)
   }
 
-  val readDataValid =
-    MuxLookup(io.fabric.readDataAddr.bits, false.B)(Seq(
-      0x0.U -> true.B,
-      0x4.U -> true.B,
-      0x8.U -> true.B,
-    ) ++ debugReadValidMap
-      ++ ((0 until p.csrOutCount).map(x => ((0x100 + 4*x).U -> true.B))))
+  // Generate read logic for all registers.
+  for ((base, regs) <- groupedRegs) {
+    when(alignedAddr === base.U) {
+      for ((offset, reg) <- regs) {
+        // Place the register value into the correct 32-bit lane of the output bus.
+        readData((offset % p.axi2DataBytes) / kRegWidthBytes) := reg
+      }
+    }
+  }
 
-  // Delay reads by one cycle
-  val readDataNext = Pipe(readDataValid, readData, 1)
+  // A read is valid if it hits any of the registers in our map.
+  val readDataValid = MuxLookup(readAddr, false.B)(
+    allReadRegs.keys.map(addr => (addr.U -> true.B)).toSeq
+  )
+
+  // Delay reads by one cycle for timing.
+  val readDataNext = Pipe(readDataValid, readData.asUInt, 1)
   io.fabric.readData := readDataNext
 
   io.reset := resetReg(0)
@@ -132,7 +160,7 @@
   io.pcStart := pcStartReg
   statusReg := Cat(io.fault, io.halted)
 
-  // Register writes
+  // Register write logic.
   resetReg := Mux(writeEn && writeAddr === 0x0.U, writeData(31,0), resetReg)
   pcStartReg := Mux(writeEn && writeAddr === 0x4.U, writeData(63,32), pcStartReg)
   if (p.useDebugModule) {
@@ -141,21 +169,26 @@
     debugReqOpReg.get := Mux(writeEn && writeAddr === CoreCsrAddrs.DbgReqOp, writeData(95,64), debugReqOpReg.get)
   }
 
+  // Map of valid write addresses for the debug module.
   val debugWriteValidMap = if (p.useDebugModule) {
-    Seq(
-      CoreCsrAddrs.DbgReqAddr -> true.B,
-      CoreCsrAddrs.DbgReqData -> true.B,
-      CoreCsrAddrs.DbgReqOp   -> true.B,
-      CoreCsrAddrs.DbgStatus  -> true.B,
+    Map(
+      CoreCsrAddrs.DbgReqAddr.litValue.toInt -> true.B,
+      CoreCsrAddrs.DbgReqData.litValue.toInt -> true.B,
+      CoreCsrAddrs.DbgReqOp.litValue.toInt   -> true.B,
+      CoreCsrAddrs.DbgStatus.litValue.toInt  -> true.B,
     )
   } else {
-    Seq()
+    Map[Int, Bool]()
   }
 
-  io.fabric.writeResp := writeEn && MuxLookup(writeAddr, false.B)(Seq(
-    0x0.U -> true.B,
-    0x4.U -> true.B,
-  ) ++ debugWriteValidMap)
+  val allWriteRegs = Map(
+    0x0 -> true.B,
+    0x4 -> true.B,
+  ) ++ debugWriteValidMap
+
+  io.fabric.writeResp := writeEn && MuxLookup(writeAddr, false.B)(
+    allWriteRegs.map { case (k, v) => k.U -> v }.toSeq
+  )
 }
 
 class CoreAxiCSR(p: Parameters,
diff --git a/hdl/chisel/src/kelvin/CoreAxiCSRTest.scala b/hdl/chisel/src/kelvin/CoreAxiCSRTest.scala
index b2b187a..8715c09 100644
--- a/hdl/chisel/src/kelvin/CoreAxiCSRTest.scala
+++ b/hdl/chisel/src/kelvin/CoreAxiCSRTest.scala
@@ -126,7 +126,7 @@
       while (dut.io.axi.read.data.valid.peek().litValue != 1) {
         dut.clock.step()
       }
-      dut.io.axi.read.data.bits.data.expect(BigInt(0x20000000) << 32)
+      assert((dut.io.axi.read.data.bits.data.peek().litValue >> 32) == 0x20000000)
       dut.io.axi.read.data.bits.last.expect(1)
       dut.io.axi.read.data.bits.resp.expect(0)