GPU data tiling: Refine tile dimensions, more preparation for thread distribution. (#18556)

* The unrolling factors in `DataTiledMMAAttr` get split between plain
unrolling and unroll-to-subgroups.
* The dimensions in `TileSwizzle` get an enum telling if they are
cross-thread / cross-instruction.
* `getSwizzle` gets moved to GPUTileSwizzleUtils as it is going to be
used in codegen outside of MaterializeEncoding.

---------

Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
index aaa4ad3..0cc1f28 100644
--- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
@@ -143,6 +143,7 @@
         "TileDispatchUsingForall.cpp",
         "TileDispatchUsingInterface.cpp",
         "TileSizeSelection.cpp",
+        "TileSwizzle.cpp",
         "TypePropagationPass.cpp",
         "UserConfig.cpp",
         "VectorizeMemrefCopy.cpp",
diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
index d828e08..3a94dcd 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -135,6 +135,7 @@
     "TileDispatchUsingForall.cpp"
     "TileDispatchUsingInterface.cpp"
     "TileSizeSelection.cpp"
+    "TileSwizzle.cpp"
     "TypePropagationPass.cpp"
     "UserConfig.cpp"
     "VectorizeMemrefCopy.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
index 9b8468b..96d7594 100644
--- a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
@@ -210,10 +210,12 @@
 }
 
 SmallVector<int64_t>
