Limit fusion to ops with single uses to avoid duplication (#7103)

Also move createResolveShapedTypeResultDimsPass before fusion
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
index 657866f..a5e6f48 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
@@ -1127,6 +1127,11 @@
                 consumerIndexingMap.getResults()) {
           continue;
         }
+        if (llvm::any_of(
+                consumer.getOutputOperands(), [&consumer](OpOperand *operand) {
+                  return !consumer.getTiedIndexingMap(operand).isIdentity();
+                }))
+          continue;
         int64_t rootNumber = getRootNumber(op);
         setRootAttribute(context, user, rootNumber);
         removeRootOpAttribute(op);
diff --git a/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp b/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
index 27b1607..02b3b25 100644
--- a/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
@@ -79,12 +79,35 @@
                           consumer.getOwner()->operand_end());
           if (operands.size() >= kIreeMaxOperandCount) return false;
 
-          llvm::SmallDenseSet<Operation *, 4> numUsers;
-          for (Operation *user : producer.getUsers()) {
-            if (isa<linalg::GenericOp>(user)) continue;
-            numUsers.insert(user);
+          bool isBroadcast = false;
+          if (auto genericOp =
+                  dyn_cast<linalg::GenericOp>(producer.getOwner())) {
+            bool parallelOp =
+                llvm::all_of(genericOp.iterator_types(), [](Attribute attr) {
+                  return attr.cast<StringAttr>().getValue() ==
+                         getParallelIteratorTypeName();
+                });
+            if (parallelOp) {
+              for (OpOperand *opOperand : genericOp.getInputOperands()) {
+                AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
+                if (indexingMap.isProjectedPermutation() &&
+                    indexingMap.getNumDims() != indexingMap.getNumResults()) {
+                  isBroadcast = true;
+                  break;
+                }
+              }
+            }
           }
-          return numUsers.empty();
+          // Only fuse if it has a single linalg generic user. It is a
+          // simplistic heuristic to avoid duplicating ops that may be
+          // expensive.
+          // TODO: Add a cost model to allow ops to be duplicated.
+          if (!isBroadcast && !isa<ConstantOp>(producer.getOwner()) &&
+              !llvm::hasSingleElement(producer.getUsers()))
+            return false;
+          return llvm::all_of(producer.getUsers(), [](Operation *user) {
+            return isa<linalg::GenericOp>(user);
+          });
         };
     // Simple heuristic to decide if reshaope should be folded in the linalg.
     // If the source of the reshape is a linalg op fold to potentially allow the
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index d005ef7..db00679 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -117,13 +117,13 @@
       mlir::createLinalgFoldUnitExtentDimsPass());
   passManager.addNestedPass<mlir::FuncOp>(createInterchangeGenericOpsPass());
   passManager.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
+  passManager.addPass(memref::createResolveShapedTypeResultDimsPass());
   passManager.addNestedPass<mlir::FuncOp>(createFusionOfTensorOpsPass());
   passManager.addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
   if (clEnableLinalgDetensorize) {
     passManager.addNestedPass<mlir::FuncOp>(
         mlir::createLinalgDetensorizePass());
   }
-  passManager.addPass(memref::createResolveShapedTypeResultDimsPass());
   passManager.addNestedPass<mlir::FuncOp>(
       createConvertToFlowBeforeDispatchFormation());
   passManager.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());