[VectorDistribution] Add pattern to distribute contractions (#16172)

This PR adds a pattern to distribute contractions to MFMA ops.
Currently, 16x16x16 and 32x32x8 MFMA operations are supported.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
index 84dce8d..5e9762c 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
@@ -85,30 +85,6 @@
   return simdIndex;
 }
 
-/// Given the state of the iterator, compute the indices of the distributed
-/// vector that the current iterator state is iterating over. The indices
-/// are not parameterized by thread, and it is expected that the indices for
-/// all threads are same.
-static SmallVector<int64_t> computeSIMTIndex(const LayoutIterator::State &state,
-                                             LayoutAttr layout) {
-  constexpr LayoutDimension labels[] = {
-      LayoutDimension::BATCHX, LayoutDimension::BATCHY,
-      LayoutDimension::VECTORZ, LayoutDimension::VECTORY,
-      LayoutDimension::VECTORX};
-
-  SmallVector<int64_t> offset;
-  for (LayoutDimension label : labels) {
-    std::optional shape = findDimShape(layout, label);
-    if (!shape) {
-      continue;
-    }
-    // Get current position for the label.
-    int64_t position = state.lookup(label).getPosition();
-    offset.push_back(position);
-  }
-  return offset;
-}
-
 struct DistributeConstants final : OpDistributionPattern<arith::ConstantOp> {
   using OpDistributionPattern::OpDistributionPattern;
 
@@ -129,8 +105,8 @@
 
     // Replace the original op with the distributed op.
     Type elementType = constant.getType().getElementType();
-    auto vectorType = VectorType::get(
-        layout.getDistributedShape(constant.getType()), elementType);
+    auto vectorType =
+        VectorType::get(layout.getDistributedShape(), elementType);
     Operation *distirbutedOp = rewriter.create<arith::ConstantOp>(
         constantOp.getLoc(), vectorType,
         SplatElementsAttr::get(vectorType, attr.getSplatValue<Attribute>()));
@@ -165,9 +141,8 @@
 
       // Distribute vector result types.
       if (auto vectorResult = dyn_cast<VectorValue>(result)) {
-        resultType = VectorType::get(
-            resLayout.getDistributedShape(vectorResult.getType()),
-            vectorResult.getType().getElementType());
+        resultType = VectorType::get(resLayout.getDistributedShape(),
+                                     vectorResult.getType().getElementType());
       }
       resultTypes.push_back(resultType);
     }
@@ -248,7 +223,7 @@
     iterator.apply([&](const LayoutIterator::State &state) {
       SmallVector<Value> memoryIndices =
           getMemoryIndices(state, memoryLayout, xferOp.getIndices(), rewriter);
-      SmallVector<int64_t> accIndices = computeSIMTIndex(state, vectorLayout);
+      SmallVector<int64_t> accIndices = state.computeSIMTIndex();
       accumulator = accessUnit(xferOp, memoryIndices, accIndices, accumulator,
                                vectorLayout, memoryLayout, rewriter);
     });
