More cleanup in `IREEGPUAttrs`. (#19161)

* Remove nearly 1,000 lines of code from `IREEGPUAttrs.cpp`, plus some
methods out of `IREEGPUAttrs.td`. The bulk of that is code that was only
used in a single place and didn't need to be in this central location,
gets moved to its single user. A few helper functions were completely
dead code, and a few additional simplifications.
* Write up-to-date comments for core data structures, particularly the
all-important `SingleSubgroupLayout`.
* Within `IREEGPUAttrs.cpp`, group at the top of the file the code that
encodes the raw information about MMA intrinsics, i.e. the code that
needs to be updated when a new intrinsic is supported.

---------

Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index 0a27437..19a7b3a 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -12,7 +12,6 @@
 #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
 #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
 #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h"
-#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
 #include "iree/compiler/Codegen/Utils/VectorOpUtils.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/STLForwardCompat.h"
@@ -45,169 +44,34 @@
 #define GET_ATTRDEF_CLASSES
 #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp.inc"
 
-using LayoutDimension = mlir::iree_compiler::IREE::VectorExt::LayoutDimension;
-using LayoutDimensionAttr =
-    mlir::iree_compiler::IREE::VectorExt::LayoutDimensionAttr;
-using VectorLayoutInterface =
-    mlir::iree_compiler::IREE::VectorExt::VectorLayoutInterface;
-using PerDimLayoutAttr = mlir::iree_compiler::IREE::VectorExt::PerDimLayoutAttr;
-using LayoutAttr = mlir::iree_compiler::IREE::VectorExt::LayoutAttr;
-using NestedLayoutAttr = mlir::iree_compiler::IREE::VectorExt::NestedLayoutAttr;
-
 namespace mlir::iree_compiler::IREE::GPU {
 
-namespace {
-// Struct containing abstract MMA shape and type information.
-struct OpaqueMmaLayout {
-  int64_t mSize = 0;
-  int64_t nSize = 0;
-  int64_t kSize = 0;
-  Type aType;
-  Type bType;
-  Type cType;
-};
-
-// Struct containing concrete MMA shape, type, and layout information.
-struct ConcreteMmaLayout {
-  OpaqueMmaLayout base;
-  PerDimLayoutAttr aMLayout;
-  PerDimLayoutAttr aKLayout;
-  PerDimLayoutAttr bKLayout;
-  PerDimLayoutAttr bNLayout;
-  PerDimLayoutAttr cMLayout;
-  PerDimLayoutAttr cNLayout;
-};
-} // namespace
-
 //===----------------------------------------------------------------------===//
-// #iree_gpu.mma_vector_layout
+// MMA intrinsics semantics: shapes, layouts, operand element types.
 //===----------------------------------------------------------------------===//
 
-static PerDimLayoutAttr getBatchedPerDimLayoutAttr(LayoutDimensionAttr batchDim,
-                                                   PerDimLayoutAttr baseLayout,
-                                                   int64_t problemSize,
-                                                   int64_t fragmentDimSize) {
-  assert(problemSize % fragmentDimSize == 0 &&
-         "invalid layout fragment for problem size");
-
-  SmallVector<LayoutDimensionAttr, 3> dimAttrs(baseLayout.getLabels());
-  dimAttrs.insert(dimAttrs.begin(), batchDim);
-
-  SmallVector<int64_t, 3> shapes(baseLayout.getShapes());
-  shapes.insert(shapes.begin(), problemSize / fragmentDimSize);
-  auto layout =
-      PerDimLayoutAttr::get(baseLayout.getContext(), dimAttrs, shapes);
-  return layout;
+static int getBlockSize(MMAIntrinsic /*intrinsic*/) {
+  // Not supporting any block size other than 1 at the moment.
+  return 1;
 }
 
-// Get the batched layout attributes for the given fragment layouts, indexing
-// map, and problem shape. The canonical fragment map is used to compare against
-// the problem map |indexingMap|. For example, for mma fragment B (RHS):
-//
-// indexingMap = affine_map<(d0, d1, d2) -> (d1, d2) # Transposed B
-// fragmentMap = affine_map<(d0, d1, d2) -> (d2, d1)
-// problemShape = [32, 64]
-// fragmentSize = [16, 8]
-// fragmentLayouts = [kLayout, nLayout]
-//
-// Gives batched layout
-//
-// Dim0 Layout = [BATCHX, nLayoutLabels], [8, nLayoutShape]
-// Dim1 Layout = [BATCHY, kLayoutLabels], [2, kLayoutShape]
-static LayoutAttr
-getBatchedLayoutAttr(AffineMap indexingMap, AffineMap fragmentMap,
-                     ArrayRef<int64_t> problemShape,
-                     ArrayRef<int64_t> fragmentSize,
-                     ArrayRef<PerDimLayoutAttr> fragmentLayouts) {
-  // Current distribution to MFMA operations does not support batched
-  // contractions so that is reflected here.
-  assert(indexingMap.getNumResults() == 2 &&
-         "invalid indexing map to non-batched simple contraction");
-
-  LayoutDimensionAttr batchX = LayoutDimensionAttr::get(
-      indexingMap.getContext(), LayoutDimension::BATCHX);
-  LayoutDimensionAttr batchY = LayoutDimensionAttr::get(
-      indexingMap.getContext(), LayoutDimension::BATCHY);
-
-  SmallVector<PerDimLayoutAttr, 2> perDimAttrs;
-  for (auto [expr, batchType] :
-       llvm::zip_equal(indexingMap.getResults(),
-                       SmallVector<LayoutDimensionAttr, 2>{batchX, batchY})) {
-    auto maybeResultPosition = fragmentMap.getResultPosition(expr);
-    assert(maybeResultPosition && "fragment map and problem map mismatch");
-    int64_t idx = *maybeResultPosition;
-    perDimAttrs.push_back(getBatchedPerDimLayoutAttr(
-        batchType, fragmentLayouts[idx], problemShape[idx], fragmentSize[idx]));
-  }
-
-  return LayoutAttr::get(indexingMap.getContext(), perDimAttrs);
+static uint32_t getArchID(MMAIntrinsic intrinsic) {
+  return static_cast<int>(intrinsic) & 0xFF00;
 }
 
-static FailureOr<std::tuple<VectorLayoutInterface, VectorLayoutInterface,
-                            VectorLayoutInterface>>
-getContractionLayout(vector::ContractionOp contract, ConcreteMmaLayout layout) {
-  MLIRContext *context = contract.getContext();
-  FailureOr<linalg::ContractionDimensions> maybeContractionDims =
-      linalg::inferContractionDims(contract.getIndexingMapsArray());
-  if (failed(maybeContractionDims)) {
-    return failure();
-  }
-  auto contractionDims = *maybeContractionDims;
-  // TODO: Relax this condition to strictly alignment requirements.
-  if (contractionDims.k.size() != 1 || contractionDims.m.size() != 1 ||
-      contractionDims.n.size() != 1) {
-    return failure();
-  }
-  // TODO: Support batched contractions.
-  if (contractionDims.batch.size() > 0) {
-    return failure();
-  }
-  unsigned mDim = contractionDims.m[0];
-  unsigned nDim = contractionDims.n[0];
-  unsigned kDim = contractionDims.k[0];
-
-  SmallVector<int64_t> iterationBounds;
-  contract.getIterationBounds(iterationBounds);
-
-  int64_t problemMSize = iterationBounds[mDim];
-  int64_t problemNSize = iterationBounds[nDim];
-  int64_t problemKSize = iterationBounds[kDim];
-
-  int64_t mSize = layout.base.mSize;
-  int64_t nSize = layout.base.nSize;
-  int64_t kSize = layout.base.kSize;
-
-  // The problem size currently must be strictly aligned to the size of the mma.
-  // This is expected to succeed assuming the correct [masked] vector size was
-  // set at strategy configuration time (for this mma).
-  if (problemMSize % mSize != 0 || problemNSize % nSize ||
-      problemKSize % kSize) {
-    return failure();
-  }
-
-  LayoutAttr aLayout = getBatchedLayoutAttr(
-      contract.getIndexingMapsArray()[0],
-      AffineMap::getMultiDimMapWithTargets(3, {mDim, kDim}, context),
-      {problemMSize, problemKSize}, {mSize, kSize},
-      {layout.aMLayout, layout.aKLayout});
-  LayoutAttr bLayout = getBatchedLayoutAttr(
-      contract.getIndexingMapsArray()[1],
-      AffineMap::getMultiDimMapWithTargets(3, {kDim, nDim}, context),
-      {problemKSize, problemNSize}, {kSize, nSize},
-      {layout.bKLayout, layout.bNLayout});
-  LayoutAttr cLayout = getBatchedLayoutAttr(
-      contract.getIndexingMapsArray()[2],
-      AffineMap::getMultiDimMapWithTargets(3, {mDim, nDim}, context),
-      {problemMSize, problemNSize}, {mSize, nSize},
-      {layout.cMLayout, layout.cNLayout});
-
-  return std::make_tuple<VectorLayoutInterface, VectorLayoutInterface,
-                         VectorLayoutInterface>(aLayout, bLayout, cLayout);
+static bool is_AMD_MFMA(MMAIntrinsic intrinsic) {
+  return getArchID(intrinsic) >= 0x1000 && getArchID(intrinsic) <= 0x17FF;
 }
 
-//===----------------------------------------------------------------------===//
-// Layout Attribute Building Helpers
-//===----------------------------------------------------------------------===//
+static bool is_AMD_WMMA(MMAIntrinsic intrinsic) {
+  return getArchID(intrinsic) >= 0x1800 && getArchID(intrinsic) <= 0x1FFF;
+}
+
+static int64_t getIntrinsicSubgroupSize(MMAIntrinsic intrinsic) {
+  // Not using Wave64 at all at the moment, so the only place where the
+  // subgroup size is CDNA* architectures.
+  return is_AMD_MFMA(intrinsic) ? 64 : 32;
+}
 
 static std::tuple<Type, Type, Type> getABCElementTypes(MLIRContext *context,
                                                        MMAIntrinsic intrinsic) {
@@ -263,233 +127,6 @@
   return {};
 }
 
-template <typename MMAIntrinsicType>
-static OpaqueMmaLayout getOpaqueMMALayout(MLIRContext *context,
-                                          MMAIntrinsicType intrinsic) {
-  OpaqueMmaLayout o;
-  std::tie(o.aType, o.bType, o.cType) = getABCElementTypes(context, intrinsic);
-  auto lhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Lhs);
-  auto rhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Rhs);
-  o.mSize = lhs.outer[0] * lhs.thread[0] * lhs.element[0];
-  o.kSize = lhs.outer[1] * lhs.thread[1] * lhs.element[1];
-  o.nSize = rhs.outer[1] * rhs.thread[1] * rhs.element[1];
-  return o;
-}
-
-static std::tuple<PerDimLayoutAttr, PerDimLayoutAttr>
-getPerDimLayoutAttrs(MLIRContext *context, TileSwizzle swizzle) {
-  // Step 1: obtain the swizzled tile shape, but keeping track of the source
-  // dimension indices.
-  struct SrcIndexAndSwizzleDim {
-    size_t srcIndex;
-    TileSwizzle::Dim dim;
-  };
-  SmallVector<SrcIndexAndSwizzleDim> swizzledShape;
-  for (auto [i, e] : llvm::enumerate(swizzle.expandShape)) {
-    for (TileSwizzle::Dim d : e) {
-      swizzledShape.push_back(SrcIndexAndSwizzleDim{i, d});
-    }
-  }
-  applyPermutationToVector(swizzledShape, swizzle.permutation);
-
-  // Step 2: collect the appropriate labels to use for the swizzled dims.
-  LayoutDimension internalLabels[] = {LayoutDimension::VECTORZ,
-                                      LayoutDimension::VECTORY,
-                                      LayoutDimension::VECTORX};
-  LayoutDimension crossThreadLabels[] = {
-      LayoutDimension::LANEZ, LayoutDimension::LANEY, LayoutDimension::LANEX};
-  auto internalLabelIter = std::end(internalLabels);
-  auto crossThreadLabelIter = std::end(crossThreadLabels);
-  for (SrcIndexAndSwizzleDim d : swizzledShape) {
-    if (d.dim.kind == TileSwizzle::Dim::Kind::Internal) {
-      assert(internalLabelIter != std::begin(internalLabels));
-      --internalLabelIter;
-    } else if (d.dim.kind == TileSwizzle::Dim::Kind::CrossThread) {
-      assert(crossThreadLabelIter != std::begin(crossThreadLabels));
-      --crossThreadLabelIter;
-    } else {
-      assert(false && "unexpected dimension kind in intrinsic swizzle");
-    }
-  }
-
-  // Step 3: put together the result PerDimLayoutAttr'd for the two source dims.
-  SmallVector<LayoutDimensionAttr> labels[2];
-  SmallVector<int64_t> shape[2];
-  for (SrcIndexAndSwizzleDim d : swizzledShape) {
-    shape[d.srcIndex].push_back(d.dim.size);
-    auto &labelIterRef = (d.dim.kind == TileSwizzle::Dim::Kind::Internal)
-                             ? internalLabelIter
-                             : crossThreadLabelIter;
-    labels[d.srcIndex].push_back(LayoutDimensionAttr::get(
-        context, static_cast<LayoutDimension>(*labelIterRef++)));
-  }
-  return {PerDimLayoutAttr::get(context, labels[0], shape[0]),
-          PerDimLayoutAttr::get(context, labels[1], shape[1])};
-};
-
-static ConcreteMmaLayout getConcreteMMALayout(MLIRContext *context,
-                                              MMAIntrinsic intrinsic) {
-  auto opaque = getOpaqueMMALayout(context, intrinsic);
-  ConcreteMmaLayout concreteLayout;
-  concreteLayout.base = opaque;
-  auto lhsSwizzle = getIntrinsicSwizzle(intrinsic, MMAFragment::Lhs);
-  auto rhsSwizzle = getIntrinsicSwizzle(intrinsic, MMAFragment::Rhs);
-  auto accSwizzle = getIntrinsicSwizzle(intrinsic, MMAFragment::Acc);
-  std::tie(concreteLayout.aMLayout, concreteLayout.aKLayout) =
-      getPerDimLayoutAttrs(context, lhsSwizzle);
-  std::tie(concreteLayout.bNLayout, concreteLayout.bKLayout) =
-      getPerDimLayoutAttrs(context, rhsSwizzle);
-  std::tie(concreteLayout.cMLayout, concreteLayout.cNLayout) =
-      getPerDimLayoutAttrs(context, accSwizzle);
-  return concreteLayout;
-}
-
-//===----------------------------------------------------------------------===//
-// MmaInterface Attribute Helper Functions
-//===----------------------------------------------------------------------===//
-
-MMASingleSubgroupLayout getASingleSubgroupLayout(MmaInterfaceAttr mmaKind) {
-  if (auto mmaAttr = dyn_cast<MMAAttr>(mmaKind)) {
-    return mmaAttr.getASingleSubgroupLayout();
-  } else if (auto vmmaAttr = dyn_cast<VirtualMMAAttr>(mmaKind)) {
-    return vmmaAttr.getASingleSubgroupLayout();
-  } else {
-    assert(false && "unhandled MMA Interface type.");
-    return {};
-  }
-}
-
-MMASingleSubgroupLayout getBSingleSubgroupLayout(MmaInterfaceAttr mmaKind) {
-  if (auto mmaAttr = dyn_cast<MMAAttr>(mmaKind)) {
-    return mmaAttr.getBSingleSubgroupLayout();
-  } else if (auto vmmaAttr = dyn_cast<VirtualMMAAttr>(mmaKind)) {
-    return vmmaAttr.getBSingleSubgroupLayout();
-  } else {
-    assert(false && "unhandled MMA Interface type.");
-    return {};
-  }
-}
-
-MMASingleSubgroupLayout getCSingleSubgroupLayout(MmaInterfaceAttr mmaKind) {
-  if (auto mmaAttr = dyn_cast<MMAAttr>(mmaKind)) {
-    return mmaAttr.getCSingleSubgroupLayout();
-  } else if (auto vmmaAttr = dyn_cast<VirtualMMAAttr>(mmaKind)) {
-    return vmmaAttr.getCSingleSubgroupLayout();
-  } else {
-    assert(false && "unhandled MMA Interface type.");
-    return {};
-  }
-}
-
-//===----------------------------------------------------------------------===//
-// MFMA Attributes
-//===----------------------------------------------------------------------===//
-
-Attribute MMAAttr::parse(AsmParser &p, Type type) {
-  if (failed(p.parseLess()))
-    return {};
-
-  FailureOr<MMAIntrinsicAttr> mmaIntrinsic =
-      FieldParser<MMAIntrinsicAttr>::parse(p);
-  if (failed(mmaIntrinsic)) {
-    p.emitError(p.getCurrentLocation(), "failed to parse mfma type identifier");
-    return {};
-  }
-
-  if (failed(p.parseGreater()))
-    return {};
-
-  return get(p.getContext(), mmaIntrinsic->getValue());
-}
-
-void MMAAttr::print(AsmPrinter &p) const {
-  auto &os = p.getStream();
-  os << "<";
-  os << stringifyMMAIntrinsic(getIntrinsic().getValue());
-  os << ">";
-}
-
-MMAAttr MMAAttr::get(MLIRContext *context, MMAIntrinsic type) {
-  auto layout = getOpaqueMMALayout(context, type);
-  return Base::get(context, MMAIntrinsicAttr::get(context, type), layout.mSize,
-                   layout.nSize, layout.kSize, layout.aType, layout.bType,
-                   layout.cType);
-}
-
-std::tuple<Type, Type, Type> MMAAttr::getABCElementTypes() const {
-  return {getAType(), getBType(), getCType()};
-}
-
-std::tuple<int64_t, int64_t, int64_t> MMAAttr::getMNKShape() const {
-  return {getMSize(), getNSize(), getKSize()};
-}
-
-template <typename MMAIntrinsicType>
-static VectorType getVectorType(MLIRContext *context,
-                                MMAIntrinsicType intrinsic,
-                                MMAFragment fragment) {
-  auto o = getOpaqueMMALayout(context, intrinsic);
-  auto s = getSingleSubgroupLayout(intrinsic, fragment);
-  Type elemType = (fragment == MMAFragment::Lhs)   ? o.aType
-                  : (fragment == MMAFragment::Rhs) ? o.bType
-                                                   : o.cType;
-  return VectorType::get(
-      {s.element[0] * s.element[1] * s.outer[0] * s.outer[1]}, elemType);
-}
-
-std::tuple<VectorType, VectorType, VectorType>
-MMAAttr::getABCVectorTypes() const {
-  MLIRContext *context = getContext();
-  MMAIntrinsic intrinsic = getIntrinsic().getValue();
-  VectorType aVecType = getVectorType(context, intrinsic, MMAFragment::Lhs);
-  VectorType bVecType = getVectorType(context, intrinsic, MMAFragment::Rhs);
-  VectorType cVecType = getVectorType(context, intrinsic, MMAFragment::Acc);
-  return {aVecType, bVecType, cVecType};
-}
-
-FailureOr<std::tuple<VectorLayoutInterface, VectorLayoutInterface,
-                     VectorLayoutInterface>>
-MMAAttr::getContractionLayout(vector::ContractionOp contract) const {
-  ConcreteMmaLayout layout =
-      getConcreteMMALayout(contract->getContext(), getIntrinsic().getValue());
-  return IREE::GPU::getContractionLayout(contract, layout);
-}
-
-static int getBlockSize(MMAIntrinsic /*intrinsic*/) {
-  // Not supporting any block size other than 1 at the moment.
-  return 1;
-}
-
-int64_t MMAAttr::getBlockSize() const {
-  return IREE::GPU::getBlockSize(getIntrinsic().getValue());
-}
-
-static uint32_t getArchID(MMAIntrinsic intrinsic) {
-  return static_cast<int>(intrinsic) & 0xFF00;
-}
-
-static bool is_AMD_MFMA(MMAIntrinsic intrinsic) {
-  return getArchID(intrinsic) >= 0x1000 && getArchID(intrinsic) <= 0x17FF;
-}
-
-static bool is_AMD_WMMA(MMAIntrinsic intrinsic) {
-  return getArchID(intrinsic) >= 0x1800 && getArchID(intrinsic) <= 0x1FFF;
-}
-
-static int64_t getIntrinsicSubgroupSize(MMAIntrinsic intrinsic) {
-  // Not using Wave64 at all at the moment, so the only place where the
-  // subgroup size is CDNA* architectures.
-  return is_AMD_MFMA(intrinsic) ? 64 : 32;
-}
-
-int64_t MMAAttr::getSubgroupSize() const {
-  return getIntrinsicSubgroupSize(getIntrinsic().getValue());
-}
-
-FailureOr<IREE::GPU::MMAScope> MMAAttr::getMmaScope() const {
-  return IREE::GPU::MMAScope::Subgroup;
-}
-
 MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic,
                                                 MMAFragment fragment) {
   switch (intrinsic) {
@@ -637,6 +274,139 @@
   return {};
 }
 
+template <typename MMAIntrinsicType>
+static OpaqueMmaLayout getOpaqueMMALayout(MLIRContext *context,
+                                          MMAIntrinsicType intrinsic) {
+  OpaqueMmaLayout o;
+  std::tie(o.aType, o.bType, o.cType) = getABCElementTypes(context, intrinsic);
+  auto lhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Lhs);
+  auto rhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Rhs);
+  o.mSize = lhs.outer[0] * lhs.thread[0] * lhs.element[0];
+  o.kSize = lhs.outer[1] * lhs.thread[1] * lhs.element[1];
+  o.nSize = rhs.outer[1] * rhs.thread[1] * rhs.element[1];
+  return o;
+}
+
+OpaqueMmaLayout getOpaqueMMALayout(MLIRContext *context,
+                                   IREE::GPU::MMAIntrinsic intrinsic) {
+  return getOpaqueMMALayout<IREE::GPU::MMAIntrinsic>(context, intrinsic);
+}
+
+//===----------------------------------------------------------------------===//
+// MmaInterface Attribute Helper Functions
+//===----------------------------------------------------------------------===//
+
+MMASingleSubgroupLayout getASingleSubgroupLayout(MmaInterfaceAttr mmaKind) {
+  if (auto mmaAttr = dyn_cast<MMAAttr>(mmaKind)) {
+    return mmaAttr.getASingleSubgroupLayout();
+  }
+  if (auto vmmaAttr = dyn_cast<VirtualMMAAttr>(mmaKind)) {
+    return vmmaAttr.getASingleSubgroupLayout();
+  }
+  assert(false && "unhandled MMA Interface type.");
+  return {};
+}
+
+MMASingleSubgroupLayout getBSingleSubgroupLayout(MmaInterfaceAttr mmaKind) {
+  if (auto mmaAttr = dyn_cast<MMAAttr>(mmaKind)) {
+    return mmaAttr.getBSingleSubgroupLayout();
+  }
+  if (auto vmmaAttr = dyn_cast<VirtualMMAAttr>(mmaKind)) {
+    return vmmaAttr.getBSingleSubgroupLayout();
+  }
+  assert(false && "unhandled MMA Interface type.");
+  return {};
+}
+
+MMASingleSubgroupLayout getCSingleSubgroupLayout(MmaInterfaceAttr mmaKind) {
+  if (auto mmaAttr = dyn_cast<MMAAttr>(mmaKind)) {
+    return mmaAttr.getCSingleSubgroupLayout();
+  }
+  if (auto vmmaAttr = dyn_cast<VirtualMMAAttr>(mmaKind)) {
+    return vmmaAttr.getCSingleSubgroupLayout();
+  }
+  assert(false && "unhandled MMA Interface type.");
+  return {};
+}
+
+//===----------------------------------------------------------------------===//
+// MMA Attributes
+//===----------------------------------------------------------------------===//
+
+Attribute MMAAttr::parse(AsmParser &p, Type type) {
+  if (failed(p.parseLess()))
+    return {};
+
+  FailureOr<MMAIntrinsicAttr> mmaIntrinsic =
+      FieldParser<MMAIntrinsicAttr>::parse(p);
+  if (failed(mmaIntrinsic)) {
+    p.emitError(p.getCurrentLocation(), "failed to parse mfma type identifier");
+    return {};
+  }
+
+  if (failed(p.parseGreater()))
+    return {};
+
+  return get(p.getContext(), mmaIntrinsic->getValue());
+}
+
+void MMAAttr::print(AsmPrinter &p) const {
+  auto &os = p.getStream();
+  os << "<";
+  os << stringifyMMAIntrinsic(getIntrinsic().getValue());
+  os << ">";
+}
+
+MMAAttr MMAAttr::get(MLIRContext *context, MMAIntrinsic type) {
+  auto layout = getOpaqueMMALayout(context, type);
+  return Base::get(context, MMAIntrinsicAttr::get(context, type), layout.mSize,
+                   layout.nSize, layout.kSize, layout.aType, layout.bType,
+                   layout.cType);
+}
+
+std::tuple<Type, Type, Type> MMAAttr::getABCElementTypes() const {
+  return {getAType(), getBType(), getCType()};
+}
+
+std::tuple<int64_t, int64_t, int64_t> MMAAttr::getMNKShape() const {
+  return {getMSize(), getNSize(), getKSize()};
+}
+
+template <typename MMAIntrinsicType>
+static VectorType getVectorType(MLIRContext *context,
+                                MMAIntrinsicType intrinsic,
+                                MMAFragment fragment) {
+  auto o = getOpaqueMMALayout(context, intrinsic);
+  auto s = getSingleSubgroupLayout(intrinsic, fragment);
+  Type elemType = (fragment == MMAFragment::Lhs)   ? o.aType
+                  : (fragment == MMAFragment::Rhs) ? o.bType
+                                                   : o.cType;
+  return VectorType::get(
+      {s.element[0] * s.element[1] * s.outer[0] * s.outer[1]}, elemType);
+}
+
+std::tuple<VectorType, VectorType, VectorType>
+MMAAttr::getABCVectorTypes() const {
+  MLIRContext *context = getContext();
+  MMAIntrinsic intrinsic = getIntrinsic().getValue();
+  VectorType aVecType = getVectorType(context, intrinsic, MMAFragment::Lhs);
+  VectorType bVecType = getVectorType(context, intrinsic, MMAFragment::Rhs);
+  VectorType cVecType = getVectorType(context, intrinsic, MMAFragment::Acc);
+  return {aVecType, bVecType, cVecType};
+}
+
+int64_t MMAAttr::getBlockSize() const {
+  return IREE::GPU::getBlockSize(getIntrinsic().getValue());
+}
+
+int64_t MMAAttr::getSubgroupSize() const {
+  return getIntrinsicSubgroupSize(getIntrinsic().getValue());
+}
+
+FailureOr<IREE::GPU::MMAScope> MMAAttr::getMmaScope() const {
+  return IREE::GPU::MMAScope::Subgroup;
+}
+
 MMASingleSubgroupLayout MMAAttr::getASingleSubgroupLayout() const {
   return getSingleSubgroupLayout(getIntrinsic().getValue(), MMAFragment::Lhs);
 }
@@ -781,22 +551,8 @@
     SmallVector<OpFoldResult> &offsets, SmallVector<OpFoldResult> &sizes,
     SmallVector<OpFoldResult> &strides) const {
 
-  MMASingleSubgroupLayout subgroupLayout;
-  switch (fragment) {
-  case IREE::GPU::MMAFragment::Lhs: {
-    subgroupLayout = getASingleSubgroupLayout();
-    break;
-  }
-  case IREE::GPU::MMAFragment::Rhs: {
-    subgroupLayout = getBSingleSubgroupLayout();
-    break;
-  }
-  case IREE::GPU::MMAFragment::Acc: {
-    subgroupLayout = getCSingleSubgroupLayout();
-    break;
-  }
-  }
-
+  MMASingleSubgroupLayout subgroupLayout =
+      getSingleSubgroupLayout(getIntrinsic().getValue(), fragment);
   SmallVector<OpFoldResult> canonicalOffsets;
   SmallVector<OpFoldResult> canonicalSizes;
   if (failed(populateCanonicalOffsetsSizesAndStrides(
@@ -810,82 +566,6 @@
   return success();
 }
 
-LogicalResult MMAAttr::materializeOperandConcreteShape(
-    OpBuilder &builder, IREE::GPU::MMAFragment fragment, Value operand,
-    std::optional<ArrayRef<int64_t>> permutation,
-    SmallVector<ReassociationIndices> &reassociations,
-    RankedTensorType &resultType) const {
-
-  SmallVector<int64_t, 2> outerSizes;
-  SmallVector<int64_t, 2> opaqueSizes;
-  auto [m, n, k] = getMNKShape();
-  switch (fragment) {
-  case IREE::GPU::MMAFragment::Lhs: {
-    outerSizes = getASingleSubgroupLayout().outer;
-    opaqueSizes.append({m, k});
-    break;
-  }
-  case IREE::GPU::MMAFragment::Rhs: {
-    outerSizes = getBSingleSubgroupLayout().outer;
-    opaqueSizes.append({k, n});
-    break;
-  }
-  case IREE::GPU::MMAFragment::Acc: {
-    outerSizes = getCSingleSubgroupLayout().outer;
-    opaqueSizes.append({m, n});
-    break;
-  }
-  }
-  if (permutation.has_value()) {
-    if (permutation.value().size() != outerSizes.size()) {
-      return failure();
-    }
-    applyPermutationToVector(opaqueSizes, permutation.value());
-    applyPermutationToVector(outerSizes, permutation.value());
-  }
-
-  // Inner tile must have sizes matching the opaque layout.
-  auto operandType = llvm::cast<RankedTensorType>(operand.getType());
-  ArrayRef<int64_t> operandShape = operandType.getShape();
-  SmallVector<int64_t, 2> innerShape(operandShape.end() - opaqueSizes.size(),
-                                     operandShape.end());
-  if (!llvm::equal(opaqueSizes, innerShape)) {
-    return failure();
-  }
-
-  // Expand the shape of the inner tile to reflect the MMA thread layout.
-  SmallVector<int64_t, 4> resultShape(operandShape.begin(),
-                                      operandShape.end() - 2);
-  SmallVector<ReassociationIndices> reInds =
-      llvm::map_to_vector(llvm::seq<int64_t>(resultShape.size()),
-                          [](int64_t idx) -> ReassociationIndices {
-                            return ReassociationIndices({idx});
-                          });
-  int idx = reInds.size();
-  for (auto [outer, native] : llvm::zip_equal(outerSizes, opaqueSizes)) {
-    // Skip expansion if the outer dim is unit as the SingleSubgroupLayout gives
-    // a guarantee that the |element| counts are contiguous within the layout,
-    // and a unit outer implies a single offset and size for that dimension.
-    if (outer == 1) {
-      resultShape.push_back(native);
-      reInds.push_back(ReassociationIndices({idx++}));
-      continue;
-    }
-
-    // Reshape to [outer, native / outer] == [outer, thread * element]. This
-    // corresponds to |outer| repetitions of the thread/element sublayout.
-    resultShape.push_back(outer);
-    assert(native % outer == 0 && "invalid mma layout");
-    resultShape.push_back(native / outer);
-    reInds.push_back(ReassociationIndices{idx, idx + 1});
-    idx += 2;
-  }
-
-  reassociations = reInds;
-  resultType = operandType.clone(resultShape);
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // DataTiledMMA Attributes
 //===----------------------------------------------------------------------===//
@@ -1265,22 +945,8 @@
     SmallVector<OpFoldResult> &offsets, SmallVector<OpFoldResult> &sizes,
     SmallVector<OpFoldResult> &strides) const {
 
-  MMASingleSubgroupLayout subgroupLayout;
-  switch (fragment) {
-  case IREE::GPU::MMAFragment::Lhs: {
-    subgroupLayout = getASingleSubgroupLayout();
-    break;
-  }
-  case IREE::GPU::MMAFragment::Rhs: {
-    subgroupLayout = getBSingleSubgroupLayout();
-    break;
-  }
-  case IREE::GPU::MMAFragment::Acc: {
-    subgroupLayout = getCSingleSubgroupLayout();
-    break;
-  }
-  }
-
+  MMASingleSubgroupLayout subgroupLayout =
+      getSingleSubgroupLayout(getIntrinsic().getValue(), fragment);
   SmallVector<OpFoldResult> canonicalOffsets;
   SmallVector<OpFoldResult> canonicalSizes;
   if (failed(populateCanonicalOffsetsSizesAndStrides(
@@ -1445,348 +1111,6 @@
 }
 
 //===----------------------------------------------------------------------===//
-// MMA Schedule Attributes
-//===----------------------------------------------------------------------===//
-
-/// Gets a unit vector of the given rank, but fills in the given dimensions
-/// from the 2 element array |counts|. |dim0| is the position in the returned
-/// vector to put the first element of |counts|, and |dim1| is the position to
-/// put the second element. For example,
-///
-/// rank = 3, counts = [5, 7], dim0 = 2, dim1 = 1
-/// returns [1, 5, 7]
-SmallVector<int64_t> getUnitOfRankWithDims(int64_t rank,
-                                           ArrayRef<int64_t> counts,
-                                           int64_t dim0, int64_t dim1) {
-  assert(counts.size() == 2 &&
-         "Unexpected non-rank 2 single subgroup dimension counts");
-  SmallVector<int64_t> res(rank, 1);
-  res[dim0] = counts[0];
-  res[dim1] = counts[1];
-  return res;
-}
-
-SmallVector<int64_t> getIdentityPerm(int64_t rank) {
-  return llvm::to_vector(llvm::seq(static_cast<int64_t>(0), rank));
-}
-
-/// Constructs an identity permutation with the given rank, except it applies
-/// the given rank-2 |perm| to the two dimensions |dim0| and |dim1|, and then
-/// swaps the positions of dim0 and dim1 in the final permutation. For example,
-///
-/// rank = 3, perm = [1, 0], dim0 = 1, dim1 = 2
-/// returns [0, 1, 2]
-///
-/// This is essentially just applying two rank-2 permutations to two particular
-/// dimensions. First it applies |perm|, which corresponds to a permutation
-/// needed by the underlying intrinsic, then it does another permutation based
-/// on the order of actual dimensions for the MMA fragment. For example, for the
-/// B matrix, dim0 = K and dim1 = N, so for the element order of an MFMA
-/// 16x16x16, perm would be `[1, 0]`, however if the actual contraction is a
-/// matmul_transpose_b, then the element order needs to be [0, 1].
-SmallVector<int64_t> getIdentityPermWithSwap(int64_t rank,
-                                             ArrayRef<int64_t> perm,
-                                             int64_t dim0, int64_t dim1) {
-  assert(perm.size() == 2 &&
-         "Unexpected non-rank 2 single subgroup dimension order");
-  SmallVector<int64_t> res = getIdentityPerm(rank);
-  if (perm[0] > perm[1]) {
-    std::swap(dim0, dim1);
-  }
-  if (dim0 > dim1) {
-    res[dim0] = dim1;
-    res[dim1] = dim0;
-  }
-  return res;
-}
-
-/// Constructs the nested layout given the layout for a single subgroup and the
-/// subgroup/batch counts and orders, as well as the dimensions along which to
-/// distribute the intrinsic's layout.
-///
-/// |outerDim| and |innerDim| refer to which dimensions are the outermost and
-/// innermost for a canonical MK_KN_MN matrix multiply, for a particular
-/// fragment. For example, for the B matrix of an MK_NK_MN matrix multiply,
-/// we would have:
-///   outerDim = 1 for the K dim
-///   innerDim = 0 for the N dim
-///
-/// For something like MK_NKN_MN with multiple N dims, it would typically be:
-///   outerDim = 1 for K
-///   innerDim = 2 for the second N dim
-///
-/// Importantly these two dimensions always refer to the actual dimension
-/// positions in the undistributed vector. For each fragment, this means:
-///   A: [outerDim, innerDim] = [innerMostMDim, innerMostKDim]
-///   B: [outerDim, innerDim] = [innerMostKDim, innerMostNDim]
-///   C: [outerDim, innerDim] = [innerMostMDim, innerMostNDim]
-///
-/// And here inner most is referential to the iteration order, not the order
-/// they appear per fragment (because there is no relationship between the
-/// dimension order of M in A and in C, for example).
-NestedLayoutAttr createNestedLayout(MLIRContext *context, int64_t rank,
-                                    int64_t outerDim, int64_t innerDim,
-                                    SmallVector<int64_t> subgroupSizes,
-                                    SmallVector<int64_t> subgroupStrides,
-                                    SmallVector<int64_t> batchCount,
-                                    MMASingleSubgroupLayout counts) {
-
-  LLVM_DEBUG({
-    llvm::errs() << "Creating Nested Layout for::";
-    llvm::errs() << "\n    outerDim = " << outerDim;
-    llvm::errs() << "\n    innerDim = " << innerDim;
-    llvm::errs() << "\n    subgroupSizes: ";
-    llvm::interleaveComma(subgroupSizes, llvm::errs());
-    llvm::errs() << "\n    subgroupStrides: ";
-    llvm::interleaveComma(subgroupStrides, llvm::errs());
-    llvm::errs() << "\n    batchCount: ";
-    llvm::interleaveComma(batchCount, llvm::errs());
-    llvm::errs() << "\n    counts.outer: ";
-    llvm::interleaveComma(counts.outer, llvm::errs());
-    llvm::errs() << "\n    counts.thread: ";
-    llvm::interleaveComma(counts.thread, llvm::errs());
-    llvm::errs() << "\n    counts.element: ";
-    llvm::interleaveComma(counts.element, llvm::errs());
-    llvm::errs() << "\n    counts.tstrides: ";
-    llvm::interleaveComma(counts.tstrides, llvm::errs());
-    llvm::errs() << "\n";
-  });
-
-  SmallVector<int64_t> outerCount =
-      getUnitOfRankWithDims(rank, counts.outer, outerDim, innerDim);
-  SmallVector<int64_t> threadCount =
-      getUnitOfRankWithDims(rank, counts.thread, outerDim, innerDim);
-  SmallVector<int64_t> threadStrides =
-      getUnitOfRankWithDims(rank, counts.tstrides, outerDim, innerDim);
-  SmallVector<int64_t> elementCount =
-      getUnitOfRankWithDims(rank, counts.element, outerDim, innerDim);
-
-  auto layoutAttr = NestedLayoutAttr::get(context, subgroupSizes, batchCount,
-                                          outerCount, threadCount, elementCount,
-                                          subgroupStrides, threadStrides);
-  return layoutAttr;
-}
-
-FailureOr<std::tuple<VectorExt::VectorLayoutInterface,
-                     VectorExt::VectorLayoutInterface,
-                     VectorExt::VectorLayoutInterface>>
-MMAScheduleAttr::getContractionLayout(VectorContractOpInfo &opInfo,
-                                      linalg::LinalgOp contractOp) const {
-  LLVM_DEBUG({
-    llvm::errs() << "Getting mma layouts for:\n" << contractOp << "\n";
-    llvm::errs() << "For schedule: " << *this << "\n";
-  });
-
-  int64_t rank = contractOp.getIteratorTypesArray().size();
-  auto mmaAttr = llvm::cast<MmaInterfaceAttr>(getIntrinsic());
-  MLIRContext *context = getContext();
-
-  SmallVector<int64_t> bounds = contractOp.getStaticLoopRanges();
-  if (llvm::any_of(bounds,
-                   [](int64_t x) { return x == ShapedType::kDynamic; })) {
-    return failure();
-  }
-
-  if (!llvm::all_of(opInfo.getBatchDims(),
-                    [&bounds](int64_t dim) { return bounds[dim] == 1; })) {
-    LLVM_DEBUG({ llvm::errs() << "non-unit batch dimension\n"; });
-    return failure();
-  }
-
-  // Get the concrete nested layout for each matrix. Note that the struct
-  // MMASingleSubgroupLayout contains the partial layout for the
-  // canonical (M, K) x (K, N) -> (M, N) matmul form; while the specific
-  // contract op we are looking at right now may not be exactly in that form.
-  // So here we need to permute/transpose the canonical layout to match with
-  // the concrete contract op.
-
-  // Note that no matter how we permute/transpose the input contraction
-  // problem, the way we view the hardware warps remain the same--that is,
-  // from the hardware's perspective, a single warp has the same warp ID no
-  // matter what part of the contraction it works on. Similarly here, we are
-  // delinearizing the linearized GPU hardware lane ID into a n-D concatenated
-  // logical warp+thread using the subgroup/thread basis, so the subgroup
-  // basis should remain the same for all A/B/C matrix.
-
-  auto [intrinsicM, intrinsicN, intrinsicK] = mmaAttr.getMNKShape();
-
-  SmallVector<int64_t, 2> subgroupMBasis;
-  SmallVector<int64_t, 2> batchMSizes;
-  int64_t currMCount = getSubgroupMCount();
-
-  auto divideGreedily = [](int64_t availableSubgroups, int64_t dimSize,
-                           int64_t minDimSize) -> std::pair<int64_t, int64_t> {
-    int64_t dividableDim = dimSize / minDimSize;
-    int64_t subgroupsUsed = std::gcd(availableSubgroups, dividableDim);
-    dividableDim /= subgroupsUsed;
-    int64_t batchesUsed = dividableDim;
-    return {subgroupsUsed, batchesUsed};
-  };
-
-  // Greedily break up the M subgroup and batch counts along the "M" iteration
-  // bounds. We distribute as many residual subgroups as possible per M dim,
-  // and then divide the remaining along batch dims. The inner most M dim is
-  // always the one used for the intrinsic, meaning for a valid schedule, the
-  // computed batch counts and subgroup basis will satisfy totalMSize /
-  // intrinsicM = product(batchMSizes) * product(subgroupMBasis)
-  for (auto dim : opInfo.getMDims()) {
-    // Get the number of subgroups and batches used for this dimension based
-    // on the intrinsic size and the bound size.
-    int64_t subgroupsUsed, batchesUsed;
-    if (dim == opInfo.getMDims().back()) {
-      std::tie(subgroupsUsed, batchesUsed) =
-          divideGreedily(currMCount, bounds[dim], intrinsicM);
-    } else {
-      std::tie(subgroupsUsed, batchesUsed) =
-          divideGreedily(currMCount, bounds[dim], 1);
-    }
-    subgroupMBasis.push_back(subgroupsUsed);
-    batchMSizes.push_back(batchesUsed);
-    // Update available subgroup count.
-    currMCount /= subgroupsUsed;
-  }
-
-  SmallVector<int64_t, 2> subgroupNBasis;
-  SmallVector<int64_t, 2> batchNSizes;
-  int64_t currNCount = getSubgroupNCount();
-
-  // Do the same for N dims.
-  for (auto dim : opInfo.getNDims()) {
-    // Get the number of subgroups and batches used for this dimension based
-    // on the intrinsic size and the bound size.
-    int64_t subgroupsUsed, batchesUsed;
-    if (dim == opInfo.getNDims().back()) {
-      std::tie(subgroupsUsed, batchesUsed) =
-          divideGreedily(currNCount, bounds[dim], intrinsicN);
-    } else {
-      std::tie(subgroupsUsed, batchesUsed) =
-          divideGreedily(currNCount, bounds[dim], 1);
-    }
-    subgroupNBasis.push_back(subgroupsUsed);
-    batchNSizes.push_back(batchesUsed);
-    // Update available subgroup count.
-    currNCount /= subgroupsUsed;
-  }
-
-  SmallVector<int64_t> subgroupMStrides(subgroupMBasis.size());
-  SmallVector<int64_t> subgroupNStrides(subgroupNBasis.size());
-
-  auto mDimVec = opInfo.getMDims();
-  llvm::SmallDenseSet<int64_t> mDims(mDimVec.begin(), mDimVec.end());
-  auto nDimVec = opInfo.getNDims();
-  llvm::SmallDenseSet<int64_t> nDims(nDimVec.begin(), nDimVec.end());
-  // Because we currently require all batch dimensions to be unit, the
-  // subgroup basis can be constructed from the M and N bases. To keep things
-  // simple, the current heuristic is to distribute the loop dimensions from
-  // outer to inner.
-  int64_t currStride = 1;
-  int64_t currM = subgroupMStrides.size() - 1;
-  int64_t currN = subgroupNStrides.size() - 1;
-  for (int64_t dim : llvm::reverse(llvm::seq<int64_t>(rank))) {
-    if (mDims.contains(dim)) {
-      subgroupMStrides[currM] = currStride;
-      currStride *= subgroupMBasis[currM];
-      currM--;
-      continue;
-    }
-
-    if (nDims.contains(dim)) {
-      subgroupNStrides[currN] = currStride;
-      currStride *= subgroupNBasis[currN];
-      currN--;
-      continue;
-    }
-  }
-
-  // C matrix layout
-  auto [m, n] = opInfo.getResultMNIndex();
-  int64_t cRank = opInfo.getCRank();
-
-  // Get the M and N dims w.r.t. the dimensions of the C matrix. cMDims and
-  // cNDims are the M and N dimensions of the C matrix in the order they are
-  // iterated over in the contraction.
-  SmallVector<int64_t> cMDims = opInfo.outMDims;
-  SmallVector<int64_t> cNDims = opInfo.outNDims;
-  SmallVector<int64_t> cBatchSizes(cRank, 1);
-  SmallVector<int64_t> cSubgroupSizes(cRank, 1);
-  SmallVector<int64_t> cSubgroupStrides(cRank, 0);
-  for (auto [i, dim] : llvm::enumerate(cMDims)) {
-    cBatchSizes[dim] = batchMSizes[i];
-    cSubgroupSizes[dim] = subgroupMBasis[i];
-    cSubgroupStrides[dim] = subgroupMStrides[i];
-  }
-  for (auto [i, dim] : llvm::enumerate(cNDims)) {
-    cBatchSizes[dim] = batchNSizes[i];
-    cSubgroupSizes[dim] = subgroupNBasis[i];
-    cSubgroupStrides[dim] = subgroupNStrides[i];
-  }
-
-  auto cLayout = createNestedLayout(context, cRank, m, n,
-                                    /*subgroupCount=*/cSubgroupSizes,
-                                    /*subgroupStrides=*/cSubgroupStrides,
-                                    /*batchCount=*/cBatchSizes,
-                                    getCSingleSubgroupLayout(mmaAttr));
-  LLVM_DEBUG({ llvm::errs() << "C layout: " << cLayout << "\n"; });
-
-  // A matrix layout
-  auto [afm, bfn] = opInfo.getOperandMNIndex();
-  auto [afk, bfk] = opInfo.getOperandKIndex();
-
-  int64_t aRank = opInfo.getARank();
-
-  SmallVector<int64_t> aMDims = opInfo.lhsMDims;
-  SmallVector<int64_t> aBatchSizes(aRank, 1);
-  SmallVector<int64_t> aSubgroupSizes(aRank, 1);
-  SmallVector<int64_t> aSubgroupStrides(aRank, 0);
-  for (auto [i, dim] : llvm::enumerate(aMDims)) {
-    aBatchSizes[dim] = batchMSizes[i];
-    aSubgroupSizes[dim] = subgroupMBasis[i];
-    aSubgroupStrides[dim] = subgroupMStrides[i];
-  }
-  for (auto [kDim, lhsKDim] :
-       llvm::zip_equal(opInfo.getKDims(), opInfo.lhsKDim)) {
-    aBatchSizes[lhsKDim] = bounds[kDim];
-  }
-  aBatchSizes[afk] = bounds[opInfo.getKDims().back()] / intrinsicK;
-
-  auto aLayout = createNestedLayout(context, aRank, afm, afk,
-                                    /*subgroupCount=*/aSubgroupSizes,
-                                    /*subgroupStrides=*/aSubgroupStrides,
-                                    /*batchCount=*/aBatchSizes,
-                                    getASingleSubgroupLayout(mmaAttr));
-  LLVM_DEBUG({ llvm::errs() << "A layout: " << aLayout << "\n"; });
-
-  int64_t bRank = opInfo.getBRank();
-
-  SmallVector<int64_t> bNDims = opInfo.rhsNDims;
-  SmallVector<int64_t> bBatchSizes(bRank, 1);
-  SmallVector<int64_t> bSubgroupSizes(bRank, 1);
-  SmallVector<int64_t> bSubgroupStrides(bRank, 0);
-  for (auto [i, dim] : llvm::enumerate(bNDims)) {
-    bBatchSizes[dim] = batchNSizes[i];
-    bSubgroupSizes[dim] = subgroupNBasis[i];
-    bSubgroupStrides[dim] = subgroupNStrides[i];
-  }
-  for (auto [kDim, rhsKDim] :
-       llvm::zip_equal(opInfo.getKDims(), opInfo.rhsKDim)) {
-    bBatchSizes[rhsKDim] = bounds[kDim];
-  }
-  bBatchSizes[bfk] = bounds[opInfo.getKDims().back()] / intrinsicK;
-
-  auto bLayout = createNestedLayout(context, bRank, bfk, bfn,
-                                    /*subgroupCount=*/bSubgroupSizes,
-                                    /*subgroupStrides=*/bSubgroupStrides,
-                                    /*batchCount=*/bBatchSizes,
-                                    getBSingleSubgroupLayout(mmaAttr));
-  LLVM_DEBUG({ llvm::errs() << "B layout: " << bLayout << "\n"; });
-
-  std::tuple<VectorLayoutInterface, VectorLayoutInterface,
-             VectorLayoutInterface>
-      result = {aLayout, bLayout, cLayout};
-  return result;
-}
-
-//===----------------------------------------------------------------------===//
 // Target Attributes
 //===----------------------------------------------------------------------===//
 
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h
index 7942066..92ce4f4 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h
@@ -22,12 +22,40 @@
 
 namespace mlir::iree_compiler::IREE::GPU {
 
-// Partial nested layout for an MMA intrinsic's matrix input/output inside
-// a single subgroup.
+// Struct describing the detailed subgroup-level layout of a MMA intrinsic.
+// Together with element type information and subgroup size, it completes the
+// full description of the semantics of a MMA intrinsic.
+//
+// Note: It is not possible to infer subgroup size from the information in this
+// struct. The product of the `thread` sizes here is often, but not always equal
+// to subgroup size. When the product of the `thread` sizes (call that product
+// `P`) is smaller than subgroup size, it must be a divisor of it, and the
+// semantics in that case are that threads within the subgroup whose thread-ids
+// differ by a multiple of `P`, are accessing the same elements.
+//
+// Example observed in RDNA3 WMMA Wave64 intrinsics:
+// If the subgroup size is 64 but the product `P` of `thread` sizes is 32, that
+// means that each element is being accessed by 2 threads (2 = 64/32), and the
+// threads accessing the same element are those whose tids are exactly 32 apart.
 struct MMASingleSubgroupLayout {
+  // Internal dimensions (as in TileSwizzle::Dim::Kind::Internal) that are
+  // outer-most in the layout. This happens when a MMA op, seen on a single
+  // thread, has an operand that consists of multiple elements, and these elems
+  // are NOT contiguous.
+  // This is not used by every MMA op; ops which don't use that simply have 1's.
   SmallVector<int64_t, 2> outer;
+  // Cross-thread dimensions (as in TileSwizzle::Dim::Kind::CrossThread).
+  // This is the kind of dimension that is present in all GPU MMA ops, by
+  // definition of "SIMT". It is still possible for one of the `thread` dims to
+  // be 1, but not both.
   SmallVector<int64_t, 2> thread;
+  // Strides corresponding to the cross-thread dimensions.
   SmallVector<int64_t, 2> tstrides;
+  // Internal dimensions (as in TileSwizzle::Dim::Kind::Internal) that are
+  // inner-most in the layout. This happens when a MMA op, seen on a single
+  // thread, has an operand that consists of multiple elements, and these elems
+  // are NOT contiguous.
+  // This is not used by every MMA op; ops which don't use that simply have 1's.
   SmallVector<int64_t, 2> element;
 };
 
@@ -43,6 +71,22 @@
 
 MMASingleSubgroupLayout getCSingleSubgroupLayout(MmaInterfaceAttr mmaKind);
 
+// Struct describing the shape of a MMA operation, but not the detailed layout.
+// TODO(bjacob): the only user outside of IREEGPUAttrs.cpp is
+// LLVMGPU/TransformExtensions, so maybe make that internal again if/when that
+// goes away.
+struct OpaqueMmaLayout {
+  int64_t mSize = 0;
+  int64_t nSize = 0;
+  int64_t kSize = 0;
+  Type aType;
+  Type bType;
+  Type cType;
+};
+
+OpaqueMmaLayout getOpaqueMMALayout(MLIRContext *context,
+                                   IREE::GPU::MMAIntrinsic intrinsic);
+
 } // namespace mlir::iree_compiler::IREE::GPU
 
 // clang-format off
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
index bc6f2d3..51674db 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
@@ -148,7 +148,6 @@
   DeclareAttrInterfaceMethods<IREEGPU_MmaInterfaceAttr, [
     "getABCElementTypes",
     "getABCVectorTypes",
-    "getContractionLayout",
     "getMNKShape",
     "getSubgroupSize",
     "getMmaScope",
@@ -157,7 +156,6 @@
     "getCSingleSubgroupLayout",
     "buildMmaOperation",
     "populateOperandOffsetsSizesStrides",
-    "materializeOperandConcreteShape",
   ]>
 ]> {
   let cppNamespace = "::mlir::iree_compiler::IREE::GPU";
@@ -259,19 +257,11 @@
     |intrinsic| field specifies which particular MMA intrinsic is targeted by
     the data-tiling.
 
-    The tile swizzling already happens, so the attribute does not need to
-    implement materializeOperandConcreteShape interface method. E.g., if the
-    target intrinsic is MFMA_F32_16x16x4_F32:
-      - The inner tile shape of LHS is 4x16.
-      - The inner tile shape of RHS is 4x16.
-      - The inner tile shape of ACC is 4x16x4.
-
-    Furthermore, the unrolling and interleaving can be represented with the
-    attribute. In the concept of data-tiling, we always unroll the parallel
-    dimensions (i.e., M, N dimensions) to be outermost, and interleave the
-    unrolled K dimension. I.e., the unrolled K dimension becomes the innermost
-    dimension. The constraint can be relaxed based on data-tiling needs. The
-    additional information can be added to `parameters`.
+    The other fields default to one, and that default results in a single
+    intrinsic equivalent to MMAAttr, while values greater than one result in
+    wider "kernels" consisting of multiple intrinsics, with the data layout
+    already swizzled into a tile layout that allows each intrinsic to access
+    data at an offset that's as simple as possible a mapping from the thread ID.
   }];
 
   let assemblyFormat = "`<` struct(params) `>`";
@@ -369,15 +359,6 @@
   );
 
   let assemblyFormat = "`<` struct(params) `>`";
-
-  let extraClassDeclaration = [{
-    // Returns the A/B/C matrix concrete layout targeting |contractOp|.
-    ::mlir::FailureOr<::std::tuple<VectorExt::VectorLayoutInterface,
-                                 VectorExt::VectorLayoutInterface,
-                                 VectorExt::VectorLayoutInterface>>
-      getContractionLayout(::mlir::iree_compiler::VectorContractOpInfo &opInfo,
-                           ::mlir::linalg::LinalgOp contractOp) const;
-  }];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/ConcretizeMmaShapes.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/ConcretizeMmaShapes.cpp
index 94bcf3d..6b2e834 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/ConcretizeMmaShapes.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/ConcretizeMmaShapes.cpp
@@ -23,7 +23,80 @@
   using ConcretizeMmaShapesPassBase::ConcretizeMmaShapesPassBase;
   void runOnOperation() override;
 };
-} // namespace
+
+LogicalResult materializeOperandConcreteShape(
+    OpBuilder &builder, MMAAttr mma, IREE::GPU::MMAFragment fragment,
+    Value operand, std::optional<ArrayRef<int64_t>> permutation,
+    SmallVector<ReassociationIndices> &reassociations,
+    RankedTensorType &resultType) {
+
+  SmallVector<int64_t, 2> outerSizes;
+  SmallVector<int64_t, 2> opaqueSizes;
+  auto [m, n, k] = mma.getMNKShape();
+  switch (fragment) {
+  case IREE::GPU::MMAFragment::Lhs: {
+    outerSizes = mma.getASingleSubgroupLayout().outer;
+    opaqueSizes.append({m, k});
+    break;
+  }
+  case IREE::GPU::MMAFragment::Rhs: {
+    outerSizes = mma.getBSingleSubgroupLayout().outer;
+    opaqueSizes.append({k, n});
+    break;
+  }
+  case IREE::GPU::MMAFragment::Acc: {
+    outerSizes = mma.getCSingleSubgroupLayout().outer;
+    opaqueSizes.append({m, n});
+    break;
+  }
+  }
+  if (permutation.has_value()) {
+    if (permutation.value().size() != outerSizes.size()) {
+      return failure();
+    }
+    applyPermutationToVector(opaqueSizes, permutation.value());
+    applyPermutationToVector(outerSizes, permutation.value());
+  }
+
+  // Inner tile must have sizes matching the opaque layout.
+  auto operandType = llvm::cast<RankedTensorType>(operand.getType());
+  ArrayRef<int64_t> operandShape = operandType.getShape();
+  if (opaqueSizes != operandShape.take_back(opaqueSizes.size())) {
+    return failure();
+  }
+
+  // Expand the shape of the inner tile to reflect the MMA thread layout.
+  SmallVector<int64_t, 4> resultShape(operandShape.begin(),
+                                      operandShape.end() - 2);
+  SmallVector<ReassociationIndices> reInds =
+      llvm::map_to_vector(llvm::seq<int64_t>(resultShape.size()),
+                          [](int64_t idx) -> ReassociationIndices {
+                            return ReassociationIndices({idx});
+                          });
+  int idx = reInds.size();
+  for (auto [outer, native] : llvm::zip_equal(outerSizes, opaqueSizes)) {
+    // Skip expansion if the outer dim is unit as the SingleSubgroupLayout gives
+    // a guarantee that the |element| counts are contiguous within the layout,
+    // and a unit outer implies a single offset and size for that dimension.
+    if (outer == 1) {
+      resultShape.push_back(native);
+      reInds.push_back(ReassociationIndices({idx++}));
+      continue;
+    }
+
+    // Reshape to [outer, native / outer] == [outer, thread * element]. This
+    // corresponds to |outer| repetitions of the thread/element sublayout.
+    resultShape.push_back(outer);
+    assert(native % outer == 0 && "invalid mma layout");
+    resultShape.push_back(native / outer);
+    reInds.push_back(ReassociationIndices{idx, idx + 1});
+    idx += 2;
+  }
+
+  reassociations = reInds;
+  resultType = operandType.clone(resultShape);
+  return success();
+}
 
 struct ConcretizeMmaOperandShape final : OpRewritePattern<MultiMmaOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -56,12 +129,15 @@
     }
 
     // Get the reassociation indices and result type of the expand_shape op.