-getExpandedTileShape(SmallVector<SmallVector<int64_t>> expandShape) {
+getExpandedTileShape(const TileSwizzle::ExpandShapeType &expandShape) {
   SmallVector<int64_t> result;
-  for (auto expandShapeDim : expandShape) {
-    result.append(expandShapeDim);
+  for (auto e : expandShape) {
+    for (auto d : e) {
+      result.push_back(d.size);
+    }
   }
   return result;
 }
diff --git a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h
index b7d75c9..2905028 100644
--- a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h
@@ -143,7 +143,7 @@
 
 /// Concatenates the vectors.
 SmallVector<int64_t>
-getExpandedTileShape(SmallVector<SmallVector<int64_t>> expandShape);
+getExpandedTileShape(const TileSwizzle::ExpandShapeType &expandShape);
 
 } // namespace mlir::iree_compiler
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp
index 4ffa632..1174e81 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp
@@ -37,84 +37,6 @@
 #define GEN_PASS_DEF_GPUMATERIALIZEDEVICEENCODINGPASS
 #include "iree/compiler/Codegen/Common/GPU/Passes.h.inc"
 
-/// Returns the index of the dimension whose flattened size (flattening inner
-/// dimensions into it) matches the given `targetSize`. This is used to compute
-/// interleaving indices.
-///
-/// Example:
-///    Input shape = [16, 8, 4, 4]
-///    Input targetSize = 16
-/// -> Return 2, because the tail of the shape starting at index 2 is [4, 4],
-///    whose product equals targetSize.
-static int64_t getDimIdxForTargetSize(ArrayRef<int64_t> shape,
-                                      int64_t targetSize) {
-  int interleaveAt = 0;
-  int size = 1;
-  for (interleaveAt = shape.size() - 1; interleaveAt >= 0; --interleaveAt) {
-    assert(size <= targetSize);
-    assert((targetSize % size) == 0);
-    if (size == targetSize) {
-      break;
-    }
-    size *= shape[interleaveAt];
-  }
-  return interleaveAt;
-}
-
-/// Generates the swizzle for the full data-tiled-mma tile, including all the
-/// relevant unrolling factors.
-static TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma,
-                              IREE::GPU::MMAFragment fragment) {
-  auto [AType, BType, CType] = mma.getABCElementTypes();
-  int ABits = AType.getIntOrFloatBitWidth();
-  int BBits = BType.getIntOrFloatBitWidth();
-  // TODO(bjacob): Should be looked up from GPU target, instead of hard-coded.
-  const int targetPreferredLoadBitWidth = 128;
-  auto swizzle = getIntrinsicSwizzle(mma.getIntrinsic().getValue(), fragment);
-  switch (fragment) {
-  case IREE::GPU::MMAFragment::Lhs:
-    // A-matrix (LHS). Source dimensions are M (index 0) and K (index 1).
-    // Unroll on K with interleaving, then on M.
-    if (mma.getUnrollK() > 1) {
-      unroll(swizzle, 1, mma.getUnrollK());
-      int interleavingIdx = getDimIdxForTargetSize(
-          swizzle.expandShape[1],
-          targetPreferredLoadBitWidth / (mma.getUnrollK() * ABits));
-      interleave(swizzle, 1, interleavingIdx);
-    }
-    if (mma.getUnrollM() > 1) {
-      unroll(swizzle, 0, mma.getUnrollM());
-    }
-    break;
-  case IREE::GPU::MMAFragment::Rhs:
-    // B-matrix (RHS). Since the pack ops already took care of transposing B,
-    // source dimensions are N (index 0) and K (index 1).
-    // Unroll on K with interleaving, then on N.
-    if (mma.getUnrollK() > 1) {
-      unroll(swizzle, 1, mma.getUnrollK());
-      int interleavingIdx = getDimIdxForTargetSize(
-          swizzle.expandShape[1],
-          targetPreferredLoadBitWidth / (mma.getUnrollK() * BBits));
-      interleave(swizzle, 1, interleavingIdx);
-    }
-    if (mma.getUnrollN() > 1) {
-      unroll(swizzle, 0, mma.getUnrollN());
-    }
-    break;
-  case IREE::GPU::MMAFragment::Acc:
-    // C-matrix (accumulator). Source dimensions are M (index 0) and N (index
-    // 1). Unroll on N, then on M.
-    if (mma.getUnrollN() > 1) {
-      unroll(swizzle, 1, mma.getUnrollN());
-    }
-    if (mma.getUnrollM() > 1) {
-      unroll(swizzle, 0, mma.getUnrollM());
-    }
-    break;
-  }
-  return swizzle;
-}
-
 static bool hasIntrinsic(IREE::GPU::TargetAttr target,
                          IREE::GPU::MMAIntrinsic intrinsic) {
   for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) {
@@ -133,13 +55,16 @@
   Type lhs = elementTypes[0];
   Type rhs = elementTypes[1];
   Type out = elementTypes[2];
-  auto match = [=](MMAIntrinsic intrinsic, int unrollM, int unrollN,
+  auto match = [=](MMAIntrinsic intrinsic, int unrollM, int unrollMToThreads,
+                   int unrollN, int unrollNToThreads,
                    int unrollK) -> std::optional<DataTiledMMAAttr> {
     if (!hasIntrinsic(target, intrinsic)) {
       return std::nullopt;
     }
     auto candidate = DataTiledMMAAttr::get(
-        ctx, MMAIntrinsicAttr::get(ctx, intrinsic), unrollM, unrollN, unrollK);
+        ctx, MMAIntrinsicAttr::get(ctx, intrinsic), /*unroll_m=*/unrollM,
+        /*unroll_m_to_subgroups=*/unrollMToThreads, /*unroll_n=*/unrollN,
+        /*unroll_n_to_subgroups=*/unrollNToThreads, /*unroll_k=*/unrollK);
     auto [candidateLhs, candidateRhs, candidateOut] =
         candidate.getABCElementTypes();
     if (candidateLhs != lhs || candidateRhs != rhs || candidateOut != out) {
@@ -147,13 +72,13 @@
     }
     return candidate;
   };
-  if (auto m = match(MMAIntrinsic::MFMA_F32_16x16x4_F32, 8, 8, 4)) {
+  if (auto m = match(MMAIntrinsic::MFMA_F32_16x16x4_F32, 8, 1, 2, 4, 4)) {
     return m;
   }
-  if (auto m = match(MMAIntrinsic::MFMA_F32_16x16x16_F16, 8, 8, 2)) {
+  if (auto m = match(MMAIntrinsic::MFMA_F32_16x16x16_F16, 8, 1, 2, 4, 2)) {
     return m;
   }
-  if (auto m = match(MMAIntrinsic::MFMA_I32_16x16x32_I8, 8, 8, 2)) {
+  if (auto m = match(MMAIntrinsic::MFMA_I32_16x16x32_I8, 8, 1, 2, 4, 2)) {
     return m;
   }
   // Fallback - no architecture-optimized tile size for this case.
@@ -220,7 +145,7 @@
 
 SmallVector<ReassociationIndices>
 getReassociationIndices(int outerDims,
-                        SmallVector<SmallVector<int64_t>> expandShape) {
+                        const TileSwizzle::ExpandShapeType &expandShape) {
   SmallVector<ReassociationIndices> result;
   int expandedIdx = 0;
   for (int i = 0; i < outerDims; ++i) {
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileSwizzleUtils.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileSwizzleUtils.cpp
index b225e69..94335c4 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileSwizzleUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileSwizzleUtils.cpp
@@ -5,7 +5,6 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 #include "iree/compiler/Codegen/Common/GPU/GPUTileSwizzleUtils.h"
-#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
 
 namespace mlir::iree_compiler {
 
@@ -13,7 +12,7 @@
 // dimensions to expanded dimensions, returns the index of the first expanded
 // dimension corresponding to the given source dimension index.
 static int64_t
-getExpandedDimFirstIdx(const SmallVector<SmallVector<int64_t>> &expandShape,
+getExpandedDimFirstIdx(const TileSwizzle::ExpandShapeType &expandShape,
                        int64_t srcIndex) {
   int dstIndexFirst = 0;
   for (int i = 0; i < srcIndex; ++i) {
@@ -22,14 +21,17 @@
   return dstIndexFirst;
 }
 
-void unroll(TileSwizzle &swizzle, int srcIndex, int unrollFactor) {
+void unroll(TileSwizzle &swizzle, int srcIndex, int unrollFactor,
+            TileSwizzle::Dim::Kind kind) {
   assert(unrollFactor > 1);
   int dstIndexFirst = getExpandedDimFirstIdx(swizzle.expandShape, srcIndex);
-
+  TileSwizzle::Dim unrollDim;
+  unrollDim.size = unrollFactor;
+  unrollDim.kind = kind;
   // The new unrolling dimension is inserted at the start of the expandShape
   // dimensions group corresponding to srcIndex.
   swizzle.expandShape[srcIndex].insert(swizzle.expandShape[srcIndex].begin(),
-                                       unrollFactor);
+                                       unrollDim);
   // Since we are not interleaving here, generating side-by-side copies of the
   // original layout, the new unrolling dimension is the new outermost
   // dimension. Existing entries get shifted to make room for it.
@@ -97,7 +99,10 @@
   // shape expansion for now.
   TileSwizzle swizzle;
   for (auto t : layout.thread) {
-    swizzle.expandShape.push_back({t});
+    TileSwizzle::Dim dim;
+    dim.size = t;
+    dim.kind = TileSwizzle::Dim::Kind::CrossThread; // Because `layout.thread`.
+    swizzle.expandShape.push_back({dim});
   }
   // The layout strides decide the initial swizzle.permutation.
   // Some WMMA intrinsics have tstrides=0 values, assert on that as that
@@ -112,9 +117,12 @@
   // Deal with any element size greater than 1 by inserting it innermost.
   // Notice that this is similar to the unroll() function, just creating an
   // inner dimension instead of an outer dimension.
-  for (int i = 0; i < layout.element.size(); ++i) {
-    if (layout.element[i] != 1) {
-      swizzle.expandShape[i].push_back(layout.element[i]);
+  for (auto [i, e] : llvm::enumerate(layout.element)) {
+    if (e != 1) {
+      TileSwizzle::Dim dim;
+      dim.size = e;
+      dim.kind = TileSwizzle::Dim::Kind::Internal; // Because `layout.element`.
+      swizzle.expandShape[i].push_back(dim);
       int newIndex = getExpandedDimFirstIdx(swizzle.expandShape, i + 1) - 1;
       for (auto &p : swizzle.permutation) {
         p += (p >= newIndex);
@@ -125,13 +133,105 @@
   // Deal with any outer size greater than 1 as just a call to unroll.
   // Iterate over dims in reverse order because we are creating a new outermost
   // dimension each time.
-  for (int i = layout.outer.size() - 1; i >= 0; --i) {
-    if (layout.outer[i] != 1) {
-      unroll(swizzle, i, layout.outer[i]);
+  for (auto [i, o] : llvm::enumerate(layout.outer)) {
+    if (o != 1) {
+      // `layout.outer` means additional Internal dimensions, just like
+      // `layout.element`, just swizzled outermost.
+      unroll(swizzle, i, o, TileSwizzle::Dim::Kind::Internal);
     }
   }
 
   return swizzle;
 }
 
+// Returns the index of the dimension whose flattened size (flattening inner
+// dimensions into it) matches the given `targetSize`. This is used to compute
+// interleaving indices.
+//
+// Example:
+//    Input shape = [16, 8, 4, 4]
+//    Input targetSize = 16
+// -> Return 2, because the tail of the shape starting at index 2 is [4, 4],
+//    whose product equals targetSize.
+static int64_t
+getDimIdxForTargetSize(const TileSwizzle::ExpandShapeDimVectorType &shape,
+                       int64_t targetSize) {
+  int interleaveAt = 0;
+  int size = 1;
+  for (interleaveAt = shape.size() - 1; interleaveAt >= 0; --interleaveAt) {
+    assert(size <= targetSize);
+    assert((targetSize % size) == 0);
+    if (size == targetSize) {
+      break;
+    }
+    size *= shape[interleaveAt].size;
+  }
+  return interleaveAt;
+}
+
+TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma,
+                       IREE::GPU::MMAFragment fragment) {
+  auto [AType, BType, CType] = mma.getABCElementTypes();
+  int ABits = AType.getIntOrFloatBitWidth();
+  int BBits = BType.getIntOrFloatBitWidth();
+  // TODO(bjacob): Should be looked up from GPU target, instead of hard-coded.
+  const int targetPreferredLoadBitWidth = 128;
+  auto swizzle = getIntrinsicSwizzle(mma.getIntrinsic().getValue(), fragment);
+  using Kind = TileSwizzle::Dim::Kind;
+  switch (fragment) {
+  case IREE::GPU::MMAFragment::Lhs:
+    // A-matrix (LHS). Source dimensions are M (index 0) and K (index 1).
+    // Unroll on K with interleaving, then on M.
+    if (mma.getUnrollK() > 1) {
+      unroll(swizzle, 1, mma.getUnrollK(), Kind::CrossIntrinsic);
+      int interleavingIdx = getDimIdxForTargetSize(
+          swizzle.expandShape[1],
+          targetPreferredLoadBitWidth / (mma.getUnrollK() * ABits));
+      interleave(swizzle, 1, interleavingIdx);
+    }
+    if (mma.getUnrollM() > 1) {
+      unroll(swizzle, 0, mma.getUnrollM(), Kind::CrossIntrinsic);
+    }
+    if (mma.getUnrollMToSubgroups() > 1) {
+      unroll(swizzle, 0, mma.getUnrollMToSubgroups(), Kind::CrossThread);
+    }
+    break;
+  case IREE::GPU::MMAFragment::Rhs:
+    // B-matrix (RHS). Since the pack ops already took care of transposing B,
+    // source dimensions are N (index 0) and K (index 1).
+    // Unroll on K with interleaving, then on N.
+    if (mma.getUnrollK() > 1) {
+      unroll(swizzle, 1, mma.getUnrollK(), Kind::CrossIntrinsic);
+      int interleavingIdx = getDimIdxForTargetSize(
+          swizzle.expandShape[1],
+          targetPreferredLoadBitWidth / (mma.getUnrollK() * BBits));
+      interleave(swizzle, 1, interleavingIdx);
+    }
+    if (mma.getUnrollN() > 1) {
+      unroll(swizzle, 0, mma.getUnrollN(), Kind::CrossIntrinsic);
+    }
+    if (mma.getUnrollNToSubgroups() > 1) {
+      unroll(swizzle, 0, mma.getUnrollNToSubgroups(), Kind::CrossThread);
+    }
+    break;
+  case IREE::GPU::MMAFragment::Acc:
+    // C-matrix (accumulator). Source dimensions are M (index 0) and N (index
+    // 1). Unroll on N, then on M.
+    if (mma.getUnrollN() > 1) {
+      unroll(swizzle, 1, mma.getUnrollN(), Kind::CrossIntrinsic);
+    }
+    if (mma.getUnrollNToSubgroups() > 1) {
+      unroll(swizzle, 1, mma.getUnrollNToSubgroups(), Kind::CrossThread);
+    }
+    if (mma.getUnrollM() > 1) {
+      unroll(swizzle, 0, mma.getUnrollM(), Kind::CrossIntrinsic);
+    }
+    if (mma.getUnrollMToSubgroups() > 1) {
+      unroll(swizzle, 0, mma.getUnrollMToSubgroups(), Kind::CrossThread);
+    }
+    break;
+  }
+  return swizzle;
+}
+
 } // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileSwizzleUtils.h b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileSwizzleUtils.h
index fc5af79..fc79bf0 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileSwizzleUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileSwizzleUtils.h
@@ -8,6 +8,7 @@
 #define IREE_COMPILER_SRC_IREE_COMPILER_CODEGEN_COMMON_GPU_GPUTILESWIZZLEUTILS_H_
 
 #include "iree/compiler/Codegen/Common/TileSwizzle.h"
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
 #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
 
 namespace mlir::iree_compiler {
@@ -17,17 +18,26 @@
 TileSwizzle getIntrinsicSwizzle(IREE::GPU::MMAIntrinsic intrinsic,
                                 IREE::GPU::MMAFragment fragment);
 
+// Returns the swizzle for the full data-tiled-mma tile, including all the
+// relevant unrolling factors.
+TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma,
+                       IREE::GPU::MMAFragment fragment);
+
 // Unrolls the dimension given by `srcIndex` by the given `unrollFactor`.
 // This is not interleaving layouts. The layout will consist of multiple copies
 // of the input tile, side by side.
 //
+// The enum parameter `kind` initializes the corresponding member on the newly
+// created TileSwizzle::Dim.
+//
 // Example:
 //    Input swizzle = { expandShape = [[16], [4]], permutation = [1, 0] }
 //    Input srcIndex = 1
 //    Input unrollFactor = 4
 // -> Output swizzle = { expandShape = [[16], [4, 4]], permutation = [1, 2, 0] }
 //
-void unroll(TileSwizzle &swizzle, int srcIndex, int unrollFactor);
+void unroll(TileSwizzle &swizzle, int srcIndex, int unrollFactor,
+            TileSwizzle::Dim::Kind kind);
 
 // Interleaves the layout in `swizzle` by mutating `swizzle.permutation` to
 // move permutation[0], the outer-most dimension (which the unroll() function
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir
index ac28e84..8e5927b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir
@@ -128,11 +128,11 @@
 // CHECK-SAME:      inner_tiles = [128, 16]
 // CHECK-SAME:      : tensor<255x513xf32> -> tensor<5x16x128x16xf32>
 // CHECK:         %[[EXPAND:.*]] = tensor.expand_shape %[[PACK]]
-// CHECK-SAME       : tensor<5x16x128x16xf32> into tensor<5x16x8x16x4x4xf32>
+// CHECK-SAME       : tensor<5x16x128x16xf32> into tensor<5x16x4x2x16x4x4xf32>
 // CHECK:         %[[TRANSPOSE:.*]] = linalg.transpose
-// CHECK-SAME:       ins(%[[EXPAND]] : tensor<5x16x8x16x4x4xf32>)
-// CHECK-SAME:       outs({{.*}} : tensor<5x16x8x4x16x4xf32>)
-// CHECK-SAME:       permutation = [0, 1, 2, 5, 3, 4]
+// CHECK-SAME:       ins(%[[EXPAND]] : tensor<5x16x4x2x16x4x4xf32>)
+// CHECK-SAME:       outs({{.*}} : tensor<5x16x4x2x4x16x4xf32>)
+// CHECK-SAME:       permutation = [0, 1, 2, 3, 6, 4, 5]
 // CHECK:         flow.dispatch.tensor.store %[[TRANSPOSE]]
 
 // -----
@@ -161,11 +161,11 @@
 // CHECK-SAME:      inner_tiles = [128, 128]
 // CHECK-SAME:      : tensor<255x513xf32> -> tensor<2x5x128x128xf32>
 // CHECK:         %[[EXPAND:.*]] = tensor.expand_shape %[[PACK]]
-// CHECK-SAME       : tensor<2x5x128x128xf32> into tensor<2x5x8x4x4x8x16xf32>
+// CHECK-SAME       : tensor<2x5x128x128xf32> into tensor<2x5x8x4x4x4x2x16xf32>
 // CHECK:         %[[TRANSPOSE:.*]] = linalg.transpose
-// CHECK-SAME:       ins(%[[EXPAND]] : tensor<2x5x8x4x4x8x16xf32>)
-// CHECK-SAME:       outs({{.*}} : tensor<2x5x8x8x4x16x4xf32>)
-// CHECK-SAME:       permutation = [0, 1, 2, 5, 3, 6, 4]
+// CHECK-SAME:       ins(%[[EXPAND]] : tensor<2x5x8x4x4x4x2x16xf32>)
+// CHECK-SAME:       outs({{.*}} : tensor<2x5x8x4x2x4x16x4xf32>)
+// CHECK-SAME:       permutation = [0, 1, 2, 5, 6, 3, 7, 4]
 // CHECK:         flow.dispatch.tensor.store %[[TRANSPOSE]]
 
 // -----
@@ -189,11 +189,11 @@
 
 // CHECK-LABEL: func.func @unset_encoding_ACC_unroll8x8x4_MFMA_F32_16x16x4_F32() {
 // CHECK:         %[[TRANSPOSE:.*]] = linalg.transpose
-// CHECK-SAME:       ins(%{{.+}} : tensor<2x5x8x8x4x16x4xf32>)
-// CHECK-SAME:       outs({{.*}} : tensor<2x5x8x4x4x8x16xf32>)
-// CHECK-SAME:       permutation = [0, 1, 2, 4, 6, 3, 5]
+// CHECK-SAME:       ins(%{{.+}} : tensor<2x5x8x4x2x4x16x4xf32>)
+// CHECK-SAME:       outs({{.*}} : tensor<2x5x8x4x4x4x2x16xf32>)
+// CHECK-SAME:       permutation = [0, 1, 2, 5, 7, 3, 4, 6]
 // CHECK:         %[[COLLAPSE:.*]] = tensor.collapse_shape %[[TRANSPOSE]]
-// CHECK-SAME:      : tensor<2x5x8x4x4x8x16xf32> into tensor<2x5x128x128xf32>
+// CHECK-SAME:      : tensor<2x5x8x4x4x4x2x16xf32> into tensor<2x5x128x128xf32>
 // CHECK:         %[[UNPACK:.*]] = tensor.unpack %[[COLLAPSE]]
 // CHECK-SAME:      outer_dims_perm = [0, 1]
 // CHECK-SAME:      inner_dims_pos = [0, 1]
@@ -232,11 +232,11 @@
 }
 // CHECK-LABEL: func.func @unset_encoding_ACC_dynamic_unroll8x8x4_MFMA_F32_16x16x4_F32
 // CHECK:         %[[TRANSPOSE:.*]] = linalg.transpose
-// CHECK-SAME:       ins(%{{.+}} : tensor<?x?x8x8x4x16x4xf32>)
-// CHECK-SAME:       outs({{.*}} : tensor<?x?x8x4x4x8x16xf32>)
-// CHECK-SAME:       permutation = [0, 1, 2, 4, 6, 3, 5]
+// CHECK-SAME:       ins(%{{.+}} : tensor<?x?x8x4x2x4x16x4xf32>)
+// CHECK-SAME:       outs({{.*}} : tensor<?x?x8x4x4x4x2x16xf32>)
+// CHECK-SAME:       permutation = [0, 1, 2, 5, 7, 3, 4, 6]
 // CHECK:         %[[COLLAPSE:.*]] = tensor.collapse_shape %[[TRANSPOSE]]
-// CHECK-SAME:      : tensor<?x?x8x4x4x8x16xf32> into tensor<?x?x128x128xf32>
+// CHECK-SAME:      : tensor<?x?x8x4x4x4x2x16xf32> into tensor<?x?x128x128xf32>
 // CHECK:         %[[UNPACK:.*]] = tensor.unpack %[[COLLAPSE]]
 // CHECK-SAME:      outer_dims_perm = [0, 1]
 // CHECK-SAME:      inner_dims_pos = [0, 1]
@@ -295,12 +295,12 @@
 // CHECK-DAG:   %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1)
 // CHECK-DAG:   %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
 // CHECK-DAG:   %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x8x4x16x4xf32>
-// CHECK-DAG:   %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x8x4x16x4xf32>
-// CHECK-DAG:   %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x8x8x4x16x4xf32>
+// CHECK-DAG:   %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x4x2x4x16x4xf32>
+// CHECK-DAG:   %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x8x4x2x4x16x4xf32>
 // CHECK:       %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
 // CHECK-SAME:    indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
 // CHECK-SAME:    iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
-// CHECK-SAME:    kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m = 8, unroll_n = 8, unroll_k = 4>
+// CHECK-SAME:    kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m = 8, unroll_n = 2, unroll_n_to_subgroups = 4, unroll_k = 4>
 // CHECK:       flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]]
 
 
@@ -365,11 +365,11 @@
 // CHECK-SAME:      inner_tiles = [128, 64]
 // CHECK-SAME:      : tensor<255x513xi8> -> tensor<5x4x128x64xi8>
 // CHECK:         %[[EXPAND:.*]] = tensor.expand_shape %[[PACK]]
-// CHECK-SAME       : tensor<5x4x128x64xi8> into tensor<5x4x8x16x2x4x8xi8>
+// CHECK-SAME       : tensor<5x4x128x64xi8> into tensor<5x4x4x2x16x2x4x8xi8>
 // CHECK:         %[[TRANSPOSE:.*]] = linalg.transpose
-// CHECK-SAME:       ins(%[[EXPAND]] : tensor<5x4x8x16x2x4x8xi8>)
-// CHECK-SAME:       outs({{.*}} : tensor<5x4x8x4x16x2x8xi8>)
-// CHECK-SAME:       permutation = [0, 1, 2, 5, 3, 4, 6]
+// CHECK-SAME:       ins(%[[EXPAND]] : tensor<5x4x4x2x16x2x4x8xi8>)
+// CHECK-SAME:       outs({{.*}} : tensor<5x4x4x2x4x16x2x8xi8>)
+// CHECK-SAME:       permutation = [0, 1, 2, 3, 6, 4, 5, 7]
 // CHECK:         flow.dispatch.tensor.store %[[TRANSPOSE]]
 
 // -----
@@ -398,11 +398,11 @@
 // CHECK-SAME:      inner_tiles = [128, 128]
 // CHECK-SAME:      : tensor<255x513xi32> -> tensor<2x5x128x128xi32>
 // CHECK:         %[[EXPAND:.*]] = tensor.expand_shape %[[PACK]]
-// CHECK-SAME       : tensor<2x5x128x128xi32> into tensor<2x5x8x4x4x8x16xi32>
+// CHECK-SAME       : tensor<2x5x128x128xi32> into tensor<2x5x8x4x4x4x2x16xi32>
 // CHECK:         %[[TRANSPOSE:.*]] = linalg.transpose
-// CHECK-SAME:       ins(%[[EXPAND]] : tensor<2x5x8x4x4x8x16xi32>)
-// CHECK-SAME:       outs({{.*}} : tensor<2x5x8x8x4x16x4xi32>)
-// CHECK-SAME:       permutation = [0, 1, 2, 5, 3, 6, 4]
+// CHECK-SAME:       ins(%[[EXPAND]] : tensor<2x5x8x4x4x4x2x16xi32>)
+// CHECK-SAME:       outs({{.*}} : tensor<2x5x8x4x2x4x16x4xi32>)
+// CHECK-SAME:       permutation = [0, 1, 2, 5, 6, 3, 7, 4]
 // CHECK:         flow.dispatch.tensor.store %[[TRANSPOSE]]
 
 // -----
@@ -426,11 +426,11 @@
 
 // CHECK-LABEL: func.func @unset_encoding_ACC_unroll8x8x2_MFMA_I32_16x16x32_I8() {
 // CHECK:         %[[TRANSPOSE:.*]] = linalg.transpose
-// CHECK-SAME:       ins(%{{.+}} : tensor<2x5x8x8x4x16x4xi32>)
-// CHECK-SAME:       outs({{.*}} : tensor<2x5x8x4x4x8x16xi32>)
-// CHECK-SAME:       permutation = [0, 1, 2, 4, 6, 3, 5]
+// CHECK-SAME:       ins(%{{.+}} : tensor<2x5x8x4x2x4x16x4xi32>)
+// CHECK-SAME:       outs({{.*}} : tensor<2x5x8x4x4x4x2x16xi32>)
+// CHECK-SAME:       permutation = [0, 1, 2, 5, 7, 3, 4, 6]
 // CHECK:         %[[COLLAPSE:.*]] = tensor.collapse_shape %[[TRANSPOSE]]
-// CHECK-SAME:      : tensor<2x5x8x4x4x8x16xi32> into tensor<2x5x128x128xi32>
+// CHECK-SAME:      : tensor<2x5x8x4x4x4x2x16xi32> into tensor<2x5x128x128xi32>
 // CHECK:         %[[UNPACK:.*]] = tensor.unpack %[[COLLAPSE]]
 // CHECK-SAME:      outer_dims_perm = [0, 1]
 // CHECK-SAME:      inner_dims_pos = [0, 1]
@@ -490,10 +490,10 @@
 // CHECK-DAG:   %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1)
 // CHECK-DAG:   %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
 // CHECK-DAG:   %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x8x4x16x2x8xi8>
-// CHECK-DAG:   %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x8x4x16x2x8xi8>
-// CHECK-DAG:   %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x8x8x4x16x4xi32>
+// CHECK-DAG:   %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x4x2x4x16x2x8xi8>
+// CHECK-DAG:   %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x8x4x2x4x16x4xi32>
 // CHECK:       %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
 // CHECK-SAME:    indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
 // CHECK-SAME:    iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
-// CHECK-SAME:    kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, unroll_m = 8, unroll_n = 8, unroll_k = 2>
+// CHECK-SAME:    kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, unroll_m = 8, unroll_n = 2, unroll_n_to_subgroups = 4, unroll_k = 2>
 // CHECK:       flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]]
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileSwizzle.cpp b/compiler/src/iree/compiler/Codegen/Common/TileSwizzle.cpp
new file mode 100644
index 0000000..7ae46e6
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/TileSwizzle.cpp
@@ -0,0 +1,46 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Common/TileSwizzle.h"
+#include "llvm/ADT/STLExtras.h"
+
+namespace mlir::iree_compiler {
+
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+                              TileSwizzle::Dim::Kind kind) {
+  switch (kind) {
+  case TileSwizzle::Dim::Kind::Internal:
+    return os << "Internal";
+  case TileSwizzle::Dim::Kind::CrossThread:
+    return os << "CrossThread";
+  case TileSwizzle::Dim::Kind::CrossIntrinsic:
+    return os << "CrossIntrinsic";
+  }
+}
+
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os, TileSwizzle::Dim dim) {
+  return os << dim.size << "(" << dim.kind << ")";
+}
+
+static llvm::raw_ostream &
+operator<<(llvm::raw_ostream &os,
+           const TileSwizzle::ExpandShapeDimVectorType &expandShapeDimVector) {
+  os << "[";
+  llvm::interleaveComma(expandShapeDimVector, os);
+  return os << "]";
+}
+
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+                              const TileSwizzle &swizzle) {
+  os << "{expandShape = [";
+  llvm::interleaveComma(swizzle.expandShape, os);
+  os << "], swizzle = [";
+  llvm::interleaveComma(swizzle.permutation, os);
+  os << "]}";
+  return os;
+}
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileSwizzle.h b/compiler/src/iree/compiler/Codegen/Common/TileSwizzle.h
index b908ae4..738bb6a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileSwizzle.h
+++ b/compiler/src/iree/compiler/Codegen/Common/TileSwizzle.h
@@ -9,6 +9,7 @@
 
 #include <cstdint>
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/raw_ostream.h"
 
 namespace mlir::iree_compiler {
 
@@ -16,13 +17,50 @@
 // pair of ops performing a change of layout within the tiles. This is used
 // on GPU, where the tiles themselves can have an arbitrary layout.
 struct TileSwizzle {
+  struct Dim {
+    // Describes what varies across this dimension.
+    enum class Kind : int8_t {
+      // This dimension is internal to one intrinsic on one thread. This
+      // is only seen for intrinsic operands that are themselves vectors.
+      // For example, with AMD MFMA, for the MFMA_F32_16x16x4_F32 intrinsic,
+      // the C-matrix operand is a vector of 4 floats already at the level of
+      // one intrinsic on one thread. That dimension of size 4 is 'Internal'.
+      Internal,
+      // This dimension is internal to one intrinsic, but is across threads.
+      // For example, with AMD MFMA, for the MFMA_F32_16x16x4_F32 intrinsic,
+      // the A-matrix tile has shape 16x4, and these two dimensions of size 16
+      // and 4 are 'CrossThread': neither is visible at the single-thread level
+      // (in the intrinsic itself, the A-matrix operand is a single scalar) but
+      // as we move along these dimensions, we are moving over the 64 threads
+      // of the subgroup.
+      //
+      // Another example of cross-thread dimensions is in kernels that are
+      // "unrolled" across subgroups. Such dimensions are cross-subgroup, so in
+      // particular they are cross-thread.
+      CrossThread,
+      // This dimensions is across intrinsics, as in, actual instructions in the
+      // generated code. In other words, it is an actual unrolling factor,
+      // resulting in this many more instructions being generated and executed
+      // on each thread/subgroup.
+      CrossIntrinsic
+    };
+
+    Kind kind = Kind::Internal;
+
+    // The size of the dimension.
+    int16_t size = 0;
+  };
+
+  using ExpandShapeDimVectorType = llvm::SmallVector<Dim, 4>;
+  using ExpandShapeType = llvm::SmallVector<ExpandShapeDimVectorType>;
+
   // This vector-of-vectors contains all the information needed to generate
   // a `tensor.expand_shape` creating additional internal dimensions into the
   // tile. For example, expandShape = [[16], [4, 2]] means that the original
   // tile shape [16, 8] gets expanded such that the first dimension 16 is left
   // unchanged, and the second dimension 8 gets split into two internal dims
   // of size 4 and 2.
-  llvm::SmallVector<llvm::SmallVector<int64_t>> expandShape;
+  ExpandShapeType expandShape;
   // This permutation vector applies to the expanded dimensions and is used
   // to generate a `linalg.transpose` changing the layout of the tile. For
   // example, permutation[0] dictates which of the expanded dimensions becomes
@@ -30,6 +68,14 @@
   llvm::SmallVector<int64_t> permutation;
 };
 
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+                              TileSwizzle::Dim::Kind kind);
+
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os, TileSwizzle::Dim dim);
+
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+                              const TileSwizzle &swizzle);
+
 } // namespace mlir::iree_compiler
 
 #endif // IREE_COMPILER_SRC_IREE_COMPILER_CODEGEN_COMMON_TILESWIZZLE_H_
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index 07cd27d..b8b8806 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -898,7 +898,8 @@
 std::tuple<int64_t, int64_t, int64_t> DataTiledMMAAttr::getMNKShape() const {
   MLIRContext *ctx = getContext();
   auto opaqueLayout = getOpaqueMFMALayout(ctx, getIntrinsic().getValue());
-  return {opaqueLayout.mSize * getUnrollM(), opaqueLayout.nSize * getUnrollN(),
+  return {opaqueLayout.mSize * getUnrollM() * getUnrollMToSubgroups(),
+          opaqueLayout.nSize * getUnrollN() * getUnrollNToSubgroups(),
           opaqueLayout.kSize * getUnrollK()};
 }
 
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
index be04d19..d3dc53a 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
@@ -255,9 +255,11 @@
 
   let parameters = (ins
     "::mlir::iree_compiler::IREE::GPU::MMAIntrinsicAttr":$intrinsic,
-    "int64_t":$unroll_m,
-    "int64_t":$unroll_n,
-    "int64_t":$unroll_k
+    DefaultValuedParameter<"int64_t", "1", "Unrolling along the M dimension, on the same thread.">:$unroll_m,
+    DefaultValuedParameter<"int64_t", "1", "Unrolling along the M dimension, distributed across this many more threads.">:$unroll_m_to_subgroups,
+    DefaultValuedParameter<"int64_t", "1", "Unrolling along the N dimension, on the same thread.">:$unroll_n,
+    DefaultValuedParameter<"int64_t", "1", "Unrolling along the N dimension, distributed across this many more threads.">:$unroll_n_to_subgroups,
+    DefaultValuedParameter<"int64_t", "1", "Unrolling along the K dimension, on the same thread, with interleaved layout.">:$unroll_k
   );
 }
 
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_attrs.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_attrs.mlir
index 046a3c8..0ebe947 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_attrs.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_attrs.mlir
@@ -29,21 +29,21 @@
 
 module {
   func.func @test_data_tiled_mfma_f32_16x16x4_f32() attributes {
-      mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m = 1, unroll_n = 1, unroll_k = 1>} {
+      mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m = 4, unroll_m_to_subgroups = 2, unroll_k = 1>} {
     return
   }
 }
 // CHECK-LABEL: func @test_data_tiled_mfma_f32_16x16x4_f32
-//  CHECK-SAME:   mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m = 1, unroll_n = 1, unroll_k = 1>
+//  CHECK-SAME:   mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m = 4, unroll_m_to_subgroups = 2>
 
 module {
   func.func @test_data_tiled_mfma_f32_16x16x16_f16() attributes {
-      mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x16_F16, unroll_m = 1, unroll_n = 1, unroll_k = 1>} {
+      mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x16_F16, unroll_m = 1, unroll_n_to_subgroups = 2, unroll_k = 2>} {
     return
   }
 }
 // CHECK-LABEL: func @test_data_tiled_mfma_f32_16x16x16_f16
-//  CHECK-SAME:   mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x16_F16, unroll_m = 1, unroll_n = 1, unroll_k = 1>
+//  CHECK-SAME:   mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x16_F16, unroll_n_to_subgroups = 2, unroll_k = 2>
 
 module {
   func.func @test_data_tiled_mfma_i32_16x16x32_i8() attributes {
@@ -52,7 +52,7 @@
   }
 }
 // CHECK-LABEL: func @test_data_tiled_mfma_i32_16x16x32_i8
-//  CHECK-SAME:   mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, unroll_m = 1, unroll_n = 1, unroll_k = 1>
+//  CHECK-SAME:   mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8>
 
 module {
   func.func @test_any_lowering_config() attributes {
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir
index b174f8f..0aa922e 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir
@@ -227,7 +227,7 @@
   %0 = iree_gpu.multi_mma %lhs, %rhs, %acc {
     indexing_maps = #contraction_accesses,
     iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>],
-    kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m = 1, unroll_n = 1, unroll_k = 1>
+    kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32>
   } : tensor<?x?x4x16x1x1xf32>, tensor<?x?x4x16x1x1xf32> into tensor<?x?x4x16x4x1xf32>
   return %0 : tensor<?x?x4x16x4x1xf32>
 }
@@ -240,7 +240,7 @@
 //       CHECK:   iree_gpu.multi_mma %arg0, %arg1, %arg2
 //  CHECK-SAME:       indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
 //  CHECK-SAME:       iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
-//  CHECK-SAME:       kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m = 1, unroll_n = 1, unroll_k = 1>
+//  CHECK-SAME:       kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32>
 //  CHECK-SAME:     : tensor<?x?x4x16x1x1xf32>, tensor<?x?x4x16x1x1xf32> into tensor<?x?x4x16x4x1xf32>
 
 // -----
@@ -272,6 +272,34 @@
 
 // -----
 
+#contraction_accesses = [
+ affine_map<(i, j, k) -> (i, k)>,
+ affine_map<(i, j, k) -> (k, j)>,
+ affine_map<(i, j, k) -> (i, j)>
+]
+func.func @data_tiled_2x2x4_tensor_multi_mma(%lhs: tensor<?x?x2x4x16x1x4xf32>, %rhs: tensor<?x?x2x4x16x1x4xf32>, %acc: tensor<?x?x2x2x4x16x4x1xf32>) -> tensor<?x?x2x2x4x16x4x1xf32> {
+  %0 = iree_gpu.multi_mma %lhs, %rhs, %acc {
+    indexing_maps = #contraction_accesses,
+    iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>],
+    kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m_to_subgroups = 2, unroll_n_to_subgroups = 2, unroll_k = 4>
+  } : tensor<?x?x2x4x16x1x4xf32>, tensor<?x?x2x4x16x1x4xf32> into tensor<?x?x2x2x4x16x4x1xf32>
+  return %0 : tensor<?x?x2x2x4x16x4x1xf32>
+}
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @data_tiled_2x2x4_tensor_multi_mma
+//       CHECK:   iree_gpu.multi_mma %arg0, %arg1, %arg2
+//  CHECK-SAME:       indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
+//  CHECK-SAME:       iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
+//  CHECK-SAME:       kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m_to_subgroups = 2, unroll_n_to_subgroups = 2, unroll_k = 4>
+//  CHECK-SAME:     : tensor<?x?x2x4x16x1x4xf32>, tensor<?x?x2x4x16x1x4xf32> into tensor<?x?x2x2x4x16x4x1xf32>
+
+
+// -----
+
 func.func @tensor_barrier(%input: tensor<?xf16>) -> tensor<?xf16> {
   %out = iree_gpu.value_barrier %input : tensor<?xf16>
   return %out : tensor<?xf16>