[Codegen][GPU] Adding new heuristics to take all dimensions into account when distributing tiles (#21803)

This motivation of this PR is to address the multi-dimension
distribution situation in convolution codegen.

A sample convolution config that wouldn't distribute properly looks like
the following:
> convbfp16 -n 16 -c 768 -H 48 -W 32 -k 2048 -y3 -x 3 -p 1 -q 1 -u 1 -v
1 -l 1 -j 1 -m conv -g 1 -F 1 -t 1 ​

There are 3 dimension in M: [16, 48, 32]. There's one dimension in N:
[256]. Since N's last dimension is much larger than M's last dimension,
the current algorithm will yield an extremely imbalanced tile that
allocate all subgroup and tiles to N dimension, causing a small memory
bound workgroup.

The new tile allocation algorithm can prevent the problem by considering
the entire aggregated M and N dimension together and find optimal
balanced tile for the full scope. Then it will attempt to allocate the
full allocated tile to each sub-dimension. This yields a much more
reasonably distributed tiles. With the new algorithm, it will improve
the performance of this convolution from 5000us -> 1500us, and 5% of
performance among all 478 convolutions.

I'm pushing this to review as I gather gemm and model perf. Likely since
this has little impact with M/N problems that have a single dimension,
the performance should stay flat. I'll be posting perf updates as
follow-up comments soon.

---------

Signed-off-by: jerryyin <zhuoryin@amd.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
index 8617327..6192fe0 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
@@ -312,6 +312,72 @@
   return kTileSizes;
 }
 
+/// Distributes tilesToDistribute to totalTiles using their GCD. Both
+/// totalTiles and tilesToDistribute are updated to reflect the remaining
+/// tiles to distribute. The return value is the number of tiles distributed.
+static int64_t distributeTilesUsingGCD(int64_t &totalTiles,
+                                       int64_t &tilesToDistribute) {
+  APInt gcd = GreatestCommonDivisor(APInt(64, tilesToDistribute),
+                                    APInt(64, totalTiles));
+  int64_t distributeTileCount = gcd.getSExtValue();
+  totalTiles /= distributeTileCount;
+  tilesToDistribute /= distributeTileCount;
+
+  return distributeTileCount;
+}
+
+/// Distributes the square root of the subgroup and tile counts to both M and N
+/// dimensions. The first argument servers as a flag to indicate whether the
+/// distribution is for the M or N dimension. Both total tiles and remaining
+/// tiles are updated to reflect the remaining tiles to distribute. Note: This
+/// function should only be used for primary distribution as it assigns the sqrt
+/// directly to the dimension.
+static void distributeSqrtForDim(
+    bool isMDim, int64_t subgroupSqrt, int64_t tileSqrt,
+    int64_t &mTotalTileToDistribute, int64_t &nTotalTileToDistribute,
+    int64_t &mSubgroupDistributed, int64_t &nSubgroupDistributed,
+    int64_t &mTileSizeDistributed, int64_t &nTileSizeDistributed,
+    int64_t &remainingSubgroups, int64_t &remainingTiles) {
+  if (isMDim) {
+    mSubgroupDistributed = subgroupSqrt;
+    mTileSizeDistributed = tileSqrt;
+    mTotalTileToDistribute /= (subgroupSqrt * tileSqrt);
+  } else {
+    nSubgroupDistributed = subgroupSqrt;
+    nTileSizeDistributed = tileSqrt;
+    nTotalTileToDistribute /= (subgroupSqrt * tileSqrt);
+  }
+
+  remainingSubgroups /= subgroupSqrt;
+  remainingTiles /= tileSqrt;
+}
+
+/// Distributes tiles and subgroups to both M and N dimensions using their GCD.
+/// The first argument servers as a flag to indicate whether the distribution is
+/// for the M or N dimension. Both total tiles and remaining tiles are updated
+/// to reflect the remaining tiles to distribute.
+static void distributeGCDForDim(bool isMDim, int64_t &mTotalTileToDistribute,
+                                int64_t &nTotalTileToDistribute,
+                                int64_t &mSubgroupDistributed,
+                                int64_t &nSubgroupDistributed,
+                                int64_t &mTileSizeDistributed,
+                                int64_t &nTileSizeDistributed,
+                                int64_t &remainingSubgroups,
+                                int64_t &remainingTiles) {
+
+  int64_t &totalTilesToDistribute =
+      isMDim ? mTotalTileToDistribute : nTotalTileToDistribute;
+  int64_t &subgroupDistributed =
+      isMDim ? mSubgroupDistributed : nSubgroupDistributed;
+  int64_t &tileDistributed =
+      isMDim ? mTileSizeDistributed : nTileSizeDistributed;
+
+  subgroupDistributed =
+      distributeTilesUsingGCD(totalTilesToDistribute, remainingSubgroups);
+  tileDistributed =
+      distributeTilesUsingGCD(totalTilesToDistribute, remainingTiles);
+}
+
 /// Choose an optimal mma schedule with the heuristic that minimized the total
 /// amount of data read from global memory, per workgroup, respecting the
 /// heuristic seeds.
