Add e2e tests for F8E5M2FNUZ and F8E4M3FNUZ data-tiled MFMA on CDNA3 (#18888)

Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt
index 36e1255..aeb18c3 100644
--- a/tests/e2e/matmul/CMakeLists.txt
+++ b/tests/e2e/matmul/CMakeLists.txt
@@ -1526,7 +1526,7 @@
 
 iree_generated_e2e_runner_test(
   NAME
-    e2e_matmul_rocm_f16_large_cdna3_mfma_data_tiled
+    e2e_matmul_rocm_f16_cdna3_mfma_data_tiled
   TEST_TYPE
     matmul
   GENERATOR
@@ -1555,7 +1555,7 @@
 
 iree_generated_e2e_runner_test(
   NAME
-    e2e_matmul_rocm_i8_large_cdna3_mfma_data_tiled
+    e2e_matmul_rocm_i8_cdna3_mfma_data_tiled
   TEST_TYPE
     matmul
   GENERATOR
@@ -1584,7 +1584,7 @@
 
 iree_generated_e2e_runner_test(
   NAME
-    e2e_matmul_rocm_f32_large_cdna3_mfma_data_tiled
+    e2e_matmul_rocm_f32_cdna3_mfma_data_tiled
   TEST_TYPE
     matmul
   GENERATOR
@@ -1611,6 +1611,64 @@
     "requires-gpu-cdna3"
 )
 
+iree_generated_e2e_runner_test(
+  NAME
+    e2e_matmul_rocm_f8E5M2FNUZ_cdna3_mfma_data_tiled
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=f8E5M2FNUZ"
+    "--acc_type=f32"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "rocm"
+  DRIVERS
+    "hip"
+  COMPILER_FLAGS
+    ${IREE_HIP_TEST_COMPILER_FLAGS}
+    "--iree-opt-data-tiling"
+    "--iree-global-opt-experimental-rocm-data-tiling"
+    "--iree-global-opt-enable-early-materialization=true"
+  LABELS
+    "noasan"
+    "nomsan"
+    "notsan"
+    "noubsan"
+    "requires-gpu-cdna3"
+)
+
+iree_generated_e2e_runner_test(
+  NAME
+    e2e_matmul_rocm_f8E4M3FNUZ_cdna3_mfma_data_tiled
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=f8E4M3FNUZ"
+    "--acc_type=f32"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "rocm"
+  DRIVERS
+    "hip"
+  COMPILER_FLAGS
+    ${IREE_HIP_TEST_COMPILER_FLAGS}
+    "--iree-opt-data-tiling"
+    "--iree-global-opt-experimental-rocm-data-tiling"
+    "--iree-global-opt-enable-early-materialization=true"
+  LABELS
+    "noasan"
+    "nomsan"
+    "notsan"
+    "noubsan"
+    "requires-gpu-cdna3"
+)
+
 endif()
 
 elseif(IREE_HIP_TEST_TARGET_CHIP MATCHES "^gfx11")
diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py
index 30d210d..b5dac41 100644
--- a/tests/e2e/matmul/generate_e2e_matmul_tests.py
+++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py
@@ -27,8 +27,11 @@
     I32 = "i32"
     F32 = "f32"
     F16 = "f16"
-    F8E4M3FNUZ = "f8E4M3FNUZ"
     BF16 = "bf16"
+    F8E5M2 = "f8E5M2"
+    F8E4M3 = "f8E4M3"
+    F8E5M2FNUZ = "f8E5M2FNUZ"
+    F8E4M3FNUZ = "f8E4M3FNUZ"
 
 
 # Enumerates of the collections of shapes that we can generate tests for.
@@ -905,7 +908,17 @@
     parser.add_argument(
         "--lhs_rhs_type",
         type=str,
-        choices=["i32", "i8", "f32", "f16", "f8E4M3FNUZ", "bf16"],
+        choices=[
+            "i32",
+            "i8",
+            "f32",
+            "f16",
+            "bf16",
+            "f8E5M2",
+            "f8E4M3",
+            "f8E5M2FNUZ",
+            "f8E4M3FNUZ",
+        ],
         help="Numeric type of input matrices",
         required=True,
     )
@@ -999,6 +1012,12 @@
 def infer_acc_type(lhs_rhs_type: MatrixElemTypeId, acc_type: MatrixElemTypeId):
     if acc_type != MatrixElemTypeId.NONE:
         return acc_type
+    if lhs_rhs_type == MatrixElemTypeId.F8E5M2:
+        return MatrixElemTypeId.F32
+    if lhs_rhs_type == MatrixElemTypeId.F8E4M3:
+        return MatrixElemTypeId.F32
+    if lhs_rhs_type == MatrixElemTypeId.F8E5M2FNUZ:
+        return MatrixElemTypeId.F32
     if lhs_rhs_type == MatrixElemTypeId.F8E4M3FNUZ:
         return MatrixElemTypeId.F32
     if lhs_rhs_type == MatrixElemTypeId.I8:
diff --git a/tools/testing/e2e/iree-e2e-matmul-test.cc b/tools/testing/e2e/iree-e2e-matmul-test.cc
index 2309560..ce589e2 100644
--- a/tools/testing/e2e/iree-e2e-matmul-test.cc
+++ b/tools/testing/e2e/iree-e2e-matmul-test.cc
@@ -128,6 +128,29 @@
   result_data[n + m * n_size] = acc;
 }
 
+#define REFERENCE_MATMUL_F8(LHSTYPE, RHSTYPE)                                  \
+  static void reference_matmul_##LHSTYPE##_##RHSTYPE##_f32_f32(                \
+      iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size,     \
+      iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type,      \
+      iree_hal_element_type_t acc_type, bool transpose_rhs,                    \
+      const uint8_t* lhs_data, const uint8_t* rhs_data, const float* acc_data, \
+      float* result_data, iree_hal_dim_t m, iree_hal_dim_t n) {                \
+    float acc = acc_data ? acc_data[n + m * n_size] : 0;                       \
+    for (iree_hal_dim_t k = 0; k < k_size; ++k) {                              \
+      float lhs_float =                                                        \
+          iree_math_##LHSTYPE##_to_f32(lhs_data[k + m * k_size]);              \
+      float rhs_float = iree_math_##RHSTYPE##_to_f32(                          \
+          rhs_data[transpose_rhs ? k + n * k_size : n + k * n_size]);          \
+      acc += lhs_float * rhs_float;                                            \
+    }                                                                          \
+    result_data[n + m * n_size] = acc;                                         \
+  }
+
+REFERENCE_MATMUL_F8(f8e5m2, f8e5m2)
+REFERENCE_MATMUL_F8(f8e4m3, f8e4m3)
+REFERENCE_MATMUL_F8(f8e5m2fnuz, f8e5m2fnuz)
+REFERENCE_MATMUL_F8(f8e4m3fnuz, f8e4m3fnuz)
+
 // Helper for reference_matmul.
 // Computes one element in the result matrix.
 static iree_status_t reference_matmul_element(
@@ -185,6 +208,34 @@
         m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs,
         (const uint16_t*)lhs_data, (const uint16_t*)rhs_data,
         (const float*)acc_data, (float*)result_data, m, n);
