[Flow] Allow CollapseDimensions pass to fold reduction dimensions as well (#14656)

This makes the CollapseReductionDims pass redundant and can be dropped.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
index a3ad95f..515bb3b 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
@@ -35,7 +35,6 @@
         "CleanupTensorShapes.cpp",
         "CloneProducersIntoDispatchRegions.cpp",
         "CollapseDimensions.cpp",
-        "CollapseReductionDims.cpp",
         "Convert1X1FilterConv2DToMatmul.cpp",
         "ConvertRegionToWorkgroups.cpp",
         "ConvertToFlow.cpp",
@@ -100,6 +99,7 @@
         "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:AffineDialect",
+        "@llvm-project//mlir:AffineUtils",
         "@llvm-project//mlir:Analysis",
         "@llvm-project//mlir:ArithDialect",
         "@llvm-project//mlir:ArithUtils",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
index c3b86cc..bab9e87 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -34,7 +34,6 @@
     "CleanupTensorShapes.cpp"
     "CloneProducersIntoDispatchRegions.cpp"
     "CollapseDimensions.cpp"
-    "CollapseReductionDims.cpp"
     "Convert1X1FilterConv2DToMatmul.cpp"
     "ConvertRegionToWorkgroups.cpp"
     "ConvertToFlow.cpp"
@@ -82,6 +81,7 @@
     IREELinalgTransformDialect
     LLVMSupport
     MLIRAffineDialect
+    MLIRAffineUtils
     MLIRAnalysis
     MLIRArithDialect
     MLIRArithUtils
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseDimensions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseDimensions.cpp
index 6b0445e..19a0a43 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseDimensions.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseDimensions.cpp
@@ -12,6 +12,8 @@
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Affine/Utils.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
@@ -27,6 +29,8 @@
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
+#include <deque>
+
 #define DEBUG_TYPE "iree-flow-collapse-dimensions"
 
 namespace mlir {
@@ -48,47 +52,88 @@
 
 /// Searches the same sequence in all the affine maps and collapses these
 /// dimensions. It only applies these to "parallel" loops without mixing them
-/// with "reduction" types.
+/// with "reduction" types. It is expected that the `genericOp` has projected
+/// permutations only as indexing maps. (Checked using `isEligibleForCollapse`).
 static SmallVector<ReassociationIndices>
 getCollapsibleLoops(linalg::GenericOp genericOp) {
   SmallVector<ReassociationIndices> contiguousLoops;
 
-  SmallVector<unsigned> pDims;
+  SmallVector<unsigned> pDims, rDims;
   genericOp.getParallelDims(pDims);
-  if (pDims.size() < 2)
-    return contiguousLoops;
-
-  llvm::SmallDenseSet<unsigned> pLoops(pDims.begin(), pDims.end());
+  genericOp.getReductionDims(rDims);
+  llvm::SmallDenseSet<unsigned> pDimsSet, rDimsSet;
+  pDimsSet.insert(pDims.begin(), pDims.end());
+  rDimsSet.insert(rDims.begin(), rDims.end());
 
   auto hasAllMapsSameSequence = [&](AffineExpr preExpr, AffineExpr nextExpr) {
+    // Check that all indexing maps of the `genericOp`
+    // - Either both `preExpr` and `nextExpr` contiguous, or
+    // - are missing in
+    // Then `preExpr` and `nextExpr` can be collapsed.
     for (AffineMap map : genericOp.getIndexingMapsArray()) {
-      bool foundSeq = false;
+      // If map has no results, no need to check.
+      if (map.getNumResults() == 0) {
+        continue;
+      }
       for (auto [index, resultExpr] : llvm::enumerate(map.getResults())) {
+        // If we find the preExpr, we should find the nextExpr.
+        if (resultExpr == preExpr) {
+          if (index == map.getNumResults() - 1) {
+            // Reached end of list. Return false;
+            return false;
+          }
+          if (map.getResult(index + 1) != nextExpr) {
+            return false;
+          }
+        }
+        // If we find nextExpr the previous one should be `prevExpr`.
+        // This is redundant check for the most part, but is cheap enough, so
+        // #YOLO
         if (resultExpr == nextExpr) {
-          foundSeq = (index > 0 && preExpr == map.getResult(index - 1));
-          break;
+          if (index == 0) {
+            // match at beginning of the list. Return false;
+            return false;
+          }
+          if (map.getResult(index - 1) != preExpr) {
+            return false;
+          }
         }
       }
-      if (!foundSeq)
-        return false;
     }
     return true;
   };
+  auto hasSameIteratorType = [&](AffineExpr preExpr, AffineExpr nextExpr) {
+    unsigned prePos = preExpr.cast<AffineDimExpr>().getPosition();
+    unsigned nextPos = nextExpr.cast<AffineDimExpr>().getPosition();
+    return (pDimsSet.count(prePos) && pDimsSet.count(nextPos)) ||
+           (rDimsSet.count(prePos) && rDimsSet.count(nextPos));
+  };
 
   ReassociationIndices range;
   AffineExpr preExpr;
+  // Find the largest sequence of dimensions that are
+  // - Either preserved in all maps, or
+  // - are completely absent
+  // This sequence can be collapsed. To find the sequence,
+  // 1) Take the result expressions of one of the indexing maps
+  // 2) Find a sequence of 2 that is found in all maps
+  // 3) Then take last element of this sequence and the next
+  //    result expression, and check if this sequence of 2 is
+  //    found in all maps. If so, add to sequence (to get a sequence of 3)
+  //    and repeat till the last element of sequence and the next result
+  //    expression is not found as a sequence in all maps.
   for (auto nextExpr : genericOp.getIndexingMapsArray().front().getResults()) {
-    unsigned pos = nextExpr.cast<AffineDimExpr>().getPosition();
     if (!range.empty()) {
-      if (!hasAllMapsSameSequence(preExpr, nextExpr) || !pLoops.count(pos)) {
-        if (range.size() > 1)
+      if (!hasAllMapsSameSequence(preExpr, nextExpr) ||
+          !hasSameIteratorType(preExpr, nextExpr)) {
+        if (range.size() > 1) {
           contiguousLoops.push_back({range.begin(), range.end()});
+        }
         range.clear();
       }
     }
+    range.push_back(nextExpr.cast<AffineDimExpr>().getPosition());
     preExpr = nextExpr;
-    if (pLoops.count(pos))
-      range.push_back(pos);
   }
   if (range.size() > 1)
     contiguousLoops.push_back(range);
@@ -107,22 +152,6 @@
   return contiguousLoops;
 }
 
-/// Collapse possible dimension of the given linalg.generic
-static FailureOr<SmallVector<Value>>
-collapseLinalgGeneric(IRRewriter &rewriter, linalg::GenericOp genericOp,
-                      SmallVector<ReassociationIndices> &collapseIndices) {
-  rewriter.setInsertionPoint(genericOp->getParentOp());
-  FailureOr<SmallVector<Value>> replacements =
-      mlir::linalg::collapseGenericOpIterationDims(genericOp, collapseIndices,
-                                                   rewriter);
-  if (failed(replacements) || replacements->empty()) {
-    return rewriter.notifyMatchFailure(genericOp,
-                                       "failed to collapse dimensions");
-  }
-
-  return replacements;
-}
-
 /// Returns true if the given op is collapsable.
 static bool isEligibleForCollapse(linalg::GenericOp genericOp) {
   // TODO(guray) There is no mechanism to tell the collapsed indexes to
@@ -154,101 +183,298 @@
 /// without any producers.
 static FailureOr<linalg::GenericOp>
 findRootGenericOp(DispatchRegionOp regionOp) {
-  SmallVector<Operation *> computeOps;
-  auto &ops = regionOp.getBody().front().getOperations();
-  for (Operation &op : ops) {
-    if (isa<TilingInterface>(op))
-      computeOps.push_back(&op);
+  if (!llvm::hasSingleElement(regionOp.getBody())) {
+    return failure();
   }
-  // Looking for root without producer
-  if (computeOps.size() != 1 || ops.size() != 2)
+
+  // Check the yielded value is from a single `linalg.generic`.
+  auto returnOp =
+      cast<Flow::ReturnOp>(regionOp.getBody().front().getTerminator());
+  auto collapsibleOp = dyn_cast_or_null<linalg::GenericOp>(
+      returnOp->getOperand(0).getDefiningOp());
+  if (!collapsibleOp) {
     return failure();
-  auto genericOp = llvm::dyn_cast<linalg::GenericOp>(computeOps.front());
-  if (!genericOp)
+  }
+  for (auto returnVal : returnOp->getOperands().drop_front()) {
+    if (returnVal.getDefiningOp() != collapsibleOp.getOperation()) {
+      return failure();
+    }
+  }
+
+  // Check that the operands of the generic op are defined outside the dispatch.
+  for (OpOperand *inputOperands : collapsibleOp.getDpsInputOperands()) {
+    Operation *definingOp = inputOperands->get().getDefiningOp();
+    if (definingOp &&
+        definingOp->getParentOfType<DispatchRegionOp>() == regionOp) {
+      return failure();
+    }
+  }
+
+  // Check that the output is either a `tensor.empty` or a `linalg.fill` op by
+  // traversing the operations that define the `init` operands of the
+  // `collapsibleOp`.
+  std::deque<Operation *> worklist;
+  llvm::SmallDenseSet<Operation *> visited;
+  auto addDefiningOpToWorklist = [&](Value v) {
+    Operation *definingOp = v.getDefiningOp();
+    if (definingOp &&
+        definingOp->getParentOfType<DispatchRegionOp>() == regionOp &&
+        !visited.count(definingOp)) {
+      worklist.push_back(definingOp);
+      visited.insert(definingOp);
+    }
+  };
+  for (Value initOperand : collapsibleOp.getDpsInits()) {
+    addDefiningOpToWorklist(initOperand);
+  }
+
+  while (!worklist.empty()) {
+    Operation *op = worklist.front();
+    worklist.pop_front();
+    if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
+      addDefiningOpToWorklist(fillOp.getDpsInitOperand(0)->get());
+      continue;
+    }
+    if (isa<tensor::EmptyOp>(op)) {
+      continue;
+    }
     return failure();
-  return genericOp;
+  }
+  return collapsibleOp;
 }
 
-/// Generate a new dispatch.region and workload according with the collapsed
-/// linalg Generic Op
-static LogicalResult
-generateNewDispatchRegion(IRRewriter &rewriter, DispatchRegionOp regionOp,
-                          SmallVector<Value> collapseResults,
-                          linalg::GenericOp newGenericOp) {
-  OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPoint(regionOp->getParentOp());
-
-  auto maybeRegionOp = Flow::wrapOpInDispatchRegion(rewriter, newGenericOp);
-  if (failed(maybeRegionOp))
+/// Hoist `tensor.collapse_shape` ops at the beginning of the `dispatchOp`
+/// and `tensor.expand_shape` ops at the end of the `dispatchOp`, out of the
+/// dispatch.
+static FailureOr<DispatchRegionOp>
+hoistTensorReshapesOutOfDispatchRegion(RewriterBase &rewriter,
+                                       DispatchRegionOp dispatchOp) {
+  // Only do this for `dispatchOp` with a single operation.
+  if (!llvm::hasSingleElement(dispatchOp.getBody())) {
     return failure();
+  }
+  Block &body = dispatchOp.getBody().front();
+  auto returnOp = cast<Flow::ReturnOp>(body.getTerminator());
 
-  // Replace old regionOp with the result of collapse
-  rewriter.replaceOp(regionOp, collapseResults);
+  // 1. Get the slice of operations within `dispatchOp` that produce the yielded
+  // value.
+  BackwardSliceOptions sliceOptions;
+  sliceOptions.filter = [&](Operation *op) {
+    return op->getParentOfType<DispatchRegionOp>();
+  };
+  SetVector<Operation *> slice;
+  getBackwardSlice(returnOp, &slice, sliceOptions);
 
-  return success();
+  // 2. Get the leaf operations that are tensor.collapse_shape ops.
+  SmallVector<tensor::CollapseShapeOp> leafs;
+  for (Operation *op : slice) {
+    auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(op);
+    if (!collapseShapeOp) {
+      continue;
+    }
+    if (llvm::all_of(op->getOperands(), [&](Value operand) {
+          Operation *definingOp = operand.getDefiningOp();
+          return !definingOp || slice.count(definingOp) == 0;
+        })) {
+      leafs.push_back(collapseShapeOp);
+    }
+  }
+
+  // 3. Clone the leaf `tensor.collapse_shape` ops outside the dispatch.
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(dispatchOp);
+  for (auto reshapeOp : leafs) {
+    Operation *clonedOp = rewriter.clone(*reshapeOp.getOperation());
+    rewriter.replaceOp(reshapeOp, clonedOp->getResults());
+  }
+
+  // 4. From the yielded values find any that are produced by
+  //    `tensor.expand_shape` operation and move them out of the dispatch. For
+  //    this a new `DispatchRegionOp` is needed. For values that are yielded and
+  //    produced from `tensor.expand_shape`, the type of the result changes. The
+  //    dynamic dimensions of the result type also need to be updated.
+  SmallVector<Type> newReturnTypes;
+  SmallVector<Value> newDynamicDims;
+  SmallVector<Value> newYieldVals;
+  SmallVector<SmallVector<ReassociationIndices>> allReassociationIndices;
+  ValueRange dynamicDimsList = dispatchOp.getResultDims();
+  Location loc = dispatchOp.getLoc();
+  for (Value yieldedValue : returnOp->getOperands()) {
+    auto expandShapeOp = yieldedValue.getDefiningOp<tensor::ExpandShapeOp>();
+    if (!expandShapeOp) {
+      // 4a. Keep the same yield value if the producer is not a
+      // `tensor.expand_shape` op.
+      newReturnTypes.push_back(yieldedValue.getType());
+      newYieldVals.push_back(yieldedValue);
+      continue;
+    }
+
+    // 4b. The return type is same as the type of the source of the
+    // `tensor.expand_shape`.
+    RankedTensorType collapsedShapeType = expandShapeOp.getSrcType();
+    newReturnTypes.push_back(collapsedShapeType);
+    newYieldVals.push_back(expandShapeOp.getSrc());
+    SmallVector<ReassociationIndices> reassociation =
+        expandShapeOp.getReassociationIndices();
+    ArrayRef<int64_t> expandedShape = expandShapeOp.getResultType().getShape();
+
+    // 4c. Dynamic dims of the result shape is obtained by taking the static
+    //     shape + dynamic dims and collapsing them using the same reassociation
+    //     map as the `tensor.expand_shape`.
+    for (auto [index, shape] : llvm::enumerate(collapsedShapeType.getShape())) {
+      int64_t staticCollapsedShape = 1;
+      SmallVector<OpFoldResult> dynamicCollapsedDims;
+      for (auto collapsedDim : reassociation[index]) {
+        if (expandedShape[collapsedDim] == ShapedType::kDynamic) {
+          dynamicCollapsedDims.push_back(dynamicDimsList.front());
+          dynamicDimsList = dynamicDimsList.drop_front();
+        } else {
+          staticCollapsedShape *= expandedShape[collapsedDim];
+        }
+      }
+
+      if (dynamicCollapsedDims.empty()) {
+        // If there are no dynamic dims, there is nothing to do.
+        continue;
+      }
+      SmallVector<AffineExpr> exprs(dynamicCollapsedDims.size());
+      bindSymbolsList(rewriter.getContext(),
+                      MutableArrayRef<AffineExpr>(exprs));
+      AffineExpr multiplyAll = exprs.front();
+      for (auto expr : ArrayRef<AffineExpr>(exprs).drop_front()) {
+        multiplyAll = multiplyAll * expr;
+      }
+      if (staticCollapsedShape != 1) {
+        multiplyAll = multiplyAll * staticCollapsedShape;
+      }
+      OpFoldResult collapsedShape = affine::makeComposedFoldedAffineApply(
+          rewriter, loc, multiplyAll, dynamicCollapsedDims);
+      newDynamicDims.push_back(
+          getValueOrCreateConstantIndexOp(rewriter, loc, collapsedShape));
+    }
+    allReassociationIndices.emplace_back(std::move(reassociation));
+  }
+
+  // 5. Create the new dispatch op.
+  auto newDispatchOp = rewriter.create<DispatchRegionOp>(
+      loc, newReturnTypes, newDynamicDims, dispatchOp.getWorkload());
+
+  // 5a. Move the body over, but replace the `flow.return` to use the new yield
+  // values.
+  Region &newBody = newDispatchOp.getBody();
+  rewriter.inlineRegionBefore(dispatchOp.getBody(), newBody, newBody.begin());
+  {
+    Operation *terminator = newBody.front().getTerminator();
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(terminator);
+    rewriter.replaceOpWithNewOp<Flow::ReturnOp>(terminator, newYieldVals);
+  }
+
+  // 5b. Move the workgroup count region over.
+  Region &workgroupCountRegion = dispatchOp.getWorkgroupCount();
+  if (!workgroupCountRegion.empty()) {
+    Region &newWorkgroupCountRegion = newDispatchOp.getWorkgroupCount();
+    rewriter.inlineRegionBefore(workgroupCountRegion, newWorkgroupCountRegion,
+                                newWorkgroupCountRegion.begin());
+  }
+
+  // 6. Map the modified result values back to their original shape using
+  //    `tensor.expand_shape` operations.
+  ArrayRef<SmallVector<ReassociationIndices>> allReassociationIndicesRef(
+      allReassociationIndices);
+  for (auto [index, returnValue] :
+       llvm::enumerate(newDispatchOp.getResults())) {
+    Value origResult = dispatchOp->getResult(index);
+    if (returnValue.getType() == origResult.getType()) {
+      rewriter.replaceAllUsesWith(origResult, returnValue);
+      continue;
+    }
+    auto newExpandShapeOp = rewriter.create<tensor::ExpandShapeOp>(
+        loc, origResult.getType(), returnValue,
+        allReassociationIndicesRef.front());
+    allReassociationIndicesRef = allReassociationIndicesRef.drop_front();
+    rewriter.replaceAllUsesWith(origResult, newExpandShapeOp.getResult());
+  }
+  rewriter.eraseOp(dispatchOp);
+  return newDispatchOp;
 }
 
 /// Traverses DispatchRegionOps to find linalg genericOps that has no
 /// producers and tries to collapse its dimensions.
-static LogicalResult collapseDimensions(IRRewriter &rewriter,
-                                        DispatchRegionOp &regionOp) {
+static bool collapseDimensions(IRRewriter &rewriter,
+                               DispatchRegionOp &regionOp) {
   // Step 1. Find the root linalg.generic Op with no producer
   std::optional<linalg::GenericOp> genericOp = findRootGenericOp(regionOp);
   if (!genericOp.has_value())
-    return success();
+    return false;
 
   // Step 2. Check whether it is possible to collapse
   if (!isEligibleForCollapse(genericOp.value()))
-    return success();
+    return false;
   SmallVector<ReassociationIndices> collapseIndices;
   collapseIndices = getCollapsibleLoops(genericOp.value());
   if (collapseIndices.empty())
-    return success();
+    return false;
 
   // Step 3. Collapse dimensions
-  auto maybeReplacements =
-      collapseLinalgGeneric(rewriter, genericOp.value(), collapseIndices);
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(genericOp.value());
+
+  FailureOr<SmallVector<Value>> maybeReplacements =
+      mlir::linalg::collapseGenericOpIterationDims(genericOp.value(),
+                                                   collapseIndices, rewriter);
   if (failed(maybeReplacements))
-    return failure();
-  auto expandshapeOp =
-      maybeReplacements->front().getDefiningOp<tensor::ExpandShapeOp>();
-  if (!expandshapeOp)
-    return failure();
-  auto newGenericOp =
-      expandshapeOp.getOperand().getDefiningOp<linalg::GenericOp>();
-  if (!newGenericOp)
-    return failure();
-
-  // Step 4. Generate new dispatch region and replace old one users
-  if (failed(generateNewDispatchRegion(rewriter, regionOp, *maybeReplacements,
-                                       newGenericOp)))
-    return failure();
-
-  return success();
+    return false;
+  rewriter.replaceOp(genericOp.value(), maybeReplacements.value());
+  return true;
 }
 
 void CollapseDimensionsPass::runOnOperation() {
   mlir::FunctionOpInterface funcOp = getOperation();
-  IRRewriter rewriter(funcOp->getContext());
+  MLIRContext *context = funcOp->getContext();
+  IRRewriter rewriter(context);
 
-  auto walkResult = funcOp->walk([&](DispatchRegionOp regionOp) {
-    if (failed(collapseDimensions(rewriter, regionOp)))
-      return WalkResult::interrupt();
-    return WalkResult::advance();
+  SmallVector<DispatchRegionOp> modifiedDispatchOps;
+  funcOp->walk([&](DispatchRegionOp dispatchOp) {
+    if (collapseDimensions(rewriter, dispatchOp)) {
+      modifiedDispatchOps.push_back(dispatchOp);
+    }
   });
-  if (walkResult.wasInterrupted()) {
-    funcOp->emitOpError("failed in collapsing dimensions pass");
-    return signalPassFailure();
-  }
 
-  RewritePatternSet canonicalizationPatterns(&getContext());
-  memref::populateResolveRankedShapedTypeResultDimsPatterns(
-      canonicalizationPatterns);
-  tensor::populateFoldTensorEmptyPatterns(canonicalizationPatterns);
-  if (failed(applyPatternsAndFoldGreedily(
-          funcOp, std::move(canonicalizationPatterns)))) {
-    funcOp->emitOpError("failed to apply cleanup patterns");
-    return signalPassFailure();
+  LLVM_DEBUG({
+    llvm::dbgs() << "[CollapseDims] : After collapsing generic ops: \n";
+    funcOp.print(llvm::dbgs());
+    llvm::dbgs() << "\n";
+  });
+
+  // Move all the `tensor.collapse_shape` leafs  and `tensor.expand_shape` roots
+  // of the modified dispatches out of the dispatch.
+  for (auto dispatchOp : modifiedDispatchOps) {
+    Region &body = dispatchOp.getBody();
+    assert(llvm::hasSingleElement(body) && "expected op with a single body");
+    Block &block = body.front();
+    RewritePatternSet moveReshapeOps(&getContext());
+    linalg::FillOp::getCanonicalizationPatterns(moveReshapeOps, context);
+    memref::populateResolveRankedShapedTypeResultDimsPatterns(moveReshapeOps);
+    tensor::populateFoldTensorEmptyPatterns(moveReshapeOps);
+    SmallVector<Operation *> candidateOps;
+    block.walk([&](Operation *op) {
+      if (isa<tensor::CollapseShapeOp>(op)) {
+        candidateOps.push_back(op);
+      }
+    });
+    if (failed(
+            applyOpPatternsAndFold(candidateOps, std::move(moveReshapeOps)))) {
+      funcOp.emitOpError(
+          "failed to propagate reshape ops introduced during collapse");
+      return signalPassFailure();
+    }
+
+    if (failed(hoistTensorReshapesOutOfDispatchRegion(
+            rewriter, cast<DispatchRegionOp>(dispatchOp)))) {
+      dispatchOp->emitOpError("failed to hoist reshapes out of dispatch");
+      return signalPassFailure();
+    }
   }
 }
 
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseReductionDims.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseReductionDims.cpp
deleted file mode 100644
index 407777b..0000000
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseReductionDims.cpp
+++ /dev/null
@@ -1,95 +0,0 @@
-// Copyright 2022 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
-#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREE {
-namespace Flow {
-
-namespace {
-
-/// Check whether the given dimensions are contiguous in the result map.
-/// If non of the dimension are present in the map return true as well.
-static bool hasContiguousDims(AffineMap map, ArrayRef<unsigned> dims) {
-  if (!map.isProjectedPermutation())
-    return false;
-  llvm::SmallDenseSet<unsigned> existingDims(dims.begin(), dims.end());
-  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
-    if (map.getDimPosition(i) != dims[0]) {
-      if (existingDims.count(map.getDimPosition(i))) {
-        return false;
-      }
-      continue;
-    }
-    // Check that the following dimensions are match the order of `dims`
-    for (unsigned j = 1, numDims = dims.size(); j < numDims; j++) {
-      unsigned pos = i + j;
-      if (pos >= map.getNumResults() || map.getDimPosition(pos) != dims[j]) {
-        return false;
-      }
-    }
-    break;
-  }
-  return true;
-}
-
-static SmallVector<ReassociationIndices>
-collapseDimensions(linalg::GenericOp genericOp) {
-  SmallVector<ReassociationIndices> collapseIndices;
-
-  if (!isNonNullAndOutsideDispatch(genericOp)) {
-    return collapseIndices;
-  }
-
-  SmallVector<unsigned> reductionDims;
-  genericOp.getReductionDims(reductionDims);
-  if (reductionDims.size() < 2)
-    return collapseIndices;
-
-  for (AffineMap map : genericOp.getIndexingMapsArray()) {
-    if (!hasContiguousDims(map, reductionDims))
-      return collapseIndices;
-  }
-  ReassociationIndices indices;
-  for (unsigned dim : reductionDims) {
-    indices.push_back(int64_t(dim));
-  }
-  collapseIndices.push_back(indices);
-  return collapseIndices;
-}
-
-struct CollapseDimsPass : public CollapseDimsBase<CollapseDimsPass> {
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<linalg::LinalgDialect>();
-  }
-
-  void runOnOperation() override {
-    RewritePatternSet patterns(&getContext());
-    linalg::populateCollapseDimensions(patterns, collapseDimensions);
-    if (failed(applyPatternsAndFoldGreedily(getOperation(),
-                                            std::move(patterns)))) {
-      return signalPassFailure();
-    }
-  }
-};
-
-} // namespace
-
-std::unique_ptr<Pass> createCollapseDimsPass() {
-  return std::make_unique<CollapseDimsPass>();
-}
-
-} // namespace Flow
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
index 20746e7..a3f1f67 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
@@ -114,10 +114,16 @@
   //      broadcast this ends up redundantly computing operations without more
   //      parallelism.
   if (auto linalgConsumerOp = dyn_cast<linalg::LinalgOp>(consumerOp)) {
-    return linalgConsumerOp.getNumParallelLoops() ==
-               linalgConsumerOp.getNumLoops() ||
-           linalgConsumerOp.getMatchingIndexingMap(fusedOperand)
-               .isPermutation();
+    if (linalgConsumerOp.getNumParallelLoops() ==
+        linalgConsumerOp.getNumLoops()) {
+      return true;
+    }
+    if (linalgConsumerOp.getNumReductionLoops() != 1 ||
+        !linalgConsumerOp.getMatchingIndexingMap(fusedOperand)
+             .isPermutation()) {
+      return false;
+    }
+    return true;
   }
 
   // All other cases dont fuse.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index beece70..524977e 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -129,7 +129,6 @@
       // Preprocess the input to a form more amenable for fusion
       .addPass(createRaiseSpecialOps)
       .addPass(createInterchangeGenericOpsPass)
-      .addPass(createCollapseDimsPass)
       .addPass(memref::createResolveShapedTypeResultDimsPass)
       .addPass(mlir::createCanonicalizerPass)
       .addPass(mlir::createCSEPass)
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
index b8db1df..16f4c1e 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -239,9 +239,6 @@
 // Create a pass to split reduction dimension.
 std::unique_ptr<Pass> createSplitReductionPass();
 
-// Create a pass to collapse reduction dimensions
-std::unique_ptr<Pass> createCollapseDimsPass();
-
 //===----------------------------------------------------------------------===//
 // Module Analysis and Finalization
 //===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
index 2aa8ef3..52fa8fb 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
@@ -26,12 +26,6 @@
   let constructor = "mlir::iree_compiler::IREE::Flow::createCleanupNumericNarrowingPass()";
 }
 
-def CollapseDims :
-    Pass<"iree-flow-collapse-dims", ""> {
-  let summary = "Collapse reduction dimensions when possible.";
-  let constructor = "mlir::iree_compiler::IREE::Flow::createCollapseDimsPass()";
-}
-
 def Convert1X1FilterConv2DToMatmul:
     Pass<"iree-flow-convert-1x1-filter-conv2d-to-matmul", ""> {
   let summary = "Convert linalg convolution ops with 1x1 kernels into linalg matrix multiplication ops.";
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
index e506ec9..b8cd4bc 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
@@ -20,7 +20,6 @@
             "cleanup_numeric_narrowing.mlir",
             "cleanup_tensor_shapes.mlir",
             "clone_producers_into_dispatch_regions.mlir",
-            "collapse_reduction.mlir",
             "conv1x1_to_matmul.mlir",
             "convert_region_to_workgroups.mlir",
             "deduplicate_executables.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
index 9c72420..2bf5af2 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
@@ -19,7 +19,6 @@
     "cleanup_tensor_shapes.mlir"
     "clone_producers_into_dispatch_regions.mlir"
     "collapse_linalg_generic_on_tensors.mlir"
-    "collapse_reduction.mlir"
     "conv1x1_to_matmul.mlir"
     "convert_region_to_workgroups.mlir"
     "deduplicate_executables.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_linalg_generic_on_tensors.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_linalg_generic_on_tensors.mlir
index 21b6fe9..5e4ec9b 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_linalg_generic_on_tensors.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_linalg_generic_on_tensors.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-flow-form-dispatch-regions{fuse-multi-use=true}, iree-flow-collapse-dimensions))" %s | FileCheck %s
+// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-flow-form-dispatch-regions{fuse-multi-use=true}, iree-flow-clone-producers-into-dispatch-regions, iree-flow-collapse-dimensions, cse))" %s | FileCheck %s
 !type = tensor<2x4x8x16x32x64xf32>
 util.global private @"__transpose_10_input" {noinline} = dense<1.0> : !type
 
@@ -23,14 +23,14 @@
 
 }
 
-//      CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)>
-//      CHECK-LABEL: func.func @collapse1
-//      CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1, 2, 3, 4, 5]] : tensor<2x4x8x16x32x64xf32> into tensor<2097152xf32>
-//      CHECK: %[[OUT:.+]] = tensor.empty() : tensor<2097152xf32>
-//      CHECK: %[[RES:.+]] = flow.dispatch.region
-//      CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]]], iterator_types = ["parallel"]}
-//      CHECK: ins(%[[IN]] : tensor<2097152xf32>) outs(%[[OUT]] : tensor<2097152xf32>)
-//      CHECK:  tensor.expand_shape %[[RES]] {{\[}}[0, 1, 2, 3, 4, 5]] : tensor<2097152xf32> into tensor<2x4x8x16x32x64xf32>
+//       CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func.func @collapse1
+//       CHECK:   %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1, 2, 3, 4, 5]] : tensor<2x4x8x16x32x64xf32> into tensor<2097152xf32>
+//       CHECK:   %[[RES:.+]] = flow.dispatch.region
+//       CHECK:     %[[OUT:.+]] = tensor.empty() : tensor<2097152xf32>
+//       CHECK:     linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]]], iterator_types = ["parallel"]}
+//  CHECK-SAME:         ins(%[[IN]] : tensor<2097152xf32>) outs(%[[OUT]] : tensor<2097152xf32>)
+//       CHECK:   tensor.expand_shape %[[RES]] {{\[}}[0, 1, 2, 3, 4, 5]] : tensor<2097152xf32> into tensor<2x4x8x16x32x64xf32>
 
 // -----
 
@@ -58,15 +58,15 @@
 
 }
 
-//      CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d2, d4)>
-//      CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
-//      CHECK-LABEL: func.func @collapse2
-//      CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3], [4], [5, 6]] : tensor<2x4x8x32x32x64x128xf32> into tensor<8x8x32x32x8192xf32>
-//      CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x8x32x32x8192xf32>
-//      CHECK: %[[RES:.+]] = flow.dispatch.region
-//      CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel"]}
-//      CHECK: ins(%[[IN]] : tensor<8x8x32x32x8192xf32>) outs(%[[OUT]] : tensor<8x8x32x32x8192xf32>)
-//      CHECK:  tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3], [4], [5, 6]] : tensor<8x8x32x32x8192xf32> into tensor<2x4x8x32x32x64x128xf32>
+//       CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d2, d4)>
+//       CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK-LABEL: func.func @collapse2
+//       CHECK:   %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3], [4], [5, 6]] : tensor<2x4x8x32x32x64x128xf32> into tensor<8x8x32x32x8192xf32>
+//       CHECK:   %[[RES:.+]] = flow.dispatch.region
+//       CHECK:     %[[OUT:.+]] = tensor.empty() : tensor<8x8x32x32x8192xf32>
+//       CHECK:     linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel"]}
+//  CHECK-SAME:         ins(%[[IN]] : tensor<8x8x32x32x8192xf32>) outs(%[[OUT]] : tensor<8x8x32x32x8192xf32>)
+//       CHECK:   tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3], [4], [5, 6]] : tensor<8x8x32x32x8192xf32> into tensor<2x4x8x32x32x64x128xf32>
 
 // -----
 !type = tensor<2x4x8x16x32x64x128x256xf32>
@@ -93,14 +93,14 @@
 
 }
 
-//      CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-//      CHECK-LABEL: func.func @collapse3
-//      CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3, 4, 5, 6, 7]] : tensor<2x4x8x16x32x64x128x256xf32> into tensor<8x8x1073741824xf32>
-//      CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x8x1073741824xf32>
-//      CHECK: %[[RES:.+]] = flow.dispatch.region
-//      CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]]], iterator_types = ["parallel", "reduction", "parallel"]}
-//      CHECK: ins(%[[IN]] : tensor<8x8x1073741824xf32>) outs(%[[OUT]] : tensor<8x8x1073741824xf32>)
-//      CHECK:  tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3, 4, 5, 6, 7]] : tensor<8x8x1073741824xf32> into tensor<2x4x8x16x32x64x128x256xf32>
+//       CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-LABEL: func.func @collapse3
+//       CHECK:   %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3, 4, 5, 6, 7]] : tensor<2x4x8x16x32x64x128x256xf32> into tensor<8x8x1073741824xf32>
+//       CHECK:   %[[RES:.+]] = flow.dispatch.region
+//       CHECK:     %[[OUT:.+]] = tensor.empty() : tensor<8x8x1073741824xf32>
+//       CHECK:     linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]]], iterator_types = ["parallel", "reduction", "parallel"]}
+//  CHECK-SAME:         ins(%[[IN]] : tensor<8x8x1073741824xf32>) outs(%[[OUT]] : tensor<8x8x1073741824xf32>)
+//       CHECK:   tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3, 4, 5, 6, 7]] : tensor<8x8x1073741824xf32> into tensor<2x4x8x16x32x64x128x256xf32>
 
 // -----
 
@@ -127,15 +127,15 @@
 
 }
 
-//      CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
-//      CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4, d3, d5)>
-//      CHECK-LABEL: func.func @collapse4
-//      CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x8x16x64x64x128x256xf32> into tensor<8x8x16x64x64x32768xf32>
-//      CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x8x16x64x64x32768xf32>
-//      CHECK: %[[RES:.+]] = flow.dispatch.region
-//      CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel", "parallel"]}
-//      CHECK: ins(%[[IN]] : tensor<8x8x16x64x64x32768xf32>) outs(%[[OUT]] : tensor<8x8x16x64x64x32768xf32>)
-//      CHECK:  tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<8x8x16x64x64x32768xf32> into tensor<2x4x8x16x64x64x128x256xf32>
+//       CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+//       CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4, d3, d5)>
+// CHECK-LABEL: func.func @collapse4
+//       CHECK:   %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x8x16x64x64x128x256xf32> into tensor<8x8x16x64x64x32768xf32>
+//       CHECK:   %[[RES:.+]] = flow.dispatch.region
+//       CHECK:     %[[OUT:.+]] = tensor.empty() : tensor<8x8x16x64x64x32768xf32>
+//       CHECK:     linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel", "parallel"]}
+//  CHECK-SAME:         ins(%[[IN]] : tensor<8x8x16x64x64x32768xf32>) outs(%[[OUT]] : tensor<8x8x16x64x64x32768xf32>)
+//       CHECK:   tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<8x8x16x64x64x32768xf32> into tensor<2x4x8x16x64x64x128x256xf32>
 
 // -----
 
@@ -167,18 +167,18 @@
 
 }
 
