blob: 50d978eb531781e5c2126bcf8358a5e0756a03d1 [file] [log] [blame]
package kelvin
import chisel3._
import chisel3.util._
object VDot {
// Conv2D
def apply(en: Bool, adata: UInt, bdata: UInt,
abias: UInt, bbias: UInt, asign: Bool, bsign: Bool): UInt = {
assert(abias.getWidth == 9)
assert(bbias.getWidth == 9)
assert(adata.getWidth == 32)
assert(bdata.getWidth == 32)
val mul = Wire(Vec(4, SInt(20.W)))
// input clamps
val adatac = MuxOR(en, adata)
val bdatac = MuxOR(en, bdata)
val abiasc = MuxOR(en, abias)
val bbiasc = MuxOR(en, bbias)
for (i <- 0 until 4) {
val as = adatac(8 * i + 7) & asign
val bs = bdatac(8 * i + 7) & bsign
val aval = Cat(as, adatac(8 * i + 7, 8 * i)).asSInt +& abiasc.asSInt
val bval = Cat(bs, bdatac(8 * i + 7, 8 * i)).asSInt +& bbiasc.asSInt
val mval = aval * bval
mul(i) := mval
assert(aval.getWidth == 10)
assert(bval.getWidth == 10)
assert(mval.getWidth == 20)
}
val dotp = (mul(0) +& mul(1)) +& (mul(2) +& mul(3))
val sdotp = Cat(MuxOR(dotp(21), ~0.U(10.W)), dotp)
assert(dotp.getWidth == 22)
assert(sdotp.getWidth == 32)
sdotp
}
// Depthwise
def apply(alu: Int, en: Bool, adata: Vec[UInt], bdata: Vec[UInt],
scalar: UInt): (UInt, UInt) = {
assert(adata.length == 3)
assert(bdata.length == 3)
assert(scalar.getWidth == 32)
val sparse = scalar(3,2)
val abias = scalar(20,12)
val asign = scalar(21)
val bbias = scalar(30,22)
val bsign = scalar(31)
val sparse0 = sparse === 0.U
val sparse1 = sparse === 1.U
val sparse2 = sparse === 2.U
val w = adata(0).getWidth
val cnt = w / 32
val dout0 = Wire(Vec(cnt, UInt(32.W)))
val dout1 = Wire(Vec(cnt, UInt(32.W)))
// Input clamps and dense/sparse swizzle.
val adatac = Wire(Vec(3, Vec(cnt, UInt(32.W))))
val bdatac = Wire(Vec(3, Vec(cnt, UInt(32.W))))
val abiasc = MuxOR(en, abias)
val bbiasc = MuxOR(en, bbias)
// Sparse 1 [n-1,n,n+1].
val adata1 = Wire(Vec(cnt + 2, UInt(32.W)))
if (true) {
val lsb = (cnt - 1) * 32
val msb = lsb + 32 - 1
adata1(0) := MuxOR(en && sparse1, adata(0)(msb,lsb))
}
for (i <- 0 until cnt) {
val lsb = i * 32
val msb = lsb + 32 - 1
adata1(i + 1) := MuxOR(en && sparse1, adata(1)(msb,lsb))
}
if (true) {
val lsb = 0
val msb = 31
adata1(cnt + 1) := MuxOR(en && sparse1, adata(2)(msb,lsb))
}
// Sparse 2 [n,n+1,n+2].
val adata2 = Wire(Vec(cnt + 2, UInt(32.W)))
for (i <- 0 until cnt) {
val lsb = i * 32
val msb = lsb + 32 - 1
adata2(i) := MuxOR(en && sparse2, adata(0)(msb,lsb))
}
for (i <- 0 until 2) {
val lsb = i * 32
val msb = lsb + 32 - 1
adata2(cnt + i) := MuxOR(en && sparse2, adata(1)(msb,lsb))
}
// vdot(a,b) for sparse[0,1,2].
for (j <- 0 until 3) {
for (i <- 0 until cnt) {
val lsb = i * 32
val msb = lsb + 32 - 1
val k = i + j
val adata0 = MuxOR(en && sparse0, adata(j)(msb,lsb))
adatac(j)(i) := adata0 | adata1(k) | adata2(k)
bdatac(j)(i) := MuxOR(en, bdata(j)(msb,lsb))
}
}
for (i <- 0 until cnt) {
val ad = VecInit(adatac(0)(i), adatac(1)(i), adatac(2)(i))
val bd = VecInit(bdatac(0)(i), bdatac(1)(i), bdatac(2)(i))
val (o0, o1) = dwlane(alu, en, ad, bd, abiasc, bbiasc, asign, bsign)
dout0(i) := o0
dout1(i) := o1
}
val out0 = dout0.asUInt
val out1 = dout1.asUInt
assert(out0.getWidth == w)
assert(out1.getWidth == w)
(out0, out1)
}
private def dwlane(alu: Int, en: Bool, adata: Vec[UInt], bdata: Vec[UInt],
abias: UInt, bbias: UInt, asign: Bool, bsign: Bool):
(UInt, UInt) = {
assert(adata.length == 3)
assert(bdata.length == 3)
assert(abias.getWidth == 9)
assert(bbias.getWidth == 9)
for (i <- 0 until 3) {
assert(adata(i).getWidth == 32)
assert(bdata(i).getWidth == 32)
}
val out = Wire(Vec(2, UInt(32.W)))
for (j <- 0 until 2) {
val m = 2 * j + alu // alu[0]: {0, 2}; alu[1]: {1, 3}
val mul = Wire(Vec(3, SInt(20.W)))
for (i <- 0 until 3) {
val as = adata(i)(8 * m + 7) & asign
val bs = bdata(i)(8 * m + 7) & bsign
val aval = Cat(as, adata(i)(8 * m + 7, 8 * m)).asSInt +& abias.asSInt
val bval = Cat(bs, bdata(i)(8 * m + 7, 8 * m)).asSInt +& bbias.asSInt
val mval = aval * bval
mul(i) := mval
assert(aval.getWidth == 10)
assert(bval.getWidth == 10)
assert(mval.getWidth == 20)
}
val dotp = (mul(0) +& mul(1)) +& mul(2)
val sdotp = Cat(MuxOR(dotp(21), ~0.U(10.W)), dotp)
assert(dotp.getWidth == 22)
assert(sdotp.getWidth == 32)
out(j) := sdotp
}
(out(0), out(1))
}
}