[GPU] Fix alignment check for scaled matmul (#22737)

## Problem

The current alignment check in `GPUHeuristics.cpp` is incorrect for any
intrinsic that has multiple M, N, and K dimensions. The root cause is
that the product of intrinsic sizes is passed to `GPUMMASchedule`
instead of passing the individual dimension sizes as a vector.

## Example


https://github.com/iree-org/iree/blob/b98c1b92cb630bd696992f47df591bb2f247a8d7/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp#L516-L525

Consider the scaled MFMA where `intrinsic.kSizes = [K, KB] = [4, 32]`.
Instead of passing the vector `[4, 32]`, the value `128` (product: 4 ×
32) is passed to `GPUMMASchedule`.


https://github.com/iree-org/iree/blob/b98c1b92cb630bd696992f47df591bb2f247a8d7/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp#L98-L108

Assume tile size = `[4, 1]`. The returned schedule sizes become `[4,
128]` instead of the correct `[16, 32]`. As a result, the last dimension
`128` always makes the alignment check fail, since the problem size of
KB is `32` and `32 % 128 != 0`.

When the alignment check fails, no intrinsic is selected and the
operation falls back to complete serialization. This leads to extremely
slow execution for workloads like Llama 405B FP4 prefill with direct
codegen.

## Solution

This PR passes all intrinsic sizes as vectors to `GPUMMASchedule`.

## Performance

**Llama 405B FP4 prefill direct codegen with shark-ai:**
- Before: 11 minutes
- After: 234 ms

Closes: #22559

ci-extra: test_torch

---------

Signed-off-by: Yu-Zhewen <zhewenyu@amd.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
index 630f157..6c83c7f 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
@@ -38,9 +38,9 @@
 llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
                               const GPUMMASchedule &schedule) {
   os << "mmaKind " << schedule.mmaKind << ", ";
-  os << "mSizes: " << schedule.mSize << ", ";
-  os << "nSizes: " << schedule.nSize << ", ";
-  os << "kSizes: " << schedule.kSize << ", ";
+  os << "mSizes: " << schedule.mSizes << ", ";
+  os << "nSizes: " << schedule.nSizes << ", ";
+  os << "kSizes: " << schedule.kSizes << ", ";
   os << "mTileSizes: " << schedule.mTileSizes << ", ";
   os << "nTileSizes: " << schedule.nTileSizes << ", ";
   os << "kTileSizes: " << schedule.kTileSizes << ", ";
@@ -52,11 +52,11 @@
 static int64_t calculateOperandsSharedMemoryUsedInBytes(
     const GPUMMASchedule &schedule, int64_t lhsBitwidth, int64_t rhsBitwidth,
     int64_t numRhs = 1) {
-  int64_t tileM = schedule.mSize * llvm::product_of(schedule.mTileSizes) *
-                  llvm::product_of(schedule.mSubgroupCounts);
-  int64_t tileN = schedule.nSize * llvm::product_of(schedule.nTileSizes) *
-                  llvm::product_of(schedule.nSubgroupCounts);
-  int64_t tileK = schedule.kSize * llvm::product_of(schedule.kTileSizes);
+  int64_t tileM = schedule.getTotalMSize() * schedule.getTotalMTileSize() *
+                  schedule.getTotalMSubgroupCount();
+  int64_t tileN = schedule.getTotalNSize() * schedule.getTotalNTileSize() *
+                  schedule.getTotalNSubgroupCount();
+  int64_t tileK = schedule.getTotalKSize() * schedule.getTotalKTileSize();
   return (tileM * tileK * lhsBitwidth + numRhs * tileN * tileK * rhsBitwidth) /
          8;
 }
@@ -65,50 +65,48 @@
 calculateResultSharedMemoryUsedInBytes(const GPUMMASchedule &schedule,
                                        int64_t resultBitwidth,
                                        int64_t numRes = 1) {
-  int64_t tileM = schedule.mSize * llvm::product_of(schedule.mTileSizes) *
-                  llvm::product_of(schedule.mSubgroupCounts);
-  int64_t tileN = schedule.nSize * llvm::product_of(schedule.nTileSizes) *
-                  llvm::product_of(schedule.nSubgroupCounts);
+  int64_t tileM = schedule.getTotalMSize() * schedule.getTotalMTileSize() *
+                  schedule.getTotalMSubgroupCount();
+  int64_t tileN = schedule.getTotalNSize() * schedule.getTotalNTileSize() *
+                  schedule.getTotalNSubgroupCount();
   return (numRes * tileM * tileN * resultBitwidth) / 8;
 }
 
 /// Check that a GPUMMASchedule fits alignment restrictions. To be aligned,
 /// the problem must be evenly divisible by the number of elements in the
-/// schedule for each dimension. If `mustBeAligned` is false, then the innermost
-/// problem dimension is allowed to be unaligned .
+/// schedule for each dimension. If `mustBeAligned` is false, then the problem
+/// is allowed to be unaligned and the function simply returns true.
 static bool isScheduleAligned(const GPUMatmulShapeType &problem,
                               const GPUMMASchedule &schedule,
                               bool mustBeAligned) {
-  SmallVector<int64_t, 2> alignedMSizes(problem.mSizes);
-  alignedMSizes.back() =
-      mustBeAligned ? problem.mSizes.back()
-                    : llvm::divideCeil(problem.mSizes.back(), schedule.mSize) *
-                          schedule.mSize;
-  SmallVector<int64_t, 2> alignedNSizes(problem.nSizes);
-  alignedNSizes.back() =
-      mustBeAligned ? problem.nSizes.back()
-                    : llvm::divideCeil(problem.nSizes.back(), schedule.nSize) *
-                          schedule.nSize;
-  SmallVector<int64_t, 2> alignedKSizes(problem.kSizes);
-  alignedKSizes.back() =
-      mustBeAligned ? problem.kSizes.back()
-                    : llvm::divideCeil(problem.kSizes.back(), schedule.kSize) *
-                          schedule.kSize;
+  // If alignment is not required, skip checks and return true.
+  if (!mustBeAligned) {
+    return true;
+  }
   // Returns the number of elements in the schedule for each dimension.
-  auto getScheduleSizes =
-      [&](int64_t size, SmallVector<int64_t> tileCount,
-          std::optional<SmallVector<int64_t>> subgroupCount) {
-        SmallVector<int64_t> sizes = llvm::map_to_vector(
-            llvm::seq<int64_t>(tileCount.size()), [&](int64_t i) {
-              return subgroupCount ? tileCount[i] * subgroupCount.value()[i]
-                                   : tileCount[i];
-            });
-        sizes.back() *= size;
-        return sizes;
-      };
+  auto getScheduleSizes = [&](ArrayRef<int64_t> intrinsicSizes,
+                              ArrayRef<int64_t> tileCount,
+                              std::optional<ArrayRef<int64_t>> subgroupCount) {
+    SmallVector<int64_t> sizes = llvm::map_to_vector(
+        llvm::seq<int64_t>(tileCount.size()), [&](int64_t i) {
+          return subgroupCount ? tileCount[i] * subgroupCount.value()[i]
+                               : tileCount[i];
+        });
+    // Multiply by intrinsic sizes, applying to the inner dimensions, as
+    // the outer dimensions are unrolling factors. For example, if tileCount
+    // = [a, b, c, d] and intrinsicSizes = [x, y], the result is [a, b, c*x,
+    // d*y].
+    assert(intrinsicSizes.size() <= sizes.size() &&
+           "intrinsic sizes should not exceed tile count sizes");
+    for (auto [intrinsicSize, size] :
+         llvm::zip(llvm::reverse(intrinsicSizes), llvm::reverse(sizes))) {
+      size *= intrinsicSize;
+    }
+    return sizes;
+  };
   // Checks whether the elements of `a` are evenly divisible by the
   // corresponding elements of `b`.
-  auto areAligned = [](SmallVector<int64_t, 2> a, SmallVector<int64_t, 2> b) {
+  auto areAligned = [](ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
     for (auto [aVal, bVal] : llvm::zip_equal(a, b)) {
       if (aVal % bVal != 0) {
         return false;
@@ -117,14 +115,14 @@
     return true;
   };
   bool isValidM = areAligned(
-      alignedMSizes, getScheduleSizes(schedule.mSize, schedule.mTileSizes,
-                                      schedule.mSubgroupCounts));
+      problem.mSizes, getScheduleSizes(schedule.mSizes, schedule.mTileSizes,
+                                       schedule.mSubgroupCounts));
   bool isValidN = areAligned(
-      alignedNSizes, getScheduleSizes(schedule.nSize, schedule.nTileSizes,
-                                      schedule.nSubgroupCounts));
+      problem.nSizes, getScheduleSizes(schedule.nSizes, schedule.nTileSizes,
+                                       schedule.nSubgroupCounts));
   bool isValidK = areAligned(
-      alignedKSizes,
-      getScheduleSizes(schedule.kSize, schedule.kTileSizes, std::nullopt));
+      problem.kSizes,
+      getScheduleSizes(schedule.kSizes, schedule.kTileSizes, std::nullopt));
   return isValidM && isValidN && isValidK;
 }
 
@@ -144,14 +142,13 @@
   const int64_t kMaxVectorLoadBitWidth = 128;
   int64_t elemsPerThread =
       kMaxVectorLoadBitWidth / problem.bType.getIntOrFloatBitWidth();
-  int64_t wgThreads = subgroupSize *
-                      llvm::product_of(schedule.mSubgroupCounts) *
-                      llvm::product_of(schedule.nSubgroupCounts);
-  int64_t mWgSize = schedule.mSize * llvm::product_of(schedule.mTileSizes) *
-                    llvm::product_of(schedule.mSubgroupCounts);
-  int64_t nWgSize = schedule.nSize * llvm::product_of(schedule.nTileSizes) *
-                    llvm::product_of(schedule.nSubgroupCounts);
-  int64_t kWgSize = schedule.kSize * llvm::product_of(schedule.kTileSizes);
+  int64_t wgThreads = subgroupSize * schedule.getTotalMSubgroupCount() *
+                      schedule.getTotalNSubgroupCount();
+  int64_t mWgSize = schedule.getTotalMSize() * schedule.getTotalMTileSize() *
+                    schedule.getTotalMSubgroupCount();
+  int64_t nWgSize = schedule.getTotalNSize() * schedule.getTotalNTileSize() *
+                    schedule.getTotalNSubgroupCount();
+  int64_t kWgSize = schedule.getTotalKSize() * schedule.getTotalKTileSize();
   int64_t innerLhsDimSize = transposedLhs ? mWgSize : kWgSize;
   int64_t innerRhsDimSize = transposedRhs ? kWgSize : nWgSize;
 
@@ -178,7 +175,7 @@
            << schedule << "\nShrinking schedule...";
 
     auto decrementIfPossible =
-        [](SmallVector<int64_t> &sizes) -> LogicalResult {
+        [](MutableArrayRef<int64_t> sizes) -> LogicalResult {
       for (int64_t &size : sizes) {
         if (size <= 1)
           continue;
@@ -513,15 +510,9 @@
   SmallVector<int64_t> kTileSizes =
       getBestKTileSizes(problem, intrinsic, seeds);
 
-  return GPUMMASchedule{intrinsic.mmaKind,
-                        llvm::product_of(intrinsic.mSizes),
-                        llvm::product_of(intrinsic.nSizes),
-                        llvm::product_of(intrinsic.kSizes),
-                        mSubgroupCounts,
-                        nSubgroupCounts,
-                        mTileSizes,
-                        nTileSizes,
-                        kTileSizes};
+  return GPUMMASchedule{intrinsic.mmaKind, intrinsic.mSizes, intrinsic.nSizes,
+                        intrinsic.kSizes,  mSubgroupCounts,  nSubgroupCounts,
+                        mTileSizes,        nTileSizes,       kTileSizes};
 }
 
 /// Compare the MMA intrinsics by following precedence rules:
@@ -917,9 +908,9 @@
     qkKSizes.back() = qkMatmul.kSizes.back() / intrinsicAK;
     GPUMMASchedule qkSchedule{
         intrinsicA.mmaKind,
-        pvSchedule->mSize,
-        pvSchedule->kSize,
-        intrinsicAK,
+        pvSchedule->mSizes,
+        pvSchedule->kSizes,
+        {intrinsicAK},
         /*mSubgroupCount=*/pvSchedule->mSubgroupCounts,
         /*nSubgroupCount=*/SmallVector<int64_t>(qkMatmul.nSizes.size(), 1),
         pvSchedule->mTileSizes,
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h
index d1d4091..db998e5 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h
@@ -68,34 +68,35 @@
 struct GPUMMASchedule {
   // The MMA intrinsic kind to use for this schedule.
   IREE::Codegen::InnerTileDescAttrInterface mmaKind;
-  // Native MMA intrinsic size along M dimension for a subgroup.
-  int64_t mSize = 0;
-  // Native MMA intrinsic size along N dimension for a subgroup.
-  int64_t nSize = 0;
-  // Native MMA intrinsic size along K dimension for a subgroup.
-  int64_t kSize = 0;
+
+  // Native MMA intrinsic sizes along M, N and K dimensions for a subgroup.
+  SmallVector<int64_t, 2> mSizes;
+  SmallVector<int64_t, 2> nSizes;
+  SmallVector<int64_t, 2> kSizes;
 
   // Number of subgroups along each M and N dimension.
-  SmallVector<int64_t> mSubgroupCounts;
-  SmallVector<int64_t> nSubgroupCounts;
+  SmallVector<int64_t, 2> mSubgroupCounts;
+  SmallVector<int64_t, 2> nSubgroupCounts;
 
   // Tile sizes for each M, N, and K dimension. When there are multiple M, N,
   // or K dimensions, the intrinsic sizes are targeted to the innermost
   // dimension, and the outer dimensions can be thought of as unrolling factors
   // along M, N, or K.
-  SmallVector<int64_t> mTileSizes; // M tile sizes per subgroup.
-  SmallVector<int64_t> nTileSizes; // N tile sizes per subgroup.
-  SmallVector<int64_t> kTileSizes; // K tile sizes.
+  SmallVector<int64_t, 2> mTileSizes; // M tile sizes per subgroup.
+  SmallVector<int64_t, 2> nTileSizes; // N tile sizes per subgroup.
+  SmallVector<int64_t, 2> kTileSizes; // K tile sizes.
 
   // Constructor for multi M, N, K dim schedules.
   GPUMMASchedule(IREE::Codegen::InnerTileDescAttrInterface kind,
-                 int64_t mIntrinsicSize, int64_t nIntrinsicSize,
-                 int64_t kIntrinsicSize, ArrayRef<int64_t> mSubgroupCounts,
+                 ArrayRef<int64_t> mIntrinsicSizes,
+                 ArrayRef<int64_t> nIntrinsicSizes,
+                 ArrayRef<int64_t> kIntrinsicSizes,
+                 ArrayRef<int64_t> mSubgroupCounts,
                  ArrayRef<int64_t> nSubgroupCounts,
                  ArrayRef<int64_t> mTileSizes, ArrayRef<int64_t> nTileSizes,
                  ArrayRef<int64_t> kTileSizes)
-      : mmaKind(kind), mSize(mIntrinsicSize), nSize(nIntrinsicSize),
-        kSize(kIntrinsicSize), mSubgroupCounts(mSubgroupCounts),
+      : mmaKind(kind), mSizes(mIntrinsicSizes), nSizes(nIntrinsicSizes),
+        kSizes(kIntrinsicSizes), mSubgroupCounts(mSubgroupCounts),
         nSubgroupCounts(nSubgroupCounts), mTileSizes(mTileSizes),
         nTileSizes(nTileSizes), kTileSizes(kTileSizes) {}
 
@@ -104,10 +105,36 @@
                  int64_t mIntrinsicSize, int64_t nIntrinsicSize,
                  int64_t kIntrinsicSize, int64_t mSubgroup, int64_t nSubgroup,
                  int64_t mTileSize, int64_t nTileSize, int64_t kTileSize)
-      : mmaKind(kind), mSize(mIntrinsicSize), nSize(nIntrinsicSize),
-        kSize(kIntrinsicSize), mSubgroupCounts({mSubgroup}),
+      : mmaKind(kind), mSizes({mIntrinsicSize}), nSizes({nIntrinsicSize}),
+        kSizes({kIntrinsicSize}), mSubgroupCounts({mSubgroup}),
         nSubgroupCounts({nSubgroup}), mTileSizes({mTileSize}),
         nTileSizes({nTileSize}), kTileSizes({kTileSize}) {}
+
+  // Helper methods to get the total product of intrinsic sizes.
+  int64_t getTotalMSize() const { return llvm::product_of(mSizes); }
+  int64_t getTotalNSize() const { return llvm::product_of(nSizes); }
+  int64_t getTotalKSize() const { return llvm::product_of(kSizes); }
+
+  // Helper methods to get the total product of tile sizes.
+  int64_t getTotalMTileSize() const { return llvm::product_of(mTileSizes); }
+  int64_t getTotalNTileSize() const { return llvm::product_of(nTileSizes); }
+  int64_t getTotalKTileSize() const { return llvm::product_of(kTileSizes); }
+
+  // Helper methods to get the total product of subgroup counts.
+  int64_t getTotalMSubgroupCount() const {
+    return llvm::product_of(mSubgroupCounts);
+  }
+  int64_t getTotalNSubgroupCount() const {
+    return llvm::product_of(nSubgroupCounts);
+  }
+
+  // Check if all schedule dimensions are single-element.
+  bool hasSingleDimensions() const {
+    return llvm::all_equal({size_t(1), mSubgroupCounts.size(),
+                            nSubgroupCounts.size(), mTileSizes.size(),
+                            nTileSizes.size(), kTileSizes.size(), mSizes.size(),
+                            nSizes.size(), kSizes.size()});
+  }
 };
 
 /// Returns a schedule for using one of the given MMA |intrinsics| to target the
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
index 3238d23..545dcba 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
@@ -761,7 +761,7 @@
     // Multiply by the intrinsic shape for the inner most dim as we distribute
     // to workgroups before packing to intrinsic.
     if (i == mDims.size() - 1) {
-      workgroupTileSizes[mDim] *= schedule->mSize;
+      workgroupTileSizes[mDim] *= schedule->getTotalMSize();
     }
     subgroupTileSizes[mDim] = schedule->mTileSizes[i];
   }
@@ -771,7 +771,7 @@
     // Multiply by the intrinsic shape for the inner most dim as we distribute
     // to workgroups before packing to intrinsic.
     if (i == nDims.size() - 1) {
-      workgroupTileSizes[nDim] *= schedule->nSize;
+      workgroupTileSizes[nDim] *= schedule->getTotalNSize();
     }
     subgroupTileSizes[nDim] = schedule->nTileSizes[i];
   }
@@ -1742,11 +1742,12 @@
   }
 
   // Compute the M/N dimension tile size by multiply subgroup information.
