[GPU] Remove reshape by expansion in workgroup scope of combine layout pass (#21869)
This makes sense in the dispatch scope, but when doing it in the
workgroup scope can lead to more reshape ops and lead to problems in
in-place bufferizations.
Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/CombineLayoutTransformation.cpp b/compiler/src/iree/compiler/Codegen/Common/CombineLayoutTransformation.cpp
index 193d28c..82f7f02 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CombineLayoutTransformation.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/CombineLayoutTransformation.cpp
@@ -455,6 +455,15 @@
linalg::TransposeOp>(op);
}
+// This is only desirable in the dispatch scope but not in the workgroup scope.
+static bool
+shouldDoReshapesByExpansion(IREE::Codegen::RelayoutCombinationScope scope) {
+ if (scope == IREE::Codegen::RelayoutCombinationScope::Dispatch) {
+ return true;
+ }
+ return false;
+}
+
/// Insert identity map_scatter ops after the given operation if it is a valid
/// leaf op of a relayout op chain. A relayout op chain is a sequence of
/// relayout ops (defined by `isSupportedRelayoutOp`) for which the only users
@@ -497,6 +506,7 @@
LogicalResult
combineLayoutTransformation(MLIRContext *ctx, FunctionOpInterface funcOp,
PadDistributionConfigFn padDistributionConfigFn,
+ bool doReshapeByExpansion,
CombineRelayoutOpsControlFnRef controlFn) {
// Sink relayout operations to the end of the funcOp.
RewritePatternSet propagationPatterns(ctx);
@@ -504,13 +514,16 @@
tensor::ExpandShapeOp::getCanonicalizationPatterns(propagationPatterns, ctx);
tensor::CollapseShapeOp::getCanonicalizationPatterns(propagationPatterns,
ctx);
- // Only sink reshape ops, so bail if the consumer operation is a reshape.
- auto controlSinkReshapesFn = [](OpOperand *operand) -> bool {
- Operation *consumer = operand->getOwner();
- return !llvm::isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp>(consumer);
- };
- linalg::populateFoldReshapeOpsByExpansionPatterns(propagationPatterns,
- controlSinkReshapesFn);
+ if (doReshapeByExpansion) {
+ // Only sink reshape ops, so bail if the consumer operation is a reshape.
+ auto controlSinkReshapesFn = [](OpOperand *operand) -> bool {
+ Operation *consumer = operand->getOwner();
+ return !llvm::isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp>(
+ consumer);
+ };
+ linalg::populateFoldReshapeOpsByExpansionPatterns(propagationPatterns,
+ controlSinkReshapesFn);
+ }
// Only sink unpack ops, so bail if the producer operation is not an unpack.
// Also only sink unpack ops when new pack operations will not be created.
// This means the consumer op must have at most one additional destination
@@ -671,9 +684,11 @@
void runOnOperation() override {
CombineRelayoutOpsControlFn controlFn =
getCombineRelayoutOpsControlFn(this->scope);
- if (failed(combineLayoutTransformation(
- &getContext(), getOperation(),
- defaultPadWorkgroupDistributionConfigFn, controlFn))) {
+ bool doReshapesByExpansion = shouldDoReshapesByExpansion(this->scope);
+ if (failed(
+ combineLayoutTransformation(&getContext(), getOperation(),
+ defaultPadWorkgroupDistributionConfigFn,
+ doReshapesByExpansion, controlFn))) {
return signalPassFailure();
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/CombineLayoutTransformation.h b/compiler/src/iree/compiler/Codegen/Common/CombineLayoutTransformation.h
index 53f9c20..4486b2a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CombineLayoutTransformation.h
+++ b/compiler/src/iree/compiler/Codegen/Common/CombineLayoutTransformation.h
@@ -94,6 +94,7 @@
LogicalResult
combineLayoutTransformation(MLIRContext *ctx, FunctionOpInterface funcOp,
PadDistributionConfigFn padDistributionConfigFn,
+ bool doReshapeByExpansion,
CombineRelayoutOpsControlFnRef controlFn = nullptr);
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCombineLayoutTransformation.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCombineLayoutTransformation.cpp
index 1bb1b24..0add7ae 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCombineLayoutTransformation.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCombineLayoutTransformation.cpp
@@ -70,9 +70,9 @@
// Workgroup scope.
CombineRelayoutOpsControlFn controlFn = getCombineRelayoutOpsControlFn(
IREE::Codegen::RelayoutCombinationScope::Dispatch);
- if (failed(combineLayoutTransformation(&getContext(), getOperation(),
- gpuPadDistributionConfigFn,
- controlFn))) {
+ if (failed(combineLayoutTransformation(
+ &getContext(), getOperation(), gpuPadDistributionConfigFn,
+ /*doReshapeByExpansion=*/true, controlFn))) {
return signalPassFailure();
}
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/combine_layout_transformation.mlir b/compiler/src/iree/compiler/Codegen/Common/test/combine_layout_transformation.mlir
index f7ed80a..bc1c850 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/combine_layout_transformation.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/combine_layout_transformation.mlir
@@ -317,6 +317,9 @@
// DISPATCH-SCOPE: %[[MAP_SCATTER:.+]] = iree_linalg_ext.map_scatter %[[COMPUTE_OP]]
// DISPATCH-SCOPE: iree_codegen.store_to_buffer %[[MAP_SCATTER]]
+// WORKGROUP-SCOPE-LABEL: @propagate_relayout_ops
+// WORKGROUP-SCOPE: linalg.generic {{.*}} outs(%{{.*}} : tensor<?xf16>) {
+
// -----
func.func @insert_in_workgroup_forall(%2 : tensor<32xbf16>, %3 : tensor<32xbf16>, %9 : tensor<10xbf16>) -> (tensor<32xbf16>, tensor<32xbf16>) {