Add utility function to convert integers to Fp32.

Change-Id: I473f86c3db8e25868f5ab121623032b7b3359241
diff --git a/hdl/chisel/src/common/BUILD b/hdl/chisel/src/common/BUILD
index 2805f14..1b5745b 100644
--- a/hdl/chisel/src/common/BUILD
+++ b/hdl/chisel/src/common/BUILD
@@ -16,6 +16,14 @@
      "chisel_cc_library", "chisel_library", "chisel_test")
 
 chisel_library(
+    name = "library",
+    srcs = [
+        "Library.scala",
+    ],
+    visibility = ["//visibility:public"],
+)
+
+chisel_library(
     name = "common",
     srcs = [
         "FifoXe.scala",
@@ -23,10 +31,12 @@
         "FifoIxO.scala",
         "Fifo.scala",
         "IDiv.scala",
-        "Library.scala",
         "MathUtil.scala",
         "Slice.scala",
     ],
+    deps = [
+        ":library",
+    ],
     visibility = ["//visibility:public"],
 )
 
@@ -35,6 +45,21 @@
     srcs = [
         "Fp.scala",
     ],
+    deps = [
+        ":library",
+    ],
+    visibility = ["//visibility:public"],
+)
+
+chisel_library(
+    name = "fp_test_utils",
+    srcs = [
+        "FpTestUtils.scala",
+    ],
+    deps = [
+        "@edu_berkeley_cs_chiseltest//jar",
+        ":fp",
+    ],
     visibility = ["//visibility:public"],
 )
 
@@ -45,6 +70,7 @@
     ],
     deps = [
         ":fp",
+        ":fp_test_utils",
     ],
 )
 
diff --git a/hdl/chisel/src/common/Fp.scala b/hdl/chisel/src/common/Fp.scala
index a75201d..bc4cbe6 100644
--- a/hdl/chisel/src/common/Fp.scala
+++ b/hdl/chisel/src/common/Fp.scala
@@ -73,6 +73,36 @@
     fp
   }
 
+  /** "static_cast" an integer into a fp32 number.
+    * @param word The integer to convert.
+    * @param sign If the integer is signed.
+    * @return The converted floating point number.
+    */
+  def fromInteger(int: UInt, sign: Bool): Fp32 = {
+    val intLength = int.getWidth
+    assert(intLength == 32, "fromInteger currently only supports 32-bit ints")
+    val floatSign = sign & int(intLength - 1)
+    val absInt = Mux(floatSign, (~int) + 1.U, int)
+
+    val preround = Wire(UInt(25.W))
+    val leadingZeros = Clz(absInt)
+    if (intLength >= 25) {
+      preround := (absInt << leadingZeros)(intLength - 1, intLength - 25)
+    } else {
+      val shift = (25 - intLength).U + leadingZeros
+      preround := (absInt << leadingZeros)(24, 0)
+    }
+
+    val zero = (preround === 0.U)
+    // TODO(derekjchow): Rounding mode
+    val rounded = preround +& (~floatSign) // 26 bits
+    val mantissa = Mux(rounded(25), rounded(24, 2), rounded(23, 1))
+    val exponent = Mux(
+        zero, 0.U(8.W), (intLength + 127 - 1).U(8.W) - leadingZeros)
+
+    Fp32(floatSign, exponent, mantissa)
+  }
+
   def Zero(sign: Bool): Fp32 = {
     val fp = Wire(new Fp32)
     fp.mantissa := 0.U
diff --git a/hdl/chisel/src/common/FpTest.scala b/hdl/chisel/src/common/FpTest.scala
index 70bf7b8..83bf181 100644
--- a/hdl/chisel/src/common/FpTest.scala
+++ b/hdl/chisel/src/common/FpTest.scala
@@ -34,6 +34,24 @@
   io.is_nan := fp.isNan()
 }
 
