blob: 80a4f264fa81ee817c73b7eaa0606e72dcd158b3 [file] [log] [blame]
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package kelvin
import chisel3._
import chisel3.util._
import kelvin.float.{CsrFloatIO}
class CsrRvvIO(p: Parameters) extends Bundle {
// To Csr from RvvCore
val vstart = Input(UInt(log2Ceil(p.rvvVlen).W))
val vxrm = Input(UInt(2.W))
val vxsat = Input(Bool())
// From Csr to RvvCore
val vstart_write = Output(Valid(UInt(log2Ceil(p.rvvVlen).W)))
val vxrm_write = Output(Valid(UInt(2.W)))
val vxsat_write = Output(Valid(Bool()))
}
object Csr {
def apply(p: Parameters): Csr = {
return Module(new Csr(p))
}
}
object CsrAddress extends ChiselEnum {
val FFLAGS = Value(0x001.U(12.W))
val FRM = Value(0x002.U(12.W))
val FCSR = Value(0x003.U(12.W))
val VSTART = Value(0x008.U(12.W))
val VXSAT = Value(0x009.U(12.W))
val VXRM = Value(0x00A.U(12.W))
val MSTATUS = Value(0x300.U(12.W))
val MISA = Value(0x301.U(12.W))
val MIE = Value(0x304.U(12.W))
val MTVEC = Value(0x305.U(12.W))
val MSCRATCH = Value(0x340.U(12.W))
val MEPC = Value(0x341.U(12.W))
val MCAUSE = Value(0x342.U(12.W))
val MTVAL = Value(0x343.U(12.W))
val TSELECT = Value(0x7A0.U(12.W))
val TDATA1 = Value(0x7A1.U(12.W))
val TDATA2 = Value(0x7A2.U(12.W))
val TINFO = Value(0x7A4.U(12.W))
val DCSR = Value(0x7B0.U(12.W))
val DPC = Value(0x7B1.U(12.W))
val DSCRATCH0 = Value(0x7B2.U(12.W))
val DSCRATCH1 = Value(0x7B3.U(12.W))
val MCONTEXT0 = Value(0x7C0.U(12.W))
val MCONTEXT1 = Value(0x7C1.U(12.W))
val MCONTEXT2 = Value(0x7C2.U(12.W))
val MCONTEXT3 = Value(0x7C3.U(12.W))
val MCONTEXT4 = Value(0x7C4.U(12.W))
val MCONTEXT5 = Value(0x7C5.U(12.W))
val MCONTEXT6 = Value(0x7C6.U(12.W))
val MCONTEXT7 = Value(0x7C7.U(12.W))
val MPC = Value(0x7E0.U(12.W))
val MSP = Value(0x7E1.U(12.W))
val MCYCLE = Value(0xB00.U(12.W))
val MINSTRET = Value(0xB02.U(12.W))
val MCYCLEH = Value(0xB80.U(12.W))
val MINSTRETH = Value(0xB82.U(12.W))
val VLENB = Value(0xC22.U(12.W))
val MVENDORID = Value(0xF11.U(12.W))
val MARCHID = Value(0xF12.U(12.W))
val MIMPID = Value(0xF13.U(12.W))
val MHARTID = Value(0xF14.U(12.W))
val KISA = Value(0xFC0.U(12.W))
val KSCM0 = Value(0xFC4.U(12.W))
val KSCM1 = Value(0xFC8.U(12.W))
val KSCM2 = Value(0xFCC.U(12.W))
val KSCM3 = Value(0xFD0.U(12.W))
val KSCM4 = Value(0xFD4.U(12.W))
}
object CsrMode extends ChiselEnum {
val Machine = Value(0.U(2.W))
val User = Value(1.U(2.W))
val Debug = Value(2.U(2.W))
}
/* For details, see The RISC-V Debug Specification v1.0, chapter 4.9.1 */
class Dcsr extends Bundle {
val debugver = UInt(4.W)
val extcause = UInt(3.W)
val cetrig = Bool()
val pelp = Bool()
val ebreakvs = Bool()
val ebreakvu = Bool()
val ebreakm = Bool()
val ebreaks = Bool()
val ebreaku = Bool()
val stepie = Bool()
val stopcount = Bool()
val stoptime = Bool()
val cause = UInt(3.W)
val v = Bool()
val mprven = Bool()
val nmip = Bool()
val step = Bool()
val prv = UInt(2.W)
def asWord: UInt = {
val ret = Cat(debugver, 0.U(1.W), extcause, 0.U(4.W), cetrig, pelp, ebreakvs, ebreakvu, ebreakm, 0.U(1.W),
ebreaks, ebreaku, stepie, stopcount, stoptime, cause, v, mprven, nmip, step, prv)
assert(ret.getWidth == 32)
ret
}
}
/* For details, see The RISC-V Debug Specification v1.0, chapter 5.7.2 */
class Tdata1 extends Bundle {
val type_ = UInt(4.W)
val dmode = Bool()
val data = UInt(27.W)
def asWord: UInt = {
val ret = Cat(type_, dmode, data)
assert(ret.getWidth == 32)
ret
}
def isTrigger6: Bool = {
type_ === 6.U(4.W)
}
}
class CsrCounters(p: Parameters) extends Bundle {
val rfwriteCount = UInt(3.W)
val storeCount = UInt(2.W)
val branchCount = UInt(1.W)
val vrfwriteCount = if (p.enableVector) {
Some(UInt(3.W))
} else { None }
val vstoreCount = if (p.enableVector) {
Some(UInt(2.W))
} else { None }
}
class CsrBruIO(p: Parameters) extends Bundle {
val in = new Bundle {
val mode = Valid(CsrMode())
val mcause = Valid(UInt(32.W))
val mepc = Valid(UInt(32.W))
val mtval = Valid(UInt(32.W))
val halt = Output(Bool())
val fault = Output(Bool())
val wfi = Output(Bool())
}
val out = new Bundle {
val mode = Input(CsrMode())
val mepc = Input(UInt(32.W))
val mtvec = Input(UInt(32.W))
}
def defaults() = {
out.mode := CsrMode.Machine
out.mepc := 0.U
out.mtvec := 0.U
}
}
class Csr(p: Parameters) extends Module {
val io = IO(new Bundle {
// Reset and shutdown.
val csr = new CsrInOutIO(p)
// Decode cycle.
val req = Flipped(Valid(new CsrCmd))
// Execute cycle.
val rs1 = Flipped(new RegfileReadDataIO)
val rd = Valid(Flipped(new RegfileWriteDataIO))
val bru = Flipped(new CsrBruIO(p))
val float = Option.when(p.enableFloat) { Flipped(new CsrFloatIO(p)) }
val rvv = Option.when(p.enableRvv) { new CsrRvvIO(p) }
// Vector core.
val vcore = (if (p.enableVector) {
Some(Input(new Bundle { val undef = Bool() }))
} else { None })
val counters = Input(new CsrCounters(p))
// Pipeline Control.
val halted = Output(Bool())
val fault = Output(Bool())
val wfi = Output(Bool())
val irq = Input(Bool())
val dm = Option.when(p.useDebugModule)(new Bundle {
val debug_req = Input(Bool())
val resume_req = Input(Bool())
val debug_mode = Output(Bool())
val single_step = Output(Bool())
val dcsr_step = Output(Bool())
val next_pc = Input(UInt(32.W))
})
val trace = Option.when(p.useRetirementBuffer)(Output(new CsrTraceIO(p)))
})
def LegalizeTdata1(wdata: UInt): Tdata1 = {
assert(wdata.getWidth == 32)
val newWdata = Wire(new Tdata1)
val newType = wdata(31,28)
val newTypeTrigger6 = (newType === 6.U(4.W))
newWdata.type_ := Mux(newTypeTrigger6, newType, 15.U(4.W))
newWdata.data := MuxOR(newTypeTrigger6, wdata(26,0))
newWdata.dmode := wdata(27)
newWdata
}
// Control registers.
val req = Pipe(io.req)
// Pipeline Control.
val halted = RegInit(false.B)
val fault = RegInit(false.B)
val wfi = RegInit(false.B)
// Machine(0)/User(1)/Debug(2) Mode.
val mode = RegInit(CsrMode.Machine)
// CSRs parallel loaded when(reset).
val mpc = RegInit(0.U(32.W))
val msp = RegInit(0.U(32.W))
val mcause = RegInit(0.U(32.W))
val mtval = RegInit(0.U(32.W))
val mcontext0 = RegInit(0.U(32.W))
val mcontext1 = RegInit(0.U(32.W))
val mcontext2 = RegInit(0.U(32.W))
val mcontext3 = RegInit(0.U(32.W))
val mcontext4 = RegInit(0.U(32.W))
val mcontext5 = RegInit(0.U(32.W))
val mcontext6 = RegInit(0.U(32.W))
val mcontext7 = RegInit(0.U(32.W))
// Debug mode CSRs
val dcsr = Option.when(p.useDebugModule)(RegInit(0.U.asTypeOf(new Dcsr)))
val dpc = Option.when(p.useDebugModule)(RegInit(0.U(32.W)))
val dscratch0 = Option.when(p.useDebugModule)(RegInit(0.U(32.W)))
val dscratch1 = Option.when(p.useDebugModule)(RegInit(0.U(32.W)))
// Trigger CSRs
val tselect = Option.when(p.useDebugModule)(RegInit(0.U(32.W)))
val tdata1 = Option.when(p.useDebugModule)(RegInit(0.U.asTypeOf(new Tdata1)))
val tdata2 = Option.when(p.useDebugModule)(RegInit(0.U(32.W)))
/* For details, see The RISC-V Debug Specification v1.0, chapter 5.7.5 */
val tinfo = Option.when(p.useDebugModule)(RegInit(0x01000040.U(32.W)))
// CSRs with initialization.
val fflags = RegInit(0.U(5.W))
val frm = RegInit(0.U(3.W))
val mie = RegInit(0.U(1.W))
val mtvec = RegInit(0.U(32.W))
val mscratch = RegInit(0.U(32.W))
val mepc = RegInit(0.U(32.W))
val mpp = RegInit(0.U(2.W))
val mhartid = RegInit(p.hartId.U(32.W))
val mcycle = RegInit(0.U(64.W))
val minstret = RegInit(0.U(64.W))
// 32-bit MXLEN, I,M,X extensions
val misa = RegInit(((
0x40001100 |
(if (p.enableVector) { 1 << 23 /* 'X' */ } else { 0 }) |
(if (p.enableRvv) { 1 << 21 /* 'V' */ } else { 0 }) |
(if (p.enableFloat) { 1 << 5 /* 'F' */ } else { 0 })
).U)(32.W))
// Kelvin-specific ISA register.
val kisa = RegInit(0.U(32.W))
// SCM Revision (spread over 5 indices)
val kscm = RegInit(((new ScmInfo).revision).U(160.W))
// 0x426 - Google's Vendor ID
val mvendorid = RegInit(0x426.U(32.W))
// Unimplemented -- explicitly return zero.
val marchid = RegInit(0.U(1.W))
val mimpid = RegInit(0.U(1.W))
val fcsr = Cat(frm, fflags)
// Decode the Index.
val (csr_address, csr_address_valid) = CsrAddress.safe(req.bits.index)
assert(!(req.valid && !csr_address_valid))
val fflagsEn = csr_address === CsrAddress.FFLAGS
val frmEn = csr_address === CsrAddress.FRM
val fcsrEn = csr_address === CsrAddress.FCSR
val vstartEn = Option.when(p.enableRvv) { csr_address === CsrAddress.VSTART }
val vxrmEn = Option.when(p.enableRvv) { csr_address === CsrAddress.VXRM }
val vxsatEn = Option.when(p.enableRvv) { csr_address === CsrAddress.VXSAT }
val mstatusEn = csr_address === CsrAddress.MSTATUS
val misaEn = csr_address === CsrAddress.MISA
val mieEn = csr_address === CsrAddress.MIE
val mtvecEn = csr_address === CsrAddress.MTVEC
val mscratchEn = csr_address === CsrAddress.MSCRATCH
val mepcEn = csr_address === CsrAddress.MEPC
val mcauseEn = csr_address === CsrAddress.MCAUSE
val mtvalEn = csr_address === CsrAddress.MTVAL
// Debug CSRs.
val tselectEn = Option.when(p.useDebugModule)(csr_address === CsrAddress.TSELECT)
val tdata1En = Option.when(p.useDebugModule)(csr_address === CsrAddress.TDATA1)
val tdata2En = Option.when(p.useDebugModule)(csr_address === CsrAddress.TDATA2)
val tinfoEn = Option.when(p.useDebugModule)(csr_address === CsrAddress.TINFO)
val dcsrEn = Option.when(p.useDebugModule)(csr_address === CsrAddress.DCSR)
val dpcEn = Option.when(p.useDebugModule)(csr_address === CsrAddress.DPC)
val dscratch0En = Option.when(p.useDebugModule)(csr_address === CsrAddress.DSCRATCH0)
val dscratch1En = Option.when(p.useDebugModule)(csr_address === CsrAddress.DSCRATCH1)
val mcontext0En = csr_address === CsrAddress.MCONTEXT0
val mcontext1En = csr_address === CsrAddress.MCONTEXT1
val mcontext2En = csr_address === CsrAddress.MCONTEXT2
val mcontext3En = csr_address === CsrAddress.MCONTEXT3
val mcontext4En = csr_address === CsrAddress.MCONTEXT4
val mcontext5En = csr_address === CsrAddress.MCONTEXT5
val mcontext6En = csr_address === CsrAddress.MCONTEXT6
val mcontext7En = csr_address === CsrAddress.MCONTEXT7
val mpcEn = csr_address === CsrAddress.MPC
val mspEn = csr_address === CsrAddress.MSP
// M-mode performance CSRs.
val mcycleEn = csr_address === CsrAddress.MCYCLE
val minstretEn = csr_address === CsrAddress.MINSTRET
val mcyclehEn = csr_address === CsrAddress.MCYCLEH
val minstrethEn = csr_address === CsrAddress.MINSTRETH
// Vector CSRs.
val vlenbEn = Option.when(p.enableRvv) { csr_address === CsrAddress.VLENB }
// M-mode information CSRs.
val mvendoridEn = csr_address === CsrAddress.MVENDORID
val marchidEn = csr_address === CsrAddress.MARCHID
val mimpidEn = csr_address === CsrAddress.MIMPID
val mhartidEn = csr_address === CsrAddress.MHARTID
// Start of custom CSRs.
val kisaEn = csr_address === CsrAddress.KISA
val kscm0En = csr_address === CsrAddress.KSCM0
val kscm1En = csr_address === CsrAddress.KSCM1
val kscm2En = csr_address === CsrAddress.KSCM2
val kscm3En = csr_address === CsrAddress.KSCM3
val kscm4En = csr_address === CsrAddress.KSCM4
// Pipeline Control.
val vcoreUndef = if (p.enableVector) { io.vcore.get.undef } else { false.B }
when (io.bru.in.halt || vcoreUndef) {
halted := true.B
}
when (io.bru.in.fault || vcoreUndef) {
fault := true.B
}
wfi := Mux(wfi, !io.irq, io.bru.in.wfi)
io.halted := halted
io.fault := fault
io.wfi := wfi
assert(!(io.fault && !io.halted && !io.wfi))
// Register state.
val rs1 = io.rs1.data
val rdata = MuxCase(0.U(32.W), Seq(
fflagsEn -> Cat(0.U(27.W), fflags),
frmEn -> Cat(0.U(29.W), frm),
fcsrEn -> Cat(0.U(24.W), fcsr),
mstatusEn -> Cat(0.U(19.W), mpp, 0.U(11.W)),
misaEn -> misa,
mieEn -> Cat(0.U(31.W), mie),
mtvecEn -> mtvec,
mscratchEn -> mscratch,
mepcEn -> mepc,
mcauseEn -> mcause,
mtvalEn -> mtval,
mcontext0En -> mcontext0,
mcontext1En -> mcontext1,
mcontext2En -> mcontext2,
mcontext3En -> mcontext3,
mcontext4En -> mcontext4,
mcontext5En -> mcontext5,
mcontext6En -> mcontext6,
mcontext7En -> mcontext7,
mpcEn -> mpc,
mspEn -> msp,
mcycleEn -> mcycle(31,0),
mcyclehEn -> mcycle(63,32),
minstretEn -> minstret(31,0),
minstrethEn -> minstret(63,32),
mvendoridEn -> mvendorid,
marchidEn -> Cat(0.U(31.W), marchid),
mimpidEn -> Cat(0.U(31.W), mimpid),
mhartidEn -> mhartid,
kisaEn -> kisa,
kscm0En -> kscm(31,0),
kscm1En -> kscm(63,32),
kscm2En -> kscm(95,64),
kscm3En -> kscm(127,96),
kscm4En -> kscm(159,128),
) ++
Option.when(p.enableRvv) {
Seq(
vstartEn.get -> io.rvv.get.vstart,
vxrmEn.get -> io.rvv.get.vxrm,
vxsatEn.get -> io.rvv.get.vxsat,
vlenbEn.get -> 16.U(32.W), // Vector length in Bytes
)
}.getOrElse(Seq())
++
Option.when(p.useDebugModule) {
Seq(
tselectEn.get -> tselect.get,
tdata1En.get -> tdata1.get.asWord,
tdata2En.get -> tdata2.get,
tinfoEn.get -> tinfo.get,
dcsrEn.get -> dcsr.get.asWord,
dpcEn.get -> dpc.get,
dscratch0En.get -> dscratch0.get,
dscratch1En.get -> dscratch1.get,
)
}.getOrElse(Seq())
)
val wdata = MuxLookup(req.bits.op, 0.U)(Seq(
CsrOp.CSRRW -> rs1,
CsrOp.CSRRS -> (rdata | rs1),
CsrOp.CSRRC -> (rdata & ~rs1)
))
when (req.valid) {
when (fflagsEn) { fflags := wdata }
when (frmEn) { frm := wdata }
when (fcsrEn) { fflags := wdata(4,0)
frm := wdata(7,5) }
when (mstatusEn) { mpp := wdata(12,11) }
when (mieEn) { mie := wdata }
when (mtvecEn) { mtvec := wdata }
when (mscratchEn) { mscratch := wdata }
when (mepcEn) { mepc := wdata }
when (mcauseEn) { mcause := wdata }
when (mtvalEn) { mtval := wdata }
when (mpcEn) { mpc := wdata }
when (mspEn) { msp := wdata }
when (mcontext0En) { mcontext0 := wdata }
when (mcontext1En) { mcontext1 := wdata }
when (mcontext2En) { mcontext2 := wdata }
when (mcontext3En) { mcontext3 := wdata }
when (mcontext4En) { mcontext4 := wdata }
when (mcontext5En) { mcontext5 := wdata }
when (mcontext6En) { mcontext6 := wdata }
when (mcontext7En) { mcontext7 := wdata }
if (p.useDebugModule) {
when (dscratch0En.get) { dscratch0.get := wdata }
when (dscratch1En.get) { dscratch1.get := wdata }
when (tdata1En.get) { tdata1.get := LegalizeTdata1(wdata) }
when (tdata2En.get) { tdata2.get := wdata }
}
}
if (p.enableRvv) {
io.rvv.get.vstart_write.valid := req.valid && vstartEn.get
io.rvv.get.vstart_write.bits := wdata(log2Ceil(p.rvvVlen)-1, 0)
io.rvv.get.vxrm_write.valid := req.valid && vxrmEn.get
io.rvv.get.vxrm_write.bits := wdata(1,0)
io.rvv.get.vxsat_write.valid := req.valid && vxsatEn.get
io.rvv.get.vxsat_write.bits := wdata(0)
}
// mcycle implementation
// If one of the enable signals for
// the register are true, overwrite the enabled half
// of the register.
// Increment the value of mcycle by 1.
val mcycle_th = Mux(mcyclehEn, wdata, mcycle(63,32))
val mcycle_tl = Mux(mcycleEn, wdata, mcycle(31,0))
val mcycle_t = Cat(mcycle_th, mcycle_tl)
mcycle := Mux(req.valid, mcycle_t, mcycle) + 1.U
val minstret_th = Mux(minstrethEn, wdata, minstret(63,32))
val minstret_tl = Mux(minstretEn, wdata, minstret(31,0))
val minstret_t = Cat(minstret_th, minstret_tl)
val minstretThisCycle = io.counters.rfwriteCount +
io.counters.storeCount +
io.counters.branchCount +
(if (p.enableVector) {
io.counters.vrfwriteCount.get +
io.counters.vstoreCount.get
} else { 0.U })
minstret := MuxCase(minstret, Seq(
req.valid -> minstret_t,
(minstretThisCycle =/= 0.U) -> (minstret + minstretThisCycle),
))
if (p.useDebugModule) {
val trigger_enabled = tdata1.get.isTrigger6
val trigger_match = (trigger_enabled && io.dm.get.next_pc === tdata2.get)
val entering_debug_mode = (mode =/= CsrMode.Debug) && (io.dm.get.debug_req || trigger_match)
val exiting_debug_mode = (mode === CsrMode.Debug) && (io.dm.get.resume_req)
mode := MuxCase(mode, Seq(
entering_debug_mode -> CsrMode.Debug,
exiting_debug_mode -> CsrMode.Machine,
io.bru.in.mode.valid -> io.bru.in.mode.bits,
))
io.dm.get.debug_mode := (mode === CsrMode.Debug)
dcsr.get := MuxCase(dcsr.get, Seq(
entering_debug_mode -> {
val newDcsr = Wire(new Dcsr)
newDcsr := dcsr.get
newDcsr.extcause := false.B
val causeWidth = newDcsr.cause.getWidth.W
newDcsr.cause := MuxCase(7.U(causeWidth), Seq(
(io.dm.get.debug_req && !io.dm.get.dcsr_step) -> 3.U(causeWidth),
trigger_match -> 2.U(causeWidth),
io.dm.get.dcsr_step -> 4.U(causeWidth),
))
newDcsr.prv := Mux(mode === CsrMode.Machine, 3.U(2.W), 0.U(2.W))
newDcsr
},
(req.valid && dcsrEn.get) -> wdata.asTypeOf(new Dcsr),
))
dpc.get := MuxCase(io.dm.get.next_pc, Seq(
entering_debug_mode -> io.dm.get.next_pc,
(req.valid && dpcEn.get) -> wdata,
))
io.dm.get.dcsr_step := dcsr.get.step
io.dm.get.single_step := trigger_enabled
} else {
when (io.bru.in.mode.valid) {
mode := io.bru.in.mode.bits
}
}
// High bit of mcause is set for an external interrupt.
val interrupt = mcause(31)
when (io.bru.in.mcause.valid) {
mcause := io.bru.in.mcause.bits
}
when (io.bru.in.mtval.valid) {
mtval := io.bru.in.mtval.bits
}
when (io.bru.in.mepc.valid) {
mepc := io.bru.in.mepc.bits
}
if (p.enableFloat) {
when (io.float.get.in.fflags.valid) {
fflags := io.float.get.in.fflags.bits | fflags
}
}
// Forwarding.
io.bru.out.mode := mode
io.bru.out.mepc := Mux(mepcEn && req.valid, wdata, mepc)
io.bru.out.mtvec := Mux(mtvecEn && req.valid, wdata, mtvec)
if (p.enableFloat) {
io.float.get.out.frm := Mux(frmEn && req.valid, wdata(2,0), frm)
}
io.csr.out.value(0) := io.csr.in.value(12)
io.csr.out.value(1) := mepc
io.csr.out.value(2) := mtval
io.csr.out.value(3) := mcause
io.csr.out.value(4) := mcycle(31,0)
io.csr.out.value(5) := mcycle(63,32)
io.csr.out.value(6) := minstret(31,0)
io.csr.out.value(7) := minstret(63,32)
// Write port.
io.rd.valid := req.valid
io.rd.bits.addr := req.bits.addr
io.rd.bits.data := rdata
if (p.useRetirementBuffer) {
io.trace.get.valid := req.valid
io.trace.get.addr := req.bits.index
io.trace.get.data := wdata
}
// Assertions.
assert(!(req.valid && !io.rs1.valid))
}