[rocdl] Add e2e matmul test for cdna3 matrix core (#16510)

This commit adds support for matmul correctness
test over CodeGen paths that target cnda3 mfma.
Right now we only cover (M, K) x (K, N) -> (M, N)
matmul for v_mfma_f32_16x16x16_f16.
diff --git a/build_tools/bazel/iree_e2e_matmul_test.bzl b/build_tools/bazel/iree_e2e_matmul_test.bzl
index f2a9e86..dcf8f0f 100644
--- a/build_tools/bazel/iree_e2e_matmul_test.bzl
+++ b/build_tools/bazel/iree_e2e_matmul_test.bzl
@@ -218,8 +218,8 @@
 
     tests = []
     for backend, driver in target_backends_and_drivers:
-        # CUDA backend/driver not supported by Bazel build.
-        if backend == "cuda" or driver == "cuda":
+        # CUDA/ROCm backend/driver not supported by Bazel build.
+        if backend == "cuda" or driver == "cuda" or backend == "rocm" or driver == "hip":
             continue
         suite_entry_name = "_".join([name, backend, driver])
         iree_single_backend_e2e_matmul_test(
diff --git a/build_tools/cmake/iree_check_test.cmake b/build_tools/cmake/iree_check_test.cmake
index 89c7c94..c6cb3ec 100644
--- a/build_tools/cmake/iree_check_test.cmake
+++ b/build_tools/cmake/iree_check_test.cmake
@@ -126,12 +126,6 @@
   if(DEFINED _RULE_DRIVER)
     string(TOUPPER ${_RULE_DRIVER} _UPPERCASE_DRIVER)
     string(REPLACE "-" "_" _NORMALIZED_DRIVER ${_UPPERCASE_DRIVER})
-    string(TOUPPER "${IREE_EXTERNAL_HAL_DRIVERS}" _UPPERCASE_EXTERNAL_DRIVERS)
-    string(REPLACE "-" "_" _NORMALIZED_EXTERNAL_DRIVERS "${_UPPERCASE_EXTERNAL_DRIVERS}")
-    if((NOT DEFINED IREE_HAL_DRIVER_${_NORMALIZED_DRIVER}) AND
-       (NOT ${_NORMALIZED_DRIVER} IN_LIST _NORMALIZED_EXTERNAL_DRIVERS))
-      message(SEND_ERROR "Unknown driver '${_RULE_DRIVER}'. Check IREE_HAL_DRIVER_*/IREE_EXTERNAL_HAL_DRIVERS options.")
-    endif()
     if((NOT IREE_HAL_DRIVER_${_NORMALIZED_DRIVER}) AND
        (NOT IREE_EXTERNAL_${_NORMALIZED_DRIVER}_HAL_DRIVER_FOUND))
       set(_TEST_DISABLED TRUE)
diff --git a/build_tools/cmake/iree_e2e_matmul_test.cmake b/build_tools/cmake/iree_e2e_matmul_test.cmake
index 15850b4..9bc630e 100644
--- a/build_tools/cmake/iree_e2e_matmul_test.cmake
+++ b/build_tools/cmake/iree_e2e_matmul_test.cmake
@@ -216,12 +216,6 @@
   if(DEFINED _RULE_DRIVER)
     string(TOUPPER ${_RULE_DRIVER} _UPPERCASE_DRIVER)
     string(REPLACE "-" "_" _NORMALIZED_DRIVER ${_UPPERCASE_DRIVER})
-    string(TOUPPER "${IREE_EXTERNAL_HAL_DRIVERS}" _UPPERCASE_EXTERNAL_DRIVERS)
-    string(REPLACE "-" "_" _NORMALIZED_EXTERNAL_DRIVERS "${_UPPERCASE_EXTERNAL_DRIVERS}")
-    if((NOT DEFINED IREE_HAL_DRIVER_${_NORMALIZED_DRIVER}) AND
-       (NOT ${_NORMALIZED_DRIVER} IN_LIST _NORMALIZED_EXTERNAL_DRIVERS))
-      message(SEND_ERROR "Unknown driver '${_RULE_DRIVER}'. Check IREE_HAL_DRIVER_*/IREE_EXTERNAL_HAL_DRIVERS options.")
-    endif()
     if((NOT IREE_HAL_DRIVER_${_NORMALIZED_DRIVER}) AND
        (NOT IREE_EXTERNAL_${_NORMALIZED_DRIVER}_HAL_DRIVER_FOUND))
       set(_TEST_DISABLED TRUE)
diff --git a/build_tools/cmake/test_riscv.sh b/build_tools/cmake/test_riscv.sh
index b9d31ea..e14855b 100755
--- a/build_tools/cmake/test_riscv.sh
+++ b/build_tools/cmake/test_riscv.sh
@@ -36,6 +36,8 @@
   "^driver=vulkan$"
   "^driver=metal$"
   "^driver=cuda$"
+  "^driver=rocm$"
+  "^driver=hip$"
   "^vulkan_uses_vk_khr_shader_float16_int8$"
   "^requires-filesystem$"
   "^requires-dtz$"
diff --git a/tests/e2e/matmul/BUILD.bazel b/tests/e2e/matmul/BUILD.bazel
index 8c2f03c..99ffe99 100644
--- a/tests/e2e/matmul/BUILD.bazel
+++ b/tests/e2e/matmul/BUILD.bazel
@@ -425,6 +425,39 @@
 
 ###########################################################################
 ##
+## ROCm backend
+##
+###########################################################################
+
+# Testing CDNA3 + matrix core path.
+# v_mfma_f32_16x16x16_f16
+iree_generated_e2e_matmul_test(
+    name = "e2e_matmul_rocm_f16_large_cdna3_matrixcore",
+    compiler_flags = [
+        "--iree-rocm-target-chip=gfx942",
+    ],
+    generator = ":generate_e2e_matmul_tests",
+    generator_args = [
+        "--lhs_rhs_type=f16",
+        "--acc_type=f32",
+        "--shapes=gpu_large_aligned",
+        "--compilation_info=LLVMGPUVectorDistribute",
+    ],
+    tags = [
+        "noasan",
+        "nomsan",
+        "notsan",
+        "noubsan",
+        "requires-gpu-cdna3",
+    ],
+    target_backends_and_drivers = [
+        ("rocm", "rocm"),
+    ],
+    test_runner = "//tools:iree-e2e-matmul-test",
+)
+
+###########################################################################
+##
 ## Vulkan backend
 ##
 ###########################################################################
diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt
index 187ed7b..9939e35 100644
--- a/tests/e2e/matmul/CMakeLists.txt
+++ b/tests/e2e/matmul/CMakeLists.txt
@@ -978,6 +978,32 @@
 
 iree_generated_e2e_matmul_test(
   NAME
+    e2e_matmul_rocm_f16_large_cdna3_matrixcore
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=f16"
+    "--acc_type=f32"
+    "--shapes=gpu_large_aligned"
+    "--compilation_info=LLVMGPUVectorDistribute"
+  TEST_RUNNER
+    iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "rocm"
+  DRIVERS
+    "rocm"
+  COMPILER_FLAGS
+    "--iree-rocm-target-chip=gfx942"
+  LABELS
+    "noasan"
+    "nomsan"
+    "notsan"
+    "noubsan"
+    "requires-gpu-cdna3"
+)
+
+iree_generated_e2e_matmul_test(
+  NAME
     e2e_matmul_vulkan_i8_large_valhall
   GENERATOR
     "generate_e2e_matmul_tests.py"
diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py
index 456947c..151ca42 100644
--- a/tests/e2e/matmul/generate_e2e_matmul_tests.py
+++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py
@@ -48,6 +48,7 @@
     LLVMGPUMatmulSimt = "LLVMGPUMatmulSimt"
     LLVMGPUMatmulTensorCore = "LLVMGPUMatmulTensorCore"
     LLVMGPUMatmulTensorCoreMmaSync = "LLVMGPUMatmulTensorCoreMmaSync"
+    LLVMGPUVectorDistribute = "LLVMGPUVectorDistribute"
     SPIRVCooperativeMatrixVectorize = "SPIRVCooperativeMatrixVectorize"
     SPIRVVectorizeMali = "SPIRVVectorizeMali"
     SPIRVVectorizeNVIDIA = "SPIRVVectorizeNVIDIA"
@@ -81,6 +82,28 @@
     accumulate: bool
 
 
+# Describes a workgroup and tiling schedule to target a specific MMA intrinsic.
+@dataclasses.dataclass
+class MMASchedule:
+    intrinsic: str
+    m_count: int  # Number of subgroups per workgroup along M
+    n_count: int  # Number of subgroups per workgroup along N
+    m_tile_count: int
+    n_tile_count: int
+    k_tile_count: int
+
+    def __str__(self):
+        return (
+            "mma_schedule = #iree_gpu.mma_schedule<"
+            + f"intrinsic = #iree_gpu.mfma_layout<{self.intrinsic}>, "
+            + f"subgroup_m_count = {self.m_count}, "
+            + f"subgroup_n_count = {self.n_count}, "
+            + f"subgroup_m_tile_count = {self.m_tile_count}, "
+            + f"subgroup_n_tile_count = {self.n_tile_count}, "
+            + f"subgroup_k_tile_count = {self.k_tile_count}>"
+        )
+
+
 # Describes how to construct compilation info for the testcase.
 @dataclasses.dataclass
 class CompilationInfo:
@@ -88,8 +111,8 @@
     tile_sizes: typing.List[typing.List[int]]
     # Translation Info
     dispatch_lowering_pass_pipeline: str
-    workload_per_wg: typing.List[int]
     software_pipeline_depth: int
+    mma_schedule: typing.Optional[MMASchedule]
     # Compilation info
     workgroup_size: typing.List[int]
 
@@ -155,8 +178,8 @@
         ]
     if shapes_id == ShapesId.GPU_LARGE_ALIGNED:
         return [
-            TestShape(m=256, k=128, n=512, accumulate=True),
-            TestShape(m=256, k=128, n=512, accumulate=False),
+            TestShape(m=512, k=128, n=512, accumulate=True),
+            TestShape(m=512, k=128, n=512, accumulate=False),
         ]
     if shapes_id == ShapesId.GPU_LARGE:
         return [
@@ -197,8 +220,8 @@
     workgroup_size: typing.List[int]
 
 
-# Constructs a TileWorkgroupSizePair for SPIRV Targets enforcing the constraints between
-# the workgroup_size and tile size
+# Constructs a TileWorkgroupSizePair for SPIR-V targets enforcing the
+# constraints between the workgroup_size and tile size
 def get_spirv_tile_workgroup_size_pair(
     workgroup_size, t_tile_k, t_tile_m=4, t_tile_n=4
 ):
@@ -224,12 +247,57 @@
     return tile_workgroup_size_pairs
 
 
+def get_rocm_test_compilation_infos(compilation_info_id: CompilationInfoId):
+    assert compilation_info_id == CompilationInfoId.LLVMGPUVectorDistribute
+
+    schedules = [
+        MMASchedule("F16_16x16x16_F32", 1, 1, 1, 1, 1),
+        MMASchedule("F16_16x16x16_F32", 1, 1, 1, 1, 2),
+        MMASchedule("F16_16x16x16_F32", 1, 1, 1, 2, 1),
+        MMASchedule("F16_16x16x16_F32", 1, 1, 2, 1, 1),
+        MMASchedule("F16_16x16x16_F32", 2, 2, 1, 1, 1),
+        MMASchedule("F16_16x16x16_F32", 2, 4, 2, 1, 2),
+        MMASchedule("F16_16x16x16_F32", 4, 2, 4, 2, 2),
+    ]
+
+    infos = []
+    for schedule in schedules:
+        if schedule.intrinsic == "F16_16x16x16_F32":
+            wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
+            wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
+            wg_tile_k = schedule.k_tile_count * 16
+        elif schedule.intrinsic == "F16_32x32x8_F32":
+            wg_tile_m = schedule.m_count * schedule.m_tile_count * 32
+            wg_tile_n = schedule.n_count * schedule.n_tile_count * 32
+            wg_tile_k = schedule.k_tile_count * 8
+        else:
+            raise NotImplementedError("unhandled intrinsic case")
+
+        workgroup_tile = [[wg_tile_m, wg_tile_n, wg_tile_k]]
+        workgroup_size = [schedule.n_count * 64, schedule.m_count, 1]
+        infos.append(
+            CompilationInfo(
+                tile_sizes=workgroup_tile,
+                dispatch_lowering_pass_pipeline=compilation_info_id.value,
+                workgroup_size=workgroup_size,
+                software_pipeline_depth=0,
+                mma_schedule=schedule,
+            )
+        )
+    return infos
+
+
 # Returns the list of CompilationInfo's to use for the CompilationInfoId.
 def get_test_compilation_infos(
     compilation_info_id: CompilationInfoId, lhs_rhs_type: MatrixElemTypeId
 ) -> typing.List[typing.Optional[CompilationInfo]]:
     if compilation_info_id == CompilationInfoId.NONE:
         return [None]
+
+    if compilation_info_id == CompilationInfoId.LLVMGPUVectorDistribute:
+        return get_rocm_test_compilation_infos(compilation_info_id)
+
+    software_pipeline_depth = 0
     if compilation_info_id == CompilationInfoId.LLVMGPUMatmulSimt:
         tile_workgroup_size_pairs = [
             TileWorkgroupSizePair([[32, 128, 32]], [32, 8, 1]),
@@ -240,6 +308,7 @@
             TileWorkgroupSizePair([[16, 64, 4]], [16, 2, 1]),
             TileWorkgroupSizePair([[1, 128, 8]], [32, 1, 1]),
         ]
+        software_pipeline_depth = 3
     elif compilation_info_id == CompilationInfoId.SPIRVCooperativeMatrixVectorize:
         tile_workgroup_size_pairs = [
             TileWorkgroupSizePair(
@@ -277,6 +346,7 @@
             tile_workgroup_size_pairs.append(
                 TileWorkgroupSizePair([[128, 128, 16]], [64, 2, 1])
             )
+        software_pipeline_depth = 3
 
     compilation_infos = []
     for tile_workgroup_size_pair in tile_workgroup_size_pairs:
@@ -284,11 +354,9 @@
             CompilationInfo(
                 tile_sizes=tile_workgroup_size_pair.tile_size,
                 dispatch_lowering_pass_pipeline=compilation_info_id.value,
-                workload_per_wg=[
-                    a for a in reversed(tile_workgroup_size_pair.tile_size[0:2])
-                ],
                 workgroup_size=tile_workgroup_size_pair.workgroup_size,
-                software_pipeline_depth=3,
+                software_pipeline_depth=software_pipeline_depth,
+                mma_schedule=None,
             )
         )
     return compilation_infos
@@ -389,7 +457,10 @@
         info = f"_for_{compilation_info.dispatch_lowering_pass_pipeline}_{tile_workgroup_key}"
 
     matmul_kind = "matmul_accumulate" if accumulate else "matmul"
-    return f"{matmul_kind}_{lhs_m}x{lhs_k}x{input_t}_times_{rhs_k}x{rhs_n}x{input_t}_into_{acc_m}x{acc_n}x{acc_t}{info}"
+    return (
+        f"{matmul_kind}_{lhs_m}x{lhs_k}x{input_t}_times_"
+        + f"{rhs_k}x{rhs_n}x{input_t}_into_{acc_m}x{acc_n}x{acc_t}{info}"
+    )
 
 
 # Represents a generated test function.
@@ -429,30 +500,26 @@
     func_definition = ""
     compilation_info_attr = ""
     if compilation_info:
-        if (
-            "SPIRV"
-            in compilation_info.dispatch_lowering_pass_pipeline
-            == "SPIRVVectorizeMali"
-        ):
-            dispatch_lowering_pass_pipeline = "SPIRVBaseVectorize"
-        elif (
-            compilation_info.dispatch_lowering_pass_pipeline
-            == "SPIRVCooperativeMatrixVectorize"
-        ):
-            dispatch_lowering_pass_pipeline = "SPIRVCooperativeMatrixVectorize"
-        elif compilation_info.dispatch_lowering_pass_pipeline == "SPIRVVectorizeNVIDIA":
+        requested_pipeline = compilation_info.dispatch_lowering_pass_pipeline
+        compiler_pipeline = requested_pipeline
+        if requested_pipeline == "SPIRVVectorizeMali":
+            compiler_pipeline = "SPIRVBaseVectorize"
+        elif requested_pipeline == "SPIRVCooperativeMatrixVectorize":
+            compiler_pipeline = "SPIRVCooperativeMatrixVectorize"
+        elif requested_pipeline == "SPIRVVectorizeNVIDIA":
             # TODO: change to test SPIRVMatmulPromoteVectorize too
-            dispatch_lowering_pass_pipeline = "SPIRVBaseVectorize"
-        else:
-            dispatch_lowering_pass_pipeline = (
-                compilation_info.dispatch_lowering_pass_pipeline
-            )
+            compiler_pipeline = "SPIRVBaseVectorize"
+
+        mma_schedule = ""
+        if compilation_info.mma_schedule is not None:
+            mma_schedule = ", {}".format(compilation_info.mma_schedule)
         compilation_info_string = (
-            f"#compilation{generate_function.compilation_index} = #iree_codegen.compilation_info<\n"
+            f"#compilation{generate_function.compilation_index} = "
+            "#iree_codegen.compilation_info<\n"
             f"  lowering_config = <tile_sizes = {compilation_info.tile_sizes}>,\n"
-            f"  translation_info = <{dispatch_lowering_pass_pipeline},\n"
+            f"  translation_info = <{compiler_pipeline},\n"
             f"  {{ pipeline_depth = {compilation_info.software_pipeline_depth}, "
-            f"  store_stage = 1 }}>,\n"
+            f"  store_stage = 1{mma_schedule} }}>,\n"
             f"  workgroup_size = {compilation_info.workgroup_size_str()}>\n"
         )
         compilation_info_attr = (