[GPU] Add pattern to sink extract_slice through generic ops (#21796)

This is useful propagation pattern that helps fusing consumer generic
ops into the tiled loops.

---------

Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPackToIntrinsics.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPackToIntrinsics.cpp
index e3b0d28..3e639e8 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPackToIntrinsics.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPackToIntrinsics.cpp
@@ -289,7 +289,33 @@
     return !getLoweringConfig(producer) && !getLoweringConfig(consumer);
   };
 
+  // Additionally, we do not sink extract slice through generic if slice source
+  // is a block argument or if the source, slice or generic are in different
+  // blocks as this would affect how tiling uses extract slice ops.
+
+  linalg::ControlPropagationFn controlExtract =
+      [](OpOperand *opOperand) -> bool {
+    Operation *producer = opOperand->get().getDefiningOp();
+    Operation *consumer = opOperand->getOwner();
+    if (getLoweringConfig(producer) || getLoweringConfig(consumer)) {
+      return false;
+    }
+    auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(producer);
+    if (!sliceOp) {
+      return false;
+    }
+    Operation *producerSrc = sliceOp.getSource().getDefiningOp();
+    // If source is not an op, e.g is a block argument then return false.
+    if (!producerSrc) {
+      return false;
+    }
+
+    return producerSrc->getBlock() == producer->getBlock() &&
+           consumer->getBlock() == producer->getBlock();
+  };
+
   linalg::populateDataLayoutPropagationPatterns(patterns, control);
+  linalg::populateExtractSliceSinkingPatterns(patterns, controlExtract);
   patterns.add<PackDestinationForOp>(context);
   linalg::UnPackOp::getCanonicalizationPatterns(patterns, context);
   if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_pack_to_instrinsics.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_pack_to_instrinsics.mlir
index ce2fa20..e4565a0 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_pack_to_instrinsics.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_pack_to_instrinsics.mlir
@@ -309,3 +309,81 @@
 //       CHECK:     scf.yield %[[INNER_FOR_RESULT]] : tensor<1x1x4x2x16x16xf32>
 //       CHECK:   %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x64x32xf32>
 //       CHECK:   linalg.unpack %[[OUTER_FOR_RESULT]] inner_dims_pos = [2, 3] inner_tiles = [16, 16] into %[[EMPTY:.+]] : tensor<1x1x4x2x16x16xf32> -> tensor<1x1x64x32xf32>
+
+// -----
+
+func.func @propagate_extract_basic(%input : tensor<128xf32>, %arg0 : index, %arg1 : tensor<?xbf16>) -> tensor<?xbf16> {
+  %empty = util.optimization_barrier %input : tensor<128xf32>
+  %extracted_slice =  tensor.extract_slice %empty[%arg0] [%arg0] [1] : tensor<128xf32> to tensor<?xf32>
+  %generic = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%extracted_slice : tensor<?xf32>) outs(%arg1 : tensor<?xbf16>) {
+  ^bb0(%in: f32, %out: bf16):
+    %1 = arith.truncf %in : f32 to bf16
+    linalg.yield %1 : bf16
+  } -> tensor<?xbf16>
+  return %generic : tensor<?xbf16>
+}
+
+// CHECK-LABEL: func.func @propagate_extract_basic
+//       CHECK:  linalg.generic
+//       CHECK:  tensor.extract_slice
+
+// -----
+
+func.func @no_propagate_extract_blockargument(%input : tensor<128xf32>, %arg0 : index, %arg1 : tensor<?xbf16>) -> tensor<?xbf16> {
+  %extracted_slice =  tensor.extract_slice %input[%arg0] [%arg0] [1] : tensor<128xf32> to tensor<?xf32>
+  %generic = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%extracted_slice : tensor<?xf32>) outs(%arg1 : tensor<?xbf16>) {
+  ^bb0(%in: f32, %out: bf16):
+    %1 = arith.truncf %in : f32 to bf16
+    linalg.yield %1 : bf16
+  } -> tensor<?xbf16>
+  return %generic : tensor<?xbf16>
+}
+
+// CHECK-LABEL: func.func @no_propagate_extract_blockargument
+//       CHECK:  tensor.extract_slice
+//       CHECK:  linalg.generic
+
+
+// -----
+
+func.func @no_propagate_extract_differentblock_1(%input : tensor<128xf32>, %arg0 : index, %arg1 : tensor<?xbf16>) -> tensor<?xbf16> {
+  %empty = util.optimization_barrier %input : tensor<128xf32>
+  %c0 = arith.constant 0 : index
+  %c32 = arith.constant 32 : index
+  %for = scf.for %arg2 = %c0 to %c32 step %arg0 iter_args(%arg4 = %arg1) -> tensor<?xbf16> {
+    %extracted_slice =  tensor.extract_slice %empty[%arg0] [%arg0] [1] : tensor<128xf32> to tensor<?xf32>
+    %generic = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%extracted_slice : tensor<?xf32>) outs(%arg4 : tensor<?xbf16>) {
+    ^bb0(%in: f32, %out: bf16):
+      %1 = arith.truncf %in : f32 to bf16
+      linalg.yield %1 : bf16
+    } -> tensor<?xbf16>
+    scf.yield %generic : tensor<?xbf16>
+  }
+  return %for : tensor<?xbf16>
+}
+
+// CHECK-LABEL: func.func @no_propagate_extract_differentblock_1
+//       CHECK:  tensor.extract_slice
+//       CHECK:  linalg.generic
+
+// -----
+
+func.func @no_propagate_extract_differentblock_2(%input : tensor<128xf32>, %arg0 : index, %arg1 : tensor<?xbf16>) -> tensor<?xbf16> {
+  %empty = util.optimization_barrier %input : tensor<128xf32>
+  %c0 = arith.constant 0 : index
+  %c32 = arith.constant 32 : index
+  %extracted_slice =  tensor.extract_slice %empty[%arg0] [%arg0] [1] : tensor<128xf32> to tensor<?xf32>
+  %for = scf.for %arg2 = %c0 to %c32 step %arg0 iter_args(%arg4 = %arg1) -> tensor<?xbf16> {
+    %generic = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%extracted_slice : tensor<?xf32>) outs(%arg4 : tensor<?xbf16>) {
+    ^bb0(%in: f32, %out: bf16):
+      %1 = arith.truncf %in : f32 to bf16
+      linalg.yield %1 : bf16
+    } -> tensor<?xbf16>
+    scf.yield %generic : tensor<?xbf16>
+  }
+  return %for : tensor<?xbf16>
+}
+
+// CHECK-LABEL: func.func @no_propagate_extract_differentblock_2
+//       CHECK:  tensor.extract_slice
+//       CHECK:  linalg.generic