[codegen] more consumer fusion (#21521)

Enable consumer fusion to fuse more trailing operations, in particular
to fuse all operations used in layer normalization. The particular
enablement is to allow fusion of consumers of fused operations that are
not the "main" operation into which fusion happens.

This required three modifications:

1. Generate replacement values for any fused operation whose results are
used by operations dominated by the main operation (that will turn into
a loop). This is needed because the consumer fusion finds candidates
based on `insert_slice` operations in the loop that are not produced
unless the replacement values are generated. In the case of layer
normalizaiton, we also need actual replacement values so we can return
them. Spurious values are being cleaned up by canonicalization patterns
in the same pass.
2. Avoid generating replacement values for values used as DPS inits as
those cannot (currently) be fused.
3. Sort consumers to fuse by dominance and first most-dominating
consumer fuse to avoid triangular-use situation that prevents fusion.

This generally makes fusion more aggressive on the consumer side.

Signed-off-by: Alex Zinenko <git@ozinenko.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.cpp b/compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.cpp
index 21a269a..64de60c 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.cpp
@@ -9,6 +9,7 @@
 #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
 #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/Debug.h"
 #include "mlir/Analysis/TopologicalSortUtils.h"
@@ -87,16 +88,51 @@
   }
 }
 
+namespace {
+// Entry for the pseudo-priority queue of consumer fusion candidates. Contains
+// the consumer (fusableUser) that can be fused and the set of slice operations
+// in the loop to fuse into that feed the consumer.
+struct ConsumerFusionQueueEntry {
+  ConsumerFusionQueueEntry(SmallVector<Operation *> &&slices,
+                           Operation *fusableUser)
+      : slices(std::move(slices)), fusableUser(fusableUser) {}
+
+  SmallVector<Operation *> slices;
+  Operation *fusableUser;
+};
+} // namespace
+
 FailureOr<std::queue<Operation *>>