+  } else if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2 &&
+             rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2 &&
+             acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
+    reference_matmul_f8e5m2_f8e5m2_f32_f32(
+        m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs,
+        (const uint8_t*)lhs_data, (const uint8_t*)rhs_data,
+        (const float*)acc_data, (float*)result_data, m, n);
+  } else if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3 &&
+             rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3 &&
+             acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
+    reference_matmul_f8e4m3_f8e4m3_f32_f32(
+        m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs,
+        (const uint8_t*)lhs_data, (const uint8_t*)rhs_data,
+        (const float*)acc_data, (float*)result_data, m, n);
+  } else if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ &&
+             rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ &&
+             acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
+    reference_matmul_f8e5m2fnuz_f8e5m2fnuz_f32_f32(
+        m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs,
+        (const uint8_t*)lhs_data, (const uint8_t*)rhs_data,
+        (const float*)acc_data, (float*)result_data, m, n);
+  } else if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ &&
+             rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ &&
+             acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
+    reference_matmul_f8e4m3fnuz_f8e4m3fnuz_f32_f32(
+        m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs,
+        (const uint8_t*)lhs_data, (const uint8_t*)rhs_data,
+        (const float*)acc_data, (float*)result_data, m, n);
   } else {
     return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
                             "unhandled combination of element types in matmul");
diff --git a/tools/testing/e2e/test_utils.c b/tools/testing/e2e/test_utils.c
index a7119dc..c54c719 100644
--- a/tools/testing/e2e/test_utils.c
+++ b/tools/testing/e2e/test_utils.c
@@ -93,6 +93,36 @@
   return result;
 }
 
