[GPU] Support multiple contraction dims in MmaSchedules (#18720)
This adds support for multiple M, N, and K dims in problems when
deducing a GPUMMASchedule. The new heuristic is similar to the old one,
but works on pairs of M and N dims. For example:
```
tensor<M1xM0xK1xK0> * tensor<N1xN0xK1xK0> -> tensor<M1xN1xM0xN0>
```
This will try to distribute the seeded tile counts to `M0` and `N0`
(first attempting to distribute evenly, and then distributing to N
followed by N), and then distribute the residual counts to `M1` and
`N1`. The K tile counts will be partitioned to `K0` first, and then the
residual tile counts will be partitioned to `K1`.
This PR also updates the config selection logic for the TileAndFuse
pipeline to make use of the multiple contraction dimensions in mma
schedules.
---------
Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
index dc30783..790484d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
@@ -9,6 +9,7 @@
#include <cstdint>
#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/Sequence.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
@@ -20,51 +21,106 @@
namespace mlir::iree_compiler {
+template <typename T>
static llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
- const GPUMMASchedule &schedule) {
- os << "mSize: " << schedule.mSize << ", ";
- os << "nSize: " << schedule.nSize << ", ";
- os << "kSize: " << schedule.kSize << ", ";
- os << "mTileCount: " << schedule.mTileCount << ", ";
- os << "nTileCount: " << schedule.nTileCount << ", ";
- os << "kTileCount: " << schedule.kTileCount << ", ";
- os << "mWarpCount: " << schedule.mWarpCount << ", ";
- os << "nWarpCount: " << schedule.nWarpCount;
+ const llvm::SmallVectorImpl<T> &vector) {
+ os << "[";
+ llvm::interleaveComma(vector, os);
+ os << "]";
return os;
}
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ const GPUMMASchedule &schedule) {
+ os << "mSizes: " << schedule.mSize << ", ";
+ os << "nSizes: " << schedule.nSize << ", ";
+ os << "kSizes: " << schedule.kSize << ", ";
+ os << "mTileSizes: " << schedule.mTileSizes << ", ";
+ os << "nTileSizes: " << schedule.nTileSizes << ", ";
+ os << "kTileSizes: " << schedule.kTileSizes << ", ";
+ os << "mSubgroupCounts: " << schedule.mSubgroupCounts << ", ";
+ os << "nSubgroupCounts: " << schedule.nSubgroupCounts;
+ return os;
+}
+
+// Shortened helper to compute the product of `values`.
+static int64_t prod(ArrayRef<int64_t> values) {
+ return ShapedType::getNumElements(values);
+}
+
static int64_t calculateSharedMemoryUsedInBytes(const GPUMMASchedule &schedule,
int64_t lhsBitwidth,
int64_t rhsBitwidth) {
- int64_t tileM = schedule.mSize * schedule.mTileCount * schedule.mWarpCount;
- int64_t tileN = schedule.nSize * schedule.nTileCount * schedule.nWarpCount;
- int64_t tileK = schedule.kSize * schedule.kTileCount;
+
+ int64_t tileM = schedule.mSize * prod(schedule.mTileSizes) *
+ prod(schedule.mSubgroupCounts);
+ int64_t tileN = schedule.nSize * prod(schedule.nTileSizes) *
+ prod(schedule.nSubgroupCounts);
+ int64_t tileK = schedule.kSize * prod(schedule.kTileSizes);
return (tileM * tileK * lhsBitwidth + tileN * tileK * rhsBitwidth) / 8;
}
+/// Check that a GPUMMASchedule fits alignment restrictions. To be aligned,
+/// the problem must be evenly divisible by the number of elements in the
+/// schedule for each dimension. If `mustBeAligned` is false, then the innermost
+/// problem dimension is allowed to be unaligned .
static bool isScheduleAligned(const GPUMatmulShapeType &problem,
const GPUMMASchedule &schedule,
bool mustBeAligned) {
- auto alignedMSize =
- mustBeAligned
- ? problem.mSize
- : llvm::divideCeil(problem.mSize, schedule.mSize) * schedule.mSize;
- auto alignedNSize =
- mustBeAligned
- ? problem.nSize
- : llvm::divideCeil(problem.nSize, schedule.nSize) * schedule.nSize;
- auto alignedKSize =
- mustBeAligned
- ? problem.kSize
- : llvm::divideCeil(problem.kSize, schedule.kSize) * schedule.kSize;
- bool isValidM = (alignedMSize % (schedule.mSize * schedule.mTileCount *
- schedule.mWarpCount)) == 0;
- bool isValidN = (alignedNSize % (schedule.nSize * schedule.nTileCount *
- schedule.nWarpCount)) == 0;
- bool isValidK = (alignedKSize % (schedule.kSize * schedule.kTileCount)) == 0;
+ SmallVector<int64_t> alignedMSizes(problem.mSizes);
+ alignedMSizes.back() =
+ mustBeAligned ? problem.mSizes.back()
+ : llvm::divideCeil(problem.mSizes.back(), schedule.mSize) *
+ schedule.mSize;
+ SmallVector<int64_t> alignedNSizes(problem.nSizes);
+ alignedNSizes.back() =
+ mustBeAligned ? problem.nSizes.back()
+ : llvm::divideCeil(problem.nSizes.back(), schedule.nSize) *
+ schedule.nSize;
+ SmallVector<int64_t> alignedKSizes(problem.kSizes);
+ alignedKSizes.back() =
+ mustBeAligned ? problem.kSizes.back()
+ : llvm::divideCeil(problem.kSizes.back(), schedule.kSize) *
+ schedule.kSize;
+ // Returns the number of elements in the schedule for each dimension.
+ auto getScheduleSizes =
+ [&](int64_t size, SmallVector<int64_t> tileCount,
+ std::optional<SmallVector<int64_t>> subgroupCount) {
+ SmallVector<int64_t> sizes = llvm::map_to_vector(
+ llvm::seq<int64_t>(tileCount.size()), [&](int64_t i) {
+ return subgroupCount ? tileCount[i] * subgroupCount.value()[i]
+ : tileCount[i];
+ });
+ sizes.back() *= size;
+ return sizes;
+ };
+ // Checks whether the elements of `a` are evenly divisible by the
+ // corresponding elements of `b`.
+ auto areAligned = [](SmallVector<int64_t> a, SmallVector<int64_t> b) {
+ for (auto [aVal, bVal] : llvm::zip_equal(a, b)) {
+ if (aVal % bVal != 0) {
+ return false;
+ }
+ }
+ return true;
+ };
+ bool isValidM = areAligned(
+ alignedMSizes, getScheduleSizes(schedule.mSize, schedule.mTileSizes,
+ schedule.mSubgroupCounts));
+ bool isValidN = areAligned(
+ alignedNSizes, getScheduleSizes(schedule.nSize, schedule.nTileSizes,
+ schedule.nSubgroupCounts));
+ bool isValidK = areAligned(
+ alignedKSizes,
+ getScheduleSizes(schedule.kSize, schedule.kTileSizes, std::nullopt));
return isValidM && isValidN && isValidK;
}
+/// Returns whether or not a GPUMMASchedule is valid for the given problem.
+/// This checks that:
+/// - The problem is aligned to the schedule
+/// - the number of threads in the schedule workgroup can be distributed
+/// to a corresponding vector.transfer read in VectorDistribute.
static bool isValidMMASchedule(const GPUMatmulShapeType &problem,
const GPUMMASchedule &schedule,
bool mustBeAligned, int64_t subgroupSize,
@@ -76,11 +132,13 @@
const int64_t kMaxVectorLoadBitWidth = 128;
int64_t elemsPerThread =
kMaxVectorLoadBitWidth / problem.bType.getIntOrFloatBitWidth();
- int64_t wgThreads = schedule.mWarpCount * schedule.nWarpCount * subgroupSize;
-
- int64_t mWgSize = schedule.mSize * schedule.mTileCount * schedule.mWarpCount;
- int64_t nWgSize = schedule.nSize * schedule.nTileCount * schedule.nWarpCount;
- int64_t kWgSize = schedule.kSize * schedule.kTileCount;
+ int64_t wgThreads = subgroupSize * prod(schedule.mSubgroupCounts) *
+ prod(schedule.nSubgroupCounts);
+ int64_t mWgSize = schedule.mSize * prod(schedule.mTileSizes) *
+ prod(schedule.mSubgroupCounts);
+ int64_t nWgSize = schedule.nSize * prod(schedule.nTileSizes) *
+ prod(schedule.nSubgroupCounts);
+ int64_t kWgSize = schedule.kSize * prod(schedule.kTileSizes);
int64_t innerLhsDimSize = transposedLhs ? mWgSize : kWgSize;
int64_t innerRhsDimSize = transposedRhs ? kWgSize : nWgSize;
@@ -94,6 +152,10 @@
return isAligned && isDistributableLhs && isDistributableRhs;
}
+/// Tries to fit the schedule into shared memory by decrementing the size of the
+/// schedule dimensions from outermost to innermost until a valid schedule is
+/// found. The schedule sizes are reduced in the order of mTileSizes,
+/// nTileSizes, kTileSizes, mSubgroupCounts, nSubgroupCounts.
static FailureOr<GPUMMASchedule> fitScheduleInSharedMemory(
GPUMatmulShapeType intrinsic, GPUMMASchedule schedule,
llvm::function_ref<bool(const GPUMMASchedule &schedule)> isScheduleValid) {
@@ -105,31 +167,35 @@
llvm::dbgs() << "Shrinking schedule...\n";
});
- auto decrementIfPossible = [](int64_t &c) -> LogicalResult {
- if (c <= 1) {
- return failure();
+ auto decrementIfPossible =
+ [](SmallVector<int64_t> &sizes) -> LogicalResult {
+ for (int64_t &size : sizes) {
+ if (size <= 1)
+ continue;
+ --size;
+ return success();
}
- --c;
- return success();
+ return failure();
};
// Attempt to shrink the schedule along one of the dimensions.
// TODO: A better solution should probably factor problem.mSize /
- // (mWarpCount * mTileCount * mSize) and then pop off the smallest factors
- // one at a time, preferably trying to keep the tile "generally square."
- if (succeeded(decrementIfPossible(schedule.mTileCount))) {
+ // (mSubgroupCount * mTileCount * mSize) and then pop off the smallest
+ // factors one at a time, preferably trying to keep the tile "generally
+ // square."
+ if (succeeded(decrementIfPossible(schedule.mTileSizes))) {
continue;
}
- if (succeeded(decrementIfPossible(schedule.nTileCount))) {
+ if (succeeded(decrementIfPossible(schedule.nTileSizes))) {
continue;
}
- if (succeeded(decrementIfPossible(schedule.kTileCount))) {
+ if (succeeded(decrementIfPossible(schedule.kTileSizes))) {
continue;
}
- if (succeeded(decrementIfPossible(schedule.mWarpCount))) {
+ if (succeeded(decrementIfPossible(schedule.mSubgroupCounts))) {
continue;
}
- if (succeeded(decrementIfPossible(schedule.nWarpCount))) {
+ if (succeeded(decrementIfPossible(schedule.nSubgroupCounts))) {
continue;
}
@@ -148,6 +214,9 @@
static LogicalResult canTargetIntrinsic(const GPUMatmulShapeType &problem,
const GPUMatmulShapeType &intrinsic,
bool canUpcastAcc, bool mustBeAligned) {
+ assert(intrinsic.mSizes.size() == 1 && intrinsic.nSizes.size() == 1 &&
+ intrinsic.kSizes.size() == 1 &&
+ "expected intrinsic to have a single M, N, and K dimension.");
if (problem.aType != intrinsic.aType || problem.bType != intrinsic.bType) {
return failure(); // Cannot use this intrinsic for mismatched types
}
@@ -161,17 +230,17 @@
}
}
- if (mustBeAligned && (problem.mSize % intrinsic.mSize != 0 ||
- problem.nSize % intrinsic.nSize != 0 ||
- problem.kSize % intrinsic.kSize != 0)) {
+ if (mustBeAligned && (problem.mSizes.back() % intrinsic.mSizes[0] != 0 ||
+ problem.nSizes.back() % intrinsic.nSizes[0] != 0 ||
+ problem.kSizes.back() % intrinsic.kSizes[0] != 0)) {
return failure(); // Cannot use this intrinsic for misaligned cases.
}
// Cannot use the intrinsic when the tile size is greater than problem size.
// Because tiling is a no-op, and we can't infer tiling sizes from IR.
- if (!mustBeAligned &&
- (problem.mSize < intrinsic.mSize || problem.nSize < intrinsic.nSize ||
- problem.kSize < intrinsic.kSize)) {
+ if (!mustBeAligned && (problem.mSizes.back() < intrinsic.mSizes[0] ||
+ problem.nSizes.back() < intrinsic.nSizes[0] ||
+ problem.kSizes.back() < intrinsic.kSizes[0])) {
return failure();
}
@@ -185,77 +254,123 @@
const GPUMatmulShapeType &intrinsic,
const GPUMMAHeuristicSeeds &seeds,
uint64_t intrinsicIndex) {
- int64_t mTotalTileCount = llvm::divideCeil(problem.mSize, intrinsic.mSize);
- int64_t nTotalTileCount = llvm::divideCeil(problem.nSize, intrinsic.nSize);
+ assert(intrinsic.mSizes.size() == 1 && intrinsic.nSizes.size() == 1 &&
+ intrinsic.kSizes.size() == 1 &&
+ "expected intrinsic to have a single M, N, and K dimension.");
+ // mTotalTileCounts and nTotalTileCounts represent the total number of
+ // intrinsics along the M or N dimensions needed to fill the problem size.
+ // For example, if the problem is {M:[4, 16], N:[2, 32], K[3, 128]} for a
+ // 16x16x16 intrinsic, then:
+ // - mTotalTileCounts would be 4 * (16/16) = 4
+ // - nTotalTileCounts would be 2 * (32/16) = 4
+ SmallVector<int64_t> mTotalTileCounts = problem.mSizes;
+ SmallVector<int64_t> nTotalTileCounts = problem.nSizes;
+ mTotalTileCounts.back() =
+ llvm::divideCeil(problem.mSizes.back(), intrinsic.mSizes[0]);
+ nTotalTileCounts.back() =
+ llvm::divideCeil(problem.nSizes.back(), intrinsic.nSizes[0]);
- int64_t remainingWarps = seeds.bestSubgroupCountPerWorkgroup;
+ int64_t remainingSubgroups = seeds.bestSubgroupCountPerWorkgroup;
int64_t remainingTiles = seeds.bestMNTileCountPerSubgroup;
- // Assign more warps to the M dimension (used later) to balance thread
+ // Assign more subgroups to the M dimension (used later) to balance thread
// counts along X and Y dimensions.
- int64_t warpSqrt =
- 1ull << (llvm::divideCeil(llvm::Log2_64(remainingWarps), 2));
- int64_t tileSqrt = 1ull << (llvm::Log2_64(remainingTiles) / 2);
+ int mDim = problem.mSizes.size() - 1;
+ int nDim = problem.nSizes.size() - 1;
+ SmallVector<int64_t> mTileSizes(problem.mSizes.size(), 0),
+ nTileSizes(problem.nSizes.size(), 0),
+ mSubgroupCounts(problem.mSizes.size(), 0),
+ nSubgroupCounts(problem.nSizes.size(), 0);
+ // Start at the innermost nDim and mDim, and try to distribute evenly to M and
+ // N for each pair of M and N dims. Otherwise, distribute to N and then M.
+ while (mDim >= 0 || nDim >= 0) {
+ int64_t subgroupSqrt =
+ 1ull << (llvm::divideCeil(llvm::Log2_64(remainingSubgroups), 2));
+ int64_t tileSqrt = 1ull << (llvm::Log2_64(remainingTiles) / 2);
- int64_t mWarpCount = 0, nWarpCount = 0;
- int64_t mTileCount = 0, nTileCount = 0;
+ // See if the square root can divide mTotalTileCount. If so it means we can
+ // distribute to both dimensions evenly to minimize the number of global
+ // loads. Otherwise, try to distribute to N and then M.
+ if (mDim >= 0 && nDim >= 0 &&
+ mTotalTileCounts[mDim] > (subgroupSqrt * tileSqrt) &&
+ mTotalTileCounts[mDim] % (subgroupSqrt * tileSqrt) == 0) {
+ mSubgroupCounts[mDim] = subgroupSqrt;
+ mTileSizes[mDim] = tileSqrt;
- // See if the square root can divide mTotalTileCount. If so it means we can
- // distribute to both dimensions evenly. Otherwise, try to distribute to N
- // and then M.
- if (mTotalTileCount > (warpSqrt * tileSqrt) &&
- mTotalTileCount % (warpSqrt * tileSqrt) == 0) {
- mWarpCount = warpSqrt;
- mTileCount = tileSqrt;
+ remainingSubgroups /= subgroupSqrt;
+ remainingTiles /= tileSqrt;
- remainingWarps /= warpSqrt;
- remainingTiles /= tileSqrt;
+ APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCounts[nDim]),
+ APInt(64, remainingSubgroups));
+ nSubgroupCounts[nDim] = nGCD.getSExtValue();
+ nTotalTileCounts[nDim] /= nSubgroupCounts[nDim];
+ remainingSubgroups /= nSubgroupCounts[nDim];
- APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount),
- APInt(64, remainingWarps));
- nWarpCount = nGCD.getSExtValue();
- nTotalTileCount /= nWarpCount;
- remainingWarps /= nWarpCount;
+ nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCounts[nDim]),
+ APInt(64, remainingTiles));
+ nTileSizes[nDim] = nGCD.getSExtValue();
+ remainingTiles /= nTileSizes[nDim];
+ } else {
+ if (nDim >= 0) {
+ APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCounts[nDim]),
+ APInt(64, remainingSubgroups));
+ nSubgroupCounts[nDim] = nGCD.getSExtValue();
+ nTotalTileCounts[nDim] /= nSubgroupCounts[nDim];
+ remainingSubgroups /= nSubgroupCounts[nDim];
- nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount),
- APInt(64, remainingTiles));
- nTileCount = nGCD.getSExtValue();
- } else {
- APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount),
- APInt(64, remainingWarps));
- nWarpCount = nGCD.getSExtValue();
- nTotalTileCount /= nWarpCount;
- remainingWarps /= nWarpCount;
+ nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCounts[nDim]),
+ APInt(64, remainingTiles));
+ nTileSizes[nDim] = nGCD.getSExtValue();
+ remainingTiles /= nTileSizes[nDim];
+ }
- nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount),
- APInt(64, remainingTiles));
- nTileCount = nGCD.getSExtValue();
- remainingTiles /= nTileCount;
+ if (mDim >= 0) {
+ APInt mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCounts[mDim]),
+ APInt(64, remainingSubgroups));
+ mSubgroupCounts[mDim] = mGCD.getSExtValue();
+ mTotalTileCounts[mDim] /= mSubgroupCounts[mDim];
+ remainingSubgroups /= mSubgroupCounts[mDim];
- APInt mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCount),
- APInt(64, remainingWarps));
- mWarpCount = mGCD.getSExtValue();
- mTotalTileCount /= mWarpCount;
- remainingWarps /= mWarpCount;
-
- mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCount),
- APInt(64, remainingTiles));
- mTileCount = mGCD.getSExtValue();
+ mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCounts[mDim]),
+ APInt(64, remainingTiles));
+ mTileSizes[mDim] = mGCD.getSExtValue();
+ remainingTiles /= mTileSizes[mDim];
+ }
+ }
+ --mDim;
+ --nDim;
}
- const uint64_t kTotalTileCount =
- llvm::divideCeil(problem.kSize, intrinsic.kSize);
+ // kTotalTileCounts is similar to m/nTotalTileCounts, representing the total
+ // number of intrinsics along the K dimensions needed to fill the problem.
+ // For the problem described above {M:[4, 16], N:[2, 32], K[3, 128]} with a
+ // 16x16x16 intrinsic, then:
+ // - kTotalTileCounts would be 3 * (128/16) = 24
+ SmallVector<int64_t> kTotalTileCounts = problem.kSizes;
+ kTotalTileCounts.back() =
+ llvm::divideCeil(problem.kSizes.back(), intrinsic.kSizes[0]);
+ // Compute the ideal number of intrinsics along K per subgroup based on the
+ // seed.
int64_t bestKTileCountPerSubgroup =
seeds.bestKElementCountPerSubgroup
? llvm::divideCeil(seeds.bestKElementCountPerSubgroup,
- intrinsic.kSize)
+ intrinsic.kSizes[0])
: seeds.bestKTileCountPerSubgroup;
- APInt kGCD = GreatestCommonDivisor(APInt(64, kTotalTileCount),
- APInt(64, bestKTileCountPerSubgroup));
- int64_t kTileCount = kGCD.getSExtValue();
+ SmallVector<int64_t> kTileSizes(problem.kSizes.size(), 0);
+ // Start at the innermost K dim, and tile each dim to try to satisfy the ideal
+ // K intrinsic count per subgroup with the overall product of K tile counts.
+ int kDim = problem.kSizes.size() - 1;
+ while (kDim >= 0) {
+ APInt kGCD = GreatestCommonDivisor(APInt(64, kTotalTileCounts[kDim]),
+ APInt(64, bestKTileCountPerSubgroup));
+ kTileSizes[kDim] = kGCD.getSExtValue();
+ bestKTileCountPerSubgroup /= kTileSizes[kDim];
+ --kDim;
+ }
- return GPUMMASchedule{intrinsicIndex, intrinsic.mSize, intrinsic.nSize,
- intrinsic.kSize, mWarpCount, nWarpCount,
- mTileCount, nTileCount, kTileCount};
+ return GPUMMASchedule{
+ intrinsicIndex, intrinsic.mSizes[0], intrinsic.nSizes[0],
+ intrinsic.kSizes[0], mSubgroupCounts, nSubgroupCounts,
+ mTileSizes, nTileSizes, kTileSizes};
}
FailureOr<GPUMMASchedule> deduceMMASchedule(
@@ -297,7 +412,6 @@
return isAligned && sharedMemoryUsed <= sharedMemLimitInBytes;
};
-
return fitScheduleInSharedMemory(intrinsic, schedule, isValidSchedule);
}
return failure();
@@ -309,7 +423,10 @@
const GPUMMAHeuristicSeeds &pvMatmulSeeds, int64_t sharedMemLimitInBytes,
int64_t subgroupSize, bool transposedQ, bool transposedK, bool transposedV,
bool canUpcastAcc, bool mustBeAligned) {
-
+ assert(pvMatmul.mSizes.size() == 1 && pvMatmul.nSizes.size() == 1 &&
+ pvMatmul.kSizes.size() == 1 && qkMatmul.mSizes.size() == 1 &&
+ qkMatmul.nSizes.size() == 1 && qkMatmul.kSizes.size() == 1 &&
+ "unimplemented: multi M/N/K attention schedule");
for (auto [index, intrinsic] : llvm::enumerate(intrinsics)) {
if (failed(canTargetIntrinsic(qkMatmul, intrinsic, canUpcastAcc,
mustBeAligned))) {
@@ -329,7 +446,7 @@
llvm::dbgs() << " " << schedule << "\n";
});
- int64_t intrinsicK = intrinsic.kSize;
+ int64_t intrinsicK = intrinsic.kSizes[0];
auto isValidSchedule = [&](const GPUMMASchedule &schedule) -> bool {
// Create a mma schedule for qkMatmul in attention.
// qkMatmul.M = pvMatmul.M
@@ -339,11 +456,11 @@
schedule.mSize,
schedule.kSize,
intrinsicK,
- /*mWarpCount=*/schedule.mWarpCount,
- /*nWarpCount=*/1,
- schedule.mTileCount,
- schedule.kTileCount,
- qkMatmul.kSize / intrinsicK};
+ /*mSubgroupCount=*/schedule.mSubgroupCounts[0],
+ /*nSubgroupCount=*/1,
+ schedule.mTileSizes[0],
+ schedule.kTileSizes[0],
+ qkMatmul.kSizes[0] / intrinsicK};
bool isQKAligned =
isValidMMASchedule(qkMatmul, qkSchedule, mustBeAligned, subgroupSize,
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h
index 8211443..13f6a56 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h
@@ -10,15 +10,18 @@
/// Struct containing information about a matmul's shape and type.
struct GPUMatmulShapeType {
- int64_t mSize;
- int64_t nSize;
- int64_t kSize;
+ SmallVector<int64_t> mSizes;
+ SmallVector<int64_t> nSizes;
+ SmallVector<int64_t> kSizes;
Type aType;
Type bType;
Type cType;
GPUMatmulShapeType(int64_t m, int64_t n, int64_t k, Type a, Type b, Type c)
- : mSize(m), nSize(n), kSize(k), aType(a), bType(b), cType(c) {}
+ : mSizes({m}), nSizes({n}), kSizes({k}), aType(a), bType(b), cType(c) {}
+ GPUMatmulShapeType(SmallVector<int64_t> m, SmallVector<int64_t> n,
+ SmallVector<int64_t> k, Type a, Type b, Type c)
+ : mSizes(m), nSizes(n), kSizes(k), aType(a), bType(b), cType(c) {}
};
/// Struct containing seed tile sizes for GPU MMA heuristics deduction logic.
@@ -38,14 +41,42 @@
struct GPUMMASchedule {
// Index of the chosen intrinsic into the list of given MMA intrinsics
uint64_t index;
- int64_t mSize; // Native MMA size along M dimension
- int64_t nSize; // Native MMA size along N dimension
- int64_t kSize; // Native MMA size along K dimension
- int64_t mWarpCount; // Number of subgroups along M dimension
- int64_t nWarpCount; // Number of subgroups along N dimension
- int64_t mTileCount; // Number of tiles per subgroup along M dimension
- int64_t nTileCount; // Number of tiles per subgroup along N dimension
- int64_t kTileCount; // Number of tiles along K dimension
+ int64_t mSize; // Native MMA intrinsic size along M dimension for a subgroup.
+ int64_t nSize; // Native MMA intrinsic size along N dimension for a subgroup.
+ int64_t kSize; // Native MMA intrinsic size along K dimension for a subgroup.
+
+ // Number of subgroups along each M and N dimension.
+ SmallVector<int64_t> mSubgroupCounts;
+ SmallVector<int64_t> nSubgroupCounts;
+
+ // Tile sizes for each M, N, and K dimension. When there are multiple M, N,
+ // or K dimensions, the intrinsic sizes are targeted to the innermost
+ // dimension, and the outer dimensions can be thought of as unrolling factors
+ // along M, N, or K.
+ SmallVector<int64_t> mTileSizes; // M tile sizes per subgroup.
+ SmallVector<int64_t> nTileSizes; // N tile sizes per subgroup.
+ SmallVector<int64_t> kTileSizes; // K tile sizes.
+
+ // Constructor for multi M, N, K dim schedules.
+ GPUMMASchedule(uint64_t i, int64_t mIntrinsicSize, int64_t nIntrinsicSize,
+ int64_t kIntrinsicSize, SmallVector<int64_t> mSubgroupCounts,
+ SmallVector<int64_t> nSubgroupCounts,
+ SmallVector<int64_t> mTileSizes,
+ SmallVector<int64_t> nTileSizes,
+ SmallVector<int64_t> kTileSizes)
+ : index(i), mSize(mIntrinsicSize), nSize(nIntrinsicSize),
+ kSize(kIntrinsicSize), mSubgroupCounts(mSubgroupCounts),
+ nSubgroupCounts(nSubgroupCounts), mTileSizes(mTileSizes),
+ nTileSizes(nTileSizes), kTileSizes(kTileSizes) {}
+
+ // Constructor for single M, N, K dim schedules.
+ GPUMMASchedule(uint64_t i, int64_t mIntrinsicSize, int64_t nIntrinsicSize,
+ int64_t kIntrinsicSize, int64_t mSubgroup, int64_t nSubgroup,
+ int64_t mTileSize, int64_t nTileSize, int64_t kTileSize)
+ : index(i), mSize(mIntrinsicSize), nSize(nIntrinsicSize),
+ kSize(kIntrinsicSize), mSubgroupCounts({mSubgroup}),
+ nSubgroupCounts({nSubgroup}), mTileSizes({mTileSize}),
+ nTileSizes({nTileSize}), kTileSizes({kTileSize}) {}
};
/// Returns a schedule for using one of the given MMA |intrinsics| to target the
@@ -69,4 +100,7 @@
bool transposedV = false, bool canUpcastAcc = false,
bool mustBeAligned = true);
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ const GPUMMASchedule &schedule);
+
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
index ca23b0c..58bfdc0 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
@@ -13,6 +13,7 @@
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
@@ -124,20 +125,37 @@
return failure();
}
- // For now we are not being smart and trying to reshape dimensions to allow
- // for better usage of intrinsics, and instead are tiling all dimensions
- // except the inner most m, n, and k dimensions to 1.
- int64_t mDim = contractionDims.m.back();
- int64_t nDim = contractionDims.n.back();
- int64_t kDim = contractionDims.k.back();
-
- // Dynamic dims are expected to be taken care of earlier in the pipeline.
- if (ShapedType::isDynamic(bounds[mDim]) ||
- ShapedType::isDynamic(bounds[nDim]) ||
- ShapedType::isDynamic(bounds[kDim])) {
+ // TODO(Max191): add dynamic shape support for inner most dims.
+ if (ShapedType::isDynamic(bounds[contractionDims.m.back()]) ||
+ ShapedType::isDynamic(bounds[contractionDims.n.back()]) ||
+ ShapedType::isDynamic(bounds[contractionDims.k.back()])) {
return failure();
}
+ // Gather all static M, N, and K dimensions to deduce the MMASchedule. Dynamic
+ // dimensions will be tiled to 1 in workgroup tiling, so they are ignored when
+ // computing an MMA schedule.
+ SmallVector<int64_t> mDims, nDims, kDims;
+ for (auto mDim : contractionDims.m) {
+ if (!ShapedType::isDynamic(bounds[mDim])) {
+ mDims.push_back(mDim);
+ }
+ }
+ for (auto nDim : contractionDims.n) {
+ if (!ShapedType::isDynamic(bounds[nDim])) {
+ nDims.push_back(nDim);
+ }
+ }
+ for (auto kDim : contractionDims.k) {
+ if (!ShapedType::isDynamic(bounds[kDim])) {
+ kDims.push_back(kDim);
+ }
+ }
+
+ auto getDimBounds = [&](SmallVector<int64_t> dims) -> SmallVector<int64_t> {
+ return llvm::map_to_vector(dims, [&](int64_t dim) { return bounds[dim]; });
+ };
+
Value lhs = linalgOp.getDpsInputOperand(0)->get();
Value rhs = linalgOp.getDpsInputOperand(1)->get();
Value init = linalgOp.getDpsInitOperand(0)->get();
@@ -146,8 +164,9 @@
Type rhsElemType = getElementTypeOrSelf(rhs);
Type initElemType = getElementTypeOrSelf(init);
- GPUMatmulShapeType problem{bounds[mDim], bounds[nDim], bounds[kDim],
- lhsElemType, rhsElemType, initElemType};
+ GPUMatmulShapeType problem{getDimBounds(mDims), getDimBounds(nDims),
+ getDimBounds(kDims), lhsElemType,
+ rhsElemType, initElemType};
SmallVector<GPUMatmulShapeType> intrinsics;
for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) {
@@ -166,7 +185,9 @@
// Note that the following heuristic seeds are just placeholder values.
// We need to clean it up and make it adjusting to different targets.
// See https://github.com/iree-org/iree/issues/16341 for details.
- if (problem.mSize * problem.nSize <= 512 * 512) {
+ int64_t mSize = ShapedType::getNumElements(problem.mSizes);
+ int64_t nSize = ShapedType::getNumElements(problem.nSizes);
+ if (mSize * nSize <= 512 * 512) {
// For matmuls with small M*N size, we want to distribute M*N onto more
// workgroups to fill the GPU. Use a smaller bestMNTileCountPerSubgroup
// and a larger bestKTileCountPerSubgroup.
@@ -190,10 +211,10 @@
// TODO: Drop this. This is only a consideration for other pipelines.
SmallVector<AffineMap> maps = linalgOp.getIndexingMapsArray();
bool transposedLhs =
- kDim !=
+ kDims.back() !=
llvm::cast<AffineDimExpr>(maps[0].getResults().back()).getPosition();
bool transposedRhs =
- nDim !=
+ nDims.back() !=
llvm::cast<AffineDimExpr>(maps[1].getResults().back()).getPosition();
// First try to find a schedule with an exactly matching intrinsic.
@@ -213,16 +234,13 @@
}
LDBG("Target Subgroup size: " << targetSubgroupSize);
- LDBG("Schedule: sizes [" << schedule->mSize << ", " << schedule->nSize << ", "
- << schedule->kSize << "]");
- LDBG("Schedule: tile counts [" << schedule->mTileCount << ", "
- << schedule->nTileCount << ", "
- << schedule->kTileCount << "]");
- LDBG("Schedule: warp counts [" << schedule->mWarpCount << ", "
- << schedule->nWarpCount << "]");
+ LDBG("Schedule: " << schedule);
- std::array<int64_t, 3> workgroupSize{
- schedule->nWarpCount * targetSubgroupSize, schedule->mWarpCount, 1};
+ int64_t flatWorkgroupSize =
+ targetSubgroupSize *
+ ShapedType::getNumElements(schedule->nSubgroupCounts) *
+ ShapedType::getNumElements(schedule->mSubgroupCounts);
+ std::array<int64_t, 3> workgroupSize{flatWorkgroupSize, 1, 1};
SmallVector<int64_t> workgroupTileSizes(linalgOp.getNumLoops(), 0);
SmallVector<int64_t> reductionTileSizes(linalgOp.getNumLoops(), 0);
@@ -244,16 +262,30 @@
reductionTileSizes[k] = 1;
}
- // Compute the M/N dimension tile size by multiplying subgroup information.
- workgroupTileSizes[mDim] = schedule->mWarpCount * schedule->mTileCount;
- workgroupTileSizes[nDim] = schedule->nWarpCount * schedule->nTileCount;
+ // Adjust the inner bound size for packing to intrinsic shapes, since tiling
+ // happens after packing.
+ assert(bounds[mDims.back()] % schedule->mSize == 0 &&
+ bounds[nDims.back()] % schedule->nSize == 0 &&
+ "expected inner bound to be evenly divisible by schedule sizes.");
+ bounds[mDims.back()] /= schedule->mSize;
+ bounds[nDims.back()] /= schedule->nSize;
- // Specify the subgroup tile sizes from the mma schedule. This is applied
- subgroupTileSizes[mDim] = schedule->mTileCount;
- subgroupTileSizes[nDim] = schedule->nTileCount;
+ // Compute the M/N dimension tile sizes by multiplying subgroup information.
+ for (auto [i, mDim] : llvm::enumerate(mDims)) {
+ workgroupTileSizes[mDim] =
+ schedule->mSubgroupCounts[i] * schedule->mTileSizes[i];
+ subgroupTileSizes[mDim] = schedule->mTileSizes[i];
+ }
+ for (auto [i, nDim] : llvm::enumerate(nDims)) {
+ workgroupTileSizes[nDim] =
+ schedule->nSubgroupCounts[i] * schedule->nTileSizes[i];
+ subgroupTileSizes[nDim] = schedule->nTileSizes[i];
+ }
// Similarly the reduction tile size is just the post-packing tile count.
- reductionTileSizes[kDim] = schedule->kTileCount;
+ for (auto [i, kDim] : llvm::enumerate(kDims)) {
+ reductionTileSizes[kDim] = schedule->kTileSizes[i];
+ }
IREE::GPU::MmaInterfaceAttr mmaKind =
target.getWgp().getMma()[schedule->index];
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index ff002ac..4b64cda 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -301,6 +301,11 @@
Type rhsElemType = getElementTypeOrSelf(rhs);
Type initElemType = getElementTypeOrSelf(init);
+ // TODO(Max191): Support multiple M/N/K dimension problems for MMASchedules
+ // once the pipeline is able to support it. After adding multiple dimensions,
+ // all instances of schedule->m/nSubgroupCounts[0] and
+ // schedule->m/n/kTileSizes[0] need to use the full list of sizes instead of
+ // just the first element.
GPUMatmulShapeType problem{bounds[mDim], bounds[nDim], bounds[kDim],
lhsElemType, rhsElemType, initElemType};
@@ -339,8 +344,9 @@
return failure();
}
- std::array<int64_t, 3> workgroupSize{
- schedule->nWarpCount * targetSubgroupSize, schedule->mWarpCount, 1};
+ std::array<int64_t, 3> workgroupSize{schedule->nSubgroupCounts[0] *
+ targetSubgroupSize,
+ schedule->mSubgroupCounts[0], 1};
SmallVector<int64_t> workgroupTileSizes(op.getNumLoops(), 0);
SmallVector<int64_t> reductionTileSizes(op.getNumLoops(), 0);
@@ -360,11 +366,11 @@
}
// Compute the M/N dimension tile size by multiply subgroup information.
workgroupTileSizes[mDim] =
- schedule->mWarpCount * schedule->mTileCount * schedule->mSize;
+ schedule->mSubgroupCounts[0] * schedule->mTileSizes[0] * schedule->mSize;
workgroupTileSizes[nDim] =
- schedule->nWarpCount * schedule->nTileCount * schedule->nSize;
+ schedule->nSubgroupCounts[0] * schedule->nTileSizes[0] * schedule->nSize;
- reductionTileSizes[kDim] = schedule->kTileCount * schedule->kSize;
+ reductionTileSizes[kDim] = schedule->kTileSizes[0] * schedule->kSize;
// Tile all filter loop dimensions to 1.
for (int64_t filterDim : convolutionDims->filterLoop) {
@@ -386,8 +392,8 @@
// for later access in the pipeline.
SmallVector<NamedAttribute, 1> pipelineAttrs;
auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get(
- context, target.getWgp().getMma()[schedule->index], schedule->mWarpCount,
- schedule->nWarpCount);
+ context, target.getWgp().getMma()[schedule->index],
+ schedule->mSubgroupCounts[0], schedule->nSubgroupCounts[0]);
pipelineAttrs.emplace_back(StringAttr::get(context, "mma_schedule"),
scheduleAttr);
@@ -489,6 +495,11 @@
rhsElemType = getElementTypeOrSelf(rhsOp.getDpsInputs()[0]);
}
+ // TODO(Max191): Support multiple M/N/K dimension problems for MMASchedules
+ // once the pipeline is able to support it. After adding multiple dimensions,
+ // all instances of schedule->m/nSubgroupCounts[0] and
+ // schedule->m/n/kTileSizes[0] need to use the full list of sizes instead of
+ // just the first element.
GPUMatmulShapeType problem{bounds[mDim], bounds[nDim], bounds[kDim],
lhsElemType, rhsElemType, initElemType};
@@ -509,7 +520,7 @@
// Note that the following heuristic seeds are just placeholder values.
// We need to clean it up and make it adjusting to different targets.
// See https://github.com/iree-org/iree/issues/16341 for details.
- if (problem.mSize * problem.nSize <= clGPUMatmulCThreshold) {
+ if (problem.mSizes[0] * problem.nSizes[0] <= clGPUMatmulCThreshold) {
// For matmuls with small M*N size, we want to distribute M*N onto more
// workgroups to fill the GPU. Use a smaller bestMNTileCountPerSubgroup
// and a larger bestKTileCountPerSubgroup.
@@ -573,16 +584,11 @@
}
LDBG("Target Subgroup size: " << targetSubgroupSize);
- LDBG("Schedule: sizes [" << schedule->mSize << ", " << schedule->nSize << ", "
- << schedule->kSize << "]");
- LDBG("Schedule: tile counts [" << schedule->mTileCount << ", "
- << schedule->nTileCount << ", "
- << schedule->kTileCount << "]");
- LDBG("Schedule: warp counts [" << schedule->mWarpCount << ", "
- << schedule->nWarpCount << "]");
+ LDBG("Schedule: " << schedule);
- std::array<int64_t, 3> workgroupSize{
- schedule->nWarpCount * targetSubgroupSize, schedule->mWarpCount, 1};
+ std::array<int64_t, 3> workgroupSize{schedule->nSubgroupCounts[0] *
+ targetSubgroupSize,
+ schedule->mSubgroupCounts[0], 1};
SmallVector<int64_t> workgroupTileSizes(op.getNumLoops(), 0);
SmallVector<int64_t> reductionTileSizes(op.getNumLoops(), 0);
@@ -605,11 +611,11 @@
// Compute the M/N dimension tile size by multiply subgroup information.
workgroupTileSizes[mDim] =
- schedule->mWarpCount * schedule->mTileCount * schedule->mSize;
+ schedule->mSubgroupCounts[0] * schedule->mTileSizes[0] * schedule->mSize;
workgroupTileSizes[nDim] =
- schedule->nWarpCount * schedule->nTileCount * schedule->nSize;
+ schedule->nSubgroupCounts[0] * schedule->nTileSizes[0] * schedule->nSize;
- reductionTileSizes[kDim] = schedule->kTileCount * schedule->kSize;
+ reductionTileSizes[kDim] = schedule->kTileSizes[0] * schedule->kSize;
LLVM_DEBUG(debugPrintContractionInfo("Workgroup tile sizes", op.getNumLoops(),
*contractionDims, workgroupTileSizes));
@@ -631,8 +637,8 @@
// for later access in the pipeline.
SmallVector<NamedAttribute, 1> pipelineAttrs;
auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get(
- context, target.getWgp().getMma()[schedule->index], schedule->mWarpCount,
- schedule->nWarpCount);
+ context, target.getWgp().getMma()[schedule->index],
+ schedule->mSubgroupCounts[0], schedule->nSubgroupCounts[0]);
pipelineAttrs.emplace_back(StringAttr::get(context, "mma_schedule"),
scheduleAttr);
@@ -772,22 +778,17 @@
// TODO: Due to a bug in layout configuration, we cannot set warp count on
// the N dimension. This is however ok, because we generally do not want to
// distribute subgroups on N dimension anyway.
- if (schedule->nWarpCount != 1) {
- schedule->nTileCount *= schedule->nWarpCount;
- schedule->nWarpCount = 1;
+ if (schedule->nSubgroupCounts[0] != 1) {
+ schedule->nTileSizes[0] *= schedule->nSubgroupCounts[0];
+ schedule->nSubgroupCounts[0] = 1;
}
LDBG("Target Subgroup size: " << targetSubgroupSize);
- LDBG("Schedule: sizes [" << schedule->mSize << ", " << schedule->nSize << ", "
- << schedule->kSize << "]");
- LDBG("Schedule: tile counts [" << schedule->mTileCount << ", "
- << schedule->nTileCount << ", "
- << schedule->kTileCount << "]");
- LDBG("Schedule: warp counts [" << schedule->mWarpCount << ", "
- << schedule->nWarpCount << "]");
+ LDBG("Schedule: " << schedule);
- std::array<int64_t, 3> workgroupSize{
- schedule->nWarpCount * targetSubgroupSize, schedule->mWarpCount, 1};
+ std::array<int64_t, 3> workgroupSize{schedule->nSubgroupCounts[0] *
+ targetSubgroupSize,
+ schedule->mSubgroupCounts[0], 1};
SmallVector<int64_t> workgroupTileSizes(opInfo.getDomainRank(), 0);
SmallVector<int64_t> reductionTileSizes(op.getNumLoops(), 0);
@@ -811,11 +812,11 @@
// Compute the M/N dimension tile size by multiply subgroup information.
workgroupTileSizes[mDim] =
- schedule->mWarpCount * schedule->mTileCount * schedule->mSize;
+ schedule->mSubgroupCounts[0] * schedule->mTileSizes[0] * schedule->mSize;
workgroupTileSizes[nDim] =
- schedule->nWarpCount * schedule->nTileCount * schedule->nSize;
+ schedule->nSubgroupCounts[0] * schedule->nTileSizes[0] * schedule->nSize;
- reductionTileSizes[k2Dim] = schedule->kTileCount * schedule->kSize;
+ reductionTileSizes[k2Dim] = schedule->kTileSizes[0] * schedule->kSize;
MLIRContext *context = op.getContext();
SmallVector<NamedAttribute, 2> attrs;
@@ -831,8 +832,8 @@
// for later access in the pipeline.
SmallVector<NamedAttribute, 1> pipelineAttrs;
auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get(
- context, target.getWgp().getMma()[schedule->index], schedule->mWarpCount,
- schedule->nWarpCount);
+ context, target.getWgp().getMma()[schedule->index],
+ schedule->mSubgroupCounts[0], schedule->nSubgroupCounts[0]);
pipelineAttrs.emplace_back(StringAttr::get(context, "mma_schedule"),
scheduleAttr);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
index b98e85a..819b882 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
@@ -37,11 +37,79 @@
// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
// CHECK-SAME: promote_operands = [0, 1]
// CHECK-SAME: reduction = [0, 0, 0, 0, 4]
-// CHECK-SAME: subgroup = [0, 0, 4, 1, 0]
+// CHECK-SAME: subgroup = [1, 1, 4, 1, 0]
// CHECK-SAME: workgroup = [1, 1, 4, 4, 0]
// -----
+#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d4, d5)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d3, d4, d5)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+func.func @multi_dim_mma_schedule(%lhs: tensor<10x32x128x16xf16>, %rhs: tensor<4x32x128x16xf16>) -> tensor<10x4x32x32xf16> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f16
+ %5 = tensor.empty() : tensor<10x4x32x32xf16>
+ %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<10x4x32x32xf16>) -> tensor<10x4x32x32xf16>
+ %7 = linalg.generic {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
+ ins(%lhs, %rhs : tensor<10x32x128x16xf16>, tensor<4x32x128x16xf16>) outs(%6 : tensor<10x4x32x32xf16>) {
+ ^bb0(%in: f16, %in_0: f16, %out: f16):
+ %8 = arith.mulf %in, %in_0 : f16
+ %9 = arith.addf %8, %out : f16
+ linalg.yield %9 : f16
+ } -> tensor<10x4x32x32xf16>
+ return %7 : tensor<10x4x32x32xf16>
+}
+
+// CHECK-LABEL: func.func @multi_dim_mma_schedule
+// CHECK-SAME: #iree_codegen.translation_info<LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
+// CHECK-SAME: #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>
+
+// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
+// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
+// CHECK-SAME: promote_operands = [0, 1]
+// CHECK-SAME: reduction = [0, 0, 0, 0, 4, 1]
+// CHECK-SAME: subgroup = [2, 2, 1, 1, 0, 0]
+// CHECK-SAME: workgroup = [2, 2, 2, 2, 0, 0]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d5, d6)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d4, d5, d6)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
+func.func @dynamic_multi_dim_mma_schedule(%lhs: tensor<?x6x16x?x16xf16>, %rhs: tensor<?x32x?x16xf16>) -> tensor<?x6x?x16x32xf16> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f16
+ %d0 = tensor.dim %lhs, %c0 : tensor<?x6x16x?x16xf16>
+ %d2 = tensor.dim %rhs, %c0 : tensor<?x32x?x16xf16>
+ %5 = tensor.empty(%d0, %d2) : tensor<?x6x?x16x32xf16>
+ %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<?x6x?x16x32xf16>) -> tensor<?x6x?x16x32xf16>
+ %7 = linalg.generic {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
+ ins(%lhs, %rhs : tensor<?x6x16x?x16xf16>, tensor<?x32x?x16xf16>) outs(%6 : tensor<?x6x?x16x32xf16>) {
+ ^bb0(%in: f16, %in_0: f16, %out: f16):
+ %8 = arith.mulf %in, %in_0 : f16
+ %9 = arith.addf %8, %out : f16
+ linalg.yield %9 : f16
+ } -> tensor<?x6x?x16x32xf16>
+ return %7 : tensor<?x6x?x16x32xf16>
+}
+
+// CHECK-LABEL: func.func @dynamic_multi_dim_mma_schedule
+// CHECK-SAME: #iree_codegen.translation_info<LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
+// CHECK-SAME: #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>
+
+// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
+// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
+// CHECK-SAME: promote_operands = [0, 1]
+// CHECK-SAME: reduction = [0, 0, 0, 0, 0, 1, 1]
+// CHECK-SAME: subgroup = [0, 1, 0, 1, 1, 0, 0]
+// CHECK-SAME: workgroup = [1, 2, 1, 1, 2, 0, 0]
+
+// -----
+
func.func @mfma_matmul_1024x1024x1024(%lhs: tensor<1024x1024xf16>, %rhs: tensor<1024x1024xf16>) -> tensor<1024x1024xf32> {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
@@ -52,7 +120,7 @@
}
// CHECK-LABEL: func.func @mfma_matmul_1024x1024x1024
-// CHECK-SAME: #iree_codegen.translation_info<LLVMGPUTileAndFuse workgroup_size = [128, 2, 1] subgroup_size = 64
+// CHECK-SAME: #iree_codegen.translation_info<LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
// CHECK-SAME: #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>
// Verify that the fill does not have the lowering config propagated to it.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/llvmgpu_convolution_to_igemm.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/llvmgpu_convolution_to_igemm.mlir
index 1fa2bae..9618281 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/llvmgpu_convolution_to_igemm.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/llvmgpu_convolution_to_igemm.mlir
@@ -33,4 +33,4 @@
// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<
// CHECK-SAME: {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>,
// CHECK-SAME: promote_operands = [0, 1], reduction = [0, 0, 0, 0, 8],
-// CHECK-SAME: subgroup = [0, 0, 2, 2, 0], workgroup = [1, 1, 2, 8, 0]}>
+// CHECK-SAME: subgroup = [1, 1, 2, 2, 0], workgroup = [1, 1, 2, 8, 0]}>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
index 16a1acf..bbdec5c 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
@@ -884,6 +884,11 @@
Type lhsElem = getElementType(lhs);
Type rhsElem = getElementType(rhs);
Type initElem = getElementType(init);
+ // TODO(Max191): Support multiple M/N/K dimension problems for MMASchedules
+ // once the pipeline is able to support it. After adding multiple dimensions,
+ // all instances of schedule->m/nSubgroupCounts[0] and
+ // schedule->m/n/kTileSizes[0] need to use the full list of sizes instead of
+ // just the first element.
GPUMatmulShapeType problem(dimM, dimN, dimK, lhsElem, rhsElem, initElem);
SmallVector<GPUMatmulShapeType> intrinsics;
@@ -921,8 +926,9 @@
auto pipeline = CodeGenPipeline::SPIRVCooperativeMatrixVectorize;
- std::array<int64_t, 3> workgroupSize{schedule->nWarpCount * subgroupSize,
- schedule->mWarpCount, 1};
+ std::array<int64_t, 3> workgroupSize{schedule->nSubgroupCounts[0] *
+ subgroupSize,
+ schedule->mSubgroupCounts[0], 1};
SmallVector<int64_t> vectorSizes(kIndex + 1, 0);
if (isBM)
@@ -934,21 +940,23 @@
SmallVector<int64_t> subgroupTileSizes(lastParallelDim + 1, 0);
if (isBM)
subgroupTileSizes[bIndex] = 1;
- subgroupTileSizes[mIndex] = schedule->mTileCount * vectorSizes[mIndex];
- subgroupTileSizes[nIndex] = schedule->nTileCount * vectorSizes[nIndex];
+ subgroupTileSizes[mIndex] = schedule->mTileSizes[0] * vectorSizes[mIndex];
+ subgroupTileSizes[nIndex] = schedule->nTileSizes[0] * vectorSizes[nIndex];
SmallVector<int64_t> workgroupTileSizes(lastParallelDim + 1, 0);
if (isBM)
workgroupTileSizes[bIndex] = 1;
- workgroupTileSizes[mIndex] = schedule->mWarpCount * subgroupTileSizes[mIndex];
- workgroupTileSizes[nIndex] = schedule->nWarpCount * subgroupTileSizes[nIndex];
+ workgroupTileSizes[mIndex] =
+ schedule->mSubgroupCounts[0] * subgroupTileSizes[mIndex];
+ workgroupTileSizes[nIndex] =
+ schedule->nSubgroupCounts[0] * subgroupTileSizes[nIndex];
// Also create one level for reduction. This is needed because of
// SPIRVTileAndPromotePass requires it.
// TODO(#10499): Consolidate tiling configuration across different pipelines.
SmallVector<int64_t> reductionTileSizes;
reductionTileSizes.append(kIndex, 0);
- reductionTileSizes.push_back(schedule->kTileCount * schedule->kSize);
+ reductionTileSizes.push_back(schedule->kTileSizes[0] * schedule->kSize);
TileSizesListType tileSizes = {workgroupTileSizes, subgroupTileSizes,
reductionTileSizes, vectorSizes};
@@ -956,7 +964,7 @@
// Don't do multibuffering if the inner reduction loop is folded out.
auto pipelineDepth = softwarePipelineDepth;
auto storeStage = softwarePipelineStoreStage;
- if (schedule->kTileCount <= 1) {
+ if (schedule->kTileSizes[0] <= 1) {
pipelineDepth = 0;
storeStage = 0;
}
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp b/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp
index ba415b3..922e508 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp
+++ b/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp
@@ -242,16 +242,16 @@
return llvm::divideCeil(value, padTo) * padTo - value;
};
- if (mSize % intrinsic.mSize != 0) {
- mPadding = getPadding(mSize, intrinsic.mSize);
+ if (mSize % intrinsic.mSizes[0] != 0) {
+ mPadding = getPadding(mSize, intrinsic.mSizes[0]);
}
- if (nSize % intrinsic.nSize != 0) {
- nPadding = getPadding(nSize, intrinsic.nSize);
+ if (nSize % intrinsic.nSizes[0] != 0) {
+ nPadding = getPadding(nSize, intrinsic.nSizes[0]);
}
- if (kSize % intrinsic.kSize != 0) {
- kPadding = getPadding(kSize, intrinsic.kSize);
+ if (kSize % intrinsic.kSizes[0] != 0) {
+ kPadding = getPadding(kSize, intrinsic.kSizes[0]);
}
if (!mPadding && !nPadding && !kPadding) {
@@ -381,7 +381,7 @@
for (GPUMatmulShapeType &intrinsic : intrinsics) {
std::optional<OpFoldResult> mPadding, nPadding, kPadding;
SmallVector<std::pair<int64_t, int64_t>> dimsToExpandCandidate;
- if (mSize % intrinsic.mSize != 0 || ShapedType::isDynamic(mSize)) {
+ if (mSize % intrinsic.mSizes[0] != 0 || ShapedType::isDynamic(mSize)) {
OpFoldResult mSizeExpr = rewriter.getIndexAttr(mSize);
if (ShapedType::isDynamic(mSize)) {
auto mOperandDimPair = getSrcOperandAndDim(mDim);
@@ -390,12 +390,12 @@
auto [mOperand, mOperandDim] = mOperandDimPair.value();
mSizeExpr = rewriter.create<tensor::DimOp>(loc, mOperand, mOperandDim)
.getResult();
- dimsToExpandCandidate.emplace_back(mDim, intrinsic.mSize);
+ dimsToExpandCandidate.emplace_back(mDim, intrinsic.mSizes[0]);
}
- mPadding = getPadding(mSizeExpr, intrinsic.mSize);
+ mPadding = getPadding(mSizeExpr, intrinsic.mSizes[0]);
}
- if (nSize % intrinsic.nSize != 0 || ShapedType::isDynamic(nSize)) {
+ if (nSize % intrinsic.nSizes[0] != 0 || ShapedType::isDynamic(nSize)) {
OpFoldResult nSizeExpr = rewriter.getIndexAttr(nSize);
if (ShapedType::isDynamic(nSize)) {
auto nOperandDimPair = getSrcOperandAndDim(nDim);
@@ -404,12 +404,12 @@
auto [nOperand, nOperandDim] = nOperandDimPair.value();
nSizeExpr = rewriter.create<tensor::DimOp>(loc, nOperand, nOperandDim)
.getResult();
- dimsToExpandCandidate.emplace_back(nDim, intrinsic.nSize);
+ dimsToExpandCandidate.emplace_back(nDim, intrinsic.nSizes[0]);
}
- nPadding = getPadding(nSizeExpr, intrinsic.nSize);
+ nPadding = getPadding(nSizeExpr, intrinsic.nSizes[0]);
}
- if (kSize % intrinsic.kSize != 0 || ShapedType::isDynamic(kSize)) {
+ if (kSize % intrinsic.kSizes[0] != 0 || ShapedType::isDynamic(kSize)) {
OpFoldResult kSizeExpr = rewriter.getIndexAttr(kSize);
if (ShapedType::isDynamic(kSize)) {
auto kOperandDimPair = getSrcOperandAndDim(kDim);
@@ -418,9 +418,9 @@
auto [kOperand, kOperandDim] = kOperandDimPair.value();
kSizeExpr = rewriter.create<tensor::DimOp>(loc, kOperand, kOperandDim)
.getResult();
- dimsToExpandCandidate.emplace_back(kDim, intrinsic.kSize);
+ dimsToExpandCandidate.emplace_back(kDim, intrinsic.kSizes[0]);
}
- kPadding = getPadding(kSizeExpr, intrinsic.kSize);
+ kPadding = getPadding(kSizeExpr, intrinsic.kSizes[0]);
}
if (!mPadding && !nPadding && !kPadding) {