Added TFLite end-to-end integration test for mobilenet v1 (#7953)
Committed the initial unittest for mobilenet. This should handle long running
integration tests for the tflite frontend.
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 84914be..b4a0e3c 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -75,7 +75,9 @@
option(IREE_BUILD_EXPERIMENTAL_REMOTING "Builds experimental remoting support." OFF)
option(IREE_BUILD_EXPERIMENTAL_ROCM "Builds the experimental ROCm Backend." OFF)
-option(IREE_ENABLE_NEW_INTEGRATION_TESTS "Enables new integration tests and disables old." OFF)
+option(IREE_ENABLE_OLD_INTEGRATION_TESTS "Enables old integration tests." OFF)
+option(IREE_ENABLE_NEW_INTEGRATION_TESTS "Enables new integration tests." OFF)
+option(IREE_ENABLE_TFLITE_INTEGRATION_TESTS "Enables tflite integration tests." OFF)
#-------------------------------------------------------------------------------
# Derived flags based on primary options
diff --git a/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-swiftshader/build.sh b/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-swiftshader/build.sh
index 46a1180..dfad1ab 100755
--- a/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-swiftshader/build.sh
+++ b/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-swiftshader/build.sh
@@ -61,7 +61,10 @@
-DIREE_BUILD_SAMPLES=OFF \
-DIREE_BUILD_XLA_COMPILER=ON \
-DIREE_BUILD_TFLITE_COMPILER=ON \
- -DIREE_BUILD_TENSORFLOW_COMPILER=ON .
+ -DIREE_BUILD_TENSORFLOW_COMPILER=ON \
+ -DIREE_ENABLE_OLD_INTEGRATION_TESTS=ON \
+ -DIREE_ENABLE_TFLITE_INTEGRATION_TESTS=ON \
+ .
echo "Building with Ninja"
cd "${CMAKE_BUILD_DIR?}"
diff --git a/integrations/tensorflow/CMakeLists.txt b/integrations/tensorflow/CMakeLists.txt
index 71c188e..6e8f443 100644
--- a/integrations/tensorflow/CMakeLists.txt
+++ b/integrations/tensorflow/CMakeLists.txt
@@ -18,9 +18,8 @@
endif()
if(${IREE_BUILD_TESTS} AND ${IREE_BUILD_PYTHON_BINDINGS})
- if(${IREE_ENABLE_NEW_INTEGRATION_TESTS})
- add_subdirectory(test)
- else()
+ add_subdirectory(test)
+ if(${IREE_ENABLE_OLD_INTEGRATION_TESTS})
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/e2e/")
include(iree_e2e_cartesian_product_test_suite)
add_subdirectory(e2e)
diff --git a/integrations/tensorflow/test/CMakeLists.txt b/integrations/tensorflow/test/CMakeLists.txt
index 6345d63..17b84f6 100644
--- a/integrations/tensorflow/test/CMakeLists.txt
+++ b/integrations/tensorflow/test/CMakeLists.txt
@@ -4,4 +4,11 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-add_subdirectory(tf_integration)
+
+if(${IREE_ENABLE_NEW_INTEGRATION_TESTS})
+ add_subdirectory(tf_integration)
+endif()
+
+if(${IREE_ENABLE_TFLITE_INTEGRATION_TESTS})
+ add_subdirectory(tflite_integration)
+endif()
diff --git a/integrations/tensorflow/test/tflite_integration/CMakeLists.txt b/integrations/tensorflow/test/tflite_integration/CMakeLists.txt
new file mode 100644
index 0000000..a8ebe5f
--- /dev/null
+++ b/integrations/tensorflow/test/tflite_integration/CMakeLists.txt
@@ -0,0 +1,27 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+
+iree_py_library(
+ NAME
+ test_util
+ SRCS
+ "test_util.py"
+)
+
+iree_py_test(
+ NAME
+ mobilenet_v1
+ SRCS
+ "mobilenet_v1_test.py"
+)
+
+iree_py_test(
+ NAME
+ posenet_i8
+ SRCS
+ "posenet_i8_test.py"
+)
diff --git a/integrations/tensorflow/test/tflite_integration/mobilenet_v1_test.py b/integrations/tensorflow/test/tflite_integration/mobilenet_v1_test.py
new file mode 100644
index 0000000..1d1ba6f
--- /dev/null
+++ b/integrations/tensorflow/test/tflite_integration/mobilenet_v1_test.py
@@ -0,0 +1,30 @@
+# Copyright 2022 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import absl.testing
+import numpy
+import test_util
+
+model_path = "https://storage.googleapis.com/iree-model-artifacts/tflite-integration-tests/mobilenet_v1.tflite"
+
+
+class MobilenetV1Test(test_util.TFLiteModelTest):
+
+ def __init__(self, *args, **kwargs):
+ super(MobilenetV1Test, self).__init__(model_path, *args, **kwargs)
+
+ def compare_results(self, iree_results, tflite_results, details):
+ super(MobilenetV1Test, self).compare_results(iree_results, tflite_results,
+ details)
+ self.assertTrue(
+ numpy.isclose(iree_results[0], tflite_results[0], atol=1e-4).all())
+
+ def test_compile_tflite(self):
+ self.compile_and_execute()
+
+
+if __name__ == '__main__':
+ absl.testing.absltest.main()
diff --git a/integrations/tensorflow/test/tflite_integration/posenet_i8_test.py b/integrations/tensorflow/test/tflite_integration/posenet_i8_test.py
new file mode 100644
index 0000000..e86a09b
--- /dev/null
+++ b/integrations/tensorflow/test/tflite_integration/posenet_i8_test.py
@@ -0,0 +1,51 @@
+# Copyright 2022 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import absl.testing
+import numpy
+import test_util
+import urllib.request
+
+from PIL import Image
+
+model_path = "https://storage.googleapis.com/iree-model-artifacts/tflite-integration-tests/posenet_i8.tflite"
+model_input = "https://storage.googleapis.com/iree-model-artifacts/tflite-integration-tests/posenet_i8_input.jpg"
+
+
+class PosenetI8Test(test_util.TFLiteModelTest):
+
+ def __init__(self, *args, **kwargs):
+ super(PosenetI8Test, self).__init__(model_path, *args, **kwargs)
+
+ def compare_results(self, iree_results, tflite_results, details):
+ super(PosenetI8Test, self).compare_results(iree_results, tflite_results,
+ details)
+ # This value is a discretized location of the persons joints. If we are
+ # *close* to the expected position we can consider this good enough.
+ self.assertTrue(
+ numpy.isclose(iree_results[0][:, :, :, 0],
+ tflite_results[0][:, :, :, 0],
+ atol=25e-3).all())
+ self.assertTrue(
+ numpy.isclose(iree_results[0][:, :, :, 1],
+ tflite_results[0][:, :, :, 1],
+ atol=25e-3).all())
+
+ def generate_inputs(self, input_details):
+ local_path = "/".join([self.workdir, "person.jpg"])
+ urllib.request.urlretrieve(model_input, local_path)
+
+ shape = input_details[0]["shape"]
+ im = numpy.array(Image.open(local_path).resize((shape[1], shape[2])))
+ args = [im.reshape(shape)]
+ return args
+
+ def test_compile_tflite(self):
+ self.compile_and_execute()
+
+
+if __name__ == '__main__':
+ absl.testing.absltest.main()
diff --git a/integrations/tensorflow/test/tflite_integration/test_util.py b/integrations/tensorflow/test/tflite_integration/test_util.py
new file mode 100644
index 0000000..752a550
--- /dev/null
+++ b/integrations/tensorflow/test/tflite_integration/test_util.py
@@ -0,0 +1,150 @@
+# Lint as: python3
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+"""Test architecture for a set of tflite tests."""
+
+import absl
+from absl.flags import FLAGS
+import absl.testing as testing
+import iree.compiler.tflite as iree_tflite_compile
+import iree.runtime as iree_rt
+import numpy as np
+import os
+import sys
+import tempfile
+import tensorflow.compat.v2 as tf
+import time
+import urllib.request
+
+targets = {
+ 'dylib': 'dylib-llvm-aot',
+ 'vulkan': 'vulkan-spirv',
+}
+
+configs = {
+ 'dylib': 'dylib',
+ 'vulkan': 'vulkan',
+}
+
+absl.flags.DEFINE_string('config', 'dylib', 'model path to execute')
+
+
+class TFLiteModelTest(testing.absltest.TestCase):
+
+ def __init__(self, model_path, *args, **kwargs):
+ super(TFLiteModelTest, self).__init__(*args, **kwargs)
+ self.model_path = model_path
+
+ def setUp(self):
+ if self.model_path is None:
+ return
+ exe_basename = os.path.basename(sys.argv[0])
+ self.workdir = tempfile.mkdtemp(dir=testing.absltest.TEST_TMPDIR.value)
+ print(f"TMP_DIR = {self.workdir}")
+ self.tflite_file = '/'.join([self.workdir, 'model.tflite'])
+ self.tflite_ir = '/'.join([self.workdir, 'tflite.mlir'])
+ self.iree_ir = '/'.join([self.workdir, 'tosa.mlir'])
+ if os.path.exists(self.model_path):
+ self.tflite_file = self.model_path
+ else:
+ urllib.request.urlretrieve(self.model_path, self.tflite_file)
+ self.binary = '/'.join([self.workdir, 'module.bytecode'])
+
+ def generate_inputs(self, input_details):
+ args = []
+ for input in input_details:
+ absl.logging.info("\t%s, %s", str(input["shape"]),
+ input["dtype"].__name__)
+ args.append(np.zeros(shape=input["shape"], dtype=input["dtype"]))
+ return args
+
+ def compare_results(self, iree_results, tflite_results, details):
+ self.assertEqual(len(iree_results), len(tflite_results),
+ "Number of results do not match")
+
+ for i in range(len(details)):
+ iree_result = iree_results[i]
+ tflite_result = tflite_results[i]
+ iree_result = iree_result.astype(np.single)
+ tflite_result = tflite_result.astype(np.single)
+ self.assertEqual(iree_result.shape, tflite_result.shape)
+ maxError = np.max(np.abs(iree_result - tflite_result))
+ absl.logging.info("Max error (%d): %f", i, maxError)
+
+ def setup_tflite(self):
+ absl.logging.info("Setting up tflite interpreter")
+ self.tflite_interpreter = tf.lite.Interpreter(model_path=self.tflite_file)
+ self.tflite_interpreter.allocate_tensors()
+ self.input_details = self.tflite_interpreter.get_input_details()
+ self.output_details = self.tflite_interpreter.get_output_details()
+
+ def setup_iree(self):
+ absl.logging.info("Setting up iree runtime")
+ with open(self.binary, 'rb') as f:
+ config = iree_rt.Config(configs[absl.flags.FLAGS.config])
+ self.iree_context = iree_rt.SystemContext(config=config)
+ vm_module = iree_rt.VmModule.from_flatbuffer(f.read())
+ self.iree_context.add_vm_module(vm_module)
+
+ def invoke_tflite(self, args):
+ for i, input in enumerate(args):
+ self.tflite_interpreter.set_tensor(self.input_details[i]['index'], input)
+ start = time.perf_counter()
+ self.tflite_interpreter.invoke()
+ end = time.perf_counter()
+ tflite_results = []
+ absl.logging.info(f"Invocation time: {end - start:0.4f} seconds")
+ for output_detail in self.output_details:
+ tflite_results.append(
+ np.array(self.tflite_interpreter.get_tensor(output_detail['index'])))
+
+ for i in range(len(self.output_details)):
+ dtype = self.output_details[i]["dtype"]
+ tflite_results[i] = tflite_results[i].astype(dtype)
+ return tflite_results
+
+ def invoke_iree(self, args):
+ invoke = self.iree_context.modules.module["main"]
+ start = time.perf_counter()
+ iree_results = invoke(*args)
+ end = time.perf_counter()
+ absl.logging.info(f"Invocation time: {end - start:0.4f} seconds")
+ if not isinstance(iree_results, tuple):
+ iree_results = (iree_results,)
+ return iree_results
+
+ def compile_and_execute(self):
+ self.assertIsNotNone(self.model_path)
+
+ absl.logging.info("Setting up for IREE")
+ iree_tflite_compile.compile_file(
+ self.tflite_file,
+ input_type="tosa",
+ output_file=self.binary,
+ save_temp_tfl_input=self.tflite_ir,
+ save_temp_iree_input=self.iree_ir,
+ target_backends=[targets[absl.flags.FLAGS.config]],
+ import_only=False)
+
+ self.setup_tflite()
+ self.setup_iree()
+
+ absl.logging.info("Setting up test inputs")
+ args = self.generate_inputs(self.input_details)
+
+ absl.logging.info("Invoking TFLite")
+ tflite_results = self.invoke_tflite(args)
+
+ absl.logging.info("Invoke IREE")
+ iree_results = self.invoke_iree(args)
+
+ # Fix type information for unsigned cases.
+ iree_results = list(iree_results)
+ for i in range(len(self.output_details)):
+ dtype = self.output_details[i]["dtype"]
+ iree_results[i] = iree_results[i].astype(dtype)
+
+ self.compare_results(iree_results, tflite_results, self.output_details)
diff --git a/scripts/update_tflite_models.py b/scripts/update_tflite_models.py
new file mode 100644
index 0000000..e2ea887
--- /dev/null
+++ b/scripts/update_tflite_models.py
@@ -0,0 +1,70 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# This tool handles mirroring tflite testing files from their source to the
+# the iree-model-artifacts test bucket. This avoids taking dependency on
+# external test files that may change or no longer be available.
+#
+# To update all files:
+# python update_tflite_models.py --file all
+#
+# To update a specific file:
+# python update_tflite_models.py --file posenet_i8_input.jpg
+#
+# Note you must have write permission to the iree-model-artifacts GCS bucket
+# with local gcloud authentication.
+
+from absl import app
+from absl import flags
+from google.cloud import storage
+from google_auth_oauthlib import flow
+
+import tempfile
+import urllib
+
+FLAGS = flags.FLAGS
+flags.DEFINE_string('file', '', 'file to update')
+
+file_dict = dict({
+ "mobilenet_v1.tflite":
+ "https://tfhub.dev/tensorflow/lite-model/mobilenet_v1_1.0_160/1/default/1?lite-format=tflite",
+ "posenet_i8.tflite":
+ "https://tfhub.dev/google/lite-model/movenet/singlepose/lightning/tflite/int8/4?lite-format=tflite",
+ "posenet_i8_input.jpg":
+ "https://github.com/tensorflow/examples/raw/master/lite/examples/pose_estimation/raspberry_pi/test_data/image3.jpeg"
+})
+
+BUCKET_NAME = "iree-model-artifacts"
+FOLDER_NAME = "tflite-integration-tests"
+
+
+def upload_model(source, destination, tmpfile):
+ """Uploads a file to the bucket."""
+ urllib.request.urlretrieve(source, tmpfile)
+
+ storage_client = storage.Client()
+ bucket = storage_client.get_bucket(BUCKET_NAME)
+ blob = bucket.blob("/".join([FOLDER_NAME, destination]))
+ blob.upload_from_filename(tmpfile)
+
+
+def main(argv):
+ tf = tempfile.NamedTemporaryFile()
+
+ items = file_dict.items()
+
+ if FLAGS.file in file_dict:
+ items = [(FLAGS.file, file_dict[FLAGS.file])]
+ elif FLAGS.file != "all":
+ print('Unknown file to upload: ', "\"" + FLAGS.file + "\"")
+ exit()
+
+ for dst, src in items:
+ upload_model(src, dst, tf.name)
+
+
+if __name__ == '__main__':
+ app.run(main)