[Flow] Add pattern to canonicalize away full tensor.insert_slice ops (#18941)

Additionally drops the pad combining pattern from this pass because it
was upstreamed.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Canonicalizer.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Canonicalizer.cpp
index b19bccc..1978d77 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Canonicalizer.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Canonicalizer.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
@@ -18,80 +19,66 @@
 
 namespace {
 
-/// Folds a chain of `tensor.pad` ops with the same constant padding value.
-///
-/// Example:
-///
-/// ```mlir
-///   %1 = tensor.pad %0 low[0, 1] high[0, 2] {
-///       tensor.yield %val
-///     } : tensor<1x2xf32> to tensor<2x5xf32>
-///   %res = tensor.pad %1 low[0, 2] high[3, 0] {
-///       tensor.yield %val
-///     } : tensor<1x5xf32> to tensor<5x7xf32>
-/// ```
-///
-/// folds into:
-///
-/// ```mlir
-///   %res = tensor.pad %0 low[0, 3] high[3, 2] {
-///       tensor.yield %val
-///     } : tensor<1x2xf32> to tensor<5x7xf32>
-/// ```
-///
-/// NOTE: This wasn't sent upstream as a canonicalization due to the use of
-/// the Affine dialect.
-struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
-  using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
+static std::optional<SmallVector<OpFoldResult>> getDefiningMixedSizes(Value v) {
+  if (auto empty = v.getDefiningOp<tensor::EmptyOp>()) {
+    return empty.getMixedSizes();
+  } else if (auto extract = v.getDefiningOp<tensor::ExtractSliceOp>()) {
+    // TODO: Support rank reducing cases.
+    if (extract.getSourceType().getRank() !=
+        extract.getResultType().getRank()) {
+      return {};
+    }
+    return extract.getMixedSizes();
+  }
+  return {};
+}
 
-  LogicalResult matchAndRewrite(tensor::PadOp padOp,
+struct FoldFullInsertSlice : public OpRewritePattern<tensor::InsertSliceOp> {
+  using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
                                 PatternRewriter &rewriter) const override {
-    if (padOp.getNofold()) {
-      return failure();
+    if (!insertSliceOp.hasUnitStride() || !insertSliceOp.hasZeroOffset()) {
+      return rewriter.notifyMatchFailure(insertSliceOp,
+                                         "non-unit stride or non-zero offset.");
     }
-    auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
-    if (!producerPad || producerPad.getNofold()) {
+
+    RankedTensorType sourceType = insertSliceOp.getSourceType();
+    RankedTensorType resultType = insertSliceOp.getResultType();
+    if (sourceType != resultType) {
       return rewriter.notifyMatchFailure(
-          padOp, "producer is not a foldable tensor.pad op");
+          insertSliceOp,
+          "unimplemented: Cast-like or reshape-like insert ops.");
     }
 
-    // Fail if the tensor::PadOps padding values do not match.
-    Value consumerPadValue = padOp.getConstantPaddingValue();
-    Value producerPadValue = producerPad.getConstantPaddingValue();
-    if (!consumerPadValue || !producerPadValue ||
-        consumerPadValue != producerPadValue) {
+    std::optional<SmallVector<OpFoldResult>> mixedSizes =
+        getDefiningMixedSizes(insertSliceOp.getDest());
+    if (!mixedSizes) {
       return rewriter.notifyMatchFailure(
-          padOp, "cannot fold PadOps with different padding values");
+          insertSliceOp, "Could not find producer with list of tensor sizes.");
     }
 
-    Location loc = padOp.getLoc();
-    AffineExpr d0, d1;
-    bindDims(rewriter.getContext(), d0, d1);
-
-    // Combine the low/high paddings of the two tensor::PadOps.
-    auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
-                           ArrayRef<OpFoldResult> producerPaddings) {
-      SmallVector<OpFoldResult> sumPaddings;
-      for (auto [consumerIndex, producerIndex] :
-           llvm::zip_equal(consumerPaddings, producerPaddings)) {
-        sumPaddings.push_back(affine::makeComposedFoldedAffineApply(
-            rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
+    for (auto [insertSize, destSize] :
+         llvm::zip_equal(insertSliceOp.getMixedSizes(), mixedSizes.value())) {
+      if (isa<Value>(insertSize) || isa<Value>(destSize)) {
+        if (insertSize != destSize) {
+          return rewriter.notifyMatchFailure(insertSliceOp,
+                                             "dynamic size mismatch");
+        }
+        continue;
       }
-      return sumPaddings;
-    };
 
-    SmallVector<OpFoldResult> newHighPad =
-        addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
-    SmallVector<OpFoldResult> newLowPad =
-        addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
+      // `getMixedSizes` for different ops returns different attribute types
+      // (`index` or `i64`) so we compare the values of the ints directly here.
+      int64_t staticInsertSize = getConstantIntValue(insertSize).value();
+      int64_t staticDestSize = getConstantIntValue(insertSize).value();
+      if (staticInsertSize != staticDestSize) {
+        return rewriter.notifyMatchFailure(insertSliceOp,
+                                           "static size mismatch");
+      }
+    }
 
-    auto newPadOp = rewriter.create<tensor::PadOp>(
-        padOp.getLoc(), padOp.getResultType(), producerPad.getSource(),
-        newLowPad, newHighPad, padOp.getNofold(),
-        getPrunedAttributeList(padOp, tensor::PadOp::getAttributeNames()));
-    rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
-                                newPadOp.getRegion().begin());
-    rewriter.replaceOp(padOp, newPadOp.getResult());
+    rewriter.replaceOp(insertSliceOp, insertSliceOp.getSource());
     return success();
   }
 };
@@ -117,7 +104,7 @@
     // Pull in some borderline/downstream canonicalizations for the Flow
     // compilation phase.
     tensor::populateMergeConsecutiveInsertExtractSlicePatterns(owningPatterns);
-    owningPatterns.add<FoldConsecutiveConstantPadding>(context);
+    owningPatterns.add<FoldFullInsertSlice>(context);
 
     patterns =
         std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/flow_canonicalize.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/flow_canonicalize.mlir
index 81203a5..8734b85 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/flow_canonicalize.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/flow_canonicalize.mlir
@@ -1,84 +1,56 @@
 // RUN: iree-opt --iree-flow-canonicalize %s --split-input-file --mlir-print-local-scope | FileCheck %s
 
-util.func public @merge_constant_padding(%arg0: tensor<2x3xf32>, %pad_value: f32) -> tensor<7x8xf32> {
-  %pad0 = tensor.pad %arg0 low[1, 1] high[1, 0] {
-    ^bb0(%b0: index, %b1 : index):
-      tensor.yield %pad_value : f32
-    } : tensor<2x3xf32> to tensor<4x4xf32>
-  %pad1 = tensor.pad %pad0 low[0, 2] high[3, 2] {
-    ^bb0(%b2: index, %b3 : index):
-      tensor.yield %pad_value : f32
-    } : tensor<4x4xf32> to tensor<7x8xf32>
-  util.return %pad1 : tensor<7x8xf32>
+util.func public @fold_full_insert_into_extract(
+    %source: tensor<8x?xf32>,
+    %dest: tensor<10x?xf32>,
+    %size: index) -> tensor<8x?xf32> {
+  %extract = tensor.extract_slice %dest [1, 1] [8, %size] [1, 1] : tensor<10x?xf32> to tensor<8x?xf32>
+  %insert = tensor.insert_slice %source into %extract [0, 0] [8, %size] [1, 1] : tensor<8x?xf32> into tensor<8x?xf32>
+  util.return %insert : tensor<8x?xf32>
 }
-// CHECK-LABEL: util.func public @merge_constant_padding
-//  CHECK-SAME:   %[[ARG0:[A-Za-z0-9]+]]: tensor<2x3xf32>
-//  CHECK-SAME:   %[[PADVAL:[A-Za-z0-9]+]]: f32
-//       CHECK:   %[[PAD:.+]] = tensor.pad %[[ARG0]] low[1, 3] high[4, 2]
-//       CHECK:     tensor.yield %[[PADVAL]]
-//       CHECK:   util.return %[[PAD]]
+
+// CHECK-LABEL: util.func public @fold_full_insert_into_extract
+//  CHECK-SAME:   %[[SOURCE:.+]]: tensor<8x?xf32>
+//       CHECK:   util.return %[[SOURCE]]
 
 // -----
 
-util.func public @merge_constant_padding_dynamic(%arg0: tensor<?x?xf32>, %idx: index, %pad_value: f32) -> tensor<?x?xf32> {
-  %pad0 = tensor.pad %arg0 low[%idx, 1] high[1, 0] {
-    ^bb0(%b0: index, %b1 : index):
-      tensor.yield %pad_value : f32
-    } : tensor<?x?xf32> to tensor<?x?xf32>
-  %pad1 = tensor.pad %pad0 low[0, 2] high[%idx, 2] {
-    ^bb0(%b2: index, %b3 : index):
-      tensor.yield %pad_value : f32
-    } : tensor<?x?xf32> to tensor<?x?xf32>
-  util.return %pad1 : tensor<?x?xf32>
+util.func public @fold_full_insert_into_empty(
+    %source: tensor<8x?xf32>,
+    %size: index) -> tensor<8x?xf32> {
+  %empty = tensor.empty(%size) : tensor<8x?xf32>
+  %insert = tensor.insert_slice %source into %empty [0, 0] [8, %size] [1, 1] : tensor<8x?xf32> into tensor<8x?xf32>
+  util.return %insert : tensor<8x?xf32>
 }
-// CHECK-LABEL: util.func public @merge_constant_padding_dynamic
-//  CHECK-SAME:   %[[ARG0:[A-Za-z0-9]+]]: tensor<?x?xf32>
-//  CHECK-SAME:   %[[IDX:[A-Za-z0-9]+]]: index
-//  CHECK-SAME:   %[[PADVAL:[A-Za-z0-9]+]]: f32
-//       CHECK:   %[[HIGH:.+]] = affine.apply affine_map<()[s0] -> (s0 + 1)>()[%[[IDX]]]
-//       CHECK:   %[[PAD:.+]] = tensor.pad %[[ARG0]] low[%[[IDX]], 3] high[%[[HIGH]], 2]
-//       CHECK:     tensor.yield %[[PADVAL]]
-//       CHECK:   util.return %[[PAD]]
+
+// CHECK-LABEL: util.func public @fold_full_insert_into_empty
+//  CHECK-SAME:   %[[SOURCE:.+]]: tensor<8x?xf32>
+//       CHECK:   util.return %[[SOURCE]]
 
 // -----
 
-util.func public @dont_merge_constant_padding_nofold(%arg0: tensor<2x3xf32>, %pad_value: f32) -> tensor<7x8xf32> {
-  %pad0 = tensor.pad %arg0 low[1, 1] high[1, 0] {
-    ^bb0(%b0: index, %b1 : index):
-      tensor.yield %pad_value : f32
-    } : tensor<2x3xf32> to tensor<4x4xf32>
-  %pad1 = tensor.pad %pad0 nofold low[0, 2] high[3, 2] {
-    ^bb0(%b2: index, %b3 : index):
-      tensor.yield %pad_value : f32
-    } : tensor<4x4xf32> to tensor<7x8xf32>
-  util.return %pad1 : tensor<7x8xf32>
+util.func public @dont_fold_not_full_insert_into_empty(
+    %source: tensor<8x?xf32>,
+    %size1: index, %size2: index) -> tensor<8x?xf32> {
+  %empty = tensor.empty(%size1) : tensor<8x?xf32>
+  %insert = tensor.insert_slice %source into %empty [0, 0] [8, %size2] [1, 1] : tensor<8x?xf32> into tensor<8x?xf32>
+  util.return %insert : tensor<8x?xf32>
 }
 
-// Verify that folding does not happen if it would drop a nofold attribute
-
-// CHECK-LABEL: util.func public @dont_merge_constant_padding_nofold
-//       CHECK:   tensor.pad
-//       CHECK:   tensor.pad {{.*}} nofold
+// CHECK-LABEL: util.func public @dont_fold_not_full_insert_into_empty
+//       CHECK:   %[[INSERT:.+]] = tensor.insert_slice
+//       CHECK:   util.return %[[INSERT]]
 
 // -----
 
-util.func public @dont_merge_constant_padding_different_vals(
-    %arg0: tensor<2x3xf32>,
-    %pad_value0: f32,
-    %pad_value1: f32) -> tensor<7x8xf32> {
-  %pad0 = tensor.pad %arg0 low[1, 1] high[1, 0] {
-    ^bb0(%b0: index, %b1 : index):
-      tensor.yield %pad_value0 : f32
-    } : tensor<2x3xf32> to tensor<4x4xf32>
-  %pad1 = tensor.pad %pad0 nofold low[0, 2] high[3, 2] {
-    ^bb0(%b2: index, %b3 : index):
-      tensor.yield %pad_value1 : f32
-    } : tensor<4x4xf32> to tensor<7x8xf32>
-  util.return %pad1 : tensor<7x8xf32>
+util.func public @dont_fold_not_full_static_insert_into_empty(
+    %source: tensor<8x?xf32>,
+    %size: index) -> tensor<10x?xf32> {
+  %empty = tensor.empty(%size) : tensor<10x?xf32>
+  %insert = tensor.insert_slice %source into %empty [0, 0] [8, %size] [1, 1] : tensor<8x?xf32> into tensor<10x?xf32>
+  util.return %insert : tensor<10x?xf32>
 }
 
-// Verify that folding does not happen if it would drop a nofold attribute
-
-// CHECK-LABEL: util.func public @dont_merge_constant_padding_different_vals
-//       CHECK:   tensor.pad
-//       CHECK:   tensor.pad
+// CHECK-LABEL: util.func public @dont_fold_not_full_static_insert_into_empty
+//       CHECK:   %[[INSERT:.+]] = tensor.insert_slice
+//       CHECK:   util.return %[[INSERT]]