[LLVMGPU][ROCM][Layoutv1] Landing Implementation of WMMA on layoutV1 (#17580)

Introducing WMMA on layoutv1. This feature is required/a prerequisite
for getting FA2 on RDNA3 GPUs.
This PR brings:

1. WMMA on layoutV1
2. Ensuring VectorX is specified to prevent unexpected behavior for
"step" setting in SIMT index.
3. Update of amdgpu_distribute_vectors/distribution of contract such
that it follow the newer style of `set_contraction_layout_attributes`
for setting intrinsic types as opposed to the less safer method of
inferring intrinsic types.
4. Update of tests to match above.

Co-authored-by: Groverkss <groverkss@gmail.com>

---------

Co-authored-by: Kunwar Grover <groverkss@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index ffe82da..554ecbf 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -290,8 +290,9 @@
     auto aKLayout = inner;
     auto bKLayout = inner;
     auto bNLayout = outer;
-    auto cMLayout = PerDimLayoutAttr::get(context, {laneY, vectorX}, {8, 2});
-    auto cNLayout = outer;
+    auto cMLayout =
+        PerDimLayoutAttr::get(context, {vectorY, laneY, vectorX}, {8, 2, 1});
+    auto cNLayout = PerDimLayoutAttr::get(context, {laneX}, {16});
     return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
                              bNLayout,     cMLayout, cNLayout};
   }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
index 17492ea..94e8a16 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
@@ -1647,6 +1647,7 @@
              << "invalid opaque mma layout for annotation " << mmaType;
     }
 
+    contract->setAttr("iree.amdgpu.mma", mmaType);
     auto [aLayout, bLayout, cLayout] = *maybeLayouts;
     contract->setAttr("__vector_layout_test_anchor_operand_0", aLayout);
     contract->setAttr("__vector_layout_test_anchor_operand_1", bLayout);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/AMDGPUDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/AMDGPUDistributionPatterns.cpp
index 39cdd7c..6253dd7 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/AMDGPUDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/AMDGPUDistributionPatterns.cpp
@@ -7,6 +7,7 @@
 #include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h"
 #include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h"
 #include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h"
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
 #include "iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h"
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
 
@@ -18,83 +19,6 @@
 enum class ContractMatrixType { A, B, C, D };
 enum class ContractType { MM, MMT, MTM, MTMT, UNSUPPORTED };
 
