Cleanup flow pipeline. (#8395)
Reording the passes added before dispatch region formation to do
pass + canonicalization + cse. Also instead of using the pass to fold
tensor.dim operations use the patterns directly where tensor.dim
could be introduced. That might remove the tensor.dim operations to
begin with.
diff --git a/iree/compiler/Dialect/Flow/Transforms/ConvertToFlow.cpp b/iree/compiler/Dialect/Flow/Transforms/ConvertToFlow.cpp
index 534e9b8..881e52e 100644
--- a/iree/compiler/Dialect/Flow/Transforms/ConvertToFlow.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/ConvertToFlow.cpp
@@ -11,6 +11,7 @@
#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
@@ -113,6 +114,7 @@
LinalgTensorReshapeToFlowTensorReshape<tensor::ExpandShapeOp>>(
context);
populateTensorToFlowPatternsBeforeDispatchFormation(context, patterns);
+ memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
IREE::Flow::TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
@@ -137,6 +139,7 @@
patterns.insert<LinalgFillToFlowTensorSplat>(context);
populateTensorToFlowPatternsAfterDispatchFormation(context, patterns);
+ memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
IREE::Flow::TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
index 9bf008f..fb1655a 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
@@ -15,6 +15,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
@@ -923,6 +924,8 @@
RewritePatternSet canonicalizationPatterns(context);
linalg::populateLinalgTilingCanonicalizationPatterns(
canonicalizationPatterns);
+ memref::populateResolveRankedShapeTypeResultDimsPatterns(
+ canonicalizationPatterns);
if (failed(applyPatternsAndFoldGreedily(
funcOp, std::move(canonicalizationPatterns)))) {
return failure();
@@ -935,24 +938,6 @@
llvm::dbgs() << "\n\n";
});
- // Run necessary canonicalization patterns before rewrite destructive updates.
- {
- RewritePatternSet patterns(context);
- // Resolve `tensor.dim` of result of operations into operations on its
- // operands using the `ReifyRankedShapedTypeOpInterface`.
- memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
- // This is needed because tiling and distribution may create
- // subtensor_insert ops whose source operands come from tensor.cast ops.
- // Those tensor.cast ops cast tensors into a more dynamic shape, in order
- // to guarantee type match during transformation. Later in destructive
- // update subtensor_insert ops will be turned into flow dispatch output
- // store ops.
- tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, context);
- if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
- return failure();
- }
- }
-
// After outlining in dispatch region we can rewrite the dispatch ops with
// proper captures to make it isolated from above.
if (funcOp
diff --git a/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp b/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
index c1e6a80..cb30e9a 100644
--- a/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
@@ -122,6 +123,7 @@
linalg::LinalgElementwiseFusionOptions()
.setControlFoldingReshapes(foldReshapeBetweenLinalgFn)
.setControlElementwiseOpsFusionFn(controlFn));
+ memref::populateResolveRankedShapeTypeResultDimsPatterns(fusionPatterns);
if (failed(applyPatternsAndFoldGreedily(op->getRegions(),
std::move(fusionPatterns)))) {
@@ -139,6 +141,7 @@
context);
linalg::FillOp::getCanonicalizationPatterns(reshapeCanonicalizations,
context);
+ memref::populateResolveRankedShapeTypeResultDimsPatterns(fusionPatterns);
if (failed(applyPatternsAndFoldGreedily(
op->getRegions(), std::move(reshapeCanonicalizations)))) {
return signalPassFailure();
@@ -154,6 +157,7 @@
linalg::InitTensorOp::getCanonicalizationPatterns(pushReshapePatterns,
context);
linalg::FillOp::getCanonicalizationPatterns(pushReshapePatterns, context);
+ memref::populateResolveRankedShapeTypeResultDimsPatterns(fusionPatterns);
if (failed(applyPatternsAndFoldGreedily(op->getRegions(),
std::move(pushReshapePatterns)))) {
return signalPassFailure();
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index e536c6d..f743187 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -152,33 +152,40 @@
// Pad tensors.
.addPass(IREE::Flow::createPadTensorToSubTensorInsertPass)
- // Elementwise, fusion, tiling and distribution.
+ // Preprocess the input to a form more amenable for fusion
+ // - Convert all elementwise ops to Linalg
+ // - Remove unit-extent dimensions.
.addPass(mlir::createConvertElementwiseToLinalgPass)
.addPass(mlir::createLinalgFoldUnitExtentDimsPass)
- .addPass(IREE::Flow::createInterchangeGenericOpsPass)
- .addPass(mlir::createCanonicalizerPass)
+ .addPass(createInterchangeGenericOpsPass)
.addPass(memref::createResolveShapedTypeResultDimsPass)
-
- // Fusion.
- .addPass(IREE::Flow::createFusionOfTensorOpsPass)
+ .addPass(mlir::createCanonicalizerPass)
.addPass(mlir::createCSEPass)
+
+ // Elementwise fusion.
+ .addPass(createFusionOfTensorOpsPass)
.addPredicatedPass(clEnableLinalgDetensorize,
mlir::createLinalgDetensorizePass)
- // Dispatch region formation.
- .addPass(IREE::Flow::createConvertToFlowBeforeDispatchFormation)
.addPass(mlir::createCanonicalizerPass)
- .addPass(IREE::Flow::createDispatchLinalgOnTensorsPass)
- .addPass(memref::createResolveShapedTypeResultDimsPass)
- .addPass(IREE::Flow::createCaptureDispatchDynamicDimsPass)
- .addPass(IREE::Flow::createConvertToFlowAfterDispatchFormation)
- .addPass(mlir::createCanonicalizerPass)
- .addPass(memref::createResolveShapedTypeResultDimsPass)
+ .addPass(mlir::createCSEPass)
- // NOTE: required because the current dispatch-linalg-on-tensors pass
- // creates a lot of dead IR that needs to be cleaned up.
- .addPass(IREE::Flow::createConvertToFlowAfterDispatchFormation)
- .addPass(mlir::createCanonicalizerPass);
+ // Dispatch region formation.
+ // TODO(ravishankarm): Fold ConvertToFlowBefore/ConvertToFlowAfter into
+ // dispatch region formation pass.
+ .addPass(createConvertToFlowBeforeDispatchFormation)
+ .addPass(mlir::createCanonicalizerPass)
+ .addPass(mlir::createCSEPass)
+ .addPass(createDispatchLinalgOnTensorsPass)
+ .addPass(createCaptureDispatchDynamicDimsPass)
+ .addPass(mlir::createCanonicalizerPass)
+ .addPass(createCSEPass)
+
+ // Convert remaining ops to Flow ops, after this stage no Linalg ops
+ // should remain.
+ .addPass(createConvertToFlowAfterDispatchFormation)
+ .addPass(mlir::createCanonicalizerPass)
+ .addPass(mlir::createCSEPass);
// Module pass to outline the dispatch regions into their own functions
// wrapped in executables.