-  workgroupTileSizes[mDim] =
-      schedule->mSubgroupCounts[0] * schedule->mTileSizes[0] * schedule->mSize;
+  assert(schedule->hasSingleDimensions() && "expected single M/N/K dimension");
+  workgroupTileSizes[mDim] = schedule->mSubgroupCounts[0] *
+                             schedule->mTileSizes[0] * schedule->mSizes[0];
   subgroupTileSizes[mDim] = schedule->mTileSizes[0];
-  workgroupTileSizes[nDim] =
-      schedule->nSubgroupCounts[0] * schedule->nTileSizes[0] * schedule->nSize;
+  workgroupTileSizes[nDim] = schedule->nSubgroupCounts[0] *
+                             schedule->nTileSizes[0] * schedule->nSizes[0];
   subgroupTileSizes[nDim] = schedule->nTileSizes[0];
 
   // The reduction tile size is just the post-packing tile count.
@@ -1763,7 +1764,7 @@
 
   if (!mustBeAligned) {
     SmallVector<int64_t> paddingTileSizes = workgroupTileSizes;
-    paddingTileSizes[kDim] = reductionTileSizes[kDim] * schedule->kSize;
+    paddingTileSizes[kDim] = reductionTileSizes[kDim] * schedule->kSizes[0];
     attrs.emplace_back("padding_conv", b.getI64ArrayAttr(paddingTileSizes));
   }
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 44259ce..6c2c893 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -285,9 +285,9 @@
 
   // TODO(Max191): Support multiple M/N/K dimension problems for MMASchedules
   // once the pipeline is able to support it. After adding multiple dimensions,
