[Codegen][GPU] Add pattern to drop lead unit dims of multi_mma ops (#17456)
Dropping leading unit dims is a useful step that simplifies the process
of lowering to intrinsics by removing the outer iteration space. The
typical lowering flow is to unroll the outer dims to 1 and then apply
this pattern to drop the unit outer dims.
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td
index 6b18e5c..c794d2f 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td
@@ -208,6 +208,18 @@
accType.getRank() - getIndexingMapsArray()[2].getNumResults();
return accType.getShape().take_back(accInnerDimRank);
}
+
+ int64_t getLhsOuterRank() {
+ return getIndexingMapsArray()[0].getNumResults();
+ }
+
+ int64_t getRhsOuterRank() {
+ return getIndexingMapsArray()[1].getNumResults();
+ }
+
+ int64_t getAccOuterRank() {
+ return getIndexingMapsArray()[2].getNumResults();
+ }
}];
let hasVerifier = 1;
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp
index 8dcabd1..7f9141c 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp
@@ -26,6 +26,15 @@
}
//===---------------------------------------------------------------------===//
+// ApplyDropMultiMmaOpUnitDims
+//===---------------------------------------------------------------------===//
+
+void transform_dialect::ApplyDropMultiMmaOpUnitDims::populatePatterns(
+ RewritePatternSet &patterns) {
+ IREE::GPU::populateIREEGPUDropUnitDimsPatterns(patterns);
+}
+
+//===---------------------------------------------------------------------===//
// ApplyLowerValueBarrierOp
//===---------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td
index dc69083..1a33675 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td
@@ -14,6 +14,19 @@
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"
+def ApplyDropMultiMmaOpUnitDims : Op<Transform_Dialect,
+ "apply_patterns.iree.drop_multi_mma_unit_dims",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ Populate patterns to drop the unit dims from multi_mma ops with
+ only unit iteration bounds.
+ }];
+
+ let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+ let assemblyFormat = "attr-dict";
+}
+
def ApplyLowerValueBarrierOp : Op<Transform_Dialect,
"apply_patterns.iree.lower_value_barrier",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>,
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel
index f78f6d0..4b644fa 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel
@@ -18,6 +18,7 @@
name = "lit",
srcs = enforce_glob(
[
+ "drop_multi_mma_unit_dims.mlir",
"lower_vector_barrier.mlir",
"transform_fuse_forall.mlir",
"vectorize_multi_mma.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt
index f3e2e40..8e3ec6d 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt
@@ -14,6 +14,7 @@
NAME
lit
SRCS
+ "drop_multi_mma_unit_dims.mlir"
"lower_vector_barrier.mlir"
"transform_fuse_forall.mlir"
"unroll_multi_mma.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/drop_multi_mma_unit_dims.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/drop_multi_mma_unit_dims.mlir
new file mode 100644
index 0000000..9adbf3b
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/drop_multi_mma_unit_dims.mlir
@@ -0,0 +1,78 @@
+// RUN: iree-opt %s -iree-transform-dialect-interpreter -transform-dialect-drop-schedule --split-input-file | FileCheck %s
+
+#contraction_accesses = [
+ affine_map<(i, j, k) -> (i, k)>,
+ affine_map<(i, j, k) -> (k, j)>,
+ affine_map<(i, j, k) -> (i, j)>
+]
+func.func @drop_multi_mma_unit_dims(%lhs: vector<1x1x4xf16>, %rhs: vector<1x1x4xf16>, %acc: vector<1x1x4xf32>) -> vector<1x1x4xf32> {
+ %0 = iree_gpu.multi_mma %lhs, %rhs, %acc {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
+ } : vector<1x1x4xf16>, vector<1x1x4xf16> into vector<1x1x4xf32>
+ return %0 : vector<1x1x4xf32>
+}
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.iree.drop_multi_mma_unit_dims
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK: #[[$MAP:.+]] = affine_map<() -> ()>
+
+// CHECK-LABEL: func @drop_multi_mma_unit_dims
+// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: vector<1x1x4xf16>
+// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: vector<1x1x4xf16>
+// CHECK-SAME: %[[ACC:[A-Za-z0-9]+]]: vector<1x1x4xf32>
+// CHECK: %[[LHS_EXT:.+]] = vector.extract %[[LHS]][0, 0] : vector<4xf16> from vector<1x1x4xf16>
+// CHECK: %[[RHS_EXT:.+]] = vector.extract %[[RHS]][0, 0] : vector<4xf16> from vector<1x1x4xf16>
+// CHECK: %[[ACC_EXT:.+]] = vector.extract %[[ACC]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
+// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS_EXT]], %[[RHS_EXT]], %[[ACC_EXT]]
+// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]], iterator_types = []
+// CHECK-SAME: kind = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>} : vector<4xf16>, vector<4xf16> into vector<4xf32>
+// CHECK: vector.broadcast %[[MMA]] : vector<4xf32> to vector<1x1x4xf32>
+
+// -----
+
+#contraction_accesses = [
+ affine_map<(i) -> (i)>,
+ affine_map<(i) -> ()>,
+ affine_map<(i) -> (i)>
+]
+func.func @drop_multi_mma_unit_dims_no_kn(%lhs: vector<1x4xf16>, %rhs: vector<4xf16>, %acc: vector<1x4xf32>) -> vector<1x4xf32> {
+ %0 = iree_gpu.multi_mma %lhs, %rhs, %acc {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#iree_gpu.iterator_type<parallel>],
+ kind = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
+ } : vector<1x4xf16>, vector<4xf16> into vector<1x4xf32>
+ return %0 : vector<1x4xf32>
+}
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.iree.drop_multi_mma_unit_dims
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK: #[[$MAP:.+]] = affine_map<() -> ()>
+
+// CHECK-LABEL: func @drop_multi_mma_unit_dims_no_kn
+// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: vector<1x4xf16>
+// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: vector<4xf16>
+// CHECK-SAME: %[[ACC:[A-Za-z0-9]+]]: vector<1x4xf32>
+// CHECK: %[[LHS_EXT:.+]] = vector.extract %[[LHS]][0] : vector<4xf16> from vector<1x4xf16>
+// CHECK: %[[ACC_EXT:.+]] = vector.extract %[[ACC]][0] : vector<4xf32> from vector<1x4xf32>
+// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS_EXT]], %[[RHS]], %[[ACC_EXT]]
+// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]], iterator_types = []
+// CHECK-SAME: kind = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>} : vector<4xf16>, vector<4xf16> into vector<4xf32>
+// CHECK: vector.broadcast %[[MMA]] : vector<4xf32> to vector<1x4xf32>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp
index 5e982a8..60596a0 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp
@@ -165,6 +165,65 @@
}
//===----------------------------------------------------------------------===//
+// MultiMmaOp Unit Dim Folding
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct DropMultiMmaUnitDimsPattern
+ : public OpRewritePattern<IREE::GPU::MultiMmaOp> {
+ using OpRewritePattern<IREE::GPU::MultiMmaOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(IREE::GPU::MultiMmaOp mmaOp,
+ PatternRewriter &rewriter) const override {
+ if (mmaOp.hasTensorSemantics()) {
+ return rewriter.notifyMatchFailure(
+ mmaOp, "unimplemented: unit dim dropping for tensor mma ops");
+ }
+ SmallVector<int64_t> bounds;
+ mmaOp.getIterationBounds(bounds);
+ if (bounds.empty()) {
+ return rewriter.notifyMatchFailure(mmaOp, "no dimensions to fold");
+ }
+
+ // TODO: Generalize to allow only some iteration bounds to be unit. This
+ // pattern currently only supports the most common case of unrolling to the
+ // intrinsic shape.
+ if (!llvm::all_of(bounds, [](int64_t b) { return b == 1; })) {
+ return rewriter.notifyMatchFailure(mmaOp,
+ "not all iteration bounds are unit");
+ }
+
+ Location loc = mmaOp.getLoc();
+ auto dropLeadUnitDims = [&](Value operand, int64_t numDims) -> Value {
+ if (numDims == 0) {
+ return operand;
+ }
+ SmallVector<int64_t> droppedDimIndices(numDims, 0);
+ return rewriter.create<vector::ExtractOp>(loc, operand,
+ droppedDimIndices);
+ };
+
+ Value newLhs = dropLeadUnitDims(mmaOp.getLhs(), mmaOp.getLhsOuterRank());
+ Value newRhs = dropLeadUnitDims(mmaOp.getRhs(), mmaOp.getRhsOuterRank());
+ Value newAcc = dropLeadUnitDims(mmaOp.getAcc(), mmaOp.getAccOuterRank());
+
+ AffineMap empty = AffineMap::get(rewriter.getContext());
+ auto newMmaOp = rewriter.create<IREE::GPU::MultiMmaOp>(
+ loc, newLhs, newRhs, newAcc,
+ rewriter.getAffineMapArrayAttr({empty, empty, empty}),
+ rewriter.getArrayAttr({}), mmaOp.getKind());
+
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+ mmaOp, mmaOp.getResultType(), newMmaOp);
+ return success();
+ }
+};
+} // namespace
+
+void populateIREEGPUDropUnitDimsPatterns(RewritePatternSet &patterns) {
+ patterns.add<DropMultiMmaUnitDimsPattern>(patterns.getContext());
+}
+
+//===----------------------------------------------------------------------===//
// MultiMmaOp Unrolling
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h
index 0e2afa3..c706ee1 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h
@@ -34,13 +34,12 @@
scf::ForallOp consumer,
tensor::ExtractSliceOp slice);
+void populateIREEGPUDropUnitDimsPatterns(RewritePatternSet &patterns);
+void populateIREEGPULowerValueBarrierPatterns(RewritePatternSet &patterns);
void populateIREEGPUVectorUnrollPatterns(
RewritePatternSet &patterns, const vector::UnrollVectorOptions &options);
-
void populateIREEGPUVectorizationPatterns(RewritePatternSet &patterns);
-void populateIREEGPULowerValueBarrierPatterns(RewritePatternSet &patterns);
-
} // namespace mlir::iree_compiler::IREE::GPU
#endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMS_TRANSFORMS_H_