[GPU] Support multiple contraction dims in MmaSchedules (#18720)

This adds support for multiple M, N, and K dims in problems when
deducing a GPUMMASchedule. The new heuristic is similar to the old one,
but works on pairs of M and N dims. For example:
```
tensor<M1xM0xK1xK0> * tensor<N1xN0xK1xK0> -> tensor<M1xN1xM0xN0>
```
This will try to distribute the seeded tile counts to `M0` and `N0`
(first attempting to distribute evenly, and then distributing to N
followed by N), and then distribute the residual counts to `M1` and
`N1`. The K tile counts will be partitioned to `K0` first, and then the
residual tile counts will be partitioned to `K1`.

This PR also updates the config selection logic for the TileAndFuse
pipeline to make use of the multiple contraction dimensions in mma
schedules.

---------

Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
index dc30783..790484d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
@@ -9,6 +9,7 @@
 #include <cstdint>
 
 #include "llvm/ADT/APInt.h"
+#include "llvm/ADT/Sequence.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/raw_ostream.h"
@@ -20,51 +21,106 @@
 
 namespace mlir::iree_compiler {
 
+template <typename T>
 static llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
-                                     const GPUMMASchedule &schedule) {
-  os << "mSize: " << schedule.mSize << ", ";
-  os << "nSize: " << schedule.nSize << ", ";
-  os << "kSize: " << schedule.kSize << ", ";
-  os << "mTileCount: " << schedule.mTileCount << ", ";
-  os << "nTileCount: " << schedule.nTileCount << ", ";
-  os << "kTileCount: " << schedule.kTileCount << ", ";
-  os << "mWarpCount: " << schedule.mWarpCount << ", ";
-  os << "nWarpCount: " << schedule.nWarpCount;
+                                     const llvm::SmallVectorImpl<T> &vector) {
+  os << "[";
+  llvm::interleaveComma(vector, os);
+  os << "]";
   return os;
 }
 
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+                              const GPUMMASchedule &schedule) {
+  os << "mSizes: " << schedule.mSize << ", ";
+  os << "nSizes: " << schedule.nSize << ", ";
+  os << "kSizes: " << schedule.kSize << ", ";
+  os << "mTileSizes: " << schedule.mTileSizes << ", ";
+  os << "nTileSizes: " << schedule.nTileSizes << ", ";
+  os << "kTileSizes: " << schedule.kTileSizes << ", ";
+  os << "mSubgroupCounts: " << schedule.mSubgroupCounts << ", ";
+  os << "nSubgroupCounts: " << schedule.nSubgroupCounts;
+  return os;
+}
+
+// Shortened helper to compute the product of `values`.
+static int64_t prod(ArrayRef<int64_t> values) {
+  return ShapedType::getNumElements(values);
+}
+
 static int64_t calculateSharedMemoryUsedInBytes(const GPUMMASchedule &schedule,
                                                 int64_t lhsBitwidth,
                                                 int64_t rhsBitwidth) {
-  int64_t tileM = schedule.mSize * schedule.mTileCount * schedule.mWarpCount;
-  int64_t tileN = schedule.nSize * schedule.nTileCount * schedule.nWarpCount;
-  int64_t tileK = schedule.kSize * schedule.kTileCount;
+
+  int64_t tileM = schedule.mSize * prod(schedule.mTileSizes) *
+                  prod(schedule.mSubgroupCounts);
+  int64_t tileN = schedule.nSize * prod(schedule.nTileSizes) *
+                  prod(schedule.nSubgroupCounts);
+  int64_t tileK = schedule.kSize * prod(schedule.kTileSizes);
   return (tileM * tileK * lhsBitwidth + tileN * tileK * rhsBitwidth) / 8;
 }
 
+/// Check that a GPUMMASchedule fits alignment restrictions. To be aligned,
+/// the problem must be evenly divisible by the number of elements in the
+/// schedule for each dimension. If `mustBeAligned` is false, then the innermost
+/// problem dimension is allowed to be unaligned .
 static bool isScheduleAligned(const GPUMatmulShapeType &problem,
                               const GPUMMASchedule &schedule,
                               bool mustBeAligned) {
-  auto alignedMSize =
-      mustBeAligned
-          ? problem.mSize
-          : llvm::divideCeil(problem.mSize, schedule.mSize) * schedule.mSize;
-  auto alignedNSize =
-      mustBeAligned
-          ? problem.nSize
-          : llvm::divideCeil(problem.nSize, schedule.nSize) * schedule.nSize;
-  auto alignedKSize =
-      mustBeAligned
-          ? problem.kSize
-          : llvm::divideCeil(problem.kSize, schedule.kSize) * schedule.kSize;
-  bool isValidM = (alignedMSize % (schedule.mSize * schedule.mTileCount *
-                                   schedule.mWarpCount)) == 0;
-  bool isValidN = (alignedNSize % (schedule.nSize * schedule.nTileCount *
-                                   schedule.nWarpCount)) == 0;
-  bool isValidK = (alignedKSize % (schedule.kSize * schedule.kTileCount)) == 0;
+  SmallVector<int64_t> alignedMSizes(problem.mSizes);
+  alignedMSizes.back() =
+      mustBeAligned ? problem.mSizes.back()
+                    : llvm::divideCeil(problem.mSizes.back(), schedule.mSize) *
+                          schedule.mSize;
+  SmallVector<int64_t> alignedNSizes(problem.nSizes);
+  alignedNSizes.back() =
+      mustBeAligned ? problem.nSizes.back()
+                    : llvm::divideCeil(problem.nSizes.back(), schedule.nSize) *
+                          schedule.nSize;
+  SmallVector<int64_t> alignedKSizes(problem.kSizes);
+  alignedKSizes.back() =
+      mustBeAligned ? problem.kSizes.back()
+                    : llvm::divideCeil(problem.kSizes.back(), schedule.kSize) *
+                          schedule.kSize;
+  // Returns the number of elements in the schedule for each dimension.
+  auto getScheduleSizes =
+      [&](int64_t size, SmallVector<int64_t> tileCount,
+          std::optional<SmallVector<int64_t>> subgroupCount) {
+        SmallVector<int64_t> sizes = llvm::map_to_vector(
+            llvm::seq<int64_t>(tileCount.size()), [&](int64_t i) {
+              return subgroupCount ? tileCount[i] * subgroupCount.value()[i]
+                                   : tileCount[i];
+            });
+        sizes.back() *= size;
+        return sizes;
+      };
+  // Checks whether the elements of `a` are evenly divisible by the
+  // corresponding elements of `b`.
+  auto areAligned = [](SmallVector<int64_t> a, SmallVector<int64_t> b) {
+    for (auto [aVal, bVal] : llvm::zip_equal(a, b)) {
+      if (aVal % bVal != 0) {
+        return false;
+      }
+    }
+    return true;
+  };
+  bool isValidM = areAligned(
+      alignedMSizes, getScheduleSizes(schedule.mSize, schedule.mTileSizes,
+                                      schedule.mSubgroupCounts));
+  bool isValidN = areAligned(
+      alignedNSizes, getScheduleSizes(schedule.nSize, schedule.nTileSizes,
+                                      schedule.nSubgroupCounts));
+  bool isValidK = areAligned(
+      alignedKSizes,
+      getScheduleSizes(schedule.kSize, schedule.kTileSizes, std::nullopt));
   return isValidM && isValidN && isValidK;
 }
 
