Simplifications in e2e matmul tests (#18889)
Two commits:
1. Stop inferring `acc_type`. Require specifying it. Only a few tests
were relying on the inferrence.
2. Stop special-casing narrow float types (only using f32 as ABI type,
generating `arith.truncf` internally). This was only needed when these
narrow float types were not supported in the rest of IREE.
Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
diff --git a/tests/e2e/matmul/BUILD.bazel b/tests/e2e/matmul/BUILD.bazel
index 635ee0c..a82bfb6 100644
--- a/tests/e2e/matmul/BUILD.bazel
+++ b/tests/e2e/matmul/BUILD.bazel
@@ -360,6 +360,7 @@
generator = ":generate_e2e_matmul_tests",
generator_args = [
"--lhs_rhs_type=%s" % lhs_rhs_type,
+ "--acc_type=%s" % acc_type,
"--shapes=small",
],
target_backends_and_drivers = [
@@ -367,9 +368,9 @@
],
test_runner = "//tools/testing/e2e:iree-e2e-matmul-test",
test_type = "matmul",
-) for lhs_rhs_type in [
- "i8",
- "f32",
+) for (lhs_rhs_type, acc_type) in [
+ ("i8", "i32"),
+ ("f32", "f32"),
]]
###########################################################################
@@ -383,6 +384,7 @@
generator = ":generate_e2e_matmul_tests",
generator_args = [
"--lhs_rhs_type=f32",
+ "--acc_type=f32",
"--shapes=easy_large_static",
"--compilation_info=LLVMGPUMatmulSimt",
],
@@ -411,6 +413,7 @@
generator = ":generate_e2e_matmul_tests",
generator_args = [
"--lhs_rhs_type=f32",
+ "--acc_type=f32",
"--shapes=easy_large_static",
"--compilation_info=LLVMGPUMatmulTensorCore",
],
@@ -437,6 +440,7 @@
generator = ":generate_e2e_matmul_tests",
generator_args = [
"--lhs_rhs_type=f32",
+ "--acc_type=f32",
],
tags = [
# CUDA cuInit fails with sanitizer on.
@@ -461,6 +465,7 @@
generator = ":generate_e2e_matmul_tests",
generator_args = [
"--lhs_rhs_type=f16",
+ "--acc_type=f32",
],
tags = [
# CUDA cuInit fails with sanitizer on.
@@ -486,6 +491,7 @@
generator = ":generate_e2e_matmul_tests",
generator_args = [
"--lhs_rhs_type=f32",
+ "--acc_type=f32",
"--shapes=easy_large_static",
"--compilation_info=LLVMGPUMatmulTensorCoreMmaSync",
],
@@ -513,6 +519,7 @@
generator = ":generate_e2e_matmul_tests",
generator_args = [
"--lhs_rhs_type=f16",
+ "--acc_type=f32",
"--shapes=easy_large_static",
"--compilation_info=LLVMGPUMatmulTensorCore",
],
@@ -540,6 +547,7 @@
generator = ":generate_e2e_matmul_tests",
generator_args = [
"--lhs_rhs_type=f16",
+ "--acc_type=f32",
"--shapes=easy_large_static",
"--compilation_info=LLVMGPUMatmulTensorCoreMmaSync",
],
@@ -566,6 +574,7 @@
generator = ":generate_e2e_matmul_tests",
generator_args = [
"--lhs_rhs_type=%s" % lhs_rhs_type,
+ "--acc_type=%s" % acc_type,
],
tags = [
# CUDA cuInit fails with sanitizer on.
@@ -580,8 +589,8 @@
],
test_runner = "//tools/testing/e2e:iree-e2e-matmul-test",
test_type = "matmul",
-) for lhs_rhs_type in [
- "f32",
+) for (lhs_rhs_type, acc_type) in [
+ ("f32", "f32"),
]]
###########################################################################
@@ -598,6 +607,7 @@
generator = ":generate_e2e_matmul_tests",
generator_args = [
"--lhs_rhs_type=%s" % lhs_rhs_type,
+ "--acc_type=%s" % acc_type,
"--shapes=easy_large_static",
"--compilation_info=SPIRVVectorizeMali",
],
@@ -611,10 +621,10 @@
],
test_runner = "//tools/testing/e2e:iree-e2e-matmul-test",
test_type = "matmul",
-) for lhs_rhs_type in [
- "i8",
- "f16",
- "f32",
+) for (lhs_rhs_type, acc_type) in [
+ ("i8", "i32"),
+ ("f16", "f32"),
+ ("f32", "f32"),
]]
[iree_generated_e2e_runner_test(
@@ -625,6 +635,7 @@
generator = ":generate_e2e_matmul_tests",
generator_args = [
"--lhs_rhs_type=%s" % lhs_rhs_type,
+ "--acc_type=%s" % acc_type,
"--shapes=easy_large_static",
"--compilation_info=SPIRVVectorizeNVIDIA",
],
@@ -637,10 +648,10 @@
],
test_runner = "//tools/testing/e2e:iree-e2e-matmul-test",
test_type = "matmul",
-) for lhs_rhs_type in [
- "i8",
- "f16",
- "f32",
+) for (lhs_rhs_type, acc_type) in [
+ ("i8", "i32"),
+ ("f16", "f32"),
+ ("f32", "f32"),
]]
iree_generated_e2e_runner_test(
@@ -651,6 +662,7 @@
generator = ":generate_e2e_matmul_tests",
generator_args = [
"--lhs_rhs_type=f16",
+ "--acc_type=f32",
"--shapes=easy_large_static",
"--compilation_info=SPIRVCooperativeMatrixVectorize",
],
diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt
index aeb18c3..98a4ff1 100644
--- a/tests/e2e/matmul/CMakeLists.txt
+++ b/tests/e2e/matmul/CMakeLists.txt
@@ -927,6 +927,7 @@
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=i8"
+ "--acc_type=i32"
"--shapes=small"
TEST_RUNNER
iree_tools_testing_e2e_iree-e2e-matmul-test
@@ -948,6 +949,7 @@
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f32"
+ "--acc_type=f32"
"--shapes=small"
TEST_RUNNER
iree_tools_testing_e2e_iree-e2e-matmul-test
@@ -969,6 +971,7 @@
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f32"
+ "--acc_type=f32"
"--shapes=easy_large_static"
"--compilation_info=LLVMGPUMatmulSimt"
TEST_RUNNER
@@ -994,6 +997,7 @@
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f32"
+ "--acc_type=f32"
"--shapes=easy_large_static"
"--compilation_info=LLVMGPUMatmulTensorCore"
TEST_RUNNER
@@ -1021,6 +1025,7 @@
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f32"
+ "--acc_type=f32"
TEST_RUNNER
iree_tools_testing_e2e_iree-e2e-matmul-test
TARGET_BACKENDS
@@ -1046,6 +1051,7 @@
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f16"
+ "--acc_type=f32"
TEST_RUNNER
iree_tools_testing_e2e_iree-e2e-matmul-test
TARGET_BACKENDS
@@ -1071,6 +1077,7 @@
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f32"
+ "--acc_type=f32"
"--shapes=easy_large_static"
"--compilation_info=LLVMGPUMatmulTensorCoreMmaSync"
TEST_RUNNER
@@ -1098,6 +1105,7 @@
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f16"
+ "--acc_type=f32"
"--shapes=easy_large_static"
"--compilation_info=LLVMGPUMatmulTensorCore"
TEST_RUNNER
@@ -1125,6 +1133,7 @@
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f16"
+ "--acc_type=f32"
"--shapes=easy_large_static"
"--compilation_info=LLVMGPUMatmulTensorCoreMmaSync"
TEST_RUNNER
@@ -1152,6 +1161,7 @@
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f32"
+ "--acc_type=f32"
TEST_RUNNER
iree_tools_testing_e2e_iree-e2e-matmul-test
TARGET_BACKENDS
@@ -1177,6 +1187,7 @@
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=i8"
+ "--acc_type=i32"
"--shapes=easy_large_static"
"--compilation_info=SPIRVVectorizeMali"
TEST_RUNNER
@@ -1201,6 +1212,7 @@
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f16"
+ "--acc_type=f32"
"--shapes=easy_large_static"
"--compilation_info=SPIRVVectorizeMali"
TEST_RUNNER
@@ -1225,6 +1237,7 @@
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f32"
+ "--acc_type=f32"
"--shapes=easy_large_static"
"--compilation_info=SPIRVVectorizeMali"
TEST_RUNNER
@@ -1249,6 +1262,7 @@
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=i8"
+ "--acc_type=i32"
"--shapes=easy_large_static"
"--compilation_info=SPIRVVectorizeNVIDIA"
TEST_RUNNER
@@ -1273,6 +1287,7 @@
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f16"
+ "--acc_type=f32"
"--shapes=easy_large_static"
"--compilation_info=SPIRVVectorizeNVIDIA"
TEST_RUNNER
@@ -1297,6 +1312,7 @@
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f32"
+ "--acc_type=f32"
"--shapes=easy_large_static"
"--compilation_info=SPIRVVectorizeNVIDIA"
TEST_RUNNER
@@ -1321,6 +1337,7 @@
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f16"
+ "--acc_type=f32"
"--shapes=easy_large_static"
"--compilation_info=SPIRVCooperativeMatrixVectorize"
TEST_RUNNER
diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py
index b5dac41..cd6f8eb 100644
--- a/tests/e2e/matmul/generate_e2e_matmul_tests.py
+++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py
@@ -545,20 +545,6 @@
return s.value or "DYN"
-# Gets friendlier form/type that we can use as arg types which we can cast into the target_type.
-def cast_argtype_if_required(target_type: MatrixElemTypeId):
- if target_type == MatrixElemTypeId.F8E4M3FNUZ:
- return MatrixElemTypeId.F32
- return target_type
-
-
-# Gets the op needed to cast/convert from the friendly form/type into the target_type.
-def get_castback_from_arg_op(target_type: MatrixElemTypeId):
- if target_type == MatrixElemTypeId.F8E4M3FNUZ:
- return "arith.truncf"
- return ValueError(f"Unhandled castback type of {target_type}")
-
-
# Describes the fully resolved shape dimensions of all 3 input matrices,
# LHS, RHS, and Accumulator, in a testcase.
# Each value is a string, which may either represent a positive integer such as "123",
@@ -659,9 +645,8 @@
acc_r = int_or_question_mark(shapes.acc_rows)
acc_c = int_or_question_mark(shapes.acc_cols)
- casted_lhs_rhs_type = cast_argtype_if_required(lhs_rhs_type)
- lhs_tensor_type = f"tensor<{lhs_r}x{lhs_c}x{casted_lhs_rhs_type.value}>"
- rhs_tensor_type = f"tensor<{rhs_r}x{rhs_c}x{casted_lhs_rhs_type.value}>"
+ lhs_tensor_type = f"tensor<{lhs_r}x{lhs_c}x{lhs_rhs_type.value}>"
+ rhs_tensor_type = f"tensor<{rhs_r}x{rhs_c}x{lhs_rhs_type.value}>"
acc_tensor_type = f"tensor<{acc_r}x{acc_c}x{acc_type.value}>"
if transpose_rhs:
@@ -680,15 +665,6 @@
func_definition = func_definition + compilation_info_string
generate_function.compilation_index += 1
compute = f" %result = {op_name} {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
- if casted_lhs_rhs_type != lhs_rhs_type:
- castback_op = get_castback_from_arg_op(lhs_rhs_type)
- compute_lhs_tensor_type = f"tensor<{lhs_r}x{lhs_c}x{lhs_rhs_type.value}>"
- compute_rhs_tensor_type = f"tensor<{rhs_r}x{rhs_c}x{lhs_rhs_type.value}>"
- compute = (
- f" %lhs_casted = {castback_op} %lhs: {lhs_tensor_type} to {compute_lhs_tensor_type}\n"
- f" %rhs_casted = {castback_op} %rhs: {rhs_tensor_type} to {compute_rhs_tensor_type}\n"
- f" %result = {op_name} {compilation_info_attr}ins(%lhs_casted, %rhs_casted: {compute_lhs_tensor_type}, {compute_rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}"
- )
if shape.accumulate:
signature = f"({lhs_tensor_type}, {rhs_tensor_type}, {acc_tensor_type}) -> {acc_tensor_type}"
import_declaration = f"func.func private @module.{func_name}(%lhs: !hal.buffer_view, %rhs: !hal.buffer_view, %acc: !hal.buffer_view) -> !hal.buffer_view"
@@ -818,9 +794,8 @@
rhs_shape = [shape.k, shape.n]
transpose_rhs = 0
- casted_lhs_rhs_type = cast_argtype_if_required(lhs_rhs_type)
- op = op + generate_random_matrix("lhs", lhs_shape, casted_lhs_rhs_type)
- op = op + generate_random_matrix("rhs", rhs_shape, casted_lhs_rhs_type)
+ op = op + generate_random_matrix("lhs", lhs_shape, lhs_rhs_type)
+ op = op + generate_random_matrix("rhs", rhs_shape, lhs_rhs_type)
if shape.accumulate:
op = op + generate_random_matrix("acc", [shape.m, shape.n], acc_type)
# TODO(#16168): there's a bug with in-place input->output aliasing and
@@ -919,16 +894,15 @@
"f8E5M2FNUZ",
"f8E4M3FNUZ",
],
- help="Numeric type of input matrices",
+ help="Numeric type of input LHS and RHS matrices",
required=True,
)
parser.add_argument(
"--acc_type",
type=str,
choices=["i32", "f32", "f16", "bf16"],
- help="Numeric type of input matrices",
- default="",
- required=False,
+ help="Numeric type of the accumulator and result matrices",
+ required=True,
)
parser.add_argument(
"--shapes",
@@ -1005,30 +979,9 @@
file.write(module_definition)
-# For now, the accumulator type can always be inferred from the input LHS/RHS
-# type, so we do that. That is temporary: eventually there will be cases
-# where the same input types are used with different accumulator types, e.g.
-# f16 inputs with both f16 and f32 accumulator.
-def infer_acc_type(lhs_rhs_type: MatrixElemTypeId, acc_type: MatrixElemTypeId):
- if acc_type != MatrixElemTypeId.NONE:
- return acc_type
- if lhs_rhs_type == MatrixElemTypeId.F8E5M2:
- return MatrixElemTypeId.F32
- if lhs_rhs_type == MatrixElemTypeId.F8E4M3:
- return MatrixElemTypeId.F32
- if lhs_rhs_type == MatrixElemTypeId.F8E5M2FNUZ:
- return MatrixElemTypeId.F32
- if lhs_rhs_type == MatrixElemTypeId.F8E4M3FNUZ:
- return MatrixElemTypeId.F32
- if lhs_rhs_type == MatrixElemTypeId.I8:
- return MatrixElemTypeId.I32
- return lhs_rhs_type
-
-
def main(args):
lhs_rhs_type = MatrixElemTypeId(args.lhs_rhs_type)
acc_type = MatrixElemTypeId(args.acc_type)
- acc_type = infer_acc_type(lhs_rhs_type, acc_type)
shapes_id = ShapesId(args.shapes)
compilation_info_id = CompilationInfoId(args.compilation_info)