Refactor Bru
- Use ChiselEnum
- Combine intermediate state into one register
- Use Chisel MuxCase/MuxLookup for readability.
Change-Id: I889cfc661409a2a86f3fdc79693574836171f078
diff --git a/hdl/chisel/src/kelvin/scalar/Bru.scala b/hdl/chisel/src/kelvin/scalar/Bru.scala
index a83d157..f76f257 100644
--- a/hdl/chisel/src/kelvin/scalar/Bru.scala
+++ b/hdl/chisel/src/kelvin/scalar/Bru.scala
@@ -24,34 +24,32 @@
}
}
-case class BruOp() {
- val JAL = 0
- val JALR = 1
- val BEQ = 2
- val BNE = 3
- val BLT = 4
- val BGE = 5
- val BLTU = 6
- val BGEU = 7
- val EBREAK = 8
- val ECALL = 9
- val EEXIT = 10
- val EYIELD = 11
- val ECTXSW = 12
- val MPAUSE = 13
- val MRET = 14
- val FENCEI = 15
- val UNDEF = 16
- val Entries = 17
+object BruOp extends ChiselEnum {
+ val JAL = Value
+ val JALR = Value
+ val BEQ = Value
+ val BNE = Value
+ val BLT = Value
+ val BGE = Value
+ val BLTU = Value
+ val BGEU = Value
+ val EBREAK = Value
+ val ECALL = Value
+ val EEXIT = Value
+ val EYIELD = Value
+ val ECTXSW = Value
+ val MPAUSE = Value
+ val MRET = Value
+ val FENCEI = Value
+ val UNDEF = Value
}
-class BruIO(p: Parameters) extends Bundle {
- val valid = Input(Bool())
- val fwd = Input(Bool())
- val op = Input(UInt(new BruOp().Entries.W))
- val pc = Input(UInt(p.programCounterBits.W))
- val target = Input(UInt(p.programCounterBits.W))
- val link = Input(UInt(5.W))
+class BruCmd(p: Parameters) extends Bundle {
+ val fwd = Bool()
+ val op = BruOp()
+ val pc = UInt(p.programCounterBits.W)
+ val target = UInt(p.programCounterBits.W)
+ val link = UInt(5.W)
}
class BranchTakenIO(p: Parameters) extends Bundle {
@@ -59,10 +57,34 @@
val value = Output(UInt(p.programCounterBits.W))
}
+class BranchState(p: Parameters) extends Bundle {
+ val fwd = Bool()
+ val op = BruOp()
+ val target = UInt(p.programCounterBits.W)
+ val linkValid = Bool()
+ val linkAddr = UInt(5.W)
+ val linkData = UInt(p.programCounterBits.W)
+ val pcEx = UInt(32.W)
+}
+
+object BranchState {
+ def default(p: Parameters): BranchState = {
+ val result = Wire(new BranchState(p))
+ result.fwd := false.B
+ result.op := BruOp.JAL
+ result.target := 0.U
+ result.linkValid := false.B
+ result.linkAddr := 0.U
+ result.linkData := 0.U
+ result.pcEx := 0.U
+ result
+ }
+}
+
class Bru(p: Parameters) extends Module {
val io = IO(new Bundle {
// Decode cycle.
- val req = new BruIO(p)
+ val req = Flipped(Valid(new BruCmd(p)))
// Execute cycle.
val csr = new CsrBruIO(p)
@@ -75,60 +97,44 @@
val iflush = Output(Bool())
})
- val branch = new BruOp()
-
+ // Interlock
val interlock = RegInit(false.B)
+ interlock := io.req.valid && io.req.bits.op.isOneOf(
+ BruOp.EBREAK, BruOp.ECALL, BruOp.EEXIT, BruOp.EYIELD, BruOp.ECTXSW,
+ BruOp.MPAUSE, BruOp.MRET)
+ io.interlock := interlock
- val readRs = RegInit(false.B)
- val fwd = RegInit(false.B)
- val op = RegInit(0.U(branch.Entries.W))
- val target = Reg(UInt(p.programCounterBits.W))
- val linkValid = RegInit(false.B)
- val linkAddr = Reg(UInt(5.W))
- val linkData = Reg(UInt(p.programCounterBits.W))
- val pcEx = Reg(UInt(32.W))
-
- linkValid := io.req.valid && io.req.link =/= 0.U &&
- (io.req.op(branch.JAL) || io.req.op(branch.JALR))
-
- op := Mux(io.req.valid, io.req.op, 0.U)
- fwd := io.req.valid && io.req.fwd
-
- readRs := Mux(io.req.valid,
- io.req.op(branch.BEQ) || io.req.op(branch.BNE) ||
- io.req.op(branch.BLT) || io.req.op(branch.BGE) ||
- io.req.op(branch.BLTU) || io.req.op(branch.BGEU), false.B)
-
+ // Assign state
val mode = io.csr.out.mode // (0) machine, (1) user
- val pcDe = io.req.pc
- val pc4De = io.req.pc + 4.U
+ val pcDe = io.req.bits.pc
+ val pc4De = io.req.bits.pc + 4.U
- when (io.req.valid) {
- val mret = io.req.op(branch.MRET) && !mode
- val call = io.req.op(branch.MRET) && mode ||
- io.req.op(branch.EBREAK) ||
- io.req.op(branch.ECALL) ||
- io.req.op(branch.EEXIT) ||
- io.req.op(branch.EYIELD) ||
- io.req.op(branch.ECTXSW) ||
- io.req.op(branch.MPAUSE)
- target := Mux(mret, io.csr.out.mepc,
- Mux(call, io.csr.out.mtvec,
- Mux(io.req.fwd || io.req.op(branch.FENCEI), pc4De,
- Mux(io.req.op(branch.JALR), io.target.data,
- io.req.target))))
- linkAddr := io.req.link
- linkData := pc4De
- pcEx := pcDe
- }
+ val mret = (io.req.bits.op === BruOp.MRET) && !mode
+ val call = ((io.req.bits.op === BruOp.MRET) && mode) ||
+ io.req.bits.op.isOneOf(BruOp.EBREAK, BruOp.ECALL, BruOp.EEXIT,
+ BruOp.EYIELD, BruOp.ECTXSW, BruOp.MPAUSE)
- interlock := io.req.valid && (io.req.op(branch.EBREAK) ||
- io.req.op(branch.ECALL) || io.req.op(branch.EEXIT) ||
- io.req.op(branch.EYIELD) || io.req.op(branch.ECTXSW) ||
- io.req.op(branch.MPAUSE) || io.req.op(branch.MRET))
+ val stateReg = RegInit(MakeValid(false.B, BranchState.default(p)))
+ val nextState = Wire(new BranchState(p))
+ nextState.linkValid := io.req.valid && (io.req.bits.link =/= 0.U) &&
+ (io.req.bits.op.isOneOf(BruOp.JAL, BruOp.JALR))
- io.interlock := interlock
+ nextState.op := io.req.bits.op
+ nextState.fwd := io.req.valid && io.req.bits.fwd
+
+ nextState.linkAddr := io.req.bits.link
+ nextState.linkData := pc4De
+ nextState.pcEx := pcDe
+
+ nextState.target := MuxCase(io.req.bits.target, Seq(
+ mret -> io.csr.out.mepc,
+ call -> io.csr.out.mepc,
+ (io.req.bits.fwd || (io.req.bits.op === BruOp.FENCEI)) -> pc4De,
+ (io.req.bits.op === BruOp.JALR) -> io.target.data,
+ ))
+ stateReg.valid := io.req.valid
+ stateReg.bits := nextState
// This mux sits on the critical path.
// val rs1 = Mux(readRs, io.rs1.data, 0.U)
@@ -143,94 +149,83 @@
val ltu = rs1 < rs2
val geu = !ltu
- io.taken.valid := op(branch.EBREAK) && mode ||
- op(branch.ECALL) && mode ||
- op(branch.EEXIT) && mode ||
- op(branch.EYIELD) && mode ||
- op(branch.ECTXSW) && mode ||
- op(branch.MRET) && !mode ||
- op(branch.MRET) && mode || // fault
- op(branch.MPAUSE) && mode || // fault
- op(branch.FENCEI) ||
- (op(branch.JAL) ||
- op(branch.JALR) ||
- op(branch.BEQ) && eq ||
- op(branch.BNE) && neq ||
- op(branch.BLT) && lt ||
- op(branch.BGE) && ge ||
- op(branch.BLTU) && ltu ||
- op(branch.BGEU) && geu) =/= fwd
+ val op = stateReg.bits.op
- io.taken.value := target
+ io.taken.valid := stateReg.valid && MuxLookup(op, false.B)(Seq(
+ BruOp.EBREAK -> mode,
+ BruOp.ECALL -> mode,
+ BruOp.EEXIT -> mode,
+ BruOp.EYIELD -> mode,
+ BruOp.ECTXSW -> mode,
+ BruOp.MPAUSE -> mode, // fault
+ BruOp.MRET -> true.B, // fault if user mode.
+ BruOp.FENCEI -> true.B,
+ BruOp.JAL -> (true.B =/= stateReg.bits.fwd),
+ BruOp.JALR -> (true.B =/= stateReg.bits.fwd),
+ BruOp.BEQ -> (eq =/= stateReg.bits.fwd),
+ BruOp.BNE -> (neq =/= stateReg.bits.fwd),
+ BruOp.BLT -> (lt =/= stateReg.bits.fwd),
+ BruOp.BGE -> (ge =/= stateReg.bits.fwd),
+ BruOp.BLTU -> (ltu =/= stateReg.bits.fwd),
+ BruOp.BGEU -> (geu =/= stateReg.bits.fwd),
+ ))
+ io.taken.value := stateReg.bits.target
- io.rd.valid := linkValid
- io.rd.addr := linkAddr
- io.rd.data := linkData
+ io.rd.valid := stateReg.valid && stateReg.bits.linkValid
+ io.rd.addr := stateReg.bits.linkAddr
+ io.rd.data := stateReg.bits.linkData
// Undefined Fault.
- val undefFault = op(branch.UNDEF)
+ val undefFault = stateReg.valid && (op === BruOp.UNDEF)
// Usage Fault.
- val usageFault = op(branch.EBREAK) && !mode ||
- op(branch.ECALL) && !mode ||
- op(branch.EEXIT) && !mode ||
- op(branch.EYIELD) && !mode ||
- op(branch.ECTXSW) && !mode ||
- op(branch.MPAUSE) && mode ||
- op(branch.MRET) && mode
+ val usageFault = stateReg.valid && Mux(
+ mode, op.isOneOf(BruOp.MPAUSE, BruOp.MRET),
+ op.isOneOf(BruOp.EBREAK, BruOp.ECALL, BruOp.EEXIT, BruOp.EYIELD,
+ BruOp.ECTXSW))
- io.csr.in.mode.valid := op(branch.EBREAK) && mode ||
- op(branch.ECALL) && mode ||
- op(branch.EEXIT) && mode ||
- op(branch.EYIELD) && mode ||
- op(branch.ECTXSW) && mode ||
- op(branch.MPAUSE) && mode || // fault
- op(branch.MRET) && mode || // fault
- op(branch.MRET) && !mode
- io.csr.in.mode.bits := MuxOR(op(branch.MRET) && !mode, true.B)
+ io.csr.in.mode.valid := stateReg.valid && Mux(
+ mode, op.isOneOf(BruOp.EBREAK, BruOp.ECALL, BruOp.EEXIT, BruOp.EYIELD,
+ BruOp.ECTXSW, BruOp.MPAUSE, BruOp.MRET),
+ (op === BruOp.MRET))
+ io.csr.in.mode.bits := ((op === BruOp.MRET) && !mode)
- io.csr.in.mepc.valid := op(branch.EBREAK) && mode ||
- op(branch.ECALL) && mode ||
- op(branch.EEXIT) && mode ||
- op(branch.EYIELD) && mode ||
- op(branch.ECTXSW) && mode ||
- op(branch.MPAUSE) && mode || // fault
- op(branch.MRET) && mode // fault
- io.csr.in.mepc.bits := Mux(op(branch.EYIELD), linkData, pcEx)
+ io.csr.in.mepc.valid := stateReg.valid && mode &&
+ op.isOneOf(BruOp.EBREAK, BruOp.ECALL, BruOp.EEXIT, BruOp.EYIELD,
+ BruOp.ECTXSW, BruOp.MPAUSE, BruOp.MRET)
+ io.csr.in.mepc.bits := Mux(op === BruOp.EYIELD, stateReg.bits.linkData,
+ stateReg.bits.pcEx)
- io.csr.in.mcause.valid := undefFault || usageFault ||
- op(branch.EBREAK) && mode ||
- op(branch.ECALL) && mode ||
- op(branch.EEXIT) && mode ||
- op(branch.EYIELD) && mode ||
- op(branch.ECTXSW) && mode
+ io.csr.in.mcause.valid := stateReg.valid && (undefFault || usageFault ||
+ (mode && op.isOneOf(BruOp.EBREAK, BruOp.ECALL, BruOp.EEXIT, BruOp.EYIELD,
+ BruOp.ECTXSW)))
val faultMsb = 1.U << 31
- io.csr.in.mcause.bits := Mux(undefFault, 2.U | faultMsb,
- Mux(usageFault, 16.U | faultMsb,
- MuxOR(op(branch.EBREAK), 1.U) |
- MuxOR(op(branch.ECALL), 2.U) |
- MuxOR(op(branch.EEXIT), 3.U) |
- MuxOR(op(branch.EYIELD), 4.U) |
- MuxOR(op(branch.ECTXSW), 5.U)))
+ io.csr.in.mcause.bits := MuxCase(0.U, Seq(
+ undefFault -> (2.U | faultMsb),
+ usageFault -> (16.U | faultMsb),
+ (op === BruOp.EBREAK) -> 1.U,
+ (op === BruOp.ECALL) -> 2.U,
+ (op === BruOp.EEXIT) -> 3.U,
+ (op === BruOp.EYIELD) -> 4.U,
+ (op === BruOp.ECTXSW) -> 5.U,
+ ))
- io.csr.in.mtval.valid := undefFault || usageFault
- io.csr.in.mtval.bits := pcEx
+ io.csr.in.mtval.valid := stateReg.valid && (undefFault || usageFault)
+ io.csr.in.mtval.bits := stateReg.bits.pcEx
- io.iflush := op(branch.FENCEI)
+ io.iflush := stateReg.valid && (op === BruOp.FENCEI)
// Pipeline will be halted.
- io.csr.in.halt := op(branch.MPAUSE) && !mode || io.csr.in.fault
- io.csr.in.fault := undefFault && !mode || usageFault && !mode
+ io.csr.in.halt := (stateReg.valid && (op === BruOp.MPAUSE) && !mode) ||
+ io.csr.in.fault
+ io.csr.in.fault := (undefFault && !mode) || (usageFault && !mode)
// Assertions.
- val valid = RegInit(false.B)
- val ignore = op(branch.JAL) || op(branch.JALR) || op(branch.EBREAK) ||
- op(branch.ECALL) || op(branch.EEXIT) || op(branch.EYIELD) ||
- op(branch.ECTXSW) || op(branch.MPAUSE) || op(branch.MRET) ||
- op(branch.FENCEI) || op(branch.UNDEF)
+ val ignore = op.isOneOf(BruOp.JAL, BruOp.JALR, BruOp.EBREAK, BruOp.ECALL,
+ BruOp.EEXIT, BruOp.EYIELD, BruOp.ECTXSW, BruOp.MPAUSE,
+ BruOp.MRET, BruOp.FENCEI, BruOp.UNDEF)
- valid := io.req.valid
- assert(!(valid && !io.rs1.valid) || ignore)
- assert(!(valid && !io.rs2.valid) || ignore)
+ assert(!(stateReg.valid && !io.rs1.valid) || ignore)
+ assert(!(stateReg.valid && !io.rs2.valid) || ignore)
}
diff --git a/hdl/chisel/src/kelvin/scalar/Decode.scala b/hdl/chisel/src/kelvin/scalar/Decode.scala
index 206c516..a6d131f 100644
--- a/hdl/chisel/src/kelvin/scalar/Decode.scala
+++ b/hdl/chisel/src/kelvin/scalar/Decode.scala
@@ -191,7 +191,7 @@
val alu = Valid(new AluCmd)
// Branch interface.
- val bru = Flipped(new BruIO(p))
+ val bru = Valid(new BruCmd(p))
// CSR interface.
val csr = Valid(new CsrCmd)
@@ -306,33 +306,31 @@
io.alu.bits.op := alu.bits
// Branch conditional opcode.
- val bru = new BruOp()
- val bruOp = Wire(Vec(bru.Entries, Bool()))
- val bruValid = WiredOR(io.bru.op) // used without decodeEn
- io.bru.valid := decodeEn && bruValid
- io.bru.fwd := io.inst.brchFwd
- io.bru.op := bruOp.asUInt
- io.bru.pc := io.inst.addr
- io.bru.target := io.inst.addr + Mux(io.inst.inst(2), d.immjal, d.immbr)
- io.bru.link := rdAddr
-
- bruOp(bru.JAL) := d.jal
- bruOp(bru.JALR) := d.jalr
- bruOp(bru.BEQ) := d.beq
- bruOp(bru.BNE) := d.bne
- bruOp(bru.BLT) := d.blt
- bruOp(bru.BGE) := d.bge
- bruOp(bru.BLTU) := d.bltu
- bruOp(bru.BGEU) := d.bgeu
- bruOp(bru.EBREAK) := d.ebreak
- bruOp(bru.ECALL) := d.ecall
- bruOp(bru.EEXIT) := d.eexit
- bruOp(bru.EYIELD) := d.eyield
- bruOp(bru.ECTXSW) := d.ectxsw
- bruOp(bru.MPAUSE) := d.mpause
- bruOp(bru.MRET) := d.mret
- bruOp(bru.FENCEI) := d.fencei
- bruOp(bru.UNDEF) := d.undef
+ val bru = MuxCase(MakeValid(false.B, BruOp.JAL), Seq(
+ d.jal -> MakeValid(true.B, BruOp.JAL),
+ d.jalr -> MakeValid(true.B, BruOp.JALR),
+ d.beq -> MakeValid(true.B, BruOp.BEQ),
+ d.bne -> MakeValid(true.B, BruOp.BNE),
+ d.blt -> MakeValid(true.B, BruOp.BLT),
+ d.bge -> MakeValid(true.B, BruOp.BGE),
+ d.bltu -> MakeValid(true.B, BruOp.BLTU),
+ d.bgeu -> MakeValid(true.B, BruOp.BGEU),
+ d.ebreak -> MakeValid(true.B, BruOp.EBREAK),
+ d.ecall -> MakeValid(true.B, BruOp.ECALL),
+ d.eexit -> MakeValid(true.B, BruOp.EEXIT),
+ d.eyield -> MakeValid(true.B, BruOp.EYIELD),
+ d.ectxsw -> MakeValid(true.B, BruOp.ECTXSW),
+ d.mpause -> MakeValid(true.B, BruOp.MPAUSE),
+ d.mret -> MakeValid(true.B, BruOp.MRET),
+ d.fencei -> MakeValid(true.B, BruOp.FENCEI),
+ d.undef -> MakeValid(true.B, BruOp.UNDEF),
+ ))
+ io.bru.valid := decodeEn && bru.valid
+ io.bru.bits.fwd := io.inst.brchFwd
+ io.bru.bits.op := bru.bits
+ io.bru.bits.pc := io.inst.addr
+ io.bru.bits.target := io.inst.addr + Mux(io.inst.inst(2), d.immjal, d.immbr)
+ io.bru.bits.link := rdAddr
// CSR opcode.
val csr = MuxCase(MakeValid(false.B, CsrOp.CSRRW), Seq(
@@ -440,7 +438,7 @@
alu.valid || csr.valid || mlu.valid || dvu.valid && io.dvu.ready ||
lsu.valid && d.isLoad() ||
d.getvl || d.getmaxvl || vldst_wb ||
- bruValid && (bruOp(bru.JAL) || bruOp(bru.JALR)) && rdAddr =/= 0.U
+ bru.valid && (bru.bits.isOneOf(BruOp.JAL, BruOp.JALR)) && rdAddr =/= 0.U
// val scoreboard_spec = Mux(rdMark_valid || d.io.vst, UIntToOH(rdAddr, 32), 0.U) // TODO: why was d.io.vst included?
val scoreboard_spec = Mux(rdMark_valid, UIntToOH(rdAddr, 32), 0.U)