Add conversions for FP8 types (F8E5M2 and F8E4M3) (#16374)

This PR almost doesn't make code any bigger because the existing
conversion code was already essentially generic. So at least the F8E5M2
type falls for free. F8E4M3 is a bit trickier due to it not having
infinities and reclaiming that encoding space to get extra large finite
values.
diff --git a/runtime/src/iree/base/internal/math.h b/runtime/src/iree/base/internal/math.h
index c3c29f9..58dd88d 100644
--- a/runtime/src/iree/base/internal/math.h
+++ b/runtime/src/iree/base/internal/math.h
@@ -264,7 +264,7 @@
 }
 
 //==============================================================================
-// FP16 and BFloat16 support
+// FP16, BFloat16 and FP8 support
 //==============================================================================
 
 // NOTE: We used to have code here using built-in _Float16 type support.
@@ -273,47 +273,61 @@
 // in slow generic fallbacks or test code, and we weren't able to use
 // a builtin for bf16 anyway.
 
-#define IREE_MATH_FP_FORMAT_CONSTANTS(prefix, bits, ebits)                   \
+// Define some helper constants for working with a floating-point format with
+// the given number of {exponent,mantissa} bits.
+#define IREE_MATH_FP_FORMAT_CONSTANTS(prefix, ebits, mbits)                  \
   const int prefix##exp_bits IREE_ATTRIBUTE_UNUSED = ebits;                  \