-//      CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
-//      CHECK: #[[$MAP1:.+]] =  affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d2, d4, d5)>
-//      CHECK: #[[$MAP2:.+]] =  affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d2, d1, d4, d5)>
-//      CHECK-LABEL: func.func @collapse5
-//      CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x32x32x32x64x128x256xf32> into tensor<8x32x32x32x64x32768xf32>
-//      CHECK: %[[IN1:.+]] = tensor.collapse_shape %[[INPUT1:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x32x32x32x64x128x256xf32> into tensor<8x32x32x32x64x32768xf32>
-//      CHECK: %[[IN2:.+]] = tensor.collapse_shape %[[INPUT2:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x32x32x32x64x128x256xf32> into tensor<8x32x32x32x64x32768xf32>
-//      CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x32x32x32x64x32768xf32>
-//      CHECK: %[[RES:.+]] = flow.dispatch.region
-//      CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]], #[[$MAP]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "parallel"]}
-//      CHECK: ins(%[[IN]], %[[IN1]], %[[IN2]] : tensor<8x32x32x32x64x32768xf32>, tensor<8x32x32x32x64x32768xf32>, tensor<8x32x32x32x64x32768xf32>) outs(%[[OUT]] : tensor<8x32x32x32x64x32768xf32>)
-//      CHECK:  tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<8x32x32x32x64x32768xf32> into tensor<2x4x32x32x32x64x128x256xf32>
+//       CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+//       CHECK: #[[$MAP1:.+]] =  affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d2, d4, d5)>
+//       CHECK: #[[$MAP2:.+]] =  affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d2, d1, d4, d5)>
+// CHECK-LABEL: func.func @collapse5
+//       CHECK:   %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x32x32x32x64x128x256xf32> into tensor<8x32x32x32x64x32768xf32>
+//       CHECK:   %[[IN1:.+]] = tensor.collapse_shape %[[INPUT1:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x32x32x32x64x128x256xf32> into tensor<8x32x32x32x64x32768xf32>
+//       CHECK:   %[[IN2:.+]] = tensor.collapse_shape %[[INPUT2:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x32x32x32x64x128x256xf32> into tensor<8x32x32x32x64x32768xf32>
+//       CHECK:   %[[RES:.+]] = flow.dispatch.region
+//       CHECK:     %[[OUT:.+]] = tensor.empty() : tensor<8x32x32x32x64x32768xf32>
+//       CHECK:     linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]], #[[$MAP]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "parallel"]}
+//  CHECK-SAME:         ins(%[[IN]], %[[IN1]], %[[IN2]] : tensor<8x32x32x32x64x32768xf32>, tensor<8x32x32x32x64x32768xf32>, tensor<8x32x32x32x64x32768xf32>) outs(%[[OUT]] : tensor<8x32x32x32x64x32768xf32>)
+//       CHECK:  tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<8x32x32x32x64x32768xf32> into tensor<2x4x32x32x32x64x128x256xf32>
 
 // -----
 