-/// We define AMD MFMA instruction layouts only for contract type MM, i.e. C(i,
-/// k) += A(i, j) * B(j, k). We call these canonical layouts: layoutA, layoutB,
-/// layoutC, corresponding to the A, B, C matrices.
-///
-/// For any other contract type, the layout is simply transposed for that
-/// operand. For example, for MMT, the layouts used should be layoutA,
-/// layoutB.T, layoutC. For an easier understanding of this transposition,
-/// think of the transpose simply being outside the contract:
-///
-/// vector.contract {type = MMT} %a, %b
-///
-/// is equivalent to
-///
-/// %bt = vector.transpose %b
-/// vector.contract {type = MM} %a, %bt
-///
-/// Now, you would assign layouts based on contract type MM, and would get the
-/// right layout for %b by transposing the layout for B.
-///
-/// Now that we have defined what layoutA, layoutB, layoutC are, we will define
-/// what the canonical layouts are for each MFMA instruction. These are
-/// represented as the original matrix, with elements representing which thread
-/// id in the subgroup gets which element.
-/// These layouts were referenced from
-/// https://github.com/ROCm/amd_matrix_instruction_calculator
-///
-/// The naming scheme for these operators is InputType_MxNxK_OutputType.
-enum class MFMAType {
-  /// layoutA:
-  ///  0  0  0  0 16 16 16 16 32 32 32 32 48 48 48 48
-  ///  1  1  1  1 17 17 17 17 33 33 33 33 49 49 49 49
-  ///  2  2  2  2 18 18 18 18 34 34 34 34 50 50 50 50
-  /// ...
-  /// 15 15 15 15 31 31 31 31 47 47 47 47 63 63 63 63
-  ///
-  /// layoutB:
-  /// Transpose of layoutA
-  ///
-  /// layoutC:
-  /// Same as layoutB
-  F16_16x16x16_F32,
-  /// layoutA:
-  ///  0  0  0  0 32 32 32 32
-  ///  1  1  1  1 33 33 33 33
-  ///  2  2  2  2 34 34 34 34
-  ///  ⋮  ⋮  ⋮  ⋮  ⋮  ⋮  ⋮  ⋮
-  /// 31 31 31 31 63 63 63 63
-  ///
-  /// layoutB:
-  /// Transpose of layoutA
-  ///
-  /// layoutC:
-  ///  0  1  2 ... 31
-  ///  0  1  2 ... 31
-  ///  0  1  2 ... 31
-  ///  0  1  2 ... 31
-  /// 32 33 34 ... 63
-  /// 32 33 34 ... 63
-  /// 32 33 34 ... 63
-  /// 32 33 34 ... 63
-  ///  0  1  2 ... 31
-  ///  ⋮  ⋮  ⋮ ...  ⋮
-  /// 32 33 34 ... 63
-  ///  0  1  2 ... 31
-  ///  ⋮  ⋮  ⋮ ...  ⋮
-  /// 32 33 34 ... 63
-  ///  0  1  2 ... 31
-  ///  0  1  2 ... 31
-  ///  0  1  2 ... 31
-  ///  0  1  2 ... 31
-  /// 32 33 34 ... 63
-  /// 32 33 34 ... 63
-  /// 32 33 34 ... 63
-  /// 32 33 34 ... 63
-  F16_32x32x8_F32,
-};
-
 namespace {
 
 static bool isOperandATransposed(ContractType contractType) {
@@ -166,208 +90,6 @@
     return ContractType::UNSUPPORTED;
   }
 
-  Value computeMMA(Value a, Value b, Value c, Location loc, OpBuilder &rewriter,
-                   MFMAType mfmaType) const {
-    uint32_t m, n, k, blks;
-    if (mfmaType == MFMAType::F16_16x16x16_F32) {
-      m = n = k = 16;
-    } else if (mfmaType == MFMAType::F16_32x32x8_F32) {
-      m = n = 32;
-      k = 8;
-    }
-    blks = 1;
-    return rewriter.create<amdgpu::MFMAOp>(loc, c.getType(), m, n, k, blks, a,
-                                           b, c);
-  }
-
-  PerDimLayoutAttr createPerDimLayout(MLIRContext *ctx,
-                                      ArrayRef<LayoutDimension> dims,
-                                      ArrayRef<int64_t> shapes) const {
-    SmallVector<LayoutDimensionAttr> dimAttrs;
-    for (auto dim : dims)
-      dimAttrs.push_back(LayoutDimensionAttr::get(ctx, dim));
-    return PerDimLayoutAttr::get(ctx, dimAttrs, shapes);
-  }
-
-  std::tuple<PerDimLayoutAttr, PerDimLayoutAttr> createCanonicalLayouts16x16x16(
-      LayoutDimension batchRowLabel, int64_t batchRow,
-      LayoutDimension batchColLabel, int64_t batchCol) const {
-    MLIRContext *ctx = getContext();
-    PerDimLayoutAttr rowLayout = createPerDimLayout(
-        ctx, {batchRowLabel, LayoutDimension::LANEX}, {batchRow, 16});
-    PerDimLayoutAttr colLayout = createPerDimLayout(
-        ctx, {batchColLabel, LayoutDimension::LANEY, LayoutDimension::VECTORX},
-        {batchCol, 4, 4});
-    return {rowLayout, colLayout};
-  }
-
-  bool isCompatible16x16x16A(LayoutAttr layout, int64_t batchRow,
-                             int64_t batchCol) const {
-    auto [rowLayout, colLayout] = createCanonicalLayouts16x16x16(
-        LayoutDimension::BATCHX, batchRow, LayoutDimension::BATCHY, batchCol);
-    LayoutAttr canonicalLayout =
-        LayoutAttr::get(getContext(), {rowLayout, colLayout});
-    return layout == canonicalLayout;
-  }
-
-  bool isCompatible16x16x16B(LayoutAttr layout, int64_t batchRow,
-                             int64_t batchCol) const {
-    auto [colLayout, rowLayout] = createCanonicalLayouts16x16x16(
-        LayoutDimension::BATCHY, batchCol, LayoutDimension::BATCHX, batchRow);
-    LayoutAttr canonicalLayout =
-        LayoutAttr::get(getContext(), {rowLayout, colLayout});
-    return layout == canonicalLayout;
-  }
-
-  bool isCompatible16x16x16C(LayoutAttr layout, int64_t batchRow,
-                             int64_t batchCol) const {
-    return isCompatible16x16x16B(layout, batchRow, batchCol);
-  }
-
-  std::tuple<PerDimLayoutAttr, PerDimLayoutAttr>
-  createCanonicalLayouts32x32x8(LayoutDimension batchRowLabel, int64_t batchRow,
-                                LayoutDimension batchColLabel, int64_t batchCol,
-                                ContractMatrixType matrixType) const {
-    MLIRContext *ctx = getContext();
-    PerDimLayoutAttr rowLayout = createPerDimLayout(
-        ctx, {batchRowLabel, LayoutDimension::LANEX}, {batchRow, 32});
-    PerDimLayoutAttr colLayout;
-    if (matrixType == ContractMatrixType::C) {
-      colLayout =
-          createPerDimLayout(ctx,
-                             {batchColLabel, LayoutDimension::VECTORY,
-                              LayoutDimension::LANEY, LayoutDimension::VECTORX},
-                             {batchCol, 4, 2, 4});
-    } else {
-      colLayout = createPerDimLayout(
-          ctx,
-          {batchColLabel, LayoutDimension::LANEY, LayoutDimension::VECTORX},
-          {batchCol, 2, 4});
-    }
-    return {rowLayout, colLayout};
-  }
-
-  bool isCompatible32x32x8A(LayoutAttr layout, int64_t batchRow,
-                            int64_t batchCol) const {
-    auto [rowLayout, colLayout] = createCanonicalLayouts32x32x8(
-        LayoutDimension::BATCHX, batchRow, LayoutDimension::BATCHY, batchCol,
-        ContractMatrixType::A);
-    LayoutAttr canonicalLayout =
-        LayoutAttr::get(getContext(), {rowLayout, colLayout});
-    return layout == canonicalLayout;
-  }
-
-  bool isCompatible32x32x8B(LayoutAttr layout, int64_t batchRow,
-                            int64_t batchCol) const {
-    auto [colLayout, rowLayout] = createCanonicalLayouts32x32x8(
-        LayoutDimension::BATCHY, batchCol, LayoutDimension::BATCHX, batchRow,
-        ContractMatrixType::B);
-    LayoutAttr canonicalLayout =
-        LayoutAttr::get(getContext(), {rowLayout, colLayout});
-    return layout == canonicalLayout;
-  }
-
-  bool isCompatible32x32x8C(LayoutAttr layout, int64_t batchRow,
-                            int64_t batchCol) const {
-    auto [colLayout, rowLayout] = createCanonicalLayouts32x32x8(
-        LayoutDimension::BATCHY, batchCol, LayoutDimension::BATCHX, batchRow,
-        ContractMatrixType::C);
-    LayoutAttr canonicalLayout =
-        LayoutAttr::get(getContext(), {rowLayout, colLayout});
-    return layout == canonicalLayout;
-  }
-
-  bool isCompatible16x16x16(LayoutAttr layout, ContractMatrixType matrixType,
-                            int64_t batchRow, int64_t batchCol) const {
-    switch (matrixType) {
-    case ContractMatrixType::A:
-      return isCompatible16x16x16A(layout, batchRow, batchCol);
-    case ContractMatrixType::B:
-      return isCompatible16x16x16B(layout, batchRow, batchCol);
-    default:
-      return isCompatible16x16x16C(layout, batchRow, batchCol);
-    }
-    return false;
-  }
-
-  bool isCompatible32x32x8(LayoutAttr layout, ContractMatrixType matrixType,
-                           int64_t batchRow, int64_t batchCol) const {
-    switch (matrixType) {
-    case ContractMatrixType::A:
-      return isCompatible32x32x8A(layout, batchRow, batchCol);
-    case ContractMatrixType::B:
-      return isCompatible32x32x8B(layout, batchRow, batchCol);
-    default:
-      return isCompatible32x32x8C(layout, batchRow, batchCol);
-    }
-    return false;
-  }
-
-  bool isCompatible(LayoutAttr layout, ContractMatrixType matrixType,
-                    MFMAType mfmaType) const {
-    std::optional<int64_t> batchRow = layout.getBatchDim(0);
-    if (!batchRow)
-      return false;
-    std::optional<int64_t> batchCol = layout.getBatchDim(1);
-    if (!batchCol)
-      return false;
-    switch (mfmaType) {
-    case MFMAType::F16_16x16x16_F32:
-      return isCompatible16x16x16(layout, matrixType, batchRow.value(),
-                                  batchCol.value());
-    case MFMAType::F16_32x32x8_F32:
-      return isCompatible32x32x8(layout, matrixType, batchRow.value(),
-                                 batchCol.value());
-    default:
-      return false;
-    }
-    return false;
-  }
-
-  // If we have a prior guess of the MFMA type, only evaluate that type.
-  // Otherwise, evaluate all types to find a match.
-  std::optional<MFMAType> inferMFMAType(LayoutAttr layout,
-                                        ContractMatrixType matrixType,
-                                        std::optional<MFMAType> prior) const {
-    SmallVector<MFMAType> mfmaTypes;
-    if (prior) {
-      mfmaTypes.push_back(prior.value());
-    } else {
-      mfmaTypes = {MFMAType::F16_16x16x16_F32, MFMAType::F16_32x32x8_F32};
-    }
-    for (MFMAType mfmaType : mfmaTypes) {
-      if (isCompatible(layout, matrixType, mfmaType))
-        return mfmaType;
-    }
-    return std::nullopt;
-  }
-
-  // Inputs are LHS, RHS and ACC operands and corresponding layouts.
-  // Output is inferred MFMAType or none (if layout is not compatible with any
-  // MFMA layout).
-  std::optional<MFMAType>
-  inferCompatibleMFMAType(ArrayRef<LayoutAttr> layouts,
-                          ContractType contractType) const {
-    std::optional<MFMAType> mfmaType{std::nullopt};
-    SmallVector<ContractMatrixType> matrixTypes{
-        ContractMatrixType::A, ContractMatrixType::B, ContractMatrixType::C};
-
-    // Canonical layouts for MFMA are transposes of each other.
-    if (isOperandATransposed(contractType)) {
-      matrixTypes[0] = ContractMatrixType::B;
-    }
-    if (isOperandBTransposed(contractType)) {
-      matrixTypes[1] = ContractMatrixType::A;
-    }
-
-    for (auto [layout, matrixType] : llvm::zip(layouts, matrixTypes)) {
-      mfmaType = inferMFMAType(layout, matrixType, mfmaType);
-      if (!mfmaType)
-        return std::nullopt;
-    }
-    return mfmaType;
-  }
-
   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
                                 DistributionSignature &signature,
                                 PatternRewriter &rewriter) const override {
@@ -408,10 +130,12 @@
     if (contractType == ContractType::UNSUPPORTED)
       return failure();
 
-    std::optional<MFMAType> mfmaType =
-        inferCompatibleMFMAType(layouts, contractType);
-    if (!mfmaType)
-      return failure();
+    auto mmaAttr =
+        contractOp->getAttrOfType<IREE::GPU::MMAAttr>("iree.amdgpu.mma");
+    if (!mmaAttr) {
+      return rewriter.notifyMatchFailure(
+          contractOp, "missing iree.amdgpu.mma intrinsic attribute");
+    }
 
     std::optional<int64_t> rowBatch = layouts[LHS].getBatchDim(0);
     if (!rowBatch)
@@ -434,10 +158,13 @@
         Value bMatrix = rewriter.create<vector::ExtractOp>(
             loc, getDistributed(rewriter, operands[RHS], layouts[RHS]),
             getIndices(contractType, ContractMatrixType::B, k, indices[1]));
-        dMatrix = computeMMA(aMatrix, bMatrix, dMatrix, loc, rewriter,
-                             mfmaType.value());
+        dMatrix = mmaAttr
+                      .buildMmaOperation(rewriter, loc, dMatrix.getType(),
+                                         aMatrix, bMatrix, dMatrix)
+                      .value();
       }
       vector = rewriter.create<vector::InsertOp>(loc, dMatrix, vector, indices);