-  const int prefix##mantissa_bits IREE_ATTRIBUTE_UNUSED =                    \
-      bits - 1 - prefix##exp_bits;                                           \
-  const int prefix##sign_shift IREE_ATTRIBUTE_UNUSED = bits - 1;             \
+  const int prefix##mantissa_bits IREE_ATTRIBUTE_UNUSED = mbits;             \
+  const int prefix##sign_shift IREE_ATTRIBUTE_UNUSED = ebits + mbits;        \
   const int prefix##exp_shift IREE_ATTRIBUTE_UNUSED = prefix##mantissa_bits; \
   const int prefix##sign_mask IREE_ATTRIBUTE_UNUSED = 1u                     \
                                                       << prefix##sign_shift; \
   const int prefix##mantissa_mask IREE_ATTRIBUTE_UNUSED =                    \
       (1u << prefix##exp_shift) - 1;                                         \
   const int prefix##exp_mask IREE_ATTRIBUTE_UNUSED =                         \
-      (1u << prefix##sign_shift) - (1u << prefix##exp_shift);
+      (1u << prefix##sign_shift) - (1u << prefix##exp_shift);                \
+  const int prefix##exp_bias IREE_ATTRIBUTE_UNUSED =                         \
+      (1u << (prefix##exp_bits - 1)) - 1;
 
-static inline float iree_math_generic_fp16_to_f32(uint16_t f16_value,
-                                                  int exp_bits) {
-  IREE_MATH_FP_FORMAT_CONSTANTS(f16_, 16, exp_bits)
-  IREE_MATH_FP_FORMAT_CONSTANTS(f32_, 32, 8)
-  const uint32_t f16_sign = f16_value & f16_sign_mask;
-  const uint32_t f32_sign = f16_sign << (f32_sign_shift - f16_sign_shift);
-  const uint32_t f16_exp = f16_value & f16_exp_mask;
-  const uint32_t f16_mantissa = f16_value & f16_mantissa_mask;
+// Generic conversion from any less-than-32-bit floating-point format to f32.
+// The `src` value is typed as a uint32_t for genericity but occupies only the
+// bottom (1 + exp_bits + mantissa_bits) bits. The upper bits of `src` are
+// unused.
+static inline float iree_math_make_f32_from_bits(uint32_t src, int exp_bits,
+                                                 int mantissa_bits,
+                                                 bool have_infinity) {
+  IREE_MATH_FP_FORMAT_CONSTANTS(src_, exp_bits, mantissa_bits)
+  IREE_MATH_FP_FORMAT_CONSTANTS(f32_, 8, 23)
+  const uint32_t src_sign = src & src_sign_mask;
+  const uint32_t f32_sign = src_sign << (f32_sign_shift - src_sign_shift);
+  const uint32_t src_exp = src & src_exp_mask;
+  const uint32_t src_mantissa = src & src_mantissa_mask;
   uint32_t f32_exp = 0;
   uint32_t f32_mantissa = 0;
-  if (f16_exp == f16_exp_mask) {
+  if (src_exp == src_exp_mask) {
+    // No infinities => more large finite values.
+    if (!have_infinity && src_mantissa != src_mantissa_mask) {
+      float sign = (src & src_sign_mask) ? -1.0f : 1.0f;
+      return sign * 2 * (1u << src_exp_bits) *
+             ((1u << src_mantissa_bits) + src_mantissa);
+    }
     // NaN or Inf case.
     f32_exp = f32_exp_mask;
-    if (f16_mantissa) {
+    if (src_mantissa) {
       // NaN. Generate a quiet NaN.
       f32_mantissa = f32_mantissa_mask;
     } else {
       // Inf. Leave zero mantissa.
     }
-  } else if (f16_exp == 0) {
+  } else if (src_exp == 0) {
     // Zero or subnormal. Generate zero. Leave zero mantissa.
   } else {
     // Normal finite value.
-    int arithmetic_f16_exp = f16_exp >> f16_exp_shift;
-    int arithmetic_f32_exp = arithmetic_f16_exp + (1 << (f32_exp_bits - 1)) -
-                             (1 << (f16_exp_bits - 1));
+    int arithmetic_src_exp = src_exp >> src_exp_shift;
+    int arithmetic_f32_exp = arithmetic_src_exp + (1 << (f32_exp_bits - 1)) -
+                             (1 << (src_exp_bits - 1));
     f32_exp = arithmetic_f32_exp << f32_exp_shift;
-    f32_mantissa = f16_mantissa << (f32_mantissa_bits - f16_mantissa_bits);
+    f32_mantissa = src_mantissa << (f32_mantissa_bits - src_mantissa_bits);
   }
   const uint32_t u32_value = f32_sign | f32_exp | f32_mantissa;
   float f32_value;
@@ -321,24 +335,28 @@
   return f32_value;
 }
 
-static inline uint16_t iree_math_f32_to_generic_fp16(float value,
-                                                     int exp_bits) {
-  IREE_MATH_FP_FORMAT_CONSTANTS(f16_, 16, exp_bits)
-  IREE_MATH_FP_FORMAT_CONSTANTS(f32_, 32, 8)
+// Generic conversion from f32 to any less-than-32-bit floating-point format,
+// rounding to nearest-even. The return value is typed as a uint32_t for
+// genericity but occupies only the bottom (1 + exp_bits + mantissa_bits) bits.
+// The upper bits of the return value are unused.
+static inline uint32_t iree_math_truncate_f32_to_bits_rounding_to_nearest_even(
+    float value, int exp_bits, int mantissa_bits, bool have_infinity) {
+  IREE_MATH_FP_FORMAT_CONSTANTS(dst_, exp_bits, mantissa_bits)
+  IREE_MATH_FP_FORMAT_CONSTANTS(f32_, 8, 23)
   uint32_t u32_value;
   memcpy(&u32_value, &value, sizeof value);
   const uint32_t f32_sign = u32_value & f32_sign_mask;
-  const uint32_t f16_sign = f32_sign >> (f32_sign_shift - f16_sign_shift);
+  const uint32_t dst_sign = f32_sign >> (f32_sign_shift - dst_sign_shift);
   const uint32_t f32_exp = u32_value & f32_exp_mask;
   const uint32_t f32_mantissa = u32_value & f32_mantissa_mask;
-  uint32_t f16_exp = 0;
-  uint32_t f16_mantissa = 0;
-  if (f32_exp == f32_exp_mask) {
+  uint32_t dst_exp = 0;
+  uint32_t dst_mantissa = 0;
+  if (f32_exp >= f32_exp_mask) {
     // NaN or Inf case.
-    f16_exp = f16_exp_mask;
-    if (f32_mantissa) {
+    dst_exp = dst_exp_mask;
+    if (f32_mantissa || !have_infinity) {
       // NaN. Generate a quiet NaN.
-      f16_mantissa = f16_mantissa_mask;
+      dst_mantissa = dst_mantissa_mask;
     } else {
       // Inf. Leave zero mantissa.
     }
@@ -346,18 +364,24 @@
     // Zero or subnormal. Generate zero. Leave zero mantissa.
   } else {
     // Normal finite value.
-    int arithmetic_exp = (f32_exp >> f32_exp_shift) - (1 << (f32_exp_bits - 1));
-    if (arithmetic_exp >= (1 << (f16_exp_bits - 1))) {
+    int arithmetic_exp = (f32_exp >> f32_exp_shift) - f32_exp_bias;
+    // Test if the exponent is too large for the destination type. If
+    // the destination type does not have infinities, that frees up the
+    // max exponent value for additional finite values.
+    if (arithmetic_exp > (1 << (dst_exp_bits - 1)) - have_infinity) {
       // Overflow. Generate Inf. Leave zero mantissa.
-      f16_exp = f16_exp_mask;
-    } else if (arithmetic_exp < -(1 << (f16_exp_bits - 1))) {
+      dst_exp = dst_exp_mask;
+      if (!have_infinity) {
+        // Generate NaN.
+        dst_mantissa = dst_mantissa_mask;
+      }
+    } else if (arithmetic_exp < -(1 << (dst_exp_bits - 1))) {
       // Underflow. Generate zero. Leave zero mantissa.
-      f16_exp = 0;
+      dst_exp = 0;
     } else {
       // Normal case.
       // Implement round-to-nearest-even, by adding a bias before truncating.
-      // truncating.
-      int even_bit = 1u << (f32_mantissa_bits - f16_mantissa_bits);
+      int even_bit = 1u << (f32_mantissa_bits - dst_mantissa_bits);
       int odd_bit = even_bit >> 1;
       uint32_t biased_f32_mantissa =
           f32_mantissa +
@@ -377,52 +401,56 @@
         biased_f32_mantissa = 0;
         ++arithmetic_exp;
       }
-      // The exponent increment in the above if() branch may cause overflow.
-      // This is exercised by converting 65520.0f from f32 to f16. No special
-      // handling is needed for this case: the above if() branch already set
-      // biased_f32_mantissa=0, so we will be generating a 0 mantissa, as
-      // needed for infinite values.
-      f16_exp = (arithmetic_exp + (1 << (f16_exp_bits - 1))) << f16_exp_shift;
-      f16_mantissa =
-          biased_f32_mantissa >> (f32_mantissa_bits - f16_mantissa_bits);
+      // In the !have_infinity case, arithmetic_exp might have been the top
+      // value already, so incrementing it may have overflown it.
+      if (!have_infinity && arithmetic_exp > (1 << (dst_exp_bits - 1))) {
+        dst_exp = dst_exp_mask;
+        dst_mantissa = dst_mantissa_mask;
+      } else {
+        // The exponent increment in the above if() branch may cause overflow.
+        // This is exercised by converting 65520.0f from f32 to f16. No special
+        // handling is needed for this case: the above if() branch already set
+        // biased_f32_mantissa=0, so we will be generating a 0 mantissa, as
+        // needed for infinite values.
+        dst_exp = (arithmetic_exp + dst_exp_bias) << dst_exp_shift;
+        dst_mantissa =
+            biased_f32_mantissa >> (f32_mantissa_bits - dst_mantissa_bits);
+      }
     }
   }
-  uint16_t f16_value = f16_sign | f16_exp | f16_mantissa;
-  return f16_value;
+  uint32_t dst_value = dst_sign | dst_exp | dst_mantissa;
+  return dst_value;
 }
 
-// Converts a fp16 value to a 32-bit C `float`.
-static inline float iree_math_f16_to_f32(uint16_t f16_value) {
-  return iree_math_generic_fp16_to_f32(f16_value, 5);
-}
+#define IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(NAME, INT_TYPE, EXP_BITS,     \
+                                          MANTISSA_BITS, HAVE_INFINITY) \
+  /* Converts a to a 32-bit C `float`. */                               \
+  static inline float iree_math_##NAME##_to_f32(INT_TYPE src) {         \
+    return iree_math_make_f32_from_bits(src, EXP_BITS, MANTISSA_BITS,   \
+                                        HAVE_INFINITY);                 \
+  }                                                                     \
+  /* Truncates a 32-bit C `float`, rounding to nearest even. */         \
+  static inline INT_TYPE iree_math_f32_to_##NAME(float value) {         \
+    return iree_math_truncate_f32_to_bits_rounding_to_nearest_even(     \
+        value, EXP_BITS, MANTISSA_BITS, HAVE_INFINITY);                 \
+  }                                                                     \
+  /* Round-trip f32->f32 rounding via the narrow float type */          \
+  static inline float iree_math_round_to_nearest_##NAME(float value) {  \
+    return iree_math_##NAME##_to_f32(iree_math_f32_to_##NAME(value));   \
+  }
 
-// Converts a 32-bit C `float` value to a fp16 value, rounding to nearest
-// even.
-static inline uint16_t iree_math_f32_to_f16(float value) {
-  return iree_math_f32_to_generic_fp16(value, 5);
-}
+// IEEE half-precision a.k.a. float16,
+// https://en.wikipedia.org/wiki/Half-precision_floating-point_format
+IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(f16, uint16_t, 5, 10, /*have_infinity=*/true)
 
-// Rounds of 32-bit C `float` value to nearest 16-bit value and returns
-// 32-bit `float`
-static inline float iree_math_round_to_nearest_f16(float f32_value) {
-  return iree_math_f16_to_f32(iree_math_f32_to_f16(f32_value));
-}
+// Bfloat16, https://en.wikipedia.org/wiki/Bfloat16_floating-point_format
+IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(bf16, uint16_t, 8, 7, /*have_infinity=*/true)
 
-// Converts a bfloat16 value to a 32-bit C `float`.
-static inline float iree_math_bf16_to_f32(uint16_t bf16_value) {
-  return iree_math_generic_fp16_to_f32(bf16_value, 8);
-}
+// F8E5M2 type, https://arxiv.org/abs/2209.05433
+IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(f8e5m2, uint8_t, 5, 2, /*have_infinity=*/true)
 
-// Converts a 32-bit C `float` value to a bfloat16 value, rounding to nearest
-// even.
-static inline uint16_t iree_math_f32_to_bf16(float value) {
-  return iree_math_f32_to_generic_fp16(value, 8);
-}
-
-// Rounds of 32-bit C `float` value to nearest bfloat16 value and returns
-// 32-bit `float`
-static inline float iree_math_round_to_nearest_bf16(float f32_value) {
-  return iree_math_bf16_to_f32(iree_math_f32_to_bf16(f32_value));
-}
+// F8E4M3 type, https://arxiv.org/abs/2209.05433.
+IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(f8e4m3, uint8_t, 4, 3,
+                                  /*have_infinity=*/false)
 
 #endif  // IREE_BASE_INTERNAL_MATH_H_
diff --git a/runtime/src/iree/base/internal/math_test.cc b/runtime/src/iree/base/internal/math_test.cc
index be191cc..b3548d7 100644
--- a/runtime/src/iree/base/internal/math_test.cc
+++ b/runtime/src/iree/base/internal/math_test.cc
@@ -169,7 +169,7 @@
 
 TEST(F16ConversionTest, F32ToF16) {
   constexpr float kF16Max = 65504.f;
-  constexpr float kF16Min = 0.0000610351563f;
+  constexpr float kF16Min = 1.f / 16384.f;
   // Within range, normal truncation.
   EXPECT_EQ(0x3400, iree_math_f32_to_f16(0.25f));
   EXPECT_EQ(0xd646, iree_math_f32_to_f16(-100.375f));
@@ -201,7 +201,7 @@
 
 TEST(F16ConversionTest, F32ToF16ToF32) {
   constexpr float kF16Max = 65504.f;
-  constexpr float kF16Min = 0.0000610351563f;
+  constexpr float kF16Min = 1.f / 16384.f;
   // Within range, should just round.
   EXPECT_EQ(0.25f, iree_math_f16_to_f32(iree_math_f32_to_f16(0.25f)));
   EXPECT_EQ(-0.25f, iree_math_f16_to_f32(iree_math_f32_to_f16(-0.25f)));
@@ -257,6 +257,10 @@
   EXPECT_NE(nan, nan);
 }
 
+//==============================================================================
+// Bfloat16 support
+//==============================================================================
+
 TEST(BF16ConversionTest, F32ToBF16) {
   // Within range, normal truncation.
   EXPECT_EQ(0x3e80, iree_math_f32_to_bf16(0.25f));
@@ -319,4 +323,204 @@
   EXPECT_NE(nan, nan);
 }
 
+//==============================================================================
+// F8E5M2 support
+//==============================================================================
+
+TEST(F8E5M2ConversionTest, F32ToF8E5M2) {
+  // See https://arxiv.org/pdf/2209.05433.pdf, Table 1.
+  constexpr float kF8E5M2Max = 57344.f;
+  constexpr float kF8E5M2Min = 1.f / 16384.f;
+  // Within range, normal truncation.
+  EXPECT_EQ(0x34, iree_math_f32_to_f8e5m2(0.25f));
+  EXPECT_EQ(0xd6, iree_math_f32_to_f8e5m2(-100.375f));
+  EXPECT_EQ(0x7A, iree_math_f32_to_f8e5m2(49152.f));
+  EXPECT_EQ(0xFA, iree_math_f32_to_f8e5m2(-49152.f));
+  EXPECT_EQ(0x7B, iree_math_f32_to_f8e5m2(kF8E5M2Max));
+  EXPECT_EQ(0xFB, iree_math_f32_to_f8e5m2(-kF8E5M2Max));
+  EXPECT_EQ(0x04, iree_math_f32_to_f8e5m2(kF8E5M2Min));
+  EXPECT_EQ(0x84, iree_math_f32_to_f8e5m2(-kF8E5M2Min));
+  // Infinity
+  EXPECT_EQ(0x7c, iree_math_f32_to_f8e5m2(INFINITY));
+  EXPECT_EQ(0xfc, iree_math_f32_to_f8e5m2(-INFINITY));
+  // Overflow
+  EXPECT_EQ(0x7C, iree_math_f32_to_f8e5m2(FLT_MAX));
+  EXPECT_EQ(0xFC, iree_math_f32_to_f8e5m2(-FLT_MAX));
+  // Important case to test: overflow due to rounding to nearest-even of 61440
+  // to 65536.
+  EXPECT_EQ(0x7B, iree_math_f32_to_f8e5m2(61439.f));
+  EXPECT_EQ(0xFB, iree_math_f32_to_f8e5m2(-61439.f));
+  EXPECT_EQ(0x7C, iree_math_f32_to_f8e5m2(61440.f));
+  EXPECT_EQ(0xFC, iree_math_f32_to_f8e5m2(-61440.f));
+  EXPECT_EQ(0x7C, iree_math_f32_to_f8e5m2(65536.f));
+  EXPECT_EQ(0xFC, iree_math_f32_to_f8e5m2(-65536.f));
+  // Underflow
+  EXPECT_EQ(0, iree_math_f32_to_f8e5m2(FLT_MIN));
+  EXPECT_EQ(0x80, iree_math_f32_to_f8e5m2(-FLT_MIN));
+  // Denormals may or may not get flushed to zero. Accept both ways.
+  uint16_t positive_denormal = iree_math_f32_to_f8e5m2(kF8E5M2Min / 2);
+  EXPECT_TRUE(positive_denormal == 0 || positive_denormal == 0x02);
+  uint16_t negative_denormal = iree_math_f32_to_f8e5m2(-kF8E5M2Min / 2);
+  EXPECT_TRUE(negative_denormal == 0x80 || negative_denormal == 0x82);
+}
+
+TEST(F8E5M2ConversionTest, F32ToF8E5M2ToF32) {
+  // See https://arxiv.org/pdf/2209.05433.pdf, Table 1.
+  constexpr float kF8E5M2Max = 57344.f;
+  constexpr float kF8E5M2Min = 1.f / 16384.f;
+  // Within range, should just round.
+  EXPECT_EQ(0.25f, iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(0.25f)));
+  EXPECT_EQ(-0.25f, iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(-0.25f)));
+  EXPECT_EQ(96.f, iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(100.375f)));
+  EXPECT_EQ(-96.f, iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(-100.375f)));
+  EXPECT_EQ(96.f, iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(96.f)));
+  EXPECT_EQ(-96.f, iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(-96.f)));
+  EXPECT_EQ(kF8E5M2Max,
+            iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(kF8E5M2Max)));
+  EXPECT_EQ(-kF8E5M2Max,
+            iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(-kF8E5M2Max)));
+  EXPECT_EQ(kF8E5M2Min,
+            iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(kF8E5M2Min)));
+  EXPECT_EQ(-kF8E5M2Min,
+            iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(-kF8E5M2Min)));
+  // Powers of two should always be exactly representable across the
+  // exponent range.
+  EXPECT_EQ(32768.f, iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(32768.f)));
+  EXPECT_EQ(-32768.f,
+            iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(-32768.f)));
+  // Overflow
+  EXPECT_EQ(INFINITY,
+            iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(FLT_MAX)));
+  EXPECT_EQ(-INFINITY,
+            iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(-FLT_MAX)));
+  EXPECT_GT(kF8E5M2Max + 1.f,
+            iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(kF8E5M2Max + 1.f)));
+  // Underflow
+  EXPECT_EQ(0.0f, iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(FLT_MIN)));
+  EXPECT_EQ(0.0f, iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(-FLT_MIN)));
+  // Denormals may or may not get flushed to zero. Accept both ways.
+  float positive_denormal =
+      iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(kF8E5M2Min / 2));
+  EXPECT_TRUE(positive_denormal == 0.0f ||
+              positive_denormal == 3.05175781e-05f);
+  // Inf and Nan
+  EXPECT_EQ(INFINITY,
+            iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(INFINITY)));
+  EXPECT_EQ(-INFINITY,
+            iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(-INFINITY)));
+  // Check that the result is a Nan with nan != nan.
+  float nan = iree_math_f8e5m2_to_f32(iree_math_f32_to_f8e5m2(NAN));
+  EXPECT_NE(nan, nan);
+}
+
+//==============================================================================
+// F8E4M3 support
+//==============================================================================
+
+TEST(F8E4M3ConversionTest, F32ToF8E4M3) {
+  // See https://arxiv.org/pdf/2209.05433.pdf, Table 1.
+  // The F8E4M3 format is special: it has no infinities, and has some larger
+  // finite values instead.
+  constexpr float kF8E4M3Max = 448.f;
+  constexpr float kF8E4M3Min = 1.f / 64.f;
+  // Within range, normal truncation.
+  EXPECT_EQ(0x28, iree_math_f32_to_f8e4m3(0.25f));
+  EXPECT_EQ(0xED, iree_math_f32_to_f8e4m3(-100.375f));
+  // Extra large finite values thanks to not having infinities.
+  EXPECT_EQ(0x78, iree_math_f32_to_f8e4m3(256.0f));
+  EXPECT_EQ(0x79, iree_math_f32_to_f8e4m3(288.0f));
+  EXPECT_EQ(0x7A, iree_math_f32_to_f8e4m3(320.0f));
+  EXPECT_EQ(0x7B, iree_math_f32_to_f8e4m3(352.0f));
+  EXPECT_EQ(0x7C, iree_math_f32_to_f8e4m3(384.0f));
+  EXPECT_EQ(0x7D, iree_math_f32_to_f8e4m3(416.0f));
+  EXPECT_EQ(0x7E, iree_math_f32_to_f8e4m3(kF8E4M3Max));
+  EXPECT_EQ(0xFE, iree_math_f32_to_f8e4m3(-kF8E4M3Max));
+  // Min normal values.
+  EXPECT_EQ(0x08, iree_math_f32_to_f8e4m3(kF8E4M3Min));
+  EXPECT_EQ(0x88, iree_math_f32_to_f8e4m3(-kF8E4M3Min));
+  // Infinity
+  EXPECT_EQ(0x7F, iree_math_f32_to_f8e4m3(INFINITY));
+  EXPECT_EQ(0xfF, iree_math_f32_to_f8e4m3(-INFINITY));
+  // Overflow
+  EXPECT_EQ(0x7F, iree_math_f32_to_f8e4m3(FLT_MAX));
+  EXPECT_EQ(0xFF, iree_math_f32_to_f8e4m3(-FLT_MAX));
+  // Test some round-to-nearest-even behavior.
+  EXPECT_EQ(0x70, iree_math_f32_to_f8e4m3(136.0f));
+  EXPECT_EQ(0x72, iree_math_f32_to_f8e4m3(152.0f));
+  EXPECT_EQ(0x72, iree_math_f32_to_f8e4m3(168.0f));
+  EXPECT_EQ(0x74, iree_math_f32_to_f8e4m3(184.0f));
+  EXPECT_EQ(0x78, iree_math_f32_to_f8e4m3(272.0f));
+  EXPECT_EQ(0x7A, iree_math_f32_to_f8e4m3(304.0f));
+  EXPECT_EQ(0x7A, iree_math_f32_to_f8e4m3(336.0f));
+  EXPECT_EQ(0x7C, iree_math_f32_to_f8e4m3(368.0f));
+  // Important case to test: overflow due to rounding to nearest-even of 465
+  // to 512, while 464 gets rounded to nearest-even 448, not overflowing.
+  EXPECT_EQ(0x7E, iree_math_f32_to_f8e4m3(464.f));
+  EXPECT_EQ(0xFE, iree_math_f32_to_f8e4m3(-464.f));
+  EXPECT_EQ(0x7F, iree_math_f32_to_f8e4m3(465.f));
+  EXPECT_EQ(0xFF, iree_math_f32_to_f8e4m3(-465.f));
+  // Largest float value in the same exponent bucket, a tricky case.
+  EXPECT_EQ(0x7F, iree_math_f32_to_f8e4m3(511.f));
+  EXPECT_EQ(0xFF, iree_math_f32_to_f8e4m3(-511.f));
+  // Underflow
+  EXPECT_EQ(0, iree_math_f32_to_f8e4m3(FLT_MIN));
+  EXPECT_EQ(0x80, iree_math_f32_to_f8e4m3(-FLT_MIN));
+  // Denormals may or may not get flushed to zero. Accept both ways.
+  uint8_t positive_denormal = iree_math_f32_to_f8e4m3(kF8E4M3Min / 2);
+  EXPECT_TRUE(positive_denormal == 0 || positive_denormal == 0x04);
+  uint8_t negative_denormal = iree_math_f32_to_f8e4m3(-kF8E4M3Min / 2);
+  EXPECT_TRUE(negative_denormal == 0x80 || negative_denormal == 0x84);
+}
+
+TEST(F8E4M3ConversionTest, F32ToF8E4M3ToF32) {
+  // See https://arxiv.org/pdf/2209.05433.pdf, Table 1.
+  // The F8E4M3 format is special: it has no infinities, and has some larger
+  // finite values instead.
+  constexpr float kF8E4M3Max = 448.f;
+  constexpr float kF8E4M3Min = 1.f / 64.f;
+  // Within range, should just round.
+  EXPECT_EQ(0.25f, iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(0.25f)));
+  EXPECT_EQ(-0.25f, iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(-0.25f)));
+  EXPECT_EQ(104.f, iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(100.375f)));
+  EXPECT_EQ(-104.f,
+            iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(-100.375f)));
+  EXPECT_EQ(104.f, iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(100.4f)));
+  EXPECT_EQ(-104.f, iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(-100.4f)));
+  EXPECT_EQ(kF8E4M3Max,
+            iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(kF8E4M3Max)));
+  EXPECT_EQ(-kF8E4M3Max,
+            iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(-kF8E4M3Max)));
+  EXPECT_EQ(kF8E4M3Min,
+            iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(kF8E4M3Min)));
+  EXPECT_EQ(-kF8E4M3Min,
+            iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(-kF8E4M3Min)));
+  // Powers of two should always be exactly representable across the
+  // exponent range.
+  EXPECT_EQ(256.f, iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(256.f)));
+  EXPECT_EQ(-256.f, iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(-256.f)));
+  // Overflow
+  EXPECT_TRUE(
+      std::isnan(iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(FLT_MAX))));
+  EXPECT_TRUE(
+      std::isnan(iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(-FLT_MAX))));
+  EXPECT_GT(kF8E4M3Max + 1.f,
+            iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(kF8E4M3Max + 1.f)));
+  // Underflow
+  EXPECT_EQ(0.0f, iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(FLT_MIN)));
+  EXPECT_EQ(0.0f, iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(-FLT_MIN)));
+  // Denormals may or may not get flushed to zero. Accept both ways.
+  float positive_denormal =
+      iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(kF8E4M3Min / 2));
+  EXPECT_TRUE(positive_denormal == 0.0f ||
+              positive_denormal == 3.05175781e-05f);
+  // Inf and Nan
+  EXPECT_TRUE(
+      std::isnan(iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(INFINITY))));
+  EXPECT_TRUE(
+      std::isnan(iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(-INFINITY))));
+  // Check that the result is a Nan with nan != nan.
+  float nan = iree_math_f8e4m3_to_f32(iree_math_f32_to_f8e4m3(NAN));
+  EXPECT_NE(nan, nan);
+}
+
 }  // namespace