Add new aggressive fusion heuristics to DispatchLinalgOnTensorsPass (#10472)

diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
index 4a8f212..2ab384c 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
@@ -31,6 +31,7 @@
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/Block.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Builders.h"
@@ -60,12 +61,6 @@
                    "dispatch region"),
     llvm::cl::init(256));
 
-static llvm::cl::opt<bool> clEnableMultiResultDispatches(
-    "iree-flow-enable-multi-result-dispatches",
-    llvm::cl::desc(
-        "Enable dispatch region formation to enable multi-result dispatches"),
-    llvm::cl::init(false));
-
 static const char kRootOpAttr[] = "__root_op__";
 static const char kFusionGroupsAttr[] = "__fused_op__";
 
@@ -381,22 +376,22 @@
   int64_t groupNum = getRootNumber(rootOp);
   std::deque<Operation *> worklist;
   worklist.push_back(rootOp);
-  llvm::SmallDenseSet<Operation *, 2> movedOps;
-  movedOps.insert(rootOp);
+  llvm::SmallDenseSet<Operation *, 2> visitedOps;
+  visitedOps.insert(rootOp);
 
   while (!worklist.empty()) {
     Operation *currRoot = worklist.front();
     worklist.pop_front();
     for (auto operand : currRoot->getOperands()) {
       auto producer = operand.getDefiningOp();
-      if (movedOps.count(producer)) continue;
-      if (!producer || !isInFusionGroup(producer, groupNum)) continue;
-      movedOps.insert(producer);
+      if (!producer || visitedOps.count(producer)) continue;
+      visitedOps.insert(producer);
+      if (!isInFusionGroup(producer, groupNum)) continue;
       worklist.push_back(producer);
       dispatchOps.push_back(producer);
     }
   }
-  return dispatchOps;
+  return llvm::to_vector(llvm::reverse(orderOperations(dispatchOps)));
 }
 
 //===---------------------------------------------------------------------===//
@@ -721,27 +716,91 @@
 // Heuristics for fusing dispatchble ops with root ops using tile + fuse.
 //===----------------------------------------------------------------------===//
 
