Add utility function to convert integers to Fp32.
Change-Id: I473f86c3db8e25868f5ab121623032b7b3359241
diff --git a/hdl/chisel/src/common/BUILD b/hdl/chisel/src/common/BUILD
index 2805f14..1b5745b 100644
--- a/hdl/chisel/src/common/BUILD
+++ b/hdl/chisel/src/common/BUILD
@@ -16,6 +16,14 @@
"chisel_cc_library", "chisel_library", "chisel_test")
chisel_library(
+ name = "library",
+ srcs = [
+ "Library.scala",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+chisel_library(
name = "common",
srcs = [
"FifoXe.scala",
@@ -23,10 +31,12 @@
"FifoIxO.scala",
"Fifo.scala",
"IDiv.scala",
- "Library.scala",
"MathUtil.scala",
"Slice.scala",
],
+ deps = [
+ ":library",
+ ],
visibility = ["//visibility:public"],
)
@@ -35,6 +45,21 @@
srcs = [
"Fp.scala",
],
+ deps = [
+ ":library",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+chisel_library(
+ name = "fp_test_utils",
+ srcs = [
+ "FpTestUtils.scala",
+ ],
+ deps = [
+ "@edu_berkeley_cs_chiseltest//jar",
+ ":fp",
+ ],
visibility = ["//visibility:public"],
)
@@ -45,6 +70,7 @@
],
deps = [
":fp",
+ ":fp_test_utils",
],
)
diff --git a/hdl/chisel/src/common/Fp.scala b/hdl/chisel/src/common/Fp.scala
index a75201d..bc4cbe6 100644
--- a/hdl/chisel/src/common/Fp.scala
+++ b/hdl/chisel/src/common/Fp.scala
@@ -73,6 +73,36 @@
fp
}
+ /** "static_cast" an integer into a fp32 number.
+ * @param word The integer to convert.
+ * @param sign If the integer is signed.
+ * @return The converted floating point number.
+ */
+ def fromInteger(int: UInt, sign: Bool): Fp32 = {
+ val intLength = int.getWidth
+ assert(intLength == 32, "fromInteger currently only supports 32-bit ints")
+ val floatSign = sign & int(intLength - 1)
+ val absInt = Mux(floatSign, (~int) + 1.U, int)
+
+ val preround = Wire(UInt(25.W))
+ val leadingZeros = Clz(absInt)
+ if (intLength >= 25) {
+ preround := (absInt << leadingZeros)(intLength - 1, intLength - 25)
+ } else {
+ val shift = (25 - intLength).U + leadingZeros
+ preround := (absInt << leadingZeros)(24, 0)
+ }
+
+ val zero = (preround === 0.U)
+ // TODO(derekjchow): Rounding mode
+ val rounded = preround +& (~floatSign) // 26 bits
+ val mantissa = Mux(rounded(25), rounded(24, 2), rounded(23, 1))
+ val exponent = Mux(
+ zero, 0.U(8.W), (intLength + 127 - 1).U(8.W) - leadingZeros)
+
+ Fp32(floatSign, exponent, mantissa)
+ }
+
def Zero(sign: Bool): Fp32 = {
val fp = Wire(new Fp32)
fp.mantissa := 0.U
diff --git a/hdl/chisel/src/common/FpTest.scala b/hdl/chisel/src/common/FpTest.scala
index 70bf7b8..83bf181 100644
--- a/hdl/chisel/src/common/FpTest.scala
+++ b/hdl/chisel/src/common/FpTest.scala
@@ -34,6 +34,24 @@
io.is_nan := fp.isNan()
}
+class Fp32CvtuTester extends Module {
+ val io = IO(new Bundle {
+ val int = Input(UInt(32.W))
+ val fp = Output(new Fp32)
+ })
+
+ io.fp := Fp32.fromInteger(io.int, false.B)
+}
+
+class Fp32CvtTester extends Module {
+ val io = IO(new Bundle {
+ val int = Input(SInt(32.W))
+ val fp = Output(new Fp32)
+ })
+
+ io.fp := Fp32.fromInteger(io.int.asUInt, true.B)
+}
+
class FpSpec extends AnyFreeSpec with ChiselScalatestTester {
"Zero" in {
test(new Fp32Tester()) { dut =>
@@ -75,4 +93,24 @@
assert(dut.io.is_nan.peekInt() == 1)
}
}
+
+ "Convert UInt to Float" in {
+ test(new Fp32CvtuTester) { dut =>
+ for (i <- 0 until 20000000 by 3000) {
+ dut.io.int.poke(i)
+ dut.clock.step()
+ assertResult(i.toFloat) { PeekFloat(dut.io.fp) }
+ }
+ }
+ }
+
+ "Convert SInt to Float" in {
+ test(new Fp32CvtTester) { dut =>
+ for (i <- -20000001 until 20000000 by 3000) {
+ dut.io.int.poke(i)
+ dut.clock.step()
+ assertResult(i.toFloat) { PeekFloat(dut.io.fp) }
+ }
+ }
+ }
}
\ No newline at end of file
diff --git a/hdl/chisel/src/common/FpTestUtils.scala b/hdl/chisel/src/common/FpTestUtils.scala
new file mode 100644
index 0000000..729d813
--- /dev/null
+++ b/hdl/chisel/src/common/FpTestUtils.scala
@@ -0,0 +1,85 @@
+// Copyright 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package common
+
+import chisel3._
+import chisel3.util._
+import chiseltest._
+
+/** Effectively "reinterpret_casts" a float32 into a BigInt. BigInt is used
+ * because there's no easy uint32 type in scala, but BigInt works well enough
+ * for Chisel use cases.
+ * @param f A scala float.
+ * @return A BigInt (representing uint32) bit interpretation of the float.
+ */
+object Float2BigInt {
+ def apply(f: Float): BigInt = {
+ val abs = f.abs
+ var int = BigInt(java.lang.Float.floatToIntBits(abs))
+ if (f < 0) {
+ int += (BigInt(1) << 31)
+ }
+ int
+ }
+}
+
+/** Breaks down a float32 into it's sign, exponent and mantissa.
+ * @param f A scala float.
+ * @return A tuple of the sign, exponent and mantissa.
+ */
+object Float2Bits {
+ def apply(f: Float): (Boolean, Int, Int) = {
+ val abs = f.abs
+ var int = java.lang.Float.floatToIntBits(abs)
+
+ val sign: Boolean = (f < 0)
+ val exponent: Int = int >> 23
+ val mantissa: Int = int & ((1 << 23) - 1)
+
+ (sign, exponent, mantissa)
+ }
+}
+
+/** Pokes a float.
+ * @param dut The float input.
+ * @param f A scala float.
+ */
+object PokeFloat {
+ def apply(dut: Fp32, f: Float) = {
+ val int = java.lang.Float.floatToRawIntBits(f)
+ val sign = if (int < 0) { true.B } else { false.B }
+ val mantissa = int & 0x7FFFFF
+ val exponent = (int >> 23) & 0xFF
+
+ dut.sign.poke(sign)
+ dut.mantissa.poke(mantissa)
+ dut.exponent.poke(exponent)
+ }
+}
+
+/** Peeks a float.
+ * @param dut The float input.
+ * @param f A scala float.
+ */
+object PeekFloat {
+ def apply(dut: Fp32): Float = {
+ val sign = dut.sign.peekInt().toInt
+ val exponent = dut.exponent.peekInt().toInt
+ val mantissa = dut.mantissa.peekInt().toInt
+
+ val i = (sign << 31) + (exponent << 23) + mantissa
+ java.lang.Float.intBitsToFloat(i)
+ }
+}
\ No newline at end of file
diff --git a/hdl/chisel/src/common/Library.scala b/hdl/chisel/src/common/Library.scala
index d722e89..381519f 100644
--- a/hdl/chisel/src/common/Library.scala
+++ b/hdl/chisel/src/common/Library.scala
@@ -34,4 +34,10 @@
result.bits := bits
result
}
+}
+
+object Clz {
+ def apply(bits: UInt): UInt = {
+ PriorityEncoder(Cat(1.U(1.W), Reverse(bits)))
+ }
}
\ No newline at end of file