Add e2e tests for F8E5M2FNUZ and F8E4M3FNUZ data-tiled MFMA on CDNA3 (#18888)
Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt
index 36e1255..aeb18c3 100644
--- a/tests/e2e/matmul/CMakeLists.txt
+++ b/tests/e2e/matmul/CMakeLists.txt
@@ -1526,7 +1526,7 @@
iree_generated_e2e_runner_test(
NAME
- e2e_matmul_rocm_f16_large_cdna3_mfma_data_tiled
+ e2e_matmul_rocm_f16_cdna3_mfma_data_tiled
TEST_TYPE
matmul
GENERATOR
@@ -1555,7 +1555,7 @@
iree_generated_e2e_runner_test(
NAME
- e2e_matmul_rocm_i8_large_cdna3_mfma_data_tiled
+ e2e_matmul_rocm_i8_cdna3_mfma_data_tiled
TEST_TYPE
matmul
GENERATOR
@@ -1584,7 +1584,7 @@
iree_generated_e2e_runner_test(
NAME
- e2e_matmul_rocm_f32_large_cdna3_mfma_data_tiled
+ e2e_matmul_rocm_f32_cdna3_mfma_data_tiled
TEST_TYPE
matmul
GENERATOR
@@ -1611,6 +1611,64 @@
"requires-gpu-cdna3"
)
+iree_generated_e2e_runner_test(
+ NAME
+ e2e_matmul_rocm_f8E5M2FNUZ_cdna3_mfma_data_tiled
+ TEST_TYPE
+ matmul
+ GENERATOR
+ "generate_e2e_matmul_tests.py"
+ GENERATOR_ARGS
+ "--lhs_rhs_type=f8E5M2FNUZ"
+ "--acc_type=f32"
+ TEST_RUNNER
+ iree_tools_testing_e2e_iree-e2e-matmul-test
+ TARGET_BACKENDS
+ "rocm"
+ DRIVERS
+ "hip"
+ COMPILER_FLAGS
+ ${IREE_HIP_TEST_COMPILER_FLAGS}
+ "--iree-opt-data-tiling"
+ "--iree-global-opt-experimental-rocm-data-tiling"
+ "--iree-global-opt-enable-early-materialization=true"
+ LABELS
+ "noasan"
+ "nomsan"
+ "notsan"
+ "noubsan"
+ "requires-gpu-cdna3"
+)
+
+iree_generated_e2e_runner_test(
+ NAME
+ e2e_matmul_rocm_f8E4M3FNUZ_cdna3_mfma_data_tiled
+ TEST_TYPE
+ matmul
+ GENERATOR
+ "generate_e2e_matmul_tests.py"
+ GENERATOR_ARGS
+ "--lhs_rhs_type=f8E4M3FNUZ"
+ "--acc_type=f32"
+ TEST_RUNNER
+ iree_tools_testing_e2e_iree-e2e-matmul-test
+ TARGET_BACKENDS
+ "rocm"
+ DRIVERS
+ "hip"
+ COMPILER_FLAGS
+ ${IREE_HIP_TEST_COMPILER_FLAGS}
+ "--iree-opt-data-tiling"
+ "--iree-global-opt-experimental-rocm-data-tiling"
+ "--iree-global-opt-enable-early-materialization=true"
+ LABELS
+ "noasan"
+ "nomsan"
+ "notsan"
+ "noubsan"
+ "requires-gpu-cdna3"
+)
+
endif()
elseif(IREE_HIP_TEST_TARGET_CHIP MATCHES "^gfx11")
diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py
index 30d210d..b5dac41 100644
--- a/tests/e2e/matmul/generate_e2e_matmul_tests.py
+++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py
@@ -27,8 +27,11 @@
I32 = "i32"
F32 = "f32"
F16 = "f16"
- F8E4M3FNUZ = "f8E4M3FNUZ"
BF16 = "bf16"
+ F8E5M2 = "f8E5M2"
+ F8E4M3 = "f8E4M3"
+ F8E5M2FNUZ = "f8E5M2FNUZ"
+ F8E4M3FNUZ = "f8E4M3FNUZ"
# Enumerates of the collections of shapes that we can generate tests for.
@@ -905,7 +908,17 @@
parser.add_argument(
"--lhs_rhs_type",
type=str,
- choices=["i32", "i8", "f32", "f16", "f8E4M3FNUZ", "bf16"],
+ choices=[
+ "i32",
+ "i8",
+ "f32",
+ "f16",
+ "bf16",
+ "f8E5M2",
+ "f8E4M3",
+ "f8E5M2FNUZ",
+ "f8E4M3FNUZ",
+ ],
help="Numeric type of input matrices",
required=True,
)
@@ -999,6 +1012,12 @@
def infer_acc_type(lhs_rhs_type: MatrixElemTypeId, acc_type: MatrixElemTypeId):
if acc_type != MatrixElemTypeId.NONE:
return acc_type
+ if lhs_rhs_type == MatrixElemTypeId.F8E5M2:
+ return MatrixElemTypeId.F32
+ if lhs_rhs_type == MatrixElemTypeId.F8E4M3:
+ return MatrixElemTypeId.F32
+ if lhs_rhs_type == MatrixElemTypeId.F8E5M2FNUZ:
+ return MatrixElemTypeId.F32
if lhs_rhs_type == MatrixElemTypeId.F8E4M3FNUZ:
return MatrixElemTypeId.F32
if lhs_rhs_type == MatrixElemTypeId.I8:
diff --git a/tools/testing/e2e/iree-e2e-matmul-test.cc b/tools/testing/e2e/iree-e2e-matmul-test.cc
index 2309560..ce589e2 100644
--- a/tools/testing/e2e/iree-e2e-matmul-test.cc
+++ b/tools/testing/e2e/iree-e2e-matmul-test.cc
@@ -128,6 +128,29 @@
result_data[n + m * n_size] = acc;
}
+#define REFERENCE_MATMUL_F8(LHSTYPE, RHSTYPE) \
+ static void reference_matmul_##LHSTYPE##_##RHSTYPE##_f32_f32( \
+ iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size, \
+ iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type, \
+ iree_hal_element_type_t acc_type, bool transpose_rhs, \
+ const uint8_t* lhs_data, const uint8_t* rhs_data, const float* acc_data, \
+ float* result_data, iree_hal_dim_t m, iree_hal_dim_t n) { \
+ float acc = acc_data ? acc_data[n + m * n_size] : 0; \
+ for (iree_hal_dim_t k = 0; k < k_size; ++k) { \
+ float lhs_float = \
+ iree_math_##LHSTYPE##_to_f32(lhs_data[k + m * k_size]); \
+ float rhs_float = iree_math_##RHSTYPE##_to_f32( \
+ rhs_data[transpose_rhs ? k + n * k_size : n + k * n_size]); \
+ acc += lhs_float * rhs_float; \
+ } \
+ result_data[n + m * n_size] = acc; \
+ }
+
+REFERENCE_MATMUL_F8(f8e5m2, f8e5m2)
+REFERENCE_MATMUL_F8(f8e4m3, f8e4m3)
+REFERENCE_MATMUL_F8(f8e5m2fnuz, f8e5m2fnuz)
+REFERENCE_MATMUL_F8(f8e4m3fnuz, f8e4m3fnuz)
+
// Helper for reference_matmul.
// Computes one element in the result matrix.
static iree_status_t reference_matmul_element(
@@ -185,6 +208,34 @@
m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs,
(const uint16_t*)lhs_data, (const uint16_t*)rhs_data,
(const float*)acc_data, (float*)result_data, m, n);
+ } else if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2 &&
+ rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2 &&
+ acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
+ reference_matmul_f8e5m2_f8e5m2_f32_f32(
+ m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs,
+ (const uint8_t*)lhs_data, (const uint8_t*)rhs_data,
+ (const float*)acc_data, (float*)result_data, m, n);
+ } else if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3 &&
+ rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3 &&
+ acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
+ reference_matmul_f8e4m3_f8e4m3_f32_f32(
+ m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs,
+ (const uint8_t*)lhs_data, (const uint8_t*)rhs_data,
+ (const float*)acc_data, (float*)result_data, m, n);
+ } else if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ &&
+ rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ &&
+ acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
+ reference_matmul_f8e5m2fnuz_f8e5m2fnuz_f32_f32(
+ m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs,
+ (const uint8_t*)lhs_data, (const uint8_t*)rhs_data,
+ (const float*)acc_data, (float*)result_data, m, n);
+ } else if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ &&
+ rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ &&
+ acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
+ reference_matmul_f8e4m3fnuz_f8e4m3fnuz_f32_f32(
+ m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs,
+ (const uint8_t*)lhs_data, (const uint8_t*)rhs_data,
+ (const float*)acc_data, (float*)result_data, m, n);
} else {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"unhandled combination of element types in matmul");
diff --git a/tools/testing/e2e/test_utils.c b/tools/testing/e2e/test_utils.c
index a7119dc..c54c719 100644
--- a/tools/testing/e2e/test_utils.c
+++ b/tools/testing/e2e/test_utils.c
@@ -93,6 +93,36 @@
return result;
}
+iree_test_utils_e2e_value_t iree_test_utils_value_make_f8E5M2(uint8_t value) {
+ iree_test_utils_e2e_value_t result;
+ result.type = IREE_TEST_UTILS_VALUE_TYPE_F8E5M2;
+ result.f8_u8 = value;
+ return result;
+}
+
+iree_test_utils_e2e_value_t iree_test_utils_value_make_f8E4M3(uint8_t value) {
+ iree_test_utils_e2e_value_t result;
+ result.type = IREE_TEST_UTILS_VALUE_TYPE_F8E4M3;
+ result.f8_u8 = value;
+ return result;
+}
+
+iree_test_utils_e2e_value_t iree_test_utils_value_make_f8E5M2FNUZ(
+ uint16_t value) {
+ iree_test_utils_e2e_value_t result;
+ result.type = IREE_TEST_UTILS_VALUE_TYPE_F8E5M2FNUZ;
+ result.f8_u8 = value;
+ return result;
+}
+
+iree_test_utils_e2e_value_t iree_test_utils_value_make_f8E4M3FNUZ(
+ uint16_t value) {
+ iree_test_utils_e2e_value_t result;
+ result.type = IREE_TEST_UTILS_VALUE_TYPE_F8E4M3FNUZ;
+ result.f8_u8 = value;
+ return result;
+}
+
iree_test_utils_e2e_value_t iree_test_utils_value_make_f16(uint16_t value) {
iree_test_utils_e2e_value_t result;
result.type = IREE_TEST_UTILS_VALUE_TYPE_F16;
@@ -123,6 +153,14 @@
return iree_test_utils_value_make_i16(((int16_t*)data)[index]);
} else if (iree_hal_element_type_is_integer(result_type, 32)) {
return iree_test_utils_value_make_i32(((int32_t*)data)[index]);
+ } else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2) {
+ return iree_test_utils_value_make_f8E5M2(((uint8_t*)data)[index]);
+ } else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3) {
+ return iree_test_utils_value_make_f8E4M3(((uint8_t*)data)[index]);
+ } else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ) {
+ return iree_test_utils_value_make_f8E5M2FNUZ(((uint8_t*)data)[index]);
+ } else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ) {
+ return iree_test_utils_value_make_f8E4M3FNUZ(((uint8_t*)data)[index]);
} else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16) {
return iree_test_utils_value_make_f16(((uint16_t*)data)[index]);
} else if (result_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16) {
@@ -147,6 +185,22 @@
return snprintf(buf, bufsize, "%" PRIi32, value.i32);
case IREE_TEST_UTILS_VALUE_TYPE_I64:
return snprintf(buf, bufsize, "%" PRIi64, value.i64);
+ case IREE_TEST_UTILS_VALUE_TYPE_F8E5M2:
+ return snprintf(buf, bufsize,
+ precision == PRECISION_HIGH ? "%.3g" : "%.2g",
+ iree_math_f8e5m2_to_f32(value.f8_u8));
+ case IREE_TEST_UTILS_VALUE_TYPE_F8E4M3:
+ return snprintf(buf, bufsize,
+ precision == PRECISION_HIGH ? "%.3g" : "%.2g",
+ iree_math_f8e4m3_to_f32(value.f8_u8));
+ case IREE_TEST_UTILS_VALUE_TYPE_F8E5M2FNUZ:
+ return snprintf(buf, bufsize,
+ precision == PRECISION_HIGH ? "%.3g" : "%.2g",
+ iree_math_f8e5m2fnuz_to_f32(value.f8_u8));
+ case IREE_TEST_UTILS_VALUE_TYPE_F8E4M3FNUZ:
+ return snprintf(buf, bufsize,
+ precision == PRECISION_HIGH ? "%.3g" : "%.2g",
+ iree_math_f8e4m3fnuz_to_f32(value.f8_u8));
case IREE_TEST_UTILS_VALUE_TYPE_F16:
return snprintf(buf, bufsize,
precision == PRECISION_HIGH ? "%.5g" : "%.4g",
@@ -257,6 +311,18 @@
case IREE_HAL_ELEMENT_TYPE_BFLOAT_16:
*(uint16_t*)dst = iree_math_f32_to_bf16((float)value);
break;
+ case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2:
+ *(uint8_t*)dst = iree_math_f32_to_f8e5m2((float)value);
+ break;
+ case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3:
+ *(uint8_t*)dst = iree_math_f32_to_f8e4m3((float)value);
+ break;
+ case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ:
+ *(uint8_t*)dst = iree_math_f32_to_f8e5m2fnuz((float)value);
+ break;
+ case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ:
+ *(uint8_t*)dst = iree_math_f32_to_f8e4m3fnuz((float)value);
+ break;
WRITE_ELEMENT_CASE(FLOAT_32, float)
WRITE_ELEMENT_CASE(FLOAT_64, double)
// clang-format on
@@ -296,6 +362,10 @@
*max = +4;
break;
case IREE_HAL_ELEMENT_TYPE_BFLOAT_16:
+ case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2:
+ case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3:
+ case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ:
+ case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ:
*min = -2;
*max = +2;
break;
diff --git a/tools/testing/e2e/test_utils.h b/tools/testing/e2e/test_utils.h
index f095537..46d99f1 100644
--- a/tools/testing/e2e/test_utils.h
+++ b/tools/testing/e2e/test_utils.h
@@ -48,6 +48,11 @@
IREE_TEST_UTILS_VALUE_TYPE_F64 = 7,
// bfloat16
IREE_TEST_UTILS_VALUE_TYPE_BF16 = 8,
+ // 8-bit float types.
+ IREE_TEST_UTILS_VALUE_TYPE_F8E5M2 = 9,
+ IREE_TEST_UTILS_VALUE_TYPE_F8E4M3 = 10,
+ IREE_TEST_UTILS_VALUE_TYPE_F8E5M2FNUZ = 11,
+ IREE_TEST_UTILS_VALUE_TYPE_F8E4M3FNUZ = 12,
} iree_test_utils_value_type_t;
// Maximum size, in bytes, of any value type we can represent.
@@ -64,6 +69,7 @@
float f32;
uint16_t f16_u16;
uint16_t bf16_u16;
+ uint8_t f8_u8;
double f64;
uint8_t value_storage[IREE_E2E_TEST_VALUE_STORAGE_SIZE]; // max size of all
// value types