[LLVMGPU][ROCM][Layoutv1] Landing Implementation of WMMA on layoutV1 (#17580)
Introducing WMMA on layoutv1. This feature is required/a prerequisite
for getting FA2 on RDNA3 GPUs.
This PR brings:
1. WMMA on layoutV1
2. Ensuring VectorX is specified to prevent unexpected behavior for
"step" setting in SIMT index.
3. Update of amdgpu_distribute_vectors/distribution of contract such
that it follow the newer style of `set_contraction_layout_attributes`
for setting intrinsic types as opposed to the less safer method of
inferring intrinsic types.
4. Update of tests to match above.
Co-authored-by: Groverkss <groverkss@gmail.com>
---------
Co-authored-by: Kunwar Grover <groverkss@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index ffe82da..554ecbf 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -290,8 +290,9 @@
auto aKLayout = inner;
auto bKLayout = inner;
auto bNLayout = outer;
- auto cMLayout = PerDimLayoutAttr::get(context, {laneY, vectorX}, {8, 2});
- auto cNLayout = outer;
+ auto cMLayout =
+ PerDimLayoutAttr::get(context, {vectorY, laneY, vectorX}, {8, 2, 1});
+ auto cNLayout = PerDimLayoutAttr::get(context, {laneX}, {16});
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
index 17492ea..94e8a16 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
@@ -1647,6 +1647,7 @@
<< "invalid opaque mma layout for annotation " << mmaType;
}
+ contract->setAttr("iree.amdgpu.mma", mmaType);
auto [aLayout, bLayout, cLayout] = *maybeLayouts;
contract->setAttr("__vector_layout_test_anchor_operand_0", aLayout);
contract->setAttr("__vector_layout_test_anchor_operand_1", bLayout);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/AMDGPUDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/AMDGPUDistributionPatterns.cpp
index 39cdd7c..6253dd7 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/AMDGPUDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/AMDGPUDistributionPatterns.cpp
@@ -7,6 +7,7 @@
#include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h"
#include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h"
#include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h"
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
@@ -18,83 +19,6 @@
enum class ContractMatrixType { A, B, C, D };
enum class ContractType { MM, MMT, MTM, MTMT, UNSUPPORTED };
-/// We define AMD MFMA instruction layouts only for contract type MM, i.e. C(i,
-/// k) += A(i, j) * B(j, k). We call these canonical layouts: layoutA, layoutB,
-/// layoutC, corresponding to the A, B, C matrices.
-///
-/// For any other contract type, the layout is simply transposed for that
-/// operand. For example, for MMT, the layouts used should be layoutA,
-/// layoutB.T, layoutC. For an easier understanding of this transposition,
-/// think of the transpose simply being outside the contract:
-///
-/// vector.contract {type = MMT} %a, %b
-///
-/// is equivalent to
-///
-/// %bt = vector.transpose %b
-/// vector.contract {type = MM} %a, %bt
-///
-/// Now, you would assign layouts based on contract type MM, and would get the
-/// right layout for %b by transposing the layout for B.
-///
-/// Now that we have defined what layoutA, layoutB, layoutC are, we will define
-/// what the canonical layouts are for each MFMA instruction. These are
-/// represented as the original matrix, with elements representing which thread
-/// id in the subgroup gets which element.
-/// These layouts were referenced from
-/// https://github.com/ROCm/amd_matrix_instruction_calculator
-///
-/// The naming scheme for these operators is InputType_MxNxK_OutputType.
-enum class MFMAType {
- /// layoutA:
- /// 0 0 0 0 16 16 16 16 32 32 32 32 48 48 48 48
- /// 1 1 1 1 17 17 17 17 33 33 33 33 49 49 49 49
- /// 2 2 2 2 18 18 18 18 34 34 34 34 50 50 50 50
- /// ...
- /// 15 15 15 15 31 31 31 31 47 47 47 47 63 63 63 63
- ///
- /// layoutB:
- /// Transpose of layoutA
- ///
- /// layoutC:
- /// Same as layoutB
- F16_16x16x16_F32,
- /// layoutA:
- /// 0 0 0 0 32 32 32 32
- /// 1 1 1 1 33 33 33 33
- /// 2 2 2 2 34 34 34 34
- /// ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮
- /// 31 31 31 31 63 63 63 63
- ///
- /// layoutB:
- /// Transpose of layoutA
- ///
- /// layoutC:
- /// 0 1 2 ... 31
- /// 0 1 2 ... 31
- /// 0 1 2 ... 31
- /// 0 1 2 ... 31
- /// 32 33 34 ... 63
- /// 32 33 34 ... 63
- /// 32 33 34 ... 63
- /// 32 33 34 ... 63
- /// 0 1 2 ... 31
- /// ⋮ ⋮ ⋮ ... ⋮
- /// 32 33 34 ... 63
- /// 0 1 2 ... 31
- /// ⋮ ⋮ ⋮ ... ⋮
- /// 32 33 34 ... 63
- /// 0 1 2 ... 31
- /// 0 1 2 ... 31
- /// 0 1 2 ... 31
- /// 0 1 2 ... 31
- /// 32 33 34 ... 63
- /// 32 33 34 ... 63
- /// 32 33 34 ... 63
- /// 32 33 34 ... 63
- F16_32x32x8_F32,
-};
-
namespace {
static bool isOperandATransposed(ContractType contractType) {
@@ -166,208 +90,6 @@
return ContractType::UNSUPPORTED;
}
- Value computeMMA(Value a, Value b, Value c, Location loc, OpBuilder &rewriter,
- MFMAType mfmaType) const {
- uint32_t m, n, k, blks;
- if (mfmaType == MFMAType::F16_16x16x16_F32) {
- m = n = k = 16;
- } else if (mfmaType == MFMAType::F16_32x32x8_F32) {
- m = n = 32;
- k = 8;
- }
- blks = 1;
- return rewriter.create<amdgpu::MFMAOp>(loc, c.getType(), m, n, k, blks, a,
- b, c);
- }
-
- PerDimLayoutAttr createPerDimLayout(MLIRContext *ctx,
- ArrayRef<LayoutDimension> dims,
- ArrayRef<int64_t> shapes) const {
- SmallVector<LayoutDimensionAttr> dimAttrs;
- for (auto dim : dims)
- dimAttrs.push_back(LayoutDimensionAttr::get(ctx, dim));
- return PerDimLayoutAttr::get(ctx, dimAttrs, shapes);
- }
-
- std::tuple<PerDimLayoutAttr, PerDimLayoutAttr> createCanonicalLayouts16x16x16(
- LayoutDimension batchRowLabel, int64_t batchRow,
- LayoutDimension batchColLabel, int64_t batchCol) const {
- MLIRContext *ctx = getContext();
- PerDimLayoutAttr rowLayout = createPerDimLayout(
- ctx, {batchRowLabel, LayoutDimension::LANEX}, {batchRow, 16});
- PerDimLayoutAttr colLayout = createPerDimLayout(
- ctx, {batchColLabel, LayoutDimension::LANEY, LayoutDimension::VECTORX},
- {batchCol, 4, 4});
- return {rowLayout, colLayout};
- }
-
- bool isCompatible16x16x16A(LayoutAttr layout, int64_t batchRow,
- int64_t batchCol) const {
- auto [rowLayout, colLayout] = createCanonicalLayouts16x16x16(
- LayoutDimension::BATCHX, batchRow, LayoutDimension::BATCHY, batchCol);
- LayoutAttr canonicalLayout =
- LayoutAttr::get(getContext(), {rowLayout, colLayout});
- return layout == canonicalLayout;
- }
-
- bool isCompatible16x16x16B(LayoutAttr layout, int64_t batchRow,
- int64_t batchCol) const {
- auto [colLayout, rowLayout] = createCanonicalLayouts16x16x16(
- LayoutDimension::BATCHY, batchCol, LayoutDimension::BATCHX, batchRow);
- LayoutAttr canonicalLayout =
- LayoutAttr::get(getContext(), {rowLayout, colLayout});
- return layout == canonicalLayout;
- }
-
- bool isCompatible16x16x16C(LayoutAttr layout, int64_t batchRow,
- int64_t batchCol) const {
- return isCompatible16x16x16B(layout, batchRow, batchCol);
- }
-
- std::tuple<PerDimLayoutAttr, PerDimLayoutAttr>
- createCanonicalLayouts32x32x8(LayoutDimension batchRowLabel, int64_t batchRow,
- LayoutDimension batchColLabel, int64_t batchCol,
- ContractMatrixType matrixType) const {
- MLIRContext *ctx = getContext();
- PerDimLayoutAttr rowLayout = createPerDimLayout(
- ctx, {batchRowLabel, LayoutDimension::LANEX}, {batchRow, 32});
- PerDimLayoutAttr colLayout;
- if (matrixType == ContractMatrixType::C) {
- colLayout =
- createPerDimLayout(ctx,
- {batchColLabel, LayoutDimension::VECTORY,
- LayoutDimension::LANEY, LayoutDimension::VECTORX},
- {batchCol, 4, 2, 4});
- } else {
- colLayout = createPerDimLayout(
- ctx,
- {batchColLabel, LayoutDimension::LANEY, LayoutDimension::VECTORX},
- {batchCol, 2, 4});
- }
- return {rowLayout, colLayout};
- }
-
- bool isCompatible32x32x8A(LayoutAttr layout, int64_t batchRow,
- int64_t batchCol) const {
- auto [rowLayout, colLayout] = createCanonicalLayouts32x32x8(
- LayoutDimension::BATCHX, batchRow, LayoutDimension::BATCHY, batchCol,
- ContractMatrixType::A);
- LayoutAttr canonicalLayout =
- LayoutAttr::get(getContext(), {rowLayout, colLayout});
- return layout == canonicalLayout;
- }
-
- bool isCompatible32x32x8B(LayoutAttr layout, int64_t batchRow,
- int64_t batchCol) const {
- auto [colLayout, rowLayout] = createCanonicalLayouts32x32x8(
- LayoutDimension::BATCHY, batchCol, LayoutDimension::BATCHX, batchRow,
- ContractMatrixType::B);
- LayoutAttr canonicalLayout =
- LayoutAttr::get(getContext(), {rowLayout, colLayout});
- return layout == canonicalLayout;
- }
-
- bool isCompatible32x32x8C(LayoutAttr layout, int64_t batchRow,
- int64_t batchCol) const {
- auto [colLayout, rowLayout] = createCanonicalLayouts32x32x8(
- LayoutDimension::BATCHY, batchCol, LayoutDimension::BATCHX, batchRow,
- ContractMatrixType::C);
- LayoutAttr canonicalLayout =
- LayoutAttr::get(getContext(), {rowLayout, colLayout});
- return layout == canonicalLayout;
- }
-
- bool isCompatible16x16x16(LayoutAttr layout, ContractMatrixType matrixType,
- int64_t batchRow, int64_t batchCol) const {
- switch (matrixType) {
- case ContractMatrixType::A:
- return isCompatible16x16x16A(layout, batchRow, batchCol);
- case ContractMatrixType::B:
- return isCompatible16x16x16B(layout, batchRow, batchCol);
- default:
- return isCompatible16x16x16C(layout, batchRow, batchCol);
- }
- return false;
- }
-
- bool isCompatible32x32x8(LayoutAttr layout, ContractMatrixType matrixType,
- int64_t batchRow, int64_t batchCol) const {
- switch (matrixType) {
- case ContractMatrixType::A:
- return isCompatible32x32x8A(layout, batchRow, batchCol);
- case ContractMatrixType::B:
- return isCompatible32x32x8B(layout, batchRow, batchCol);
- default:
- return isCompatible32x32x8C(layout, batchRow, batchCol);
- }
- return false;
- }
-
- bool isCompatible(LayoutAttr layout, ContractMatrixType matrixType,
- MFMAType mfmaType) const {
- std::optional<int64_t> batchRow = layout.getBatchDim(0);
- if (!batchRow)
- return false;
- std::optional<int64_t> batchCol = layout.getBatchDim(1);
- if (!batchCol)
- return false;
- switch (mfmaType) {
- case MFMAType::F16_16x16x16_F32:
- return isCompatible16x16x16(layout, matrixType, batchRow.value(),
- batchCol.value());
- case MFMAType::F16_32x32x8_F32:
- return isCompatible32x32x8(layout, matrixType, batchRow.value(),
- batchCol.value());
- default:
- return false;
- }
- return false;
- }
-
- // If we have a prior guess of the MFMA type, only evaluate that type.
- // Otherwise, evaluate all types to find a match.
- std::optional<MFMAType> inferMFMAType(LayoutAttr layout,
- ContractMatrixType matrixType,
- std::optional<MFMAType> prior) const {
- SmallVector<MFMAType> mfmaTypes;
- if (prior) {
- mfmaTypes.push_back(prior.value());
- } else {
- mfmaTypes = {MFMAType::F16_16x16x16_F32, MFMAType::F16_32x32x8_F32};
- }
- for (MFMAType mfmaType : mfmaTypes) {
- if (isCompatible(layout, matrixType, mfmaType))
- return mfmaType;
- }
- return std::nullopt;
- }
-
- // Inputs are LHS, RHS and ACC operands and corresponding layouts.
- // Output is inferred MFMAType or none (if layout is not compatible with any
- // MFMA layout).
- std::optional<MFMAType>
- inferCompatibleMFMAType(ArrayRef<LayoutAttr> layouts,
- ContractType contractType) const {
- std::optional<MFMAType> mfmaType{std::nullopt};
- SmallVector<ContractMatrixType> matrixTypes{
- ContractMatrixType::A, ContractMatrixType::B, ContractMatrixType::C};
-
- // Canonical layouts for MFMA are transposes of each other.
- if (isOperandATransposed(contractType)) {
- matrixTypes[0] = ContractMatrixType::B;
- }
- if (isOperandBTransposed(contractType)) {
- matrixTypes[1] = ContractMatrixType::A;
- }
-
- for (auto [layout, matrixType] : llvm::zip(layouts, matrixTypes)) {
- mfmaType = inferMFMAType(layout, matrixType, mfmaType);
- if (!mfmaType)
- return std::nullopt;
- }
- return mfmaType;
- }
-
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
DistributionSignature &signature,
PatternRewriter &rewriter) const override {
@@ -408,10 +130,12 @@
if (contractType == ContractType::UNSUPPORTED)
return failure();
- std::optional<MFMAType> mfmaType =
- inferCompatibleMFMAType(layouts, contractType);
- if (!mfmaType)
- return failure();
+ auto mmaAttr =
+ contractOp->getAttrOfType<IREE::GPU::MMAAttr>("iree.amdgpu.mma");
+ if (!mmaAttr) {
+ return rewriter.notifyMatchFailure(
+ contractOp, "missing iree.amdgpu.mma intrinsic attribute");
+ }
std::optional<int64_t> rowBatch = layouts[LHS].getBatchDim(0);
if (!rowBatch)
@@ -434,10 +158,13 @@
Value bMatrix = rewriter.create<vector::ExtractOp>(
loc, getDistributed(rewriter, operands[RHS], layouts[RHS]),
getIndices(contractType, ContractMatrixType::B, k, indices[1]));
- dMatrix = computeMMA(aMatrix, bMatrix, dMatrix, loc, rewriter,
- mfmaType.value());
+ dMatrix = mmaAttr
+ .buildMmaOperation(rewriter, loc, dMatrix.getType(),
+ aMatrix, bMatrix, dMatrix)
+ .value();
}
vector = rewriter.create<vector::InsertOp>(loc, dMatrix, vector, indices);
+ return success();
};
LayoutIterator iterator(resultLayout);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/BUILD.bazel
index c5f7fd2..695ecb0 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/BUILD.bazel
@@ -29,6 +29,7 @@
"//compiler/src/iree/compiler/Codegen/Common",
"//compiler/src/iree/compiler/Codegen/Common:VectorLayoutAnalysis",
"//compiler/src/iree/compiler/Codegen/Common/GPU:CommonGPUPasses",
+ "//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
"//compiler/src/iree/compiler/Codegen/Transforms",
"//compiler/src/iree/compiler/Codegen/Utils",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/CMakeLists.txt
index b94f2d6..49a9e7e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/CMakeLists.txt
@@ -41,6 +41,7 @@
iree::compiler::Codegen::Common
iree::compiler::Codegen::Common::GPU::CommonGPUPasses
iree::compiler::Codegen::Common::VectorLayoutAnalysis
+ iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
iree::compiler::Codegen::Transforms
iree::compiler::Codegen::Utils
iree::compiler::Dialect::HAL::IR
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
index ec04bc8..1757b5c 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
@@ -20,6 +20,7 @@
[
"amdgpu_chained_matmul.mlir",
"amdgpu_contraction_distribution.mlir",
+ "amdgpu_set_anchor_layouts.mlir",
"attention.mlir",
"attention_mfma.mlir",
"conv_pipeline_test_cuda.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index be60c97..2ff84aa 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -16,6 +16,7 @@
SRCS
"amdgpu_chained_matmul.mlir"
"amdgpu_contraction_distribution.mlir"
+ "amdgpu_set_anchor_layouts.mlir"
"attention.mlir"
"attention_mfma.mlir"
"cast_address_space_function.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_contraction_distribution.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_contraction_distribution.mlir
index b46a8d0..fd12413 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_contraction_distribution.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_contraction_distribution.mlir
@@ -4,6 +4,8 @@
// layoutC means and how these layouts are assigned based on the instruction
// type.
+#layout = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
+
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d1, d0)>
@@ -22,15 +24,6 @@
#layout_c = #iree_vector_ext.layout<#row_layout2, #col_layout2>
builtin.module attributes { transform.with_named_sequence } {
func.func @distribute_mfma_16x16x16_mmt(%a : vector<16x16xf16>, %b : vector<16x16xf16>, %c : vector<16x16xf32>) -> vector<16x16xf32> {
- // CHECK-LABEL: distribute_mfma_16x16x16_mmt
- // CHECK-SAME: %[[ARG0:.+]]: vector<16x16xf16>, %[[ARG1:.+]]: vector<16x16xf16>, %[[ARG2:.+]]: vector<16x16xf32>
- // CHECK-DAG: %[[C:.+]] = iree_vector_ext.to_simt %[[ARG2]] : vector<16x16xf32> -> vector<1x1x4xf32>
- // CHECK-DAG: %[[CV:.+]] = vector.extract %[[C]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
- // CHECK-DAG: %[[A:.+]] = iree_vector_ext.to_simt %[[ARG0]] : vector<16x16xf16> -> vector<1x1x4xf16>
- // CHECK-DAG: %[[AV:.+]] = vector.extract %[[A]][0, 0] : vector<4xf16> from vector<1x1x4xf16>
- // CHECK-DAG: %[[B:.+]] = iree_vector_ext.to_simt %[[ARG1]] : vector<16x16xf16> -> vector<1x1x4xf16>
- // CHECK-DAG: %[[BV:.+]] = vector.extract %[[B]][0, 0] : vector<4xf16> from vector<1x1x4xf16>
- // CHECK-DAG: %[[OUT:.+]] = amdgpu.mfma %[[AV]] * %[[BV]] + %[[CV]] {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
%output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
"__vector_layout_test_anchor_operand_0" = #layout_a,
@@ -42,14 +35,31 @@
return %output : vector<16x16xf32>
}
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %contract = transform.structured.match ops{["vector.contract"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ %layout16x16x16 = transform.param.constant #layout -> !transform.any_param
+ transform.iree.set_contraction_layout_attributes %contract, %layout16x16x16 : !transform.any_op, !transform.any_param
+
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
transform.iree.amdgpu_distribute_vectors %top_level_func test_conversion : (!transform.any_op) -> !transform.any_op
transform.yield
}
}
+// CHECK-LABEL: distribute_mfma_16x16x16_mmt
+
+// CHECK-SAME: %[[ARG0:.+]]: vector<16x16xf16>, %[[ARG1:.+]]: vector<16x16xf16>, %[[ARG2:.+]]: vector<16x16xf32>
+// CHECK-DAG: %[[C:.+]] = iree_vector_ext.to_simt %[[ARG2]] : vector<16x16xf32> -> vector<1x1x4xf32>
+// CHECK-DAG: %[[CV:.+]] = vector.extract %[[C]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
+// CHECK-DAG: %[[A:.+]] = iree_vector_ext.to_simt %[[ARG0]] : vector<16x16xf16> -> vector<1x1x4xf16>
+// CHECK-DAG: %[[AV:.+]] = vector.extract %[[A]][0, 0] : vector<4xf16> from vector<1x1x4xf16>
+// CHECK-DAG: %[[B:.+]] = iree_vector_ext.to_simt %[[ARG1]] : vector<16x16xf16> -> vector<1x1x4xf16>
+// CHECK-DAG: %[[BV:.+]] = vector.extract %[[B]][0, 0] : vector<4xf16> from vector<1x1x4xf16>
+// CHECK-DAG: %[[OUT:.+]] = amdgpu.mfma %[[AV]] * %[[BV]] + %[[CV]] {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
+
// -----
+#layout = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
+
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d1, d0)>
@@ -70,8 +80,6 @@
#layout_c = #iree_vector_ext.layout<#row_layout3, #col_layout3>
builtin.module attributes { transform.with_named_sequence } {
func.func @distribute_mfma_16x16x16_mmt_batch(%a : vector<32x128xf16>, %b : vector<64x128xf16>, %c : vector<32x64xf32>) -> vector<32x64xf32> {
- // CHECK-LABEL: distribute_mfma_16x16x16_mmt_batch
- // CHECK-COUNT-64: amdgpu.mfma {{.*}}, vector<4xf32>
%output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
"__vector_layout_test_anchor_operand_0" = #layout_a,
@@ -83,14 +91,24 @@
return %output : vector<32x64xf32>
}
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %contract = transform.structured.match ops{["vector.contract"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ %layout16x16x16 = transform.param.constant #layout -> !transform.any_param
+ transform.iree.set_contraction_layout_attributes %contract, %layout16x16x16 : !transform.any_op, !transform.any_param
+
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
transform.iree.amdgpu_distribute_vectors %top_level_func test_conversion : (!transform.any_op) -> !transform.any_op
transform.yield
}
}
+// CHECK-LABEL: distribute_mfma_16x16x16_mmt_batch
+
+// CHECK-COUNT-64: amdgpu.mfma {{.*}}, vector<4xf32>
+
// -----
+#layout = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>
+
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -111,15 +129,6 @@
#layout_c = #iree_vector_ext.layout<#row_layout2, #col_layout2>
builtin.module attributes { transform.with_named_sequence } {
func.func @distribute_mfma_32x32x8_mm(%a : vector<32x8xf16>, %b : vector<8x32xf16>, %c : vector<32x32xf32>) -> vector<32x32xf32> {
- // CHECK-LABEL: distribute_mfma_32x32x8_mm
- // CHECK-SAME: %[[ARG0:.+]]: vector<32x8xf16>, %[[ARG1:.+]]: vector<8x32xf16>, %[[ARG2:.+]]: vector<32x32xf32>
- // CHECK-DAG: %[[C:.+]] = iree_vector_ext.to_simt %[[ARG2]] : vector<32x32xf32> -> vector<1x1x16xf32>
- // CHECK-DAG: %[[CV:.+]] = vector.extract %[[C]][0, 0] : vector<16xf32> from vector<1x1x16xf32>
- // CHECK-DAG: %[[A:.+]] = iree_vector_ext.to_simt %[[ARG0]] : vector<32x8xf16> -> vector<1x1x4xf16>
- // CHECK-DAG: %[[AV:.+]] = vector.extract %[[A]][0, 0] : vector<4xf16> from vector<1x1x4xf16>
- // CHECK-DAG: %[[B:.+]] = iree_vector_ext.to_simt %[[ARG1]] : vector<8x32xf16> -> vector<1x1x4xf16>
- // CHECK-DAG: %[[BV:.+]] = vector.extract %[[B]][0, 0] : vector<4xf16> from vector<1x1x4xf16>
- // CHECK-DAG: %[[OUT:.+]] = amdgpu.mfma %[[AV]] * %[[BV]] + %[[CV]] {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<16xf32>
%output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
"__vector_layout_test_anchor_operand_0" = #layout_a,
@@ -131,14 +140,31 @@
return %output : vector<32x32xf32>
}
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %contract = transform.structured.match ops{["vector.contract"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ %layout32x32x8 = transform.param.constant #layout -> !transform.any_param
+ transform.iree.set_contraction_layout_attributes %contract, %layout32x32x8 : !transform.any_op, !transform.any_param
+
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
transform.iree.amdgpu_distribute_vectors %top_level_func test_conversion : (!transform.any_op) -> !transform.any_op
transform.yield
}
}
+// CHECK-LABEL: distribute_mfma_32x32x8_mm
+
+// CHECK-SAME: %[[ARG0:.+]]: vector<32x8xf16>, %[[ARG1:.+]]: vector<8x32xf16>, %[[ARG2:.+]]: vector<32x32xf32>
+// CHECK-DAG: %[[C:.+]] = iree_vector_ext.to_simt %[[ARG2]] : vector<32x32xf32> -> vector<1x1x16xf32>
+// CHECK-DAG: %[[CV:.+]] = vector.extract %[[C]][0, 0] : vector<16xf32> from vector<1x1x16xf32>
+// CHECK-DAG: %[[A:.+]] = iree_vector_ext.to_simt %[[ARG0]] : vector<32x8xf16> -> vector<1x1x4xf16>
+// CHECK-DAG: %[[AV:.+]] = vector.extract %[[A]][0, 0] : vector<4xf16> from vector<1x1x4xf16>
+// CHECK-DAG: %[[B:.+]] = iree_vector_ext.to_simt %[[ARG1]] : vector<8x32xf16> -> vector<1x1x4xf16>
+// CHECK-DAG: %[[BV:.+]] = vector.extract %[[B]][0, 0] : vector<4xf16> from vector<1x1x4xf16>
+// CHECK-DAG: %[[OUT:.+]] = amdgpu.mfma %[[AV]] * %[[BV]] + %[[CV]] {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<16xf32>
+
// -----
+#layout = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>
+
#map1 = affine_map<(d0, d1, d2) -> (d2, d0)>
#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -160,7 +186,6 @@
#layout_c = #iree_vector_ext.layout<#row_layout2, #col_layout2>
builtin.module attributes { transform.with_named_sequence } {
func.func @distribute_mfma_32x32x8_mtm(%a : vector<8x64xf16>, %b : vector<8x32xf16>, %c : vector<64x32xf32>) -> vector<64x32xf32> {
- // CHECK-LABEL: distribute_mfma_32x32x8_mtm
%output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
"__vector_layout_test_anchor_operand_0" = #layout_a,
@@ -169,17 +194,126 @@
"__vector_layout_test_anchor_result_0" = #layout_c
}
%a, %b, %c : vector<8x64xf16>, vector<8x32xf16> into vector<64x32xf32>
- // CHECK-DAG: %[[A1:.+]] = vector.extract %[[A:.+]][0, 0] : vector<4xf16> from vector<1x2x4xf16>
- // CHECK-DAG: %[[B1:.+]] = vector.extract %[[B:.+]][0, 0] : vector<4xf16> from vector<1x1x4xf16>
- // CHECK-DAG: %{{.*}} = amdgpu.mfma %[[A1]] * %[[B1]]
- // CHECK-DAG: %[[A2:.+]] = vector.extract %[[A]][0, 1] : vector<4xf16> from vector<1x2x4xf16>
- // CHECK-DAG: %{{.*}} = amdgpu.mfma %[[A2]] * %[[B1]]
- // CHECK-NOT: amdgpu.mfma
return %output : vector<64x32xf32>
}
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %contract = transform.structured.match ops{["vector.contract"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ %layout32x32x8 = transform.param.constant #layout -> !transform.any_param
+ transform.iree.set_contraction_layout_attributes %contract, %layout32x32x8 : !transform.any_op, !transform.any_param
+
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
transform.iree.amdgpu_distribute_vectors %top_level_func test_conversion : (!transform.any_op) -> !transform.any_op
transform.yield
}
}
+
+// CHECK-LABEL: distribute_mfma_32x32x8_mtm
+
+// CHECK-DAG: %[[A1:.+]] = vector.extract %[[A:.+]][0, 0] : vector<4xf16> from vector<1x2x4xf16>
+// CHECK-DAG: %[[B1:.+]] = vector.extract %[[B:.+]][0, 0] : vector<4xf16> from vector<1x1x4xf16>
+// CHECK-DAG: %{{.*}} = amdgpu.mfma %[[A1]] * %[[B1]]
+// CHECK-DAG: %[[A2:.+]] = vector.extract %[[A]][0, 1] : vector<4xf16> from vector<1x2x4xf16>
+// CHECK-DAG: %{{.*}} = amdgpu.mfma %[[A2]] * %[[B1]]
+// CHECK-NOT: amdgpu.mfma
+
+// -----
+
+#layout = #iree_gpu.mma_layout<WMMA_F16_16x16x16_F32>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d1, d0)>
+
+// A: vector<16x16>, layout = layoutA
+#row_layout = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX], [1, 16]>
+#col_layout = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [1, 1, 16]>
+#layout_a = #iree_vector_ext.layout<#row_layout, #col_layout>
+
+// B: vector<16x16>, layout = transpose(layoutB) = layoutA
+// Since shapes are also same, we can use the same layout attribute, layout_a.
+
+// C: vector<16x16>, layout = layoutC
+#row_layout2 = #iree_vector_ext.per_dim_layout<[BATCHX, VECTORY, LANEY, VECTORX], [1, 8, 2, 1]>
+#col_layout2 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [1, 16]>
+#layout_c = #iree_vector_ext.layout<#row_layout2, #col_layout2>
+builtin.module attributes { transform.with_named_sequence } {
+ func.func @distribute_wmma_16x16x16_mmt(%a : vector<16x16xf16>, %b : vector<16x16xf16>, %c : vector<16x16xf32>) -> vector<16x16xf32> {
+ %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>,
+ "__vector_layout_test_anchor_operand_0" = #layout_a,
+ "__vector_layout_test_anchor_operand_1" = #layout_a,
+ "__vector_layout_test_anchor_operand_2" = #layout_c,
+ "__vector_layout_test_anchor_result_0" = #layout_c
+ }
+ %a, %b, %c : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf32>
+ return %output : vector<16x16xf32>
+ }
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %contract = transform.structured.match ops{["vector.contract"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ %layout16x16x16 = transform.param.constant #layout -> !transform.any_param
+ transform.iree.set_contraction_layout_attributes %contract, %layout16x16x16 : !transform.any_op, !transform.any_param
+
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.amdgpu_distribute_vectors %top_level_func test_conversion : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: distribute_wmma_16x16x16_mmt
+
+// CHECK-SAME: %[[ARG0:.+]]: vector<16x16xf16>, %[[ARG1:.+]]: vector<16x16xf16>, %[[ARG2:.+]]: vector<16x16xf32>
+// CHECK-DAG: %[[C:.+]] = iree_vector_ext.to_simt %[[ARG2]] : vector<16x16xf32> -> vector<1x1x8xf32>
+// CHECK-DAG: %[[CV:.+]] = vector.extract %[[C]][0, 0] : vector<8xf32> from vector<1x1x8xf32>
+// CHECK-DAG: %[[A:.+]] = iree_vector_ext.to_simt %[[ARG0]] : vector<16x16xf16> -> vector<1x1x16xf16>
+// CHECK-DAG: %[[AV:.+]] = vector.extract %[[A]][0, 0] : vector<16xf16> from vector<1x1x16xf16>
+// CHECK-DAG: %[[B:.+]] = iree_vector_ext.to_simt %[[ARG1]] : vector<16x16xf16> -> vector<1x1x16xf16>
+// CHECK-DAG: %[[BV:.+]] = vector.extract %[[B]][0, 0] : vector<16xf16> from vector<1x1x16xf16>
+// CHECK-DAG: %[[OUT:.+]] = amdgpu.wmma %[[AV]] * %[[BV]] + %[[CV]] : vector<16xf16>, vector<16xf16>, vector<8xf32>
+
+// -----
+
+#layout = #iree_gpu.mma_layout<WMMA_F16_16x16x16_F32>
+
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d1, d0)>
+
+// A: vector<32x128>, layout = layoutA
+#row_layout = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX], [2, 16]>
+#col_layout = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [8, 1, 16]>
+#layout_a = #iree_vector_ext.layout<#row_layout, #col_layout>
+
+// B: vector<64x128>, layout = transpose(layoutB) = layoutA
+#row_layout2 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX], [4, 16]>
+#col_layout2 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [8, 1, 16]>
+#layout_b = #iree_vector_ext.layout<#row_layout2, #col_layout2>
+
+// C: vector<32x64>, layout = layoutC
+#row_layout3 = #iree_vector_ext.per_dim_layout<[BATCHX, VECTORY, LANEY, VECTORX], [2, 8, 2, 1]>
+#col_layout3 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [4, 16]>
+#layout_c = #iree_vector_ext.layout<#row_layout3, #col_layout3>
+builtin.module attributes { transform.with_named_sequence } {
+ func.func @distribute_wmma_16x16x16_mmt_batch(%a : vector<32x128xf16>, %b : vector<64x128xf16>, %c : vector<32x64xf32>) -> vector<32x64xf32> {
+ %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>,
+ "__vector_layout_test_anchor_operand_0" = #layout_a,
+ "__vector_layout_test_anchor_operand_1" = #layout_b,
+ "__vector_layout_test_anchor_operand_2" = #layout_c,
+ "__vector_layout_test_anchor_result_0" = #layout_c
+ }
+ %a, %b, %c : vector<32x128xf16>, vector<64x128xf16> into vector<32x64xf32>
+ return %output : vector<32x64xf32>
+ }
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %contract = transform.structured.match ops{["vector.contract"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ %layout16x16x16 = transform.param.constant #layout -> !transform.any_param
+ transform.iree.set_contraction_layout_attributes %contract, %layout16x16x16 : !transform.any_op, !transform.any_param
+
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.amdgpu_distribute_vectors %top_level_func test_conversion : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: distribute_wmma_16x16x16_mmt_batch
+
+// CHECK-COUNT-64: amdgpu.wmma {{.*}} : vector<16xf16>, vector<16xf16>, vector<8xf32>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_set_anchor_layouts.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_set_anchor_layouts.mlir
new file mode 100644
index 0000000..2269dc6
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_set_anchor_layouts.mlir
@@ -0,0 +1,95 @@
+// RUN: iree-opt --iree-transform-dialect-interpreter --split-input-file --cse %s --verify-diagnostics
+
+// This tests that the compiler is setting the correct layout anchors for various vectorOps and shapes.
+// Currently only testing on contraction layoutV1, but can be expanded to others.
+
+#layout = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d1, d0)>
+
+builtin.module attributes { transform.with_named_sequence } {
+ func.func @anchor_mfma_16x16x16_mmt(%a : memref<16x16xf16>, %b : memref<16x16xf16>, %init : vector<16x16xf32>) -> vector<16x16xf32> {
+ // CHECK-LABEL: anchor_mfma_16x16x16_mmt
+ %c0 = arith.constant 0 : index
+ %cst_0 = arith.constant 0.0 : f16
+ %lhs = vector.transfer_read %a[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+ // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, LANEX], [1, 16]>, <[ BATCHY, LANEY, VECTORX], [1, 4, 4]>>}}
+ %rhs = vector.transfer_read %b[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+ // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, LANEX], [1, 16]>, <[ BATCHY, LANEY, VECTORX], [1, 4, 4]>>}}
+ %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %init : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf32>
+ // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, LANEY, VECTORX], [1, 4, 4]>, <[ BATCHY, LANEX], [1, 16]>>}}
+ return %output : vector<16x16xf32>
+ }
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %contract = transform.structured.match ops{["vector.contract"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ %layout16x16x16 = transform.param.constant #layout -> !transform.any_param
+ transform.iree.set_contraction_layout_attributes %contract, %layout16x16x16 : !transform.any_op, !transform.any_param
+
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.test_vector_layout_analysis %top_level_func : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+#layout = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d1, d0)>
+
+builtin.module attributes { transform.with_named_sequence } {
+ func.func @anchor_mfma_16x16x16_mmt_batch(%a : memref<32x128xf16>, %b : memref<64x128xf16>, %init : vector<32x64xf32>) -> vector<32x64xf32> {
+ // CHECK-LABEL: anchor_mfma_16x16x16_mmt_batch
+ %c0 = arith.constant 0 : index
+ %cst_0 = arith.constant 0.0 : f16
+ %lhs = vector.transfer_read %a[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<32x128xf16>, vector<32x128xf16>
+ // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, LANEX], [2, 16]>, <[ BATCHY, LANEY, VECTORX], [8, 4, 4]>>}}
+ %rhs = vector.transfer_read %b[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<64x128xf16>, vector<64x128xf16>
+ // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, LANEX], [4, 16]>, <[ BATCHY, LANEY, VECTORX], [8, 4, 4]>>}}
+ %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %init : vector<32x128xf16>, vector<64x128xf16> into vector<32x64xf32>
+ // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, LANEY, VECTORX], [2, 4, 4]>, <[ BATCHY, LANEX], [4, 16]>>}}
+ return %output : vector<32x64xf32>
+ }
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %contract = transform.structured.match ops{["vector.contract"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ %layout16x16x16 = transform.param.constant #layout -> !transform.any_param
+ transform.iree.set_contraction_layout_attributes %contract, %layout16x16x16 : !transform.any_op, !transform.any_param
+
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.test_vector_layout_analysis %top_level_func : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+#layout = #iree_gpu.mma_layout<WMMA_F16_16x16x16_F32>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d1, d0)>
+
+builtin.module attributes { transform.with_named_sequence } {
+ func.func @anchor_wmma_16x16x16_mmt(%a : memref<16x16xf16>, %b : memref<16x16xf16>, %init : vector<16x16xf32>) -> vector<16x16xf32> {
+ // CHECK-LABEL: anchor_wmma_16x16x16_mmt
+ %c0 = arith.constant 0 : index
+ %cst_0 = arith.constant 0.0 : f16
+ %lhs = vector.transfer_read %a[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+ // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, LANEX], [1, 16]>, <[ BATCHY, LANEY, VECTORX], [1, 1, 16]>>}}
+ %rhs = vector.transfer_read %b[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+ // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, LANEX], [1, 16]>, <[ BATCHY, LANEY, VECTORX], [1, 1, 16]>>}}
+ %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %init : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf32>
+ // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, VECTORY, LANEY, VECTORX], [1, 8, 2, 1]>, <[ BATCHY, LANEX], [1, 16]>>}}
+ return %output : vector<16x16xf32>
+ }
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %contract = transform.structured.match ops{["vector.contract"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ %layout16x16x16 = transform.param.constant #layout -> !transform.any_param
+ transform.iree.set_contraction_layout_attributes %contract, %layout16x16x16 : !transform.any_op, !transform.any_param
+
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.test_vector_layout_analysis %top_level_func : !transform.any_op
+ transform.yield
+ }
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp
index 8d17162..dd52c2c 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp
@@ -158,6 +158,8 @@
if (isVectorDimension(name)) {
int64_t step{1};
if (name == LayoutDimension::VECTORY) {
+ assert(ranges.contains(LayoutDimension::VECTORX) &&
+ "Expected VectorX to be specified on layouts with VectorY.");
step = ranges.lookup(LayoutDimension::VECTORX).stop;
}
vecOffset = vecOffset.value_or(0) + it.getPosition() * step;