[Codegen][GPU] Move conversion to multi_mma to PackToIntrinsics (#18141)

Now that `iree_gpu.multi_mma` has a tiling interface implementation, the
conversion from linalg to it can happen before other levels of tiling.
This allows for reshaping the inner dimensions freely before reduction
tiling and then propagating the reshapes to nearby ops without needing
to hoist them out of tiling contructs.

Additionally this is closer to the required flow for data tiling where
we need to generate the `iree_gpu.multi_mma` op before any tiling.
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/DistributeMmaToLanes.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/DistributeMmaToLanes.cpp
index ef6d5cc..577e005 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/DistributeMmaToLanes.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/DistributeMmaToLanes.cpp
@@ -35,26 +35,6 @@
 };
 } // namespace
 
-struct ConvertToMultiMma final : OpInterfaceRewritePattern<linalg::LinalgOp> {
-  using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
-  LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
-                                PatternRewriter &rewriter) const override {
-    auto loweringConfig =
-        getLoweringConfig<IREE::GPU::LoweringConfigAttr>(linalgOp);
-    if (!loweringConfig) {
-      return failure();
-    }
-    IREE::GPU::MmaInterfaceAttr kind = loweringConfig.getMmaKind();
-    if (!kind) {
-      return failure();
-    }
-    if (failed(convertContractionToMultiMma(rewriter, linalgOp, kind))) {
-      return failure();
-    }
-    return success();
-  }
-};
-
 LogicalResult fuseProducersGreedily(RewriterBase &rewriter,
                                     scf::ForallOp laneForall) {
 
@@ -100,17 +80,7 @@
   MLIRContext *context = &getContext();
   auto funcOp = getOperation();
 
-  // Step 1. Convert configured linalg ops to multi_mma.
-  {
-    RewritePatternSet patterns(context);
-    patterns.add<ConvertToMultiMma>(context);
-    if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
-      funcOp.emitError() << "failed to convert linalg to multi_mma";
-      return signalPassFailure();
-    }
-  }
-
-  // Step 2. Distribute multi_mma ops to lanes and greedily fuse producers.
+  // Distribute multi_mma ops to lanes and greedily fuse producers.
   SmallVector<IREE::GPU::MultiMmaOp> mmaOps;
   funcOp.walk([&](IREE::GPU::MultiMmaOp mmaOp) { mmaOps.push_back(mmaOp); });
   IRRewriter rewriter(funcOp);
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/PackToIntrinsics.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/PackToIntrinsics.cpp
index df0ff1c..f79365c 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/PackToIntrinsics.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/PackToIntrinsics.cpp
@@ -70,9 +70,31 @@
   return success();
 }
 
