Simplifications in e2e matmul tests (#18889)

Two commits:
1. Stop inferring `acc_type`. Require specifying it. Only a few tests
were relying on the inferrence.
2. Stop special-casing narrow float types (only using f32 as ABI type,
generating `arith.truncf` internally). This was only needed when these
narrow float types were not supported in the rest of IREE.

Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
diff --git a/tests/e2e/matmul/BUILD.bazel b/tests/e2e/matmul/BUILD.bazel
index 635ee0c..a82bfb6 100644
--- a/tests/e2e/matmul/BUILD.bazel
+++ b/tests/e2e/matmul/BUILD.bazel
@@ -360,6 +360,7 @@
     generator = ":generate_e2e_matmul_tests",
     generator_args = [
         "--lhs_rhs_type=%s" % lhs_rhs_type,
+        "--acc_type=%s" % acc_type,
         "--shapes=small",
     ],
     target_backends_and_drivers = [
@@ -367,9 +368,9 @@
     ],
     test_runner = "//tools/testing/e2e:iree-e2e-matmul-test",
     test_type = "matmul",
-) for lhs_rhs_type in [
-    "i8",
-    "f32",
+) for (lhs_rhs_type, acc_type) in [
+    ("i8", "i32"),
+    ("f32", "f32"),
 ]]
 
 ###########################################################################
@@ -383,6 +384,7 @@
     generator = ":generate_e2e_matmul_tests",
     generator_args = [
         "--lhs_rhs_type=f32",
+        "--acc_type=f32",
         "--shapes=easy_large_static",
         "--compilation_info=LLVMGPUMatmulSimt",
     ],
@@ -411,6 +413,7 @@
     generator = ":generate_e2e_matmul_tests",
     generator_args = [
         "--lhs_rhs_type=f32",
+        "--acc_type=f32",
         "--shapes=easy_large_static",
         "--compilation_info=LLVMGPUMatmulTensorCore",
     ],
@@ -437,6 +440,7 @@
     generator = ":generate_e2e_matmul_tests",
     generator_args = [
         "--lhs_rhs_type=f32",
+        "--acc_type=f32",
     ],
     tags = [
         # CUDA cuInit fails with sanitizer on.
@@ -461,6 +465,7 @@
     generator = ":generate_e2e_matmul_tests",
     generator_args = [
         "--lhs_rhs_type=f16",
+        "--acc_type=f32",
     ],
     tags = [
         # CUDA cuInit fails with sanitizer on.
@@ -486,6 +491,7 @@
     generator = ":generate_e2e_matmul_tests",
     generator_args = [
         "--lhs_rhs_type=f32",
+        "--acc_type=f32",
         "--shapes=easy_large_static",
         "--compilation_info=LLVMGPUMatmulTensorCoreMmaSync",
     ],
@@ -513,6 +519,7 @@
     generator = ":generate_e2e_matmul_tests",
     generator_args = [
         "--lhs_rhs_type=f16",
+        "--acc_type=f32",
         "--shapes=easy_large_static",
         "--compilation_info=LLVMGPUMatmulTensorCore",
     ],
@@ -540,6 +547,7 @@
     generator = ":generate_e2e_matmul_tests",
     generator_args = [
         "--lhs_rhs_type=f16",
+        "--acc_type=f32",
         "--shapes=easy_large_static",
         "--compilation_info=LLVMGPUMatmulTensorCoreMmaSync",
     ],
@@ -566,6 +574,7 @@
     generator = ":generate_e2e_matmul_tests",
     generator_args = [
         "--lhs_rhs_type=%s" % lhs_rhs_type,
+        "--acc_type=%s" % acc_type,
     ],
     tags = [
         # CUDA cuInit fails with sanitizer on.
@@ -580,8 +589,8 @@
     ],
     test_runner = "//tools/testing/e2e:iree-e2e-matmul-test",
     test_type = "matmul",
-) for lhs_rhs_type in [
-    "f32",
+) for (lhs_rhs_type, acc_type) in [
+    ("f32", "f32"),
 ]]
 
 ###########################################################################
@@ -598,6 +607,7 @@
     generator = ":generate_e2e_matmul_tests",
     generator_args = [
         "--lhs_rhs_type=%s" % lhs_rhs_type,
+        "--acc_type=%s" % acc_type,
         "--shapes=easy_large_static",
         "--compilation_info=SPIRVVectorizeMali",
     ],
@@ -611,10 +621,10 @@
     ],
     test_runner = "//tools/testing/e2e:iree-e2e-matmul-test",
     test_type = "matmul",
