Revert "Revert "[LLVMGPU][ROCm] Add MFMA_F32_16x16x4_F32 instruction"… (#17921)
… (#17894)"
This reverts commit 02c2000795e157e4cf63fbac89d21a1ed886a7b0.
diff --git a/compiler/plugins/target/ROCM/test/target_device_features.mlir b/compiler/plugins/target/ROCM/test/target_device_features.mlir
index 7859c0b..5973c05 100644
--- a/compiler/plugins/target/ROCM/test/target_device_features.mlir
+++ b/compiler/plugins/target/ROCM/test/target_device_features.mlir
@@ -6,13 +6,13 @@
// GFX942: target = #iree_gpu.target<arch = "gfx942",
// GFX942-SAME: wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8,
// GFX942-SAME: subgroup = shuffle|arithmetic, dot = dp4xi8toi32,
-// GFX942-SAME: mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_F8E4M3FNUZ_16x16x32_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],
+// GFX942-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_F8E4M3FNUZ_16x16x32_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],
// GFX942-SAME: subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
// GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>,
// GFX942-SAME: chip = <wgp_count = 304>>
// GFX940: target = #iree_gpu.target<arch = "gfx940",
-// GFX940-SAME: mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_F8E4M3FNUZ_16x16x32_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],
+// GFX940-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_F8E4M3FNUZ_16x16x32_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],
// GFX1100: target = #iree_gpu.target<arch = "gfx1100",
// GFX1100-SAME: mma = [<WMMA_F16_16x16x16_F32>, <WMMA_F16_16x16x16_F16>]
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index 2e4d771..86e7bc0 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -211,6 +211,9 @@
Type i32 = IntegerType::get(context, 32);
switch (type) {
+ case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
+ return OpaqueMmaLayout{16, 16, 4, f32, f32, f32};
+ }
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
return OpaqueMmaLayout{16, 16, 16, f16, f16, f32};
}
@@ -255,6 +258,24 @@
LayoutDimensionAttr::get(context, LayoutDimension::VECTORZ);
(void)laneZ, (void)vectorZ;
switch (type) {
+ case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
+ // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
+ // #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 1]>
+ // #layout_a = #iree_vector_ext.layout<#outer, #inner>
+ // #layout_b = #iree_vector_ext.layout<#inner, #outer>
+ // #layout_c = #iree_vector_ext.layout<#inner, #outer>
+
+ auto outer = PerDimLayoutAttr::get(context, {laneX}, {16});
+ auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {4, 1});
+ auto aMLayout = outer;
+ auto aKLayout = inner;
+ auto bKLayout = inner;
+ auto bNLayout = outer;
+ auto cMLayout = PerDimLayoutAttr::get(context, {laneY, vectorX}, {4, 4});
+ auto cNLayout = outer;
+ return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
+ bNLayout, cMLayout, cNLayout};
+ }
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
// #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 4]>
@@ -409,6 +430,12 @@
// amd_matrix_instruction_calculator tells us about the number of 32-bit
// registers. So need to adjust accordingly. All vectors should be 1-D.
switch (getIntrinsic().getValue()) {
+ case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
+ auto aType = VectorType::get({1}, getAType());
+ auto bType = VectorType::get({1}, getBType());
+ auto cType = VectorType::get({4}, getCType());
+ return std::make_tuple(aType, bType, cType);
+ }
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
auto aType = VectorType::get({4}, getAType());
auto bType = VectorType::get({4}, getBType());
@@ -456,6 +483,7 @@
int64_t MMAAttr::getBlockSize() const {
switch (getIntrinsic().getValue()) {
+ case MMAIntrinsic::MFMA_F32_16x16x4_F32:
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
@@ -472,6 +500,7 @@
int64_t MMAAttr::getSubgroupSize() const {
switch (getIntrinsic().getValue()) {
+ case MMAIntrinsic::MFMA_F32_16x16x4_F32:
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
@@ -490,6 +519,10 @@
MMAAttr::SingleSubgroupLayout MMAAttr::getASingleSubgroupLayout() const {
switch (getIntrinsic().getValue()) {
+ case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
+ return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*strides=*/{1, 16},
+ /*element=*/{1, 1}};
+ }
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*strides=*/{1, 16},
/*element=*/{1, 4}};
@@ -518,6 +551,10 @@
MMAAttr::SingleSubgroupLayout MMAAttr::getBSingleSubgroupLayout() const {
switch (getIntrinsic().getValue()) {
+ case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
+ return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
+ /*element=*/{1, 1}};
+ }
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
/*element=*/{4, 1}};
@@ -546,6 +583,7 @@
MMAAttr::SingleSubgroupLayout MMAAttr::getCSingleSubgroupLayout() const {
switch (getIntrinsic().getValue()) {
+ case MMAIntrinsic::MFMA_F32_16x16x4_F32:
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
@@ -582,6 +620,17 @@
return failure();
}
switch (getIntrinsic().getValue()) {
+ case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
+ // Update the lhs and rhs to extract the first element since vector<1xT> is
+ // not supoorted by amgpu.mfma op.
+ lhs = builder.create<vector::ExtractOp>(loc, lhs, ArrayRef{int64_t{0}});
+ rhs = builder.create<vector::ExtractOp>(loc, rhs, ArrayRef{int64_t{0}});
+ auto [m, n, k] = getMNKShape();
+ return builder
+ .create<amdgpu::MFMAOp>(loc, resultType, m, n, k, getBlockSize(), lhs,
+ rhs, acc)
+ .getResult();
+ }
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
index 108deb9..55a83b3 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
@@ -99,9 +99,10 @@
}
// Format: <kind>_<input-type>_<M>x<N>x<K>_<output-type>
-def MFMA_F16_16x16x16_F32 : I32EnumAttrCase<"MFMA_F16_16x16x16_F32", 0>;
-def MFMA_F16_32x32x8_F32 : I32EnumAttrCase<"MFMA_F16_32x32x8_F32", 1>;
-def MFMA_F8E4M3FNUZ_16x16x32_F32 : I32EnumAttrCase<"MFMA_F8E4M3FNUZ_16x16x32_F32", 2>;
+def MFMA_F32_16x16x4_F32 : I32EnumAttrCase<"MFMA_F32_16x16x4_F32", 0>;
+def MFMA_F16_16x16x16_F32 : I32EnumAttrCase<"MFMA_F16_16x16x16_F32", 1>;
+def MFMA_F16_32x32x8_F32 : I32EnumAttrCase<"MFMA_F16_32x32x8_F32", 2>;
+def MFMA_F8E4M3FNUZ_16x16x32_F32 : I32EnumAttrCase<"MFMA_F8E4M3FNUZ_16x16x32_F32", 3>;
def MFMA_I8_16x16x32_I32 : I32EnumAttrCase<"MFMA_I8_16x16x32_I32", 4>;
def MFMA_I8_32x32x16_I32 : I32EnumAttrCase<"MFMA_I8_32x32x16_I32", 5>;
// TODO: Create separate WMMA ops for AMD and NVIDIA GPUs
@@ -110,6 +111,7 @@
def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic",
"Descriptor for different MMA intrinsics", [
+ MFMA_F32_16x16x4_F32,
MFMA_F16_16x16x16_F32,
MFMA_F16_32x32x8_F32,
MFMA_F8E4M3FNUZ_16x16x32_F32,
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
index c9e200f..2f3d254 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
@@ -122,6 +122,7 @@
const WgpDetails *getCDNA3WgpDetails() {
static const MMAIntrinsic cdna3MMAOps[] = {
+ MMAIntrinsic::MFMA_F32_16x16x4_F32,
MMAIntrinsic::MFMA_F16_16x16x16_F32,
MMAIntrinsic::MFMA_F16_32x32x8_F32,
MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32,
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp b/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp
index 63ea310..721522a 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp
+++ b/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp
@@ -214,10 +214,23 @@
int64_t nSize = bounds[nDim];
int64_t kSize = bounds[kDim];
+ auto inpElemType =
+ cast<ShapedType>(linalgOp.getDpsInputOperand(0)->get().getType())
+ .getElementType();
+ auto kernelElemType =
+ cast<ShapedType>(linalgOp.getDpsInputOperand(1)->get().getType())
+ .getElementType();
+
// TODO: Generalize to other dimensions.
// Try to search for pad value and check only filter dimension is blocked.
SmallVector<std::array<int64_t, 3>> mnkPaddingCandidates;
for (const GPUMatmulShapeType &intrinsic : intrinsics) {
+
+ if (!(inpElemType == intrinsic.aType &&
+ kernelElemType == intrinsic.bType)) {
+ continue;
+ }
+
std::optional<int64_t> mPadding, nPadding, kPadding;
auto getPadding = [](int64_t value, int64_t padTo) {
return llvm::divideCeil(value, padTo) * padTo - value;
diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt
index c908a8a..0556e75 100644
--- a/tests/e2e/matmul/CMakeLists.txt
+++ b/tests/e2e/matmul/CMakeLists.txt
@@ -2291,6 +2291,34 @@
iree_generated_e2e_runner_test(
NAME
+ e2e_matmul_rocm_f32_large_cdna3_mfma
+ TEST_TYPE
+ matmul
+ GENERATOR
+ "generate_e2e_matmul_tests.py"
+ GENERATOR_ARGS
+ "--lhs_rhs_type=f32"
+ "--acc_type=f32"
+ "--shapes=gpu_large_aligned"
+ "--compilation_info=LLVMGPUVectorDistributeMFMA"
+ TEST_RUNNER
+ iree_tools_testing_e2e_iree-e2e-matmul-test
+ TARGET_BACKENDS
+ "rocm"
+ DRIVERS
+ "hip"
+ COMPILER_FLAGS
+ ${IREE_HIP_TEST_COMPILER_FLAGS}
+ LABELS
+ "noasan"
+ "nomsan"
+ "notsan"
+ "noubsan"
+ "requires-gpu-cdna3"
+)
+
+iree_generated_e2e_runner_test(
+ NAME
e2e_matmul_rocm_f16_large_cdna3_mfma_tb
TEST_TYPE
matmul
diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py
index 8b1fbf6..56d590b 100644
--- a/tests/e2e/matmul/generate_e2e_matmul_tests.py
+++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py
@@ -261,6 +261,11 @@
schedules = []
if intrinsic == "MFMA":
schedules = [
+ MMASchedule("MFMA_F32_16x16x4_F32", 1, 1, 1, 1, 1),
+ MMASchedule("MFMA_F32_16x16x4_F32", 1, 1, 1, 1, 2),
+ MMASchedule("MFMA_F32_16x16x4_F32", 1, 1, 1, 2, 1),
+ MMASchedule("MFMA_F32_16x16x4_F32", 1, 1, 2, 1, 1),
+ MMASchedule("MFMA_F32_16x16x4_F32", 2, 2, 1, 1, 2),
MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 1, 1, 1),
MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 1, 1, 2),
MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 1, 2, 1),
@@ -304,10 +309,17 @@
for schedule in schedules:
# Skip schedules with an intrinsic which element type does not
# match the requested one.
- if lhs_rhs_type.value.upper() not in schedule.intrinsic:
+ # Extracts the input type from strings containing either 'MFMA' or 'WMMA'
+ # followed by an underscore.
+ extract_input_type = lambda s: re.search(r"(?:MFMA|WMMA)_([^_]+)_", s).group(1)
+ if lhs_rhs_type.value.upper() != extract_input_type(schedule.intrinsic):
continue
- if schedule.intrinsic == "MFMA_F16_16x16x16_F32":
+ if schedule.intrinsic == "MFMA_F32_16x16x4_F32":
+ wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
+ wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
+ wg_tile_k = schedule.k_tile_count * 4
+ elif schedule.intrinsic == "MFMA_F16_16x16x16_F32":
wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
wg_tile_k = schedule.k_tile_count * 16