[VectorDistribution] Reuse intrinsic layout in chained gemm (#18505)

This patch teaches attention codegen pipeline to reuse the intrinsic
layout of output of the first matmul as the lhs of the second matmul.
This is possible for 16x16x16 and 32x32x8 MFMA intrinsic layouts.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
index 55e1b24..296b316 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
@@ -104,6 +104,7 @@
         "//compiler/src/iree/compiler/Codegen/Utils:VectorOpUtils",
         "//compiler/src/iree/compiler/Dialect/Encoding/IR",
         "//compiler/src/iree/compiler/Dialect/HAL/IR",
+        "//compiler/src/iree/compiler/Utils",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:AMDGPUDialect",
         "@llvm-project//mlir:AffineDialect",
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
index 4ded89f..e078969 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
@@ -135,6 +135,7 @@
     iree::compiler::Codegen::Utils::VectorOpUtils
     iree::compiler::Dialect::Encoding::IR
     iree::compiler::Dialect::HAL::IR
+    iree::compiler::Utils
   PUBLIC
 )
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
index ad0dbc2..58ed23a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
@@ -1020,10 +1020,16 @@
                                 PatternRewriter &rewriter) const override {
     auto input = cast<VectorValue>(toLayoutOp.getInput());
     auto output = cast<VectorValue>(toLayoutOp.getOutput());
-    VectorLayoutInterface currentLayout =
-        dyn_cast<LayoutAttr>(signature[input]);
-    VectorLayoutInterface targetLayout =
-        dyn_cast<LayoutAttr>(signature[output]);
+    VectorLayoutInterface currentLayout = signature[input];
+    VectorLayoutInterface targetLayout = signature[output];
+
+    if (!currentLayout) {
+      return rewriter.notifyMatchFailure(toLayoutOp, "No layout set on input");
+    }
+
+    if (!targetLayout) {
+      return rewriter.notifyMatchFailure(toLayoutOp, "No layout set on output");
+    }
 
     if (currentLayout != targetLayout) {
       return rewriter.notifyMatchFailure(toLayoutOp,
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
index 260f7c2..ec86bf2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
@@ -10,6 +10,7 @@
 #include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h"
 #include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
 #include "iree/compiler/Codegen/Utils/GPUUtils.h"
+#include "iree/compiler/Utils/Permutation.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/FormatVariadic.h"
@@ -601,6 +602,93 @@
   }
 };
 
+struct DistributeBatchOuterToLayoutConversions final
+    : OpDistributionPattern<IREE::VectorExt::ToLayoutOp> {
+  using OpDistributionPattern::OpDistributionPattern;
+
+  LogicalResult matchAndRewrite(IREE::VectorExt::ToLayoutOp toLayoutOp,
+                                DistributionSignature &signature,
+                                PatternRewriter &rewriter) const override {
+    Location loc = toLayoutOp.getLoc();
+    auto input = cast<VectorValue>(toLayoutOp.getInput());
+    auto output = cast<VectorValue>(toLayoutOp.getOutput());
+    auto layoutA = dyn_cast<NestedLayoutAttr>(signature[input]);
+    auto layoutB = dyn_cast<NestedLayoutAttr>(signature[output]);
+
+    if (!layoutA || !layoutB) {
+      return rewriter.notifyMatchFailure(toLayoutOp, "non-nested layout");
+    }
+
+    // Check if everything other than batch and outer tile matches.
+    if (layoutA.getSubgroupTile() != layoutB.getSubgroupTile()) {
+      return failure();
+    }
+    if (layoutA.getSubgroupStrides() != layoutB.getSubgroupStrides()) {
+      return failure();
+    }
+    if (layoutA.getThreadTile() != layoutB.getThreadTile()) {
+      return failure();
+    }
+    if (layoutA.getThreadStrides() != layoutB.getThreadStrides()) {
+      return failure();
+    }
+    if (layoutA.getElementTile() != layoutB.getElementTile()) {
+      return failure();
+    }
+
+    auto batchTileA = SmallVector<int64_t>(layoutA.getBatchTile());
+    auto outerTileA = SmallVector<int64_t>(layoutA.getOuterTile());
+    auto batchTileB = SmallVector<int64_t>(layoutB.getBatchTile());
+    auto outerTileB = SmallVector<int64_t>(layoutB.getOuterTile());
+
+    // Check if there is a batch/outer tile mismatch.
+    if (batchTileA == batchTileB && outerTileA == outerTileB) {
+      return rewriter.notifyMatchFailure(toLayoutOp,
+                                         "trivial layout conversion");
+    }
+
+    SmallVector<int64_t> shapeA = layoutA.getDistributedShape();
+    SmallVector<int64_t> shapeB = layoutB.getDistributedShape();
+    int64_t rank = layoutA.getRank();
+
+    // Interleave batch and outer dims by transposing.
+
+    // Build a permutation for interleaving.
+    auto interleavePermutation =
+        llvm::to_vector(llvm::seq<int64_t>(shapeA.size()));
+    for (int i = 0; i < rank; ++i) {
+      // Batch tile : [0...rank]
+      // OuterTile : [rank+1...2*rank]
+      // Interleave : [batch0, outer0, batch1, outer1,...]
+      interleavePermutation[2 * i] = i;
+      interleavePermutation[2 * i + 1] = i + rank;
+    }
+
+    auto interleaved = rewriter.create<vector::TransposeOp>(
+        loc, getDistributed(rewriter, input, layoutA), interleavePermutation);
+
+    // Shape cast to match the new layout.
+
+    SmallVector<int64_t> transposedShapeB(shapeB);
+    applyPermutationToVector(transposedShapeB, interleavePermutation);
+    Type reshapedType = VectorType::get(
+        transposedShapeB, interleaved.getResultVectorType().getElementType());
+
+    auto reshaped =
+        rewriter.create<vector::ShapeCastOp>(loc, reshapedType, interleaved);
+
+    // Inverse transpose to preserve original order.
+    SmallVector<int64_t> invertedPermutation =
+        invertPermutationVector(interleavePermutation);
+
+    auto layouted = rewriter.create<vector::TransposeOp>(loc, reshaped,
+                                                         invertedPermutation);
+
+    replaceOpWithDistributedValues(rewriter, toLayoutOp, layouted.getResult());
+    return success();
+  }
+};
+
 } // namespace
 
 void populateGPUDistributeNestedLayoutAttrPatterns(RewritePatternSet &patterns,
@@ -612,6 +700,7 @@
   patterns.add<DistributeBroadcast, DistributeTranspose>(patterns.getContext());
   patterns.add<DistributeMultiReduction>(patterns.getContext(), subgroupSize,
                                          maxBitsPerShuffle);
+  patterns.add<DistributeBatchOuterToLayoutConversions>(patterns.getContext());
 }
 
 }; // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
