Add WMMA to matmul test suite for LLVMGPUVectorDistribute (#17285)

These tests should not be hardcoding the chip target and instead use an
environment variable, but for now, to keep consistency with the MFMA
tests, the tests use this hardcoded chip target. This should be replaced
with an environment variable in future.
diff --git a/tests/e2e/matmul/BUILD.bazel b/tests/e2e/matmul/BUILD.bazel
index c8511d1..b6fe7b8 100644
--- a/tests/e2e/matmul/BUILD.bazel
+++ b/tests/e2e/matmul/BUILD.bazel
@@ -532,7 +532,7 @@
         "--lhs_rhs_type=f16",
         "--acc_type=f32",
         "--shapes=gpu_large_aligned",
-        "--compilation_info=LLVMGPUVectorDistribute",
+        "--compilation_info=LLVMGPUVectorDistributeMFMA",
     ],
     tags = [
         "noasan",
@@ -542,7 +542,7 @@
         "requires-gpu-cdna3",
     ],
     target_backends_and_drivers = [
-        ("rocm", "rocm"),
+        ("rocm", "hip"),
     ],
     test_runner = "//tools/testing/e2e:iree-e2e-matmul-test",
     test_type = "matmul",
@@ -559,7 +559,7 @@
         "--acc_type=f32",
         "--transpose_rhs",
         "--shapes=gpu_large_aligned",