-    MmaInterfaceAttr kind = mmaOp.getKind();
+    MMAAttr kind = dyn_cast<MMAAttr>(mmaOp.getKind());
+    if (!kind) {
+      return failure();
+    }
     SmallVector<ReassociationIndices> reassociations;
     RankedTensorType concreteType;
-    if (failed(kind.materializeOperandConcreteShape(rewriter, fragment, operand,
-                                                    permutation, reassociations,
-                                                    concreteType))) {
+    if (failed(materializeOperandConcreteShape(rewriter, kind, fragment,
+                                               operand, permutation,
+                                               reassociations, concreteType))) {
       return failure();
     }
 
@@ -140,6 +216,8 @@
   MMAFragment fragment;
 };
 
+} // namespace
+
 void ConcretizeMmaShapesPass::runOnOperation() {
   MLIRContext *context = &getContext();
   auto funcOp = getOperation();
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
index 69333a1..9df22c7 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
@@ -8,6 +8,7 @@
 #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
 #include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
 #include "iree/compiler/Codegen/LLVMGPU/Passes.h"
+#include "llvm/Support/Debug.h"
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -71,6 +72,310 @@
   return *config.getSubgroupNCount();
 }
 
+/// Gets a unit vector of the given rank, but fills in the given dimensions
+/// from the 2 element array |counts|. |dim0| is the position in the returned
+/// vector to put the first element of |counts|, and |dim1| is the position to
+/// put the second element. For example,
+///
+/// rank = 3, counts = [5, 7], dim0 = 2, dim1 = 1
+/// returns [1, 5, 7]
+static SmallVector<int64_t> getUnitOfRankWithDims(int64_t rank,
+                                                  ArrayRef<int64_t> counts,
+                                                  int64_t dim0, int64_t dim1) {
+  assert(counts.size() == 2 &&
+         "Unexpected non-rank 2 single subgroup dimension counts");
+  SmallVector<int64_t> res(rank, 1);
+  res[dim0] = counts[0];
+  res[dim1] = counts[1];
+  return res;
+}
+
+/// Constructs the nested layout given the layout for a single subgroup and the
+/// subgroup/batch counts and orders, as well as the dimensions along which to
+/// distribute the intrinsic's layout.
+///
+/// |outerDim| and |innerDim| refer to which dimensions are the outermost and
+/// innermost for a canonical MK_KN_MN matrix multiply, for a particular
+/// fragment. For example, for the B matrix of an MK_NK_MN matrix multiply,
+/// we would have:
+///   outerDim = 1 for the K dim
+///   innerDim = 0 for the N dim
+///
+/// For something like MK_NKN_MN with multiple N dims, it would typically be:
+///   outerDim = 1 for K
+///   innerDim = 2 for the second N dim
+///
+/// Importantly these two dimensions always refer to the actual dimension
+/// positions in the undistributed vector. For each fragment, this means:
+///   A: [outerDim, innerDim] = [innerMostMDim, innerMostKDim]
+///   B: [outerDim, innerDim] = [innerMostKDim, innerMostNDim]
+///   C: [outerDim, innerDim] = [innerMostMDim, innerMostNDim]
+///
+/// And here inner most is referential to the iteration order, not the order
+/// they appear per fragment (because there is no relationship between the
+/// dimension order of M in A and in C, for example).
+static NestedLayoutAttr createNestedLayout(
+    MLIRContext *context, int64_t rank, int64_t outerDim, int64_t innerDim,
+    ArrayRef<int64_t> subgroupSizes, ArrayRef<int64_t> subgroupStrides,
+    ArrayRef<int64_t> batchCount, IREE::GPU::MMASingleSubgroupLayout counts) {
+
+  LLVM_DEBUG({
+    llvm::dbgs() << "Creating Nested Layout for::";
+    llvm::dbgs() << "\n    outerDim = " << outerDim;
+    llvm::dbgs() << "\n    innerDim = " << innerDim;
+    llvm::dbgs() << "\n    subgroupSizes: ";
+    llvm::interleaveComma(subgroupSizes, llvm::dbgs());
+    llvm::dbgs() << "\n    subgroupStrides: ";
+    llvm::interleaveComma(subgroupStrides, llvm::dbgs());
+    llvm::dbgs() << "\n    batchCount: ";
+    llvm::interleaveComma(batchCount, llvm::dbgs());
+    llvm::dbgs() << "\n    counts.outer: ";
+    llvm::interleaveComma(counts.outer, llvm::dbgs());
+    llvm::dbgs() << "\n    counts.thread: ";
+    llvm::interleaveComma(counts.thread, llvm::dbgs());
+    llvm::dbgs() << "\n    counts.element: ";
+    llvm::interleaveComma(counts.element, llvm::dbgs());
+    llvm::dbgs() << "\n    counts.tstrides: ";
+    llvm::interleaveComma(counts.tstrides, llvm::dbgs());
+    llvm::dbgs() << "\n";
+  });
+
+  SmallVector<int64_t> outerCount =
+      getUnitOfRankWithDims(rank, counts.outer, outerDim, innerDim);
+  SmallVector<int64_t> threadCount =
+      getUnitOfRankWithDims(rank, counts.thread, outerDim, innerDim);
+  SmallVector<int64_t> threadStrides =
+      getUnitOfRankWithDims(rank, counts.tstrides, outerDim, innerDim);
+  SmallVector<int64_t> elementCount =
+      getUnitOfRankWithDims(rank, counts.element, outerDim, innerDim);
+
+  auto layoutAttr = NestedLayoutAttr::get(context, subgroupSizes, batchCount,
+                                          outerCount, threadCount, elementCount,
+                                          subgroupStrides, threadStrides);
+  return layoutAttr;
+}
+
+static FailureOr<std::tuple<IREE::VectorExt::VectorLayoutInterface,
+                            IREE::VectorExt::VectorLayoutInterface,
+                            IREE::VectorExt::VectorLayoutInterface>>
+getContractionLayout(IREE::GPU::MMAScheduleAttr schedule,
+                     VectorContractOpInfo &opInfo,
+                     linalg::LinalgOp contractOp) {
+  LLVM_DEBUG({
+    llvm::dbgs() << "Getting mma layouts for:\n" << contractOp << "\n";
+    llvm::dbgs() << "For schedule: " << schedule << "\n";
+  });
+
+  int64_t rank = contractOp.getIteratorTypesArray().size();
+  auto mmaAttr =
+      llvm::cast<IREE::GPU::MmaInterfaceAttr>(schedule.getIntrinsic());
+  MLIRContext *context = schedule.getContext();
+
+  SmallVector<int64_t> bounds = contractOp.getStaticLoopRanges();
+  if (llvm::any_of(bounds,
+                   [](int64_t x) { return x == ShapedType::kDynamic; })) {
+    return failure();
+  }
+
+  if (!llvm::all_of(opInfo.getBatchDims(),
+                    [&bounds](int64_t dim) { return bounds[dim] == 1; })) {
+    LLVM_DEBUG({ llvm::dbgs() << "non-unit batch dimension\n"; });
+    return failure();
+  }
+
+  // Get the concrete nested layout for each matrix. Note that the struct
+  // MMASingleSubgroupLayout contains the partial layout for the
+  // canonical (M, K) x (K, N) -> (M, N) matmul form; while the specific
+  // contract op we are looking at right now may not be exactly in that form.
+  // So here we need to permute/transpose the canonical layout to match with
+  // the concrete contract op.
+
+  // Note that no matter how we permute/transpose the input contraction
+  // problem, the way we view the hardware warps remain the same--that is,
+  // from the hardware's perspective, a single warp has the same warp ID no
+  // matter what part of the contraction it works on. Similarly here, we are
+  // delinearizing the linearized GPU hardware lane ID into a n-D concatenated
+  // logical warp+thread using the subgroup/thread basis, so the subgroup
+  // basis should remain the same for all A/B/C matrix.
+
+  auto [intrinsicM, intrinsicN, intrinsicK] = mmaAttr.getMNKShape();
+
+  SmallVector<int64_t, 2> subgroupMBasis;
+  SmallVector<int64_t, 2> batchMSizes;
+  int64_t currMCount = schedule.getSubgroupMCount();
+
+  auto divideGreedily = [](int64_t availableSubgroups, int64_t dimSize,
+                           int64_t minDimSize) -> std::pair<int64_t, int64_t> {
+    int64_t dividableDim = dimSize / minDimSize;
+    int64_t subgroupsUsed = std::gcd(availableSubgroups, dividableDim);
+    dividableDim /= subgroupsUsed;
+    int64_t batchesUsed = dividableDim;
+    return {subgroupsUsed, batchesUsed};
+  };
+
+  // Greedily break up the M subgroup and batch counts along the "M" iteration
+  // bounds. We distribute as many residual subgroups as possible per M dim,
+  // and then divide the remaining along batch dims. The inner most M dim is
+  // always the one used for the intrinsic, meaning for a valid schedule, the
+  // computed batch counts and subgroup basis will satisfy totalMSize /
+  // intrinsicM = product(batchMSizes) * product(subgroupMBasis)
+  for (auto dim : opInfo.getMDims()) {
+    // Get the number of subgroups and batches used for this dimension based
+    // on the intrinsic size and the bound size.
+    int64_t subgroupsUsed, batchesUsed;
+    if (dim == opInfo.getMDims().back()) {
+      std::tie(subgroupsUsed, batchesUsed) =
+          divideGreedily(currMCount, bounds[dim], intrinsicM);
+    } else {
+      std::tie(subgroupsUsed, batchesUsed) =
+          divideGreedily(currMCount, bounds[dim], 1);
+    }
+    subgroupMBasis.push_back(subgroupsUsed);
+    batchMSizes.push_back(batchesUsed);
+    // Update available subgroup count.
+    currMCount /= subgroupsUsed;
+  }
+
+  SmallVector<int64_t, 2> subgroupNBasis;
+  SmallVector<int64_t, 2> batchNSizes;
+  int64_t currNCount = schedule.getSubgroupNCount();
+
+  // Do the same for N dims.
+  for (auto dim : opInfo.getNDims()) {
+    // Get the number of subgroups and batches used for this dimension based
+    // on the intrinsic size and the bound size.
+    int64_t subgroupsUsed, batchesUsed;
+    if (dim == opInfo.getNDims().back()) {
+      std::tie(subgroupsUsed, batchesUsed) =
+          divideGreedily(currNCount, bounds[dim], intrinsicN);
+    } else {
+      std::tie(subgroupsUsed, batchesUsed) =
+          divideGreedily(currNCount, bounds[dim], 1);
+    }
+    subgroupNBasis.push_back(subgroupsUsed);
+    batchNSizes.push_back(batchesUsed);
+    // Update available subgroup count.
+    currNCount /= subgroupsUsed;
+  }
+
+  SmallVector<int64_t> subgroupMStrides(subgroupMBasis.size());
+  SmallVector<int64_t> subgroupNStrides(subgroupNBasis.size());
+
+  auto mDimVec = opInfo.getMDims();
+  llvm::SmallDenseSet<int64_t> mDims(mDimVec.begin(), mDimVec.end());
+  auto nDimVec = opInfo.getNDims();
+  llvm::SmallDenseSet<int64_t> nDims(nDimVec.begin(), nDimVec.end());
+  // Because we currently require all batch dimensions to be unit, the
+  // subgroup basis can be constructed from the M and N bases. To keep things
+  // simple, the current heuristic is to distribute the loop dimensions from
+  // outer to inner.
+  int64_t currStride = 1;
+  int64_t currM = subgroupMStrides.size() - 1;
+  int64_t currN = subgroupNStrides.size() - 1;
+  for (int64_t dim : llvm::reverse(llvm::seq<int64_t>(rank))) {
+    if (mDims.contains(dim)) {
+      subgroupMStrides[currM] = currStride;
+      currStride *= subgroupMBasis[currM];
+      currM--;
+      continue;
+    }
+
+    if (nDims.contains(dim)) {
+      subgroupNStrides[currN] = currStride;
+      currStride *= subgroupNBasis[currN];
+      currN--;
+      continue;
+    }
+  }
+
+  // C matrix layout
+  auto [m, n] = opInfo.getResultMNIndex();
+  int64_t cRank = opInfo.getCRank();
+
+  // Get the M and N dims w.r.t. the dimensions of the C matrix. cMDims and
+  // cNDims are the M and N dimensions of the C matrix in the order they are
+  // iterated over in the contraction.
+  SmallVector<int64_t> cMDims = opInfo.outMDims;
+  SmallVector<int64_t> cNDims = opInfo.outNDims;
+  SmallVector<int64_t> cBatchSizes(cRank, 1);
+  SmallVector<int64_t> cSubgroupSizes(cRank, 1);
+  SmallVector<int64_t> cSubgroupStrides(cRank, 0);
+  for (auto [i, dim] : llvm::enumerate(cMDims)) {
+    cBatchSizes[dim] = batchMSizes[i];
+    cSubgroupSizes[dim] = subgroupMBasis[i];
+    cSubgroupStrides[dim] = subgroupMStrides[i];
+  }
+  for (auto [i, dim] : llvm::enumerate(cNDims)) {
+    cBatchSizes[dim] = batchNSizes[i];
+    cSubgroupSizes[dim] = subgroupNBasis[i];
+    cSubgroupStrides[dim] = subgroupNStrides[i];
+  }
+
+  auto cLayout = createNestedLayout(context, cRank, m, n,
+                                    /*subgroupCount=*/cSubgroupSizes,
+                                    /*subgroupStrides=*/cSubgroupStrides,
+                                    /*batchCount=*/cBatchSizes,
+                                    getCSingleSubgroupLayout(mmaAttr));
+  LLVM_DEBUG({ llvm::dbgs() << "C layout: " << cLayout << "\n"; });
+
+  // A matrix layout
+  auto [afm, bfn] = opInfo.getOperandMNIndex();
+  auto [afk, bfk] = opInfo.getOperandKIndex();
+
+  int64_t aRank = opInfo.getARank();
+
+  SmallVector<int64_t> aMDims = opInfo.lhsMDims;
+  SmallVector<int64_t> aBatchSizes(aRank, 1);
+  SmallVector<int64_t> aSubgroupSizes(aRank, 1);
+  SmallVector<int64_t> aSubgroupStrides(aRank, 0);
+  for (auto [i, dim] : llvm::enumerate(aMDims)) {
+    aBatchSizes[dim] = batchMSizes[i];
+    aSubgroupSizes[dim] = subgroupMBasis[i];
+    aSubgroupStrides[dim] = subgroupMStrides[i];
+  }
+  for (auto [kDim, lhsKDim] :
+       llvm::zip_equal(opInfo.getKDims(), opInfo.lhsKDim)) {
+    aBatchSizes[lhsKDim] = bounds[kDim];
+  }
+  aBatchSizes[afk] = bounds[opInfo.getKDims().back()] / intrinsicK;
+
+  auto aLayout = createNestedLayout(context, aRank, afm, afk,
+                                    /*subgroupCount=*/aSubgroupSizes,
+                                    /*subgroupStrides=*/aSubgroupStrides,
+                                    /*batchCount=*/aBatchSizes,
+                                    getASingleSubgroupLayout(mmaAttr));
+  LLVM_DEBUG({ llvm::dbgs() << "A layout: " << aLayout << "\n"; });
+
+  int64_t bRank = opInfo.getBRank();
+
+  SmallVector<int64_t> bNDims = opInfo.rhsNDims;
+  SmallVector<int64_t> bBatchSizes(bRank, 1);
+  SmallVector<int64_t> bSubgroupSizes(bRank, 1);
+  SmallVector<int64_t> bSubgroupStrides(bRank, 0);
+  for (auto [i, dim] : llvm::enumerate(bNDims)) {
+    bBatchSizes[dim] = batchNSizes[i];
+    bSubgroupSizes[dim] = subgroupNBasis[i];
+    bSubgroupStrides[dim] = subgroupNStrides[i];
+  }
+  for (auto [kDim, rhsKDim] :
+       llvm::zip_equal(opInfo.getKDims(), opInfo.rhsKDim)) {
+    bBatchSizes[rhsKDim] = bounds[kDim];
+  }
+  bBatchSizes[bfk] = bounds[opInfo.getKDims().back()] / intrinsicK;
+
+  auto bLayout = createNestedLayout(context, bRank, bfk, bfn,
+                                    /*subgroupCount=*/bSubgroupSizes,
+                                    /*subgroupStrides=*/bSubgroupStrides,
+                                    /*batchCount=*/bBatchSizes,
+                                    getBSingleSubgroupLayout(mmaAttr));
+  LLVM_DEBUG({ llvm::dbgs() << "B layout: " << bLayout << "\n"; });
+
+  std::tuple<VectorLayoutInterface, VectorLayoutInterface,
+             VectorLayoutInterface>
+      result = {aLayout, bLayout, cLayout};
+  return result;
+}
+
 static LogicalResult setContractionAnchor(IREE::GPU::MMAScheduleAttr schedule,
                                           SmallVector<bool> promotedOperands,
                                           RewriterBase &rewriter,
@@ -89,7 +394,7 @@
           contract.getIndexingMapsArray());
   assert(succeeded(opInfo) && "contraction should have been inferred");
 