+/// Returns whether or not a GPUMMASchedule is valid for the given problem.
+/// This checks that:
+///  - The problem is aligned to the schedule
+///  - the number of threads in the schedule workgroup can be distributed
+///    to a corresponding vector.transfer read in VectorDistribute.
 static bool isValidMMASchedule(const GPUMatmulShapeType &problem,
                                const GPUMMASchedule &schedule,
                                bool mustBeAligned, int64_t subgroupSize,
@@ -76,11 +132,13 @@
   const int64_t kMaxVectorLoadBitWidth = 128;
   int64_t elemsPerThread =
       kMaxVectorLoadBitWidth / problem.bType.getIntOrFloatBitWidth();
-  int64_t wgThreads = schedule.mWarpCount * schedule.nWarpCount * subgroupSize;
-
-  int64_t mWgSize = schedule.mSize * schedule.mTileCount * schedule.mWarpCount;
-  int64_t nWgSize = schedule.nSize * schedule.nTileCount * schedule.nWarpCount;
-  int64_t kWgSize = schedule.kSize * schedule.kTileCount;
+  int64_t wgThreads = subgroupSize * prod(schedule.mSubgroupCounts) *
+                      prod(schedule.nSubgroupCounts);
+  int64_t mWgSize = schedule.mSize * prod(schedule.mTileSizes) *
+                    prod(schedule.mSubgroupCounts);
+  int64_t nWgSize = schedule.nSize * prod(schedule.nTileSizes) *
+                    prod(schedule.nSubgroupCounts);
+  int64_t kWgSize = schedule.kSize * prod(schedule.kTileSizes);
   int64_t innerLhsDimSize = transposedLhs ? mWgSize : kWgSize;
   int64_t innerRhsDimSize = transposedRhs ? kWgSize : nWgSize;
 
@@ -94,6 +152,10 @@
   return isAligned && isDistributableLhs && isDistributableRhs;
 }
 
+/// Tries to fit the schedule into shared memory by decrementing the size of the
+/// schedule dimensions from outermost to innermost until a valid schedule is
+/// found. The schedule sizes are reduced in the order of mTileSizes,
+/// nTileSizes, kTileSizes, mSubgroupCounts, nSubgroupCounts.
 static FailureOr<GPUMMASchedule> fitScheduleInSharedMemory(
     GPUMatmulShapeType intrinsic, GPUMMASchedule schedule,
     llvm::function_ref<bool(const GPUMMASchedule &schedule)> isScheduleValid) {
@@ -105,31 +167,35 @@
       llvm::dbgs() << "Shrinking schedule...\n";
     });
 
-    auto decrementIfPossible = [](int64_t &c) -> LogicalResult {
-      if (c <= 1) {
-        return failure();
+    auto decrementIfPossible =
+        [](SmallVector<int64_t> &sizes) -> LogicalResult {
+      for (int64_t &size : sizes) {
+        if (size <= 1)
+          continue;
+        --size;
+        return success();
       }
-      --c;
-      return success();
+      return failure();
     };
 
     // Attempt to shrink the schedule along one of the dimensions.
     // TODO: A better solution should probably factor problem.mSize /
-    // (mWarpCount * mTileCount * mSize) and then pop off the smallest factors
-    // one at a time, preferably trying to keep the tile "generally square."
-    if (succeeded(decrementIfPossible(schedule.mTileCount))) {
+    // (mSubgroupCount * mTileCount * mSize) and then pop off the smallest
+    // factors one at a time, preferably trying to keep the tile "generally
+    // square."
+    if (succeeded(decrementIfPossible(schedule.mTileSizes))) {
       continue;
     }
-    if (succeeded(decrementIfPossible(schedule.nTileCount))) {
+    if (succeeded(decrementIfPossible(schedule.nTileSizes))) {
       continue;
     }
-    if (succeeded(decrementIfPossible(schedule.kTileCount))) {
+    if (succeeded(decrementIfPossible(schedule.kTileSizes))) {
       continue;
     }
-    if (succeeded(decrementIfPossible(schedule.mWarpCount))) {
+    if (succeeded(decrementIfPossible(schedule.mSubgroupCounts))) {
       continue;
     }
-    if (succeeded(decrementIfPossible(schedule.nWarpCount))) {
+    if (succeeded(decrementIfPossible(schedule.nSubgroupCounts))) {
       continue;
     }
 
@@ -148,6 +214,9 @@
 static LogicalResult canTargetIntrinsic(const GPUMatmulShapeType &problem,
                                         const GPUMatmulShapeType &intrinsic,
                                         bool canUpcastAcc, bool mustBeAligned) {
+  assert(intrinsic.mSizes.size() == 1 && intrinsic.nSizes.size() == 1 &&
+         intrinsic.kSizes.size() == 1 &&
+         "expected intrinsic to have a single M, N, and K dimension.");
   if (problem.aType != intrinsic.aType || problem.bType != intrinsic.bType) {
     return failure(); // Cannot use this intrinsic for mismatched types
   }
@@ -161,17 +230,17 @@
     }
   }
 
-  if (mustBeAligned && (problem.mSize % intrinsic.mSize != 0 ||
-                        problem.nSize % intrinsic.nSize != 0 ||
-                        problem.kSize % intrinsic.kSize != 0)) {
+  if (mustBeAligned && (problem.mSizes.back() % intrinsic.mSizes[0] != 0 ||
+                        problem.nSizes.back() % intrinsic.nSizes[0] != 0 ||
+                        problem.kSizes.back() % intrinsic.kSizes[0] != 0)) {
     return failure(); // Cannot use this intrinsic for misaligned cases.
   }
 
   // Cannot use the intrinsic when the tile size is greater than problem size.
   // Because tiling is a no-op, and we can't infer tiling sizes from IR.
-  if (!mustBeAligned &&
-      (problem.mSize < intrinsic.mSize || problem.nSize < intrinsic.nSize ||
-       problem.kSize < intrinsic.kSize)) {
+  if (!mustBeAligned && (problem.mSizes.back() < intrinsic.mSizes[0] ||
+                         problem.nSizes.back() < intrinsic.nSizes[0] ||
+                         problem.kSizes.back() < intrinsic.kSizes[0])) {
     return failure();
   }
 
@@ -185,77 +254,123 @@
                                             const GPUMatmulShapeType &intrinsic,
                                             const GPUMMAHeuristicSeeds &seeds,
                                             uint64_t intrinsicIndex) {
-  int64_t mTotalTileCount = llvm::divideCeil(problem.mSize, intrinsic.mSize);
-  int64_t nTotalTileCount = llvm::divideCeil(problem.nSize, intrinsic.nSize);
+  assert(intrinsic.mSizes.size() == 1 && intrinsic.nSizes.size() == 1 &&
+         intrinsic.kSizes.size() == 1 &&
+         "expected intrinsic to have a single M, N, and K dimension.");
+  // mTotalTileCounts and nTotalTileCounts represent the total number of
+  // intrinsics along the M or N dimensions needed to fill the problem size.
+  // For example, if the problem is {M:[4, 16], N:[2, 32], K[3, 128]} for a
+  // 16x16x16 intrinsic, then:
+  //  - mTotalTileCounts would be 4 * (16/16) = 4
+  //  - nTotalTileCounts would be 2 * (32/16) = 4
+  SmallVector<int64_t> mTotalTileCounts = problem.mSizes;
+  SmallVector<int64_t> nTotalTileCounts = problem.nSizes;
+  mTotalTileCounts.back() =
+      llvm::divideCeil(problem.mSizes.back(), intrinsic.mSizes[0]);
+  nTotalTileCounts.back() =
+      llvm::divideCeil(problem.nSizes.back(), intrinsic.nSizes[0]);
 
-  int64_t remainingWarps = seeds.bestSubgroupCountPerWorkgroup;
+  int64_t remainingSubgroups = seeds.bestSubgroupCountPerWorkgroup;
   int64_t remainingTiles = seeds.bestMNTileCountPerSubgroup;
-  // Assign more warps to the M dimension (used later) to balance thread
+  // Assign more subgroups to the M dimension (used later) to balance thread
   // counts along X and Y dimensions.
-  int64_t warpSqrt =
-      1ull << (llvm::divideCeil(llvm::Log2_64(remainingWarps), 2));
-  int64_t tileSqrt = 1ull << (llvm::Log2_64(remainingTiles) / 2);
+  int mDim = problem.mSizes.size() - 1;
+  int nDim = problem.nSizes.size() - 1;
+  SmallVector<int64_t> mTileSizes(problem.mSizes.size(), 0),
+      nTileSizes(problem.nSizes.size(), 0),
+      mSubgroupCounts(problem.mSizes.size(), 0),
+      nSubgroupCounts(problem.nSizes.size(), 0);
+  // Start at the innermost nDim and mDim, and try to distribute evenly to M and
+  // N for each pair of M and N dims. Otherwise, distribute to N and then M.
+  while (mDim >= 0 || nDim >= 0) {
+    int64_t subgroupSqrt =
+        1ull << (llvm::divideCeil(llvm::Log2_64(remainingSubgroups), 2));
+    int64_t tileSqrt = 1ull << (llvm::Log2_64(remainingTiles) / 2);
 
-  int64_t mWarpCount = 0, nWarpCount = 0;
-  int64_t mTileCount = 0, nTileCount = 0;
+    // See if the square root can divide mTotalTileCount. If so it means we can
+    // distribute to both dimensions evenly to minimize the number of global
+    // loads. Otherwise, try to distribute to N and then M.
+    if (mDim >= 0 && nDim >= 0 &&
+        mTotalTileCounts[mDim] > (subgroupSqrt * tileSqrt) &&
+        mTotalTileCounts[mDim] % (subgroupSqrt * tileSqrt) == 0) {
+      mSubgroupCounts[mDim] = subgroupSqrt;
+      mTileSizes[mDim] = tileSqrt;
 
-  // See if the square root can divide mTotalTileCount. If so it means we can
-  // distribute to both dimensions evenly. Otherwise, try to distribute to N
-  // and then M.
-  if (mTotalTileCount > (warpSqrt * tileSqrt) &&
-      mTotalTileCount % (warpSqrt * tileSqrt) == 0) {
-    mWarpCount = warpSqrt;
-    mTileCount = tileSqrt;
+      remainingSubgroups /= subgroupSqrt;
+      remainingTiles /= tileSqrt;
 
-    remainingWarps /= warpSqrt;
-    remainingTiles /= tileSqrt;
+      APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCounts[nDim]),
+                                         APInt(64, remainingSubgroups));
+      nSubgroupCounts[nDim] = nGCD.getSExtValue();
+      nTotalTileCounts[nDim] /= nSubgroupCounts[nDim];
+      remainingSubgroups /= nSubgroupCounts[nDim];
 
-    APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount),
-                                       APInt(64, remainingWarps));
-    nWarpCount = nGCD.getSExtValue();
-    nTotalTileCount /= nWarpCount;
-    remainingWarps /= nWarpCount;
+      nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCounts[nDim]),
+                                   APInt(64, remainingTiles));
+      nTileSizes[nDim] = nGCD.getSExtValue();
+      remainingTiles /= nTileSizes[nDim];
+    } else {
+      if (nDim >= 0) {
+        APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCounts[nDim]),
+                                           APInt(64, remainingSubgroups));
+        nSubgroupCounts[nDim] = nGCD.getSExtValue();
+        nTotalTileCounts[nDim] /= nSubgroupCounts[nDim];
+        remainingSubgroups /= nSubgroupCounts[nDim];
 
-    nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount),
-                                 APInt(64, remainingTiles));
-    nTileCount = nGCD.getSExtValue();
-  } else {
-    APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount),
-                                       APInt(64, remainingWarps));
-    nWarpCount = nGCD.getSExtValue();
-    nTotalTileCount /= nWarpCount;
-    remainingWarps /= nWarpCount;
+        nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCounts[nDim]),
+                                     APInt(64, remainingTiles));
+        nTileSizes[nDim] = nGCD.getSExtValue();
+        remainingTiles /= nTileSizes[nDim];
+      }
 
