[VectorDistribute] Refactor layout configuration to a simpler logic (#21883)

This patch refactors the layout configuration for mma intrinsics to a
much simpler implementation and documents it better. This implementation
will eventually allow us to stop using subgroup_m_count/subgroup_n_count
attributes and use subgroup_basis to control setting number of subgroups
on each dimension.

This patch is mostly NFC, except it removes the ability to greedily
distribute multiple subgroups to multiple dimensions automatically. This
path was not used or maintained. We will eventually have the ability to
do it in following patches once we switch to subgroup_basis.

The test deleted checks the greedy distribution to multiple subgroups
ability which was not used anywhere else.

The tuning spec is changed because it was actually wrong and the
compiler was falsely accepting the configuration. The new configuration
will crash on this invalid specification.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
index 83768db..98c0f5f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
@@ -29,6 +29,7 @@
 #define GEN_PASS_DEF_LLVMGPUCONFIGURETENSORLAYOUTSPASS
 #include "iree/compiler/Codegen/LLVMGPU/Passes.h.inc"
 
+using IREE::GPU::MMASingleSubgroupLayout;
 using IREE::VectorExt::NestedLayoutAttr;
 using IREE::VectorExt::ToLayoutOp;
 using IREE::VectorExt::VectorLayoutInterface;
@@ -79,81 +80,123 @@
   return getSubgroupNCount(config).value();
 }
 
