[LLVMGPU] allow multiple m and n dims in contraction distribution (#16943)

This adjusts the layout generation logic to allow distribution of
contractions with multiple m and n dimensions by greedily using the
subgroup/tile_counts of the mma_schedule with the outer dims. The inner
most m/n dimensions are still required to be divisible by the intrinsic
shape. (and this only supports a single k dimension).

This also decouples the ordering logic of the batch/subgroup
distribution from the lane distribution for the intrinsics. Currently it
assumes intrinsics can only specify three important sizes, an M, N, and
K size. To support distributed batches this would require adding a
fourth dim type.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp
index c2a2233..efe43d3 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp
@@ -97,8 +97,8 @@
     LLVM_DEBUG(llvm::dbgs() << "init tile: " << finalTile << "\n");
 
     // Offsets into the LHS/RHS batches.
-    SmallVector<int64_t, 2> lhsBatchOffsets(rank, 0);
-    SmallVector<int64_t, 2> rhsBatchOffsets(rank, 0);
+    SmallVector<int64_t, 2> lhsBatchOffsets(lhsLayout.getRank(), 0);
+    SmallVector<int64_t, 2> rhsBatchOffsets(rhsLayout.getRank(), 0);
 
     // Offsets into the result batches.
     ArrayRef<int64_t> resultBatches = resultLayout.getBatchesPerSubgroup();
@@ -183,7 +183,7 @@
   std::optional<int64_t> getKBatchSize(const VectorContractOpInfo &opDetail,
                                        NestedLayoutAttr lhsLayout,
                                        NestedLayoutAttr rhsLayout) const {
-    auto [lhsK, rhsK] = *opDetail.getOperandKIndex();
+    auto [lhsK, rhsK] = opDetail.getOperandKIndex();
     int64_t lhsKBatch = lhsLayout.getBatchesPerSubgroup()[lhsK];
     int64_t rhsKBatch = rhsLayout.getBatchesPerSubgroup()[rhsK];
 
@@ -201,15 +201,21 @@
                                SmallVector<int64_t, 2> &rhsOffsets,
                                NestedLayoutAttr lhsLayout,
                                NestedLayoutAttr rhsLayout) const {
-    auto [lhsM, rhsN] = *opDetail.getOperandMNIndex();
-    auto [lhsK, rhsK] = *opDetail.getOperandKIndex();
-    auto [resultM, resultN] = *opDetail.getResultMNIndex();
+
     // resultOffsets contains batch indices into the C/D vector. It is a 2-D
     // index for both M and N. We need to split out for M and N, and add index
     // for K.
-    lhsOffsets[lhsM] = resultOffsets[resultM];
+    for (auto [lhsM, resultM] :
+         llvm::zip_equal(opDetail.lhsMDims, opDetail.outMDims)) {
+      lhsOffsets[lhsM] = resultOffsets[resultM];
+    }
+    for (auto [rhsN, resultN] :
+         llvm::zip_equal(opDetail.rhsNDims, opDetail.outNDims)) {
+      rhsOffsets[rhsN] = resultOffsets[resultN];
+    }
+
+    auto [lhsK, rhsK] = opDetail.getOperandKIndex();
     lhsOffsets[lhsK] = kOffset;
-    rhsOffsets[rhsN] = resultOffsets[resultN];
     rhsOffsets[rhsK] = kOffset;
 
     // Now apply permutation on LHS/RHS according to their batch order.
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir b/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir
index 5b0761f..059b1e7 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir
@@ -243,13 +243,13 @@
     %c0 = arith.constant 0 : index
     %cst_0 = arith.constant 0.0 : f16
     %cst0_1 = arith.constant dense<0.0> : vector<16xf16>
-    // expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids= [false, true], thread_basis = [4, 16], thread_active_ids= [false, true]}}
+    // expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids = [false, true], thread_basis = [4, 16], thread_active_ids = [false, true]}}
     %root = vector.transfer_read %arr[%c0, %c0], %cst_0 {in_bounds = [true, true], "__vector_layout_test_anchor_result_0" = #layout} : memref<16x16xf16>, vector<16x16xf16>
     // expected-remark @above {{thread_basis = [4, 16]}}
     %root_red = vector.multi_reduction<add>, %root, %cst0_1 [0]  : vector<16x16xf16> to vector<16xf16>
-    // expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids= [false, true], thread_basis = [4, 16], thread_active_ids= [false, true]}}
+    // expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids = [false, true], thread_basis = [4, 16], thread_active_ids = [false, true]}}
     %c = arith.mulf %root_red, %a : vector<16xf16>
-    // expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids= [false, true], thread_basis = [4, 16], thread_active_ids= [false, true]}}
+    // expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids = [false, true], thread_basis = [4, 16], thread_active_ids = [false, true]}}
     func.return %c : vector<16xf16>
   }
 
@@ -281,13 +281,13 @@
     %c0 = arith.constant 0 : index
     %cst_0 = arith.constant 0.0 : f16
     %cst0_1 = arith.constant dense<0.0> : vector<16xf16>
-    // expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids= [true, false], thread_basis = [4, 16], thread_active_ids= [true, false]}}
+    // expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids = [true, false], thread_basis = [4, 16], thread_active_ids = [true, false]}}
     %root = vector.transfer_read %arr[%c0, %c0], %cst_0 {in_bounds = [true, true], "__vector_layout_test_anchor_result_0" = #layout} : memref<16x16xf16>, vector<16x16xf16>
     // expected-remark @above {{thread_basis = [4, 16]}}
     %root_red = vector.multi_reduction<add>, %root, %cst0_1 [1]  : vector<16x16xf16> to vector<16xf16>
-    // expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids= [true, false], thread_basis = [4, 16], thread_active_ids= [true, false]}}
+    // expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids = [true, false], thread_basis = [4, 16], thread_active_ids = [true, false]}}
     %c = arith.mulf %root_red, %a : vector<16xf16>
-    // expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids= [true, false], thread_basis = [4, 16], thread_active_ids= [true, false]}}
+    // expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids = [true, false], thread_basis = [4, 16], thread_active_ids = [true, false]}}
     func.return %c : vector<16xf16>
   }
 
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index 208333e..f27f67d 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -5,6 +5,7 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
+#include <numeric>
 
 #include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h"
 #include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h"
@@ -522,33 +523,178 @@
 // MMA Schedule Attributes
 //===----------------------------------------------------------------------===//
 
-NestedLayoutAttr permuteAndCreateNestedLayout(
-    MLIRContext *context, ArrayRef<int64_t> permute,
-    SmallVector<int64_t, 2> subgroupCount,
-    SmallVector<int64_t, 2> subgroupOrder, SmallVector<int64_t, 2> batchCount,
-    SmallVector<int64_t, 2> batchOrder, SmallVector<int64_t, 2> outerCount,
-    SmallVector<int64_t, 2> outerOrder, SmallVector<int64_t, 2> threadCount,
-    SmallVector<int64_t, 2> threadOrder, SmallVector<int64_t, 2> elementCount,
-    SmallVector<int64_t, 2> elementOrder, ArrayRef<int64_t> subgroupBasis,
-    ArrayRef<int64_t> threadBasis) {
-  if (!isIdentityPermutation(permute)) {
-    applyPermutationToVector(subgroupCount, permute);
-    applyPermutationToVector(subgroupOrder, permute);
-    applyPermutationToVector(batchCount, permute);
-    applyPermutationToVector(batchOrder, permute);
-    applyPermutationToVector(outerCount, permute);
-    applyPermutationToVector(outerOrder, permute);
-    applyPermutationToVector(threadCount, permute);
-    applyPermutationToVector(threadOrder, permute);
-    applyPermutationToVector(elementCount, permute);
-    applyPermutationToVector(elementOrder, permute);
-  }
+/// Gets a unit vector of the given rank, but fills in the given dimensions
+/// from the 2 element array |counts|. |dim0| is the position in the returned
+/// vector to put the first element of |counts|, and |dim1| is the position to
+/// put the second element. For example,
+///
+/// rank = 3, counts = [5, 7], dim0 = 2, dim1 = 1
+/// returns [1, 5, 7]
+SmallVector<int64_t> getUnitOfRankWithDims(int64_t rank,
+                                           ArrayRef<int64_t> counts,
+                                           int64_t dim0, int64_t dim1) {
+  assert(counts.size() == 2 &&
+         "Unexpected non-rank 2 single subgroup dimension counts");
+  SmallVector<int64_t> res(rank, 1);
+  res[dim0] = counts[0];
+  res[dim1] = counts[1];
+  return res;
+}
 
-  return NestedLayoutAttr::get(
+SmallVector<int64_t> getIdentityPerm(int64_t rank) {
+  return llvm::to_vector(llvm::seq(static_cast<int64_t>(0), rank));
+}
+
+/// Constructs an identity permutation with the given rank, except it applies
+/// the given rank-2 |perm| to the two dimensions |dim0| and |dim1|, and then
+/// swaps the positions of dim0 and dim1 in the final permutation. For example,
+///
+/// rank = 3, perm = [1, 0], dim0 = 1, dim1 = 2
+/// returns [0, 1, 2]
+///
+/// This is essentially just applying two rank-2 permutations to two particular
+/// dimensions. First it applies |perm|, which corresponds to a permutation
+/// needed by the underlying intrinsic, then it does another permutation based
+/// on the order of actual dimensions for the MMA fragment. For example, for the
+/// B matrix, dim0 = K and dim1 = N, so for the element order of an MFMA
+/// 16x16x16, perm would be `[1, 0]`, however if the actual contraction is a
+/// matmul_transpose_b, then the element order needs to be [0, 1].
+SmallVector<int64_t> getIdentityPermWithSwap(int64_t rank,
+                                             ArrayRef<int64_t> perm,
+                                             int64_t dim0, int64_t dim1) {
+  assert(perm.size() == 2 &&
+         "Unexpected non-rank 2 single subgroup dimension order");
+  SmallVector<int64_t> res = getIdentityPerm(rank);
+  if (perm[0] > perm[1]) {
+    std::swap(dim0, dim1);
+  }
+  if (dim0 > dim1) {
+    res[dim0] = dim1;
+    res[dim1] = dim0;
+  }
+  return res;
+}
+
+/// Constructs the nested layout given the layout for a single subgroup and the
+/// subgroup/batch counts and orders, as well as the dimensions along which to
+/// distribute the intrinsic's layout.
+///
+/// |outerDim| and |innerDim| refer to which dimensions are the outermost and
+/// innermost for a canonical MK_KN_MN matrix multiply, for a particular
+/// fragment. For example, for the B matrix of an MK_NK_MN matrix multiply,
+/// we would have:
+///   outerDim = 1 for the K dim
+///   innerDim = 0 for the N dim
+///
+/// For something like MK_NKN_MN with multiple N dims, it would typically be:
+///   outerDim = 1 for K
+///   innerDim = 2 for the second N dim
+///
+/// Importantly these two dimensions always refer to the actual dimension
+/// positions in the undistributed vector. For each fragment, this means:
+///   A: [outerDim, innerDim] = [innerMostMDim, innerMostKDim]
+///   B: [outerDim, innerDim] = [innerMostKDim, innerMostNDim]
+///   C: [outerDim, innerDim] = [innerMostMDim, innerMostNDim]
+///
+/// And here inner most is referential to the iteration order, not the order
+/// they appear per fragment (because there is no relationship between the
+/// dimension order of M in A and in C, for example).
+NestedLayoutAttr permuteAndCreateNestedLayout(
+    MLIRContext *context, int64_t rank, int64_t outerDim, int64_t innerDim,
+    SmallVector<int64_t> subgroupCount, SmallVector<int64_t> subgroupOrder,
+    SmallVector<int64_t> batchCount, SmallVector<int64_t> batchOrder,
+    MMAAttr::SingleSubgroupLayout counts, MMAAttr::SingleSubgroupLayout orders,
+    ArrayRef<int64_t> dataDuplicate, ArrayRef<int64_t> subgroupBasis,
+    ArrayRef<bool> subgroupActiveIds) {
+
+  LLVM_DEBUG({
+    llvm::errs() << "Given:";
+    llvm::errs() << "\n    outerDim = " << outerDim;
+    llvm::errs() << "\n    innerDim = " << innerDim;
+    llvm::errs() << "\n    subgroupCount: ";
+    llvm::interleaveComma(subgroupCount, llvm::errs());
+    llvm::errs() << "\n    subgroupOrder: ";
+    llvm::interleaveComma(subgroupOrder, llvm::errs());
+    llvm::errs() << "\n    batchCount: ";
+    llvm::interleaveComma(batchCount, llvm::errs());
+    llvm::errs() << "\n    batchOrder: ";
+    llvm::interleaveComma(batchOrder, llvm::errs());
+    llvm::errs() << "\n    counts.outer: ";
+    llvm::interleaveComma(counts.outer, llvm::errs());
+    llvm::errs() << "\n    orders.outer: ";
+    llvm::interleaveComma(orders.outer, llvm::errs());
+    llvm::errs() << "\n    counts.thread: ";
+    llvm::interleaveComma(counts.thread, llvm::errs());
+    llvm::errs() << "\n    orders.thread: ";
+    llvm::interleaveComma(orders.thread, llvm::errs());
+    llvm::errs() << "\n    counts.element: ";
+    llvm::interleaveComma(counts.element, llvm::errs());
+    llvm::errs() << "\n    orders.element: ";
+    llvm::interleaveComma(orders.element, llvm::errs());
+    llvm::errs() << "\n    subgroupBasis: ";
+    llvm::interleaveComma(subgroupBasis, llvm::errs());
+    llvm::errs() << "\n    subgroupActiveIds: ";
+    llvm::interleaveComma(subgroupActiveIds, llvm::errs());
+    llvm::errs() << "\n";
+  });
+
+  SmallVector<int64_t> outerOrder =
+      getIdentityPermWithSwap(rank, orders.outer, outerDim, innerDim);
+  SmallVector<int64_t> threadOrder =
+      getIdentityPermWithSwap(rank, orders.thread, outerDim, innerDim);
+  SmallVector<int64_t> elementOrder =
+      getIdentityPermWithSwap(rank, orders.element, outerDim, innerDim);
+
+  SmallVector<int64_t> threadBasis =
+      getUnitOfRankWithDims(rank, counts.thread, outerDim, innerDim);
+  threadBasis[outerDim] *= dataDuplicate[0];
+  threadBasis[innerDim] *= dataDuplicate[1];
+  applyPermutationToVector(threadBasis, threadOrder);
+
+  SmallVector<int64_t> outerCount =
+      getUnitOfRankWithDims(rank, counts.outer, outerDim, innerDim);
+  SmallVector<int64_t> threadCount =
+      getUnitOfRankWithDims(rank, counts.thread, outerDim, innerDim);
+  SmallVector<int64_t> elementCount =
+      getUnitOfRankWithDims(rank, counts.element, outerDim, innerDim);
+
+  LLVM_DEBUG({
+    llvm::errs() << "\nNew layout attr:";
+    llvm::errs() << "\n    subgroupCount: ";
+    llvm::interleaveComma(subgroupCount, llvm::errs());
+    llvm::errs() << "\n    subgroupOrder: ";
+    llvm::interleaveComma(subgroupOrder, llvm::errs());
+    llvm::errs() << "\n    batchCount: ";
+    llvm::interleaveComma(batchCount, llvm::errs());
+    llvm::errs() << "\n    batchOrder: ";
+    llvm::interleaveComma(batchOrder, llvm::errs());
+    llvm::errs() << "\n    outerCount: ";
+    llvm::interleaveComma(outerCount, llvm::errs());
+    llvm::errs() << "\n    outerOrder: ";
+    llvm::interleaveComma(outerOrder, llvm::errs());
+    llvm::errs() << "\n    threadCount: ";
+    llvm::interleaveComma(threadCount, llvm::errs());
+    llvm::errs() << "\n    threadOrder: ";
+    llvm::interleaveComma(threadOrder, llvm::errs());
+    llvm::errs() << "\n    elementCount: ";
+    llvm::interleaveComma(elementCount, llvm::errs());
+    llvm::errs() << "\n    elementOrder: ";
+    llvm::interleaveComma(elementOrder, llvm::errs());
+    llvm::errs() << "\n    subgroupBasis: ";
+    llvm::interleaveComma(subgroupBasis, llvm::errs());
+    llvm::errs() << "\n    subgroupActiveIds: ";
+    llvm::interleaveComma(subgroupActiveIds, llvm::errs());
+    llvm::errs() << "\n    threadBasis: ";
+    llvm::interleaveComma(threadBasis, llvm::errs());
+    llvm::errs() << "\n";
+  });
+
+  auto layoutAttr = NestedLayoutAttr::get(
       context, subgroupCount, subgroupOrder, batchCount, batchOrder, outerCount,
       outerOrder, threadCount, threadOrder, elementCount, elementOrder,
-      subgroupBasis, SmallVector<bool>(subgroupBasis.size(), true), threadBasis,
+      subgroupBasis, subgroupActiveIds, threadBasis,
       SmallVector<bool>(threadBasis.size(), true));
+  return layoutAttr;
 }
 
 std::optional<std::tuple<VectorExt::VectorLayoutInterface,
@@ -556,19 +702,19 @@
                          VectorExt::VectorLayoutInterface>>
 MMAScheduleAttr::getContractionLayout(vector::ContractionOp contractOp) const {
   VectorContractOpInfo opInfo(contractOp);
+  LLVM_DEBUG({
+    llvm::errs() << "Getting mma layouts for:\n" << contractOp << "\n";
+    llvm::errs() << "For schedule: " << *this << "\n";
+  });
   if (opInfo.getOpKind() == VectorContractOpInfo::OpKind::UNKNOWN)
     return std::nullopt;
 
-  auto [aM, bN] = *opInfo.getOperandMNIndex();
-  auto [aK, bK] = *opInfo.getOperandKIndex();
-  auto [cM, cN] = *opInfo.getResultMNIndex();
-  SmallVector<int64_t, 2> aPermute = {aM, aK};
-  SmallVector<int64_t, 2> bPermute = {bK, bN};
-  SmallVector<int64_t, 2> cPermute = {cM, cN};
-
   auto mmaAttr = llvm::cast<MMAAttr>(getIntrinsic());
   MLIRContext *context = getContext();
 
+  SmallVector<int64_t> bounds;
+  contractOp.getIterationBounds(bounds);
+
   // Get the concrete nested layout for each matrix. Note that the struct
   // MMAAttr::SingleSubgroupLayout contains the partial layout for the
   // canonical (M, K) x (K, N) -> (M, N) matmul form; while the specific
@@ -583,8 +729,71 @@
   // the linearized GPU hardware lane ID into a n-D concatenated logical
   // warp+thread using the subgroup/thread basis, so the subgroup basis should
   // remain the same for all A/B/C matrix.
-  SmallVector<int64_t, 2> subgroupBasis = {getSubgroupMCount(),
-                                           getSubgroupNCount()};
+
+  SmallVector<int64_t, 2> subgroupMBasis;
+  SmallVector<int64_t, 2> batchMSizes;
+  int64_t currMCount = getSubgroupMCount();
+  int64_t currMBatch = getSubgroupMTileCount();
+
+  // Greedily break up the M subgroup and batch counts along the "M" iteration
+  // bounds. We distribute as many residual subgroups as possible per M dim, and
+  // then divide the remaining along batch dims. The inner most M dim is always
+  // the one used for the intrinsic, meaning for a valid schedule, the computed
+  // batch counts and subgroup basis will satisfy
+  // totalMSize / intrinsicM = product(batchMSizes) * product(subgroupMBasis)
+  for (auto dim : opInfo.getMDims()) {
+    int64_t threads = std::gcd(currMCount, bounds[dim]);
+    subgroupMBasis.push_back(threads);
+    currMCount /= threads;
+    int64_t batchCount = bounds[dim] / threads;
+    batchCount = batchCount >= currMBatch ? currMBatch : batchCount;
+    batchMSizes.push_back(batchCount);
+    currMBatch /= batchCount;
+  }
+
+  SmallVector<int64_t, 2> subgroupNBasis;
+  SmallVector<int64_t, 2> batchNSizes;
+  int64_t currNCount = getSubgroupNCount();
+  int64_t currNBatch = getSubgroupNTileCount();
+
+  // Do the same for N dims.
+  for (auto dim : opInfo.getNDims()) {
+    int64_t threads = std::gcd(currNCount, bounds[dim]);
+    subgroupNBasis.push_back(threads);
+    currNCount /= threads;
+    int64_t batchCount = bounds[dim] / threads;
+    batchCount = batchCount >= currNBatch ? currNBatch : batchCount;
+    batchNSizes.push_back(batchCount);
+    currNBatch /= batchCount;
+  }
+
+  SmallVector<int64_t> subgroupBasis;
+  auto mDimVec = opInfo.getMDims();
+  llvm::SmallDenseSet<int64_t> mDims(mDimVec.begin(), mDimVec.end());
+  auto nDimVec = opInfo.getNDims();
+  llvm::SmallDenseSet<int64_t> nDims(nDimVec.begin(), nDimVec.end());
+
+  int64_t currM = 0;
+  int64_t currN = 0;
+  // Because we currently require all batch dimensions to be unit, the subgroup
+  // basis can be constructed from the M and N bases. To keep things simple,
+  // the current heuristic is to distribute all M dims followed by all N dims.
+  for (auto dim : llvm::seq(static_cast<int64_t>(0), opInfo.getCRank())) {
+    if (mDims.contains(dim)) {
+      subgroupBasis.push_back(subgroupMBasis[currM]);
+      // Construct mDimVec such that it contains the order in which the M dims
+      // appear in the C matrix.
+      mDimVec[currM] = dim;
+      currM++;
+    }
+    if (nDims.contains(dim)) {
+      subgroupBasis.push_back(subgroupNBasis[currN]);
+      // Construct nDimVec such that it contains the order in which the N dims
+      // appear in the C matrix.
+      nDimVec[currN] = dim;
+      currN++;
+    }
+  }
 
   // For threads though, we also need to make sure the basis is consistent
   // across A, B, and C matrix. Though here we need to additionally think it
@@ -606,23 +815,43 @@
   MMAAttr::SingleSubgroupLayout cOrders =
       mmaAttr.getCSingleSubgroupLayoutOrder();
 
-  SmallVector<int64_t, 2> cThreadBasis = cCounts.thread;
-  SmallVector<int64_t, 2> cDataDuplicate = mmaAttr.getCDataDuplicate();
-  for (auto [idx, duplicateFactor] : llvm::enumerate(cDataDuplicate)) {
-    cThreadBasis[idx] *= duplicateFactor;
+  auto [m, n] = opInfo.getResultMNIndex();
+  int64_t cRank = opInfo.getCRank();
+
+  // Get the M and N dims w.r.t. the dimensions of the C matrix. cMDims and
+  // cNDims are the M and N dimensions of the C matrix in the order they are
+  // iterated over in the contraction.
+  SmallVector<int64_t> cMDims = opInfo.outMDims;
+  SmallVector<int64_t> cNDims = opInfo.outNDims;
+  SmallVector<int64_t> cBatchSizes(cRank, 1);
+  SmallVector<int64_t> cSubgroupSizes(cRank, 1);
+  SmallVector<int64_t> cOverallOrder(cRank, 0);
+  for (auto [i, dim] : llvm::enumerate(cMDims)) {
+    cBatchSizes[dim] = batchMSizes[i];
+    cSubgroupSizes[dim] = subgroupMBasis[i];
+    cOverallOrder[dim] = mDimVec[i];
   }
-  applyPermutationToVector(cThreadBasis, cOrders.thread);
+  for (auto [i, dim] : llvm::enumerate(cNDims)) {
+    cBatchSizes[dim] = batchNSizes[i];
+    cSubgroupSizes[dim] = subgroupNBasis[i];
+    cOverallOrder[dim] = nDimVec[i];
+  }
+
+  // Dummy 1 for the k dimension.
+  subgroupBasis.push_back(1);
+
+  SmallVector<bool> cActiveSubgroups(cRank + 1, true);
+  cActiveSubgroups.back() = false;
 
   auto cLayout = permuteAndCreateNestedLayout(
-      context, cPermute,
-      /*subgroupCount=*/{getSubgroupMCount(), getSubgroupNCount()},
-      /*subgroupOrder=*/{0, 1},
-      /*batchCount=*/{getSubgroupMTileCount(), getSubgroupNTileCount()},
-      /*batchOrder=*/{0, 1}, /*outerCount=*/cCounts.outer,
-      /*outerOrder=*/cOrders.outer, /*threadCount=*/cCounts.thread,
-      /*threadOrder=*/cOrders.thread,
-      /*elementCount=*/cCounts.element, /*elementOrder=*/cOrders.element,
-      subgroupBasis, cThreadBasis);
+      context, cRank, m, n,
+      /*subgroupCount=*/cSubgroupSizes,
+      /*subgroupOrder=*/cOverallOrder,
+      /*batchCount=*/cBatchSizes,
+      /*batchOrder=*/cOverallOrder, cCounts, cOrders,
+      /*dataDuplicate=*/mmaAttr.getCDataDuplicate(), subgroupBasis,
+      cActiveSubgroups);
+  LLVM_DEBUG({ llvm::errs() << "C layout: " << cLayout << "\n"; });
 
   // A matrix layout
   MMAAttr::SingleSubgroupLayout aCounts =
@@ -630,23 +859,41 @@
   MMAAttr::SingleSubgroupLayout aOrders =
       mmaAttr.getASingleSubgroupLayoutOrder();
 
-  SmallVector<int64_t, 2> aThreadBasis = aCounts.thread;
-  SmallVector<int64_t, 2> aDataDuplicate = mmaAttr.getADataDuplicate();
-  for (auto [idx, duplicateFactor] : llvm::enumerate(aDataDuplicate)) {
-    aThreadBasis[idx] *= duplicateFactor;
+  auto [afm, bfn] = opInfo.getOperandMNIndex();
+  auto [afk, bfk] = opInfo.getOperandKIndex();
+
+  int64_t aRank = opInfo.getARank();
+
+  SmallVector<int64_t> aMDims = opInfo.lhsMDims;
+  SmallVector<int64_t> aBatchSizes(aRank, 1);
+  SmallVector<int64_t> aSubgroupSizes(aRank, 1);
+  SmallVector<int64_t> aSubgroupOrder(aRank, 0);
+  SmallVector<int64_t> aBatchOrder(aRank, 0);
+  for (auto [i, dim] : llvm::enumerate(aMDims)) {
+    aBatchSizes[dim] = batchMSizes[i];
+    aSubgroupSizes[dim] = subgroupMBasis[i];
+    aSubgroupOrder[dim] = i;
+    aBatchOrder[dim] = i >= afk ? i + 1 : i;
   }
-  applyPermutationToVector(aThreadBasis, aOrders.thread);
+  aSubgroupOrder[afk] = aRank - 1;
+  aBatchOrder[afk] = afk;
+  aBatchSizes[afk] = getSubgroupKTileCount();
+
+  SmallVector<bool> aActiveSubgroups(subgroupBasis.size(), false);
+  for (auto mDim : mDims) {
+    aActiveSubgroups[mDim] = true;
+  }
+  aActiveSubgroups.back() = true;
 
   auto aLayout = permuteAndCreateNestedLayout(
-      context, aPermute,
-      /*subgroupCount=*/{getSubgroupMCount(), 1},
-      /*subgroupOrder=*/{0, 1},
-      /*batchCount=*/{getSubgroupMTileCount(), getSubgroupKTileCount()},
-      /*batchOrder=*/{0, 1}, /*outerCount=*/aCounts.outer,
-      /*outerOrder=*/aOrders.outer, /*threadCount=*/aCounts.thread,
-      /*threadOrder=*/aOrders.thread,
-      /*elementCount=*/aCounts.element, /*elementOrder=*/aOrders.element,
-      subgroupBasis, aThreadBasis);
+      context, aRank, afm, afk,
+      /*subgroupCount=*/aSubgroupSizes,
+      /*subgroupOrder=*/aSubgroupOrder,
+      /*batchCount=*/aBatchSizes,
+      /*batchOrder=*/getIdentityPerm(aRank), aCounts, aOrders,
+      /*dataDuplicate=*/mmaAttr.getADataDuplicate(), subgroupBasis,
+      aActiveSubgroups);
+  LLVM_DEBUG({ llvm::errs() << "A layout: " << aLayout << "\n"; });
 
   // B matrix layout
   MMAAttr::SingleSubgroupLayout bCounts =
@@ -654,23 +901,38 @@
   MMAAttr::SingleSubgroupLayout bOrders =
       mmaAttr.getBSingleSubgroupLayoutOrder();
 
-  SmallVector<int64_t, 2> bThreadBasis = bCounts.thread;
-  SmallVector<int64_t, 2> bDataDuplicate = mmaAttr.getBDataDuplicate();
-  for (auto [idx, duplicateFactor] : llvm::enumerate(bDataDuplicate)) {
-    bThreadBasis[idx] *= duplicateFactor;
+  int64_t bRank = opInfo.getBRank();
+
+  SmallVector<int64_t> bNDims = opInfo.rhsNDims;
+  SmallVector<int64_t> bBatchSizes(bRank, 1);
+  SmallVector<int64_t> bSubgroupSizes(bRank, 1);
+  SmallVector<int64_t> bSubgroupOrder(bRank, 0);
+  SmallVector<int64_t> bBatchOrder(bRank, 0);
+  for (auto [i, dim] : llvm::enumerate(bNDims)) {
+    bBatchSizes[dim] = batchNSizes[i];
+    bSubgroupSizes[dim] = subgroupNBasis[i];
+    bSubgroupOrder[dim] = i;
+    bBatchOrder[dim] = i >= bfk ? i + 1 : i;
   }
-  applyPermutationToVector(bThreadBasis, bOrders.thread);
+  bSubgroupOrder[bfk] = bRank - 1;
+  bBatchOrder[bfk] = bfk;
+  bBatchSizes[bfk] = getSubgroupKTileCount();
+
+  SmallVector<bool> bActiveSubgroups(subgroupBasis.size(), false);
+  for (auto nDim : nDims) {
+    bActiveSubgroups[nDim] = true;
+  }
+  bActiveSubgroups.back() = true;
 
   auto bLayout = permuteAndCreateNestedLayout(
-      context, bPermute,
-      /*subgroupCount=*/{1, getSubgroupNCount()},
-      /*subgroupOrder=*/{0, 1},
-      /*batchCount=*/{getSubgroupKTileCount(), getSubgroupNTileCount()},
-      /*batchOrder=*/{0, 1}, /*outerCount=*/bCounts.outer,
-      /*outerOrder=*/bOrders.outer, /*threadCount=*/bCounts.thread,
-      /*threadOrder=*/bOrders.thread,
-      /*elementCount=*/bCounts.element, /*elementOrder=*/bOrders.element,
-      subgroupBasis, bThreadBasis);
+      context, bRank, bfk, bfn,
+      /*subgroupCount=*/bSubgroupSizes,
+      /*subgroupOrder=*/bSubgroupOrder,
+      /*batchCount=*/bBatchSizes,
+      /*batchOrder=*/bBatchOrder, bCounts, bOrders,
+      /*dataDuplicate=*/mmaAttr.getBDataDuplicate(), subgroupBasis,
+      bActiveSubgroups);
+  LLVM_DEBUG({ llvm::errs() << "B layout: " << bLayout << "\n"; });
 
   return std::make_tuple(aLayout, bLayout, cLayout);
 }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp
index e51fef1..3aead0a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp
@@ -45,7 +45,6 @@
 
     auto [dstAElemType, dstBElemType, dstCElemType] =
         intrinsic.getABCElementTypes();
-    auto [dstM, dstN, dstK] = intrinsic.getMNKShape();
 
     auto srcCElemFType = dyn_cast<FloatType>(srcCType.getElementType());
     auto dstCElemFType = dyn_cast<FloatType>(dstCElemType);
@@ -60,16 +59,6 @@
       return rewriter.notifyMatchFailure(contractOp, "a/b type mismatch");
     }
 
-    auto [srcCMIndex, srcCNIndex] = *opInfo.getResultMNIndex();
-    auto [srcAKIndex, srcBKIndex] = *opInfo.getOperandKIndex();
-    int64_t srcM = srcCType.getShape()[srcCMIndex];
-    int64_t srcN = srcCType.getShape()[srcCNIndex];
-    int64_t srcK = srcAType.getShape()[srcAKIndex];
-
-    if (srcM % dstM != 0 || srcN % dstN != 0 || srcK % dstK != 0) {
-      return rewriter.notifyMatchFailure(contractOp, "shape cannot divide");
-    }
-
     Location loc = contractOp.getLoc();
     auto dstCType = srcCType.clone(dstCElemFType);
     auto extOp =
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir
index d252c78..856cc01 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir
@@ -45,26 +45,6 @@
 
 // -----
 
-func.func @mfma_matmul_96x64x16_mm_cannot_divide(%lhs: vector<95x16xf16>, %rhs: vector<16x64xf16>, %init: vector<95x64xf16>) -> vector<95x64xf16> attributes {
-    mma_schedule = #iree_gpu.mma_schedule<
-      intrinsic = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>,
-      subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 3, subgroup_n_tile_count = 2, subgroup_k_tile_count = 2>,
-    workgroup_size = [64, 1, 1]} {
-    %0 = vector.contract {
-      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
-      iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
-      %lhs, %rhs, %init : vector<95x16xf16>, vector<16x64xf16> into vector<95x64xf16>
-  return %0 : vector<95x64xf16>
-}
-
-// CHECK-LABEL: func.func @mfma_matmul_96x64x16_mm_cannot_divide
-//   CHECK-NOT:   arith.extf
-//       CHECK:   vector.contract
-//  CHECK-SAME:     %{{.+}}, %{{.+}}, %{{.+}} : vector<95x16xf16>, vector<16x64xf16> into vector<95x64xf16>
-//   CHECK-NOT:   arith.truncf
-
-// -----
-
 func.func @mfma_matmul_96x64x16_mm_cannot_downcast(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf64>) -> vector<96x64xf64> attributes {
     mma_schedule = #iree_gpu.mma_schedule<
       intrinsic = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>,
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir
index b8ee6d5..06c5aec 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir
@@ -15,15 +15,15 @@
 //      CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
 // CHECK-SAME:   subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 4],
 // CHECK-SAME:   thread_order = [1, 0]
-// CHECK-SAME:   subgroup_basis = [1, 1], thread_basis = [2, 32]>
+// CHECK-SAME:   subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, false, true], thread_basis = [2, 32]>
 //      CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
 // CHECK-SAME:   subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [2, 32], elements_per_thread = [4, 1],
 // CHECK-SAME:   element_order = [1, 0],
-// CHECK-SAME:   subgroup_basis = [1, 1], thread_basis = [2, 32]>
+// CHECK-SAME:   subgroup_basis = [1, 1, 1], subgroup_active_ids = [false, true, true], thread_basis = [2, 32]>
 //      CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
 // CHECK-SAME:   subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [4, 1], threads_per_outer = [2, 32], elements_per_thread = [4, 1],
 // CHECK-SAME:   element_order = [1, 0],
-// CHECK-SAME:   subgroup_basis = [1, 1], thread_basis = [2, 32]>
+// CHECK-SAME:   subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, true, false], thread_basis = [2, 32]>
 
 // -----
 
@@ -42,15 +42,15 @@
 //      CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
 // CHECK-SAME:   subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 4],
 // CHECK-SAME:   thread_order = [1, 0]
-// CHECK-SAME:   subgroup_basis = [1, 1], thread_basis = [2, 32]>
+// CHECK-SAME:   subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, false, true], thread_basis = [2, 32]>
 //      CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
 // CHECK-SAME:   subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 4],
-// CHECK-SAME:   subgroup_order = [1, 0], batch_order = [1, 0], outer_order = [1, 0], thread_order = [1, 0]
-// CHECK-SAME:   subgroup_basis = [1, 1], thread_basis = [2, 32]>
+// CHECK-SAME:   outer_order = [1, 0], thread_order = [1, 0]
+// CHECK-SAME:   subgroup_basis = [1, 1, 1], subgroup_active_ids = [false, true, true], thread_basis = [2, 32]>
 //      CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
 // CHECK-SAME:   subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [4, 1], threads_per_outer = [2, 32], elements_per_thread = [4, 1],
-// CHECK-SAME:   element_order = [1, 0],
-// CHECK-SAME:   subgroup_basis = [1, 1], thread_basis = [2, 32]>
+// CHECK-SAME:   element_order = [1, 0]
+// CHECK-SAME:   subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, true, false], thread_basis = [2, 32]>
 
 // -----
 
@@ -117,15 +117,15 @@
 //      CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
 // CHECK-SAME:   subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 2], outers_per_batch = [1, 1], threads_per_outer = [16, 4], elements_per_thread = [1, 4],
 // CHECK-SAME:   thread_order = [1, 0]
-// CHECK-SAME:   subgroup_basis = [1, 1], thread_basis = [4, 16]>
+// CHECK-SAME:   subgroup_active_ids = [true, false, true], thread_basis = [4, 16]>
 //      CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
 // CHECK-SAME:   subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [4, 1],
-// CHECK-SAME:   element_order = [1, 0],
-// CHECK-SAME:   subgroup_basis = [1, 1], thread_basis = [4, 16]>
+// CHECK-SAME:   subgroup_order = [1, 0], element_order = [1, 0],
+// CHECK-SAME:   subgroup_basis = [1, 1, 1], subgroup_active_ids = [false, true, true], thread_basis = [4, 16]>
 //      CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
 // CHECK-SAME:   subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [4, 1],
 // CHECK-SAME:   element_order = [1, 0],
-// CHECK-SAME:   subgroup_basis = [1, 1], thread_basis = [4, 16]>
+// CHECK-SAME:   subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, true, false], thread_basis = [4, 16]>
 
 // -----
 
@@ -178,15 +178,15 @@
 //      CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
 // CHECK-SAME:   subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 2], outers_per_batch = [1, 1], threads_per_outer = [16, 4], elements_per_thread = [1, 4],
 // CHECK-SAME:   thread_order = [1, 0]
-// CHECK-SAME:   subgroup_basis = [1, 1], thread_basis = [4, 16]>
+// CHECK-SAME:   subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, false, true], thread_basis = [4, 16]>
 //      CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
 // CHECK-SAME:   subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [4, 1],
 // CHECK-SAME:   element_order = [1, 0],
-// CHECK-SAME:   subgroup_basis = [1, 1], thread_basis = [4, 16]>
+// CHECK-SAME:   subgroup_basis = [1, 1, 1], subgroup_active_ids = [false, true, true], thread_basis = [4, 16]>
 //      CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
 // CHECK-SAME:   subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [4, 1],
 // CHECK-SAME:   element_order = [1, 0],
-// CHECK-SAME:   subgroup_basis = [1, 1], thread_basis = [4, 16]>
+// CHECK-SAME:   subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, true, false], thread_basis = [4, 16]>
 
 // -----
 
@@ -243,15 +243,15 @@
 //      CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
 // CHECK-SAME:   subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [1, 1], threads_per_outer = [16, 1], elements_per_thread = [1, 16],
 // CHECK-SAME:   thread_order = [1, 0]
-// CHECK-SAME:   subgroup_basis = [1, 1], thread_basis = [1, 32]>
+// CHECK-SAME:   subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, false, true], thread_basis = [1, 32]>
 //      CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
 // CHECK-SAME:   subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [1, 16], elements_per_thread = [16, 1],
 // CHECK-SAME:   element_order = [1, 0],
-// CHECK-SAME:   subgroup_basis = [1, 1], thread_basis = [1, 32]>
+// CHECK-SAME:   subgroup_basis = [1, 1, 1], subgroup_active_ids = [false, true, true], thread_basis = [1, 32]>
 //      CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
 // CHECK-SAME:   subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [8, 1], threads_per_outer = [2, 16], elements_per_thread = [1, 1],
 // CHECK-SAME:   element_order = [1, 0],
-// CHECK-SAME:   subgroup_basis = [1, 1], thread_basis = [2, 16]>
+// CHECK-SAME:   subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, true, false], thread_basis = [2, 16]>
 
 // -----
 
@@ -270,12 +270,113 @@
 //      CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
 // CHECK-SAME:   subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [1, 1], threads_per_outer = [16, 1], elements_per_thread = [1, 16],
 // CHECK-SAME:   thread_order = [1, 0]
-// CHECK-SAME:   subgroup_basis = [1, 1], thread_basis = [1, 32]>
+// CHECK-SAME:   subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, false, true], thread_basis = [1, 32]>
 //      CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
 // CHECK-SAME:   subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [16, 1], elements_per_thread = [1, 16],
-// CHECK-SAME:   subgroup_order = [1, 0], batch_order = [1, 0], outer_order = [1, 0], thread_order = [1, 0],
-// CHECK-SAME:   subgroup_basis = [1, 1], thread_basis = [1, 32]>
+// CHECK-SAME:   outer_order = [1, 0], thread_order = [1, 0],
+// CHECK-SAME:   subgroup_basis = [1, 1, 1], subgroup_active_ids = [false, true, true], thread_basis = [1, 32]>
 //      CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
 // CHECK-SAME:   subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [8, 1], threads_per_outer = [2, 16], elements_per_thread = [1, 1],
 // CHECK-SAME:   element_order = [1, 0],
-// CHECK-SAME:   subgroup_basis = [1, 1], thread_basis = [2, 16]>
+// CHECK-SAME:   subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, true, false], thread_basis = [2, 16]>
+
+// -----
+
+func.func @matmul_192x64x16_mmt_multi_m(%lhs: vector<2x64x16xf16>, %rhs: vector<16x64xf16>, %init: vector<2x64x64xf32>) -> vector<2x64x64xf32> attributes {
+    mma_schedule = #iree_gpu.mma_schedule<
+      intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>,
+      subgroup_m_count = 2, subgroup_n_count = 1, subgroup_m_tile_count = 4, subgroup_n_tile_count = 4, subgroup_k_tile_count = 1>,
+    workgroup_size = [64, 2, 1]} {
+    %0 = vector.contract {
+      indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>],
+      iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+      %lhs, %rhs, %init : vector<2x64x16xf16>, vector<16x64xf16> into vector<2x64x64xf32>
+  return %0 : vector<2x64x64xf32>
+}
+
+//      CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
+// CHECK-SAME:   subgroups_per_workgroup = [2, 1, 1],
+// CHECK-SAME:   batches_per_subgroup = [1, 4, 1],
+// CHECK-SAME:   outers_per_batch = [1, 1, 1],
+// CHECK-SAME:   threads_per_outer = [1, 16, 4],
+// CHECK-SAME:   elements_per_thread = [1, 1, 4],
+// CHECK-SAME:   thread_order = [0, 2, 1],
+// CHECK-SAME:   subgroup_basis = [2, 1, 1, 1],
+// CHECK-SAME:   subgroup_active_ids = [true, true, false, true],
+// CHECK-SAME:   thread_basis = [1, 4, 16]>
+//      CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
+// CHECK-SAME:   subgroups_per_workgroup = [1, 1],
+// CHECK-SAME:   batches_per_subgroup = [1, 4],
+// CHECK-SAME:   outers_per_batch = [1, 1],
+// CHECK-SAME:   threads_per_outer = [4, 16],
+// CHECK-SAME:   elements_per_thread = [4, 1],
+// CHECK-SAME:   subgroup_order = [1, 0],
+// CHECK-SAME:   element_order = [1, 0],
+// CHECK-SAME:   subgroup_basis = [2, 1, 1, 1],
+// CHECK-SAME:   subgroup_active_ids = [false, false, true, true],
+// CHECK-SAME:   thread_basis = [4, 16]>
+//      CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
+// CHECK-SAME:   subgroups_per_workgroup = [2, 1, 1],
+// CHECK-SAME:   batches_per_subgroup = [1, 4, 4],
+// CHECK-SAME:   outers_per_batch = [1, 1, 1],
+// CHECK-SAME:   threads_per_outer = [1, 4, 16],
+// CHECK-SAME:   elements_per_thread = [1, 4, 1],
+// CHECK-SAME:   element_order = [0, 2, 1],
+// CHECK-SAME:   subgroup_basis = [2, 1, 1, 1],
+// CHECK-SAME:   subgroup_active_ids = [true, true, true, false],
+// CHECK-SAME:   thread_basis = [1, 4, 16]>
+
+// -----
+
+func.func @matmul_192x64x16_mmt_multi_split_m(%lhs: vector<2x64x16xf16>, %rhs: vector<16x64xf16>, %init: vector<2x64x64xf32>) -> vector<2x64x64xf32> attributes {
+    mma_schedule = #iree_gpu.mma_schedule<
+      intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>,
+      subgroup_m_count = 4, subgroup_n_count = 1, subgroup_m_tile_count = 2, subgroup_n_tile_count = 4, subgroup_k_tile_count = 1>,
+    workgroup_size = [64, 2, 1]} {
+    %0 = vector.contract {
+      indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>],
+      iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+      %lhs, %rhs, %init : vector<2x64x16xf16>, vector<16x64xf16> into vector<2x64x64xf32>
+  return %0 : vector<2x64x64xf32>
+}
+
+//      CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
+// CHECK-SAME:   subgroups_per_workgroup = [2, 2, 1],
+// CHECK-SAME:   batches_per_subgroup = [1, 2, 1],
+// CHECK-SAME:   subgroup_basis = [2, 2, 1, 1],
+// CHECK-SAME:   subgroup_active_ids = [true, true, false, true]
+//      CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
+// CHECK-SAME:   subgroups_per_workgroup = [2, 2, 1],
+// CHECK-SAME:   batches_per_subgroup = [1, 2, 4],
+// CHECK-SAME:   subgroup_basis = [2, 2, 1, 1],
+// CHECK-SAME:   subgroup_active_ids = [true, true, true, false]
+
+// -----
+
+func.func @matmul_192x64x16_mmt_multi_m_and_n(%lhs: vector<4x64x16xf16>, %rhs: vector<2x16x64xf16>, %init: vector<4x2x64x64xf32>) -> vector<4x2x64x64xf32> attributes {
+    mma_schedule = #iree_gpu.mma_schedule<
+      intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>,
+      subgroup_m_count = 2, subgroup_n_count = 2, subgroup_m_tile_count = 8, subgroup_n_tile_count = 4, subgroup_k_tile_count = 1>,
+    workgroup_size = [128, 2, 1]} {
+    %0 = vector.contract {
+      indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>],
+      iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+      %lhs, %rhs, %init : vector<4x64x16xf16>, vector<2x16x64xf16> into vector<4x2x64x64xf32>
+  return %0 : vector<4x2x64x64xf32>
+}
+
+//      CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
+// CHECK-SAME:   subgroups_per_workgroup = [2, 1, 1],
+// CHECK-SAME:   batches_per_subgroup = [2, 4, 1],
+// CHECK-SAME:   subgroup_basis = [2, 2, 1, 1, 1],
+// CHECK-SAME:   subgroup_active_ids = [true, false, true, false, true]
+//      CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
+// CHECK-SAME:   subgroups_per_workgroup = [2, 1, 1],
+// CHECK-SAME:   batches_per_subgroup = [1, 1, 4],
+// CHECK-SAME:   subgroup_basis = [2, 2, 1, 1, 1],
+// CHECK-SAME:   subgroup_active_ids = [false, true, false, true, true]
+//      CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
+// CHECK-SAME:   subgroups_per_workgroup = [2, 2, 1, 1],
+// CHECK-SAME:   batches_per_subgroup = [2, 1, 4, 4],
+// CHECK-SAME:   subgroup_basis = [2, 2, 1, 1, 1],
+// CHECK-SAME:   subgroup_active_ids = [true, true, true, true, false]
diff --git a/compiler/src/iree/compiler/Codegen/Utils/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Utils/BUILD.bazel
index eec2e7c..20c69c2 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Utils/BUILD.bazel
@@ -78,6 +78,7 @@
     deps = [
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:LinalgDialect",
         "@llvm-project//mlir:Support",
         "@llvm-project//mlir:VectorDialect",
     ],
diff --git a/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt
index cefe73e..0c22b45 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt
@@ -72,6 +72,7 @@
   DEPS
     LLVMSupport
     MLIRIR
+    MLIRLinalgDialect
     MLIRSupport
     MLIRVectorDialect
   PUBLIC
diff --git a/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.cpp
index 826ba1d..c5be60b 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.cpp
@@ -6,62 +6,61 @@
 
 #include "iree/compiler/Codegen/Utils/VectorOpUtils.h"
 
+#include "llvm/ADT/STLExtras.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 
 namespace mlir::iree_compiler {
 
-std::optional<std::pair<int, int>>
-VectorContractOpInfo::getOperandMNIndex() const {
-  switch (opKind) {
-  case OpKind::MK_KN_MN:
-    return std::make_pair(0, 1);
-  case OpKind::MK_NK_MN:
-    return std::make_pair(0, 0);
-  case OpKind::UNKNOWN:
-    break;
-  }
-  return std::nullopt;
+std::pair<int, int> VectorContractOpInfo::getOperandMNIndex() const {
+  return std::make_pair(lhsMDims.back(), rhsNDims.back());
 }
 
 // Returns the (LHS K, RHS K) dimension index pair.
-std::optional<std::pair<int, int>>
-VectorContractOpInfo::getOperandKIndex() const {
-  switch (opKind) {
-  case OpKind::MK_KN_MN:
-    return std::make_pair(1, 0);
-  case OpKind::MK_NK_MN:
-    return std::make_pair(1, 1);
-  case OpKind::UNKNOWN:
-    break;
-  }
-  return std::nullopt;
+std::pair<int, int> VectorContractOpInfo::getOperandKIndex() const {
+  return std::make_pair(lhsKDim, rhsKDim);
 }
 
 // Returns the result (M, N) dimension index pair.
-std::optional<std::pair<int, int>>
-VectorContractOpInfo::getResultMNIndex() const {
-  switch (opKind) {
-  case OpKind::MK_KN_MN:
-  case OpKind::MK_NK_MN:
-    return std::make_pair(0, 1);
-  default:
-    break;
-  }
-  return std::nullopt;
+std::pair<int, int> VectorContractOpInfo::getResultMNIndex() const {
+  return std::make_pair(outMDims.back(), outNDims.back());
 }
 
 VectorContractOpInfo::OpKind
 VectorContractOpInfo::inferOpKind(MLIRContext *ctx,
-                                  SmallVector<AffineMap> maps) const {
-  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
-  auto infer = [&](MapList m) { return AffineMap::inferFromExprList(m, ctx); };
-  AffineExpr m, n, k;
-  bindDims(ctx, m, n, k);
-  if (maps == infer({{m, k}, {k, n}, {m, n}}))
-    return OpKind::MK_KN_MN;
-  if (maps == infer({{m, k}, {n, k}, {m, n}}))
-    return OpKind::MK_NK_MN;
+                                  SmallVector<AffineMap> maps) {
+  if (contractionDims.k.size() != 1) {
+    return OpKind::UNKNOWN;
+  }
+
+  int64_t innerM = contractionDims.m.back();
+  int64_t innerN = contractionDims.n.back();
+  int64_t k = contractionDims.k.back();
+
+  int64_t lhsM = *maps[0].getResultPosition(getAffineDimExpr(innerM, ctx));
+  lhsKDim = *maps[0].getResultPosition(getAffineDimExpr(k, ctx));
+  int64_t rhsN = *maps[1].getResultPosition(getAffineDimExpr(innerN, ctx));
+  rhsKDim = *maps[1].getResultPosition(getAffineDimExpr(k, ctx));
+  int64_t outM = *maps[2].getResultPosition(getAffineDimExpr(innerM, ctx));
+  int64_t outN = *maps[2].getResultPosition(getAffineDimExpr(innerN, ctx));
+
+  for (auto m : contractionDims.m) {
+    lhsMDims.push_back(*maps[0].getResultPosition(getAffineDimExpr(m, ctx)));
+    outMDims.push_back(*maps[2].getResultPosition(getAffineDimExpr(m, ctx)));
+  }
+  for (auto n : contractionDims.n) {
+    rhsNDims.push_back(*maps[1].getResultPosition(getAffineDimExpr(n, ctx)));
+    outNDims.push_back(*maps[2].getResultPosition(getAffineDimExpr(n, ctx)));
+  }
+
+  if (outM < outN) {
+    if (lhsM < lhsKDim) {
+      if (rhsN < rhsKDim) {
+        return OpKind::MK_NK_MN;
+      }
+      return OpKind::MK_KN_MN;
+    }
+  }
   return OpKind::UNKNOWN;
 }
 
diff --git a/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.h b/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.h
index be1ba22..3864322 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.h
@@ -4,7 +4,9 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypes.h"
 
 namespace mlir::iree_compiler {
 
@@ -14,25 +16,49 @@
   enum class OpKind { MK_KN_MN, MK_NK_MN, UNKNOWN };
 
   explicit VectorContractOpInfo(vector::ContractionOp op) {
+    contractionDims = *linalg::inferContractionDims(op.getIndexingMapsArray());
     opKind = inferOpKind(op.getContext(), op.getIndexingMapsArray());
   }
 
   OpKind getOpKind() const { return opKind; }
 
   // Returns the (LHS M, RHS N) dimension index pair.
-  std::optional<std::pair<int, int>> getOperandMNIndex() const;
+  std::pair<int, int> getOperandMNIndex() const;
 
   // Returns the (LHS K, RHS K) dimension index pair.
-  std::optional<std::pair<int, int>> getOperandKIndex() const;
+  std::pair<int, int> getOperandKIndex() const;
 
   // Returns the result (M, N) dimension index pair.
-  std::optional<std::pair<int, int>> getResultMNIndex() const;
+  std::pair<int, int> getResultMNIndex() const;
+
+  SmallVector<unsigned, 2> getMDims() const { return contractionDims.m; }
+
+  SmallVector<unsigned, 2> getNDims() const { return contractionDims.n; }
+
+  int64_t getARank() {
+    return contractionDims.m.size() + contractionDims.k.size();
+  }
+  int64_t getBRank() {
+    return contractionDims.k.size() + contractionDims.n.size();
+  }
+  int64_t getCRank() {
+    return contractionDims.m.size() + contractionDims.n.size();
+  }
+
+  SmallVector<int64_t> lhsMDims;
+  int64_t lhsKDim;
+  SmallVector<int64_t> rhsNDims;
+  int64_t rhsKDim;
+  SmallVector<int64_t> outMDims;
+  SmallVector<int64_t> outNDims;
 
 private:
   // Gets the kind of a contract op with the given indexing |maps|.
-  OpKind inferOpKind(MLIRContext *ctx, SmallVector<AffineMap> maps) const;
+  OpKind inferOpKind(MLIRContext *ctx, SmallVector<AffineMap> maps);
 
   OpKind opKind = OpKind::UNKNOWN;
+
+  linalg::ContractionDimensions contractionDims;
 };
 
 } // namespace mlir::iree_compiler
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
index 6e59a85..a970a2c 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
@@ -740,7 +740,7 @@
   p << ']';
   if (llvm::any_of(mask, [](bool b) { return !b; })) {
     p << ',' << ' ';
-    p << maskName;
+    p << maskName << ' ';
     p << '=';
     p << ' ';
     p << '[';
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir
index 6e60022..5600905 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir
@@ -110,29 +110,89 @@
 func.func @specify_nested(%lhs: memref<32x32xf16>) -> vector<32x32xf16> {
   %cst_0 = arith.constant 0.0 : f16
   %c0 = arith.constant 0 : index
-  %result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<32x32xf16>, vector<32x32xf16>
-  %2 = iree_vector_ext.layout_conflict_resolution %result {
-    sourceLayout = #nested_1,
-    desiredLayout = #nested_2,
-    otherLayout0 = #nested_3,
-    otherLayout1 = #nested_4,
-    otherLayout2 = #nested_5
-  } : vector<32x32xf16> -> vector<32x32xf16>
-  return %2 : vector<32x32xf16>
+  %result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {
+    in_bounds = [true, true],
+    layout0 = #nested_1,
+    layout1 = #nested_2,
+    layout2 = #nested_3,
+    layout3 = #nested_4,
+    layout4 = #nested_5
+  } : memref<32x32xf16>, vector<32x32xf16>
+  return %result : vector<32x32xf16>
 }
 
-// CHECK-DAG: #[[LAYOUT0:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [4, 2], outers_per_batch = [1, 4], threads_per_outer = [2, 4], elements_per_thread = [4, 1], subgroup_order = [1, 0], batch_order = [1, 0], thread_order = [1, 0], element_order = [1, 0], subgroup_basis = [1, 1], thread_basis = [2, 4]>
-// CHECK-DAG: #[[LAYOUT1:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 4], outers_per_batch = [4, 1], threads_per_outer = [4, 2], elements_per_thread = [1, 4], outer_order = [1, 0], subgroup_basis = [1, 1], thread_basis = [4, 2]>
-// CHECK-DAG: #[[LAYOUT2:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [4, 2], outers_per_batch = [1, 4], threads_per_outer = [2, 4], elements_per_thread = [4, 1], subgroup_order = [1, 0], batch_order = [1, 0], thread_order = [1, 0], element_order = [1, 0], subgroup_basis = [2, 4, 8], subgroup_active_ids= [true, true, false], thread_basis = [2, 4]>
-// CHECK-DAG: #[[LAYOUT3:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [4, 2], outers_per_batch = [1, 4], threads_per_outer = [2, 4], elements_per_thread = [4, 1], subgroup_order = [1, 0], batch_order = [1, 0], thread_order = [1, 0], element_order = [1, 0], subgroup_basis = [2, 4, 8], subgroup_active_ids= [true, true, false], thread_basis = [2, 4, 2], thread_active_ids= [false, true, true]>
-// CHECK-DAG: #[[LAYOUT4:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [4, 2], outers_per_batch = [1, 4], threads_per_outer = [2, 4], elements_per_thread = [4, 1], subgroup_order = [1, 0], batch_order = [1, 0], thread_order = [1, 0], element_order = [1, 0], subgroup_basis = [2, 4], thread_basis = [4, 2]>
+// CHECK: #[[LAYOUT0:.+]] = #iree_vector_ext.nested_layout<
+// CHECK-SAME: subgroups_per_workgroup = [1, 1],
+// CHECK-SAME: batches_per_subgroup = [2, 4],
+// CHECK-SAME: outers_per_batch = [4, 1],
+// CHECK-SAME: threads_per_outer = [4, 2],
+// CHECK-SAME: elements_per_thread = [1, 4],
+// CHECK-SAME: outer_order = [1, 0],
+// CHECK-SAME: subgroup_basis = [1, 1],
+// CHECK-SAME: thread_basis = [4, 2]>
+
+// CHECK: #[[LAYOUT1:.+]] = #iree_vector_ext.nested_layout<
+// CHECK-SAME: subgroups_per_workgroup = [1, 1],
+// CHECK-SAME: batches_per_subgroup = [4, 2],
+// CHECK-SAME: outers_per_batch = [1, 4],
+// CHECK-SAME: threads_per_outer = [2, 4],
+// CHECK-SAME: elements_per_thread = [4, 1],
+// CHECK-SAME: subgroup_order = [1, 0],
+// CHECK-SAME: batch_order = [1, 0],
+// CHECK-SAME: thread_order = [1, 0],
+// CHECK-SAME: element_order = [1, 0],
+// CHECK-SAME: subgroup_basis = [1, 1],
+// CHECK-SAME: thread_basis = [2, 4]>
+
+// CHECK: #[[LAYOUT2:.+]] = #iree_vector_ext.nested_layout<
+// CHECK-SAME: subgroups_per_workgroup = [1, 1],
+// CHECK-SAME: batches_per_subgroup = [4, 2],
+// CHECK-SAME: outers_per_batch = [1, 4],
+// CHECK-SAME: threads_per_outer = [2, 4],
+// CHECK-SAME: elements_per_thread = [4, 1],
+// CHECK-SAME: subgroup_order = [1, 0],
+// CHECK-SAME: batch_order = [1, 0],
+// CHECK-SAME: thread_order = [1, 0],
+// CHECK-SAME: element_order = [1, 0],
+// CHECK-SAME: subgroup_basis = [2, 4, 8],
+// CHECK-SAME: subgroup_active_ids = [true, true, false],
+// CHECK-SAME: thread_basis = [2, 4]>
+
+// CHECK: #[[LAYOUT3:.+]] = #iree_vector_ext.nested_layout<
+// CHECK-SAME: subgroups_per_workgroup = [1, 1],
+// CHECK-SAME: batches_per_subgroup = [4, 2],
+// CHECK-SAME: outers_per_batch = [1, 4],
+// CHECK-SAME: threads_per_outer = [2, 4],
+// CHECK-SAME: elements_per_thread = [4, 1],
+// CHECK-SAME: subgroup_order = [1, 0],
+// CHECK-SAME: batch_order = [1, 0],
+// CHECK-SAME: thread_order = [1, 0],
+// CHECK-SAME: element_order = [1, 0],
+// CHECK-SAME: subgroup_basis = [2, 4, 8],
+// CHECK-SAME: subgroup_active_ids = [true, true, false],
+// CHECK-SAME: thread_basis = [2, 4, 2],
+// CHECK-SAME: thread_active_ids = [false, true, true]>
+
+// CHECK: #[[LAYOUT4:.+]] = #iree_vector_ext.nested_layout<
+// CHECK-SAME: subgroups_per_workgroup = [1, 1],
+// CHECK-SAME: batches_per_subgroup = [4, 2],
+// CHECK-SAME: outers_per_batch = [1, 4],
+// CHECK-SAME: threads_per_outer = [2, 4],
+// CHECK-SAME: elements_per_thread = [4, 1],
+// CHECK-SAME: subgroup_order = [1, 0],
+// CHECK-SAME: batch_order = [1, 0],
+// CHECK-SAME: thread_order = [1, 0],
+// CHECK-SAME: element_order = [1, 0],
+// CHECK-SAME: subgroup_basis = [2, 4],
+// CHECK-SAME: thread_basis = [4, 2]>
+
 // CHECK-LABEL: func.func @specify_nested
-// CHECK:      iree_vector_ext.layout_conflict_resolution
-// CHECK-SAME:         desiredLayout = #[[LAYOUT0]]
-// CHECK-SAME:         otherLayout0 = #[[LAYOUT2]]
-// CHECK-SAME:         otherLayout1 = #[[LAYOUT3]]
-// CHECK-SAME:         otherLayout2 = #[[LAYOUT4]]
-// CHECK-SAME:         sourceLayout = #[[LAYOUT1]]
+// CHECK:      vector.transfer_read
+// CHECK-SAME:         layout0 = #[[LAYOUT0]]
+// CHECK-SAME:         layout1 = #[[LAYOUT1]]
+// CHECK-SAME:         layout2 = #[[LAYOUT2]]
+// CHECK-SAME:         layout3 = #[[LAYOUT3]]
+// CHECK-SAME:         layout4 = #[[LAYOUT4]]
 
 // -----