Revert "[VectorDistribute] Refactor layout configuration to a simpler logic" (#21887)

CI is failing. PkgCI sharktank tests are not required on pre-submit
(they should be).

Reverts iree-org/iree#21883
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
index 98c0f5f..83768db 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
@@ -29,7 +29,6 @@
 #define GEN_PASS_DEF_LLVMGPUCONFIGURETENSORLAYOUTSPASS
 #include "iree/compiler/Codegen/LLVMGPU/Passes.h.inc"
 
-using IREE::GPU::MMASingleSubgroupLayout;
 using IREE::VectorExt::NestedLayoutAttr;
 using IREE::VectorExt::ToLayoutOp;
 using IREE::VectorExt::VectorLayoutInterface;
@@ -80,123 +79,81 @@
   return getSubgroupNCount(config).value();
 }
 
-struct ContractionLayout {
-  VectorLayoutInterface lhs;
-  VectorLayoutInterface rhs;
-  VectorLayoutInterface acc;
-};
-
-// Get the layouts to use for the contraction given the intrinsic to use and
-// number of subgroups on the M and N dimension.
-//
-// The contraction is expected to have 3 operands: lhs, rhs and acc of the
-// contraction and a single accumulator.
-static FailureOr<ContractionLayout>
-getContractionLayout(IREE::Codegen::InnerTileDescAttrInterface intrinsic,
-                     int64_t subgroupMCount, int64_t subgroupNCount,
-                     ArrayRef<int64_t> bounds,
-                     ArrayRef<AffineMap> contractIndexingMaps) {
-  auto mmaIntrinsic = dyn_cast<IREE::GPU::MmaInterfaceAttr>(intrinsic);
-  if (!mmaIntrinsic) {
-    return failure();
-  }
-
-  int64_t rank = bounds.size();
-  FailureOr<VectorContractOpInfo> maybeOpInfo =
-      VectorContractOpInfo::inferFromIndexingMaps(contractIndexingMaps);
-  if (failed(maybeOpInfo)) {
-    return failure();
-  }
-  VectorContractOpInfo opInfo = maybeOpInfo.value();
-  // Get the inner dimensions.
-  int64_t innerMDim = opInfo.getMDims().back();
-  int64_t innerNDim = opInfo.getNDims().back();
-  int64_t innerKDim = opInfo.getKDims().back();
-  // Find the number Of subgroups being used from subgroupMCount/subgroupNCount.
-  // Just assign them to the inner-most dimension for now.
-  // TODO: Use subgroup_basis to get this instead and allow distributing
-  // subgroups on multiple dimensions.
-  SmallVector<int64_t> subgroupCounts(rank, 1);
-  SmallVector<int64_t> subgroupStrides(rank, 0);
-  subgroupCounts[innerMDim] = subgroupMCount;
-  subgroupCounts[innerNDim] = subgroupNCount;
-  // Distribute on M and then N.
-  // TODO: Use subgroup_basis to get the strides.
-  subgroupStrides[innerMDim] = 1;
-  subgroupStrides[innerNDim] = subgroupMCount;
-
-  // Since these MMA intrinsics have a given tile size for each subgroup, we can
-  // calculate the batch dimensions without looking at the subgroup layout.
-  SmallVector<int64_t> subgroupSize(rank, 1);
-  auto [mSize, nSize, kSize] = mmaIntrinsic.getMNKShape();
-  subgroupSize[innerMDim] = mSize;
-  subgroupSize[innerNDim] = nSize;
-  subgroupSize[innerKDim] = kSize;
-
-  SmallVector<int64_t> batchCounts(rank);
-  for (auto [batchCount, subgroupCount, subgroupSize, bound] :
-       llvm::zip_equal(batchCounts, subgroupCounts, subgroupSize, bounds)) {
-    int64_t workgroupDimSize = subgroupCount * subgroupSize;
-    batchCount = llvm::divideCeil(bound, workgroupDimSize);
-  }
-
-  // MMA intrinsics can be weird and usually don't have a single subgroup
-  // iteration space, so we need to find their value subgroup iteration space
-  // indvidually.
-  auto getFragmentLayout = [&](IREE::GPU::MMAFragment fragment,
-                               int64_t outerDim, int64_t innerDim,
-                               AffineMap map) -> VectorLayoutInterface {
-    // Note that the struct MMASingleSubgroupLayout contains the partial layout
-    // for the canonical (M, K) x (K, N) -> (M, N) matmul form. We treat the
-    // concrete nested layout as the layout for the innermost M, N, K
-    // dimensions.
-    SmallVector<int64_t> outerCounts(rank, 1);
-    SmallVector<int64_t> elementCounts(rank, 1);
-    SmallVector<int64_t> threadCounts(rank, 1);
-    SmallVector<int64_t> threadStrides(rank, 0);
-
-    MMASingleSubgroupLayout subgroupLayout =
-        IREE::GPU::getSingleSubgroupLayout(intrinsic, fragment);
-    outerCounts[outerDim] = subgroupLayout.outer[0];
-    outerCounts[innerDim] = subgroupLayout.outer[1];
-    threadCounts[outerDim] = subgroupLayout.thread[0];
-    threadCounts[innerDim] = subgroupLayout.thread[1];
-    threadStrides[outerDim] = subgroupLayout.tstrides[0];
-    threadStrides[innerDim] = subgroupLayout.tstrides[1];
-    elementCounts[outerDim] = subgroupLayout.element[0];
-    elementCounts[innerDim] = subgroupLayout.element[1];
-    // Get the fragment layout for the entire iteration space and then project
-    // it. This is significantly easier than trying to create a layout for each
-    // fragment itself.
-    auto fragmentSpaceLayout = NestedLayoutAttr::get(
-        map.getContext(), subgroupCounts, batchCounts, outerCounts,
-        threadCounts, elementCounts, subgroupStrides, threadStrides);
-    return fragmentSpaceLayout.apply(map);
-  };
-
-  VectorLayoutInterface lhs =
-      getFragmentLayout(IREE::GPU::MMAFragment::Lhs, innerMDim, innerKDim,
-                        contractIndexingMaps[0]);
-  VectorLayoutInterface rhs =
-      getFragmentLayout(IREE::GPU::MMAFragment::Rhs, innerKDim, innerNDim,
-                        contractIndexingMaps[1]);
-  VectorLayoutInterface acc =
-      getFragmentLayout(IREE::GPU::MMAFragment::Acc, innerMDim, innerNDim,
-                        contractIndexingMaps[2]);
-
-  return ContractionLayout{lhs, rhs, acc};
+/// Gets a unit vector of the given rank, but fills in the given dimensions
+/// from the 2 element array |counts|. |dim0| is the position in the returned
+/// vector to put the first element of |counts|, and |dim1| is the position to
+/// put the second element. For example,
+///
+/// rank = 3, counts = [5, 7], dim0 = 2, dim1 = 1
+/// returns [1, 5, 7]
+static SmallVector<int64_t> getUnitOfRankWithDims(int64_t rank,
+                                                  ArrayRef<int64_t> counts,
+                                                  int64_t dim0, int64_t dim1) {
+  assert(counts.size() == 2 &&
+         "Unexpected non-rank 2 single subgroup dimension counts");
+  SmallVector<int64_t> res(rank, 1);
+  res[dim0] = counts[0];
+  res[dim1] = counts[1];
+  return res;
 }
 
-SmallVector<int64_t> getIterationSpaceBounds(linalg::LinalgOp linalgOp) {
-  SmallVector<int64_t> bounds = linalgOp.getStaticLoopRanges();
-  std::optional<VectorizationTileSizes> sizes =
-      inferSizesFromIR(linalgOp, std::nullopt);
-  // Even though the opShape could be dynamic, we could potentially
-  // infer the vector shape
-  if (sizes.has_value()) {
-    bounds = sizes.value().vectorSizes;
-  }
-  return bounds;
+/// Constructs the nested layout given the layout for a single subgroup and the
+/// subgroup/batch counts and orders, as well as the dimensions along which to
+/// distribute the intrinsic's layout.
+///
+/// |outerDim| and |innerDim| refer to which dimensions are the outermost and
+/// innermost for a canonical MK_KN_MN matrix multiply, for a particular
+/// fragment. For example, for the B matrix of an MK_NK_MN matrix multiply,
+/// we would have:
+///   outerDim = 1 for the K dim
+///   innerDim = 0 for the N dim
+///
+/// For something like MK_NKN_MN with multiple N dims, it would typically be:
+///   outerDim = 1 for K
+///   innerDim = 2 for the second N dim
+///
+/// Importantly these two dimensions always refer to the actual dimension
+/// positions in the undistributed vector. For each fragment, this means:
+///   A: [outerDim, innerDim] = [innerMostMDim, innerMostKDim]
+///   B: [outerDim, innerDim] = [innerMostKDim, innerMostNDim]
+///   C: [outerDim, innerDim] = [innerMostMDim, innerMostNDim]
+///
+/// And here inner most is referential to the iteration order, not the order
+/// they appear per fragment (because there is no relationship between the
+/// dimension order of M in A and in C, for example).
+static NestedLayoutAttr createNestedLayout(
+    MLIRContext *context, int64_t rank, int64_t outerDim, int64_t innerDim,
+    ArrayRef<int64_t> subgroupSizes, ArrayRef<int64_t> subgroupStrides,
+    ArrayRef<int64_t> batchCount, IREE::GPU::MMASingleSubgroupLayout counts) {
+
+  LLVM_DEBUG({
+    llvm::dbgs() << "Creating Nested Layout for::" << "\n    outerDim = "
+                 << outerDim << "\n    innerDim = " << innerDim
+                 << "\n    subgroupSizes: " << llvm::interleaved(subgroupSizes)
+                 << "\n    subgroupStrides: "
+                 << llvm::interleaved(subgroupStrides)
+                 << "\n    batchCount: " << llvm::interleaved(batchCount)
+                 << "\n    counts.outer: " << llvm::interleaved(counts.outer)
+                 << "\n    counts.thread: " << llvm::interleaved(counts.thread)
+                 << "\n    counts.element: "
+                 << llvm::interleaved(counts.element)
+                 << "\n    counts.tstrides: "
+                 << llvm::interleaved(counts.tstrides) << "\n";
+  });
+
+  SmallVector<int64_t> outerCount =
+      getUnitOfRankWithDims(rank, counts.outer, outerDim, innerDim);
+  SmallVector<int64_t> threadCount =
+      getUnitOfRankWithDims(rank, counts.thread, outerDim, innerDim);
+  SmallVector<int64_t> threadStrides =
+      getUnitOfRankWithDims(rank, counts.tstrides, outerDim, innerDim);
+  SmallVector<int64_t> elementCount =
+      getUnitOfRankWithDims(rank, counts.element, outerDim, innerDim);
+
+  auto layoutAttr = NestedLayoutAttr::get(context, subgroupSizes, batchCount,
+                                          outerCount, threadCount, elementCount,
+                                          subgroupStrides, threadStrides);
+  return layoutAttr;
 }
 
 static LogicalResult setContractionAnchor(IREE::GPU::MMAScheduleAttr schedule,
@@ -212,10 +169,12 @@
   assert(linalg::isaContractionOpInterface(contract) &&
          "cannot set contraction anchor on non contraction op");
 
-  SmallVector<int64_t> bounds = getIterationSpaceBounds(contract);
-  auto layouts = getContractionLayout(
-      schedule.getIntrinsic(), schedule.getSubgroupMCount(),
-      schedule.getSubgroupNCount(), bounds, contract.getIndexingMapsArray());
+  FailureOr<VectorContractOpInfo> opInfo =
+      VectorContractOpInfo::inferFromIndexingMaps(
+          contract.getIndexingMapsArray());
+  assert(succeeded(opInfo) && "contraction should have been inferred");
+
+  auto layouts = getContractionLayout(schedule, opInfo.value(), contract);
   if (failed(layouts)) {
     return contract->emitError("cannot get concrete layout for contraction");
   }
@@ -297,10 +256,15 @@
     map = projectDims(map, filterDims, /*compressDimsFlag=*/false);
   }
 
-  SmallVector<int64_t> bounds = getIterationSpaceBounds(conv);
-  auto layouts = getContractionLayout(
-      schedule.getIntrinsic(), schedule.getSubgroupMCount(),
-      schedule.getSubgroupNCount(), bounds, maps);
+  FailureOr<VectorContractOpInfo> opInfo =
+      VectorContractOpInfo::inferFromIndexingMaps(maps);
+  assert(succeeded(opInfo) &&
+         "unit filter dim convolution should have been infered");
+
+  auto layouts = getContractionLayout(schedule, opInfo.value(), conv);
+  if (failed(layouts)) {
+    return conv->emitError("cannot get concrete layout for convolution");
+  }
 
   auto [aLayout, bLayout, cLayout] = *layouts;
   Location loc = conv.getLoc();
@@ -492,6 +456,26 @@
                               pvMatmul);
 }
 
