[Flow] Fold flow reshape with mismatching dyn dims (#18680)
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
index 1a60d1c..6930906 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
@@ -744,15 +744,15 @@
static bool compareShapesEqual(ShapedType lhsType, ValueRange lhsDynamicDims,
ShapedType rhsType, ValueRange rhsDynamicDims) {
- if (lhsType.hasStaticShape() && rhsType.hasStaticShape() &&
- lhsType == rhsType) {
+ if (lhsType.hasStaticShape() && rhsType.hasStaticShape()) {
// Static shape equivalence means we can fast-path the check.
- return true;
+ return lhsType == rhsType;
}
if (lhsType.getRank() != rhsType.getRank()) {
return false;
}
unsigned dynamicDimIndex = 0;
+ unsigned numNonmatchingSSADims = 0;
for (unsigned i = 0; i < lhsType.getRank(); ++i) {
if (lhsType.isDynamicDim(i) != rhsType.isDynamicDim(i)) {
// Static/dynamic dimension mismatch - definitely differ.
@@ -760,8 +760,7 @@
} else if (lhsType.isDynamicDim(i)) {
unsigned j = dynamicDimIndex++;
if (lhsDynamicDims[j] != rhsDynamicDims[j]) {
- // Dynamic dimensions with different SSA values - probably differ.
- return false;
+ numNonmatchingSSADims++;
}
} else {
if (lhsType.getDimSize(i) != rhsType.getDimSize(i)) {
@@ -770,7 +769,7 @@
}
}
}
- return true;
+ return numNonmatchingSSADims <= 1;
}
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
index 959e398..1559c3c 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
@@ -112,26 +112,38 @@
// CHECK-LABEL: @reshapeDynamicDifferent
util.func public @reshapeDynamicDifferent(%arg0: tensor<4x?xf32>, %dim0: index, %dim1: index) -> tensor<4x?xf32> {
- // CHECK-NEXT: flow.tensor.reshape %arg0
+ // CHECK-NEXT: util.return %arg0 : tensor<4x?xf32>
%0 = flow.tensor.reshape %arg0 : tensor<4x?xf32>{%dim0} -> tensor<4x?xf32>{%dim1}
util.return %0 : tensor<4x?xf32>
}
// -----
-// CHECK-LABEL: @flattenReshapeChain
+// CHECK-LABEL: @foldReshapeChain
// CHECK-SAME: %[[ARG:.+]]: tensor<4x?xf32>,
// CHECK-SAME: %[[DIM0:.+]]: index, %[[DIM1:.+]]: index, %[[DIM2:.+]]: index
-util.func public @flattenReshapeChain(%arg0: tensor<4x?xf32>, %dim0: index, %dim1: index, %dim2: index) -> tensor<4x?xf32> {
- // CHECK-NEXT: %[[RET:.+]] = flow.tensor.reshape %[[ARG]] : tensor<4x?xf32>{%[[DIM0]]} -> tensor<4x?xf32>{%[[DIM2]]}
+util.func public @foldReshapeChain(%arg0: tensor<4x?xf32>, %dim0: index, %dim1: index, %dim2: index) -> tensor<4x?xf32> {
%0 = flow.tensor.reshape %arg0 : tensor<4x?xf32>{%dim0} -> tensor<4x?xf32>{%dim1}
%1 = flow.tensor.reshape %0 : tensor<4x?xf32>{%dim1} -> tensor<4x?xf32>{%dim2}
- // CHECK-NEXT: util.return %[[RET]]
+ // CHECK-NEXT: util.return %[[ARG]]
util.return %1 : tensor<4x?xf32>
}
// -----
+// CHECK-LABEL: @flattenReshapeChain
+// CHECK-SAME: %[[ARG:.+]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[DIM0:.+]]: index, %[[DIM1:.+]]: index, %[[DIM2:.+]]: index, %[[DIM3:.+]]: index, %[[DIM4:.+]]: index, %[[DIM5:.+]]: index
+util.func public @flattenReshapeChain(%arg0: tensor<?x?xf32>, %dim0: index, %dim1: index, %dim2: index, %dim3 : index, %dim4 : index, %dim5 : index) -> tensor<?x?xf32> {
+ // CHECK-NEXT: %[[RET:.+]] = flow.tensor.reshape %[[ARG]] : tensor<?x?xf32>{%[[DIM0]], %[[DIM1]]} -> tensor<?x?xf32>{%[[DIM4]], %[[DIM5]]}
+ %0 = flow.tensor.reshape %arg0 : tensor<?x?xf32>{%dim0, %dim1} -> tensor<?x?xf32>{%dim2, %dim3}
+ %1 = flow.tensor.reshape %0 : tensor<?x?xf32>{%dim2, %dim3} -> tensor<?x?xf32>{%dim4, %dim5}
+ // CHECK-NEXT: util.return %[[RET]]
+ util.return %1 : tensor<?x?xf32>
+}
+
+// -----
+
// CHECK-LABEL: @flattenReshapeBitcastChain
// CHECK-SAME: %[[ARG:.+]]: tensor<4x?xi16>,
// CHECK-SAME: %[[DIM0:.+]]: index, %[[DIM1:.+]]: index, %[[DIM2:.+]]: index