FPU Multiply operation.
Change-Id: I0355ec74ae2773794bcfbda183d4939f8f5b2bc6
diff --git a/hdl/chisel/src/common/BUILD b/hdl/chisel/src/common/BUILD
index 2dc5305..8397f95 100644
--- a/hdl/chisel/src/common/BUILD
+++ b/hdl/chisel/src/common/BUILD
@@ -44,4 +44,25 @@
deps = [
":fp",
],
-)
\ No newline at end of file
+)
+
+chisel_library(
+ name = "fpu",
+ srcs = [
+ "Fpu.scala",
+ ],
+ deps = [
+ ":fp",
+ ],
+)
+
+chisel_test(
+ name = "fpu_test",
+ srcs = [
+ "FpuTest.scala",
+ ],
+ deps = [
+ ":fp",
+ ":fpu",
+ ],
+)
diff --git a/hdl/chisel/src/common/Fpu.scala b/hdl/chisel/src/common/Fpu.scala
new file mode 100644
index 0000000..384a1f6
--- /dev/null
+++ b/hdl/chisel/src/common/Fpu.scala
@@ -0,0 +1,87 @@
+// Copyright 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package common
+
+import chisel3._
+import chisel3.util._
+
+object FpuOp extends ChiselEnum {
+ val FpuMul = Value
+}
+
+class FpuCmd extends Bundle {
+ val op = FpuOp()
+ val ina = new Fp32
+ val inb = new Fp32
+}
+
+class FpuState1 extends Bundle {
+ val zero = Bool()
+ val inf = Bool()
+ val nan = Bool()
+ val sign = Bool()
+ val exponent = SInt(10.W)
+ val significand = UInt(48.W)
+}
+
+object Fpu {
+ def apply(cmd: FpuCmd): Fp32 = {
+ FpuStage2(FpuStage1(cmd))
+ }
+
+ def FpuStage1(cmd: FpuCmd): FpuState1 = {
+ val state = Wire(new FpuState1)
+
+ state.zero := cmd.ina.isZero() || cmd.inb.isZero()
+ state.inf := cmd.ina.isInf() || cmd.inb.isInf()
+ state.nan := cmd.ina.isNan() || cmd.inb.isNan()
+
+ state.sign := cmd.ina.sign ^ cmd.inb.sign
+ state.exponent := (cmd.ina.exponent +& cmd.inb.exponent).zext - 127.S
+ state.significand := cmd.ina.significand() * cmd.inb.significand()
+ state
+ }
+
+ def FpuStage2(state: FpuState1): Fp32 = {
+ // Grab 24-bits of the mantissa for rounding. At least one of the MSB (for
+ // when the significand product >= 2) or 2nd MSB is guarenteed to be set.
+ // The below mux effectively picks the correct 25-bit truncated significand
+ // depending if the MSB is set, then returns the lower 24-bits of that
+ // result (the mantissa of the truncated significand).
+ val mantissa24 = Mux(
+ state.significand(47),
+ state.significand(46, 23),
+ state.significand(45, 22))
+ // TODO(derekjchow): Rounding modes
+ val mantissa = ((mantissa24 + 1.U(1.W)) >> 1)(22, 0)
+
+ // If the significand product >= 2, we "shift the decimal" to the right
+ // by one bit. Add 1 to the exponent to compensate.
+ val exponent = state.exponent + state.significand(47).asUInt.zext
+ // Check for overflow.
+ val inf = state.inf || (exponent >= (1 << 8).S)
+ // Check for very small numbers that should round to zero.
+ val zero = state.zero || (exponent < 0.S)
+
+ MuxCase(
+ Fp32(state.sign, exponent(7, 0), mantissa),
+ Array(
+ state.nan -> Fp32(false.B, ((1<<8)-1).U, mantissa),
+ (state.zero && state.inf) -> Fp32.NaN(),
+ inf -> Fp32.Inf(state.sign),
+ zero -> Fp32.Zero(state.sign)
+ ))
+ }
+}
diff --git a/hdl/chisel/src/common/FpuTest.scala b/hdl/chisel/src/common/FpuTest.scala
new file mode 100644
index 0000000..c80cf7c
--- /dev/null
+++ b/hdl/chisel/src/common/FpuTest.scala
@@ -0,0 +1,129 @@
+// Copyright 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package common
+
+import chisel3._
+import chisel3.util._
+import chiseltest._
+import org.scalatest.freespec.AnyFreeSpec
+import chisel3.experimental.BundleLiterals._
+
+class FpuTester extends Module {
+ val io = IO(new Bundle {
+ val ina = Input(UInt(32.W))
+ val inb = Input(UInt(32.W))
+ val op = Input(FpuOp())
+ val out = Output(new Fp32)
+ })
+
+ val fp_a = Fp32.fromWord(io.ina)
+ val fp_b = Fp32.fromWord(io.inb)
+
+ val cmd = Wire(new FpuCmd)
+ cmd.ina := fp_a
+ cmd.inb := fp_b
+ cmd.op := io.op
+
+ val stage1 = Fpu.FpuStage1(cmd)
+ io.out := Fpu.FpuStage2(stage1)
+}
+
+class FpuSpec extends AnyFreeSpec with ChiselScalatestTester {
+ def Float2BigInt(x: Float): BigInt = {
+ val abs = x.abs
+ var int = BigInt(java.lang.Float.floatToIntBits(abs))
+ if (x < 0) {
+ int += (BigInt(1) << 31)
+ }
+ int
+ }
+
+ "Zero" in {
+ test(new FpuTester()) { dut =>
+ dut.io.op.poke(FpuOp.FpuMul)
+ dut.io.ina.poke(Float2BigInt(0))
+ dut.io.inb.poke(Float2BigInt(42))
+ assertResult(0) { dut.io.out.sign.peekInt() }
+ assertResult(0) { dut.io.out.exponent.peekInt() }
+ assertResult(0) { dut.io.out.mantissa.peekInt() }
+ }
+ }
+
+ "Identity" in {
+ test(new FpuTester()) { dut =>
+ dut.io.op.poke(FpuOp.FpuMul)
+ dut.io.ina.poke(Float2BigInt(1))
+ dut.io.inb.poke(Float2BigInt(42))
+ assertResult(0) { dut.io.out.sign.peekInt() }
+ assertResult(132) { dut.io.out.exponent.peekInt() }
+ assertResult(2621440) { dut.io.out.mantissa.peekInt() }
+ }
+ }
+
+ "Negative" in {
+ test(new FpuTester()) { dut =>
+ dut.io.op.poke(FpuOp.FpuMul)
+ dut.io.ina.poke(Float2BigInt(-1.0f))
+ dut.io.inb.poke(Float2BigInt(42))
+ assertResult(1) { dut.io.out.sign.peekInt() }
+ assertResult(132) { dut.io.out.exponent.peekInt() }
+ assertResult(2621440) { dut.io.out.mantissa.peekInt() }
+ }
+ }
+
+ "Half" in {
+ test(new FpuTester()) { dut =>
+ dut.io.op.poke(FpuOp.FpuMul)
+ dut.io.ina.poke(Float2BigInt(0.5f))
+ dut.io.inb.poke(Float2BigInt(42))
+ assertResult(0) { dut.io.out.sign.peekInt() }
+ assertResult(131) { dut.io.out.exponent.peekInt() }
+ assertResult(2621440) { dut.io.out.mantissa.peekInt() }
+ }
+ }
+
+ "Overflow" in {
+ test(new FpuTester()) { dut =>
+ dut.io.op.poke(FpuOp.FpuMul)
+ dut.io.ina.poke(Float2BigInt(2e30f))
+ dut.io.inb.poke(Float2BigInt(2e30f))
+ assertResult(0) { dut.io.out.sign.peekInt() }
+ assertResult(255) { dut.io.out.exponent.peekInt() }
+ assertResult(0) { dut.io.out.mantissa.peekInt() }
+ }
+ }
+
+ "Rounds to Zero" in {
+ test(new FpuTester()) { dut =>
+ dut.io.op.poke(FpuOp.FpuMul)
+ dut.io.ina.poke(Float2BigInt(1e-30f))
+ dut.io.inb.poke(Float2BigInt(1e-30f))
+ assertResult(0) { dut.io.out.sign.peekInt() }
+ assertResult(0) { dut.io.out.exponent.peekInt() }
+ assertResult(0) { dut.io.out.mantissa.peekInt() }
+ }
+ }
+
+ "NaN" in {
+ test(new FpuTester()) { dut =>
+ dut.io.op.poke(FpuOp.FpuMul)
+ dut.io.ina.poke(Float2BigInt(Float.NaN))
+ dut.io.inb.poke(Float2BigInt(4.0f))
+ assertResult(0) { dut.io.out.sign.peekInt() }
+ assertResult(255) { dut.io.out.exponent.peekInt() }
+ assert(dut.io.out.mantissa.peekInt() != 0)
+ }
+ }
+}
\ No newline at end of file