Revert "Revert "[LLVMGPU][ROCm] Add MFMA_F32_16x16x4_F32 instruction"… (#17921)

… (#17894)"

This reverts commit 02c2000795e157e4cf63fbac89d21a1ed886a7b0.
diff --git a/compiler/plugins/target/ROCM/test/target_device_features.mlir b/compiler/plugins/target/ROCM/test/target_device_features.mlir
index 7859c0b..5973c05 100644
--- a/compiler/plugins/target/ROCM/test/target_device_features.mlir
+++ b/compiler/plugins/target/ROCM/test/target_device_features.mlir
@@ -6,13 +6,13 @@
 // GFX942: target = #iree_gpu.target<arch = "gfx942",
 // GFX942-SAME: wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8,
 // GFX942-SAME:         subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32,
-// GFX942-SAME:         mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_F8E4M3FNUZ_16x16x32_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],
+// GFX942-SAME:         mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_F8E4M3FNUZ_16x16x32_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],
 // GFX942-SAME:         subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
 // GFX942-SAME:         max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>,
 // GFX942-SAME: chip = <wgp_count = 304>>
 
 // GFX940: target = #iree_gpu.target<arch = "gfx940",
-// GFX940-SAME:         mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_F8E4M3FNUZ_16x16x32_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],
+// GFX940-SAME:         mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_F8E4M3FNUZ_16x16x32_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],
 
 // GFX1100: target = #iree_gpu.target<arch = "gfx1100",
 // GFX1100-SAME:        mma = [<WMMA_F16_16x16x16_F32>, <WMMA_F16_16x16x16_F16>]
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index 2e4d771..86e7bc0 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -211,6 +211,9 @@
   Type i32 = IntegerType::get(context, 32);
 
   switch (type) {
+  case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
+    return OpaqueMmaLayout{16, 16, 4, f32, f32, f32};
+  }
   case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
     return OpaqueMmaLayout{16, 16, 16, f16, f16, f32};
   }
@@ -255,6 +258,24 @@
       LayoutDimensionAttr::get(context, LayoutDimension::VECTORZ);
   (void)laneZ, (void)vectorZ;
   switch (type) {
+  case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
+    // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
+    // #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 1]>
+    // #layout_a = #iree_vector_ext.layout<#outer, #inner>
+    // #layout_b = #iree_vector_ext.layout<#inner, #outer>
+    // #layout_c = #iree_vector_ext.layout<#inner, #outer>
+
+    auto outer = PerDimLayoutAttr::get(context, {laneX}, {16});
+    auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {4, 1});
+    auto aMLayout = outer;
+    auto aKLayout = inner;
+    auto bKLayout = inner;
+    auto bNLayout = outer;
+    auto cMLayout = PerDimLayoutAttr::get(context, {laneY, vectorX}, {4, 4});
+    auto cNLayout = outer;
+    return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
+                             bNLayout,     cMLayout, cNLayout};
+  }
   case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
     // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
     // #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 4]>
@@ -409,6 +430,12 @@
   // amd_matrix_instruction_calculator tells us about the number of 32-bit
   // registers. So need to adjust accordingly. All vectors should be 1-D.
   switch (getIntrinsic().getValue()) {
+  case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
+    auto aType = VectorType::get({1}, getAType());
+    auto bType = VectorType::get({1}, getBType());
+    auto cType = VectorType::get({4}, getCType());
+    return std::make_tuple(aType, bType, cType);
+  }
   case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
     auto aType = VectorType::get({4}, getAType());
     auto bType = VectorType::get({4}, getBType());
