[Flow] Add pattern to canonicalize away full tensor.insert_slice ops (#18941)
Additionally drops the pad combining pattern from this pass because it
was upstreamed.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Canonicalizer.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Canonicalizer.cpp
index b19bccc..1978d77 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Canonicalizer.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Canonicalizer.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -18,80 +19,66 @@
namespace {
-/// Folds a chain of `tensor.pad` ops with the same constant padding value.
-///
-/// Example:
-///
-/// ```mlir
-/// %1 = tensor.pad %0 low[0, 1] high[0, 2] {
-/// tensor.yield %val
-/// } : tensor<1x2xf32> to tensor<2x5xf32>
-/// %res = tensor.pad %1 low[0, 2] high[3, 0] {
-/// tensor.yield %val
-/// } : tensor<1x5xf32> to tensor<5x7xf32>
-/// ```
-///
-/// folds into:
-///
-/// ```mlir
-/// %res = tensor.pad %0 low[0, 3] high[3, 2] {
-/// tensor.yield %val
-/// } : tensor<1x2xf32> to tensor<5x7xf32>
-/// ```
-///
-/// NOTE: This wasn't sent upstream as a canonicalization due to the use of
-/// the Affine dialect.
-struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
- using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
+static std::optional<SmallVector<OpFoldResult>> getDefiningMixedSizes(Value v) {
+ if (auto empty = v.getDefiningOp<tensor::EmptyOp>()) {
+ return empty.getMixedSizes();
+ } else if (auto extract = v.getDefiningOp<tensor::ExtractSliceOp>()) {
+ // TODO: Support rank reducing cases.
+ if (extract.getSourceType().getRank() !=
+ extract.getResultType().getRank()) {
+ return {};
+ }
+ return extract.getMixedSizes();
+ }
+ return {};
+}
- LogicalResult matchAndRewrite(tensor::PadOp padOp,
+struct FoldFullInsertSlice : public OpRewritePattern<tensor::InsertSliceOp> {
+ using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
PatternRewriter &rewriter) const override {
- if (padOp.getNofold()) {
- return failure();
+ if (!insertSliceOp.hasUnitStride() || !insertSliceOp.hasZeroOffset()) {
+ return rewriter.notifyMatchFailure(insertSliceOp,
+ "non-unit stride or non-zero offset.");
}
- auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
- if (!producerPad || producerPad.getNofold()) {
+
+ RankedTensorType sourceType = insertSliceOp.getSourceType();
+ RankedTensorType resultType = insertSliceOp.getResultType();
+ if (sourceType != resultType) {
return rewriter.notifyMatchFailure(
- padOp, "producer is not a foldable tensor.pad op");
+ insertSliceOp,
+ "unimplemented: Cast-like or reshape-like insert ops.");
}
- // Fail if the tensor::PadOps padding values do not match.
- Value consumerPadValue = padOp.getConstantPaddingValue();
- Value producerPadValue = producerPad.getConstantPaddingValue();
- if (!consumerPadValue || !producerPadValue ||
- consumerPadValue != producerPadValue) {
+ std::optional<SmallVector<OpFoldResult>> mixedSizes =
+ getDefiningMixedSizes(insertSliceOp.getDest());
+ if (!mixedSizes) {
return rewriter.notifyMatchFailure(
- padOp, "cannot fold PadOps with different padding values");
+ insertSliceOp, "Could not find producer with list of tensor sizes.");
}
- Location loc = padOp.getLoc();
- AffineExpr d0, d1;
- bindDims(rewriter.getContext(), d0, d1);
-
- // Combine the low/high paddings of the two tensor::PadOps.
- auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
- ArrayRef<OpFoldResult> producerPaddings) {
- SmallVector<OpFoldResult> sumPaddings;
- for (auto [consumerIndex, producerIndex] :
- llvm::zip_equal(consumerPaddings, producerPaddings)) {
- sumPaddings.push_back(affine::makeComposedFoldedAffineApply(
- rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
+ for (auto [insertSize, destSize] :
+ llvm::zip_equal(insertSliceOp.getMixedSizes(), mixedSizes.value())) {
+ if (isa<Value>(insertSize) || isa<Value>(destSize)) {
+ if (insertSize != destSize) {
+ return rewriter.notifyMatchFailure(insertSliceOp,
+ "dynamic size mismatch");
+ }
+ continue;
}
- return sumPaddings;
- };
- SmallVector<OpFoldResult> newHighPad =
- addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
- SmallVector<OpFoldResult> newLowPad =
- addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
+ // `getMixedSizes` for different ops returns different attribute types
+ // (`index` or `i64`) so we compare the values of the ints directly here.
+ int64_t staticInsertSize = getConstantIntValue(insertSize).value();
+ int64_t staticDestSize = getConstantIntValue(insertSize).value();
+ if (staticInsertSize != staticDestSize) {
+ return rewriter.notifyMatchFailure(insertSliceOp,
+ "static size mismatch");
+ }
+ }
- auto newPadOp = rewriter.create<tensor::PadOp>(
- padOp.getLoc(), padOp.getResultType(), producerPad.getSource(),
- newLowPad, newHighPad, padOp.getNofold(),
- getPrunedAttributeList(padOp, tensor::PadOp::getAttributeNames()));
- rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
- newPadOp.getRegion().begin());
- rewriter.replaceOp(padOp, newPadOp.getResult());
+ rewriter.replaceOp(insertSliceOp, insertSliceOp.getSource());
return success();
}
};
@@ -117,7 +104,7 @@
// Pull in some borderline/downstream canonicalizations for the Flow
// compilation phase.
tensor::populateMergeConsecutiveInsertExtractSlicePatterns(owningPatterns);
- owningPatterns.add<FoldConsecutiveConstantPadding>(context);
+ owningPatterns.add<FoldFullInsertSlice>(context);
patterns =
std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/flow_canonicalize.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/flow_canonicalize.mlir
index 81203a5..8734b85 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/flow_canonicalize.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/flow_canonicalize.mlir
@@ -1,84 +1,56 @@
// RUN: iree-opt --iree-flow-canonicalize %s --split-input-file --mlir-print-local-scope | FileCheck %s
-util.func public @merge_constant_padding(%arg0: tensor<2x3xf32>, %pad_value: f32) -> tensor<7x8xf32> {
- %pad0 = tensor.pad %arg0 low[1, 1] high[1, 0] {
- ^bb0(%b0: index, %b1 : index):
- tensor.yield %pad_value : f32
- } : tensor<2x3xf32> to tensor<4x4xf32>
- %pad1 = tensor.pad %pad0 low[0, 2] high[3, 2] {
- ^bb0(%b2: index, %b3 : index):
- tensor.yield %pad_value : f32
- } : tensor<4x4xf32> to tensor<7x8xf32>
- util.return %pad1 : tensor<7x8xf32>
+util.func public @fold_full_insert_into_extract(
+ %source: tensor<8x?xf32>,
+ %dest: tensor<10x?xf32>,
+ %size: index) -> tensor<8x?xf32> {
+ %extract = tensor.extract_slice %dest [1, 1] [8, %size] [1, 1] : tensor<10x?xf32> to tensor<8x?xf32>
+ %insert = tensor.insert_slice %source into %extract [0, 0] [8, %size] [1, 1] : tensor<8x?xf32> into tensor<8x?xf32>
+ util.return %insert : tensor<8x?xf32>
}
-// CHECK-LABEL: util.func public @merge_constant_padding
-// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<2x3xf32>
-// CHECK-SAME: %[[PADVAL:[A-Za-z0-9]+]]: f32
-// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] low[1, 3] high[4, 2]
-// CHECK: tensor.yield %[[PADVAL]]
-// CHECK: util.return %[[PAD]]
+
+// CHECK-LABEL: util.func public @fold_full_insert_into_extract
+// CHECK-SAME: %[[SOURCE:.+]]: tensor<8x?xf32>
+// CHECK: util.return %[[SOURCE]]
// -----
-util.func public @merge_constant_padding_dynamic(%arg0: tensor<?x?xf32>, %idx: index, %pad_value: f32) -> tensor<?x?xf32> {
- %pad0 = tensor.pad %arg0 low[%idx, 1] high[1, 0] {
- ^bb0(%b0: index, %b1 : index):
- tensor.yield %pad_value : f32
- } : tensor<?x?xf32> to tensor<?x?xf32>
- %pad1 = tensor.pad %pad0 low[0, 2] high[%idx, 2] {
- ^bb0(%b2: index, %b3 : index):
- tensor.yield %pad_value : f32
- } : tensor<?x?xf32> to tensor<?x?xf32>
- util.return %pad1 : tensor<?x?xf32>
+util.func public @fold_full_insert_into_empty(
+ %source: tensor<8x?xf32>,
+ %size: index) -> tensor<8x?xf32> {
+ %empty = tensor.empty(%size) : tensor<8x?xf32>
+ %insert = tensor.insert_slice %source into %empty [0, 0] [8, %size] [1, 1] : tensor<8x?xf32> into tensor<8x?xf32>
+ util.return %insert : tensor<8x?xf32>
}
-// CHECK-LABEL: util.func public @merge_constant_padding_dynamic
-// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[IDX:[A-Za-z0-9]+]]: index
-// CHECK-SAME: %[[PADVAL:[A-Za-z0-9]+]]: f32
-// CHECK: %[[HIGH:.+]] = affine.apply affine_map<()[s0] -> (s0 + 1)>()[%[[IDX]]]
-// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] low[%[[IDX]], 3] high[%[[HIGH]], 2]
-// CHECK: tensor.yield %[[PADVAL]]
-// CHECK: util.return %[[PAD]]
+
+// CHECK-LABEL: util.func public @fold_full_insert_into_empty
+// CHECK-SAME: %[[SOURCE:.+]]: tensor<8x?xf32>
+// CHECK: util.return %[[SOURCE]]
// -----
-util.func public @dont_merge_constant_padding_nofold(%arg0: tensor<2x3xf32>, %pad_value: f32) -> tensor<7x8xf32> {
- %pad0 = tensor.pad %arg0 low[1, 1] high[1, 0] {
- ^bb0(%b0: index, %b1 : index):
- tensor.yield %pad_value : f32
- } : tensor<2x3xf32> to tensor<4x4xf32>
- %pad1 = tensor.pad %pad0 nofold low[0, 2] high[3, 2] {
- ^bb0(%b2: index, %b3 : index):
- tensor.yield %pad_value : f32
- } : tensor<4x4xf32> to tensor<7x8xf32>
- util.return %pad1 : tensor<7x8xf32>
+util.func public @dont_fold_not_full_insert_into_empty(
+ %source: tensor<8x?xf32>,
+ %size1: index, %size2: index) -> tensor<8x?xf32> {
+ %empty = tensor.empty(%size1) : tensor<8x?xf32>
+ %insert = tensor.insert_slice %source into %empty [0, 0] [8, %size2] [1, 1] : tensor<8x?xf32> into tensor<8x?xf32>
+ util.return %insert : tensor<8x?xf32>
}
-// Verify that folding does not happen if it would drop a nofold attribute
-
-// CHECK-LABEL: util.func public @dont_merge_constant_padding_nofold
-// CHECK: tensor.pad
-// CHECK: tensor.pad {{.*}} nofold
+// CHECK-LABEL: util.func public @dont_fold_not_full_insert_into_empty
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice
+// CHECK: util.return %[[INSERT]]
// -----
-util.func public @dont_merge_constant_padding_different_vals(
- %arg0: tensor<2x3xf32>,
- %pad_value0: f32,
- %pad_value1: f32) -> tensor<7x8xf32> {
- %pad0 = tensor.pad %arg0 low[1, 1] high[1, 0] {
- ^bb0(%b0: index, %b1 : index):
- tensor.yield %pad_value0 : f32
- } : tensor<2x3xf32> to tensor<4x4xf32>
- %pad1 = tensor.pad %pad0 nofold low[0, 2] high[3, 2] {
- ^bb0(%b2: index, %b3 : index):
- tensor.yield %pad_value1 : f32
- } : tensor<4x4xf32> to tensor<7x8xf32>
- util.return %pad1 : tensor<7x8xf32>
+util.func public @dont_fold_not_full_static_insert_into_empty(
+ %source: tensor<8x?xf32>,
+ %size: index) -> tensor<10x?xf32> {
+ %empty = tensor.empty(%size) : tensor<10x?xf32>
+ %insert = tensor.insert_slice %source into %empty [0, 0] [8, %size] [1, 1] : tensor<8x?xf32> into tensor<10x?xf32>
+ util.return %insert : tensor<10x?xf32>
}
-// Verify that folding does not happen if it would drop a nofold attribute
-
-// CHECK-LABEL: util.func public @dont_merge_constant_padding_different_vals
-// CHECK: tensor.pad
-// CHECK: tensor.pad
+// CHECK-LABEL: util.func public @dont_fold_not_full_static_insert_into_empty
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice
+// CHECK: util.return %[[INSERT]]