[DispatchCreation] Fold extract_slice of broadcast during split reduction tiling (#23012)
This PR moves FoldExtractSliceOfBroadcast to cleanup patterns in
`FormSplitReductionDispatches`. This folds
`extract_slice(broadcast(...))` during tiling rather than in
post-processing. No new lit test is needed as this change reuses an
existing pattern (FoldExtractSliceOfBroadcast) which already has test
coverage.
This relies on an upstream LLVM changes that adds rank-reducing slices
to generatedSlices in replaceExtractSliceWithTiledProducer in
https://github.com/llvm/llvm-project/pull/174248.
ci-extra: test_torch
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
diff --git a/compiler/src/iree/compiler/DispatchCreation/FormSplitReductionDispatches.cpp b/compiler/src/iree/compiler/DispatchCreation/FormSplitReductionDispatches.cpp
index 107ca7c..c8693ae 100644
--- a/compiler/src/iree/compiler/DispatchCreation/FormSplitReductionDispatches.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/FormSplitReductionDispatches.cpp
@@ -102,16 +102,17 @@
tileAndFuseOptions.setFusionControlFn(fusionControlFn);
tileAndFuseOptions.setTilingOptions(std::move(options));
+ MLIRContext *context = rewriter.getContext();
+ RewritePatternSet cleanupPatterns(context);
+ populateFoldExtractSliceOfBroadcastPattern(cleanupPatterns);
if (fusePad) {
- MLIRContext *context = rewriter.getContext();
- RewritePatternSet cleanupPatterns(context);
// When fusing pads we do not want to generate zeroSliceGuards.
cleanupPatterns.insert<linalg::ExtractSliceOfPadTensorSwapPattern>(
context,
[](tensor::ExtractSliceOp) { return /*zeroSliceGuard=*/false; });
- tileAndFuseOptions.cleanupPatterns =
- FrozenRewritePatternSet(std::move(cleanupPatterns));
}
+ tileAndFuseOptions.cleanupPatterns =
+ FrozenRewritePatternSet(std::move(cleanupPatterns));
FailureOr<scf::SCFTileAndFuseResult> result =
scf::tileConsumerAndFuseProducersUsingSCF(rewriter, op,
@@ -208,7 +209,6 @@
RewritePatternSet patterns(context);
linalg::populateSwapExtractSliceWithFillPatterns(patterns);
tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
- populateFoldExtractSliceOfBroadcastPattern(patterns);
GreedyRewriteConfig config;
config.setMaxIterations(GreedyRewriteConfig::kNoLimit).enableFolding(true);
if (failed(applyPatternsGreedily(funcOp, std::move(patterns), config))) {