Add TFLite integration tests to run on presubmits (#8269)
* Add tflite tests to run on presubmits
* Update test_util impor
* Run yapf
diff --git a/integrations/tensorflow/test/iree_tfl_tests/cartoon_gan.run b/integrations/tensorflow/test/iree_tfl_tests/cartoon_gan.run
new file mode 100644
index 0000000..7526e4a
--- /dev/null
+++ b/integrations/tensorflow/test/iree_tfl_tests/cartoon_gan.run
@@ -0,0 +1 @@
+# RUN: %PYTHON -m iree_tfl_tests.cartoon_gan_test -artifacts_dir=%t
diff --git a/integrations/tensorflow/test/iree_tfl_tests/east_text_detector.run b/integrations/tensorflow/test/iree_tfl_tests/east_text_detector.run
new file mode 100644
index 0000000..af10e7a
--- /dev/null
+++ b/integrations/tensorflow/test/iree_tfl_tests/east_text_detector.run
@@ -0,0 +1 @@
+# RUN: %PYTHON -m iree_tfl_tests.east_text_detector_test -artifacts_dir=%t
diff --git a/integrations/tensorflow/test/iree_tfl_tests/gpt2.run b/integrations/tensorflow/test/iree_tfl_tests/gpt2.run
new file mode 100644
index 0000000..ca103dc
--- /dev/null
+++ b/integrations/tensorflow/test/iree_tfl_tests/gpt2.run
@@ -0,0 +1 @@
+# RUN: %PYTHON -m iree_tfl_tests.gpt2_test -artifacts_dir=%t
diff --git a/integrations/tensorflow/test/iree_tfl_tests/mnasnet.run b/integrations/tensorflow/test/iree_tfl_tests/mnasnet.run
new file mode 100644
index 0000000..978b8cc
--- /dev/null
+++ b/integrations/tensorflow/test/iree_tfl_tests/mnasnet.run
@@ -0,0 +1 @@
+# RUN: %PYTHON -m iree_tfl_tests.mnasnet_test -artifacts_dir=%t
diff --git a/integrations/tensorflow/test/iree_tfl_tests/mobilenet_v3.run b/integrations/tensorflow/test/iree_tfl_tests/mobilenet_v3.run
new file mode 100644
index 0000000..27fe77d
--- /dev/null
+++ b/integrations/tensorflow/test/iree_tfl_tests/mobilenet_v3.run
@@ -0,0 +1 @@
+# RUN: %PYTHON -m iree_tfl_tests.mobilenet_v3_test -artifacts_dir=%t
diff --git a/integrations/tensorflow/test/iree_tfl_tests/person_detect.run b/integrations/tensorflow/test/iree_tfl_tests/person_detect.run
new file mode 100644
index 0000000..5a0c6c1
--- /dev/null
+++ b/integrations/tensorflow/test/iree_tfl_tests/person_detect.run
@@ -0,0 +1 @@
+# RUN: %PYTHON -m iree_tfl_tests.person_detect_test -artifacts_dir=%t
diff --git a/integrations/tensorflow/test/python/iree_tfl_tests/cartoon_gan_test.py b/integrations/tensorflow/test/python/iree_tfl_tests/cartoon_gan_test.py
new file mode 100644
index 0000000..aa67a3f
--- /dev/null
+++ b/integrations/tensorflow/test/python/iree_tfl_tests/cartoon_gan_test.py
@@ -0,0 +1,23 @@
+# Copyright 2022 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import absl.testing
+from . import test_util
+
+model_path = "https://tfhub.dev/sayakpaul/lite-model/cartoongan/dr/1?lite-format=tflite"
+
+
+class CartoonGanTest(test_util.TFLiteModelTest):
+
+ def __init__(self, *args, **kwargs):
+ super(CartoonGanTest, self).__init__(model_path, *args, **kwargs)
+
+ def test_compile_tflite(self):
+ self.compile_and_execute()
+
+
+if __name__ == '__main__':
+ absl.testing.absltest.main()
diff --git a/integrations/tensorflow/test/python/iree_tfl_tests/east_text_detector_test.py b/integrations/tensorflow/test/python/iree_tfl_tests/east_text_detector_test.py
new file mode 100644
index 0000000..b5d9f26
--- /dev/null
+++ b/integrations/tensorflow/test/python/iree_tfl_tests/east_text_detector_test.py
@@ -0,0 +1,39 @@
+# Copyright 2022 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import absl.testing
+import numpy
+from . import test_util
+
+model_path = "https://tfhub.dev/sayakpaul/lite-model/east-text-detector/dr/1?lite-format=tflite"
+
+
+class EastTextDetectorTest(test_util.TFLiteModelTest):
+
+ def __init__(self, *args, **kwargs):
+ super(EastTextDetectorTest, self).__init__(model_path, *args, **kwargs)
+
+ def compare_results(self, iree_results, tflite_results, details):
+ super(EastTextDetectorTest, self).compare_results(iree_results,
+ tflite_results, details)
+ self.assertTrue(
+ numpy.isclose(iree_results[0], tflite_results[0], atol=1e-3).all())
+
+ # The second return is extremely noisy as it is not a binary classification. To handle we
+ # check normalized correlation with an expectation of "close enough".
+ iree_norm = numpy.sqrt(iree_results[1] * iree_results[1])
+ tflite_norm = numpy.sqrt(tflite_results[1] * tflite_results[1])
+
+ correlation = numpy.average(iree_results[1] * tflite_results[1] /
+ iree_norm / tflite_norm)
+ self.assertTrue(numpy.isclose(correlation, 1.0, atol=1e-2).all())
+
+ def test_compile_tflite(self):
+ self.compile_and_execute()
+
+
+if __name__ == '__main__':
+ absl.testing.absltest.main()
diff --git a/integrations/tensorflow/test/python/iree_tfl_tests/gpt2_test.py b/integrations/tensorflow/test/python/iree_tfl_tests/gpt2_test.py
new file mode 100644
index 0000000..5d98e7e
--- /dev/null
+++ b/integrations/tensorflow/test/python/iree_tfl_tests/gpt2_test.py
@@ -0,0 +1,41 @@
+# Copyright 2022 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import absl.testing
+import numpy
+from . import test_util
+
+model_path = "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-64.tflite"
+
+
+# This test is a massive download and excluded due to causing timeouts.
+class GPT2Test(test_util.TFLiteModelTest):
+
+ def __init__(self, *args, **kwargs):
+ super(GPT2Test, self).__init__(model_path, *args, **kwargs)
+
+ # Inputs modified to be useful mobilebert inputs.
+ def generate_inputs(self, input_details):
+ args = []
+ args.append(
+ numpy.random.randint(low=0,
+ high=256,
+ size=input_details[0]["shape"],
+ dtype=input_details[0]["dtype"]))
+ return args
+
+ def compare_results(self, iree_results, tflite_results, details):
+ super(GPT2Test, self).compare_results(iree_results, tflite_results, details)
+ for i in range(len(iree_results)):
+ self.assertTrue(
+ numpy.isclose(iree_results[i], tflite_results[i], atol=5e-3).all())
+
+ def test_compile_tflite(self):
+ self.compile_and_execute()
+
+
+if __name__ == '__main__':
+ absl.testing.absltest.main()
diff --git a/integrations/tensorflow/test/python/iree_tfl_tests/mnasnet_test.py b/integrations/tensorflow/test/python/iree_tfl_tests/mnasnet_test.py
new file mode 100644
index 0000000..71f2044
--- /dev/null
+++ b/integrations/tensorflow/test/python/iree_tfl_tests/mnasnet_test.py
@@ -0,0 +1,24 @@
+# Copyright 2022 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import absl.testing
+import numpy
+from . import test_util
+
+model_path = "https://tfhub.dev/tensorflow/lite-model/mnasnet_1.0_224/1/metadata/1?lite-format=tflite"
+
+
+class MnasnetTest(test_util.TFLiteModelTest):
+
+ def __init__(self, *args, **kwargs):
+ super(MnasnetTest, self).__init__(model_path, *args, **kwargs)
+
+ def test_compile_tflite(self):
+ self.compile_and_execute()
+
+
+if __name__ == '__main__':
+ absl.testing.absltest.main()
diff --git a/integrations/tensorflow/test/python/iree_tfl_tests/mobilenet_v3_test.py b/integrations/tensorflow/test/python/iree_tfl_tests/mobilenet_v3_test.py
new file mode 100644
index 0000000..e77e395
--- /dev/null
+++ b/integrations/tensorflow/test/python/iree_tfl_tests/mobilenet_v3_test.py
@@ -0,0 +1,30 @@
+# Copyright 2022 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import absl.testing
+import numpy
+from . import test_util
+
+model_path = "https://tfhub.dev/bohemian-visual-recognition-alliance/lite-model/models/mushroom-identification_v1/1?lite-format=tflite"
+
+
+class MobilenetV3Test(test_util.TFLiteModelTest):
+
+ def __init__(self, *args, **kwargs):
+ super(MobilenetV3Test, self).__init__(model_path, *args, **kwargs)
+
+ def compare_results(self, iree_results, tflite_results, details):
+ super(MobilenetV3Test, self).compare_results(iree_results, tflite_results,
+ details)
+ self.assertTrue(
+ numpy.isclose(iree_results[0], tflite_results[0], atol=1e-4).all())
+
+ def test_compile_tflite(self):
+ self.compile_and_execute()
+
+
+if __name__ == '__main__':
+ absl.testing.absltest.main()
diff --git a/integrations/tensorflow/test/python/iree_tfl_tests/person_detect_test.py b/integrations/tensorflow/test/python/iree_tfl_tests/person_detect_test.py
new file mode 100644
index 0000000..5b7662c
--- /dev/null
+++ b/integrations/tensorflow/test/python/iree_tfl_tests/person_detect_test.py
@@ -0,0 +1,60 @@
+# Copyright 2022 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import absl.testing
+import numpy
+from . import test_util
+import urllib.request
+
+from PIL import Image
+
+model_path = "https://github.com/tensorflow/tflite-micro/raw/aeac6f39e5c7475cea20c54e86d41e3a38312546/tensorflow/lite/micro/models/person_detect.tflite"
+
+
+class PersonDetectTest(test_util.TFLiteModelTest):
+
+ def __init__(self, *args, **kwargs):
+ super(PersonDetectTest, self).__init__(model_path, *args, **kwargs)
+
+ def compare_results(self, iree_results, tflite_results, details):
+ super(PersonDetectTest, self).compare_results(iree_results, tflite_results,
+ details)
+ self.assertTrue(
+ numpy.isclose(iree_results[0], tflite_results[0], atol=1e-3).all())
+
+ # TFLite is broken with this model so we hardcode the input/output details.
+ def setup_tflite(self):
+ self.input_details = [{
+ "shape": [1, 96, 96, 1],
+ "dtype": numpy.int8,
+ "index": 0,
+ }]
+ self.output_details = [{
+ "shape": [1, 2],
+ "dtype": numpy.int8,
+ }]
+
+ # The input has known expected values. We hardcode this value.
+ def invoke_tflite(self, args):
+ return [numpy.array([[-113, 113]], dtype=numpy.int8)]
+
+ def generate_inputs(self, input_details):
+ img_path = "https://github.com/tensorflow/tflite-micro/raw/aeac6f39e5c7475cea20c54e86d41e3a38312546/tensorflow/lite/micro/examples/person_detection/testdata/person.bmp"
+ local_path = "/".join([self.workdir, "person.bmp"])
+ urllib.request.urlretrieve(img_path, local_path)
+
+ shape = input_details[0]["shape"]
+ im = numpy.array(Image.open(local_path).resize(
+ (shape[1], shape[2]))).astype(input_details[0]["dtype"])
+ args = [im.reshape(shape)]
+ return args
+
+ def test_compile_tflite(self):
+ self.compile_and_execute()
+
+
+if __name__ == '__main__':
+ absl.testing.absltest.main()