-    nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount),
-                                 APInt(64, remainingTiles));
-    nTileCount = nGCD.getSExtValue();
-    remainingTiles /= nTileCount;
+      if (mDim >= 0) {
+        APInt mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCounts[mDim]),
+                                           APInt(64, remainingSubgroups));
+        mSubgroupCounts[mDim] = mGCD.getSExtValue();
+        mTotalTileCounts[mDim] /= mSubgroupCounts[mDim];
+        remainingSubgroups /= mSubgroupCounts[mDim];
 
-    APInt mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCount),
-                                       APInt(64, remainingWarps));
-    mWarpCount = mGCD.getSExtValue();
-    mTotalTileCount /= mWarpCount;
-    remainingWarps /= mWarpCount;
-
-    mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCount),
-                                 APInt(64, remainingTiles));
-    mTileCount = mGCD.getSExtValue();
+        mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCounts[mDim]),
+                                     APInt(64, remainingTiles));
+        mTileSizes[mDim] = mGCD.getSExtValue();
+        remainingTiles /= mTileSizes[mDim];
+      }
+    }
+    --mDim;
+    --nDim;
   }
 
-  const uint64_t kTotalTileCount =
-      llvm::divideCeil(problem.kSize, intrinsic.kSize);
+  // kTotalTileCounts is similar to m/nTotalTileCounts, representing the total
+  // number of intrinsics along the K dimensions needed to fill the problem.
+  // For the problem described above {M:[4, 16], N:[2, 32], K[3, 128]} with a
+  // 16x16x16 intrinsic, then:
+  //  - kTotalTileCounts would be 3 * (128/16) = 24
+  SmallVector<int64_t> kTotalTileCounts = problem.kSizes;
+  kTotalTileCounts.back() =
+      llvm::divideCeil(problem.kSizes.back(), intrinsic.kSizes[0]);
+  // Compute the ideal number of intrinsics along K per subgroup based on the
+  // seed.
   int64_t bestKTileCountPerSubgroup =
       seeds.bestKElementCountPerSubgroup
           ? llvm::divideCeil(seeds.bestKElementCountPerSubgroup,
-                             intrinsic.kSize)
+                             intrinsic.kSizes[0])
           : seeds.bestKTileCountPerSubgroup;
-  APInt kGCD = GreatestCommonDivisor(APInt(64, kTotalTileCount),
-                                     APInt(64, bestKTileCountPerSubgroup));
-  int64_t kTileCount = kGCD.getSExtValue();
+  SmallVector<int64_t> kTileSizes(problem.kSizes.size(), 0);
+  // Start at the innermost K dim, and tile each dim to try to satisfy the ideal
+  // K intrinsic count per subgroup with the overall product of K tile counts.
+  int kDim = problem.kSizes.size() - 1;
+  while (kDim >= 0) {
+    APInt kGCD = GreatestCommonDivisor(APInt(64, kTotalTileCounts[kDim]),
+                                       APInt(64, bestKTileCountPerSubgroup));
+    kTileSizes[kDim] = kGCD.getSExtValue();
+    bestKTileCountPerSubgroup /= kTileSizes[kDim];
+    --kDim;
+  }
 
