Add utility function to convert integers to Fp32. Change-Id: I473f86c3db8e25868f5ab121623032b7b3359241
diff --git a/hdl/chisel/src/common/BUILD b/hdl/chisel/src/common/BUILD index 2805f14..1b5745b 100644 --- a/hdl/chisel/src/common/BUILD +++ b/hdl/chisel/src/common/BUILD
@@ -16,6 +16,14 @@ "chisel_cc_library", "chisel_library", "chisel_test") chisel_library( + name = "library", + srcs = [ + "Library.scala", + ], + visibility = ["//visibility:public"], +) + +chisel_library( name = "common", srcs = [ "FifoXe.scala", @@ -23,10 +31,12 @@ "FifoIxO.scala", "Fifo.scala", "IDiv.scala", - "Library.scala", "MathUtil.scala", "Slice.scala", ], + deps = [ + ":library", + ], visibility = ["//visibility:public"], ) @@ -35,6 +45,21 @@ srcs = [ "Fp.scala", ], + deps = [ + ":library", + ], + visibility = ["//visibility:public"], +) + +chisel_library( + name = "fp_test_utils", + srcs = [ + "FpTestUtils.scala", + ], + deps = [ + "@edu_berkeley_cs_chiseltest//jar", + ":fp", + ], visibility = ["//visibility:public"], ) @@ -45,6 +70,7 @@ ], deps = [ ":fp", + ":fp_test_utils", ], )
diff --git a/hdl/chisel/src/common/Fp.scala b/hdl/chisel/src/common/Fp.scala index a75201d..bc4cbe6 100644 --- a/hdl/chisel/src/common/Fp.scala +++ b/hdl/chisel/src/common/Fp.scala
@@ -73,6 +73,36 @@ fp } + /** "static_cast" an integer into a fp32 number. + * @param word The integer to convert. + * @param sign If the integer is signed. + * @return The converted floating point number. + */ + def fromInteger(int: UInt, sign: Bool): Fp32 = { + val intLength = int.getWidth + assert(intLength == 32, "fromInteger currently only supports 32-bit ints") + val floatSign = sign & int(intLength - 1) + val absInt = Mux(floatSign, (~int) + 1.U, int) + + val preround = Wire(UInt(25.W)) + val leadingZeros = Clz(absInt) + if (intLength >= 25) { + preround := (absInt << leadingZeros)(intLength - 1, intLength - 25) + } else { + val shift = (25 - intLength).U + leadingZeros + preround := (absInt << leadingZeros)(24, 0) + } + + val zero = (preround === 0.U) + // TODO(derekjchow): Rounding mode + val rounded = preround +& (~floatSign) // 26 bits + val mantissa = Mux(rounded(25), rounded(24, 2), rounded(23, 1)) + val exponent = Mux( + zero, 0.U(8.W), (intLength + 127 - 1).U(8.W) - leadingZeros) + + Fp32(floatSign, exponent, mantissa) + } + def Zero(sign: Bool): Fp32 = { val fp = Wire(new Fp32) fp.mantissa := 0.U
diff --git a/hdl/chisel/src/common/FpTest.scala b/hdl/chisel/src/common/FpTest.scala index 70bf7b8..83bf181 100644 --- a/hdl/chisel/src/common/FpTest.scala +++ b/hdl/chisel/src/common/FpTest.scala
@@ -34,6 +34,24 @@ io.is_nan := fp.isNan() } +class Fp32CvtuTester extends Module { + val io = IO(new Bundle { + val int = Input(UInt(32.W)) + val fp = Output(new Fp32) + }) + + io.fp := Fp32.fromInteger(io.int, false.B) +} + +class Fp32CvtTester extends Module { + val io = IO(new Bundle { + val int = Input(SInt(32.W)) + val fp = Output(new Fp32) + }) + + io.fp := Fp32.fromInteger(io.int.asUInt, true.B) +} + class FpSpec extends AnyFreeSpec with ChiselScalatestTester { "Zero" in { test(new Fp32Tester()) { dut => @@ -75,4 +93,24 @@ assert(dut.io.is_nan.peekInt() == 1) } } + + "Convert UInt to Float" in { + test(new Fp32CvtuTester) { dut => + for (i <- 0 until 20000000 by 3000) { + dut.io.int.poke(i) + dut.clock.step() + assertResult(i.toFloat) { PeekFloat(dut.io.fp) } + } + } + } + + "Convert SInt to Float" in { + test(new Fp32CvtTester) { dut => + for (i <- -20000001 until 20000000 by 3000) { + dut.io.int.poke(i) + dut.clock.step() + assertResult(i.toFloat) { PeekFloat(dut.io.fp) } + } + } + } } \ No newline at end of file
diff --git a/hdl/chisel/src/common/FpTestUtils.scala b/hdl/chisel/src/common/FpTestUtils.scala new file mode 100644 index 0000000..729d813 --- /dev/null +++ b/hdl/chisel/src/common/FpTestUtils.scala
@@ -0,0 +1,85 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package common + +import chisel3._ +import chisel3.util._ +import chiseltest._ + +/** Effectively "reinterpret_casts" a float32 into a BigInt. BigInt is used + * because there's no easy uint32 type in scala, but BigInt works well enough + * for Chisel use cases. + * @param f A scala float. + * @return A BigInt (representing uint32) bit interpretation of the float. + */ +object Float2BigInt { + def apply(f: Float): BigInt = { + val abs = f.abs + var int = BigInt(java.lang.Float.floatToIntBits(abs)) + if (f < 0) { + int += (BigInt(1) << 31) + } + int + } +} + +/** Breaks down a float32 into it's sign, exponent and mantissa. + * @param f A scala float. + * @return A tuple of the sign, exponent and mantissa. + */ +object Float2Bits { + def apply(f: Float): (Boolean, Int, Int) = { + val abs = f.abs + var int = java.lang.Float.floatToIntBits(abs) + + val sign: Boolean = (f < 0) + val exponent: Int = int >> 23 + val mantissa: Int = int & ((1 << 23) - 1) + + (sign, exponent, mantissa) + } +} + +/** Pokes a float. + * @param dut The float input. + * @param f A scala float. + */ +object PokeFloat { + def apply(dut: Fp32, f: Float) = { + val int = java.lang.Float.floatToRawIntBits(f) + val sign = if (int < 0) { true.B } else { false.B } + val mantissa = int & 0x7FFFFF + val exponent = (int >> 23) & 0xFF + + dut.sign.poke(sign) + dut.mantissa.poke(mantissa) + dut.exponent.poke(exponent) + } +} + +/** Peeks a float. + * @param dut The float input. + * @param f A scala float. + */ +object PeekFloat { + def apply(dut: Fp32): Float = { + val sign = dut.sign.peekInt().toInt + val exponent = dut.exponent.peekInt().toInt + val mantissa = dut.mantissa.peekInt().toInt + + val i = (sign << 31) + (exponent << 23) + mantissa + java.lang.Float.intBitsToFloat(i) + } +} \ No newline at end of file
diff --git a/hdl/chisel/src/common/Library.scala b/hdl/chisel/src/common/Library.scala index d722e89..381519f 100644 --- a/hdl/chisel/src/common/Library.scala +++ b/hdl/chisel/src/common/Library.scala
@@ -34,4 +34,10 @@ result.bits := bits result } +} + +object Clz { + def apply(bits: UInt): UInt = { + PriorityEncoder(Cat(1.U(1.W), Reverse(bits))) + } } \ No newline at end of file