[VectorDistribution] Infer operand transpose for vector.contract distribution (#16414)

This patch also adds documentation on MFMA instruction layouts and a
test for MTM contract type.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/AMDGPUDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/AMDGPUDistributionPatterns.cpp
index c6ea79b..ec21950 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/AMDGPUDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/AMDGPUDistributionPatterns.cpp
@@ -18,15 +18,94 @@
 enum class ContractMatrixType { A, B, C, D };
 enum class ContractType { MM, MMT, MTM, MTMT, UNSUPPORTED };
 
-// The naming scheme for these operators is:
-// InputType_MxNxK_OutputType.
+/// We define AMD MFMA instruction layouts only for contract type MM, i.e. C(i,
+/// k) += A(i, j) * B(j, k). We call these canonical layouts: layoutA, layoutB,
+/// layoutC, corresponding to the A, B, C matrices.
+///
+/// For any other contract type, the layout is simply transposed for that
+/// operand. For example, for MMT, the layouts used should be layoutA,
+/// layoutB.T, layoutC. For an easier understanding of this transposition,
+/// think of the transpose simply being outside the contract:
+///
+/// vector.contract {type = MMT} %a, %b
+///
+/// is equivalent to
+///
+/// %bt = vector.transpose %b
+/// vector.contract {type = MM} %a, %bt
+///
+/// Now, you would assign layouts based on contract type MM, and would get the
+/// right layout for %b by transposing the layout for B.
+///
+/// Now that we have defined what layoutA, layoutB, layoutC are, we will define
+/// what the canonical layouts are for each MFMA instruction. These are
+/// represented as the original matrix, with elements representing which thread
+/// id in the subgroup gets which element.
+/// These layouts were referenced from
+/// https://github.com/ROCm/amd_matrix_instruction_calculator
+///
+/// The naming scheme for these operators is InputType_MxNxK_OutputType.
 enum class MFMAType {
+  /// layoutA:
+  ///  0  0  0  0 16 16 16 16 32 32 32 32 48 48 48 48
+  ///  1  1  1  1 17 17 17 17 33 33 33 33 49 49 49 49
+  ///  2  2  2  2 18 18 18 18 34 34 34 34 50 50 50 50
+  /// ...
+  /// 15 15 15 15 31 31 31 31 47 47 47 47 63 63 63 63
+  ///
+  /// layoutB:
+  /// Transpose of layoutA
+  ///
+  /// layoutC:
+  /// Same as layoutB
   F16_16x16x16_F32,
+  /// layoutA:
+  ///  0  0  0  0 32 32 32 32
+  ///  1  1  1  1 33 33 33 33
+  ///  2  2  2  2 34 34 34 34
+  ///  ⋮  ⋮  ⋮  ⋮  ⋮  ⋮  ⋮  ⋮
+  /// 31 31 31 31 63 63 63 63
+  ///
+  /// layoutB:
+  /// Transpose of layoutA
+  ///
+  /// layoutC:
+  ///  0  1  2 ... 31
+  ///  0  1  2 ... 31
+  ///  0  1  2 ... 31
+  ///  0  1  2 ... 31
+  /// 32 33 34 ... 63
+  /// 32 33 34 ... 63
+  /// 32 33 34 ... 63
+  /// 32 33 34 ... 63
+  ///  0  1  2 ... 31
+  ///  ⋮  ⋮  ⋮ ...  ⋮
+  /// 32 33 34 ... 63
+  ///  0  1  2 ... 31
+  ///  ⋮  ⋮  ⋮ ...  ⋮
+  /// 32 33 34 ... 63
+  ///  0  1  2 ... 31
+  ///  0  1  2 ... 31
+  ///  0  1  2 ... 31
+  ///  0  1  2 ... 31
+  /// 32 33 34 ... 63
+  /// 32 33 34 ... 63
+  /// 32 33 34 ... 63
+  /// 32 33 34 ... 63
   F16_32x32x8_F32,
 };
 
 namespace {
 
+static bool isOperandATransposed(ContractType contractType) {
+  return (contractType == ContractType::MTM) ||
+         (contractType == ContractType::MTMT);
+}
+
+static bool isOperandBTransposed(ContractType contractType) {
+  return (contractType == ContractType::MMT) ||
+         (contractType == ContractType::MTMT);
+}
 struct DistributeContractions final
     : OpDistributionPattern<vector::ContractionOp> {
   using OpDistributionPattern::OpDistributionPattern;
@@ -54,8 +133,7 @@
 
   int64_t getReductionDimensionShape(int64_t rowBatch, int64_t colBatch,
                                      ContractType contractType) const {
-    if ((contractType == ContractType::MTM) ||
-        (contractType == ContractType::MTMT)) {
+    if (isOperandATransposed(contractType)) {
       return rowBatch;
     }
     return colBatch;
@@ -266,10 +344,20 @@
   // Output is inferred MFMAType or none (if layout is not compatible with any
   // MFMA layout).
   std::optional<MFMAType>
-  inferCompatibleMFMAType(ArrayRef<LayoutAttr> layouts) const {
+  inferCompatibleMFMAType(ArrayRef<LayoutAttr> layouts,
+                          ContractType contractType) const {
     std::optional<MFMAType> mfmaType{std::nullopt};
     SmallVector<ContractMatrixType> matrixTypes{
         ContractMatrixType::A, ContractMatrixType::B, ContractMatrixType::C};
+
+    // Canonical layouts for MFMA are transposes of each other.
+    if (isOperandATransposed(contractType)) {
+      matrixTypes[0] = ContractMatrixType::B;
+    }
+    if (isOperandBTransposed(contractType)) {
+      matrixTypes[1] = ContractMatrixType::A;
+    }
+
     for (auto [layout, matrixType] : llvm::zip(layouts, matrixTypes)) {
       mfmaType = inferMFMAType(layout, matrixType, mfmaType);
       if (!mfmaType)
@@ -305,10 +393,6 @@
     if (!resultLayout)
       return failure();
 
-    std::optional<MFMAType> mfmaType = inferCompatibleMFMAType(layouts);
-    if (!mfmaType)
-      return failure();
-
     Type elementType =
         llvm::cast<ShapedType>(operands[ACC].getType()).getElementType();
     SmallVector<int64_t> vectorShape = resultLayout.getDistributedShape();
@@ -322,6 +406,11 @@
     if (contractType == ContractType::UNSUPPORTED)
       return failure();
 
+    std::optional<MFMAType> mfmaType =
+        inferCompatibleMFMAType(layouts, contractType);
+    if (!mfmaType)
+      return failure();
+
     std::optional<int64_t> rowBatch = layouts[LHS].getBatchDim(0);
     if (!rowBatch)
       return failure();
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_contraction_distribution.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_contraction_distribution.mlir
index 03f4011..c77ca89 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_contraction_distribution.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_contraction_distribution.mlir
@@ -1,13 +1,24 @@
 // RUN: iree-opt --iree-transform-dialect-interpreter --split-input-file --cse %s | FileCheck %s
 
+// Refer to the distribution pattern documentation for what layoutA, layoutB,
+// layoutC means and how these layouts are assigned based on the instruction
+// type.
+
 #map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
 #map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map3 = affine_map<(d0, d1, d2) -> (d1, d0)>
+
+// A: vector<16x16>, layout = layoutA 
 #row_layout = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX], [1, 16]>
 #col_layout = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [1, 4, 4]>
+#layout_a = #iree_vector_ext.layout<#row_layout, #col_layout>
+
+// B: vector<16x16>, layout = transpose(layoutB) = layoutA
+// Since shapes are also same, we can use the same layout attribute, layout_a.
+
+// C: vector<16x16>, layout = layoutC
 #row_layout2 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEY, VECTORX], [1, 4, 4]>
 #col_layout2 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [1, 16]>
-#layout_a = #iree_vector_ext.layout<#row_layout, #col_layout>
 #layout_c = #iree_vector_ext.layout<#row_layout2, #col_layout2>
 builtin.module attributes { transform.with_named_sequence } {
   func.func @distribute_mfma_16x16x16_mmt(%a : vector<16x16xf16>, %b : vector<16x16xf16>, %c : vector<16x16xf32>) -> vector<16x16xf32> {
@@ -23,7 +34,7 @@
     %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"],
                                kind = #vector.kind<add>,
                                "__vector_layout_test_anchor_operand_0" = #layout_a,
-                               "__vector_layout_test_anchor_operand_1" = #layout_c,
+                               "__vector_layout_test_anchor_operand_1" = #layout_a,
                                "__vector_layout_test_anchor_operand_2" = #layout_c,
                                "__vector_layout_test_anchor_result_0" = #layout_c
                                }
@@ -39,17 +50,64 @@
 
 // -----
 
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d1, d0)>
+
+// A: vector<32x128>, layout = layoutA
+#row_layout = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX], [2, 16]>
+#col_layout = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [8, 4, 4]>
+#layout_a = #iree_vector_ext.layout<#row_layout, #col_layout>
+
+// B: vector<64x128>, layout = transpose(layoutB) = layoutA
+#row_layout2 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX], [4, 16]>
+#col_layout2 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [8, 4, 4]>
+#layout_b = #iree_vector_ext.layout<#row_layout2, #col_layout2>
+
+// C: vector<32x64>, layout = layoutC
+#row_layout3 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEY, VECTORX], [2, 4, 4]>
+#col_layout3 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [4, 16]>
+#layout_c = #iree_vector_ext.layout<#row_layout3, #col_layout3>
+builtin.module attributes { transform.with_named_sequence } {
+  func.func @distribute_mfma_16x16x16_mmt_batch(%a : vector<32x128xf16>, %b : vector<64x128xf16>, %c : vector<32x64xf32>) -> vector<32x64xf32> {
+    // CHECK-LABEL: distribute_mfma_16x16x16_mmt_batch
+    // CHECK-COUNT-64: amdgpu.mfma {{.*}}, vector<4xf32>
+    %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"],
+                               kind = #vector.kind<add>,
+                               "__vector_layout_test_anchor_operand_0" = #layout_a,
+                               "__vector_layout_test_anchor_operand_1" = #layout_b,
+                               "__vector_layout_test_anchor_operand_2" = #layout_c,
+                               "__vector_layout_test_anchor_result_0" = #layout_c
+                               }
+                                %a, %b, %c : vector<32x128xf16>, vector<64x128xf16> into vector<32x64xf32>
+    return %output : vector<32x64xf32>
+  }
+  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.test_amdgpu_contraction_distribution %top_level_func : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 #map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
 #map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// A: vector<32x8>, layout = layoutA
 #row_layout = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX], [1, 32]>
 #col_layout = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [1, 2, 4]>
+#layout_a = #iree_vector_ext.layout<#row_layout, #col_layout>
+
+// B: vector<8x32>, layout = layoutB
 #row_layout1 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEY, VECTORX], [1, 2, 4]>
 #col_layout1 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [1, 32]>
+#layout_b = #iree_vector_ext.layout<#row_layout1, #col_layout1>
+
+// C: vector<32x32>, layout = layoutC
 #row_layout2 = #iree_vector_ext.per_dim_layout<[BATCHX, VECTORY, LANEY, VECTORX], [1, 4, 2, 4]>
 #col_layout2 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [1, 32]>
-#layout_a = #iree_vector_ext.layout<#row_layout, #col_layout>
-#layout_b = #iree_vector_ext.layout<#row_layout1, #col_layout1>
 #layout_c = #iree_vector_ext.layout<#row_layout2, #col_layout2>
 builtin.module attributes { transform.with_named_sequence } {
   func.func @distribute_mfma_32x32x8_mm(%a : vector<32x8xf16>, %b : vector<8x32xf16>, %c : vector<32x32xf32>) -> vector<32x32xf32> {
@@ -81,22 +139,28 @@
 
 // -----
 
-#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
-#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map3 = affine_map<(d0, d1, d2) -> (d1, d0)>
-#row_layout = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX], [2, 16]>
-#col_layout = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [8, 4, 4]>
-#row_layout2 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEY, VECTORX], [4, 4, 4]>
-#col_layout2 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [8, 16]>
-#row_layout3 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEY, VECTORX], [2, 4, 4]>
-#col_layout3 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [4, 16]>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d0)>
+#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// A: vector<8x64>, layout = transpose(layoutA) = layoutB
+#row_layout = #iree_vector_ext.per_dim_layout<[BATCHX, LANEY, VECTORX], [1, 2, 4]>
+#col_layout = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [2, 32]>
 #layout_a = #iree_vector_ext.layout<#row_layout, #col_layout>
-#layout_b = #iree_vector_ext.layout<#row_layout2, #col_layout2>
-#layout_c = #iree_vector_ext.layout<#row_layout3, #col_layout3>
+
+// B: vector<8x32>, layout = layoutB
+// We can use the same layout attribute, layout_a, since the shapes are same.
+#row_layout1 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEY, VECTORX], [1, 2, 4]>
+#col_layout1 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [1, 32]>
+#layout_b = #iree_vector_ext.layout<#row_layout1, #col_layout1>
+
+// C: vector<64x32>, layout = layoutC
+#row_layout2 = #iree_vector_ext.per_dim_layout<[BATCHX, VECTORY, LANEY, VECTORX], [2, 4, 2, 4]>
+#col_layout2 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [1, 32]>
+#layout_c = #iree_vector_ext.layout<#row_layout2, #col_layout2>
 builtin.module attributes { transform.with_named_sequence } {
-  func.func @distribute_mfma_16x16x16_mmt_batch(%a : vector<32x128xf16>, %b : vector<64x128xf16>, %c : vector<32x64xf32>) -> vector<32x64xf32> {
-    // CHECK-LABEL: distribute_mfma_16x16x16_mmt_batch
-    // CHECK-COUNT-64: amdgpu.mfma {{.*}}, vector<4xf32>
+  func.func @distribute_mfma_32x32x8_mtm(%a : vector<8x64xf16>, %b : vector<8x32xf16>, %c : vector<64x32xf32>) -> vector<64x32xf32> {
+    // CHECK-LABEL: distribute_mfma_32x32x8_mtm
     %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"],
                                kind = #vector.kind<add>,
                                "__vector_layout_test_anchor_operand_0" = #layout_a,
@@ -104,8 +168,14 @@
                                "__vector_layout_test_anchor_operand_2" = #layout_c,
                                "__vector_layout_test_anchor_result_0" = #layout_c
                                }
-                                %a, %b, %c : vector<32x128xf16>, vector<64x128xf16> into vector<32x64xf32>
-    return %output : vector<32x64xf32>
+                                %a, %b, %c : vector<8x64xf16>, vector<8x32xf16> into vector<64x32xf32>
+    // CHECK-DAG: %[[A1:.+]] = vector.extract %[[A:.+]][0, 0] : vector<4xf16> from vector<1x2x4xf16>
+    // CHECK-DAG: %[[B1:.+]] = vector.extract %[[B:.+]][0, 0] : vector<4xf16> from vector<1x1x4xf16>
+    // CHECK-DAG: %{{.*}} = amdgpu.mfma %[[A1]] * %[[B1]]
+    // CHECK-DAG: %[[A2:.+]] = vector.extract %[[A]][0, 1] : vector<4xf16> from vector<1x2x4xf16>
+    // CHECK-DAG: %{{.*}} = amdgpu.mfma %[[A2]] * %[[B1]]
+    // CHECK-NOT: amdgpu.mfma
+    return %output : vector<64x32xf32>
   }
   transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
     %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op