-  return GPUMMASchedule{intrinsicIndex,  intrinsic.mSize, intrinsic.nSize,
-                        intrinsic.kSize, mWarpCount,      nWarpCount,
-                        mTileCount,      nTileCount,      kTileCount};
+  return GPUMMASchedule{
+      intrinsicIndex,      intrinsic.mSizes[0], intrinsic.nSizes[0],
+      intrinsic.kSizes[0], mSubgroupCounts,     nSubgroupCounts,
+      mTileSizes,          nTileSizes,          kTileSizes};
 }
 
 FailureOr<GPUMMASchedule> deduceMMASchedule(
@@ -297,7 +412,6 @@
 
       return isAligned && sharedMemoryUsed <= sharedMemLimitInBytes;
     };
-
     return fitScheduleInSharedMemory(intrinsic, schedule, isValidSchedule);
   }
   return failure();
@@ -309,7 +423,10 @@
     const GPUMMAHeuristicSeeds &pvMatmulSeeds, int64_t sharedMemLimitInBytes,
     int64_t subgroupSize, bool transposedQ, bool transposedK, bool transposedV,
     bool canUpcastAcc, bool mustBeAligned) {
-
+  assert(pvMatmul.mSizes.size() == 1 && pvMatmul.nSizes.size() == 1 &&
+         pvMatmul.kSizes.size() == 1 && qkMatmul.mSizes.size() == 1 &&
+         qkMatmul.nSizes.size() == 1 && qkMatmul.kSizes.size() == 1 &&
+         "unimplemented: multi M/N/K attention schedule");
   for (auto [index, intrinsic] : llvm::enumerate(intrinsics)) {
     if (failed(canTargetIntrinsic(qkMatmul, intrinsic, canUpcastAcc,
                                   mustBeAligned))) {
@@ -329,7 +446,7 @@
       llvm::dbgs() << "  " << schedule << "\n";
     });
 
-    int64_t intrinsicK = intrinsic.kSize;
+    int64_t intrinsicK = intrinsic.kSizes[0];
     auto isValidSchedule = [&](const GPUMMASchedule &schedule) -> bool {
       // Create a mma schedule for qkMatmul in attention.
       // qkMatmul.M = pvMatmul.M
@@ -339,11 +456,11 @@
                                 schedule.mSize,
                                 schedule.kSize,
                                 intrinsicK,
-                                /*mWarpCount=*/schedule.mWarpCount,
-                                /*nWarpCount=*/1,
-                                schedule.mTileCount,
-                                schedule.kTileCount,
-                                qkMatmul.kSize / intrinsicK};
+                                /*mSubgroupCount=*/schedule.mSubgroupCounts[0],
+                                /*nSubgroupCount=*/1,
+                                schedule.mTileSizes[0],
+                                schedule.kTileSizes[0],
+                                qkMatmul.kSizes[0] / intrinsicK};
 
       bool isQKAligned =
           isValidMMASchedule(qkMatmul, qkSchedule, mustBeAligned, subgroupSize,
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h
index 8211443..13f6a56 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h
@@ -10,15 +10,18 @@
 
 /// Struct containing information about a matmul's shape and type.
 struct GPUMatmulShapeType {
-  int64_t mSize;
-  int64_t nSize;
-  int64_t kSize;
+  SmallVector<int64_t> mSizes;
+  SmallVector<int64_t> nSizes;
+  SmallVector<int64_t> kSizes;
   Type aType;
   Type bType;
   Type cType;
 
   GPUMatmulShapeType(int64_t m, int64_t n, int64_t k, Type a, Type b, Type c)
-      : mSize(m), nSize(n), kSize(k), aType(a), bType(b), cType(c) {}
+      : mSizes({m}), nSizes({n}), kSizes({k}), aType(a), bType(b), cType(c) {}
+  GPUMatmulShapeType(SmallVector<int64_t> m, SmallVector<int64_t> n,
+                     SmallVector<int64_t> k, Type a, Type b, Type c)
+      : mSizes(m), nSizes(n), kSizes(k), aType(a), bType(b), cType(c) {}
 };
 
 /// Struct containing seed tile sizes for GPU MMA heuristics deduction logic.
@@ -38,14 +41,42 @@
 struct GPUMMASchedule {
   // Index of the chosen intrinsic into the list of given MMA intrinsics
   uint64_t index;
-  int64_t mSize;      // Native MMA size along M dimension
-  int64_t nSize;      // Native MMA size along N dimension
-  int64_t kSize;      // Native MMA size along K dimension
-  int64_t mWarpCount; // Number of subgroups along M dimension
-  int64_t nWarpCount; // Number of subgroups along N dimension
-  int64_t mTileCount; // Number of tiles per subgroup along M dimension
-  int64_t nTileCount; // Number of tiles per subgroup along N dimension
-  int64_t kTileCount; // Number of tiles along K dimension
+  int64_t mSize; // Native MMA intrinsic size along M dimension for a subgroup.
+  int64_t nSize; // Native MMA intrinsic size along N dimension for a subgroup.
+  int64_t kSize; // Native MMA intrinsic size along K dimension for a subgroup.
+
+  // Number of subgroups along each M and N dimension.
+  SmallVector<int64_t> mSubgroupCounts;
+  SmallVector<int64_t> nSubgroupCounts;
+
+  // Tile sizes for each M, N, and K dimension. When there are multiple M, N,
+  // or K dimensions, the intrinsic sizes are targeted to the innermost
+  // dimension, and the outer dimensions can be thought of as unrolling factors
+  // along M, N, or K.
+  SmallVector<int64_t> mTileSizes; // M tile sizes per subgroup.
+  SmallVector<int64_t> nTileSizes; // N tile sizes per subgroup.
+  SmallVector<int64_t> kTileSizes; // K tile sizes.
+
+  // Constructor for multi M, N, K dim schedules.
+  GPUMMASchedule(uint64_t i, int64_t mIntrinsicSize, int64_t nIntrinsicSize,
+                 int64_t kIntrinsicSize, SmallVector<int64_t> mSubgroupCounts,
+                 SmallVector<int64_t> nSubgroupCounts,
+                 SmallVector<int64_t> mTileSizes,
+                 SmallVector<int64_t> nTileSizes,
+                 SmallVector<int64_t> kTileSizes)
+      : index(i), mSize(mIntrinsicSize), nSize(nIntrinsicSize),
+        kSize(kIntrinsicSize), mSubgroupCounts(mSubgroupCounts),
+        nSubgroupCounts(nSubgroupCounts), mTileSizes(mTileSizes),
+        nTileSizes(nTileSizes), kTileSizes(kTileSizes) {}
+
+  // Constructor for single M, N, K dim schedules.
+  GPUMMASchedule(uint64_t i, int64_t mIntrinsicSize, int64_t nIntrinsicSize,
+                 int64_t kIntrinsicSize, int64_t mSubgroup, int64_t nSubgroup,
+                 int64_t mTileSize, int64_t nTileSize, int64_t kTileSize)
+      : index(i), mSize(mIntrinsicSize), nSize(nIntrinsicSize),
+        kSize(kIntrinsicSize), mSubgroupCounts({mSubgroup}),
+        nSubgroupCounts({nSubgroup}), mTileSizes({mTileSize}),
+        nTileSizes({nTileSize}), kTileSizes({kTileSize}) {}
 };
 
 /// Returns a schedule for using one of the given MMA |intrinsics| to target the
@@ -69,4 +100,7 @@
     bool transposedV = false, bool canUpcastAcc = false,
     bool mustBeAligned = true);
 
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+                              const GPUMMASchedule &schedule);
+
 } // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
index ca23b0c..58bfdc0 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
@@ -13,6 +13,7 @@
 #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h"
 #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
 #include "iree/compiler/Codegen/Utils/Utils.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/Debug.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
@@ -124,20 +125,37 @@
     return failure();
   }
 