-/// Gets a unit vector of the given rank, but fills in the given dimensions
-/// from the 2 element array |counts|. |dim0| is the position in the returned
-/// vector to put the first element of |counts|, and |dim1| is the position to
-/// put the second element. For example,
-///
-/// rank = 3, counts = [5, 7], dim0 = 2, dim1 = 1
-/// returns [1, 5, 7]
-static SmallVector<int64_t> getUnitOfRankWithDims(int64_t rank,
-                                                  ArrayRef<int64_t> counts,
-                                                  int64_t dim0, int64_t dim1) {
-  assert(counts.size() == 2 &&
-         "Unexpected non-rank 2 single subgroup dimension counts");
-  SmallVector<int64_t> res(rank, 1);
-  res[dim0] = counts[0];
-  res[dim1] = counts[1];
-  return res;
+struct ContractionLayout {
+  VectorLayoutInterface lhs;
+  VectorLayoutInterface rhs;
+  VectorLayoutInterface acc;
+};
+
+// Get the layouts to use for the contraction given the intrinsic to use and
+// number of subgroups on the M and N dimension.
+//
+// The contraction is expected to have 3 operands: lhs, rhs and acc of the
+// contraction and a single accumulator.
+static FailureOr<ContractionLayout>
+getContractionLayout(IREE::Codegen::InnerTileDescAttrInterface intrinsic,
+                     int64_t subgroupMCount, int64_t subgroupNCount,
+                     ArrayRef<int64_t> bounds,
+                     ArrayRef<AffineMap> contractIndexingMaps) {
+  auto mmaIntrinsic = dyn_cast<IREE::GPU::MmaInterfaceAttr>(intrinsic);
+  if (!mmaIntrinsic) {
+    return failure();
+  }
+
+  int64_t rank = bounds.size();
+  FailureOr<VectorContractOpInfo> maybeOpInfo =
+      VectorContractOpInfo::inferFromIndexingMaps(contractIndexingMaps);
+  if (failed(maybeOpInfo)) {
+    return failure();
+  }
+  VectorContractOpInfo opInfo = maybeOpInfo.value();
+  // Get the inner dimensions.
+  int64_t innerMDim = opInfo.getMDims().back();
+  int64_t innerNDim = opInfo.getNDims().back();
+  int64_t innerKDim = opInfo.getKDims().back();
+  // Find the number Of subgroups being used from subgroupMCount/subgroupNCount.
+  // Just assign them to the inner-most dimension for now.
+  // TODO: Use subgroup_basis to get this instead and allow distributing
+  // subgroups on multiple dimensions.
+  SmallVector<int64_t> subgroupCounts(rank, 1);
+  SmallVector<int64_t> subgroupStrides(rank, 0);
+  subgroupCounts[innerMDim] = subgroupMCount;
+  subgroupCounts[innerNDim] = subgroupNCount;
+  // Distribute on M and then N.
+  // TODO: Use subgroup_basis to get the strides.
+  subgroupStrides[innerMDim] = 1;
+  subgroupStrides[innerNDim] = subgroupMCount;
+
+  // Since these MMA intrinsics have a given tile size for each subgroup, we can
+  // calculate the batch dimensions without looking at the subgroup layout.
+  SmallVector<int64_t> subgroupSize(rank, 1);
+  auto [mSize, nSize, kSize] = mmaIntrinsic.getMNKShape();
+  subgroupSize[innerMDim] = mSize;
+  subgroupSize[innerNDim] = nSize;
+  subgroupSize[innerKDim] = kSize;
+
+  SmallVector<int64_t> batchCounts(rank);
+  for (auto [batchCount, subgroupCount, subgroupSize, bound] :
+       llvm::zip_equal(batchCounts, subgroupCounts, subgroupSize, bounds)) {
+    int64_t workgroupDimSize = subgroupCount * subgroupSize;
+    batchCount = llvm::divideCeil(bound, workgroupDimSize);
+  }
+
+  // MMA intrinsics can be weird and usually don't have a single subgroup
+  // iteration space, so we need to find their value subgroup iteration space
+  // indvidually.
+  auto getFragmentLayout = [&](IREE::GPU::MMAFragment fragment,
+                               int64_t outerDim, int64_t innerDim,
+                               AffineMap map) -> VectorLayoutInterface {
+    // Note that the struct MMASingleSubgroupLayout contains the partial layout
+    // for the canonical (M, K) x (K, N) -> (M, N) matmul form. We treat the
+    // concrete nested layout as the layout for the innermost M, N, K
+    // dimensions.
+    SmallVector<int64_t> outerCounts(rank, 1);
+    SmallVector<int64_t> elementCounts(rank, 1);
+    SmallVector<int64_t> threadCounts(rank, 1);
+    SmallVector<int64_t> threadStrides(rank, 0);
+
+    MMASingleSubgroupLayout subgroupLayout =
+        IREE::GPU::getSingleSubgroupLayout(intrinsic, fragment);
+    outerCounts[outerDim] = subgroupLayout.outer[0];
+    outerCounts[innerDim] = subgroupLayout.outer[1];
+    threadCounts[outerDim] = subgroupLayout.thread[0];
+    threadCounts[innerDim] = subgroupLayout.thread[1];
+    threadStrides[outerDim] = subgroupLayout.tstrides[0];
+    threadStrides[innerDim] = subgroupLayout.tstrides[1];
+    elementCounts[outerDim] = subgroupLayout.element[0];
+    elementCounts[innerDim] = subgroupLayout.element[1];
+    // Get the fragment layout for the entire iteration space and then project
+    // it. This is significantly easier than trying to create a layout for each
+    // fragment itself.
+    auto fragmentSpaceLayout = NestedLayoutAttr::get(
+        map.getContext(), subgroupCounts, batchCounts, outerCounts,
+        threadCounts, elementCounts, subgroupStrides, threadStrides);
+    return fragmentSpaceLayout.apply(map);
+  };
+
+  VectorLayoutInterface lhs =
+      getFragmentLayout(IREE::GPU::MMAFragment::Lhs, innerMDim, innerKDim,
+                        contractIndexingMaps[0]);
+  VectorLayoutInterface rhs =
+      getFragmentLayout(IREE::GPU::MMAFragment::Rhs, innerKDim, innerNDim,
+                        contractIndexingMaps[1]);
+  VectorLayoutInterface acc =
+      getFragmentLayout(IREE::GPU::MMAFragment::Acc, innerMDim, innerNDim,
+                        contractIndexingMaps[2]);
+
+  return ContractionLayout{lhs, rhs, acc};
 }
 
-/// Constructs the nested layout given the layout for a single subgroup and the
-/// subgroup/batch counts and orders, as well as the dimensions along which to
-/// distribute the intrinsic's layout.
-///
-/// |outerDim| and |innerDim| refer to which dimensions are the outermost and
-/// innermost for a canonical MK_KN_MN matrix multiply, for a particular
-/// fragment. For example, for the B matrix of an MK_NK_MN matrix multiply,
-/// we would have:
-///   outerDim = 1 for the K dim
-///   innerDim = 0 for the N dim
-///
-/// For something like MK_NKN_MN with multiple N dims, it would typically be:
-///   outerDim = 1 for K
-///   innerDim = 2 for the second N dim
-///
-/// Importantly these two dimensions always refer to the actual dimension
-/// positions in the undistributed vector. For each fragment, this means:
-///   A: [outerDim, innerDim] = [innerMostMDim, innerMostKDim]
-///   B: [outerDim, innerDim] = [innerMostKDim, innerMostNDim]
-///   C: [outerDim, innerDim] = [innerMostMDim, innerMostNDim]
-///
-/// And here inner most is referential to the iteration order, not the order
-/// they appear per fragment (because there is no relationship between the
-/// dimension order of M in A and in C, for example).
-static NestedLayoutAttr createNestedLayout(
-    MLIRContext *context, int64_t rank, int64_t outerDim, int64_t innerDim,
-    ArrayRef<int64_t> subgroupSizes, ArrayRef<int64_t> subgroupStrides,
-    ArrayRef<int64_t> batchCount, IREE::GPU::MMASingleSubgroupLayout counts) {
-
-  LLVM_DEBUG({
-    llvm::dbgs() << "Creating Nested Layout for::" << "\n    outerDim = "
-                 << outerDim << "\n    innerDim = " << innerDim
-                 << "\n    subgroupSizes: " << llvm::interleaved(subgroupSizes)
-                 << "\n    subgroupStrides: "
-                 << llvm::interleaved(subgroupStrides)
-                 << "\n    batchCount: " << llvm::interleaved(batchCount)
-                 << "\n    counts.outer: " << llvm::interleaved(counts.outer)
-                 << "\n    counts.thread: " << llvm::interleaved(counts.thread)
-                 << "\n    counts.element: "
-                 << llvm::interleaved(counts.element)
-                 << "\n    counts.tstrides: "
-                 << llvm::interleaved(counts.tstrides) << "\n";
-  });
-
-  SmallVector<int64_t> outerCount =
-      getUnitOfRankWithDims(rank, counts.outer, outerDim, innerDim);
-  SmallVector<int64_t> threadCount =
-      getUnitOfRankWithDims(rank, counts.thread, outerDim, innerDim);
-  SmallVector<int64_t> threadStrides =
-      getUnitOfRankWithDims(rank, counts.tstrides, outerDim, innerDim);
-  SmallVector<int64_t> elementCount =
-      getUnitOfRankWithDims(rank, counts.element, outerDim, innerDim);
-
-  auto layoutAttr = NestedLayoutAttr::get(context, subgroupSizes, batchCount,
-                                          outerCount, threadCount, elementCount,
-                                          subgroupStrides, threadStrides);
-  return layoutAttr;
+SmallVector<int64_t> getIterationSpaceBounds(linalg::LinalgOp linalgOp) {
+  SmallVector<int64_t> bounds = linalgOp.getStaticLoopRanges();
+  std::optional<VectorizationTileSizes> sizes =
+      inferSizesFromIR(linalgOp, std::nullopt);
+  // Even though the opShape could be dynamic, we could potentially
+  // infer the vector shape
+  if (sizes.has_value()) {
+    bounds = sizes.value().vectorSizes;
+  }
+  return bounds;
 }
 
 static LogicalResult setContractionAnchor(IREE::GPU::MMAScheduleAttr schedule,
@@ -169,12 +212,10 @@
   assert(linalg::isaContractionOpInterface(contract) &&
          "cannot set contraction anchor on non contraction op");
 
-  FailureOr<VectorContractOpInfo> opInfo =
-      VectorContractOpInfo::inferFromIndexingMaps(
-          contract.getIndexingMapsArray());
-  assert(succeeded(opInfo) && "contraction should have been inferred");
-
-  auto layouts = getContractionLayout(schedule, opInfo.value(), contract);
+  SmallVector<int64_t> bounds = getIterationSpaceBounds(contract);
+  auto layouts = getContractionLayout(
+      schedule.getIntrinsic(), schedule.getSubgroupMCount(),
+      schedule.getSubgroupNCount(), bounds, contract.getIndexingMapsArray());
   if (failed(layouts)) {
     return contract->emitError("cannot get concrete layout for contraction");
   }
@@ -256,15 +297,10 @@
     map = projectDims(map, filterDims, /*compressDimsFlag=*/false);
   }
 
-  FailureOr<VectorContractOpInfo> opInfo =
-      VectorContractOpInfo::inferFromIndexingMaps(maps);
-  assert(succeeded(opInfo) &&
-         "unit filter dim convolution should have been infered");
-
-  auto layouts = getContractionLayout(schedule, opInfo.value(), conv);
-  if (failed(layouts)) {
-    return conv->emitError("cannot get concrete layout for convolution");
-  }
+  SmallVector<int64_t> bounds = getIterationSpaceBounds(conv);
+  auto layouts = getContractionLayout(
+      schedule.getIntrinsic(), schedule.getSubgroupMCount(),
+      schedule.getSubgroupNCount(), bounds, maps);
 
   auto [aLayout, bLayout, cLayout] = *layouts;
   Location loc = conv.getLoc();
@@ -456,26 +492,6 @@
                               pvMatmul);
 }
 
