Cherry pick D139308 (#11454)
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
index 25180bf..a934afa 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
@@ -34,6 +34,7 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -413,6 +414,7 @@
// casting ops into tiled operations.
RewritePatternSet patterns(context);
linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
+ tensor::populateFoldTensorEmptyPatterns(patterns);
populateFoldAffineMinInDistributedLoopsPatterns(patterns);
context->getOrLoadDialect<IREE::LinalgExt::IREELinalgExtDialect>()
->getCanonicalizationPatterns(patterns);
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD
index 94db542..ab9fe2c 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD
@@ -114,6 +114,7 @@
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
+ "@llvm-project//mlir:TensorTransforms",
"@llvm-project//mlir:TensorUtils",
"@llvm-project//mlir:TilingInterface",
"@llvm-project//mlir:TosaDialect",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
index b22b201..86649a8 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -94,6 +94,7 @@
MLIRSCFDialect
MLIRSupport
MLIRTensorDialect
+ MLIRTensorTransforms
MLIRTensorUtils
MLIRTilingInterface
MLIRTosaDialect
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
index f446903..d306e85 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
@@ -20,6 +20,7 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -265,6 +266,7 @@
linalg::GenericOp::getCanonicalizationPatterns(fusionPatterns, context);
tensor::ExpandShapeOp::getCanonicalizationPatterns(fusionPatterns,
context);
+ tensor::populateFoldTensorEmptyPatterns(fusionPatterns);
tensor::CollapseShapeOp::getCanonicalizationPatterns(fusionPatterns,
context);
context->getLoadedDialect<linalg::LinalgDialect>()
@@ -311,6 +313,7 @@
collapsingReshapePatterns, context);
tensor::ExpandShapeOp::getCanonicalizationPatterns(
collapsingReshapePatterns, context);
+ tensor::populateFoldTensorEmptyPatterns(collapsingReshapePatterns);
memref::populateResolveRankedShapeTypeResultDimsPatterns(
collapsingReshapePatterns);
if (failed(applyPatternsAndFoldGreedily(
@@ -326,6 +329,17 @@
});
}
+ // Run some patterns that fold away a few operations.
+ {
+ RewritePatternSet opFoldingPatterns(&getContext());
+ tensor::populateFoldTensorEmptyPatterns(opFoldingPatterns);
+ if (failed(applyPatternsAndFoldGreedily(funcOp->getRegions(),
+ std::move(opFoldingPatterns)))) {
+ funcOp->emitError("failed to apply op folding patterns");
+ return signalPassFailure();
+ }
+ }
+
if (fuseMultiUse) {
// Run fusion of producer with consumer when producer has multiple uses.
// For now run this sequence a fixed times (2 by default). Ideally we
diff --git a/llvm-external-projects/iree-dialects/BUILD b/llvm-external-projects/iree-dialects/BUILD
index 6b73d38..1821fd3 100644
--- a/llvm-external-projects/iree-dialects/BUILD
+++ b/llvm-external-projects/iree-dialects/BUILD
@@ -418,6 +418,7 @@
"@llvm-project//mlir:SCFTransforms",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
+ "@llvm-project//mlir:TensorTransforms",
"@llvm-project//mlir:TensorUtils",
"@llvm-project//mlir:TilingInterface",
"@llvm-project//mlir:TransformUtils",
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt
index ee1ca25..f02241a 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt
@@ -29,5 +29,6 @@
MLIRLinalgTransforms
MLIRPass
MLIRSCFDialect
+ MLIRTensorTransforms
MLIRTransforms
)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
index 619345a..b335776 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -567,6 +568,7 @@
RewritePatternSet patterns =
linalg::getLinalgTilingCanonicalizationPatterns(context);
scf::populateSCFForLoopCanonicalizationPatterns(patterns);
+ tensor::populateFoldTensorEmptyPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
return signalPassFailure();
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 279d294..941b064 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 279d294d26c39e86dd7baabf5cd3385676d9a7a4
+Subproject commit 941b064ffe422cb0f68c08369c2406bada649a6a