Revert "[DispatchCreation] Extend multi-use producer fusion…" (#19468)
There appears to be more problems with this change. This is causing
compilation failure with 405b due to huge program slices during multi
use fusion. `getBackwardSlice` is recursive so it is causing a stack
overflow.
Reverts iree-org/iree#19431
diff --git a/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp b/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
index 8454856..a78b6b8 100644
--- a/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
@@ -7,7 +7,6 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
-#include "iree/compiler/DispatchCreation/FusionUtils.h"
#include "iree/compiler/DispatchCreation/Passes.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
@@ -108,6 +107,25 @@
return true;
}
+/// Check that a given operation is "horizontal" to the group. The operation
+/// is horizontal if the `slice` of the operation does not contain any op
+/// from the group.
+static bool isHorizontalToGroup(Operation *op,
+ const llvm::SetVector<Operation *> &currGroup,
+ const DominanceInfo &dominanceInfo,
+ Operation *seedOp) {
+ BackwardSliceOptions options;
+ // Limit the slice to the seed to make sure the slice is small.
+ options.filter = [&](Operation *op) {
+ return !dominanceInfo.properlyDominates(op, seedOp);
+ };
+ llvm::SetVector<Operation *> slice;
+ getBackwardSlice(op, &slice, options);
+ return !llvm::any_of(currGroup, [&](Operation *groupedOp) {
+ return slice.contains(groupedOp);
+ });
+}
+
/// Get user of operation that is a truncate operation.
static std::optional<linalg::GenericOp>
getTruncateOp(Operation *op,
@@ -131,8 +149,8 @@
if (!checkOperationEquivalence(genericOp, seedTruncateOp.value())) {
return std::nullopt;
}
- if (!isHorizontalToGroup(genericOp, groupedOperations.getArrayRef(),
- dominanceInfo, seedTruncateOp.value())) {
+ if (!isHorizontalToGroup(genericOp, groupedOperations, dominanceInfo,
+ seedTruncateOp.value())) {
return std::nullopt;
}
}
@@ -208,8 +226,7 @@
if (!dominanceInfo.properlyDominates(seedOp, linalgOp)) {
return false;
}
- if (!isHorizontalToGroup(linalgOp, allOps.getArrayRef(), dominanceInfo,
- seedOp)) {
+ if (!isHorizontalToGroup(linalgOp, allOps, dominanceInfo, seedOp)) {
return false;
}
return true;
@@ -329,6 +346,40 @@
return newIndexingMap.insertResult(rewriter.getAffineDimExpr(0), 0);
}
+/// During horizontal fusion, there might be operands of the fused operations
+/// whose definitions are interspersed between the fused operations. For groups
+/// chosen to fuse horizontally, such operations can be moved before the
+/// seed contraction operation (where the fused operation is generated).
+template <typename T>
+static LogicalResult
+moveOperandDefs(RewriterBase &rewriter, ArrayRef<T> operations,
+ Operation *insertionPoint, DominanceInfo &dominanceInfo,
+ ArrayRef<linalg::LinalgOp> ignoreOperations = {}) {
+ BackwardSliceOptions options;
+ llvm::DenseSet<Operation *> ignoreOperationsSet;
+ ignoreOperationsSet.insert(ignoreOperations.begin(), ignoreOperations.end());
+ options.filter = [&](Operation *op) {
+ return !dominanceInfo.properlyDominates(op, insertionPoint) &&
+ !ignoreOperationsSet.contains(op);
+ };
+ // Set inclusive to true cause the slice is computed from the operand, and
+ // we want to include the defining op (which is the point here)
+ options.inclusive = true;
+
+ llvm::SetVector<Operation *> slice;
+ for (auto op : operations) {
+ for (auto operand : op->getOperands()) {
+ getBackwardSlice(operand, &slice, options);
+ }
+ }
+
+ mlir::topologicalSort(slice);
+ for (auto op : slice) {
+ rewriter.moveOpBefore(op, insertionPoint);
+ }
+ return success();
+}
+
/// On finding this pattern
/// ```
/// %0 = linalg.matmul ins(%arg0, %arg1)
diff --git a/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp b/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp
index d79d514..9d9d477 100644
--- a/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp
@@ -16,13 +16,9 @@
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
-#include "iree/compiler/DispatchCreation/FusionUtils.h"
#include "iree/compiler/DispatchCreation/Passes.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
-#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -49,55 +45,25 @@
llvm::cl::desc("Maximum number of elements to try to constant fold."),
llvm::cl::init(0));
-static Operation *getMostDominantUse(Operation *op,
- const DominanceInfo &dominanceInfo) {
- auto uses = op->getUses();
- auto it = llvm::find_if(uses, [&](OpOperand &source) {
- Operation *sourceOp = source.getOwner();
-
- return llvm::all_of(uses, [&](OpOperand &target) {
- Operation *targetOp = target.getOwner();
- return dominanceInfo.dominates(sourceOp, targetOp);
- });
- });
- if (it != uses.end()) {
- return it->getOwner();
- }
- return nullptr;
-}
-
/// Check if any of the use dominates all other uses of the operation.
-static Operation *getFusableUse(Operation *op,
- const DominanceInfo &dominanceInfo) {
+static std::optional<OpOperand *> getFusableUse(Operation *op,
+ DominanceInfo &dominanceInfo) {
auto uses = op->getUses();
- Operation *fusableUse = nullptr;
for (OpOperand &source : uses) {
Operation *sourceOp = source.getOwner();
-
- bool dominatesAllFusableOps = llvm::all_of(uses, [&](OpOperand &target) {
+ bool dominatesAllUsers = true;
+ for (OpOperand &target : uses) {
Operation *targetOp = target.getOwner();
- return !isa<linalg::GenericOp>(targetOp) ||
- dominanceInfo.dominates(sourceOp, targetOp);
- });
- if (dominatesAllFusableOps) {
- fusableUse = sourceOp;
- break;
+ if (!dominanceInfo.dominates(sourceOp, targetOp)) {
+ dominatesAllUsers = false;
+ break;
+ }
+ }
+ if (dominatesAllUsers) {
+ return &source;
}
}
- Operation *mostDominantOp = getMostDominantUse(op, dominanceInfo);
- if (!fusableUse || !mostDominantOp) {
- return nullptr;
- }
-
- // If `fusableUse` dominates all other users, there's nothing else to do.
- if (fusableUse == mostDominantOp) {
- return fusableUse;
- }
-
- SmallVector<Operation *> users(op->getUsers().begin(), op->getUsers().end());
- return isHorizontalToGroup(fusableUse, users, dominanceInfo, mostDominantOp)
- ? fusableUse
- : nullptr;
+ return std::nullopt;
}
static OpOperand *getFirstUseInConsumer(Operation *producer,
@@ -125,7 +91,6 @@
/// using elementwise fusion.
static LogicalResult doMultiUseFusion(Operation *rootOp,
llvm::SetVector<Operation *> &fusableOps,
- const DominanceInfo &dominanceInfo,
RewriterBase &rewriter) {
assert(rootOp && "root op cant be null");
@@ -147,20 +112,11 @@
Operation *consumerOp = rootOp;
OpBuilder::InsertionGuard g(rewriter);
for (Operation *producerOp : llvm::reverse(fusedOpsVec)) {
- Operation *mostDominantUser = getMostDominantUse(producerOp, dominanceInfo);
// Fuse all uses from producer -> consumer. It has been checked
// before that all uses are fusable.
while (OpOperand *fusedOperand =
getFirstUseInConsumer(producerOp, consumerOp)) {
rewriter.setInsertionPoint(consumerOp);
-
- if (consumerOp != mostDominantUser &&
- failed(moveOperandDefs(rewriter, ArrayRef<Operation *>{consumerOp},
- mostDominantUser, dominanceInfo))) {
- return rewriter.notifyMatchFailure(consumerOp,
- "failed to move operand defs");
- }
- rewriter.moveOpBefore(consumerOp, mostDominantUser);
FailureOr<linalg::ElementwiseOpFusionResult> fusionResult =
linalg::fuseElementwiseOps(rewriter, fusedOperand);
if (failed(fusionResult)) {
@@ -234,8 +190,9 @@
}
// 6. Check that the `genericOp` dominates all uses of `producer`.
- Operation *fusableUse = getFusableUse(producer, dominanceInfo);
- if (!fusableUse || fusableUse != genericOp) {
+ std::optional<OpOperand *> fusableUse =
+ getFusableUse(producer, dominanceInfo);
+ if (!fusableUse || fusableUse.value()->getOwner() != genericOp) {
continue;
}
@@ -275,8 +232,7 @@
IRRewriter rewriter(context);
for (auto it = fusedOps.rbegin(), ie = fusedOps.rend(); it != ie; ++it) {
- if (failed(
- doMultiUseFusion(it->first, it->second, dominanceInfo, rewriter))) {
+ if (failed(doMultiUseFusion(it->first, it->second, rewriter))) {
return funcOp->emitOpError("failed multi use fusion");
}
}
diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp
index 3e3e653..c428091 100644
--- a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp
@@ -10,11 +10,7 @@
#include "compiler/src/iree/compiler/DispatchCreation/FusionUtils.h"
#include "compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
-#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/IR/Dominance.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/Transforms/RegionUtils.h"
namespace mlir::iree_compiler::DispatchCreation {
@@ -101,22 +97,4 @@
return true;
}
-bool isHorizontalToGroup(Operation *op, ArrayRef<Operation *> currGroup,
- const DominanceInfo &dominanceInfo,
- Operation *seedOp) {
- assert(dominanceInfo.properlyDominates(seedOp, op) &&
- op->getParentRegion() == seedOp->getParentRegion());
- BackwardSliceOptions options;
- options.omitUsesFromAbove = false;
- // Limit the slice to the seed to make sure the slice is small.
- options.filter = [&](Operation *op) {
- return !dominanceInfo.properlyDominates(op, seedOp);
- };
- llvm::SetVector<Operation *> slice;
- getBackwardSlice(op, &slice, options);
- return !llvm::any_of(currGroup, [&](Operation *groupedOp) {
- return slice.contains(groupedOp);
- });
-}
-
} // namespace mlir::iree_compiler::DispatchCreation
diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h
index a264db9..1d9c930 100644
--- a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h
+++ b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h
@@ -10,10 +10,6 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Analysis/SliceAnalysis.h"
-#include "mlir/Analysis/TopologicalSortUtils.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/IR/Dominance.h"
#include "mlir/IR/Operation.h"
namespace mlir::iree_compiler::DispatchCreation {
@@ -23,45 +19,4 @@
bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *operand,
bool fuseMultiReduction);
-/// Check that a given operation is "horizontal" to the group. The operation
-/// is horizontal if the program slice of the operation (from op back to seedOp)
-/// does not contain any op from the group.
-bool isHorizontalToGroup(Operation *op, ArrayRef<Operation *> currGroup,
- const DominanceInfo &dominanceInfo, Operation *seedOp);
-
-/// Moves the operands and transitive defs for each op in `operations` directly
-/// after `insertionPoint`. Note: this does not check if it is legal to move the
-/// operands.
-template <typename T>
-static LogicalResult
-moveOperandDefs(RewriterBase &rewriter, ArrayRef<T> operations,
- Operation *insertionPoint, const DominanceInfo &dominanceInfo,
- ArrayRef<linalg::LinalgOp> ignoreOperations = {}) {
- BackwardSliceOptions options;
- options.omitUsesFromAbove = false;
- llvm::DenseSet<Operation *> ignoreOperationsSet;
- ignoreOperationsSet.insert(ignoreOperations.begin(), ignoreOperations.end());
- options.filter = [&](Operation *op) {
- return !dominanceInfo.properlyDominates(op, insertionPoint) &&
- !ignoreOperationsSet.contains(op);
- };
- // Set inclusive to true cause the slice is computed from the operand, and
- // we want to include the defining op (which is the point here)
- options.inclusive = true;
-
- llvm::SetVector<Operation *> slice;
- for (auto op : operations) {
- assert(insertionPoint->getBlock() == op->getBlock());
- for (auto operand : op->getOperands()) {
- getBackwardSlice(operand, &slice, options);
- }
- }
-
- mlir::topologicalSort(slice);
- for (auto op : slice) {
- rewriter.moveOpBefore(op, insertionPoint);
- }
- return success();
-}
-
} // namespace mlir::iree_compiler::DispatchCreation
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir b/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir
index c6af7b1..cc3e159 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir
@@ -139,84 +139,3 @@
// CHECK: %[[GENERIC:.+]]:2 = linalg.generic
// CHECK-DAG: check.expect_almost_eq(%[[GENERIC]]#0,
// CHECK-DAG: check.expect_almost_eq(%[[GENERIC]]#1,
-
-// -----
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
-util.func public @fuse_by_moving_consumer(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) -> (tensor<5x5xf32>, tensor<25xf32>) {
- %cst = arith.constant 1.000000e+00 : f32
- %cst_0 = arith.constant 2.000000e+00 : f32
- %cst_1 = arith.constant 3.000000e+00 : f32
- %4 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
- ^bb0(%arg2: f32, %arg3: f32):
- %8 = arith.addf %arg2, %cst : f32
- linalg.yield %8 : f32
- } -> tensor<5x5xf32>
- // expected-note @below {{prior use here}}
- %collapsed = tensor.collapse_shape %4 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32>
- %5 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%4 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
- ^bb0(%arg2: f32, %arg3: f32):
- %8 = arith.subf %arg2, %cst_0 : f32
- linalg.yield %8 : f32
- } -> tensor<5x5xf32>
- util.return %5, %collapsed: tensor<5x5xf32>, tensor<25xf32>
-}
-// CHECK-LABEL: util.func public @fuse_by_moving_consumer
-// CHECK: linalg.generic
-// CHECK-NOT: linalg.generic
-
-
-// -----
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
-util.func public @dont_fuse_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) -> (tensor<5x5xf32>, tensor<25xf32>) {
- %cst = arith.constant 1.000000e+00 : f32
- %cst_0 = arith.constant 2.000000e+00 : f32
- %cst_1 = arith.constant 3.000000e+00 : f32
- %0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
- ^bb0(%in: f32, %out: f32):
- %2 = arith.addf %in, %cst : f32
- linalg.yield %2 : f32
- } -> tensor<5x5xf32>
- %collapsed = tensor.collapse_shape %0 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32>
- %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
- ^bb0(%in: f32, %out: f32):
- %c2 = arith.constant 2 : index
- %extracted = tensor.extract %collapsed[%c2] : tensor<25xf32>
- %2 = arith.addf %extracted, %extracted : f32
- linalg.yield %2 : f32
- } -> tensor<5x5xf32>
- util.return %1, %collapsed : tensor<5x5xf32>, tensor<25xf32>
-}
-
-// CHECK-LABEL: util.func public @dont_fuse_use_from_above
-// CHECK: linalg.generic
-// CHECK: linalg.generic
-
-
-// -----
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
-util.func public @do_fuse_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) -> (tensor<5x5xf32>, tensor<25xf32>) {
- %cst = arith.constant 1.000000e+00 : f32
- %cst_0 = arith.constant 2.000000e+00 : f32
- %cst_1 = arith.constant 3.000000e+00 : f32
- %0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
- ^bb0(%in: f32, %out: f32):
- %2 = arith.addf %in, %cst : f32
- linalg.yield %2 : f32
- } -> tensor<5x5xf32>
- %collapsed = tensor.collapse_shape %0 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32>
- %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
- ^bb0(%in: f32, %out: f32):
- %c2 = arith.constant 2 : index
- %extracted = tensor.extract %arg0[%c2, %c2] : tensor<5x5xf32>
- %2 = arith.addf %extracted, %extracted : f32
- linalg.yield %2 : f32
- } -> tensor<5x5xf32>
- util.return %1, %collapsed : tensor<5x5xf32>, tensor<25xf32>
-}
-
-// CHECK-LABEL: util.func public @do_fuse_use_from_above
-// CHECK: linalg.generic
-// CHECK-NOT: linalg.generic