Fix rank-mismatch issue in flow.tensor.reshape folders. (#5475)
If the ranks mismatch, we can `return false` ealier. It would cause
crash when the rank of lhs is greater than the rnak of rhs.
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
index 6fe680b..3ccd458 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
+++ b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
@@ -571,6 +571,9 @@
// Static shape equivalence means we can fast-path the check.
return true;
}
+ if (lhsType.getRank() != rhsType.getRank()) {
+ return false;
+ }
unsigned dynamicDimIndex = 0;
for (unsigned i = 0; i < lhsType.getRank(); ++i) {
if (lhsType.isDynamicDim(i) != rhsType.isDynamicDim(i)) {
diff --git a/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir b/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
index a31ab03..27515c5 100644
--- a/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
+++ b/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
@@ -20,6 +20,15 @@
// -----
+// CHECK-LABEL: @reshapeRankDifferent
+func @reshapeRankDifferent(%arg0: tensor<1xf32>) -> tensor<f32> {
+ // CHECK-NEXT: flow.tensor.reshape %arg0
+ %0 = flow.tensor.reshape %arg0 : tensor<1xf32> -> tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// -----
+
// CHECK-LABEL: @reshapeStaticDifferent
func @reshapeStaticDifferent(%arg0: tensor<1x4xf32>) -> tensor<4x1xf32> {
// CHECK-NEXT: flow.tensor.reshape %arg0