[Codegen] Don't require full slice to decompose boundary pack and unpack ops (#18906)
This PR loosens the restrictions on decomposing boundary pack and unpack
ops. The current restriction is that the dispatch.tensor.load/store ops
are full slices, but this is not necessary for the current use case in
the TileAndFuse pipeline.
Instead, it is better for the time being to decompose non-padded
pack/unpack ops at function boundaries regardless of the
dispatch.tensor.load/store ops being full slices, because decomposing
such ops later on can cause issues with DPS. The DPS issues are tracked
in https://github.com/iree-org/iree/issues/18902, but we can loosen the
restrictions regardless, since it does not pose any issues to decompose
in such cases.
Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp b/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp
index f816941..fed4470 100644
--- a/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp
@@ -314,51 +314,31 @@
}
/// Control function for decomposing pack and unpack ops. Returns true if the
-/// op is a pack or unpack op, and its reshapes can be folded with a producer
-/// or consumer interface tensor op. To be foldable, the following conditions
-/// must be met:
-///
+/// op is an unpadded pack or unpack op, and it is at the boundary of a
+/// dispatch. The following conditions need to be met:
/// 1. The PackOp or UnPackOp must have no padding.
/// 2. If the op is a PackOp, then its producer must be a dispatch tensor load.
/// 3. If the op is an UnPackOp, then all of its consumers must be dispatch
/// tensor stores.
-/// 4. Any dispatch tensor load producers or dispatch tensor store consumers
-/// must be full slices.
-static LogicalResult isFoldableIntoInterfaceTensor(Operation *op) {
- // Full slice means zero offsets, unit strides, and sizes match full tensor
- // shape.
- auto isFullSlice =
- [](ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
- ArrayRef<OpFoldResult> strides, ArrayRef<int64_t> fullTensorShape) {
- return areAllConstantIntValue(offsets, 0) &&
- areAllConstantIntValue(strides, 1) &&
- areConstantIntValues(sizes, fullTensorShape);
- };
- if (!isa<tensor::PackOp, tensor::UnPackOp>(op)) {
+static LogicalResult isUnpaddedAndAtBoundary(Operation *op) {
+ if (!isa<tensor::PackOp>(op) && !isa<tensor::UnPackOp>(op)) {
return failure();
}
if (hasPadding(op)) {
return failure();
}
- // If the producer is a full slice dispatch tensor load, then the `op` is
- // foldable if it is a PackOp.
- auto load = dyn_cast<IREE::Flow::DispatchTensorLoadOp>(
- op->getOperand(0).getDefiningOp());
- if (isa<tensor::PackOp>(op) && load &&
- isFullSlice(load.getMixedOffsets(), load.getMixedSizes(),
- load.getMixedStrides(), load.getSourceType().getShape())) {
+ // If the producer is a dispatch tensor load, then the `op` is decomposable
+ // if it is a PackOp.
+ if (isa<tensor::PackOp>(op) &&
+ op->getOperand(0).getDefiningOp<IREE::Flow::DispatchTensorLoadOp>()) {
return success();
}
- // If all consumers are full slice dispatch tensor stores, then the `op` is
- // foldable if it is an UnPackOp.
+ // If all consumers are dispatch tensor stores, then the `op` is decomposable
+ // if it is an UnPackOp.
if (isa<tensor::UnPackOp>(op) &&
llvm::all_of(op->getUsers(), [&](Operation *user) {
- auto store = dyn_cast<IREE::Flow::DispatchTensorStoreOp>(user);
- return store &&
- isFullSlice(store.getMixedOffsets(), store.getMixedSizes(),
- store.getMixedStrides(),
- store.getTargetType().getShape());
+ return isa<IREE::Flow::DispatchTensorStoreOp>(user);
})) {
return success();
}
@@ -368,7 +348,7 @@
void DecomposeBoundaryPackUnPackOpsPass::runOnOperation() {
if (failed(commonRunOnOperation(&getContext(), getOperation(),
/*useOnlyReshapes=*/true, tileOuterToOne,
- isFoldableIntoInterfaceTensor))) {
+ isUnpaddedAndAtBoundary))) {
return signalPassFailure();
}
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/decompose_boundary_pack_unpack_ops.mlir b/compiler/src/iree/compiler/Codegen/Common/test/decompose_boundary_pack_unpack_ops.mlir
index 096043b..6ff5bed 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/decompose_boundary_pack_unpack_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/decompose_boundary_pack_unpack_ops.mlir
@@ -133,7 +133,7 @@
return
}
// CHECK-LABEL: func.func @load_non_full_slice
-// CHECK: tensor.pack
+// CHECK-NOT: tensor.pack
// -----
@@ -152,7 +152,7 @@
return
}
// CHECK-LABEL: func.func @store_non_full_slice
-// CHECK: tensor.unpack
+// CHECK-NOT: tensor.unpack
// -----