+// Apply the permuted projection map to the layout.
+static IREE::VectorExt::VectorLayoutInterface
+getLayoutForMap(VectorLayoutInterface layout, AffineMap map) {
+  // Project out unusued dims in layout.
+  SmallVector<bool> projectedDims(layout.getRank(), false);
+  llvm::SmallBitVector unusedBits = getUnusedDimsBitVector(map);
+  for (int dim : unusedBits.set_bits()) {
+    projectedDims[dim] = true;
+  }
+  IREE::VectorExt::VectorLayoutInterface projectedLayout =
+      layout.project(projectedDims);
+
+  // Transpose dims in layout.
+  AffineMap permMap = compressUnusedDims(map);
+  SmallVector<int64_t> identity =
+      llvm::to_vector(llvm::seq<int64_t>(permMap.getNumDims()));
+  SmallVector<int64_t> perm = applyPermutationMap<int64_t>(permMap, identity);
+  return projectedLayout.permute(perm);
+}
+
 static LogicalResult setDerivedThreadConfigLayout(
     IREE::GPU::DerivedThreadConfigAttr config, linalg::LinalgOp linalgOp,
     ArrayRef<int64_t> workgroupSize, RewriterBase &rewriter) {
@@ -568,7 +552,7 @@
   rewriter.setInsertionPointAfter(linalgOp);
   for (OpResult result : linalgOp->getResults()) {
     VectorLayoutInterface resultLayout =
-        layout.apply(linalgOp.getIndexingMapMatchingResult(result));
+        getLayoutForMap(layout, linalgOp.getIndexingMapMatchingResult(result));
     auto toLayout = rewriter.create<ToLayoutOp>(loc, result, resultLayout);
     rewriter.replaceAllUsesExcept(result, toLayout, toLayout);
   }
@@ -691,7 +675,14 @@
   MLIRContext *context = config.getContext();
   Location loc = candidate.getLoc();
 
-  SmallVector<int64_t> bounds = getIterationSpaceBounds(candidate);
+  SmallVector<int64_t> bounds = candidate.getStaticLoopRanges();
+  std::optional<VectorizationTileSizes> sizes =
+      inferSizesFromIR(candidate, std::nullopt);
+  // Even though the opShape could be dynamic, we could potentially
+  // infer the vector shape
+  if (sizes.has_value()) {
+    bounds = sizes.value().vectorSizes;
+  }
 
   // Subgroup distribution layouts.
   SmallVector<int64_t> subgroupSizes, subgroupStrides;
@@ -732,7 +723,7 @@
   rewriter.setInsertionPoint(candidate);
   for (OpOperand &operand : candidate->getOpOperands()) {
     VectorLayoutInterface operandLayout =
-        layout.apply(candidate.getMatchingIndexingMap(&operand));
+        getLayoutForMap(layout, candidate.getMatchingIndexingMap(&operand));
     auto toLayout =
         rewriter.create<ToLayoutOp>(loc, operand.get(), operandLayout);
     // Set shared memory promotion if requested.
@@ -744,7 +735,7 @@
   rewriter.setInsertionPointAfter(candidate);
   for (OpResult result : candidate->getResults()) {
     VectorLayoutInterface resultLayout =
-        layout.apply(candidate.getIndexingMapMatchingResult(result));
+        getLayoutForMap(layout, candidate.getIndexingMapMatchingResult(result));
     auto toLayout = rewriter.create<ToLayoutOp>(loc, result, resultLayout);
     rewriter.replaceAllUsesExcept(result, toLayout, toLayout);
   }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.cpp
index 6cc7285..d8b9fc4 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.cpp
@@ -415,4 +415,335 @@
   packAllocs(builder, funcOp, aliasGroups);
 }
 
+/// Gets a unit vector of the given rank, but fills in the given dimensions
+/// from the 2 element array |counts|. |dim0| is the position in the returned
+/// vector to put the first element of |counts|, and |dim1| is the position to
+/// put the second element. For example,
+///
+/// rank = 3, counts = [5, 7], dim0 = 2, dim1 = 1
+/// returns [1, 5, 7]
+static SmallVector<int64_t> getUnitOfRankWithDims(int64_t rank,
+                                                  ArrayRef<int64_t> counts,
+                                                  int64_t dim0, int64_t dim1) {
+  assert(counts.size() == 2 &&
+         "Unexpected non-rank 2 single subgroup dimension counts");
+  SmallVector<int64_t> res(rank, 1);
+  res[dim0] = counts[0];
+  res[dim1] = counts[1];
+  return res;
+}
+
+/// Constructs the nested layout given the layout for a single subgroup and the
+/// subgroup/batch counts and orders, as well as the dimensions along which to
+/// distribute the intrinsic's layout.
+///
+/// |outerDim| and |innerDim| refer to which dimensions are the outermost and
+/// innermost for a canonical MK_KN_MN matrix multiply, for a particular
+/// fragment. For example, for the B matrix of an MK_NK_MN matrix multiply,
+/// we would have:
+///   outerDim = 1 for the K dim
+///   innerDim = 0 for the N dim
+///
+/// For something like MK_NKN_MN with multiple N dims, it would typically be:
+///   outerDim = 1 for K
+///   innerDim = 2 for the second N dim
+///
+/// Importantly these two dimensions always refer to the actual dimension
+/// positions in the undistributed vector. For each fragment, this means:
+///   A: [outerDim, innerDim] = [innerMostMDim, innerMostKDim]
+///   B: [outerDim, innerDim] = [innerMostKDim, innerMostNDim]
+///   C: [outerDim, innerDim] = [innerMostMDim, innerMostNDim]
+///
+/// And here inner most is referential to the iteration order, not the order
+/// they appear per fragment (because there is no relationship between the
+/// dimension order of M in A and in C, for example).
+static NestedLayoutAttr createNestedLayout(
+    MLIRContext *context, int64_t rank, int64_t outerDim, int64_t innerDim,
+    ArrayRef<int64_t> subgroupSizes, ArrayRef<int64_t> subgroupStrides,
+    ArrayRef<int64_t> batchCount, IREE::GPU::MMASingleSubgroupLayout counts) {
+
+  LLVM_DEBUG({
+    llvm::dbgs() << "Creating Nested Layout for::";
+    llvm::dbgs() << "\n    outerDim = " << outerDim;
+    llvm::dbgs() << "\n    innerDim = " << innerDim;
+    llvm::dbgs() << "\n    subgroupSizes: ";
+    llvm::interleaveComma(subgroupSizes, llvm::dbgs());
+    llvm::dbgs() << "\n    subgroupStrides: ";
+    llvm::interleaveComma(subgroupStrides, llvm::dbgs());
+    llvm::dbgs() << "\n    batchCount: ";
+    llvm::interleaveComma(batchCount, llvm::dbgs());
+    llvm::dbgs() << "\n    counts.outer: ";
+    llvm::interleaveComma(counts.outer, llvm::dbgs());
+    llvm::dbgs() << "\n    counts.thread: ";
+    llvm::interleaveComma(counts.thread, llvm::dbgs());
+    llvm::dbgs() << "\n    counts.element: ";
+    llvm::interleaveComma(counts.element, llvm::dbgs());
+    llvm::dbgs() << "\n    counts.tstrides: ";
+    llvm::interleaveComma(counts.tstrides, llvm::dbgs());
+    llvm::dbgs() << "\n";
+  });
+
+  SmallVector<int64_t> outerCount =
+      getUnitOfRankWithDims(rank, counts.outer, outerDim, innerDim);
+  SmallVector<int64_t> threadCount =
+      getUnitOfRankWithDims(rank, counts.thread, outerDim, innerDim);
+  SmallVector<int64_t> threadStrides =
+      getUnitOfRankWithDims(rank, counts.tstrides, outerDim, innerDim);
+  SmallVector<int64_t> elementCount =
+      getUnitOfRankWithDims(rank, counts.element, outerDim, innerDim);
+
+  auto layoutAttr = NestedLayoutAttr::get(context, subgroupSizes, batchCount,
+                                          outerCount, threadCount, elementCount,
+                                          subgroupStrides, threadStrides);
+  return layoutAttr;
+}
+
+template <typename ContractOpTy>
+static FailureOr<std::tuple<IREE::VectorExt::VectorLayoutInterface,
+                            IREE::VectorExt::VectorLayoutInterface,
+                            IREE::VectorExt::VectorLayoutInterface>>
+getContractionLayoutImpl(IREE::GPU::MMAScheduleAttr schedule,
+                         VectorContractOpInfo &opInfo,
+                         ContractOpTy contractOp) {
+  LLVM_DEBUG({
+    llvm::dbgs() << "Getting mma layouts for:\n" << contractOp << "\n";
+    llvm::dbgs() << "For schedule: " << schedule << "\n";
+  });
+
+  int64_t rank = contractOp.getIteratorTypesArray().size();
+  auto mmaAttr =
+      llvm::dyn_cast<IREE::GPU::MmaInterfaceAttr>(schedule.getIntrinsic());
+  if (!mmaAttr) {
+    return failure();
+  }
+
+  MLIRContext *context = schedule.getContext();
+
+  SmallVector<int64_t> bounds = contractOp.getStaticLoopRanges();
+  if (llvm::any_of(bounds,
+                   [](int64_t x) { return x == ShapedType::kDynamic; })) {
+    return failure();
+  }
+
+  if (!llvm::all_of(opInfo.getBatchDims(),
+                    [&bounds](int64_t dim) { return bounds[dim] == 1; })) {
+    LLVM_DEBUG({ llvm::dbgs() << "non-unit batch dimension\n"; });
+    return failure();
+  }
+
+  // Get the concrete nested layout for each matrix. Note that the struct
+  // MMASingleSubgroupLayout contains the partial layout for the
+  // canonical (M, K) x (K, N) -> (M, N) matmul form; while the specific
+  // contract op we are looking at right now may not be exactly in that form.
+  // So here we need to permute/transpose the canonical layout to match with
+  // the concrete contract op.
+
+  // Note that no matter how we permute/transpose the input contraction
+  // problem, the way we view the hardware warps remain the same--that is,
+  // from the hardware's perspective, a single warp has the same warp ID no
+  // matter what part of the contraction it works on. Similarly here, we are
+  // delinearizing the linearized GPU hardware lane ID into a n-D concatenated
+  // logical warp+thread using the subgroup/thread basis, so the subgroup
+  // basis should remain the same for all A/B/C matrix.
+
+  auto [intrinsicM, intrinsicN, intrinsicK] = mmaAttr.getMNKShape();
+
+  SmallVector<int64_t, 2> subgroupMBasis;
+  SmallVector<int64_t, 2> batchMSizes;
+  int64_t currMCount = schedule.getSubgroupMCount();
+
+  auto divideGreedily = [](int64_t availableSubgroups, int64_t dimSize,
+                           int64_t minDimSize) -> std::pair<int64_t, int64_t> {
+    int64_t dividableDim = dimSize / minDimSize;
+    int64_t subgroupsUsed = std::gcd(availableSubgroups, dividableDim);
+    dividableDim /= subgroupsUsed;
+    int64_t batchesUsed = dividableDim;
+    return {subgroupsUsed, batchesUsed};
+  };
+
+  // Greedily break up the M subgroup and batch counts along the "M" iteration
+  // bounds. We distribute as many residual subgroups as possible per M dim,
+  // and then divide the remaining along batch dims. The inner most M dim is
+  // always the one used for the intrinsic, meaning for a valid schedule, the
+  // computed batch counts and subgroup basis will satisfy totalMSize /
+  // intrinsicM = product(batchMSizes) * product(subgroupMBasis)
+  for (auto dim : opInfo.getMDims()) {
+    // Get the number of subgroups and batches used for this dimension based
+    // on the intrinsic size and the bound size.
+    int64_t subgroupsUsed, batchesUsed;
+    if (dim == opInfo.getMDims().back()) {
+      std::tie(subgroupsUsed, batchesUsed) =
+          divideGreedily(currMCount, bounds[dim], intrinsicM);
+    } else {
+      std::tie(subgroupsUsed, batchesUsed) =
+          divideGreedily(currMCount, bounds[dim], 1);
+    }
+    subgroupMBasis.push_back(subgroupsUsed);
+    batchMSizes.push_back(batchesUsed);
+    // Update available subgroup count.
+    currMCount /= subgroupsUsed;
+  }
+
+  SmallVector<int64_t, 2> subgroupNBasis;
+  SmallVector<int64_t, 2> batchNSizes;
+  int64_t currNCount = schedule.getSubgroupNCount();
+
+  // Do the same for N dims.
+  for (auto dim : opInfo.getNDims()) {
+    // Get the number of subgroups and batches used for this dimension based
+    // on the intrinsic size and the bound size.
+    int64_t subgroupsUsed, batchesUsed;
+    if (dim == opInfo.getNDims().back()) {
+      std::tie(subgroupsUsed, batchesUsed) =
+          divideGreedily(currNCount, bounds[dim], intrinsicN);
+    } else {
+      std::tie(subgroupsUsed, batchesUsed) =
+          divideGreedily(currNCount, bounds[dim], 1);
+    }
+    subgroupNBasis.push_back(subgroupsUsed);
+    batchNSizes.push_back(batchesUsed);
+    // Update available subgroup count.
+    currNCount /= subgroupsUsed;
+  }
+
+  SmallVector<int64_t> subgroupMStrides(subgroupMBasis.size());
+  SmallVector<int64_t> subgroupNStrides(subgroupNBasis.size());
+
+  auto mDimVec = opInfo.getMDims();
+  llvm::SmallDenseSet<int64_t> mDims(mDimVec.begin(), mDimVec.end());
+  auto nDimVec = opInfo.getNDims();
+  llvm::SmallDenseSet<int64_t> nDims(nDimVec.begin(), nDimVec.end());
+  // Because we currently require all batch dimensions to be unit, the
+  // subgroup basis can be constructed from the M and N bases. To keep things
+  // simple, the current heuristic is to distribute the loop dimensions from
+  // outer to inner.
+  int64_t currStride = 1;
+  int64_t currM = subgroupMStrides.size() - 1;
+  int64_t currN = subgroupNStrides.size() - 1;
+  for (int64_t dim : llvm::reverse(llvm::seq<int64_t>(rank))) {
+    if (mDims.contains(dim)) {
+      subgroupMStrides[currM] = currStride;
+      currStride *= subgroupMBasis[currM];
+      currM--;
+      continue;
+    }
+
+    if (nDims.contains(dim)) {
+      subgroupNStrides[currN] = currStride;
+      currStride *= subgroupNBasis[currN];
+      currN--;
+      continue;
+    }
+  }
+
+  // C matrix layout
+  auto [m, n] = opInfo.getResultMNIndex();
+  int64_t cRank = opInfo.getCRank();
+
+  // Get the M and N dims w.r.t. the dimensions of the C matrix. cMDims and
+  // cNDims are the M and N dimensions of the C matrix in the order they are
+  // iterated over in the contraction.
+  SmallVector<int64_t> cMDims = opInfo.outMDims;
+  SmallVector<int64_t> cNDims = opInfo.outNDims;
+  SmallVector<int64_t> cBatchSizes(cRank, 1);
+  SmallVector<int64_t> cSubgroupSizes(cRank, 1);
+  SmallVector<int64_t> cSubgroupStrides(cRank, 0);
+  for (auto [i, dim] : llvm::enumerate(cMDims)) {
+    cBatchSizes[dim] = batchMSizes[i];
+    cSubgroupSizes[dim] = subgroupMBasis[i];
+    cSubgroupStrides[dim] = subgroupMStrides[i];
+  }
+  for (auto [i, dim] : llvm::enumerate(cNDims)) {
+    cBatchSizes[dim] = batchNSizes[i];
+    cSubgroupSizes[dim] = subgroupNBasis[i];
+    cSubgroupStrides[dim] = subgroupNStrides[i];
+  }
+
+  IREE::VectorExt::NestedLayoutAttr cLayout = createNestedLayout(
+      context, cRank, m, n,
+      /*subgroupCount=*/cSubgroupSizes,
+      /*subgroupStrides=*/cSubgroupStrides,
+      /*batchCount=*/cBatchSizes,
+      getSingleSubgroupLayout(mmaAttr, IREE::GPU::MMAFragment::Acc));
+  LLVM_DEBUG({ llvm::dbgs() << "C layout: " << cLayout << "\n"; });
+
+  // A matrix layout
+  auto [afm, bfn] = opInfo.getOperandMNIndex();
+  auto [afk, bfk] = opInfo.getOperandKIndex();
+
+  int64_t aRank = opInfo.getARank();
+
+  SmallVector<int64_t> aMDims = opInfo.lhsMDims;
+  SmallVector<int64_t> aBatchSizes(aRank, 1);
+  SmallVector<int64_t> aSubgroupSizes(aRank, 1);
+  SmallVector<int64_t> aSubgroupStrides(aRank, 0);
+  for (auto [i, dim] : llvm::enumerate(aMDims)) {
+    aBatchSizes[dim] = batchMSizes[i];
+    aSubgroupSizes[dim] = subgroupMBasis[i];
+    aSubgroupStrides[dim] = subgroupMStrides[i];
+  }
+  for (auto [kDim, lhsKDim] :
+       llvm::zip_equal(opInfo.getKDims(), opInfo.lhsKDim)) {
+    aBatchSizes[lhsKDim] = bounds[kDim];
+  }
+  aBatchSizes[afk] = bounds[opInfo.getKDims().back()] / intrinsicK;
+
+  IREE::VectorExt::NestedLayoutAttr aLayout = createNestedLayout(
+      context, aRank, afm, afk,
+      /*subgroupCount=*/aSubgroupSizes,
+      /*subgroupStrides=*/aSubgroupStrides,
+      /*batchCount=*/aBatchSizes,
+      getSingleSubgroupLayout(mmaAttr, IREE::GPU::MMAFragment::Lhs));
+  LLVM_DEBUG({ llvm::dbgs() << "A layout: " << aLayout << "\n"; });
+
+  int64_t bRank = opInfo.getBRank();
+
+  SmallVector<int64_t> bNDims = opInfo.rhsNDims;
+  SmallVector<int64_t> bBatchSizes(bRank, 1);
+  SmallVector<int64_t> bSubgroupSizes(bRank, 1);
+  SmallVector<int64_t> bSubgroupStrides(bRank, 0);
+  for (auto [i, dim] : llvm::enumerate(bNDims)) {
+    bBatchSizes[dim] = batchNSizes[i];
+    bSubgroupSizes[dim] = subgroupNBasis[i];
+    bSubgroupStrides[dim] = subgroupNStrides[i];
+  }
+  for (auto [kDim, rhsKDim] :
+       llvm::zip_equal(opInfo.getKDims(), opInfo.rhsKDim)) {
+    bBatchSizes[rhsKDim] = bounds[kDim];
+  }
+  bBatchSizes[bfk] = bounds[opInfo.getKDims().back()] / intrinsicK;
+
+  IREE::VectorExt::NestedLayoutAttr bLayout = createNestedLayout(
+      context, bRank, bfk, bfn,
+      /*subgroupCount=*/bSubgroupSizes,
+      /*subgroupStrides=*/bSubgroupStrides,
+      /*batchCount=*/bBatchSizes,
+      getSingleSubgroupLayout(mmaAttr, IREE::GPU::MMAFragment::Rhs));
+  LLVM_DEBUG({ llvm::dbgs() << "B layout: " << bLayout << "\n"; });
+
+  std::tuple<VectorLayoutInterface, VectorLayoutInterface,
+             VectorLayoutInterface>
+      result = {aLayout, bLayout, cLayout};
+  return result;
+}
+
+/// Template specializations
+::mlir::FailureOr<::std::tuple<IREE::VectorExt::VectorLayoutInterface,
+                               IREE::VectorExt::VectorLayoutInterface,
+                               IREE::VectorExt::VectorLayoutInterface>>
+IREE::GPU::getContractionLayout(IREE::GPU::MMAScheduleAttr scheduleAttr,
+                                VectorContractOpInfo &opInfo,
+                                linalg::LinalgOp contractOp) {
+  return getContractionLayoutImpl(scheduleAttr, opInfo, contractOp);
+}
+
+::mlir::FailureOr<::std::tuple<IREE::VectorExt::VectorLayoutInterface,
+                               IREE::VectorExt::VectorLayoutInterface,
+                               IREE::VectorExt::VectorLayoutInterface>>
+IREE::GPU::getContractionLayout(IREE::GPU::MMAScheduleAttr scheduleAttr,
+                                VectorContractOpInfo &opInfo,
+                                vector::ContractionOp contractOp) {
+  return getContractionLayoutImpl(scheduleAttr, opInfo, contractOp);
+}
+
 } // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h