-fuseConsumersIntoForall(RewriterBase &rewriter, Operation *tiledOp,
+fuseConsumersIntoForall(RewriterBase &rewriter, ArrayRef<Operation *> tiledOps,
                         MutableArrayRef<LoopLikeOpInterface> loops,
                         std::function<bool(Operation *)> filterFn) {
   // Collect the candidate slices which can be potential consumers that can be
-  // fused.
-  std::queue<SmallVector<Operation *>> candidates;
+  // fused. Keep them in a vector reverse-sorted by dominance: the candidate
+  // dominating others comes last (so it can be cheaply popped from the vector).
+  // The most-dominating candidate is to be fused first since not fusing it may
+  // prevent dominated candidates to be fused:
+  //
+  //     A
+  //     |
+  //     B
+  //   / |
+  //  |  D
+  //  | /
+  //  C
+  //
+  // here, B must be fused before both C and D, and D must be fused before C.
+  // Candidates are kept in a vector rather than a priority queue since we may
+  // update them as fusion happens, in particular, more slices may need to be
+  // handled. For example, fusing B with A will create a slice of B that will
+  // need to be handled correctly.
+  SmallVector<ConsumerFusionQueueEntry> candidates;
   llvm::SmallDenseSet<tensor::ParallelInsertSliceOp> allCandidates;
   auto addCandidateSlices = [&candidates, &allCandidates,
                              &filterFn](Operation *fusedOp) {
+    // Dominance info recreated since op creation/movement in the fusion logic
+    // invalidates it anyway.
+    DominanceInfo dominanceInfo;
+
     for (auto *userOp : fusedOp->getResults().getUsers()) {
       auto sliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(userOp);
       if (!sliceOp || allCandidates.contains(sliceOp)) {
@@ -135,22 +171,38 @@
         allCandidates.insert_range(slices);
       }
       if (!fusedSlices.empty()) {
-        candidates.emplace(std::move(fusedSlices));
+        ConsumerFusionQueueEntry entry(std::move(fusedSlices), fusableUser);
+
+        // Comparator that puts the dominating user last.
+        auto comp = [&](const ConsumerFusionQueueEntry &lhs,
+                        const ConsumerFusionQueueEntry &rhs) {
+          return dominanceInfo.properlyDominates(rhs.fusableUser,
+                                                 lhs.fusableUser);
+        };
+
+        // If the fusable user is already a candidate, update it with the new
+        // list of slices to handle. Otherwise, insert it into the right
+        // position based on dominance.
+        auto *it = llvm::lower_bound(candidates, entry, comp);
+        if (it != candidates.end() && it->fusableUser == fusableUser)
+          *it = std::move(entry);
+        else
+          candidates.insert(it, std::move(entry));
       }
     }
   };
 
-  addCandidateSlices(tiledOp);
+  // Add slices from all tiled ops, not only the "main" one.
+  for (Operation *tiledOp : tiledOps)
+    addCandidateSlices(tiledOp);
 
   std::queue<Operation *> newFusionOpportunities;
   while (!candidates.empty()) {
-    // Traverse the slices in BFS fashion.
-    SmallVector<Operation *> candidateSlices = candidates.front();
-    candidates.pop();
+    // Get the next candidate.
+    ConsumerFusionQueueEntry entry = candidates.pop_back_val();
 
     FailureOr<scf::SCFFuseConsumerOfSliceResult> fusedResult =
-        mlir::scf::tileAndFuseConsumerOfSlices(rewriter, candidateSlices,
-                                               loops);
+        mlir::scf::tileAndFuseConsumerOfSlices(rewriter, entry.slices, loops);
     if (failed(fusedResult)) {
       return failure();
     }
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.h b/compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.h
index 690df25..303f831 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.h
@@ -31,12 +31,11 @@
 void collectTiledAndFusedOps(Operation *rootOp,
                              llvm::SmallDenseSet<Operation *> &result);
 
-/// Fuse all consumers of the given `tiledOp` into the surrounding `scf.forall`.
-/// Returns a list of new `tensor.extract_slice` ops with new fusion
-/// opportunities, as well as the new surrounding `scf.forall` (because consumer
-/// fusion replaces the loop).
+/// Fuse all consumers of the given `tiledOps` into the surrounding `scf.forall`
+/// unless specified otherwise by `filterFn`. Returns a list of new
+/// `tensor.extract_slice` ops with new fusion opportunities.
 FailureOr<std::queue<Operation *>> fuseConsumersIntoForall(
-    RewriterBase &rewriter, Operation *tiledOp,
+    RewriterBase &rewriter, ArrayRef<Operation *> tiledOps,
     MutableArrayRef<LoopLikeOpInterface> loops,
     std::function<bool(Operation *)> filterFn = [](Operation *) {
       return true;
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp
index 9650f73..978cae4 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp
@@ -210,6 +210,18 @@
 // Pass implementation.
 //===---------------------------------------------------------------------===//
 
+/// Returns true if any value produced by `producer` is used as an init value
+/// for the DPS `user`. Returns false if the user is not in DPS.
+static bool isUsedAsInit(Operation *producer, Operation *user) {
+  auto dpsIface = dyn_cast<DestinationStyleOpInterface>(user);
+  if (!dpsIface)
+    return false;
+  ValueRange results = producer->getResults();
+  return llvm::any_of(dpsIface.getDpsInits(), [&](Value operand) {
+    return llvm::is_contained(results, operand);
+  });
+}
+
 void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
   auto funcOp = getOperation();
   auto *context = &getContext();
@@ -232,10 +244,14 @@
 
   llvm::DenseSet<Operation *> yieldReplacementsFor;
   for (auto op : tiledAndFusedOps) {
-    // If tiledAndFused ops doesn't contain the user; add an replacement
-    // for that.
+    // Require replacement for values that are used after the main tilable op or
+    // by ops that will definitely not be fused. Note that if a value is used as
+    // an init of a DPS op, the user currently cannot be fused. Having a
+    // replacement for it would attempt fusion and fail, so avoid such cases.
     if (llvm::any_of(op->getUsers(), [&](Operation *user) {
-          return dominanceInfo.properlyDominates(tilableOp, user) &&
+          if (isUsedAsInit(op, user))
+            return false;
+          return dominanceInfo.properlyDominates(tilableOp, user) ||
                  !tiledAndFusedOps.contains(user);
         })) {
       yieldReplacementsFor.insert(op);
@@ -317,16 +333,18 @@
       });
     }
     std::swap(tileAndFuseResult->loops, tilingLoops);
-    Operation *rootTiledOp = tileAndFuseResult->tiledAndFusedOps.front();
+
     FailureOr<std::queue<Operation *>> newFusionOpportunities =
-        fuseConsumersIntoForall(rewriter, rootTiledOp, tilingLoops,
-                                [&tiledAndFusedOps](Operation *op) {
-                                  return tiledAndFusedOps.contains(op);
-                                });
+        fuseConsumersIntoForall(
+            rewriter, tileAndFuseResult->tiledAndFusedOps.getArrayRef(),
+            tilingLoops, [&tiledAndFusedOps](Operation *op) {
+              return tiledAndFusedOps.contains(op);
+            });
     if (failed(newFusionOpportunities)) {
       // Continue the work if the failure is allowed.
       if (!verifyComputeOpsAfterDistribution(funcOp)) {
-        rootTiledOp->emitOpError("failed to fuse consumers");
+        tileAndFuseResult->tiledAndFusedOps.front()->emitOpError(
+            "failed to fuse consumers");
         return signalPassFailure();
       }
     } else {
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_workgroups_using_forall.mlir b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_workgroups_using_forall.mlir
index e4503ff..5ddb2d4 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_workgroups_using_forall.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_workgroups_using_forall.mlir
@@ -1083,3 +1083,107 @@
 //       CHECK:     linalg.generic
 //       CHECK:   scf.forall.in_parallel {
 //       CHECK:   linalg.pack
+
+// -----
+
+// Adapted from layer normalization. The graph structure is as follows
+//
+//              %14
+//           /   |   \
+//         /    %15   %17
+//        |      |   / |
+//        |    [%19]   |
+//       %21     |    %22
+//        |      |     |
+//        v      v     v
+//
+// In particular, %21 and %22 are not users of the "main" tilable
+// operation but we still want them to be fused. %19, %21 and %22
+// all produce results returned from the function.
+//
+// Check that everything is fused and that there are three results
+// from the loop being produced and returned.
+//
+// CHECK-LABEL: @multi_result_consumer_fusion
+//   CHECK-NOT: linalg.generic
+//       CHECK: %[[LOOP:.+]]:3 = scf.forall (%[[I:.+]], %[[J:.+]]) in (16, 256) shared_outs(%[[OUT0:.+]] = %{{.+}}, %[[OUT1:.+]] = %{{.+}}, %[[OUT2:.+]] = %{{.+}})
+//       CHECK: %[[v14:.+]] = linalg.generic
+//       CHECK:     arith.divf
+//       CHECK:   %[[v15:.+]] = linalg.generic
+//       CHECK:     arith.subf
+//       CHECK:   %[[v17:.+]] = linalg.generic
+//       CHECK:     arith.divf
+//       CHECK:     math.rsqrt
+//       CHECK:   %[[RES0:.+]] = linalg.generic
+//       CHECK:     arith.mulf
+//       CHECK:     arith.extf
+//       CHECK:     arith.mulf
+//       CHECK:     arith.extf
+//       CHECK:     arith.addf
+//       CHECK:     arith.truncf
+//       CHECK:   %[[RES1:.+]] = linalg.generic {{.*}} ins(%[[v14]] :
+//       CHECK:     arith.truncf
+//       CHECK:   %[[RES2:.+]] = linalg.generic {{.*}} ins(%[[v17]] :
+//       CHECK:     arith.truncf
+//       CHECK:   scf.forall.in_parallel
+//       CHECK:     tensor.parallel_insert_slice %[[RES0]] into %[[OUT0]]
+//       CHECK:     tensor.parallel_insert_slice %[[RES1]] into %[[OUT1]]
+//       CHECK:     tensor.parallel_insert_slice %[[RES2]] into %[[OUT2]]
+//   CHECK-NOT: linalg.generic
+//       CHECK: return %[[LOOP]]#0, %[[LOOP]]#1, %[[LOOP]]#2
+func.func @multi_result_consumer_fusion(
+  %6: tensor<16x256x2048xbf16>,
+  %7: tensor<2048xbf16>,
+  %8: tensor<2048xbf16>,
+  %10: tensor<16x256x2048xf32>,
+  %13: tensor<16x256xf32>
+) -> (
+  tensor<16x256x2048xbf16>,
+  tensor<16x256xbf16>,
+  tensor<16x256xbf16>
+) {
+  %cst = arith.constant 0.000000e+00 : f32
+  %cst_0 = arith.constant 2.048000e+03 : f32
+  %c0 = arith.constant 0 : index
+  %9 = tensor.empty() : tensor<16x256x2048xf32>
+  %11 = tensor.empty() : tensor<16x256xf32>
+  %14 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%13 : tensor<16x256xf32>) outs(%11 : tensor<16x256xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %23 = arith.divf %in, %cst_0 : f32
+    linalg.yield %23 : f32
+  } -> tensor<16x256xf32>
+  %15 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%10, %14 : tensor<16x256x2048xf32>, tensor<16x256xf32>) outs(%9 : tensor<16x256x2048xf32>) {
+  ^bb0(%in: f32, %in_1: f32, %out: f32):
+    %23 = arith.subf %in, %in_1 : f32
+    linalg.yield %23 : f32
+  } -> tensor<16x256x2048xf32>
+  %17 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%14 : tensor<16x256xf32>) outs(%11 : tensor<16x256xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %23 = arith.divf %in, %cst_0 : f32
+    %24 = math.rsqrt %23 : f32
+    linalg.yield %24 : f32
+  } -> tensor<16x256xf32>
+  %18 = tensor.empty() : tensor<16x256x2048xbf16>
+  %19 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%15, %17, %7, %8 : tensor<16x256x2048xf32>, tensor<16x256xf32>, tensor<2048xbf16>, tensor<2048xbf16>) outs(%18 : tensor<16x256x2048xbf16>) attrs =  {lowering_config = #iree_gpu.lowering_config<{lane_basis = [[1, 1, 64], [0, 1, 2]], reduction = [0, 0, 256], subgroup_basis = [[1, 1, 1], [0, 1, 2]], thread = [0, 0, 4], workgroup = [1, 1, 0]}>} {
+  ^bb0(%in: f32, %in_1: f32, %in_2: bf16, %in_3: bf16, %out: bf16):
+    %23 = arith.mulf %in, %in_1 : f32
+    %24 = arith.extf %in_2 : bf16 to f32
+    %25 = arith.mulf %23, %24 : f32
+    %26 = arith.extf %in_3 : bf16 to f32
+    %27 = arith.addf %25, %26 : f32
+    %28 = arith.truncf %27 : f32 to bf16
+    linalg.yield %28 : bf16
+  } -> tensor<16x256x2048xbf16>
+  %20 = tensor.empty() : tensor<16x256xbf16>
+  %21 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%14 : tensor<16x256xf32>) outs(%20 : tensor<16x256xbf16>) {
+  ^bb0(%in: f32, %out: bf16):
+    %23 = arith.truncf %in : f32 to bf16
+    linalg.yield %23 : bf16
+  } -> tensor<16x256xbf16>
+  %22 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%17 : tensor<16x256xf32>) outs(%20 : tensor<16x256xbf16>) {
+  ^bb0(%in: f32, %out: bf16):
+    %23 = arith.truncf %in : f32 to bf16
+    linalg.yield %23 : bf16
+  } -> tensor<16x256xbf16>
+  return %19, %21, %22 : tensor<16x256x2048xbf16>, tensor<16x256xbf16>, tensor<16x256xbf16>
+}