-  // all instances of schedule->m/nSubgroupCounts[0] and
-  // schedule->m/n/kTileSizes[0] need to use the full list of sizes instead of
-  // just the first element.
+  // all instances of schedule->m/nSubgroupCounts[0],
+  // schedule->m/n/kTileSizes[0] and schedule->m/n/kSizes[0] need to use the
+  // full list of sizes instead of just the first element.
   GPUMatmulShapeType problem{bounds[mDim], bounds[nDim], bounds[kDim],
                              lhsElemType,  rhsElemType,  initElemType};
 
@@ -345,6 +345,8 @@
 
   LDBG() << "Schedule: " << schedule;
 
+  assert(schedule->hasSingleDimensions() && "expected single M/N/K dimension");
+
   int64_t flatWorkgroupSize =
       targetSubgroupSize *
       ShapedType::getNumElements(schedule->nSubgroupCounts) *
@@ -368,12 +370,12 @@
     reductionTileSizes[ic] = 1;
   }
   // Compute the M/N dimension tile size by multiply subgroup information.
-  workgroupTileSizes[mDim] =
-      schedule->mSubgroupCounts[0] * schedule->mTileSizes[0] * schedule->mSize;
-  workgroupTileSizes[nDim] =
-      schedule->nSubgroupCounts[0] * schedule->nTileSizes[0] * schedule->nSize;
+  workgroupTileSizes[mDim] = schedule->mSubgroupCounts[0] *
+                             schedule->mTileSizes[0] * schedule->mSizes[0];
+  workgroupTileSizes[nDim] = schedule->nSubgroupCounts[0] *
+                             schedule->nTileSizes[0] * schedule->nSizes[0];
 