index 617b774..061e139 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h
@@ -62,6 +62,24 @@
 void addBarrier(mlir::FunctionOpInterface funcOp, Operation *alloc,
                 ArrayRef<Operation *> aliasGroup, bool hasAsyncCopies = true);
 
+namespace IREE {
+namespace GPU {
+class MMAScheduleAttr;
+
+::llvm::FailureOr<::std::tuple<IREE::VectorExt::VectorLayoutInterface,
+                               IREE::VectorExt::VectorLayoutInterface,
+                               IREE::VectorExt::VectorLayoutInterface>>
+getContractionLayout(IREE::GPU::MMAScheduleAttr scheduleAttr,
+                     VectorContractOpInfo &opInfo, linalg::LinalgOp contractOp);
+
+::llvm::FailureOr<::std::tuple<IREE::VectorExt::VectorLayoutInterface,
+                               IREE::VectorExt::VectorLayoutInterface,
+                               IREE::VectorExt::VectorLayoutInterface>>
+getContractionLayout(IREE::GPU::MMAScheduleAttr scheduleAttr,
+                     VectorContractOpInfo &opInfo,
+                     vector::ContractionOp contractOp);
+} // namespace GPU
+} // namespace IREE
 } // namespace iree_compiler
 } // namespace mlir
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir
index 074664c..a419f4a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir
@@ -207,6 +207,57 @@
                                               workgroup_size = [64, 1, 1]
                                               subgroup_size = 64>
 
