[GPU] Remove reshape by expansion in workgroup scope of combine layout pass (#21869)

This makes sense in the dispatch scope, but when doing it in the
workgroup scope can lead to more reshape ops and lead to problems in
in-place bufferizations.

Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/CombineLayoutTransformation.cpp b/compiler/src/iree/compiler/Codegen/Common/CombineLayoutTransformation.cpp
index 193d28c..82f7f02 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CombineLayoutTransformation.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/CombineLayoutTransformation.cpp
@@ -455,6 +455,15 @@
              linalg::TransposeOp>(op);
 }
 
+// This is only desirable in the dispatch scope but not in the workgroup scope.
+static bool
+shouldDoReshapesByExpansion(IREE::Codegen::RelayoutCombinationScope scope) {
+  if (scope == IREE::Codegen::RelayoutCombinationScope::Dispatch) {
+    return true;
+  }
+  return false;
+}
+
 /// Insert identity map_scatter ops after the given operation if it is a valid
 /// leaf op of a relayout op chain. A relayout op chain is a sequence of
 /// relayout ops (defined by `isSupportedRelayoutOp`) for which the only users
@@ -497,6 +506,7 @@
 LogicalResult
 combineLayoutTransformation(MLIRContext *ctx, FunctionOpInterface funcOp,
                             PadDistributionConfigFn padDistributionConfigFn,
+                            bool doReshapeByExpansion,
                             CombineRelayoutOpsControlFnRef controlFn) {
   // Sink relayout operations to the end of the funcOp.
   RewritePatternSet propagationPatterns(ctx);
@@ -504,13 +514,16 @@
   tensor::ExpandShapeOp::getCanonicalizationPatterns(propagationPatterns, ctx);
   tensor::CollapseShapeOp::getCanonicalizationPatterns(propagationPatterns,
                                                        ctx);
-  // Only sink reshape ops, so bail if the consumer operation is a reshape.
-  auto controlSinkReshapesFn = [](OpOperand *operand) -> bool {
-    Operation *consumer = operand->getOwner();
-    return !llvm::isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp>(consumer);
-  };
-  linalg::populateFoldReshapeOpsByExpansionPatterns(propagationPatterns,
-                                                    controlSinkReshapesFn);
+  if (doReshapeByExpansion) {
+    // Only sink reshape ops, so bail if the consumer operation is a reshape.
+    auto controlSinkReshapesFn = [](OpOperand *operand) -> bool {
+      Operation *consumer = operand->getOwner();
+      return !llvm::isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp>(
+          consumer);
+    };
+    linalg::populateFoldReshapeOpsByExpansionPatterns(propagationPatterns,
+                                                      controlSinkReshapesFn);
+  }
   // Only sink unpack ops, so bail if the producer operation is not an unpack.
   // Also only sink unpack ops when new pack operations will not be created.
   // This means the consumer op must have at most one additional destination
@@ -671,9 +684,11 @@
   void runOnOperation() override {
     CombineRelayoutOpsControlFn controlFn =
         getCombineRelayoutOpsControlFn(this->scope);
-    if (failed(combineLayoutTransformation(
-            &getContext(), getOperation(),
-            defaultPadWorkgroupDistributionConfigFn, controlFn))) {
+    bool doReshapesByExpansion = shouldDoReshapesByExpansion(this->scope);
+    if (failed(
+            combineLayoutTransformation(&getContext(), getOperation(),
+                                        defaultPadWorkgroupDistributionConfigFn,
+                                        doReshapesByExpansion, controlFn))) {
       return signalPassFailure();
     }
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/CombineLayoutTransformation.h b/compiler/src/iree/compiler/Codegen/Common/CombineLayoutTransformation.h
index 53f9c20..4486b2a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CombineLayoutTransformation.h
+++ b/compiler/src/iree/compiler/Codegen/Common/CombineLayoutTransformation.h
@@ -94,6 +94,7 @@
 LogicalResult
 combineLayoutTransformation(MLIRContext *ctx, FunctionOpInterface funcOp,
                             PadDistributionConfigFn padDistributionConfigFn,
+                            bool doReshapeByExpansion,
                             CombineRelayoutOpsControlFnRef controlFn = nullptr);
 
 } // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCombineLayoutTransformation.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCombineLayoutTransformation.cpp
index 1bb1b24..0add7ae 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCombineLayoutTransformation.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCombineLayoutTransformation.cpp
@@ -70,9 +70,9 @@
     // Workgroup scope.
     CombineRelayoutOpsControlFn controlFn = getCombineRelayoutOpsControlFn(
         IREE::Codegen::RelayoutCombinationScope::Dispatch);
-    if (failed(combineLayoutTransformation(&getContext(), getOperation(),
-                                           gpuPadDistributionConfigFn,
-                                           controlFn))) {
+    if (failed(combineLayoutTransformation(
+            &getContext(), getOperation(), gpuPadDistributionConfigFn,
+            /*doReshapeByExpansion=*/true, controlFn))) {
       return signalPassFailure();
     }
   }
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/combine_layout_transformation.mlir b/compiler/src/iree/compiler/Codegen/Common/test/combine_layout_transformation.mlir
index f7ed80a..bc1c850 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/combine_layout_transformation.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/combine_layout_transformation.mlir
@@ -317,6 +317,9 @@
 //       DISPATCH-SCOPE:   %[[MAP_SCATTER:.+]] = iree_linalg_ext.map_scatter %[[COMPUTE_OP]]
 //       DISPATCH-SCOPE:   iree_codegen.store_to_buffer %[[MAP_SCATTER]]
 
+// WORKGROUP-SCOPE-LABEL: @propagate_relayout_ops
+//       WORKGROUP-SCOPE: linalg.generic {{.*}} outs(%{{.*}} : tensor<?xf16>) {
+
 // -----
 
 func.func @insert_in_workgroup_forall(%2 : tensor<32xbf16>, %3 : tensor<32xbf16>, %9 : tensor<10xbf16>) -> (tensor<32xbf16>, tensor<32xbf16>) {