[GPU] Add pattern to fold fill into pad ops (#21864)

This pattern helps with fusion as otherwise due to the pad we end up
with an extract slice on the fill that cant be fused in the thread loop.

Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPadConvs.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPadConvs.cpp
index 0024893..d8afbf8 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPadConvs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPadConvs.cpp
@@ -6,8 +6,10 @@
 
 #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
 #include "iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h"
+#include "iree/compiler/Codegen/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/Passes.h"
 
 namespace mlir::iree_compiler {
@@ -72,6 +74,12 @@
         return signalPassFailure();
       }
     });
+    MLIRContext *context = &getContext();
+    RewritePatternSet cleanupPatterns(context);
+    populateFoldFillIntoPadPattern(cleanupPatterns);
+    if (failed(applyPatternsGreedily(funcOp, std::move(cleanupPatterns)))) {
+      return signalPassFailure();
+    }
   }
 };
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPadOperands.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPadOperands.cpp
index 538af2f..4ce43ab 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPadOperands.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPadOperands.cpp
@@ -7,8 +7,10 @@
 #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
 #include "iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h"
 #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
+#include "iree/compiler/Codegen/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/Passes.h"
 
 namespace mlir::iree_compiler {
@@ -75,6 +77,12 @@
         return signalPassFailure();
       }
     });
+    MLIRContext *context = &getContext();
+    RewritePatternSet cleanupPatterns(context);
+    populateFoldFillIntoPadPattern(cleanupPatterns);
+    if (failed(applyPatternsGreedily(funcOp, std::move(cleanupPatterns)))) {
+      return signalPassFailure();
+    }
   }
 };
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_pad_convs.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_pad_convs.mlir
index 12406eb..6ba4072 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_pad_convs.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_pad_convs.mlir
@@ -4,8 +4,11 @@
 #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
 #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
 #lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, padding_conv = [1, 8, 32, 32, 0, 0, 32]}>