@@ -301,7 +276,6 @@
   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
                                 DistributionSignature &signature,
                                 PatternRewriter &rewriter) const override {
-    VectorValue vector = readOp.getVector();
     LayoutAttr vectorLayout = dyn_cast<LayoutAttr>(signature.results[0]);
     if (!vectorLayout) {
       return failure();
@@ -310,8 +284,8 @@
     // TODO: Return failure if we need masking.
 
     Type elementType = readOp.getSource().getType().getElementType();
-    auto vectorType = VectorType::get(
-        vectorLayout.getDistributedShape(vector.getType()), elementType);
+    auto vectorType =
+        VectorType::get(vectorLayout.getDistributedShape(), elementType);
     Value zero = rewriter.create<arith::ConstantOp>(
         readOp.getLoc(), vectorType, rewriter.getZeroAttr(vectorType));
     VectorValue acc = cast<VectorValue>(zero);
@@ -345,7 +319,6 @@
   LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
                                 DistributionSignature &signature,
                                 PatternRewriter &rewriter) const override {
-    VectorValue vector = writeOp.getVector();
     LayoutAttr vectorLayout = dyn_cast<LayoutAttr>(signature.operands[0]);
     if (!vectorLayout) {
       return failure();
@@ -354,8 +327,8 @@
     // TODO: Return failure if we need masking.
 
     Type elementType = writeOp.getSource().getType().getElementType();
-    auto vectorType = VectorType::get(
-        vectorLayout.getDistributedShape(vector.getType()), elementType);
+    auto vectorType =
+        VectorType::get(vectorLayout.getDistributedShape(), elementType);
     Value zero = rewriter.create<arith::ConstantOp>(
         writeOp.getLoc(), vectorType, rewriter.getZeroAttr(vectorType));
     VectorValue acc = cast<VectorValue>(zero);
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp
index 7de8a9b..4f98b6a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp
@@ -103,8 +103,7 @@
     return cast<VectorValue>(toSIMD.getInput());
   }
   // Create a "to_simt" op to convert the value to the distributed layout.
-  SmallVector<int64_t> distributedShape =
-      layout.getDistributedShape(value.getType());
+  SmallVector<int64_t> distributedShape = layout.getDistributedShape();
   VectorType distributedType =
       VectorType::get(distributedShape, value.getType().getElementType());
   auto toSIMT = rewriter.create<IREE::VectorExt::ToSIMTOp>(
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
index a7ab386..2d35a9a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
@@ -7,15 +7,15 @@
   func.func @distribute_elementwise_f16(%a: vector<16x16xf16>, %b: vector<16x16xf16>) -> vector<16x16xf16> {
     %c0 = arith.constant 0 : index
     %cst_0 = arith.constant 0.0 : f16
-    // CHECK: %[[ROOT:.*]] = arith.constant dense<0.000000e+00> : vector<4xf16>
+    // CHECK: %[[ROOT:.*]] = arith.constant dense<0.000000e+00> : vector<16xf16>
     %root = arith.constant {"__vector_layout_test_anchor_result_0" = #layout} dense<0.0> : vector<16x16xf16>
-    // CHECK-DAG: %[[B:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<16x16xf16> -> vector<4xf16>
-    // CHECK-DAG: %[[C:.*]] = arith.mulf %[[ROOT]], %[[B]] : vector<4xf16>
+    // CHECK-DAG: %[[B:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<16x16xf16> -> vector<16xf16>
+    // CHECK-DAG: %[[C:.*]] = arith.mulf %[[ROOT]], %[[B]] : vector<16xf16>
     %c = arith.mulf %root, %b : vector<16x16xf16>
-    // CHECK-DAG: %[[A:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<16x16xf16> -> vector<4xf16>
-    // CHECK-DAG: %[[D:.*]] = arith.addf %[[C]], %[[A]] fastmath<reassoc,nnan> : vector<4xf16>
+    // CHECK-DAG: %[[A:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<16x16xf16> -> vector<16xf16>
+    // CHECK-DAG: %[[D:.*]] = arith.addf %[[C]], %[[A]] fastmath<reassoc,nnan> : vector<16xf16>
     %d = arith.addf %c, %a fastmath<reassoc,nnan> : vector<16x16xf16>
-    // CHECK: iree_vector_ext.to_simd %[[D]] : vector<4xf16> -> vector<16x16xf16>
+    // CHECK: iree_vector_ext.to_simd %[[D]] : vector<16xf16> -> vector<16x16xf16>
     return %d : vector<16x16xf16>
   }
 
@@ -23,15 +23,15 @@
   func.func @distribute_elementwise_i32(%a: vector<16x16xi32>, %b: vector<16x16xi32>) -> vector<16x16xi32> {
     %c0 = arith.constant 0 : index
     %cst_0 = arith.constant 0 : i32
-    // CHECK: %[[ROOT:.*]] = arith.constant dense<0> : vector<4xi32>
+    // CHECK: %[[ROOT:.*]] = arith.constant dense<0> : vector<16xi32>
     %root = arith.constant {"__vector_layout_test_anchor_result_0" = #layout} dense<0> : vector<16x16xi32>
-    // CHECK-DAG: %[[B:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<16x16xi32> -> vector<4xi32>
-    // CHECK-DAG: %[[C:.*]] = arith.muli %[[ROOT]], %[[B]] : vector<4xi32>
+    // CHECK-DAG: %[[B:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<16x16xi32> -> vector<16xi32>
+    // CHECK-DAG: %[[C:.*]] = arith.muli %[[ROOT]], %[[B]] : vector<16xi32>
     %c = arith.muli %root, %b : vector<16x16xi32>
-    // CHECK-DAG: %[[A:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<16x16xi32> -> vector<4xi32>
-    // CHECK-DAG: %[[D:.*]] = arith.addi %[[C]], %[[A]] : vector<4xi32>
+    // CHECK-DAG: %[[A:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<16x16xi32> -> vector<16xi32>
+    // CHECK-DAG: %[[D:.*]] = arith.addi %[[C]], %[[A]] : vector<16xi32>
     %d = arith.addi %c, %a : vector<16x16xi32>
-    // CHECK: iree_vector_ext.to_simd %[[D]] : vector<4xi32> -> vector<16x16xi32>
+    // CHECK: iree_vector_ext.to_simd %[[D]] : vector<16xi32> -> vector<16x16xi32>
     return %d : vector<16x16xi32>
   }
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index 810b7ea..008e033 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -890,32 +890,6 @@
 // TestVectorLayoutAnalysisOp
 //===----------------------------------------------------------------------===//
 
-static void setAnchorOpsFromAttributes(VectorLayoutAnalysis &analysis,
-                                       Operation *root) {
-  root->walk([&](Operation *op) {
-    for (NamedAttribute attr : op->getAttrs()) {
-      StringRef name = attr.getName().strref();
-      if (name.find("__vector_layout_test_anchor_operand_") !=
-          std::string::npos) {
-        int operandNum;
-        name.substr(name.find_last_of("_") + 1)
-            .getAsInteger(/*Radix=*/10, operandNum);
-        assert(operandNum < op->getNumOperands() &&
-               "operand number out of range");
-        analysis.setAnchor(op->getOperand(operandNum), attr.getValue());
-      }
-      if (name.find("__vector_layout_test_anchor_result_") !=
-          std::string::npos) {
-        int resultNum;
-        name.substr(name.find_last_of("_") + 1)
-            .getAsInteger(/*Radix=*/10, resultNum);
-        assert(resultNum < op->getNumResults() && "result number out of range");
-        analysis.setAnchor(op->getResult(resultNum), attr.getValue());
-      }
-    }
-  });
-}
-
 static void emitLayoutRemarks(VectorLayoutAnalysis &analysis,
                               func::FuncOp funcOp) {
   funcOp.walk([&](Operation *op) {
diff --git a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp
index 22ec064..e01c4d4 100644
--- a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp
@@ -1019,3 +1019,31 @@
   print(llvm::dbgs());
   llvm::dbgs() << "\n";
 }
+
+namespace mlir::iree_compiler {
+
+void setAnchorOpsFromAttributes(VectorLayoutAnalysis &analysis,
+                                Operation *root) {
+  root->walk([&](Operation *op) {
+    for (NamedAttribute attr : op->getAttrs()) {
+      StringRef name = attr.getName().strref();
+      if (name.contains("__vector_layout_test_anchor_operand_")) {
+        int operandNum;
+        name.substr(name.find_last_of("_") + 1)
+            .getAsInteger(/*Radix=*/10, operandNum);
+        assert(operandNum < op->getNumOperands() &&
+               "operand number out of range");
+        analysis.setAnchor(op->getOperand(operandNum), attr.getValue());
+      }
+      if (name.contains("__vector_layout_test_anchor_result_")) {
+        int resultNum;
+        name.substr(name.find_last_of("_") + 1)
+            .getAsInteger(/*Radix=*/10, resultNum);
+        assert(resultNum < op->getNumResults() && "result number out of range");
+        analysis.setAnchor(op->getResult(resultNum), attr.getValue());
+      }
+    }
+  });
+}
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.h b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.h
index 3c9ebae..0282cb2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.h
+++ b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.h
@@ -146,6 +146,9 @@
   DataFlowSolver solver;
 };
 
+void setAnchorOpsFromAttributes(VectorLayoutAnalysis &analysis,
+                                Operation *root);
+
 }; // namespace iree_compiler
 }; // namespace mlir
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD.bazel
index 7f68c8d..274228e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD.bazel
@@ -65,6 +65,7 @@
         "//llvm-external-projects/iree-dialects:IREEDialectsTransforms",
         "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect",
         "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:AMDGPUDialect",
         "@llvm-project//mlir:AffineDialect",
         "@llvm-project//mlir:ArithDialect",
         "@llvm-project//mlir:BufferizationDialect",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/CMakeLists.txt
index 8491b58..0390d66 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/CMakeLists.txt
@@ -34,6 +34,7 @@
     IREEDialectsTransforms
     IREELinalgTransformDialect
     LLVMSupport
+    MLIRAMDGPUDialect
     MLIRAffineDialect
     MLIRArithDialect
     MLIRBufferizationDialect
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
index 7fd04fc..0ec4618 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
@@ -8,6 +8,8 @@
 
 #include "iree-dialects/Dialect/LinalgTransform/SimplePatternRewriter.h"
 #include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h"
+#include "iree/compiler/Codegen/Common/GPU/GPUPatterns.h"
+#include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h"
 #include "iree/compiler/Codegen/Common/GPU/Passes.h"
 #include "iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h"
 #include "iree/compiler/Codegen/Utils/GPUUtils.h"
@@ -17,6 +19,7 @@
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/Support/Debug.h"
 #include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
@@ -56,6 +59,7 @@
   // CreateAsyncGroupsOp depends on the following two dialects.
   declareGeneratedDialect<gpu::GPUDialect>();
   declareGeneratedDialect<nvgpu::NVGPUDialect>();
+  declareGeneratedDialect<amdgpu::AMDGPUDialect>();
 
   registerTransformOps<
 #define GET_OP_LIST
@@ -1496,5 +1500,32 @@
   transform::modifiesPayload(effects);
 }
 
+class TestVectorLayoutOptions : public VectorLayoutOptions {
+public:
+  TestVectorLayoutOptions(Operation *root) : VectorLayoutOptions(root) {}
+
+  void setAnchorOps(VectorLayoutAnalysis &analysis) override {
+    setAnchorOpsFromAttributes(analysis, root);
+  }
+};
+
+DiagnosedSilenceableFailure
+transform_dialect::TestAMDGPUContractionDistribution::applyToOne(
+    transform::TransformRewriter &rewriter, func::FuncOp target,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  TestVectorLayoutOptions options(target);
+  RewritePatternSet patterns(target.getContext());
+  populateAMDGPUDistributionPatterns(patterns);
+  distributeVectorOps(target, patterns, options);
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform_dialect::TestAMDGPUContractionDistribution::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(getTarget(), effects);
+  transform::modifiesPayload(effects);
+}
+
 #define GET_OP_CLASSES
 #include "iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.cpp.inc"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td
index 9c0d0b4..7b9e208 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td
@@ -312,8 +312,8 @@
   let results = (outs);
 
   let assemblyFormat = [{
-    $target 
-    attr-dict 
+    $target
+    attr-dict
     `:` functional-type($target, results)
   }];
   let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
@@ -437,9 +437,9 @@
 
   let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
 
-  let assemblyFormat = [{ 
-    $for_op 
-    attr-dict 
+  let assemblyFormat = [{
+    $for_op
+    attr-dict
     `:` functional-type(operands, results)}];
 
   let extraClassDeclaration = [{
@@ -472,9 +472,9 @@
 
   let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
 
-  let assemblyFormat = [{ 
-    $for_op 
-    attr-dict 
+  let assemblyFormat = [{
+    $for_op
+    attr-dict
     `:` functional-type(operands, results)}];
 
   let extraClassDeclaration = [{
@@ -509,9 +509,9 @@
                    UnitAttr:$use_mma_sync);
   let results = (outs);
 
-  let assemblyFormat = [{ 
-    $target 
-    attr-dict 
+  let assemblyFormat = [{
+    $target
+    attr-dict
     `:` functional-type(operands, results)}];
   let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
 
@@ -682,4 +682,41 @@
   }];
 }
 
+def TestAMDGPUContractionDistribution :
+  Op<Transform_Dialect, "iree.test_amdgpu_contraction_distribution",
+    [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     TransformEachOpTrait,
+     TransformOpInterface,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Run AMDGPU Vector Contraction distribution on the target as the root.
+
+    The anchor points are set by using the attribute
+    "__vector_layout_test_anchor_operand_x" and
+    "__vector_layout_test_anchor_result_x", where "x" is the operand/result
+    number.
+
+    This op produces amdgpu MFMA ops.
+
+    #### Return modes
+
+    This transform does not consume the target handle and always return success.
+    }];
+
+    let arguments = (ins TransformHandleTypeInterface:$target);
+    let results = (outs);
+
+    let assemblyFormat = [{ $target attr-dict `:` type($target)}];
+    let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+
+    let extraClassDeclaration = [{
+      ::mlir::DiagnosedSilenceableFailure applyToOne(
+          ::mlir::transform::TransformRewriter &rewriter,
+          ::mlir::func::FuncOp funcOp,
+          ::mlir::transform::ApplyToEachResultList &results,
+          ::mlir::transform::TransformState &state);
+    }];
+}
+
+
 #endif // IREE_COMPILER_CODEGEN_LLVMGPU_TRANSFORMEXTENSIONS_LLVMGPUEXTENSIONS
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/AMDGPUDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/AMDGPUDistributionPatterns.cpp
new file mode 100644
index 0000000..cf0a0cb
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/AMDGPUDistributionPatterns.cpp
@@ -0,0 +1,358 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h"
+#include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h"
+#include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h"
+#include "iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h"
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+
+namespace mlir::iree_compiler {
+
+using namespace mlir::iree_compiler::IREE::VectorExt;
+using VectorValue = TypedValue<VectorType>;
+
+enum class ContractMatrixType { A, B, C, D };
+enum class ContractType { MM, MMT, MTM, MTMT, UNSUPPORTED };
+
+// The naming scheme for these operators is:
+// InputType_MxNxK_OutputType.
+enum class MFMAType {
+  F16_16x16x16_F32,
+  F16_32x32x8_F32,
+};
+
+namespace {
+
+struct DistributeContractions final
+    : OpDistributionPattern<vector::ContractionOp> {
+  using OpDistributionPattern::OpDistributionPattern;
+
+  // For a MM contraction, we compute C(i, k) += A(i, j) * B(j, k).
+  // If we have an MMT contraction, we compute C(i, k) += A(i, j) * B(k, j).
+  // This function returns the appropriate indices for the A and B matrices.
+  // Given incoming indices (i, j), it either returns the same or swaps them,
+  // depending on the type of contraction and type of matrix.
+  SmallVector<int64_t> getIndices(ContractType contractType,
+                                  ContractMatrixType matrixType, int i,
+                                  int j) const {
+    SmallVector<int64_t> originalIndices{i, j};
+    SmallVector<int64_t> swappedIndices{j, i};
+    if (contractType == ContractType::MTMT)
+      return swappedIndices;
+    if ((contractType == ContractType::MTM) &&
+        (matrixType == ContractMatrixType::A))
+      return swappedIndices;
+    if ((contractType == ContractType::MMT) &&
+        (matrixType == ContractMatrixType::B))
+      return swappedIndices;
+    return originalIndices;
+  }
+
+  int64_t getReductionDimensionShape(int64_t rowBatch, int64_t colBatch,
+                                     ContractType contractType) const {
+    if ((contractType == ContractType::MTM) ||
+        (contractType == ContractType::MTMT)) {
+      return rowBatch;
+    }
+    return colBatch;
+  }
+
+  ContractType inferContractType(MLIRContext *ctx,
+                                 SmallVector<AffineMap> maps) const {
+    using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+    auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
+    AffineExpr m, n, k;
+    bindDims(ctx, m, n, k);
+    if ((maps == infer({{m, k}, {k, n}, {m, n}})) ||
+        (maps == infer({{n, k}, {k, m}, {n, m}}))) {
+      return ContractType::MM;
+    }
+    if ((maps == infer({{m, k}, {n, k}, {m, n}})) ||
+        (maps == infer({{n, k}, {m, k}, {n, m}}))) {
+      return ContractType::MMT;
+    }
+    if ((maps == infer({{k, m}, {k, n}, {m, n}})) ||
+        (maps == infer({{k, n}, {k, m}, {n, m}}))) {
+      return ContractType::MTM;
+    }
+    if ((maps == infer({{k, m}, {n, k}, {m, n}})) ||
+        (maps == infer({{k, n}, {m, k}, {n, m}}))) {
+      return ContractType::MTMT;
+    }
+    return ContractType::UNSUPPORTED;
+  }
+
+  Value computeMMA(Value a, Value b, Value c, Location loc, OpBuilder &rewriter,
+                   MFMAType mfmaType) const {
+    uint32_t m, n, k, blks;
+    if (mfmaType == MFMAType::F16_16x16x16_F32) {
+      m = n = k = 16;
+    } else if (mfmaType == MFMAType::F16_32x32x8_F32) {
+      m = n = 32;
+      k = 8;
+    }
+    blks = 1;
+    return rewriter.create<amdgpu::MFMAOp>(loc, c.getType(), m, n, k, blks, a,
+                                           b, c);
+  }
+
+  PerDimLayoutAttr createPerDimLayout(MLIRContext *ctx,
+                                      ArrayRef<LayoutDimension> dims,
+                                      ArrayRef<int64_t> shapes) const {
+    SmallVector<LayoutDimensionAttr> dimAttrs;
+    for (auto dim : dims)
+      dimAttrs.push_back(LayoutDimensionAttr::get(ctx, dim));
+    return PerDimLayoutAttr::get(ctx, dimAttrs, shapes);
+  }
+
+  std::tuple<PerDimLayoutAttr, PerDimLayoutAttr> createCanonicalLayouts16x16x16(
+      LayoutDimension batchRowLabel, int64_t batchRow,
+      LayoutDimension batchColLabel, int64_t batchCol) const {
+    MLIRContext *ctx = getContext();
+    PerDimLayoutAttr rowLayout = createPerDimLayout(
+        ctx, {batchRowLabel, LayoutDimension::LANEX}, {batchRow, 16});
+    PerDimLayoutAttr colLayout = createPerDimLayout(
+        ctx, {batchColLabel, LayoutDimension::LANEY, LayoutDimension::VECTORX},
+        {batchCol, 4, 4});
+    return {rowLayout, colLayout};
+  }
+
+  bool isCompatible16x16x16A(LayoutAttr layout, int64_t batchRow,
+                             int64_t batchCol) const {
+    auto [rowLayout, colLayout] = createCanonicalLayouts16x16x16(
+        LayoutDimension::BATCHX, batchRow, LayoutDimension::BATCHY, batchCol);
+    LayoutAttr canonicalLayout =
+        LayoutAttr::get(getContext(), {rowLayout, colLayout});
+    return layout == canonicalLayout;
+  }
+
+  bool isCompatible16x16x16B(LayoutAttr layout, int64_t batchRow,
+                             int64_t batchCol) const {
+    auto [colLayout, rowLayout] = createCanonicalLayouts16x16x16(
+        LayoutDimension::BATCHY, batchCol, LayoutDimension::BATCHX, batchRow);
+    LayoutAttr canonicalLayout =
+        LayoutAttr::get(getContext(), {rowLayout, colLayout});
+    return layout == canonicalLayout;
+  }
+
+  bool isCompatible16x16x16C(LayoutAttr layout, int64_t batchRow,
+                             int64_t batchCol) const {
+    return isCompatible16x16x16B(layout, batchRow, batchCol);
+  }
+
+  std::tuple<PerDimLayoutAttr, PerDimLayoutAttr>
+  createCanonicalLayouts32x32x8(LayoutDimension batchRowLabel, int64_t batchRow,
+                                LayoutDimension batchColLabel, int64_t batchCol,
+                                ContractMatrixType matrixType) const {
+    MLIRContext *ctx = getContext();
+    PerDimLayoutAttr rowLayout = createPerDimLayout(
+        ctx, {batchRowLabel, LayoutDimension::LANEX}, {batchRow, 32});
+    PerDimLayoutAttr colLayout;
+    if (matrixType == ContractMatrixType::C) {
+      colLayout =
+          createPerDimLayout(ctx,
+                             {batchColLabel, LayoutDimension::VECTORY,
+                              LayoutDimension::LANEY, LayoutDimension::VECTORX},
+                             {batchCol, 4, 2, 4});
+    } else {
+      colLayout = createPerDimLayout(
+          ctx,
+          {batchColLabel, LayoutDimension::LANEY, LayoutDimension::VECTORX},
+          {batchCol, 2, 4});
+    }
+    return {rowLayout, colLayout};
+  }
+
+  bool isCompatible32x32x8A(LayoutAttr layout, int64_t batchRow,
+                            int64_t batchCol) const {
+    auto [rowLayout, colLayout] = createCanonicalLayouts32x32x8(
+        LayoutDimension::BATCHX, batchRow, LayoutDimension::BATCHY, batchCol,
+        ContractMatrixType::A);
+    LayoutAttr canonicalLayout =
+        LayoutAttr::get(getContext(), {rowLayout, colLayout});
+    return layout == canonicalLayout;
+  }
+
+  bool isCompatible32x32x8B(LayoutAttr layout, int64_t batchRow,
+                            int64_t batchCol) const {
+    auto [colLayout, rowLayout] = createCanonicalLayouts32x32x8(
+        LayoutDimension::BATCHY, batchCol, LayoutDimension::BATCHX, batchRow,
+        ContractMatrixType::B);
+    LayoutAttr canonicalLayout =
+        LayoutAttr::get(getContext(), {rowLayout, colLayout});
+    return layout == canonicalLayout;
+  }
+
+  bool isCompatible32x32x8C(LayoutAttr layout, int64_t batchRow,
+                            int64_t batchCol) const {
+    auto [colLayout, rowLayout] = createCanonicalLayouts32x32x8(
+        LayoutDimension::BATCHY, batchCol, LayoutDimension::BATCHX, batchRow,
+        ContractMatrixType::C);
+    LayoutAttr canonicalLayout =
+        LayoutAttr::get(getContext(), {rowLayout, colLayout});
+    return layout == canonicalLayout;
+  }
+
+  bool isCompatible16x16x16(LayoutAttr layout, ContractMatrixType matrixType,
+                            int64_t batchRow, int64_t batchCol) const {
+    switch (matrixType) {
+    case ContractMatrixType::A:
+      return isCompatible16x16x16A(layout, batchRow, batchCol);
+    case ContractMatrixType::B:
+      return isCompatible16x16x16B(layout, batchRow, batchCol);
+    default:
+      return isCompatible16x16x16C(layout, batchRow, batchCol);
+    }
+    return false;
+  }
+
+  bool isCompatible32x32x8(LayoutAttr layout, ContractMatrixType matrixType,
+                           int64_t batchRow, int64_t batchCol) const {
+    switch (matrixType) {
+    case ContractMatrixType::A:
+      return isCompatible32x32x8A(layout, batchRow, batchCol);
+    case ContractMatrixType::B:
+      return isCompatible32x32x8B(layout, batchRow, batchCol);
+    default:
+      return isCompatible32x32x8C(layout, batchRow, batchCol);
+    }
+    return false;
+  }
+
+  bool isCompatible(LayoutAttr layout, ContractMatrixType matrixType,
+                    MFMAType mfmaType) const {
+    std::optional<int64_t> batchRow = layout.getBatchDim(0);
+    if (!batchRow)
+      return false;
+    std::optional<int64_t> batchCol = layout.getBatchDim(1);
+    if (!batchCol)
+      return false;
+    switch (mfmaType) {
+    case MFMAType::F16_16x16x16_F32:
+      return isCompatible16x16x16(layout, matrixType, batchRow.value(),
+                                  batchCol.value());
+    case MFMAType::F16_32x32x8_F32:
+      return isCompatible32x32x8(layout, matrixType, batchRow.value(),
+                                 batchCol.value());
+    default:
+      return false;
+    }
+    return false;
+  }
+
+  // If we have a prior guess of the MFMA type, only evaluate that type.
+  // Otherwise, evaluate all types to find a match.
+  std::optional<MFMAType> inferMFMAType(LayoutAttr layout,
+                                        ContractMatrixType matrixType,
+                                        std::optional<MFMAType> prior) const {
+    SmallVector<MFMAType> mfmaTypes;
+    if (prior) {
+      mfmaTypes.push_back(prior.value());
+    } else {
+      mfmaTypes = {MFMAType::F16_16x16x16_F32, MFMAType::F16_32x32x8_F32};
+    }
+    for (MFMAType mfmaType : mfmaTypes) {
+      if (isCompatible(layout, matrixType, mfmaType))
+        return mfmaType;
+    }
+    return std::nullopt;
+  }
+
+  // Inputs are LHS, RHS and ACC operands and corresponding layouts.
+  // Output is inferred MFMAType or none (if layout is not compatible with any
+  // MFMA layout).
+  std::optional<MFMAType>
+  inferCompatibleMFMAType(ArrayRef<LayoutAttr> layouts) const {
+    std::optional<MFMAType> mfmaType{std::nullopt};
+    SmallVector<ContractMatrixType> matrixTypes{
+        ContractMatrixType::A, ContractMatrixType::B, ContractMatrixType::C};
+    for (auto [layout, matrixType] : llvm::zip(layouts, matrixTypes)) {
+      mfmaType = inferMFMAType(layout, matrixType, mfmaType);
+      if (!mfmaType)
+        return std::nullopt;
+    }
+    return mfmaType;
+  }
+
+  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+                                DistributionSignature &signature,
+                                PatternRewriter &rewriter) const override {
+    constexpr int LHS = 0;
+    constexpr int RHS = 1;
+    constexpr int ACC = 2;
+    SmallVector<VectorValue> operands;
+    SmallVector<LayoutAttr> layouts;
+    for (auto [operand, layout] :
+         llvm::zip_equal(contractOp->getOperands(), signature.operands)) {
+      if (auto vectorOperand = dyn_cast<VectorValue>(operand)) {
+        if (auto vectorLayout = dyn_cast<LayoutAttr>(layout)) {
+          operands.push_back(vectorOperand);
+          layouts.push_back(vectorLayout);
+        }
+      }
+    }
+    LayoutAttr resultLayout = dyn_cast<LayoutAttr>(signature.results[0]);
+    if (!resultLayout)
+      return failure();
+    std::optional<MFMAType> mfmaType = inferCompatibleMFMAType(layouts);
+    if (!mfmaType)
+      return failure();
+
+    Type elementType =
+        llvm::cast<ShapedType>(operands[ACC].getType()).getElementType();
+    SmallVector<int64_t> vectorShape = resultLayout.getDistributedShape();
+    auto vectorType = VectorType::get(vectorShape, elementType);
+    Location loc = contractOp.getLoc();
+    Value vector = rewriter.create<arith::ConstantOp>(
+        loc, vectorType, rewriter.getZeroAttr(vectorType));
+
+    ContractType contractType = inferContractType(
+        contractOp.getContext(), contractOp.getIndexingMapsArray());
+    if (contractType == ContractType::UNSUPPORTED)
+      return failure();
+
+    std::optional<int64_t> rowBatch = layouts[LHS].getBatchDim(0);
+    if (!rowBatch)
+      return failure();
+    std::optional<int64_t> colBatch = layouts[LHS].getBatchDim(1);
+    if (!colBatch)
+      return failure();
+
+    int K = getReductionDimensionShape(rowBatch.value(), colBatch.value(),
+                                       contractType);
+
+    auto contractFn = [&](const LayoutIterator::State &state) {
+      SmallVector<int64_t> indices = state.computeIteratorProjectedSIMTIndex();
+      Value dMatrix = rewriter.create<vector::ExtractOp>(
+          loc, getDistributed(rewriter, operands[ACC], layouts[ACC]), indices);
+      for (int k = 0; k < K; k++) {
+        Value aMatrix = rewriter.create<vector::ExtractOp>(
+            loc, getDistributed(rewriter, operands[LHS], layouts[LHS]),
+            getIndices(contractType, ContractMatrixType::A, indices[0], k));
+        Value bMatrix = rewriter.create<vector::ExtractOp>(
+            loc, getDistributed(rewriter, operands[RHS], layouts[RHS]),
+            getIndices(contractType, ContractMatrixType::B, k, indices[1]));
+        dMatrix = computeMMA(aMatrix, bMatrix, dMatrix, loc, rewriter,
+                             mfmaType.value());
+      }
+      vector = rewriter.create<vector::InsertOp>(loc, dMatrix, vector, indices);
+    };
+
+    LayoutIterator iterator(resultLayout);
+    LayoutIterator batchIterator = iterator.getBatchIterator();
+    batchIterator.apply(contractFn);
+    replaceOpWithDistributedValues(rewriter, contractOp, vector);
+    return success();
+  }
+};
+} // namespace
+
+void populateAMDGPUDistributionPatterns(RewritePatternSet &patterns) {
+  patterns.add<DistributeContractions>(patterns.getContext());
+}
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/BUILD.bazel
index 62e6766..b6a328f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/BUILD.bazel
@@ -17,6 +17,7 @@
 iree_compiler_cc_library(
     name = "Utils",
     srcs = [
+        "AMDGPUDistributionPatterns.cpp",
         "LLVMGPULayoutAnalysisAndDistribution.cpp",
         "LLVMGPUUtils.cpp",
     ],
@@ -24,9 +25,13 @@
         "LLVMGPUUtils.h",
     ],
     deps = [
+        "//compiler/src/iree/compiler/Codegen/Common",
+        "//compiler/src/iree/compiler/Codegen/Common/GPU:CommonGPUPasses",
         "//compiler/src/iree/compiler/Codegen/Transforms",
         "//compiler/src/iree/compiler/Codegen/Utils",
+        "//llvm-external-projects/iree-dialects:IREEVectorExtDialect",
         "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:AMDGPUDialect",
         "@llvm-project//mlir:AffineDialect",
         "@llvm-project//mlir:ArithDialect",
         "@llvm-project//mlir:FuncDialect",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/CMakeLists.txt
index 8609095..10edda1 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/CMakeLists.txt
@@ -16,10 +16,13 @@
   HDRS
     "LLVMGPUUtils.h"
   SRCS
+    "AMDGPUDistributionPatterns.cpp"
     "LLVMGPULayoutAnalysisAndDistribution.cpp"
     "LLVMGPUUtils.cpp"
   DEPS
+    IREEVectorExtDialect
     LLVMSupport
+    MLIRAMDGPUDialect
     MLIRAffineDialect
     MLIRArithDialect
     MLIRFuncDialect
@@ -29,6 +32,8 @@
     MLIRMemRefDialect
     MLIRNVGPUDialect
     MLIRVectorDialect
+    iree::compiler::Codegen::Common
+    iree::compiler::Codegen::Common::GPU::CommonGPUPasses
     iree::compiler::Codegen::Transforms
     iree::compiler::Codegen::Utils
   PUBLIC
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h
index 7a6a632..eb7b126 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h
@@ -32,6 +32,9 @@
 /// from the previous alias group before starting a new one.
 void packSharedMemoryAlloc(func::FuncOp funcOp);
 
+// Add patterns to distribute contractions to MFMA ops.
+void populateAMDGPUDistributionPatterns(RewritePatternSet &patterns);
+
 } // namespace mlir::iree_compiler
 
 #endif
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
index bd80615..3fa0c2f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
@@ -18,6 +18,7 @@
     name = "lit",
     srcs = enforce_glob(
         [
+            "amdgpu_contraction_distribution.mlir",
             "attention.mlir",
             "conv_pipeline_test.mlir",
             "convert_to_nvvm.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index d20d23d..cb082fc 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -14,6 +14,7 @@
   NAME
     lit
   SRCS
+    "amdgpu_contraction_distribution.mlir"
     "attention.mlir"
     "cast_address_space_function.mlir"
     "config_matvec.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_contraction_distribution.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_contraction_distribution.mlir
new file mode 100644
index 0000000..03f4011
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_contraction_distribution.mlir
@@ -0,0 +1,115 @@
+// RUN: iree-opt --iree-transform-dialect-interpreter --split-input-file --cse %s | FileCheck %s
+
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d1, d0)>
+#row_layout = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX], [1, 16]>
+#col_layout = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [1, 4, 4]>
+#row_layout2 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEY, VECTORX], [1, 4, 4]>
+#col_layout2 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [1, 16]>
+#layout_a = #iree_vector_ext.layout<#row_layout, #col_layout>
+#layout_c = #iree_vector_ext.layout<#row_layout2, #col_layout2>
+builtin.module attributes { transform.with_named_sequence } {
+  func.func @distribute_mfma_16x16x16_mmt(%a : vector<16x16xf16>, %b : vector<16x16xf16>, %c : vector<16x16xf32>) -> vector<16x16xf32> {
+    // CHECK-LABEL: distribute_mfma_16x16x16_mmt
+    // CHECK-SAME: %[[ARG0:.+]]: vector<16x16xf16>, %[[ARG1:.+]]: vector<16x16xf16>, %[[ARG2:.+]]: vector<16x16xf32>
+    // CHECK-DAG: %[[C:.+]] = iree_vector_ext.to_simt %[[ARG2]] : vector<16x16xf32> -> vector<1x1x4xf32>
+    // CHECK-DAG: %[[CV:.+]] = vector.extract %[[C]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
+    // CHECK-DAG: %[[A:.+]] = iree_vector_ext.to_simt %[[ARG0]] : vector<16x16xf16> -> vector<1x1x4xf16>
+    // CHECK-DAG: %[[AV:.+]] = vector.extract %[[A]][0, 0] : vector<4xf16> from vector<1x1x4xf16>
+    // CHECK-DAG: %[[B:.+]] = iree_vector_ext.to_simt %[[ARG1]] : vector<16x16xf16> -> vector<1x1x4xf16>
+    // CHECK-DAG: %[[BV:.+]] = vector.extract %[[B]][0, 0] : vector<4xf16> from vector<1x1x4xf16>
+    // CHECK-DAG: %[[OUT:.+]] = amdgpu.mfma %[[AV]] * %[[BV]] + %[[CV]] {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
+    %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"],
+                               kind = #vector.kind<add>,
+                               "__vector_layout_test_anchor_operand_0" = #layout_a,
+                               "__vector_layout_test_anchor_operand_1" = #layout_c,
+                               "__vector_layout_test_anchor_operand_2" = #layout_c,
+                               "__vector_layout_test_anchor_result_0" = #layout_c
+                               }
+                                %a, %b, %c : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf32>
+    return %output : vector<16x16xf32>
+  }
+  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.test_amdgpu_contraction_distribution %top_level_func : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#row_layout = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX], [1, 32]>
+#col_layout = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [1, 2, 4]>
+#row_layout1 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEY, VECTORX], [1, 2, 4]>
+#col_layout1 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [1, 32]>
+#row_layout2 = #iree_vector_ext.per_dim_layout<[BATCHX, VECTORY, LANEY, VECTORX], [1, 4, 2, 4]>
+#col_layout2 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [1, 32]>
+#layout_a = #iree_vector_ext.layout<#row_layout, #col_layout>
+#layout_b = #iree_vector_ext.layout<#row_layout1, #col_layout1>
+#layout_c = #iree_vector_ext.layout<#row_layout2, #col_layout2>
+builtin.module attributes { transform.with_named_sequence } {
+  func.func @distribute_mfma_32x32x8_mm(%a : vector<32x8xf16>, %b : vector<8x32xf16>, %c : vector<32x32xf32>) -> vector<32x32xf32> {
+    // CHECK-LABEL: distribute_mfma_32x32x8_mm
+    // CHECK-SAME: %[[ARG0:.+]]: vector<32x8xf16>, %[[ARG1:.+]]: vector<8x32xf16>, %[[ARG2:.+]]: vector<32x32xf32>
+    // CHECK-DAG: %[[C:.+]] = iree_vector_ext.to_simt %[[ARG2]] : vector<32x32xf32> -> vector<1x1x16xf32>
+    // CHECK-DAG: %[[CV:.+]] = vector.extract %[[C]][0, 0] : vector<16xf32> from vector<1x1x16xf32>
+    // CHECK-DAG: %[[A:.+]] = iree_vector_ext.to_simt %[[ARG0]] : vector<32x8xf16> -> vector<1x1x4xf16>
+    // CHECK-DAG: %[[AV:.+]] = vector.extract %[[A]][0, 0] : vector<4xf16> from vector<1x1x4xf16>
+    // CHECK-DAG: %[[B:.+]] = iree_vector_ext.to_simt %[[ARG1]] : vector<8x32xf16> -> vector<1x1x4xf16>
+    // CHECK-DAG: %[[BV:.+]] = vector.extract %[[B]][0, 0] : vector<4xf16> from vector<1x1x4xf16>
+    // CHECK-DAG: %[[OUT:.+]] = amdgpu.mfma %[[AV]] * %[[BV]] + %[[CV]] {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<16xf32>
+    %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"],
+                               kind = #vector.kind<add>,
+                               "__vector_layout_test_anchor_operand_0" = #layout_a,
+                               "__vector_layout_test_anchor_operand_1" = #layout_b,
+                               "__vector_layout_test_anchor_operand_2" = #layout_c,
+                               "__vector_layout_test_anchor_result_0" = #layout_c
+                               }
+                                %a, %b, %c : vector<32x8xf16>, vector<8x32xf16> into vector<32x32xf32>
+    return %output : vector<32x32xf32>
+  }
+  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.test_amdgpu_contraction_distribution %top_level_func : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d1, d0)>
+#row_layout = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX], [2, 16]>
+#col_layout = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [8, 4, 4]>
+#row_layout2 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEY, VECTORX], [4, 4, 4]>
+#col_layout2 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [8, 16]>
+#row_layout3 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEY, VECTORX], [2, 4, 4]>
+#col_layout3 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [4, 16]>
+#layout_a = #iree_vector_ext.layout<#row_layout, #col_layout>
+#layout_b = #iree_vector_ext.layout<#row_layout2, #col_layout2>
+#layout_c = #iree_vector_ext.layout<#row_layout3, #col_layout3>
+builtin.module attributes { transform.with_named_sequence } {
+  func.func @distribute_mfma_16x16x16_mmt_batch(%a : vector<32x128xf16>, %b : vector<64x128xf16>, %c : vector<32x64xf32>) -> vector<32x64xf32> {
+    // CHECK-LABEL: distribute_mfma_16x16x16_mmt_batch
+    // CHECK-COUNT-64: amdgpu.mfma {{.*}}, vector<4xf32>
+    %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"],
+                               kind = #vector.kind<add>,
+                               "__vector_layout_test_anchor_operand_0" = #layout_a,
+                               "__vector_layout_test_anchor_operand_1" = #layout_b,
+                               "__vector_layout_test_anchor_operand_2" = #layout_c,
+                               "__vector_layout_test_anchor_result_0" = #layout_c
+                               }
+                                %a, %b, %c : vector<32x128xf16>, vector<64x128xf16> into vector<32x64xf32>
+    return %output : vector<32x64xf32>
+  }
+  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.test_amdgpu_contraction_distribution %top_level_func : !transform.any_op
+    transform.yield
+  }
+}
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtBase.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtBase.td
index 69588af..0b4562a 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtBase.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtBase.td
@@ -109,6 +109,10 @@
   );
   let assemblyFormat = "`<`$layouts`>`";
   let genVerifyDecl = 0;
+  let extraClassDeclaration = [{
+    std::optional<int64_t> getBatchDim(int64_t dim);
+    PerDimLayoutAttr getDimLayout(int64_t dim) const;
+  }];
 }
 
 #endif // IREE_DIALECT_VECTOREXT_BASE
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtInterfaces.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtInterfaces.td
index 743186b..6d84b14 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtInterfaces.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtInterfaces.td
@@ -39,7 +39,7 @@
       /*description=*/"Get the distributed shape for the given vector type.",
       /*retTy=*/"SmallVector<int64_t>",
       /*methodName=*/"getDistributedShape",
-      /*args=*/(ins "VectorType":$type)
+      /*args=*/(ins)
     >
   ];
 }
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h
index f13d2f0..ec33eb1 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h
@@ -59,7 +59,6 @@
   DimensionalIterator begin() const { return DimensionalIterator(start, step); }
   DimensionalIterator end() const { return DimensionalIterator(stop, step); }
 
-private:
   int64_t start, stop, step;
 };
 
@@ -70,27 +69,70 @@
 // required during distribution.
 class LayoutIterator {
 public:
-  using State = llvm::MapVector<LayoutDimension, DimensionalIterator>;
-  using DimensionMapping =
-      llvm::DenseMap<int64_t, SmallVector<LayoutDimension>>;
-  void maybeFreezeAndConcatenate(const LayoutIterator &frozenIterator);
+  struct State {
+    SmallVector<int64_t> computeSIMTIndex() const;
+    SmallVector<int64_t> computeIteratorProjectedSIMTIndex() const;
+    bool contains(LayoutDimension dim) const { return iterators.contains(dim); }
+    void erase(LayoutDimension dim) { iterators.erase(dim); }
+    DimensionalIterator lookup(LayoutDimension dim) const {
+      return iterators.lookup(dim);
+    }
+    DimensionalIterator &operator[](LayoutDimension dim) {
+      return iterators[dim];
+    }
+    void print() const {
+      for (const auto &[dim, it] : iterators) {
+        llvm::outs() << stringifyLayoutDimension(dim).str() + ":" +
+                            std::to_string(*it) + ", ";
+      }
+      llvm::outs() << "\n";
+    }
+    llvm::MapVector<LayoutDimension, DimensionalIterator> iterators;
+    DenseMap<int64_t, DenseSet<LayoutDimension>> simdToLayoutDim;
+    llvm::MapVector<LayoutDimension, DimensionalRange> ranges;
+    SmallVector<LayoutDimension> labels{
+        LayoutDimension::BATCHX, LayoutDimension::BATCHY,
+        LayoutDimension::VECTORY, LayoutDimension::VECTORX};
+  };
+  void maybeFreezeAndConcatenate(const LayoutIterator::State &frozenState);
+  LayoutIterator(LayoutAttr &attr);
+  LayoutIterator(LayoutAttr &attr, int64_t simtIndex);
   LayoutIterator(LayoutAttr &attr, DenseMap<LayoutDimension, int64_t> strides);
+  LayoutIterator(LayoutAttr &attr, DenseMap<LayoutDimension, int64_t> strides,
+                 int64_t simtIndex);
   LayoutIterator(PerDimLayoutAttr &attr,
                  DenseMap<LayoutDimension, int64_t> strides);
   void apply(std::function<void(const LayoutIterator::State &)>);
   LayoutIterator &operator++();
   State getState() const { return state; }
+  void erase(LayoutDimension dim);
+  LayoutIterator getBatchIterator() const;
 
 private:
-  void initialize(PerDimLayoutAttr &attr,
-                  DenseMap<LayoutDimension, int64_t> strides);
+  void initialize(const PerDimLayoutAttr &attr,
+                  DenseMap<LayoutDimension, int64_t> strides,
+                  std::optional<int64_t> simdIndex);
   bool iterationComplete();
   State state;
-  llvm::MapVector<LayoutDimension, DimensionalRange> ranges;
-  DimensionMapping simdDimensionToLayoutDimension;
   DenseSet<LayoutDimension> frozenDimensions;
+  int64_t iterations{0};
+  int64_t maxIterations{1};
 };
 
+inline bool isBatchDimension(LayoutDimension dim) {
+  return (dim == LayoutDimension::BATCHX) || (dim == LayoutDimension::BATCHY);
+}
+
+inline bool isLaneDimension(LayoutDimension dim) {
+  return (dim == LayoutDimension::LANEX) || (dim == LayoutDimension::LANEY) ||
+         (dim == LayoutDimension::LANEZ);
+}
+
+inline bool isVectorDimension(LayoutDimension dim) {
+  return (dim == LayoutDimension::VECTORX) ||
+         (dim == LayoutDimension::VECTORY) || (dim == LayoutDimension::VECTORZ);
+}
+
 } // namespace mlir::iree_compiler::IREE::VectorExt
 
 #endif // IREE_DIALECTS_DIALECT_VECTOREXT_IR_VECTOREXTOPS_H_
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtDialect.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtDialect.cpp
index e2973a1..04cb0dc 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtDialect.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtDialect.cpp
@@ -109,17 +109,47 @@
   return LayoutAttr::get(getContext(), newLayouts);
 }
 