-  reductionTileSizes[kDim] = schedule->kTileSizes[0] * schedule->kSize;
+  reductionTileSizes[kDim] = schedule->kTileSizes[0] * schedule->kSizes[0];
 
   // Tile all filter loop dimensions to 1.
   for (int64_t filterDim : convolutionDims->filterLoop) {
@@ -510,9 +512,9 @@
 
   // TODO(Max191): Support multiple M/N/K dimension problems for MMASchedules
   // once the pipeline is able to support it. After adding multiple dimensions,
-  // all instances of schedule->m/nSubgroupCounts[0] and
-  // schedule->m/n/kTileSizes[0] need to use the full list of sizes instead of
-  // just the first element.
+  // all instances of schedule->m/nSubgroupCounts[0],
+  // schedule->m/n/kTileSizes[0] and schedule->m/n/kSizes[0] need to use the
+  // full list of sizes instead of just the first element.
   GPUMatmulShapeType problem{
       {bounds[mDim]}, {bounds[nDim]}, {bounds[kDim]}, getDimBounds(batchDims),
       lhsElemType,    rhsElemType,    initElemType,   numHorizontallyFusedOps};
@@ -597,6 +599,8 @@
   LDBG() << "Target Subgroup size: " << targetSubgroupSize;
   LDBG() << "Schedule: " << schedule;
 
+  assert(schedule->hasSingleDimensions() && "expected single M/N/K dimension");
+
   int64_t flatWorkgroupSize =
       targetSubgroupSize *
       ShapedType::getNumElements(schedule->nSubgroupCounts) *
@@ -623,12 +627,12 @@
   }
 
   // Compute the M/N dimension tile size by multiply subgroup information.