-  // For now we are not being smart and trying to reshape dimensions to allow
-  // for better usage of intrinsics, and instead are tiling all dimensions
-  // except the inner most m, n, and k dimensions to 1.
-  int64_t mDim = contractionDims.m.back();
-  int64_t nDim = contractionDims.n.back();
-  int64_t kDim = contractionDims.k.back();
-
-  // Dynamic dims are expected to be taken care of earlier in the pipeline.
-  if (ShapedType::isDynamic(bounds[mDim]) ||
-      ShapedType::isDynamic(bounds[nDim]) ||
-      ShapedType::isDynamic(bounds[kDim])) {
+  // TODO(Max191): add dynamic shape support for inner most dims.
+  if (ShapedType::isDynamic(bounds[contractionDims.m.back()]) ||
+      ShapedType::isDynamic(bounds[contractionDims.n.back()]) ||
+      ShapedType::isDynamic(bounds[contractionDims.k.back()])) {
     return failure();
   }
 
+  // Gather all static M, N, and K dimensions to deduce the MMASchedule. Dynamic
+  // dimensions will be tiled to 1 in workgroup tiling, so they are ignored when
+  // computing an MMA schedule.
+  SmallVector<int64_t> mDims, nDims, kDims;
+  for (auto mDim : contractionDims.m) {
+    if (!ShapedType::isDynamic(bounds[mDim])) {
+      mDims.push_back(mDim);
+    }
+  }
+  for (auto nDim : contractionDims.n) {
+    if (!ShapedType::isDynamic(bounds[nDim])) {
+      nDims.push_back(nDim);
+    }
+  }
+  for (auto kDim : contractionDims.k) {
+    if (!ShapedType::isDynamic(bounds[kDim])) {
+      kDims.push_back(kDim);
+    }
+  }
+
+  auto getDimBounds = [&](SmallVector<int64_t> dims) -> SmallVector<int64_t> {
+    return llvm::map_to_vector(dims, [&](int64_t dim) { return bounds[dim]; });
+  };
+
   Value lhs = linalgOp.getDpsInputOperand(0)->get();
   Value rhs = linalgOp.getDpsInputOperand(1)->get();
   Value init = linalgOp.getDpsInitOperand(0)->get();
@@ -146,8 +164,9 @@
   Type rhsElemType = getElementTypeOrSelf(rhs);
   Type initElemType = getElementTypeOrSelf(init);
 
-  GPUMatmulShapeType problem{bounds[mDim], bounds[nDim], bounds[kDim],
-                             lhsElemType,  rhsElemType,  initElemType};
+  GPUMatmulShapeType problem{getDimBounds(mDims), getDimBounds(nDims),
+                             getDimBounds(kDims), lhsElemType,
+                             rhsElemType,         initElemType};
 
   SmallVector<GPUMatmulShapeType> intrinsics;
   for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) {
@@ -166,7 +185,9 @@
   // Note that the following heuristic seeds are just placeholder values.
   // We need to clean it up and make it adjusting to different targets.
   // See https://github.com/iree-org/iree/issues/16341 for details.
-  if (problem.mSize * problem.nSize <= 512 * 512) {
+  int64_t mSize = ShapedType::getNumElements(problem.mSizes);
+  int64_t nSize = ShapedType::getNumElements(problem.nSizes);
+  if (mSize * nSize <= 512 * 512) {
     // For matmuls with small M*N size, we want to distribute M*N onto more
     // workgroups to fill the GPU. Use a smaller bestMNTileCountPerSubgroup
     // and a larger bestKTileCountPerSubgroup.
@@ -190,10 +211,10 @@
   // TODO: Drop this. This is only a consideration for other pipelines.
   SmallVector<AffineMap> maps = linalgOp.getIndexingMapsArray();
   bool transposedLhs =
-      kDim !=
+      kDims.back() !=
       llvm::cast<AffineDimExpr>(maps[0].getResults().back()).getPosition();
   bool transposedRhs =
-      nDim !=
+      nDims.back() !=
       llvm::cast<AffineDimExpr>(maps[1].getResults().back()).getPosition();
 
   // First try to find a schedule with an exactly matching intrinsic.
@@ -213,16 +234,13 @@
   }
 
   LDBG("Target Subgroup size: " << targetSubgroupSize);
-  LDBG("Schedule: sizes [" << schedule->mSize << ", " << schedule->nSize << ", "
-                           << schedule->kSize << "]");
-  LDBG("Schedule: tile counts [" << schedule->mTileCount << ", "
-                                 << schedule->nTileCount << ", "
-                                 << schedule->kTileCount << "]");
-  LDBG("Schedule: warp counts [" << schedule->mWarpCount << ", "
-                                 << schedule->nWarpCount << "]");
+  LDBG("Schedule: " << schedule);
 
-  std::array<int64_t, 3> workgroupSize{
-      schedule->nWarpCount * targetSubgroupSize, schedule->mWarpCount, 1};
+  int64_t flatWorkgroupSize =
+      targetSubgroupSize *
+      ShapedType::getNumElements(schedule->nSubgroupCounts) *
+      ShapedType::getNumElements(schedule->mSubgroupCounts);
+  std::array<int64_t, 3> workgroupSize{flatWorkgroupSize, 1, 1};
 
   SmallVector<int64_t> workgroupTileSizes(linalgOp.getNumLoops(), 0);
   SmallVector<int64_t> reductionTileSizes(linalgOp.getNumLoops(), 0);
@@ -244,16 +262,30 @@
     reductionTileSizes[k] = 1;
   }
 
-  // Compute the M/N dimension tile size by multiplying subgroup information.
-  workgroupTileSizes[mDim] = schedule->mWarpCount * schedule->mTileCount;
-  workgroupTileSizes[nDim] = schedule->nWarpCount * schedule->nTileCount;
+  // Adjust the inner bound size for packing to intrinsic shapes, since tiling
+  // happens after packing.
+  assert(bounds[mDims.back()] % schedule->mSize == 0 &&
+         bounds[nDims.back()] % schedule->nSize == 0 &&
+         "expected inner bound to be evenly divisible by schedule sizes.");
+  bounds[mDims.back()] /= schedule->mSize;
+  bounds[nDims.back()] /= schedule->nSize;
 
-  // Specify the subgroup tile sizes from the mma schedule. This is applied
-  subgroupTileSizes[mDim] = schedule->mTileCount;
-  subgroupTileSizes[nDim] = schedule->nTileCount;
+  // Compute the M/N dimension tile sizes by multiplying subgroup information.
+  for (auto [i, mDim] : llvm::enumerate(mDims)) {
+    workgroupTileSizes[mDim] =
+        schedule->mSubgroupCounts[i] * schedule->mTileSizes[i];
+    subgroupTileSizes[mDim] = schedule->mTileSizes[i];
+  }
+  for (auto [i, nDim] : llvm::enumerate(nDims)) {
+    workgroupTileSizes[nDim] =
+        schedule->nSubgroupCounts[i] * schedule->nTileSizes[i];
+    subgroupTileSizes[nDim] = schedule->nTileSizes[i];
+  }
 
   // Similarly the reduction tile size is just the post-packing tile count.
-  reductionTileSizes[kDim] = schedule->kTileCount;
+  for (auto [i, kDim] : llvm::enumerate(kDims)) {
+    reductionTileSizes[kDim] = schedule->kTileSizes[i];
+  }
 
   IREE::GPU::MmaInterfaceAttr mmaKind =
       target.getWgp().getMma()[schedule->index];
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index ff002ac..4b64cda 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -301,6 +301,11 @@
   Type rhsElemType = getElementTypeOrSelf(rhs);
   Type initElemType = getElementTypeOrSelf(init);
 
+  // TODO(Max191): Support multiple M/N/K dimension problems for MMASchedules
+  // once the pipeline is able to support it. After adding multiple dimensions,
+  // all instances of schedule->m/nSubgroupCounts[0] and
+  // schedule->m/n/kTileSizes[0] need to use the full list of sizes instead of
+  // just the first element.
   GPUMatmulShapeType problem{bounds[mDim], bounds[nDim], bounds[kDim],
                              lhsElemType,  rhsElemType,  initElemType};
 
@@ -339,8 +344,9 @@
     return failure();
   }
 
