[Codegen][GPU] Add pattern to drop lead unit dims of multi_mma ops (#17456)

Dropping leading unit dims is a useful step that simplifies the process
of lowering to intrinsics by removing the outer iteration space. The
typical lowering flow is to unroll the outer dims to 1 and then apply
this pattern to drop the unit outer dims.
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td
index 6b18e5c..c794d2f 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td
@@ -208,6 +208,18 @@
         accType.getRank() - getIndexingMapsArray()[2].getNumResults();
       return accType.getShape().take_back(accInnerDimRank);
     }
+
+    int64_t getLhsOuterRank() {
+      return getIndexingMapsArray()[0].getNumResults();
+    }
+
+    int64_t getRhsOuterRank() {
+      return getIndexingMapsArray()[1].getNumResults();
+    }
+
+    int64_t getAccOuterRank() {
+      return getIndexingMapsArray()[2].getNumResults();
+    }
   }];
 
   let hasVerifier = 1;
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp
index 8dcabd1..7f9141c 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp
@@ -26,6 +26,15 @@
 }
 
 //===---------------------------------------------------------------------===//
+// ApplyDropMultiMmaOpUnitDims
+//===---------------------------------------------------------------------===//
+
+void transform_dialect::ApplyDropMultiMmaOpUnitDims::populatePatterns(
+    RewritePatternSet &patterns) {
+  IREE::GPU::populateIREEGPUDropUnitDimsPatterns(patterns);
+}
+
+//===---------------------------------------------------------------------===//
 // ApplyLowerValueBarrierOp
 //===---------------------------------------------------------------------===//
 
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td
index dc69083..1a33675 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td
@@ -14,6 +14,19 @@
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/OpBase.td"
 
+def ApplyDropMultiMmaOpUnitDims : Op<Transform_Dialect,
+    "apply_patterns.iree.drop_multi_mma_unit_dims",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Populate patterns to drop the unit dims from multi_mma ops with
+    only unit iteration bounds.
+  }];
+
+  let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplyLowerValueBarrierOp : Op<Transform_Dialect,
     "apply_patterns.iree.lower_value_barrier",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>,
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel
index f78f6d0..4b644fa 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel
@@ -18,6 +18,7 @@
     name = "lit",
     srcs = enforce_glob(
         [
+            "drop_multi_mma_unit_dims.mlir",
             "lower_vector_barrier.mlir",
             "transform_fuse_forall.mlir",
             "vectorize_multi_mma.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt
index f3e2e40..8e3ec6d 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt
@@ -14,6 +14,7 @@
   NAME
     lit
   SRCS
+    "drop_multi_mma_unit_dims.mlir"
     "lower_vector_barrier.mlir"
     "transform_fuse_forall.mlir"
     "unroll_multi_mma.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/drop_multi_mma_unit_dims.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/drop_multi_mma_unit_dims.mlir
new file mode 100644
index 0000000..9adbf3b
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/drop_multi_mma_unit_dims.mlir
@@ -0,0 +1,78 @@
+// RUN: iree-opt %s -iree-transform-dialect-interpreter -transform-dialect-drop-schedule --split-input-file | FileCheck %s
+
+#contraction_accesses = [
+ affine_map<(i, j, k) -> (i, k)>,
+ affine_map<(i, j, k) -> (k, j)>,
+ affine_map<(i, j, k) -> (i, j)>
+]
+func.func @drop_multi_mma_unit_dims(%lhs: vector<1x1x4xf16>, %rhs: vector<1x1x4xf16>, %acc: vector<1x1x4xf32>) -> vector<1x1x4xf32> {
+  %0 = iree_gpu.multi_mma %lhs, %rhs, %acc {
+    indexing_maps = #contraction_accesses,
+    iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>],
+    kind = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
+  } : vector<1x1x4xf16>, vector<1x1x4xf16> into vector<1x1x4xf32>
+  return %0 : vector<1x1x4xf32>
+}
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.iree.drop_multi_mma_unit_dims
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP:.+]] =  affine_map<() -> ()>
+
+// CHECK-LABEL: func @drop_multi_mma_unit_dims
+//  CHECK-SAME:   %[[LHS:[A-Za-z0-9]+]]: vector<1x1x4xf16>
+//  CHECK-SAME:   %[[RHS:[A-Za-z0-9]+]]: vector<1x1x4xf16>
+//  CHECK-SAME:   %[[ACC:[A-Za-z0-9]+]]: vector<1x1x4xf32>
+//       CHECK:   %[[LHS_EXT:.+]] = vector.extract %[[LHS]][0, 0] : vector<4xf16> from vector<1x1x4xf16>
+//       CHECK:   %[[RHS_EXT:.+]] = vector.extract %[[RHS]][0, 0] : vector<4xf16> from vector<1x1x4xf16>
+//       CHECK:   %[[ACC_EXT:.+]] = vector.extract %[[ACC]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
+//       CHECK:   %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS_EXT]], %[[RHS_EXT]], %[[ACC_EXT]]
+//  CHECK-SAME:     indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]], iterator_types = []
+//  CHECK-SAME:     kind = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>} : vector<4xf16>, vector<4xf16> into vector<4xf32>
+//       CHECK:   vector.broadcast %[[MMA]] : vector<4xf32> to vector<1x1x4xf32>
+
+// -----
+
+#contraction_accesses = [
+ affine_map<(i) -> (i)>,
+ affine_map<(i) -> ()>,
+ affine_map<(i) -> (i)>
+]
+func.func @drop_multi_mma_unit_dims_no_kn(%lhs: vector<1x4xf16>, %rhs: vector<4xf16>, %acc: vector<1x4xf32>) -> vector<1x4xf32> {
+  %0 = iree_gpu.multi_mma %lhs, %rhs, %acc {
+    indexing_maps = #contraction_accesses,
+    iterator_types = [#iree_gpu.iterator_type<parallel>],
+    kind = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
+  } : vector<1x4xf16>, vector<4xf16> into vector<1x4xf32>
+  return %0 : vector<1x4xf32>
+}
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.iree.drop_multi_mma_unit_dims
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP:.+]] =  affine_map<() -> ()>
+
+// CHECK-LABEL: func @drop_multi_mma_unit_dims_no_kn
+//  CHECK-SAME:   %[[LHS:[A-Za-z0-9]+]]: vector<1x4xf16>
+//  CHECK-SAME:   %[[RHS:[A-Za-z0-9]+]]: vector<4xf16>
+//  CHECK-SAME:   %[[ACC:[A-Za-z0-9]+]]: vector<1x4xf32>
+//       CHECK:   %[[LHS_EXT:.+]] = vector.extract %[[LHS]][0] : vector<4xf16> from vector<1x4xf16>
+//       CHECK:   %[[ACC_EXT:.+]] = vector.extract %[[ACC]][0] : vector<4xf32> from vector<1x4xf32>
+//       CHECK:   %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS_EXT]], %[[RHS]], %[[ACC_EXT]]
+//  CHECK-SAME:     indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]], iterator_types = []
+//  CHECK-SAME:     kind = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>} : vector<4xf16>, vector<4xf16> into vector<4xf32>
+//       CHECK:   vector.broadcast %[[MMA]] : vector<4xf32> to vector<1x4xf32>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp
index 5e982a8..60596a0 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp
@@ -165,6 +165,65 @@
 }
 
 //===----------------------------------------------------------------------===//