-  workgroupTileSizes[mDim] =
-      schedule->mSubgroupCounts[0] * schedule->mTileSizes[0] * schedule->mSize;
-  workgroupTileSizes[nDim] =
-      schedule->nSubgroupCounts[0] * schedule->nTileSizes[0] * schedule->nSize;
+  workgroupTileSizes[mDim] = schedule->mSubgroupCounts[0] *
+                             schedule->mTileSizes[0] * schedule->mSizes[0];
+  workgroupTileSizes[nDim] = schedule->nSubgroupCounts[0] *
+                             schedule->nTileSizes[0] * schedule->nSizes[0];
 
-  reductionTileSizes[kDim] = schedule->kTileSizes[0] * schedule->kSize;
+  reductionTileSizes[kDim] = schedule->kTileSizes[0] * schedule->kSizes[0];
 
   LLVM_DEBUG(debugPrintContractionInfo("Workgroup tile sizes", op.getNumLoops(),
                                        *contractionDims, workgroupTileSizes));
@@ -893,7 +897,7 @@
         pvSchedule.mSubgroupCounts[i] * pvSchedule.mTileSizes[i];
     // Multiply by the intrinsic shape for the inner most dim.
     if (i == mDims.size() - 1) {
-      workgroupTileSizes[mDim] *= pvSchedule.mSize;
+      workgroupTileSizes[mDim] *= llvm::product_of(pvSchedule.mSizes);
     }
     subgroupBasis.counts[mDim] = pvSchedule.mSubgroupCounts[i];
   }