-// Apply the permuted projection map to the layout.
-static IREE::VectorExt::VectorLayoutInterface
-getLayoutForMap(VectorLayoutInterface layout, AffineMap map) {
-  // Project out unusued dims in layout.
-  SmallVector<bool> projectedDims(layout.getRank(), false);
-  llvm::SmallBitVector unusedBits = getUnusedDimsBitVector(map);
-  for (int dim : unusedBits.set_bits()) {
-    projectedDims[dim] = true;
-  }
-  IREE::VectorExt::VectorLayoutInterface projectedLayout =
-      layout.project(projectedDims);
-
-  // Transpose dims in layout.
-  AffineMap permMap = compressUnusedDims(map);
-  SmallVector<int64_t> identity =
-      llvm::to_vector(llvm::seq<int64_t>(permMap.getNumDims()));
-  SmallVector<int64_t> perm = applyPermutationMap<int64_t>(permMap, identity);
-  return projectedLayout.permute(perm);
-}
-
 static LogicalResult setDerivedThreadConfigLayout(
     IREE::GPU::DerivedThreadConfigAttr config, linalg::LinalgOp linalgOp,
     ArrayRef<int64_t> workgroupSize, RewriterBase &rewriter) {
@@ -552,7 +568,7 @@
   rewriter.setInsertionPointAfter(linalgOp);
   for (OpResult result : linalgOp->getResults()) {
     VectorLayoutInterface resultLayout =
-        getLayoutForMap(layout, linalgOp.getIndexingMapMatchingResult(result));
+        layout.apply(linalgOp.getIndexingMapMatchingResult(result));
     auto toLayout = rewriter.create<ToLayoutOp>(loc, result, resultLayout);
     rewriter.replaceAllUsesExcept(result, toLayout, toLayout);
   }
@@ -675,14 +691,7 @@
   MLIRContext *context = config.getContext();
   Location loc = candidate.getLoc();
 
-  SmallVector<int64_t> bounds = candidate.getStaticLoopRanges();
-  std::optional<VectorizationTileSizes> sizes =
-      inferSizesFromIR(candidate, std::nullopt);
-  // Even though the opShape could be dynamic, we could potentially
-  // infer the vector shape
-  if (sizes.has_value()) {
-    bounds = sizes.value().vectorSizes;
-  }
+  SmallVector<int64_t> bounds = getIterationSpaceBounds(candidate);
 
   // Subgroup distribution layouts.
   SmallVector<int64_t> subgroupSizes, subgroupStrides;
@@ -723,7 +732,7 @@
   rewriter.setInsertionPoint(candidate);
   for (OpOperand &operand : candidate->getOpOperands()) {
     VectorLayoutInterface operandLayout =
-        getLayoutForMap(layout, candidate.getMatchingIndexingMap(&operand));
+        layout.apply(candidate.getMatchingIndexingMap(&operand));
     auto toLayout =
         rewriter.create<ToLayoutOp>(loc, operand.get(), operandLayout);
     // Set shared memory promotion if requested.
@@ -735,7 +744,7 @@
   rewriter.setInsertionPointAfter(candidate);
   for (OpResult result : candidate->getResults()) {
     VectorLayoutInterface resultLayout =
-        getLayoutForMap(layout, candidate.getIndexingMapMatchingResult(result));
+        layout.apply(candidate.getIndexingMapMatchingResult(result));
     auto toLayout = rewriter.create<ToLayoutOp>(loc, result, resultLayout);
     rewriter.replaceAllUsesExcept(result, toLayout, toLayout);
   }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.cpp
index d8b9fc4..6cc7285 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.cpp
@@ -415,335 +415,4 @@
   packAllocs(builder, funcOp, aliasGroups);
 }
 
-/// Gets a unit vector of the given rank, but fills in the given dimensions
-/// from the 2 element array |counts|. |dim0| is the position in the returned
-/// vector to put the first element of |counts|, and |dim1| is the position to
-/// put the second element. For example,
-///
-/// rank = 3, counts = [5, 7], dim0 = 2, dim1 = 1
-/// returns [1, 5, 7]
-static SmallVector<int64_t> getUnitOfRankWithDims(int64_t rank,
-                                                  ArrayRef<int64_t> counts,
-                                                  int64_t dim0, int64_t dim1) {
-  assert(counts.size() == 2 &&
-         "Unexpected non-rank 2 single subgroup dimension counts");
-  SmallVector<int64_t> res(rank, 1);
-  res[dim0] = counts[0];
-  res[dim1] = counts[1];
-  return res;
-}
-
-/// Constructs the nested layout given the layout for a single subgroup and the
-/// subgroup/batch counts and orders, as well as the dimensions along which to
-/// distribute the intrinsic's layout.
-///
-/// |outerDim| and |innerDim| refer to which dimensions are the outermost and
-/// innermost for a canonical MK_KN_MN matrix multiply, for a particular
-/// fragment. For example, for the B matrix of an MK_NK_MN matrix multiply,
-/// we would have:
-///   outerDim = 1 for the K dim
-///   innerDim = 0 for the N dim
-///
-/// For something like MK_NKN_MN with multiple N dims, it would typically be:
-///   outerDim = 1 for K
-///   innerDim = 2 for the second N dim
-///
-/// Importantly these two dimensions always refer to the actual dimension
-/// positions in the undistributed vector. For each fragment, this means:
-///   A: [outerDim, innerDim] = [innerMostMDim, innerMostKDim]
-///   B: [outerDim, innerDim] = [innerMostKDim, innerMostNDim]
-///   C: [outerDim, innerDim] = [innerMostMDim, innerMostNDim]
-///
-/// And here inner most is referential to the iteration order, not the order
-/// they appear per fragment (because there is no relationship between the
-/// dimension order of M in A and in C, for example).
-static NestedLayoutAttr createNestedLayout(
-    MLIRContext *context, int64_t rank, int64_t outerDim, int64_t innerDim,
-    ArrayRef<int64_t> subgroupSizes, ArrayRef<int64_t> subgroupStrides,
-    ArrayRef<int64_t> batchCount, IREE::GPU::MMASingleSubgroupLayout counts) {
-
-  LLVM_DEBUG({
-    llvm::dbgs() << "Creating Nested Layout for::";
-    llvm::dbgs() << "\n    outerDim = " << outerDim;
-    llvm::dbgs() << "\n    innerDim = " << innerDim;
-    llvm::dbgs() << "\n    subgroupSizes: ";
-    llvm::interleaveComma(subgroupSizes, llvm::dbgs());
-    llvm::dbgs() << "\n    subgroupStrides: ";
-    llvm::interleaveComma(subgroupStrides, llvm::dbgs());
-    llvm::dbgs() << "\n    batchCount: ";
-    llvm::interleaveComma(batchCount, llvm::dbgs());
-    llvm::dbgs() << "\n    counts.outer: ";
-    llvm::interleaveComma(counts.outer, llvm::dbgs());
-    llvm::dbgs() << "\n    counts.thread: ";
-    llvm::interleaveComma(counts.thread, llvm::dbgs());
-    llvm::dbgs() << "\n    counts.element: ";
-    llvm::interleaveComma(counts.element, llvm::dbgs());
-    llvm::dbgs() << "\n    counts.tstrides: ";
-    llvm::interleaveComma(counts.tstrides, llvm::dbgs());
-    llvm::dbgs() << "\n";
-  });
-
-  SmallVector<int64_t> outerCount =
-      getUnitOfRankWithDims(rank, counts.outer, outerDim, innerDim);
-  SmallVector<int64_t> threadCount =
-      getUnitOfRankWithDims(rank, counts.thread, outerDim, innerDim);
-  SmallVector<int64_t> threadStrides =
-      getUnitOfRankWithDims(rank, counts.tstrides, outerDim, innerDim);
-  SmallVector<int64_t> elementCount =
-      getUnitOfRankWithDims(rank, counts.element, outerDim, innerDim);
-
-  auto layoutAttr = NestedLayoutAttr::get(context, subgroupSizes, batchCount,
-                                          outerCount, threadCount, elementCount,
-                                          subgroupStrides, threadStrides);
-  return layoutAttr;
-}
-
-template <typename ContractOpTy>
-static FailureOr<std::tuple<IREE::VectorExt::VectorLayoutInterface,
-                            IREE::VectorExt::VectorLayoutInterface,
-                            IREE::VectorExt::VectorLayoutInterface>>
-getContractionLayoutImpl(IREE::GPU::MMAScheduleAttr schedule,
-                         VectorContractOpInfo &opInfo,
-                         ContractOpTy contractOp) {
-  LLVM_DEBUG({
-    llvm::dbgs() << "Getting mma layouts for:\n" << contractOp << "\n";
-    llvm::dbgs() << "For schedule: " << schedule << "\n";
-  });
-
-  int64_t rank = contractOp.getIteratorTypesArray().size();
-  auto mmaAttr =
-      llvm::dyn_cast<IREE::GPU::MmaInterfaceAttr>(schedule.getIntrinsic());
-  if (!mmaAttr) {
-    return failure();
-  }
-
-  MLIRContext *context = schedule.getContext();
-
-  SmallVector<int64_t> bounds = contractOp.getStaticLoopRanges();
-  if (llvm::any_of(bounds,
-                   [](int64_t x) { return x == ShapedType::kDynamic; })) {
-    return failure();
-  }
-
-  if (!llvm::all_of(opInfo.getBatchDims(),
-                    [&bounds](int64_t dim) { return bounds[dim] == 1; })) {
-    LLVM_DEBUG({ llvm::dbgs() << "non-unit batch dimension\n"; });
-    return failure();
-  }
-
-  // Get the concrete nested layout for each matrix. Note that the struct
-  // MMASingleSubgroupLayout contains the partial layout for the
-  // canonical (M, K) x (K, N) -> (M, N) matmul form; while the specific
-  // contract op we are looking at right now may not be exactly in that form.
-  // So here we need to permute/transpose the canonical layout to match with
-  // the concrete contract op.
-
-  // Note that no matter how we permute/transpose the input contraction
-  // problem, the way we view the hardware warps remain the same--that is,
-  // from the hardware's perspective, a single warp has the same warp ID no
-  // matter what part of the contraction it works on. Similarly here, we are
-  // delinearizing the linearized GPU hardware lane ID into a n-D concatenated
-  // logical warp+thread using the subgroup/thread basis, so the subgroup
-  // basis should remain the same for all A/B/C matrix.
-
-  auto [intrinsicM, intrinsicN, intrinsicK] = mmaAttr.getMNKShape();
-
-  SmallVector<int64_t, 2> subgroupMBasis;
-  SmallVector<int64_t, 2> batchMSizes;
-  int64_t currMCount = schedule.getSubgroupMCount();
-
-  auto divideGreedily = [](int64_t availableSubgroups, int64_t dimSize,
-                           int64_t minDimSize) -> std::pair<int64_t, int64_t> {
-    int64_t dividableDim = dimSize / minDimSize;
-    int64_t subgroupsUsed = std::gcd(availableSubgroups, dividableDim);
-    dividableDim /= subgroupsUsed;
-    int64_t batchesUsed = dividableDim;
-    return {subgroupsUsed, batchesUsed};
-  };
-
-  // Greedily break up the M subgroup and batch counts along the "M" iteration
-  // bounds. We distribute as many residual subgroups as possible per M dim,
-  // and then divide the remaining along batch dims. The inner most M dim is
-  // always the one used for the intrinsic, meaning for a valid schedule, the
-  // computed batch counts and subgroup basis will satisfy totalMSize /
-  // intrinsicM = product(batchMSizes) * product(subgroupMBasis)
-  for (auto dim : opInfo.getMDims()) {
-    // Get the number of subgroups and batches used for this dimension based
-    // on the intrinsic size and the bound size.
-    int64_t subgroupsUsed, batchesUsed;
-    if (dim == opInfo.getMDims().back()) {
-      std::tie(subgroupsUsed, batchesUsed) =
-          divideGreedily(currMCount, bounds[dim], intrinsicM);
-    } else {
-      std::tie(subgroupsUsed, batchesUsed) =
-          divideGreedily(currMCount, bounds[dim], 1);
-    }
-    subgroupMBasis.push_back(subgroupsUsed);
-    batchMSizes.push_back(batchesUsed);
-    // Update available subgroup count.
-    currMCount /= subgroupsUsed;
-  }
-
-  SmallVector<int64_t, 2> subgroupNBasis;
-  SmallVector<int64_t, 2> batchNSizes;
-  int64_t currNCount = schedule.getSubgroupNCount();
-
-  // Do the same for N dims.
-  for (auto dim : opInfo.getNDims()) {
-    // Get the number of subgroups and batches used for this dimension based
-    // on the intrinsic size and the bound size.
-    int64_t subgroupsUsed, batchesUsed;
-    if (dim == opInfo.getNDims().back()) {
-      std::tie(subgroupsUsed, batchesUsed) =
-          divideGreedily(currNCount, bounds[dim], intrinsicN);
-    } else {
-      std::tie(subgroupsUsed, batchesUsed) =
-          divideGreedily(currNCount, bounds[dim], 1);
-    }
-    subgroupNBasis.push_back(subgroupsUsed);
-    batchNSizes.push_back(batchesUsed);
-    // Update available subgroup count.
-    currNCount /= subgroupsUsed;
-  }
-
-  SmallVector<int64_t> subgroupMStrides(subgroupMBasis.size());
-  SmallVector<int64_t> subgroupNStrides(subgroupNBasis.size());
-
-  auto mDimVec = opInfo.getMDims();
-  llvm::SmallDenseSet<int64_t> mDims(mDimVec.begin(), mDimVec.end());
-  auto nDimVec = opInfo.getNDims();
-  llvm::SmallDenseSet<int64_t> nDims(nDimVec.begin(), nDimVec.end());
-  // Because we currently require all batch dimensions to be unit, the
-  // subgroup basis can be constructed from the M and N bases. To keep things
-  // simple, the current heuristic is to distribute the loop dimensions from
-  // outer to inner.
-  int64_t currStride = 1;
-  int64_t currM = subgroupMStrides.size() - 1;
-  int64_t currN = subgroupNStrides.size() - 1;
-  for (int64_t dim : llvm::reverse(llvm::seq<int64_t>(rank))) {
-    if (mDims.contains(dim)) {
-      subgroupMStrides[currM] = currStride;
-      currStride *= subgroupMBasis[currM];
-      currM--;
-      continue;
-    }
-
-    if (nDims.contains(dim)) {
-      subgroupNStrides[currN] = currStride;
-      currStride *= subgroupNBasis[currN];
-      currN--;
-      continue;
-    }
-  }
-
-  // C matrix layout
-  auto [m, n] = opInfo.getResultMNIndex();
-  int64_t cRank = opInfo.getCRank();
-
-  // Get the M and N dims w.r.t. the dimensions of the C matrix. cMDims and
-  // cNDims are the M and N dimensions of the C matrix in the order they are
-  // iterated over in the contraction.
-  SmallVector<int64_t> cMDims = opInfo.outMDims;
-  SmallVector<int64_t> cNDims = opInfo.outNDims;
-  SmallVector<int64_t> cBatchSizes(cRank, 1);
-  SmallVector<int64_t> cSubgroupSizes(cRank, 1);
-  SmallVector<int64_t> cSubgroupStrides(cRank, 0);
-  for (auto [i, dim] : llvm::enumerate(cMDims)) {
-    cBatchSizes[dim] = batchMSizes[i];
-    cSubgroupSizes[dim] = subgroupMBasis[i];
-    cSubgroupStrides[dim] = subgroupMStrides[i];
-  }
-  for (auto [i, dim] : llvm::enumerate(cNDims)) {
-    cBatchSizes[dim] = batchNSizes[i];
-    cSubgroupSizes[dim] = subgroupNBasis[i];
-    cSubgroupStrides[dim] = subgroupNStrides[i];
-  }
-
-  IREE::VectorExt::NestedLayoutAttr cLayout = createNestedLayout(
-      context, cRank, m, n,
-      /*subgroupCount=*/cSubgroupSizes,
-      /*subgroupStrides=*/cSubgroupStrides,
-      /*batchCount=*/cBatchSizes,
-      getSingleSubgroupLayout(mmaAttr, IREE::GPU::MMAFragment::Acc));
-  LLVM_DEBUG({ llvm::dbgs() << "C layout: " << cLayout << "\n"; });
-
-  // A matrix layout
-  auto [afm, bfn] = opInfo.getOperandMNIndex();
-  auto [afk, bfk] = opInfo.getOperandKIndex();
-
-  int64_t aRank = opInfo.getARank();
-
-  SmallVector<int64_t> aMDims = opInfo.lhsMDims;
-  SmallVector<int64_t> aBatchSizes(aRank, 1);
-  SmallVector<int64_t> aSubgroupSizes(aRank, 1);
-  SmallVector<int64_t> aSubgroupStrides(aRank, 0);
-  for (auto [i, dim] : llvm::enumerate(aMDims)) {
-    aBatchSizes[dim] = batchMSizes[i];
-    aSubgroupSizes[dim] = subgroupMBasis[i];
-    aSubgroupStrides[dim] = subgroupMStrides[i];
-  }
-  for (auto [kDim, lhsKDim] :
-       llvm::zip_equal(opInfo.getKDims(), opInfo.lhsKDim)) {
-    aBatchSizes[lhsKDim] = bounds[kDim];
-  }
-  aBatchSizes[afk] = bounds[opInfo.getKDims().back()] / intrinsicK;
-
-  IREE::VectorExt::NestedLayoutAttr aLayout = createNestedLayout(
-      context, aRank, afm, afk,
-      /*subgroupCount=*/aSubgroupSizes,
-      /*subgroupStrides=*/aSubgroupStrides,
-      /*batchCount=*/aBatchSizes,
-      getSingleSubgroupLayout(mmaAttr, IREE::GPU::MMAFragment::Lhs));
-  LLVM_DEBUG({ llvm::dbgs() << "A layout: " << aLayout << "\n"; });
-
-  int64_t bRank = opInfo.getBRank();
-
-  SmallVector<int64_t> bNDims = opInfo.rhsNDims;
-  SmallVector<int64_t> bBatchSizes(bRank, 1);
-  SmallVector<int64_t> bSubgroupSizes(bRank, 1);
-  SmallVector<int64_t> bSubgroupStrides(bRank, 0);
-  for (auto [i, dim] : llvm::enumerate(bNDims)) {
-    bBatchSizes[dim] = batchNSizes[i];
-    bSubgroupSizes[dim] = subgroupNBasis[i];
-    bSubgroupStrides[dim] = subgroupNStrides[i];
-  }
-  for (auto [kDim, rhsKDim] :
-       llvm::zip_equal(opInfo.getKDims(), opInfo.rhsKDim)) {
-    bBatchSizes[rhsKDim] = bounds[kDim];
-  }
-  bBatchSizes[bfk] = bounds[opInfo.getKDims().back()] / intrinsicK;
-
-  IREE::VectorExt::NestedLayoutAttr bLayout = createNestedLayout(
-      context, bRank, bfk, bfn,
-      /*subgroupCount=*/bSubgroupSizes,
-      /*subgroupStrides=*/bSubgroupStrides,
-      /*batchCount=*/bBatchSizes,
-      getSingleSubgroupLayout(mmaAttr, IREE::GPU::MMAFragment::Rhs));
-  LLVM_DEBUG({ llvm::dbgs() << "B layout: " << bLayout << "\n"; });
-
-  std::tuple<VectorLayoutInterface, VectorLayoutInterface,
-             VectorLayoutInterface>
-      result = {aLayout, bLayout, cLayout};
-  return result;
-}
-
-/// Template specializations
-::mlir::FailureOr<::std::tuple<IREE::VectorExt::VectorLayoutInterface,
-                               IREE::VectorExt::VectorLayoutInterface,
-                               IREE::VectorExt::VectorLayoutInterface>>
-IREE::GPU::getContractionLayout(IREE::GPU::MMAScheduleAttr scheduleAttr,
-                                VectorContractOpInfo &opInfo,
-                                linalg::LinalgOp contractOp) {
-  return getContractionLayoutImpl(scheduleAttr, opInfo, contractOp);
-}
-
-::mlir::FailureOr<::std::tuple<IREE::VectorExt::VectorLayoutInterface,
-                               IREE::VectorExt::VectorLayoutInterface,
-                               IREE::VectorExt::VectorLayoutInterface>>
-IREE::GPU::getContractionLayout(IREE::GPU::MMAScheduleAttr scheduleAttr,
-                                VectorContractOpInfo &opInfo,
-                                vector::ContractionOp contractOp) {
-  return getContractionLayoutImpl(scheduleAttr, opInfo, contractOp);
-}
-
 } // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h
