Add conversions for FP8 types (F8E5M2 and F8E4M3) (#16374)
This PR almost doesn't make code any bigger because the existing
conversion code was already essentially generic. So at least the F8E5M2
type falls for free. F8E4M3 is a bit trickier due to it not having
infinities and reclaiming that encoding space to get extra large finite
values.
diff --git a/runtime/src/iree/base/internal/math.h b/runtime/src/iree/base/internal/math.h
index c3c29f9..58dd88d 100644
--- a/runtime/src/iree/base/internal/math.h
+++ b/runtime/src/iree/base/internal/math.h
@@ -264,7 +264,7 @@
}
//==============================================================================
-// FP16 and BFloat16 support
+// FP16, BFloat16 and FP8 support
//==============================================================================
// NOTE: We used to have code here using built-in _Float16 type support.
@@ -273,47 +273,61 @@
// in slow generic fallbacks or test code, and we weren't able to use
// a builtin for bf16 anyway.
-#define IREE_MATH_FP_FORMAT_CONSTANTS(prefix, bits, ebits) \
+// Define some helper constants for working with a floating-point format with
+// the given number of {exponent,mantissa} bits.
+#define IREE_MATH_FP_FORMAT_CONSTANTS(prefix, ebits, mbits) \
const int prefix##exp_bits IREE_ATTRIBUTE_UNUSED = ebits; \
- const int prefix##mantissa_bits IREE_ATTRIBUTE_UNUSED = \
- bits - 1 - prefix##exp_bits; \
- const int prefix##sign_shift IREE_ATTRIBUTE_UNUSED = bits - 1; \
+ const int prefix##mantissa_bits IREE_ATTRIBUTE_UNUSED = mbits; \
+ const int prefix##sign_shift IREE_ATTRIBUTE_UNUSED = ebits + mbits; \
const int prefix##exp_shift IREE_ATTRIBUTE_UNUSED = prefix##mantissa_bits; \
const int prefix##sign_mask IREE_ATTRIBUTE_UNUSED = 1u \
<< prefix##sign_shift; \
const int prefix##mantissa_mask IREE_ATTRIBUTE_UNUSED = \
(1u << prefix##exp_shift) - 1; \
const int prefix##exp_mask IREE_ATTRIBUTE_UNUSED = \
- (1u << prefix##sign_shift) - (1u << prefix##exp_shift);
+ (1u << prefix##sign_shift) - (1u << prefix##exp_shift); \
+ const int prefix##exp_bias IREE_ATTRIBUTE_UNUSED = \
+ (1u << (prefix##exp_bits - 1)) - 1;
-static inline float iree_math_generic_fp16_to_f32(uint16_t f16_value,
- int exp_bits) {
- IREE_MATH_FP_FORMAT_CONSTANTS(f16_, 16, exp_bits)
- IREE_MATH_FP_FORMAT_CONSTANTS(f32_, 32, 8)
- const uint32_t f16_sign = f16_value & f16_sign_mask;
- const uint32_t f32_sign = f16_sign << (f32_sign_shift - f16_sign_shift);
- const uint32_t f16_exp = f16_value & f16_exp_mask;
- const uint32_t f16_mantissa = f16_value & f16_mantissa_mask;
+// Generic conversion from any less-than-32-bit floating-point format to f32.
+// The `src` value is typed as a uint32_t for genericity but occupies only the
+// bottom (1 + exp_bits + mantissa_bits) bits. The upper bits of `src` are
+// unused.
+static inline float iree_math_make_f32_from_bits(uint32_t src, int exp_bits,
+ int mantissa_bits,
+ bool have_infinity) {
+ IREE_MATH_FP_FORMAT_CONSTANTS(src_, exp_bits, mantissa_bits)
+ IREE_MATH_FP_FORMAT_CONSTANTS(f32_, 8, 23)
+ const uint32_t src_sign = src & src_sign_mask;
+ const uint32_t f32_sign = src_sign << (f32_sign_shift - src_sign_shift);
+ const uint32_t src_exp = src & src_exp_mask;
+ const uint32_t src_mantissa = src & src_mantissa_mask;
uint32_t f32_exp = 0;
uint32_t f32_mantissa = 0;
- if (f16_exp == f16_exp_mask) {
+ if (src_exp == src_exp_mask) {
+ // No infinities => more large finite values.
+ if (!have_infinity && src_mantissa != src_mantissa_mask) {
+ float sign = (src & src_sign_mask) ? -1.0f : 1.0f;
+ return sign * 2 * (1u << src_exp_bits) *
+ ((1u << src_mantissa_bits) + src_mantissa);
+ }
// NaN or Inf case.
f32_exp = f32_exp_mask;
- if (f16_mantissa) {
+ if (src_mantissa) {
// NaN. Generate a quiet NaN.
f32_mantissa = f32_mantissa_mask;
} else {
// Inf. Leave zero mantissa.
}
- } else if (f16_exp == 0) {
+ } else if (src_exp == 0) {
// Zero or subnormal. Generate zero. Leave zero mantissa.
} else {
// Normal finite value.
- int arithmetic_f16_exp = f16_exp >> f16_exp_shift;
- int arithmetic_f32_exp = arithmetic_f16_exp + (1 << (f32_exp_bits - 1)) -
- (1 << (f16_exp_bits - 1));
+ int arithmetic_src_exp = src_exp >> src_exp_shift;
+ int arithmetic_f32_exp = arithmetic_src_exp + (1 << (f32_exp_bits - 1)) -
+ (1 << (src_exp_bits - 1));
f32_exp = arithmetic_f32_exp << f32_exp_shift;
- f32_mantissa = f16_mantissa << (f32_mantissa_bits - f16_mantissa_bits);
+ f32_mantissa = src_mantissa << (f32_mantissa_bits - src_mantissa_bits);
}
const uint32_t u32_value = f32_sign | f32_exp | f32_mantissa;
float f32_value;
@@ -321,24 +335,28 @@
return f32_value;
}
-static inline uint16_t iree_math_f32_to_generic_fp16(float value,
- int exp_bits) {
- IREE_MATH_FP_FORMAT_CONSTANTS(f16_, 16, exp_bits)
- IREE_MATH_FP_FORMAT_CONSTANTS(f32_, 32, 8)
+// Generic conversion from f32 to any less-than-32-bit floating-point format,
+// rounding to nearest-even. The return value is typed as a uint32_t for
+// genericity but occupies only the bottom (1 + exp_bits + mantissa_bits) bits.
+// The upper bits of the return value are unused.
+static inline uint32_t iree_math_truncate_f32_to_bits_rounding_to_nearest_even(
+ float value, int exp_bits, int mantissa_bits, bool have_infinity) {
+ IREE_MATH_FP_FORMAT_CONSTANTS(dst_, exp_bits, mantissa_bits)
+ IREE_MATH_FP_FORMAT_CONSTANTS(f32_, 8, 23)
uint32_t u32_value;
memcpy(&u32_value, &value, sizeof value);
const uint32_t f32_sign = u32_value & f32_sign_mask;
- const uint32_t f16_sign = f32_sign >> (f32_sign_shift - f16_sign_shift);
+ const uint32_t dst_sign = f32_sign >> (f32_sign_shift - dst_sign_shift);
const uint32_t f32_exp = u32_value & f32_exp_mask;
const uint32_t f32_mantissa = u32_value & f32_mantissa_mask;
- uint32_t f16_exp = 0;
- uint32_t f16_mantissa = 0;
- if (f32_exp == f32_exp_mask) {
+ uint32_t dst_exp = 0;
+ uint32_t dst_mantissa = 0;
+ if (f32_exp >= f32_exp_mask) {
// NaN or Inf case.
- f16_exp = f16_exp_mask;
- if (f32_mantissa) {
+ dst_exp = dst_exp_mask;
+ if (f32_mantissa || !have_infinity) {
// NaN. Generate a quiet NaN.
- f16_mantissa = f16_mantissa_mask;
+ dst_mantissa = dst_mantissa_mask;
} else {
// Inf. Leave zero mantissa.
}
@@ -346,18 +364,24 @@
// Zero or subnormal. Generate zero. Leave zero mantissa.
} else {
// Normal finite value.
- int arithmetic_exp = (f32_exp >> f32_exp_shift) - (1 << (f32_exp_bits - 1));
- if (arithmetic_exp >= (1 << (f16_exp_bits - 1))) {
+ int arithmetic_exp = (f32_exp >> f32_exp_shift) - f32_exp_bias;
+ // Test if the exponent is too large for the destination type. If
+ // the destination type does not have infinities, that frees up the
+ // max exponent value for additional finite values.
+ if (arithmetic_exp > (1 << (dst_exp_bits - 1)) - have_infinity) {
// Overflow. Generate Inf. Leave zero mantissa.
- f16_exp = f16_exp_mask;
- } else if (arithmetic_exp < -(1 << (f16_exp_bits - 1))) {
+ dst_exp = dst_exp_mask;
+ if (!have_infinity) {
+ // Generate NaN.
+ dst_mantissa = dst_mantissa_mask;
+ }
+ } else if (arithmetic_exp < -(1 << (dst_exp_bits - 1))) {
// Underflow. Generate zero. Leave zero mantissa.
- f16_exp = 0;
+ dst_exp = 0;
} else {
// Normal case.
// Implement round-to-nearest-even, by adding a bias before truncating.
- // truncating.
- int even_bit = 1u << (f32_mantissa_bits - f16_mantissa_bits);
+ int even_bit = 1u << (f32_mantissa_bits - dst_mantissa_bits);
int odd_bit = even_bit >> 1;
uint32_t biased_f32_mantissa =
f32_mantissa +
@@ -377,52 +401,56 @@
biased_f32_mantissa = 0;
++arithmetic_exp;
}
- // The exponent increment in the above if() branch may cause overflow.
- // This is exercised by converting 65520.0f from f32 to f16. No special
- // handling is needed for this case: the above if() branch already set
- // biased_f32_mantissa=0, so we will be generating a 0 mantissa, as
- // needed for infinite values.
- f16_exp = (arithmetic_exp + (1 << (f16_exp_bits - 1))) << f16_exp_shift;
- f16_mantissa =
- biased_f32_mantissa >> (f32_mantissa_bits - f16_mantissa_bits);
+ // In the !have_infinity case, arithmetic_exp might have been the top
+ // value already, so incrementing it may have overflown it.
+ if (!have_infinity && arithmetic_exp > (1 << (dst_exp_bits - 1))) {
+ dst_exp = dst_exp_mask;
+ dst_mantissa = dst_mantissa_mask;
+ } else {
+ // The exponent increment in the above if() branch may cause overflow.
+ // This is exercised by converting 65520.0f from f32 to f16. No special
+ // handling is needed for this case: the above if() branch already set
+ // biased_f32_mantissa=0, so we will be generating a 0 mantissa, as
+ // needed for infinite values.
+ dst_exp = (arithmetic_exp + dst_exp_bias) << dst_exp_shift;
+ dst_mantissa =
+ biased_f32_mantissa >> (f32_mantissa_bits - dst_mantissa_bits);
+ }
}
}
- uint16_t f16_value = f16_sign | f16_exp | f16_mantissa;
- return f16_value;
+ uint32_t dst_value = dst_sign | dst_exp | dst_mantissa;
+ return dst_value;
}
-// Converts a fp16 value to a 32-bit C `float`.
-static inline float iree_math_f16_to_f32(uint16_t f16_value) {
- return iree_math_generic_fp16_to_f32(f16_value, 5);
-}
+#define IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(NAME, INT_TYPE, EXP_BITS, \
+ MANTISSA_BITS, HAVE_INFINITY) \
+ /* Converts a to a 32-bit C `float`. */ \
+ static inline float iree_math_##NAME##_to_f32(INT_TYPE src) { \
+ return iree_math_make_f32_from_bits(src, EXP_BITS, MANTISSA_BITS, \
+ HAVE_INFINITY); \
+ } \
+ /* Truncates a 32-bit C `float`, rounding to nearest even. */ \
+ static inline INT_TYPE iree_math_f32_to_##NAME(float value) { \
+ return iree_math_truncate_f32_to_bits_rounding_to_nearest_even( \
+ value, EXP_BITS, MANTISSA_BITS, HAVE_INFINITY); \
+ } \
+ /* Round-trip f32->f32 rounding via the narrow float type */ \
+ static inline float iree_math_round_to_nearest_##NAME(float value) { \
+ return iree_math_##NAME##_to_f32(iree_math_f32_to_##NAME(value)); \
+ }
-// Converts a 32-bit C `float` value to a fp16 value, rounding to nearest
-// even.
-static inline uint16_t iree_math_f32_to_f16(float value) {
- return iree_math_f32_to_generic_fp16(value, 5);
-}
+// IEEE half-precision a.k.a. float16,
+// https://en.wikipedia.org/wiki/Half-precision_floating-point_format
+IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(f16, uint16_t, 5, 10, /*have_infinity=*/true)
-// Rounds of 32-bit C `float` value to nearest 16-bit value and returns
-// 32-bit `float`
-static inline float iree_math_round_to_nearest_f16(float f32_value) {
- return iree_math_f16_to_f32(iree_math_f32_to_f16(f32_value));
-}
+// Bfloat16, https://en.wikipedia.org/wiki/Bfloat16_floating-point_format
+IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(bf16, uint16_t, 8, 7, /*have_infinity=*/true)
-// Converts a bfloat16 value to a 32-bit C `float`.
-static inline float iree_math_bf16_to_f32(uint16_t bf16_value) {
- return iree_math_generic_fp16_to_f32(bf16_value, 8);
-}
+// F8E5M2 type, https://arxiv.org/abs/2209.05433
+IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(f8e5m2, uint8_t, 5, 2, /*have_infinity=*/true)
-// Converts a 32-bit C `float` value to a bfloat16 value, rounding to nearest
-// even.
-static inline uint16_t iree_math_f32_to_bf16(float value) {
- return iree_math_f32_to_generic_fp16(value, 8);
-}
-
-// Rounds of 32-bit C `float` value to nearest bfloat16 value and returns
-// 32-bit `float`
-static inline float iree_math_round_to_nearest_bf16(float f32_value) {
- return iree_math_bf16_to_f32(iree_math_f32_to_bf16(f32_value));
-}
+// F8E4M3 type, https://arxiv.org/abs/2209.05433.
+IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(f8e4m3, uint8_t, 4, 3,
+ /*have_infinity=*/false)
#endif // IREE_BASE_INTERNAL_MATH_H_
diff --git a/runtime/src/iree/base/internal/math_test.cc b/runtime/src/iree/base/internal/math_test.cc
index be191cc..b3548d7 100644
--- a/runtime/src/iree/base/internal/math_test.cc
+++ b/runtime/src/iree/base/internal/math_test.cc
@@ -169,7 +169,7 @@
TEST(F16ConversionTest, F32ToF16) {
constexpr float kF16Max = 65504.f;
- constexpr float kF16Min = 0.0000610351563f;
+ constexpr float kF16Min = 1.f / 16384.f;
// Within range, normal truncation.
EXPECT_EQ(0x3400, iree_math_f32_to_f16(0.25f));
EXPECT_EQ(0xd646, iree_math_f32_to_f16(-100.375f));
@@ -201,7 +201,7 @@
TEST(F16ConversionTest, F32ToF16ToF32) {
constexpr float kF16Max = 65504.f;
- constexpr float kF16Min = 0.0000610351563f;
+ constexpr float kF16Min = 1.f / 16384.f;
// Within range, should just round.
EXPECT_EQ(0.25f, iree_math_f16_to_f32(iree_math_f32_to_f16(0.25f)));
EXPECT_EQ(-0.25f, iree_math_f16_to_f32(iree_math_f32_to_f16(-0.25f)));
@@ -257,6 +257,10 @@
EXPECT_NE(nan, nan);
}
+//==============================================================================
+// Bfloat16 support
+//==============================================================================
+
TEST(BF16ConversionTest, F32ToBF16) {
// Within range, normal truncation.
EXPECT_EQ(0x3e80, iree_math_f32_to_bf16(0.25f));
@@ -319,4 +323,204 @@
EXPECT_NE(nan, nan);
}
+//==============================================================================
+// F8E5M2 support
+//==============================================================================
+
+TEST(F8E5M2ConversionTest, F32ToF8E5M2) {
+ // See https://arxiv.org/pdf/2209.05433.pdf, Table 1.
+ constexpr float kF8E5M2Max = 57344.f;
+ constexpr float kF8E5M2Min = 1.f / 16384.f;
+ // Within range, normal truncation.
+ EXPECT_EQ(0x34, iree_math_f32_to_f8e5m2(0.25f));
+ EXPECT_EQ(0xd6, iree_math_f32_to_f8e5m2(-100.375f));
+ EXPECT_EQ(0x7A, iree_math_f32_to_f8e5m2(49152.f));
+ EXPECT_EQ(0xFA, iree_math_f32_to_f8e5m2(-49152.f));
+ EXPECT_EQ(0x7B, iree_math_f32_to_f8e5m2(kF8E5M2Max));
+ EXPECT_EQ(0xFB, iree_math_f32_to_f8e5m2(-kF8E5M2Max));
+ EXPECT_EQ(0x04, iree_math_f32_to_f8e5m2(kF8E5M2Min));
+ EXPECT_EQ(0x84, iree_math_f32_to_f8e5m2(-kF8E5M2Min));
+ // Infinity
+ EXPECT_EQ(0x7c, iree_math_f32_to_f8e5m2(INFINITY));
+ EXPECT_EQ(0xfc, iree_math_f32_to_f8e5m2(-INFINITY));
+ // Overflow
+ EXPECT_EQ(0x7C, iree_math_f32_to_f8e5m2(FLT_MAX));
+ EXPECT_EQ(0xFC, iree_math_f32_to_f8e5m2(-FLT_MAX));
+ // Important case to test: overflow due to rounding to nearest-even of 61440
+ // to 65536.
+ EXPECT_EQ(0x7B, iree_math_f32_to_f8e5m2(61439.f));
+ EXPECT_EQ(0xFB, iree_math_f32_to_f8e5m2(-61439.f));
+ EXPECT_EQ(0x7C, iree_math_f32_to_f8e5m2(61440.f));
+ EXPECT_EQ(0xFC, iree_math_f32_to_f8e5m2(-61440.f));
+ EXPECT_EQ(0x7C, iree_math_f32_to_f8e5m2(65536.f));
+ EXPECT_EQ(0xFC, iree_math_f32_to_f8e5m2(-65536.f));
+ // Underflow
+ EXPECT_EQ(0, iree_math_f32_to_f8e5m2(FLT_MIN));
+ EXPECT_EQ(0x80, iree_math_f32_to_f8e5m2(-FLT_MIN));
+ // Denormals may or may not get flushed to zero. Accept both ways.
+ uint16_t positive_denormal = iree_math_f32_to_f8e5m2(kF8E5M2Min / 2);
+ EXPECT_TRUE(positive_denormal == 0 || positive_denormal == 0x02);
+ uint16_t negative_denormal = iree_math_f32_to_f8e5m2(-kF8E5M2Min / 2);
+ EXPECT_TRUE(negative_denormal == 0x80 || negative_denormal == 0x82);
+}
+
+TEST(F8E5M2ConversionTest, F32ToF8E5M2ToF32) {
+ // See https://arxiv.org/pdf/2209.05433.pdf, Table 1.
+ constexpr float kF8E5M2Max = 57344.f;
+ constexpr float kF8E5M2Min = 1.f / 16384.f;
+ // Within range, should just round.
+ EXPECT_EQ(0.25f, iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(0.25f)));
+ EXPECT_EQ(-0.25f, iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(-0.25f)));
+ EXPECT_EQ(96.f, iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(100.375f)));
+ EXPECT_EQ(-96.f, iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(-100.375f)));
+ EXPECT_EQ(96.f, iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(96.f)));
+ EXPECT_EQ(-96.f, iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(-96.f)));
+ EXPECT_EQ(kF8E5M2Max,
+ iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(kF8E5M2Max)));
+ EXPECT_EQ(-kF8E5M2Max,
+ iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(-kF8E5M2Max)));
+ EXPECT_EQ(kF8E5M2Min,
+ iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(kF8E5M2Min)));
+ EXPECT_EQ(-kF8E5M2Min,
+ iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(-kF8E5M2Min)));
+ // Powers of two should always be exactly representable across the
+ // exponent range.
+ EXPECT_EQ(32768.f, iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(32768.f)));
+ EXPECT_EQ(-32768.f,
+ iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(-32768.f)));
+ // Overflow
+ EXPECT_EQ(INFINITY,
+ iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(FLT_MAX)));
+ EXPECT_EQ(-INFINITY,
+ iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(-FLT_MAX)));
+ EXPECT_GT(kF8E5M2Max + 1.f,
+ iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(kF8E5M2Max + 1.f)));
+ // Underflow
+ EXPECT_EQ(0.0f, iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(FLT_MIN)));
+ EXPECT_EQ(0.0f, iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(-FLT_MIN)));
+ // Denormals may or may not get flushed to zero. Accept both ways.
+ float positive_denormal =
+ iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(kF8E5M2Min / 2));
+ EXPECT_TRUE(positive_denormal == 0.0f ||
+ positive_denormal == 3.05175781e-05f);
+ // Inf and Nan
+ EXPECT_EQ(INFINITY,
+ iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(INFINITY)));
+ EXPECT_EQ(-INFINITY,
+ iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(-INFINITY)));
+ // Check that the result is a Nan with nan != nan.
+ float nan = iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(NAN));
+ EXPECT_NE(nan, nan);
+}
+
+//==============================================================================
+// F8E4M3 support
+//==============================================================================
+
+TEST(F8E4M3ConversionTest, F32ToF8E4M3) {
+ // See https://arxiv.org/pdf/2209.05433.pdf, Table 1.
+ // The F8E4M3 format is special: it has no infinities, and has some larger
+ // finite values instead.
+ constexpr float kF8E4M3Max = 448.f;
+ constexpr float kF8E4M3Min = 1.f / 64.f;
+ // Within range, normal truncation.
+ EXPECT_EQ(0x28, iree_math_f32_to_f8e4m3(0.25f));
+ EXPECT_EQ(0xED, iree_math_f32_to_f8e4m3(-100.375f));
+ // Extra large finite values thanks to not having infinities.
+ EXPECT_EQ(0x78, iree_math_f32_to_f8e4m3(256.0f));
+ EXPECT_EQ(0x79, iree_math_f32_to_f8e4m3(288.0f));
+ EXPECT_EQ(0x7A, iree_math_f32_to_f8e4m3(320.0f));
+ EXPECT_EQ(0x7B, iree_math_f32_to_f8e4m3(352.0f));
+ EXPECT_EQ(0x7C, iree_math_f32_to_f8e4m3(384.0f));
+ EXPECT_EQ(0x7D, iree_math_f32_to_f8e4m3(416.0f));
+ EXPECT_EQ(0x7E, iree_math_f32_to_f8e4m3(kF8E4M3Max));
+ EXPECT_EQ(0xFE, iree_math_f32_to_f8e4m3(-kF8E4M3Max));
+ // Min normal values.
+ EXPECT_EQ(0x08, iree_math_f32_to_f8e4m3(kF8E4M3Min));
+ EXPECT_EQ(0x88, iree_math_f32_to_f8e4m3(-kF8E4M3Min));
+ // Infinity
+ EXPECT_EQ(0x7F, iree_math_f32_to_f8e4m3(INFINITY));
+ EXPECT_EQ(0xfF, iree_math_f32_to_f8e4m3(-INFINITY));
+ // Overflow
+ EXPECT_EQ(0x7F, iree_math_f32_to_f8e4m3(FLT_MAX));
+ EXPECT_EQ(0xFF, iree_math_f32_to_f8e4m3(-FLT_MAX));
+ // Test some round-to-nearest-even behavior.
+ EXPECT_EQ(0x70, iree_math_f32_to_f8e4m3(136.0f));
+ EXPECT_EQ(0x72, iree_math_f32_to_f8e4m3(152.0f));
+ EXPECT_EQ(0x72, iree_math_f32_to_f8e4m3(168.0f));
+ EXPECT_EQ(0x74, iree_math_f32_to_f8e4m3(184.0f));
+ EXPECT_EQ(0x78, iree_math_f32_to_f8e4m3(272.0f));
+ EXPECT_EQ(0x7A, iree_math_f32_to_f8e4m3(304.0f));
+ EXPECT_EQ(0x7A, iree_math_f32_to_f8e4m3(336.0f));
+ EXPECT_EQ(0x7C, iree_math_f32_to_f8e4m3(368.0f));
+ // Important case to test: overflow due to rounding to nearest-even of 465
+ // to 512, while 464 gets rounded to nearest-even 448, not overflowing.
+ EXPECT_EQ(0x7E, iree_math_f32_to_f8e4m3(464.f));
+ EXPECT_EQ(0xFE, iree_math_f32_to_f8e4m3(-464.f));
+ EXPECT_EQ(0x7F, iree_math_f32_to_f8e4m3(465.f));
+ EXPECT_EQ(0xFF, iree_math_f32_to_f8e4m3(-465.f));
+ // Largest float value in the same exponent bucket, a tricky case.
+ EXPECT_EQ(0x7F, iree_math_f32_to_f8e4m3(511.f));
+ EXPECT_EQ(0xFF, iree_math_f32_to_f8e4m3(-511.f));
+ // Underflow
+ EXPECT_EQ(0, iree_math_f32_to_f8e4m3(FLT_MIN));
+ EXPECT_EQ(0x80, iree_math_f32_to_f8e4m3(-FLT_MIN));
+ // Denormals may or may not get flushed to zero. Accept both ways.
+ uint8_t positive_denormal = iree_math_f32_to_f8e4m3(kF8E4M3Min / 2);
+ EXPECT_TRUE(positive_denormal == 0 || positive_denormal == 0x04);
+ uint8_t negative_denormal = iree_math_f32_to_f8e4m3(-kF8E4M3Min / 2);
+ EXPECT_TRUE(negative_denormal == 0x80 || negative_denormal == 0x84);
+}
+
+TEST(F8E4M3ConversionTest, F32ToF8E4M3ToF32) {
+ // See https://arxiv.org/pdf/2209.05433.pdf, Table 1.
+ // The F8E4M3 format is special: it has no infinities, and has some larger
+ // finite values instead.
+ constexpr float kF8E4M3Max = 448.f;
+ constexpr float kF8E4M3Min = 1.f / 64.f;
+ // Within range, should just round.
+ EXPECT_EQ(0.25f, iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(0.25f)));
+ EXPECT_EQ(-0.25f, iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(-0.25f)));
+ EXPECT_EQ(104.f, iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(100.375f)));
+ EXPECT_EQ(-104.f,
+ iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(-100.375f)));
+ EXPECT_EQ(104.f, iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(100.4f)));
+ EXPECT_EQ(-104.f, iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(-100.4f)));
+ EXPECT_EQ(kF8E4M3Max,
+ iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(kF8E4M3Max)));
+ EXPECT_EQ(-kF8E4M3Max,
+ iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(-kF8E4M3Max)));
+ EXPECT_EQ(kF8E4M3Min,
+ iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(kF8E4M3Min)));
+ EXPECT_EQ(-kF8E4M3Min,
+ iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(-kF8E4M3Min)));
+ // Powers of two should always be exactly representable across the
+ // exponent range.
+ EXPECT_EQ(256.f, iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(256.f)));
+ EXPECT_EQ(-256.f, iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(-256.f)));
+ // Overflow
+ EXPECT_TRUE(
+ std::isnan(iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(FLT_MAX))));
+ EXPECT_TRUE(
+ std::isnan(iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(-FLT_MAX))));
+ EXPECT_GT(kF8E4M3Max + 1.f,
+ iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(kF8E4M3Max + 1.f)));
+ // Underflow
+ EXPECT_EQ(0.0f, iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(FLT_MIN)));
+ EXPECT_EQ(0.0f, iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(-FLT_MIN)));
+ // Denormals may or may not get flushed to zero. Accept both ways.
+ float positive_denormal =
+ iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(kF8E4M3Min / 2));
+ EXPECT_TRUE(positive_denormal == 0.0f ||
+ positive_denormal == 3.05175781e-05f);
+ // Inf and Nan
+ EXPECT_TRUE(
+ std::isnan(iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(INFINITY))));
+ EXPECT_TRUE(
+ std::isnan(iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(-INFINITY))));
+ // Check that the result is a Nan with nan != nan.
+ float nan = iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(NAN));
+ EXPECT_NE(nan, nan);
+}
+
} // namespace