[GPU] Add chained reshape support for scf.forall expand destination pattern (#19597)

Currently when expanding scf.forall we make trivially foldable
expand.shape ops (same source and destination ranks) with a wrong
reassociation map. However, if there is another expand.shape consumer to
these ops then the upstream `ComposeReassociativeReshapeOps` can merge
these into a wrong expand.shape op that leads to error. So we just
replace the uses of these expand op to avoid this issue.

Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp b/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp
index 5de9e3d..b1b9935 100644
--- a/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp
@@ -88,14 +88,15 @@
       return failure();
     if (extractSliceOp.getMixedOffsets() != parallelInsertOp.getMixedOffsets())
       return failure();
-    auto expandShapeOp =
-        dyn_cast<tensor::ExpandShapeOp>(*extractSliceOp->getUsers().begin());
-    if (!expandShapeOp)
-      return failure();
-    SmallVector<ReassociationIndices> expandReIndices =
-        expandShapeOp.getReassociationIndices();
-    if (reIndices != expandReIndices)
-      return failure();
+    for (Operation *user : extractSliceOp->getUsers()) {
+      auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(user);
+      if (!expandShapeOp)
+        return failure();
+      SmallVector<ReassociationIndices> expandReIndices =
+          expandShapeOp.getReassociationIndices();
+      if (reIndices != expandReIndices)
+        return failure();
+    }
     expandableUsers.push_back(extractSliceOp);
   }
   return success();
@@ -155,9 +156,14 @@
       expandedOffsets, expandedSizes, expandedStrides);
   for (tensor::ExtractSliceOp extractSliceOp : expandableUsers) {
     rewriter.setInsertionPoint(extractSliceOp);
-    rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
-        extractSliceOp, resultType, extractSliceOp.getSource(), expandedOffsets,
-        expandedSizes, expandedStrides);
+    auto newExtractSliceOp =
+        rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
+            extractSliceOp, resultType, extractSliceOp.getSource(),
+            expandedOffsets, expandedSizes, expandedStrides);
+    for (Operation *user : newExtractSliceOp->getUsers()) {
+      auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(user);
+      expandShapeOp->replaceAllUsesWith(newExtractSliceOp);
+    }
   }
   return;
 }
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir b/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir
index 88abd0c..abced17 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir
@@ -337,3 +337,54 @@
 //       CHECK:   flow.dispatch.tensor.store %[[SCFFORALL]], %[[SUBSPAN]]
 //  CHECK-SAME:   offsets = [1], sizes = [32], strides = [1] : tensor<32xf32>
 //  CHECK-SAME:   !flow.dispatch.tensor<writeonly:tensor<34xf32>>
