[GPU] Add pattern to sink extract_slice through generic ops (#21796)
This is useful propagation pattern that helps fusing consumer generic
ops into the tiled loops.
---------
Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPackToIntrinsics.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPackToIntrinsics.cpp
index e3b0d28..3e639e8 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPackToIntrinsics.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPackToIntrinsics.cpp
@@ -289,7 +289,33 @@
return !getLoweringConfig(producer) && !getLoweringConfig(consumer);
};
+ // Additionally, we do not sink extract slice through generic if slice source
+ // is a block argument or if the source, slice or generic are in different
+ // blocks as this would affect how tiling uses extract slice ops.
+
+ linalg::ControlPropagationFn controlExtract =
+ [](OpOperand *opOperand) -> bool {
+ Operation *producer = opOperand->get().getDefiningOp();
+ Operation *consumer = opOperand->getOwner();
+ if (getLoweringConfig(producer) || getLoweringConfig(consumer)) {
+ return false;
+ }
+ auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(producer);
+ if (!sliceOp) {
+ return false;
+ }
+ Operation *producerSrc = sliceOp.getSource().getDefiningOp();
+ // If source is not an op, e.g is a block argument then return false.
+ if (!producerSrc) {
+ return false;
+ }
+
+ return producerSrc->getBlock() == producer->getBlock() &&
+ consumer->getBlock() == producer->getBlock();
+ };
+
linalg::populateDataLayoutPropagationPatterns(patterns, control);
+ linalg::populateExtractSliceSinkingPatterns(patterns, controlExtract);
patterns.add<PackDestinationForOp>(context);
linalg::UnPackOp::getCanonicalizationPatterns(patterns, context);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_pack_to_instrinsics.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_pack_to_instrinsics.mlir
index ce2fa20..e4565a0 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_pack_to_instrinsics.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_pack_to_instrinsics.mlir
@@ -309,3 +309,81 @@
// CHECK: scf.yield %[[INNER_FOR_RESULT]] : tensor<1x1x4x2x16x16xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x64x32xf32>
// CHECK: linalg.unpack %[[OUTER_FOR_RESULT]] inner_dims_pos = [2, 3] inner_tiles = [16, 16] into %[[EMPTY:.+]] : tensor<1x1x4x2x16x16xf32> -> tensor<1x1x64x32xf32>
+
+// -----
+
+func.func @propagate_extract_basic(%input : tensor<128xf32>, %arg0 : index, %arg1 : tensor<?xbf16>) -> tensor<?xbf16> {
+ %empty = util.optimization_barrier %input : tensor<128xf32>
+ %extracted_slice = tensor.extract_slice %empty[%arg0] [%arg0] [1] : tensor<128xf32> to tensor<?xf32>
+ %generic = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%extracted_slice : tensor<?xf32>) outs(%arg1 : tensor<?xbf16>) {
+ ^bb0(%in: f32, %out: bf16):
+ %1 = arith.truncf %in : f32 to bf16
+ linalg.yield %1 : bf16
+ } -> tensor<?xbf16>
+ return %generic : tensor<?xbf16>
+}
+
+// CHECK-LABEL: func.func @propagate_extract_basic
+// CHECK: linalg.generic
+// CHECK: tensor.extract_slice
+
+// -----
+
+func.func @no_propagate_extract_blockargument(%input : tensor<128xf32>, %arg0 : index, %arg1 : tensor<?xbf16>) -> tensor<?xbf16> {
+ %extracted_slice = tensor.extract_slice %input[%arg0] [%arg0] [1] : tensor<128xf32> to tensor<?xf32>
+ %generic = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%extracted_slice : tensor<?xf32>) outs(%arg1 : tensor<?xbf16>) {
+ ^bb0(%in: f32, %out: bf16):
+ %1 = arith.truncf %in : f32 to bf16
+ linalg.yield %1 : bf16
+ } -> tensor<?xbf16>
+ return %generic : tensor<?xbf16>
+}
+
+// CHECK-LABEL: func.func @no_propagate_extract_blockargument
+// CHECK: tensor.extract_slice
+// CHECK: linalg.generic
+
+
+// -----
+
+func.func @no_propagate_extract_differentblock_1(%input : tensor<128xf32>, %arg0 : index, %arg1 : tensor<?xbf16>) -> tensor<?xbf16> {
+ %empty = util.optimization_barrier %input : tensor<128xf32>
+ %c0 = arith.constant 0 : index
+ %c32 = arith.constant 32 : index
+ %for = scf.for %arg2 = %c0 to %c32 step %arg0 iter_args(%arg4 = %arg1) -> tensor<?xbf16> {
+ %extracted_slice = tensor.extract_slice %empty[%arg0] [%arg0] [1] : tensor<128xf32> to tensor<?xf32>
+ %generic = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%extracted_slice : tensor<?xf32>) outs(%arg4 : tensor<?xbf16>) {
+ ^bb0(%in: f32, %out: bf16):
+ %1 = arith.truncf %in : f32 to bf16
+ linalg.yield %1 : bf16
+ } -> tensor<?xbf16>
+ scf.yield %generic : tensor<?xbf16>
+ }
+ return %for : tensor<?xbf16>
+}
+
+// CHECK-LABEL: func.func @no_propagate_extract_differentblock_1
+// CHECK: tensor.extract_slice
+// CHECK: linalg.generic
+
+// -----
+
+func.func @no_propagate_extract_differentblock_2(%input : tensor<128xf32>, %arg0 : index, %arg1 : tensor<?xbf16>) -> tensor<?xbf16> {
+ %empty = util.optimization_barrier %input : tensor<128xf32>
+ %c0 = arith.constant 0 : index
+ %c32 = arith.constant 32 : index
+ %extracted_slice = tensor.extract_slice %empty[%arg0] [%arg0] [1] : tensor<128xf32> to tensor<?xf32>
+ %for = scf.for %arg2 = %c0 to %c32 step %arg0 iter_args(%arg4 = %arg1) -> tensor<?xbf16> {
+ %generic = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%extracted_slice : tensor<?xf32>) outs(%arg4 : tensor<?xbf16>) {
+ ^bb0(%in: f32, %out: bf16):
+ %1 = arith.truncf %in : f32 to bf16
+ linalg.yield %1 : bf16
+ } -> tensor<?xbf16>
+ scf.yield %generic : tensor<?xbf16>
+ }
+ return %for : tensor<?xbf16>
+}
+
+// CHECK-LABEL: func.func @no_propagate_extract_differentblock_2
+// CHECK: tensor.extract_slice
+// CHECK: linalg.generic