blob: 6abf0855e1bb16e03ba8b21ade5e01b7aa0b18fd [file] [log] [blame]
#!/usr/bin/env python3
# Copyright 2025 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""compilation_info support in e2e matmul tests"""
from typing import Optional
import enum
import dataclasses
import typing
from tests.e2e.matmul.common import *
# Describes a workgroup and tiling schedule to target a specific MMA intrinsic.
@dataclasses.dataclass
class MMASchedule:
intrinsic: str
m_count: int # Number of subgroups per workgroup along M
n_count: int # Number of subgroups per workgroup along N
m_tile_count: int
n_tile_count: int
k_tile_count: int
def get_subgroup_basis(self) -> str:
return f"[[{self.m_count}, {self.n_count}, 1], [0, 1, 2]]"
# Enumerates of the collections of compilation info that we can generate tests
# for. The values are the accepted values for the --compilation_info= flag.
@enum.unique
class CompilationInfoId(enum.Enum):
NONE = ""
LLVMGPUMatmulTensorCore = "LLVMGPUMatmulTensorCore"
LLVMGPUMatmulTensorCoreMmaSync = "LLVMGPUMatmulTensorCoreMmaSync"
LLVMGPUVectorDistributeMFMA = "LLVMGPUVectorDistributeMFMA"
LLVMGPUVectorDistributeWMMAR3 = "LLVMGPUVectorDistributeWMMAR3"
LLVMGPUVectorDistributeWMMAR4 = "LLVMGPUVectorDistributeWMMAR4"
LLVMGPUVectorDistributeWMMA1250 = "LLVMGPUVectorDistributeWMMA1250"
SPIRVCooperativeMatrixVectorize = "SPIRVCooperativeMatrixVectorize"
SPIRVVectorizeMali = "SPIRVVectorizeMali"
SPIRVVectorizeNVIDIA = "SPIRVVectorizeNVIDIA"
# Describes how to construct compilation info for the testcase.
@dataclasses.dataclass
class CompilationInfo:
# Compilation info
workgroup_size: typing.List[int]
subgroup_size: Optional[int]
# Translation info
dispatch_lowering_pass_pipeline: str
# Prints the workgroup size
def workgroup_size_str(self):
return "workgroup_size = [" + ", ".join(map(str, self.workgroup_size)) + "]"
def get_compilation_info_attr(self) -> str:
...
@dataclasses.dataclass
class IREEGPUCompilationInfo(CompilationInfo):
# Lowering Config
workgroup_tile: list[int]
reduction_tile: list[int]
# Translation Info
mma_schedule: Optional[MMASchedule]
def get_compilation_info_attr(self) -> str:
requested_pipeline = self.dispatch_lowering_pass_pipeline
compiler_pipeline = requested_pipeline
subgroup_size_str = ""
if self.subgroup_size is not None:
subgroup_size_str = f"subgroup_size = {self.subgroup_size}"
return (
"#iree_codegen.compilation_info<\n"
f" lowering_config = #iree_gpu.lowering_config<{{"
f" mma_kind = #iree_gpu.mma_layout<{self.mma_schedule.intrinsic}>, "
f" subgroup_basis = {self.mma_schedule.get_subgroup_basis()}, "
f" workgroup = {self.workgroup_tile}, "
f" reduction = {self.reduction_tile} }}>,\n"
f" translation_info = #iree_codegen.translation_info<pipeline = {compiler_pipeline} {self.workgroup_size_str()}\n"
f" {subgroup_size_str}>>\n"
)
@dataclasses.dataclass
class LegacyCompilationInfo(CompilationInfo):
# Lowering Config
tile_sizes: typing.List[typing.List[int]]
# Translation Info
software_pipeline_depth: int
def get_compilation_info_attr(self) -> str:
requested_pipeline = self.dispatch_lowering_pass_pipeline
compiler_pipeline = requested_pipeline
if requested_pipeline == "SPIRVVectorizeMali":
compiler_pipeline = "SPIRVBaseVectorize"
elif requested_pipeline == "SPIRVCooperativeMatrixVectorize":
compiler_pipeline = "SPIRVCooperativeMatrixVectorize"
elif requested_pipeline == "SPIRVVectorizeNVIDIA":
# TODO: change to test SPIRVMatmulPromoteVectorize too
compiler_pipeline = "SPIRVBaseVectorize"
subgroup_size_str = ""
if self.subgroup_size is not None:
subgroup_size_str = f"subgroup_size = {self.subgroup_size}"
return (
"#iree_codegen.compilation_info<\n"
f" lowering_config = #iree_codegen.lowering_config<tile_sizes = {self.tile_sizes}>,\n"
f" translation_info = #iree_codegen.translation_info<pipeline = {compiler_pipeline} {self.workgroup_size_str()}\n"
f" {subgroup_size_str},\n"
f" {{ pipeline_depth = {self.software_pipeline_depth}, store_stage = 1}}>>"
)
def generate_compilation_info_string_and_attr(compilation_info: CompilationInfo):
if not compilation_info:
return ("", "")
index = generate_compilation_info_string_and_attr.compilation_index
generate_compilation_info_string_and_attr.compilation_index += 1
return (
f"#compilation{index} = {compilation_info.get_compilation_info_attr()}",
f"{{compilation_info = #compilation{index}}} ",
)
generate_compilation_info_string_and_attr.compilation_index = 0
@dataclasses.dataclass
class TileWorkgroupSizePair:
tile_size: typing.List[typing.List[int]]
workgroup_size: typing.List[int]
# Constructs a TileWorkgroupSizePair for SPIR-V targets enforcing the
# constraints between the workgroup_size and tile size
def get_spirv_tile_workgroup_size_pair(
workgroup_size, t_tile_k, t_tile_m=4, t_tile_n=4
):
x, y, z = workgroup_size
wg_tile_m = y * t_tile_m
wg_tile_n = x * t_tile_n
return TileWorkgroupSizePair(
[[wg_tile_m, wg_tile_n], [t_tile_m, t_tile_n], [0, 0, t_tile_k]], workgroup_size
)
# Returns all the TileWorkgroupSizePairs for a given SPIRV Target
def get_all_spirv_tile_workgroup_size_pairs(t_tile_k):
tile_workgroup_size_pairs = [
get_spirv_tile_workgroup_size_pair([32, 8, 1], t_tile_k),
get_spirv_tile_workgroup_size_pair([16, 8, 1], t_tile_k),
get_spirv_tile_workgroup_size_pair([64, 2, 1], t_tile_k),
get_spirv_tile_workgroup_size_pair([8, 8, 1], t_tile_k),
get_spirv_tile_workgroup_size_pair([32, 1, 1], t_tile_k),
get_spirv_tile_workgroup_size_pair([16, 2, 1], t_tile_k),
get_spirv_tile_workgroup_size_pair([32, 1, 1], t_tile_k),
]
return tile_workgroup_size_pairs
def get_rocm_test_compilation_infos(
compilation_info_id: CompilationInfoId, lhs_rhs_type: MatrixElemTypeId
):
intrinsic = ""
if compilation_info_id == CompilationInfoId.LLVMGPUVectorDistributeMFMA:
intrinsic = "MFMA"
elif compilation_info_id == CompilationInfoId.LLVMGPUVectorDistributeWMMAR3:
intrinsic = "WMMAR3"
elif compilation_info_id == CompilationInfoId.LLVMGPUVectorDistributeWMMAR4:
intrinsic = "WMMAR4"
elif compilation_info_id == CompilationInfoId.LLVMGPUVectorDistributeWMMA1250:
intrinsic = "WMMA1250"
else:
raise ValueError("Unknown pipeline for rocm")
schedules = []
if intrinsic == "MFMA":
schedules = [
MMASchedule("MFMA_F32_16x16x4_F32", 1, 1, 1, 1, 1),
MMASchedule("MFMA_F32_16x16x4_F32", 1, 1, 1, 1, 2),
MMASchedule("MFMA_F32_16x16x4_F32", 1, 1, 1, 2, 1),
MMASchedule("MFMA_F32_16x16x4_F32", 1, 1, 2, 1, 1),
MMASchedule("MFMA_F32_16x16x4_F32", 2, 2, 1, 1, 2),
MMASchedule("MFMA_F32_16x16x16_F16", 1, 1, 1, 1, 1),
MMASchedule("MFMA_F32_16x16x16_F16", 1, 1, 1, 1, 2),
MMASchedule("MFMA_F32_16x16x16_F16", 1, 1, 1, 2, 1),
MMASchedule("MFMA_F32_16x16x16_F16", 1, 1, 2, 1, 1),
MMASchedule("MFMA_F32_16x16x16_F16", 2, 2, 1, 1, 1),
MMASchedule("MFMA_F32_16x16x16_F16", 2, 4, 2, 1, 2),
MMASchedule("MFMA_F32_16x16x16_F16", 4, 2, 4, 2, 2),
MMASchedule("MFMA_F32_32x32x8_F16", 1, 1, 1, 2, 2),
MMASchedule("MFMA_F32_32x32x8_F16", 2, 2, 1, 1, 1),
MMASchedule("MFMA_F32_32x32x8_F16", 1, 4, 2, 1, 2),
MMASchedule("MFMA_F32_32x32x8_F16", 4, 2, 1, 2, 4),
MMASchedule("MFMA_F32_16x16x32_F8E4M3FNUZ", 1, 1, 1, 1, 1),
MMASchedule("MFMA_F32_16x16x32_F8E4M3FNUZ", 2, 2, 1, 1, 2),
MMASchedule("MFMA_F32_16x16x32_F8E4M3FNUZ", 4, 1, 4, 1, 1),
MMASchedule("MFMA_F32_16x16x32_F8E4M3FNUZ", 4, 2, 4, 2, 1),
MMASchedule("MFMA_F32_32x32x16_F8E4M3FNUZ", 1, 1, 1, 1, 1),
MMASchedule("MFMA_F32_32x32x16_F8E4M3FNUZ", 2, 2, 1, 1, 2),
MMASchedule("MFMA_F32_32x32x16_F8E4M3FNUZ", 4, 1, 1, 2, 2),
MMASchedule("MFMA_F32_32x32x16_F8E4M3FNUZ", 4, 2, 2, 2, 2),
MMASchedule("MFMA_F32_16x16x32_F8E4M3FN", 1, 1, 1, 1, 1),
MMASchedule("MFMA_F32_16x16x32_F8E4M3FN", 2, 2, 1, 1, 2),
MMASchedule("MFMA_F32_16x16x32_F8E4M3FN", 4, 1, 4, 1, 1),
MMASchedule("MFMA_F32_16x16x32_F8E4M3FN", 4, 2, 4, 2, 1),
MMASchedule("MFMA_F32_32x32x16_F8E4M3FN", 1, 1, 1, 1, 1),
MMASchedule("MFMA_F32_32x32x16_F8E4M3FN", 2, 2, 1, 1, 2),
MMASchedule("MFMA_F32_32x32x16_F8E4M3FN", 4, 1, 1, 2, 2),
MMASchedule("MFMA_F32_32x32x16_F8E4M3FN", 4, 2, 2, 2, 2),
MMASchedule("MFMA_I32_16x16x32_I8", 1, 1, 1, 1, 1),
MMASchedule("MFMA_I32_16x16x32_I8", 2, 2, 1, 1, 2),
MMASchedule("MFMA_I32_16x16x32_I8", 4, 1, 4, 1, 1),
MMASchedule("MFMA_I32_16x16x32_I8", 4, 2, 4, 2, 1),
MMASchedule("MFMA_I32_32x32x16_I8", 1, 1, 1, 1, 1),
MMASchedule("MFMA_I32_32x32x16_I8", 2, 2, 1, 1, 2),
MMASchedule("MFMA_I32_32x32x16_I8", 4, 1, 1, 2, 2),
MMASchedule("MFMA_I32_32x32x16_I8", 4, 2, 2, 2, 2),
MMASchedule("VMFMA_F32_16x16x32_F16", 1, 1, 1, 1, 1),
MMASchedule("VMFMA_F32_16x16x32_F16", 4, 2, 1, 2, 4),
MMASchedule("VMFMA_F32_32x32x16_F16", 1, 1, 1, 1, 1),
MMASchedule("VMFMA_F32_32x32x16_F16", 4, 2, 1, 2, 4),
MMASchedule("VMFMA_F32_16x16x32_F8E4M3FNUZ", 1, 1, 1, 1, 1),
MMASchedule("VMFMA_F32_16x16x32_F8E4M3FNUZ", 4, 1, 4, 1, 1),
]
elif intrinsic == "WMMAR3":
schedules = [
MMASchedule("WMMAR3_F32_16x16x16_F16", 1, 1, 1, 1, 1),
MMASchedule("WMMAR3_F32_16x16x16_F16", 1, 1, 1, 1, 2),
MMASchedule("WMMAR3_F32_16x16x16_F16", 1, 1, 1, 2, 1),
MMASchedule("WMMAR3_F32_16x16x16_F16", 1, 1, 2, 1, 1),
MMASchedule("WMMAR3_F32_16x16x16_F16", 2, 2, 1, 1, 1),
MMASchedule("WMMAR3_F32_16x16x16_F16", 2, 4, 2, 1, 2),
MMASchedule("WMMAR3_F32_16x16x16_F16", 4, 2, 4, 2, 2),
MMASchedule("WMMAR3_I32_16x16x16_I8", 1, 1, 1, 1, 1),
MMASchedule("WMMAR3_I32_16x16x16_I8", 1, 1, 1, 1, 2),
MMASchedule("WMMAR3_I32_16x16x16_I8", 1, 1, 1, 2, 1),
MMASchedule("WMMAR3_I32_16x16x16_I8", 1, 1, 2, 1, 1),
MMASchedule("WMMAR3_I32_16x16x16_I8", 2, 2, 1, 1, 1),
MMASchedule("WMMAR3_I32_16x16x16_I8", 2, 4, 2, 1, 2),
MMASchedule("WMMAR3_I32_16x16x16_I8", 4, 2, 4, 2, 2),
]
elif intrinsic == "WMMAR4":
schedules = [
MMASchedule("WMMAR4_F32_16x16x16_F16", 1, 1, 1, 1, 1),
MMASchedule("WMMAR4_F32_16x16x16_F16", 1, 1, 1, 1, 2),
MMASchedule("WMMAR4_F32_16x16x16_F16", 1, 1, 1, 2, 1),
MMASchedule("WMMAR4_F32_16x16x16_F16", 1, 1, 2, 1, 1),
MMASchedule("WMMAR4_F32_16x16x16_F16", 2, 2, 1, 1, 1),
MMASchedule("WMMAR4_F32_16x16x16_F16", 2, 4, 2, 1, 2),
MMASchedule("WMMAR4_F32_16x16x16_F16", 4, 2, 4, 2, 2),
MMASchedule("WMMAR4_F32_16x16x16_F8E4M3FN", 1, 1, 1, 1, 1),
MMASchedule("WMMAR4_F32_16x16x16_F8E4M3FN", 1, 1, 1, 1, 2),
MMASchedule("WMMAR4_F32_16x16x16_F8E4M3FN", 1, 1, 1, 2, 1),
MMASchedule("WMMAR4_F32_16x16x16_F8E4M3FN", 1, 1, 2, 1, 1),
MMASchedule("WMMAR4_F32_16x16x16_F8E4M3FN", 2, 2, 1, 1, 1),
MMASchedule("WMMAR4_F32_16x16x16_F8E4M3FN", 2, 4, 2, 1, 2),
MMASchedule("WMMAR4_F32_16x16x16_F8E4M3FN", 4, 2, 4, 2, 2),
MMASchedule("WMMAR4_I32_16x16x16_I8", 1, 1, 1, 1, 1),
MMASchedule("WMMAR4_I32_16x16x16_I8", 1, 1, 1, 1, 2),
MMASchedule("WMMAR4_I32_16x16x16_I8", 1, 1, 1, 2, 1),
MMASchedule("WMMAR4_I32_16x16x16_I8", 1, 1, 2, 1, 1),
MMASchedule("WMMAR4_I32_16x16x16_I8", 2, 2, 1, 1, 1),
MMASchedule("WMMAR4_I32_16x16x16_I8", 2, 4, 2, 1, 2),
MMASchedule("WMMAR4_I32_16x16x16_I8", 4, 2, 4, 2, 2),
]
elif intrinsic == "WMMA1250":
# gfx1250 WMMA intrinsics: 16x16 tiles with various K sizes.
# F16: K=32, F8E4M3FN: K=64 or K=128, I8: K=64
schedules = [
MMASchedule("WMMA_F32_16x16x32_F16", 1, 1, 1, 1, 1),
MMASchedule("WMMA_F32_16x16x32_F16", 1, 1, 1, 1, 2),
MMASchedule("WMMA_F32_16x16x32_F16", 1, 1, 1, 2, 1),
MMASchedule("WMMA_F32_16x16x32_F16", 1, 1, 2, 1, 1),
MMASchedule("WMMA_F32_16x16x32_F16", 2, 2, 1, 1, 1),
MMASchedule("WMMA_F32_16x16x32_F16", 2, 4, 2, 1, 2),
MMASchedule("WMMA_F32_16x16x32_F16", 4, 2, 4, 2, 2),
# K=64 F8E4M3FN
MMASchedule("WMMA_F32_16x16x64_F8E4M3FN", 1, 1, 1, 1, 1),
MMASchedule("WMMA_F32_16x16x64_F8E4M3FN", 1, 1, 1, 1, 2),
MMASchedule("WMMA_F32_16x16x64_F8E4M3FN", 1, 1, 1, 2, 1),
MMASchedule("WMMA_F32_16x16x64_F8E4M3FN", 1, 1, 2, 1, 1),
MMASchedule("WMMA_F32_16x16x64_F8E4M3FN", 2, 2, 1, 1, 1),
MMASchedule("WMMA_F32_16x16x64_F8E4M3FN", 2, 4, 2, 1, 2),
MMASchedule("WMMA_F32_16x16x64_F8E4M3FN", 4, 2, 4, 2, 2),
# K=128 F8E4M3FN
MMASchedule("WMMA_F32_16x16x128_F8E4M3FN", 1, 1, 1, 1, 1),
MMASchedule("WMMA_F32_16x16x128_F8E4M3FN", 1, 1, 1, 1, 2),
MMASchedule("WMMA_F32_16x16x128_F8E4M3FN", 1, 1, 1, 2, 1),
MMASchedule("WMMA_F32_16x16x128_F8E4M3FN", 1, 1, 2, 1, 1),
MMASchedule("WMMA_F32_16x16x128_F8E4M3FN", 2, 2, 1, 1, 1),
MMASchedule("WMMA_F32_16x16x128_F8E4M3FN", 2, 4, 2, 1, 2),
MMASchedule("WMMA_F32_16x16x128_F8E4M3FN", 4, 2, 4, 2, 2),
# I8
MMASchedule("WMMA_I32_16x16x64_I8", 1, 1, 1, 1, 1),
MMASchedule("WMMA_I32_16x16x64_I8", 1, 1, 1, 1, 2),
MMASchedule("WMMA_I32_16x16x64_I8", 1, 1, 1, 2, 1),
MMASchedule("WMMA_I32_16x16x64_I8", 1, 1, 2, 1, 1),
MMASchedule("WMMA_I32_16x16x64_I8", 2, 2, 1, 1, 1),
MMASchedule("WMMA_I32_16x16x64_I8", 2, 4, 2, 1, 2),
MMASchedule("WMMA_I32_16x16x64_I8", 4, 2, 4, 2, 2),
]
else:
raise NotImplementedError("unhandled intrinsic case")
# WMMAR3, WMMAR4, WMMA1250 all use 32.
subgroup_size = 64 if intrinsic == "MFMA" else 32
infos = []
for schedule in schedules:
# Skip schedules with an intrinsic which element type does not
# match the requested one.
# Extracts the input type from strings. The naming convention is
# [output_type]_MxNxK_[input_type].
input_type = schedule.intrinsic.split("_")[-1]
if lhs_rhs_type.value.upper() != input_type:
continue
if schedule.intrinsic == "MFMA_F32_16x16x4_F32":
wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
wg_tile_k = schedule.k_tile_count * 4
elif schedule.intrinsic == "MFMA_F32_16x16x16_F16":
wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
wg_tile_k = schedule.k_tile_count * 16
elif schedule.intrinsic == "MFMA_F32_32x32x8_F16":
wg_tile_m = schedule.m_count * schedule.m_tile_count * 32
wg_tile_n = schedule.n_count * schedule.n_tile_count * 32
wg_tile_k = schedule.k_tile_count * 8
elif (
schedule.intrinsic == "VMFMA_F32_16x16x32_F16"
or schedule.intrinsic == "MFMA_I32_16x16x32_I8"
or schedule.intrinsic == "MFMA_F32_16x16x32_F8E4M3FNUZ"
or schedule.intrinsic == "VMFMA_F32_16x16x32_F8E4M3FNUZ"
or schedule.intrinsic == "MFMA_F32_16x16x32_F8E4M3FN"
or schedule.intrinsic == "VMFMA_F32_16x16x32_F8E4M3FN"
):
wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
wg_tile_k = schedule.k_tile_count * 32
elif (
schedule.intrinsic == "VMFMA_F32_32x32x16_F16"
or schedule.intrinsic == "MFMA_F32_32x32x16_F8E4M3FNUZ"
or schedule.intrinsic == "MFMA_F32_32x32x16_F8E4M3FN"
or schedule.intrinsic == "MFMA_I32_32x32x16_I8"
):
wg_tile_m = schedule.m_count * schedule.m_tile_count * 32
wg_tile_n = schedule.n_count * schedule.n_tile_count * 32
wg_tile_k = schedule.k_tile_count * 16
elif schedule.intrinsic in (
"WMMAR3_F32_16x16x16_F16",
"WMMAR3_I32_16x16x16_I8",
"WMMAR4_F32_16x16x16_F16",
"WMMAR4_F32_16x16x16_F8E4M3FN",
"WMMAR4_I32_16x16x16_I8",
):
wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
wg_tile_k = schedule.k_tile_count * 16
elif schedule.intrinsic == "WMMA_F32_16x16x32_F16":
# gfx1250: M=16, N=16, K=32
wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
wg_tile_k = schedule.k_tile_count * 32
elif schedule.intrinsic in (
"WMMA_F32_16x16x64_F8E4M3FN",
"WMMA_I32_16x16x64_I8",
):
# gfx1250: M=16, N=16, K=64
wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
wg_tile_k = schedule.k_tile_count * 64
elif schedule.intrinsic == "WMMA_F32_16x16x128_F8E4M3FN":
# gfx1250: M=16, N=16, K=128
wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
wg_tile_k = schedule.k_tile_count * 128
else:
raise NotImplementedError("unhandled intrinsic case")
workgroup_tile = [wg_tile_m, wg_tile_n, 0]
reduction_tile = [0, 0, wg_tile_k]
workgroup_size = [schedule.n_count * subgroup_size, schedule.m_count, 1]
infos.append(
IREEGPUCompilationInfo(
workgroup_tile=workgroup_tile,
reduction_tile=reduction_tile,
dispatch_lowering_pass_pipeline="LLVMGPUVectorDistribute",
workgroup_size=workgroup_size,
mma_schedule=schedule,
subgroup_size=subgroup_size,
)
)
return infos
# Returns the list of CompilationInfo's to use for the CompilationInfoId.
def get_test_compilation_infos(
compilation_info_id: CompilationInfoId, lhs_rhs_type: MatrixElemTypeId
) -> typing.List[typing.Optional[CompilationInfo]]:
if compilation_info_id == CompilationInfoId.NONE:
return [None]
if compilation_info_id in [
CompilationInfoId.LLVMGPUVectorDistributeMFMA,
CompilationInfoId.LLVMGPUVectorDistributeWMMAR3,
CompilationInfoId.LLVMGPUVectorDistributeWMMAR4,
CompilationInfoId.LLVMGPUVectorDistributeWMMA1250,
]:
return get_rocm_test_compilation_infos(compilation_info_id, lhs_rhs_type)
software_pipeline_depth = 0
tile_workgroup_size_pairs = []
if compilation_info_id == CompilationInfoId.SPIRVCooperativeMatrixVectorize:
tile_workgroup_size_pairs = [
TileWorkgroupSizePair(
[[64, 128], [32, 64], [0, 0, 32], [16, 16, 16]], [64, 2, 1]
)
]
elif compilation_info_id == CompilationInfoId.SPIRVVectorizeNVIDIA:
tile_workgroup_size_pairs = get_all_spirv_tile_workgroup_size_pairs(32)
elif compilation_info_id == CompilationInfoId.SPIRVVectorizeMali:
tile_workgroup_size_pairs = get_all_spirv_tile_workgroup_size_pairs(4)
elif (
compilation_info_id == CompilationInfoId.LLVMGPUMatmulTensorCore
or compilation_info_id == CompilationInfoId.LLVMGPUMatmulTensorCoreMmaSync
):
tile_workgroup_size_pairs = []
## WarpShape = 2x2
tile_workgroup_size_pairs.append(
TileWorkgroupSizePair([[32, 32, 16]], [64, 2, 1])
)
tile_workgroup_size_pairs.append(
TileWorkgroupSizePair([[64, 64, 64]], [64, 2, 1])
)
## WarpShape = 4x1
tile_workgroup_size_pairs.append(
TileWorkgroupSizePair([[32, 32, 32]], [64, 1, 1])
)
## WarpShape = 2x2 with large tiles using larger Shared Memory capacity.
if lhs_rhs_type == MatrixElemTypeId.F16:
tile_workgroup_size_pairs.append(
TileWorkgroupSizePair([[128, 128, 64]], [64, 2, 1])
)
elif lhs_rhs_type == MatrixElemTypeId.F32:
tile_workgroup_size_pairs.append(
TileWorkgroupSizePair([[128, 128, 16]], [64, 2, 1])
)
software_pipeline_depth = 3
compilation_infos = []
for tile_workgroup_size_pair in tile_workgroup_size_pairs:
compilation_infos.append(
LegacyCompilationInfo(
tile_sizes=tile_workgroup_size_pair.tile_size,
dispatch_lowering_pass_pipeline=compilation_info_id.value,
workgroup_size=tile_workgroup_size_pair.workgroup_size,
subgroup_size=None,
software_pipeline_depth=software_pipeline_depth,
)
)
return compilation_infos