+
+// -----
+#pipeline_layout = #hal.pipeline.layout<constants = 1, bindings = [
+    #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>
+func.func @expand_dest_forall_chained() {
+  %cst = arith.constant 0.000000e+00 : f16
+  %c0 = arith.constant 0 : index
+  %index = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
+  %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0)
+      flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<?x64x32xf32>>{%index}
+  %1 = tensor.empty(%index) : tensor<?x64x32xf32>
+  %extra = tensor.empty() : tensor<32x32xf32>
+  %2 = scf.forall (%arg0, %arg1) = (0, 0) to (64, 32) step (16, 16)
+    shared_outs(%arg2 = %1) -> (tensor<?x64x32xf32>) {
+    %extracted_slice = tensor.extract_slice %arg2[%c0, %arg0, %arg1] [1, 16, 16] [1, 1, 1]
+         : tensor<?x64x32xf32> to tensor<1x16x16xf32>
+    %expanded = tensor.expand_shape %extracted_slice [[0], [1], [2, 3, 4]]
+              output_shape [1, 16, 2, 4, 2] : tensor<1x16x16xf32> into tensor<1x16x2x4x2xf32>
+    %expanded2 = tensor.expand_shape %expanded [[0], [1, 2], [3], [4], [5]]
+              output_shape [1, 8, 2, 2, 4, 2] : tensor<1x16x2x4x2xf32> into tensor<1x8x2x2x4x2xf32>
+    %expanded_barrier = util.optimization_barrier %expanded2 : tensor<1x8x2x2x4x2xf32>
+    %collapsed = tensor.collapse_shape %expanded_barrier [[0], [1, 2], [3], [4], [5]] :  tensor<1x8x2x2x4x2xf32> into tensor<1x16x2x4x2xf32>
+    %collapsed2 = tensor.collapse_shape %collapsed [[0], [1], [2, 3, 4]] : tensor<1x16x2x4x2xf32> into tensor<1x16x16xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %collapsed2 into %arg2[%c0, %arg0, %arg1] [1, 16, 16] [1, 1, 1]
+        : tensor<1x16x16xf32> into tensor<?x64x32xf32>
+    }
+  } {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
+  flow.dispatch.tensor.store %2, %0, offsets = [0, 0, 0], sizes = [%index, 64, 32], strides = [1, 1, 1]
+     : tensor<?x64x32xf32> -> !flow.dispatch.tensor<writeonly:tensor<?x64x32xf32>>{%index}
+  return
+}
+
+// CHECK-LABEL: func @expand_dest_forall_chained(
+//       CHECK:   %[[LOAD_CONST:.+]] = hal.interface.constant.load
+//       CHECK:   %[[SUBSPAN:.+]] = hal.interface.binding.subspan
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty(%[[LOAD_CONST]]) : tensor<?x32x2x4x4x2xf32>
+//       CHECK:   %[[SCFFORALL:.+]] = scf.forall (%[[ARG0:.+]], %[[ARG1:.+]]) = (0, 0)
+//  CHECK-SAME:       shared_outs(%[[ARG2:.+]] = %[[EMPTY]]) -> (tensor<?x32x2x4x4x2xf32>) {
+//   CHECK-DAG:     %[[OFFSET0:.+]] = affine.apply affine_map<()[s0] -> (s0 floordiv 8)>()[%[[ARG1]]]
+//   CHECK-DAG:     %[[OFFSET1:.+]] = affine.apply affine_map<()[s0] -> (s0 floordiv 2)>()[%[[ARG0]]]
+//       CHECK:     %[[EXTRACT:.+]] = tensor.extract_slice %[[ARG2]]
+//  CHECK-SAME:         [0, %[[OFFSET1]], 0, %[[OFFSET0]], 0, 0] [1, 8, 2, 2, 4, 2] [1, 1, 1, 1, 1, 1]
+//  CHECK-SAME:         tensor<?x32x2x4x4x2xf32> to tensor<1x8x2x2x4x2xf32>
+//       CHECK:     %[[BARRIER:.+]] = util.optimization_barrier %[[EXTRACT]] : tensor<1x8x2x2x4x2xf32>
+//       CHECK:     tensor.parallel_insert_slice %[[BARRIER]] into %[[ARG2]]
+//  CHECK-SAME:         [0, %[[OFFSET1]], 0, %[[OFFSET0]], 0, 0] [1, 8, 2, 2, 4, 2] [1, 1, 1, 1, 1, 1]
+//  CHECK-SAME:         tensor<1x8x2x2x4x2xf32> into tensor<?x32x2x4x4x2xf32>
+//       CHECK:   flow.dispatch.tensor.store %[[SCFFORALL]], %[[SUBSPAN]]
+//  CHECK-SAME:   offsets = [0, 0, 0, 0, 0, 0], sizes = [%[[LOAD_CONST]], 32, 2, 4, 4, 2], strides = [1, 1, 1, 1, 1, 1]
+//  CHECK-SAME:   !flow.dispatch.tensor<writeonly:tensor<?x32x2x4x4x2xf32>>{%[[LOAD_CONST]]}