-func.func @conv_2d_nhwc_fhwc(%arg0: tensor<16x26x19x287xf16>, %arg1: tensor<287x3x3x287xf16>, %arg2: tensor<16x24x17x287xf32>) -> tensor<16x24x17x287xf32> {
-  %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<16x26x19x287xf16>, tensor<287x3x3x287xf16>) outs(%arg2 : tensor<16x24x17x287xf32>) attrs = {lowering_config = #lowering_config} {
+func.func @conv_2d_nhwc_fhwc(%arg0: tensor<16x26x19x287xf16>, %arg1: tensor<287x3x3x287xf16>) -> tensor<16x24x17x287xf32> {
+   %cst = arith.constant 0.000000e+00 : f32
+  %empty = tensor.empty() : tensor<16x24x17x287xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<16x24x17x287xf32>) -> tensor<16x24x17x287xf32>
+  %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<16x26x19x287xf16>, tensor<287x3x3x287xf16>) outs(%fill : tensor<16x24x17x287xf32>) attrs = {lowering_config = #lowering_config} {
   ^bb0(%in: f16, %in_0: f16, %out: f32):
     %1 = arith.extf %in : f16 to f32
     %2 = arith.extf %in_0 : f16 to f32
@@ -19,13 +22,12 @@
 // CHECK-LABEL: func.func @conv_2d_nhwc_fhwc
 //  CHECK-SAME:   %[[A:[A-Za-z0-9]+]]: tensor<16x26x19x287xf16>
 //  CHECK-SAME:   %[[B:[A-Za-z0-9]+]]: tensor<287x3x3x287xf16>
-//  CHECK-SAME:   %[[C:[A-Za-z0-9]+]]: tensor<16x24x17x287xf32>
-//       CHECK:   %[[PADDED_LHS:.+]] = tensor.pad %[[A]] low[0, 0, 0, 0] high[0, 0, 15, 1]
-//       CHECK:   %[[PADDED_RHS:.+]] = tensor.pad %[[B]] low[0, 0, 0, 0] high[1, 0, 0, 1]
-//       CHECK:   %[[PADDED_INIT:.+]] = tensor.pad %[[C]] low[0, 0, 0, 0] high[0, 0, 15, 1]
+//   CHECK-DAG:   %[[PADDED_LHS:.+]] = tensor.pad %[[A]] low[0, 0, 0, 0] high[0, 0, 15, 1]
+//   CHECK-DAG:   %[[PADDED_RHS:.+]] = tensor.pad %[[B]] low[0, 0, 0, 0] high[1, 0, 0, 1]
+//   CHECK-DAG:   %[[FILL:.+]] = linalg.fill {{.*}} ->  tensor<16x24x32x288xf32>
 //       CHECK:   %[[PADDED_RESULT:.+]] = linalg.generic
 //  CHECK-SAME:     ins(%[[PADDED_LHS]], %[[PADDED_RHS]] : tensor<16x26x34x288xf16>, tensor<288x3x3x288xf16>)
-//  CHECK-SAME:     outs(%[[PADDED_INIT]] : tensor<16x24x32x288xf32>)
+//  CHECK-SAME:     outs(%[[FILL]] : tensor<16x24x32x288xf32>)
 //       CHECK:   %[[EXTRACT:.+]] = tensor.extract_slice %[[PADDED_RESULT]][0, 0, 0, 0] [16, 24, 17, 287] [1, 1, 1, 1]
 //  CHECK-SAME:     : tensor<16x24x32x288xf32> to tensor<16x24x17x287xf32>
 //       CHECK:   return %[[EXTRACT]] : tensor<16x24x17x287xf32>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_pad_operands.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_pad_operands.mlir
index 1621865..99ce2b7 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_pad_operands.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_pad_operands.mlir
@@ -14,13 +14,12 @@
 // CHECK-LABEL: func.func @matmul
 //  CHECK-SAME:   %[[A:[A-Za-z0-9]+]]: tensor<32x1024xf32>
 //  CHECK-SAME:   %[[B:[A-Za-z0-9]+]]: tensor<1024x128xf32>
-//       CHECK:   %[[FILL:.+]] = linalg.fill {{.*}} -> tensor<32x128xf32>
-//       CHECK:   %[[PADDED_LHS:.+]] = tensor.pad %[[A]] low[0, 0] high[1, 10]
-//       CHECK:   %[[PADDED_RHS:.+]] = tensor.pad %[[B]] low[0, 0] high[10, 5]
-//       CHECK:   %[[PADDED_INIT:.+]] = tensor.pad %[[FILL]] low[0, 0] high[1, 5]
+//   CHECK-DAG:   %[[FILL:.+]] = linalg.fill {{.*}} -> tensor<33x133xf32>
+//   CHECK-DAG:   %[[PADDED_LHS:.+]] = tensor.pad %[[A]] low[0, 0] high[1, 10]
+//   CHECK-DAG:   %[[PADDED_RHS:.+]] = tensor.pad %[[B]] low[0, 0] high[10, 5]
 //       CHECK:   %[[PADDED_RESULT:.+]] = linalg.matmul
 //  CHECK-SAME:     ins(%[[PADDED_LHS]], %[[PADDED_RHS]] : tensor<33x1034xf32>, tensor<1034x133xf32>)
-//  CHECK-SAME:     outs(%[[PADDED_INIT]] : tensor<33x133xf32>) -> tensor<33x133xf32>
+//  CHECK-SAME:     outs(%[[FILL]] : tensor<33x133xf32>) -> tensor<33x133xf32>
 //       CHECK:   %[[EXTRACT:.+]] = tensor.extract_slice %[[PADDED_RESULT]][0, 0] [32, 128] [1, 1]
 //  CHECK-SAME:     : tensor<33x133xf32> to tensor<32x128xf32>
 //       CHECK:   return %[[EXTRACT]] : tensor<32x128xf32>
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index 2d16c64..9839bbc 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -123,59 +123,9 @@
                                                setFusedOpOperandLimit<3>);
 }
 