-  std::array<int64_t, 3> workgroupSize{
-      schedule->nWarpCount * targetSubgroupSize, schedule->mWarpCount, 1};
+  std::array<int64_t, 3> workgroupSize{schedule->nSubgroupCounts[0] *
+                                           targetSubgroupSize,
+                                       schedule->mSubgroupCounts[0], 1};
 
   SmallVector<int64_t> workgroupTileSizes(op.getNumLoops(), 0);
   SmallVector<int64_t> reductionTileSizes(op.getNumLoops(), 0);
@@ -360,11 +366,11 @@
   }
   // Compute the M/N dimension tile size by multiply subgroup information.
   workgroupTileSizes[mDim] =
-      schedule->mWarpCount * schedule->mTileCount * schedule->mSize;
+      schedule->mSubgroupCounts[0] * schedule->mTileSizes[0] * schedule->mSize;
   workgroupTileSizes[nDim] =
-      schedule->nWarpCount * schedule->nTileCount * schedule->nSize;
+      schedule->nSubgroupCounts[0] * schedule->nTileSizes[0] * schedule->nSize;
 
-  reductionTileSizes[kDim] = schedule->kTileCount * schedule->kSize;
+  reductionTileSizes[kDim] = schedule->kTileSizes[0] * schedule->kSize;
 
   // Tile all filter loop dimensions to 1.
   for (int64_t filterDim : convolutionDims->filterLoop) {
@@ -386,8 +392,8 @@
   // for later access in the pipeline.
   SmallVector<NamedAttribute, 1> pipelineAttrs;
   auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get(
-      context, target.getWgp().getMma()[schedule->index], schedule->mWarpCount,
-      schedule->nWarpCount);
+      context, target.getWgp().getMma()[schedule->index],
+      schedule->mSubgroupCounts[0], schedule->nSubgroupCounts[0]);
   pipelineAttrs.emplace_back(StringAttr::get(context, "mma_schedule"),
                              scheduleAttr);
 
@@ -489,6 +495,11 @@
       rhsElemType = getElementTypeOrSelf(rhsOp.getDpsInputs()[0]);
   }
 
+  // TODO(Max191): Support multiple M/N/K dimension problems for MMASchedules
+  // once the pipeline is able to support it. After adding multiple dimensions,
+  // all instances of schedule->m/nSubgroupCounts[0] and
+  // schedule->m/n/kTileSizes[0] need to use the full list of sizes instead of
+  // just the first element.
   GPUMatmulShapeType problem{bounds[mDim], bounds[nDim], bounds[kDim],
                              lhsElemType,  rhsElemType,  initElemType};
 
@@ -509,7 +520,7 @@
   // Note that the following heuristic seeds are just placeholder values.
   // We need to clean it up and make it adjusting to different targets.
   // See https://github.com/iree-org/iree/issues/16341 for details.
-  if (problem.mSize * problem.nSize <= clGPUMatmulCThreshold) {
+  if (problem.mSizes[0] * problem.nSizes[0] <= clGPUMatmulCThreshold) {
     // For matmuls with small M*N size, we want to distribute M*N onto more
     // workgroups to fill the GPU. Use a smaller bestMNTileCountPerSubgroup
     // and a larger bestKTileCountPerSubgroup.
@@ -573,16 +584,11 @@
   }
 
   LDBG("Target Subgroup size: " << targetSubgroupSize);
-  LDBG("Schedule: sizes [" << schedule->mSize << ", " << schedule->nSize << ", "
-                           << schedule->kSize << "]");
-  LDBG("Schedule: tile counts [" << schedule->mTileCount << ", "
-                                 << schedule->nTileCount << ", "
-                                 << schedule->kTileCount << "]");
-  LDBG("Schedule: warp counts [" << schedule->mWarpCount << ", "
-                                 << schedule->nWarpCount << "]");
+  LDBG("Schedule: " << schedule);
 
-  std::array<int64_t, 3> workgroupSize{
-      schedule->nWarpCount * targetSubgroupSize, schedule->mWarpCount, 1};
+  std::array<int64_t, 3> workgroupSize{schedule->nSubgroupCounts[0] *
+                                           targetSubgroupSize,
+                                       schedule->mSubgroupCounts[0], 1};
 
   SmallVector<int64_t> workgroupTileSizes(op.getNumLoops(), 0);
   SmallVector<int64_t> reductionTileSizes(op.getNumLoops(), 0);
@@ -605,11 +611,11 @@
 
   // Compute the M/N dimension tile size by multiply subgroup information.
   workgroupTileSizes[mDim] =
-      schedule->mWarpCount * schedule->mTileCount * schedule->mSize;
+      schedule->mSubgroupCounts[0] * schedule->mTileSizes[0] * schedule->mSize;
   workgroupTileSizes[nDim] =
-      schedule->nWarpCount * schedule->nTileCount * schedule->nSize;
+      schedule->nSubgroupCounts[0] * schedule->nTileSizes[0] * schedule->nSize;
 
-  reductionTileSizes[kDim] = schedule->kTileCount * schedule->kSize;
+  reductionTileSizes[kDim] = schedule->kTileSizes[0] * schedule->kSize;
 
   LLVM_DEBUG(debugPrintContractionInfo("Workgroup tile sizes", op.getNumLoops(),
                                        *contractionDims, workgroupTileSizes));
@@ -631,8 +637,8 @@
   // for later access in the pipeline.
   SmallVector<NamedAttribute, 1> pipelineAttrs;
   auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get(
-      context, target.getWgp().getMma()[schedule->index], schedule->mWarpCount,
-      schedule->nWarpCount);
+      context, target.getWgp().getMma()[schedule->index],
+      schedule->mSubgroupCounts[0], schedule->nSubgroupCounts[0]);
   pipelineAttrs.emplace_back(StringAttr::get(context, "mma_schedule"),
                              scheduleAttr);
 
@@ -772,22 +778,17 @@
   // TODO: Due to a bug in layout configuration, we cannot set warp count on
   // the N dimension. This is however ok, because we generally do not want to
   // distribute subgroups on N dimension anyway.
