[Codegen] PV and QK matmul's must have same acc layout (#21729)

Fixes issue #21602 where vector distribute failed due to an unresolvable
layout change in attention.

Check that the 2 matmuls have the same accumulator layout. 

With this change, the reproducer in #21602 compiles down to a .vmfb. I
have not checked numerics or looked at any performance benchmarks.

---------

Signed-off-by: James Newling <james.newling@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
index e8d9eb8..d9fe58c 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
@@ -9,6 +9,7 @@
 
 #include <cstdint>
 
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/Sequence.h"
 #include "llvm/Support/DebugLog.h"
@@ -666,6 +667,13 @@
   bool canReuseAOutputForB;
 };
 
+static bool matchLayout(IREE::GPU::MMASingleSubgroupLayout layoutA,
+                        IREE::GPU::MMASingleSubgroupLayout layoutB) {
+  return (layoutA.element == layoutB.element) &&
+         (layoutA.thread == layoutB.thread) &&
+         (layoutA.tstrides == layoutB.tstrides);
+};
+
 FailureOr<std::pair<GPUMMASchedule, GPUMMASchedule>> deduceAttentionSchedule(
     const GPUMatmulShapeType &qkMatmul, const GPUMatmulShapeType &pvMatmul,
     ArrayRef<GPUIntrinsicType> intrinsics,
@@ -677,28 +685,33 @@
          qkMatmul.nSizes.size() == 1 && qkMatmul.kSizes.size() == 1 &&
          "unimplemented: multi M/N/K attention schedule");
 
+  SmallVector<uint64_t> qkViableIntrinsicIndices;
+  SmallVector<uint64_t> pvViableIntrinsicIndices;
+  for (const auto &[index, intrinsic] : llvm::enumerate(intrinsics)) {
+    if (!failed(canTargetIntrinsic(qkMatmul, intrinsic, subgroupSize,
+                                   canUpcastAcc, mustBeAligned))) {
+      qkViableIntrinsicIndices.push_back(index);
+    }
+    if (!failed(canTargetIntrinsic(pvMatmul, intrinsic, subgroupSize,
+                                   canUpcastAcc, mustBeAligned))) {
+      pvViableIntrinsicIndices.push_back(index);
+    }
+  }
+
   std::vector<ChainedMMAIntrinsics> intrinsicPairs;
-
-  for (const GPUIntrinsicType &intrinsicA : intrinsics) {
-    for (const GPUIntrinsicType &intrinsicB : intrinsics) {
-      if (failed(canTargetIntrinsic(qkMatmul, intrinsicA, subgroupSize,
-                                    canUpcastAcc, mustBeAligned))) {
+  for (unsigned qkIndex : qkViableIntrinsicIndices) {
+    for (unsigned pvIndex : pvViableIntrinsicIndices) {
+      const GPUIntrinsicType &intrinsicA = intrinsics[qkIndex];
+      const GPUIntrinsicType &intrinsicB = intrinsics[pvIndex];
+      if (!matchLayout(getSingleSubgroupLayout(intrinsicA.mmaKind,
+                                               IREE::GPU::MMAFragment::Acc),
+                       getSingleSubgroupLayout(intrinsicB.mmaKind,
+                                               IREE::GPU::MMAFragment::Acc))) {
         continue;
       }
 
-      if (failed(canTargetIntrinsic(pvMatmul, intrinsicB, subgroupSize,
-                                    canUpcastAcc, mustBeAligned))) {
-        continue;
-      }
       // Check if we can reuse the output of intrinsicA for lhs/rhs of
       // intrinsicB.
-      auto matchLayout =
-          [](IREE::GPU::MMASingleSubgroupLayout layoutA,
-             IREE::GPU::MMASingleSubgroupLayout layoutB) -> bool {
-        return (layoutA.element == layoutB.element) &&
-               (layoutA.thread == layoutB.thread) &&
-               (layoutA.tstrides == layoutB.tstrides);
-      };
       bool canReuseAOutForBLhs =
           matchLayout(getSingleSubgroupLayout(intrinsicA.mmaKind,
                                               IREE::GPU::MMAFragment::Acc),
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h
index 49d27b7..4fe1b85 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h
@@ -17,8 +17,6 @@
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinTypes.h"
 
 namespace mlir::iree_compiler::IREE::GPU {
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx950.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx950.mlir
index 95038b9..57ab1be 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx950.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx950.mlir
@@ -282,3 +282,33 @@
 // CHECK-SAME: subgroup_n_count = 1
 // CHECK-SAME: reduction =  [0, 0, 64, 0]
 // CHECK-SAME: workgroup =  [64, 0, 0, 64]
+
+// -----
+
+// The fix introduced for bug https://github.com/iree-org/iree/issues/21602 was
+// to constrain the MMA layout to be the same for the 2 matmuls inside
+// attention. Before this fix, the PV matmul used MFMA_F32_16x16x128_F8E4M3FN
+// and the QK matmul used MFMA_F32_32x32x64_F8E4M3FN. Vector distribution failed
+// to distribute these layouts to threads.
+
+//       CHECK: #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {}>
+// CHECK-LABEL: func.func @attention_check_mma_accs_compatable
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d2)>
+#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()>
+#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+func.func @attention_check_mma_accs_compatable(%arg0: f32, %arg1: tensor<960x4096x64xf8E4M3FN>, %arg2: tensor<960x4096x64xf8E4M3FN>, %arg3: tensor<960x4096x64xf8E4M3FN>, %arg4: tensor<960x4096x64xf32>, %arg5: !iree_tensor_ext.dispatch.tensor<writeonly:tensor<960x4096x64xf32>>) {
+  %0 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4]} ins(%arg1, %arg2, %arg3, %arg0 : tensor<960x4096x64xf8E4M3FN>, tensor<960x4096x64xf8E4M3FN>, tensor<960x4096x64xf8E4M3FN>, f32) outs(%arg4 : tensor<960x4096x64xf32>) {
+  ^bb0(%arg6: f32):
+    iree_linalg_ext.yield %arg6 : f32
+  } -> tensor<960x4096x64xf32>
+  iree_tensor_ext.dispatch.tensor.store %0, %arg5, offsets = [0, 0, 0], sizes = [960, 4096, 64], strides = [1, 1, 1] : tensor<960x4096x64xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<960x4096x64xf32>>
+  return
+}
+//      CHECK: decomposition_config =
+// CHECK-SAME: attention_pv_matmul
+// CHECK-SAME:   #iree_gpu.mma_layout<MFMA_F32_32x32x64_F8E4M3FN>
+// CHECK-SAME: attention_qk_matmul
+// CHECK-SAME:   #iree_gpu.mma_layout<MFMA_F32_32x32x64_F8E4M3FN>