[Flow] Add pattern to canonicalize consecutive pads (#17878)

This allows fusing `pad(pad)` into a single pad with the same padding
value. This is added here due to the inclusion of the Affine dialect and
upstream tensor not depending on Affine.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Canonicalizer.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Canonicalizer.cpp
index b683b26..b19bccc 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Canonicalizer.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Canonicalizer.cpp
@@ -6,6 +6,7 @@
 
 #include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -17,6 +18,84 @@
 
 namespace {
 
+/// Folds a chain of `tensor.pad` ops with the same constant padding value.
+///
+/// Example:
+///
+/// ```mlir
+///   %1 = tensor.pad %0 low[0, 1] high[0, 2] {
+///       tensor.yield %val
+///     } : tensor<1x2xf32> to tensor<2x5xf32>
+///   %res = tensor.pad %1 low[0, 2] high[3, 0] {
+///       tensor.yield %val
+///     } : tensor<1x5xf32> to tensor<5x7xf32>
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+///   %res = tensor.pad %0 low[0, 3] high[3, 2] {
+///       tensor.yield %val
+///     } : tensor<1x2xf32> to tensor<5x7xf32>
+/// ```
+///
+/// NOTE: This wasn't sent upstream as a canonicalization due to the use of
+/// the Affine dialect.
+struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
+  using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::PadOp padOp,
+                                PatternRewriter &rewriter) const override {
+    if (padOp.getNofold()) {
+      return failure();
+    }
+    auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
+    if (!producerPad || producerPad.getNofold()) {
+      return rewriter.notifyMatchFailure(
+          padOp, "producer is not a foldable tensor.pad op");
+    }
+
+    // Fail if the tensor::PadOps padding values do not match.
+    Value consumerPadValue = padOp.getConstantPaddingValue();
+    Value producerPadValue = producerPad.getConstantPaddingValue();
+    if (!consumerPadValue || !producerPadValue ||
+        consumerPadValue != producerPadValue) {
+      return rewriter.notifyMatchFailure(
+          padOp, "cannot fold PadOps with different padding values");
+    }
+
+    Location loc = padOp.getLoc();
+    AffineExpr d0, d1;
+    bindDims(rewriter.getContext(), d0, d1);
+
+    // Combine the low/high paddings of the two tensor::PadOps.
+    auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
+                           ArrayRef<OpFoldResult> producerPaddings) {
+      SmallVector<OpFoldResult> sumPaddings;
+      for (auto [consumerIndex, producerIndex] :
+           llvm::zip_equal(consumerPaddings, producerPaddings)) {
+        sumPaddings.push_back(affine::makeComposedFoldedAffineApply(
+            rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
+      }
+      return sumPaddings;
+    };
+
+    SmallVector<OpFoldResult> newHighPad =
+        addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
+    SmallVector<OpFoldResult> newLowPad =
+        addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
+
+    auto newPadOp = rewriter.create<tensor::PadOp>(
+        padOp.getLoc(), padOp.getResultType(), producerPad.getSource(),
+        newLowPad, newHighPad, padOp.getNofold(),
+        getPrunedAttributeList(padOp, tensor::PadOp::getAttributeNames()));
+    rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
+                                newPadOp.getRegion().begin());
+    rewriter.replaceOp(padOp, newPadOp.getResult());
+    return success();
+  }
+};
+
 /// Canonicalize operations in nested regions.
 struct CanonicalizerPass
     : public impl::CanonicalizerPassBase<CanonicalizerPass> {
@@ -35,9 +114,10 @@
     for (RegisteredOperationName op : context->getRegisteredOperations())
       op.getCanonicalizationPatterns(owningPatterns, context);
 
-    // Some Flow specific patterns we want to pull in for common
-    // canonicalization.
+    // Pull in some borderline/downstream canonicalizations for the Flow
+    // compilation phase.
     tensor::populateMergeConsecutiveInsertExtractSlicePatterns(owningPatterns);
+    owningPatterns.add<FoldConsecutiveConstantPadding>(context);
 
     patterns =
         std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
index 2f39493..9764b3b 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
@@ -32,6 +32,7 @@
             "dispatch_linalg_transform_dialect.mlir",
             "dispatch_region_formation_preprocessing.mlir",
             "export_benchmark_funcs.mlir",
+            "flow_canonicalize.mlir",
             "fold_unit_dims.mlir",
             "form_dispatch_regions.mlir",
             "form_dispatch_workgroups.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
index 44d05cf..b80b495 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
@@ -31,6 +31,7 @@
     "dispatch_linalg_transform_dialect.mlir"
     "dispatch_region_formation_preprocessing.mlir"
     "export_benchmark_funcs.mlir"
+    "flow_canonicalize.mlir"
     "fold_unit_dims.mlir"
     "form_dispatch_regions.mlir"
     "form_dispatch_workgroups.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/flow_canonicalize.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/flow_canonicalize.mlir
new file mode 100644
index 0000000..81203a5
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/flow_canonicalize.mlir
@@ -0,0 +1,84 @@
+// RUN: iree-opt --iree-flow-canonicalize %s --split-input-file --mlir-print-local-scope | FileCheck %s
+
+util.func public @merge_constant_padding(%arg0: tensor<2x3xf32>, %pad_value: f32) -> tensor<7x8xf32> {
+  %pad0 = tensor.pad %arg0 low[1, 1] high[1, 0] {
+    ^bb0(%b0: index, %b1 : index):
+      tensor.yield %pad_value : f32
+    } : tensor<2x3xf32> to tensor<4x4xf32>
+  %pad1 = tensor.pad %pad0 low[0, 2] high[3, 2] {
+    ^bb0(%b2: index, %b3 : index):
+      tensor.yield %pad_value : f32
+    } : tensor<4x4xf32> to tensor<7x8xf32>
+  util.return %pad1 : tensor<7x8xf32>
+}
+// CHECK-LABEL: util.func public @merge_constant_padding
+//  CHECK-SAME:   %[[ARG0:[A-Za-z0-9]+]]: tensor<2x3xf32>
+//  CHECK-SAME:   %[[PADVAL:[A-Za-z0-9]+]]: f32
+//       CHECK:   %[[PAD:.+]] = tensor.pad %[[ARG0]] low[1, 3] high[4, 2]
+//       CHECK:     tensor.yield %[[PADVAL]]
+//       CHECK:   util.return %[[PAD]]
+
+// -----
+
+util.func public @merge_constant_padding_dynamic(%arg0: tensor<?x?xf32>, %idx: index, %pad_value: f32) -> tensor<?x?xf32> {
+  %pad0 = tensor.pad %arg0 low[%idx, 1] high[1, 0] {
+    ^bb0(%b0: index, %b1 : index):
+      tensor.yield %pad_value : f32
+    } : tensor<?x?xf32> to tensor<?x?xf32>
+  %pad1 = tensor.pad %pad0 low[0, 2] high[%idx, 2] {
+    ^bb0(%b2: index, %b3 : index):
+      tensor.yield %pad_value : f32
+    } : tensor<?x?xf32> to tensor<?x?xf32>
+  util.return %pad1 : tensor<?x?xf32>
+}
+// CHECK-LABEL: util.func public @merge_constant_padding_dynamic
+//  CHECK-SAME:   %[[ARG0:[A-Za-z0-9]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[IDX:[A-Za-z0-9]+]]: index
+//  CHECK-SAME:   %[[PADVAL:[A-Za-z0-9]+]]: f32
+//       CHECK:   %[[HIGH:.+]] = affine.apply affine_map<()[s0] -> (s0 + 1)>()[%[[IDX]]]
+//       CHECK:   %[[PAD:.+]] = tensor.pad %[[ARG0]] low[%[[IDX]], 3] high[%[[HIGH]], 2]
+//       CHECK:     tensor.yield %[[PADVAL]]
+//       CHECK:   util.return %[[PAD]]
+
+// -----
+
+util.func public @dont_merge_constant_padding_nofold(%arg0: tensor<2x3xf32>, %pad_value: f32) -> tensor<7x8xf32> {
+  %pad0 = tensor.pad %arg0 low[1, 1] high[1, 0] {
+    ^bb0(%b0: index, %b1 : index):
+      tensor.yield %pad_value : f32
+    } : tensor<2x3xf32> to tensor<4x4xf32>
+  %pad1 = tensor.pad %pad0 nofold low[0, 2] high[3, 2] {
+    ^bb0(%b2: index, %b3 : index):
+      tensor.yield %pad_value : f32
+    } : tensor<4x4xf32> to tensor<7x8xf32>
+  util.return %pad1 : tensor<7x8xf32>
+}
+
+// Verify that folding does not happen if it would drop a nofold attribute
+
+// CHECK-LABEL: util.func public @dont_merge_constant_padding_nofold
+//       CHECK:   tensor.pad
+//       CHECK:   tensor.pad {{.*}} nofold
+
+// -----
+
+util.func public @dont_merge_constant_padding_different_vals(
+    %arg0: tensor<2x3xf32>,
+    %pad_value0: f32,
+    %pad_value1: f32) -> tensor<7x8xf32> {
+  %pad0 = tensor.pad %arg0 low[1, 1] high[1, 0] {
+    ^bb0(%b0: index, %b1 : index):
+      tensor.yield %pad_value0 : f32
+    } : tensor<2x3xf32> to tensor<4x4xf32>
+  %pad1 = tensor.pad %pad0 nofold low[0, 2] high[3, 2] {
+    ^bb0(%b2: index, %b3 : index):
+      tensor.yield %pad_value1 : f32
+    } : tensor<4x4xf32> to tensor<7x8xf32>
+  util.return %pad1 : tensor<7x8xf32>
+}
+
+// Verify that folding does not happen if it would drop a nofold attribute
+
+// CHECK-LABEL: util.func public @dont_merge_constant_padding_different_vals
+//       CHECK:   tensor.pad
+//       CHECK:   tensor.pad