-        "--compilation_info=LLVMGPUVectorDistribute",
+        "--compilation_info=LLVMGPUVectorDistributeMFMA",
     ],
     tags = [
         "noasan",
@@ -569,7 +569,61 @@
         "requires-gpu-cdna3",
     ],
     target_backends_and_drivers = [
-        ("rocm", "rocm"),
+        ("rocm", "hip"),
+    ],
+    test_runner = "//tools/testing/e2e:iree-e2e-matmul-test",
+    test_type = "matmul",
+)
+
+# Testing RDNA3 + matrix core path.
+iree_generated_e2e_runner_test(
+    name = "e2e_matmul_rocm_f16_large_rdna3_wmma",
+    compiler_flags = [
+        "--iree-rocm-target-chip=gfx1100",
+    ],
+    generator = ":generate_e2e_matmul_tests",
+    generator_args = [
+        "--lhs_rhs_type=f16",
+        "--acc_type=f32",
+        "--shapes=gpu_large_aligned",
+        "--compilation_info=LLVMGPUVectorDistributeWMMA",
+    ],
+    tags = [
+        "noasan",
+        "nomsan",
+        "notsan",
+        "noubsan",
+        "requires-gpu-rdna3",
+    ],
+    target_backends_and_drivers = [
+        ("rocm", "hip"),
+    ],
+    test_runner = "//tools/testing/e2e:iree-e2e-matmul-test",
+    test_type = "matmul",
+)
+
+iree_generated_e2e_runner_test(
+    name = "e2e_matmul_rocm_f16_large_rdna3_wmma_tb",
+    compiler_flags = [
+        "--iree-rocm-target-chip=gfx1100",
+    ],
+    generator = ":generate_e2e_matmul_tests",
+    generator_args = [
+        "--lhs_rhs_type=f16",
+        "--acc_type=f32",
+        "--transpose_rhs",
+        "--shapes=gpu_large_aligned",
+        "--compilation_info=LLVMGPUVectorDistributeWMMA",
+    ],
+    tags = [
+        "noasan",
+        "nomsan",
+        "notsan",
+        "noubsan",
+        "requires-gpu-rdna3",
+    ],
+    target_backends_and_drivers = [
+        ("rocm", "hip"),
     ],
     test_runner = "//tools/testing/e2e:iree-e2e-matmul-test",
     test_type = "matmul",
diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt
index f5fa735..cd4f225 100644
--- a/tests/e2e/matmul/CMakeLists.txt
+++ b/tests/e2e/matmul/CMakeLists.txt
@@ -1823,13 +1823,13 @@
     "--lhs_rhs_type=f16"
     "--acc_type=f32"
     "--shapes=gpu_large_aligned"
-    "--compilation_info=LLVMGPUVectorDistribute"
+    "--compilation_info=LLVMGPUVectorDistributeMFMA"
   TEST_RUNNER
     iree_tools_testing_e2e_iree-e2e-matmul-test
   TARGET_BACKENDS
     "rocm"
   DRIVERS
-    "rocm"
+    "hip"
   COMPILER_FLAGS
     "--iree-rocm-target-chip=gfx942"
   LABELS
@@ -1852,13 +1852,13 @@
     "--acc_type=f32"
     "--transpose_rhs"
     "--shapes=gpu_large_aligned"
-    "--compilation_info=LLVMGPUVectorDistribute"
+    "--compilation_info=LLVMGPUVectorDistributeMFMA"
   TEST_RUNNER
     iree_tools_testing_e2e_iree-e2e-matmul-test
   TARGET_BACKENDS
     "rocm"
   DRIVERS
-    "rocm"
+    "hip"
   COMPILER_FLAGS
     "--iree-rocm-target-chip=gfx942"
   LABELS
@@ -1871,6 +1871,63 @@
 
 iree_generated_e2e_runner_test(
   NAME
+    e2e_matmul_rocm_f16_large_rdna3_wmma
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=f16"
+    "--acc_type=f32"
+    "--shapes=gpu_large_aligned"
+    "--compilation_info=LLVMGPUVectorDistributeWMMA"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "rocm"
+  DRIVERS
+    "hip"
+  COMPILER_FLAGS
+    "--iree-rocm-target-chip=gfx1100"
+  LABELS
+    "noasan"
+    "nomsan"
+    "notsan"
+    "noubsan"
+    "requires-gpu-rdna3"
+)
+
+iree_generated_e2e_runner_test(
+  NAME
+    e2e_matmul_rocm_f16_large_rdna3_wmma_tb
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=f16"
+    "--acc_type=f32"
+    "--transpose_rhs"
+    "--shapes=gpu_large_aligned"
+    "--compilation_info=LLVMGPUVectorDistributeWMMA"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "rocm"
+  DRIVERS
+    "hip"
+  COMPILER_FLAGS
+    "--iree-rocm-target-chip=gfx1100"
+  LABELS
+    "noasan"
+    "nomsan"
+    "notsan"
+    "noubsan"
+    "requires-gpu-rdna3"
+)
+
+iree_generated_e2e_runner_test(
+  NAME
     e2e_matmul_vulkan_i8_large_valhall
   TEST_TYPE
     matmul
diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py
index 2e2aa79..596f610 100644
--- a/tests/e2e/matmul/generate_e2e_matmul_tests.py
+++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py
@@ -48,7 +48,8 @@
     LLVMGPUMatmulSimt = "LLVMGPUMatmulSimt"
     LLVMGPUMatmulTensorCore = "LLVMGPUMatmulTensorCore"
     LLVMGPUMatmulTensorCoreMmaSync = "LLVMGPUMatmulTensorCoreMmaSync"
-    LLVMGPUVectorDistribute = "LLVMGPUVectorDistribute"
+    LLVMGPUVectorDistributeMFMA = "LLVMGPUVectorDistributeMFMA"
+    LLVMGPUVectorDistributeWMMA = "LLVMGPUVectorDistributeWMMA"
     SPIRVCooperativeMatrixVectorize = "SPIRVCooperativeMatrixVectorize"
     SPIRVVectorizeMali = "SPIRVVectorizeMali"
     SPIRVVectorizeNVIDIA = "SPIRVVectorizeNVIDIA"
@@ -246,21 +247,43 @@
 
 
 def get_rocm_test_compilation_infos(compilation_info_id: CompilationInfoId):
-    assert compilation_info_id == CompilationInfoId.LLVMGPUVectorDistribute
-    # TODO: Add test for WMMA layout.
-    schedules = [
-        MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 1, 1, 1),
-        MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 1, 1, 2),
-        MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 1, 2, 1),
-        MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 2, 1, 1),
-        MMASchedule("MFMA_F16_16x16x16_F32", 2, 2, 1, 1, 1),
-        MMASchedule("MFMA_F16_16x16x16_F32", 2, 4, 2, 1, 2),
-        MMASchedule("MFMA_F16_16x16x16_F32", 4, 2, 4, 2, 2),
-        MMASchedule("MFMA_F16_32x32x8_F32", 1, 1, 1, 2, 2),
-        MMASchedule("MFMA_F16_32x32x8_F32", 2, 2, 1, 1, 1),
-        MMASchedule("MFMA_F16_32x32x8_F32", 1, 4, 2, 1, 2),
-        MMASchedule("MFMA_F16_32x32x8_F32", 4, 2, 1, 2, 4),
-    ]
+    intrinsic = ""
+    if compilation_info_id == CompilationInfoId.LLVMGPUVectorDistributeMFMA:
+        intrinsic = "MFMA"
+    elif compilation_info_id == CompilationInfoId.LLVMGPUVectorDistributeWMMA:
+        intrinsic = "WMMA"
+    else:
+        raise ValueError("Unknown pipeline for rocm")
+
+    schedules = []
+    if intrinsic == "MFMA":
+        schedules = [
+            MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 1, 1, 1),
+            MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 1, 1, 2),
+            MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 1, 2, 1),
+            MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 2, 1, 1),
+            MMASchedule("MFMA_F16_16x16x16_F32", 2, 2, 1, 1, 1),
+            MMASchedule("MFMA_F16_16x16x16_F32", 2, 4, 2, 1, 2),
+            MMASchedule("MFMA_F16_16x16x16_F32", 4, 2, 4, 2, 2),
+            MMASchedule("MFMA_F16_32x32x8_F32", 1, 1, 1, 2, 2),
+            MMASchedule("MFMA_F16_32x32x8_F32", 2, 2, 1, 1, 1),
+            MMASchedule("MFMA_F16_32x32x8_F32", 1, 4, 2, 1, 2),
+            MMASchedule("MFMA_F16_32x32x8_F32", 4, 2, 1, 2, 4),
+        ]
+    elif intrinsic == "WMMA":
+        schedules = [
+            MMASchedule("WMMA_F16_16x16x16_F32", 1, 1, 1, 1, 1),
+            MMASchedule("WMMA_F16_16x16x16_F32", 1, 1, 1, 1, 2),
+            MMASchedule("WMMA_F16_16x16x16_F32", 1, 1, 1, 2, 1),
+            MMASchedule("WMMA_F16_16x16x16_F32", 1, 1, 2, 1, 1),
+            MMASchedule("WMMA_F16_16x16x16_F32", 2, 2, 1, 1, 1),
+            MMASchedule("WMMA_F16_16x16x16_F32", 2, 4, 2, 1, 2),
+            MMASchedule("WMMA_F16_16x16x16_F32", 4, 2, 4, 2, 2),
+        ]
+    else:
+        raise NotImplementedError("unhandled intrinsic case")
+
+    subgroup_size = 64 if intrinsic == "MFMA" else 32
 
     infos = []
     for schedule in schedules:
