Revert "[codegen] more consumer fusion (#21521)" (#21819)

This reverts commit 4d91ffb09a8e48844b6460edce58ccba245bced7.

The above commit causes failure in compilation of llama 405B fp4 model.
Ticket to track the same https://github.com/iree-org/iree/issues/21814

Signed-off-by: Praveen G <praveen.g2@amd.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.cpp b/compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.cpp
index 64de60c..21a269a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.cpp
@@ -9,7 +9,6 @@
 #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
 #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
 #include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/Debug.h"
 #include "mlir/Analysis/TopologicalSortUtils.h"
@@ -88,51 +87,16 @@
   }
 }
 
-namespace {
-// Entry for the pseudo-priority queue of consumer fusion candidates. Contains
-// the consumer (fusableUser) that can be fused and the set of slice operations
-// in the loop to fuse into that feed the consumer.
-struct ConsumerFusionQueueEntry {
-  ConsumerFusionQueueEntry(SmallVector<Operation *> &&slices,
-                           Operation *fusableUser)
-      : slices(std::move(slices)), fusableUser(fusableUser) {}
-
-  SmallVector<Operation *> slices;
-  Operation *fusableUser;
-};
-} // namespace
-
 FailureOr<std::queue<Operation *>>
-fuseConsumersIntoForall(RewriterBase &rewriter, ArrayRef<Operation *> tiledOps,
+fuseConsumersIntoForall(RewriterBase &rewriter, Operation *tiledOp,
                         MutableArrayRef<LoopLikeOpInterface> loops,
                         std::function<bool(Operation *)> filterFn) {
   // Collect the candidate slices which can be potential consumers that can be
-  // fused. Keep them in a vector reverse-sorted by dominance: the candidate
-  // dominating others comes last (so it can be cheaply popped from the vector).
-  // The most-dominating candidate is to be fused first since not fusing it may
-  // prevent dominated candidates to be fused:
-  //
-  //     A
-  //     |
-  //     B
-  //   / |
-  //  |  D
-  //  | /
-  //  C
-  //
-  // here, B must be fused before both C and D, and D must be fused before C.
-  // Candidates are kept in a vector rather than a priority queue since we may
-  // update them as fusion happens, in particular, more slices may need to be
-  // handled. For example, fusing B with A will create a slice of B that will
-  // need to be handled correctly.
-  SmallVector<ConsumerFusionQueueEntry> candidates;
+  // fused.
+  std::queue<SmallVector<Operation *>> candidates;
   llvm::SmallDenseSet<tensor::ParallelInsertSliceOp> allCandidates;
   auto addCandidateSlices = [&candidates, &allCandidates,
                              &filterFn](Operation *fusedOp) {
-    // Dominance info recreated since op creation/movement in the fusion logic
-    // invalidates it anyway.
-    DominanceInfo dominanceInfo;
-
     for (auto *userOp : fusedOp->getResults().getUsers()) {
       auto sliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(userOp);
       if (!sliceOp || allCandidates.contains(sliceOp)) {
@@ -171,38 +135,22 @@
         allCandidates.insert_range(slices);
       }
       if (!fusedSlices.empty()) {
-        ConsumerFusionQueueEntry entry(std::move(fusedSlices), fusableUser);
-
-        // Comparator that puts the dominating user last.
-        auto comp = [&](const ConsumerFusionQueueEntry &lhs,
-                        const ConsumerFusionQueueEntry &rhs) {
-          return dominanceInfo.properlyDominates(rhs.fusableUser,
-                                                 lhs.fusableUser);
-        };
-
-        // If the fusable user is already a candidate, update it with the new
-        // list of slices to handle. Otherwise, insert it into the right
-        // position based on dominance.
-        auto *it = llvm::lower_bound(candidates, entry, comp);
-        if (it != candidates.end() && it->fusableUser == fusableUser)
-          *it = std::move(entry);
-        else
-          candidates.insert(it, std::move(entry));
+        candidates.emplace(std::move(fusedSlices));
       }
     }
   };
 
-  // Add slices from all tiled ops, not only the "main" one.
-  for (Operation *tiledOp : tiledOps)
-    addCandidateSlices(tiledOp);
+  addCandidateSlices(tiledOp);
 
   std::queue<Operation *> newFusionOpportunities;
   while (!candidates.empty()) {
-    // Get the next candidate.
-    ConsumerFusionQueueEntry entry = candidates.pop_back_val();
+    // Traverse the slices in BFS fashion.
+    SmallVector<Operation *> candidateSlices = candidates.front();
+    candidates.pop();
 
     FailureOr<scf::SCFFuseConsumerOfSliceResult> fusedResult =
-        mlir::scf::tileAndFuseConsumerOfSlices(rewriter, entry.slices, loops);
+        mlir::scf::tileAndFuseConsumerOfSlices(rewriter, candidateSlices,
+                                               loops);
     if (failed(fusedResult)) {
       return failure();
     }
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.h b/compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.h
index 303f831..690df25 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.h
@@ -31,11 +31,12 @@
 void collectTiledAndFusedOps(Operation *rootOp,
                              llvm::SmallDenseSet<Operation *> &result);
 
-/// Fuse all consumers of the given `tiledOps` into the surrounding `scf.forall`
-/// unless specified otherwise by `filterFn`. Returns a list of new
-/// `tensor.extract_slice` ops with new fusion opportunities.
+/// Fuse all consumers of the given `tiledOp` into the surrounding `scf.forall`.
+/// Returns a list of new `tensor.extract_slice` ops with new fusion
+/// opportunities, as well as the new surrounding `scf.forall` (because consumer
+/// fusion replaces the loop).
 FailureOr<std::queue<Operation *>> fuseConsumersIntoForall(
-    RewriterBase &rewriter, ArrayRef<Operation *> tiledOps,
+    RewriterBase &rewriter, Operation *tiledOp,
     MutableArrayRef<LoopLikeOpInterface> loops,
     std::function<bool(Operation *)> filterFn = [](Operation *) {
       return true;
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp
index 978cae4..9650f73 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp
@@ -210,18 +210,6 @@
 // Pass implementation.
 //===---------------------------------------------------------------------===//
 
-/// Returns true if any value produced by `producer` is used as an init value
-/// for the DPS `user`. Returns false if the user is not in DPS.
-static bool isUsedAsInit(Operation *producer, Operation *user) {
-  auto dpsIface = dyn_cast<DestinationStyleOpInterface>(user);
-  if (!dpsIface)
-    return false;
-  ValueRange results = producer->getResults();
-  return llvm::any_of(dpsIface.getDpsInits(), [&](Value operand) {
-    return llvm::is_contained(results, operand);
-  });
-}
-
 void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
   auto funcOp = getOperation();
   auto *context = &getContext();
@@ -244,14 +232,10 @@
 
   llvm::DenseSet<Operation *> yieldReplacementsFor;
   for (auto op : tiledAndFusedOps) {
-    // Require replacement for values that are used after the main tilable op or
-    // by ops that will definitely not be fused. Note that if a value is used as
-    // an init of a DPS op, the user currently cannot be fused. Having a
-    // replacement for it would attempt fusion and fail, so avoid such cases.
+    // If tiledAndFused ops doesn't contain the user; add an replacement
+    // for that.
     if (llvm::any_of(op->getUsers(), [&](Operation *user) {
-          if (isUsedAsInit(op, user))
-            return false;
-          return dominanceInfo.properlyDominates(tilableOp, user) ||
+          return dominanceInfo.properlyDominates(tilableOp, user) &&
                  !tiledAndFusedOps.contains(user);
         })) {
       yieldReplacementsFor.insert(op);
@@ -333,18 +317,16 @@
       });
     }
     std::swap(tileAndFuseResult->loops, tilingLoops);
-
+    Operation *rootTiledOp = tileAndFuseResult->tiledAndFusedOps.front();
     FailureOr<std::queue<Operation *>> newFusionOpportunities =
-        fuseConsumersIntoForall(
-            rewriter, tileAndFuseResult->tiledAndFusedOps.getArrayRef(),
-            tilingLoops, [&tiledAndFusedOps](Operation *op) {
-              return tiledAndFusedOps.contains(op);
-            });
+        fuseConsumersIntoForall(rewriter, rootTiledOp, tilingLoops,
+                                [&tiledAndFusedOps](Operation *op) {
+                                  return tiledAndFusedOps.contains(op);
+                                });
     if (failed(newFusionOpportunities)) {
       // Continue the work if the failure is allowed.
       if (!verifyComputeOpsAfterDistribution(funcOp)) {
-        tileAndFuseResult->tiledAndFusedOps.front()->emitOpError(
-            "failed to fuse consumers");
+        rootTiledOp->emitOpError("failed to fuse consumers");
         return signalPassFailure();
       }
     } else {
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_workgroups_using_forall.mlir b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_workgroups_using_forall.mlir
index 5ddb2d4..e4503ff 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_workgroups_using_forall.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_workgroups_using_forall.mlir
@@ -1083,107 +1083,3 @@
 //       CHECK:     linalg.generic
 //       CHECK:   scf.forall.in_parallel {
 //       CHECK:   linalg.pack
-
-// -----
-
-// Adapted from layer normalization. The graph structure is as follows
-//
-//              %14
-//           /   |   \
-//         /    %15   %17
-//        |      |   / |
-//        |    [%19]   |
-//       %21     |    %22
-//        |      |     |
-//        v      v     v
-//
-// In particular, %21 and %22 are not users of the "main" tilable
-// operation but we still want them to be fused. %19, %21 and %22
-// all produce results returned from the function.
-//
-// Check that everything is fused and that there are three results
-// from the loop being produced and returned.
-//
-// CHECK-LABEL: @multi_result_consumer_fusion
-//   CHECK-NOT: linalg.generic
-//       CHECK: %[[LOOP:.+]]:3 = scf.forall (%[[I:.+]], %[[J:.+]]) in (16, 256) shared_outs(%[[OUT0:.+]] = %{{.+}}, %[[OUT1:.+]] = %{{.+}}, %[[OUT2:.+]] = %{{.+}})
-//       CHECK: %[[v14:.+]] = linalg.generic
-//       CHECK:     arith.divf
-//       CHECK:   %[[v15:.+]] = linalg.generic
-//       CHECK:     arith.subf
-//       CHECK:   %[[v17:.+]] = linalg.generic
-//       CHECK:     arith.divf
-//       CHECK:     math.rsqrt
-//       CHECK:   %[[RES0:.+]] = linalg.generic
-//       CHECK:     arith.mulf
-//       CHECK:     arith.extf
-//       CHECK:     arith.mulf
-//       CHECK:     arith.extf
-//       CHECK:     arith.addf
-//       CHECK:     arith.truncf
-//       CHECK:   %[[RES1:.+]] = linalg.generic {{.*}} ins(%[[v14]] :
-//       CHECK:     arith.truncf
-//       CHECK:   %[[RES2:.+]] = linalg.generic {{.*}} ins(%[[v17]] :
-//       CHECK:     arith.truncf
-//       CHECK:   scf.forall.in_parallel
-//       CHECK:     tensor.parallel_insert_slice %[[RES0]] into %[[OUT0]]
-//       CHECK:     tensor.parallel_insert_slice %[[RES1]] into %[[OUT1]]
-//       CHECK:     tensor.parallel_insert_slice %[[RES2]] into %[[OUT2]]
-//   CHECK-NOT: linalg.generic
-//       CHECK: return %[[LOOP]]#0, %[[LOOP]]#1, %[[LOOP]]#2
-func.func @multi_result_consumer_fusion(
-  %6: tensor<16x256x2048xbf16>,
-  %7: tensor<2048xbf16>,
-  %8: tensor<2048xbf16>,
-  %10: tensor<16x256x2048xf32>,
-  %13: tensor<16x256xf32>
-) -> (
-  tensor<16x256x2048xbf16>,
-  tensor<16x256xbf16>,
-  tensor<16x256xbf16>
-) {
-  %cst = arith.constant 0.000000e+00 : f32
-  %cst_0 = arith.constant 2.048000e+03 : f32
-  %c0 = arith.constant 0 : index
-  %9 = tensor.empty() : tensor<16x256x2048xf32>
-  %11 = tensor.empty() : tensor<16x256xf32>
-  %14 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%13 : tensor<16x256xf32>) outs(%11 : tensor<16x256xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    %23 = arith.divf %in, %cst_0 : f32
-    linalg.yield %23 : f32
-  } -> tensor<16x256xf32>
-  %15 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%10, %14 : tensor<16x256x2048xf32>, tensor<16x256xf32>) outs(%9 : tensor<16x256x2048xf32>) {
-  ^bb0(%in: f32, %in_1: f32, %out: f32):
-    %23 = arith.subf %in, %in_1 : f32
-    linalg.yield %23 : f32
-  } -> tensor<16x256x2048xf32>
-  %17 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%14 : tensor<16x256xf32>) outs(%11 : tensor<16x256xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    %23 = arith.divf %in, %cst_0 : f32
-    %24 = math.rsqrt %23 : f32
-    linalg.yield %24 : f32
-  } -> tensor<16x256xf32>
-  %18 = tensor.empty() : tensor<16x256x2048xbf16>
-  %19 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%15, %17, %7, %8 : tensor<16x256x2048xf32>, tensor<16x256xf32>, tensor<2048xbf16>, tensor<2048xbf16>) outs(%18 : tensor<16x256x2048xbf16>) attrs =  {lowering_config = #iree_gpu.lowering_config<{lane_basis = [[1, 1, 64], [0, 1, 2]], reduction = [0, 0, 256], subgroup_basis = [[1, 1, 1], [0, 1, 2]], thread = [0, 0, 4], workgroup = [1, 1, 0]}>} {
-  ^bb0(%in: f32, %in_1: f32, %in_2: bf16, %in_3: bf16, %out: bf16):
-    %23 = arith.mulf %in, %in_1 : f32
-    %24 = arith.extf %in_2 : bf16 to f32
-    %25 = arith.mulf %23, %24 : f32
-    %26 = arith.extf %in_3 : bf16 to f32
-    %27 = arith.addf %25, %26 : f32
-    %28 = arith.truncf %27 : f32 to bf16
-    linalg.yield %28 : bf16
-  } -> tensor<16x256x2048xbf16>
-  %20 = tensor.empty() : tensor<16x256xbf16>
-  %21 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%14 : tensor<16x256xf32>) outs(%20 : tensor<16x256xbf16>) {
-  ^bb0(%in: f32, %out: bf16):
-    %23 = arith.truncf %in : f32 to bf16
-    linalg.yield %23 : bf16
-  } -> tensor<16x256xbf16>
-  %22 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%17 : tensor<16x256xf32>) outs(%20 : tensor<16x256xbf16>) {
-  ^bb0(%in: f32, %out: bf16):
-    %23 = arith.truncf %in : f32 to bf16
-    linalg.yield %23 : bf16
-  } -> tensor<16x256xbf16>
-  return %19, %21, %22 : tensor<16x256x2048xbf16>, tensor<16x256xbf16>, tensor<16x256xbf16>
-}