index 061e139..617b774 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h
@@ -62,24 +62,6 @@
 void addBarrier(mlir::FunctionOpInterface funcOp, Operation *alloc,
                 ArrayRef<Operation *> aliasGroup, bool hasAsyncCopies = true);
 
-namespace IREE {
-namespace GPU {
-class MMAScheduleAttr;
-
-::llvm::FailureOr<::std::tuple<IREE::VectorExt::VectorLayoutInterface,
-                               IREE::VectorExt::VectorLayoutInterface,
-                               IREE::VectorExt::VectorLayoutInterface>>
-getContractionLayout(IREE::GPU::MMAScheduleAttr scheduleAttr,
-                     VectorContractOpInfo &opInfo, linalg::LinalgOp contractOp);
-
-::llvm::FailureOr<::std::tuple<IREE::VectorExt::VectorLayoutInterface,
-                               IREE::VectorExt::VectorLayoutInterface,
-                               IREE::VectorExt::VectorLayoutInterface>>
-getContractionLayout(IREE::GPU::MMAScheduleAttr scheduleAttr,
-                     VectorContractOpInfo &opInfo,
-                     vector::ContractionOp contractOp);
-} // namespace GPU
-} // namespace IREE
 } // namespace iree_compiler
 } // namespace mlir
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir
index a419f4a..074664c 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir
@@ -207,57 +207,6 @@
                                               workgroup_size = [64, 1, 1]
                                               subgroup_size = 64>
 