-) for lhs_rhs_type in [
-    "i8",
-    "f16",
-    "f32",
+) for (lhs_rhs_type, acc_type) in [
+    ("i8", "i32"),
+    ("f16", "f32"),
+    ("f32", "f32"),
 ]]
 
 [iree_generated_e2e_runner_test(
@@ -625,6 +635,7 @@
     generator = ":generate_e2e_matmul_tests",
     generator_args = [
         "--lhs_rhs_type=%s" % lhs_rhs_type,
+        "--acc_type=%s" % acc_type,
         "--shapes=easy_large_static",
         "--compilation_info=SPIRVVectorizeNVIDIA",
     ],
@@ -637,10 +648,10 @@
     ],
     test_runner = "//tools/testing/e2e:iree-e2e-matmul-test",
     test_type = "matmul",
-) for lhs_rhs_type in [
-    "i8",
-    "f16",
-    "f32",
+) for (lhs_rhs_type, acc_type) in [
+    ("i8", "i32"),
+    ("f16", "f32"),
+    ("f32", "f32"),
 ]]
 
 iree_generated_e2e_runner_test(
@@ -651,6 +662,7 @@
     generator = ":generate_e2e_matmul_tests",
     generator_args = [
         "--lhs_rhs_type=f16",
+        "--acc_type=f32",
         "--shapes=easy_large_static",
         "--compilation_info=SPIRVCooperativeMatrixVectorize",
     ],
diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt
index aeb18c3..98a4ff1 100644
--- a/tests/e2e/matmul/CMakeLists.txt
+++ b/tests/e2e/matmul/CMakeLists.txt
@@ -927,6 +927,7 @@
     "generate_e2e_matmul_tests.py"
   GENERATOR_ARGS
     "--lhs_rhs_type=i8"
+    "--acc_type=i32"
     "--shapes=small"
   TEST_RUNNER
     iree_tools_testing_e2e_iree-e2e-matmul-test
@@ -948,6 +949,7 @@
     "generate_e2e_matmul_tests.py"
   GENERATOR_ARGS
     "--lhs_rhs_type=f32"
+    "--acc_type=f32"
     "--shapes=small"
   TEST_RUNNER
     iree_tools_testing_e2e_iree-e2e-matmul-test
@@ -969,6 +971,7 @@
     "generate_e2e_matmul_tests.py"
   GENERATOR_ARGS
     "--lhs_rhs_type=f32"
+    "--acc_type=f32"
     "--shapes=easy_large_static"
     "--compilation_info=LLVMGPUMatmulSimt"
   TEST_RUNNER
@@ -994,6 +997,7 @@
     "generate_e2e_matmul_tests.py"
   GENERATOR_ARGS
     "--lhs_rhs_type=f32"
+    "--acc_type=f32"
     "--shapes=easy_large_static"
     "--compilation_info=LLVMGPUMatmulTensorCore"
   TEST_RUNNER
@@ -1021,6 +1025,7 @@
     "generate_e2e_matmul_tests.py"
   GENERATOR_ARGS
     "--lhs_rhs_type=f32"
+    "--acc_type=f32"
   TEST_RUNNER
     iree_tools_testing_e2e_iree-e2e-matmul-test
   TARGET_BACKENDS
@@ -1046,6 +1051,7 @@
     "generate_e2e_matmul_tests.py"
   GENERATOR_ARGS
     "--lhs_rhs_type=f16"
+    "--acc_type=f32"
   TEST_RUNNER
     iree_tools_testing_e2e_iree-e2e-matmul-test
   TARGET_BACKENDS
@@ -1071,6 +1077,7 @@
     "generate_e2e_matmul_tests.py"
   GENERATOR_ARGS
     "--lhs_rhs_type=f32"
+    "--acc_type=f32"
     "--shapes=easy_large_static"
     "--compilation_info=LLVMGPUMatmulTensorCoreMmaSync"
   TEST_RUNNER
@@ -1098,6 +1105,7 @@
     "generate_e2e_matmul_tests.py"
   GENERATOR_ARGS
     "--lhs_rhs_type=f16"
+    "--acc_type=f32"
     "--shapes=easy_large_static"
     "--compilation_info=LLVMGPUMatmulTensorCore"
   TEST_RUNNER
@@ -1125,6 +1133,7 @@
     "generate_e2e_matmul_tests.py"
   GENERATOR_ARGS
     "--lhs_rhs_type=f16"
+    "--acc_type=f32"
     "--shapes=easy_large_static"
     "--compilation_info=LLVMGPUMatmulTensorCoreMmaSync"
   TEST_RUNNER
@@ -1152,6 +1161,7 @@
     "generate_e2e_matmul_tests.py"
   GENERATOR_ARGS
     "--lhs_rhs_type=f32"
+    "--acc_type=f32"
   TEST_RUNNER
     iree_tools_testing_e2e_iree-e2e-matmul-test
   TARGET_BACKENDS
@@ -1177,6 +1187,7 @@
     "generate_e2e_matmul_tests.py"
   GENERATOR_ARGS
     "--lhs_rhs_type=i8"
+    "--acc_type=i32"
     "--shapes=easy_large_static"
     "--compilation_info=SPIRVVectorizeMali"
   TEST_RUNNER
@@ -1201,6 +1212,7 @@
     "generate_e2e_matmul_tests.py"
   GENERATOR_ARGS
     "--lhs_rhs_type=f16"
+    "--acc_type=f32"
     "--shapes=easy_large_static"
     "--compilation_info=SPIRVVectorizeMali"
   TEST_RUNNER
@@ -1225,6 +1237,7 @@
     "generate_e2e_matmul_tests.py"
   GENERATOR_ARGS
     "--lhs_rhs_type=f32"
+    "--acc_type=f32"
     "--shapes=easy_large_static"
     "--compilation_info=SPIRVVectorizeMali"
   TEST_RUNNER
@@ -1249,6 +1262,7 @@
     "generate_e2e_matmul_tests.py"
   GENERATOR_ARGS
     "--lhs_rhs_type=i8"
+    "--acc_type=i32"
     "--shapes=easy_large_static"
     "--compilation_info=SPIRVVectorizeNVIDIA"
   TEST_RUNNER
@@ -1273,6 +1287,7 @@
     "generate_e2e_matmul_tests.py"
   GENERATOR_ARGS
     "--lhs_rhs_type=f16"
+    "--acc_type=f32"
     "--shapes=easy_large_static"
     "--compilation_info=SPIRVVectorizeNVIDIA"
   TEST_RUNNER
@@ -1297,6 +1312,7 @@
     "generate_e2e_matmul_tests.py"
   GENERATOR_ARGS
     "--lhs_rhs_type=f32"
+    "--acc_type=f32"
     "--shapes=easy_large_static"
     "--compilation_info=SPIRVVectorizeNVIDIA"
   TEST_RUNNER
@@ -1321,6 +1337,7 @@
     "generate_e2e_matmul_tests.py"
   GENERATOR_ARGS
     "--lhs_rhs_type=f16"
+    "--acc_type=f32"
     "--shapes=easy_large_static"
     "--compilation_info=SPIRVCooperativeMatrixVectorize"
   TEST_RUNNER
diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py
index b5dac41..cd6f8eb 100644
--- a/tests/e2e/matmul/generate_e2e_matmul_tests.py
+++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py
@@ -545,20 +545,6 @@
     return s.value or "DYN"
 
 
-# Gets friendlier form/type that we can use as arg types which we can cast into the target_type.
-def cast_argtype_if_required(target_type: MatrixElemTypeId):
-    if target_type == MatrixElemTypeId.F8E4M3FNUZ:
-        return MatrixElemTypeId.F32
-    return target_type
-
-
-# Gets the op needed to cast/convert from the friendly form/type into the target_type.
-def get_castback_from_arg_op(target_type: MatrixElemTypeId):
-    if target_type == MatrixElemTypeId.F8E4M3FNUZ:
-        return "arith.truncf"
-    return ValueError(f"Unhandled castback type of {target_type}")
-
-
 # Describes the fully resolved shape dimensions of all 3 input matrices,
 # LHS, RHS, and Accumulator, in a testcase.
 # Each value is a string, which may either represent a positive integer such as "123",
@@ -659,9 +645,8 @@
     acc_r = int_or_question_mark(shapes.acc_rows)
     acc_c = int_or_question_mark(shapes.acc_cols)
 
-    casted_lhs_rhs_type = cast_argtype_if_required(lhs_rhs_type)
-    lhs_tensor_type = f"tensor<{lhs_r}x{lhs_c}x{casted_lhs_rhs_type.value}>"
-    rhs_tensor_type = f"tensor<{rhs_r}x{rhs_c}x{casted_lhs_rhs_type.value}>"
+    lhs_tensor_type = f"tensor<{lhs_r}x{lhs_c}x{lhs_rhs_type.value}>"
+    rhs_tensor_type = f"tensor<{rhs_r}x{rhs_c}x{lhs_rhs_type.value}>"
     acc_tensor_type = f"tensor<{acc_r}x{acc_c}x{acc_type.value}>"
 
     if transpose_rhs:
@@ -680,15 +665,6 @@
         func_definition = func_definition + compilation_info_string
         generate_function.compilation_index += 1
     compute = f"  %result = {op_name} {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
-    if casted_lhs_rhs_type != lhs_rhs_type:
-        castback_op = get_castback_from_arg_op(lhs_rhs_type)
-        compute_lhs_tensor_type = f"tensor<{lhs_r}x{lhs_c}x{lhs_rhs_type.value}>"
-        compute_rhs_tensor_type = f"tensor<{rhs_r}x{rhs_c}x{lhs_rhs_type.value}>"
-        compute = (
-            f"  %lhs_casted = {castback_op} %lhs: {lhs_tensor_type} to {compute_lhs_tensor_type}\n"
-            f"  %rhs_casted = {castback_op} %rhs: {rhs_tensor_type} to {compute_rhs_tensor_type}\n"
-            f"  %result = {op_name} {compilation_info_attr}ins(%lhs_casted, %rhs_casted: {compute_lhs_tensor_type}, {compute_rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}"
-        )
     if shape.accumulate:
         signature = f"({lhs_tensor_type}, {rhs_tensor_type}, {acc_tensor_type}) -> {acc_tensor_type}"
         import_declaration = f"func.func private @module.{func_name}(%lhs: !hal.buffer_view, %rhs: !hal.buffer_view, %acc: !hal.buffer_view) -> !hal.buffer_view"
@@ -818,9 +794,8 @@
         rhs_shape = [shape.k, shape.n]
         transpose_rhs = 0
 
-    casted_lhs_rhs_type = cast_argtype_if_required(lhs_rhs_type)
-    op = op + generate_random_matrix("lhs", lhs_shape, casted_lhs_rhs_type)
-    op = op + generate_random_matrix("rhs", rhs_shape, casted_lhs_rhs_type)
+    op = op + generate_random_matrix("lhs", lhs_shape, lhs_rhs_type)
+    op = op + generate_random_matrix("rhs", rhs_shape, lhs_rhs_type)
     if shape.accumulate:
         op = op + generate_random_matrix("acc", [shape.m, shape.n], acc_type)
         # TODO(#16168): there's a bug with in-place input->output aliasing and
@@ -919,16 +894,15 @@
             "f8E5M2FNUZ",
             "f8E4M3FNUZ",
         ],