-//===---------------------------------------------------------------------===//
-// ApplyFoldFillIntoPadPatternsOp
-//===---------------------------------------------------------------------===//
-
-namespace {
-/// Fold `tensor.pad(cst, tensor.extract*(linalg.fill(cst)))` into
-/// `linalg.fill(cst, empty)` when the padding constant and the fill constant
-/// are the same.
-/// This seems generally desirable as a folding but may be too intrusive, so we
-/// only apply it selectively for now.
-// TODO: atm hardcoded on linalg.fill but we could take any result of any
-// generic that yields a constant in that result.
-struct FoldFillIntoPad : public OpRewritePattern<tensor::PadOp> {
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(tensor::PadOp padOp,
-                                PatternRewriter &rewriter) const final {
-    Operation *currentOp = padOp.getSource().getDefiningOp();
-    auto maybeExtractSlice =
-        dyn_cast_or_null<tensor::ExtractSliceOp>(currentOp);
-    while (currentOp && maybeExtractSlice) {
-      currentOp = maybeExtractSlice.getSource().getDefiningOp();
-      maybeExtractSlice = dyn_cast_or_null<tensor::ExtractSliceOp>(currentOp);
-    }
-    auto fillOp = dyn_cast_or_null<linalg::FillOp>(currentOp);
-    if (!fillOp) {
-      return rewriter.notifyMatchFailure(
-          padOp, "not coming from a linalg.fill op via tensor.extract_slice*");
-    }
-
-    Value padValue = padOp.getConstantPaddingValue();
-    RankedTensorType resultType = padOp.getResultType();
-    if (!padValue ||
-        getAsOpFoldResult(padValue) !=
-            getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get())) {
-      return rewriter.notifyMatchFailure(
-          padOp, "not a constant value matching the fill value");
-    }
-
-    Location loc = padOp.getLoc();
-    auto emptyOp = rewriter.create<tensor::EmptyOp>(
-        loc, tensor::getMixedSizes(rewriter, loc, padOp),
-        resultType.getElementType());
-    rewriter.replaceOpWithNewOp<linalg::FillOp>(padOp, padValue,
-                                                emptyOp.getResult());
-
-    return success();
-  }
-};
-} // namespace
-
 void transform_dialect::ApplyFoldFillIntoPadPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
-  patterns.insert<FoldFillIntoPad>(patterns.getContext());
+  iree_compiler::populateFoldFillIntoPadPattern(patterns);
 }
 
 //===---------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
index 6812681..d2b7104 100644
--- a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
+++ b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
@@ -1239,4 +1239,58 @@
   patterns.insert<HoistForallFromFor>(patterns.getContext());
 }
 
+//===---------------------------------------------------------------------===//
+// ApplyFoldFillIntoPadPatternsOp
+//===---------------------------------------------------------------------===//
+
+namespace {
+/// Fold `tensor.pad(cst, tensor.extract*(linalg.fill(cst)))` into
+/// `linalg.fill(cst, empty)` when the padding constant and the fill constant
+/// are the same.
+/// This seems generally desirable as a folding but may be too intrusive, so we
+/// only apply it selectively for now.
+// TODO: atm hardcoded on linalg.fill but we could take any result of any
+// generic that yields a constant in that result.
+struct FoldFillIntoPad : public OpRewritePattern<tensor::PadOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(tensor::PadOp padOp,
+                                PatternRewriter &rewriter) const final {
+    Operation *currentOp = padOp.getSource().getDefiningOp();
+    auto maybeExtractSlice =
+        dyn_cast_or_null<tensor::ExtractSliceOp>(currentOp);
+    while (currentOp && maybeExtractSlice) {
+      currentOp = maybeExtractSlice.getSource().getDefiningOp();
+      maybeExtractSlice = dyn_cast_or_null<tensor::ExtractSliceOp>(currentOp);
+    }
+    auto fillOp = dyn_cast_or_null<linalg::FillOp>(currentOp);
+    if (!fillOp) {
+      return rewriter.notifyMatchFailure(
+          padOp, "not coming from a linalg.fill op via tensor.extract_slice*");
+    }
+
+    Value padValue = padOp.getConstantPaddingValue();
+    RankedTensorType resultType = padOp.getResultType();
+    if (!padValue ||
+        getAsOpFoldResult(padValue) !=
+            getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get())) {
+      return rewriter.notifyMatchFailure(
+          padOp, "not a constant value matching the fill value");
+    }
+
+    Location loc = padOp.getLoc();
+    auto emptyOp = rewriter.create<tensor::EmptyOp>(
+        loc, tensor::getMixedSizes(rewriter, loc, padOp),
+        resultType.getElementType());
+    rewriter.replaceOpWithNewOp<linalg::FillOp>(padOp, padValue,
+                                                emptyOp.getResult());
+
+    return success();
+  }
+};
+} // namespace
+
+void populateFoldFillIntoPadPattern(RewritePatternSet &patterns) {
+  patterns.insert<FoldFillIntoPad>(patterns.getContext());
+}
+
 } // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.h b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.h
index 8f0b4d6..d23f199 100644
--- a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.h
+++ b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.h
@@ -90,6 +90,9 @@
 /// scf.for ops.
 void populateForallLoopHoistingPattern(RewritePatternSet &patterns);
 
+/// Populate pattern that folds fill into pad ops.
+void populateFoldFillIntoPadPattern(RewritePatternSet &patterns);
+
 using GetMinMaxExprFn =
     std::function<std::optional<std::pair<AffineExpr, AffineExpr>>(
         Value value, SmallVectorImpl<Value> &dims,