@@ -333,83 +399,111 @@
       llvm::divideCeil(problem.mSizes.back(), intrinsic.mSizes[0]);
   nTotalTileCounts.back() =
       llvm::divideCeil(problem.nSizes.back(), intrinsic.nSizes[0]);
+  int64_t mTotalTileToDistribute = prod(mTotalTileCounts);
+  int64_t nTotalTileToDistribute = prod(nTotalTileCounts);
 
   int64_t remainingSubgroups = seeds.bestSubgroupCountPerWorkgroup;
   int64_t remainingTiles = seeds.bestMNTileCountPerSubgroup;
-  // Assign more subgroups to the M dimension (used later) to balance thread
-  // counts along X and Y dimensions.
-  int mDim = problem.mSizes.size() - 1;
-  int nDim = problem.nSizes.size() - 1;
-  SmallVector<int64_t> mTileSizes(problem.mSizes.size(), 0),
-      nTileSizes(problem.nSizes.size(), 0),
-      mSubgroupCounts(problem.mSizes.size(), 0),
-      nSubgroupCounts(problem.nSizes.size(), 0);
-  // Start at the innermost nDim and mDim, and try to distribute evenly to M and
-  // N for each pair of M and N dims. Otherwise, distribute to N and then M.
+
+  // Initial collapsed subgroup counts and tile sizes. Distribute to collapsed M
+  // and N dimensions to avoid starving either dimension. Once the collapsed
+  // distribution is determined, it will be distributed to individual dimensions
+  // of M and N.
+  int64_t mSubgroupDistributed = 1;
+  int64_t nSubgroupDistributed = 1;
+  int64_t mTileSizeDistributed = 1;
+  int64_t nTileSizeDistributed = 1;
+
   LDBG() << "Starting MMA schedule distribution";
