[GPU] Add pattern to fold fill into pad ops (#21864)
This pattern helps with fusion as otherwise due to the pad we end up
with an extract slice on the fill that cant be fused in the thread loop.
Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPadConvs.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPadConvs.cpp
index 0024893..d8afbf8 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPadConvs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPadConvs.cpp
@@ -6,8 +6,10 @@
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h"
+#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
namespace mlir::iree_compiler {
@@ -72,6 +74,12 @@
return signalPassFailure();
}
});
+ MLIRContext *context = &getContext();
+ RewritePatternSet cleanupPatterns(context);
+ populateFoldFillIntoPadPattern(cleanupPatterns);
+ if (failed(applyPatternsGreedily(funcOp, std::move(cleanupPatterns)))) {
+ return signalPassFailure();
+ }
}
};
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPadOperands.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPadOperands.cpp
index 538af2f..4ce43ab 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPadOperands.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPadOperands.cpp
@@ -7,8 +7,10 @@
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
+#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
namespace mlir::iree_compiler {
@@ -75,6 +77,12 @@
return signalPassFailure();
}
});
+ MLIRContext *context = &getContext();
+ RewritePatternSet cleanupPatterns(context);
+ populateFoldFillIntoPadPattern(cleanupPatterns);
+ if (failed(applyPatternsGreedily(funcOp, std::move(cleanupPatterns)))) {
+ return signalPassFailure();
+ }
}
};
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_pad_convs.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_pad_convs.mlir
index 12406eb..6ba4072 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_pad_convs.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_pad_convs.mlir
@@ -4,8 +4,11 @@
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
#lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, padding_conv = [1, 8, 32, 32, 0, 0, 32]}>
-func.func @conv_2d_nhwc_fhwc(%arg0: tensor<16x26x19x287xf16>, %arg1: tensor<287x3x3x287xf16>, %arg2: tensor<16x24x17x287xf32>) -> tensor<16x24x17x287xf32> {
- %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<16x26x19x287xf16>, tensor<287x3x3x287xf16>) outs(%arg2 : tensor<16x24x17x287xf32>) attrs = {lowering_config = #lowering_config} {
+func.func @conv_2d_nhwc_fhwc(%arg0: tensor<16x26x19x287xf16>, %arg1: tensor<287x3x3x287xf16>) -> tensor<16x24x17x287xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %empty = tensor.empty() : tensor<16x24x17x287xf32>
+ %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<16x24x17x287xf32>) -> tensor<16x24x17x287xf32>
+ %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<16x26x19x287xf16>, tensor<287x3x3x287xf16>) outs(%fill : tensor<16x24x17x287xf32>) attrs = {lowering_config = #lowering_config} {
^bb0(%in: f16, %in_0: f16, %out: f32):
%1 = arith.extf %in : f16 to f32
%2 = arith.extf %in_0 : f16 to f32
@@ -19,13 +22,12 @@
// CHECK-LABEL: func.func @conv_2d_nhwc_fhwc
// CHECK-SAME: %[[A:[A-Za-z0-9]+]]: tensor<16x26x19x287xf16>
// CHECK-SAME: %[[B:[A-Za-z0-9]+]]: tensor<287x3x3x287xf16>
-// CHECK-SAME: %[[C:[A-Za-z0-9]+]]: tensor<16x24x17x287xf32>
-// CHECK: %[[PADDED_LHS:.+]] = tensor.pad %[[A]] low[0, 0, 0, 0] high[0, 0, 15, 1]
-// CHECK: %[[PADDED_RHS:.+]] = tensor.pad %[[B]] low[0, 0, 0, 0] high[1, 0, 0, 1]
-// CHECK: %[[PADDED_INIT:.+]] = tensor.pad %[[C]] low[0, 0, 0, 0] high[0, 0, 15, 1]
+// CHECK-DAG: %[[PADDED_LHS:.+]] = tensor.pad %[[A]] low[0, 0, 0, 0] high[0, 0, 15, 1]
+// CHECK-DAG: %[[PADDED_RHS:.+]] = tensor.pad %[[B]] low[0, 0, 0, 0] high[1, 0, 0, 1]
+// CHECK-DAG: %[[FILL:.+]] = linalg.fill {{.*}} -> tensor<16x24x32x288xf32>
// CHECK: %[[PADDED_RESULT:.+]] = linalg.generic
// CHECK-SAME: ins(%[[PADDED_LHS]], %[[PADDED_RHS]] : tensor<16x26x34x288xf16>, tensor<288x3x3x288xf16>)
-// CHECK-SAME: outs(%[[PADDED_INIT]] : tensor<16x24x32x288xf32>)
+// CHECK-SAME: outs(%[[FILL]] : tensor<16x24x32x288xf32>)
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[PADDED_RESULT]][0, 0, 0, 0] [16, 24, 17, 287] [1, 1, 1, 1]
// CHECK-SAME: : tensor<16x24x32x288xf32> to tensor<16x24x17x287xf32>
// CHECK: return %[[EXTRACT]] : tensor<16x24x17x287xf32>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_pad_operands.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_pad_operands.mlir
index 1621865..99ce2b7 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_pad_operands.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_pad_operands.mlir
@@ -14,13 +14,12 @@
// CHECK-LABEL: func.func @matmul
// CHECK-SAME: %[[A:[A-Za-z0-9]+]]: tensor<32x1024xf32>
// CHECK-SAME: %[[B:[A-Za-z0-9]+]]: tensor<1024x128xf32>
-// CHECK: %[[FILL:.+]] = linalg.fill {{.*}} -> tensor<32x128xf32>
-// CHECK: %[[PADDED_LHS:.+]] = tensor.pad %[[A]] low[0, 0] high[1, 10]
-// CHECK: %[[PADDED_RHS:.+]] = tensor.pad %[[B]] low[0, 0] high[10, 5]
-// CHECK: %[[PADDED_INIT:.+]] = tensor.pad %[[FILL]] low[0, 0] high[1, 5]
+// CHECK-DAG: %[[FILL:.+]] = linalg.fill {{.*}} -> tensor<33x133xf32>
+// CHECK-DAG: %[[PADDED_LHS:.+]] = tensor.pad %[[A]] low[0, 0] high[1, 10]
+// CHECK-DAG: %[[PADDED_RHS:.+]] = tensor.pad %[[B]] low[0, 0] high[10, 5]
// CHECK: %[[PADDED_RESULT:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[PADDED_LHS]], %[[PADDED_RHS]] : tensor<33x1034xf32>, tensor<1034x133xf32>)
-// CHECK-SAME: outs(%[[PADDED_INIT]] : tensor<33x133xf32>) -> tensor<33x133xf32>
+// CHECK-SAME: outs(%[[FILL]] : tensor<33x133xf32>) -> tensor<33x133xf32>
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[PADDED_RESULT]][0, 0] [32, 128] [1, 1]
// CHECK-SAME: : tensor<33x133xf32> to tensor<32x128xf32>
// CHECK: return %[[EXTRACT]] : tensor<32x128xf32>
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index 2d16c64..9839bbc 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -123,59 +123,9 @@
setFusedOpOperandLimit<3>);
}
-//===---------------------------------------------------------------------===//
-// ApplyFoldFillIntoPadPatternsOp
-//===---------------------------------------------------------------------===//
-
-namespace {
-/// Fold `tensor.pad(cst, tensor.extract*(linalg.fill(cst)))` into
-/// `linalg.fill(cst, empty)` when the padding constant and the fill constant
-/// are the same.
-/// This seems generally desirable as a folding but may be too intrusive, so we
-/// only apply it selectively for now.
-// TODO: atm hardcoded on linalg.fill but we could take any result of any
-// generic that yields a constant in that result.
-struct FoldFillIntoPad : public OpRewritePattern<tensor::PadOp> {
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(tensor::PadOp padOp,
- PatternRewriter &rewriter) const final {
- Operation *currentOp = padOp.getSource().getDefiningOp();
- auto maybeExtractSlice =
- dyn_cast_or_null<tensor::ExtractSliceOp>(currentOp);
- while (currentOp && maybeExtractSlice) {
- currentOp = maybeExtractSlice.getSource().getDefiningOp();
- maybeExtractSlice = dyn_cast_or_null<tensor::ExtractSliceOp>(currentOp);
- }
- auto fillOp = dyn_cast_or_null<linalg::FillOp>(currentOp);
- if (!fillOp) {
- return rewriter.notifyMatchFailure(
- padOp, "not coming from a linalg.fill op via tensor.extract_slice*");
- }
-
- Value padValue = padOp.getConstantPaddingValue();
- RankedTensorType resultType = padOp.getResultType();
- if (!padValue ||
- getAsOpFoldResult(padValue) !=
- getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get())) {
- return rewriter.notifyMatchFailure(
- padOp, "not a constant value matching the fill value");
- }
-
- Location loc = padOp.getLoc();
- auto emptyOp = rewriter.create<tensor::EmptyOp>(
- loc, tensor::getMixedSizes(rewriter, loc, padOp),
- resultType.getElementType());
- rewriter.replaceOpWithNewOp<linalg::FillOp>(padOp, padValue,
- emptyOp.getResult());
-
- return success();
- }
-};
-} // namespace
-
void transform_dialect::ApplyFoldFillIntoPadPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
- patterns.insert<FoldFillIntoPad>(patterns.getContext());
+ iree_compiler::populateFoldFillIntoPadPattern(patterns);
}
//===---------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
index 6812681..d2b7104 100644
--- a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
+++ b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
@@ -1239,4 +1239,58 @@
patterns.insert<HoistForallFromFor>(patterns.getContext());
}
+//===---------------------------------------------------------------------===//
+// ApplyFoldFillIntoPadPatternsOp
+//===---------------------------------------------------------------------===//
+
+namespace {
+/// Fold `tensor.pad(cst, tensor.extract*(linalg.fill(cst)))` into
+/// `linalg.fill(cst, empty)` when the padding constant and the fill constant
+/// are the same.
+/// This seems generally desirable as a folding but may be too intrusive, so we
+/// only apply it selectively for now.
+// TODO: atm hardcoded on linalg.fill but we could take any result of any
+// generic that yields a constant in that result.
+struct FoldFillIntoPad : public OpRewritePattern<tensor::PadOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(tensor::PadOp padOp,
+ PatternRewriter &rewriter) const final {
+ Operation *currentOp = padOp.getSource().getDefiningOp();
+ auto maybeExtractSlice =
+ dyn_cast_or_null<tensor::ExtractSliceOp>(currentOp);
+ while (currentOp && maybeExtractSlice) {
+ currentOp = maybeExtractSlice.getSource().getDefiningOp();
+ maybeExtractSlice = dyn_cast_or_null<tensor::ExtractSliceOp>(currentOp);
+ }
+ auto fillOp = dyn_cast_or_null<linalg::FillOp>(currentOp);
+ if (!fillOp) {
+ return rewriter.notifyMatchFailure(
+ padOp, "not coming from a linalg.fill op via tensor.extract_slice*");
+ }
+
+ Value padValue = padOp.getConstantPaddingValue();
+ RankedTensorType resultType = padOp.getResultType();
+ if (!padValue ||
+ getAsOpFoldResult(padValue) !=
+ getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get())) {
+ return rewriter.notifyMatchFailure(
+ padOp, "not a constant value matching the fill value");
+ }
+
+ Location loc = padOp.getLoc();
+ auto emptyOp = rewriter.create<tensor::EmptyOp>(
+ loc, tensor::getMixedSizes(rewriter, loc, padOp),
+ resultType.getElementType());
+ rewriter.replaceOpWithNewOp<linalg::FillOp>(padOp, padValue,
+ emptyOp.getResult());
+
+ return success();
+ }
+};
+} // namespace
+
+void populateFoldFillIntoPadPattern(RewritePatternSet &patterns) {
+ patterns.insert<FoldFillIntoPad>(patterns.getContext());
+}
+
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.h b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.h
index 8f0b4d6..d23f199 100644
--- a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.h
+++ b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.h
@@ -90,6 +90,9 @@
/// scf.for ops.
void populateForallLoopHoistingPattern(RewritePatternSet &patterns);
+/// Populate pattern that folds fill into pad ops.
+void populateFoldFillIntoPadPattern(RewritePatternSet &patterns);
+
using GetMinMaxExprFn =
std::function<std::optional<std::pair<AffineExpr, AffineExpr>>(
Value value, SmallVectorImpl<Value> &dims,