+      return success();
     };
 
     LayoutIterator iterator(resultLayout);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/BUILD.bazel
index c5f7fd2..695ecb0 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/BUILD.bazel
@@ -29,6 +29,7 @@
         "//compiler/src/iree/compiler/Codegen/Common",
         "//compiler/src/iree/compiler/Codegen/Common:VectorLayoutAnalysis",
         "//compiler/src/iree/compiler/Codegen/Common/GPU:CommonGPUPasses",
+        "//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
         "//compiler/src/iree/compiler/Codegen/Transforms",
         "//compiler/src/iree/compiler/Codegen/Utils",
         "//compiler/src/iree/compiler/Dialect/HAL/IR",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/CMakeLists.txt
index b94f2d6..49a9e7e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/CMakeLists.txt
@@ -41,6 +41,7 @@
     iree::compiler::Codegen::Common
     iree::compiler::Codegen::Common::GPU::CommonGPUPasses
     iree::compiler::Codegen::Common::VectorLayoutAnalysis
+    iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
     iree::compiler::Codegen::Transforms
     iree::compiler::Codegen::Utils
     iree::compiler::Dialect::HAL::IR
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
index ec04bc8..1757b5c 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
@@ -20,6 +20,7 @@
         [
             "amdgpu_chained_matmul.mlir",
             "amdgpu_contraction_distribution.mlir",
+            "amdgpu_set_anchor_layouts.mlir",
             "attention.mlir",
             "attention_mfma.mlir",
             "conv_pipeline_test_cuda.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index be60c97..2ff84aa 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -16,6 +16,7 @@
   SRCS
     "amdgpu_chained_matmul.mlir"
     "amdgpu_contraction_distribution.mlir"
+    "amdgpu_set_anchor_layouts.mlir"
     "attention.mlir"
     "attention_mfma.mlir"
     "cast_address_space_function.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_contraction_distribution.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_contraction_distribution.mlir
index b46a8d0..fd12413 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_contraction_distribution.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_contraction_distribution.mlir
@@ -4,6 +4,8 @@
 // layoutC means and how these layouts are assigned based on the instruction
 // type.
 
+#layout = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
+
 #map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
 #map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map3 = affine_map<(d0, d1, d2) -> (d1, d0)>
@@ -22,15 +24,6 @@
 #layout_c = #iree_vector_ext.layout<#row_layout2, #col_layout2>
 builtin.module attributes { transform.with_named_sequence } {
   func.func @distribute_mfma_16x16x16_mmt(%a : vector<16x16xf16>, %b : vector<16x16xf16>, %c : vector<16x16xf32>) -> vector<16x16xf32> {
-    // CHECK-LABEL: distribute_mfma_16x16x16_mmt
-    // CHECK-SAME: %[[ARG0:.+]]: vector<16x16xf16>, %[[ARG1:.+]]: vector<16x16xf16>, %[[ARG2:.+]]: vector<16x16xf32>
-    // CHECK-DAG: %[[C:.+]] = iree_vector_ext.to_simt %[[ARG2]] : vector<16x16xf32> -> vector<1x1x4xf32>
-    // CHECK-DAG: %[[CV:.+]] = vector.extract %[[C]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
-    // CHECK-DAG: %[[A:.+]] = iree_vector_ext.to_simt %[[ARG0]] : vector<16x16xf16> -> vector<1x1x4xf16>
-    // CHECK-DAG: %[[AV:.+]] = vector.extract %[[A]][0, 0] : vector<4xf16> from vector<1x1x4xf16>
-    // CHECK-DAG: %[[B:.+]] = iree_vector_ext.to_simt %[[ARG1]] : vector<16x16xf16> -> vector<1x1x4xf16>
-    // CHECK-DAG: %[[BV:.+]] = vector.extract %[[B]][0, 0] : vector<4xf16> from vector<1x1x4xf16>
-    // CHECK-DAG: %[[OUT:.+]] = amdgpu.mfma %[[AV]] * %[[BV]] + %[[CV]] {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
     %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"],
                                kind = #vector.kind<add>,
                                "__vector_layout_test_anchor_operand_0" = #layout_a,
@@ -42,14 +35,31 @@
     return %output : vector<16x16xf32>
   }
   transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %contract = transform.structured.match ops{["vector.contract"]} in %variant_op :  (!transform.any_op) -> !transform.any_op
+    %layout16x16x16 = transform.param.constant #layout -> !transform.any_param
+    transform.iree.set_contraction_layout_attributes %contract, %layout16x16x16 : !transform.any_op, !transform.any_param
+
     %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
     transform.iree.amdgpu_distribute_vectors %top_level_func test_conversion : (!transform.any_op) -> !transform.any_op
     transform.yield
   }
 }
 
