Update tflite bindings to accept both identifier forms (#13195)
Keep the existing iree.identifier support but also add ml_program one
for upstream change.
diff --git a/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp b/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp
index d00e272..66332d3 100644
--- a/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp
+++ b/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp
@@ -63,6 +63,19 @@
"bindings";
}
+ static StringAttr getArgId(func::FuncOp funcOp, int i) {
+ StringAttr id =
+ funcOp.getArgAttrOfType<StringAttr>(i, "ml_program.identifier");
+ return id ? id : funcOp.getArgAttrOfType<StringAttr>(i, "iree.identifier");
+ }
+
+ static StringAttr getResultId(func::FuncOp funcOp, int i) {
+ StringAttr id =
+ funcOp.getResultAttrOfType<StringAttr>(i, "ml_program.identifier");
+ return id ? id
+ : funcOp.getResultAttrOfType<StringAttr>(i, "iree.identifier");
+ }
+
void runOnOperation() override {
auto moduleOp = getOperation();
@@ -140,8 +153,7 @@
SmallVector<std::string, 4> inputNames;
SmallVector<std::string, 4> outputNames;
for (unsigned i = 0; i < funcType.getNumInputs(); ++i) {
- auto identifier =
- funcOp.getArgAttrOfType<StringAttr>(i, "iree.identifier");
+ auto identifier = getArgId(funcOp, i);
if (identifier) {
inputNames.push_back(identifier.getValue().str());
} else {
@@ -149,8 +161,7 @@
}
}
for (unsigned i = 0; i < funcType.getNumResults(); ++i) {
- auto identifier =
- funcOp.getResultAttrOfType<StringAttr>(i, "iree.identifier");
+ auto identifier = getResultId(funcOp, i);
if (identifier) {
outputNames.push_back(identifier.getValue().str());
} else {
@@ -609,12 +620,11 @@
// Constructs an attribute containing all of the input and output identifiers:
// tfl.io.names=arg0;arg1;ret0;ret1
//
- // Default names will be used if no iree.identifiers are set on the function.
+ // Default names will be used if no identifiers are set on the function.
NamedAttribute buildIONamesAttr(mlir::func::FuncOp entryFuncOp) {
SmallVector<std::string, 4> pieces;
for (int i = 0; i < entryFuncOp.getNumArguments(); ++i) {
- auto identifierAttr =
- entryFuncOp.getArgAttrOfType<StringAttr>(i, "iree.identifier");
+ auto identifierAttr = getArgId(entryFuncOp, i);
if (!identifierAttr || identifierAttr.getValue().empty()) {
pieces.push_back("arg" + std::to_string(i));
} else {
@@ -622,8 +632,7 @@
}
}
for (int i = 0; i < entryFuncOp.getNumResults(); ++i) {
- auto identifierAttr =
- entryFuncOp.getResultAttrOfType<StringAttr>(i, "iree.identifier");
+ auto identifierAttr = getResultId(entryFuncOp, i);
if (!identifierAttr || identifierAttr.getValue().empty()) {
pieces.push_back("ret" + std::to_string(i));
} else {