[VectorDistribution] Infer operand transpose for vector.contract distribution (#16414)
This patch also adds documentation on MFMA instruction layouts and a
test for MTM contract type.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/AMDGPUDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/AMDGPUDistributionPatterns.cpp
index c6ea79b..ec21950 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/AMDGPUDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/AMDGPUDistributionPatterns.cpp
@@ -18,15 +18,94 @@
enum class ContractMatrixType { A, B, C, D };
enum class ContractType { MM, MMT, MTM, MTMT, UNSUPPORTED };
-// The naming scheme for these operators is:
-// InputType_MxNxK_OutputType.
+/// We define AMD MFMA instruction layouts only for contract type MM, i.e. C(i,
+/// k) += A(i, j) * B(j, k). We call these canonical layouts: layoutA, layoutB,
+/// layoutC, corresponding to the A, B, C matrices.
+///
+/// For any other contract type, the layout is simply transposed for that
+/// operand. For example, for MMT, the layouts used should be layoutA,
+/// layoutB.T, layoutC. For an easier understanding of this transposition,
+/// think of the transpose simply being outside the contract:
+///
+/// vector.contract {type = MMT} %a, %b
+///
+/// is equivalent to
+///
+/// %bt = vector.transpose %b
+/// vector.contract {type = MM} %a, %bt
+///
+/// Now, you would assign layouts based on contract type MM, and would get the
+/// right layout for %b by transposing the layout for B.
+///
+/// Now that we have defined what layoutA, layoutB, layoutC are, we will define
+/// what the canonical layouts are for each MFMA instruction. These are
+/// represented as the original matrix, with elements representing which thread
+/// id in the subgroup gets which element.
+/// These layouts were referenced from
+/// https://github.com/ROCm/amd_matrix_instruction_calculator
+///
+/// The naming scheme for these operators is InputType_MxNxK_OutputType.
enum class MFMAType {
+ /// layoutA:
+ /// 0 0 0 0 16 16 16 16 32 32 32 32 48 48 48 48
+ /// 1 1 1 1 17 17 17 17 33 33 33 33 49 49 49 49
+ /// 2 2 2 2 18 18 18 18 34 34 34 34 50 50 50 50
+ /// ...
+ /// 15 15 15 15 31 31 31 31 47 47 47 47 63 63 63 63
+ ///
+ /// layoutB:
+ /// Transpose of layoutA
+ ///
+ /// layoutC:
+ /// Same as layoutB
F16_16x16x16_F32,
+ /// layoutA:
+ /// 0 0 0 0 32 32 32 32
+ /// 1 1 1 1 33 33 33 33
+ /// 2 2 2 2 34 34 34 34
+ /// ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮
+ /// 31 31 31 31 63 63 63 63
+ ///
+ /// layoutB:
+ /// Transpose of layoutA
+ ///
+ /// layoutC:
+ /// 0 1 2 ... 31
+ /// 0 1 2 ... 31
+ /// 0 1 2 ... 31
+ /// 0 1 2 ... 31
+ /// 32 33 34 ... 63
+ /// 32 33 34 ... 63
+ /// 32 33 34 ... 63
+ /// 32 33 34 ... 63
+ /// 0 1 2 ... 31
+ /// ⋮ ⋮ ⋮ ... ⋮
+ /// 32 33 34 ... 63
+ /// 0 1 2 ... 31
+ /// ⋮ ⋮ ⋮ ... ⋮
+ /// 32 33 34 ... 63
+ /// 0 1 2 ... 31
+ /// 0 1 2 ... 31
+ /// 0 1 2 ... 31
+ /// 0 1 2 ... 31
+ /// 32 33 34 ... 63
+ /// 32 33 34 ... 63
+ /// 32 33 34 ... 63
+ /// 32 33 34 ... 63
F16_32x32x8_F32,
};
namespace {
+static bool isOperandATransposed(ContractType contractType) {
+ return (contractType == ContractType::MTM) ||
+ (contractType == ContractType::MTMT);
+}
+
+static bool isOperandBTransposed(ContractType contractType) {
+ return (contractType == ContractType::MMT) ||
+ (contractType == ContractType::MTMT);
+}
struct DistributeContractions final
: OpDistributionPattern<vector::ContractionOp> {
using OpDistributionPattern::OpDistributionPattern;
@@ -54,8 +133,7 @@
int64_t getReductionDimensionShape(int64_t rowBatch, int64_t colBatch,
ContractType contractType) const {
- if ((contractType == ContractType::MTM) ||
- (contractType == ContractType::MTMT)) {
+ if (isOperandATransposed(contractType)) {
return rowBatch;
}
return colBatch;
@@ -266,10 +344,20 @@
// Output is inferred MFMAType or none (if layout is not compatible with any
// MFMA layout).
std::optional<MFMAType>
- inferCompatibleMFMAType(ArrayRef<LayoutAttr> layouts) const {
+ inferCompatibleMFMAType(ArrayRef<LayoutAttr> layouts,
+ ContractType contractType) const {
std::optional<MFMAType> mfmaType{std::nullopt};
SmallVector<ContractMatrixType> matrixTypes{
ContractMatrixType::A, ContractMatrixType::B, ContractMatrixType::C};
+
+ // Canonical layouts for MFMA are transposes of each other.
+ if (isOperandATransposed(contractType)) {
+ matrixTypes[0] = ContractMatrixType::B;
+ }
+ if (isOperandBTransposed(contractType)) {
+ matrixTypes[1] = ContractMatrixType::A;
+ }
+
for (auto [layout, matrixType] : llvm::zip(layouts, matrixTypes)) {
mfmaType = inferMFMAType(layout, matrixType, mfmaType);
if (!mfmaType)
@@ -305,10 +393,6 @@
if (!resultLayout)
return failure();
- std::optional<MFMAType> mfmaType = inferCompatibleMFMAType(layouts);
- if (!mfmaType)
- return failure();
-
Type elementType =
llvm::cast<ShapedType>(operands[ACC].getType()).getElementType();
SmallVector<int64_t> vectorShape = resultLayout.getDistributedShape();
@@ -322,6 +406,11 @@
if (contractType == ContractType::UNSUPPORTED)
return failure();
+ std::optional<MFMAType> mfmaType =
+ inferCompatibleMFMAType(layouts, contractType);
+ if (!mfmaType)
+ return failure();
+
std::optional<int64_t> rowBatch = layouts[LHS].getBatchDim(0);
if (!rowBatch)
return failure();
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_contraction_distribution.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_contraction_distribution.mlir
index 03f4011..c77ca89 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_contraction_distribution.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_contraction_distribution.mlir
@@ -1,13 +1,24 @@
// RUN: iree-opt --iree-transform-dialect-interpreter --split-input-file --cse %s | FileCheck %s
+// Refer to the distribution pattern documentation for what layoutA, layoutB,
+// layoutC means and how these layouts are assigned based on the instruction
+// type.
+
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d1, d0)>
+
+// A: vector<16x16>, layout = layoutA
#row_layout = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX], [1, 16]>
#col_layout = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [1, 4, 4]>
+#layout_a = #iree_vector_ext.layout<#row_layout, #col_layout>
+
+// B: vector<16x16>, layout = transpose(layoutB) = layoutA
+// Since shapes are also same, we can use the same layout attribute, layout_a.
+
+// C: vector<16x16>, layout = layoutC
#row_layout2 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEY, VECTORX], [1, 4, 4]>
#col_layout2 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [1, 16]>
-#layout_a = #iree_vector_ext.layout<#row_layout, #col_layout>
#layout_c = #iree_vector_ext.layout<#row_layout2, #col_layout2>
builtin.module attributes { transform.with_named_sequence } {
func.func @distribute_mfma_16x16x16_mmt(%a : vector<16x16xf16>, %b : vector<16x16xf16>, %c : vector<16x16xf32>) -> vector<16x16xf32> {
@@ -23,7 +34,7 @@
%output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
"__vector_layout_test_anchor_operand_0" = #layout_a,
- "__vector_layout_test_anchor_operand_1" = #layout_c,
+ "__vector_layout_test_anchor_operand_1" = #layout_a,
"__vector_layout_test_anchor_operand_2" = #layout_c,
"__vector_layout_test_anchor_result_0" = #layout_c
}
@@ -39,17 +50,64 @@
// -----
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d1, d0)>
+
+// A: vector<32x128>, layout = layoutA
+#row_layout = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX], [2, 16]>
+#col_layout = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [8, 4, 4]>
+#layout_a = #iree_vector_ext.layout<#row_layout, #col_layout>
+
+// B: vector<64x128>, layout = transpose(layoutB) = layoutA
+#row_layout2 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX], [4, 16]>
+#col_layout2 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [8, 4, 4]>
+#layout_b = #iree_vector_ext.layout<#row_layout2, #col_layout2>
+
+// C: vector<32x64>, layout = layoutC
+#row_layout3 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEY, VECTORX], [2, 4, 4]>
+#col_layout3 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [4, 16]>
+#layout_c = #iree_vector_ext.layout<#row_layout3, #col_layout3>
+builtin.module attributes { transform.with_named_sequence } {
+ func.func @distribute_mfma_16x16x16_mmt_batch(%a : vector<32x128xf16>, %b : vector<64x128xf16>, %c : vector<32x64xf32>) -> vector<32x64xf32> {
+ // CHECK-LABEL: distribute_mfma_16x16x16_mmt_batch
+ // CHECK-COUNT-64: amdgpu.mfma {{.*}}, vector<4xf32>
+ %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>,
+ "__vector_layout_test_anchor_operand_0" = #layout_a,
+ "__vector_layout_test_anchor_operand_1" = #layout_b,
+ "__vector_layout_test_anchor_operand_2" = #layout_c,
+ "__vector_layout_test_anchor_result_0" = #layout_c
+ }
+ %a, %b, %c : vector<32x128xf16>, vector<64x128xf16> into vector<32x64xf32>
+ return %output : vector<32x64xf32>
+ }
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.test_amdgpu_contraction_distribution %top_level_func : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// A: vector<32x8>, layout = layoutA
#row_layout = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX], [1, 32]>
#col_layout = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [1, 2, 4]>
+#layout_a = #iree_vector_ext.layout<#row_layout, #col_layout>
+
+// B: vector<8x32>, layout = layoutB
#row_layout1 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEY, VECTORX], [1, 2, 4]>
#col_layout1 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [1, 32]>
+#layout_b = #iree_vector_ext.layout<#row_layout1, #col_layout1>
+
+// C: vector<32x32>, layout = layoutC
#row_layout2 = #iree_vector_ext.per_dim_layout<[BATCHX, VECTORY, LANEY, VECTORX], [1, 4, 2, 4]>
#col_layout2 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [1, 32]>
-#layout_a = #iree_vector_ext.layout<#row_layout, #col_layout>
-#layout_b = #iree_vector_ext.layout<#row_layout1, #col_layout1>
#layout_c = #iree_vector_ext.layout<#row_layout2, #col_layout2>
builtin.module attributes { transform.with_named_sequence } {
func.func @distribute_mfma_32x32x8_mm(%a : vector<32x8xf16>, %b : vector<8x32xf16>, %c : vector<32x32xf32>) -> vector<32x32xf32> {
@@ -81,22 +139,28 @@
// -----
-#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
-#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map3 = affine_map<(d0, d1, d2) -> (d1, d0)>
-#row_layout = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX], [2, 16]>
-#col_layout = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [8, 4, 4]>
-#row_layout2 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEY, VECTORX], [4, 4, 4]>
-#col_layout2 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [8, 16]>
-#row_layout3 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEY, VECTORX], [2, 4, 4]>
-#col_layout3 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [4, 16]>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d0)>
+#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// A: vector<8x64>, layout = transpose(layoutA) = layoutB
+#row_layout = #iree_vector_ext.per_dim_layout<[BATCHX, LANEY, VECTORX], [1, 2, 4]>
+#col_layout = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [2, 32]>
#layout_a = #iree_vector_ext.layout<#row_layout, #col_layout>
-#layout_b = #iree_vector_ext.layout<#row_layout2, #col_layout2>
-#layout_c = #iree_vector_ext.layout<#row_layout3, #col_layout3>
+
+// B: vector<8x32>, layout = layoutB
+// We can use the same layout attribute, layout_a, since the shapes are same.
+#row_layout1 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEY, VECTORX], [1, 2, 4]>
+#col_layout1 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [1, 32]>
+#layout_b = #iree_vector_ext.layout<#row_layout1, #col_layout1>
+
+// C: vector<64x32>, layout = layoutC
+#row_layout2 = #iree_vector_ext.per_dim_layout<[BATCHX, VECTORY, LANEY, VECTORX], [2, 4, 2, 4]>
+#col_layout2 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [1, 32]>
+#layout_c = #iree_vector_ext.layout<#row_layout2, #col_layout2>
builtin.module attributes { transform.with_named_sequence } {
- func.func @distribute_mfma_16x16x16_mmt_batch(%a : vector<32x128xf16>, %b : vector<64x128xf16>, %c : vector<32x64xf32>) -> vector<32x64xf32> {
- // CHECK-LABEL: distribute_mfma_16x16x16_mmt_batch
- // CHECK-COUNT-64: amdgpu.mfma {{.*}}, vector<4xf32>
+ func.func @distribute_mfma_32x32x8_mtm(%a : vector<8x64xf16>, %b : vector<8x32xf16>, %c : vector<64x32xf32>) -> vector<64x32xf32> {
+ // CHECK-LABEL: distribute_mfma_32x32x8_mtm
%output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
"__vector_layout_test_anchor_operand_0" = #layout_a,
@@ -104,8 +168,14 @@
"__vector_layout_test_anchor_operand_2" = #layout_c,
"__vector_layout_test_anchor_result_0" = #layout_c
}
- %a, %b, %c : vector<32x128xf16>, vector<64x128xf16> into vector<32x64xf32>
- return %output : vector<32x64xf32>
+ %a, %b, %c : vector<8x64xf16>, vector<8x32xf16> into vector<64x32xf32>
+ // CHECK-DAG: %[[A1:.+]] = vector.extract %[[A:.+]][0, 0] : vector<4xf16> from vector<1x2x4xf16>
+ // CHECK-DAG: %[[B1:.+]] = vector.extract %[[B:.+]][0, 0] : vector<4xf16> from vector<1x1x4xf16>
+ // CHECK-DAG: %{{.*}} = amdgpu.mfma %[[A1]] * %[[B1]]
+ // CHECK-DAG: %[[A2:.+]] = vector.extract %[[A]][0, 1] : vector<4xf16> from vector<1x2x4xf16>
+ // CHECK-DAG: %{{.*}} = amdgpu.mfma %[[A2]] * %[[B1]]
+ // CHECK-NOT: amdgpu.mfma
+ return %output : vector<64x32xf32>
}
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op