+// CHECK-LABEL: distribute_mfma_16x16x16_mmt
+
+// CHECK-SAME: %[[ARG0:.+]]: vector<16x16xf16>, %[[ARG1:.+]]: vector<16x16xf16>, %[[ARG2:.+]]: vector<16x16xf32>
+// CHECK-DAG: %[[C:.+]] = iree_vector_ext.to_simt %[[ARG2]] : vector<16x16xf32> -> vector<1x1x4xf32>
+// CHECK-DAG: %[[CV:.+]] = vector.extract %[[C]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
+// CHECK-DAG: %[[A:.+]] = iree_vector_ext.to_simt %[[ARG0]] : vector<16x16xf16> -> vector<1x1x4xf16>
+// CHECK-DAG: %[[AV:.+]] = vector.extract %[[A]][0, 0] : vector<4xf16> from vector<1x1x4xf16>
+// CHECK-DAG: %[[B:.+]] = iree_vector_ext.to_simt %[[ARG1]] : vector<16x16xf16> -> vector<1x1x4xf16>
+// CHECK-DAG: %[[BV:.+]] = vector.extract %[[B]][0, 0] : vector<4xf16> from vector<1x1x4xf16>
+// CHECK-DAG: %[[OUT:.+]] = amdgpu.mfma %[[AV]] * %[[BV]] + %[[CV]] {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
+
 // -----
 