@@ -902,7 +906,7 @@
         pvSchedule.nSubgroupCounts[i] * pvSchedule.nTileSizes[i];
     // Multiply by the intrinsic shape for the inner most dim.
     if (i == nDims.size() - 1) {
-      workgroupTileSizes[nDim] *= pvSchedule.nSize;
+      workgroupTileSizes[nDim] *= llvm::product_of(pvSchedule.nSizes);
     }
     subgroupBasis.counts[nDim] = pvSchedule.nSubgroupCounts[i];
   }
@@ -910,7 +914,7 @@
     reductionTileSizes[k2Dim] = pvSchedule.kTileSizes[i];
     // Multiply by the intrinsic shape for the inner most dim.
     if (i == k2Dims.size() - 1) {
-      reductionTileSizes[k2Dim] *= pvSchedule.kSize;
+      reductionTileSizes[k2Dim] *= llvm::product_of(pvSchedule.kSizes);
     }
   }
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir
index 9d50043..d98e73f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir
@@ -100,6 +100,40 @@
 
 // -----
 
+#lhs_map = affine_map<(b, m, n, ko, kb) -> (b, m, ko, kb)>
+#rhs_map = affine_map<(b, m, n, ko, kb) -> (n, ko, kb)>
+#scale_lhs = affine_map<(b, m, n, ko, kb) -> (b, m, ko)>
+#scale_rhs = affine_map<(b, m, n, ko, kb) -> (n, ko)>
+#out_map = affine_map<(b, m, n, ko, kb) -> (b, m, n)>
+func.func @scaled_matmul_with_dynamic_batch(
+    %A: tensor<?x128x512x32xf4E2M1FN>, %B: tensor<16384x512x32xf4E2M1FN>, %A_scales: tensor<?x128x512xf8E8M0FNU>, %B_scales: tensor<16384x512xf8E8M0FNU>, %C: tensor<?x128x16384xf32>) -> tensor<?x128x16384xf32> {
+  %0 = linalg.generic {
+    indexing_maps = [#lhs_map, #rhs_map, #scale_lhs, #scale_rhs, #out_map],
+    iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]
+  } ins(%A, %B, %A_scales, %B_scales : tensor<?x128x512x32xf4E2M1FN>, tensor<16384x512x32xf4E2M1FN>, tensor<?x128x512xf8E8M0FNU>, tensor<16384x512xf8E8M0FNU>) outs(%C : tensor<?x128x16384xf32>) {
+  ^bb0(%a: f4E2M1FN, %b: f4E2M1FN, %a_scale: f8E8M0FNU, %b_scale: f8E8M0FNU, %out: f32):
+    %1 = arith.scaling_extf %a, %a_scale : f4E2M1FN, f8E8M0FNU to f32
+    %2 = arith.scaling_extf %b, %b_scale : f4E2M1FN, f8E8M0FNU to f32
+    %3 = arith.mulf %1, %2 : f32
+    %4 = arith.addf %out, %3 : f32
+    linalg.yield %4 : f32
+  } -> tensor<?x128x16384xf32>
+  return %0 : tensor<?x128x16384xf32>
+}
+
+// CHECK-LABEL: func.func @scaled_matmul_with_dynamic_batch
+//  CHECK-SAME:   #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
+//  CHECK-SAME:   #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = true, use_igemm_convolution = false>
+//       CHECK:   linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
+//  CHECK-SAME:     mma_kind = #iree_gpu.scaled_mma_layout<intrinsic = MFMA_SCALE_F32_16x16x128_B32, lhs_elem_type = f4E2M1FN, rhs_elem_type = f4E2M1FN, acc_elem_type = f32>
+//  CHECK-SAME:     promote_operands = [0, 1, 2, 3]
+//   CHECK-NOT:     promotion_types = [{{.*}}#iree_gpu.use_global_load_dma{{.*}}
+//  CHECK-SAME:     reduction = [0, 0, 0, 4, 1]
+//  CHECK-SAME:     subgroup = [0, 2, 4, 0, 0]
+//  CHECK-SAME:     workgroup = [1, 64, 128, 0, 0]
+
+// -----
+
 #lhs_map = affine_map<(M, N, Ko, Kb) -> (M, Ko, Kb)>
 #rhs_map = affine_map<(M, N, Ko, Kb) -> (N, Ko, Kb)>
 #scale_m = affine_map<(M, N, Ko, Kb) -> (M, Ko)>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
