[TF/TFLite] Various fixes in the tf and tflite python binding (#17278)
This fixes:
- Compilation from a string
- Import only return types
I also ran
https://github.com/iree-org/iree/blob/main/compiler/bindings/python/test/tools/testdata/generate_tflite.py
which updated the flatbuffer file.
diff --git a/build_tools/cmake/run_tf_tests.sh b/build_tools/cmake/run_tf_tests.sh
index 2d5d8d8..27380ce 100755
--- a/build_tools/cmake/run_tf_tests.sh
+++ b/build_tools/cmake/run_tf_tests.sh
@@ -62,3 +62,9 @@
echo "Some tests failed!!!"
exit 1
fi
+
+echo "***** Running TF and TFLite python api tests *****"
+
+TF_API_TEST_DIR="compiler/bindings/python/test/tools"
+
+pytest ${TF_API_TEST_DIR}/compiler_tflite_test.py ${TF_API_TEST_DIR}/compiler_tf_test.py
diff --git a/compiler/bindings/python/iree/compiler/tools/ir_tool/__main__.py b/compiler/bindings/python/iree/compiler/tools/ir_tool/__main__.py
index 574e0e5..35a7f2c 100644
--- a/compiler/bindings/python/iree/compiler/tools/ir_tool/__main__.py
+++ b/compiler/bindings/python/iree/compiler/tools/ir_tool/__main__.py
@@ -38,7 +38,7 @@
help="sub-command help", required=True, dest="sub_command"
)
- def add_ouptut_options(subparser):
+ def add_output_options(subparser):
subparser.add_argument(
"--emit-bytecode", action="store_true", help="Emit bytecode"
)
@@ -56,7 +56,7 @@
help="Read a file and then output it using the given options, without "
"modification",
)
- add_ouptut_options(copy_parser)
+ add_output_options(copy_parser)
copy_parser.add_argument("input_file", help="File to process")
copy_parser.add_argument(
"-o", required=True, dest="output_file", help="Output file"
@@ -70,7 +70,7 @@
"replacing them with pseudo data suitable for interactive "
"debugging of IR",
)
- add_ouptut_options(strip_data_parser)
+ add_output_options(strip_data_parser)
strip_data_parser.add_argument(
"--no-import",
action="store_true",
diff --git a/compiler/bindings/python/iree/compiler/tools/tf.py b/compiler/bindings/python/iree/compiler/tools/tf.py
index 9eb28a2..be401c7 100644
--- a/compiler/bindings/python/iree/compiler/tools/tf.py
+++ b/compiler/bindings/python/iree/compiler/tools/tf.py
@@ -9,6 +9,7 @@
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
+import importlib.util
import os
import logging
import tempfile
@@ -34,10 +35,8 @@
def is_available():
- """Determine if TensorFlow and the compiler are available."""
- try:
- import tensorflow as tf
- except ModuleNotFoundError:
+ """Determine if TensorFlow and the TF frontend are available."""
+ if importlib.util.find_spec("tensorflow") is None:
logging.warn("Unable to import tensorflow")
return False
try:
diff --git a/compiler/bindings/python/iree/compiler/tools/tflite.py b/compiler/bindings/python/iree/compiler/tools/tflite.py
index 84443ad..b953c43 100644
--- a/compiler/bindings/python/iree/compiler/tools/tflite.py
+++ b/compiler/bindings/python/iree/compiler/tools/tflite.py
@@ -11,6 +11,7 @@
import logging
import os
import tempfile
+import importlib.util
from typing import List, Optional, Sequence, Set, Union
from .debugging import TempFileSaver
@@ -29,7 +30,10 @@
def is_available():
- """Determine if the TFLite frontend is available."""
+ """Determine if TensorFlow and the TFLite frontend are available."""
+ if importlib.util.find_spec("tensorflow") is None:
+ logging.warn("Unable to import tensorflow")
+ return False
try:
import iree.tools.tflite.scripts.iree_import_tflite.__main__
except ModuleNotFoundError:
@@ -107,7 +111,7 @@
if options.import_only:
if options.output_file:
return None
- with open(tfl_iree_input, "r") as f:
+ with open(tfl_iree_input, "rb") as f:
return f.read()
# Run IREE compilation pipeline
@@ -131,7 +135,7 @@
input_bytes = (
input_bytes.encode("utf-8") if isinstance(input_bytes, str) else input_bytes
)
- with tempfile.NamedTemporaryFile(mode="w") as temp_file:
- tempfile.write(input_bytes)
- tempfile.close()
- return compile_file(tempfile.name, **kwargs)
+ with tempfile.NamedTemporaryFile() as temp_file:
+ temp_file.write(input_bytes)
+ temp_file.flush() # Ensure the data is written to disk
+ return compile_file(temp_file.name, **kwargs)
diff --git a/compiler/bindings/python/test/tools/compiler_tf_test.py b/compiler/bindings/python/test/tools/compiler_tf_test.py
index da98e08..7639cef 100644
--- a/compiler/bindings/python/test/tools/compiler_tf_test.py
+++ b/compiler/bindings/python/test/tools/compiler_tf_test.py
@@ -43,12 +43,12 @@
# TODO(laurenzo): More test cases needed (may need additional files).
# Specifically, figure out how to test v1 models.
-class TfCompilerTest(tf.test.TestCase):
+class TfCompilerTest(unittest.TestCase):
def testImportSavedModel(self):
import_mlir = iree.compiler.tools.tf.compile_saved_model(
self.smdir, import_only=True, output_generic_mlir=True
- ).decode("utf-8")
- self.assertIn('sym_name = "simple_matmul"', import_mlir)
+ )
+ self.assertIn(b"simple_matmul", import_mlir)
def testCompileSavedModel(self):
binary = iree.compiler.tools.tf.compile_saved_model(
@@ -82,4 +82,4 @@
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
- tf.test.main()
+ unittest.main()
diff --git a/compiler/bindings/python/test/tools/compiler_tflite_test.py b/compiler/bindings/python/test/tools/compiler_tflite_test.py
index 8a99e32..2dec1b3 100644
--- a/compiler/bindings/python/test/tools/compiler_tflite_test.py
+++ b/compiler/bindings/python/test/tools/compiler_tflite_test.py
@@ -10,6 +10,8 @@
import tempfile
import unittest
+from iree.compiler.tools.ir_tool import __main__ as ir_tool
+
# TODO: No idea why pytype cannot find names from this module.
# pytype: disable=name-error
import iree.compiler.tools.tflite
@@ -22,12 +24,25 @@
sys.exit(0)
+def mlir_bytecode_file_to_text(bytecode_file):
+ with tempfile.NamedTemporaryFile() as temp_file:
+ args = ir_tool.parse_arguments(["copy", bytecode_file, "-o", temp_file.name])
+ ir_tool.main(args)
+ return str(temp_file.read())
+
+
+def mlir_bytecode_to_text(bytecode):
+ with tempfile.NamedTemporaryFile("wb") as temp_bytecode_file:
+ temp_bytecode_file.write(bytecode)
+ temp_bytecode_file.flush()
+ return mlir_bytecode_file_to_text(temp_bytecode_file.name)
+
+
class CompilerTest(unittest.TestCase):
def testImportBinaryPbFile(self):
path = os.path.join(os.path.dirname(__file__), "testdata", "tflite_sample.fb")
- text = iree.compiler.tools.tflite.compile_file(path, import_only=True).decode(
- "utf-8"
- )
+ bytecode = iree.compiler.tools.tflite.compile_file(path, import_only=True)
+ text = mlir_bytecode_to_text(bytecode)
logging.info("%s", text)
self.assertIn("tosa.mul", text)
@@ -48,10 +63,11 @@
path, import_only=True, output_file=f.name
)
self.assertIsNone(output)
- with open(f.name, "rt") as f_read:
- text = f_read.read()
+ with open(f.name, "rb") as f_read:
+ bytecode = f_read.read()
finally:
os.remove(f.name)
+ text = mlir_bytecode_to_text(bytecode)
logging.info("%s", text)
self.assertIn("tosa.mul", text)
@@ -77,9 +93,8 @@
path = os.path.join(os.path.dirname(__file__), "testdata", "tflite_sample.fb")
with open(path, "rb") as f:
content = f.read()
- text = iree.compiler.tools.tflite.compile_str(content, import_only=True).decode(
- "utf-8"
- )
+ bytecode = iree.compiler.tools.tflite.compile_str(content, import_only=True)
+ text = mlir_bytecode_to_text(bytecode)
logging.info("%s", text)
self.assertIn("tosa.mul", text)
diff --git a/compiler/bindings/python/test/tools/testdata/tflite_sample.fb b/compiler/bindings/python/test/tools/testdata/tflite_sample.fb
index 52cb9e4..d4a65ed 100644
--- a/compiler/bindings/python/test/tools/testdata/tflite_sample.fb
+++ b/compiler/bindings/python/test/tools/testdata/tflite_sample.fb
Binary files differ