+#layout = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
+
 #map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
 #map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map3 = affine_map<(d0, d1, d2) -> (d1, d0)>
@@ -70,8 +80,6 @@
 #layout_c = #iree_vector_ext.layout<#row_layout3, #col_layout3>
 builtin.module attributes { transform.with_named_sequence } {
   func.func @distribute_mfma_16x16x16_mmt_batch(%a : vector<32x128xf16>, %b : vector<64x128xf16>, %c : vector<32x64xf32>) -> vector<32x64xf32> {
-    // CHECK-LABEL: distribute_mfma_16x16x16_mmt_batch
-    // CHECK-COUNT-64: amdgpu.mfma {{.*}}, vector<4xf32>
     %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"],
                                kind = #vector.kind<add>,
                                "__vector_layout_test_anchor_operand_0" = #layout_a,
@@ -83,14 +91,24 @@
     return %output : vector<32x64xf32>
   }
   transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %contract = transform.structured.match ops{["vector.contract"]} in %variant_op :  (!transform.any_op) -> !transform.any_op
+    %layout16x16x16 = transform.param.constant #layout -> !transform.any_param
+    transform.iree.set_contraction_layout_attributes %contract, %layout16x16x16 : !transform.any_op, !transform.any_param
+
     %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
     transform.iree.amdgpu_distribute_vectors %top_level_func test_conversion : (!transform.any_op) -> !transform.any_op
     transform.yield
   }
 }
 
+// CHECK-LABEL: distribute_mfma_16x16x16_mmt_batch
+
+// CHECK-COUNT-64: amdgpu.mfma {{.*}}, vector<4xf32>
+
 // -----
 
+#layout = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>
+
 #map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
 #map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -111,15 +129,6 @@
 #layout_c = #iree_vector_ext.layout<#row_layout2, #col_layout2>
 builtin.module attributes { transform.with_named_sequence } {
   func.func @distribute_mfma_32x32x8_mm(%a : vector<32x8xf16>, %b : vector<8x32xf16>, %c : vector<32x32xf32>) -> vector<32x32xf32> {
-    // CHECK-LABEL: distribute_mfma_32x32x8_mm
-    // CHECK-SAME: %[[ARG0:.+]]: vector<32x8xf16>, %[[ARG1:.+]]: vector<8x32xf16>, %[[ARG2:.+]]: vector<32x32xf32>
-    // CHECK-DAG: %[[C:.+]] = iree_vector_ext.to_simt %[[ARG2]] : vector<32x32xf32> -> vector<1x1x16xf32>
-    // CHECK-DAG: %[[CV:.+]] = vector.extract %[[C]][0, 0] : vector<16xf32> from vector<1x1x16xf32>
-    // CHECK-DAG: %[[A:.+]] = iree_vector_ext.to_simt %[[ARG0]] : vector<32x8xf16> -> vector<1x1x4xf16>
-    // CHECK-DAG: %[[AV:.+]] = vector.extract %[[A]][0, 0] : vector<4xf16> from vector<1x1x4xf16>
-    // CHECK-DAG: %[[B:.+]] = iree_vector_ext.to_simt %[[ARG1]] : vector<8x32xf16> -> vector<1x1x4xf16>
-    // CHECK-DAG: %[[BV:.+]] = vector.extract %[[B]][0, 0] : vector<4xf16> from vector<1x1x4xf16>
-    // CHECK-DAG: %[[OUT:.+]] = amdgpu.mfma %[[AV]] * %[[BV]] + %[[CV]] {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<16xf32>
     %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"],
                                kind = #vector.kind<add>,
                                "__vector_layout_test_anchor_operand_0" = #layout_a,
@@ -131,14 +140,31 @@
     return %output : vector<32x32xf32>
   }
   transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %contract = transform.structured.match ops{["vector.contract"]} in %variant_op :  (!transform.any_op) -> !transform.any_op
+    %layout32x32x8 = transform.param.constant #layout -> !transform.any_param
+    transform.iree.set_contraction_layout_attributes %contract, %layout32x32x8 : !transform.any_op, !transform.any_param
+
     %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
     transform.iree.amdgpu_distribute_vectors %top_level_func test_conversion : (!transform.any_op) -> !transform.any_op
     transform.yield
   }
 }
 