+// MultiMmaOp Unit Dim Folding
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct DropMultiMmaUnitDimsPattern
+    : public OpRewritePattern<IREE::GPU::MultiMmaOp> {
+  using OpRewritePattern<IREE::GPU::MultiMmaOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(IREE::GPU::MultiMmaOp mmaOp,
+                                PatternRewriter &rewriter) const override {
+    if (mmaOp.hasTensorSemantics()) {
+      return rewriter.notifyMatchFailure(
+          mmaOp, "unimplemented: unit dim dropping for tensor mma ops");
+    }
+    SmallVector<int64_t> bounds;
+    mmaOp.getIterationBounds(bounds);
+    if (bounds.empty()) {
+      return rewriter.notifyMatchFailure(mmaOp, "no dimensions to fold");
+    }
+
+    // TODO: Generalize to allow only some iteration bounds to be unit. This
+    // pattern currently only supports the most common case of unrolling to the
+    // intrinsic shape.
+    if (!llvm::all_of(bounds, [](int64_t b) { return b == 1; })) {
+      return rewriter.notifyMatchFailure(mmaOp,
+                                         "not all iteration bounds are unit");
+    }
+
+    Location loc = mmaOp.getLoc();
+    auto dropLeadUnitDims = [&](Value operand, int64_t numDims) -> Value {
+      if (numDims == 0) {
+        return operand;
+      }
+      SmallVector<int64_t> droppedDimIndices(numDims, 0);
+      return rewriter.create<vector::ExtractOp>(loc, operand,
+                                                droppedDimIndices);
+    };
+
+    Value newLhs = dropLeadUnitDims(mmaOp.getLhs(), mmaOp.getLhsOuterRank());
+    Value newRhs = dropLeadUnitDims(mmaOp.getRhs(), mmaOp.getRhsOuterRank());
+    Value newAcc = dropLeadUnitDims(mmaOp.getAcc(), mmaOp.getAccOuterRank());
+
+    AffineMap empty = AffineMap::get(rewriter.getContext());
+    auto newMmaOp = rewriter.create<IREE::GPU::MultiMmaOp>(
+        loc, newLhs, newRhs, newAcc,
+        rewriter.getAffineMapArrayAttr({empty, empty, empty}),
+        rewriter.getArrayAttr({}), mmaOp.getKind());
+
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+        mmaOp, mmaOp.getResultType(), newMmaOp);
+    return success();
+  }
+};
+} // namespace
+
+void populateIREEGPUDropUnitDimsPatterns(RewritePatternSet &patterns) {
+  patterns.add<DropMultiMmaUnitDimsPattern>(patterns.getContext());
+}
+
+//===----------------------------------------------------------------------===//
 // MultiMmaOp Unrolling
 //===----------------------------------------------------------------------===//
 
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h
index 0e2afa3..c706ee1 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h
@@ -34,13 +34,12 @@
                                   scf::ForallOp consumer,
                                   tensor::ExtractSliceOp slice);
 
+void populateIREEGPUDropUnitDimsPatterns(RewritePatternSet &patterns);
+void populateIREEGPULowerValueBarrierPatterns(RewritePatternSet &patterns);
 void populateIREEGPUVectorUnrollPatterns(
     RewritePatternSet &patterns, const vector::UnrollVectorOptions &options);
-
 void populateIREEGPUVectorizationPatterns(RewritePatternSet &patterns);
 
-void populateIREEGPULowerValueBarrierPatterns(RewritePatternSet &patterns);
-
 } // namespace mlir::iree_compiler::IREE::GPU
 
 #endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMS_TRANSFORMS_H_