[GPU] Don't swap expand with slice in the same block (#23267)
Prevents swapping tensor.extract_slice ops with tensor.expand_shape ops
when they are in the same block during GPUFuseAndHoistParallelLoops.
This pattern is intended to be a tiling fusion pattern for
tensor.expand_shape operations, which always begin with the ops in
different blocks. When the ops are in the same block, this does not open
up more fusions, and instead moves the extract_slice further compute ops
that produce loads. This can block later optimizations that fold the
slice into loading operations like vector.transfer_read.
This restriction is done through a control function to maintain the
ability to swap same-block slice + expand ops, in case it is needed for
other use cases.
ci-extra: test_torch
---------
Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp
index 48f4e11..d8230ef 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp
@@ -383,7 +383,17 @@
populateFuseTilableForallConsumersPattern(patterns);
patterns.add<FuseCollapseShapeConsumers>(context);
patterns.add<FuseExtractSliceConsumers>(context);
- populateSwapExtractWithExpandPattern(patterns);
+ // Only swap expand_shape with extract_slice when they are in different
+ // blocks. This acts as a fusion pattern for the expand_shape, which should
+ // only apply when they are in different blocks (i.e., one inside a loop,
+ // and one outside). Swapping in other cases can interfere with later
+ // optimizations that fold the slice into consumer load operations.
+ auto swapControlFn = [](OpOperand *operand) {
+ Operation *producer = operand->get().getDefiningOp();
+ Operation *consumer = operand->getOwner();
+ return producer->getBlock() != consumer->getBlock();
+ };
+ populateSwapExtractWithExpandPattern(patterns, swapControlFn);
tensor::populateFoldTensorEmptyPatterns(patterns);
scf::ForallOp::getCanonicalizationPatterns(patterns, context);
if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fuse_and_hoist_forall.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fuse_and_hoist_forall.mlir
index 278dcdb..baaa99c 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fuse_and_hoist_forall.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fuse_and_hoist_forall.mlir
@@ -944,3 +944,20 @@
// CHECK: linalg.copy
// CHECK: scf.forall.in_parallel
// CHECK: return
+
+// -----
+
+#translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64>
+func.func @no_swap_same_block_expand_slice(%arg0: tensor<64xf16>) -> tensor<4x4xf16>
+ attributes {translation_info = #translation_info} {
+ %expanded = tensor.expand_shape %arg0 [[0, 1]] output_shape [8, 8]
+ : tensor<64xf16> into tensor<8x8xf16>
+ %extracted = tensor.extract_slice %expanded[0, 0] [4, 4] [1, 1]
+ : tensor<8x8xf16> to tensor<4x4xf16>
+ return %extracted : tensor<4x4xf16>
+}
+
+// CHECK-LABEL: func @no_swap_same_block_expand_slice
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[EXPAND]]
+// CHECK: return %[[SLICE]]
diff --git a/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp
index 945ffdf..311bd8e 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp
@@ -335,7 +335,11 @@
struct SwapExpandShapeWithSlicePattern
: public OpRewritePattern<tensor::ExtractSliceOp> {
- using Base::Base;
+ SwapExpandShapeWithSlicePattern(MLIRContext *context,
+ linalg::ControlFusionFn controlFn,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<tensor::ExtractSliceOp>(context, benefit),
+ controlFn(std::move(controlFn)) {}
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
PatternRewriter &rewriter) const override {
@@ -349,14 +353,27 @@
"unsupported: non-unit stride");
}
+ // The control function receives the source operand of the extract_slice
+ // (which uses the expand_shape result).
+ OpOperand &srcOperand = sliceOp.getSourceMutable();
+ if (controlFn && !controlFn(&srcOperand)) {
+ return rewriter.notifyMatchFailure(sliceOp,
+ "rejected by control function");
+ }
+
return swapExpandShapeWithSlice(rewriter, expandOp, sliceOp);
}
+
+private:
+ linalg::ControlFusionFn controlFn;
};
} // namespace
-void populateSwapExtractWithExpandPattern(RewritePatternSet &patterns) {
- patterns.add<SwapExpandShapeWithSlicePattern>(patterns.getContext());
+void populateSwapExtractWithExpandPattern(RewritePatternSet &patterns,
+ linalg::ControlFusionFn controlFn) {
+ patterns.add<SwapExpandShapeWithSlicePattern>(patterns.getContext(),
+ std::move(controlFn));
}
namespace {
diff --git a/compiler/src/iree/compiler/Codegen/Common/Transforms.h b/compiler/src/iree/compiler/Codegen/Common/Transforms.h
index 413cfb1..5d77f55 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Transforms.h
+++ b/compiler/src/iree/compiler/Codegen/Common/Transforms.h
@@ -190,8 +190,10 @@
void populateReplaceSlowMinMaxOpsPatterns(RewritePatternSet &patterns);
/// Populate pattern to convert `tensor.extract_slice(tensor.expand_shape)` to
-/// `tensor.expand_shape(tensor.extract_slice)`.
-void populateSwapExtractWithExpandPattern(RewritePatternSet &patterns);
+/// `tensor.expand_shape(tensor.extract_slice)`. The optional `controlFn` can be
+/// used to restrict when the pattern applies.
+void populateSwapExtractWithExpandPattern(
+ RewritePatternSet &patterns, linalg::ControlFusionFn controlFn = nullptr);
/// Populate pattern to fold `tensor.extract_slice(linalg.broadcast)` into the
/// broadcast input when the extract_slice undoes the broadcast.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_igemm_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_igemm_tile_and_fuse.mlir
index ea58e80..2ff6547 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_igemm_tile_and_fuse.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_igemm_tile_and_fuse.mlir
@@ -154,7 +154,7 @@
// CHECK-DAG: vector.transfer_read {{.*}} vector<4xf16>
// CHECK-DAG: vector.transfer_read {{.*}} vector<4xf16>
// CHECK-COUNT-1: amdgpu.mfma 16x16x16
-// CHECK: %[[LOOP_T:.+]] = vector.shape_cast %[[LOOP]] : vector<1x1x1x1x4x1xf32> to vector<4xf32>
+// CHECK: %[[LOOP_T:.+]] = vector.shape_cast %[[LOOP]]
// CHECK: vector.transfer_write %[[LOOP_T]]
// Note there is a writeback loop here that is skipped to simplify the test.
// CHECK: memref.copy {{.*}}#gpu.address_space<workgroup>> to {{.*}}#amdgpu.address_space<fat_raw_buffer>