[LLVMGPU] Generalize VectorContractOpInfo based on indexing maps (#17625)
This patch generalizes VectorContractOpInfo to work on any kind of
vector.contract. The "kind" field was not being used by any pass anyway
and they were relying on m, n, k dims.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp
index 910c5ad..a52ead0 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp
@@ -67,9 +67,6 @@
// Infer the contract kind so that we know know to correlate M/N/K dims.
VectorContractOpInfo opDetail(contractOp);
- if (opDetail.getOpKind() == VectorContractOpInfo::OpKind::UNKNOWN) {
- return rewriter.notifyMatchFailure(contractOp, "unknown contract kind");
- }
SmallVector<int64_t> distShape = resultLayout.getDistributedShape();
LLVM_DEBUG({
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index d4fd9fc..5df8a36 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -833,9 +833,8 @@
llvm::errs() << "Getting mma layouts for:\n" << contractOp << "\n";
llvm::errs() << "For schedule: " << *this << "\n";
});
- if (opInfo.getOpKind() == VectorContractOpInfo::OpKind::UNKNOWN) {
- LLVM_DEBUG({ llvm::errs() << "Unknown contraction kind\n"; });
- return failure();
+ if (opInfo.getKDims().size() != 1) {
+ return contractOp->emitError("Unimplemented: > 1 k dims");
}
auto mmaAttr = llvm::cast<MMAAttr>(getIntrinsic());
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp
index 3aead0a..84725a0 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp
@@ -32,9 +32,6 @@
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
VectorContractOpInfo opInfo(contractOp);
- if (opInfo.getOpKind() == VectorContractOpInfo::OpKind::UNKNOWN) {
- return rewriter.notifyMatchFailure(contractOp, "unhandled contract kind");
- }
auto srcCType = dyn_cast<VectorType>(contractOp.getAccType());
if (!srcCType) {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir
index 6dfaf38..a3d13d0 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir
@@ -54,6 +54,34 @@
// -----
#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
+ workgroup_size = [64, 1, 1]
+ subgroup_size = 64,
+ {mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>, subgroup_m_count = 1, subgroup_n_count = 1>}>
+
+func.func @mfma_matmul_96x64x16_mmtt(%lhs: vector<96x16xf16>, %rhs: vector<64x16xf16>, %init: vector<64x96xf32>) -> vector<64x96xf32> attributes { translation_info = #translation } {
+ %0 = vector.contract {
+ indexing_maps = [affine_map<(m, n, k) -> (m, k)>, affine_map<(m, n, k) -> (n, k)>, affine_map<(m, n, k) -> (n, m)>],
+ iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %lhs, %rhs, %init : vector<96x16xf16>, vector<64x16xf16> into vector<64x96xf32>
+ return %0 : vector<64x96xf32>
+}
+
+// CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
+// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 4],
+// CHECK-SAME: thread_order = [1, 0],
+// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, false, true], thread_basis = [2, 32]>
+// CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
+// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 4],
+// CHECK-SAME: thread_order = [1, 0],
+// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [false, true, true], thread_basis = [2, 32]>
+// CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
+// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 3], outers_per_batch = [1, 4], threads_per_outer = [32, 2], elements_per_thread = [1, 4],
+// CHECK-SAME: subgroup_order = [1, 0], thread_order = [1, 0],
+// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, true, false], thread_basis = [2, 32]>
+
+// -----
+
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [64, 2, 1]
subgroup_size = 64,
{mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>, subgroup_m_count = 2, subgroup_n_count = 1>}>
diff --git a/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.cpp
index 78007c6..0bf721c 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.cpp
@@ -26,33 +26,11 @@
return std::make_pair(outMDims.back(), outNDims.back());
}
-VectorContractOpInfo::OpKind
-VectorContractOpInfo::inferOpKind(MLIRContext *ctx,
- SmallVector<AffineMap> maps) {
- if (contractionDims.k.size() != 1) {
- return OpKind::UNKNOWN;
- }
- if (!contractionDims.batch.empty()) {
- if (contractionDims.batch.size() > 1 || contractionDims.batch[0] != 0) {
- return OpKind::UNKNOWN;
- }
- if (*maps[0].getResultPosition(getAffineDimExpr(0, ctx)) != 0 ||
- *maps[1].getResultPosition(getAffineDimExpr(0, ctx)) != 0 ||
- *maps[2].getResultPosition(getAffineDimExpr(0, ctx)) != 0) {
- return OpKind::UNKNOWN;
- }
- }
+VectorContractOpInfo::VectorContractOpInfo(vector::ContractionOp op) {
+ contractionDims = *linalg::inferContractionDims(op.getIndexingMapsArray());
- int64_t innerM = contractionDims.m.back();
- int64_t innerN = contractionDims.n.back();
- int64_t k = contractionDims.k.back();
-
- int64_t lhsM = *maps[0].getResultPosition(getAffineDimExpr(innerM, ctx));
- lhsKDim = *maps[0].getResultPosition(getAffineDimExpr(k, ctx));
- int64_t rhsN = *maps[1].getResultPosition(getAffineDimExpr(innerN, ctx));
- rhsKDim = *maps[1].getResultPosition(getAffineDimExpr(k, ctx));
- int64_t outM = *maps[2].getResultPosition(getAffineDimExpr(innerM, ctx));
- int64_t outN = *maps[2].getResultPosition(getAffineDimExpr(innerN, ctx));
+ SmallVector<AffineMap> maps = op.getIndexingMapsArray();
+ MLIRContext *ctx = op.getContext();
for (auto m : contractionDims.m) {
lhsMDims.push_back(*maps[0].getResultPosition(getAffineDimExpr(m, ctx)));
@@ -63,15 +41,9 @@
outNDims.push_back(*maps[2].getResultPosition(getAffineDimExpr(n, ctx)));
}
- if (outM < outN) {
- if (lhsM < lhsKDim) {
- if (rhsN < rhsKDim) {
- return OpKind::MK_NK_MN;
- }
- return OpKind::MK_KN_MN;
- }
- }
- return OpKind::UNKNOWN;
+ int64_t k = contractionDims.k.back();
+ lhsKDim = *maps[0].getResultPosition(getAffineDimExpr(k, ctx));
+ rhsKDim = *maps[1].getResultPosition(getAffineDimExpr(k, ctx));
}
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.h b/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.h
index 64cfc0a..f8e0cf9 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.h
@@ -6,21 +6,13 @@
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/IR/BuiltinTypes.h"
namespace mlir::iree_compiler {
/// A class for querying information about a contract op.
class VectorContractOpInfo {
public:
- enum class OpKind { MK_KN_MN, MK_NK_MN, UNKNOWN };
-
- explicit VectorContractOpInfo(vector::ContractionOp op) {
- contractionDims = *linalg::inferContractionDims(op.getIndexingMapsArray());
- opKind = inferOpKind(op.getContext(), op.getIndexingMapsArray());
- }
-
- OpKind getOpKind() const { return opKind; }
+ explicit VectorContractOpInfo(vector::ContractionOp op);
// Returns the (LHS M, RHS N) dimension index pair.
std::pair<int, int> getOperandMNIndex() const;
@@ -58,11 +50,6 @@
SmallVector<int64_t> outNDims;
private:
- // Gets the kind of a contract op with the given indexing |maps|.
- OpKind inferOpKind(MLIRContext *ctx, SmallVector<AffineMap> maps);
-
- OpKind opKind = OpKind::UNKNOWN;
-
linalg::ContractionDimensions contractionDims;
};