[rocdl] Add e2e matmul test for cdna3 matrix core (#16510)
This commit adds support for matmul correctness
test over CodeGen paths that target cnda3 mfma.
Right now we only cover (M, K) x (K, N) -> (M, N)
matmul for v_mfma_f32_16x16x16_f16.
diff --git a/build_tools/bazel/iree_e2e_matmul_test.bzl b/build_tools/bazel/iree_e2e_matmul_test.bzl
index f2a9e86..dcf8f0f 100644
--- a/build_tools/bazel/iree_e2e_matmul_test.bzl
+++ b/build_tools/bazel/iree_e2e_matmul_test.bzl
@@ -218,8 +218,8 @@
tests = []
for backend, driver in target_backends_and_drivers:
- # CUDA backend/driver not supported by Bazel build.
- if backend == "cuda" or driver == "cuda":
+ # CUDA/ROCm backend/driver not supported by Bazel build.
+ if backend == "cuda" or driver == "cuda" or backend == "rocm" or driver == "hip":
continue
suite_entry_name = "_".join([name, backend, driver])
iree_single_backend_e2e_matmul_test(
diff --git a/build_tools/cmake/iree_check_test.cmake b/build_tools/cmake/iree_check_test.cmake
index 89c7c94..c6cb3ec 100644
--- a/build_tools/cmake/iree_check_test.cmake
+++ b/build_tools/cmake/iree_check_test.cmake
@@ -126,12 +126,6 @@
if(DEFINED _RULE_DRIVER)
string(TOUPPER ${_RULE_DRIVER} _UPPERCASE_DRIVER)
string(REPLACE "-" "_" _NORMALIZED_DRIVER ${_UPPERCASE_DRIVER})
- string(TOUPPER "${IREE_EXTERNAL_HAL_DRIVERS}" _UPPERCASE_EXTERNAL_DRIVERS)
- string(REPLACE "-" "_" _NORMALIZED_EXTERNAL_DRIVERS "${_UPPERCASE_EXTERNAL_DRIVERS}")
- if((NOT DEFINED IREE_HAL_DRIVER_${_NORMALIZED_DRIVER}) AND
- (NOT ${_NORMALIZED_DRIVER} IN_LIST _NORMALIZED_EXTERNAL_DRIVERS))
- message(SEND_ERROR "Unknown driver '${_RULE_DRIVER}'. Check IREE_HAL_DRIVER_*/IREE_EXTERNAL_HAL_DRIVERS options.")
- endif()
if((NOT IREE_HAL_DRIVER_${_NORMALIZED_DRIVER}) AND
(NOT IREE_EXTERNAL_${_NORMALIZED_DRIVER}_HAL_DRIVER_FOUND))
set(_TEST_DISABLED TRUE)
diff --git a/build_tools/cmake/iree_e2e_matmul_test.cmake b/build_tools/cmake/iree_e2e_matmul_test.cmake
index 15850b4..9bc630e 100644
--- a/build_tools/cmake/iree_e2e_matmul_test.cmake
+++ b/build_tools/cmake/iree_e2e_matmul_test.cmake
@@ -216,12 +216,6 @@
if(DEFINED _RULE_DRIVER)
string(TOUPPER ${_RULE_DRIVER} _UPPERCASE_DRIVER)
string(REPLACE "-" "_" _NORMALIZED_DRIVER ${_UPPERCASE_DRIVER})
- string(TOUPPER "${IREE_EXTERNAL_HAL_DRIVERS}" _UPPERCASE_EXTERNAL_DRIVERS)
- string(REPLACE "-" "_" _NORMALIZED_EXTERNAL_DRIVERS "${_UPPERCASE_EXTERNAL_DRIVERS}")
- if((NOT DEFINED IREE_HAL_DRIVER_${_NORMALIZED_DRIVER}) AND
- (NOT ${_NORMALIZED_DRIVER} IN_LIST _NORMALIZED_EXTERNAL_DRIVERS))
- message(SEND_ERROR "Unknown driver '${_RULE_DRIVER}'. Check IREE_HAL_DRIVER_*/IREE_EXTERNAL_HAL_DRIVERS options.")
- endif()
if((NOT IREE_HAL_DRIVER_${_NORMALIZED_DRIVER}) AND
(NOT IREE_EXTERNAL_${_NORMALIZED_DRIVER}_HAL_DRIVER_FOUND))
set(_TEST_DISABLED TRUE)
diff --git a/build_tools/cmake/test_riscv.sh b/build_tools/cmake/test_riscv.sh
index b9d31ea..e14855b 100755
--- a/build_tools/cmake/test_riscv.sh
+++ b/build_tools/cmake/test_riscv.sh
@@ -36,6 +36,8 @@
"^driver=vulkan$"
"^driver=metal$"
"^driver=cuda$"
+ "^driver=rocm$"
+ "^driver=hip$"
"^vulkan_uses_vk_khr_shader_float16_int8$"
"^requires-filesystem$"
"^requires-dtz$"
diff --git a/tests/e2e/matmul/BUILD.bazel b/tests/e2e/matmul/BUILD.bazel
index 8c2f03c..99ffe99 100644
--- a/tests/e2e/matmul/BUILD.bazel
+++ b/tests/e2e/matmul/BUILD.bazel
@@ -425,6 +425,39 @@
###########################################################################
##
+## ROCm backend
+##
+###########################################################################
+
+# Testing CDNA3 + matrix core path.
+# v_mfma_f32_16x16x16_f16
+iree_generated_e2e_matmul_test(
+ name = "e2e_matmul_rocm_f16_large_cdna3_matrixcore",
+ compiler_flags = [
+ "--iree-rocm-target-chip=gfx942",
+ ],
+ generator = ":generate_e2e_matmul_tests",
+ generator_args = [
+ "--lhs_rhs_type=f16",
+ "--acc_type=f32",
+ "--shapes=gpu_large_aligned",
+ "--compilation_info=LLVMGPUVectorDistribute",
+ ],
+ tags = [
+ "noasan",
+ "nomsan",
+ "notsan",
+ "noubsan",
+ "requires-gpu-cdna3",
+ ],
+ target_backends_and_drivers = [
+ ("rocm", "rocm"),
+ ],
+ test_runner = "//tools:iree-e2e-matmul-test",
+)
+
+###########################################################################
+##
## Vulkan backend
##
###########################################################################
diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt
index 187ed7b..9939e35 100644
--- a/tests/e2e/matmul/CMakeLists.txt
+++ b/tests/e2e/matmul/CMakeLists.txt
@@ -978,6 +978,32 @@
iree_generated_e2e_matmul_test(
NAME
+ e2e_matmul_rocm_f16_large_cdna3_matrixcore
+ GENERATOR
+ "generate_e2e_matmul_tests.py"
+ GENERATOR_ARGS
+ "--lhs_rhs_type=f16"
+ "--acc_type=f32"
+ "--shapes=gpu_large_aligned"
+ "--compilation_info=LLVMGPUVectorDistribute"
+ TEST_RUNNER
+ iree-e2e-matmul-test
+ TARGET_BACKENDS
+ "rocm"
+ DRIVERS
+ "rocm"
+ COMPILER_FLAGS
+ "--iree-rocm-target-chip=gfx942"
+ LABELS
+ "noasan"
+ "nomsan"
+ "notsan"
+ "noubsan"
+ "requires-gpu-cdna3"
+)
+
+iree_generated_e2e_matmul_test(
+ NAME
e2e_matmul_vulkan_i8_large_valhall
GENERATOR
"generate_e2e_matmul_tests.py"
diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py
index 456947c..151ca42 100644
--- a/tests/e2e/matmul/generate_e2e_matmul_tests.py
+++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py
@@ -48,6 +48,7 @@
LLVMGPUMatmulSimt = "LLVMGPUMatmulSimt"
LLVMGPUMatmulTensorCore = "LLVMGPUMatmulTensorCore"
LLVMGPUMatmulTensorCoreMmaSync = "LLVMGPUMatmulTensorCoreMmaSync"
+ LLVMGPUVectorDistribute = "LLVMGPUVectorDistribute"
SPIRVCooperativeMatrixVectorize = "SPIRVCooperativeMatrixVectorize"
SPIRVVectorizeMali = "SPIRVVectorizeMali"
SPIRVVectorizeNVIDIA = "SPIRVVectorizeNVIDIA"
@@ -81,6 +82,28 @@
accumulate: bool
+# Describes a workgroup and tiling schedule to target a specific MMA intrinsic.
+@dataclasses.dataclass
+class MMASchedule:
+ intrinsic: str
+ m_count: int # Number of subgroups per workgroup along M
+ n_count: int # Number of subgroups per workgroup along N
+ m_tile_count: int
+ n_tile_count: int
+ k_tile_count: int
+
+ def __str__(self):
+ return (
+ "mma_schedule = #iree_gpu.mma_schedule<"
+ + f"intrinsic = #iree_gpu.mfma_layout<{self.intrinsic}>, "
+ + f"subgroup_m_count = {self.m_count}, "
+ + f"subgroup_n_count = {self.n_count}, "
+ + f"subgroup_m_tile_count = {self.m_tile_count}, "
+ + f"subgroup_n_tile_count = {self.n_tile_count}, "
+ + f"subgroup_k_tile_count = {self.k_tile_count}>"
+ )
+
+
# Describes how to construct compilation info for the testcase.
@dataclasses.dataclass
class CompilationInfo:
@@ -88,8 +111,8 @@
tile_sizes: typing.List[typing.List[int]]
# Translation Info
dispatch_lowering_pass_pipeline: str
- workload_per_wg: typing.List[int]
software_pipeline_depth: int
+ mma_schedule: typing.Optional[MMASchedule]
# Compilation info
workgroup_size: typing.List[int]
@@ -155,8 +178,8 @@
]
if shapes_id == ShapesId.GPU_LARGE_ALIGNED:
return [
- TestShape(m=256, k=128, n=512, accumulate=True),
- TestShape(m=256, k=128, n=512, accumulate=False),
+ TestShape(m=512, k=128, n=512, accumulate=True),
+ TestShape(m=512, k=128, n=512, accumulate=False),
]
if shapes_id == ShapesId.GPU_LARGE:
return [
@@ -197,8 +220,8 @@
workgroup_size: typing.List[int]
-# Constructs a TileWorkgroupSizePair for SPIRV Targets enforcing the constraints between
-# the workgroup_size and tile size
+# Constructs a TileWorkgroupSizePair for SPIR-V targets enforcing the
+# constraints between the workgroup_size and tile size
def get_spirv_tile_workgroup_size_pair(
workgroup_size, t_tile_k, t_tile_m=4, t_tile_n=4
):
@@ -224,12 +247,57 @@
return tile_workgroup_size_pairs
+def get_rocm_test_compilation_infos(compilation_info_id: CompilationInfoId):
+ assert compilation_info_id == CompilationInfoId.LLVMGPUVectorDistribute
+
+ schedules = [
+ MMASchedule("F16_16x16x16_F32", 1, 1, 1, 1, 1),
+ MMASchedule("F16_16x16x16_F32", 1, 1, 1, 1, 2),
+ MMASchedule("F16_16x16x16_F32", 1, 1, 1, 2, 1),
+ MMASchedule("F16_16x16x16_F32", 1, 1, 2, 1, 1),
+ MMASchedule("F16_16x16x16_F32", 2, 2, 1, 1, 1),
+ MMASchedule("F16_16x16x16_F32", 2, 4, 2, 1, 2),
+ MMASchedule("F16_16x16x16_F32", 4, 2, 4, 2, 2),
+ ]
+
+ infos = []
+ for schedule in schedules:
+ if schedule.intrinsic == "F16_16x16x16_F32":
+ wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
+ wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
+ wg_tile_k = schedule.k_tile_count * 16
+ elif schedule.intrinsic == "F16_32x32x8_F32":
+ wg_tile_m = schedule.m_count * schedule.m_tile_count * 32
+ wg_tile_n = schedule.n_count * schedule.n_tile_count * 32
+ wg_tile_k = schedule.k_tile_count * 8
+ else:
+ raise NotImplementedError("unhandled intrinsic case")
+
+ workgroup_tile = [[wg_tile_m, wg_tile_n, wg_tile_k]]
+ workgroup_size = [schedule.n_count * 64, schedule.m_count, 1]
+ infos.append(
+ CompilationInfo(
+ tile_sizes=workgroup_tile,
+ dispatch_lowering_pass_pipeline=compilation_info_id.value,
+ workgroup_size=workgroup_size,
+ software_pipeline_depth=0,
+ mma_schedule=schedule,
+ )
+ )
+ return infos
+
+
# Returns the list of CompilationInfo's to use for the CompilationInfoId.
def get_test_compilation_infos(
compilation_info_id: CompilationInfoId, lhs_rhs_type: MatrixElemTypeId
) -> typing.List[typing.Optional[CompilationInfo]]:
if compilation_info_id == CompilationInfoId.NONE:
return [None]
+
+ if compilation_info_id == CompilationInfoId.LLVMGPUVectorDistribute:
+ return get_rocm_test_compilation_infos(compilation_info_id)
+
+ software_pipeline_depth = 0
if compilation_info_id == CompilationInfoId.LLVMGPUMatmulSimt:
tile_workgroup_size_pairs = [
TileWorkgroupSizePair([[32, 128, 32]], [32, 8, 1]),
@@ -240,6 +308,7 @@
TileWorkgroupSizePair([[16, 64, 4]], [16, 2, 1]),
TileWorkgroupSizePair([[1, 128, 8]], [32, 1, 1]),
]
+ software_pipeline_depth = 3
elif compilation_info_id == CompilationInfoId.SPIRVCooperativeMatrixVectorize:
tile_workgroup_size_pairs = [
TileWorkgroupSizePair(
@@ -277,6 +346,7 @@
tile_workgroup_size_pairs.append(
TileWorkgroupSizePair([[128, 128, 16]], [64, 2, 1])
)
+ software_pipeline_depth = 3
compilation_infos = []
for tile_workgroup_size_pair in tile_workgroup_size_pairs:
@@ -284,11 +354,9 @@
CompilationInfo(
tile_sizes=tile_workgroup_size_pair.tile_size,
dispatch_lowering_pass_pipeline=compilation_info_id.value,
- workload_per_wg=[
- a for a in reversed(tile_workgroup_size_pair.tile_size[0:2])
- ],
workgroup_size=tile_workgroup_size_pair.workgroup_size,
- software_pipeline_depth=3,
+ software_pipeline_depth=software_pipeline_depth,
+ mma_schedule=None,
)
)
return compilation_infos
@@ -389,7 +457,10 @@
info = f"_for_{compilation_info.dispatch_lowering_pass_pipeline}_{tile_workgroup_key}"
matmul_kind = "matmul_accumulate" if accumulate else "matmul"
- return f"{matmul_kind}_{lhs_m}x{lhs_k}x{input_t}_times_{rhs_k}x{rhs_n}x{input_t}_into_{acc_m}x{acc_n}x{acc_t}{info}"
+ return (
+ f"{matmul_kind}_{lhs_m}x{lhs_k}x{input_t}_times_"
+ + f"{rhs_k}x{rhs_n}x{input_t}_into_{acc_m}x{acc_n}x{acc_t}{info}"
+ )
# Represents a generated test function.
@@ -429,30 +500,26 @@
func_definition = ""
compilation_info_attr = ""
if compilation_info:
- if (
- "SPIRV"
- in compilation_info.dispatch_lowering_pass_pipeline
- == "SPIRVVectorizeMali"
- ):
- dispatch_lowering_pass_pipeline = "SPIRVBaseVectorize"
- elif (
- compilation_info.dispatch_lowering_pass_pipeline
- == "SPIRVCooperativeMatrixVectorize"
- ):
- dispatch_lowering_pass_pipeline = "SPIRVCooperativeMatrixVectorize"
- elif compilation_info.dispatch_lowering_pass_pipeline == "SPIRVVectorizeNVIDIA":
+ requested_pipeline = compilation_info.dispatch_lowering_pass_pipeline
+ compiler_pipeline = requested_pipeline
+ if requested_pipeline == "SPIRVVectorizeMali":
+ compiler_pipeline = "SPIRVBaseVectorize"
+ elif requested_pipeline == "SPIRVCooperativeMatrixVectorize":
+ compiler_pipeline = "SPIRVCooperativeMatrixVectorize"
+ elif requested_pipeline == "SPIRVVectorizeNVIDIA":
# TODO: change to test SPIRVMatmulPromoteVectorize too
- dispatch_lowering_pass_pipeline = "SPIRVBaseVectorize"
- else:
- dispatch_lowering_pass_pipeline = (
- compilation_info.dispatch_lowering_pass_pipeline
- )
+ compiler_pipeline = "SPIRVBaseVectorize"
+
+ mma_schedule = ""
+ if compilation_info.mma_schedule is not None:
+ mma_schedule = ", {}".format(compilation_info.mma_schedule)
compilation_info_string = (
- f"#compilation{generate_function.compilation_index} = #iree_codegen.compilation_info<\n"
+ f"#compilation{generate_function.compilation_index} = "
+ "#iree_codegen.compilation_info<\n"
f" lowering_config = <tile_sizes = {compilation_info.tile_sizes}>,\n"
- f" translation_info = <{dispatch_lowering_pass_pipeline},\n"
+ f" translation_info = <{compiler_pipeline},\n"
f" {{ pipeline_depth = {compilation_info.software_pipeline_depth}, "
- f" store_stage = 1 }}>,\n"
+ f" store_stage = 1{mma_schedule} }}>,\n"
f" workgroup_size = {compilation_info.workgroup_size_str()}>\n"
)
compilation_info_attr = (