-  while (mDim >= 0 || nDim >= 0) {
-    LDBG() << "Current iteration: mDim: " << mDim << ", nDim: " << nDim
-           << ", remainingSubgroups: " << remainingSubgroups
-           << ", remainingTiles: " << remainingTiles
-           << ", mTileSizes: " << mTileSizes << ", nTileSizes: " << nTileSizes;
-    int64_t subgroupSqrt =
-        1ull << (llvm::divideCeil(llvm::Log2_64(remainingSubgroups), 2));
-    int64_t tileSqrt = 1ull << (llvm::Log2_64(remainingTiles) / 2);
+  LDBG() << "mTotalTileCounts: " << mTotalTileCounts
+         << ", nTotalTileCounts: " << nTotalTileCounts
+         << ", remainingSubgroups: " << remainingSubgroups
+         << ", remainingTiles: " << remainingTiles;
 
-    // See if the square root can divide mTotalTileCount. If so it means we can
-    // distribute to both dimensions evenly to minimize the number of global
-    // loads. Otherwise, try to distribute to N and then M.
-    if (mDim >= 0 && nDim >= 0 &&
-        mTotalTileCounts[mDim] > (subgroupSqrt * tileSqrt) &&
-        mTotalTileCounts[mDim] % (subgroupSqrt * tileSqrt) == 0) {
-      LDBG() << "Distributing evenly to M and N dimensions.";
-      mSubgroupCounts[mDim] = subgroupSqrt;
-      mTileSizes[mDim] = tileSqrt;
+  // This aims to be generous on subgroup splitting, produce the smallest
+  // power-of-two that is >= sqrt(remainingSubgroups)
+  int64_t subgroupSqrt =
+      1ull << (llvm::divideCeil(llvm::Log2_64(remainingSubgroups), 2));
+  // This aims to be conservative on tile splitting, produce the largest
+  // power-of-two that is <= sqrt(remainingTiles)
+  int64_t tileSqrt = 1ull << (llvm::Log2_64(remainingTiles) / 2);
+  int64_t splitFactor = subgroupSqrt * tileSqrt;
 
-      remainingSubgroups /= subgroupSqrt;
-      remainingTiles /= tileSqrt;
+  LDBG() << "splitFactor: " << splitFactor << ", subgroupSqrt: " << subgroupSqrt
+         << ", tileSqrt: " << tileSqrt;
 
-      APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCounts[nDim]),
-                                         APInt(64, remainingSubgroups));
-      nSubgroupCounts[nDim] = nGCD.getSExtValue();
-      nTotalTileCounts[nDim] /= nSubgroupCounts[nDim];
-      remainingSubgroups /= nSubgroupCounts[nDim];
+  // See if the square root can divide total tile count. If so it means we can
+  // distribute to a dimensions evenly to minimize the number of global
+  // loads. Or else fall back to GCD distribution.
+  bool canMDistributeEvenly = mTotalTileToDistribute > splitFactor &&
+                              mTotalTileToDistribute % splitFactor == 0;
+  bool canNDistributeEvenly = nTotalTileToDistribute > splitFactor &&
+                              nTotalTileToDistribute % splitFactor == 0;
+  if (canMDistributeEvenly) {
+    LDBG() << "Distributing seed evenly to M dim";
+    distributeSqrtForDim(true, subgroupSqrt, tileSqrt, mTotalTileToDistribute,
+                         nTotalTileToDistribute, mSubgroupDistributed,
+                         nSubgroupDistributed, mTileSizeDistributed,
+                         nTileSizeDistributed, remainingSubgroups,
+                         remainingTiles);
+    distributeGCDForDim(false, mTotalTileToDistribute, nTotalTileToDistribute,
+                        mSubgroupDistributed, nSubgroupDistributed,
+                        mTileSizeDistributed, nTileSizeDistributed,
+                        remainingSubgroups, remainingTiles);
+  } else if (canNDistributeEvenly) {
+    LDBG() << "Distributing seed evenly to N dim";
+    distributeSqrtForDim(false, subgroupSqrt, tileSqrt, mTotalTileToDistribute,
+                         nTotalTileToDistribute, mSubgroupDistributed,
+                         nSubgroupDistributed, mTileSizeDistributed,
+                         nTileSizeDistributed, remainingSubgroups,
+                         remainingTiles);
+    distributeGCDForDim(true, mTotalTileToDistribute, nTotalTileToDistribute,
+                        mSubgroupDistributed, nSubgroupDistributed,
+                        mTileSizeDistributed, nTileSizeDistributed,
+                        remainingSubgroups, remainingTiles);
+  } else {
+    LDBG() << "Distributing seed using GCD";
+    distributeGCDForDim(false, mTotalTileToDistribute, nTotalTileToDistribute,
+                        mSubgroupDistributed, nSubgroupDistributed,
+                        mTileSizeDistributed, nTileSizeDistributed,
+                        remainingSubgroups, remainingTiles);
+    distributeGCDForDim(true, mTotalTileToDistribute, nTotalTileToDistribute,
+                        mSubgroupDistributed, nSubgroupDistributed,
+                        mTileSizeDistributed, nTileSizeDistributed,
+                        remainingSubgroups, remainingTiles);
+  }
 
