Add a reference matmul to e2e matmul tests (#16280)
This PR adds a i32*i32+i32 reference matmul and option to use i32
lhs_rhs type in the generator python script.
diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py
index ca07e0c..456947c 100644
--- a/tests/e2e/matmul/generate_e2e_matmul_tests.py
+++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py
@@ -659,7 +659,7 @@
parser.add_argument(
"--lhs_rhs_type",
type=str,
- choices=["i8", "f32", "f16", "bf16"],
+ choices=["i32", "i8", "f32", "f16", "bf16"],
help="Numeric type of input matrices",
required=True,
)
diff --git a/tools/iree-e2e-matmul-test.cc b/tools/iree-e2e-matmul-test.cc
index 0a3b60a..a0bdabb 100644
--- a/tools/iree-e2e-matmul-test.cc
+++ b/tools/iree-e2e-matmul-test.cc
@@ -226,6 +226,11 @@
// [i32 <= i8 * i8 + i32]
REFERENCE_MATMUL(int8_t, int8_t, int32_t, int32_t)
+// Reference mamtul instantiations from macro REFERENCE_MATMUL
+// for the int8_t input, int32_t accumlation, and int32_t result.
+// [i32 <= i32 * i32 + i32]
+REFERENCE_MATMUL(int32_t, int32_t, int32_t, int32_t)
+
// Reference mamtul for the f16 input, f16 accumlation, and f16 result.
// [f16 <= f16 * f16 + f16]
static void reference_matmul_f16_f16_f16_f16(
@@ -311,6 +316,13 @@
m_size, k_size, n_size, lhs_type, rhs_type, acc_type,
(const int8_t*)lhs_data, (const int8_t*)rhs_data,
(const int32_t*)acc_data, (int32_t*)result_data, m, n);
+ } else if (iree_hal_element_type_is_integer(lhs_type, 32) &&
+ iree_hal_element_type_is_integer(rhs_type, 32) &&
+ iree_hal_element_type_is_integer(acc_type, 32)) {
+ reference_matmul_int32_t_int32_t_int32_t_int32_t(
+ m_size, k_size, n_size, lhs_type, rhs_type, acc_type,
+ (const int32_t*)lhs_data, (const int32_t*)rhs_data,
+ (const int32_t*)acc_data, (int32_t*)result_data, m, n);
} else if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16 &&
rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16 &&
acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16) {