Cleanup flow pipeline. (#8395)

Reording the passes added before dispatch region formation to do
pass + canonicalization + cse. Also instead of using the pass to fold
tensor.dim operations use the patterns directly where tensor.dim
could be introduced. That might remove the tensor.dim operations to
begin with.
diff --git a/iree/compiler/Dialect/Flow/Transforms/ConvertToFlow.cpp b/iree/compiler/Dialect/Flow/Transforms/ConvertToFlow.cpp
index 534e9b8..881e52e 100644
--- a/iree/compiler/Dialect/Flow/Transforms/ConvertToFlow.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/ConvertToFlow.cpp
@@ -11,6 +11,7 @@
 #include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
 #include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/PatternMatch.h"
@@ -113,6 +114,7 @@
                 LinalgTensorReshapeToFlowTensorReshape<tensor::ExpandShapeOp>>(
             context);
     populateTensorToFlowPatternsBeforeDispatchFormation(context, patterns);
+    memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
     IREE::Flow::TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
 
     if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
@@ -137,6 +139,7 @@
 
     patterns.insert<LinalgFillToFlowTensorSplat>(context);
     populateTensorToFlowPatternsAfterDispatchFormation(context, patterns);
+    memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
     IREE::Flow::TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
 
     if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
index 9bf008f..fb1655a 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
@@ -15,6 +15,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
@@ -923,6 +924,8 @@
     RewritePatternSet canonicalizationPatterns(context);
     linalg::populateLinalgTilingCanonicalizationPatterns(
         canonicalizationPatterns);
+    memref::populateResolveRankedShapeTypeResultDimsPatterns(
+        canonicalizationPatterns);
     if (failed(applyPatternsAndFoldGreedily(
             funcOp, std::move(canonicalizationPatterns)))) {
       return failure();
@@ -935,24 +938,6 @@
     llvm::dbgs() << "\n\n";
   });
 
-  // Run necessary canonicalization patterns before rewrite destructive updates.
-  {
-    RewritePatternSet patterns(context);
-    // Resolve `tensor.dim` of result of operations into operations on its
-    // operands using the `ReifyRankedShapedTypeOpInterface`.
-    memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
-    // This is needed because tiling and distribution may create
-    // subtensor_insert ops whose source operands come from tensor.cast ops.
-    // Those tensor.cast ops cast tensors into a more dynamic shape, in order
-    // to guarantee type match during transformation. Later in destructive
-    // update subtensor_insert ops will be turned into flow dispatch output
-    // store ops.
-    tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, context);
-    if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
-      return failure();
-    }
-  }
-
   // After outlining in dispatch region we can rewrite the dispatch ops with
   // proper captures to make it isolated from above.
   if (funcOp
diff --git a/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp b/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
index c1e6a80..cb30e9a 100644
--- a/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 namespace mlir {
@@ -122,6 +123,7 @@
         linalg::LinalgElementwiseFusionOptions()
             .setControlFoldingReshapes(foldReshapeBetweenLinalgFn)
             .setControlElementwiseOpsFusionFn(controlFn));
+    memref::populateResolveRankedShapeTypeResultDimsPatterns(fusionPatterns);
 
     if (failed(applyPatternsAndFoldGreedily(op->getRegions(),
                                             std::move(fusionPatterns)))) {
@@ -139,6 +141,7 @@
                                                       context);
     linalg::FillOp::getCanonicalizationPatterns(reshapeCanonicalizations,
                                                 context);
+    memref::populateResolveRankedShapeTypeResultDimsPatterns(fusionPatterns);
     if (failed(applyPatternsAndFoldGreedily(
             op->getRegions(), std::move(reshapeCanonicalizations)))) {
       return signalPassFailure();
@@ -154,6 +157,7 @@
     linalg::InitTensorOp::getCanonicalizationPatterns(pushReshapePatterns,
                                                       context);
     linalg::FillOp::getCanonicalizationPatterns(pushReshapePatterns, context);
+    memref::populateResolveRankedShapeTypeResultDimsPatterns(fusionPatterns);
     if (failed(applyPatternsAndFoldGreedily(op->getRegions(),
                                             std::move(pushReshapePatterns)))) {
       return signalPassFailure();
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index e536c6d..f743187 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -152,33 +152,40 @@
       // Pad tensors.
       .addPass(IREE::Flow::createPadTensorToSubTensorInsertPass)
 
-      // Elementwise, fusion, tiling and distribution.
+      // Preprocess the input to a form more amenable for fusion
+      // - Convert all elementwise ops to Linalg
+      // - Remove unit-extent dimensions.
       .addPass(mlir::createConvertElementwiseToLinalgPass)
       .addPass(mlir::createLinalgFoldUnitExtentDimsPass)
-      .addPass(IREE::Flow::createInterchangeGenericOpsPass)
-      .addPass(mlir::createCanonicalizerPass)
+      .addPass(createInterchangeGenericOpsPass)
       .addPass(memref::createResolveShapedTypeResultDimsPass)
-
-      // Fusion.
-      .addPass(IREE::Flow::createFusionOfTensorOpsPass)
+      .addPass(mlir::createCanonicalizerPass)
       .addPass(mlir::createCSEPass)
+
+      // Elementwise fusion.
+      .addPass(createFusionOfTensorOpsPass)
       .addPredicatedPass(clEnableLinalgDetensorize,
                          mlir::createLinalgDetensorizePass)
 
-      // Dispatch region formation.
-      .addPass(IREE::Flow::createConvertToFlowBeforeDispatchFormation)
       .addPass(mlir::createCanonicalizerPass)
-      .addPass(IREE::Flow::createDispatchLinalgOnTensorsPass)
-      .addPass(memref::createResolveShapedTypeResultDimsPass)
-      .addPass(IREE::Flow::createCaptureDispatchDynamicDimsPass)
-      .addPass(IREE::Flow::createConvertToFlowAfterDispatchFormation)
-      .addPass(mlir::createCanonicalizerPass)
-      .addPass(memref::createResolveShapedTypeResultDimsPass)
+      .addPass(mlir::createCSEPass)
 
-      // NOTE: required because the current dispatch-linalg-on-tensors pass
-      // creates a lot of dead IR that needs to be cleaned up.
-      .addPass(IREE::Flow::createConvertToFlowAfterDispatchFormation)
-      .addPass(mlir::createCanonicalizerPass);
+      // Dispatch region formation.
+      // TODO(ravishankarm): Fold ConvertToFlowBefore/ConvertToFlowAfter into
+      // dispatch region formation pass.
+      .addPass(createConvertToFlowBeforeDispatchFormation)
+      .addPass(mlir::createCanonicalizerPass)
+      .addPass(mlir::createCSEPass)
+      .addPass(createDispatchLinalgOnTensorsPass)
+      .addPass(createCaptureDispatchDynamicDimsPass)
+      .addPass(mlir::createCanonicalizerPass)
+      .addPass(createCSEPass)
+
+      // Convert remaining ops to Flow ops, after this stage no Linalg ops
+      // should remain.
+      .addPass(createConvertToFlowAfterDispatchFormation)
+      .addPass(mlir::createCanonicalizerPass)
+      .addPass(mlir::createCSEPass);
 
   // Module pass to outline the dispatch regions into their own functions
   // wrapped in executables.