Considering duplicate operands in Linalg fusion. (#6644)
This is the follow up on https://github.com/google/iree/pull/6478
diff --git a/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp b/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
index 6689134..27b1607 100644
--- a/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
@@ -51,7 +51,7 @@
// operations. If an operation is used in a named op, it will be computed
// anyway, so the consumers can just use that value.
linalg::ControlElementwiseOpsFusionFn controlFn =
- [](const OpResult &producer, const OpOperand &consumer) {
+ [](const OpResult &producer, OpOperand &consumer) {
// TODO(GH-5611): Enable fusion with reduction consumer for all
// targets. Currently vectorization doesn't handle generic ops with
// reduction iterators we will disable for now to allow vectorizing
@@ -68,9 +68,16 @@
// passing down to HAL. Set the number to be as same as the limit --
// IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT.
constexpr int64_t kIreeMaxOperandCount = 32;
- auto numOperands = producer.getOwner()->getNumOperands() +
- consumer.getOwner()->getNumOperands() - 1;
- if (numOperands >= kIreeMaxOperandCount) return false;
+ DenseSet<Value> operands;
+ operands.insert(producer.getOwner()->operand_begin(),
+ producer.getOwner()->operand_end());
+ operands.insert(consumer.getOwner()->operand_begin(),
+ std::next(consumer.getOwner()->operand_begin(),
+ consumer.getOperandNumber()));
+ operands.insert(std::next(consumer.getOwner()->operand_begin(),
+ consumer.getOperandNumber() + 1),
+ consumer.getOwner()->operand_end());
+ if (operands.size() >= kIreeMaxOperandCount) return false;
llvm::SmallDenseSet<Operation *, 4> numUsers;
for (Operation *user : producer.getUsers()) {