@@ -205,15 +205,15 @@
 
 }
 
-//      CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
-//      CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4, d3, d5)>
-//      CHECK-LABEL: func.func @collapse6
-//      CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0], [1], [2, 3], [4], [5], [6, 7]] : tensor<32x2x4x8x16x16x64x128xf32> into tensor<32x2x32x16x16x8192xf32>
-//      CHECK: %[[OUT:.+]] = tensor.empty() : tensor<32x2x32x16x16x8192xf32>
-//      CHECK: %[[RES:.+]] = flow.dispatch.region
-//      CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel", "parallel"]}
-//      CHECK: ins(%[[IN]] : tensor<32x2x32x16x16x8192xf32>) outs(%[[OUT]] : tensor<32x2x32x16x16x8192xf32>)
-//      CHECK:  tensor.expand_shape %[[RES]] {{\[}}[0], [1], [2, 3], [4], [5], [6, 7]] : tensor<32x2x32x16x16x8192xf32> into tensor<32x2x4x8x16x16x64x128xf32>
+//       CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+//       CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4, d3, d5)>
+// CHECK-LABEL: func.func @collapse6
+//       CHECK:   %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0], [1], [2, 3], [4], [5], [6, 7]] : tensor<32x2x4x8x16x16x64x128xf32> into tensor<32x2x32x16x16x8192xf32>
+//       CHECK:   %[[RES:.+]] = flow.dispatch.region
+//       CHECK:     %[[OUT:.+]] = tensor.empty() : tensor<32x2x32x16x16x8192xf32>
+//       CHECK:     linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel", "parallel"]}
+//  CHECK-SAME:         ins(%[[IN]] : tensor<32x2x32x16x16x8192xf32>) outs(%[[OUT]] : tensor<32x2x32x16x16x8192xf32>)
+//       CHECK:   tensor.expand_shape %[[RES]] {{\[}}[0], [1], [2, 3], [4], [5], [6, 7]] : tensor<32x2x32x16x16x8192xf32> into tensor<32x2x4x8x16x16x64x128xf32>
 
 // -----
 