@@ -272,21 +295,23 @@
             wg_tile_m = schedule.m_count * schedule.m_tile_count * 32
             wg_tile_n = schedule.n_count * schedule.n_tile_count * 32
             wg_tile_k = schedule.k_tile_count * 8
+        elif schedule.intrinsic == "WMMA_F16_16x16x16_F32":
+            wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
+            wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
+            wg_tile_k = schedule.k_tile_count * 16
         else:
             raise NotImplementedError("unhandled intrinsic case")
 
         workgroup_tile = [[wg_tile_m, wg_tile_n, wg_tile_k]]
-        workgroup_size = [schedule.n_count * 64, schedule.m_count, 1]
+        workgroup_size = [schedule.n_count * subgroup_size, schedule.m_count, 1]
         infos.append(
             CompilationInfo(
                 tile_sizes=workgroup_tile,
-                dispatch_lowering_pass_pipeline=compilation_info_id.value,
+                dispatch_lowering_pass_pipeline="LLVMGPUVectorDistribute",
                 workgroup_size=workgroup_size,
                 software_pipeline_depth=0,
                 mma_schedule=schedule,
-                # TODO: This is only valid for gfx9. Change this for RDNA3
-                # architectures.
-                subgroup_size=64,
+                subgroup_size=subgroup_size,
             )
         )
     return infos
@@ -299,7 +324,10 @@
     if compilation_info_id == CompilationInfoId.NONE:
         return [None]
 
-    if compilation_info_id == CompilationInfoId.LLVMGPUVectorDistribute:
+    if compilation_info_id in [
+        CompilationInfoId.LLVMGPUVectorDistributeMFMA,
+        CompilationInfoId.LLVMGPUVectorDistributeWMMA,
+    ]:
         return get_rocm_test_compilation_infos(compilation_info_id)
 
     software_pipeline_depth = 0