Revert "[VectorDistribute] Refactor layout configuration to a simpler logic" (#21887)
CI is failing. PkgCI sharktank tests are not required on pre-submit
(they should be).
Reverts iree-org/iree#21883
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
index 98c0f5f..83768db 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
@@ -29,7 +29,6 @@
#define GEN_PASS_DEF_LLVMGPUCONFIGURETENSORLAYOUTSPASS
#include "iree/compiler/Codegen/LLVMGPU/Passes.h.inc"
-using IREE::GPU::MMASingleSubgroupLayout;
using IREE::VectorExt::NestedLayoutAttr;
using IREE::VectorExt::ToLayoutOp;
using IREE::VectorExt::VectorLayoutInterface;
@@ -80,123 +79,81 @@
return getSubgroupNCount(config).value();
}
-struct ContractionLayout {
- VectorLayoutInterface lhs;
- VectorLayoutInterface rhs;
- VectorLayoutInterface acc;
-};
-
-// Get the layouts to use for the contraction given the intrinsic to use and
-// number of subgroups on the M and N dimension.
-//
-// The contraction is expected to have 3 operands: lhs, rhs and acc of the
-// contraction and a single accumulator.
-static FailureOr<ContractionLayout>
-getContractionLayout(IREE::Codegen::InnerTileDescAttrInterface intrinsic,
- int64_t subgroupMCount, int64_t subgroupNCount,
- ArrayRef<int64_t> bounds,
- ArrayRef<AffineMap> contractIndexingMaps) {
- auto mmaIntrinsic = dyn_cast<IREE::GPU::MmaInterfaceAttr>(intrinsic);
- if (!mmaIntrinsic) {
- return failure();
- }
-
- int64_t rank = bounds.size();
- FailureOr<VectorContractOpInfo> maybeOpInfo =
- VectorContractOpInfo::inferFromIndexingMaps(contractIndexingMaps);
- if (failed(maybeOpInfo)) {
- return failure();
- }
- VectorContractOpInfo opInfo = maybeOpInfo.value();
- // Get the inner dimensions.
- int64_t innerMDim = opInfo.getMDims().back();
- int64_t innerNDim = opInfo.getNDims().back();
- int64_t innerKDim = opInfo.getKDims().back();
- // Find the number Of subgroups being used from subgroupMCount/subgroupNCount.
- // Just assign them to the inner-most dimension for now.
- // TODO: Use subgroup_basis to get this instead and allow distributing
- // subgroups on multiple dimensions.
- SmallVector<int64_t> subgroupCounts(rank, 1);
- SmallVector<int64_t> subgroupStrides(rank, 0);
- subgroupCounts[innerMDim] = subgroupMCount;
- subgroupCounts[innerNDim] = subgroupNCount;
- // Distribute on M and then N.
- // TODO: Use subgroup_basis to get the strides.
- subgroupStrides[innerMDim] = 1;
- subgroupStrides[innerNDim] = subgroupMCount;
-
- // Since these MMA intrinsics have a given tile size for each subgroup, we can
- // calculate the batch dimensions without looking at the subgroup layout.
- SmallVector<int64_t> subgroupSize(rank, 1);
- auto [mSize, nSize, kSize] = mmaIntrinsic.getMNKShape();
- subgroupSize[innerMDim] = mSize;
- subgroupSize[innerNDim] = nSize;
- subgroupSize[innerKDim] = kSize;
-
- SmallVector<int64_t> batchCounts(rank);
- for (auto [batchCount, subgroupCount, subgroupSize, bound] :
- llvm::zip_equal(batchCounts, subgroupCounts, subgroupSize, bounds)) {
- int64_t workgroupDimSize = subgroupCount * subgroupSize;
- batchCount = llvm::divideCeil(bound, workgroupDimSize);
- }
-
- // MMA intrinsics can be weird and usually don't have a single subgroup
- // iteration space, so we need to find their value subgroup iteration space
- // indvidually.
- auto getFragmentLayout = [&](IREE::GPU::MMAFragment fragment,
- int64_t outerDim, int64_t innerDim,
- AffineMap map) -> VectorLayoutInterface {
- // Note that the struct MMASingleSubgroupLayout contains the partial layout
- // for the canonical (M, K) x (K, N) -> (M, N) matmul form. We treat the
- // concrete nested layout as the layout for the innermost M, N, K
- // dimensions.
- SmallVector<int64_t> outerCounts(rank, 1);
- SmallVector<int64_t> elementCounts(rank, 1);
- SmallVector<int64_t> threadCounts(rank, 1);
- SmallVector<int64_t> threadStrides(rank, 0);
-
- MMASingleSubgroupLayout subgroupLayout =
- IREE::GPU::getSingleSubgroupLayout(intrinsic, fragment);
- outerCounts[outerDim] = subgroupLayout.outer[0];
- outerCounts[innerDim] = subgroupLayout.outer[1];
- threadCounts[outerDim] = subgroupLayout.thread[0];
- threadCounts[innerDim] = subgroupLayout.thread[1];
- threadStrides[outerDim] = subgroupLayout.tstrides[0];
- threadStrides[innerDim] = subgroupLayout.tstrides[1];
- elementCounts[outerDim] = subgroupLayout.element[0];
- elementCounts[innerDim] = subgroupLayout.element[1];
- // Get the fragment layout for the entire iteration space and then project
- // it. This is significantly easier than trying to create a layout for each
- // fragment itself.
- auto fragmentSpaceLayout = NestedLayoutAttr::get(
- map.getContext(), subgroupCounts, batchCounts, outerCounts,
- threadCounts, elementCounts, subgroupStrides, threadStrides);
- return fragmentSpaceLayout.apply(map);
- };
-
- VectorLayoutInterface lhs =
- getFragmentLayout(IREE::GPU::MMAFragment::Lhs, innerMDim, innerKDim,
- contractIndexingMaps[0]);
- VectorLayoutInterface rhs =
- getFragmentLayout(IREE::GPU::MMAFragment::Rhs, innerKDim, innerNDim,
- contractIndexingMaps[1]);
- VectorLayoutInterface acc =
- getFragmentLayout(IREE::GPU::MMAFragment::Acc, innerMDim, innerNDim,
- contractIndexingMaps[2]);
-
- return ContractionLayout{lhs, rhs, acc};
+/// Gets a unit vector of the given rank, but fills in the given dimensions
+/// from the 2 element array |counts|. |dim0| is the position in the returned
+/// vector to put the first element of |counts|, and |dim1| is the position to
+/// put the second element. For example,
+///
+/// rank = 3, counts = [5, 7], dim0 = 2, dim1 = 1
+/// returns [1, 5, 7]
+static SmallVector<int64_t> getUnitOfRankWithDims(int64_t rank,
+ ArrayRef<int64_t> counts,
+ int64_t dim0, int64_t dim1) {
+ assert(counts.size() == 2 &&
+ "Unexpected non-rank 2 single subgroup dimension counts");
+ SmallVector<int64_t> res(rank, 1);
+ res[dim0] = counts[0];
+ res[dim1] = counts[1];
+ return res;
}
-SmallVector<int64_t> getIterationSpaceBounds(linalg::LinalgOp linalgOp) {
- SmallVector<int64_t> bounds = linalgOp.getStaticLoopRanges();
- std::optional<VectorizationTileSizes> sizes =
- inferSizesFromIR(linalgOp, std::nullopt);
- // Even though the opShape could be dynamic, we could potentially
- // infer the vector shape
- if (sizes.has_value()) {
- bounds = sizes.value().vectorSizes;
- }
- return bounds;
+/// Constructs the nested layout given the layout for a single subgroup and the
+/// subgroup/batch counts and orders, as well as the dimensions along which to
+/// distribute the intrinsic's layout.
+///
+/// |outerDim| and |innerDim| refer to which dimensions are the outermost and
+/// innermost for a canonical MK_KN_MN matrix multiply, for a particular
+/// fragment. For example, for the B matrix of an MK_NK_MN matrix multiply,
+/// we would have:
+/// outerDim = 1 for the K dim
+/// innerDim = 0 for the N dim
+///
+/// For something like MK_NKN_MN with multiple N dims, it would typically be:
+/// outerDim = 1 for K
+/// innerDim = 2 for the second N dim
+///
+/// Importantly these two dimensions always refer to the actual dimension
+/// positions in the undistributed vector. For each fragment, this means:
+/// A: [outerDim, innerDim] = [innerMostMDim, innerMostKDim]
+/// B: [outerDim, innerDim] = [innerMostKDim, innerMostNDim]
+/// C: [outerDim, innerDim] = [innerMostMDim, innerMostNDim]
+///
+/// And here inner most is referential to the iteration order, not the order
+/// they appear per fragment (because there is no relationship between the
+/// dimension order of M in A and in C, for example).
+static NestedLayoutAttr createNestedLayout(
+ MLIRContext *context, int64_t rank, int64_t outerDim, int64_t innerDim,
+ ArrayRef<int64_t> subgroupSizes, ArrayRef<int64_t> subgroupStrides,
+ ArrayRef<int64_t> batchCount, IREE::GPU::MMASingleSubgroupLayout counts) {
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "Creating Nested Layout for::" << "\n outerDim = "
+ << outerDim << "\n innerDim = " << innerDim
+ << "\n subgroupSizes: " << llvm::interleaved(subgroupSizes)
+ << "\n subgroupStrides: "
+ << llvm::interleaved(subgroupStrides)
+ << "\n batchCount: " << llvm::interleaved(batchCount)
+ << "\n counts.outer: " << llvm::interleaved(counts.outer)
+ << "\n counts.thread: " << llvm::interleaved(counts.thread)
+ << "\n counts.element: "
+ << llvm::interleaved(counts.element)
+ << "\n counts.tstrides: "
+ << llvm::interleaved(counts.tstrides) << "\n";
+ });
+
+ SmallVector<int64_t> outerCount =
+ getUnitOfRankWithDims(rank, counts.outer, outerDim, innerDim);
+ SmallVector<int64_t> threadCount =
+ getUnitOfRankWithDims(rank, counts.thread, outerDim, innerDim);
+ SmallVector<int64_t> threadStrides =
+ getUnitOfRankWithDims(rank, counts.tstrides, outerDim, innerDim);
+ SmallVector<int64_t> elementCount =
+ getUnitOfRankWithDims(rank, counts.element, outerDim, innerDim);
+
+ auto layoutAttr = NestedLayoutAttr::get(context, subgroupSizes, batchCount,
+ outerCount, threadCount, elementCount,
+ subgroupStrides, threadStrides);
+ return layoutAttr;
}
static LogicalResult setContractionAnchor(IREE::GPU::MMAScheduleAttr schedule,
@@ -212,10 +169,12 @@
assert(linalg::isaContractionOpInterface(contract) &&
"cannot set contraction anchor on non contraction op");
- SmallVector<int64_t> bounds = getIterationSpaceBounds(contract);
- auto layouts = getContractionLayout(
- schedule.getIntrinsic(), schedule.getSubgroupMCount(),
- schedule.getSubgroupNCount(), bounds, contract.getIndexingMapsArray());
+ FailureOr<VectorContractOpInfo> opInfo =
+ VectorContractOpInfo::inferFromIndexingMaps(
+ contract.getIndexingMapsArray());
+ assert(succeeded(opInfo) && "contraction should have been inferred");
+
+ auto layouts = getContractionLayout(schedule, opInfo.value(), contract);
if (failed(layouts)) {
return contract->emitError("cannot get concrete layout for contraction");
}
@@ -297,10 +256,15 @@
map = projectDims(map, filterDims, /*compressDimsFlag=*/false);
}
- SmallVector<int64_t> bounds = getIterationSpaceBounds(conv);
- auto layouts = getContractionLayout(
- schedule.getIntrinsic(), schedule.getSubgroupMCount(),
- schedule.getSubgroupNCount(), bounds, maps);
+ FailureOr<VectorContractOpInfo> opInfo =
+ VectorContractOpInfo::inferFromIndexingMaps(maps);
+ assert(succeeded(opInfo) &&
+ "unit filter dim convolution should have been infered");
+
+ auto layouts = getContractionLayout(schedule, opInfo.value(), conv);
+ if (failed(layouts)) {
+ return conv->emitError("cannot get concrete layout for convolution");
+ }
auto [aLayout, bLayout, cLayout] = *layouts;
Location loc = conv.getLoc();
@@ -492,6 +456,26 @@
pvMatmul);
}
+// Apply the permuted projection map to the layout.
+static IREE::VectorExt::VectorLayoutInterface
+getLayoutForMap(VectorLayoutInterface layout, AffineMap map) {
+ // Project out unusued dims in layout.
+ SmallVector<bool> projectedDims(layout.getRank(), false);
+ llvm::SmallBitVector unusedBits = getUnusedDimsBitVector(map);
+ for (int dim : unusedBits.set_bits()) {
+ projectedDims[dim] = true;
+ }
+ IREE::VectorExt::VectorLayoutInterface projectedLayout =
+ layout.project(projectedDims);
+
+ // Transpose dims in layout.
+ AffineMap permMap = compressUnusedDims(map);
+ SmallVector<int64_t> identity =
+ llvm::to_vector(llvm::seq<int64_t>(permMap.getNumDims()));
+ SmallVector<int64_t> perm = applyPermutationMap<int64_t>(permMap, identity);
+ return projectedLayout.permute(perm);
+}
+
static LogicalResult setDerivedThreadConfigLayout(
IREE::GPU::DerivedThreadConfigAttr config, linalg::LinalgOp linalgOp,
ArrayRef<int64_t> workgroupSize, RewriterBase &rewriter) {
@@ -568,7 +552,7 @@
rewriter.setInsertionPointAfter(linalgOp);
for (OpResult result : linalgOp->getResults()) {
VectorLayoutInterface resultLayout =
- layout.apply(linalgOp.getIndexingMapMatchingResult(result));
+ getLayoutForMap(layout, linalgOp.getIndexingMapMatchingResult(result));
auto toLayout = rewriter.create<ToLayoutOp>(loc, result, resultLayout);
rewriter.replaceAllUsesExcept(result, toLayout, toLayout);
}
@@ -691,7 +675,14 @@
MLIRContext *context = config.getContext();
Location loc = candidate.getLoc();
- SmallVector<int64_t> bounds = getIterationSpaceBounds(candidate);
+ SmallVector<int64_t> bounds = candidate.getStaticLoopRanges();
+ std::optional<VectorizationTileSizes> sizes =
+ inferSizesFromIR(candidate, std::nullopt);
+ // Even though the opShape could be dynamic, we could potentially
+ // infer the vector shape
+ if (sizes.has_value()) {
+ bounds = sizes.value().vectorSizes;
+ }
// Subgroup distribution layouts.
SmallVector<int64_t> subgroupSizes, subgroupStrides;
@@ -732,7 +723,7 @@
rewriter.setInsertionPoint(candidate);
for (OpOperand &operand : candidate->getOpOperands()) {
VectorLayoutInterface operandLayout =
- layout.apply(candidate.getMatchingIndexingMap(&operand));
+ getLayoutForMap(layout, candidate.getMatchingIndexingMap(&operand));
auto toLayout =
rewriter.create<ToLayoutOp>(loc, operand.get(), operandLayout);
// Set shared memory promotion if requested.
@@ -744,7 +735,7 @@
rewriter.setInsertionPointAfter(candidate);
for (OpResult result : candidate->getResults()) {
VectorLayoutInterface resultLayout =
- layout.apply(candidate.getIndexingMapMatchingResult(result));
+ getLayoutForMap(layout, candidate.getIndexingMapMatchingResult(result));
auto toLayout = rewriter.create<ToLayoutOp>(loc, result, resultLayout);
rewriter.replaceAllUsesExcept(result, toLayout, toLayout);
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.cpp
index 6cc7285..d8b9fc4 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.cpp
@@ -415,4 +415,335 @@
packAllocs(builder, funcOp, aliasGroups);
}
+/// Gets a unit vector of the given rank, but fills in the given dimensions
+/// from the 2 element array |counts|. |dim0| is the position in the returned
+/// vector to put the first element of |counts|, and |dim1| is the position to
+/// put the second element. For example,
+///
+/// rank = 3, counts = [5, 7], dim0 = 2, dim1 = 1
+/// returns [1, 5, 7]
+static SmallVector<int64_t> getUnitOfRankWithDims(int64_t rank,
+ ArrayRef<int64_t> counts,
+ int64_t dim0, int64_t dim1) {
+ assert(counts.size() == 2 &&
+ "Unexpected non-rank 2 single subgroup dimension counts");
+ SmallVector<int64_t> res(rank, 1);
+ res[dim0] = counts[0];
+ res[dim1] = counts[1];
+ return res;
+}
+
+/// Constructs the nested layout given the layout for a single subgroup and the
+/// subgroup/batch counts and orders, as well as the dimensions along which to
+/// distribute the intrinsic's layout.
+///
+/// |outerDim| and |innerDim| refer to which dimensions are the outermost and
+/// innermost for a canonical MK_KN_MN matrix multiply, for a particular
+/// fragment. For example, for the B matrix of an MK_NK_MN matrix multiply,
+/// we would have:
+/// outerDim = 1 for the K dim
+/// innerDim = 0 for the N dim
+///
+/// For something like MK_NKN_MN with multiple N dims, it would typically be:
+/// outerDim = 1 for K
+/// innerDim = 2 for the second N dim
+///
+/// Importantly these two dimensions always refer to the actual dimension
+/// positions in the undistributed vector. For each fragment, this means:
+/// A: [outerDim, innerDim] = [innerMostMDim, innerMostKDim]
+/// B: [outerDim, innerDim] = [innerMostKDim, innerMostNDim]
+/// C: [outerDim, innerDim] = [innerMostMDim, innerMostNDim]
+///
+/// And here inner most is referential to the iteration order, not the order
+/// they appear per fragment (because there is no relationship between the
+/// dimension order of M in A and in C, for example).
+static NestedLayoutAttr createNestedLayout(
+ MLIRContext *context, int64_t rank, int64_t outerDim, int64_t innerDim,
+ ArrayRef<int64_t> subgroupSizes, ArrayRef<int64_t> subgroupStrides,
+ ArrayRef<int64_t> batchCount, IREE::GPU::MMASingleSubgroupLayout counts) {
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "Creating Nested Layout for::";
+ llvm::dbgs() << "\n outerDim = " << outerDim;
+ llvm::dbgs() << "\n innerDim = " << innerDim;
+ llvm::dbgs() << "\n subgroupSizes: ";
+ llvm::interleaveComma(subgroupSizes, llvm::dbgs());
+ llvm::dbgs() << "\n subgroupStrides: ";
+ llvm::interleaveComma(subgroupStrides, llvm::dbgs());
+ llvm::dbgs() << "\n batchCount: ";
+ llvm::interleaveComma(batchCount, llvm::dbgs());
+ llvm::dbgs() << "\n counts.outer: ";
+ llvm::interleaveComma(counts.outer, llvm::dbgs());
+ llvm::dbgs() << "\n counts.thread: ";
+ llvm::interleaveComma(counts.thread, llvm::dbgs());
+ llvm::dbgs() << "\n counts.element: ";
+ llvm::interleaveComma(counts.element, llvm::dbgs());
+ llvm::dbgs() << "\n counts.tstrides: ";
+ llvm::interleaveComma(counts.tstrides, llvm::dbgs());
+ llvm::dbgs() << "\n";
+ });
+
+ SmallVector<int64_t> outerCount =
+ getUnitOfRankWithDims(rank, counts.outer, outerDim, innerDim);
+ SmallVector<int64_t> threadCount =
+ getUnitOfRankWithDims(rank, counts.thread, outerDim, innerDim);
+ SmallVector<int64_t> threadStrides =
+ getUnitOfRankWithDims(rank, counts.tstrides, outerDim, innerDim);
+ SmallVector<int64_t> elementCount =
+ getUnitOfRankWithDims(rank, counts.element, outerDim, innerDim);
+
+ auto layoutAttr = NestedLayoutAttr::get(context, subgroupSizes, batchCount,
+ outerCount, threadCount, elementCount,
+ subgroupStrides, threadStrides);
+ return layoutAttr;
+}
+
+template <typename ContractOpTy>
+static FailureOr<std::tuple<IREE::VectorExt::VectorLayoutInterface,
+ IREE::VectorExt::VectorLayoutInterface,
+ IREE::VectorExt::VectorLayoutInterface>>
+getContractionLayoutImpl(IREE::GPU::MMAScheduleAttr schedule,
+ VectorContractOpInfo &opInfo,
+ ContractOpTy contractOp) {
+ LLVM_DEBUG({
+ llvm::dbgs() << "Getting mma layouts for:\n" << contractOp << "\n";
+ llvm::dbgs() << "For schedule: " << schedule << "\n";
+ });
+
+ int64_t rank = contractOp.getIteratorTypesArray().size();
+ auto mmaAttr =
+ llvm::dyn_cast<IREE::GPU::MmaInterfaceAttr>(schedule.getIntrinsic());
+ if (!mmaAttr) {
+ return failure();
+ }
+
+ MLIRContext *context = schedule.getContext();
+
+ SmallVector<int64_t> bounds = contractOp.getStaticLoopRanges();
+ if (llvm::any_of(bounds,
+ [](int64_t x) { return x == ShapedType::kDynamic; })) {
+ return failure();
+ }
+
+ if (!llvm::all_of(opInfo.getBatchDims(),
+ [&bounds](int64_t dim) { return bounds[dim] == 1; })) {
+ LLVM_DEBUG({ llvm::dbgs() << "non-unit batch dimension\n"; });
+ return failure();
+ }
+
+ // Get the concrete nested layout for each matrix. Note that the struct
+ // MMASingleSubgroupLayout contains the partial layout for the
+ // canonical (M, K) x (K, N) -> (M, N) matmul form; while the specific
+ // contract op we are looking at right now may not be exactly in that form.
+ // So here we need to permute/transpose the canonical layout to match with
+ // the concrete contract op.
+
+ // Note that no matter how we permute/transpose the input contraction
+ // problem, the way we view the hardware warps remain the same--that is,
+ // from the hardware's perspective, a single warp has the same warp ID no
+ // matter what part of the contraction it works on. Similarly here, we are
+ // delinearizing the linearized GPU hardware lane ID into a n-D concatenated
+ // logical warp+thread using the subgroup/thread basis, so the subgroup
+ // basis should remain the same for all A/B/C matrix.
+
+ auto [intrinsicM, intrinsicN, intrinsicK] = mmaAttr.getMNKShape();
+
+ SmallVector<int64_t, 2> subgroupMBasis;
+ SmallVector<int64_t, 2> batchMSizes;
+ int64_t currMCount = schedule.getSubgroupMCount();
+
+ auto divideGreedily = [](int64_t availableSubgroups, int64_t dimSize,
+ int64_t minDimSize) -> std::pair<int64_t, int64_t> {
+ int64_t dividableDim = dimSize / minDimSize;
+ int64_t subgroupsUsed = std::gcd(availableSubgroups, dividableDim);
+ dividableDim /= subgroupsUsed;
+ int64_t batchesUsed = dividableDim;
+ return {subgroupsUsed, batchesUsed};
+ };
+
+ // Greedily break up the M subgroup and batch counts along the "M" iteration
+ // bounds. We distribute as many residual subgroups as possible per M dim,
+ // and then divide the remaining along batch dims. The inner most M dim is
+ // always the one used for the intrinsic, meaning for a valid schedule, the
+ // computed batch counts and subgroup basis will satisfy totalMSize /
+ // intrinsicM = product(batchMSizes) * product(subgroupMBasis)
+ for (auto dim : opInfo.getMDims()) {
+ // Get the number of subgroups and batches used for this dimension based
+ // on the intrinsic size and the bound size.
+ int64_t subgroupsUsed, batchesUsed;
+ if (dim == opInfo.getMDims().back()) {
+ std::tie(subgroupsUsed, batchesUsed) =
+ divideGreedily(currMCount, bounds[dim], intrinsicM);
+ } else {
+ std::tie(subgroupsUsed, batchesUsed) =
+ divideGreedily(currMCount, bounds[dim], 1);
+ }
+ subgroupMBasis.push_back(subgroupsUsed);
+ batchMSizes.push_back(batchesUsed);
+ // Update available subgroup count.
+ currMCount /= subgroupsUsed;
+ }
+
+ SmallVector<int64_t, 2> subgroupNBasis;
+ SmallVector<int64_t, 2> batchNSizes;
+ int64_t currNCount = schedule.getSubgroupNCount();
+
+ // Do the same for N dims.
+ for (auto dim : opInfo.getNDims()) {
+ // Get the number of subgroups and batches used for this dimension based
+ // on the intrinsic size and the bound size.
+ int64_t subgroupsUsed, batchesUsed;
+ if (dim == opInfo.getNDims().back()) {
+ std::tie(subgroupsUsed, batchesUsed) =
+ divideGreedily(currNCount, bounds[dim], intrinsicN);
+ } else {
+ std::tie(subgroupsUsed, batchesUsed) =
+ divideGreedily(currNCount, bounds[dim], 1);
+ }
+ subgroupNBasis.push_back(subgroupsUsed);
+ batchNSizes.push_back(batchesUsed);
+ // Update available subgroup count.
+ currNCount /= subgroupsUsed;
+ }
+
+ SmallVector<int64_t> subgroupMStrides(subgroupMBasis.size());
+ SmallVector<int64_t> subgroupNStrides(subgroupNBasis.size());
+
+ auto mDimVec = opInfo.getMDims();
+ llvm::SmallDenseSet<int64_t> mDims(mDimVec.begin(), mDimVec.end());
+ auto nDimVec = opInfo.getNDims();
+ llvm::SmallDenseSet<int64_t> nDims(nDimVec.begin(), nDimVec.end());
+ // Because we currently require all batch dimensions to be unit, the
+ // subgroup basis can be constructed from the M and N bases. To keep things
+ // simple, the current heuristic is to distribute the loop dimensions from
+ // outer to inner.
+ int64_t currStride = 1;
+ int64_t currM = subgroupMStrides.size() - 1;
+ int64_t currN = subgroupNStrides.size() - 1;
+ for (int64_t dim : llvm::reverse(llvm::seq<int64_t>(rank))) {
+ if (mDims.contains(dim)) {
+ subgroupMStrides[currM] = currStride;
+ currStride *= subgroupMBasis[currM];
+ currM--;
+ continue;
+ }
+
+ if (nDims.contains(dim)) {
+ subgroupNStrides[currN] = currStride;
+ currStride *= subgroupNBasis[currN];
+ currN--;
+ continue;
+ }
+ }
+
+ // C matrix layout
+ auto [m, n] = opInfo.getResultMNIndex();
+ int64_t cRank = opInfo.getCRank();
+
+ // Get the M and N dims w.r.t. the dimensions of the C matrix. cMDims and
+ // cNDims are the M and N dimensions of the C matrix in the order they are
+ // iterated over in the contraction.
+ SmallVector<int64_t> cMDims = opInfo.outMDims;
+ SmallVector<int64_t> cNDims = opInfo.outNDims;
+ SmallVector<int64_t> cBatchSizes(cRank, 1);
+ SmallVector<int64_t> cSubgroupSizes(cRank, 1);
+ SmallVector<int64_t> cSubgroupStrides(cRank, 0);
+ for (auto [i, dim] : llvm::enumerate(cMDims)) {
+ cBatchSizes[dim] = batchMSizes[i];
+ cSubgroupSizes[dim] = subgroupMBasis[i];
+ cSubgroupStrides[dim] = subgroupMStrides[i];
+ }
+ for (auto [i, dim] : llvm::enumerate(cNDims)) {
+ cBatchSizes[dim] = batchNSizes[i];
+ cSubgroupSizes[dim] = subgroupNBasis[i];
+ cSubgroupStrides[dim] = subgroupNStrides[i];
+ }
+
+ IREE::VectorExt::NestedLayoutAttr cLayout = createNestedLayout(
+ context, cRank, m, n,
+ /*subgroupCount=*/cSubgroupSizes,
+ /*subgroupStrides=*/cSubgroupStrides,
+ /*batchCount=*/cBatchSizes,
+ getSingleSubgroupLayout(mmaAttr, IREE::GPU::MMAFragment::Acc));
+ LLVM_DEBUG({ llvm::dbgs() << "C layout: " << cLayout << "\n"; });
+
+ // A matrix layout
+ auto [afm, bfn] = opInfo.getOperandMNIndex();
+ auto [afk, bfk] = opInfo.getOperandKIndex();
+
+ int64_t aRank = opInfo.getARank();
+
+ SmallVector<int64_t> aMDims = opInfo.lhsMDims;
+ SmallVector<int64_t> aBatchSizes(aRank, 1);
+ SmallVector<int64_t> aSubgroupSizes(aRank, 1);
+ SmallVector<int64_t> aSubgroupStrides(aRank, 0);
+ for (auto [i, dim] : llvm::enumerate(aMDims)) {
+ aBatchSizes[dim] = batchMSizes[i];
+ aSubgroupSizes[dim] = subgroupMBasis[i];
+ aSubgroupStrides[dim] = subgroupMStrides[i];
+ }
+ for (auto [kDim, lhsKDim] :
+ llvm::zip_equal(opInfo.getKDims(), opInfo.lhsKDim)) {
+ aBatchSizes[lhsKDim] = bounds[kDim];
+ }
+ aBatchSizes[afk] = bounds[opInfo.getKDims().back()] / intrinsicK;
+
+ IREE::VectorExt::NestedLayoutAttr aLayout = createNestedLayout(
+ context, aRank, afm, afk,
+ /*subgroupCount=*/aSubgroupSizes,
+ /*subgroupStrides=*/aSubgroupStrides,
+ /*batchCount=*/aBatchSizes,
+ getSingleSubgroupLayout(mmaAttr, IREE::GPU::MMAFragment::Lhs));
+ LLVM_DEBUG({ llvm::dbgs() << "A layout: " << aLayout << "\n"; });
+
+ int64_t bRank = opInfo.getBRank();
+
+ SmallVector<int64_t> bNDims = opInfo.rhsNDims;
+ SmallVector<int64_t> bBatchSizes(bRank, 1);
+ SmallVector<int64_t> bSubgroupSizes(bRank, 1);
+ SmallVector<int64_t> bSubgroupStrides(bRank, 0);
+ for (auto [i, dim] : llvm::enumerate(bNDims)) {
+ bBatchSizes[dim] = batchNSizes[i];
+ bSubgroupSizes[dim] = subgroupNBasis[i];
+ bSubgroupStrides[dim] = subgroupNStrides[i];
+ }
+ for (auto [kDim, rhsKDim] :
+ llvm::zip_equal(opInfo.getKDims(), opInfo.rhsKDim)) {
+ bBatchSizes[rhsKDim] = bounds[kDim];
+ }
+ bBatchSizes[bfk] = bounds[opInfo.getKDims().back()] / intrinsicK;
+
+ IREE::VectorExt::NestedLayoutAttr bLayout = createNestedLayout(
+ context, bRank, bfk, bfn,
+ /*subgroupCount=*/bSubgroupSizes,
+ /*subgroupStrides=*/bSubgroupStrides,
+ /*batchCount=*/bBatchSizes,
+ getSingleSubgroupLayout(mmaAttr, IREE::GPU::MMAFragment::Rhs));
+ LLVM_DEBUG({ llvm::dbgs() << "B layout: " << bLayout << "\n"; });
+
+ std::tuple<VectorLayoutInterface, VectorLayoutInterface,
+ VectorLayoutInterface>
+ result = {aLayout, bLayout, cLayout};
+ return result;
+}
+
+/// Template specializations
+::mlir::FailureOr<::std::tuple<IREE::VectorExt::VectorLayoutInterface,
+ IREE::VectorExt::VectorLayoutInterface,
+ IREE::VectorExt::VectorLayoutInterface>>
+IREE::GPU::getContractionLayout(IREE::GPU::MMAScheduleAttr scheduleAttr,
+ VectorContractOpInfo &opInfo,
+ linalg::LinalgOp contractOp) {
+ return getContractionLayoutImpl(scheduleAttr, opInfo, contractOp);
+}
+
+::mlir::FailureOr<::std::tuple<IREE::VectorExt::VectorLayoutInterface,
+ IREE::VectorExt::VectorLayoutInterface,
+ IREE::VectorExt::VectorLayoutInterface>>
+IREE::GPU::getContractionLayout(IREE::GPU::MMAScheduleAttr scheduleAttr,
+ VectorContractOpInfo &opInfo,
+ vector::ContractionOp contractOp) {
+ return getContractionLayoutImpl(scheduleAttr, opInfo, contractOp);
+}
+
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h
index 617b774..061e139 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h
@@ -62,6 +62,24 @@
void addBarrier(mlir::FunctionOpInterface funcOp, Operation *alloc,
ArrayRef<Operation *> aliasGroup, bool hasAsyncCopies = true);
+namespace IREE {
+namespace GPU {
+class MMAScheduleAttr;
+
+::llvm::FailureOr<::std::tuple<IREE::VectorExt::VectorLayoutInterface,
+ IREE::VectorExt::VectorLayoutInterface,
+ IREE::VectorExt::VectorLayoutInterface>>
+getContractionLayout(IREE::GPU::MMAScheduleAttr scheduleAttr,
+ VectorContractOpInfo &opInfo, linalg::LinalgOp contractOp);
+
+::llvm::FailureOr<::std::tuple<IREE::VectorExt::VectorLayoutInterface,
+ IREE::VectorExt::VectorLayoutInterface,
+ IREE::VectorExt::VectorLayoutInterface>>
+getContractionLayout(IREE::GPU::MMAScheduleAttr scheduleAttr,
+ VectorContractOpInfo &opInfo,
+ vector::ContractionOp contractOp);
+} // namespace GPU
+} // namespace IREE
} // namespace iree_compiler
} // namespace mlir
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir
index 074664c..a419f4a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir
@@ -207,6 +207,57 @@
workgroup_size = [64, 1, 1]
subgroup_size = 64>
+#maps = [
+ affine_map<(bm, bn, m, n, k) -> (bm, m, k)>,
+ affine_map<(bm, bn, m, n, k) -> (bn, n, k)>,
+ affine_map<(bm, bn, m, n, k) -> (bm, m, bn, n)>
+]
+
+#traits = {
+ indexing_maps = #maps,
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"],
+ lowering_config = #iree_gpu.lowering_config<{promote_operands = [0],
+ mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
+ subgroup_m_count = 2, subgroup_n_count = 2}>
+}
+
+func.func @packed_matmul_128x128x128(%lhs: tensor<8x16x16xf16>,
+ %rhs: tensor<8x16x16xf16>,
+ %init: tensor<8x16x8x16xf32>)
+ -> tensor<8x16x8x16xf32>
+ attributes { translation_info = #translation } {
+ %out = linalg.generic #traits
+ ins(%lhs, %rhs: tensor<8x16x16xf16>, tensor<8x16x16xf16>)
+ outs(%init: tensor<8x16x8x16xf32>) {
+ ^bb0(%in: f16, %in_1: f16, %out: f32):
+ %ex = arith.extf %in : f16 to f32
+ %ex_1 = arith.extf %in_1 : f16 to f32
+ %mul = arith.mulf %ex, %ex_1 : f32
+ %sum = arith.addf %out, %mul : f32
+ linalg.yield %sum : f32
+ } -> tensor<8x16x8x16xf32>
+ return %out : tensor<8x16x8x16xf32>
+}
+
+
+// CHECK-DAG: #[[$NESTED:.+]] = #iree_vector_ext.nested_layout<subgroup_tile = [2, 1, 1], batch_tile = [4, 1, 1], outer_tile = [1, 1, 1], thread_tile = [1, 16, 4], element_tile = [1, 1, 4], subgroup_strides = [2, 0, 0], thread_strides = [0, 1, 16]>
+// CHECK-DAG: #[[$NESTED1:.+]] = #iree_vector_ext.nested_layout<subgroup_tile = [2, 1, 1], batch_tile = [4, 1, 1], outer_tile = [1, 1, 1], thread_tile = [1, 16, 4], element_tile = [1, 1, 4], subgroup_strides = [1, 0, 0], thread_strides = [0, 1, 16]>
+// CHECK-DAG: #[[$NESTED2:.+]] = #iree_vector_ext.nested_layout<subgroup_tile = [2, 1, 2, 1], batch_tile = [4, 1, 4, 1], outer_tile = [1, 1, 1, 1], thread_tile = [1, 4, 1, 16], element_tile = [1, 4, 1, 1], subgroup_strides = [2, 0, 1, 0], thread_strides = [0, 16, 0, 1]>
+// CHECK-LABEL: func.func @packed_matmul_128x128x128
+
+// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, shared_memory_conversion}
+// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>}
+// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>}
+// CHECK: linalg.generic
+// CHECK-SAME: ins(%[[LHS]], %[[RHS]]
+// CHECK-SAME: outs(%[[ACC]]
+
+// -----
+
+#translation = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute
+ workgroup_size = [64, 1, 1]
+ subgroup_size = 64>
+
func.func @linalg_copy(%in : tensor<16x16x16xf16>) -> tensor<16x16x16xf16>
attributes { translation_info = #translation } {
%empty = tensor.empty() : tensor<16x16x16xf16>
diff --git a/tests/external/iree-test-suites/test_suite_files/attention_and_matmul_spec_unet_fp16_mi308.mlir b/tests/external/iree-test-suites/test_suite_files/attention_and_matmul_spec_unet_fp16_mi308.mlir
index b35fc42..3d14b1d 100644
--- a/tests/external/iree-test-suites/test_suite_files/attention_and_matmul_spec_unet_fp16_mi308.mlir
+++ b/tests/external/iree-test-suites/test_suite_files/attention_and_matmul_spec_unet_fp16_mi308.mlir
@@ -27,7 +27,7 @@
transform.iree.match.cast_compatible_type %in0 = tensor<?x?x?x?xf16> : !transform.any_value
%config = transform.param.constant #iree_codegen.compilation_info<
- lowering_config = #iree_gpu.lowering_config<{workgroup = [1, 1, 128, 0, 0, 0], reduction=[0, 0, 0, 0, 0, 64], promote_operands = [1, 2]}>,
+ lowering_config = #iree_gpu.lowering_config<{workgroup = [1, 1, 64, 0, 0, 0], reduction=[0, 0, 0, 0, 0, 64], promote_operands = [1, 2]}>,
translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute
workgroup_size = [64, 4]
subgroup_size = 64 ,