+iree_test_utils_e2e_value_t iree_test_utils_value_make_f8E5M2(uint8_t value) {
+  iree_test_utils_e2e_value_t result;
+  result.type = IREE_TEST_UTILS_VALUE_TYPE_F8E5M2;
+  result.f8_u8 = value;
+  return result;
+}
+
+iree_test_utils_e2e_value_t iree_test_utils_value_make_f8E4M3(uint8_t value) {
+  iree_test_utils_e2e_value_t result;
+  result.type = IREE_TEST_UTILS_VALUE_TYPE_F8E4M3;
+  result.f8_u8 = value;
+  return result;
+}
+
+iree_test_utils_e2e_value_t iree_test_utils_value_make_f8E5M2FNUZ(
+    uint16_t value) {
+  iree_test_utils_e2e_value_t result;
+  result.type = IREE_TEST_UTILS_VALUE_TYPE_F8E5M2FNUZ;
+  result.f8_u8 = value;
+  return result;
+}
+
+iree_test_utils_e2e_value_t iree_test_utils_value_make_f8E4M3FNUZ(
+    uint16_t value) {
+  iree_test_utils_e2e_value_t result;
+  result.type = IREE_TEST_UTILS_VALUE_TYPE_F8E4M3FNUZ;
+  result.f8_u8 = value;
+  return result;
+}
+
 iree_test_utils_e2e_value_t iree_test_utils_value_make_f16(uint16_t value) {
   iree_test_utils_e2e_value_t result;
   result.type = IREE_TEST_UTILS_VALUE_TYPE_F16;
@@ -123,6 +153,14 @@
     return iree_test_utils_value_make_i16(((int16_t*)data)[index]);
   } else if (iree_hal_element_type_is_integer(result_type, 32)) {
     return iree_test_utils_value_make_i32(((int32_t*)data)[index]);
+  } else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2) {
+    return iree_test_utils_value_make_f8E5M2(((uint8_t*)data)[index]);
+  } else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3) {
+    return iree_test_utils_value_make_f8E4M3(((uint8_t*)data)[index]);
+  } else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ) {
+    return iree_test_utils_value_make_f8E5M2FNUZ(((uint8_t*)data)[index]);
+  } else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ) {
+    return iree_test_utils_value_make_f8E4M3FNUZ(((uint8_t*)data)[index]);
   } else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16) {
     return iree_test_utils_value_make_f16(((uint16_t*)data)[index]);
   } else if (result_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16) {
@@ -147,6 +185,22 @@
       return snprintf(buf, bufsize, "%" PRIi32, value.i32);
     case IREE_TEST_UTILS_VALUE_TYPE_I64:
       return snprintf(buf, bufsize, "%" PRIi64, value.i64);
+    case IREE_TEST_UTILS_VALUE_TYPE_F8E5M2:
+      return snprintf(buf, bufsize,
+                      precision == PRECISION_HIGH ? "%.3g" : "%.2g",
+                      iree_math_f8e5m2_to_f32(value.f8_u8));
+    case IREE_TEST_UTILS_VALUE_TYPE_F8E4M3:
+      return snprintf(buf, bufsize,
+                      precision == PRECISION_HIGH ? "%.3g" : "%.2g",
+                      iree_math_f8e4m3_to_f32(value.f8_u8));
+    case IREE_TEST_UTILS_VALUE_TYPE_F8E5M2FNUZ:
+      return snprintf(buf, bufsize,
+                      precision == PRECISION_HIGH ? "%.3g" : "%.2g",
+                      iree_math_f8e5m2fnuz_to_f32(value.f8_u8));
+    case IREE_TEST_UTILS_VALUE_TYPE_F8E4M3FNUZ:
+      return snprintf(buf, bufsize,
+                      precision == PRECISION_HIGH ? "%.3g" : "%.2g",
+                      iree_math_f8e4m3fnuz_to_f32(value.f8_u8));
     case IREE_TEST_UTILS_VALUE_TYPE_F16:
       return snprintf(buf, bufsize,
                       precision == PRECISION_HIGH ? "%.5g" : "%.4g",
@@ -257,6 +311,18 @@
     case IREE_HAL_ELEMENT_TYPE_BFLOAT_16:
       *(uint16_t*)dst = iree_math_f32_to_bf16((float)value);
       break;
+    case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2:
+      *(uint8_t*)dst = iree_math_f32_to_f8e5m2((float)value);
+      break;
+    case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3:
+      *(uint8_t*)dst = iree_math_f32_to_f8e4m3((float)value);
+      break;
+    case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ:
+      *(uint8_t*)dst = iree_math_f32_to_f8e5m2fnuz((float)value);
+      break;
+    case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ:
+      *(uint8_t*)dst = iree_math_f32_to_f8e4m3fnuz((float)value);
+      break;
     WRITE_ELEMENT_CASE(FLOAT_32, float)
     WRITE_ELEMENT_CASE(FLOAT_64, double)
     // clang-format on
@@ -296,6 +362,10 @@
       *max = +4;
       break;
     case IREE_HAL_ELEMENT_TYPE_BFLOAT_16:
+    case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2:
+    case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3:
+    case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ:
+    case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ:
       *min = -2;
       *max = +2;
       break;
diff --git a/tools/testing/e2e/test_utils.h b/tools/testing/e2e/test_utils.h
index f095537..46d99f1 100644
--- a/tools/testing/e2e/test_utils.h
+++ b/tools/testing/e2e/test_utils.h
@@ -48,6 +48,11 @@
   IREE_TEST_UTILS_VALUE_TYPE_F64 = 7,
   // bfloat16
   IREE_TEST_UTILS_VALUE_TYPE_BF16 = 8,
+  // 8-bit float types.
+  IREE_TEST_UTILS_VALUE_TYPE_F8E5M2 = 9,
+  IREE_TEST_UTILS_VALUE_TYPE_F8E4M3 = 10,
+  IREE_TEST_UTILS_VALUE_TYPE_F8E5M2FNUZ = 11,
+  IREE_TEST_UTILS_VALUE_TYPE_F8E4M3FNUZ = 12,
 } iree_test_utils_value_type_t;
 
 // Maximum size, in bytes, of any value type we can represent.
@@ -64,6 +69,7 @@
     float f32;
     uint16_t f16_u16;
     uint16_t bf16_u16;
+    uint8_t f8_u8;
     double f64;
     uint8_t value_storage[IREE_E2E_TEST_VALUE_STORAGE_SIZE];  // max size of all
                                                               // value types