-      nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCounts[nDim]),
-                                   APInt(64, remainingTiles));
-      nTileSizes[nDim] = nGCD.getSExtValue();
-      remainingTiles /= nTileSizes[nDim];
-    } else {
-      if (nDim >= 0) {
-        LDBG() << "Distributing to N dimension first.";
-        APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCounts[nDim]),
-                                           APInt(64, remainingSubgroups));
-        nSubgroupCounts[nDim] = nGCD.getSExtValue();
-        nTotalTileCounts[nDim] /= nSubgroupCounts[nDim];
-        remainingSubgroups /= nSubgroupCounts[nDim];
+  // Note: Experimentation has proved that leaving the leftover factors
+  // unassigned is better than greedily assigning them to the larger collapsed
+  // dimension. This is likely because assigning leftover factors often results
+  // in overly aggressive tiling that ended up reducing occupancy and increasing
+  // shared memory usage.
+  LDBG() << "Leftover factors: subgroups: " << remainingSubgroups
+         << ", tiles: " << remainingTiles;
+  LDBG() << "Collapsed subgroup counts: M: " << mSubgroupDistributed
+         << ", N: " << nSubgroupDistributed;
+  LDBG() << "Collapsed tile sizes: M: " << mTileSizeDistributed
+         << ", N: " << nTileSizeDistributed;
 
-        nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCounts[nDim]),
-                                     APInt(64, remainingTiles));
-        nTileSizes[nDim] = nGCD.getSExtValue();
-        remainingTiles /= nTileSizes[nDim];
-      }
+  SmallVector<int64_t> mSubgroupCounts(problem.mSizes.size(), 0),
+      nSubgroupCounts(problem.nSizes.size(), 0),
+      mTileSizes(problem.mSizes.size(), 0),
+      nTileSizes(problem.nSizes.size(), 0);
 
-      if (mDim >= 0) {
-        LDBG() << "Distributing to M dimension next.";
-        APInt mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCounts[mDim]),
-                                           APInt(64, remainingSubgroups));
-        mSubgroupCounts[mDim] = mGCD.getSExtValue();
-        mTotalTileCounts[mDim] /= mSubgroupCounts[mDim];
-        remainingSubgroups /= mSubgroupCounts[mDim];
+  // Distribute collapsed tile to M dims from inner -> outer.
+  for (size_t e = problem.mSizes.size(), i = e - 1; i < e; --i) {
+    mSubgroupCounts[i] =
+        distributeTilesUsingGCD(mTotalTileCounts[i], mSubgroupDistributed);
+    mTileSizes[i] =
+        distributeTilesUsingGCD(mTotalTileCounts[i], mTileSizeDistributed);
+  }
 
-        mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCounts[mDim]),
-                                     APInt(64, remainingTiles));
-        mTileSizes[mDim] = mGCD.getSExtValue();
-        remainingTiles /= mTileSizes[mDim];
-      }
-    }
-    --mDim;
-    --nDim;
+  // Distribute collapsed tile to N dims from inner -> outer.
+  for (size_t e = problem.nSizes.size(), i = e - 1; i < e; --i) {
+    nSubgroupCounts[i] =
+        distributeTilesUsingGCD(nTotalTileCounts[i], nSubgroupDistributed);
+    nTileSizes[i] =
+        distributeTilesUsingGCD(nTotalTileCounts[i], nTileSizeDistributed);
   }
 
   SmallVector<int64_t> kTileSizes =
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir
index 4c77bad..db1909e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir
@@ -32,12 +32,12 @@
 //  CHECK-SAME:     promote_operands = [0, 1]
 
 //  GFX942-SAME:    reduction = [0, 0, 0, 0, 8]
