[Codegen] Add patterns for folding away no-op slices (#18419)
Adds a pattern to fold away no-op slices to
OptimizeTensorInsertExtractSlices that calls an upstream utility that
uses the ValueBoundsInterface to determine whether the sizes of
a `tensor.extract_slice`/`tensor.insert_slice` are no-ops.
This is kept out of a static canonicalizer because the ValueBoundsInterface
can be quite expensive due to walking up use-def chains indefinitely.
This folding is a pass option because some other pipelines are sensitive
to insert/extract_slice structure.
diff --git a/compiler/src/iree/compiler/Codegen/Common/OptimizeTensorInsertExtractSlices.cpp b/compiler/src/iree/compiler/Codegen/Common/OptimizeTensorInsertExtractSlices.cpp
index 82e074e..ac9334b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/OptimizeTensorInsertExtractSlices.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/OptimizeTensorInsertExtractSlices.cpp
@@ -8,6 +8,8 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Interfaces/SubsetOpInterface.h"
@@ -29,6 +31,10 @@
class OptimizeTensorInsertExtractSlicesPass final
: public impl::OptimizeTensorInsertExtractSlicesPassBase<
OptimizeTensorInsertExtractSlicesPass> {
+ using impl::OptimizeTensorInsertExtractSlicesPassBase<
+ OptimizeTensorInsertExtractSlicesPass>::
+ OptimizeTensorInsertExtractSlicesPassBase;
+
public:
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<scf::SCFDialect, vector::VectorDialect>();
@@ -200,6 +206,38 @@
}
}
+namespace {
+struct CastLikeExtractSliceOpFolder final
+ : OpRewritePattern<tensor::ExtractSliceOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
+ PatternRewriter &rewriter) const override {
+ if (!tensor::isCastLikeExtractSliceOp(sliceOp) ||
+ sliceOp.getSourceType() != sliceOp.getResultType()) {
+ return failure();
+ }
+ rewriter.replaceOp(sliceOp, sliceOp.getSource());
+ return success();
+ }
+};
+
+struct CastLikeInsertSliceOpFolder final
+ : OpRewritePattern<tensor::InsertSliceOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::InsertSliceOp sliceOp,
+ PatternRewriter &rewriter) const override {
+ if (!tensor::isCastLikeInsertSliceOp(sliceOp) ||
+ sliceOp.getSourceType() != sliceOp.getResultType()) {
+ return failure();
+ }
+ rewriter.replaceOp(sliceOp, sliceOp.getSource());
+ return success();
+ }
+};
+} // namespace
+
void OptimizeTensorInsertExtractSlicesPass::runOnOperation() {
auto funcOp = getOperation();
linalg::hoistRedundantVectorTransfers(cast<func::FuncOp>(funcOp));
@@ -223,6 +261,10 @@
populateVectorTransferTensorSliceTransforms(patterns);
scf::ForOp::getCanonicalizationPatterns(patterns, context);
vector::TransferWriteOp::getCanonicalizationPatterns(patterns, context);
+ if (foldIdentitySlices) {
+ patterns.add<CastLikeExtractSliceOpFolder>(context);
+ patterns.add<CastLikeInsertSliceOpFolder>(context);
+ }
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td
index e1c58ab..eb31e94 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td
@@ -301,6 +301,10 @@
"mlir::FunctionOpInterface"> {
let summary = "Optimize tensor.insert_slice/tensor.extract_slice operations "
"(e.g. hoist and fold)";
+ let options = [
+ Option<"foldIdentitySlices", "fold-identity-slices", "bool", "false",
+ "Enable folding of identity tensor.*_slice ops.">
+ ];
}
def HoistUnrolledVectorExtractInsertSlicePass :
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/optimize_tensor_insert_extract_slices.mlir b/compiler/src/iree/compiler/Codegen/Common/test/optimize_tensor_insert_extract_slices.mlir
index 5b072cb..dabd285 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/optimize_tensor_insert_extract_slices.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/optimize_tensor_insert_extract_slices.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-optimize-tensor-insert-extract-slices))" --split-input-file %s | FileCheck %s
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-optimize-tensor-insert-extract-slices{fold-identity-slices=true}))" --split-input-file %s | FileCheck %s
func.func @fold_extract_slice_consumer_into_xfer_write(%arg0: vector<1x64x128xf16>, %arg1: index) -> tensor<1x?x128xf16> {
%c0 = arith.constant 0 : index
@@ -308,3 +308,16 @@
// CHECK: tensor.insert_slice
// CHECK: scf.yield
// CHECK-NOT: tensor.insert_slice
+
+// -----
+
+func.func @fold_identity_extract_slice(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+ %c0 = arith.constant 0 : index
+ %dim = tensor.dim %arg0, %c0 : tensor<?xf32>
+ %slice = tensor.extract_slice %arg0[0][%dim][1] : tensor<?xf32> to tensor<?xf32>
+ return %slice : tensor<?xf32>
+}
+
+// CHECK-LABEL: @fold_identity_extract_slice
+// CHECK: %[[ARG0:.+]]: tensor<?xf32>
+// CHECK: return %[[ARG0]]
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index d33ba8b..f1c7ac1 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -395,6 +395,12 @@
// hoisting and fusion pass, as well as a lack of a fallback distribution
// pass.
funcPassManager.addPass(createLoopInvariantCodeMotionPass());
+ {
+ OptimizeTensorInsertExtractSlicesPassOptions options;
+ options.foldIdentitySlices = true;
+ funcPassManager.addPass(
+ createOptimizeTensorInsertExtractSlicesPass(options));
+ }
// Step 5. Greedily fuse parallel loops and hoist from serial loops.
funcPassManager.addPass(IREE::GPU::createFuseAndHoistParallelLoopsPass());