blob: 82ca9f75db7da4b6013873c09327083bd6746682 [file] [log] [blame]
// Copyright 2023 Google LLC
package matcha
import chisel3._
import chisel3.util._
object Crossbar {
def apply(ports: Int, addrbits: Int, databits: Int, idbits: Int) = {
Module(new Crossbar(ports, addrbits, databits, idbits))
}
}
class CrossbarIO(addrbits: Int, databits: Int, idbits: Int) extends Bundle {
// Command.
val cvalid = Output(Bool())
val cready = Input(Bool())
val cwrite = Output(Bool())
val caddr = Output(UInt(addrbits.W))
val cid = Output(UInt(idbits.W))
// Write.
val wdata = Output(UInt(databits.W))
val wmask = Output(UInt((databits / 8).W))
// Read Response.
val rvalid = Input(Bool())
val rid = Input(UInt(idbits.W))
val rdata = Input(UInt(databits.W))
}
class SramIO(addrbits: Int, databits: Int) extends Bundle {
// 1cc read response.
val valid = Output(Bool())
val write = Output(Bool())
val addr = Output(UInt(addrbits.W))
val wdata = Output(UInt(databits.W))
val wmask = Output(UInt((databits / 8).W))
val rdata = Input(UInt(databits.W))
}
class Crossbar(ports: Int, addrbits: Int, databits: Int, idbits: Int) extends Module {
val io = IO(new Bundle {
val in = Flipped(Vec(ports, new CrossbarIO(addrbits, databits, idbits)))
val out = new SramIO(addrbits, databits)
})
// Register the command interface and the read data response. 3cc latency.
//
// Cycle0: arbitrate and controls registered io.in(*)
// Cycle1: sram command io.out.valid
// Cycle2: {sram data registered}
// Cycle3: {pipelined read response}
val pidbits = idbits + log2Ceil(ports)
val alsb = log2Ceil(databits/8)
val amsb = addrbits - 1
val indexbits = addrbits - alsb
withReset(reset.asAsyncReset()) {
// ---------------------------------------------------------------------------
// Arbitrate.
val csel0 = RegInit(1.U(ports.W))
assert(PopCount(csel0) === 1.U)
def PriorityEncodeValid(i: Int = 0, active: Bool = false.B, output: UInt = 0.U((ports).W)): UInt = {
if (i == 0) {
PriorityEncodeValid(
i + 1,
io.in(i).cvalid,
io.in(i).cvalid
)
} else if (i < ports) {
PriorityEncodeValid(
i + 1,
active || io.in(i).cvalid,
Cat(io.in(i).cvalid && !active, output(i - 1, 0))
)
} else {
output
}
}
// Maintain last selection if no other activity.
val cvalid = Wire(Vec(ports, Bool()))
for (i <- 0 until ports) {
cvalid(i) := io.in(i).cvalid
}
when (cvalid.asUInt =/= 0.U) {
csel0 := PriorityEncodeValid()
}
for (i <- 0 until ports) {
io.in(i).cready := csel0(i)
}
// ---------------------------------------------------------------------------
// Controls.
def CEnable(i: Int = 0, enable: Bool = false.B): Bool = {
if (i < ports) {
CEnable(
i + 1,
enable || io.in(i).cvalid && csel0(i)
)
} else {
enable
}
}
val cen0 = CEnable()
// ---------------------------------------------------------------------------
// Controls.
val cvalid1 = RegInit(false.B)
val cwrite1 = RegInit(false.B)
val cindex1 = Reg(UInt(indexbits.W))
val cid1 = Reg(UInt(pidbits.W))
val wdata1 = Reg(UInt(databits.W))
val wmask1 = Reg(UInt((databits / 8).W))
def CData(i: Int = 0,
iwrite: Bool = false.B, iindex: UInt = 0.U(indexbits.W), iid: UInt = 0.U(pidbits.W),
idata: UInt = 0.U(databits.W), imask: UInt = 0.U((databits / 8).W)
): (Bool, UInt, UInt, UInt, UInt) = {
if (i < ports) {
CData(
i + 1,
iwrite || Mux(csel0(i), io.in(i).cwrite, false.B),
iindex | Mux(csel0(i), io.in(i).caddr(amsb,alsb), 0.U),
iid | Mux(csel0(i), Cat(i.U, io.in(i).cid), 0.U),
idata | Mux(csel0(i), io.in(i).wdata, 0.U),
imask | Mux(csel0(i), io.in(i).wmask, 0.U)
)
} else {
(iwrite, iindex, iid, idata, imask)
}
}
cvalid1 := cen0
when (cen0) {
val (cwriteNxt, cindexNxt, cidNxt, wdataNxt, wmaskNxt) = CData()
cwrite1 := cwriteNxt
cindex1 := cindexNxt
cid1 := cidNxt
when (cwriteNxt) {
wdata1 := wdataNxt
wmask1 := wmaskNxt
}
}
io.out.valid := cvalid1
io.out.write := cwrite1
io.out.addr := Cat(cindex1, 0.U(alsb.W))
io.out.wdata := wdata1
io.out.wmask := wmask1
assert(!(io.out.valid && io.out.addr(alsb - 1, 0) =/= 0.U))
// ---------------------------------------------------------------------------
// Read Data.
val rvalid2 = RegInit(false.B)
val rid2 = Reg(UInt(pidbits.W))
rvalid2 := cvalid1 && !cwrite1
rid2 := cid1
// ---------------------------------------------------------------------------
// Read Response.
val rvalid3 = RegInit(VecInit(Seq.fill(ports)(false.B)))
val rid3 = Reg(UInt(idbits.W))
val rdata3 = Reg(UInt(databits.W))
if (ports > 1) {
for (i <- 0 until ports) {
rvalid3(i) := rvalid2 && rid2(pidbits - 1, idbits) === i.U
}
} else {
rvalid3(0) := rvalid2
}
when (rvalid2) {
rdata3 := io.out.rdata
rid3 := rid2(idbits - 1, 0)
}
for (i <- 0 until ports) {
io.in(i).rvalid := rvalid3(i)
io.in(i).rdata := Mux(rvalid3(i), rdata3, 0.U)
io.in(i).rid := Mux(rvalid3(i), rid3, 0.U)
}
}
}
object EmitCrossbar extends App {
// 4MB = 2^22 = 2^17 * 256/8
(new chisel3.stage.ChiselStage).emitVerilog(new Crossbar(4, 22, 256, 8), args)
}