index bdc31eb..3f84454 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
@@ -23,9 +23,11 @@
 
 namespace {
 
-LogicalResult setContractionAnchor(IREE::GPU::MMAScheduleAttr schedule,
-                                   RewriterBase &rewriter,
-                                   linalg::LinalgOp contract) {
+static LogicalResult setContractionAnchor(IREE::GPU::MMAScheduleAttr schedule,
+                                          RewriterBase &rewriter,
+                                          linalg::LinalgOp contract,
+                                          bool promoteLhs = true,
+                                          bool promoteRhs = true) {
   // TODO: Add SIMT fallback.
   if (!schedule) {
     return contract->emitError("missing mma schedule for contraction");
@@ -65,8 +67,13 @@
   // TODO: We should read this from the lowering_config on the operation.
   // TODO: This is a hack until layout analysis is improved. The layout analysis
   // should decide where to put these shared memory conversions.
-  layoutedLhs.setSharedMemoryConversion(true);
-  layoutedRhs.setSharedMemoryConversion(true);
+  if (promoteLhs) {
+    layoutedLhs.setSharedMemoryConversion(true);
+  }
+
+  if (promoteRhs) {
+    layoutedRhs.setSharedMemoryConversion(true);
+  }
 
   contract->setOperand(0, layoutedLhs.getResult());
   contract->setOperand(1, layoutedRhs.getResult());
@@ -82,9 +89,9 @@
   return success();
 }
 
-LogicalResult setConvolutionAnchor(IREE::GPU::MMAScheduleAttr schedule,
-                                   RewriterBase &rewriter,
-                                   linalg::LinalgOp conv) {
+static LogicalResult setConvolutionAnchor(IREE::GPU::MMAScheduleAttr schedule,
+                                          RewriterBase &rewriter,
+                                          linalg::LinalgOp conv) {
   // TODO: Add SIMT fallback.
   if (!schedule) {
     return conv->emitError("missing mma schedule for convolution");
@@ -160,35 +167,164 @@
   return success();
 }
 
-LogicalResult setAttentionMatmulAnchor(IREE::GPU::MMAScheduleAttr schedule,
-                                       RewriterBase &rewriter,
-                                       linalg::LinalgOp contract) {
+/// Let's assume we have an matmul intrinsic (@) doing a matmul
+/// ((M, K) X (K, N)) which produces a particular layout:
+///
+/// C = A @ B
+///
+/// If we transpose and swap the operands, we can keep the same matmul
+/// intrinsic, but transpose the layout of the output intrinsic:
+///
+/// A.T = transpose(A)
+/// B.T = transpose(B)
+/// C.T = B.T @ A.T
+/// C = transpose(C.T)
+///
+/// This is useful when the "@" instruction that the hardware lowers to
+/// has a specific thread layout but the further uses of C expects a transposed
+/// layout to the produced layout.
+///
+/// For example, for "@" lowering to AMDGPU MFMA instructions, the operands
+/// have layout L and L.T and the result has the layout L.T .
+/// So if you have a chain of matmuls:
+///
+/// C (L.T) = A (L) @ B (L.T)
+/// E (L.T) = C (L.T)  @ D (L.T)
+///            ^^^^^^^
+///            Expected layout by instruction is L
+///
+/// To fix this, we can apply this transformation on the first matrix:
+///
+/// C.T (L.T) = B.T (L) @ A (L.T)
+/// C   (L)   = transpose C.T (L.T)
+/// E   (L.T) = C (L)  @ D (L.T)
+///            ^^^^^
+///            Layout matches the instruction!
+///
+/// Note that the mathematical formula
+///   C = A @ B --> C.T = B.T @ A.T
+/// is only defined on standard "@" function, it may be a different
+/// transformation for other indexing maps.
+///
+/// For linalg operands, since the indexing maps are part of the op defination,
+/// we can achieve the same transformation by simply swapping the operands.
+static void swapOperandsToTransposeIntrinsic(RewriterBase &rewriter,
+                                             linalg::GenericOp contractOp) {
+  Value lhs = contractOp->getOperand(0);
+  Value rhs = contractOp->getOperand(1);
+
+  SmallVector<AffineMap> indexingMaps = contractOp.getIndexingMapsArray();
+  std::swap(indexingMaps[0], indexingMaps[1]);
+
+  contractOp.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(indexingMaps));
+  contractOp->setOperand(0, rhs);
+  contractOp->setOperand(1, lhs);
+}
+
+static IREE::GPU::MMAScheduleAttr
+transposeSchedule(RewriterBase &rewriter, IREE::GPU::MMAScheduleAttr schedule) {
+  return rewriter.getAttr<IREE::GPU::MMAScheduleAttr>(
+      schedule.getIntrinsic(), schedule.getSubgroupNCount(),
+      schedule.getSubgroupMCount());
+}
+
+static LogicalResult
+setAttentionMatmulAnchor(IREE::GPU::MMAScheduleAttr schedule,
+                         RewriterBase &rewriter, linalg::LinalgOp qkMatmul,
+                         linalg::LinalgOp pvMatmul) {
   // TODO: Add SIMT fallback.
   if (!schedule) {
-    return contract->emitError("missing mma schedule for contraction");
+    return pvMatmul->emitError("missing mma schedule for contraction");
   }
 
-  if (contract->hasAttr("attention_qk_matmul")) {
-    // subgroup_n count for attention matmul is always 1, because it is the
-    // reduction dimension. The subgroup_n count is in reality, for the second
-    // matmul.
-    IREE::GPU::MMAScheduleAttr qkSchedule =
-        rewriter.getAttr<IREE::GPU::MMAScheduleAttr>(
-            schedule.getIntrinsic(),
-            /*subgroup_m_count=*/schedule.getSubgroupMCount(),
-            /*subgroup_n_count=*/1);
-    return setContractionAnchor(qkSchedule, rewriter, contract);
+  // Check if the intrinsic output for qkMatmul can be reused for pvMatmul.
+  // We know that pvMatmul takes result of qkMatmul as it's lhs.
+  // If the intrinsic output of pvMatmul can be used as rhs of pvMatmul,
+  // we swap operands of both contracts to get output as transposed intrinsic.
+  bool reuseIntrinsicOutput = false;
+  bool transposeIntrinsic = false;
+
+  auto intrinsic = cast<IREE::GPU::MMAAttr>(schedule.getIntrinsic());
+  IREE::GPU::MMASingleSubgroupLayout lhsLayout =
+      intrinsic.getASingleSubgroupLayout();
+  IREE::GPU::MMASingleSubgroupLayout rhsLayout =
+      intrinsic.getBSingleSubgroupLayout();
+  IREE::GPU::MMASingleSubgroupLayout outLayout =
+      intrinsic.getCSingleSubgroupLayout();
+
+  auto matchLayout = [](IREE::GPU::MMASingleSubgroupLayout layoutA,
+                        IREE::GPU::MMASingleSubgroupLayout layoutB) -> bool {
+    return (layoutA.element == layoutB.element) &&
+           (layoutA.thread == layoutB.thread) &&
+           (layoutA.tstrides == layoutB.tstrides);
+  };
+
+  // TODO: Move this check to KernelConfig and set appropriate attributes
+  // in lowering_config for the operation. This allows us to check shared
+  // memory usage and decide what kind of pipelining we can do.
+  if (matchLayout(outLayout, lhsLayout)) {
+    reuseIntrinsicOutput = true;
+  } else if (matchLayout(outLayout, rhsLayout)) {
+    reuseIntrinsicOutput = true;
+    transposeIntrinsic = true;
   }
 
-  if (contract->hasAttr("attention_pv_matmul")) {
-    // subgroup_n count for attention matmul is always 1, because it is the
-    // reduction dimension. The subgroup_n count is in reality, for the second
-    // matmul.
-    return setContractionAnchor(schedule, rewriter, contract);
+  // subgroup_n count for attention matmul is always 1, because it is the
+  // reduction dimension. The subgroup_n count is in reality, for the pvMatmul.
+  IREE::GPU::MMAScheduleAttr qkSchedule =
+      rewriter.getAttr<IREE::GPU::MMAScheduleAttr>(
+          schedule.getIntrinsic(),
+          /*subgroup_m_count=*/schedule.getSubgroupMCount(),
+          /*subgroup_n_count=*/1);
+  IREE::GPU::MMAScheduleAttr pvSchedule = schedule;
+
+  // Transpose the intrinsic if requested. See docs for
+  // swapOperandsToTransposeIntrinsic for more information on why this is done.
+  if (transposeIntrinsic) {
+    auto qkGeneric = dyn_cast<linalg::GenericOp>(qkMatmul.getOperation());
+    auto pvGeneric = dyn_cast<linalg::GenericOp>(pvMatmul.getOperation());
+    if (!qkGeneric || !pvGeneric) {
+      pvMatmul->emitOpError("Non generic qkMatmul/pvMatmul transpose intrinsic "
+                            "not yet implemented");
+      return failure();
+    }
+    swapOperandsToTransposeIntrinsic(rewriter, qkGeneric);
+    swapOperandsToTransposeIntrinsic(rewriter, pvGeneric);
+    qkSchedule = transposeSchedule(rewriter, qkSchedule);
+    pvSchedule = transposeSchedule(rewriter, pvSchedule);
   }
 
-  return contract->emitError("attention matmul should have either "
-                             "attention_qk_matmul or attention_pv_matmul set");
+  if (failed(setContractionAnchor(qkSchedule, rewriter, qkMatmul))) {
+    return failure();
+  }
+
+  // Do not promote lhs of pvMatmul if we are reusing the intrinsic output.
+  bool promoteLhs = !reuseIntrinsicOutput;
+  bool promoteRhs = true;
+  if (transposeIntrinsic) {
+    std::swap(promoteLhs, promoteRhs);
+  }
+
+  return setContractionAnchor(pvSchedule, rewriter, pvMatmul, promoteLhs,
+                              promoteRhs);
+}
+
+static Operation *getOpWithAttr(Operation *root, StringRef attr) {
+  Operation *result = nullptr;
+  WalkResult walkResult = root->walk([&](Operation *op) {
+    if (op->hasAttr(attr)) {
+      if (result) {
+        return WalkResult::interrupt();
+      }
+      result = op;
+    }
+    return WalkResult::advance();
+  });
+
+  if (walkResult.wasInterrupted()) {
+    return nullptr;
+  }
+  return result;
 }
 
 struct LLVMGPUConfigureTensorLayoutsPass final
@@ -212,19 +348,33 @@
     // now, layout setting for other problems like reductions is TODO.
     SmallVector<linalg::LinalgOp> contracts;
     SmallVector<linalg::LinalgOp> convs;
-    SmallVector<linalg::LinalgOp> attentionMatmuls;
+
+    auto attentionQKMatmul = dyn_cast_or_null<linalg::LinalgOp>(
+        getOpWithAttr(func, "attention_qk_matmul"));
+    auto attentionPVMatmul = dyn_cast_or_null<linalg::LinalgOp>(
+        getOpWithAttr(func, "attention_pv_matmul"));
+
+    if (attentionQKMatmul && !attentionPVMatmul) {
+      func->emitError("Expected attention attributes to be set properly");
+      return signalPassFailure();
+    }
+
+    if (!attentionQKMatmul && attentionPVMatmul) {
+      func->emitError("Expected attention attributes to be set properly");
+      return signalPassFailure();
+    }
 
     func->walk([&](linalg::LinalgOp linalgOp) {
+      if (linalgOp == attentionQKMatmul || linalgOp == attentionPVMatmul) {
+        return WalkResult::advance();
+      }
+
       if (linalg::isaContractionOpInterface(linalgOp)) {
-        if (linalgOp->hasAttr("attention_qk_matmul") ||
-            linalgOp->hasAttr("attention_pv_matmul")) {
-          attentionMatmuls.push_back(linalgOp);
-        } else {
-          contracts.push_back(linalgOp);
-        }
+        contracts.push_back(linalgOp);
       } else if (succeeded(linalg::inferConvolutionDims(linalgOp))) {
         convs.push_back(linalgOp);
       }
+      return WalkResult::advance();
     });
 
     IRRewriter rewriter(func);
@@ -241,9 +391,9 @@
       }
     }
 
-    for (linalg::LinalgOp attentionMatmul : attentionMatmuls) {
-      if (failed(setAttentionMatmulAnchor(scheduleAttr, rewriter,
-                                          attentionMatmul))) {
+    if (attentionQKMatmul && attentionPVMatmul) {
+      if (failed(setAttentionMatmulAnchor(
+              scheduleAttr, rewriter, attentionQKMatmul, attentionPVMatmul))) {
         return signalPassFailure();
       }
     }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index faec538..cbe19e5 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -868,7 +868,6 @@
 
   // Preprocessing for vector distribution.
   funcPassManager.addPass(createLLVMGPUCastTypeToFitMMAPass());
-  funcPassManager.addPass(createAMDGPUPrepareForChainedMatmulPass());
 
   // Vector SIMD -> Vector SIMT
   funcPassManager.addPass(createLLVMGPUConfigureVectorLayoutsPass());
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
index 86e7f0b..eb8f4f1 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
@@ -3,6 +3,11 @@
 // RUN:   --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-llvmgpu-lower-executable-target)))))" \
 // RUN:   %s | FileCheck %s
 
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx940 \
+// RUN:   --iree-codegen-llvmgpu-use-vector-distribution --iree-llvmgpu-enable-prefetch=true \
+// RUN:   --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-llvmgpu-lower-executable-target)))))" \
+// RUN:   %s | FileCheck %s --check-prefix=MEMORY
+
 #config = #iree_codegen.lowering_config<tile_sizes = [[64, 64, 128]]>
 #translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 2>}>
 
@@ -591,10 +596,16 @@
 // CHECK: transfer_read
 
 // CHECK: scf.for %{{.*}} = %c0 to %c4096 step %c64
-// CHECK-SAME: -> (vector<2x1x4xf32>, vector<2x1x4xf32>, vector<2x4x1x1x4x1xf32>)
+// CHECK-SAME: -> (vector<2x1x1xf32>, vector<2x1x1xf32>, vector<2x4x1x1x1x4xf32>)
 // CHECK-COUNT-48:  amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
 // CHECK: scf.yield
 
+// Check that we only use alloc for Q, K, and V. No shared memory for S is
+// needed because the intrinsic layout mathes.
+// MEMORY-LABEL: func.func @attention_20x4096x64x4096x64()
+// MEMORY-COUNT-3: memref.alloc
+// MEMORY-NOT: memref.alloc
+
 // -----
 
 #config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 64, 0, 64, 64]]>
@@ -640,6 +651,67 @@
 
 // CHECK-LABEL: func.func @attention_multiple_m_transpose()
 // CHECK: scf.for %{{.*}} = %c0 to %c72 step %c1
-// CHECK-SAME: -> (vector<2x1x4xf32>, vector<2x1x4xf32>, vector<2x8x1x1x4x1xf32>)
+// CHECK-SAME: -> (vector<2x1x1xf32>, vector<2x1x1xf32>, vector<2x8x1x1x1x4xf32>)
 // CHECK-COUNT-96:  amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
 // CHECK: scf.yield
+
+// Check that we only use alloc for Q, K, and V. No shared memory for S is
+// needed because the intrinsic layout mathes.
+// MEMORY-LABEL: func.func @attention_multiple_m_transpose()
+// MEMORY-COUNT-3: memref.alloc
+// MEMORY-NOT: memref.alloc
+
+// -----
+
+#config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 128, 0, 32, 64]]>
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [64, 4, 1] subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>, subgroup_m_count = 4, subgroup_n_count = 1>}>
+
+#pipeline_layout = #hal.pipeline.layout<bindings = [
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>
+]>
+hal.executable private @attention_mfma_32x32x8 {
+  hal.executable.variant public @rocm target(<"rocm", "rocm-hsaco-fb">) {
+    hal.executable.export public @attention_mfma_32x32x8 ordinal(0) layout(#pipeline_layout) {
+    ^bb0(%arg0: !hal.device):
+      %x, %y, %z = flow.dispatch.workgroup_count_from_slice
+      hal.return %x, %y, %z : index, index, index
+    }
+    builtin.module {
+      func.func @attention_mfma_32x32x8() attributes {translation_info = #translation} {
+        %cst = arith.constant 1.0 : f16
+        %c0 = arith.constant 0 : index
+        %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<24x64x4608x128xf16>>
+        %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<24x4608x128xf16>>
+        %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<24x4608x128xf16>>
+        %3 = hal.interface.binding.subspan layout(#pipeline_layout) binding(3) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<64x4608x24x128xf16>>
+        %4 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [24, 64, 4608, 128], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<24x64x4608x128xf16>> -> tensor<24x64x4608x128xf16>
+        %5 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [24, 4608, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<24x4608x128xf16>> -> tensor<24x4608x128xf16>
+        %6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [24, 4608, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<24x4608x128xf16>> -> tensor<24x4608x128xf16>
+        %7 = tensor.empty() : tensor<64x4608x24x128xf16>
+        %8 = tensor.empty() : tensor<24x64x4608x128xf16>
+        %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) -> tensor<24x64x4608x128xf16>
+        %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d0, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], lowering_config = #config} ins(%9 : tensor<24x64x4608x128xf16>) outs(%7 : tensor<64x4608x24x128xf16>) {
+        ^bb0(%in: f16, %out: f16):
+          linalg.yield %in : f16
+        } -> tensor<64x4608x24x128xf16>
+        flow.dispatch.tensor.store %10, %3, offsets = [0, 0, 0, 0], sizes = [64, 4608, 24, 128], strides = [1, 1, 1, 1] : tensor<64x4608x24x128xf16> -> !flow.dispatch.tensor<writeonly:tensor<64x4608x24x128xf16>>
+        return
+      }
+    }
+  }
+}
+
+// CHECK-LABEL: func.func @attention_mfma_32x32x8()
+// CHECK: scf.for %{{.*}} = %c0 to %c144 step %c1
+// CHECK-SAME: -> (vector<1x1x1xf32>, vector<1x1x1xf32>, vector<1x4x1x4x1x4xf32>)
+// CHECK-COUNT-32:  amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<16xf32>
+// CHECK: scf.yield
+
+// Check that we only use alloc for Q, K, and V. No shared memory for S is
+// needed because the intrinsic layout mathes.
+// MEMORY-LABEL: func.func @attention_mfma_32x32x8()
+// MEMORY-COUNT-3: memref.alloc
+// MEMORY-NOT: memref.alloc