Limit fusion to ops with single uses to avoid duplication (#7103)
Also move createResolveShapedTypeResultDimsPass before fusion
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
index 657866f..a5e6f48 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
@@ -1127,6 +1127,11 @@
consumerIndexingMap.getResults()) {
continue;
}
+ if (llvm::any_of(
+ consumer.getOutputOperands(), [&consumer](OpOperand *operand) {
+ return !consumer.getTiedIndexingMap(operand).isIdentity();
+ }))
+ continue;
int64_t rootNumber = getRootNumber(op);
setRootAttribute(context, user, rootNumber);
removeRootOpAttribute(op);
diff --git a/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp b/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
index 27b1607..02b3b25 100644
--- a/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
@@ -79,12 +79,35 @@
consumer.getOwner()->operand_end());
if (operands.size() >= kIreeMaxOperandCount) return false;
- llvm::SmallDenseSet<Operation *, 4> numUsers;
- for (Operation *user : producer.getUsers()) {
- if (isa<linalg::GenericOp>(user)) continue;
- numUsers.insert(user);
+ bool isBroadcast = false;
+ if (auto genericOp =
+ dyn_cast<linalg::GenericOp>(producer.getOwner())) {
+ bool parallelOp =
+ llvm::all_of(genericOp.iterator_types(), [](Attribute attr) {
+ return attr.cast<StringAttr>().getValue() ==
+ getParallelIteratorTypeName();
+ });
+ if (parallelOp) {
+ for (OpOperand *opOperand : genericOp.getInputOperands()) {
+ AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
+ if (indexingMap.isProjectedPermutation() &&
+ indexingMap.getNumDims() != indexingMap.getNumResults()) {
+ isBroadcast = true;
+ break;
+ }
+ }
+ }
}
- return numUsers.empty();
+ // Only fuse if it has a single linalg generic user. It is a
+ // simplistic heuristic to avoid duplicating ops that may be
+ // expensive.
+ // TODO: Add a cost model to allow ops to be duplicated.
+ if (!isBroadcast && !isa<ConstantOp>(producer.getOwner()) &&
+ !llvm::hasSingleElement(producer.getUsers()))
+ return false;
+ return llvm::all_of(producer.getUsers(), [](Operation *user) {
+ return isa<linalg::GenericOp>(user);
+ });
};
// Simple heuristic to decide if reshaope should be folded in the linalg.
// If the source of the reshape is a linalg op fold to potentially allow the
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index d005ef7..db00679 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -117,13 +117,13 @@
mlir::createLinalgFoldUnitExtentDimsPass());
passManager.addNestedPass<mlir::FuncOp>(createInterchangeGenericOpsPass());
passManager.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
+ passManager.addPass(memref::createResolveShapedTypeResultDimsPass());
passManager.addNestedPass<mlir::FuncOp>(createFusionOfTensorOpsPass());
passManager.addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
if (clEnableLinalgDetensorize) {
passManager.addNestedPass<mlir::FuncOp>(
mlir::createLinalgDetensorizePass());
}
- passManager.addPass(memref::createResolveShapedTypeResultDimsPass());
passManager.addNestedPass<mlir::FuncOp>(
createConvertToFlowBeforeDispatchFormation());
passManager.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());