-#maps = [
-  affine_map<(bm, bn, m, n, k) -> (bm, m, k)>,
-  affine_map<(bm, bn, m, n, k) -> (bn, n, k)>,
-  affine_map<(bm, bn, m, n, k) -> (bm, m, bn, n)>
-]
-
-#traits = {
-  indexing_maps = #maps,
-  iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"],
-  lowering_config = #iree_gpu.lowering_config<{promote_operands = [0],
-                                              mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
-                                              subgroup_m_count = 2, subgroup_n_count = 2}>
-}
-
-func.func @packed_matmul_128x128x128(%lhs: tensor<8x16x16xf16>,
-                                     %rhs: tensor<8x16x16xf16>,
-                                     %init: tensor<8x16x8x16xf32>)
-                                     -> tensor<8x16x8x16xf32>
-                           attributes { translation_info = #translation } {
-  %out = linalg.generic #traits
-                        ins(%lhs, %rhs: tensor<8x16x16xf16>, tensor<8x16x16xf16>)
-                        outs(%init: tensor<8x16x8x16xf32>) {
-    ^bb0(%in: f16, %in_1: f16, %out: f32):
-      %ex   = arith.extf %in   : f16 to f32
-      %ex_1 = arith.extf %in_1 : f16 to f32
-      %mul  = arith.mulf %ex, %ex_1 : f32
-      %sum  = arith.addf %out, %mul : f32
-      linalg.yield %sum : f32
-  } -> tensor<8x16x8x16xf32>
-  return %out : tensor<8x16x8x16xf32>
-}
-
-
-// CHECK-DAG: #[[$NESTED:.+]] = #iree_vector_ext.nested_layout<subgroup_tile = [2, 1, 1], batch_tile = [4, 1, 1], outer_tile = [1, 1, 1], thread_tile = [1, 16, 4], element_tile = [1, 1, 4], subgroup_strides = [2, 0, 0], thread_strides = [0, 1, 16]>
-// CHECK-DAG: #[[$NESTED1:.+]] = #iree_vector_ext.nested_layout<subgroup_tile = [2, 1, 1], batch_tile = [4, 1, 1], outer_tile = [1, 1, 1], thread_tile = [1, 16, 4], element_tile = [1, 1, 4], subgroup_strides = [1, 0, 0], thread_strides = [0, 1, 16]>
-// CHECK-DAG: #[[$NESTED2:.+]] = #iree_vector_ext.nested_layout<subgroup_tile = [2, 1, 2, 1], batch_tile = [4, 1, 4, 1], outer_tile = [1, 1, 1, 1], thread_tile = [1, 4, 1, 16], element_tile = [1, 4, 1, 1], subgroup_strides = [2, 0, 1, 0], thread_strides = [0, 16, 0, 1]>
-// CHECK-LABEL: func.func @packed_matmul_128x128x128
-
-// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, shared_memory_conversion}
-// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>}
-// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>}
-// CHECK: linalg.generic
-// CHECK-SAME: ins(%[[LHS]], %[[RHS]]
-// CHECK-SAME: outs(%[[ACC]]
-
-// -----
-
-#translation = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute
-                                              workgroup_size = [64, 1, 1]
-                                              subgroup_size = 64>
-
 func.func @linalg_copy(%in : tensor<16x16x16xf16>) -> tensor<16x16x16xf16>
                       attributes { translation_info = #translation } {
   %empty = tensor.empty() : tensor<16x16x16xf16>
diff --git a/tests/external/iree-test-suites/test_suite_files/attention_and_matmul_spec_unet_fp16_mi308.mlir b/tests/external/iree-test-suites/test_suite_files/attention_and_matmul_spec_unet_fp16_mi308.mlir
index 3d14b1d..b35fc42 100644
--- a/tests/external/iree-test-suites/test_suite_files/attention_and_matmul_spec_unet_fp16_mi308.mlir
+++ b/tests/external/iree-test-suites/test_suite_files/attention_and_matmul_spec_unet_fp16_mi308.mlir
@@ -27,7 +27,7 @@
   transform.iree.match.cast_compatible_type %in0 = tensor<?x?x?x?xf16> : !transform.any_value
 
   %config = transform.param.constant #iree_codegen.compilation_info<
-          lowering_config = #iree_gpu.lowering_config<{workgroup = [1, 1, 64, 0, 0, 0], reduction=[0, 0, 0, 0, 0, 64], promote_operands = [1, 2]}>,
+          lowering_config = #iree_gpu.lowering_config<{workgroup = [1, 1, 128, 0, 0, 0], reduction=[0, 0, 0, 0, 0, 64], promote_operands = [1, 2]}>,
           translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute
                                                             workgroup_size = [64, 4]
                                                             subgroup_size = 64 ,