index 400bde2..efba839 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
@@ -895,9 +895,9 @@
   Type initElem = getElementType(init);
   // TODO(Max191): Support multiple M/N/K dimension problems for MMASchedules
   // once the pipeline is able to support it. After adding multiple dimensions,
-  // all instances of schedule->m/nSubgroupCounts[0] and
-  // schedule->m/n/kTileSizes[0] need to use the full list of sizes instead of
-  // just the first element.
+  // all instances of schedule->m/nSubgroupCounts[0],
+  // schedule->m/n/kTileSizes[0] and schedule->m/n/kSizes[0] need to use the
+  // full list of sizes instead of just the first element.
   GPUMatmulShapeType problem(dimM, dimN, dimK, lhsElem, rhsElem, initElem);
 
   SmallVector<GPUIntrinsicType> intrinsics;
@@ -930,6 +930,7 @@
                         subgroupSize, transposedLhs, transposedRhs);
   if (failed(schedule))
     return failure();
+  assert(schedule->hasSingleDimensions() && "expected single M/N/K dimension");
 
   auto pipeline = CodeGenPipeline::SPIRVCooperativeMatrixVectorize;
 
@@ -940,9 +941,9 @@
   SmallVector<int64_t> vectorSizes(kIndex + 1, 0);
   if (isBM)
     vectorSizes[bIndex] = 1;
-  vectorSizes[mIndex] = schedule->mSize;
-  vectorSizes[nIndex] = schedule->nSize;
-  vectorSizes[kIndex] = schedule->kSize;
+  vectorSizes[mIndex] = schedule->mSizes[0];
+  vectorSizes[nIndex] = schedule->nSizes[0];
+  vectorSizes[kIndex] = schedule->kSizes[0];
 
   SmallVector<int64_t> subgroupTileSizes(lastParallelDim + 1, 0);
   if (isBM)
@@ -963,7 +964,7 @@
   // TODO(#10499): Consolidate tiling configuration across different pipelines.
   SmallVector<int64_t> reductionTileSizes;
   reductionTileSizes.append(kIndex, 0);
-  reductionTileSizes.push_back(schedule->kTileSizes[0] * schedule->kSize);
+  reductionTileSizes.push_back(schedule->kTileSizes[0] * schedule->kSizes[0]);
 
   TileSizesListType tileSizes = {workgroupTileSizes, subgroupTileSizes,
                                  reductionTileSizes, vectorSizes};