Refactor Bru

- Use ChiselEnum
- Combine intermediate state into one register
- Use Chisel MuxCase/MuxLookup for readability.

Change-Id: I889cfc661409a2a86f3fdc79693574836171f078
diff --git a/hdl/chisel/src/kelvin/scalar/Bru.scala b/hdl/chisel/src/kelvin/scalar/Bru.scala
index a83d157..f76f257 100644
--- a/hdl/chisel/src/kelvin/scalar/Bru.scala
+++ b/hdl/chisel/src/kelvin/scalar/Bru.scala
@@ -24,34 +24,32 @@
   }
 }
 
-case class BruOp() {
-  val JAL  = 0
-  val JALR = 1
-  val BEQ  = 2
-  val BNE  = 3
-  val BLT  = 4
-  val BGE  = 5
-  val BLTU = 6
-  val BGEU = 7
-  val EBREAK = 8
-  val ECALL = 9
-  val EEXIT = 10
-  val EYIELD = 11
-  val ECTXSW = 12
-  val MPAUSE = 13
-  val MRET = 14
-  val FENCEI = 15
-  val UNDEF = 16
-  val Entries = 17
+object BruOp extends ChiselEnum {
+  val JAL  = Value
+  val JALR = Value
+  val BEQ  = Value
+  val BNE  = Value
+  val BLT  = Value
+  val BGE  = Value
+  val BLTU = Value
+  val BGEU = Value
+  val EBREAK = Value
+  val ECALL = Value
+  val EEXIT = Value
+  val EYIELD = Value
+  val ECTXSW = Value
+  val MPAUSE = Value
+  val MRET = Value
+  val FENCEI = Value
+  val UNDEF = Value
 }
 
-class BruIO(p: Parameters) extends Bundle {
-  val valid = Input(Bool())
-  val fwd = Input(Bool())
-  val op = Input(UInt(new BruOp().Entries.W))
-  val pc = Input(UInt(p.programCounterBits.W))
-  val target = Input(UInt(p.programCounterBits.W))
-  val link = Input(UInt(5.W))
+class BruCmd(p: Parameters) extends Bundle {
+  val fwd = Bool()
+  val op = BruOp()
+  val pc = UInt(p.programCounterBits.W)
+  val target = UInt(p.programCounterBits.W)
+  val link = UInt(5.W)
 }
 
 class BranchTakenIO(p: Parameters) extends Bundle {
@@ -59,10 +57,34 @@
   val value = Output(UInt(p.programCounterBits.W))
 }
 
+class BranchState(p: Parameters) extends Bundle {
+  val fwd = Bool()
+  val op = BruOp()
+  val target = UInt(p.programCounterBits.W)
+  val linkValid = Bool()
+  val linkAddr = UInt(5.W)
+  val linkData = UInt(p.programCounterBits.W)
+  val pcEx = UInt(32.W)
+}
+
+object BranchState {
+  def default(p: Parameters): BranchState = {
+    val result = Wire(new BranchState(p))
+    result.fwd := false.B
+    result.op := BruOp.JAL
+    result.target := 0.U
+    result.linkValid := false.B
+    result.linkAddr := 0.U
+    result.linkData := 0.U
+    result.pcEx := 0.U
+    result
+  }
+}
+
 class Bru(p: Parameters) extends Module {
   val io = IO(new Bundle {
     // Decode cycle.
-    val req = new BruIO(p)
+    val req = Flipped(Valid(new BruCmd(p)))
 
     // Execute cycle.
     val csr = new CsrBruIO(p)
@@ -75,60 +97,44 @@
     val iflush = Output(Bool())
   })
 
-  val branch = new BruOp()
-
+  // Interlock
   val interlock = RegInit(false.B)
+  interlock := io.req.valid && io.req.bits.op.isOneOf(
+      BruOp.EBREAK, BruOp.ECALL, BruOp.EEXIT, BruOp.EYIELD, BruOp.ECTXSW,
+      BruOp.MPAUSE, BruOp.MRET)
+  io.interlock := interlock
 
-  val readRs = RegInit(false.B)
-  val fwd = RegInit(false.B)
-  val op = RegInit(0.U(branch.Entries.W))
-  val target = Reg(UInt(p.programCounterBits.W))
-  val linkValid = RegInit(false.B)
-  val linkAddr = Reg(UInt(5.W))
-  val linkData = Reg(UInt(p.programCounterBits.W))
-  val pcEx = Reg(UInt(32.W))
-
-  linkValid := io.req.valid && io.req.link =/= 0.U &&
-               (io.req.op(branch.JAL) || io.req.op(branch.JALR))
-
-  op := Mux(io.req.valid, io.req.op, 0.U)
-  fwd := io.req.valid && io.req.fwd
-
-  readRs := Mux(io.req.valid,
-            io.req.op(branch.BEQ)  || io.req.op(branch.BNE) ||
-            io.req.op(branch.BLT)  || io.req.op(branch.BGE) ||
-            io.req.op(branch.BLTU) || io.req.op(branch.BGEU), false.B)
-
+  // Assign state
   val mode = io.csr.out.mode  // (0) machine, (1) user
 
-  val pcDe  = io.req.pc
-  val pc4De = io.req.pc + 4.U
+  val pcDe  = io.req.bits.pc
+  val pc4De = io.req.bits.pc + 4.U
 
-  when (io.req.valid) {
-    val mret = io.req.op(branch.MRET) && !mode
-    val call = io.req.op(branch.MRET) && mode ||
-               io.req.op(branch.EBREAK) ||
-               io.req.op(branch.ECALL) ||
-               io.req.op(branch.EEXIT) ||
-               io.req.op(branch.EYIELD) ||
-               io.req.op(branch.ECTXSW) ||
-               io.req.op(branch.MPAUSE)
-    target := Mux(mret, io.csr.out.mepc,
-              Mux(call, io.csr.out.mtvec,
-              Mux(io.req.fwd || io.req.op(branch.FENCEI), pc4De,
-              Mux(io.req.op(branch.JALR), io.target.data,
-                  io.req.target))))
-    linkAddr := io.req.link
-    linkData := pc4De
-    pcEx := pcDe
-  }
+  val mret = (io.req.bits.op === BruOp.MRET) && !mode
+  val call = ((io.req.bits.op === BruOp.MRET) && mode) ||
+      io.req.bits.op.isOneOf(BruOp.EBREAK, BruOp.ECALL, BruOp.EEXIT,
+                             BruOp.EYIELD, BruOp.ECTXSW, BruOp.MPAUSE)
 
-  interlock := io.req.valid && (io.req.op(branch.EBREAK) ||
-                 io.req.op(branch.ECALL) || io.req.op(branch.EEXIT) ||
-                 io.req.op(branch.EYIELD) || io.req.op(branch.ECTXSW) ||
-                 io.req.op(branch.MPAUSE) || io.req.op(branch.MRET))
+  val stateReg = RegInit(MakeValid(false.B, BranchState.default(p)))
+  val nextState = Wire(new BranchState(p))
+  nextState.linkValid := io.req.valid && (io.req.bits.link =/= 0.U) &&
+               (io.req.bits.op.isOneOf(BruOp.JAL, BruOp.JALR))
 
-  io.interlock := interlock
+  nextState.op := io.req.bits.op
+  nextState.fwd := io.req.valid && io.req.bits.fwd
+
+  nextState.linkAddr := io.req.bits.link
+  nextState.linkData := pc4De
+  nextState.pcEx := pcDe
+
+  nextState.target := MuxCase(io.req.bits.target, Seq(
+      mret -> io.csr.out.mepc,
+      call -> io.csr.out.mepc,
+      (io.req.bits.fwd || (io.req.bits.op === BruOp.FENCEI)) -> pc4De,
+      (io.req.bits.op === BruOp.JALR) -> io.target.data,
+  ))
+  stateReg.valid := io.req.valid
+  stateReg.bits := nextState
 
   // This mux sits on the critical path.
   // val rs1 = Mux(readRs, io.rs1.data, 0.U)
@@ -143,94 +149,83 @@
   val ltu = rs1 < rs2
   val geu = !ltu
 
-  io.taken.valid := op(branch.EBREAK) && mode ||
-                    op(branch.ECALL)  && mode ||
-                    op(branch.EEXIT)  && mode ||
-                    op(branch.EYIELD) && mode ||
-                    op(branch.ECTXSW) && mode ||
-                    op(branch.MRET)   && !mode ||
-                    op(branch.MRET)   && mode ||  // fault
-                    op(branch.MPAUSE) && mode ||  // fault
-                    op(branch.FENCEI) ||
-                    (op(branch.JAL) ||
-                     op(branch.JALR) ||
-                     op(branch.BEQ)  && eq  ||
-                     op(branch.BNE)  && neq ||
-                     op(branch.BLT)  && lt  ||
-                     op(branch.BGE)  && ge  ||
-                     op(branch.BLTU) && ltu ||
-                     op(branch.BGEU) && geu) =/= fwd
+  val op = stateReg.bits.op
 
-  io.taken.value := target
+  io.taken.valid := stateReg.valid && MuxLookup(op, false.B)(Seq(
+    BruOp.EBREAK -> mode,
+    BruOp.ECALL  -> mode,
+    BruOp.EEXIT  -> mode,
+    BruOp.EYIELD -> mode,
+    BruOp.ECTXSW -> mode,
+    BruOp.MPAUSE -> mode,  // fault
+    BruOp.MRET   -> true.B,  // fault if user mode.
+    BruOp.FENCEI -> true.B,
+    BruOp.JAL    -> (true.B =/= stateReg.bits.fwd),
+    BruOp.JALR   -> (true.B =/= stateReg.bits.fwd),
+    BruOp.BEQ    -> (eq  =/= stateReg.bits.fwd),
+    BruOp.BNE    -> (neq =/= stateReg.bits.fwd),
+    BruOp.BLT    -> (lt  =/= stateReg.bits.fwd),
+    BruOp.BGE    -> (ge  =/= stateReg.bits.fwd),
+    BruOp.BLTU   -> (ltu =/= stateReg.bits.fwd),
+    BruOp.BGEU   -> (geu =/= stateReg.bits.fwd),
+  ))
+  io.taken.value := stateReg.bits.target
 
-  io.rd.valid := linkValid
-  io.rd.addr := linkAddr
-  io.rd.data := linkData
+  io.rd.valid := stateReg.valid && stateReg.bits.linkValid
+  io.rd.addr := stateReg.bits.linkAddr
+  io.rd.data := stateReg.bits.linkData
 
   // Undefined Fault.
-  val undefFault = op(branch.UNDEF)
+  val undefFault = stateReg.valid && (op === BruOp.UNDEF)
 
   // Usage Fault.
-  val usageFault = op(branch.EBREAK) && !mode ||
-                   op(branch.ECALL)  && !mode ||
-                   op(branch.EEXIT)  && !mode ||
-                   op(branch.EYIELD) && !mode ||
-                   op(branch.ECTXSW) && !mode ||
-                   op(branch.MPAUSE) && mode ||
-                   op(branch.MRET)   && mode
+  val usageFault = stateReg.valid && Mux(
+      mode, op.isOneOf(BruOp.MPAUSE, BruOp.MRET),
+            op.isOneOf(BruOp.EBREAK, BruOp.ECALL, BruOp.EEXIT, BruOp.EYIELD,
+                       BruOp.ECTXSW))
 
-  io.csr.in.mode.valid := op(branch.EBREAK) && mode ||
-                          op(branch.ECALL)  && mode ||
-                          op(branch.EEXIT)  && mode ||
-                          op(branch.EYIELD) && mode ||
-                          op(branch.ECTXSW) && mode ||
-                          op(branch.MPAUSE) && mode ||  // fault
-                          op(branch.MRET)   && mode ||  // fault
-                          op(branch.MRET)   && !mode
-  io.csr.in.mode.bits := MuxOR(op(branch.MRET) && !mode, true.B)
+  io.csr.in.mode.valid := stateReg.valid && Mux(
+      mode, op.isOneOf(BruOp.EBREAK, BruOp.ECALL, BruOp.EEXIT, BruOp.EYIELD,
+                       BruOp.ECTXSW, BruOp.MPAUSE, BruOp.MRET),
+            (op === BruOp.MRET))
+  io.csr.in.mode.bits := ((op === BruOp.MRET) && !mode)
 
-  io.csr.in.mepc.valid := op(branch.EBREAK) && mode ||
-                          op(branch.ECALL)  && mode ||
-                          op(branch.EEXIT)  && mode ||
-                          op(branch.EYIELD) && mode ||
-                          op(branch.ECTXSW) && mode ||
-                          op(branch.MPAUSE) && mode ||  // fault
-                          op(branch.MRET)   && mode     // fault
-  io.csr.in.mepc.bits := Mux(op(branch.EYIELD), linkData, pcEx)
+  io.csr.in.mepc.valid := stateReg.valid && mode &&
+      op.isOneOf(BruOp.EBREAK, BruOp.ECALL, BruOp.EEXIT, BruOp.EYIELD,
+                 BruOp.ECTXSW, BruOp.MPAUSE, BruOp.MRET)
+  io.csr.in.mepc.bits := Mux(op === BruOp.EYIELD, stateReg.bits.linkData,
+                                                  stateReg.bits.pcEx)
 
-  io.csr.in.mcause.valid := undefFault || usageFault ||
-                            op(branch.EBREAK) && mode ||
-                            op(branch.ECALL)  && mode ||
-                            op(branch.EEXIT)  && mode ||
-                            op(branch.EYIELD) && mode ||
-                            op(branch.ECTXSW) && mode
+  io.csr.in.mcause.valid := stateReg.valid && (undefFault || usageFault ||
+      (mode && op.isOneOf(BruOp.EBREAK, BruOp.ECALL, BruOp.EEXIT, BruOp.EYIELD,
+                          BruOp.ECTXSW)))
 
   val faultMsb = 1.U << 31
-  io.csr.in.mcause.bits := Mux(undefFault, 2.U  | faultMsb,
-                           Mux(usageFault, 16.U | faultMsb,
-                             MuxOR(op(branch.EBREAK), 1.U) |
-                             MuxOR(op(branch.ECALL),  2.U) |
-                             MuxOR(op(branch.EEXIT),  3.U) |
-                             MuxOR(op(branch.EYIELD), 4.U) |
-                             MuxOR(op(branch.ECTXSW), 5.U)))
+  io.csr.in.mcause.bits := MuxCase(0.U, Seq(
+      undefFault        -> (2.U  | faultMsb),
+      usageFault        -> (16.U | faultMsb),
+      (op === BruOp.EBREAK) -> 1.U,
+      (op === BruOp.ECALL)  -> 2.U,
+      (op === BruOp.EEXIT)  -> 3.U,
+      (op === BruOp.EYIELD) -> 4.U,
+      (op === BruOp.ECTXSW) -> 5.U,
+  ))
 
-  io.csr.in.mtval.valid := undefFault || usageFault
-  io.csr.in.mtval.bits := pcEx
+  io.csr.in.mtval.valid := stateReg.valid && (undefFault || usageFault)
+  io.csr.in.mtval.bits := stateReg.bits.pcEx
 
-  io.iflush := op(branch.FENCEI)
+  io.iflush := stateReg.valid && (op === BruOp.FENCEI)
 
   // Pipeline will be halted.
-  io.csr.in.halt := op(branch.MPAUSE) && !mode || io.csr.in.fault
-  io.csr.in.fault := undefFault && !mode || usageFault && !mode
+  io.csr.in.halt := (stateReg.valid && (op === BruOp.MPAUSE) && !mode) ||
+                    io.csr.in.fault
+  io.csr.in.fault := (undefFault && !mode) || (usageFault && !mode)
 
   // Assertions.
-  val valid = RegInit(false.B)
-  val ignore = op(branch.JAL) || op(branch.JALR) || op(branch.EBREAK) ||
-               op(branch.ECALL) || op(branch.EEXIT) || op(branch.EYIELD) ||
-               op(branch.ECTXSW) || op(branch.MPAUSE) || op(branch.MRET) ||
-               op(branch.FENCEI) || op(branch.UNDEF)
+  val ignore = op.isOneOf(BruOp.JAL, BruOp.JALR, BruOp.EBREAK, BruOp.ECALL,
+                          BruOp.EEXIT, BruOp.EYIELD, BruOp.ECTXSW, BruOp.MPAUSE,
+                          BruOp.MRET, BruOp.FENCEI, BruOp.UNDEF)
 
-  valid := io.req.valid
-  assert(!(valid && !io.rs1.valid) || ignore)
-  assert(!(valid && !io.rs2.valid) || ignore)
+  assert(!(stateReg.valid && !io.rs1.valid) || ignore)
+  assert(!(stateReg.valid && !io.rs2.valid) || ignore)
 }
diff --git a/hdl/chisel/src/kelvin/scalar/Decode.scala b/hdl/chisel/src/kelvin/scalar/Decode.scala
index 206c516..a6d131f 100644
--- a/hdl/chisel/src/kelvin/scalar/Decode.scala
+++ b/hdl/chisel/src/kelvin/scalar/Decode.scala
@@ -191,7 +191,7 @@
     val alu = Valid(new AluCmd)
 
     // Branch interface.
-    val bru = Flipped(new BruIO(p))
+    val bru = Valid(new BruCmd(p))
 
     // CSR interface.
     val csr = Valid(new CsrCmd)
@@ -306,33 +306,31 @@
   io.alu.bits.op := alu.bits
 
   // Branch conditional opcode.
-  val bru = new BruOp()
-  val bruOp = Wire(Vec(bru.Entries, Bool()))
-  val bruValid = WiredOR(io.bru.op)  // used without decodeEn
-  io.bru.valid := decodeEn && bruValid
-  io.bru.fwd := io.inst.brchFwd
-  io.bru.op := bruOp.asUInt
-  io.bru.pc := io.inst.addr
-  io.bru.target := io.inst.addr + Mux(io.inst.inst(2), d.immjal, d.immbr)
-  io.bru.link := rdAddr
-
-  bruOp(bru.JAL)  := d.jal
-  bruOp(bru.JALR) := d.jalr
-  bruOp(bru.BEQ)  := d.beq
-  bruOp(bru.BNE)  := d.bne
-  bruOp(bru.BLT)  := d.blt
-  bruOp(bru.BGE)  := d.bge
-  bruOp(bru.BLTU) := d.bltu
-  bruOp(bru.BGEU) := d.bgeu
-  bruOp(bru.EBREAK) := d.ebreak
-  bruOp(bru.ECALL)  := d.ecall
-  bruOp(bru.EEXIT)  := d.eexit
-  bruOp(bru.EYIELD) := d.eyield
-  bruOp(bru.ECTXSW) := d.ectxsw
-  bruOp(bru.MPAUSE) := d.mpause
-  bruOp(bru.MRET)   := d.mret
-  bruOp(bru.FENCEI) := d.fencei
-  bruOp(bru.UNDEF)  := d.undef
+  val bru = MuxCase(MakeValid(false.B, BruOp.JAL), Seq(
+    d.jal    -> MakeValid(true.B, BruOp.JAL),
+    d.jalr   -> MakeValid(true.B, BruOp.JALR),
+    d.beq    -> MakeValid(true.B, BruOp.BEQ),
+    d.bne    -> MakeValid(true.B, BruOp.BNE),
+    d.blt    -> MakeValid(true.B, BruOp.BLT),
+    d.bge    -> MakeValid(true.B, BruOp.BGE),
+    d.bltu   -> MakeValid(true.B, BruOp.BLTU),
+    d.bgeu   -> MakeValid(true.B, BruOp.BGEU),
+    d.ebreak -> MakeValid(true.B, BruOp.EBREAK),
+    d.ecall  -> MakeValid(true.B, BruOp.ECALL),
+    d.eexit  -> MakeValid(true.B, BruOp.EEXIT),
+    d.eyield -> MakeValid(true.B, BruOp.EYIELD),
+    d.ectxsw -> MakeValid(true.B, BruOp.ECTXSW),
+    d.mpause -> MakeValid(true.B, BruOp.MPAUSE),
+    d.mret   -> MakeValid(true.B, BruOp.MRET),
+    d.fencei -> MakeValid(true.B, BruOp.FENCEI),
+    d.undef  -> MakeValid(true.B, BruOp.UNDEF),
+  ))
+  io.bru.valid := decodeEn && bru.valid
+  io.bru.bits.fwd := io.inst.brchFwd
+  io.bru.bits.op := bru.bits
+  io.bru.bits.pc := io.inst.addr
+  io.bru.bits.target := io.inst.addr + Mux(io.inst.inst(2), d.immjal, d.immbr)
+  io.bru.bits.link := rdAddr
 
   // CSR opcode.
   val csr = MuxCase(MakeValid(false.B, CsrOp.CSRRW), Seq(
@@ -440,7 +438,7 @@
       alu.valid || csr.valid || mlu.valid || dvu.valid && io.dvu.ready ||
       lsu.valid && d.isLoad() ||
       d.getvl || d.getmaxvl || vldst_wb ||
-      bruValid && (bruOp(bru.JAL) || bruOp(bru.JALR)) && rdAddr =/= 0.U
+      bru.valid && (bru.bits.isOneOf(BruOp.JAL, BruOp.JALR)) && rdAddr =/= 0.U
 
   // val scoreboard_spec = Mux(rdMark_valid || d.io.vst, UIntToOH(rdAddr, 32), 0.U)  // TODO: why was d.io.vst included?
   val scoreboard_spec = Mux(rdMark_valid, UIntToOH(rdAddr, 32), 0.U)