[Codegen][GPU] Adding new heuristics to take all dimensions into account when distributing tiles (#21803)
This motivation of this PR is to address the multi-dimension
distribution situation in convolution codegen.
A sample convolution config that wouldn't distribute properly looks like
the following:
> convbfp16 -n 16 -c 768 -H 48 -W 32 -k 2048 -y3 -x 3 -p 1 -q 1 -u 1 -v
1 -l 1 -j 1 -m conv -g 1 -F 1 -t 1
There are 3 dimension in M: [16, 48, 32]. There's one dimension in N:
[256]. Since N's last dimension is much larger than M's last dimension,
the current algorithm will yield an extremely imbalanced tile that
allocate all subgroup and tiles to N dimension, causing a small memory
bound workgroup.
The new tile allocation algorithm can prevent the problem by considering
the entire aggregated M and N dimension together and find optimal
balanced tile for the full scope. Then it will attempt to allocate the
full allocated tile to each sub-dimension. This yields a much more
reasonably distributed tiles. With the new algorithm, it will improve
the performance of this convolution from 5000us -> 1500us, and 5% of
performance among all 478 convolutions.
I'm pushing this to review as I gather gemm and model perf. Likely since
this has little impact with M/N problems that have a single dimension,
the performance should stay flat. I'll be posting perf updates as
follow-up comments soon.
---------
Signed-off-by: jerryyin <zhuoryin@amd.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
index 8617327..6192fe0 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
@@ -312,6 +312,72 @@
return kTileSizes;
}
+/// Distributes tilesToDistribute to totalTiles using their GCD. Both
+/// totalTiles and tilesToDistribute are updated to reflect the remaining
+/// tiles to distribute. The return value is the number of tiles distributed.
+static int64_t distributeTilesUsingGCD(int64_t &totalTiles,
+ int64_t &tilesToDistribute) {
+ APInt gcd = GreatestCommonDivisor(APInt(64, tilesToDistribute),
+ APInt(64, totalTiles));
+ int64_t distributeTileCount = gcd.getSExtValue();
+ totalTiles /= distributeTileCount;
+ tilesToDistribute /= distributeTileCount;
+
+ return distributeTileCount;
+}
+
+/// Distributes the square root of the subgroup and tile counts to both M and N
+/// dimensions. The first argument servers as a flag to indicate whether the
+/// distribution is for the M or N dimension. Both total tiles and remaining
+/// tiles are updated to reflect the remaining tiles to distribute. Note: This
+/// function should only be used for primary distribution as it assigns the sqrt
+/// directly to the dimension.
+static void distributeSqrtForDim(
+ bool isMDim, int64_t subgroupSqrt, int64_t tileSqrt,
+ int64_t &mTotalTileToDistribute, int64_t &nTotalTileToDistribute,
+ int64_t &mSubgroupDistributed, int64_t &nSubgroupDistributed,
+ int64_t &mTileSizeDistributed, int64_t &nTileSizeDistributed,
+ int64_t &remainingSubgroups, int64_t &remainingTiles) {
+ if (isMDim) {
+ mSubgroupDistributed = subgroupSqrt;
+ mTileSizeDistributed = tileSqrt;
+ mTotalTileToDistribute /= (subgroupSqrt * tileSqrt);
+ } else {
+ nSubgroupDistributed = subgroupSqrt;
+ nTileSizeDistributed = tileSqrt;
+ nTotalTileToDistribute /= (subgroupSqrt * tileSqrt);
+ }
+
+ remainingSubgroups /= subgroupSqrt;
+ remainingTiles /= tileSqrt;
+}
+
+/// Distributes tiles and subgroups to both M and N dimensions using their GCD.
+/// The first argument servers as a flag to indicate whether the distribution is
+/// for the M or N dimension. Both total tiles and remaining tiles are updated
+/// to reflect the remaining tiles to distribute.
+static void distributeGCDForDim(bool isMDim, int64_t &mTotalTileToDistribute,
+ int64_t &nTotalTileToDistribute,
+ int64_t &mSubgroupDistributed,
+ int64_t &nSubgroupDistributed,
+ int64_t &mTileSizeDistributed,
+ int64_t &nTileSizeDistributed,
+ int64_t &remainingSubgroups,
+ int64_t &remainingTiles) {
+
+ int64_t &totalTilesToDistribute =
+ isMDim ? mTotalTileToDistribute : nTotalTileToDistribute;
+ int64_t &subgroupDistributed =
+ isMDim ? mSubgroupDistributed : nSubgroupDistributed;
+ int64_t &tileDistributed =
+ isMDim ? mTileSizeDistributed : nTileSizeDistributed;
+
+ subgroupDistributed =
+ distributeTilesUsingGCD(totalTilesToDistribute, remainingSubgroups);
+ tileDistributed =
+ distributeTilesUsingGCD(totalTilesToDistribute, remainingTiles);
+}
+
/// Choose an optimal mma schedule with the heuristic that minimized the total
/// amount of data read from global memory, per workgroup, respecting the
/// heuristic seeds.
@@ -333,83 +399,111 @@
llvm::divideCeil(problem.mSizes.back(), intrinsic.mSizes[0]);
nTotalTileCounts.back() =
llvm::divideCeil(problem.nSizes.back(), intrinsic.nSizes[0]);
+ int64_t mTotalTileToDistribute = prod(mTotalTileCounts);
+ int64_t nTotalTileToDistribute = prod(nTotalTileCounts);
int64_t remainingSubgroups = seeds.bestSubgroupCountPerWorkgroup;
int64_t remainingTiles = seeds.bestMNTileCountPerSubgroup;
- // Assign more subgroups to the M dimension (used later) to balance thread
- // counts along X and Y dimensions.
- int mDim = problem.mSizes.size() - 1;
- int nDim = problem.nSizes.size() - 1;
- SmallVector<int64_t> mTileSizes(problem.mSizes.size(), 0),
- nTileSizes(problem.nSizes.size(), 0),
- mSubgroupCounts(problem.mSizes.size(), 0),
- nSubgroupCounts(problem.nSizes.size(), 0);
- // Start at the innermost nDim and mDim, and try to distribute evenly to M and
- // N for each pair of M and N dims. Otherwise, distribute to N and then M.
+
+ // Initial collapsed subgroup counts and tile sizes. Distribute to collapsed M
+ // and N dimensions to avoid starving either dimension. Once the collapsed
+ // distribution is determined, it will be distributed to individual dimensions
+ // of M and N.
+ int64_t mSubgroupDistributed = 1;
+ int64_t nSubgroupDistributed = 1;
+ int64_t mTileSizeDistributed = 1;
+ int64_t nTileSizeDistributed = 1;
+
LDBG() << "Starting MMA schedule distribution";
- while (mDim >= 0 || nDim >= 0) {
- LDBG() << "Current iteration: mDim: " << mDim << ", nDim: " << nDim
- << ", remainingSubgroups: " << remainingSubgroups
- << ", remainingTiles: " << remainingTiles
- << ", mTileSizes: " << mTileSizes << ", nTileSizes: " << nTileSizes;
- int64_t subgroupSqrt =
- 1ull << (llvm::divideCeil(llvm::Log2_64(remainingSubgroups), 2));
- int64_t tileSqrt = 1ull << (llvm::Log2_64(remainingTiles) / 2);
+ LDBG() << "mTotalTileCounts: " << mTotalTileCounts
+ << ", nTotalTileCounts: " << nTotalTileCounts
+ << ", remainingSubgroups: " << remainingSubgroups
+ << ", remainingTiles: " << remainingTiles;
- // See if the square root can divide mTotalTileCount. If so it means we can
- // distribute to both dimensions evenly to minimize the number of global
- // loads. Otherwise, try to distribute to N and then M.
- if (mDim >= 0 && nDim >= 0 &&
- mTotalTileCounts[mDim] > (subgroupSqrt * tileSqrt) &&
- mTotalTileCounts[mDim] % (subgroupSqrt * tileSqrt) == 0) {
- LDBG() << "Distributing evenly to M and N dimensions.";
- mSubgroupCounts[mDim] = subgroupSqrt;
- mTileSizes[mDim] = tileSqrt;
+ // This aims to be generous on subgroup splitting, produce the smallest
+ // power-of-two that is >= sqrt(remainingSubgroups)
+ int64_t subgroupSqrt =
+ 1ull << (llvm::divideCeil(llvm::Log2_64(remainingSubgroups), 2));
+ // This aims to be conservative on tile splitting, produce the largest
+ // power-of-two that is <= sqrt(remainingTiles)
+ int64_t tileSqrt = 1ull << (llvm::Log2_64(remainingTiles) / 2);
+ int64_t splitFactor = subgroupSqrt * tileSqrt;
- remainingSubgroups /= subgroupSqrt;
- remainingTiles /= tileSqrt;
+ LDBG() << "splitFactor: " << splitFactor << ", subgroupSqrt: " << subgroupSqrt
+ << ", tileSqrt: " << tileSqrt;
- APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCounts[nDim]),
- APInt(64, remainingSubgroups));
- nSubgroupCounts[nDim] = nGCD.getSExtValue();
- nTotalTileCounts[nDim] /= nSubgroupCounts[nDim];
- remainingSubgroups /= nSubgroupCounts[nDim];
+ // See if the square root can divide total tile count. If so it means we can
+ // distribute to a dimensions evenly to minimize the number of global
+ // loads. Or else fall back to GCD distribution.
+ bool canMDistributeEvenly = mTotalTileToDistribute > splitFactor &&
+ mTotalTileToDistribute % splitFactor == 0;
+ bool canNDistributeEvenly = nTotalTileToDistribute > splitFactor &&
+ nTotalTileToDistribute % splitFactor == 0;
+ if (canMDistributeEvenly) {
+ LDBG() << "Distributing seed evenly to M dim";
+ distributeSqrtForDim(true, subgroupSqrt, tileSqrt, mTotalTileToDistribute,
+ nTotalTileToDistribute, mSubgroupDistributed,
+ nSubgroupDistributed, mTileSizeDistributed,
+ nTileSizeDistributed, remainingSubgroups,
+ remainingTiles);
+ distributeGCDForDim(false, mTotalTileToDistribute, nTotalTileToDistribute,
+ mSubgroupDistributed, nSubgroupDistributed,
+ mTileSizeDistributed, nTileSizeDistributed,
+ remainingSubgroups, remainingTiles);
+ } else if (canNDistributeEvenly) {
+ LDBG() << "Distributing seed evenly to N dim";
+ distributeSqrtForDim(false, subgroupSqrt, tileSqrt, mTotalTileToDistribute,
+ nTotalTileToDistribute, mSubgroupDistributed,
+ nSubgroupDistributed, mTileSizeDistributed,
+ nTileSizeDistributed, remainingSubgroups,
+ remainingTiles);
+ distributeGCDForDim(true, mTotalTileToDistribute, nTotalTileToDistribute,
+ mSubgroupDistributed, nSubgroupDistributed,
+ mTileSizeDistributed, nTileSizeDistributed,
+ remainingSubgroups, remainingTiles);
+ } else {
+ LDBG() << "Distributing seed using GCD";
+ distributeGCDForDim(false, mTotalTileToDistribute, nTotalTileToDistribute,
+ mSubgroupDistributed, nSubgroupDistributed,
+ mTileSizeDistributed, nTileSizeDistributed,
+ remainingSubgroups, remainingTiles);
+ distributeGCDForDim(true, mTotalTileToDistribute, nTotalTileToDistribute,
+ mSubgroupDistributed, nSubgroupDistributed,
+ mTileSizeDistributed, nTileSizeDistributed,
+ remainingSubgroups, remainingTiles);
+ }
- nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCounts[nDim]),
- APInt(64, remainingTiles));
- nTileSizes[nDim] = nGCD.getSExtValue();
- remainingTiles /= nTileSizes[nDim];
- } else {
- if (nDim >= 0) {
- LDBG() << "Distributing to N dimension first.";
- APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCounts[nDim]),
- APInt(64, remainingSubgroups));
- nSubgroupCounts[nDim] = nGCD.getSExtValue();
- nTotalTileCounts[nDim] /= nSubgroupCounts[nDim];
- remainingSubgroups /= nSubgroupCounts[nDim];
+ // Note: Experimentation has proved that leaving the leftover factors
+ // unassigned is better than greedily assigning them to the larger collapsed
+ // dimension. This is likely because assigning leftover factors often results
+ // in overly aggressive tiling that ended up reducing occupancy and increasing
+ // shared memory usage.
+ LDBG() << "Leftover factors: subgroups: " << remainingSubgroups
+ << ", tiles: " << remainingTiles;
+ LDBG() << "Collapsed subgroup counts: M: " << mSubgroupDistributed
+ << ", N: " << nSubgroupDistributed;
+ LDBG() << "Collapsed tile sizes: M: " << mTileSizeDistributed
+ << ", N: " << nTileSizeDistributed;
- nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCounts[nDim]),
- APInt(64, remainingTiles));
- nTileSizes[nDim] = nGCD.getSExtValue();
- remainingTiles /= nTileSizes[nDim];
- }
+ SmallVector<int64_t> mSubgroupCounts(problem.mSizes.size(), 0),
+ nSubgroupCounts(problem.nSizes.size(), 0),
+ mTileSizes(problem.mSizes.size(), 0),
+ nTileSizes(problem.nSizes.size(), 0);
- if (mDim >= 0) {
- LDBG() << "Distributing to M dimension next.";
- APInt mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCounts[mDim]),
- APInt(64, remainingSubgroups));
- mSubgroupCounts[mDim] = mGCD.getSExtValue();
- mTotalTileCounts[mDim] /= mSubgroupCounts[mDim];
- remainingSubgroups /= mSubgroupCounts[mDim];
+ // Distribute collapsed tile to M dims from inner -> outer.
+ for (size_t e = problem.mSizes.size(), i = e - 1; i < e; --i) {
+ mSubgroupCounts[i] =
+ distributeTilesUsingGCD(mTotalTileCounts[i], mSubgroupDistributed);
+ mTileSizes[i] =
+ distributeTilesUsingGCD(mTotalTileCounts[i], mTileSizeDistributed);
+ }
- mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCounts[mDim]),
- APInt(64, remainingTiles));
- mTileSizes[mDim] = mGCD.getSExtValue();
- remainingTiles /= mTileSizes[mDim];
- }
- }
- --mDim;
- --nDim;
+ // Distribute collapsed tile to N dims from inner -> outer.
+ for (size_t e = problem.nSizes.size(), i = e - 1; i < e; --i) {
+ nSubgroupCounts[i] =
+ distributeTilesUsingGCD(nTotalTileCounts[i], nSubgroupDistributed);
+ nTileSizes[i] =
+ distributeTilesUsingGCD(nTotalTileCounts[i], nTileSizeDistributed);
}
SmallVector<int64_t> kTileSizes =
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir
index 4c77bad..db1909e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir
@@ -32,12 +32,12 @@
// CHECK-SAME: promote_operands = [0, 1]
// GFX942-SAME: reduction = [0, 0, 0, 0, 8]
-// GFX942-SAME: subgroup = [1, 4, 1, 1, 0]
+// GFX942-SAME: subgroup = [1, 2, 1, 2, 0]
// GFX942-SAME: workgroup = [1, 4, 32, 64, 0]
// MI300X-SAME: reduction = [0, 0, 0, 0, 8]
// MI300X-SAME: subgroup = [1, 1, 1, 1, 0]
-// MI300X-SAME: workgroup = [1, 1, 32, 64, 0]}>
+// MI300X-SAME: workgroup = [1, 2, 32, 32, 0]}>
// -----
@@ -66,12 +66,12 @@
// CHECK-SAME: promote_operands = [0, 1]
// GFX942-SAME: reduction = [0, 0, 0, 0, 8]
-// GFX942-SAME: subgroup = [1, 1, 4, 1, 0]
+// GFX942-SAME: subgroup = [1, 2, 2, 1, 0]
// GFX942-SAME: workgroup = [1, 64, 4, 32, 0]
// MI300X-SAME: reduction = [0, 0, 0, 0, 8]
// MI300X-SAME: subgroup = [1, 1, 1, 1, 0]
-// MI300X-SAME: workgroup = [1, 64, 1, 32, 0]
+// MI300X-SAME: workgroup = [1, 32, 2, 32, 0]
// -----
@@ -104,11 +104,11 @@
// GFX942-SAME: subgroup = [2, 1, 1, 1, 0]
// GFX942-SAME: workgroup = [2, 1, 32, 64, 0]
-// MI300X-SAME: padding = [1, 1, 32, 64, 32]
+// MI300X-SAME: padding = [2, 1, 32, 32, 32]
// MI300X-SAME: promote_operands = [0, 1, 2]
// MI300X-SAME: reduction = [0, 0, 0, 0, 8]
// MI300X-SAME: subgroup = [1, 1, 1, 1, 0]
-// MI300X-SAME: workgroup = [1, 1, 32, 64, 0]
+// MI300X-SAME: workgroup = [2, 1, 32, 32, 0]
// PAD-CONV-GFX942: padding_conv = [2, 1, 32, 64, 0, 0, 0]
@@ -140,14 +140,14 @@
// GFX942-SAME: padding = [1, 64, 4, 32, 32]
// GFX942-SAME: promote_operands = [0, 1, 2]
// GFX942-SAME: reduction = [0, 0, 0, 0, 8]
-// GFX942-SAME: subgroup = [1, 1, 4, 1, 0]
+// GFX942-SAME: subgroup = [1, 2, 2, 1, 0]
// GFX942-SAME: workgroup = [1, 64, 4, 32, 0]
-// MI300X-SAME: padding = [1, 64, 1, 32, 32]
+// MI300X-SAME: padding = [1, 32, 2, 32, 32]
// MI300X-SAME: promote_operands = [0, 1, 2]
// MI300X-SAME: reduction = [0, 0, 0, 0, 8]
// MI300X-SAME: subgroup = [1, 1, 1, 1, 0]
-// MI300X-SAME: workgroup = [1, 64, 1, 32, 0]
+// MI300X-SAME: workgroup = [1, 32, 2, 32, 0]
// PAD-CONV-GFX942: padding_conv = [1, 64, 4, 32, 0, 0, 0]
@@ -176,19 +176,19 @@
// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
-// GFX942-SAME: padding = [1, 8, 32, 32, 32]
+// GFX942-SAME: padding = [1, 4, 32, 32, 32]
// GFX942-SAME: promote_operands = [0, 1, 2]
// GFX942-SAME: reduction = [0, 0, 0, 0, 2]
-// GFX942-SAME: subgroup = [1, 4, 1, 1, 0]
-// GFX942-SAME: workgroup = [1, 8, 32, 32, 0]
+// GFX942-SAME: subgroup = [1, 2, 1, 1, 0]
+// GFX942-SAME: workgroup = [1, 4, 32, 32, 0]
-// MI300X-SAME: padding = [1, 4, 32, 32, 32]
+// MI300X-SAME: padding = [1, 2, 32, 32, 32]
// MI300X-SAME: promote_operands = [0, 1, 2]
// MI300X-SAME: reduction = [0, 0, 0, 0, 2]
-// MI300X-SAME: subgroup = [1, 2, 1, 1, 0]
-// MI300X-SAME: workgroup = [1, 4, 32, 32, 0]
+// MI300X-SAME: subgroup = [1, 1, 1, 1, 0]
+// MI300X-SAME: workgroup = [1, 2, 32, 32, 0]
-// PAD-CONV-GFX942: padding_conv = [1, 8, 32, 32, 0, 0, 32]
+// PAD-CONV-GFX942: padding_conv = [1, 4, 32, 32, 0, 0, 32]
// -----
@@ -289,13 +289,13 @@
// GFX942-SAME: padding = [2, 2, 32, 64, 64]
// GFX942-SAME: promote_operands = [0, 1, 2]
// GFX942-SAME: reduction = [0, 0, 0, 0, 4]
-// GFX942-SAME: subgroup = [2, 2, 1, 1, 0]
+// GFX942-SAME: subgroup = [2, 1, 1, 2, 0]
// GFX942-SAME: workgroup = [2, 2, 32, 64, 0]
-// MI300X-SAME: padding = [1, 1, 32, 64, 64]
+// MI300X-SAME: padding = [1, 2, 32, 32, 64]
// MI300X-SAME: promote_operands = [0, 1, 2]
// MI300X-SAME: reduction = [0, 0, 0, 0, 4]
// MI300X-SAME: subgroup = [1, 1, 1, 1, 0]
-// MI300X-SAME: workgroup = [1, 1, 32, 64, 0]
+// MI300X-SAME: workgroup = [1, 2, 32, 32, 0]
// PAD-CONV-GFX942: padding_conv = [2, 2, 32, 64, 0, 0]
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
index 5101daf..d858ea7 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
@@ -84,7 +84,7 @@
// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
// CHECK-SAME: promote_operands = [0, 1]
// CHECK-SAME: reduction = [0, 0, 0, 0, 4, 1]
-// CHECK-SAME: subgroup = [2, 2, 1, 1, 0, 0]
+// CHECK-SAME: subgroup = [2, 1, 2, 1, 0, 0]
// CHECK-SAME: workgroup = [2, 4, 32, 32, 0, 0]
// LATE: LLVMGPUVectorDistribute
@@ -439,14 +439,14 @@
// schedule with nTileSize of 16 while in reality it should be 8.
// LATE-LABEL: func.func @unaligned_to_intrinsic_batched_matmul_tiling_check
-// LATE-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [512, 1, 1] subgroup_size = 64
+// LATE-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
// LATE-SAME: {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false, use_igemm_convolution = false>}
// LATE: linalg.batch_matmul {{.*}}lowering_config = #iree_gpu.lowering_config
-// LATE-SAME: padding = [1, 16, 512, 4]
+// LATE-SAME: padding = [1, 16, 128, 4]
// LATE-SAME: promote_operands = [0, 1, 2]
// LATE-SAME: reduction = [0, 0, 0, 1]
-// LATE-SAME: subgroup = [0, 1, 4, 0]
-// LATE-SAME: workgroup = [1, 16, 512, 0]
+// LATE-SAME: subgroup = [0, 1, 2, 0]
+// LATE-SAME: workgroup = [1, 16, 128, 0]
// -----