-        help="Numeric type of input matrices",
+        help="Numeric type of input LHS and RHS matrices",
         required=True,
     )
     parser.add_argument(
         "--acc_type",
         type=str,
         choices=["i32", "f32", "f16", "bf16"],
-        help="Numeric type of input matrices",
-        default="",
-        required=False,
+        help="Numeric type of the accumulator and result matrices",
+        required=True,
     )
     parser.add_argument(
         "--shapes",
@@ -1005,30 +979,9 @@
         file.write(module_definition)
 
 
-# For now, the accumulator type can always be inferred from the input LHS/RHS
-# type, so we do that. That is temporary: eventually there will be cases
-# where the same input types are used with different accumulator types, e.g.
-# f16 inputs with both f16 and f32 accumulator.
-def infer_acc_type(lhs_rhs_type: MatrixElemTypeId, acc_type: MatrixElemTypeId):
-    if acc_type != MatrixElemTypeId.NONE:
-        return acc_type
-    if lhs_rhs_type == MatrixElemTypeId.F8E5M2:
-        return MatrixElemTypeId.F32
-    if lhs_rhs_type == MatrixElemTypeId.F8E4M3:
-        return MatrixElemTypeId.F32
-    if lhs_rhs_type == MatrixElemTypeId.F8E5M2FNUZ:
-        return MatrixElemTypeId.F32
-    if lhs_rhs_type == MatrixElemTypeId.F8E4M3FNUZ:
-        return MatrixElemTypeId.F32
-    if lhs_rhs_type == MatrixElemTypeId.I8:
-        return MatrixElemTypeId.I32
-    return lhs_rhs_type
-
-
 def main(args):
     lhs_rhs_type = MatrixElemTypeId(args.lhs_rhs_type)
     acc_type = MatrixElemTypeId(args.acc_type)
-    acc_type = infer_acc_type(lhs_rhs_type, acc_type)
     shapes_id = ShapesId(args.shapes)
     compilation_info_id = CompilationInfoId(args.compilation_info)