[Codegen][GPU] Add tiling cleanup pattern to fuse pad without zero gaurd (#18748)
This PR adds a way to fuse tensor.pad in ApplyGPUTilingLevel when we
know the pad will not ever recieve an empty slice. This is useful, when
the tensor.pad is padding to the tiling size that we are tiling with,
and will never generate an empty slice.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp
index a6b6bf8..2d7dbd1 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp
@@ -13,6 +13,7 @@
#include "llvm/ADT/STLForwardCompat.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -65,10 +66,9 @@
/// Apply a tile and fuse transformation to all payload ops and store both the
/// tiled operation as well as the created tile loops.
-static LogicalResult
-applyTileAndFuseToEachRoot(RewriterBase &rewriter,
- llvm::SmallDenseSet<TilingInterface> &payloadOps,
- IREE::GPU::TilingLevel tilingLevel) {
+static LogicalResult applyTileAndFuseToEachRoot(
+ RewriterBase &rewriter, llvm::SmallDenseSet<TilingInterface> &payloadOps,
+ IREE::GPU::TilingLevel tilingLevel, bool allowZeroSlices) {
MLIRContext *context = rewriter.getContext();
for (TilingInterface tilingInterfaceOp : payloadOps) {
mlir::DominanceInfo dominanceInfo(tilingInterfaceOp);
@@ -137,7 +137,8 @@
Operation *owner = originalProducer.getOwner();
if (tilingLevel == IREE::GPU::TilingLevel::Reduction ||
tilingLevel == IREE::GPU::TilingLevel::Subgroup) {
- // Do not fuse pad in reduction and subgroup tiling.
+ // Do not fuse pad in reduction and subgroup tiling. We instead fuse
+ // pad without zero slice guard as a cleanup pattern.
if (isa<tensor::PadOp>(owner)) {
return std::nullopt;
}
@@ -161,6 +162,22 @@
};
tileAndFuseOptions.setFusionControlFn(controlFn);
+ RewritePatternSet cleanupPatterns(context);
+
+ if (allowZeroSlices) {
+ // Add pattern to fuse pad operations without zero slice gaurd, if we
+ // know we have no zero slices.
+ auto zeroSliceGuard = [](tensor::ExtractSliceOp) -> std::optional<bool> {
+ // Do not use zero slice gaurd.
+ return false;
+ };
+ cleanupPatterns.add<linalg::ExtractSliceOfPadTensorSwapPattern>(
+ context, zeroSliceGuard);
+ }
+
+ tileAndFuseOptions.cleanupPatterns =
+ FrozenRewritePatternSet(std::move(cleanupPatterns));
+
FailureOr<scf::SCFTileAndFuseResult> tiledResults =
scf::tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
tileAndFuseOptions);
@@ -221,7 +238,8 @@
getTiledOps(funcOp, tilingLevel);
IRRewriter rewriter(funcOp);
- if (failed(applyTileAndFuseToEachRoot(rewriter, targetOps, tilingLevel))) {
+ if (failed(applyTileAndFuseToEachRoot(rewriter, targetOps, tilingLevel,
+ allowZeroSlices))) {
funcOp.emitError() << "tiling of level "
<< IREE::GPU::stringifyEnum(tilingLevel) << " failed\n";
return signalPassFailure();
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
index d339adf..937f04e 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
@@ -205,6 +205,9 @@
clEnumValN(IREE::GPU::TilingLevel::Subgroup, "subgroup",
"Tile and fuse all annotated ops to threads")
)}]>,
+ Option<"allowZeroSlices", "allow-zero-slices", "bool",
+ /*default=*/"false",
+ "Allow pad fusion to generate zero size slices">
];
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir
index 3cec8a9..3350078 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir
@@ -1,4 +1,5 @@
// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-apply-tiling-level, canonicalize, cse))" %s | FileCheck %s
+// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-apply-tiling-level{allow-zero-slices=true}, canonicalize, cse))" %s | FileCheck %s --check-prefix=NOZERO
// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-apply-tiling-level{tiling-level=thread}, canonicalize, cse))" %s | FileCheck %s --check-prefix=THREAD
// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-apply-tiling-level{tiling-level=subgroup}, canonicalize, cse))" %s | FileCheck %s --check-prefix=SUBGROUP
@@ -474,3 +475,51 @@
// SUBGROUP: scf.forall.in_parallel
// SUBGROUP: tensor.parallel_insert_slice %[[MMA]] into %[[INIT]]
// SUBGROUP: return
+
+// -----
+
+// This test only checks when a tensor.pad gets fused when tiling. We disable
+// tensor.pad fusion by default, because it generates a gaurd to prevent
+// empty slices, which is hard to vectorize.
+//
+// However, if we already know no zero slices will be generated, we can fuse
+// the pad directly.
+
+#map = affine_map<()[s0] -> (s0 * -16 + 19, 16)>
+#map1 = affine_map<()[s0] -> (-s0 + 16)>
+module {
+ func.func @fuse_pad_no_zero_slice(%arg0: tensor<?x17xf32>, %arg1: tensor<17x17xf32>, %arg2: index, %arg3: index) -> tensor<?x17xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = affine.min #map()[%arg2]
+ %1 = tensor.empty() : tensor<16x32xf32>
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<16x32xf32>) -> tensor<16x32xf32>
+ %3 = affine.apply #map1()[%0]
+ %padded = tensor.pad %arg0 low[0, 0] high[%3, 7] {
+ ^bb0(%arg4: index, %arg5: index):
+ tensor.yield %cst : f32
+ } : tensor<?x17xf32> to tensor<16x24xf32>
+ %padded_0 = tensor.pad %arg1 low[0, 0] high[7, 15] {
+ ^bb0(%arg4: index, %arg5: index):
+ tensor.yield %cst : f32
+ } : tensor<17x17xf32> to tensor<24x32xf32>
+ %4 = linalg.matmul {lowering_config = #iree_gpu.lowering_config<{reduction = [0, 0, 8]}>} ins(%padded, %padded_0 : tensor<16x24xf32>, tensor<24x32xf32>) outs(%2 : tensor<16x32xf32>) -> tensor<16x32xf32>
+ %extracted_slice = tensor.extract_slice %4[0, 0] [%0, 17] [1, 1] : tensor<16x32xf32> to tensor<?x17xf32>
+ return %extracted_slice : tensor<?x17xf32>
+ }
+}
+
+// Only fuse pad when no-zero-slices is true.
+
+// CHECK-LABEL: @fuse_pad_no_zero_slice
+// CHECK: tensor.pad
+// CHECK: tensor.pad
+// CHECK: scf.for
+// CHECK-NOT: tensor.pad
+// CHECK: linalg.matmul
+
+// NOZERO-LABEL: @fuse_pad_no_zero_slice
+// NOZERO-NOT: tensor.pad
+// NOZERO: scf.for
+// NOZERO: tensor.pad
+// NOZERO: tensor.pad
+// NOZERO: linalg.matmul