blob: cb9f419393d86cbb0334f21d7ce56bef629587e0 [file] [log] [blame]
/*
* Copyright 2023 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package kelvin
import chisel3._
import chisel3.util._
import common.Fifo4x4
object VDecode {
def apply(p: Parameters): VDecode = {
return Module(new VDecode(p))
}
}
class VDecode(p: Parameters) extends Module {
val io = IO(new Bundle {
val in = Flipped(Decoupled(Vec(4, Valid(new VectorInstructionLane))))
val out = Vec(4, Decoupled(new VDecodeBits))
val cmdq = Vec(4, Output(new VDecodeCmdq))
val actv = Vec(4, Output(new VDecodeActive)) // used in testbench
val stall = Output(Bool())
val active = Input(UInt(64.W))
val vrfsb = new VRegfileScoreboardIO
val undef = Output(Bool())
val nempty = Output(Bool())
})
val guard = 8 // two cycles of 4-way dispatch
val depth = 16 + guard
val enc = new VEncodeOp()
val f = Fifo4x4(new VectorInstructionLane, depth)
val d = Seq(Module(new VDecodeInstruction(p)),
Module(new VDecodeInstruction(p)),
Module(new VDecodeInstruction(p)),
Module(new VDecodeInstruction(p)))
val e = Wire(Vec(4, new VDecodeBits))
val valid = RegInit(VecInit(Seq.fill(4)(false.B)))
val data = Reg(Vec(4, new VDecodeBits))
val cmdq = Reg(Vec(4, new VDecodeCmdq))
val actv = Wire(Vec(4, new VDecodeActive))
val actv2 = Reg(Vec(4, new VDecodeActive2))
val dataNxt = Wire(Vec(4, new VDecodeBits))
val cmdqNxt = Wire(Vec(4, new VDecodeCmdq))
val actvNxt = Wire(Vec(4, new VDecodeActive2))
// ---------------------------------------------------------------------------
// Decode.
for (i <- 0 until 4) {
d(i).io.in := f.io.out(i).bits
}
// ---------------------------------------------------------------------------
// Apply "out-of-order" tags to read/write registers.
// Since only one write may be outstanding, track using 1bit which side of
// write the read usage is occurring on.
val tagReg = RegInit(0.U(64.W))
val tag0 = tagReg
val tag1 = tag0 ^ d(0).io.actv.wactive
val tag2 = tag1 ^ d(1).io.actv.wactive
val tag3 = tag2 ^ d(2).io.actv.wactive
val tag4 = tag3 ^ d(3).io.actv.wactive
val tags = Seq(tag0, tag1, tag2, tag3, tag4)
// f.io.out is ordered, so can use a priority tree.
when(f.io.out(3).valid && f.io.out(3).ready) {
tagReg := tag4
} .elsewhen(f.io.out(2).valid && f.io.out(2).ready) {
tagReg := tag3
} .elsewhen(f.io.out(1).valid && f.io.out(1).ready) {
tagReg := tag2
} .elsewhen(f.io.out(0).valid && f.io.out(0).ready) {
tagReg := tag1
}
def TagAddr(tag: UInt, v: VAddrTag): VAddrTag = {
assert(tag.getWidth == 64)
assert(v.addr.getWidth == 6)
assert(v.tag === 0.U)
val addr = v.addr
val addrm = addr(5,2)
val tagm = Wire(Vec(16, UInt(4.W)))
for (i <- 0 until 16) {
tagm(i) := tag(4 * i + 3, 4 * i)
}
val r = Wire(new VAddrTag())
r.valid := v.valid
r.addr := v.addr
r.tag := VecAt(tagm, addrm)
r
}
for (i <- 0 until 4) {
e(i) := d(i).io.out
e(i).vs := TagAddr(tags(i), d(i).io.out.vs)
e(i).vt := TagAddr(tags(i), d(i).io.out.vt)
e(i).vu := TagAddr(tags(i), d(i).io.out.vu)
e(i).vx := TagAddr(tags(i), d(i).io.out.vx)
e(i).vy := TagAddr(tags(i), d(i).io.out.vy)
e(i).vz := TagAddr(tags(i), d(i).io.out.vz)
}
// ---------------------------------------------------------------------------
// Undef. (io.in.ready ignored to signal as early as possible)
io.undef := io.in.valid && (d(0).io.undef || d(1).io.undef || d(2).io.undef || d(3).io.undef)
// ---------------------------------------------------------------------------
// Fifo.
f.io.in <> io.in
val icount = MuxOR(io.in.valid, PopCount(Cat(io.in.bits(0).valid, io.in.bits(1).valid, io.in.bits(2).valid, io.in.bits(3).valid)))
assert(icount.getWidth == 3)
val ocount = PopCount(Cat(valid(0) && !(io.out(0).valid && io.out(0).ready),
valid(1) && !(io.out(1).valid && io.out(1).ready),
valid(2) && !(io.out(2).valid && io.out(2).ready),
valid(3) && !(io.out(3).valid && io.out(3).ready)))
assert(ocount.getWidth == 3)
for (i <- 0 until 4) {
f.io.out(i).ready := (i.U + ocount) < 4.U
}
// ---------------------------------------------------------------------------
// Valid.
val fcount = PopCount(Cat(f.io.out(0).valid && f.io.out(0).ready,
f.io.out(1).valid && f.io.out(1).ready,
f.io.out(2).valid && f.io.out(2).ready,
f.io.out(3).valid && f.io.out(3).ready))
assert(fcount.getWidth == 3)
for (i <- 0 until 4) {
valid(i) := (ocount + fcount) > i.U
}
// ---------------------------------------------------------------------------
// Stall.
io.stall := (f.io.count + icount) > (depth - guard).U
// ---------------------------------------------------------------------------
// Dependencies.
val depends = Wire(Vec(4, Bool()))
// Writes must not proceed past any outstanding reads or writes,
// or past any dispatching writes.
val wactive0 = io.vrfsb.data(63, 0) | io.vrfsb.data(127, 64) | io.active
val wactive1 = actv(0).ractive | actv(0).wactive | wactive0
val wactive2 = actv(1).ractive | actv(1).wactive | wactive1
val wactive3 = actv(2).ractive | actv(2).wactive | wactive2
val wactive = VecInit(wactive0, wactive1, wactive2, wactive3)
// Reads must not proceed past any dispatching writes.
val ractive0 = 0.U(64.W)
val ractive1 = actv(0).wactive | ractive0
val ractive2 = actv(1).wactive | ractive1
val ractive3 = actv(2).wactive | ractive2
val ractive = VecInit(ractive0, ractive1, ractive2, ractive3)
for (i <- 0 until 4) {
depends(i) := (wactive(i) & actv(i).wactive) =/= 0.U ||
(ractive(i) & actv(i).ractive) =/= 0.U
}
// ---------------------------------------------------------------------------
// Data.
val fvalid = VecInit(f.io.out(0).valid, f.io.out(1).valid,
f.io.out(2).valid, f.io.out(3).valid).asUInt
assert(!(fvalid(1) && fvalid(0,0) =/= 1.U))
assert(!(fvalid(2) && fvalid(1,0) =/= 3.U))
assert(!(fvalid(3) && fvalid(2,0) =/= 7.U))
// Register is updated when fifo has state or contents are active.
val dataEn = fvalid(0) || valid.asUInt =/= 0.U
for (i <- 0 until 4) {
when (dataEn) {
data(i) := dataNxt(i)
cmdq(i) := cmdqNxt(i)
actv2(i) := actvNxt(i)
}
}
for (i <- 0 until 4) {
actv(i).ractive := actv2(i).ractive
actv(i).wactive := actv2(i).wactive(63, 0) | actv2(i).wactive(127, 64)
}
// Tag the decode wactive.
val dactv = Wire(Vec(4, new VDecodeActive2))
for (i <- 0 until 4) {
val w0 = d(i).io.actv.wactive & ~tags(i + 1)
val w1 = d(i).io.actv.wactive & tags(i + 1)
dactv(i).ractive := d(i).io.actv.ractive
dactv(i).wactive := Cat(w1, w0)
}
// Data multiplexor of current values and fifo+decode output.
val dataMux = VecInit(data(0), data(1), data(2), data(3),
e(0), e(1), e(2), e(3))
val cmdqMux = VecInit(cmdq(0), cmdq(1), cmdq(2), cmdq(3),
d(0).io.cmdq, d(1).io.cmdq, d(2).io.cmdq, d(3).io.cmdq)
val actvMux = VecInit(actv2(0), actv2(1), actv2(2), actv2(3),
dactv(0), dactv(1), dactv(2), dactv(3))
// Mark the multiplexor entries that need to be kept.
val marked0 = Wire(UInt(5.W))
val marked1 = Wire(UInt(6.W))
val marked2 = Wire(UInt(7.W))
assert((marked1 & marked0) === marked0)
assert((marked2 & marked0) === marked0)
assert((marked2 & marked1) === marked1)
val output = Cat(io.out(3).valid && io.out(3).ready,
io.out(2).valid && io.out(2).ready,
io.out(1).valid && io.out(1).ready,
io.out(0).valid && io.out(0).ready)
when (valid(0) && !output(0)) {
dataNxt(0) := dataMux(0)
cmdqNxt(0) := cmdqMux(0)
actvNxt(0) := actvMux(0)
marked0 := 0x01.U
} .elsewhen (valid(1) && !output(1)) {
dataNxt(0) := dataMux(1)
cmdqNxt(0) := cmdqMux(1)
actvNxt(0) := actvMux(1)
marked0 := 0x03.U
} .elsewhen (valid(2) && !output(2)) {
dataNxt(0) := dataMux(2)
cmdqNxt(0) := cmdqMux(2)
actvNxt(0) := actvMux(2)
marked0 := 0x07.U
} .elsewhen (valid(3) && !output(3)) {
dataNxt(0) := dataMux(3)
cmdqNxt(0) := cmdqMux(3)
actvNxt(0) := actvMux(3)
marked0 := 0x0f.U
} .otherwise {
dataNxt(0) := dataMux(4)
cmdqNxt(0) := cmdqMux(4)
actvNxt(0) := actvMux(4)
marked0 := 0x1f.U
}
when (!marked0(1) && valid(1) && !output(1)) {
dataNxt(1) := dataMux(1)
cmdqNxt(1) := cmdqMux(1)
actvNxt(1) := actvMux(1)
marked1 := 0x03.U
} .elsewhen (!marked0(2) && valid(2) && !output(2)) {
dataNxt(1) := dataMux(2)
cmdqNxt(1) := cmdqMux(2)
actvNxt(1) := actvMux(2)
marked1 := 0x07.U
} .elsewhen (!marked0(3) && valid(3) && !output(3)) {
dataNxt(1) := dataMux(3)
cmdqNxt(1) := cmdqMux(3)
actvNxt(1) := actvMux(3)
marked1 := 0x0f.U
} .elsewhen (!marked0(4)) {
dataNxt(1) := dataMux(4)
cmdqNxt(1) := cmdqMux(4)
actvNxt(1) := actvMux(4)
marked1 := 0x1f.U
} .otherwise {
dataNxt(1) := dataMux(5)
cmdqNxt(1) := cmdqMux(5)
actvNxt(1) := actvMux(5)
marked1 := 0x3f.U
}
when (!marked1(2) && valid(2) && !output(2)) {
dataNxt(2) := dataMux(2)
cmdqNxt(2) := cmdqMux(2)
actvNxt(2) := actvMux(2)
marked2 := 0x07.U
} .elsewhen (!marked1(3) && valid(3) && !output(3)) {
dataNxt(2) := dataMux(3)
cmdqNxt(2) := cmdqMux(3)
actvNxt(2) := actvMux(3)
marked2 := 0x0f.U
} .elsewhen (!marked1(4)) {
dataNxt(2) := dataMux(4)
cmdqNxt(2) := cmdqMux(4)
actvNxt(2) := actvMux(4)
marked2 := 0x1f.U
} .elsewhen (!marked1(5)) {
dataNxt(2) := dataMux(5)
cmdqNxt(2) := cmdqMux(5)
actvNxt(2) := actvMux(5)
marked2 := 0x3f.U
} .otherwise {
dataNxt(2) := dataMux(6)
cmdqNxt(2) := cmdqMux(6)
actvNxt(2) := actvMux(6)
marked2 := 0x7f.U
}
when (!marked2(3) && valid(3) && !output(3)) {
dataNxt(3) := dataMux(3)
cmdqNxt(3) := cmdqMux(3)
actvNxt(3) := actvMux(3)
} .elsewhen (!marked2(4)) {
dataNxt(3) := dataMux(4)
cmdqNxt(3) := cmdqMux(4)
actvNxt(3) := actvMux(4)
} .elsewhen (!marked2(5)) {
dataNxt(3) := dataMux(5)
cmdqNxt(3) := cmdqMux(5)
actvNxt(3) := actvMux(5)
} .elsewhen (!marked2(6)) {
dataNxt(3) := dataMux(6)
cmdqNxt(3) := cmdqMux(6)
actvNxt(3) := actvMux(6)
} .otherwise {
dataNxt(3) := dataMux(7)
cmdqNxt(3) := cmdqMux(7)
actvNxt(3) := actvMux(7)
}
// ---------------------------------------------------------------------------
// Scoreboard.
io.vrfsb.set.valid := output(0) || output(1) || output(2) || output(3)
io.vrfsb.set.bits := (MuxOR(output(0), actv2(0).wactive) |
MuxOR(output(1), actv2(1).wactive) |
MuxOR(output(2), actv2(2).wactive) |
MuxOR(output(3), actv2(3).wactive))
assert((io.vrfsb.set.bits(63, 0) & io.vrfsb.set.bits(127, 64)) === 0.U)
assert(((io.vrfsb.data(63, 0) | io.vrfsb.data(127, 64)) & (io.vrfsb.set.bits(63, 0) | io.vrfsb.set.bits(127, 64))) === 0.U)
// ---------------------------------------------------------------------------
// Outputs.
val outvalid = Wire(Vec(4, Bool()))
val cmdsync = Wire(Vec(4, Bool()))
for (i <- 0 until 4) {
outvalid(i) := valid(i) && !depends(i)
cmdsync(i) := data(i).cmdsync
}
for (i <- 0 until 4) {
// Synchronize commands at cmdsync instance or if found in history.
// Note: {vdwinit, vdwconv, vdmulh}, vdmulh must not issue before vdwconv.
val synchronize = cmdsync.asUInt(i,0) =/= 0.U
val ordered = (~outvalid.asUInt(i,0)) === 0.U
val unorder = outvalid(i)
if (false) {
io.out(i).valid := Mux(synchronize, ordered, unorder)
} else {
io.out(i).valid := ordered
}
io.out(i).bits := data(i)
io.cmdq(i) := cmdq(i)
io.actv(i) := actv(i)
}
// ---------------------------------------------------------------------------
// Status.
val nempty = RegInit(false.B)
// Simple implementation, will overlap downstream units redundantly.
nempty := io.in.valid || f.io.nempty || valid.asUInt =/= 0.U
io.nempty := nempty
}
class VDecodeBits extends Bundle {
val op = UInt(new VEncodeOp().bits.W)
val f2 = UInt(3.W) // func2
val sz = UInt(3.W) // onehot size
val m = Bool() // stripmine
val vd = new VAddr()
val ve = new VAddr()
val vf = new VAddr()
val vg = new VAddr()
val vs = new VAddrTag()
val vt = new VAddrTag()
val vu = new VAddrTag()
val vx = new VAddrTag()
val vy = new VAddrTag()
val vz = new VAddrTag()
val sv = new SAddrData()
val cmdsync = Bool() // Dual command queues synchronize.
}
class VDecodeCmdq extends Bundle {
val alu = Bool() // ALU
val conv = Bool() // Convolution vregfile
val ldst = Bool() // L1Dcache load/store
val ld = Bool() // Uncached load
val st = Bool() // Uncached store
}
class VDecodeActive extends Bundle {
val ractive = UInt(64.W)
val wactive = UInt(64.W)
}
class VDecodeActive2 extends Bundle {
val ractive = UInt(64.W)
val wactive = UInt(128.W) // even/odd tags
}
class VAddr extends Bundle {
val valid = Bool()
val addr = UInt(6.W)
}
class VAddrTag extends Bundle {
val valid = Bool()
val addr = UInt(6.W)
val tag = UInt(4.W)
}
class SAddrData extends Bundle {
val valid = Bool()
val addr = UInt(32.W)
val data = UInt(32.W)
}
class SData extends Bundle {
val valid = Bool()
val data = UInt(32.W)
}
object EmitVDecode extends App {
val p = new Parameters
(new chisel3.stage.ChiselStage).emitVerilog(new VDecode(p), args)
}