+class Fp32CvtuTester extends Module {
+  val io = IO(new Bundle {
+    val int  = Input(UInt(32.W))
+    val fp = Output(new Fp32)
+  })
+
+  io.fp := Fp32.fromInteger(io.int, false.B)
+}
+
+class Fp32CvtTester extends Module {
+  val io = IO(new Bundle {
+    val int  = Input(SInt(32.W))
+    val fp = Output(new Fp32)
+  })
+
+  io.fp := Fp32.fromInteger(io.int.asUInt, true.B)
+}
+
 class FpSpec extends AnyFreeSpec with ChiselScalatestTester {
   "Zero" in {
     test(new Fp32Tester()) { dut =>
@@ -75,4 +93,24 @@
       assert(dut.io.is_nan.peekInt() == 1)
     }
   }
+
+  "Convert UInt to Float" in {
+    test(new Fp32CvtuTester) { dut =>
+      for (i <- 0 until 20000000 by 3000) {
+        dut.io.int.poke(i)
+        dut.clock.step()
+        assertResult(i.toFloat) { PeekFloat(dut.io.fp) }
+      }
+    }
+  }
+
+  "Convert SInt to Float" in {
+    test(new Fp32CvtTester) { dut =>
+      for (i <- -20000001 until 20000000 by 3000) {
+        dut.io.int.poke(i)
+        dut.clock.step()
+        assertResult(i.toFloat) { PeekFloat(dut.io.fp) }
+      }
+    }
+  }
 }
\ No newline at end of file
diff --git a/hdl/chisel/src/common/FpTestUtils.scala b/hdl/chisel/src/common/FpTestUtils.scala
new file mode 100644
index 0000000..729d813
--- /dev/null
+++ b/hdl/chisel/src/common/FpTestUtils.scala
@@ -0,0 +1,85 @@
+// Copyright 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package common
+
+import chisel3._
+import chisel3.util._
+import chiseltest._
+
+/** Effectively "reinterpret_casts" a float32 into a BigInt. BigInt is used
+  * because there's no easy uint32 type in scala, but BigInt works well enough
+  * for Chisel use cases.
+  * @param f A scala float.
+  * @return A BigInt (representing uint32) bit interpretation of the float.
+  */
+object Float2BigInt {
+  def apply(f: Float): BigInt = {
+    val abs = f.abs
+    var int = BigInt(java.lang.Float.floatToIntBits(abs))
+    if (f < 0) {
+      int += (BigInt(1) << 31)
+    }
+    int
+  }
+}
+
+/** Breaks down a float32 into it's sign, exponent and mantissa.
+  * @param f A scala float.
+  * @return A tuple of the sign, exponent and mantissa.
+  */
+object Float2Bits {
+  def apply(f: Float): (Boolean, Int, Int) = {
+    val abs = f.abs
+    var int = java.lang.Float.floatToIntBits(abs)
+
+    val sign: Boolean = (f < 0)
+    val exponent: Int = int >> 23
+    val mantissa: Int = int & ((1 << 23) - 1)
+
+    (sign, exponent, mantissa)
+  }
+}
+
+/** Pokes a float.
+  * @param dut The float input.
+  * @param f A scala float.
+  */
+object PokeFloat {
+  def apply(dut: Fp32, f: Float) = {
+    val int = java.lang.Float.floatToRawIntBits(f)
+    val sign = if (int < 0) { true.B } else { false.B }
+    val mantissa = int & 0x7FFFFF
+    val exponent = (int >> 23) & 0xFF
+
+    dut.sign.poke(sign)
+    dut.mantissa.poke(mantissa)
+    dut.exponent.poke(exponent)
+  }
+}
+
+/** Peeks a float.
+  * @param dut The float input.
+  * @param f A scala float.
+  */
+object PeekFloat {
+  def apply(dut: Fp32): Float = {
+    val sign = dut.sign.peekInt().toInt
+    val exponent = dut.exponent.peekInt().toInt
+    val mantissa = dut.mantissa.peekInt().toInt
+
+    val i = (sign << 31) + (exponent << 23) + mantissa
+    java.lang.Float.intBitsToFloat(i)
+  }
+}
\ No newline at end of file
diff --git a/hdl/chisel/src/common/Library.scala b/hdl/chisel/src/common/Library.scala
index d722e89..381519f 100644
--- a/hdl/chisel/src/common/Library.scala
+++ b/hdl/chisel/src/common/Library.scala
@@ -34,4 +34,10 @@
     result.bits := bits
     result
   }
+}
+
+object Clz {
+  def apply(bits: UInt): UInt = {
+    PriorityEncoder(Cat(1.U(1.W), Reverse(bits)))
+  }
 }
\ No newline at end of file