-/// Checks if the producer and consumer LinalgOps can be fused.
-static bool areFusableLinalgOps(OpOperand &use) {
-  return areLinalgOpsFusableUsingTileAndFuse(use);
+/// Returns a bit vector of size number of loops of the `interfaceOp` with
+/// the bits corresponding to outer parallel loops set to `true`.
+static llvm::SmallBitVector getOuterParallelLoops(TilingInterface interfaceOp) {
+  SmallVector<StringRef> loopIteratorTypes = interfaceOp.getLoopIteratorTypes();
+  llvm::SmallBitVector parallelLoops(loopIteratorTypes.size());
+  for (auto iteratorType : llvm::enumerate(loopIteratorTypes)) {
+    if (iteratorType.value() != getParallelIteratorTypeName()) break;
+    parallelLoops.set(iteratorType.index());
+  }
+  return parallelLoops;
 }
 
-/// Returns true if this is a fusable use.
-static bool isFusableWithConsumer(OpOperand &use) {
-  // Check for linalg producer -> consumer fusion with tile + fuse.
-  return areFusableLinalgOps(use);
+/// Returns true if `map` is an identity map with zeros, i.e. if you
+/// drop the result exprs that are constant zeros, the `map` will become an
+/// identity.
+static bool isIdentityMapWithZeros(AffineMap map) {
+  if (map.getNumSymbols() != 0) return false;
+  unsigned dimsSeen = 0;
+  for (auto result : map.getResults()) {
+    bool isValidExpr = TypeSwitch<AffineExpr, bool>(result)
+                           .Case<AffineDimExpr>([&dimsSeen](auto dimExpr) {
+                             if (dimExpr.getPosition() != dimsSeen)
+                               return false;
+                             dimsSeen++;
+                             return true;
+                           })
+                           .Case<AffineConstantExpr>([](auto constExpr) {
+                             return constExpr.getValue() == 0;
+                           })
+                           .Default([](AffineExpr) { return false; });
+    if (!isValidExpr) return false;
+  }
+  return dimsSeen == map.getNumDims();
+}
+
+/// Method to check if two `linalg.generic` op with producer-consumer
+/// relationship through `operand` have compatible outer-parallel loops.
+static bool hasCompatibleOuterParallelLoops(
+    OpOperand &operand, bool allowConsumerParallelismPessimization) {
+  auto producer = operand.get().getDefiningOp<linalg::LinalgOp>();
+  auto consumer = dyn_cast<linalg::LinalgOp>(operand.getOwner());
+  if (!producer || !consumer) return false;
+
+  llvm::SmallBitVector producerParallelLoops =
+      getOuterParallelLoops(cast<TilingInterface>(producer.getOperation()));
+  llvm::SmallBitVector consumerParallelLoops =
+      getOuterParallelLoops(cast<TilingInterface>(consumer.getOperation()));
+
+  if (allowConsumerParallelismPessimization) {
+    if (producerParallelLoops.count() > consumerParallelLoops.count())
+      return false;
+  } else if (producerParallelLoops.count() != consumerParallelLoops.count()) {
+    return false;
+  }
+
+  auto producerIndexingMap =
+      producer.getTiedIndexingMapForResult(operand.get().cast<OpResult>());
+  auto consumerIndexingMap = consumer.getTiedIndexingMap(&operand);
+  if (!producerIndexingMap.isProjectedPermutation() ||
+      !consumerIndexingMap.isProjectedPermutation()) {
+    return false;
+  }
+
+  /// Project out the non-parallel dimensions.
+  llvm::SmallBitVector producerProjectedDims(producerParallelLoops);
+  producerProjectedDims.flip();
+  auto projectedProducerMap =
+      getProjectedMap(producerIndexingMap, producerProjectedDims);
+
+  llvm::SmallBitVector consumerProjectedDims(producerParallelLoops);
+  consumerProjectedDims.flip();
+  consumerProjectedDims.resize(consumer.getNumLoops(), true);
+  auto projectedConsumerMap =
+      getProjectedMap(consumerIndexingMap, consumerProjectedDims);
+
+  return isIdentityMapWithZeros(projectedProducerMap) &&
+         isIdentityMapWithZeros(projectedConsumerMap);
 }
 
 /// For all uses of an operation, finds the use that dominates all other uses.
 static Optional<OpOperand *> getFusableUse(Operation *op,
-                                           DominanceInfo const &dominanceInfo) {
-  if (!clEnableMultiResultDispatches) {
-    if (op->hasOneUse()) {
-      OpOperand &use = *(op->use_begin());
-      return &use;
-    }
-    return llvm::None;
-  }
+                                           DominanceInfo const &dominanceInfo,
+                                           bool fuseMultiUse) {
+  if (!fuseMultiUse && !op->hasOneUse()) return llvm::None;
+
   for (auto &use : op->getUses()) {
     Operation *user = use.getOwner();
     if (llvm::all_of(op->getUsers(), [&](Operation *c) {
@@ -753,11 +812,65 @@
   return llvm::None;
 }
 
+/// Returns true if the operands are fusable under the aggressive fusion
+/// heuristics.
+static bool areOpsAggresiveFusable(Operation *producer, Operation *consumer,
+                                   bool allowConsumerParallelismPessimization) {
+  // Collect all the uses from producer to consumer.
+  SmallVector<OpOperand *> allUses;
+  for (OpOperand &producerUse : producer->getUses()) {
+    if (producerUse.getOwner() != consumer) continue;
+    allUses.push_back(&producerUse);
+  }
+
+  // Check that the consumer and producer have compatible outer parallel loops.
+  if (!llvm::all_of(allUses, [&](OpOperand *operand) {
+        return hasCompatibleOuterParallelLoops(
+            *operand, allowConsumerParallelismPessimization);
+      })) {
+    return false;
+  }
+
+  // Finally only fuse if the `ins` operand can be properly bufferized.
+  // TODO(#10498): Handle the multi-result case.
+  return llvm::all_of(allUses, [](OpOperand *operand) {
+    return isInsOperandBufferizable(operand, /*aggressiveFusion=*/true);
+  });
+}
+
+/// Returns true if this is a fusable use, while fusing a root with its
+/// consumer.
+static bool isFusableWithConsumer(OpOperand &fusedOperand,
+                                  bool aggressiveFusion) {
+  // Use the original fusion heuristics if aggressive fusion isn't enabled.
+  if (!aggressiveFusion)
+    return areLinalgOpsFusableUsingTileAndFuse(fusedOperand);
+
+  // Logics with aggressive fusion heuristics.
+  Operation *producer = fusedOperand.get().getDefiningOp();
+  Operation *consumer = fusedOperand.getOwner();
+
+  if (!isa<linalg::LinalgOp>(producer) || !isa<linalg::LinalgOp>(consumer))
+    return false;
+
+  auto consumerLinalgOp = cast<linalg::LinalgOp>(consumer);
+
+  // Check that the consumer is all parallel.
+  if (consumerLinalgOp.getNumLoops() !=
+      consumerLinalgOp.getNumParallelLoops()) {
+    return false;
+  }
+
+  return areOpsAggresiveFusable(producer, consumer,
+                                /*allowConsumerParallelismPessimization=*/true);
+}
+
 /// Fuses roots with its consumers. If a root is fused with its consumer, it is
 /// no more tagged as a root to aid with the dispatch region formation.
 static void fuseRootsWithConsumers(MLIRContext *context,
                                    ArrayRef<Operation *> roots,
-                                   DominanceInfo const &dominanceInfo) {
+                                   DominanceInfo const &dominanceInfo,
+                                   bool aggressiveFusion) {
   SmallVector<Operation *> workList(roots.begin(), roots.end());
   // Fuse with consumers where possible.
   while (!workList.empty()) {
@@ -774,7 +887,8 @@
       appendToFusionGroup(currRoot, rootNumber);
     };
 
-    Optional<OpOperand *> fusableUse = getFusableUse(currRoot, dominanceInfo);
+    Optional<OpOperand *> fusableUse = getFusableUse(
+        currRoot, dominanceInfo, /*fuseMultiUse=*/aggressiveFusion);
     if (!fusableUse) continue;
 
     // Analyse the use to see if it is fusable.
@@ -784,7 +898,7 @@
       continue;
     }
 
-    if (isFusableWithConsumer(*(fusableUse.value()))) {
+    if (isFusableWithConsumer(*(fusableUse.value()), aggressiveFusion)) {
       updateRootTo(consumerOp);
       workList.push_back(consumerOp);
     }
@@ -792,19 +906,29 @@
 }
 
 /// Method to check if the consumer of a use can be fused with its producer.
-static bool isFusableWithProducer(OpOperand &operand) {
+static bool isFusableWithProducer(OpOperand &operand, bool aggressiveFusion) {
   Operation *producer = operand.get().getDefiningOp();
   Operation *consumer = operand.getOwner();
 
-  if (isa<linalg::LinalgOp>(consumer) && isa<linalg::LinalgOp>(producer)) {
-    auto consumerLinalgOp = cast<linalg::LinalgOp>(consumer);
-    auto producerLinalgOp = cast<linalg::LinalgOp>(producer);
-    if (consumerLinalgOp.isOutputTensor(&operand) &&
-        producerLinalgOp.getNumLoops() ==
-            producerLinalgOp.getNumParallelLoops()) {
-      return true;
-    }
+  if (!isa<linalg::LinalgOp>(consumer) || !isa<linalg::LinalgOp>(producer))
+    return false;
+
+  auto consumerLinalgOp = cast<linalg::LinalgOp>(consumer);
+  auto producerLinalgOp = cast<linalg::LinalgOp>(producer);
+  if (consumerLinalgOp.isOutputTensor(&operand) &&
+      producerLinalgOp.getNumLoops() ==
+          producerLinalgOp.getNumParallelLoops()) {
+    return true;
   }
+
+  // Only fuse on inputs if both are generic ops.
+  if (aggressiveFusion && consumerLinalgOp.isInputTensor(&operand) &&
+      isa<linalg::GenericOp>(consumer) && isa<linalg::GenericOp>(producer)) {
+    return areOpsAggresiveFusable(
+        producer, consumer,
+        /*allowConsumerParallelismPessimization=*/false);
+  }
+
   return false;
 }
 
@@ -812,21 +936,28 @@
 /// in reverse to fuse with producers.
 static void fuseRootsWithProducers(MLIRContext *context, Operation *root,
                                    unsigned groupNum,
-                                   DominanceInfo const &dominanceInfo) {
-  // We probably want a worklist algorithm here, but for now just look at
-  // immediate producers.
-  for (OpOperand &operand : root->getOpOperands()) {
-    Operation *producer = operand.get().getDefiningOp();
-    if (!producer) continue;
-    if (hasFusionGroupsAttribute(producer) || hasRootOpAttribute(producer)) {
-      continue;
-    }
+                                   DominanceInfo const &dominanceInfo,
+                                   bool aggressiveFusion) {
+  SmallVector<Operation *> worklist;
+  worklist.push_back(root);
 
-    Optional<OpOperand *> fusableUse = getFusableUse(producer, dominanceInfo);
-    if (!fusableUse || fusableUse.value()->getOwner() != root) continue;
+  while (!worklist.empty()) {
+    Operation *candidate = worklist.pop_back_val();
+    for (OpOperand &operand : candidate->getOpOperands()) {
+      Operation *producer = operand.get().getDefiningOp();
+      if (!producer) continue;
+      if (hasFusionGroupsAttribute(producer) || hasRootOpAttribute(producer)) {
+        continue;
+      }
 
-    if (isFusableWithProducer(operand)) {
+      Optional<OpOperand *> fusableUse = getFusableUse(
+          producer, dominanceInfo, /*fuseMultiUse=*/aggressiveFusion);
+      if (!fusableUse || fusableUse.value()->getOwner() != candidate) continue;
+
+      if (!isFusableWithProducer(operand, aggressiveFusion)) continue;
+
       appendToFusionGroup(producer, groupNum);
+      worklist.push_back(producer);
     }
   }
 }
@@ -840,7 +971,8 @@
 /// very simple heuristic is used below, but the mechanism should be general
 /// enough to capture any heuristic.
 static unsigned decideFusableLinalgOps(FunctionOpInterface funcOp,
-                                       DominanceInfo const &dominanceInfo) {
+                                       DominanceInfo const &dominanceInfo,
+                                       bool aggressiveFusion) {
   unsigned numRootOps = 0;
   MLIRContext *context = funcOp->getContext();
   OpBuilder builder(context);
@@ -857,11 +989,12 @@
       unsigned newGroup = numRootOps++;
       setRootAttribute(context, &op, newGroup);
 
-      fuseRootsWithProducers(context, &op, newGroup, dominanceInfo);
+      fuseRootsWithProducers(context, &op, newGroup, dominanceInfo,
+                             aggressiveFusion);
       roots.push_back(&op);
     }
     roots = llvm::to_vector(llvm::reverse(roots));
-    fuseRootsWithConsumers(context, roots, dominanceInfo);
+    fuseRootsWithConsumers(context, roots, dominanceInfo, aggressiveFusion);
   }
 
   // Once all root linalg ops have been tagged, put all remaining generic ops
@@ -881,7 +1014,7 @@
       roots.push_back(&op);
     }
     roots = llvm::to_vector(llvm::reverse(roots));
-    fuseRootsWithConsumers(context, roots, dominanceInfo);
+    fuseRootsWithConsumers(context, roots, dominanceInfo, aggressiveFusion);
   }
 
   return numRootOps;
@@ -896,8 +1029,11 @@
         .insert<AffineDialect, IREE::Flow::FlowDialect, linalg::LinalgDialect,
                 scf::SCFDialect, tensor::TensorDialect>();
   }
-  DispatchLinalgOnTensorsPass() = default;
-  DispatchLinalgOnTensorsPass(const DispatchLinalgOnTensorsPass &pass) {}
+  DispatchLinalgOnTensorsPass(bool aggressiveFusion) {
+    this->aggressiveFusion = aggressiveFusion;
+  }
+  DispatchLinalgOnTensorsPass(const DispatchLinalgOnTensorsPass &pass)
+      : DispatchLinalgOnTensorsPass(pass.aggressiveFusion) {}
   void runOnOperation() override;
 
  private:
@@ -980,7 +1116,7 @@
   auto funcOp = getOperation();
   MLIRContext *context = &getContext();
   DominanceInfo const &dominanceInfo = getAnalysis<DominanceInfo>();
-  decideFusableLinalgOps(funcOp, dominanceInfo);
+  decideFusableLinalgOps(funcOp, dominanceInfo, aggressiveFusion);
 
   LLVM_DEBUG({
     llvm::dbgs() << "\n--- After annotating linalg op fusion scheme ---\n";
@@ -1063,8 +1199,8 @@
 }
 
 std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createDispatchLinalgOnTensorsPass() {
-  return std::make_unique<DispatchLinalgOnTensorsPass>();
+createDispatchLinalgOnTensorsPass(bool aggressiveFusion) {
+  return std::make_unique<DispatchLinalgOnTensorsPass>(aggressiveFusion);
 }
 
 }  // namespace Flow
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.cpp
index 616b964..a888110 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.cpp
@@ -38,16 +38,25 @@
 /// indexing map.
 // TODO: This restriction can go away if we can vectorize always, but that has
 // a long tail of tasks.
-static bool canInsOperandTieWithOutsOperand(OpOperand *insOperand) {
+bool isInsOperandBufferizable(OpOperand *insOperand, bool aggressiveFusion) {
+  // Ignore the check if in-place bufferization is not required.
+  if (!clEnsureInplaceableConsumer) return true;
+
   auto linalgOp = dyn_cast<linalg::LinalgOp>(insOperand->getOwner());
   if (!linalgOp) return false;
 
   AffineMap insOperandIndexingMap = linalgOp.getTiedIndexingMap(insOperand);
 
   auto canTieWithOutsOperand = [&](OpOperand *outsOperand) {
-    if (linalgOp.getTiedIndexingMap(outsOperand) != insOperandIndexingMap) {
-      return false;
+    AffineMap outsOperandIndexingMap = linalgOp.getTiedIndexingMap(outsOperand);
+
+    if (outsOperandIndexingMap != insOperandIndexingMap) {
+      if (!aggressiveFusion) return false;
+      // If the operand is a projected permutation a small stack might be
+      // fine.
+      if (!insOperandIndexingMap.isProjectedPermutation()) return false;
     }
+
     // TODO(#8411): Until ops are vectorized (always), we need
     // to check that the elementtype matches for the operands to be tied.
     // For now just doing this check for convolution ops since we expect
@@ -219,7 +228,7 @@
 
   // 4. In-place bufferization requirements (for now) require that the use in
   // the consumer can re-use the buffer for a result.
-  return !clEnsureInplaceableConsumer || canInsOperandTieWithOutsOperand(&use);
+  return isInsOperandBufferizable(&use, /*aggressiveFusion=*/false);
 }
 
 }  // namespace Flow
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.h
index 245ad75..2f1b5f4 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.h
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.h
@@ -18,6 +18,10 @@
 namespace IREE {
 namespace Flow {
 
+/// Returns true if the `ins` operand can be properly bufferized after the
+/// fusion.
+bool isInsOperandBufferizable(OpOperand *insOperand, bool aggressiveFusion);
+
 /// Returns true if the `use` is from a producer linalg op that can be fused
 /// with the consumer linalg op using tile + fuse.
 bool areLinalgOpsFusableUsingTileAndFuse(OpOperand &use);
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index a5017fd..11ffd5e 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -273,8 +273,11 @@
                          })
       // Only want use the transform dialect for some dispatch regions and let
       // the DispatchLinalgOnTensorsPass handle the rest.
-      .addPredicatedPass(!clDispatchViaRegionOps,
-                         createDispatchLinalgOnTensorsPass)
+      .addPredicatedPass(
+          !clDispatchViaRegionOps,
+          []() {
+            return createDispatchLinalgOnTensorsPass(clEnableAggressiveFusion);
+          })
       // DispatchLinalgOnTensorsViaRegionsPass is a variant of
       // DispatchLinalgOnTensorsPass that lowers via DispatchRegionOps. This is
       // on an opt-in basis until the pass is stable enough to replace
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
index adf7ab7..71a1226 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -143,7 +143,7 @@
 // Pass to perform dispatch of Linalg on tensor ops by tiling and distribution.
 // A dispatch region is created for each tiled loop nest.
 std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createDispatchLinalgOnTensorsPass();
+createDispatchLinalgOnTensorsPass(bool aggressiveFusion = false);
 
 // Pass to perform dispatch of Linalg on tensor ops by tiling and distribution.
 // A dispatch region is created for each tiled loop nest. (First create
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
index d68df0d..3c4f6f8 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
@@ -66,6 +66,10 @@
     InterfacePass<"iree-flow-dispatch-linalg-on-tensors-pass", "mlir::FunctionOpInterface"> {
   let summary = "Dispatch Linalg operations on tensors by using tile and distribute";
   let constructor = "mlir::iree_compiler::IREE::Flow::createDispatchLinalgOnTensorsPass()";
+  let options = [
+    Option<"aggressiveFusion", "aggressive-fusion", "bool",
+           /*default=*/"false", "Fuse with aggressive heuristics">,
+  ];
 }
 
 def DispatchLinalgOnTensorsViaRegionOps :
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD
index e01d0d7..abbe416 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD
@@ -24,6 +24,7 @@
             "deduplicate_executables.mlir",
             "detach_elementwise_from_named_ops.mlir",
             "dispatch_linalg_on_tensors.mlir",
+            "dispatch_linalg_on_tensors_aggressive_fusion.mlir",
             "dispatch_linalg_on_tensors_fusion.mlir",
             "dispatch_linalg_on_tensors_fusion_reduction_broadcast_elementwise.mlir",
             "dispatch_linalg_on_tensors_fusion_with_transpose.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
index d033ea3..e5cce09 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
@@ -22,6 +22,7 @@
     "deduplicate_executables.mlir"
     "detach_elementwise_from_named_ops.mlir"
     "dispatch_linalg_on_tensors.mlir"
+    "dispatch_linalg_on_tensors_aggressive_fusion.mlir"
     "dispatch_linalg_on_tensors_fusion.mlir"
     "dispatch_linalg_on_tensors_fusion_reduction_broadcast_elementwise.mlir"
     "dispatch_linalg_on_tensors_fusion_with_transpose.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
index 43651b4..89cd23e 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --verify-diagnostics --iree-flow-enable-multi-result-dispatches --pass-pipeline="func.func(iree-flow-dispatch-linalg-on-tensors-pass), cse, canonicalize, cse" %s | FileCheck %s
+// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="func.func(iree-flow-dispatch-linalg-on-tensors-pass{aggressive-fusion=true}), cse, canonicalize, cse" %s | FileCheck %s
 
 func.func @tile_matmul_alone(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
              %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_aggressive_fusion.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_aggressive_fusion.mlir
new file mode 100644
index 0000000..5ffa94c
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_aggressive_fusion.mlir
@@ -0,0 +1,107 @@
+// RUN: iree-opt --split-input-file --pass-pipeline="func.func(iree-flow-dispatch-linalg-on-tensors-pass{aggressive-fusion=true})" %s | FileCheck %s
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
+module {
+  func.func @softmax(%arg0: tensor<12x128x128xf32>) -> tensor<12x128x128xf32> {
+    %cst = arith.constant 1.000000e+00 : f32
+    %cst_0 = arith.constant 0.000000e+00 : f32
+    %cst_1 = arith.constant -3.40282347E+38 : f32
+    %0 = linalg.init_tensor [12, 128] : tensor<12x128xf32>
+    %1 = linalg.fill ins(%cst_1 : f32) outs(%0 : tensor<12x128xf32>) -> tensor<12x128xf32>
+    %2 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<12x128x128xf32>) outs(%1 : tensor<12x128xf32>) {
+    ^bb0(%arg1: f32, %arg2: f32):
+      %7 = arith.maxf %arg1, %arg2 : f32
+      linalg.yield %7 : f32
+    } -> tensor<12x128xf32>
+    %3 = linalg.init_tensor [12, 128, 128] : tensor<12x128x128xf32>
+    %4 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<12x128xf32>) -> tensor<12x128xf32>
+    %5:2 = linalg.generic {indexing_maps = [#map0, #map1, #map0, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %2 : tensor<12x128x128xf32>, tensor<12x128xf32>) outs(%3, %4 : tensor<12x128x128xf32>, tensor<12x128xf32>) {
+    ^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32):
+      %7 = arith.subf %arg1, %arg2 : f32
+      %8 = math.exp %7 : f32
+      %9 = arith.addf %8, %arg4 : f32
+      linalg.yield %8, %9 : f32, f32
+    } -> (tensor<12x128x128xf32>, tensor<12x128xf32>)
+    %6 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel", "parallel"]} ins(%5#0, %5#1 : tensor<12x128x128xf32>, tensor<12x128xf32>) outs(%3 : tensor<12x128x128xf32>) {
+    ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
+      %7 = arith.divf %cst, %arg2 : f32
+      %8 = arith.mulf %arg1, %7 : f32
+      linalg.yield %8 : f32
+    } -> tensor<12x128x128xf32>
+    return %6 : tensor<12x128x128xf32>
+  }
+}
+// CHECK-LABEL: func @softmax(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<12x128x128xf32>
+//       CHECK:   %[[DISPATCH:.+]] = flow.dispatch.workgroups
+//  CHECK-SAME:       (%[[ARG0]])
+//  CHECK-NEXT:     %[[ARG1:.+]]: !flow.dispatch.tensor<readonly:12x128x128xf32>
+//       CHECK:     %[[LOAD0:.+]] = flow.dispatch.tensor.load %[[ARG1]]
+//       CHECK:     %[[FILL0:.+]] = linalg.fill
+//       CHECK:     %[[FILL1:.+]] = linalg.fill
+//       CHECK:     %[[GENERIC0:.+]] = linalg.generic
+//  CHECK-SAME:         ins(%[[LOAD0]] :
+//       CHECK:     %[[GENERIC1:.+]]:2 = linalg.generic
+//  CHECK-SAME:         ins(%[[LOAD0]], %[[GENERIC0]] :
+//       CHECK:     %[[GENERIC2:.+]] = linalg.generic
+//  CHECK-SAME:         ins(%[[GENERIC1]]#0, %[[GENERIC1]]#1 :
+//       CHECK:     flow.dispatch.tensor.store %[[GENERIC2]]
+//       CHECK:     flow.return
+//       CHECK:   return %[[DISPATCH]]
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d3, d4, d0)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0)>
+#map2 = affine_map<(d0) -> (d0)>
+module {
+  func.func @batchnorm_training(%arg0: tensor<12xf32>, %arg1: tensor<12x12x12x12x12xf32>, %arg2: tensor<12xf32>) -> (tensor<12xf32>, tensor<12xf32>, tensor<12xf32>) {
+    %cst = arith.constant 1.420000e+00 : f32
+    %cst_0 = arith.constant 1.450000e+00 : f32
+    %cst_1 = arith.constant 1.300000e+00 : f32
+    %cst_2 = arith.constant 0.000000e+00 : f32
+    %0 = linalg.init_tensor [12] : tensor<12xf32>
+    %1 = linalg.fill ins(%cst_2 : f32) outs(%0 : tensor<12xf32>) -> tensor<12xf32>
+    %2 = linalg.generic {indexing_maps = [#map0, #map1, #map1], iterator_types = ["parallel", "reduction", "reduction", "reduction", "reduction"]} ins(%arg1, %arg2 : tensor<12x12x12x12x12xf32>, tensor<12xf32>) outs(%1 : tensor<12xf32>) {
+    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+      %4 = arith.subf %arg3, %arg4 : f32
+      %5 = arith.mulf %4, %4 : f32
+      %6 = arith.addf %arg5, %5 : f32
+      linalg.yield %6 : f32
+    } -> tensor<12xf32>
+    %3:3 = linalg.generic {indexing_maps = [#map2, #map2, #map2, #map2, #map2], iterator_types = ["parallel"]} ins(%arg0, %2 : tensor<12xf32>, tensor<12xf32>) outs(%0, %0, %0 : tensor<12xf32>, tensor<12xf32>, tensor<12xf32>) {
+    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32, %arg7: f32):
+      %4 = arith.divf %arg4, %cst_0 : f32
+      %5 = arith.addf %4, %cst_1 : f32
+      %6 = math.sqrt %5 : f32
+      %7 = arith.subf %arg3, %6 : f32
+      %8 = arith.mulf %7, %cst : f32
+      %9 = arith.subf %arg3, %8 : f32
+      linalg.yield %5, %6, %9 : f32, f32, f32
+    } -> (tensor<12xf32>, tensor<12xf32>, tensor<12xf32>)
+    return %3#0, %3#1, %3#2 : tensor<12xf32>, tensor<12xf32>, tensor<12xf32>
+  }
+}
+// CHECK-LABEL: func @batchnorm_training(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<12xf32>
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<12x12x12x12x12xf32>
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<12xf32>
+//       CHECK:   %[[DISPATCH:.+]]:3 = flow.dispatch.workgroups
+//  CHECK-SAME:       (%[[ARG1]], %[[ARG2]], %[[ARG0]])
+//  CHECK-NEXT:     %[[ARG3:.+]]: !flow.dispatch.tensor<readonly:12x12x12x12x12xf32>
+//  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:12xf32>
+//  CHECK-SAME:     %[[ARG5:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:12xf32>
+//   CHECK-DAG:     %[[LOAD0:.+]] = flow.dispatch.tensor.load %[[ARG3]]
+//   CHECK-DAG:     %[[LOAD1:.+]] = flow.dispatch.tensor.load %[[ARG4]]
+//   CHECK-DAG:     %[[LOAD2:.+]] = flow.dispatch.tensor.load %[[ARG5]]
+//       CHECK:     %[[FILL:.+]] = linalg.fill
+//       CHECK:     %[[GENERIC0:.+]] = linalg.generic
+//  CHECK-SAME:         ins(%[[LOAD0]], %[[LOAD1]] :
+//       CHECK:     %[[GENERIC1:.+]]:3 = linalg.generic
+//  CHECK-SAME:         ins(%[[LOAD2]], %[[GENERIC0]] :
+//   CHECK-DAG:     flow.dispatch.tensor.store %[[GENERIC1]]#0
+//   CHECK-DAG:     flow.dispatch.tensor.store %[[GENERIC1]]#1
+//   CHECK-DAG:     flow.dispatch.tensor.store %[[GENERIC1]]#2
+//       CHECK:     flow.return
+//       CHECK:   return %[[DISPATCH]]#0, %[[DISPATCH]]#1, %[[DISPATCH]]#2
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_fusion.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_fusion.mlir
index ef08d5b..e17cd08 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_fusion.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_fusion.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --verify-diagnostics --iree-flow-enable-multi-result-dispatches --pass-pipeline="func.func(iree-flow-dispatch-linalg-on-tensors-pass)" --canonicalize -cse %s | FileCheck %s
+// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="func.func(iree-flow-dispatch-linalg-on-tensors-pass{aggressive-fusion=true})" --canonicalize -cse %s | FileCheck %s
 
 func.func @fuse_conv2d_elementwise(%input: tensor<1x225x225x16xf32>, %filter: tensor<3x3x16x32xf32>, %offset: tensor<32xf32>) -> tensor<1x112x112x32xf32> {
   %cst = arith.constant 0.000000e+00 : f32
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_fusion_reduction_broadcast_elementwise.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_fusion_reduction_broadcast_elementwise.mlir
index a552edd..2095fe4 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_fusion_reduction_broadcast_elementwise.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_fusion_reduction_broadcast_elementwise.mlir
@@ -106,7 +106,7 @@
 #map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
 #map2 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 
-func.func @reduction_broadcast_elementwise_type_mismatch(%a: tensor<12x16x16xf32>, %b: tensor<12x16x32xf32>) -> tensor<12x16x32xf32> {
+func.func @reduction_broadcast_elementwise_type_mismatch(%a: tensor<12x16x16xf32>, %b: tensor<12x16x32xf32>) -> tensor<12x16x32xi32> {
   %cst_47 = arith.constant 0.000000e+00 : f32
   %37 = linalg.init_tensor [12, 16] : tensor<12x16xf32>
   %38 = linalg.fill ins(%cst_47 : f32) outs(%37 : tensor<12x16xf32>) -> tensor<12x16xf32>
@@ -115,13 +115,14 @@
     %780 = arith.maxf %arg3, %arg4 : f32
     linalg.yield %780 : f32
   } -> tensor<12x16xf32>
-  %40 = linalg.init_tensor [12, 16, 32] : tensor<12x16x32xf32>
-  %42 = linalg.generic {indexing_maps = [#map2, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%b, %39 : tensor<12x16x32xf32>, tensor<12x16xf32>) outs(%40 : tensor<12x16x32xf32>) {
-    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+  %40 = linalg.init_tensor [12, 16, 32] : tensor<12x16x32xi32>
+  %42 = linalg.generic {indexing_maps = [#map2, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%b, %39 : tensor<12x16x32xf32>, tensor<12x16xf32>) outs(%40 : tensor<12x16x32xi32>) {
+    ^bb0(%arg3: f32, %arg4: f32, %arg5: i32):
     %780 = arith.subf %arg3, %arg4 : f32
-    linalg.yield %780 : f32
-  } -> tensor<12x16x32xf32>
-  return %42 : tensor<12x16x32xf32>
+    %781 = arith.fptosi %780 : f32 to i32
+    linalg.yield %781 : i32
+  } -> tensor<12x16x32xi32>
+  return %42 : tensor<12x16x32xi32>
 }
 
 // Check that two generic ops are NOT dispatched together since the input type
@@ -163,4 +164,3 @@
 // CHECK-LABEL: func.func @reduction_broadcast_elementwise_dynamic
 //      CHECK: flow.dispatch.workgroups
 //      CHECK: flow.dispatch.workgroups
-