+// CHECK-LABEL: distribute_mfma_32x32x8_mm
+
+// CHECK-SAME: %[[ARG0:.+]]: vector<32x8xf16>, %[[ARG1:.+]]: vector<8x32xf16>, %[[ARG2:.+]]: vector<32x32xf32>
+// CHECK-DAG: %[[C:.+]] = iree_vector_ext.to_simt %[[ARG2]] : vector<32x32xf32> -> vector<1x1x16xf32>
+// CHECK-DAG: %[[CV:.+]] = vector.extract %[[C]][0, 0] : vector<16xf32> from vector<1x1x16xf32>
+// CHECK-DAG: %[[A:.+]] = iree_vector_ext.to_simt %[[ARG0]] : vector<32x8xf16> -> vector<1x1x4xf16>
+// CHECK-DAG: %[[AV:.+]] = vector.extract %[[A]][0, 0] : vector<4xf16> from vector<1x1x4xf16>
+// CHECK-DAG: %[[B:.+]] = iree_vector_ext.to_simt %[[ARG1]] : vector<8x32xf16> -> vector<1x1x4xf16>
+// CHECK-DAG: %[[BV:.+]] = vector.extract %[[B]][0, 0] : vector<4xf16> from vector<1x1x4xf16>
+// CHECK-DAG: %[[OUT:.+]] = amdgpu.mfma %[[AV]] * %[[BV]] + %[[CV]] {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<16xf32>
+
 // -----
 
+#layout = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>
+
 #map1 = affine_map<(d0, d1, d2) -> (d2, d0)>
 #map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
 #map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -160,7 +186,6 @@
 #layout_c = #iree_vector_ext.layout<#row_layout2, #col_layout2>
 builtin.module attributes { transform.with_named_sequence } {
   func.func @distribute_mfma_32x32x8_mtm(%a : vector<8x64xf16>, %b : vector<8x32xf16>, %c : vector<64x32xf32>) -> vector<64x32xf32> {
-    // CHECK-LABEL: distribute_mfma_32x32x8_mtm
     %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"],
                                kind = #vector.kind<add>,
                                "__vector_layout_test_anchor_operand_0" = #layout_a,
@@ -169,17 +194,126 @@
                                "__vector_layout_test_anchor_result_0" = #layout_c
                                }
                                 %a, %b, %c : vector<8x64xf16>, vector<8x32xf16> into vector<64x32xf32>
-    // CHECK-DAG: %[[A1:.+]] = vector.extract %[[A:.+]][0, 0] : vector<4xf16> from vector<1x2x4xf16>
-    // CHECK-DAG: %[[B1:.+]] = vector.extract %[[B:.+]][0, 0] : vector<4xf16> from vector<1x1x4xf16>
-    // CHECK-DAG: %{{.*}} = amdgpu.mfma %[[A1]] * %[[B1]]
-    // CHECK-DAG: %[[A2:.+]] = vector.extract %[[A]][0, 1] : vector<4xf16> from vector<1x2x4xf16>
-    // CHECK-DAG: %{{.*}} = amdgpu.mfma %[[A2]] * %[[B1]]
-    // CHECK-NOT: amdgpu.mfma
     return %output : vector<64x32xf32>
   }
   transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %contract = transform.structured.match ops{["vector.contract"]} in %variant_op :  (!transform.any_op) -> !transform.any_op
+    %layout32x32x8 = transform.param.constant #layout -> !transform.any_param
+    transform.iree.set_contraction_layout_attributes %contract, %layout32x32x8 : !transform.any_op, !transform.any_param
+
     %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
     transform.iree.amdgpu_distribute_vectors %top_level_func test_conversion : (!transform.any_op) -> !transform.any_op
     transform.yield
   }
 }
