Flatten tuples when generating ABI signatures. (#8502)
This fixes some legacy Jax integration tests which were relying on this
behavior.
diff --git a/integrations/tensorflow/iree_tf_compiler/MHLO/EmitDefaultIREEABI.cpp b/integrations/tensorflow/iree_tf_compiler/MHLO/EmitDefaultIREEABI.cpp
index b284839..10e2b45 100644
--- a/integrations/tensorflow/iree_tf_compiler/MHLO/EmitDefaultIREEABI.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/MHLO/EmitDefaultIREEABI.cpp
@@ -41,7 +41,8 @@
}
json::Array refArgs;
- for (Type t : funcOp.getArgumentTypes()) {
+ SmallVector<Type> argTypes = flattenTypes(funcOp.getArgumentTypes());
+ for (Type t : argTypes) {
auto descriptor = mapTypeToJsonTypeRecord(t);
if (!descriptor) {
funcOp.emitWarning()
@@ -53,7 +54,8 @@
}
json::Array refReturns;
- for (Type t : funcOp.getCallableResults()) {
+ SmallVector<Type> resultTypes = flattenTypes(funcOp.getCallableResults());
+ for (Type t : resultTypes) {
auto descriptor = mapTypeToJsonTypeRecord(t);
if (!descriptor) {
funcOp.emitWarning()
@@ -76,6 +78,22 @@
funcOp->setAttr("iree.abi", builder.getStringAttr(refStr));
}
+ SmallVector<Type> flattenTypes(ArrayRef<Type> types) {
+ SmallVector<Type> flattened;
+ std::function<void(ArrayRef<Type>)> helper =
+ [&](ArrayRef<Type> types) -> void {
+ for (Type t : types) {
+ if (auto tt = t.dyn_cast<TupleType>()) {
+ helper(tt.getTypes());
+ } else {
+ flattened.push_back(t);
+ }
+ }
+ };
+ helper(types);
+ return flattened;
+ }
+
llvm::Optional<json::Value> mapTypeToJsonTypeRecord(Type type) {
if (auto shapedType = type.dyn_cast<ShapedType>()) {
json::Array record({
diff --git a/integrations/tensorflow/iree_tf_compiler/MHLO/test/emit_default_iree_abi.mlir b/integrations/tensorflow/iree_tf_compiler/MHLO/test/emit_default_iree_abi.mlir
index f889a51..7efe5b2 100644
--- a/integrations/tensorflow/iree_tf_compiler/MHLO/test/emit_default_iree_abi.mlir
+++ b/integrations/tensorflow/iree_tf_compiler/MHLO/test/emit_default_iree_abi.mlir
@@ -5,3 +5,11 @@
func @valid(%arg0: tensor<2x3xf32>, %arg1: tensor<3xf32>) -> (tensor<3xf32>, tensor<2x3xf32>) {
return %arg1, %arg0 : tensor<3xf32>, tensor<2x3xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @tupled
+// CHECK-SAME{LITERAL}: iree.abi = "{\22a\22:[[\22ndarray\22,\22f32\22,1,3],[\22ndarray\22,\22f32\22,2,2,3]],\22r\22:[[\22ndarray\22,\22f32\22,1,3],[\22ndarray\22,\22f32\22,2,2,3]],\22v\22:1}"
+func @tupled(%arg0: tuple<tensor<3xf32>, tensor<2x3xf32>>) -> tuple<tensor<3xf32>, tensor<2x3xf32>> {
+ return %arg0 : tuple<tensor<3xf32>, tensor<2x3xf32>>
+}