Correctly handle 1:1 conversions for calls in FunctionSignatureExpansionPass (#6997)
The code was copypasted from upstream DecomposeCallGraphTypes, which is
unused and had a latent bug for materializing 1:1 conversions. It turns
out the code can be simplified.
diff --git a/iree/compiler/Dialect/Shape/Transforms/FunctionSignatureExpansionPass.cpp b/iree/compiler/Dialect/Shape/Transforms/FunctionSignatureExpansionPass.cpp
index edd0814..7b5b841 100644
--- a/iree/compiler/Dialect/Shape/Transforms/FunctionSignatureExpansionPass.cpp
+++ b/iree/compiler/Dialect/Shape/Transforms/FunctionSignatureExpansionPass.cpp
@@ -65,16 +65,9 @@
auto decomposedValues = llvm::to_vector<6>(
llvm::map_range(expandedResultIndices[i],
[&](unsigned i) { return newCallOp.getResult(i); }));
- if (decomposedValues.empty()) {
- // No replacement is required.
- replacedValues.push_back(nullptr);
- } else if (decomposedValues.size() == 1) {
- replacedValues.push_back(decomposedValues.front());
- } else {
- Value materialized = typeExpander.castToSource(loc, callOp.getType(i),
- decomposedValues, builder);
- replacedValues.push_back(materialized);
- }
+ Value materialized = typeExpander.castToSource(loc, callOp.getType(i),
+ decomposedValues, builder);
+ replacedValues.push_back(materialized);
}
callOp.replaceAllUsesWith(replacedValues);
return success();
diff --git a/iree/compiler/Dialect/Shape/Transforms/test/expand_function_ranked_shape_dims.mlir b/iree/compiler/Dialect/Shape/Transforms/test/expand_function_ranked_shape_dims.mlir
index 9ad16d5..5e29f4b 100644
--- a/iree/compiler/Dialect/Shape/Transforms/test/expand_function_ranked_shape_dims.mlir
+++ b/iree/compiler/Dialect/Shape/Transforms/test/expand_function_ranked_shape_dims.mlir
@@ -43,3 +43,18 @@
%0 = std.call @calls(%arg0) : (!shapex.ranked_shape<[?,?]>) -> !shapex.ranked_shape<[?,?]>
return %0 : !shapex.ranked_shape<[?,?]>
}
+
+// -----
+// CHECK-LABEL: func @oneUnknownDimension(
+// CHECK-SAME: %[[ARG:.*]]: index) -> index {
+// CHECK: %[[ARG_RS:.*]] = shapex.make_ranked_shape %[[ARG]] : (index) -> !shapex.ranked_shape<[?]>
+// CHECK: %[[ARG_DIM0:.*]] = shapex.ranked_dim %[[ARG_RS]][0] : !shapex.ranked_shape<[?]> -> index
+// CHECK: %[[CALL:.*]] = call @oneUnknownDimension(%[[ARG_DIM0]]) : (index) -> index
+// CHECK: %[[CALL_SHAPE:.*]] = shapex.make_ranked_shape %[[CALL]] : (index) -> !shapex.ranked_shape<[?]>
+// CHECK: %[[CALL_DIM0:.*]] = shapex.ranked_dim %[[CALL_SHAPE]][0] : !shapex.ranked_shape<[?]> -> index
+// CHECK: return %[[CALL_DIM0]] : index
+
+func @oneUnknownDimension(%arg0 :!shapex.ranked_shape<[?]>) -> !shapex.ranked_shape<[?]> {
+ %0 = std.call @oneUnknownDimension(%arg0) : (!shapex.ranked_shape<[?]>) -> !shapex.ranked_shape<[?]>
+ return %0 : !shapex.ranked_shape<[?]>
+}