[DispatchCreation] Fix trailing unit dims case for collapse of expand folding (#21677)

Previous logic only prevented collapsing all unit dims in a
reassociation element if the first element of the reassociation
represented a unit dim in the input.

i.e. it worked for cases like 1x1x44x5 -> 1x44x5
but failed for 5x44x1x1 -> 5x44x1

---------

Signed-off-by: dan <danimal197@gmail.com>
Signed-off-by: Ian Wood <ianwood@u.northwestern.edu>
Co-authored-by: Ian Wood <ianwood@u.northwestern.edu>
diff --git a/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp b/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp
index 5e24ca5..5ae7d7f 100644
--- a/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp
@@ -87,9 +87,10 @@
           continue;
         }
 
-        // If we are collapsing multiple unit dims together, at least 1 must be
-        // kept (prefer the first).
-        if (outShape[outDim] == 1 && innerIdx != 0) {
+        // If outShape[outDim] == 1, we must preserve 1 unit dim,
+        // so we drop the first. If the first is the only unit dim,
+        // we can't drop it anyway.
+        if (outShape[outDim] == 1 && innerIdx == 0) {
           continue;
         }
         toDrop.insert(inDim);
@@ -99,8 +100,13 @@
     // Remove dimensions from `toDrop` that weren't introduced by the
     // `expandOp` op.
     const auto expandReassoc = expandOp.getReassociationIndices();
-    for (const auto &[inDim, indices] : llvm::enumerate(expandReassoc)) {
-      if (indices.size() == 1) {
+    for (const auto &indices : expandReassoc) {
+      // If all of indices are in `toDrop`, we must preserve at least one
+      // to avoid an empty reassociation map during expansion.
+      // This can happen when outShape does not have a unit dimension
+      // corresponding to the unit dimensions being dropped here.
+      if (llvm::all_of(indices,
+                       [&](int64_t idx) { return toDrop.contains(idx); })) {
         toDrop.erase(indices[0]);
       }
     }
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir b/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir
index 05f9758..d27c31a 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir
@@ -312,3 +312,35 @@
 //       CHECK:   %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]]
 //  CHECK-SAME:     tensor<1x1xf16> into tensor<f16>
 //       CHECK:   util.return %[[COLLAPSED]] : tensor<f16>
+
+// -----
+
+util.func @collapse_of_expand_trailing_unit_dims(%arg0: tensor<23040x1xbf16>) -> tensor<4x5760xbf16> {
+  %expanded = tensor.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [4, 5760, 1, 1] : tensor<23040x1xbf16> into tensor<4x5760x1x1xbf16>
+  %collapsed = tensor.collapse_shape %expanded [[0], [1, 2, 3]] : tensor<4x5760x1x1xbf16> into tensor<4x5760xbf16>
+  util.return %collapsed : tensor<4x5760xbf16>
+}
+// CHECK-LABEL: util.func public @collapse_of_expand_trailing_unit_dims
+//  CHECK-SAME:   %[[ARG0:.+]]: tensor<23040x1xbf16>
+//       CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]]
+//  CHECK-SAME:     tensor<23040x1xbf16> into tensor<4x5760x1xbf16>
+//       CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]]
+//  CHECK-SAME:     tensor<4x5760x1xbf16> into tensor<4x5760xbf16>
+//       CHECK:   util.return %[[COLLAPSE]] : tensor<4x5760xbf16>
+
+// -----
+
+// This test considers the case where we have multiple trailing unit dims but must preserve one for the output,
+// as well as an isolated unit dim that must be preserved for the collapse's reassociation dims.
+util.func @collapse_of_expand_preserved_trailing_unit_dims(%arg0: tensor<1x23040xbf16>) -> tensor<4x5760x1xbf16> {
+  %expanded = tensor.expand_shape %arg0 [[0], [1, 2, 3, 4, 5]] output_shape [1, 4, 5760, 1, 1, 1] : tensor<1x23040xbf16> into tensor<1x4x5760x1x1x1xbf16>
+  %collapsed = tensor.collapse_shape %expanded [[0, 1], [2], [3, 4, 5]] : tensor<1x4x5760x1x1x1xbf16> into tensor<4x5760x1xbf16>
+  util.return %collapsed : tensor<4x5760x1xbf16>
+}
+// CHECK-LABEL: util.func public @collapse_of_expand_preserved_trailing_unit_dims
+//  CHECK-SAME:   %[[ARG0:.+]]: tensor<1x23040xbf16>
+//       CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]]
+//  CHECK-SAME:     tensor<1x23040xbf16> into tensor<1x4x5760x1xbf16>
+//       CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]]
+//  CHECK-SAME:     tensor<1x4x5760x1xbf16> into tensor<4x5760x1xbf16>
+//       CHECK:   util.return %[[COLLAPSE]] : tensor<4x5760x1xbf16>