Revert "[DispatchCreation] Extend multi-use producer fusion…" (#19468)

There appears to be more problems with this change. This is causing
compilation failure with 405b due to huge program slices during multi
use fusion. `getBackwardSlice` is recursive so it is causing a stack
overflow.




Reverts iree-org/iree#19431
diff --git a/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp b/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
index 8454856..a78b6b8 100644
--- a/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
@@ -7,7 +7,6 @@
 #include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
 #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
 #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
-#include "iree/compiler/DispatchCreation/FusionUtils.h"
 #include "iree/compiler/DispatchCreation/Passes.h"
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Analysis/TopologicalSortUtils.h"
@@ -108,6 +107,25 @@
   return true;
 }
 
+/// Check that a given operation is "horizontal" to the group. The operation
+/// is horizontal if the `slice` of the operation does not contain any op
+/// from the group.
+static bool isHorizontalToGroup(Operation *op,
+                                const llvm::SetVector<Operation *> &currGroup,
+                                const DominanceInfo &dominanceInfo,
+                                Operation *seedOp) {
+  BackwardSliceOptions options;
+  // Limit the slice to the seed to make sure the slice is small.
+  options.filter = [&](Operation *op) {
+    return !dominanceInfo.properlyDominates(op, seedOp);
+  };
+  llvm::SetVector<Operation *> slice;
+  getBackwardSlice(op, &slice, options);
+  return !llvm::any_of(currGroup, [&](Operation *groupedOp) {
+    return slice.contains(groupedOp);
+  });
+}
+
 /// Get user of operation that is a truncate operation.
 static std::optional<linalg::GenericOp>
 getTruncateOp(Operation *op,
@@ -131,8 +149,8 @@
     if (!checkOperationEquivalence(genericOp, seedTruncateOp.value())) {
       return std::nullopt;
     }
-    if (!isHorizontalToGroup(genericOp, groupedOperations.getArrayRef(),
-                             dominanceInfo, seedTruncateOp.value())) {
+    if (!isHorizontalToGroup(genericOp, groupedOperations, dominanceInfo,
+                             seedTruncateOp.value())) {
       return std::nullopt;
     }
   }
@@ -208,8 +226,7 @@
     if (!dominanceInfo.properlyDominates(seedOp, linalgOp)) {
       return false;
     }
-    if (!isHorizontalToGroup(linalgOp, allOps.getArrayRef(), dominanceInfo,
-                             seedOp)) {
+    if (!isHorizontalToGroup(linalgOp, allOps, dominanceInfo, seedOp)) {
       return false;
     }
     return true;
@@ -329,6 +346,40 @@
   return newIndexingMap.insertResult(rewriter.getAffineDimExpr(0), 0);
 }
 
+/// During horizontal fusion, there might be operands of the fused operations
+/// whose definitions are interspersed between the fused operations. For groups
+/// chosen to fuse horizontally, such operations can be moved before the
+/// seed contraction operation (where the fused operation is generated).
+template <typename T>
+static LogicalResult
+moveOperandDefs(RewriterBase &rewriter, ArrayRef<T> operations,
+                Operation *insertionPoint, DominanceInfo &dominanceInfo,
+                ArrayRef<linalg::LinalgOp> ignoreOperations = {}) {
+  BackwardSliceOptions options;
+  llvm::DenseSet<Operation *> ignoreOperationsSet;
+  ignoreOperationsSet.insert(ignoreOperations.begin(), ignoreOperations.end());
+  options.filter = [&](Operation *op) {
+    return !dominanceInfo.properlyDominates(op, insertionPoint) &&
+           !ignoreOperationsSet.contains(op);
+  };
+  // Set inclusive to true cause the slice is computed from the operand, and
+  // we want to include the defining op (which is the point here)
+  options.inclusive = true;
+
+  llvm::SetVector<Operation *> slice;
+  for (auto op : operations) {
+    for (auto operand : op->getOperands()) {
+      getBackwardSlice(operand, &slice, options);
+    }
+  }
+
+  mlir::topologicalSort(slice);
+  for (auto op : slice) {
+    rewriter.moveOpBefore(op, insertionPoint);
+  }
+  return success();
+}
+
 /// On finding this pattern
 /// ```
 /// %0 = linalg.matmul ins(%arg0, %arg1)
diff --git a/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp b/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp
index d79d514..9d9d477 100644
--- a/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp
@@ -16,13 +16,9 @@
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
-#include "iree/compiler/DispatchCreation/FusionUtils.h"
 #include "iree/compiler/DispatchCreation/Passes.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
-#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Analysis/TopologicalSortUtils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -49,55 +45,25 @@
     llvm::cl::desc("Maximum number of elements to try to constant fold."),
     llvm::cl::init(0));
 