@@ -239,24 +239,23 @@
   return %result: !type_out
 }
 
-// CHECK: #[[$MAP:.+]] =  affine_map<(d0, d1) -> (d1)>
-// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+//       CHECK: #[[$MAP:.+]] =  affine_map<(d0, d1) -> (d1)>
+//       CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d1, d0)>
 // CHECK-LABEL: func.func @collapse7
-// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1, 2]] : tensor<2x4x8xf32> into tensor<64xf32>
-// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<64x16xf32>
-// CHECK: %[[RES:.+]] = flow.dispatch.region
-// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]}
-// CHECK: ins(%[[IN]] : tensor<64xf32>) outs(%[[OUT]] : tensor<64x16xf32>)
-// CHECK:  tensor.expand_shape %[[RES]] {{\[}}[0, 1, 2], [3]] : tensor<64x16xf32> into tensor<2x4x8x16xf32>
+//       CHECK:   %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1, 2]] : tensor<2x4x8xf32> into tensor<64xf32>
+//       CHECK:   %[[RES:.+]] = flow.dispatch.region
+//       CHECK:     %[[OUT:.+]] = tensor.empty() : tensor<64x16xf32>
+//       CHECK:     linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]}
+//  CHECK-SAME:         ins(%[[IN]] : tensor<64xf32>) outs(%[[OUT]] : tensor<64x16xf32>)
+//       CHECK:   tensor.expand_shape %[[RES]] {{\[}}[0, 1, 2], [3]] : tensor<64x16xf32> into tensor<2x4x8x16xf32>
 
 // -----
 
 !type_in = tensor<16x4x32x2xf32>
 !type_out = tensor<8x16x4x32x8x2xf32>