-//  GFX942-SAME:    subgroup = [1, 4, 1, 1, 0]
+//  GFX942-SAME:    subgroup = [1, 2, 1, 2, 0]
 //  GFX942-SAME:    workgroup = [1, 4, 32, 64, 0]
 
 //  MI300X-SAME:    reduction = [0, 0, 0, 0, 8]
 //  MI300X-SAME:    subgroup = [1, 1, 1, 1, 0]
-//  MI300X-SAME:    workgroup = [1, 1, 32, 64, 0]}>
+//  MI300X-SAME:    workgroup = [1, 2, 32, 32, 0]}>
 
 // -----
 
@@ -66,12 +66,12 @@
 //  CHECK-SAME:     promote_operands = [0, 1]
 
 // GFX942-SAME:     reduction = [0, 0, 0, 0, 8]
-// GFX942-SAME:     subgroup = [1, 1, 4, 1, 0]
+// GFX942-SAME:     subgroup = [1, 2, 2, 1, 0]
 // GFX942-SAME:     workgroup = [1, 64, 4, 32, 0]
 
 // MI300X-SAME:     reduction = [0, 0, 0, 0, 8]
 // MI300X-SAME:     subgroup = [1, 1, 1, 1, 0]
-// MI300X-SAME:     workgroup = [1, 64, 1, 32, 0]
+// MI300X-SAME:     workgroup = [1, 32, 2, 32, 0]
 
 // -----
 
@@ -104,11 +104,11 @@
 // GFX942-SAME:     subgroup = [2, 1, 1, 1, 0]
 // GFX942-SAME:     workgroup = [2, 1, 32, 64, 0]
 
-// MI300X-SAME:     padding = [1, 1, 32, 64, 32]
+// MI300X-SAME:     padding = [2, 1, 32, 32, 32]
 // MI300X-SAME:     promote_operands = [0, 1, 2]
 // MI300X-SAME:     reduction = [0, 0, 0, 0, 8]
 // MI300X-SAME:     subgroup = [1, 1, 1, 1, 0]
-// MI300X-SAME:     workgroup = [1, 1, 32, 64, 0]
+// MI300X-SAME:     workgroup = [2, 1, 32, 32, 0]
 
 // PAD-CONV-GFX942:     padding_conv = [2, 1, 32, 64, 0, 0, 0]
 
@@ -140,14 +140,14 @@
 // GFX942-SAME:     padding = [1, 64, 4, 32, 32]
 // GFX942-SAME:     promote_operands = [0, 1, 2]
 // GFX942-SAME:     reduction = [0, 0, 0, 0, 8]
-// GFX942-SAME:     subgroup = [1, 1, 4, 1, 0]
+// GFX942-SAME:     subgroup = [1, 2, 2, 1, 0]
 // GFX942-SAME:     workgroup = [1, 64, 4, 32, 0]
 
-// MI300X-SAME:     padding = [1, 64, 1, 32, 32]
+// MI300X-SAME:     padding = [1, 32, 2, 32, 32]
 // MI300X-SAME:     promote_operands = [0, 1, 2]
 // MI300X-SAME:     reduction = [0, 0, 0, 0, 8]
 // MI300X-SAME:     subgroup = [1, 1, 1, 1, 0]
-// MI300X-SAME:     workgroup = [1, 64, 1, 32, 0]
+// MI300X-SAME:     workgroup = [1, 32, 2, 32, 0]
 
 // PAD-CONV-GFX942:     padding_conv = [1, 64, 4, 32, 0, 0, 0]
 
@@ -176,19 +176,19 @@
 //       CHECK:   linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
 //  CHECK-SAME:     mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
 
-// GFX942-SAME:     padding = [1, 8, 32, 32, 32]
+// GFX942-SAME:     padding = [1, 4, 32, 32, 32]
 // GFX942-SAME:     promote_operands = [0, 1, 2]
 // GFX942-SAME:     reduction = [0, 0, 0, 0, 2]
-// GFX942-SAME:     subgroup = [1, 4, 1, 1, 0]
-// GFX942-SAME:     workgroup = [1, 8, 32, 32, 0]
+// GFX942-SAME:     subgroup = [1, 2, 1, 1, 0]
+// GFX942-SAME:     workgroup = [1, 4, 32, 32, 0]
 