@@ -456,6 +483,7 @@
 
 int64_t MMAAttr::getBlockSize() const {
   switch (getIntrinsic().getValue()) {
+  case MMAIntrinsic::MFMA_F32_16x16x4_F32:
   case MMAIntrinsic::MFMA_F16_16x16x16_F32:
   case MMAIntrinsic::MFMA_F16_32x32x8_F32:
   case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
@@ -472,6 +500,7 @@
 
 int64_t MMAAttr::getSubgroupSize() const {
   switch (getIntrinsic().getValue()) {
+  case MMAIntrinsic::MFMA_F32_16x16x4_F32:
   case MMAIntrinsic::MFMA_F16_16x16x16_F32:
   case MMAIntrinsic::MFMA_F16_32x32x8_F32:
   case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
@@ -490,6 +519,10 @@
 
 MMAAttr::SingleSubgroupLayout MMAAttr::getASingleSubgroupLayout() const {
   switch (getIntrinsic().getValue()) {
+  case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
+    return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*strides=*/{1, 16},
+            /*element=*/{1, 1}};
+  }
   case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
     return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*strides=*/{1, 16},
             /*element=*/{1, 4}};
@@ -518,6 +551,10 @@
 
 MMAAttr::SingleSubgroupLayout MMAAttr::getBSingleSubgroupLayout() const {
   switch (getIntrinsic().getValue()) {
+  case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
+    return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
+            /*element=*/{1, 1}};
+  }
   case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
     return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
             /*element=*/{4, 1}};
@@ -546,6 +583,7 @@
 
 MMAAttr::SingleSubgroupLayout MMAAttr::getCSingleSubgroupLayout() const {
   switch (getIntrinsic().getValue()) {
+  case MMAIntrinsic::MFMA_F32_16x16x4_F32:
   case MMAIntrinsic::MFMA_F16_16x16x16_F32:
   case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
   case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
@@ -582,6 +620,17 @@
     return failure();
   }
   switch (getIntrinsic().getValue()) {
+  case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
+    // Update the lhs and rhs to extract the first element since vector<1xT> is
+    // not supoorted by amgpu.mfma op.
+    lhs = builder.create<vector::ExtractOp>(loc, lhs, ArrayRef{int64_t{0}});
+    rhs = builder.create<vector::ExtractOp>(loc, rhs, ArrayRef{int64_t{0}});
+    auto [m, n, k] = getMNKShape();
+    return builder
+        .create<amdgpu::MFMAOp>(loc, resultType, m, n, k, getBlockSize(), lhs,
+                                rhs, acc)
+        .getResult();
+  }
   case MMAIntrinsic::MFMA_F16_16x16x16_F32:
   case MMAIntrinsic::MFMA_F16_32x32x8_F32:
   case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
index 108deb9..55a83b3 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
@@ -99,9 +99,10 @@
 }
 
 // Format: <kind>_<input-type>_<M>x<N>x<K>_<output-type>
-def MFMA_F16_16x16x16_F32 : I32EnumAttrCase<"MFMA_F16_16x16x16_F32", 0>;
-def MFMA_F16_32x32x8_F32  : I32EnumAttrCase<"MFMA_F16_32x32x8_F32", 1>;
-def MFMA_F8E4M3FNUZ_16x16x32_F32  : I32EnumAttrCase<"MFMA_F8E4M3FNUZ_16x16x32_F32", 2>;
+def MFMA_F32_16x16x4_F32 : I32EnumAttrCase<"MFMA_F32_16x16x4_F32", 0>;
+def MFMA_F16_16x16x16_F32 : I32EnumAttrCase<"MFMA_F16_16x16x16_F32", 1>;
+def MFMA_F16_32x32x8_F32  : I32EnumAttrCase<"MFMA_F16_32x32x8_F32", 2>;
+def MFMA_F8E4M3FNUZ_16x16x32_F32  : I32EnumAttrCase<"MFMA_F8E4M3FNUZ_16x16x32_F32", 3>;
 def MFMA_I8_16x16x32_I32  : I32EnumAttrCase<"MFMA_I8_16x16x32_I32", 4>;
 def MFMA_I8_32x32x16_I32  : I32EnumAttrCase<"MFMA_I8_32x32x16_I32", 5>;
 // TODO: Create separate WMMA ops for AMD and NVIDIA GPUs
@@ -110,6 +111,7 @@
 
 def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic",
     "Descriptor for different MMA intrinsics", [
+      MFMA_F32_16x16x4_F32,
       MFMA_F16_16x16x16_F32,
       MFMA_F16_32x32x8_F32,
       MFMA_F8E4M3FNUZ_16x16x32_F32,
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
index c9e200f..2f3d254 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
@@ -122,6 +122,7 @@
 
 const WgpDetails *getCDNA3WgpDetails() {
   static const MMAIntrinsic cdna3MMAOps[] = {
+      MMAIntrinsic::MFMA_F32_16x16x4_F32,
       MMAIntrinsic::MFMA_F16_16x16x16_F32,
       MMAIntrinsic::MFMA_F16_32x32x8_F32,
       MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32,
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp b/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp
index 63ea310..721522a 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp
+++ b/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp
@@ -214,10 +214,23 @@
   int64_t nSize = bounds[nDim];
   int64_t kSize = bounds[kDim];
 
+  auto inpElemType =
+      cast<ShapedType>(linalgOp.getDpsInputOperand(0)->get().getType())
+          .getElementType();
+  auto kernelElemType =
+      cast<ShapedType>(linalgOp.getDpsInputOperand(1)->get().getType())
+          .getElementType();
+
   // TODO: Generalize to other dimensions.
   // Try to search for pad value and check only filter dimension is blocked.
   SmallVector<std::array<int64_t, 3>> mnkPaddingCandidates;
   for (const GPUMatmulShapeType &intrinsic : intrinsics) {
+
+    if (!(inpElemType == intrinsic.aType &&
+          kernelElemType == intrinsic.bType)) {
+      continue;
+    }
+
     std::optional<int64_t> mPadding, nPadding, kPadding;
     auto getPadding = [](int64_t value, int64_t padTo) {
       return llvm::divideCeil(value, padTo) * padTo - value;
diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt
index c908a8a..0556e75 100644
--- a/tests/e2e/matmul/CMakeLists.txt
+++ b/tests/e2e/matmul/CMakeLists.txt
@@ -2291,6 +2291,34 @@
 
 iree_generated_e2e_runner_test(
   NAME
+    e2e_matmul_rocm_f32_large_cdna3_mfma
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=f32"
+    "--acc_type=f32"
+    "--shapes=gpu_large_aligned"
+    "--compilation_info=LLVMGPUVectorDistributeMFMA"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "rocm"
+  DRIVERS
+    "hip"
+  COMPILER_FLAGS
+    ${IREE_HIP_TEST_COMPILER_FLAGS}
+  LABELS
+    "noasan"
+    "nomsan"
+    "notsan"
+    "noubsan"
+    "requires-gpu-cdna3"
+)
+
+iree_generated_e2e_runner_test(
+  NAME
     e2e_matmul_rocm_f16_large_cdna3_mfma_tb
   TEST_TYPE
     matmul
diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py
index 8b1fbf6..56d590b 100644
--- a/tests/e2e/matmul/generate_e2e_matmul_tests.py
+++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py
@@ -261,6 +261,11 @@
     schedules = []
     if intrinsic == "MFMA":
         schedules = [
+            MMASchedule("MFMA_F32_16x16x4_F32", 1, 1, 1, 1, 1),
+            MMASchedule("MFMA_F32_16x16x4_F32", 1, 1, 1, 1, 2),
+            MMASchedule("MFMA_F32_16x16x4_F32", 1, 1, 1, 2, 1),
+            MMASchedule("MFMA_F32_16x16x4_F32", 1, 1, 2, 1, 1),
+            MMASchedule("MFMA_F32_16x16x4_F32", 2, 2, 1, 1, 2),
             MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 1, 1, 1),
             MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 1, 1, 2),
             MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 1, 2, 1),
@@ -304,10 +309,17 @@
     for schedule in schedules:
         # Skip schedules with an intrinsic which element type does not
         # match the requested one.
-        if lhs_rhs_type.value.upper() not in schedule.intrinsic:
+        # Extracts the input type from strings containing either 'MFMA' or 'WMMA'
+        # followed by an underscore.
+        extract_input_type = lambda s: re.search(r"(?:MFMA|WMMA)_([^_]+)_", s).group(1)
+        if lhs_rhs_type.value.upper() != extract_input_type(schedule.intrinsic):
             continue
 
-        if schedule.intrinsic == "MFMA_F16_16x16x16_F32":
+        if schedule.intrinsic == "MFMA_F32_16x16x4_F32":
+            wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
+            wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
+            wg_tile_k = schedule.k_tile_count * 4
+        elif schedule.intrinsic == "MFMA_F16_16x16x16_F32":
             wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
             wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
             wg_tile_k = schedule.k_tile_count * 16