-SmallVector<int64_t> LayoutAttr::getDistributedShape(VectorType) const {
-  LayoutDimension labels[] = {LayoutDimension::BATCHX, LayoutDimension::BATCHY,
-                              LayoutDimension::VECTORX};
+// This function returns the distributed shape of the SIMT
+// vector and evaluates it in the following order:
+// BATCHX, BATCHY, VECTORY, VECTORX
+// The vector dimensions are combined into a single SIMT
+// vector dimension.
+SmallVector<int64_t> LayoutAttr::getDistributedShape() const {
+  SmallVector<LayoutDimension> labels{
+      LayoutDimension::BATCHX, LayoutDimension::BATCHY,
+      LayoutDimension::VECTORY, LayoutDimension::VECTORX};
   SmallVector<int64_t> simtVectorShape;
+  std::optional<int64_t> vectorShape;
   for (LayoutDimension dim : labels) {
     ArrayRef<PerDimLayoutAttr> layouts = getLayouts();
     for (PerDimLayoutAttr layout : layouts) {
       if (!layout.contains(dim))
         continue;
-      simtVectorShape.push_back(layout.getShape(dim).value());
+      int64_t shape = layout.getShape(dim).value();
+      if (isVectorDimension(dim)) {
+        vectorShape = shape * vectorShape.value_or(1);
+        continue;
+      }
+      simtVectorShape.push_back(shape);
     }
   }
+  if (vectorShape)
+    simtVectorShape.push_back(vectorShape.value());
   return simtVectorShape;
 }
+
+PerDimLayoutAttr LayoutAttr::getDimLayout(int64_t dim) const {
+  assert(dim >= 0 && dim < getLayouts().size());
+  return getLayouts()[dim];
+}
+
+std::optional<int64_t> LayoutAttr::getBatchDim(int64_t dim) {
+  assert(dim < getLayouts().size());
+  PerDimLayoutAttr layout = getDimLayout(dim);
+  for (auto [name, shape] :
+       llvm::zip_equal(layout.getLabels(), layout.getShapes())) {
+    if (isBatchDimension(name.getValue()))
+      return shape;
+  }
+  return std::nullopt;
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp
index 8894d69..228cc10 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp
@@ -52,22 +52,9 @@
   return {};
 }
 
-LayoutIterator &LayoutIterator::operator++() {
-  for (auto &[dim, it] : state) {
-    if (frozenDimensions.contains(dim))
-      continue;
-    ++it;
-    if (it < ranges[dim].end()) {
-      break;
-    }
-    it = ranges[dim].begin();
-  }
-  return *this;
-}
-
 void LayoutIterator::maybeFreezeAndConcatenate(
-    const LayoutIterator &frozenIterator) {
-  for (auto &[frozenDim, frozenIt] : frozenIterator.getState()) {
+    const LayoutIterator::State &frozenState) {
+  for (auto &[frozenDim, frozenIt] : frozenState.iterators) {
     if (!state.contains(frozenDim)) {
       frozenDimensions.insert(frozenDim);
       state[frozenDim] = frozenIt;
@@ -75,13 +62,9 @@
   }
 }
 
-static bool isLaneDimension(LayoutDimension dim) {
-  return (dim == LayoutDimension::LANEX) || (dim == LayoutDimension::LANEY) ||
-         (dim == LayoutDimension::LANEZ);
-}
-
-void LayoutIterator::initialize(PerDimLayoutAttr &attr,
-                                DenseMap<LayoutDimension, int64_t> strides) {
+void LayoutIterator::initialize(const PerDimLayoutAttr &attr,
+                                DenseMap<LayoutDimension, int64_t> strides,
+                                std::optional<int64_t> simdIndex) {
   auto reversedLabels = llvm::reverse(attr.getLabels());
   auto reversedShapes = llvm::reverse(attr.getShapes());
   for (auto [nameAttr, shape] : llvm::zip(reversedLabels, reversedShapes)) {
@@ -89,44 +72,138 @@
     if (isLaneDimension(dim))
       continue;
     int64_t stride = strides.contains(dim) ? strides[dim] : 1;
-    ranges[dim] = DimensionalRange(0, shape, stride);
-    state[dim] = ranges[dim].begin();
+    state.ranges[dim] = DimensionalRange(0, shape, stride);
+    state.iterators[dim] = state.ranges[dim].begin();
+    maxIterations *= shape / stride;
+    if (simdIndex) {
+      int64_t index = simdIndex.value();
+      if (!state.simdToLayoutDim.contains(index))
+        state.simdToLayoutDim[index] = {};
+      state.simdToLayoutDim[index].insert(dim);
+    }
   }
 }
 
 LayoutIterator::LayoutIterator(LayoutAttr &attr,
                                DenseMap<LayoutDimension, int64_t> strides) {
-  for (PerDimLayoutAttr perDimAttr : attr.getLayouts()) {
-    initialize(perDimAttr, strides);
+  for (auto perDimAttr : llvm::enumerate(attr.getLayouts())) {
+    initialize(perDimAttr.value(), strides, perDimAttr.index());
+  }
+}
+
+LayoutIterator::LayoutIterator(LayoutAttr &attr) {
+  DenseMap<LayoutDimension, int64_t> strides;
+  for (auto [idx, attr] : llvm::enumerate(attr.getLayouts())) {
+    initialize(attr, strides, idx);
+  }
+}
+
+LayoutIterator::LayoutIterator(LayoutAttr &attr,
+                               DenseMap<LayoutDimension, int64_t> strides,
+                               int64_t simtIndex) {
+  for (auto [idx, attr] : llvm::enumerate(attr.getLayouts())) {
+    if (idx != simtIndex)
+      continue;
+    initialize(attr, strides, idx);
+  }
+}
+
+LayoutIterator::LayoutIterator(LayoutAttr &attr, int64_t simtIndex) {
+  DenseMap<LayoutDimension, int64_t> strides;
+  for (auto [idx, attr] : llvm::enumerate(attr.getLayouts())) {
+    if (idx != simtIndex)
+      continue;
+    initialize(attr, strides, idx);
   }
 }
 
 LayoutIterator::LayoutIterator(PerDimLayoutAttr &attr,
                                DenseMap<LayoutDimension, int64_t> strides) {
-  initialize(attr, strides);
+  initialize(attr, strides, std::nullopt);
 }
 
-/// The iterator is done when it returns back to
-/// its begin state.
-bool LayoutIterator::iterationComplete() {
-  bool complete{true};
-  for (auto &[dim, it] : state) {
+LayoutIterator &LayoutIterator::operator++() {
+  for (auto &[dim, it] : state.iterators) {
     if (frozenDimensions.contains(dim))
       continue;
-    if (it != ranges[dim].begin()) {
-      complete = false;
-      break;
+    ++it;
+    if (it == state.ranges[dim].end()) {
+      it = state.ranges[dim].begin();
+      continue;
     }
+    break;
   }
-  return complete;
+  ++iterations;
+  return *this;
 }
 
+/// The iterator is done when all the loops are complete.
+bool LayoutIterator::iterationComplete() { return iterations == maxIterations; }
+
 void LayoutIterator::apply(
     std::function<void(const LayoutIterator::State &)> callback) {
-  do {
+  for (; !iterationComplete(); ++(*this)) {
     callback(state);
-    ++(*this);
-  } while (!iterationComplete());
+  }
+}
+
+// Get the offset into the SIMT vector corresponding to the incoming iterator.
+// The returned offsets will always be the same shape as the labels array.
+// Groups vector dimensions together. Assumes last dimension is vector
+// dimension.
+SmallVector<int64_t> LayoutIterator::State::computeSIMTIndex() const {
+  SmallVector<int64_t> offset;
+  std::optional<int64_t> vecOffset;
+  for (auto label : labels) {
+    for (auto [name, it] : iterators) {
+      if (name != label)
+        continue;
+      if (isBatchDimension(name)) {
+        offset.push_back(it.getPosition());
+        continue;
+      }
+      if (isVectorDimension(name)) {
+        int64_t step{1};
+        if (name == LayoutDimension::VECTORY) {
+          step = ranges.lookup(LayoutDimension::VECTORX).stop;
+        }
+        vecOffset = vecOffset.value_or(0) + it.getPosition() * step;
+      }
+    }
+  }
+  if (vecOffset)
+    offset.push_back(vecOffset.value());
+  return offset;
+}
+
+SmallVector<int64_t>
+LayoutIterator::State::computeIteratorProjectedSIMTIndex() const {
+  SmallVector<int64_t> indices = computeSIMTIndex();
+  SmallVector<int64_t> projectedIndices;
+  for (size_t i = 0, e = labels.size(); i != e; ++i) {
+    for (auto [name, it] : iterators) {
+      if (name == labels[i])
+        projectedIndices.push_back(indices[i]);
+    }
+  }
+  return projectedIndices;
+}
+
+void LayoutIterator::erase(LayoutDimension dim) {
+  if (state.contains(dim))
+    state.erase(dim);
+}
+
+LayoutIterator LayoutIterator::getBatchIterator() const {
+  LayoutIterator projectedIterator = *this;
+  for (auto [dim, it] : state.iterators) {
+    if (!isBatchDimension(dim)) {
+      DimensionalRange range = state.ranges.lookup(dim);
+      projectedIterator.maxIterations /= (range.stop / range.step);
+      projectedIterator.erase(dim);
+    }
+  }
+  return projectedIterator;
 }
 
 // clang-format off
diff --git a/llvm-external-projects/iree-dialects/test/lib/VectorExt/TestIterators.cpp b/llvm-external-projects/iree-dialects/test/lib/VectorExt/TestIterators.cpp
index fce0953..5aacc9e 100644
--- a/llvm-external-projects/iree-dialects/test/lib/VectorExt/TestIterators.cpp
+++ b/llvm-external-projects/iree-dialects/test/lib/VectorExt/TestIterators.cpp
@@ -33,7 +33,7 @@
   }
   // Prints the layout so that LIT can test it for correctness.
   static void printFn(const LayoutIterator::State &state) {
-    for (const auto &[dim, it] : state) {
+    for (const auto &[dim, it] : state.iterators) {
       llvm::outs() << stringifyLayoutDimension(dim).str() + ":" +
                           std::to_string(*it) + ", ";
     }
@@ -66,7 +66,7 @@
     DenseMap<LayoutDimension, int64_t> strides;
     LayoutIterator iterator(layout, strides);
     auto frozenIterator = createFrozenIterator(op->getContext(), strides);
-    iterator.maybeFreezeAndConcatenate(frozenIterator);
+    iterator.maybeFreezeAndConcatenate(frozenIterator.getState());
     iterator.apply(printFn);
   }
   void runOnOperation() override {