[VectorDistribution] Add pattern to distribute contractions (#16172)
This PR adds a pattern to distribute contractions to MFMA ops.
Currently, 16x16x16 and 32x32x8 MFMA operations are supported.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
index 84dce8d..5e9762c 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
@@ -85,30 +85,6 @@
return simdIndex;
}
-/// Given the state of the iterator, compute the indices of the distributed
-/// vector that the current iterator state is iterating over. The indices
-/// are not parameterized by thread, and it is expected that the indices for
-/// all threads are same.
-static SmallVector<int64_t> computeSIMTIndex(const LayoutIterator::State &state,
- LayoutAttr layout) {
- constexpr LayoutDimension labels[] = {
- LayoutDimension::BATCHX, LayoutDimension::BATCHY,
- LayoutDimension::VECTORZ, LayoutDimension::VECTORY,
- LayoutDimension::VECTORX};
-
- SmallVector<int64_t> offset;
- for (LayoutDimension label : labels) {
- std::optional shape = findDimShape(layout, label);
- if (!shape) {
- continue;
- }
- // Get current position for the label.
- int64_t position = state.lookup(label).getPosition();
- offset.push_back(position);
- }
- return offset;
-}
-
struct DistributeConstants final : OpDistributionPattern<arith::ConstantOp> {
using OpDistributionPattern::OpDistributionPattern;
@@ -129,8 +105,8 @@
// Replace the original op with the distributed op.
Type elementType = constant.getType().getElementType();
- auto vectorType = VectorType::get(
- layout.getDistributedShape(constant.getType()), elementType);
+ auto vectorType =
+ VectorType::get(layout.getDistributedShape(), elementType);
Operation *distirbutedOp = rewriter.create<arith::ConstantOp>(
constantOp.getLoc(), vectorType,
SplatElementsAttr::get(vectorType, attr.getSplatValue<Attribute>()));
@@ -165,9 +141,8 @@
// Distribute vector result types.
if (auto vectorResult = dyn_cast<VectorValue>(result)) {
- resultType = VectorType::get(
- resLayout.getDistributedShape(vectorResult.getType()),
- vectorResult.getType().getElementType());
+ resultType = VectorType::get(resLayout.getDistributedShape(),
+ vectorResult.getType().getElementType());
}
resultTypes.push_back(resultType);
}
@@ -248,7 +223,7 @@
iterator.apply([&](const LayoutIterator::State &state) {
SmallVector<Value> memoryIndices =
getMemoryIndices(state, memoryLayout, xferOp.getIndices(), rewriter);
- SmallVector<int64_t> accIndices = computeSIMTIndex(state, vectorLayout);
+ SmallVector<int64_t> accIndices = state.computeSIMTIndex();
accumulator = accessUnit(xferOp, memoryIndices, accIndices, accumulator,
vectorLayout, memoryLayout, rewriter);
});
@@ -301,7 +276,6 @@
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
DistributionSignature &signature,
PatternRewriter &rewriter) const override {
- VectorValue vector = readOp.getVector();
LayoutAttr vectorLayout = dyn_cast<LayoutAttr>(signature.results[0]);
if (!vectorLayout) {
return failure();
@@ -310,8 +284,8 @@
// TODO: Return failure if we need masking.
Type elementType = readOp.getSource().getType().getElementType();
- auto vectorType = VectorType::get(
- vectorLayout.getDistributedShape(vector.getType()), elementType);
+ auto vectorType =
+ VectorType::get(vectorLayout.getDistributedShape(), elementType);
Value zero = rewriter.create<arith::ConstantOp>(
readOp.getLoc(), vectorType, rewriter.getZeroAttr(vectorType));
VectorValue acc = cast<VectorValue>(zero);
@@ -345,7 +319,6 @@
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
DistributionSignature &signature,
PatternRewriter &rewriter) const override {
- VectorValue vector = writeOp.getVector();
LayoutAttr vectorLayout = dyn_cast<LayoutAttr>(signature.operands[0]);
if (!vectorLayout) {
return failure();
@@ -354,8 +327,8 @@
// TODO: Return failure if we need masking.
Type elementType = writeOp.getSource().getType().getElementType();
- auto vectorType = VectorType::get(
- vectorLayout.getDistributedShape(vector.getType()), elementType);
+ auto vectorType =
+ VectorType::get(vectorLayout.getDistributedShape(), elementType);
Value zero = rewriter.create<arith::ConstantOp>(
writeOp.getLoc(), vectorType, rewriter.getZeroAttr(vectorType));
VectorValue acc = cast<VectorValue>(zero);
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp
index 7de8a9b..4f98b6a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp
@@ -103,8 +103,7 @@
return cast<VectorValue>(toSIMD.getInput());
}
// Create a "to_simt" op to convert the value to the distributed layout.
- SmallVector<int64_t> distributedShape =
- layout.getDistributedShape(value.getType());
+ SmallVector<int64_t> distributedShape = layout.getDistributedShape();
VectorType distributedType =
VectorType::get(distributedShape, value.getType().getElementType());
auto toSIMT = rewriter.create<IREE::VectorExt::ToSIMTOp>(
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
index a7ab386..2d35a9a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
@@ -7,15 +7,15 @@
func.func @distribute_elementwise_f16(%a: vector<16x16xf16>, %b: vector<16x16xf16>) -> vector<16x16xf16> {
%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.0 : f16
- // CHECK: %[[ROOT:.*]] = arith.constant dense<0.000000e+00> : vector<4xf16>
+ // CHECK: %[[ROOT:.*]] = arith.constant dense<0.000000e+00> : vector<16xf16>
%root = arith.constant {"__vector_layout_test_anchor_result_0" = #layout} dense<0.0> : vector<16x16xf16>
- // CHECK-DAG: %[[B:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<16x16xf16> -> vector<4xf16>
- // CHECK-DAG: %[[C:.*]] = arith.mulf %[[ROOT]], %[[B]] : vector<4xf16>
+ // CHECK-DAG: %[[B:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<16x16xf16> -> vector<16xf16>
+ // CHECK-DAG: %[[C:.*]] = arith.mulf %[[ROOT]], %[[B]] : vector<16xf16>
%c = arith.mulf %root, %b : vector<16x16xf16>
- // CHECK-DAG: %[[A:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<16x16xf16> -> vector<4xf16>
- // CHECK-DAG: %[[D:.*]] = arith.addf %[[C]], %[[A]] fastmath<reassoc,nnan> : vector<4xf16>
+ // CHECK-DAG: %[[A:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<16x16xf16> -> vector<16xf16>
+ // CHECK-DAG: %[[D:.*]] = arith.addf %[[C]], %[[A]] fastmath<reassoc,nnan> : vector<16xf16>
%d = arith.addf %c, %a fastmath<reassoc,nnan> : vector<16x16xf16>
- // CHECK: iree_vector_ext.to_simd %[[D]] : vector<4xf16> -> vector<16x16xf16>
+ // CHECK: iree_vector_ext.to_simd %[[D]] : vector<16xf16> -> vector<16x16xf16>
return %d : vector<16x16xf16>
}
@@ -23,15 +23,15 @@
func.func @distribute_elementwise_i32(%a: vector<16x16xi32>, %b: vector<16x16xi32>) -> vector<16x16xi32> {
%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0 : i32
- // CHECK: %[[ROOT:.*]] = arith.constant dense<0> : vector<4xi32>
+ // CHECK: %[[ROOT:.*]] = arith.constant dense<0> : vector<16xi32>
%root = arith.constant {"__vector_layout_test_anchor_result_0" = #layout} dense<0> : vector<16x16xi32>
- // CHECK-DAG: %[[B:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<16x16xi32> -> vector<4xi32>
- // CHECK-DAG: %[[C:.*]] = arith.muli %[[ROOT]], %[[B]] : vector<4xi32>
+ // CHECK-DAG: %[[B:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<16x16xi32> -> vector<16xi32>
+ // CHECK-DAG: %[[C:.*]] = arith.muli %[[ROOT]], %[[B]] : vector<16xi32>
%c = arith.muli %root, %b : vector<16x16xi32>
- // CHECK-DAG: %[[A:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<16x16xi32> -> vector<4xi32>
- // CHECK-DAG: %[[D:.*]] = arith.addi %[[C]], %[[A]] : vector<4xi32>
+ // CHECK-DAG: %[[A:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<16x16xi32> -> vector<16xi32>
+ // CHECK-DAG: %[[D:.*]] = arith.addi %[[C]], %[[A]] : vector<16xi32>
%d = arith.addi %c, %a : vector<16x16xi32>
- // CHECK: iree_vector_ext.to_simd %[[D]] : vector<4xi32> -> vector<16x16xi32>
+ // CHECK: iree_vector_ext.to_simd %[[D]] : vector<16xi32> -> vector<16x16xi32>
return %d : vector<16x16xi32>
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index 810b7ea..008e033 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -890,32 +890,6 @@
// TestVectorLayoutAnalysisOp
//===----------------------------------------------------------------------===//
-static void setAnchorOpsFromAttributes(VectorLayoutAnalysis &analysis,
- Operation *root) {
- root->walk([&](Operation *op) {
- for (NamedAttribute attr : op->getAttrs()) {
- StringRef name = attr.getName().strref();
- if (name.find("__vector_layout_test_anchor_operand_") !=
- std::string::npos) {
- int operandNum;
- name.substr(name.find_last_of("_") + 1)
- .getAsInteger(/*Radix=*/10, operandNum);
- assert(operandNum < op->getNumOperands() &&
- "operand number out of range");
- analysis.setAnchor(op->getOperand(operandNum), attr.getValue());
- }
- if (name.find("__vector_layout_test_anchor_result_") !=
- std::string::npos) {
- int resultNum;
- name.substr(name.find_last_of("_") + 1)
- .getAsInteger(/*Radix=*/10, resultNum);
- assert(resultNum < op->getNumResults() && "result number out of range");
- analysis.setAnchor(op->getResult(resultNum), attr.getValue());
- }
- }
- });
-}
-
static void emitLayoutRemarks(VectorLayoutAnalysis &analysis,
func::FuncOp funcOp) {
funcOp.walk([&](Operation *op) {
diff --git a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp
index 22ec064..e01c4d4 100644
--- a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp
@@ -1019,3 +1019,31 @@
print(llvm::dbgs());
llvm::dbgs() << "\n";
}
+
+namespace mlir::iree_compiler {
+
+void setAnchorOpsFromAttributes(VectorLayoutAnalysis &analysis,
+ Operation *root) {
+ root->walk([&](Operation *op) {
+ for (NamedAttribute attr : op->getAttrs()) {
+ StringRef name = attr.getName().strref();
+ if (name.contains("__vector_layout_test_anchor_operand_")) {
+ int operandNum;
+ name.substr(name.find_last_of("_") + 1)
+ .getAsInteger(/*Radix=*/10, operandNum);
+ assert(operandNum < op->getNumOperands() &&
+ "operand number out of range");
+ analysis.setAnchor(op->getOperand(operandNum), attr.getValue());
+ }
+ if (name.contains("__vector_layout_test_anchor_result_")) {
+ int resultNum;
+ name.substr(name.find_last_of("_") + 1)
+ .getAsInteger(/*Radix=*/10, resultNum);
+ assert(resultNum < op->getNumResults() && "result number out of range");
+ analysis.setAnchor(op->getResult(resultNum), attr.getValue());
+ }
+ }
+ });
+}
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.h b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.h
index 3c9ebae..0282cb2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.h
+++ b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.h
@@ -146,6 +146,9 @@
DataFlowSolver solver;
};
+void setAnchorOpsFromAttributes(VectorLayoutAnalysis &analysis,
+ Operation *root);
+
}; // namespace iree_compiler
}; // namespace mlir
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD.bazel
index 7f68c8d..274228e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD.bazel
@@ -65,6 +65,7 @@
"//llvm-external-projects/iree-dialects:IREEDialectsTransforms",
"//llvm-external-projects/iree-dialects:IREELinalgTransformDialect",
"@llvm-project//llvm:Support",
+ "@llvm-project//mlir:AMDGPUDialect",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:BufferizationDialect",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/CMakeLists.txt
index 8491b58..0390d66 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/CMakeLists.txt
@@ -34,6 +34,7 @@
IREEDialectsTransforms
IREELinalgTransformDialect
LLVMSupport
+ MLIRAMDGPUDialect
MLIRAffineDialect
MLIRArithDialect
MLIRBufferizationDialect
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
index 7fd04fc..0ec4618 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
@@ -8,6 +8,8 @@
#include "iree-dialects/Dialect/LinalgTransform/SimplePatternRewriter.h"
#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h"
+#include "iree/compiler/Codegen/Common/GPU/GPUPatterns.h"
+#include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h"
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
#include "iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
@@ -17,6 +19,7 @@
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/Debug.h"
#include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
@@ -56,6 +59,7 @@
// CreateAsyncGroupsOp depends on the following two dialects.
declareGeneratedDialect<gpu::GPUDialect>();
declareGeneratedDialect<nvgpu::NVGPUDialect>();
+ declareGeneratedDialect<amdgpu::AMDGPUDialect>();
registerTransformOps<
#define GET_OP_LIST
@@ -1496,5 +1500,32 @@
transform::modifiesPayload(effects);
}
+class TestVectorLayoutOptions : public VectorLayoutOptions {
+public:
+ TestVectorLayoutOptions(Operation *root) : VectorLayoutOptions(root) {}
+
+ void setAnchorOps(VectorLayoutAnalysis &analysis) override {
+ setAnchorOpsFromAttributes(analysis, root);
+ }
+};
+
+DiagnosedSilenceableFailure
+transform_dialect::TestAMDGPUContractionDistribution::applyToOne(
+ transform::TransformRewriter &rewriter, func::FuncOp target,
+ transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ TestVectorLayoutOptions options(target);
+ RewritePatternSet patterns(target.getContext());
+ populateAMDGPUDistributionPatterns(patterns);
+ distributeVectorOps(target, patterns, options);
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform_dialect::TestAMDGPUContractionDistribution::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::onlyReadsHandle(getTarget(), effects);
+ transform::modifiesPayload(effects);
+}
+
#define GET_OP_CLASSES
#include "iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.cpp.inc"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td
index 9c0d0b4..7b9e208 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td
@@ -312,8 +312,8 @@
let results = (outs);
let assemblyFormat = [{
- $target
- attr-dict
+ $target
+ attr-dict
`:` functional-type($target, results)
}];
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
@@ -437,9 +437,9 @@
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
- let assemblyFormat = [{
- $for_op
- attr-dict
+ let assemblyFormat = [{
+ $for_op
+ attr-dict
`:` functional-type(operands, results)}];
let extraClassDeclaration = [{
@@ -472,9 +472,9 @@
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
- let assemblyFormat = [{
- $for_op
- attr-dict
+ let assemblyFormat = [{
+ $for_op
+ attr-dict
`:` functional-type(operands, results)}];
let extraClassDeclaration = [{
@@ -509,9 +509,9 @@
UnitAttr:$use_mma_sync);
let results = (outs);
- let assemblyFormat = [{
- $target
- attr-dict
+ let assemblyFormat = [{
+ $target
+ attr-dict
`:` functional-type(operands, results)}];
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
@@ -682,4 +682,41 @@
}];
}
+def TestAMDGPUContractionDistribution :
+ Op<Transform_Dialect, "iree.test_amdgpu_contraction_distribution",
+ [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ TransformEachOpTrait,
+ TransformOpInterface,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ Run AMDGPU Vector Contraction distribution on the target as the root.
+
+ The anchor points are set by using the attribute
+ "__vector_layout_test_anchor_operand_x" and
+ "__vector_layout_test_anchor_result_x", where "x" is the operand/result
+ number.
+
+ This op produces amdgpu MFMA ops.
+
+ #### Return modes
+
+ This transform does not consume the target handle and always return success.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs);
+
+ let assemblyFormat = [{ $target attr-dict `:` type($target)}];
+ let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::func::FuncOp funcOp,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
+
#endif // IREE_COMPILER_CODEGEN_LLVMGPU_TRANSFORMEXTENSIONS_LLVMGPUEXTENSIONS
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/AMDGPUDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/AMDGPUDistributionPatterns.cpp
new file mode 100644
index 0000000..cf0a0cb
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/AMDGPUDistributionPatterns.cpp
@@ -0,0 +1,358 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h"
+#include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h"
+#include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h"
+#include "iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h"
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+
+namespace mlir::iree_compiler {
+
+using namespace mlir::iree_compiler::IREE::VectorExt;
+using VectorValue = TypedValue<VectorType>;
+
+enum class ContractMatrixType { A, B, C, D };
+enum class ContractType { MM, MMT, MTM, MTMT, UNSUPPORTED };
+
+// The naming scheme for these operators is:
+// InputType_MxNxK_OutputType.
+enum class MFMAType {
+ F16_16x16x16_F32,
+ F16_32x32x8_F32,
+};
+
+namespace {
+
+struct DistributeContractions final
+ : OpDistributionPattern<vector::ContractionOp> {
+ using OpDistributionPattern::OpDistributionPattern;
+
+ // For a MM contraction, we compute C(i, k) += A(i, j) * B(j, k).
+ // If we have an MMT contraction, we compute C(i, k) += A(i, j) * B(k, j).
+ // This function returns the appropriate indices for the A and B matrices.
+ // Given incoming indices (i, j), it either returns the same or swaps them,
+ // depending on the type of contraction and type of matrix.
+ SmallVector<int64_t> getIndices(ContractType contractType,
+ ContractMatrixType matrixType, int i,
+ int j) const {
+ SmallVector<int64_t> originalIndices{i, j};
+ SmallVector<int64_t> swappedIndices{j, i};
+ if (contractType == ContractType::MTMT)
+ return swappedIndices;
+ if ((contractType == ContractType::MTM) &&
+ (matrixType == ContractMatrixType::A))
+ return swappedIndices;
+ if ((contractType == ContractType::MMT) &&
+ (matrixType == ContractMatrixType::B))
+ return swappedIndices;
+ return originalIndices;
+ }
+
+ int64_t getReductionDimensionShape(int64_t rowBatch, int64_t colBatch,
+ ContractType contractType) const {
+ if ((contractType == ContractType::MTM) ||
+ (contractType == ContractType::MTMT)) {
+ return rowBatch;
+ }
+ return colBatch;
+ }
+
+ ContractType inferContractType(MLIRContext *ctx,
+ SmallVector<AffineMap> maps) const {
+ using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+ auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
+ AffineExpr m, n, k;
+ bindDims(ctx, m, n, k);
+ if ((maps == infer({{m, k}, {k, n}, {m, n}})) ||
+ (maps == infer({{n, k}, {k, m}, {n, m}}))) {
+ return ContractType::MM;
+ }
+ if ((maps == infer({{m, k}, {n, k}, {m, n}})) ||
+ (maps == infer({{n, k}, {m, k}, {n, m}}))) {
+ return ContractType::MMT;
+ }
+ if ((maps == infer({{k, m}, {k, n}, {m, n}})) ||
+ (maps == infer({{k, n}, {k, m}, {n, m}}))) {
+ return ContractType::MTM;
+ }
+ if ((maps == infer({{k, m}, {n, k}, {m, n}})) ||
+ (maps == infer({{k, n}, {m, k}, {n, m}}))) {
+ return ContractType::MTMT;
+ }
+ return ContractType::UNSUPPORTED;
+ }
+
+ Value computeMMA(Value a, Value b, Value c, Location loc, OpBuilder &rewriter,
+ MFMAType mfmaType) const {
+ uint32_t m, n, k, blks;
+ if (mfmaType == MFMAType::F16_16x16x16_F32) {
+ m = n = k = 16;
+ } else if (mfmaType == MFMAType::F16_32x32x8_F32) {
+ m = n = 32;
+ k = 8;
+ }
+ blks = 1;
+ return rewriter.create<amdgpu::MFMAOp>(loc, c.getType(), m, n, k, blks, a,
+ b, c);
+ }
+
+ PerDimLayoutAttr createPerDimLayout(MLIRContext *ctx,
+ ArrayRef<LayoutDimension> dims,
+ ArrayRef<int64_t> shapes) const {
+ SmallVector<LayoutDimensionAttr> dimAttrs;
+ for (auto dim : dims)
+ dimAttrs.push_back(LayoutDimensionAttr::get(ctx, dim));
+ return PerDimLayoutAttr::get(ctx, dimAttrs, shapes);
+ }
+
+ std::tuple<PerDimLayoutAttr, PerDimLayoutAttr> createCanonicalLayouts16x16x16(
+ LayoutDimension batchRowLabel, int64_t batchRow,
+ LayoutDimension batchColLabel, int64_t batchCol) const {
+ MLIRContext *ctx = getContext();
+ PerDimLayoutAttr rowLayout = createPerDimLayout(
+ ctx, {batchRowLabel, LayoutDimension::LANEX}, {batchRow, 16});
+ PerDimLayoutAttr colLayout = createPerDimLayout(
+ ctx, {batchColLabel, LayoutDimension::LANEY, LayoutDimension::VECTORX},
+ {batchCol, 4, 4});
+ return {rowLayout, colLayout};
+ }
+
+ bool isCompatible16x16x16A(LayoutAttr layout, int64_t batchRow,
+ int64_t batchCol) const {
+ auto [rowLayout, colLayout] = createCanonicalLayouts16x16x16(
+ LayoutDimension::BATCHX, batchRow, LayoutDimension::BATCHY, batchCol);
+ LayoutAttr canonicalLayout =
+ LayoutAttr::get(getContext(), {rowLayout, colLayout});
+ return layout == canonicalLayout;
+ }
+
+ bool isCompatible16x16x16B(LayoutAttr layout, int64_t batchRow,
+ int64_t batchCol) const {
+ auto [colLayout, rowLayout] = createCanonicalLayouts16x16x16(
+ LayoutDimension::BATCHY, batchCol, LayoutDimension::BATCHX, batchRow);
+ LayoutAttr canonicalLayout =
+ LayoutAttr::get(getContext(), {rowLayout, colLayout});
+ return layout == canonicalLayout;
+ }
+
+ bool isCompatible16x16x16C(LayoutAttr layout, int64_t batchRow,
+ int64_t batchCol) const {
+ return isCompatible16x16x16B(layout, batchRow, batchCol);
+ }
+
+ std::tuple<PerDimLayoutAttr, PerDimLayoutAttr>
+ createCanonicalLayouts32x32x8(LayoutDimension batchRowLabel, int64_t batchRow,
+ LayoutDimension batchColLabel, int64_t batchCol,
+ ContractMatrixType matrixType) const {
+ MLIRContext *ctx = getContext();
+ PerDimLayoutAttr rowLayout = createPerDimLayout(
+ ctx, {batchRowLabel, LayoutDimension::LANEX}, {batchRow, 32});
+ PerDimLayoutAttr colLayout;
+ if (matrixType == ContractMatrixType::C) {
+ colLayout =
+ createPerDimLayout(ctx,
+ {batchColLabel, LayoutDimension::VECTORY,
+ LayoutDimension::LANEY, LayoutDimension::VECTORX},
+ {batchCol, 4, 2, 4});
+ } else {
+ colLayout = createPerDimLayout(
+ ctx,
+ {batchColLabel, LayoutDimension::LANEY, LayoutDimension::VECTORX},
+ {batchCol, 2, 4});
+ }
+ return {rowLayout, colLayout};
+ }
+
+ bool isCompatible32x32x8A(LayoutAttr layout, int64_t batchRow,
+ int64_t batchCol) const {
+ auto [rowLayout, colLayout] = createCanonicalLayouts32x32x8(
+ LayoutDimension::BATCHX, batchRow, LayoutDimension::BATCHY, batchCol,
+ ContractMatrixType::A);
+ LayoutAttr canonicalLayout =
+ LayoutAttr::get(getContext(), {rowLayout, colLayout});
+ return layout == canonicalLayout;
+ }
+
+ bool isCompatible32x32x8B(LayoutAttr layout, int64_t batchRow,
+ int64_t batchCol) const {
+ auto [colLayout, rowLayout] = createCanonicalLayouts32x32x8(
+ LayoutDimension::BATCHY, batchCol, LayoutDimension::BATCHX, batchRow,
+ ContractMatrixType::B);
+ LayoutAttr canonicalLayout =
+ LayoutAttr::get(getContext(), {rowLayout, colLayout});
+ return layout == canonicalLayout;
+ }
+
+ bool isCompatible32x32x8C(LayoutAttr layout, int64_t batchRow,
+ int64_t batchCol) const {
+ auto [colLayout, rowLayout] = createCanonicalLayouts32x32x8(
+ LayoutDimension::BATCHY, batchCol, LayoutDimension::BATCHX, batchRow,
+ ContractMatrixType::C);
+ LayoutAttr canonicalLayout =
+ LayoutAttr::get(getContext(), {rowLayout, colLayout});
+ return layout == canonicalLayout;
+ }
+
+ bool isCompatible16x16x16(LayoutAttr layout, ContractMatrixType matrixType,
+ int64_t batchRow, int64_t batchCol) const {
+ switch (matrixType) {
+ case ContractMatrixType::A:
+ return isCompatible16x16x16A(layout, batchRow, batchCol);
+ case ContractMatrixType::B:
+ return isCompatible16x16x16B(layout, batchRow, batchCol);
+ default:
+ return isCompatible16x16x16C(layout, batchRow, batchCol);
+ }
+ return false;
+ }
+
+ bool isCompatible32x32x8(LayoutAttr layout, ContractMatrixType matrixType,
+ int64_t batchRow, int64_t batchCol) const {
+ switch (matrixType) {
+ case ContractMatrixType::A:
+ return isCompatible32x32x8A(layout, batchRow, batchCol);
+ case ContractMatrixType::B:
+ return isCompatible32x32x8B(layout, batchRow, batchCol);
+ default:
+ return isCompatible32x32x8C(layout, batchRow, batchCol);
+ }
+ return false;
+ }
+
+ bool isCompatible(LayoutAttr layout, ContractMatrixType matrixType,
+ MFMAType mfmaType) const {
+ std::optional<int64_t> batchRow = layout.getBatchDim(0);
+ if (!batchRow)
+ return false;
+ std::optional<int64_t> batchCol = layout.getBatchDim(1);
+ if (!batchCol)
+ return false;
+ switch (mfmaType) {
+ case MFMAType::F16_16x16x16_F32:
+ return isCompatible16x16x16(layout, matrixType, batchRow.value(),
+ batchCol.value());
+ case MFMAType::F16_32x32x8_F32:
+ return isCompatible32x32x8(layout, matrixType, batchRow.value(),
+ batchCol.value());
+ default:
+ return false;
+ }
+ return false;
+ }
+
+ // If we have a prior guess of the MFMA type, only evaluate that type.
+ // Otherwise, evaluate all types to find a match.
+ std::optional<MFMAType> inferMFMAType(LayoutAttr layout,
+ ContractMatrixType matrixType,
+ std::optional<MFMAType> prior) const {
+ SmallVector<MFMAType> mfmaTypes;
+ if (prior) {
+ mfmaTypes.push_back(prior.value());
+ } else {
+ mfmaTypes = {MFMAType::F16_16x16x16_F32, MFMAType::F16_32x32x8_F32};
+ }
+ for (MFMAType mfmaType : mfmaTypes) {
+ if (isCompatible(layout, matrixType, mfmaType))
+ return mfmaType;
+ }
+ return std::nullopt;
+ }
+
+ // Inputs are LHS, RHS and ACC operands and corresponding layouts.
+ // Output is inferred MFMAType or none (if layout is not compatible with any
+ // MFMA layout).
+ std::optional<MFMAType>
+ inferCompatibleMFMAType(ArrayRef<LayoutAttr> layouts) const {
+ std::optional<MFMAType> mfmaType{std::nullopt};
+ SmallVector<ContractMatrixType> matrixTypes{
+ ContractMatrixType::A, ContractMatrixType::B, ContractMatrixType::C};
+ for (auto [layout, matrixType] : llvm::zip(layouts, matrixTypes)) {
+ mfmaType = inferMFMAType(layout, matrixType, mfmaType);
+ if (!mfmaType)
+ return std::nullopt;
+ }
+ return mfmaType;
+ }
+
+ LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+ DistributionSignature &signature,
+ PatternRewriter &rewriter) const override {
+ constexpr int LHS = 0;
+ constexpr int RHS = 1;
+ constexpr int ACC = 2;
+ SmallVector<VectorValue> operands;
+ SmallVector<LayoutAttr> layouts;
+ for (auto [operand, layout] :
+ llvm::zip_equal(contractOp->getOperands(), signature.operands)) {
+ if (auto vectorOperand = dyn_cast<VectorValue>(operand)) {
+ if (auto vectorLayout = dyn_cast<LayoutAttr>(layout)) {
+ operands.push_back(vectorOperand);
+ layouts.push_back(vectorLayout);
+ }
+ }
+ }
+ LayoutAttr resultLayout = dyn_cast<LayoutAttr>(signature.results[0]);
+ if (!resultLayout)
+ return failure();
+ std::optional<MFMAType> mfmaType = inferCompatibleMFMAType(layouts);
+ if (!mfmaType)
+ return failure();
+
+ Type elementType =
+ llvm::cast<ShapedType>(operands[ACC].getType()).getElementType();
+ SmallVector<int64_t> vectorShape = resultLayout.getDistributedShape();
+ auto vectorType = VectorType::get(vectorShape, elementType);
+ Location loc = contractOp.getLoc();
+ Value vector = rewriter.create<arith::ConstantOp>(
+ loc, vectorType, rewriter.getZeroAttr(vectorType));
+
+ ContractType contractType = inferContractType(
+ contractOp.getContext(), contractOp.getIndexingMapsArray());
+ if (contractType == ContractType::UNSUPPORTED)
+ return failure();
+
+ std::optional<int64_t> rowBatch = layouts[LHS].getBatchDim(0);
+ if (!rowBatch)
+ return failure();
+ std::optional<int64_t> colBatch = layouts[LHS].getBatchDim(1);
+ if (!colBatch)
+ return failure();
+
+ int K = getReductionDimensionShape(rowBatch.value(), colBatch.value(),
+ contractType);
+
+ auto contractFn = [&](const LayoutIterator::State &state) {
+ SmallVector<int64_t> indices = state.computeIteratorProjectedSIMTIndex();
+ Value dMatrix = rewriter.create<vector::ExtractOp>(
+ loc, getDistributed(rewriter, operands[ACC], layouts[ACC]), indices);
+ for (int k = 0; k < K; k++) {
+ Value aMatrix = rewriter.create<vector::ExtractOp>(
+ loc, getDistributed(rewriter, operands[LHS], layouts[LHS]),
+ getIndices(contractType, ContractMatrixType::A, indices[0], k));
+ Value bMatrix = rewriter.create<vector::ExtractOp>(
+ loc, getDistributed(rewriter, operands[RHS], layouts[RHS]),
+ getIndices(contractType, ContractMatrixType::B, k, indices[1]));
+ dMatrix = computeMMA(aMatrix, bMatrix, dMatrix, loc, rewriter,
+ mfmaType.value());
+ }
+ vector = rewriter.create<vector::InsertOp>(loc, dMatrix, vector, indices);
+ };
+
+ LayoutIterator iterator(resultLayout);
+ LayoutIterator batchIterator = iterator.getBatchIterator();
+ batchIterator.apply(contractFn);
+ replaceOpWithDistributedValues(rewriter, contractOp, vector);
+ return success();
+ }
+};
+} // namespace
+
+void populateAMDGPUDistributionPatterns(RewritePatternSet &patterns) {
+ patterns.add<DistributeContractions>(patterns.getContext());
+}
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/BUILD.bazel
index 62e6766..b6a328f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/BUILD.bazel
@@ -17,6 +17,7 @@
iree_compiler_cc_library(
name = "Utils",
srcs = [
+ "AMDGPUDistributionPatterns.cpp",
"LLVMGPULayoutAnalysisAndDistribution.cpp",
"LLVMGPUUtils.cpp",
],
@@ -24,9 +25,13 @@
"LLVMGPUUtils.h",
],
deps = [
+ "//compiler/src/iree/compiler/Codegen/Common",
+ "//compiler/src/iree/compiler/Codegen/Common/GPU:CommonGPUPasses",
"//compiler/src/iree/compiler/Codegen/Transforms",
"//compiler/src/iree/compiler/Codegen/Utils",
+ "//llvm-external-projects/iree-dialects:IREEVectorExtDialect",
"@llvm-project//llvm:Support",
+ "@llvm-project//mlir:AMDGPUDialect",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/CMakeLists.txt
index 8609095..10edda1 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/CMakeLists.txt
@@ -16,10 +16,13 @@
HDRS
"LLVMGPUUtils.h"
SRCS
+ "AMDGPUDistributionPatterns.cpp"
"LLVMGPULayoutAnalysisAndDistribution.cpp"
"LLVMGPUUtils.cpp"
DEPS
+ IREEVectorExtDialect
LLVMSupport
+ MLIRAMDGPUDialect
MLIRAffineDialect
MLIRArithDialect
MLIRFuncDialect
@@ -29,6 +32,8 @@
MLIRMemRefDialect
MLIRNVGPUDialect
MLIRVectorDialect
+ iree::compiler::Codegen::Common
+ iree::compiler::Codegen::Common::GPU::CommonGPUPasses
iree::compiler::Codegen::Transforms
iree::compiler::Codegen::Utils
PUBLIC
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h
index 7a6a632..eb7b126 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h
@@ -32,6 +32,9 @@
/// from the previous alias group before starting a new one.
void packSharedMemoryAlloc(func::FuncOp funcOp);
+// Add patterns to distribute contractions to MFMA ops.
+void populateAMDGPUDistributionPatterns(RewritePatternSet &patterns);
+
} // namespace mlir::iree_compiler
#endif
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
index bd80615..3fa0c2f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
@@ -18,6 +18,7 @@
name = "lit",
srcs = enforce_glob(
[
+ "amdgpu_contraction_distribution.mlir",
"attention.mlir",
"conv_pipeline_test.mlir",
"convert_to_nvvm.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index d20d23d..cb082fc 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -14,6 +14,7 @@
NAME
lit
SRCS
+ "amdgpu_contraction_distribution.mlir"
"attention.mlir"
"cast_address_space_function.mlir"
"config_matvec.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_contraction_distribution.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_contraction_distribution.mlir
new file mode 100644
index 0000000..03f4011
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_contraction_distribution.mlir
@@ -0,0 +1,115 @@
+// RUN: iree-opt --iree-transform-dialect-interpreter --split-input-file --cse %s | FileCheck %s
+
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d1, d0)>
+#row_layout = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX], [1, 16]>
+#col_layout = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [1, 4, 4]>
+#row_layout2 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEY, VECTORX], [1, 4, 4]>
+#col_layout2 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [1, 16]>
+#layout_a = #iree_vector_ext.layout<#row_layout, #col_layout>
+#layout_c = #iree_vector_ext.layout<#row_layout2, #col_layout2>
+builtin.module attributes { transform.with_named_sequence } {
+ func.func @distribute_mfma_16x16x16_mmt(%a : vector<16x16xf16>, %b : vector<16x16xf16>, %c : vector<16x16xf32>) -> vector<16x16xf32> {
+ // CHECK-LABEL: distribute_mfma_16x16x16_mmt
+ // CHECK-SAME: %[[ARG0:.+]]: vector<16x16xf16>, %[[ARG1:.+]]: vector<16x16xf16>, %[[ARG2:.+]]: vector<16x16xf32>
+ // CHECK-DAG: %[[C:.+]] = iree_vector_ext.to_simt %[[ARG2]] : vector<16x16xf32> -> vector<1x1x4xf32>
+ // CHECK-DAG: %[[CV:.+]] = vector.extract %[[C]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
+ // CHECK-DAG: %[[A:.+]] = iree_vector_ext.to_simt %[[ARG0]] : vector<16x16xf16> -> vector<1x1x4xf16>
+ // CHECK-DAG: %[[AV:.+]] = vector.extract %[[A]][0, 0] : vector<4xf16> from vector<1x1x4xf16>
+ // CHECK-DAG: %[[B:.+]] = iree_vector_ext.to_simt %[[ARG1]] : vector<16x16xf16> -> vector<1x1x4xf16>
+ // CHECK-DAG: %[[BV:.+]] = vector.extract %[[B]][0, 0] : vector<4xf16> from vector<1x1x4xf16>
+ // CHECK-DAG: %[[OUT:.+]] = amdgpu.mfma %[[AV]] * %[[BV]] + %[[CV]] {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
+ %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>,
+ "__vector_layout_test_anchor_operand_0" = #layout_a,
+ "__vector_layout_test_anchor_operand_1" = #layout_c,
+ "__vector_layout_test_anchor_operand_2" = #layout_c,
+ "__vector_layout_test_anchor_result_0" = #layout_c
+ }
+ %a, %b, %c : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf32>
+ return %output : vector<16x16xf32>
+ }
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.test_amdgpu_contraction_distribution %top_level_func : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#row_layout = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX], [1, 32]>
+#col_layout = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [1, 2, 4]>
+#row_layout1 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEY, VECTORX], [1, 2, 4]>
+#col_layout1 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [1, 32]>
+#row_layout2 = #iree_vector_ext.per_dim_layout<[BATCHX, VECTORY, LANEY, VECTORX], [1, 4, 2, 4]>
+#col_layout2 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [1, 32]>
+#layout_a = #iree_vector_ext.layout<#row_layout, #col_layout>
+#layout_b = #iree_vector_ext.layout<#row_layout1, #col_layout1>
+#layout_c = #iree_vector_ext.layout<#row_layout2, #col_layout2>
+builtin.module attributes { transform.with_named_sequence } {
+ func.func @distribute_mfma_32x32x8_mm(%a : vector<32x8xf16>, %b : vector<8x32xf16>, %c : vector<32x32xf32>) -> vector<32x32xf32> {
+ // CHECK-LABEL: distribute_mfma_32x32x8_mm
+ // CHECK-SAME: %[[ARG0:.+]]: vector<32x8xf16>, %[[ARG1:.+]]: vector<8x32xf16>, %[[ARG2:.+]]: vector<32x32xf32>
+ // CHECK-DAG: %[[C:.+]] = iree_vector_ext.to_simt %[[ARG2]] : vector<32x32xf32> -> vector<1x1x16xf32>
+ // CHECK-DAG: %[[CV:.+]] = vector.extract %[[C]][0, 0] : vector<16xf32> from vector<1x1x16xf32>
+ // CHECK-DAG: %[[A:.+]] = iree_vector_ext.to_simt %[[ARG0]] : vector<32x8xf16> -> vector<1x1x4xf16>
+ // CHECK-DAG: %[[AV:.+]] = vector.extract %[[A]][0, 0] : vector<4xf16> from vector<1x1x4xf16>
+ // CHECK-DAG: %[[B:.+]] = iree_vector_ext.to_simt %[[ARG1]] : vector<8x32xf16> -> vector<1x1x4xf16>
+ // CHECK-DAG: %[[BV:.+]] = vector.extract %[[B]][0, 0] : vector<4xf16> from vector<1x1x4xf16>
+ // CHECK-DAG: %[[OUT:.+]] = amdgpu.mfma %[[AV]] * %[[BV]] + %[[CV]] {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<16xf32>
+ %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>,
+ "__vector_layout_test_anchor_operand_0" = #layout_a,
+ "__vector_layout_test_anchor_operand_1" = #layout_b,
+ "__vector_layout_test_anchor_operand_2" = #layout_c,
+ "__vector_layout_test_anchor_result_0" = #layout_c
+ }
+ %a, %b, %c : vector<32x8xf16>, vector<8x32xf16> into vector<32x32xf32>
+ return %output : vector<32x32xf32>
+ }
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.test_amdgpu_contraction_distribution %top_level_func : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d1, d0)>
+#row_layout = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX], [2, 16]>
+#col_layout = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [8, 4, 4]>
+#row_layout2 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEY, VECTORX], [4, 4, 4]>
+#col_layout2 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [8, 16]>
+#row_layout3 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEY, VECTORX], [2, 4, 4]>
+#col_layout3 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [4, 16]>
+#layout_a = #iree_vector_ext.layout<#row_layout, #col_layout>
+#layout_b = #iree_vector_ext.layout<#row_layout2, #col_layout2>
+#layout_c = #iree_vector_ext.layout<#row_layout3, #col_layout3>
+builtin.module attributes { transform.with_named_sequence } {
+ func.func @distribute_mfma_16x16x16_mmt_batch(%a : vector<32x128xf16>, %b : vector<64x128xf16>, %c : vector<32x64xf32>) -> vector<32x64xf32> {
+ // CHECK-LABEL: distribute_mfma_16x16x16_mmt_batch
+ // CHECK-COUNT-64: amdgpu.mfma {{.*}}, vector<4xf32>
+ %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>,
+ "__vector_layout_test_anchor_operand_0" = #layout_a,
+ "__vector_layout_test_anchor_operand_1" = #layout_b,
+ "__vector_layout_test_anchor_operand_2" = #layout_c,
+ "__vector_layout_test_anchor_result_0" = #layout_c
+ }
+ %a, %b, %c : vector<32x128xf16>, vector<64x128xf16> into vector<32x64xf32>
+ return %output : vector<32x64xf32>
+ }
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.test_amdgpu_contraction_distribution %top_level_func : !transform.any_op
+ transform.yield
+ }
+}
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtBase.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtBase.td
index 69588af..0b4562a 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtBase.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtBase.td
@@ -109,6 +109,10 @@
);
let assemblyFormat = "`<`$layouts`>`";
let genVerifyDecl = 0;
+ let extraClassDeclaration = [{
+ std::optional<int64_t> getBatchDim(int64_t dim);
+ PerDimLayoutAttr getDimLayout(int64_t dim) const;
+ }];
}
#endif // IREE_DIALECT_VECTOREXT_BASE
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtInterfaces.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtInterfaces.td
index 743186b..6d84b14 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtInterfaces.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtInterfaces.td
@@ -39,7 +39,7 @@
/*description=*/"Get the distributed shape for the given vector type.",
/*retTy=*/"SmallVector<int64_t>",
/*methodName=*/"getDistributedShape",
- /*args=*/(ins "VectorType":$type)
+ /*args=*/(ins)
>
];
}
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h
index f13d2f0..ec33eb1 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h
@@ -59,7 +59,6 @@
DimensionalIterator begin() const { return DimensionalIterator(start, step); }
DimensionalIterator end() const { return DimensionalIterator(stop, step); }
-private:
int64_t start, stop, step;
};
@@ -70,27 +69,70 @@
// required during distribution.
class LayoutIterator {
public:
- using State = llvm::MapVector<LayoutDimension, DimensionalIterator>;
- using DimensionMapping =
- llvm::DenseMap<int64_t, SmallVector<LayoutDimension>>;
- void maybeFreezeAndConcatenate(const LayoutIterator &frozenIterator);
+ struct State {
+ SmallVector<int64_t> computeSIMTIndex() const;
+ SmallVector<int64_t> computeIteratorProjectedSIMTIndex() const;
+ bool contains(LayoutDimension dim) const { return iterators.contains(dim); }
+ void erase(LayoutDimension dim) { iterators.erase(dim); }
+ DimensionalIterator lookup(LayoutDimension dim) const {
+ return iterators.lookup(dim);
+ }
+ DimensionalIterator &operator[](LayoutDimension dim) {
+ return iterators[dim];
+ }
+ void print() const {
+ for (const auto &[dim, it] : iterators) {
+ llvm::outs() << stringifyLayoutDimension(dim).str() + ":" +
+ std::to_string(*it) + ", ";
+ }
+ llvm::outs() << "\n";
+ }
+ llvm::MapVector<LayoutDimension, DimensionalIterator> iterators;
+ DenseMap<int64_t, DenseSet<LayoutDimension>> simdToLayoutDim;
+ llvm::MapVector<LayoutDimension, DimensionalRange> ranges;
+ SmallVector<LayoutDimension> labels{
+ LayoutDimension::BATCHX, LayoutDimension::BATCHY,
+ LayoutDimension::VECTORY, LayoutDimension::VECTORX};
+ };
+ void maybeFreezeAndConcatenate(const LayoutIterator::State &frozenState);
+ LayoutIterator(LayoutAttr &attr);
+ LayoutIterator(LayoutAttr &attr, int64_t simtIndex);
LayoutIterator(LayoutAttr &attr, DenseMap<LayoutDimension, int64_t> strides);
+ LayoutIterator(LayoutAttr &attr, DenseMap<LayoutDimension, int64_t> strides,
+ int64_t simtIndex);
LayoutIterator(PerDimLayoutAttr &attr,
DenseMap<LayoutDimension, int64_t> strides);
void apply(std::function<void(const LayoutIterator::State &)>);
LayoutIterator &operator++();
State getState() const { return state; }
+ void erase(LayoutDimension dim);
+ LayoutIterator getBatchIterator() const;
private:
- void initialize(PerDimLayoutAttr &attr,
- DenseMap<LayoutDimension, int64_t> strides);
+ void initialize(const PerDimLayoutAttr &attr,
+ DenseMap<LayoutDimension, int64_t> strides,
+ std::optional<int64_t> simdIndex);
bool iterationComplete();
State state;
- llvm::MapVector<LayoutDimension, DimensionalRange> ranges;
- DimensionMapping simdDimensionToLayoutDimension;
DenseSet<LayoutDimension> frozenDimensions;
+ int64_t iterations{0};
+ int64_t maxIterations{1};
};
+inline bool isBatchDimension(LayoutDimension dim) {
+ return (dim == LayoutDimension::BATCHX) || (dim == LayoutDimension::BATCHY);
+}
+
+inline bool isLaneDimension(LayoutDimension dim) {
+ return (dim == LayoutDimension::LANEX) || (dim == LayoutDimension::LANEY) ||
+ (dim == LayoutDimension::LANEZ);
+}
+
+inline bool isVectorDimension(LayoutDimension dim) {
+ return (dim == LayoutDimension::VECTORX) ||
+ (dim == LayoutDimension::VECTORY) || (dim == LayoutDimension::VECTORZ);
+}
+
} // namespace mlir::iree_compiler::IREE::VectorExt
#endif // IREE_DIALECTS_DIALECT_VECTOREXT_IR_VECTOREXTOPS_H_
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtDialect.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtDialect.cpp
index e2973a1..04cb0dc 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtDialect.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtDialect.cpp
@@ -109,17 +109,47 @@
return LayoutAttr::get(getContext(), newLayouts);
}
-SmallVector<int64_t> LayoutAttr::getDistributedShape(VectorType) const {
- LayoutDimension labels[] = {LayoutDimension::BATCHX, LayoutDimension::BATCHY,
- LayoutDimension::VECTORX};
+// This function returns the distributed shape of the SIMT
+// vector and evaluates it in the following order:
+// BATCHX, BATCHY, VECTORY, VECTORX
+// The vector dimensions are combined into a single SIMT
+// vector dimension.
+SmallVector<int64_t> LayoutAttr::getDistributedShape() const {
+ SmallVector<LayoutDimension> labels{
+ LayoutDimension::BATCHX, LayoutDimension::BATCHY,
+ LayoutDimension::VECTORY, LayoutDimension::VECTORX};
SmallVector<int64_t> simtVectorShape;
+ std::optional<int64_t> vectorShape;
for (LayoutDimension dim : labels) {
ArrayRef<PerDimLayoutAttr> layouts = getLayouts();
for (PerDimLayoutAttr layout : layouts) {
if (!layout.contains(dim))
continue;
- simtVectorShape.push_back(layout.getShape(dim).value());
+ int64_t shape = layout.getShape(dim).value();
+ if (isVectorDimension(dim)) {
+ vectorShape = shape * vectorShape.value_or(1);
+ continue;
+ }
+ simtVectorShape.push_back(shape);
}
}
+ if (vectorShape)
+ simtVectorShape.push_back(vectorShape.value());
return simtVectorShape;
}
+
+PerDimLayoutAttr LayoutAttr::getDimLayout(int64_t dim) const {
+ assert(dim >= 0 && dim < getLayouts().size());
+ return getLayouts()[dim];
+}
+
+std::optional<int64_t> LayoutAttr::getBatchDim(int64_t dim) {
+ assert(dim < getLayouts().size());
+ PerDimLayoutAttr layout = getDimLayout(dim);
+ for (auto [name, shape] :
+ llvm::zip_equal(layout.getLabels(), layout.getShapes())) {
+ if (isBatchDimension(name.getValue()))
+ return shape;
+ }
+ return std::nullopt;
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp
index 8894d69..228cc10 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp
@@ -52,22 +52,9 @@
return {};
}
-LayoutIterator &LayoutIterator::operator++() {
- for (auto &[dim, it] : state) {
- if (frozenDimensions.contains(dim))
- continue;
- ++it;
- if (it < ranges[dim].end()) {
- break;
- }
- it = ranges[dim].begin();
- }
- return *this;
-}
-
void LayoutIterator::maybeFreezeAndConcatenate(
- const LayoutIterator &frozenIterator) {
- for (auto &[frozenDim, frozenIt] : frozenIterator.getState()) {
+ const LayoutIterator::State &frozenState) {
+ for (auto &[frozenDim, frozenIt] : frozenState.iterators) {
if (!state.contains(frozenDim)) {
frozenDimensions.insert(frozenDim);
state[frozenDim] = frozenIt;
@@ -75,13 +62,9 @@
}
}
-static bool isLaneDimension(LayoutDimension dim) {
- return (dim == LayoutDimension::LANEX) || (dim == LayoutDimension::LANEY) ||
- (dim == LayoutDimension::LANEZ);
-}
-
-void LayoutIterator::initialize(PerDimLayoutAttr &attr,
- DenseMap<LayoutDimension, int64_t> strides) {
+void LayoutIterator::initialize(const PerDimLayoutAttr &attr,
+ DenseMap<LayoutDimension, int64_t> strides,
+ std::optional<int64_t> simdIndex) {
auto reversedLabels = llvm::reverse(attr.getLabels());
auto reversedShapes = llvm::reverse(attr.getShapes());
for (auto [nameAttr, shape] : llvm::zip(reversedLabels, reversedShapes)) {
@@ -89,44 +72,138 @@
if (isLaneDimension(dim))
continue;
int64_t stride = strides.contains(dim) ? strides[dim] : 1;
- ranges[dim] = DimensionalRange(0, shape, stride);
- state[dim] = ranges[dim].begin();
+ state.ranges[dim] = DimensionalRange(0, shape, stride);
+ state.iterators[dim] = state.ranges[dim].begin();
+ maxIterations *= shape / stride;
+ if (simdIndex) {
+ int64_t index = simdIndex.value();
+ if (!state.simdToLayoutDim.contains(index))
+ state.simdToLayoutDim[index] = {};
+ state.simdToLayoutDim[index].insert(dim);
+ }
}
}
LayoutIterator::LayoutIterator(LayoutAttr &attr,
DenseMap<LayoutDimension, int64_t> strides) {
- for (PerDimLayoutAttr perDimAttr : attr.getLayouts()) {
- initialize(perDimAttr, strides);
+ for (auto perDimAttr : llvm::enumerate(attr.getLayouts())) {
+ initialize(perDimAttr.value(), strides, perDimAttr.index());
+ }
+}
+
+LayoutIterator::LayoutIterator(LayoutAttr &attr) {
+ DenseMap<LayoutDimension, int64_t> strides;
+ for (auto [idx, attr] : llvm::enumerate(attr.getLayouts())) {
+ initialize(attr, strides, idx);
+ }
+}
+
+LayoutIterator::LayoutIterator(LayoutAttr &attr,
+ DenseMap<LayoutDimension, int64_t> strides,
+ int64_t simtIndex) {
+ for (auto [idx, attr] : llvm::enumerate(attr.getLayouts())) {
+ if (idx != simtIndex)
+ continue;
+ initialize(attr, strides, idx);
+ }
+}
+
+LayoutIterator::LayoutIterator(LayoutAttr &attr, int64_t simtIndex) {
+ DenseMap<LayoutDimension, int64_t> strides;
+ for (auto [idx, attr] : llvm::enumerate(attr.getLayouts())) {
+ if (idx != simtIndex)
+ continue;
+ initialize(attr, strides, idx);
}
}
LayoutIterator::LayoutIterator(PerDimLayoutAttr &attr,
DenseMap<LayoutDimension, int64_t> strides) {
- initialize(attr, strides);
+ initialize(attr, strides, std::nullopt);
}
-/// The iterator is done when it returns back to
-/// its begin state.
-bool LayoutIterator::iterationComplete() {
- bool complete{true};
- for (auto &[dim, it] : state) {
+LayoutIterator &LayoutIterator::operator++() {
+ for (auto &[dim, it] : state.iterators) {
if (frozenDimensions.contains(dim))
continue;
- if (it != ranges[dim].begin()) {
- complete = false;
- break;
+ ++it;
+ if (it == state.ranges[dim].end()) {
+ it = state.ranges[dim].begin();
+ continue;
}
+ break;
}
- return complete;
+ ++iterations;
+ return *this;
}
+/// The iterator is done when all the loops are complete.
+bool LayoutIterator::iterationComplete() { return iterations == maxIterations; }
+
void LayoutIterator::apply(
std::function<void(const LayoutIterator::State &)> callback) {
- do {
+ for (; !iterationComplete(); ++(*this)) {
callback(state);
- ++(*this);
- } while (!iterationComplete());
+ }
+}
+
+// Get the offset into the SIMT vector corresponding to the incoming iterator.
+// The returned offsets will always be the same shape as the labels array.
+// Groups vector dimensions together. Assumes last dimension is vector
+// dimension.
+SmallVector<int64_t> LayoutIterator::State::computeSIMTIndex() const {
+ SmallVector<int64_t> offset;
+ std::optional<int64_t> vecOffset;
+ for (auto label : labels) {
+ for (auto [name, it] : iterators) {
+ if (name != label)
+ continue;
+ if (isBatchDimension(name)) {
+ offset.push_back(it.getPosition());
+ continue;
+ }
+ if (isVectorDimension(name)) {
+ int64_t step{1};
+ if (name == LayoutDimension::VECTORY) {
+ step = ranges.lookup(LayoutDimension::VECTORX).stop;
+ }
+ vecOffset = vecOffset.value_or(0) + it.getPosition() * step;
+ }
+ }
+ }
+ if (vecOffset)
+ offset.push_back(vecOffset.value());
+ return offset;
+}
+
+SmallVector<int64_t>
+LayoutIterator::State::computeIteratorProjectedSIMTIndex() const {
+ SmallVector<int64_t> indices = computeSIMTIndex();
+ SmallVector<int64_t> projectedIndices;
+ for (size_t i = 0, e = labels.size(); i != e; ++i) {
+ for (auto [name, it] : iterators) {
+ if (name == labels[i])
+ projectedIndices.push_back(indices[i]);
+ }
+ }
+ return projectedIndices;
+}
+
+void LayoutIterator::erase(LayoutDimension dim) {
+ if (state.contains(dim))
+ state.erase(dim);
+}
+
+LayoutIterator LayoutIterator::getBatchIterator() const {
+ LayoutIterator projectedIterator = *this;
+ for (auto [dim, it] : state.iterators) {
+ if (!isBatchDimension(dim)) {
+ DimensionalRange range = state.ranges.lookup(dim);
+ projectedIterator.maxIterations /= (range.stop / range.step);
+ projectedIterator.erase(dim);
+ }
+ }
+ return projectedIterator;
}
// clang-format off
diff --git a/llvm-external-projects/iree-dialects/test/lib/VectorExt/TestIterators.cpp b/llvm-external-projects/iree-dialects/test/lib/VectorExt/TestIterators.cpp
index fce0953..5aacc9e 100644
--- a/llvm-external-projects/iree-dialects/test/lib/VectorExt/TestIterators.cpp
+++ b/llvm-external-projects/iree-dialects/test/lib/VectorExt/TestIterators.cpp
@@ -33,7 +33,7 @@
}
// Prints the layout so that LIT can test it for correctness.
static void printFn(const LayoutIterator::State &state) {
- for (const auto &[dim, it] : state) {
+ for (const auto &[dim, it] : state.iterators) {
llvm::outs() << stringifyLayoutDimension(dim).str() + ":" +
std::to_string(*it) + ", ";
}
@@ -66,7 +66,7 @@
DenseMap<LayoutDimension, int64_t> strides;
LayoutIterator iterator(layout, strides);
auto frozenIterator = createFrozenIterator(op->getContext(), strides);
- iterator.maybeFreezeAndConcatenate(frozenIterator);
+ iterator.maybeFreezeAndConcatenate(frozenIterator.getState());
iterator.apply(printFn);
}
void runOnOperation() override {