-  auto layouts = schedule.getContractionLayout(opInfo.value(), contract);
+  auto layouts = getContractionLayout(schedule, opInfo.value(), contract);
   if (failed(layouts)) {
     return contract->emitError("cannot get concrete layout for contraction");
   }
@@ -176,7 +481,7 @@
   assert(succeeded(opInfo) &&
          "unit filter dim convolution should have been infered");
 
-  auto layouts = schedule.getContractionLayout(opInfo.value(), conv);
+  auto layouts = getContractionLayout(schedule, opInfo.value(), conv);
   if (failed(layouts)) {
     return conv->emitError("cannot get concrete layout for convolution");
   }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
index 3dd0c12..a2c2187 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
@@ -11,6 +11,7 @@
 #include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h"
 #include "iree/compiler/Codegen/Common/GPU/Passes.h"
 #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
+#include "iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h"
 #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h"
 #include "iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h"
 #include "iree/compiler/Codegen/Utils/GPUUtils.h"
@@ -1599,6 +1600,219 @@
   return VectorExt::LayoutAttr::get(ctx, perDimLayouts);
 }
 
+// Struct containing concrete MMA shape, type, and layout information.
+struct ConcreteMmaLayout {
+  GPU::OpaqueMmaLayout base;
+  VectorExt::PerDimLayoutAttr aMLayout;
+  VectorExt::PerDimLayoutAttr aKLayout;
+  VectorExt::PerDimLayoutAttr bKLayout;
+  VectorExt::PerDimLayoutAttr bNLayout;
+  VectorExt::PerDimLayoutAttr cMLayout;
+  VectorExt::PerDimLayoutAttr cNLayout;
+};
+
+static std::tuple<VectorExt::PerDimLayoutAttr, VectorExt::PerDimLayoutAttr>
+getPerDimLayoutAttrs(MLIRContext *context, TileSwizzle swizzle) {
+  // Step 1: obtain the swizzled tile shape, but keeping track of the source
+  // dimension indices.
+  struct SrcIndexAndSwizzleDim {
+    size_t srcIndex;
+    TileSwizzle::Dim dim;
+  };
+  SmallVector<SrcIndexAndSwizzleDim> swizzledShape;
+  for (auto [i, e] : llvm::enumerate(swizzle.expandShape)) {
+    for (TileSwizzle::Dim d : e) {
+      swizzledShape.push_back(SrcIndexAndSwizzleDim{i, d});
+    }
+  }
+  applyPermutationToVector(swizzledShape, swizzle.permutation);
+
+  // Step 2: collect the appropriate labels to use for the swizzled dims.
+  VectorExt::LayoutDimension internalLabels[] = {
+      VectorExt::LayoutDimension::VECTORZ, VectorExt::LayoutDimension::VECTORY,
+      VectorExt::LayoutDimension::VECTORX};
+  VectorExt::LayoutDimension crossThreadLabels[] = {
+      VectorExt::LayoutDimension::LANEZ, VectorExt::LayoutDimension::LANEY,
+      VectorExt::LayoutDimension::LANEX};
+  auto internalLabelIter = std::end(internalLabels);
+  auto crossThreadLabelIter = std::end(crossThreadLabels);
+  for (SrcIndexAndSwizzleDim d : swizzledShape) {
+    if (d.dim.kind == TileSwizzle::Dim::Kind::Internal) {
+      assert(internalLabelIter != std::begin(internalLabels));
+      --internalLabelIter;
+    } else if (d.dim.kind == TileSwizzle::Dim::Kind::CrossThread) {
+      assert(crossThreadLabelIter != std::begin(crossThreadLabels));
+      --crossThreadLabelIter;
+    } else {
+      assert(false && "unexpected dimension kind in intrinsic swizzle");
+    }
+  }
+
+  // Step 3: put together the result PerDimLayoutAttr'd for the two source dims.
+  SmallVector<VectorExt::LayoutDimensionAttr> labels[2];
+  SmallVector<int64_t> shape[2];
+  for (SrcIndexAndSwizzleDim d : swizzledShape) {
+    shape[d.srcIndex].push_back(d.dim.size);
+    auto &labelIterRef = (d.dim.kind == TileSwizzle::Dim::Kind::Internal)
+                             ? internalLabelIter
+                             : crossThreadLabelIter;
+    labels[d.srcIndex].push_back(VectorExt::LayoutDimensionAttr::get(
+        context, static_cast<VectorExt::LayoutDimension>(*labelIterRef++)));
+  }
+  return {VectorExt::PerDimLayoutAttr::get(context, labels[0], shape[0]),
+          VectorExt::PerDimLayoutAttr::get(context, labels[1], shape[1])};
+};
+
+static ConcreteMmaLayout getConcreteMMALayout(MLIRContext *context,
+                                              GPU::MMAIntrinsic intrinsic) {
+  auto opaque = GPU::getOpaqueMMALayout(context, intrinsic);
+  ConcreteMmaLayout concreteLayout;
+  concreteLayout.base = opaque;
+  auto lhsSwizzle = getIntrinsicSwizzle(intrinsic, GPU::MMAFragment::Lhs);
+  auto rhsSwizzle = getIntrinsicSwizzle(intrinsic, GPU::MMAFragment::Rhs);
+  auto accSwizzle = getIntrinsicSwizzle(intrinsic, GPU::MMAFragment::Acc);
+  std::tie(concreteLayout.aMLayout, concreteLayout.aKLayout) =
+      getPerDimLayoutAttrs(context, lhsSwizzle);
+  std::tie(concreteLayout.bNLayout, concreteLayout.bKLayout) =
+      getPerDimLayoutAttrs(context, rhsSwizzle);
+  std::tie(concreteLayout.cMLayout, concreteLayout.cNLayout) =
+      getPerDimLayoutAttrs(context, accSwizzle);
+  return concreteLayout;
+}
+
+static VectorExt::PerDimLayoutAttr
+getBatchedPerDimLayoutAttr(VectorExt::LayoutDimensionAttr batchDim,
+                           VectorExt::PerDimLayoutAttr baseLayout,
+                           int64_t problemSize, int64_t fragmentDimSize) {
+  assert(problemSize % fragmentDimSize == 0 &&
+         "invalid layout fragment for problem size");
+
+  SmallVector<VectorExt::LayoutDimensionAttr, 3> dimAttrs(
+      baseLayout.getLabels());
+  dimAttrs.insert(dimAttrs.begin(), batchDim);
+
+  SmallVector<int64_t, 3> shapes(baseLayout.getShapes());
+  shapes.insert(shapes.begin(), problemSize / fragmentDimSize);
+  auto layout = VectorExt::PerDimLayoutAttr::get(baseLayout.getContext(),
+                                                 dimAttrs, shapes);
+  return layout;
+}
+
+// Get the batched layout attributes for the given fragment layouts, indexing
+// map, and problem shape. The canonical fragment map is used to compare against
+// the problem map |indexingMap|. For example, for mma fragment B (RHS):
+//
+// indexingMap = affine_map<(d0, d1, d2) -> (d1, d2) # Transposed B
+// fragmentMap = affine_map<(d0, d1, d2) -> (d2, d1)
+// problemShape = [32, 64]
+// fragmentSize = [16, 8]
+// fragmentLayouts = [kLayout, nLayout]
+//
+// Gives batched layout
+//
+// Dim0 Layout = [BATCHX, nLayoutLabels], [8, nLayoutShape]
+// Dim1 Layout = [BATCHY, kLayoutLabels], [2, kLayoutShape]
+static VectorExt::LayoutAttr
+getBatchedLayoutAttr(AffineMap indexingMap, AffineMap fragmentMap,
+                     ArrayRef<int64_t> problemShape,
+                     ArrayRef<int64_t> fragmentSize,
+                     ArrayRef<VectorExt::PerDimLayoutAttr> fragmentLayouts) {
+  // Current distribution to MFMA operations does not support batched
+  // contractions so that is reflected here.
+  assert(indexingMap.getNumResults() == 2 &&
+         "invalid indexing map to non-batched simple contraction");
+
+  VectorExt::LayoutDimensionAttr batchX = VectorExt::LayoutDimensionAttr::get(
+      indexingMap.getContext(), VectorExt::LayoutDimension::BATCHX);
+  VectorExt::LayoutDimensionAttr batchY = VectorExt::LayoutDimensionAttr::get(
+      indexingMap.getContext(), VectorExt::LayoutDimension::BATCHY);
+
+  SmallVector<VectorExt::PerDimLayoutAttr, 2> perDimAttrs;
+  for (auto [expr, batchType] : llvm::zip_equal(
+           indexingMap.getResults(),
+           SmallVector<VectorExt::LayoutDimensionAttr, 2>{batchX, batchY})) {
+    auto maybeResultPosition = fragmentMap.getResultPosition(expr);
+    assert(maybeResultPosition && "fragment map and problem map mismatch");
+    int64_t idx = *maybeResultPosition;
+    perDimAttrs.push_back(getBatchedPerDimLayoutAttr(
+        batchType, fragmentLayouts[idx], problemShape[idx], fragmentSize[idx]));
+  }
+
+  return VectorExt::LayoutAttr::get(indexingMap.getContext(), perDimAttrs);
+}
+
+static FailureOr<std::tuple<VectorLayoutInterface, VectorLayoutInterface,
+                            VectorLayoutInterface>>
+getContractionLayout(vector::ContractionOp contract, ConcreteMmaLayout layout) {
+  MLIRContext *context = contract.getContext();
+  FailureOr<linalg::ContractionDimensions> maybeContractionDims =
+      linalg::inferContractionDims(contract.getIndexingMapsArray());
+  if (failed(maybeContractionDims)) {
+    return failure();
+  }
+  auto contractionDims = *maybeContractionDims;
+  // TODO: Relax this condition to strictly alignment requirements.
+  if (contractionDims.k.size() != 1 || contractionDims.m.size() != 1 ||
+      contractionDims.n.size() != 1) {
+    return failure();
+  }
+  // TODO: Support batched contractions.
+  if (contractionDims.batch.size() > 0) {
+    return failure();
+  }
+  unsigned mDim = contractionDims.m[0];
+  unsigned nDim = contractionDims.n[0];
+  unsigned kDim = contractionDims.k[0];
+
+  SmallVector<int64_t> iterationBounds;
+  contract.getIterationBounds(iterationBounds);
+
+  int64_t problemMSize = iterationBounds[mDim];
+  int64_t problemNSize = iterationBounds[nDim];
+  int64_t problemKSize = iterationBounds[kDim];
+
+  int64_t mSize = layout.base.mSize;
+  int64_t nSize = layout.base.nSize;
+  int64_t kSize = layout.base.kSize;
+
+  // The problem size currently must be strictly aligned to the size of the mma.
+  // This is expected to succeed assuming the correct [masked] vector size was
+  // set at strategy configuration time (for this mma).
+  if (problemMSize % mSize != 0 || problemNSize % nSize ||
+      problemKSize % kSize) {
+    return failure();
+  }
+
+  VectorExt::LayoutAttr aLayout = getBatchedLayoutAttr(
+      contract.getIndexingMapsArray()[0],
+      AffineMap::getMultiDimMapWithTargets(3, {mDim, kDim}, context),
+      {problemMSize, problemKSize}, {mSize, kSize},
+      {layout.aMLayout, layout.aKLayout});
+  VectorExt::LayoutAttr bLayout = getBatchedLayoutAttr(
+      contract.getIndexingMapsArray()[1],
+      AffineMap::getMultiDimMapWithTargets(3, {kDim, nDim}, context),
+      {problemKSize, problemNSize}, {kSize, nSize},
+      {layout.bKLayout, layout.bNLayout});
+  VectorExt::LayoutAttr cLayout = getBatchedLayoutAttr(
+      contract.getIndexingMapsArray()[2],
+      AffineMap::getMultiDimMapWithTargets(3, {mDim, nDim}, context),
+      {problemMSize, problemNSize}, {mSize, nSize},
+      {layout.cMLayout, layout.cNLayout});
+
+  return std::make_tuple<VectorLayoutInterface, VectorLayoutInterface,
+                         VectorLayoutInterface>(aLayout, bLayout, cLayout);
+}
+
+FailureOr<std::tuple<
+    VectorLayoutInterface, VectorLayoutInterface,
+    VectorLayoutInterface>> static getContractionLayout(GPU::MMAAttr mma,
+                                                        vector::ContractionOp
+                                                            contract) {
+  ConcreteMmaLayout layout = getConcreteMMALayout(
+      contract->getContext(), mma.getIntrinsic().getValue());
+  return getContractionLayout(contract, layout);
+}
+
 DiagnosedSilenceableFailure
 transform_dialect::SetContractionLayoutAttributes::apply(
     transform::TransformRewriter &rewriter,
@@ -1609,7 +1823,7 @@
     return emitDefiniteFailure()
            << "invalid more than one attribute for contraction annotation";
   }
-  auto mmaType = llvm::dyn_cast<IREE::GPU::MmaInterfaceAttr>(typeList.front());
+  auto mmaType = llvm::dyn_cast<GPU::MMAAttr>(typeList.front());
   if (!mmaType) {
     return emitDefiniteFailure()
            << "invalid non-mma attribute for contraction annotation "
@@ -1623,7 +1837,7 @@
              << "invalid non-contraction annotation " << payload;
     }
 
-    auto maybeLayouts = mmaType.getContractionLayout(contract);
+    auto maybeLayouts = getContractionLayout(mmaType, contract);
     if (failed(maybeLayouts)) {
       return emitDefiniteFailure()
              << "invalid opaque mma layout for annotation " << mmaType;