[Flow] Fold flow reshape with mismatching dyn dims (#18680)

diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
index 1a60d1c..6930906 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
@@ -744,15 +744,15 @@
 
 static bool compareShapesEqual(ShapedType lhsType, ValueRange lhsDynamicDims,
                                ShapedType rhsType, ValueRange rhsDynamicDims) {
-  if (lhsType.hasStaticShape() && rhsType.hasStaticShape() &&
-      lhsType == rhsType) {
+  if (lhsType.hasStaticShape() && rhsType.hasStaticShape()) {
     // Static shape equivalence means we can fast-path the check.
-    return true;
+    return lhsType == rhsType;
   }
   if (lhsType.getRank() != rhsType.getRank()) {
     return false;
   }
   unsigned dynamicDimIndex = 0;
+  unsigned numNonmatchingSSADims = 0;
   for (unsigned i = 0; i < lhsType.getRank(); ++i) {
     if (lhsType.isDynamicDim(i) != rhsType.isDynamicDim(i)) {
       // Static/dynamic dimension mismatch - definitely differ.
@@ -760,8 +760,7 @@
     } else if (lhsType.isDynamicDim(i)) {
       unsigned j = dynamicDimIndex++;
       if (lhsDynamicDims[j] != rhsDynamicDims[j]) {
-        // Dynamic dimensions with different SSA values - probably differ.
-        return false;
+        numNonmatchingSSADims++;
       }
     } else {
       if (lhsType.getDimSize(i) != rhsType.getDimSize(i)) {
@@ -770,7 +769,7 @@
       }
     }
   }
-  return true;
+  return numNonmatchingSSADims <= 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
index 959e398..1559c3c 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
@@ -112,26 +112,38 @@
 
 // CHECK-LABEL: @reshapeDynamicDifferent
 util.func public @reshapeDynamicDifferent(%arg0: tensor<4x?xf32>, %dim0: index, %dim1: index) -> tensor<4x?xf32> {
-  // CHECK-NEXT: flow.tensor.reshape %arg0
+  // CHECK-NEXT: util.return %arg0 : tensor<4x?xf32>
   %0 = flow.tensor.reshape %arg0 : tensor<4x?xf32>{%dim0} -> tensor<4x?xf32>{%dim1}
   util.return %0 : tensor<4x?xf32>
 }
 
 // -----
 
-// CHECK-LABEL: @flattenReshapeChain
+// CHECK-LABEL: @foldReshapeChain
 // CHECK-SAME: %[[ARG:.+]]: tensor<4x?xf32>,
 // CHECK-SAME: %[[DIM0:.+]]: index, %[[DIM1:.+]]: index, %[[DIM2:.+]]: index
-util.func public @flattenReshapeChain(%arg0: tensor<4x?xf32>, %dim0: index, %dim1: index, %dim2: index) -> tensor<4x?xf32> {
-  // CHECK-NEXT: %[[RET:.+]] = flow.tensor.reshape %[[ARG]] : tensor<4x?xf32>{%[[DIM0]]} -> tensor<4x?xf32>{%[[DIM2]]}
+util.func public @foldReshapeChain(%arg0: tensor<4x?xf32>, %dim0: index, %dim1: index, %dim2: index) -> tensor<4x?xf32> {
   %0 = flow.tensor.reshape %arg0 : tensor<4x?xf32>{%dim0} -> tensor<4x?xf32>{%dim1}
   %1 = flow.tensor.reshape %0 : tensor<4x?xf32>{%dim1} -> tensor<4x?xf32>{%dim2}
-  // CHECK-NEXT: util.return %[[RET]]
+  // CHECK-NEXT: util.return %[[ARG]]
   util.return %1 : tensor<4x?xf32>
 }
 
 // -----
 
+// CHECK-LABEL: @flattenReshapeChain
+// CHECK-SAME: %[[ARG:.+]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[DIM0:.+]]: index, %[[DIM1:.+]]: index, %[[DIM2:.+]]: index, %[[DIM3:.+]]: index, %[[DIM4:.+]]: index, %[[DIM5:.+]]: index
+util.func public @flattenReshapeChain(%arg0: tensor<?x?xf32>, %dim0: index, %dim1: index, %dim2: index, %dim3 : index, %dim4 : index, %dim5 : index) -> tensor<?x?xf32> {
+  // CHECK-NEXT: %[[RET:.+]] = flow.tensor.reshape %[[ARG]] : tensor<?x?xf32>{%[[DIM0]], %[[DIM1]]} -> tensor<?x?xf32>{%[[DIM4]], %[[DIM5]]}
+  %0 = flow.tensor.reshape %arg0 : tensor<?x?xf32>{%dim0, %dim1} -> tensor<?x?xf32>{%dim2, %dim3}
+  %1 = flow.tensor.reshape %0 : tensor<?x?xf32>{%dim2, %dim3} -> tensor<?x?xf32>{%dim4, %dim5}
+  // CHECK-NEXT: util.return %[[RET]]
+  util.return %1 : tensor<?x?xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @flattenReshapeBitcastChain
 // CHECK-SAME: %[[ARG:.+]]: tensor<4x?xi16>,
 // CHECK-SAME: %[[DIM0:.+]]: index, %[[DIM1:.+]]: index, %[[DIM2:.+]]: index