-  if (schedule->nWarpCount != 1) {
-    schedule->nTileCount *= schedule->nWarpCount;
-    schedule->nWarpCount = 1;
+  if (schedule->nSubgroupCounts[0] != 1) {
+    schedule->nTileSizes[0] *= schedule->nSubgroupCounts[0];
+    schedule->nSubgroupCounts[0] = 1;
   }
 
   LDBG("Target Subgroup size: " << targetSubgroupSize);
-  LDBG("Schedule: sizes [" << schedule->mSize << ", " << schedule->nSize << ", "
-                           << schedule->kSize << "]");
-  LDBG("Schedule: tile counts [" << schedule->mTileCount << ", "
-                                 << schedule->nTileCount << ", "
-                                 << schedule->kTileCount << "]");
-  LDBG("Schedule: warp counts [" << schedule->mWarpCount << ", "
-                                 << schedule->nWarpCount << "]");
+  LDBG("Schedule: " << schedule);
 
-  std::array<int64_t, 3> workgroupSize{
-      schedule->nWarpCount * targetSubgroupSize, schedule->mWarpCount, 1};
+  std::array<int64_t, 3> workgroupSize{schedule->nSubgroupCounts[0] *
+                                           targetSubgroupSize,
+                                       schedule->mSubgroupCounts[0], 1};
 
   SmallVector<int64_t> workgroupTileSizes(opInfo.getDomainRank(), 0);
   SmallVector<int64_t> reductionTileSizes(op.getNumLoops(), 0);
@@ -811,11 +812,11 @@
 
   // Compute the M/N dimension tile size by multiply subgroup information.
   workgroupTileSizes[mDim] =
-      schedule->mWarpCount * schedule->mTileCount * schedule->mSize;
+      schedule->mSubgroupCounts[0] * schedule->mTileSizes[0] * schedule->mSize;
   workgroupTileSizes[nDim] =
-      schedule->nWarpCount * schedule->nTileCount * schedule->nSize;
+      schedule->nSubgroupCounts[0] * schedule->nTileSizes[0] * schedule->nSize;
 
-  reductionTileSizes[k2Dim] = schedule->kTileCount * schedule->kSize;
+  reductionTileSizes[k2Dim] = schedule->kTileSizes[0] * schedule->kSize;
 
   MLIRContext *context = op.getContext();
   SmallVector<NamedAttribute, 2> attrs;
@@ -831,8 +832,8 @@
   // for later access in the pipeline.
   SmallVector<NamedAttribute, 1> pipelineAttrs;
   auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get(
-      context, target.getWgp().getMma()[schedule->index], schedule->mWarpCount,
-      schedule->nWarpCount);
+      context, target.getWgp().getMma()[schedule->index],
+      schedule->mSubgroupCounts[0], schedule->nSubgroupCounts[0]);
   pipelineAttrs.emplace_back(StringAttr::get(context, "mma_schedule"),
                              scheduleAttr);
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
index b98e85a..819b882 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
@@ -37,11 +37,79 @@
 //  CHECK-SAME:     mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
 //  CHECK-SAME:     promote_operands = [0, 1]
 //  CHECK-SAME:     reduction = [0, 0, 0, 0, 4]
-//  CHECK-SAME:     subgroup = [0, 0, 4, 1, 0]
+//  CHECK-SAME:     subgroup = [1, 1, 4, 1, 0]
 //  CHECK-SAME:     workgroup = [1, 1, 4, 4, 0]
 
 // -----
 
+#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d4, d5)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d3, d4, d5)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+func.func @multi_dim_mma_schedule(%lhs: tensor<10x32x128x16xf16>, %rhs: tensor<4x32x128x16xf16>) -> tensor<10x4x32x32xf16> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f16
+  %5 = tensor.empty() : tensor<10x4x32x32xf16>
+  %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<10x4x32x32xf16>) -> tensor<10x4x32x32xf16>
+  %7 = linalg.generic {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
+    ins(%lhs, %rhs : tensor<10x32x128x16xf16>, tensor<4x32x128x16xf16>) outs(%6 : tensor<10x4x32x32xf16>) {
+  ^bb0(%in: f16, %in_0: f16, %out: f16):
+    %8 = arith.mulf %in, %in_0 : f16
+    %9 = arith.addf %8, %out : f16
+    linalg.yield %9 : f16
+  } -> tensor<10x4x32x32xf16>
+  return %7 : tensor<10x4x32x32xf16>
+}
+
+// CHECK-LABEL: func.func @multi_dim_mma_schedule
+//  CHECK-SAME:   #iree_codegen.translation_info<LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
+//  CHECK-SAME:   #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>
+
+//       CHECK:   linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
+//  CHECK-SAME:     mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
+//  CHECK-SAME:     promote_operands = [0, 1]
+//  CHECK-SAME:     reduction = [0, 0, 0, 0, 4, 1]
+//  CHECK-SAME:     subgroup = [2, 2, 1, 1, 0, 0]
+//  CHECK-SAME:     workgroup = [2, 2, 2, 2, 0, 0]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d5, d6)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d4, d5, d6)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
+func.func @dynamic_multi_dim_mma_schedule(%lhs: tensor<?x6x16x?x16xf16>, %rhs: tensor<?x32x?x16xf16>) -> tensor<?x6x?x16x32xf16> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f16
+  %d0 = tensor.dim %lhs, %c0 : tensor<?x6x16x?x16xf16>
+  %d2 = tensor.dim %rhs, %c0 : tensor<?x32x?x16xf16>
+  %5 = tensor.empty(%d0, %d2) : tensor<?x6x?x16x32xf16>
+  %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<?x6x?x16x32xf16>) -> tensor<?x6x?x16x32xf16>
+  %7 = linalg.generic {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
+    ins(%lhs, %rhs : tensor<?x6x16x?x16xf16>, tensor<?x32x?x16xf16>) outs(%6 : tensor<?x6x?x16x32xf16>) {
+  ^bb0(%in: f16, %in_0: f16, %out: f16):
+    %8 = arith.mulf %in, %in_0 : f16
+    %9 = arith.addf %8, %out : f16
+    linalg.yield %9 : f16
+  } -> tensor<?x6x?x16x32xf16>
+  return %7 : tensor<?x6x?x16x32xf16>
+}
+
+// CHECK-LABEL: func.func @dynamic_multi_dim_mma_schedule
+//  CHECK-SAME:   #iree_codegen.translation_info<LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
+//  CHECK-SAME:   #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>
+
+//       CHECK:   linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
+//  CHECK-SAME:     mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
+//  CHECK-SAME:     promote_operands = [0, 1]
+//  CHECK-SAME:     reduction = [0, 0, 0, 0, 0, 1, 1]
+//  CHECK-SAME:     subgroup = [0, 1, 0, 1, 1, 0, 0]
+//  CHECK-SAME:     workgroup = [1, 2, 1, 1, 2, 0, 0]
+
+// -----
+
 func.func @mfma_matmul_1024x1024x1024(%lhs: tensor<1024x1024xf16>, %rhs: tensor<1024x1024xf16>) -> tensor<1024x1024xf32> {
   %cst = arith.constant 0.000000e+00 : f32
   %c0 = arith.constant 0 : index
@@ -52,7 +120,7 @@
 }
 
 // CHECK-LABEL: func.func @mfma_matmul_1024x1024x1024
-//  CHECK-SAME:   #iree_codegen.translation_info<LLVMGPUTileAndFuse workgroup_size = [128, 2, 1] subgroup_size = 64
+//  CHECK-SAME:   #iree_codegen.translation_info<LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
 //  CHECK-SAME:   #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>
 
 // Verify that the fill does not have the lowering config propagated to it.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/llvmgpu_convolution_to_igemm.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/llvmgpu_convolution_to_igemm.mlir
index 1fa2bae..9618281 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/llvmgpu_convolution_to_igemm.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/llvmgpu_convolution_to_igemm.mlir
@@ -33,4 +33,4 @@
 // CHECK-SAME:     lowering_config = #iree_gpu.lowering_config<
 // CHECK-SAME:         {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>,
 // CHECK-SAME:          promote_operands = [0, 1], reduction = [0, 0, 0, 0, 8],
-// CHECK-SAME:          subgroup = [0, 0, 2, 2, 0], workgroup = [1, 1, 2, 8, 0]}>
+// CHECK-SAME:          subgroup = [1, 1, 2, 2, 0], workgroup = [1, 1, 2, 8, 0]}>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
index 16a1acf..bbdec5c 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
@@ -884,6 +884,11 @@
   Type lhsElem = getElementType(lhs);
   Type rhsElem = getElementType(rhs);
   Type initElem = getElementType(init);