-func.func @collapse8() -> !type_out {
+func.func @collapse8(%input : !type_in) -> !type_out {
   %cst = arith.constant 0.000000e+00 : f32
   %c0 = arith.constant 0 : index
-  %input = tensor.empty() : !type_in
   %output = tensor.empty() : !type_out
   // Can collapse (d3, d0, d1)
   %6 = linalg.generic { indexing_maps = [ 
@@ -272,15 +271,16 @@
   return %6: !type_out
 }
 
-// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
-// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+//       CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+//       CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 // CHECK-LABEL: func.func @collapse8
-// CHECK: %[[IN:.+]] = tensor.empty() : tensor<2048x2xf32>
-// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x2048x8x2xf32>
-// CHECK: %[[RES:.+]] = flow.dispatch.region
-// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
-// CHECK: ins(%[[IN]] : tensor<2048x2xf32>) outs(%[[OUT]] : tensor<8x2048x8x2xf32
-// CHECK:  tensor.expand_shape %[[RES]] {{\[}}[0], [1, 2, 3], [4], [5]] : tensor<8x2048x8x2xf32> into tensor<8x16x4x32x8x2xf32>
+//  CHECK-SAME:     (%[[IN:.+]]: tensor<16x4x32x2xf32>)
+//       CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[IN]] {{\[}}[0, 1, 2], [3]{{\]}}
+//       CHECK:   %[[RES:.+]] = flow.dispatch.region
+//       CHECK:     %[[OUT:.+]] = tensor.empty() : tensor<8x2048x8x2xf32>
+//       CHECK:     linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+//  CHECK-SAME:         ins(%[[COLLAPSE]] : tensor<2048x2xf32>) outs(%[[OUT]] : tensor<8x2048x8x2xf32
+//       CHECK:   tensor.expand_shape %[[RES]] {{\[}}[0], [1, 2, 3], [4], [5]] : tensor<8x2048x8x2xf32> into tensor<8x16x4x32x8x2xf32>
 
 // -----
 
@@ -304,7 +304,7 @@
   return %6: !type_out
 }
 // CHECK-LABEL: func.func @dont_collapse
-// CHECK: linalg.generic {indexing_maps = [#[[$MAP:.+]], #[[$MAP2:.+]]], iterator_types = ["parallel", "parallel", "parallel"]}
+//       CHECK:   linalg.generic {indexing_maps = [#[[$MAP:.+]], #[[$MAP2:.+]]], iterator_types = ["parallel", "parallel", "parallel"]}
 
 // -----
 
@@ -333,11 +333,11 @@
 }
 
 
-// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
-// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d4, d3, d5)>
+//       CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+//       CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d4, d3, d5)>
 // CHECK-LABEL: func.func @collapse9
-// CHECK: %[[RES:.+]] = flow.dispatch.region
-// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel", "parallel"]}
+//       CHECK:   %[[RES:.+]] = flow.dispatch.region
+//       CHECK:     linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel", "parallel"]}
 
 
 // -----
@@ -345,10 +345,9 @@
 !type_in = tensor<10x10x30xf32>
 !type_out = tensor<20x10x10x30x20xf32>
 
-func.func @collapse10() -> !type_out {
+func.func @collapse10(%input : !type_in) -> !type_out {
   %cst = arith.constant 0.000000e+00 : f32
   %c0 = arith.constant 0 : index
-  %input = tensor.empty() : !type_in
   %output = tensor.empty() : !type_out
 
   // Can collapse (d1, d3, d0)
@@ -364,21 +363,18 @@
   return %result: !type_out
 }
 
-// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0)>
-// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
 // CHECK-LABEL: func.func @collapse10
-// CHECK: %[[RES:.+]] = flow.dispatch.region
-// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel"]}
+//       CHECK:   %[[RES:.+]] = flow.dispatch.region
+//       CHECK:     linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel"]}
 
 // -----
 
 !type_in =  tensor<10x20xf32>
 !type_out =  tensor<10x20xf32>
 
-func.func @collapse11() -> !type_out {
+func.func @collapse11(%input : !type_in) -> !type_out {
   %cst = arith.constant 0.000000e+00 : f32
   %c0 = arith.constant 0 : index
-  %input = tensor.empty() : !type_in
   %output = tensor.empty() : !type_out
 
   // Can collapse (d1, d0)
@@ -394,10 +390,10 @@
 }
 
 
-// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)>
+//       CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)>
 // CHECK-LABEL: func.func @collapse11
-// CHECK: %[[RES:.+]] = flow.dispatch.region
-// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]]], iterator_types = ["parallel"]}
+//       CHECK:   %[[RES:.+]] = flow.dispatch.region
+//       CHECK:     linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]]], iterator_types = ["parallel"]}
 
 // -----
 
@@ -420,7 +416,7 @@
 }
 
 // CHECK-LABEL: func.func @dont_collapse
-// CHECK: linalg.generic {indexing_maps = [#[[$MAP:.+]]], iterator_types = ["parallel", "parallel"]}
+//       CHECK:   linalg.generic {indexing_maps = [#[[$MAP:.+]]], iterator_types = ["parallel", "parallel"]}
 
 // -----
 
@@ -456,8 +452,146 @@
   return %6, %7, %8, %9  : !type,!type,!type,!type
 }
 
-// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)>
+//       CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)>
 // CHECK-LABEL: func.func @collapse12
-// CHECK: %[[RES:.+]] = flow.dispatch.region
-// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]], #[[$MAP]], #[[$MAP]]], iterator_types = ["parallel"]}
+//       CHECK:   %[[RES:.+]] = flow.dispatch.region
+//       CHECK:     linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]], #[[$MAP]], #[[$MAP]]], iterator_types = ["parallel"]}
 
+// -----
+
+func.func @multi_reduce_dim(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
+  %cst = arith.constant -0.000000e+00 : f32
+  %0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<2x32x10x4096xf32>
+  %1 = tensor.empty() : tensor<2x32xf32>
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<2x32xf32>) -> tensor<2x32xf32>
+  %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%0 : tensor<2x32x10x4096xf32>) outs(%2 : tensor<2x32xf32>) {
+  ^bb0(%arg1: f32, %arg2: f32):
+    %6 = arith.addf %arg1, %arg2 : f32
+    linalg.yield %6 : f32
+  } -> tensor<2x32xf32>
+  %4 = tensor.expand_shape %3 [[0], [1, 2, 3]] : tensor<2x32xf32> into tensor<2x32x1x1xf32>
+  %5 = hal.tensor.export %4 : tensor<2x32x1x1xf32> -> !hal.buffer_view
+  return %5 : !hal.buffer_view
+}
+
+// Check that we collapse dimensions.
+// CHECK-LABEL: @multi_reduce_dim(
+//   CHECK-DAG:   %[[ARG0:.+]] = hal.tensor.import
+//       CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}}
+//       CHECK:   %[[DISPATCH:.+]] = flow.dispatch.region
+//       CHECK:     %[[EMPTY:.+]] = tensor.empty() : tensor<64xf32>
+//       CHECK:     %[[FILL:.+]] = linalg.fill
+//  CHECK-SAME:         outs(%[[EMPTY]] :
+//       CHECK:     %[[GENERIC:.+]] = linalg.generic
+//  CHECK-SAME:         ins(%[[COLLAPSE]] :
+//  CHECK-SAME:         outs(%[[FILL]] :
+//       CHECK:      flow.return %[[GENERIC]]
+//       CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %[[DISPATCH]] {{\[}}[0, 1]{{\]}}
+
+// -----
+
+// Collapsing is not supported when an input is broadcasted; we can't collapse
+// the input from tensor<4xf32> to tensor<32xf32> for example.
+
+func.func @input_broadcast(%arg0: tensor<4x8xf32>, %arg1: tensor<4xf32>) -> tensor<f32> {
+  %empty = tensor.empty() : tensor<f32>
+  %reduce = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> ()>], iterator_types = ["reduction", "reduction"]} ins(%arg0, %arg1 : tensor<4x8xf32>, tensor<4xf32>) outs(%empty : tensor<f32>) {
+  ^bb0(%arg2: f32, %arg3: f32, %out: f32):
+    %div = arith.divf %arg2, %arg3 : f32
+    %add = arith.addf %out, %div : f32
+    linalg.yield %add : f32
+  } -> tensor<f32>
+  return %reduce : tensor<f32>
+}
+
+//     CHECK: @input_broadcast
+// CHECK-NOT: tensor.collapse_shape
+
+// -----
+
+// Do nothing if the dispatch is not a single elementwise op (with tensor.empty/linalg.fill producers)
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
+#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>
+#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+module {
+  func.func @quantized_matmul(%arg0: tensor<4096x32x128xi8>, %arg1: tensor<1x1x32x128xf32>) -> tensor<1x1x4096xf32> {
+    %cst = arith.constant dense_resource<__elided__> : tensor<4096x32xf32>
+    %cst_0 = arith.constant dense_resource<__elided__> : tensor<4096x32xf32>
+    %0 = flow.dispatch.region -> (tensor<1x1x4096xf32>) {
+      %cst_1 = arith.constant 0.000000e+00 : f32
+      %1 = tensor.empty() : tensor<1x1x4096xf32>
+      %2 = tensor.empty() : tensor<4096x32x128xf32>
+      %3 = linalg.fill ins(%cst_1 : f32) outs(%1 : tensor<1x1x4096xf32>) -> tensor<1x1x4096xf32>
+      %4 = linalg.generic {indexing_maps = [#map, #map1, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %cst, %cst_0 : tensor<4096x32x128xi8>, tensor<4096x32xf32>, tensor<4096x32xf32>) outs(%2 : tensor<4096x32x128xf32>) {
+      ^bb0(%in: i8, %in_2: f32, %in_3: f32, %out: f32):
+        %6 = arith.extui %in : i8 to i32
+        %7 = arith.uitofp %6 : i32 to f32
+        %8 = arith.subf %7, %in_3 : f32
+        %9 = arith.mulf %8, %in_2 : f32
+        linalg.yield %9 : f32
+      } -> tensor<4096x32x128xf32>
+      %5 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg1, %4 : tensor<1x1x32x128xf32>, tensor<4096x32x128xf32>) outs(%3 : tensor<1x1x4096xf32>) {
+      ^bb0(%in: f32, %in_2: f32, %out: f32):
+        %6 = arith.mulf %in, %in_2 : f32
+        %7 = arith.addf %6, %out : f32
+        linalg.yield %7 : f32
+      } -> tensor<1x1x4096xf32>
+      flow.return %5 : tensor<1x1x4096xf32>
+    }
+    return %0 : tensor<1x1x4096xf32>
+  }
+}
+
+// CHECK-LABEL:  func.func @quantized_matmul
+//       CHECK:    %[[DISPATCH:.+]] = flow.dispatch.region
+//       CHECK:      linalg.generic
+//  CHECK-SAME:          iterator_types = ["parallel", "parallel", "parallel"]
+//       CHECK:      linalg.generic
+//  CHECK-SAME:          iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]
+//       CHECK:      flow.return
+//       CHECK:    return %[[DISPATCH]]
+
+// -----
+
+module {
+  func.func @batchnorm_failure_repro(%arg0 : tensor<2x4xf32>, %arg1 : tensor<4xf32>) -> tensor<2x4xf32> {
+    %0 = tensor.empty() : tensor<2x4xf32>
+    %1 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+        iterator_types = ["parallel", "parallel"]}
+        ins(%arg0, %arg1 : tensor<2x4xf32>, tensor<4xf32>) outs(%0 : tensor<2x4xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+        %2 = arith.addf %b0, %b1 : f32
+        linalg.yield %2 : f32
+    } -> tensor<2x4xf32>
+    return %1 : tensor<2x4xf32>
+  }
+}
+// CHECK-LABEL: func @batchnorm_failure_repro
+//       CHECK:   %[[DISPATCH:.+]] = flow.dispatch.region
+//       CHECK:     %[[GENERIC:.+]] = linalg.generic
+//  CHECK-SAME:         iterator_types = ["parallel", "parallel"]
+//       CHECK:     flow.return %[[GENERIC]]
+//       CHECK:   return %[[DISPATCH]]
+
+// -----
+
+module {
+  func.func @catch_invalid_collapse(%arg0 : tensor<10x20x30xf32>) -> tensor<10x30x40xf32> {
+    %0 = tensor.empty() : tensor<10x30x40xf32>
+    %1 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>],
+        iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+        ins(%arg0 : tensor<10x20x30xf32>) outs(%0 : tensor<10x30x40xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+        linalg.yield %b0 : f32
+    } -> tensor<10x30x40xf32>
+    return %1 : tensor<10x30x40xf32>
+  }
+}
+// CHECK-LABEL: func @catch_invalid_collapse
+//       CHECK:   linalg.generic
+//  CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "parallel"]
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_reduction.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_reduction.mlir
deleted file mode 100644
index 5409dfb..0000000
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_reduction.mlir
+++ /dev/null
@@ -1,64 +0,0 @@
-// RUN: iree-opt --split-input-file -iree-flow-collapse-dims %s | FileCheck %s
-
-func.func @multi_reduce_dim(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
-  %cst = arith.constant -0.000000e+00 : f32
-  %0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<2x32x10x4096xf32>
-  %1 = tensor.empty() : tensor<2x32xf32>
-  %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<2x32xf32>) -> tensor<2x32xf32>
-  %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%0 : tensor<2x32x10x4096xf32>) outs(%2 : tensor<2x32xf32>) {
-  ^bb0(%arg1: f32, %arg2: f32):
-    %6 = arith.addf %arg1, %arg2 : f32
-    linalg.yield %6 : f32
-  } -> tensor<2x32xf32>
-  %4 = tensor.expand_shape %3 [[0], [1, 2, 3]] : tensor<2x32xf32> into tensor<2x32x1x1xf32>
-  %5 = hal.tensor.export %4 : tensor<2x32x1x1xf32> -> !hal.buffer_view
-  return %5 : !hal.buffer_view
-}
-
-// Check that we collapse dimensions.
-// CHECK: @multi_reduce_dim
-// CHECK: linalg.generic {{.*}} iterator_types = ["parallel", "parallel", "reduction"]
-
-// -----
-
-// Collapsing is not supported when an input is broadcasted; we can't collapse
-// the input from tensor<4xf32> to tensor<32xf32> for example.
-
-func.func @input_broadcast(%arg0: tensor<4x8xf32>, %arg1: tensor<4xf32>) -> tensor<f32> {
-  %empty = tensor.empty() : tensor<f32>
-  %reduce = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> ()>], iterator_types = ["reduction", "reduction"]} ins(%arg0, %arg1 : tensor<4x8xf32>, tensor<4xf32>) outs(%empty : tensor<f32>) {
-  ^bb0(%arg2: f32, %arg3: f32, %out: f32):
-    %div = arith.divf %arg2, %arg3 : f32
-    %add = arith.addf %out, %div : f32
-    linalg.yield %add : f32
-  } -> tensor<f32>
-  return %reduce : tensor<f32>
-}
-
-// CHECK: @input_broadcast
-// CHECK-NOT: tensor.collapse_shape
-
-// -----
-
-// Collapsing should not happen to ops in flow.dispatch.region or flow.dispatch.workgroups
-
-func.func @multi_reduce_dim_dispatch(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
-  %cst = arith.constant -0.000000e+00 : f32
-  %0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<2x32x10x4096xf32>
-  %1 = tensor.empty() : tensor<2x32xf32>
-  %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<2x32xf32>) -> tensor<2x32xf32>
-  %3 = flow.dispatch.region -> (tensor<2x32xf32>) {
-    %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%0 : tensor<2x32x10x4096xf32>) outs(%2 : tensor<2x32xf32>) {
-    ^bb0(%arg1: f32, %arg2: f32):
-      %7 = arith.addf %arg1, %arg2 : f32
-      linalg.yield %7 : f32
-    } -> tensor<2x32xf32>
-    flow.return %6 : tensor<2x32xf32>
-  }
-  %4 = tensor.expand_shape %3 [[0], [1, 2, 3]] : tensor<2x32xf32> into tensor<2x32x1x1xf32>
-  %5 = hal.tensor.export %4 : tensor<2x32x1x1xf32> -> !hal.buffer_view
-  return %5 : !hal.buffer_view
-}
-
-// CHECK: @multi_reduce_dim_dispatch
-// CHECK: linalg.generic {{.*}} iterator_types = ["parallel", "parallel", "reduction", "reduction"]