+struct ConvertToMultiMma final : OpInterfaceRewritePattern<linalg::LinalgOp> {
+  using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
+  LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
+                                PatternRewriter &rewriter) const override {
+    auto loweringConfig =
+        getLoweringConfig<IREE::GPU::LoweringConfigAttr>(linalgOp);
+    if (!loweringConfig) {
+      return failure();
+    }
+    IREE::GPU::MmaInterfaceAttr kind = loweringConfig.getMmaKind();
+    if (!kind) {
+      return failure();
+    }
+    if (failed(convertContractionToMultiMma(rewriter, linalgOp, kind))) {
+      return failure();
+    }
+    return success();
+  }
+};
+
 void PackToIntrinsicsPass::runOnOperation() {
   MLIRContext *context = &getContext();
   auto funcOp = getOperation();
+
+  // Step 1. Pack candidate linalg ops to specified shapes.
   IRRewriter rewriter(funcOp);
   SmallVector<linalg::LinalgOp> packingCandidates;
   funcOp->walk([&](linalg::LinalgOp linalgOp) {
@@ -95,7 +117,18 @@
     }
   }
 
-  // Run layout propagation patterns to pull in adjacent un-configured ops.
+  // Step 2. Convert configured linalg ops to multi_mma.
+  {
+    RewritePatternSet patterns(context);
+    patterns.add<ConvertToMultiMma>(context);
+    if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+      funcOp.emitError() << "failed to convert linalg to multi_mma";
+      return signalPassFailure();
+    }
+  }
+
+  // Step 3. Run layout propagation patterns to pull in adjacent un-configured
+  // ops.
   RewritePatternSet patterns(context);
   linalg::ControlPropagationFn control = [](OpOperand *opOperand) -> bool {
     Operation *producer = opOperand->get().getDefiningOp();
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td
index 6cc7f11..a882b83 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td
@@ -11,12 +11,11 @@
 
 def DistributeMmaToLanesPass :
     InterfacePass<"iree-gpu-distribute-mma-to-lanes", "mlir::FunctionOpInterface"> {
-  let summary = "Converts and distributes linalg ops with mma kinds to lanes";
+  let summary = "Distributes iree_gpu.multi_mma ops to lanes";
   let dependentDialects = [
     "::mlir::arith::ArithDialect",
     "::mlir::affine::AffineDialect",
     "::mlir::scf::SCFDialect",
-    "::mlir::iree_compiler::IREE::GPU::IREEGPUDialect",
   ];
 }
 
@@ -58,7 +57,7 @@
 
 def PackToIntrinsicsPass :
     InterfacePass<"iree-gpu-pack-to-intrinsics", "mlir::FunctionOpInterface"> {
-  let summary = "Packs matmul like operations to specified intrinsic shapes";
+  let summary = "Packs matmul like operations and converts to iree_gpu.multi_mma";
   let dependentDialects = [
     "::mlir::tensor::TensorDialect",
     "::mlir::iree_compiler::IREE::GPU::IREEGPUDialect"
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp
index 389ffc3..45d8ad1 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp
@@ -6,6 +6,7 @@
 
 #include "iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h"
 
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
 #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
 #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
 #include "llvm/ADT/ArrayRef.h"
@@ -442,10 +443,16 @@
     accPerm = accInnerPerm;
   }
 
+  IREE::Codegen::LoweringConfigAttrInterface maybeLoweringConfig =
+      getLoweringConfig(linalgOp);
+
   auto newMmaOp = rewriter.replaceOpWithNewOp<IREE::GPU::MultiMmaOp>(
       linalgOp, inputs[0], inputs[1], inputs[2],
       ArrayRef<AffineMap>{outerLhsMap, outerRhsMap, outerAccMap}, iteratorTypes,
       mmaKind, lhsPerm, rhsPerm, accPerm);
+  if (maybeLoweringConfig) {
+    setLoweringConfig(newMmaOp, maybeLoweringConfig);
+  }
   return newMmaOp;
 }
 
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_mma_to_lanes.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_mma_to_lanes.mlir
index b98ff42..214b432 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_mma_to_lanes.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_mma_to_lanes.mlir
@@ -1,25 +1,20 @@
 // RUN: iree-opt %s --pass-pipeline='builtin.module(func.func(iree-gpu-distribute-mma-to-lanes, canonicalize, cse))' --split-input-file | FileCheck %s
 
-#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
-#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)>
-#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+#contraction_accesses = [
+ affine_map<(i, j, k) -> (i, k)>,
+ affine_map<(i, j, k) -> (k, j)>,
+ affine_map<(i, j, k) -> (i, j)>
+]
 module {
   func.func @matmul_16x16x16(%arg0: tensor<8x2x16x16xf16>, %arg1: tensor<8x2x16x16xf16>, %arg2: tensor<2x2x16x16xf32>) -> tensor<2x2x16x16xf32> {
     %empty = tensor.empty() : tensor<2x8x16x16xf16>
     %lhs_transpose = linalg.transpose ins(%arg0: tensor<8x2x16x16xf16>) outs(%empty: tensor<2x8x16x16xf16>) permutation = [1, 0, 2, 3]
-    %mm = linalg.generic {
-      indexing_maps = [#map, #map1, #map2],
-      iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
-      ins(%lhs_transpose, %arg1 : tensor<2x8x16x16xf16>, tensor<8x2x16x16xf16>)
-      outs(%arg2 : tensor<2x2x16x16xf32>)
-    attrs =  {lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>}>} {
-    ^bb0(%in: f16, %in_2: f16, %out: f32):
-      %4 = arith.extf %in : f16 to f32
-      %5 = arith.extf %in_2 : f16 to f32
-      %6 = arith.mulf %4, %5 : f32
-      %7 = arith.addf %out, %6 : f32
-      linalg.yield %7 : f32
-    } -> tensor<2x2x16x16xf32>
+    %mm = iree_gpu.multi_mma %lhs_transpose, %arg1, %arg2 {
+      indexing_maps = #contraction_accesses,
+      iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>],
+      kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
+      rhs_permutation = array<i64: 1, 0>
+    } : tensor<2x8x16x16xf16>, tensor<8x2x16x16xf16> into tensor<2x2x16x16xf32>
     return %mm : tensor<2x2x16x16xf32>
   }
 }
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/pack_to_intrinsics.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/pack_to_intrinsics.mlir
index 7da25ab..30153f3 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/pack_to_intrinsics.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/pack_to_intrinsics.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt %s --pass-pipeline='builtin.module(func.func(iree-gpu-pack-to-intrinsics, canonicalize, cse))' --split-input-file | FileCheck %s
+// RUN: iree-opt %s --mlir-print-local-scope --pass-pipeline='builtin.module(func.func(iree-gpu-pack-to-intrinsics, canonicalize, cse))' --split-input-file | FileCheck %s
 
 #config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>}>
 module {
@@ -15,10 +15,15 @@
 //   CHECK-DAG:   %[[A_PACK:.+]] = tensor.pack %[[A]] inner_dims_pos = [0, 1] inner_tiles = [32, 8]
 //   CHECK-DAG:   %[[B_PACK:.+]] = tensor.pack %[[B]] inner_dims_pos = [1, 0] inner_tiles = [32, 8]
 //   CHECK-DAG:   %[[C_PACK:.+]] = tensor.pack %[[C]] inner_dims_pos = [0, 1] inner_tiles = [32, 32]
-//       CHECK:   %[[PACKED_MM:.+]] = linalg.generic
-//  CHECK-SAME:     ins(%[[A_PACK]], %[[B_PACK]] : tensor<2x8x32x8xf16>, tensor<8x2x32x8xf16>)
-//  CHECK-SAME:     outs(%[[C_PACK]] : tensor<2x2x32x32xf32>)
+//       CHECK:   iree_gpu.multi_mma %[[A_PACK]], %[[B_PACK]], %[[C_PACK]]
+//  CHECK-SAME:     indexing_maps =
+//  CHECK-SAME:       affine_map<(d0, d1, d2) -> (d0, d2)>
+//  CHECK-SAME:       affine_map<(d0, d1, d2) -> (d2, d1)>
+//  CHECK-SAME:       affine_map<(d0, d1, d2) -> (d0, d1)>
+//  CHECK-SAME:     iterator_types = {{.*}}parallel{{.*}}parallel{{.*}}reduction
+//  CHECK-SAME:     kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>
 //  CHECK-SAME:     lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>}>
+//  CHECK-SAME:     rhs_permutation = array<i64: 1, 0>
 
 // -----
 
@@ -45,13 +50,11 @@
   }
 }
 
-// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d3, d4, d5, d7)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d0, d3, d4, d6, d7)>
-// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d5, d6)>
-
 // CHECK-LABEL: func.func @matmul_16x16x16
-//       CHECK:   %[[PACKED_MM:.+]] = linalg.generic
-//  CHECK-SAME:     indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
-//  CHECK-SAME:     ins({{.*}} : tensor<?x?x?x16x16xf16>, tensor<?x?x?x?x16x16xf16>)
-//  CHECK-SAME:     outs({{.*}} : tensor<?x?x?x16x16xf32>)
+//       CHECK:   iree_gpu.multi_mma
+//  CHECK-SAME:     indexing_maps =
+//  CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>
+//  CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4) -> (d2, d0, d3, d4)>
+//  CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
 //  CHECK-SAME:     lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>}>
+//  CHECK-SAME:     : tensor<?x?x?x16x16xf16>, tensor<?x?x?x?x16x16xf16> into tensor<?x?x?x16x16xf32>