[LLVMGPU] allow multiple m and n dims in contraction distribution (#16943)
This adjusts the layout generation logic to allow distribution of
contractions with multiple m and n dimensions by greedily using the
subgroup/tile_counts of the mma_schedule with the outer dims. The inner
most m/n dimensions are still required to be divisible by the intrinsic
shape. (and this only supports a single k dimension).
This also decouples the ordering logic of the batch/subgroup
distribution from the lane distribution for the intrinsics. Currently it
assumes intrinsics can only specify three important sizes, an M, N, and
K size. To support distributed batches this would require adding a
fourth dim type.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp
index c2a2233..efe43d3 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp
@@ -97,8 +97,8 @@
LLVM_DEBUG(llvm::dbgs() << "init tile: " << finalTile << "\n");
// Offsets into the LHS/RHS batches.
- SmallVector<int64_t, 2> lhsBatchOffsets(rank, 0);
- SmallVector<int64_t, 2> rhsBatchOffsets(rank, 0);
+ SmallVector<int64_t, 2> lhsBatchOffsets(lhsLayout.getRank(), 0);
+ SmallVector<int64_t, 2> rhsBatchOffsets(rhsLayout.getRank(), 0);
// Offsets into the result batches.
ArrayRef<int64_t> resultBatches = resultLayout.getBatchesPerSubgroup();
@@ -183,7 +183,7 @@
std::optional<int64_t> getKBatchSize(const VectorContractOpInfo &opDetail,
NestedLayoutAttr lhsLayout,
NestedLayoutAttr rhsLayout) const {
- auto [lhsK, rhsK] = *opDetail.getOperandKIndex();
+ auto [lhsK, rhsK] = opDetail.getOperandKIndex();
int64_t lhsKBatch = lhsLayout.getBatchesPerSubgroup()[lhsK];
int64_t rhsKBatch = rhsLayout.getBatchesPerSubgroup()[rhsK];
@@ -201,15 +201,21 @@
SmallVector<int64_t, 2> &rhsOffsets,
NestedLayoutAttr lhsLayout,
NestedLayoutAttr rhsLayout) const {
- auto [lhsM, rhsN] = *opDetail.getOperandMNIndex();
- auto [lhsK, rhsK] = *opDetail.getOperandKIndex();
- auto [resultM, resultN] = *opDetail.getResultMNIndex();
+
// resultOffsets contains batch indices into the C/D vector. It is a 2-D
// index for both M and N. We need to split out for M and N, and add index
// for K.
- lhsOffsets[lhsM] = resultOffsets[resultM];
+ for (auto [lhsM, resultM] :
+ llvm::zip_equal(opDetail.lhsMDims, opDetail.outMDims)) {
+ lhsOffsets[lhsM] = resultOffsets[resultM];
+ }
+ for (auto [rhsN, resultN] :
+ llvm::zip_equal(opDetail.rhsNDims, opDetail.outNDims)) {
+ rhsOffsets[rhsN] = resultOffsets[resultN];
+ }
+
+ auto [lhsK, rhsK] = opDetail.getOperandKIndex();
lhsOffsets[lhsK] = kOffset;
- rhsOffsets[rhsN] = resultOffsets[resultN];
rhsOffsets[rhsK] = kOffset;
// Now apply permutation on LHS/RHS according to their batch order.
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir b/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir
index 5b0761f..059b1e7 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir
@@ -243,13 +243,13 @@
%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.0 : f16
%cst0_1 = arith.constant dense<0.0> : vector<16xf16>
- // expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids= [false, true], thread_basis = [4, 16], thread_active_ids= [false, true]}}
+ // expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids = [false, true], thread_basis = [4, 16], thread_active_ids = [false, true]}}
%root = vector.transfer_read %arr[%c0, %c0], %cst_0 {in_bounds = [true, true], "__vector_layout_test_anchor_result_0" = #layout} : memref<16x16xf16>, vector<16x16xf16>
// expected-remark @above {{thread_basis = [4, 16]}}
%root_red = vector.multi_reduction<add>, %root, %cst0_1 [0] : vector<16x16xf16> to vector<16xf16>
- // expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids= [false, true], thread_basis = [4, 16], thread_active_ids= [false, true]}}
+ // expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids = [false, true], thread_basis = [4, 16], thread_active_ids = [false, true]}}
%c = arith.mulf %root_red, %a : vector<16xf16>
- // expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids= [false, true], thread_basis = [4, 16], thread_active_ids= [false, true]}}
+ // expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids = [false, true], thread_basis = [4, 16], thread_active_ids = [false, true]}}
func.return %c : vector<16xf16>
}
@@ -281,13 +281,13 @@
%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.0 : f16
%cst0_1 = arith.constant dense<0.0> : vector<16xf16>
- // expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids= [true, false], thread_basis = [4, 16], thread_active_ids= [true, false]}}
+ // expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids = [true, false], thread_basis = [4, 16], thread_active_ids = [true, false]}}
%root = vector.transfer_read %arr[%c0, %c0], %cst_0 {in_bounds = [true, true], "__vector_layout_test_anchor_result_0" = #layout} : memref<16x16xf16>, vector<16x16xf16>
// expected-remark @above {{thread_basis = [4, 16]}}
%root_red = vector.multi_reduction<add>, %root, %cst0_1 [1] : vector<16x16xf16> to vector<16xf16>
- // expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids= [true, false], thread_basis = [4, 16], thread_active_ids= [true, false]}}
+ // expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids = [true, false], thread_basis = [4, 16], thread_active_ids = [true, false]}}
%c = arith.mulf %root_red, %a : vector<16xf16>
- // expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids= [true, false], thread_basis = [4, 16], thread_active_ids= [true, false]}}
+ // expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids = [true, false], thread_basis = [4, 16], thread_active_ids = [true, false]}}
func.return %c : vector<16xf16>
}
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index 208333e..f27f67d 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -5,6 +5,7 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
+#include <numeric>
#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h"
#include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h"
@@ -522,33 +523,178 @@
// MMA Schedule Attributes
//===----------------------------------------------------------------------===//
-NestedLayoutAttr permuteAndCreateNestedLayout(
- MLIRContext *context, ArrayRef<int64_t> permute,
- SmallVector<int64_t, 2> subgroupCount,
- SmallVector<int64_t, 2> subgroupOrder, SmallVector<int64_t, 2> batchCount,
- SmallVector<int64_t, 2> batchOrder, SmallVector<int64_t, 2> outerCount,
- SmallVector<int64_t, 2> outerOrder, SmallVector<int64_t, 2> threadCount,
- SmallVector<int64_t, 2> threadOrder, SmallVector<int64_t, 2> elementCount,
- SmallVector<int64_t, 2> elementOrder, ArrayRef<int64_t> subgroupBasis,
- ArrayRef<int64_t> threadBasis) {
- if (!isIdentityPermutation(permute)) {
- applyPermutationToVector(subgroupCount, permute);
- applyPermutationToVector(subgroupOrder, permute);
- applyPermutationToVector(batchCount, permute);
- applyPermutationToVector(batchOrder, permute);
- applyPermutationToVector(outerCount, permute);
- applyPermutationToVector(outerOrder, permute);
- applyPermutationToVector(threadCount, permute);
- applyPermutationToVector(threadOrder, permute);
- applyPermutationToVector(elementCount, permute);
- applyPermutationToVector(elementOrder, permute);
- }
+/// Gets a unit vector of the given rank, but fills in the given dimensions
+/// from the 2 element array |counts|. |dim0| is the position in the returned
+/// vector to put the first element of |counts|, and |dim1| is the position to
+/// put the second element. For example,
+///
+/// rank = 3, counts = [5, 7], dim0 = 2, dim1 = 1
+/// returns [1, 5, 7]
+SmallVector<int64_t> getUnitOfRankWithDims(int64_t rank,
+ ArrayRef<int64_t> counts,
+ int64_t dim0, int64_t dim1) {
+ assert(counts.size() == 2 &&
+ "Unexpected non-rank 2 single subgroup dimension counts");
+ SmallVector<int64_t> res(rank, 1);
+ res[dim0] = counts[0];
+ res[dim1] = counts[1];
+ return res;
+}
- return NestedLayoutAttr::get(
+SmallVector<int64_t> getIdentityPerm(int64_t rank) {
+ return llvm::to_vector(llvm::seq(static_cast<int64_t>(0), rank));
+}
+
+/// Constructs an identity permutation with the given rank, except it applies
+/// the given rank-2 |perm| to the two dimensions |dim0| and |dim1|, and then
+/// swaps the positions of dim0 and dim1 in the final permutation. For example,
+///
+/// rank = 3, perm = [1, 0], dim0 = 1, dim1 = 2
+/// returns [0, 1, 2]
+///
+/// This is essentially just applying two rank-2 permutations to two particular
+/// dimensions. First it applies |perm|, which corresponds to a permutation
+/// needed by the underlying intrinsic, then it does another permutation based
+/// on the order of actual dimensions for the MMA fragment. For example, for the
+/// B matrix, dim0 = K and dim1 = N, so for the element order of an MFMA
+/// 16x16x16, perm would be `[1, 0]`, however if the actual contraction is a
+/// matmul_transpose_b, then the element order needs to be [0, 1].
+SmallVector<int64_t> getIdentityPermWithSwap(int64_t rank,
+ ArrayRef<int64_t> perm,
+ int64_t dim0, int64_t dim1) {
+ assert(perm.size() == 2 &&
+ "Unexpected non-rank 2 single subgroup dimension order");
+ SmallVector<int64_t> res = getIdentityPerm(rank);
+ if (perm[0] > perm[1]) {
+ std::swap(dim0, dim1);
+ }
+ if (dim0 > dim1) {
+ res[dim0] = dim1;
+ res[dim1] = dim0;
+ }
+ return res;
+}
+
+/// Constructs the nested layout given the layout for a single subgroup and the
+/// subgroup/batch counts and orders, as well as the dimensions along which to
+/// distribute the intrinsic's layout.
+///
+/// |outerDim| and |innerDim| refer to which dimensions are the outermost and
+/// innermost for a canonical MK_KN_MN matrix multiply, for a particular
+/// fragment. For example, for the B matrix of an MK_NK_MN matrix multiply,
+/// we would have:
+/// outerDim = 1 for the K dim
+/// innerDim = 0 for the N dim
+///
+/// For something like MK_NKN_MN with multiple N dims, it would typically be:
+/// outerDim = 1 for K
+/// innerDim = 2 for the second N dim
+///
+/// Importantly these two dimensions always refer to the actual dimension
+/// positions in the undistributed vector. For each fragment, this means:
+/// A: [outerDim, innerDim] = [innerMostMDim, innerMostKDim]
+/// B: [outerDim, innerDim] = [innerMostKDim, innerMostNDim]
+/// C: [outerDim, innerDim] = [innerMostMDim, innerMostNDim]
+///
+/// And here inner most is referential to the iteration order, not the order
+/// they appear per fragment (because there is no relationship between the
+/// dimension order of M in A and in C, for example).
+NestedLayoutAttr permuteAndCreateNestedLayout(
+ MLIRContext *context, int64_t rank, int64_t outerDim, int64_t innerDim,
+ SmallVector<int64_t> subgroupCount, SmallVector<int64_t> subgroupOrder,
+ SmallVector<int64_t> batchCount, SmallVector<int64_t> batchOrder,
+ MMAAttr::SingleSubgroupLayout counts, MMAAttr::SingleSubgroupLayout orders,
+ ArrayRef<int64_t> dataDuplicate, ArrayRef<int64_t> subgroupBasis,
+ ArrayRef<bool> subgroupActiveIds) {
+
+ LLVM_DEBUG({
+ llvm::errs() << "Given:";
+ llvm::errs() << "\n outerDim = " << outerDim;
+ llvm::errs() << "\n innerDim = " << innerDim;
+ llvm::errs() << "\n subgroupCount: ";
+ llvm::interleaveComma(subgroupCount, llvm::errs());
+ llvm::errs() << "\n subgroupOrder: ";
+ llvm::interleaveComma(subgroupOrder, llvm::errs());
+ llvm::errs() << "\n batchCount: ";
+ llvm::interleaveComma(batchCount, llvm::errs());
+ llvm::errs() << "\n batchOrder: ";
+ llvm::interleaveComma(batchOrder, llvm::errs());
+ llvm::errs() << "\n counts.outer: ";
+ llvm::interleaveComma(counts.outer, llvm::errs());
+ llvm::errs() << "\n orders.outer: ";
+ llvm::interleaveComma(orders.outer, llvm::errs());
+ llvm::errs() << "\n counts.thread: ";
+ llvm::interleaveComma(counts.thread, llvm::errs());
+ llvm::errs() << "\n orders.thread: ";
+ llvm::interleaveComma(orders.thread, llvm::errs());
+ llvm::errs() << "\n counts.element: ";
+ llvm::interleaveComma(counts.element, llvm::errs());
+ llvm::errs() << "\n orders.element: ";
+ llvm::interleaveComma(orders.element, llvm::errs());
+ llvm::errs() << "\n subgroupBasis: ";
+ llvm::interleaveComma(subgroupBasis, llvm::errs());
+ llvm::errs() << "\n subgroupActiveIds: ";
+ llvm::interleaveComma(subgroupActiveIds, llvm::errs());
+ llvm::errs() << "\n";
+ });
+
+ SmallVector<int64_t> outerOrder =
+ getIdentityPermWithSwap(rank, orders.outer, outerDim, innerDim);
+ SmallVector<int64_t> threadOrder =
+ getIdentityPermWithSwap(rank, orders.thread, outerDim, innerDim);
+ SmallVector<int64_t> elementOrder =
+ getIdentityPermWithSwap(rank, orders.element, outerDim, innerDim);
+
+ SmallVector<int64_t> threadBasis =
+ getUnitOfRankWithDims(rank, counts.thread, outerDim, innerDim);
+ threadBasis[outerDim] *= dataDuplicate[0];
+ threadBasis[innerDim] *= dataDuplicate[1];
+ applyPermutationToVector(threadBasis, threadOrder);
+
+ SmallVector<int64_t> outerCount =
+ getUnitOfRankWithDims(rank, counts.outer, outerDim, innerDim);
+ SmallVector<int64_t> threadCount =
+ getUnitOfRankWithDims(rank, counts.thread, outerDim, innerDim);
+ SmallVector<int64_t> elementCount =
+ getUnitOfRankWithDims(rank, counts.element, outerDim, innerDim);
+
+ LLVM_DEBUG({
+ llvm::errs() << "\nNew layout attr:";
+ llvm::errs() << "\n subgroupCount: ";
+ llvm::interleaveComma(subgroupCount, llvm::errs());
+ llvm::errs() << "\n subgroupOrder: ";
+ llvm::interleaveComma(subgroupOrder, llvm::errs());
+ llvm::errs() << "\n batchCount: ";
+ llvm::interleaveComma(batchCount, llvm::errs());
+ llvm::errs() << "\n batchOrder: ";
+ llvm::interleaveComma(batchOrder, llvm::errs());
+ llvm::errs() << "\n outerCount: ";
+ llvm::interleaveComma(outerCount, llvm::errs());
+ llvm::errs() << "\n outerOrder: ";
+ llvm::interleaveComma(outerOrder, llvm::errs());
+ llvm::errs() << "\n threadCount: ";
+ llvm::interleaveComma(threadCount, llvm::errs());
+ llvm::errs() << "\n threadOrder: ";
+ llvm::interleaveComma(threadOrder, llvm::errs());
+ llvm::errs() << "\n elementCount: ";
+ llvm::interleaveComma(elementCount, llvm::errs());
+ llvm::errs() << "\n elementOrder: ";
+ llvm::interleaveComma(elementOrder, llvm::errs());
+ llvm::errs() << "\n subgroupBasis: ";
+ llvm::interleaveComma(subgroupBasis, llvm::errs());
+ llvm::errs() << "\n subgroupActiveIds: ";
+ llvm::interleaveComma(subgroupActiveIds, llvm::errs());
+ llvm::errs() << "\n threadBasis: ";
+ llvm::interleaveComma(threadBasis, llvm::errs());
+ llvm::errs() << "\n";
+ });
+
+ auto layoutAttr = NestedLayoutAttr::get(
context, subgroupCount, subgroupOrder, batchCount, batchOrder, outerCount,
outerOrder, threadCount, threadOrder, elementCount, elementOrder,
- subgroupBasis, SmallVector<bool>(subgroupBasis.size(), true), threadBasis,
+ subgroupBasis, subgroupActiveIds, threadBasis,
SmallVector<bool>(threadBasis.size(), true));
+ return layoutAttr;
}
std::optional<std::tuple<VectorExt::VectorLayoutInterface,
@@ -556,19 +702,19 @@
VectorExt::VectorLayoutInterface>>
MMAScheduleAttr::getContractionLayout(vector::ContractionOp contractOp) const {
VectorContractOpInfo opInfo(contractOp);
+ LLVM_DEBUG({
+ llvm::errs() << "Getting mma layouts for:\n" << contractOp << "\n";
+ llvm::errs() << "For schedule: " << *this << "\n";
+ });
if (opInfo.getOpKind() == VectorContractOpInfo::OpKind::UNKNOWN)
return std::nullopt;
- auto [aM, bN] = *opInfo.getOperandMNIndex();
- auto [aK, bK] = *opInfo.getOperandKIndex();
- auto [cM, cN] = *opInfo.getResultMNIndex();
- SmallVector<int64_t, 2> aPermute = {aM, aK};
- SmallVector<int64_t, 2> bPermute = {bK, bN};
- SmallVector<int64_t, 2> cPermute = {cM, cN};
-
auto mmaAttr = llvm::cast<MMAAttr>(getIntrinsic());
MLIRContext *context = getContext();
+ SmallVector<int64_t> bounds;
+ contractOp.getIterationBounds(bounds);
+
// Get the concrete nested layout for each matrix. Note that the struct
// MMAAttr::SingleSubgroupLayout contains the partial layout for the
// canonical (M, K) x (K, N) -> (M, N) matmul form; while the specific
@@ -583,8 +729,71 @@
// the linearized GPU hardware lane ID into a n-D concatenated logical
// warp+thread using the subgroup/thread basis, so the subgroup basis should
// remain the same for all A/B/C matrix.
- SmallVector<int64_t, 2> subgroupBasis = {getSubgroupMCount(),
- getSubgroupNCount()};
+
+ SmallVector<int64_t, 2> subgroupMBasis;
+ SmallVector<int64_t, 2> batchMSizes;
+ int64_t currMCount = getSubgroupMCount();
+ int64_t currMBatch = getSubgroupMTileCount();
+
+ // Greedily break up the M subgroup and batch counts along the "M" iteration
+ // bounds. We distribute as many residual subgroups as possible per M dim, and
+ // then divide the remaining along batch dims. The inner most M dim is always
+ // the one used for the intrinsic, meaning for a valid schedule, the computed
+ // batch counts and subgroup basis will satisfy
+ // totalMSize / intrinsicM = product(batchMSizes) * product(subgroupMBasis)
+ for (auto dim : opInfo.getMDims()) {
+ int64_t threads = std::gcd(currMCount, bounds[dim]);
+ subgroupMBasis.push_back(threads);
+ currMCount /= threads;
+ int64_t batchCount = bounds[dim] / threads;
+ batchCount = batchCount >= currMBatch ? currMBatch : batchCount;
+ batchMSizes.push_back(batchCount);
+ currMBatch /= batchCount;
+ }
+
+ SmallVector<int64_t, 2> subgroupNBasis;
+ SmallVector<int64_t, 2> batchNSizes;
+ int64_t currNCount = getSubgroupNCount();
+ int64_t currNBatch = getSubgroupNTileCount();
+
+ // Do the same for N dims.
+ for (auto dim : opInfo.getNDims()) {
+ int64_t threads = std::gcd(currNCount, bounds[dim]);
+ subgroupNBasis.push_back(threads);
+ currNCount /= threads;
+ int64_t batchCount = bounds[dim] / threads;
+ batchCount = batchCount >= currNBatch ? currNBatch : batchCount;
+ batchNSizes.push_back(batchCount);
+ currNBatch /= batchCount;
+ }
+
+ SmallVector<int64_t> subgroupBasis;
+ auto mDimVec = opInfo.getMDims();
+ llvm::SmallDenseSet<int64_t> mDims(mDimVec.begin(), mDimVec.end());
+ auto nDimVec = opInfo.getNDims();
+ llvm::SmallDenseSet<int64_t> nDims(nDimVec.begin(), nDimVec.end());
+
+ int64_t currM = 0;
+ int64_t currN = 0;
+ // Because we currently require all batch dimensions to be unit, the subgroup
+ // basis can be constructed from the M and N bases. To keep things simple,
+ // the current heuristic is to distribute all M dims followed by all N dims.
+ for (auto dim : llvm::seq(static_cast<int64_t>(0), opInfo.getCRank())) {
+ if (mDims.contains(dim)) {
+ subgroupBasis.push_back(subgroupMBasis[currM]);
+ // Construct mDimVec such that it contains the order in which the M dims
+ // appear in the C matrix.
+ mDimVec[currM] = dim;
+ currM++;
+ }
+ if (nDims.contains(dim)) {
+ subgroupBasis.push_back(subgroupNBasis[currN]);
+ // Construct nDimVec such that it contains the order in which the N dims
+ // appear in the C matrix.
+ nDimVec[currN] = dim;
+ currN++;
+ }
+ }
// For threads though, we also need to make sure the basis is consistent
// across A, B, and C matrix. Though here we need to additionally think it
@@ -606,23 +815,43 @@
MMAAttr::SingleSubgroupLayout cOrders =
mmaAttr.getCSingleSubgroupLayoutOrder();
- SmallVector<int64_t, 2> cThreadBasis = cCounts.thread;
- SmallVector<int64_t, 2> cDataDuplicate = mmaAttr.getCDataDuplicate();
- for (auto [idx, duplicateFactor] : llvm::enumerate(cDataDuplicate)) {
- cThreadBasis[idx] *= duplicateFactor;
+ auto [m, n] = opInfo.getResultMNIndex();
+ int64_t cRank = opInfo.getCRank();
+
+ // Get the M and N dims w.r.t. the dimensions of the C matrix. cMDims and
+ // cNDims are the M and N dimensions of the C matrix in the order they are
+ // iterated over in the contraction.
+ SmallVector<int64_t> cMDims = opInfo.outMDims;
+ SmallVector<int64_t> cNDims = opInfo.outNDims;
+ SmallVector<int64_t> cBatchSizes(cRank, 1);
+ SmallVector<int64_t> cSubgroupSizes(cRank, 1);
+ SmallVector<int64_t> cOverallOrder(cRank, 0);
+ for (auto [i, dim] : llvm::enumerate(cMDims)) {
+ cBatchSizes[dim] = batchMSizes[i];
+ cSubgroupSizes[dim] = subgroupMBasis[i];
+ cOverallOrder[dim] = mDimVec[i];
}
- applyPermutationToVector(cThreadBasis, cOrders.thread);
+ for (auto [i, dim] : llvm::enumerate(cNDims)) {
+ cBatchSizes[dim] = batchNSizes[i];
+ cSubgroupSizes[dim] = subgroupNBasis[i];
+ cOverallOrder[dim] = nDimVec[i];
+ }
+
+ // Dummy 1 for the k dimension.
+ subgroupBasis.push_back(1);
+
+ SmallVector<bool> cActiveSubgroups(cRank + 1, true);
+ cActiveSubgroups.back() = false;
auto cLayout = permuteAndCreateNestedLayout(
- context, cPermute,
- /*subgroupCount=*/{getSubgroupMCount(), getSubgroupNCount()},
- /*subgroupOrder=*/{0, 1},
- /*batchCount=*/{getSubgroupMTileCount(), getSubgroupNTileCount()},
- /*batchOrder=*/{0, 1}, /*outerCount=*/cCounts.outer,
- /*outerOrder=*/cOrders.outer, /*threadCount=*/cCounts.thread,
- /*threadOrder=*/cOrders.thread,
- /*elementCount=*/cCounts.element, /*elementOrder=*/cOrders.element,
- subgroupBasis, cThreadBasis);
+ context, cRank, m, n,
+ /*subgroupCount=*/cSubgroupSizes,
+ /*subgroupOrder=*/cOverallOrder,
+ /*batchCount=*/cBatchSizes,
+ /*batchOrder=*/cOverallOrder, cCounts, cOrders,
+ /*dataDuplicate=*/mmaAttr.getCDataDuplicate(), subgroupBasis,
+ cActiveSubgroups);
+ LLVM_DEBUG({ llvm::errs() << "C layout: " << cLayout << "\n"; });
// A matrix layout
MMAAttr::SingleSubgroupLayout aCounts =
@@ -630,23 +859,41 @@
MMAAttr::SingleSubgroupLayout aOrders =
mmaAttr.getASingleSubgroupLayoutOrder();
- SmallVector<int64_t, 2> aThreadBasis = aCounts.thread;
- SmallVector<int64_t, 2> aDataDuplicate = mmaAttr.getADataDuplicate();
- for (auto [idx, duplicateFactor] : llvm::enumerate(aDataDuplicate)) {
- aThreadBasis[idx] *= duplicateFactor;
+ auto [afm, bfn] = opInfo.getOperandMNIndex();
+ auto [afk, bfk] = opInfo.getOperandKIndex();
+
+ int64_t aRank = opInfo.getARank();
+
+ SmallVector<int64_t> aMDims = opInfo.lhsMDims;
+ SmallVector<int64_t> aBatchSizes(aRank, 1);
+ SmallVector<int64_t> aSubgroupSizes(aRank, 1);
+ SmallVector<int64_t> aSubgroupOrder(aRank, 0);
+ SmallVector<int64_t> aBatchOrder(aRank, 0);
+ for (auto [i, dim] : llvm::enumerate(aMDims)) {
+ aBatchSizes[dim] = batchMSizes[i];
+ aSubgroupSizes[dim] = subgroupMBasis[i];
+ aSubgroupOrder[dim] = i;
+ aBatchOrder[dim] = i >= afk ? i + 1 : i;
}
- applyPermutationToVector(aThreadBasis, aOrders.thread);
+ aSubgroupOrder[afk] = aRank - 1;
+ aBatchOrder[afk] = afk;
+ aBatchSizes[afk] = getSubgroupKTileCount();
+
+ SmallVector<bool> aActiveSubgroups(subgroupBasis.size(), false);
+ for (auto mDim : mDims) {
+ aActiveSubgroups[mDim] = true;
+ }
+ aActiveSubgroups.back() = true;
auto aLayout = permuteAndCreateNestedLayout(
- context, aPermute,
- /*subgroupCount=*/{getSubgroupMCount(), 1},
- /*subgroupOrder=*/{0, 1},
- /*batchCount=*/{getSubgroupMTileCount(), getSubgroupKTileCount()},
- /*batchOrder=*/{0, 1}, /*outerCount=*/aCounts.outer,
- /*outerOrder=*/aOrders.outer, /*threadCount=*/aCounts.thread,
- /*threadOrder=*/aOrders.thread,
- /*elementCount=*/aCounts.element, /*elementOrder=*/aOrders.element,
- subgroupBasis, aThreadBasis);
+ context, aRank, afm, afk,
+ /*subgroupCount=*/aSubgroupSizes,
+ /*subgroupOrder=*/aSubgroupOrder,
+ /*batchCount=*/aBatchSizes,
+ /*batchOrder=*/getIdentityPerm(aRank), aCounts, aOrders,
+ /*dataDuplicate=*/mmaAttr.getADataDuplicate(), subgroupBasis,
+ aActiveSubgroups);
+ LLVM_DEBUG({ llvm::errs() << "A layout: " << aLayout << "\n"; });
// B matrix layout
MMAAttr::SingleSubgroupLayout bCounts =
@@ -654,23 +901,38 @@
MMAAttr::SingleSubgroupLayout bOrders =
mmaAttr.getBSingleSubgroupLayoutOrder();
- SmallVector<int64_t, 2> bThreadBasis = bCounts.thread;
- SmallVector<int64_t, 2> bDataDuplicate = mmaAttr.getBDataDuplicate();
- for (auto [idx, duplicateFactor] : llvm::enumerate(bDataDuplicate)) {
- bThreadBasis[idx] *= duplicateFactor;
+ int64_t bRank = opInfo.getBRank();
+
+ SmallVector<int64_t> bNDims = opInfo.rhsNDims;
+ SmallVector<int64_t> bBatchSizes(bRank, 1);
+ SmallVector<int64_t> bSubgroupSizes(bRank, 1);
+ SmallVector<int64_t> bSubgroupOrder(bRank, 0);
+ SmallVector<int64_t> bBatchOrder(bRank, 0);
+ for (auto [i, dim] : llvm::enumerate(bNDims)) {
+ bBatchSizes[dim] = batchNSizes[i];
+ bSubgroupSizes[dim] = subgroupNBasis[i];
+ bSubgroupOrder[dim] = i;
+ bBatchOrder[dim] = i >= bfk ? i + 1 : i;
}
- applyPermutationToVector(bThreadBasis, bOrders.thread);
+ bSubgroupOrder[bfk] = bRank - 1;
+ bBatchOrder[bfk] = bfk;
+ bBatchSizes[bfk] = getSubgroupKTileCount();
+
+ SmallVector<bool> bActiveSubgroups(subgroupBasis.size(), false);
+ for (auto nDim : nDims) {
+ bActiveSubgroups[nDim] = true;
+ }
+ bActiveSubgroups.back() = true;
auto bLayout = permuteAndCreateNestedLayout(
- context, bPermute,
- /*subgroupCount=*/{1, getSubgroupNCount()},
- /*subgroupOrder=*/{0, 1},
- /*batchCount=*/{getSubgroupKTileCount(), getSubgroupNTileCount()},
- /*batchOrder=*/{0, 1}, /*outerCount=*/bCounts.outer,
- /*outerOrder=*/bOrders.outer, /*threadCount=*/bCounts.thread,
- /*threadOrder=*/bOrders.thread,
- /*elementCount=*/bCounts.element, /*elementOrder=*/bOrders.element,
- subgroupBasis, bThreadBasis);
+ context, bRank, bfk, bfn,
+ /*subgroupCount=*/bSubgroupSizes,
+ /*subgroupOrder=*/bSubgroupOrder,
+ /*batchCount=*/bBatchSizes,
+ /*batchOrder=*/bBatchOrder, bCounts, bOrders,
+ /*dataDuplicate=*/mmaAttr.getBDataDuplicate(), subgroupBasis,
+ bActiveSubgroups);
+ LLVM_DEBUG({ llvm::errs() << "B layout: " << bLayout << "\n"; });
return std::make_tuple(aLayout, bLayout, cLayout);
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp
index e51fef1..3aead0a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp
@@ -45,7 +45,6 @@
auto [dstAElemType, dstBElemType, dstCElemType] =
intrinsic.getABCElementTypes();
- auto [dstM, dstN, dstK] = intrinsic.getMNKShape();
auto srcCElemFType = dyn_cast<FloatType>(srcCType.getElementType());
auto dstCElemFType = dyn_cast<FloatType>(dstCElemType);
@@ -60,16 +59,6 @@
return rewriter.notifyMatchFailure(contractOp, "a/b type mismatch");
}
- auto [srcCMIndex, srcCNIndex] = *opInfo.getResultMNIndex();
- auto [srcAKIndex, srcBKIndex] = *opInfo.getOperandKIndex();
- int64_t srcM = srcCType.getShape()[srcCMIndex];
- int64_t srcN = srcCType.getShape()[srcCNIndex];
- int64_t srcK = srcAType.getShape()[srcAKIndex];
-
- if (srcM % dstM != 0 || srcN % dstN != 0 || srcK % dstK != 0) {
- return rewriter.notifyMatchFailure(contractOp, "shape cannot divide");
- }
-
Location loc = contractOp.getLoc();
auto dstCType = srcCType.clone(dstCElemFType);
auto extOp =
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir
index d252c78..856cc01 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir
@@ -45,26 +45,6 @@
// -----
-func.func @mfma_matmul_96x64x16_mm_cannot_divide(%lhs: vector<95x16xf16>, %rhs: vector<16x64xf16>, %init: vector<95x64xf16>) -> vector<95x64xf16> attributes {
- mma_schedule = #iree_gpu.mma_schedule<
- intrinsic = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>,
- subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 3, subgroup_n_tile_count = 2, subgroup_k_tile_count = 2>,
- workgroup_size = [64, 1, 1]} {
- %0 = vector.contract {
- indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
- %lhs, %rhs, %init : vector<95x16xf16>, vector<16x64xf16> into vector<95x64xf16>
- return %0 : vector<95x64xf16>
-}
-
-// CHECK-LABEL: func.func @mfma_matmul_96x64x16_mm_cannot_divide
-// CHECK-NOT: arith.extf
-// CHECK: vector.contract
-// CHECK-SAME: %{{.+}}, %{{.+}}, %{{.+}} : vector<95x16xf16>, vector<16x64xf16> into vector<95x64xf16>
-// CHECK-NOT: arith.truncf
-
-// -----
-
func.func @mfma_matmul_96x64x16_mm_cannot_downcast(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf64>) -> vector<96x64xf64> attributes {
mma_schedule = #iree_gpu.mma_schedule<
intrinsic = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>,
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir
index b8ee6d5..06c5aec 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir
@@ -15,15 +15,15 @@
// CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 4],
// CHECK-SAME: thread_order = [1, 0]
-// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [2, 32]>
+// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, false, true], thread_basis = [2, 32]>
// CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [2, 32], elements_per_thread = [4, 1],
// CHECK-SAME: element_order = [1, 0],
-// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [2, 32]>
+// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [false, true, true], thread_basis = [2, 32]>
// CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [4, 1], threads_per_outer = [2, 32], elements_per_thread = [4, 1],
// CHECK-SAME: element_order = [1, 0],
-// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [2, 32]>
+// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, true, false], thread_basis = [2, 32]>
// -----
@@ -42,15 +42,15 @@
// CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 4],
// CHECK-SAME: thread_order = [1, 0]
-// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [2, 32]>
+// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, false, true], thread_basis = [2, 32]>
// CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 4],
-// CHECK-SAME: subgroup_order = [1, 0], batch_order = [1, 0], outer_order = [1, 0], thread_order = [1, 0]
-// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [2, 32]>
+// CHECK-SAME: outer_order = [1, 0], thread_order = [1, 0]
+// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [false, true, true], thread_basis = [2, 32]>
// CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [4, 1], threads_per_outer = [2, 32], elements_per_thread = [4, 1],
-// CHECK-SAME: element_order = [1, 0],
-// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [2, 32]>
+// CHECK-SAME: element_order = [1, 0]
+// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, true, false], thread_basis = [2, 32]>
// -----
@@ -117,15 +117,15 @@
// CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 2], outers_per_batch = [1, 1], threads_per_outer = [16, 4], elements_per_thread = [1, 4],
// CHECK-SAME: thread_order = [1, 0]
-// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [4, 16]>
+// CHECK-SAME: subgroup_active_ids = [true, false, true], thread_basis = [4, 16]>
// CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [4, 1],
-// CHECK-SAME: element_order = [1, 0],
-// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [4, 16]>
+// CHECK-SAME: subgroup_order = [1, 0], element_order = [1, 0],
+// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [false, true, true], thread_basis = [4, 16]>
// CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [4, 1],
// CHECK-SAME: element_order = [1, 0],
-// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [4, 16]>
+// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, true, false], thread_basis = [4, 16]>
// -----
@@ -178,15 +178,15 @@
// CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 2], outers_per_batch = [1, 1], threads_per_outer = [16, 4], elements_per_thread = [1, 4],
// CHECK-SAME: thread_order = [1, 0]
-// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [4, 16]>
+// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, false, true], thread_basis = [4, 16]>
// CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [4, 1],
// CHECK-SAME: element_order = [1, 0],
-// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [4, 16]>
+// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [false, true, true], thread_basis = [4, 16]>
// CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [4, 1],
// CHECK-SAME: element_order = [1, 0],
-// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [4, 16]>
+// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, true, false], thread_basis = [4, 16]>
// -----
@@ -243,15 +243,15 @@
// CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [1, 1], threads_per_outer = [16, 1], elements_per_thread = [1, 16],
// CHECK-SAME: thread_order = [1, 0]
-// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [1, 32]>
+// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, false, true], thread_basis = [1, 32]>
// CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [1, 16], elements_per_thread = [16, 1],
// CHECK-SAME: element_order = [1, 0],
-// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [1, 32]>
+// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [false, true, true], thread_basis = [1, 32]>
// CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [8, 1], threads_per_outer = [2, 16], elements_per_thread = [1, 1],
// CHECK-SAME: element_order = [1, 0],
-// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [2, 16]>
+// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, true, false], thread_basis = [2, 16]>
// -----
@@ -270,12 +270,113 @@
// CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [1, 1], threads_per_outer = [16, 1], elements_per_thread = [1, 16],
// CHECK-SAME: thread_order = [1, 0]
-// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [1, 32]>
+// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, false, true], thread_basis = [1, 32]>
// CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [16, 1], elements_per_thread = [1, 16],
-// CHECK-SAME: subgroup_order = [1, 0], batch_order = [1, 0], outer_order = [1, 0], thread_order = [1, 0],
-// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [1, 32]>
+// CHECK-SAME: outer_order = [1, 0], thread_order = [1, 0],
+// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [false, true, true], thread_basis = [1, 32]>
// CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [8, 1], threads_per_outer = [2, 16], elements_per_thread = [1, 1],
// CHECK-SAME: element_order = [1, 0],
-// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [2, 16]>
+// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, true, false], thread_basis = [2, 16]>
+
+// -----
+
+func.func @matmul_192x64x16_mmt_multi_m(%lhs: vector<2x64x16xf16>, %rhs: vector<16x64xf16>, %init: vector<2x64x64xf32>) -> vector<2x64x64xf32> attributes {
+ mma_schedule = #iree_gpu.mma_schedule<
+ intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>,
+ subgroup_m_count = 2, subgroup_n_count = 1, subgroup_m_tile_count = 4, subgroup_n_tile_count = 4, subgroup_k_tile_count = 1>,
+ workgroup_size = [64, 2, 1]} {
+ %0 = vector.contract {
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %lhs, %rhs, %init : vector<2x64x16xf16>, vector<16x64xf16> into vector<2x64x64xf32>
+ return %0 : vector<2x64x64xf32>
+}
+
+// CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
+// CHECK-SAME: subgroups_per_workgroup = [2, 1, 1],
+// CHECK-SAME: batches_per_subgroup = [1, 4, 1],
+// CHECK-SAME: outers_per_batch = [1, 1, 1],
+// CHECK-SAME: threads_per_outer = [1, 16, 4],
+// CHECK-SAME: elements_per_thread = [1, 1, 4],
+// CHECK-SAME: thread_order = [0, 2, 1],
+// CHECK-SAME: subgroup_basis = [2, 1, 1, 1],
+// CHECK-SAME: subgroup_active_ids = [true, true, false, true],
+// CHECK-SAME: thread_basis = [1, 4, 16]>
+// CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
+// CHECK-SAME: subgroups_per_workgroup = [1, 1],
+// CHECK-SAME: batches_per_subgroup = [1, 4],
+// CHECK-SAME: outers_per_batch = [1, 1],
+// CHECK-SAME: threads_per_outer = [4, 16],
+// CHECK-SAME: elements_per_thread = [4, 1],
+// CHECK-SAME: subgroup_order = [1, 0],
+// CHECK-SAME: element_order = [1, 0],
+// CHECK-SAME: subgroup_basis = [2, 1, 1, 1],
+// CHECK-SAME: subgroup_active_ids = [false, false, true, true],
+// CHECK-SAME: thread_basis = [4, 16]>
+// CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
+// CHECK-SAME: subgroups_per_workgroup = [2, 1, 1],
+// CHECK-SAME: batches_per_subgroup = [1, 4, 4],
+// CHECK-SAME: outers_per_batch = [1, 1, 1],
+// CHECK-SAME: threads_per_outer = [1, 4, 16],
+// CHECK-SAME: elements_per_thread = [1, 4, 1],
+// CHECK-SAME: element_order = [0, 2, 1],
+// CHECK-SAME: subgroup_basis = [2, 1, 1, 1],
+// CHECK-SAME: subgroup_active_ids = [true, true, true, false],
+// CHECK-SAME: thread_basis = [1, 4, 16]>
+
+// -----
+
+func.func @matmul_192x64x16_mmt_multi_split_m(%lhs: vector<2x64x16xf16>, %rhs: vector<16x64xf16>, %init: vector<2x64x64xf32>) -> vector<2x64x64xf32> attributes {
+ mma_schedule = #iree_gpu.mma_schedule<
+ intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>,
+ subgroup_m_count = 4, subgroup_n_count = 1, subgroup_m_tile_count = 2, subgroup_n_tile_count = 4, subgroup_k_tile_count = 1>,
+ workgroup_size = [64, 2, 1]} {
+ %0 = vector.contract {
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %lhs, %rhs, %init : vector<2x64x16xf16>, vector<16x64xf16> into vector<2x64x64xf32>
+ return %0 : vector<2x64x64xf32>
+}
+
+// CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
+// CHECK-SAME: subgroups_per_workgroup = [2, 2, 1],
+// CHECK-SAME: batches_per_subgroup = [1, 2, 1],
+// CHECK-SAME: subgroup_basis = [2, 2, 1, 1],
+// CHECK-SAME: subgroup_active_ids = [true, true, false, true]
+// CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
+// CHECK-SAME: subgroups_per_workgroup = [2, 2, 1],
+// CHECK-SAME: batches_per_subgroup = [1, 2, 4],
+// CHECK-SAME: subgroup_basis = [2, 2, 1, 1],
+// CHECK-SAME: subgroup_active_ids = [true, true, true, false]
+
+// -----
+
+func.func @matmul_192x64x16_mmt_multi_m_and_n(%lhs: vector<4x64x16xf16>, %rhs: vector<2x16x64xf16>, %init: vector<4x2x64x64xf32>) -> vector<4x2x64x64xf32> attributes {
+ mma_schedule = #iree_gpu.mma_schedule<
+ intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>,
+ subgroup_m_count = 2, subgroup_n_count = 2, subgroup_m_tile_count = 8, subgroup_n_tile_count = 4, subgroup_k_tile_count = 1>,
+ workgroup_size = [128, 2, 1]} {
+ %0 = vector.contract {
+ indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %lhs, %rhs, %init : vector<4x64x16xf16>, vector<2x16x64xf16> into vector<4x2x64x64xf32>
+ return %0 : vector<4x2x64x64xf32>
+}
+
+// CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
+// CHECK-SAME: subgroups_per_workgroup = [2, 1, 1],
+// CHECK-SAME: batches_per_subgroup = [2, 4, 1],
+// CHECK-SAME: subgroup_basis = [2, 2, 1, 1, 1],
+// CHECK-SAME: subgroup_active_ids = [true, false, true, false, true]
+// CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
+// CHECK-SAME: subgroups_per_workgroup = [2, 1, 1],
+// CHECK-SAME: batches_per_subgroup = [1, 1, 4],
+// CHECK-SAME: subgroup_basis = [2, 2, 1, 1, 1],
+// CHECK-SAME: subgroup_active_ids = [false, true, false, true, true]
+// CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
+// CHECK-SAME: subgroups_per_workgroup = [2, 2, 1, 1],
+// CHECK-SAME: batches_per_subgroup = [2, 1, 4, 4],
+// CHECK-SAME: subgroup_basis = [2, 2, 1, 1, 1],
+// CHECK-SAME: subgroup_active_ids = [true, true, true, true, false]
diff --git a/compiler/src/iree/compiler/Codegen/Utils/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Utils/BUILD.bazel
index eec2e7c..20c69c2 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Utils/BUILD.bazel
@@ -78,6 +78,7 @@
deps = [
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
+ "@llvm-project//mlir:LinalgDialect",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:VectorDialect",
],
diff --git a/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt
index cefe73e..0c22b45 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt
@@ -72,6 +72,7 @@
DEPS
LLVMSupport
MLIRIR
+ MLIRLinalgDialect
MLIRSupport
MLIRVectorDialect
PUBLIC
diff --git a/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.cpp
index 826ba1d..c5be60b 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.cpp
@@ -6,62 +6,61 @@
#include "iree/compiler/Codegen/Utils/VectorOpUtils.h"
+#include "llvm/ADT/STLExtras.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
namespace mlir::iree_compiler {
-std::optional<std::pair<int, int>>
-VectorContractOpInfo::getOperandMNIndex() const {
- switch (opKind) {
- case OpKind::MK_KN_MN:
- return std::make_pair(0, 1);
- case OpKind::MK_NK_MN:
- return std::make_pair(0, 0);
- case OpKind::UNKNOWN:
- break;
- }
- return std::nullopt;
+std::pair<int, int> VectorContractOpInfo::getOperandMNIndex() const {
+ return std::make_pair(lhsMDims.back(), rhsNDims.back());
}
// Returns the (LHS K, RHS K) dimension index pair.
-std::optional<std::pair<int, int>>
-VectorContractOpInfo::getOperandKIndex() const {
- switch (opKind) {
- case OpKind::MK_KN_MN:
- return std::make_pair(1, 0);
- case OpKind::MK_NK_MN:
- return std::make_pair(1, 1);
- case OpKind::UNKNOWN:
- break;
- }
- return std::nullopt;
+std::pair<int, int> VectorContractOpInfo::getOperandKIndex() const {
+ return std::make_pair(lhsKDim, rhsKDim);
}
// Returns the result (M, N) dimension index pair.
-std::optional<std::pair<int, int>>
-VectorContractOpInfo::getResultMNIndex() const {
- switch (opKind) {
- case OpKind::MK_KN_MN:
- case OpKind::MK_NK_MN:
- return std::make_pair(0, 1);
- default:
- break;
- }
- return std::nullopt;
+std::pair<int, int> VectorContractOpInfo::getResultMNIndex() const {
+ return std::make_pair(outMDims.back(), outNDims.back());
}
VectorContractOpInfo::OpKind
VectorContractOpInfo::inferOpKind(MLIRContext *ctx,
- SmallVector<AffineMap> maps) const {
- using MapList = ArrayRef<ArrayRef<AffineExpr>>;
- auto infer = [&](MapList m) { return AffineMap::inferFromExprList(m, ctx); };
- AffineExpr m, n, k;
- bindDims(ctx, m, n, k);
- if (maps == infer({{m, k}, {k, n}, {m, n}}))
- return OpKind::MK_KN_MN;
- if (maps == infer({{m, k}, {n, k}, {m, n}}))
- return OpKind::MK_NK_MN;
+ SmallVector<AffineMap> maps) {
+ if (contractionDims.k.size() != 1) {
+ return OpKind::UNKNOWN;
+ }
+
+ int64_t innerM = contractionDims.m.back();
+ int64_t innerN = contractionDims.n.back();
+ int64_t k = contractionDims.k.back();
+
+ int64_t lhsM = *maps[0].getResultPosition(getAffineDimExpr(innerM, ctx));
+ lhsKDim = *maps[0].getResultPosition(getAffineDimExpr(k, ctx));
+ int64_t rhsN = *maps[1].getResultPosition(getAffineDimExpr(innerN, ctx));
+ rhsKDim = *maps[1].getResultPosition(getAffineDimExpr(k, ctx));
+ int64_t outM = *maps[2].getResultPosition(getAffineDimExpr(innerM, ctx));
+ int64_t outN = *maps[2].getResultPosition(getAffineDimExpr(innerN, ctx));
+
+ for (auto m : contractionDims.m) {
+ lhsMDims.push_back(*maps[0].getResultPosition(getAffineDimExpr(m, ctx)));
+ outMDims.push_back(*maps[2].getResultPosition(getAffineDimExpr(m, ctx)));
+ }
+ for (auto n : contractionDims.n) {
+ rhsNDims.push_back(*maps[1].getResultPosition(getAffineDimExpr(n, ctx)));
+ outNDims.push_back(*maps[2].getResultPosition(getAffineDimExpr(n, ctx)));
+ }
+
+ if (outM < outN) {
+ if (lhsM < lhsKDim) {
+ if (rhsN < rhsKDim) {
+ return OpKind::MK_NK_MN;
+ }
+ return OpKind::MK_KN_MN;
+ }
+ }
return OpKind::UNKNOWN;
}
diff --git a/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.h b/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.h
index be1ba22..3864322 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.h
@@ -4,7 +4,9 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypes.h"
namespace mlir::iree_compiler {
@@ -14,25 +16,49 @@
enum class OpKind { MK_KN_MN, MK_NK_MN, UNKNOWN };
explicit VectorContractOpInfo(vector::ContractionOp op) {
+ contractionDims = *linalg::inferContractionDims(op.getIndexingMapsArray());
opKind = inferOpKind(op.getContext(), op.getIndexingMapsArray());
}
OpKind getOpKind() const { return opKind; }
// Returns the (LHS M, RHS N) dimension index pair.
- std::optional<std::pair<int, int>> getOperandMNIndex() const;
+ std::pair<int, int> getOperandMNIndex() const;
// Returns the (LHS K, RHS K) dimension index pair.
- std::optional<std::pair<int, int>> getOperandKIndex() const;
+ std::pair<int, int> getOperandKIndex() const;
// Returns the result (M, N) dimension index pair.
- std::optional<std::pair<int, int>> getResultMNIndex() const;
+ std::pair<int, int> getResultMNIndex() const;
+
+ SmallVector<unsigned, 2> getMDims() const { return contractionDims.m; }
+
+ SmallVector<unsigned, 2> getNDims() const { return contractionDims.n; }
+
+ int64_t getARank() {
+ return contractionDims.m.size() + contractionDims.k.size();
+ }
+ int64_t getBRank() {
+ return contractionDims.k.size() + contractionDims.n.size();
+ }
+ int64_t getCRank() {
+ return contractionDims.m.size() + contractionDims.n.size();
+ }
+
+ SmallVector<int64_t> lhsMDims;
+ int64_t lhsKDim;
+ SmallVector<int64_t> rhsNDims;
+ int64_t rhsKDim;
+ SmallVector<int64_t> outMDims;
+ SmallVector<int64_t> outNDims;
private:
// Gets the kind of a contract op with the given indexing |maps|.
- OpKind inferOpKind(MLIRContext *ctx, SmallVector<AffineMap> maps) const;
+ OpKind inferOpKind(MLIRContext *ctx, SmallVector<AffineMap> maps);
OpKind opKind = OpKind::UNKNOWN;
+
+ linalg::ContractionDimensions contractionDims;
};
} // namespace mlir::iree_compiler
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
index 6e59a85..a970a2c 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
@@ -740,7 +740,7 @@
p << ']';
if (llvm::any_of(mask, [](bool b) { return !b; })) {
p << ',' << ' ';
- p << maskName;
+ p << maskName << ' ';
p << '=';
p << ' ';
p << '[';
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir
index 6e60022..5600905 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir
@@ -110,29 +110,89 @@
func.func @specify_nested(%lhs: memref<32x32xf16>) -> vector<32x32xf16> {
%cst_0 = arith.constant 0.0 : f16
%c0 = arith.constant 0 : index
- %result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<32x32xf16>, vector<32x32xf16>
- %2 = iree_vector_ext.layout_conflict_resolution %result {
- sourceLayout = #nested_1,
- desiredLayout = #nested_2,
- otherLayout0 = #nested_3,
- otherLayout1 = #nested_4,
- otherLayout2 = #nested_5
- } : vector<32x32xf16> -> vector<32x32xf16>
- return %2 : vector<32x32xf16>
+ %result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {
+ in_bounds = [true, true],
+ layout0 = #nested_1,
+ layout1 = #nested_2,
+ layout2 = #nested_3,
+ layout3 = #nested_4,
+ layout4 = #nested_5
+ } : memref<32x32xf16>, vector<32x32xf16>
+ return %result : vector<32x32xf16>
}
-// CHECK-DAG: #[[LAYOUT0:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [4, 2], outers_per_batch = [1, 4], threads_per_outer = [2, 4], elements_per_thread = [4, 1], subgroup_order = [1, 0], batch_order = [1, 0], thread_order = [1, 0], element_order = [1, 0], subgroup_basis = [1, 1], thread_basis = [2, 4]>
-// CHECK-DAG: #[[LAYOUT1:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 4], outers_per_batch = [4, 1], threads_per_outer = [4, 2], elements_per_thread = [1, 4], outer_order = [1, 0], subgroup_basis = [1, 1], thread_basis = [4, 2]>
-// CHECK-DAG: #[[LAYOUT2:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [4, 2], outers_per_batch = [1, 4], threads_per_outer = [2, 4], elements_per_thread = [4, 1], subgroup_order = [1, 0], batch_order = [1, 0], thread_order = [1, 0], element_order = [1, 0], subgroup_basis = [2, 4, 8], subgroup_active_ids= [true, true, false], thread_basis = [2, 4]>
-// CHECK-DAG: #[[LAYOUT3:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [4, 2], outers_per_batch = [1, 4], threads_per_outer = [2, 4], elements_per_thread = [4, 1], subgroup_order = [1, 0], batch_order = [1, 0], thread_order = [1, 0], element_order = [1, 0], subgroup_basis = [2, 4, 8], subgroup_active_ids= [true, true, false], thread_basis = [2, 4, 2], thread_active_ids= [false, true, true]>
-// CHECK-DAG: #[[LAYOUT4:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [4, 2], outers_per_batch = [1, 4], threads_per_outer = [2, 4], elements_per_thread = [4, 1], subgroup_order = [1, 0], batch_order = [1, 0], thread_order = [1, 0], element_order = [1, 0], subgroup_basis = [2, 4], thread_basis = [4, 2]>
+// CHECK: #[[LAYOUT0:.+]] = #iree_vector_ext.nested_layout<
+// CHECK-SAME: subgroups_per_workgroup = [1, 1],
+// CHECK-SAME: batches_per_subgroup = [2, 4],
+// CHECK-SAME: outers_per_batch = [4, 1],
+// CHECK-SAME: threads_per_outer = [4, 2],
+// CHECK-SAME: elements_per_thread = [1, 4],
+// CHECK-SAME: outer_order = [1, 0],
+// CHECK-SAME: subgroup_basis = [1, 1],
+// CHECK-SAME: thread_basis = [4, 2]>
+
+// CHECK: #[[LAYOUT1:.+]] = #iree_vector_ext.nested_layout<
+// CHECK-SAME: subgroups_per_workgroup = [1, 1],
+// CHECK-SAME: batches_per_subgroup = [4, 2],
+// CHECK-SAME: outers_per_batch = [1, 4],
+// CHECK-SAME: threads_per_outer = [2, 4],
+// CHECK-SAME: elements_per_thread = [4, 1],
+// CHECK-SAME: subgroup_order = [1, 0],
+// CHECK-SAME: batch_order = [1, 0],
+// CHECK-SAME: thread_order = [1, 0],
+// CHECK-SAME: element_order = [1, 0],
+// CHECK-SAME: subgroup_basis = [1, 1],
+// CHECK-SAME: thread_basis = [2, 4]>
+
+// CHECK: #[[LAYOUT2:.+]] = #iree_vector_ext.nested_layout<
+// CHECK-SAME: subgroups_per_workgroup = [1, 1],
+// CHECK-SAME: batches_per_subgroup = [4, 2],
+// CHECK-SAME: outers_per_batch = [1, 4],
+// CHECK-SAME: threads_per_outer = [2, 4],
+// CHECK-SAME: elements_per_thread = [4, 1],
+// CHECK-SAME: subgroup_order = [1, 0],
+// CHECK-SAME: batch_order = [1, 0],
+// CHECK-SAME: thread_order = [1, 0],
+// CHECK-SAME: element_order = [1, 0],
+// CHECK-SAME: subgroup_basis = [2, 4, 8],
+// CHECK-SAME: subgroup_active_ids = [true, true, false],
+// CHECK-SAME: thread_basis = [2, 4]>
+
+// CHECK: #[[LAYOUT3:.+]] = #iree_vector_ext.nested_layout<
+// CHECK-SAME: subgroups_per_workgroup = [1, 1],
+// CHECK-SAME: batches_per_subgroup = [4, 2],
+// CHECK-SAME: outers_per_batch = [1, 4],
+// CHECK-SAME: threads_per_outer = [2, 4],
+// CHECK-SAME: elements_per_thread = [4, 1],
+// CHECK-SAME: subgroup_order = [1, 0],
+// CHECK-SAME: batch_order = [1, 0],
+// CHECK-SAME: thread_order = [1, 0],
+// CHECK-SAME: element_order = [1, 0],
+// CHECK-SAME: subgroup_basis = [2, 4, 8],
+// CHECK-SAME: subgroup_active_ids = [true, true, false],
+// CHECK-SAME: thread_basis = [2, 4, 2],
+// CHECK-SAME: thread_active_ids = [false, true, true]>
+
+// CHECK: #[[LAYOUT4:.+]] = #iree_vector_ext.nested_layout<
+// CHECK-SAME: subgroups_per_workgroup = [1, 1],
+// CHECK-SAME: batches_per_subgroup = [4, 2],
+// CHECK-SAME: outers_per_batch = [1, 4],
+// CHECK-SAME: threads_per_outer = [2, 4],
+// CHECK-SAME: elements_per_thread = [4, 1],
+// CHECK-SAME: subgroup_order = [1, 0],
+// CHECK-SAME: batch_order = [1, 0],
+// CHECK-SAME: thread_order = [1, 0],
+// CHECK-SAME: element_order = [1, 0],
+// CHECK-SAME: subgroup_basis = [2, 4],
+// CHECK-SAME: thread_basis = [4, 2]>
+
// CHECK-LABEL: func.func @specify_nested
-// CHECK: iree_vector_ext.layout_conflict_resolution
-// CHECK-SAME: desiredLayout = #[[LAYOUT0]]
-// CHECK-SAME: otherLayout0 = #[[LAYOUT2]]
-// CHECK-SAME: otherLayout1 = #[[LAYOUT3]]
-// CHECK-SAME: otherLayout2 = #[[LAYOUT4]]
-// CHECK-SAME: sourceLayout = #[[LAYOUT1]]
+// CHECK: vector.transfer_read
+// CHECK-SAME: layout0 = #[[LAYOUT0]]
+// CHECK-SAME: layout1 = #[[LAYOUT1]]
+// CHECK-SAME: layout2 = #[[LAYOUT2]]
+// CHECK-SAME: layout3 = #[[LAYOUT3]]
+// CHECK-SAME: layout4 = #[[LAYOUT4]]
// -----