-// MI300X-SAME:     padding = [1, 4, 32, 32, 32]
+// MI300X-SAME:     padding = [1, 2, 32, 32, 32]
 // MI300X-SAME:     promote_operands = [0, 1, 2]
 // MI300X-SAME:     reduction = [0, 0, 0, 0, 2]
-// MI300X-SAME:     subgroup = [1, 2, 1, 1, 0]
-// MI300X-SAME:     workgroup = [1, 4, 32, 32, 0]
+// MI300X-SAME:     subgroup = [1, 1, 1, 1, 0]
+// MI300X-SAME:     workgroup = [1, 2, 32, 32, 0]
 
-// PAD-CONV-GFX942:     padding_conv = [1, 8, 32, 32, 0, 0, 32]
+// PAD-CONV-GFX942:     padding_conv = [1, 4, 32, 32, 0, 0, 32]
 
 // -----
 
@@ -289,13 +289,13 @@
 // GFX942-SAME:     padding = [2, 2, 32, 64, 64]
 // GFX942-SAME:     promote_operands = [0, 1, 2]
 // GFX942-SAME:     reduction = [0, 0, 0, 0, 4]
-// GFX942-SAME:     subgroup = [2, 2, 1, 1, 0]
+// GFX942-SAME:     subgroup = [2, 1, 1, 2, 0]
 // GFX942-SAME:     workgroup = [2, 2, 32, 64, 0]
 
-// MI300X-SAME:     padding = [1, 1, 32, 64, 64]
+// MI300X-SAME:     padding = [1, 2, 32, 32, 64]
 // MI300X-SAME:     promote_operands = [0, 1, 2]
 // MI300X-SAME:     reduction = [0, 0, 0, 0, 4]
 // MI300X-SAME:     subgroup = [1, 1, 1, 1, 0]
-// MI300X-SAME:     workgroup = [1, 1, 32, 64, 0]
+// MI300X-SAME:     workgroup = [1, 2, 32, 32, 0]
 
 // PAD-CONV-GFX942:     padding_conv = [2, 2, 32, 64, 0, 0]
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
index 5101daf..d858ea7 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
@@ -84,7 +84,7 @@
 //  CHECK-SAME:     mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
 //  CHECK-SAME:     promote_operands = [0, 1]
 //  CHECK-SAME:     reduction = [0, 0, 0, 0, 4, 1]
-//  CHECK-SAME:     subgroup = [2, 2, 1, 1, 0, 0]
+//  CHECK-SAME:     subgroup = [2, 1, 2, 1, 0, 0]
 //  CHECK-SAME:     workgroup = [2, 4, 32, 32, 0, 0]
 
 //        LATE:  LLVMGPUVectorDistribute
@@ -439,14 +439,14 @@
 // schedule with nTileSize of 16 while in reality it should be 8.
 
 // LATE-LABEL: func.func @unaligned_to_intrinsic_batched_matmul_tiling_check
-// LATE-SAME:    #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [512, 1, 1] subgroup_size = 64
+// LATE-SAME:    #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
 // LATE-SAME:    {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false, use_igemm_convolution = false>}
 //      LATE:    linalg.batch_matmul {{.*}}lowering_config = #iree_gpu.lowering_config
-//  LATE-SAME:     padding = [1, 16, 512, 4]
+//  LATE-SAME:     padding = [1, 16, 128, 4]
 //  LATE-SAME:     promote_operands = [0, 1, 2]
 //  LATE-SAME:     reduction = [0, 0, 0, 1]
-//  LATE-SAME:     subgroup = [0, 1, 4, 0]
-//  LATE-SAME:     workgroup = [1, 16, 512, 0]
+//  LATE-SAME:     subgroup = [0, 1, 2, 0]
+//  LATE-SAME:     workgroup = [1, 16, 128, 0]
 
 // -----