+#maps = [
+  affine_map<(bm, bn, m, n, k) -> (bm, m, k)>,
+  affine_map<(bm, bn, m, n, k) -> (bn, n, k)>,
+  affine_map<(bm, bn, m, n, k) -> (bm, m, bn, n)>
+]
+
+#traits = {
+  indexing_maps = #maps,
+  iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"],
+  lowering_config = #iree_gpu.lowering_config<{promote_operands = [0],
+                                              mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
+                                              subgroup_m_count = 2, subgroup_n_count = 2}>
+}
+
+func.func @packed_matmul_128x128x128(%lhs: tensor<8x16x16xf16>,
+                                     %rhs: tensor<8x16x16xf16>,
+                                     %init: tensor<8x16x8x16xf32>)
+                                     -> tensor<8x16x8x16xf32>
+                           attributes { translation_info = #translation } {
+  %out = linalg.generic #traits
+                        ins(%lhs, %rhs: tensor<8x16x16xf16>, tensor<8x16x16xf16>)
+                        outs(%init: tensor<8x16x8x16xf32>) {
+    ^bb0(%in: f16, %in_1: f16, %out: f32):
+      %ex   = arith.extf %in   : f16 to f32
+      %ex_1 = arith.extf %in_1 : f16 to f32
+      %mul  = arith.mulf %ex, %ex_1 : f32
+      %sum  = arith.addf %out, %mul : f32
+      linalg.yield %sum : f32
+  } -> tensor<8x16x8x16xf32>
+  return %out : tensor<8x16x8x16xf32>
+}
+
+
+// CHECK-DAG: #[[$NESTED:.+]] = #iree_vector_ext.nested_layout<subgroup_tile = [2, 1, 1], batch_tile = [4, 1, 1], outer_tile = [1, 1, 1], thread_tile = [1, 16, 4], element_tile = [1, 1, 4], subgroup_strides = [2, 0, 0], thread_strides = [0, 1, 16]>
+// CHECK-DAG: #[[$NESTED1:.+]] = #iree_vector_ext.nested_layout<subgroup_tile = [2, 1, 1], batch_tile = [4, 1, 1], outer_tile = [1, 1, 1], thread_tile = [1, 16, 4], element_tile = [1, 1, 4], subgroup_strides = [1, 0, 0], thread_strides = [0, 1, 16]>
+// CHECK-DAG: #[[$NESTED2:.+]] = #iree_vector_ext.nested_layout<subgroup_tile = [2, 1, 2, 1], batch_tile = [4, 1, 4, 1], outer_tile = [1, 1, 1, 1], thread_tile = [1, 4, 1, 16], element_tile = [1, 4, 1, 1], subgroup_strides = [2, 0, 1, 0], thread_strides = [0, 16, 0, 1]>
+// CHECK-LABEL: func.func @packed_matmul_128x128x128
+
+// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, shared_memory_conversion}
+// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>}
+// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>}
+// CHECK: linalg.generic
+// CHECK-SAME: ins(%[[LHS]], %[[RHS]]
+// CHECK-SAME: outs(%[[ACC]]
+
+// -----
+
+#translation = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute
+                                              workgroup_size = [64, 1, 1]
+                                              subgroup_size = 64>
+
 func.func @linalg_copy(%in : tensor<16x16x16xf16>) -> tensor<16x16x16xf16>
                       attributes { translation_info = #translation } {
   %empty = tensor.empty() : tensor<16x16x16xf16>
diff --git a/tests/external/iree-test-suites/test_suite_files/attention_and_matmul_spec_unet_fp16_mi308.mlir b/tests/external/iree-test-suites/test_suite_files/attention_and_matmul_spec_unet_fp16_mi308.mlir
index b35fc42..3d14b1d 100644
--- a/tests/external/iree-test-suites/test_suite_files/attention_and_matmul_spec_unet_fp16_mi308.mlir
+++ b/tests/external/iree-test-suites/test_suite_files/attention_and_matmul_spec_unet_fp16_mi308.mlir
@@ -27,7 +27,7 @@
   transform.iree.match.cast_compatible_type %in0 = tensor<?x?x?x?xf16> : !transform.any_value
 
   %config = transform.param.constant #iree_codegen.compilation_info<
-          lowering_config = #iree_gpu.lowering_config<{workgroup = [1, 1, 128, 0, 0, 0], reduction=[0, 0, 0, 0, 0, 64], promote_operands = [1, 2]}>,
+          lowering_config = #iree_gpu.lowering_config<{workgroup = [1, 1, 64, 0, 0, 0], reduction=[0, 0, 0, 0, 0, 64], promote_operands = [1, 2]}>,
           translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute
                                                             workgroup_size = [64, 4]
                                                             subgroup_size = 64 ,