+  // TODO(Max191): Support multiple M/N/K dimension problems for MMASchedules
+  // once the pipeline is able to support it. After adding multiple dimensions,
+  // all instances of schedule->m/nSubgroupCounts[0] and
+  // schedule->m/n/kTileSizes[0] need to use the full list of sizes instead of
+  // just the first element.
   GPUMatmulShapeType problem(dimM, dimN, dimK, lhsElem, rhsElem, initElem);
 
   SmallVector<GPUMatmulShapeType> intrinsics;
@@ -921,8 +926,9 @@
 
   auto pipeline = CodeGenPipeline::SPIRVCooperativeMatrixVectorize;
 
-  std::array<int64_t, 3> workgroupSize{schedule->nWarpCount * subgroupSize,
-                                       schedule->mWarpCount, 1};
+  std::array<int64_t, 3> workgroupSize{schedule->nSubgroupCounts[0] *
+                                           subgroupSize,
+                                       schedule->mSubgroupCounts[0], 1};
 
   SmallVector<int64_t> vectorSizes(kIndex + 1, 0);
   if (isBM)
@@ -934,21 +940,23 @@
   SmallVector<int64_t> subgroupTileSizes(lastParallelDim + 1, 0);
   if (isBM)
     subgroupTileSizes[bIndex] = 1;
-  subgroupTileSizes[mIndex] = schedule->mTileCount * vectorSizes[mIndex];
-  subgroupTileSizes[nIndex] = schedule->nTileCount * vectorSizes[nIndex];
+  subgroupTileSizes[mIndex] = schedule->mTileSizes[0] * vectorSizes[mIndex];
+  subgroupTileSizes[nIndex] = schedule->nTileSizes[0] * vectorSizes[nIndex];
 
   SmallVector<int64_t> workgroupTileSizes(lastParallelDim + 1, 0);
   if (isBM)
     workgroupTileSizes[bIndex] = 1;
-  workgroupTileSizes[mIndex] = schedule->mWarpCount * subgroupTileSizes[mIndex];
-  workgroupTileSizes[nIndex] = schedule->nWarpCount * subgroupTileSizes[nIndex];
+  workgroupTileSizes[mIndex] =
+      schedule->mSubgroupCounts[0] * subgroupTileSizes[mIndex];
+  workgroupTileSizes[nIndex] =
+      schedule->nSubgroupCounts[0] * subgroupTileSizes[nIndex];
 
   // Also create one level for reduction. This is needed because of
   // SPIRVTileAndPromotePass requires it.
   // TODO(#10499): Consolidate tiling configuration across different pipelines.
   SmallVector<int64_t> reductionTileSizes;
   reductionTileSizes.append(kIndex, 0);
-  reductionTileSizes.push_back(schedule->kTileCount * schedule->kSize);
+  reductionTileSizes.push_back(schedule->kTileSizes[0] * schedule->kSize);
 
   TileSizesListType tileSizes = {workgroupTileSizes, subgroupTileSizes,
                                  reductionTileSizes, vectorSizes};
@@ -956,7 +964,7 @@
   // Don't do multibuffering if the inner reduction loop is folded out.
   auto pipelineDepth = softwarePipelineDepth;
   auto storeStage = softwarePipelineStoreStage;
-  if (schedule->kTileCount <= 1) {
+  if (schedule->kTileSizes[0] <= 1) {
     pipelineDepth = 0;
     storeStage = 0;
   }
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp b/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp
index ba415b3..922e508 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp
+++ b/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp
@@ -242,16 +242,16 @@
       return llvm::divideCeil(value, padTo) * padTo - value;
     };
 
-    if (mSize % intrinsic.mSize != 0) {
-      mPadding = getPadding(mSize, intrinsic.mSize);
+    if (mSize % intrinsic.mSizes[0] != 0) {
+      mPadding = getPadding(mSize, intrinsic.mSizes[0]);
     }
 
-    if (nSize % intrinsic.nSize != 0) {
-      nPadding = getPadding(nSize, intrinsic.nSize);
+    if (nSize % intrinsic.nSizes[0] != 0) {
+      nPadding = getPadding(nSize, intrinsic.nSizes[0]);
     }
 
-    if (kSize % intrinsic.kSize != 0) {
-      kPadding = getPadding(kSize, intrinsic.kSize);
+    if (kSize % intrinsic.kSizes[0] != 0) {
+      kPadding = getPadding(kSize, intrinsic.kSizes[0]);
     }
 
     if (!mPadding && !nPadding && !kPadding) {
@@ -381,7 +381,7 @@
   for (GPUMatmulShapeType &intrinsic : intrinsics) {
     std::optional<OpFoldResult> mPadding, nPadding, kPadding;
     SmallVector<std::pair<int64_t, int64_t>> dimsToExpandCandidate;
-    if (mSize % intrinsic.mSize != 0 || ShapedType::isDynamic(mSize)) {
+    if (mSize % intrinsic.mSizes[0] != 0 || ShapedType::isDynamic(mSize)) {
       OpFoldResult mSizeExpr = rewriter.getIndexAttr(mSize);
       if (ShapedType::isDynamic(mSize)) {
         auto mOperandDimPair = getSrcOperandAndDim(mDim);
@@ -390,12 +390,12 @@
         auto [mOperand, mOperandDim] = mOperandDimPair.value();
         mSizeExpr = rewriter.create<tensor::DimOp>(loc, mOperand, mOperandDim)
                         .getResult();
-        dimsToExpandCandidate.emplace_back(mDim, intrinsic.mSize);
+        dimsToExpandCandidate.emplace_back(mDim, intrinsic.mSizes[0]);
       }
-      mPadding = getPadding(mSizeExpr, intrinsic.mSize);
+      mPadding = getPadding(mSizeExpr, intrinsic.mSizes[0]);
     }
 
-    if (nSize % intrinsic.nSize != 0 || ShapedType::isDynamic(nSize)) {
+    if (nSize % intrinsic.nSizes[0] != 0 || ShapedType::isDynamic(nSize)) {
       OpFoldResult nSizeExpr = rewriter.getIndexAttr(nSize);
       if (ShapedType::isDynamic(nSize)) {
         auto nOperandDimPair = getSrcOperandAndDim(nDim);
@@ -404,12 +404,12 @@
         auto [nOperand, nOperandDim] = nOperandDimPair.value();
         nSizeExpr = rewriter.create<tensor::DimOp>(loc, nOperand, nOperandDim)
                         .getResult();
-        dimsToExpandCandidate.emplace_back(nDim, intrinsic.nSize);
+        dimsToExpandCandidate.emplace_back(nDim, intrinsic.nSizes[0]);
       }
-      nPadding = getPadding(nSizeExpr, intrinsic.nSize);
+      nPadding = getPadding(nSizeExpr, intrinsic.nSizes[0]);
     }
 
-    if (kSize % intrinsic.kSize != 0 || ShapedType::isDynamic(kSize)) {
+    if (kSize % intrinsic.kSizes[0] != 0 || ShapedType::isDynamic(kSize)) {
       OpFoldResult kSizeExpr = rewriter.getIndexAttr(kSize);
       if (ShapedType::isDynamic(kSize)) {
         auto kOperandDimPair = getSrcOperandAndDim(kDim);
@@ -418,9 +418,9 @@
         auto [kOperand, kOperandDim] = kOperandDimPair.value();
         kSizeExpr = rewriter.create<tensor::DimOp>(loc, kOperand, kOperandDim)
                         .getResult();
-        dimsToExpandCandidate.emplace_back(kDim, intrinsic.kSize);
+        dimsToExpandCandidate.emplace_back(kDim, intrinsic.kSizes[0]);
       }
-      kPadding = getPadding(kSizeExpr, intrinsic.kSize);
+      kPadding = getPadding(kSizeExpr, intrinsic.kSizes[0]);
     }
 
     if (!mPadding && !nPadding && !kPadding) {