+
+// CHECK-LABEL: distribute_mfma_32x32x8_mtm
+
+// CHECK-DAG: %[[A1:.+]] = vector.extract %[[A:.+]][0, 0] : vector<4xf16> from vector<1x2x4xf16>
+// CHECK-DAG: %[[B1:.+]] = vector.extract %[[B:.+]][0, 0] : vector<4xf16> from vector<1x1x4xf16>
+// CHECK-DAG: %{{.*}} = amdgpu.mfma %[[A1]] * %[[B1]]
+// CHECK-DAG: %[[A2:.+]] = vector.extract %[[A]][0, 1] : vector<4xf16> from vector<1x2x4xf16>
+// CHECK-DAG: %{{.*}} = amdgpu.mfma %[[A2]] * %[[B1]]
+// CHECK-NOT: amdgpu.mfma
+
+// -----
+
+#layout = #iree_gpu.mma_layout<WMMA_F16_16x16x16_F32>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d1, d0)>
+
+// A: vector<16x16>, layout = layoutA
+#row_layout = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX], [1, 16]>
+#col_layout = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [1, 1, 16]>
+#layout_a = #iree_vector_ext.layout<#row_layout, #col_layout>
+
+// B: vector<16x16>, layout = transpose(layoutB) = layoutA
+// Since shapes are also same, we can use the same layout attribute, layout_a.
+
+// C: vector<16x16>, layout = layoutC
+#row_layout2 = #iree_vector_ext.per_dim_layout<[BATCHX, VECTORY, LANEY, VECTORX], [1, 8, 2, 1]>
+#col_layout2 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [1, 16]>
+#layout_c = #iree_vector_ext.layout<#row_layout2, #col_layout2>
+builtin.module attributes { transform.with_named_sequence } {
+  func.func @distribute_wmma_16x16x16_mmt(%a : vector<16x16xf16>, %b : vector<16x16xf16>, %c : vector<16x16xf32>) -> vector<16x16xf32> {
+    %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"],
+                               kind = #vector.kind<add>,
+                               "__vector_layout_test_anchor_operand_0" = #layout_a,
+                               "__vector_layout_test_anchor_operand_1" = #layout_a,
+                               "__vector_layout_test_anchor_operand_2" = #layout_c,
+                               "__vector_layout_test_anchor_result_0" = #layout_c
+                               }
+                                %a, %b, %c : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf32>
+    return %output : vector<16x16xf32>
+  }
+  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %contract = transform.structured.match ops{["vector.contract"]} in %variant_op :  (!transform.any_op) -> !transform.any_op
+    %layout16x16x16 = transform.param.constant #layout -> !transform.any_param
+    transform.iree.set_contraction_layout_attributes %contract, %layout16x16x16 : !transform.any_op, !transform.any_param
+
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.amdgpu_distribute_vectors %top_level_func test_conversion : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: distribute_wmma_16x16x16_mmt
+
+// CHECK-SAME: %[[ARG0:.+]]: vector<16x16xf16>, %[[ARG1:.+]]: vector<16x16xf16>, %[[ARG2:.+]]: vector<16x16xf32>
+// CHECK-DAG: %[[C:.+]] = iree_vector_ext.to_simt %[[ARG2]] : vector<16x16xf32> -> vector<1x1x8xf32>
+// CHECK-DAG: %[[CV:.+]] = vector.extract %[[C]][0, 0] : vector<8xf32> from vector<1x1x8xf32>
+// CHECK-DAG: %[[A:.+]] = iree_vector_ext.to_simt %[[ARG0]] : vector<16x16xf16> -> vector<1x1x16xf16>
+// CHECK-DAG: %[[AV:.+]] = vector.extract %[[A]][0, 0] : vector<16xf16> from vector<1x1x16xf16>
+// CHECK-DAG: %[[B:.+]] = iree_vector_ext.to_simt %[[ARG1]] : vector<16x16xf16> -> vector<1x1x16xf16>
+// CHECK-DAG: %[[BV:.+]] = vector.extract %[[B]][0, 0] : vector<16xf16> from vector<1x1x16xf16>
+// CHECK-DAG: %[[OUT:.+]] = amdgpu.wmma %[[AV]] * %[[BV]] + %[[CV]] : vector<16xf16>, vector<16xf16>, vector<8xf32>
+
+// -----
+
+#layout = #iree_gpu.mma_layout<WMMA_F16_16x16x16_F32>
+
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d1, d0)>
+
+// A: vector<32x128>, layout = layoutA
+#row_layout = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX], [2, 16]>
+#col_layout = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [8, 1, 16]>
+#layout_a = #iree_vector_ext.layout<#row_layout, #col_layout>
+
+// B: vector<64x128>, layout = transpose(layoutB) = layoutA
+#row_layout2 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX], [4, 16]>
+#col_layout2 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [8, 1, 16]>
+#layout_b = #iree_vector_ext.layout<#row_layout2, #col_layout2>
+
+// C: vector<32x64>, layout = layoutC
+#row_layout3 = #iree_vector_ext.per_dim_layout<[BATCHX, VECTORY, LANEY, VECTORX], [2, 8, 2, 1]>
+#col_layout3 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [4, 16]>
+#layout_c = #iree_vector_ext.layout<#row_layout3, #col_layout3>
+builtin.module attributes { transform.with_named_sequence } {
+  func.func @distribute_wmma_16x16x16_mmt_batch(%a : vector<32x128xf16>, %b : vector<64x128xf16>, %c : vector<32x64xf32>) -> vector<32x64xf32> {
+    %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"],
+                               kind = #vector.kind<add>,
+                               "__vector_layout_test_anchor_operand_0" = #layout_a,
+                               "__vector_layout_test_anchor_operand_1" = #layout_b,
+                               "__vector_layout_test_anchor_operand_2" = #layout_c,
+                               "__vector_layout_test_anchor_result_0" = #layout_c
+                               }
+                                %a, %b, %c : vector<32x128xf16>, vector<64x128xf16> into vector<32x64xf32>
+    return %output : vector<32x64xf32>
+  }
+  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %contract = transform.structured.match ops{["vector.contract"]} in %variant_op :  (!transform.any_op) -> !transform.any_op
+    %layout16x16x16 = transform.param.constant #layout -> !transform.any_param
+    transform.iree.set_contraction_layout_attributes %contract, %layout16x16x16 : !transform.any_op, !transform.any_param
+
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.amdgpu_distribute_vectors %top_level_func test_conversion : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: distribute_wmma_16x16x16_mmt_batch
+
+// CHECK-COUNT-64: amdgpu.wmma {{.*}} : vector<16xf16>, vector<16xf16>, vector<8xf32>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_set_anchor_layouts.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_set_anchor_layouts.mlir
new file mode 100644
index 0000000..2269dc6
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_set_anchor_layouts.mlir
@@ -0,0 +1,95 @@
+// RUN: iree-opt --iree-transform-dialect-interpreter --split-input-file --cse %s --verify-diagnostics
+
+// This tests that the compiler is setting the correct layout anchors for various vectorOps and shapes.
+// Currently only testing on contraction layoutV1, but can be expanded to others.
+
+#layout = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d1, d0)>
+
+builtin.module attributes { transform.with_named_sequence } {
+  func.func @anchor_mfma_16x16x16_mmt(%a : memref<16x16xf16>, %b : memref<16x16xf16>, %init : vector<16x16xf32>) -> vector<16x16xf32> {
+    // CHECK-LABEL: anchor_mfma_16x16x16_mmt
+    %c0 = arith.constant 0 : index
+    %cst_0 = arith.constant 0.0 : f16
+    %lhs = vector.transfer_read %a[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+    // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX,  LANEX], [1, 16]>, <[ BATCHY,  LANEY,  VECTORX], [1, 4, 4]>>}}
+    %rhs = vector.transfer_read %b[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+    // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX,  LANEX], [1, 16]>, <[ BATCHY,  LANEY,  VECTORX], [1, 4, 4]>>}}
+    %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %init : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf32>
+    // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX,  LANEY,  VECTORX], [1, 4, 4]>, <[ BATCHY,  LANEX], [1, 16]>>}}
+    return %output : vector<16x16xf32>
+  }
+  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %contract = transform.structured.match ops{["vector.contract"]} in %variant_op :  (!transform.any_op) -> !transform.any_op
+    %layout16x16x16 = transform.param.constant #layout -> !transform.any_param
+    transform.iree.set_contraction_layout_attributes %contract, %layout16x16x16 : !transform.any_op, !transform.any_param
+
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.test_vector_layout_analysis %top_level_func : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+#layout = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d1, d0)>
+
+builtin.module attributes { transform.with_named_sequence } {
+  func.func @anchor_mfma_16x16x16_mmt_batch(%a : memref<32x128xf16>, %b : memref<64x128xf16>, %init : vector<32x64xf32>) -> vector<32x64xf32> {
+    // CHECK-LABEL: anchor_mfma_16x16x16_mmt_batch
+    %c0 = arith.constant 0 : index
+    %cst_0 = arith.constant 0.0 : f16
+    %lhs = vector.transfer_read %a[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<32x128xf16>, vector<32x128xf16>
+    // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX,  LANEX], [2, 16]>, <[ BATCHY,  LANEY,  VECTORX], [8, 4, 4]>>}}
+    %rhs = vector.transfer_read %b[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<64x128xf16>, vector<64x128xf16>
+    // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX,  LANEX], [4, 16]>, <[ BATCHY,  LANEY,  VECTORX], [8, 4, 4]>>}}
+    %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %init : vector<32x128xf16>, vector<64x128xf16> into vector<32x64xf32>
+    // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX,  LANEY,  VECTORX], [2, 4, 4]>, <[ BATCHY,  LANEX], [4, 16]>>}}
+    return %output : vector<32x64xf32>
+  }
+  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %contract = transform.structured.match ops{["vector.contract"]} in %variant_op :  (!transform.any_op) -> !transform.any_op
+    %layout16x16x16 = transform.param.constant #layout -> !transform.any_param
+    transform.iree.set_contraction_layout_attributes %contract, %layout16x16x16 : !transform.any_op, !transform.any_param
+
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.test_vector_layout_analysis %top_level_func : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+#layout = #iree_gpu.mma_layout<WMMA_F16_16x16x16_F32>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d1, d0)>
+
+builtin.module attributes { transform.with_named_sequence } {
+  func.func @anchor_wmma_16x16x16_mmt(%a : memref<16x16xf16>, %b : memref<16x16xf16>, %init : vector<16x16xf32>) -> vector<16x16xf32> {
+    // CHECK-LABEL: anchor_wmma_16x16x16_mmt
+    %c0 = arith.constant 0 : index
+    %cst_0 = arith.constant 0.0 : f16
+    %lhs = vector.transfer_read %a[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+    // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX,  LANEX], [1, 16]>, <[ BATCHY,  LANEY,  VECTORX], [1, 1, 16]>>}}
+    %rhs = vector.transfer_read %b[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+    // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX,  LANEX], [1, 16]>, <[ BATCHY,  LANEY,  VECTORX], [1, 1, 16]>>}}
+    %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %init : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf32>
+    // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX,  VECTORY,  LANEY,  VECTORX], [1, 8, 2, 1]>, <[ BATCHY,  LANEX], [1, 16]>>}}
+    return %output : vector<16x16xf32>
+  }
+  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %contract = transform.structured.match ops{["vector.contract"]} in %variant_op :  (!transform.any_op) -> !transform.any_op
+    %layout16x16x16 = transform.param.constant #layout -> !transform.any_param
+    transform.iree.set_contraction_layout_attributes %contract, %layout16x16x16 : !transform.any_op, !transform.any_param
+
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.test_vector_layout_analysis %top_level_func : !transform.any_op
+    transform.yield
+  }
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp
index 8d17162..dd52c2c 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp
@@ -158,6 +158,8 @@
       if (isVectorDimension(name)) {
         int64_t step{1};
         if (name == LayoutDimension::VECTORY) {
+          assert(ranges.contains(LayoutDimension::VECTORX) &&
+                 "Expected VectorX to be specified on layouts with VectorY.");
           step = ranges.lookup(LayoutDimension::VECTORX).stop;
         }
         vecOffset = vecOffset.value_or(0) + it.getPosition() * step;