-static Operation *getMostDominantUse(Operation *op,
-                                     const DominanceInfo &dominanceInfo) {
-  auto uses = op->getUses();
-  auto it = llvm::find_if(uses, [&](OpOperand &source) {
-    Operation *sourceOp = source.getOwner();
-
-    return llvm::all_of(uses, [&](OpOperand &target) {
-      Operation *targetOp = target.getOwner();
-      return dominanceInfo.dominates(sourceOp, targetOp);
-    });
-  });
-  if (it != uses.end()) {
-    return it->getOwner();
-  }
-  return nullptr;
-}
-
 /// Check if any of the use dominates all other uses of the operation.
-static Operation *getFusableUse(Operation *op,
-                                const DominanceInfo &dominanceInfo) {
+static std::optional<OpOperand *> getFusableUse(Operation *op,
+                                                DominanceInfo &dominanceInfo) {
   auto uses = op->getUses();
-  Operation *fusableUse = nullptr;
   for (OpOperand &source : uses) {
     Operation *sourceOp = source.getOwner();
-
-    bool dominatesAllFusableOps = llvm::all_of(uses, [&](OpOperand &target) {
+    bool dominatesAllUsers = true;
+    for (OpOperand &target : uses) {
       Operation *targetOp = target.getOwner();
-      return !isa<linalg::GenericOp>(targetOp) ||
-             dominanceInfo.dominates(sourceOp, targetOp);
-    });
-    if (dominatesAllFusableOps) {
-      fusableUse = sourceOp;
-      break;
+      if (!dominanceInfo.dominates(sourceOp, targetOp)) {
+        dominatesAllUsers = false;
+        break;
+      }
+    }
+    if (dominatesAllUsers) {
+      return &source;
     }
   }
-  Operation *mostDominantOp = getMostDominantUse(op, dominanceInfo);
-  if (!fusableUse || !mostDominantOp) {
-    return nullptr;
-  }
-
-  // If `fusableUse` dominates all other users, there's nothing else to do.
-  if (fusableUse == mostDominantOp) {
-    return fusableUse;
-  }
-
-  SmallVector<Operation *> users(op->getUsers().begin(), op->getUsers().end());
-  return isHorizontalToGroup(fusableUse, users, dominanceInfo, mostDominantOp)
-             ? fusableUse
-             : nullptr;
+  return std::nullopt;
 }
 
 static OpOperand *getFirstUseInConsumer(Operation *producer,
@@ -125,7 +91,6 @@
 /// using elementwise fusion.
 static LogicalResult doMultiUseFusion(Operation *rootOp,
                                       llvm::SetVector<Operation *> &fusableOps,
-                                      const DominanceInfo &dominanceInfo,
                                       RewriterBase &rewriter) {
   assert(rootOp && "root op cant be null");
 
@@ -147,20 +112,11 @@
   Operation *consumerOp = rootOp;
   OpBuilder::InsertionGuard g(rewriter);
   for (Operation *producerOp : llvm::reverse(fusedOpsVec)) {
-    Operation *mostDominantUser = getMostDominantUse(producerOp, dominanceInfo);
     // Fuse all uses from producer -> consumer. It has been checked
     // before that all uses are fusable.
     while (OpOperand *fusedOperand =
                getFirstUseInConsumer(producerOp, consumerOp)) {
       rewriter.setInsertionPoint(consumerOp);
-
-      if (consumerOp != mostDominantUser &&
-          failed(moveOperandDefs(rewriter, ArrayRef<Operation *>{consumerOp},
-                                 mostDominantUser, dominanceInfo))) {
-        return rewriter.notifyMatchFailure(consumerOp,
-                                           "failed to move operand defs");
-      }
-      rewriter.moveOpBefore(consumerOp, mostDominantUser);
       FailureOr<linalg::ElementwiseOpFusionResult> fusionResult =
           linalg::fuseElementwiseOps(rewriter, fusedOperand);
       if (failed(fusionResult)) {
@@ -234,8 +190,9 @@
           }
 
           // 6. Check that the `genericOp` dominates all uses of `producer`.
-          Operation *fusableUse = getFusableUse(producer, dominanceInfo);
-          if (!fusableUse || fusableUse != genericOp) {
+          std::optional<OpOperand *> fusableUse =
+              getFusableUse(producer, dominanceInfo);
+          if (!fusableUse || fusableUse.value()->getOwner() != genericOp) {
             continue;
           }
 
@@ -275,8 +232,7 @@
 
   IRRewriter rewriter(context);
   for (auto it = fusedOps.rbegin(), ie = fusedOps.rend(); it != ie; ++it) {
-    if (failed(
-            doMultiUseFusion(it->first, it->second, dominanceInfo, rewriter))) {
+    if (failed(doMultiUseFusion(it->first, it->second, rewriter))) {
       return funcOp->emitOpError("failed multi use fusion");
     }
   }
diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp
index 3e3e653..c428091 100644
--- a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp
@@ -10,11 +10,7 @@
 #include "compiler/src/iree/compiler/DispatchCreation/FusionUtils.h"
 #include "compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
 #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
-#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/IR/Dominance.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/Transforms/RegionUtils.h"
 
 namespace mlir::iree_compiler::DispatchCreation {
 
@@ -101,22 +97,4 @@
   return true;
 }
 
-bool isHorizontalToGroup(Operation *op, ArrayRef<Operation *> currGroup,
-                         const DominanceInfo &dominanceInfo,
-                         Operation *seedOp) {
-  assert(dominanceInfo.properlyDominates(seedOp, op) &&
-         op->getParentRegion() == seedOp->getParentRegion());
-  BackwardSliceOptions options;
-  options.omitUsesFromAbove = false;
-  // Limit the slice to the seed to make sure the slice is small.
-  options.filter = [&](Operation *op) {
-    return !dominanceInfo.properlyDominates(op, seedOp);
-  };
-  llvm::SetVector<Operation *> slice;
-  getBackwardSlice(op, &slice, options);
-  return !llvm::any_of(currGroup, [&](Operation *groupedOp) {
-    return slice.contains(groupedOp);
-  });
-}
-
 } // namespace mlir::iree_compiler::DispatchCreation
diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h
index a264db9..1d9c930 100644
--- a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h
+++ b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h
@@ -10,10 +10,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Analysis/SliceAnalysis.h"
-#include "mlir/Analysis/TopologicalSortUtils.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/IR/Dominance.h"
 #include "mlir/IR/Operation.h"
 
 namespace mlir::iree_compiler::DispatchCreation {
@@ -23,45 +19,4 @@
 bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *operand,
                                 bool fuseMultiReduction);
 
-/// Check that a given operation is "horizontal" to the group. The operation
-/// is horizontal if the program slice of the operation (from op back to seedOp)
-/// does not contain any op from the group.
-bool isHorizontalToGroup(Operation *op, ArrayRef<Operation *> currGroup,
-                         const DominanceInfo &dominanceInfo, Operation *seedOp);
-
-/// Moves the operands and transitive defs for each op in `operations` directly
-/// after `insertionPoint`. Note: this does not check if it is legal to move the
-/// operands.
-template <typename T>
-static LogicalResult
-moveOperandDefs(RewriterBase &rewriter, ArrayRef<T> operations,
-                Operation *insertionPoint, const DominanceInfo &dominanceInfo,
-                ArrayRef<linalg::LinalgOp> ignoreOperations = {}) {
-  BackwardSliceOptions options;
-  options.omitUsesFromAbove = false;
-  llvm::DenseSet<Operation *> ignoreOperationsSet;
-  ignoreOperationsSet.insert(ignoreOperations.begin(), ignoreOperations.end());
-  options.filter = [&](Operation *op) {
-    return !dominanceInfo.properlyDominates(op, insertionPoint) &&
-           !ignoreOperationsSet.contains(op);
-  };
-  // Set inclusive to true cause the slice is computed from the operand, and
-  // we want to include the defining op (which is the point here)
-  options.inclusive = true;
-
-  llvm::SetVector<Operation *> slice;
-  for (auto op : operations) {
-    assert(insertionPoint->getBlock() == op->getBlock());
-    for (auto operand : op->getOperands()) {
-      getBackwardSlice(operand, &slice, options);
-    }
-  }
-
-  mlir::topologicalSort(slice);
-  for (auto op : slice) {
-    rewriter.moveOpBefore(op, insertionPoint);
-  }
-  return success();
-}
-
 } // namespace mlir::iree_compiler::DispatchCreation
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir b/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir
index c6af7b1..cc3e159 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir
@@ -139,84 +139,3 @@
 //       CHECK:   %[[GENERIC:.+]]:2 = linalg.generic
 //   CHECK-DAG:   check.expect_almost_eq(%[[GENERIC]]#0,
 //   CHECK-DAG:   check.expect_almost_eq(%[[GENERIC]]#1,
-
-// -----
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
-util.func public @fuse_by_moving_consumer(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) -> (tensor<5x5xf32>, tensor<25xf32>) {
-  %cst = arith.constant 1.000000e+00 : f32
-  %cst_0 = arith.constant 2.000000e+00 : f32
-  %cst_1 = arith.constant 3.000000e+00 : f32
-  %4 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
-  ^bb0(%arg2: f32, %arg3: f32):
-    %8 = arith.addf %arg2, %cst : f32
-    linalg.yield %8 : f32
-  } -> tensor<5x5xf32>
-  // expected-note @below {{prior use here}}
-  %collapsed = tensor.collapse_shape %4 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32>
-  %5 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%4 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
-  ^bb0(%arg2: f32, %arg3: f32):
-    %8 = arith.subf %arg2, %cst_0 : f32
-    linalg.yield %8 : f32
-  } -> tensor<5x5xf32>
-  util.return %5, %collapsed: tensor<5x5xf32>, tensor<25xf32>
-}
-// CHECK-LABEL: util.func public @fuse_by_moving_consumer
-// CHECK:         linalg.generic
-// CHECK-NOT:     linalg.generic
-
-
-// -----
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
-util.func public @dont_fuse_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) -> (tensor<5x5xf32>, tensor<25xf32>) {
-  %cst = arith.constant 1.000000e+00 : f32
-  %cst_0 = arith.constant 2.000000e+00 : f32
-  %cst_1 = arith.constant 3.000000e+00 : f32
-  %0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    %2 = arith.addf %in, %cst : f32
-    linalg.yield %2 : f32
-  } -> tensor<5x5xf32>
-  %collapsed = tensor.collapse_shape %0 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32>
-  %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    %c2 = arith.constant 2 : index
-    %extracted = tensor.extract %collapsed[%c2] : tensor<25xf32>
-    %2 = arith.addf %extracted, %extracted : f32
-    linalg.yield %2 : f32
-  } -> tensor<5x5xf32>
-  util.return %1, %collapsed : tensor<5x5xf32>, tensor<25xf32>
-}
-
-// CHECK-LABEL: util.func public @dont_fuse_use_from_above
-// CHECK:         linalg.generic
-// CHECK:         linalg.generic
-
-
-// -----
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
-util.func public @do_fuse_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) -> (tensor<5x5xf32>, tensor<25xf32>) {
-  %cst = arith.constant 1.000000e+00 : f32
-  %cst_0 = arith.constant 2.000000e+00 : f32
-  %cst_1 = arith.constant 3.000000e+00 : f32
-  %0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    %2 = arith.addf %in, %cst : f32
-    linalg.yield %2 : f32
-  } -> tensor<5x5xf32>
-  %collapsed = tensor.collapse_shape %0 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32>
-  %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    %c2 = arith.constant 2 : index
-    %extracted = tensor.extract %arg0[%c2, %c2] : tensor<5x5xf32>
-    %2 = arith.addf %extracted, %extracted : f32
-    linalg.yield %2 : f32
-  } -> tensor<5x5xf32>
-  util.return %1, %collapsed : tensor<5x5xf32>, tensor<25xf32>
-}
-
-// CHECK-LABEL: util.func public @do_fuse_use_from_above
-// CHECK:         linalg.generic
-// CHECK-NOT:     linalg.generic