blob: d6de873b58077b412ba8c8ffe1b4794eccf3c6af [file] [log] [blame]
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Fetch Unit: 4 way fetcher that directly feeds the 4 decoders.
// The fetcher itself has a partial decoder to identify branches, where backwards
// branches are assumed taken and forward branches assumed not taken.
package kelvin
import chisel3._
import chisel3.util._
import common._
import _root_.circt.stage.ChiselStage
object Fetch {
def apply(p: Parameters): Fetch = {
return Module(new Fetch(p))
}
}
class IBusIO(p: Parameters) extends Bundle {
// Control Phase.
val valid = Output(Bool())
val ready = Input(Bool())
val addr = Output(UInt(p.fetchAddrBits.W))
// Read Phase.
val rdata = Input(UInt(p.fetchDataBits.W))
}
class FetchInstruction(p: Parameters) extends Bundle {
val valid = Output(Bool())
val ready = Input(Bool())
val addr = Output(UInt(p.programCounterBits.W))
val inst = Output(UInt(p.instructionBits.W))
val brchFwd = Output(Bool())
}
class FetchIO(p: Parameters) extends Bundle {
val lanes = Vec(p.instructionLanes, new FetchInstruction(p))
}
class Fetch(p: Parameters) extends Module {
val io = IO(new Bundle {
val csr = new CsrInIO(p)
val ibus = new IBusIO(p)
val inst = new FetchIO(p)
val branch = Flipped(Vec(p.instructionLanes, new BranchTakenIO(p)))
val linkPort = Flipped(new RegfileLinkPortIO)
val iflush = Flipped(new IFlushIO(p))
})
// This is the only compiled and tested configuration (at this time).
assert(p.fetchAddrBits == 32)
assert(p.fetchDataBits == 256)
val aslice = Slice(UInt(p.fetchAddrBits.W), true)
val readAddr = Reg(UInt(p.fetchAddrBits.W))
val readDataEn = RegInit(false.B)
val readAddrEn = io.ibus.valid && io.ibus.ready
val readData = io.ibus.rdata
readDataEn := readAddrEn && !io.iflush.valid
io.iflush.ready := !aslice.io.out.valid
// L0 cache
// ____________________________________
// | Tag |Index|xxxxx|
// ¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯
val lanes = p.fetchDataBits / p.instructionBits // input lanes
val indices = p.fetchCacheBytes * 8 / p.fetchDataBits
val indexLsb = log2Ceil(p.fetchDataBits / 8)
val indexMsb = log2Ceil(indices) + indexLsb - 1
val tagLsb = indexMsb + 1
val tagMsb = p.fetchAddrBits - 1
val indexCountBits = log2Ceil(indices - 1)
if (p.fetchCacheBytes == 1024) {
assert(indexLsb == 5)
assert(indexMsb == 9)
assert(tagLsb == 10)
assert(tagMsb == 31)
assert(indices == 32)
assert(indexCountBits == 5)
assert(lanes == 8)
}
val l0valid = RegInit(0.U(indices.W))
val l0req = RegInit(0.U(indices.W))
val l0tag = Reg(Vec(indices, UInt((tagMsb - tagLsb + 1).W)))
val l0data = Reg(Vec(indices, UInt(p.fetchDataBits.W)))
// Instruction outputs.
val instValid = RegInit(VecInit(Seq.fill(p.instructionLanes)(false.B)))
val instAddr = Reg(Vec(p.instructionLanes, UInt(p.instructionBits.W)))
val instBits = Reg(Vec(p.instructionLanes, UInt(p.instructionBits.W)))
val instAligned0 = Cat(instAddr(0)(31, indexLsb), 0.U(indexLsb.W))
val instAligned1 = instAligned0 + Cat(1.U, 0.U(indexLsb.W))
val instIndex0 = instAligned0(indexMsb, indexLsb)
val instIndex1 = instAligned1(indexMsb, indexLsb)
val instTag0 = instAligned0(tagMsb, tagLsb)
val instTag1 = instAligned1(tagMsb, tagLsb)
val l0valid0 = l0valid(instIndex0)
val l0valid1 = l0valid(instIndex1)
val l0tag0 = VecAt(l0tag, instIndex0)
val l0tag1 = VecAt(l0tag, instIndex1)
val match0 = l0valid0 && instTag0 === l0tag0
val match1 = l0valid1 && instTag1 === l0tag1
// Read interface.
// Do not request entries that are already inflight.
// Perform a branch tag lookup to see if target is in cache.
def Predecode(addr: UInt, op: UInt): (Bool, UInt) = {
val jal = DecodeBits(op, "xxxxxxxxxxxxxxxxxxxx_xxxxx_1101111")
val immed = Cat(Fill(12, op(31)), op(19,12), op(20), op(30,21), 0.U(1.W))
val target = addr + immed
(jal, target)
}
val preBranch = (0 until p.instructionLanes).map(x => Predecode(instAddr(x), instBits(x)))
val preBranchTakens = preBranch.map { case (taken, target) => taken }
val preBranchTargets = preBranch.map { case (taken, target) => target }
val preBranchTaken = (0 until p.instructionLanes).map(i =>
io.inst.lanes(i).valid && preBranchTakens(i)).reduce(_ || _)
val preBranchTarget = MuxCase(
preBranchTargets(p.instructionLanes - 1),
(0 until p.instructionLanes - 1).map(i => preBranchTakens(i) -> preBranchTargets(i))
)
val preBranchTag = preBranchTarget(tagMsb, tagLsb)
val preBranchIndex = preBranchTarget(indexMsb, indexLsb)
val branchTags = io.branch.map(x => x.value(tagMsb, tagLsb))
val branchIndices = io.branch.map(x => x.value(indexMsb, indexLsb))
val l0valids = (0 until p.instructionLanes).map(x => l0valid(branchIndices(x)))
val l0validP = l0valid(preBranchIndex)
val l0tags = (0 until p.instructionLanes).map(x => VecAt(l0tag, branchIndices(x)))
val l0tagP = VecAt(l0tag, preBranchIndex)
val reqBValid = (0 until p.instructionLanes).map(x =>
io.branch(x).valid && !l0req(branchIndices(x)) &&
(branchTags(x) =/= l0tags(x) || !l0valids(x)))
val prevValid = io.branch.map(_.valid).scan(false.B)(_||_)
val reqs = (0 until p.instructionLanes).map(x => reqBValid(x) && !prevValid(x))
val reqP = preBranchTaken && !l0req(preBranchIndex) && (preBranchTag =/= l0tagP || !l0validP)
val req0 = !match0 && !l0req(instIndex0)
val req1 = !match1 && !l0req(instIndex1)
aslice.io.in.valid := (reqs ++ Seq(reqP, req0, req1)).reduce(_ || _) && !io.iflush.valid
aslice.io.in.bits := MuxCase(instAligned1,
(0 until p.instructionLanes).map(x => reqs(x) -> Cat(io.branch(x).value(31,indexLsb), 0.U(indexLsb.W))) ++
Array(
reqP -> Cat(preBranchTarget(31,indexLsb), 0.U(indexLsb.W)),
req0 -> instAligned0,
)
)
when (readAddrEn) {
readAddr := io.ibus.addr
}
io.ibus.valid := aslice.io.out.valid
aslice.io.out.ready := io.ibus.ready || io.iflush.valid
io.ibus.addr := aslice.io.out.bits
// initialize tags to 1s as 0xfffxxxxx are invalid instruction addresses
val l0validClr = WireInit(0.U(indices.W))
val l0validSet = WireInit(0.U(indices.W))
val l0reqClr = WireInit(0.U(indices.W))
val l0reqSet = WireInit(0.U(indices.W))
val readIdx = readAddr(indexMsb, indexLsb)
for (i <- 0 until indices) {
when (readDataEn && readIdx === i.U) {
l0tag(i.U) := readAddr(tagMsb, tagLsb)
l0data(i.U) := readData
}
}
when (readDataEn) {
val bits = UIntToOH(readIdx, indices)
l0validSet := bits
l0reqClr := bits
}
when (io.iflush.valid) {
val clr = ~(0.U(l0validClr.getWidth.W))
l0validClr := clr
l0reqClr := clr
}
when (aslice.io.in.valid && aslice.io.in.ready) {
l0reqSet := UIntToOH(aslice.io.in.bits(indexMsb, indexLsb), indices)
}
when (l0validClr =/= 0.U || l0validSet =/= 0.U) {
l0valid := (l0valid | l0validSet) & ~l0validClr
}
when (l0reqClr =/= 0.U || l0reqSet =/= 0.U) {
l0req := (l0req | l0reqSet) & ~l0reqClr
}
// Instruction Outputs
// Do not use the next instruction address directly in the lookup, as that
// creates excessive timing pressure. We know that the match is either on
// the old line or the next line, so can late mux on lookups of prior.
// Widen the arithmetic paths and select from results.
val fetchEn = Wire(Vec(p.instructionLanes, Bool()))
for (i <- 0 until p.instructionLanes) {
fetchEn(i) := io.inst.lanes(i).valid && io.inst.lanes(i).ready
}
val fsela = Cat((0 until p.instructionLanes).reverse.map(x =>
(x until p.instructionLanes).map(y =>
(if (y == x) { fetchEn(y) } else { !fetchEn(y) })
).reduce(_ && _)
))
val fselb = (0 until p.instructionLanes).map(x => !fetchEn(x)).reduce(_ && _)
val fsel = Cat(fsela, fselb)
val nxtInstAddrOffset = instAddr.map(x => x) ++ instAddr.map(x => x + (p.instructionLanes * 4).U)
val nxtInstAddr = (0 until p.instructionLanes).map(i =>
(0 until (p.instructionLanes + 1)).map(
j => MuxOR(fsel(j), nxtInstAddrOffset(j + i))).reduce(_|_))
val nxtInstIndex0 = nxtInstAddr(0)(indexMsb, indexLsb)
val nxtInstIndex1 = nxtInstAddr(p.instructionLanes - 1)(indexMsb, indexLsb)
val readFwd0 =
readDataEn && readAddr(31,indexLsb) === instAligned0(31,indexLsb)
val readFwd1 =
readDataEn && readAddr(31,indexLsb) === instAligned1(31,indexLsb)
val nxtMatch0Fwd = match0 || readFwd0
val nxtMatch1Fwd = match1 || readFwd1
val nxtMatch0 =
Mux(instIndex0(0) === nxtInstIndex0(0), nxtMatch0Fwd, nxtMatch1Fwd)
val nxtMatch1 =
Mux(instIndex0(0) === nxtInstIndex1(0), nxtMatch0Fwd, nxtMatch1Fwd)
val nxtInstValid = Wire(Vec(p.instructionLanes, Bool()))
val nxtInstBits0 = Mux(readFwd0, readData, VecAt(l0data, instIndex0))
val nxtInstBits1 = Mux(readFwd1, readData, VecAt(l0data, instIndex1))
val nxtInstBits = Wire(Vec(16, UInt(p.instructionBits.W)))
for (i <- 0 until 8) {
val offset = 32 * i
nxtInstBits(i + 0) := nxtInstBits0(31 + offset, offset)
nxtInstBits(i + 8) := nxtInstBits1(31 + offset, offset)
}
def BranchMatchDe(valid: Bool, value: UInt):
(Bool, UInt, Vec[UInt], Vec[UInt]) = {
val addr = VecInit((0 until p.instructionLanes).map(x => value + (x * 4).U))
val match0 = l0valid(addr(0)(indexMsb,indexLsb)) &&
addr(0)(tagMsb,tagLsb) === VecAt(l0tag, addr(0)(indexMsb,indexLsb))
val match1 = l0valid(addr(p.instructionLanes - 1)(indexMsb,indexLsb)) &&
addr(p.instructionLanes - 1)(tagMsb,tagLsb) === VecAt(l0tag, addr(p.instructionLanes - 1)(indexMsb,indexLsb))
val vvalid = VecInit((0 until p.instructionLanes).reverse.map(x =>
Mux(addr(0)(2 + log2Ceil(p.instructionLanes),2) <= (4+x).U, match0, match1)))
val muxbits0 = VecAt(l0data, addr(0)(indexMsb,indexLsb))
val muxbits1 = VecAt(l0data, addr(p.instructionLanes - 1)(indexMsb,indexLsb))
val muxbits = Wire(Vec(16, UInt(p.instructionBits.W)))
for (i <- 0 until 8) {
val offset = 32 * i
muxbits(i + 0) := muxbits0(31 + offset, offset)
muxbits(i + 8) := muxbits1(31 + offset, offset)
}
val bits = Wire(Vec(p.instructionLanes, UInt(p.instructionBits.W)))
for (i <- 0 until p.instructionLanes) {
val idx = Cat(addr(0)(5) =/= addr(i)(5), addr(i)(4,2))
bits(i) := VecAt(muxbits, idx)
}
(valid, vvalid.asUInt, addr, bits)
}
def BranchMatchEx(branch: Vec[BranchTakenIO]):
(Bool, UInt, Vec[UInt], Vec[UInt]) = {
val valid = branch.map(x => x.valid).reduce(_ || _)
val addr = VecInit((0 until branch.length).map(x =>
MuxCase(branch(branch.length - 1).value + (x * 4).U, (
(0 until branch.length - 1).map(y =>
branch(y).valid -> (branch(y).value + (x * 4).U)
)
))))
val match0 = l0valid(addr(0)(indexMsb,indexLsb)) &&
addr(0)(tagMsb,tagLsb) === VecAt(l0tag, addr(0)(indexMsb,indexLsb))
val match1 = l0valid(addr(branch.length - 1)(indexMsb,indexLsb)) &&
addr(branch.length - 1)(tagMsb,tagLsb) === VecAt(l0tag, addr(branch.length - 1)(indexMsb,indexLsb))
val vvalid = VecInit((0 until branch.length).reverse.map(x =>
Mux(addr(0)(2 + log2Ceil(branch.length),2) <= (4 + x).U, match0, match1)))
val muxbits0 = VecAt(l0data, addr(0)(indexMsb,indexLsb))
val muxbits1 = VecAt(l0data, addr(branch.length - 1)(indexMsb,indexLsb))
val muxbits = Wire(Vec(16, UInt(p.instructionBits.W)))
for (i <- 0 until 8) {
val offset = 32 * i
muxbits(i + 0) := muxbits0(31 + offset, offset)
muxbits(i + 8) := muxbits1(31 + offset, offset)
}
val bits = Wire(Vec(branch.length, UInt(p.instructionBits.W)))
for (i <- 0 until branch.length) {
val idx = Cat(addr(0)(5) =/= addr(i)(5), addr(i)(4,2))
bits(i) := VecAt(muxbits, idx)
}
(valid, vvalid.asUInt, addr, bits)
}
def PredecodeDe(addr: UInt, op: UInt): (Bool, UInt) = {
val jal = DecodeBits(op, "xxxxxxxxxxxxxxxxxxxx_xxxxx_1101111")
val ret = DecodeBits(op, "000000000000_00001_000_00000_1100111") &&
io.linkPort.valid
val bxx = DecodeBits(op, "xxxxxxx_xxxxx_xxxxx_xxx_xxxxx_1100011") &&
op(31) && op(14,13) =/= 1.U
val immjal = Cat(Fill(12, op(31)), op(19,12), op(20), op(30,21), 0.U(1.W))
val immbxx = Cat(Fill(20, op(31)), op(7), op(30,25), op(11,8), 0.U(1.W))
val immed = Mux(op(2), immjal, immbxx)
val target = Mux(ret, io.linkPort.value, addr + immed)
(jal || ret || bxx, target)
}
val brchDe = (0 until p.instructionLanes).map(x => PredecodeDe(instAddr(x), instBits(x)))
val brchTakensDe = brchDe.map { case (taken, target) => taken }
val brchTargetsDe = brchDe.map { case (taken, target) => target }
val brchTakenDeOr = (0 until p.instructionLanes).map(x =>
io.inst.lanes(x).ready && io.inst.lanes(x).valid && brchTakensDe(x)
).reduce(_ || _)
val brchTargetDe = MuxCase(brchTargetsDe(p.instructionLanes - 1),
(0 until p.instructionLanes - 1).map(x => brchTakensDe(x) -> brchTargetsDe(x))
)
val (brchTakenDe, brchValidDe, brchAddrDe, brchBitsDe) =
BranchMatchDe(brchTakenDeOr, brchTargetDe)
val (brchTakenEx, brchValidEx, brchAddrEx, brchBitsEx) =
BranchMatchEx(io.branch)
val brchValidDeMask =
Cat((0 until p.instructionLanes).reverse.map(x =>
if (x == 0) { true.B } else {
(0 until x).map(y =>
!brchTakensDe(y)
).reduce(_ && _)
}
))
val brchFwd =
Cat((0 until p.instructionLanes).reverse.map(x =>
brchTakensDe(x) && (if (x == 0) { true.B } else { (0 until x).map(y => !brchTakensDe(y)).reduce(_ && _) })
))
for (i <- 0 until p.instructionLanes) {
// 1, 11, 111, ...
nxtInstValid(i) := Mux(
nxtInstAddr(0)(4,2) <= (7 - i).U,
nxtMatch0,
nxtMatch1)
val nxtInstValidUInt = nxtInstValid.asUInt
instValid(i) := Mux(brchTakenEx, brchValidEx(i,0) === ~0.U((i+1).W),
Mux(brchTakenDe, brchValidDe(i,0) === ~0.U((i+1).W),
nxtInstValidUInt(i,0) === ~0.U((i+1).W))) && !io.iflush.valid
instAddr(i) := Mux(brchTakenEx, brchAddrEx(i),
Mux(brchTakenDe, brchAddrDe(i), nxtInstAddr(i)))
// The (2,0) bits are the offset within the base line plus the next line.
// The (3) bit of the index must factor the base difference of addresses
// instAddr and nxtInstAddr which are line aligned.
val idx = Cat(instAddr(0)(5) =/= nxtInstAddr(i)(5), nxtInstAddr(i)(4,2))
instBits(i) := Mux(brchTakenEx, brchBitsEx(i),
Mux(brchTakenDe, brchBitsDe(i),
VecAt(nxtInstBits, idx)))
}
// This pattern of separate when() blocks requires resets after the data.
when (reset.asBool) {
val addr = Cat(io.csr.value(0)(31,2), 0.U(2.W))
instAddr := (0 until p.instructionLanes).map(i => addr + (4 * i).U)
}
// Outputs
for (i <- 0 until p.instructionLanes) {
io.inst.lanes(i).valid := instValid(i) & brchValidDeMask(i)
io.inst.lanes(i).addr := instAddr(i)
io.inst.lanes(i).inst := instBits(i)
io.inst.lanes(i).brchFwd := brchFwd(i)
}
// Assertions.
for (i <- 1 until p.instructionLanes) {
assert(instAddr(0) + (4 * i).U === instAddr(i))
}
assert(fsel.getWidth == (p.instructionLanes + 1))
assert(PopCount(fsel) <= 1.U)
val instValidUInt = instValid.asUInt
val instLanesReady = Cat((0 until p.instructionLanes).reverse.map(x => io.inst.lanes(x).ready))
for (i <- 0 until p.instructionLanes - 1) {
assert(!(!instValidUInt(i) && (instValidUInt(p.instructionLanes - 1, i + 1) =/= 0.U)))
assert(!(!instLanesReady(i) && (instLanesReady(p.instructionLanes - 1, i + 1) =/= 0.U)))
}
}
object EmitFetch extends App {
val p = new Parameters
ChiselStage.emitSystemVerilogFile(new Fetch(p), args)
}