FPU Multiply operation.

Change-Id: I0355ec74ae2773794bcfbda183d4939f8f5b2bc6
diff --git a/hdl/chisel/src/common/BUILD b/hdl/chisel/src/common/BUILD
index 2dc5305..8397f95 100644
--- a/hdl/chisel/src/common/BUILD
+++ b/hdl/chisel/src/common/BUILD
@@ -44,4 +44,25 @@
     deps = [
         ":fp",
     ],
-)
\ No newline at end of file
+)
+
+chisel_library(
+    name = "fpu",
+    srcs = [
+        "Fpu.scala",
+    ],
+    deps = [
+        ":fp",
+    ],
+)
+
+chisel_test(
+    name = "fpu_test",
+    srcs = [
+        "FpuTest.scala",
+    ],
+    deps = [
+        ":fp",
+        ":fpu",
+    ],
+)
diff --git a/hdl/chisel/src/common/Fpu.scala b/hdl/chisel/src/common/Fpu.scala
new file mode 100644
index 0000000..384a1f6
--- /dev/null
+++ b/hdl/chisel/src/common/Fpu.scala
@@ -0,0 +1,87 @@
+// Copyright 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package common
+
+import chisel3._
+import chisel3.util._
+
+object FpuOp extends ChiselEnum {
+  val FpuMul = Value
+}
+
+class FpuCmd extends Bundle {
+  val op = FpuOp()
+  val ina = new Fp32
+  val inb = new Fp32
+}
+
+class FpuState1 extends Bundle {
+  val zero        = Bool()
+  val inf         = Bool()
+  val nan         = Bool()
+  val sign        = Bool()
+  val exponent    = SInt(10.W)
+  val significand = UInt(48.W)
+}
+
+object Fpu {
+  def apply(cmd: FpuCmd): Fp32 = {
+    FpuStage2(FpuStage1(cmd))
+  }
+
+  def FpuStage1(cmd: FpuCmd): FpuState1 = {
+    val state = Wire(new FpuState1)
+
+    state.zero := cmd.ina.isZero() || cmd.inb.isZero()
+    state.inf := cmd.ina.isInf() || cmd.inb.isInf()
+    state.nan := cmd.ina.isNan() || cmd.inb.isNan()
+
+    state.sign := cmd.ina.sign ^ cmd.inb.sign
+    state.exponent := (cmd.ina.exponent +& cmd.inb.exponent).zext - 127.S
+    state.significand := cmd.ina.significand() * cmd.inb.significand()
+    state
+  }
+
+  def FpuStage2(state: FpuState1): Fp32 = {
+    // Grab 24-bits of the mantissa for rounding. At least one of the MSB (for
+    // when the significand product >= 2) or 2nd MSB is guarenteed to be set.
+    // The below mux effectively picks the correct 25-bit truncated significand
+    // depending if the MSB is set, then returns the lower 24-bits of that
+    // result (the mantissa of the truncated significand).
+    val mantissa24 = Mux(
+        state.significand(47),
+        state.significand(46, 23),
+        state.significand(45, 22))
+    // TODO(derekjchow): Rounding modes
+    val mantissa = ((mantissa24 + 1.U(1.W)) >> 1)(22, 0)
+
+    // If the significand product >= 2, we "shift the decimal" to the right
+    // by one bit. Add 1 to the exponent to compensate.
+    val exponent = state.exponent + state.significand(47).asUInt.zext
+    // Check for overflow.
+    val inf = state.inf || (exponent >= (1 << 8).S)
+    // Check for very small numbers that should round to zero.
+    val zero = state.zero || (exponent < 0.S)
+
+    MuxCase(
+        Fp32(state.sign, exponent(7, 0), mantissa),
+        Array(
+            state.nan -> Fp32(false.B, ((1<<8)-1).U, mantissa),
+            (state.zero && state.inf) -> Fp32.NaN(),
+            inf -> Fp32.Inf(state.sign),
+            zero -> Fp32.Zero(state.sign)
+        ))
+  }
+}
diff --git a/hdl/chisel/src/common/FpuTest.scala b/hdl/chisel/src/common/FpuTest.scala
new file mode 100644
index 0000000..c80cf7c
--- /dev/null
+++ b/hdl/chisel/src/common/FpuTest.scala
@@ -0,0 +1,129 @@
+// Copyright 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package common
+
+import chisel3._
+import chisel3.util._
+import chiseltest._
+import org.scalatest.freespec.AnyFreeSpec
+import chisel3.experimental.BundleLiterals._
+
+class FpuTester extends Module {
+  val io = IO(new Bundle {
+    val ina    = Input(UInt(32.W))
+    val inb    = Input(UInt(32.W))
+    val op     = Input(FpuOp())
+    val out    = Output(new Fp32)
+  })
+
+  val fp_a = Fp32.fromWord(io.ina)
+  val fp_b = Fp32.fromWord(io.inb)
+
+  val cmd = Wire(new FpuCmd)
+  cmd.ina := fp_a
+  cmd.inb := fp_b
+  cmd.op := io.op
+
+  val stage1 = Fpu.FpuStage1(cmd)
+  io.out := Fpu.FpuStage2(stage1)
+}
+
+class FpuSpec extends AnyFreeSpec with ChiselScalatestTester {
+  def Float2BigInt(x: Float): BigInt = {
+    val abs = x.abs
+    var int = BigInt(java.lang.Float.floatToIntBits(abs))
+    if (x < 0) {
+      int += (BigInt(1) << 31)
+    }
+    int
+  }
+
+  "Zero" in {
+    test(new FpuTester()) { dut =>
+      dut.io.op.poke(FpuOp.FpuMul)
+      dut.io.ina.poke(Float2BigInt(0))
+      dut.io.inb.poke(Float2BigInt(42))
+      assertResult(0) { dut.io.out.sign.peekInt() }
+      assertResult(0) { dut.io.out.exponent.peekInt() }
+      assertResult(0) { dut.io.out.mantissa.peekInt() }
+    }
+  }
+
+  "Identity" in {
+    test(new FpuTester()) { dut =>
+      dut.io.op.poke(FpuOp.FpuMul)
+      dut.io.ina.poke(Float2BigInt(1))
+      dut.io.inb.poke(Float2BigInt(42))
+      assertResult(0) { dut.io.out.sign.peekInt() }
+      assertResult(132) { dut.io.out.exponent.peekInt() }
+      assertResult(2621440) { dut.io.out.mantissa.peekInt() }
+    }
+  }
+
+  "Negative" in {
+    test(new FpuTester()) { dut =>
+      dut.io.op.poke(FpuOp.FpuMul)
+      dut.io.ina.poke(Float2BigInt(-1.0f))
+      dut.io.inb.poke(Float2BigInt(42))
+      assertResult(1) { dut.io.out.sign.peekInt() }
+      assertResult(132) { dut.io.out.exponent.peekInt() }
+      assertResult(2621440) { dut.io.out.mantissa.peekInt() }
+    }
+  }
+
+  "Half" in {
+    test(new FpuTester()) { dut =>
+      dut.io.op.poke(FpuOp.FpuMul)
+      dut.io.ina.poke(Float2BigInt(0.5f))
+      dut.io.inb.poke(Float2BigInt(42))
+      assertResult(0) { dut.io.out.sign.peekInt() }
+      assertResult(131) { dut.io.out.exponent.peekInt() }
+      assertResult(2621440) { dut.io.out.mantissa.peekInt() }
+    }
+  }
+
+  "Overflow" in {
+    test(new FpuTester()) { dut =>
+      dut.io.op.poke(FpuOp.FpuMul)
+      dut.io.ina.poke(Float2BigInt(2e30f))
+      dut.io.inb.poke(Float2BigInt(2e30f))
+      assertResult(0) { dut.io.out.sign.peekInt() }
+      assertResult(255) { dut.io.out.exponent.peekInt() }
+      assertResult(0) { dut.io.out.mantissa.peekInt() }
+    }
+  }
+
+  "Rounds to Zero" in {
+    test(new FpuTester()) { dut =>
+      dut.io.op.poke(FpuOp.FpuMul)
+      dut.io.ina.poke(Float2BigInt(1e-30f))
+      dut.io.inb.poke(Float2BigInt(1e-30f))
+      assertResult(0) { dut.io.out.sign.peekInt() }
+      assertResult(0) { dut.io.out.exponent.peekInt() }
+      assertResult(0) { dut.io.out.mantissa.peekInt() }
+    }
+  }
+
+  "NaN" in {
+    test(new FpuTester()) { dut =>
+      dut.io.op.poke(FpuOp.FpuMul)
+      dut.io.ina.poke(Float2BigInt(Float.NaN))
+      dut.io.inb.poke(Float2BigInt(4.0f))
+      assertResult(0) { dut.io.out.sign.peekInt() }
+      assertResult(255) { dut.io.out.exponent.peekInt() }
+      assert(dut.io.out.mantissa.peekInt() != 0)
+    }
+  }
+}
\ No newline at end of file