[GPU] Add chained reshape support for scf.forall expand destination pattern (#19597)
Currently when expanding scf.forall we make trivially foldable
expand.shape ops (same source and destination ranks) with a wrong
reassociation map. However, if there is another expand.shape consumer to
these ops then the upstream `ComposeReassociativeReshapeOps` can merge
these into a wrong expand.shape op that leads to error. So we just
replace the uses of these expand op to avoid this issue.
Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp b/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp
index 5de9e3d..b1b9935 100644
--- a/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp
@@ -88,14 +88,15 @@
return failure();
if (extractSliceOp.getMixedOffsets() != parallelInsertOp.getMixedOffsets())
return failure();
- auto expandShapeOp =
- dyn_cast<tensor::ExpandShapeOp>(*extractSliceOp->getUsers().begin());
- if (!expandShapeOp)
- return failure();
- SmallVector<ReassociationIndices> expandReIndices =
- expandShapeOp.getReassociationIndices();
- if (reIndices != expandReIndices)
- return failure();
+ for (Operation *user : extractSliceOp->getUsers()) {
+ auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(user);
+ if (!expandShapeOp)
+ return failure();
+ SmallVector<ReassociationIndices> expandReIndices =
+ expandShapeOp.getReassociationIndices();
+ if (reIndices != expandReIndices)
+ return failure();
+ }
expandableUsers.push_back(extractSliceOp);
}
return success();
@@ -155,9 +156,14 @@
expandedOffsets, expandedSizes, expandedStrides);
for (tensor::ExtractSliceOp extractSliceOp : expandableUsers) {
rewriter.setInsertionPoint(extractSliceOp);
- rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
- extractSliceOp, resultType, extractSliceOp.getSource(), expandedOffsets,
- expandedSizes, expandedStrides);
+ auto newExtractSliceOp =
+ rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
+ extractSliceOp, resultType, extractSliceOp.getSource(),
+ expandedOffsets, expandedSizes, expandedStrides);
+ for (Operation *user : newExtractSliceOp->getUsers()) {
+ auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(user);
+ expandShapeOp->replaceAllUsesWith(newExtractSliceOp);
+ }
}
return;
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir b/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir
index 88abd0c..abced17 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir
@@ -337,3 +337,54 @@
// CHECK: flow.dispatch.tensor.store %[[SCFFORALL]], %[[SUBSPAN]]
// CHECK-SAME: offsets = [1], sizes = [32], strides = [1] : tensor<32xf32>
// CHECK-SAME: !flow.dispatch.tensor<writeonly:tensor<34xf32>>
+
+// -----
+#pipeline_layout = #hal.pipeline.layout<constants = 1, bindings = [
+ #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>
+func.func @expand_dest_forall_chained() {
+ %cst = arith.constant 0.000000e+00 : f16
+ %c0 = arith.constant 0 : index
+ %index = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
+ %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0)
+ flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<?x64x32xf32>>{%index}
+ %1 = tensor.empty(%index) : tensor<?x64x32xf32>
+ %extra = tensor.empty() : tensor<32x32xf32>
+ %2 = scf.forall (%arg0, %arg1) = (0, 0) to (64, 32) step (16, 16)
+ shared_outs(%arg2 = %1) -> (tensor<?x64x32xf32>) {
+ %extracted_slice = tensor.extract_slice %arg2[%c0, %arg0, %arg1] [1, 16, 16] [1, 1, 1]
+ : tensor<?x64x32xf32> to tensor<1x16x16xf32>
+ %expanded = tensor.expand_shape %extracted_slice [[0], [1], [2, 3, 4]]
+ output_shape [1, 16, 2, 4, 2] : tensor<1x16x16xf32> into tensor<1x16x2x4x2xf32>
+ %expanded2 = tensor.expand_shape %expanded [[0], [1, 2], [3], [4], [5]]
+ output_shape [1, 8, 2, 2, 4, 2] : tensor<1x16x2x4x2xf32> into tensor<1x8x2x2x4x2xf32>
+ %expanded_barrier = util.optimization_barrier %expanded2 : tensor<1x8x2x2x4x2xf32>
+ %collapsed = tensor.collapse_shape %expanded_barrier [[0], [1, 2], [3], [4], [5]] : tensor<1x8x2x2x4x2xf32> into tensor<1x16x2x4x2xf32>
+ %collapsed2 = tensor.collapse_shape %collapsed [[0], [1], [2, 3, 4]] : tensor<1x16x2x4x2xf32> into tensor<1x16x16xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %collapsed2 into %arg2[%c0, %arg0, %arg1] [1, 16, 16] [1, 1, 1]
+ : tensor<1x16x16xf32> into tensor<?x64x32xf32>
+ }
+ } {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
+ flow.dispatch.tensor.store %2, %0, offsets = [0, 0, 0], sizes = [%index, 64, 32], strides = [1, 1, 1]
+ : tensor<?x64x32xf32> -> !flow.dispatch.tensor<writeonly:tensor<?x64x32xf32>>{%index}
+ return
+}
+
+// CHECK-LABEL: func @expand_dest_forall_chained(
+// CHECK: %[[LOAD_CONST:.+]] = hal.interface.constant.load
+// CHECK: %[[SUBSPAN:.+]] = hal.interface.binding.subspan
+// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[LOAD_CONST]]) : tensor<?x32x2x4x4x2xf32>
+// CHECK: %[[SCFFORALL:.+]] = scf.forall (%[[ARG0:.+]], %[[ARG1:.+]]) = (0, 0)
+// CHECK-SAME: shared_outs(%[[ARG2:.+]] = %[[EMPTY]]) -> (tensor<?x32x2x4x4x2xf32>) {
+// CHECK-DAG: %[[OFFSET0:.+]] = affine.apply affine_map<()[s0] -> (s0 floordiv 8)>()[%[[ARG1]]]
+// CHECK-DAG: %[[OFFSET1:.+]] = affine.apply affine_map<()[s0] -> (s0 floordiv 2)>()[%[[ARG0]]]
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[ARG2]]
+// CHECK-SAME: [0, %[[OFFSET1]], 0, %[[OFFSET0]], 0, 0] [1, 8, 2, 2, 4, 2] [1, 1, 1, 1, 1, 1]
+// CHECK-SAME: tensor<?x32x2x4x4x2xf32> to tensor<1x8x2x2x4x2xf32>
+// CHECK: %[[BARRIER:.+]] = util.optimization_barrier %[[EXTRACT]] : tensor<1x8x2x2x4x2xf32>
+// CHECK: tensor.parallel_insert_slice %[[BARRIER]] into %[[ARG2]]
+// CHECK-SAME: [0, %[[OFFSET1]], 0, %[[OFFSET0]], 0, 0] [1, 8, 2, 2, 4, 2] [1, 1, 1, 1, 1, 1]
+// CHECK-SAME: tensor<1x8x2x2x4x2xf32> into tensor<?x32x2x4x4x2xf32>
+// CHECK: flow.dispatch.tensor.store %[[SCFFORALL]], %[[SUBSPAN]]
+// CHECK-SAME: offsets = [0, 0, 0, 0, 0, 0], sizes = [%[[LOAD_CONST]], 32, 2, 4, 4, 2], strides = [1, 1, 1, 1, 1, 1]
+// CHECK-SAME: !flow.dispatch.tensor<writeonly:tensor<?x32x2x4x4x2xf32>>{%[[LOAD_CONST]]}