[DispatchCreation] Fix trailing unit dims case for collapse of expand folding (#21677)
Previous logic only prevented collapsing all unit dims in a
reassociation element if the first element of the reassociation
represented a unit dim in the input.
i.e. it worked for cases like 1x1x44x5 -> 1x44x5
but failed for 5x44x1x1 -> 5x44x1
---------
Signed-off-by: dan <danimal197@gmail.com>
Signed-off-by: Ian Wood <ianwood@u.northwestern.edu>
Co-authored-by: Ian Wood <ianwood@u.northwestern.edu>
diff --git a/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp b/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp
index 5e24ca5..5ae7d7f 100644
--- a/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp
@@ -87,9 +87,10 @@
continue;
}
- // If we are collapsing multiple unit dims together, at least 1 must be
- // kept (prefer the first).
- if (outShape[outDim] == 1 && innerIdx != 0) {
+ // If outShape[outDim] == 1, we must preserve 1 unit dim,
+ // so we drop the first. If the first is the only unit dim,
+ // we can't drop it anyway.
+ if (outShape[outDim] == 1 && innerIdx == 0) {
continue;
}
toDrop.insert(inDim);
@@ -99,8 +100,13 @@
// Remove dimensions from `toDrop` that weren't introduced by the
// `expandOp` op.
const auto expandReassoc = expandOp.getReassociationIndices();
- for (const auto &[inDim, indices] : llvm::enumerate(expandReassoc)) {
- if (indices.size() == 1) {
+ for (const auto &indices : expandReassoc) {
+ // If all of indices are in `toDrop`, we must preserve at least one
+ // to avoid an empty reassociation map during expansion.
+ // This can happen when outShape does not have a unit dimension
+ // corresponding to the unit dimensions being dropped here.
+ if (llvm::all_of(indices,
+ [&](int64_t idx) { return toDrop.contains(idx); })) {
toDrop.erase(indices[0]);
}
}
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir b/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir
index 05f9758..d27c31a 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir
@@ -312,3 +312,35 @@
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]]
// CHECK-SAME: tensor<1x1xf16> into tensor<f16>
// CHECK: util.return %[[COLLAPSED]] : tensor<f16>
+
+// -----
+
+util.func @collapse_of_expand_trailing_unit_dims(%arg0: tensor<23040x1xbf16>) -> tensor<4x5760xbf16> {
+ %expanded = tensor.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [4, 5760, 1, 1] : tensor<23040x1xbf16> into tensor<4x5760x1x1xbf16>
+ %collapsed = tensor.collapse_shape %expanded [[0], [1, 2, 3]] : tensor<4x5760x1x1xbf16> into tensor<4x5760xbf16>
+ util.return %collapsed : tensor<4x5760xbf16>
+}
+// CHECK-LABEL: util.func public @collapse_of_expand_trailing_unit_dims
+// CHECK-SAME: %[[ARG0:.+]]: tensor<23040x1xbf16>
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]]
+// CHECK-SAME: tensor<23040x1xbf16> into tensor<4x5760x1xbf16>
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]]
+// CHECK-SAME: tensor<4x5760x1xbf16> into tensor<4x5760xbf16>
+// CHECK: util.return %[[COLLAPSE]] : tensor<4x5760xbf16>
+
+// -----
+
+// This test considers the case where we have multiple trailing unit dims but must preserve one for the output,
+// as well as an isolated unit dim that must be preserved for the collapse's reassociation dims.
+util.func @collapse_of_expand_preserved_trailing_unit_dims(%arg0: tensor<1x23040xbf16>) -> tensor<4x5760x1xbf16> {
+ %expanded = tensor.expand_shape %arg0 [[0], [1, 2, 3, 4, 5]] output_shape [1, 4, 5760, 1, 1, 1] : tensor<1x23040xbf16> into tensor<1x4x5760x1x1x1xbf16>
+ %collapsed = tensor.collapse_shape %expanded [[0, 1], [2], [3, 4, 5]] : tensor<1x4x5760x1x1x1xbf16> into tensor<4x5760x1xbf16>
+ util.return %collapsed : tensor<4x5760x1xbf16>
+}
+// CHECK-LABEL: util.func public @collapse_of_expand_preserved_trailing_unit_dims
+// CHECK-SAME: %[[ARG0:.+]]: tensor<1x23040xbf16>
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]]
+// CHECK-SAME: tensor<1x23040xbf16> into tensor<1x4x5760x1xbf16>
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]]
+// CHECK-SAME: tensor<1x4x5760x1xbf16> into tensor<4x5760x1xbf16>
+// CHECK: util.return %[[COLLAPSE]] : tensor<4x5760x1xbf16>