Open source slim vision model tests (#3336)
diff --git a/integrations/tensorflow/e2e/slim_vision_models/BUILD b/integrations/tensorflow/e2e/slim_vision_models/BUILD
new file mode 100644
index 0000000..14183cc
--- /dev/null
+++ b/integrations/tensorflow/e2e/slim_vision_models/BUILD
@@ -0,0 +1,231 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Test coverage across backends for e2e tests is defined directly in the BUILD
+# files. A coverage table generated from this file can be viewed here:
+# https://google.github.io/iree/tf-e2e-coverage
+# Updates made to test suite names should also be reflected here:
+# https://github.com/google/iree/blob/main/scripts/update_e2e_coverage.py
+
+load(
+ "//bindings/python:build_defs.oss.bzl",
+ "INTREE_TENSORFLOW_PY_DEPS",
+ "NUMPY_DEPS",
+ "iree_py_binary",
+)
+load(
+ "//integrations/tensorflow/e2e/slim_vision_models:iree_slim_vision_test_suite.bzl",
+ "iree_slim_vision_test_suite",
+)
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+# Create binaries for all test srcs to allow them to be run manually.
+iree_py_binary(
+ name = "slim_vision_model_test_manual",
+ srcs = ["slim_vision_model_test.py"],
+ args = ["--tf_hub_url=https://tfhub.dev/google/imagenet/"],
+ main = "slim_vision_model_test.py",
+ python_version = "PY3",
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+)
+
+iree_slim_vision_test_suite(
+ name = "slim_vision_tests",
+ backends = [
+ "tf",
+ "tflite",
+ "iree_vmla",
+ "iree_llvmjit",
+ "iree_vulkan",
+ ],
+ failing_configurations = [
+ {
+ # Failing all but tf and vmla:
+ "models": [
+ "inception_resnet_v2",
+ # tflite: RuntimeError: tensorflow/lite/kernels/reshape.cc:66 num_input_elements != num_output_elements (38400 != -1571481807)Node number 333 (RESHAPE) failed to prepare.
+ # llvmjit: *** Received signal 6 *** (mangled stack trace)
+ # vulkan: Floating point difference between ref and tar was too large. Max abs diff: 1.3961234, atol: 2e-05, max relative diff: 0.27304956, rtol: 1e-06
+ "resnet_v2_101",
+ # tflite: RuntimeError: tensorflow/lite/core/subgraph.cc BytesRequired number of elements overflowed.
+ # llvmjit: Floating point difference between ref and tar was too large. Max abs diff: 11.668068, atol: 2e-05, max relative diff: 0.93950737, rtol: 1e-06
+ # vulkan: Floating point difference between ref and tar was too large. Max abs diff: 11.668067, atol: 2e-05, max relative diff: 0.93950737, rtol: 1e-06
+ "resnet_v2_152",
+ # tflite: RuntimeError: tensorflow/lite/core/subgraph.cc BytesRequired number of elements overflowed.
+ # llvmjit: Floating point difference between ref and tar was too large. Max abs diff: 7.080696, atol: 2e-05, max relative diff: 0.97750616, rtol: 1e-06
+ # vulkan: Floating point difference between ref and tar was too large. Max abs diff: 7.08069, atol: 2e-05, max relative diff: 0.97750485, rtol: 1e-06
+ ],
+ "backends": [
+ "tflite",
+ "iree_llvmjit",
+ "iree_vulkan",
+ ],
+ },
+ {
+ # Failing llvmjit and vulkan:
+ "models": [
+ "inception_v2",
+ # llvmjit: double free or corruption (!prev); *** Received signal 6 ***
+ # vulkan: Floating point difference between ref and tar was too large. Max abs diff: 1.0769763, atol: 2e-05, max relative diff: 0.19576924, rtol: 1e-06
+ "inception_v3",
+ # llvmjit: double free or corruption (!prev); *** Received signal 6 ***
+ # vulkan: Floating point difference between ref and tar was too large. Max abs diff: 2.5201874, atol: 2e-05, max relative diff: 0.53700095, rtol: 1e-06
+ "nasnet_mobile",
+ # llvmjit: corrupted size vs. prev_size; *** Received signal 6 ***
+ # vulkan: *** Received signal 11 ***
+ "resnet_v2_50",
+ # llvmjit: Floating point difference between ref and tar was too large. Max abs diff: 5.8187943, atol: 2e-05, max relative diff: 0.7946711, rtol: 1e-06
+ # vulkan: Floating point difference between ref and tar was too large. Max abs diff: 5.8187933, atol: 2e-05, max relative diff: 0.79467094, rtol: 1e-06
+ ],
+ "backends": [
+ "iree_llvmjit",
+ "iree_vulkan",
+ ],
+ },
+ ],
+ models = [
+ "inception_resnet_v2",
+ "inception_v1",
+ "inception_v2",
+ "inception_v3",
+ "mobilenet_v1_025_128",
+ "mobilenet_v1_025_160",
+ "mobilenet_v1_025_192",
+ "mobilenet_v1_025_224",
+ "mobilenet_v1_050_128",
+ "mobilenet_v1_050_160",
+ "mobilenet_v1_050_192",
+ "mobilenet_v1_050_224",
+ "mobilenet_v1_075_128",
+ "mobilenet_v1_075_160",
+ "mobilenet_v1_075_192",
+ "mobilenet_v1_075_224",
+ "mobilenet_v1_100_128",
+ "mobilenet_v1_100_160",
+ "mobilenet_v1_100_192",
+ "mobilenet_v1_100_224",
+ "mobilenet_v2_035_224",
+ "mobilenet_v2_050_224",
+ "mobilenet_v2_075_224",
+ "mobilenet_v2_100_224",
+ "mobilenet_v2_130_224",
+ "mobilenet_v2_140_224",
+ "nasnet_mobile",
+ "resnet_v1_101",
+ "resnet_v1_152",
+ "resnet_v1_50",
+ "resnet_v2_101",
+ "resnet_v2_152",
+ "resnet_v2_50",
+ ],
+ reference_backend = "tf",
+ tags = [
+ "external",
+ "guitar",
+ "manual",
+ "no-remote",
+ "nokokoro",
+ "notap",
+ ],
+ tf_hub_url = "https://tfhub.dev/google/imagenet/",
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+)
+
+# TODO(meadowlark): Get these working on tf.
+iree_slim_vision_test_suite(
+ name = "slim_vision_ill_configured_tests",
+ backends = [
+ "tf",
+ "tflite",
+ "iree_vmla",
+ "iree_llvmjit",
+ "iree_vulkan",
+ ],
+ failing_configurations = [
+ {
+ # Failing on all backends:
+ "models": [
+ "amoebanet_a_n18_f448",
+ "mobilenet_v2_035_128",
+ "mobilenet_v2_035_160",
+ "mobilenet_v2_035_192",
+ "mobilenet_v2_035_96",
+ "mobilenet_v2_050_128",
+ "mobilenet_v2_050_160",
+ "mobilenet_v2_050_192",
+ "mobilenet_v2_050_96",
+ "mobilenet_v2_075_128",
+ "mobilenet_v2_075_160",
+ "mobilenet_v2_075_192",
+ "mobilenet_v2_075_96",
+ "mobilenet_v2_100_128",
+ "mobilenet_v2_100_160",
+ "mobilenet_v2_100_192",
+ "mobilenet_v2_100_96",
+ "nasnet_large",
+ "pnasnet_large",
+ ],
+ "backends": [
+ "tf",
+ "tflite",
+ "iree_vmla",
+ "iree_llvmjit",
+ "iree_vulkan",
+ ],
+ },
+ ],
+ models = [
+ "amoebanet_a_n18_f448",
+ "mobilenet_v2_035_128",
+ "mobilenet_v2_035_160",
+ "mobilenet_v2_035_192",
+ "mobilenet_v2_035_96",
+ "mobilenet_v2_050_128",
+ "mobilenet_v2_050_160",
+ "mobilenet_v2_050_192",
+ "mobilenet_v2_050_96",
+ "mobilenet_v2_075_128",
+ "mobilenet_v2_075_160",
+ "mobilenet_v2_075_192",
+ "mobilenet_v2_075_96",
+ "mobilenet_v2_100_128",
+ "mobilenet_v2_100_160",
+ "mobilenet_v2_100_192",
+ "mobilenet_v2_100_96",
+ "nasnet_large",
+ "pnasnet_large",
+ ],
+ reference_backend = "tf",
+ tags = [
+ "external",
+ "guitar",
+ "manual",
+ "no-remote",
+ "nokokoro",
+ "notap",
+ ],
+ tf_hub_url = "https://tfhub.dev/google/imagenet/",
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+)
diff --git a/integrations/tensorflow/e2e/slim_vision_models/README.md b/integrations/tensorflow/e2e/slim_vision_models/README.md
new file mode 100644
index 0000000..0a0b5dd
--- /dev/null
+++ b/integrations/tensorflow/e2e/slim_vision_models/README.md
@@ -0,0 +1,10 @@
+# Slim Vision Model Tests
+
+These tests require an additional python dependency on `tensorflow_hub`, which
+can be installed as follows:
+
+```shell
+python3 -m pip install tensorflow_hub
+```
+
+Like the `vision_external_tests`, these tests are not checked by the OSS CI.
diff --git a/integrations/tensorflow/e2e/slim_vision_models/iree_slim_vision_test_suite.bzl b/integrations/tensorflow/e2e/slim_vision_models/iree_slim_vision_test_suite.bzl
new file mode 100644
index 0000000..857c1b4
--- /dev/null
+++ b/integrations/tensorflow/e2e/slim_vision_models/iree_slim_vision_test_suite.bzl
@@ -0,0 +1,159 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Macro for building e2e keras vision model tests."""
+
+load("//bindings/python:build_defs.oss.bzl", "iree_py_test")
+load("@bazel_skylib//lib:new_sets.bzl", "sets")
+
+def iree_slim_vision_test_suite(
+ name,
+ models,
+ backends,
+ reference_backend,
+ failing_configurations = None,
+ tf_hub_url = None,
+ tags = None,
+ deps = None,
+ size = "large",
+ python_version = "PY3",
+ **kwargs):
+ """Creates a test for each configuration and bundles a succeeding and failing test suite.
+
+ Creates one test per model and backend. Tests indicated in
+ `failing_configurations` are bundled into a suite suffixed with "_failing"
+ tagged to be excluded from CI and wildcard builds. All other tests are
+ bundled into a suite with the same name as the macro.
+
+ Args:
+ name:
+ name of the generated passing test suite. If failing_configurations is
+ not `None` then a test suite named name_failing will also be generated.
+ models:
+ an iterable of slim vision model tags to generate targets for.
+ backends:
+ an iterable of targets backends to generate targets for.
+ reference_backend:
+ the backend to use as a source of truth for the expected output results.
+ failing_configurations:
+ an iterable of dictionaries with the keys `models` and `backends`. Each
+ key points to a string or iterable of strings specifying a set of models
+ and backends that are failing.
+ tf_hub_url:
+ a string pointing to the TF Hub base url of the models to test.
+ tags:
+ tags to apply to the test. Note that as in standard test suites, manual
+ is treated specially and will also apply to the test suite itself.
+ deps:
+ test dependencies.
+ size:
+ size of the tests. Default: "large".
+ python_version:
+ the python version to run the tests with. Uses python3 by default.
+ **kwargs:
+ any additional arguments that will be passed to the underlying tests and
+ test_suite.
+ """
+ failing_set = sets.make([])
+ if failing_configurations != None:
+ # Parse failing configurations.
+ for configuration in failing_configurations:
+ # Normalize configuration input.
+ # {backend: "iree_llvmjit"} -> {backend: ["iree_llvmjit"]}
+ for key, value in configuration.items():
+ if type(value) == type(""):
+ configuration[key] = [value]
+
+ for model in configuration["models"]:
+ for backend in configuration["backends"]:
+ sets.insert(failing_set, (model, backend))
+
+ tests = []
+ for model in models:
+ for backend in backends:
+ # Check if this is a failing configuration.
+ failing = sets.contains(failing_set, (model, backend))
+
+ # Append "_failing" to name if this is a failing configuration.
+ test_name = name if not failing else name + "_failing"
+ test_name = "{}_{}__{}__{}".format(
+ test_name,
+ model,
+ reference_backend,
+ backend,
+ )
+ tests.append(test_name)
+
+ args = [
+ "--model={}".format(model),
+ "--tf_hub_url={}".format(tf_hub_url),
+ "--reference_backend={}".format(reference_backend),
+ "--target_backends={}".format(backend),
+ ]
+
+ # TODO(GH-2175): Simplify this after backend names are
+ # standardized.
+ # "iree_<driver>" --> "<driver>"
+ driver = backend.replace("iree_", "")
+ if driver == "llvmjit":
+ driver = "llvm"
+ py_test_tags = ["driver={}".format(driver)]
+ if tags != None: # `is` is not supported.
+ py_test_tags += tags
+
+ # Add additional tags if this is a failing configuration.
+ if failing:
+ py_test_tags += [
+ "failing", # Only used for test_suite filtering below.
+ "manual",
+ "nokokoro",
+ "notap",
+ ]
+
+ iree_py_test(
+ name = test_name,
+ main = "slim_vision_model_test.py",
+ srcs = ["slim_vision_model_test.py"],
+ args = args,
+ tags = py_test_tags,
+ deps = deps,
+ size = size,
+ python_version = python_version,
+ **kwargs
+ )
+
+ native.test_suite(
+ name = name,
+ tests = tests,
+ # Add "-failing" to exclude tests in `tests` that have the "failing"
+ # tag.
+ tags = tags + ["-failing"],
+ # If there are kwargs that need to be passed here which only apply to
+ # the generated tests and not to test_suite, they should be extracted
+ # into separate named arguments.
+ **kwargs
+ )
+
+ if failing_configurations != None:
+ native.test_suite(
+ name = name + "_failing",
+ tests = tests,
+ # Add "+failing" to only include tests in `tests` that have the
+ # "failing" tag.
+ tags = tags + ["+failing"],
+ # If there are kwargs that need to be passed here which only apply
+ # to the generated tests and not to test_suite, they should be
+ # extracted into separate named arguments.
+ **kwargs
+ )
diff --git a/integrations/tensorflow/e2e/slim_vision_models/slim_vision_model_test.py b/integrations/tensorflow/e2e/slim_vision_models/slim_vision_model_test.py
new file mode 100644
index 0000000..fcc9ff6
--- /dev/null
+++ b/integrations/tensorflow/e2e/slim_vision_models/slim_vision_model_test.py
@@ -0,0 +1,88 @@
+# Lint as: python3
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Test all vision models from slim lib."""
+
+import posixpath
+
+from absl import app
+from absl import flags
+import numpy as np
+from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
+import tensorflow as tf
+import tensorflow_hub as hub
+
+FLAGS = flags.FLAGS
+
+# Testing vision models from
+# https://github.com/tensorflow/models/tree/master/research/slim
+# slim models were designed with tf v1 and then coverted to SavedModel
+# they are stored at tensorflow_hub.
+flags.DEFINE_string(
+ 'model', 'mobilenet_v1_100_224', 'example model names: '
+ '[resnet_v1_50, resnet_v1_101, resnet_v2_50, resnet_v2_101, '
+ 'mobilenet_v1_100_224, mobilenet_v1_025_224, mobilenet_v2_100_224, '
+ 'mobilenet_v2_035_224]\nAt least a subset can be viewed here:\n'
+ 'https://tfhub.dev/s?dataset=imagenet&module-type=image-classification,image-classifier'
+)
+flags.DEFINE_string('tf_hub_url', None,
+ 'Base URL for the models to test. URL at the time of '
+ 'writing:\nhttps://tfhub.dev/google/imagenet/')
+
+# Classification mode; 4 - is a format of the model (SavedModel TF v2).
+MODE = 'classification/4'
+INPUT_SHAPE = (1, 224, 224, 3)
+
+
+class SlimVisionModule(tf.Module):
+
+ def __init__(self):
+ super(SlimVisionModule, self).__init__()
+ tf_utils.set_random_seed()
+ model_path = posixpath.join(FLAGS.tf_hub_url, FLAGS.model, MODE)
+ hub_layer = hub.KerasLayer(model_path)
+ self.m = tf.keras.Sequential([hub_layer])
+ self.m.build(INPUT_SHAPE)
+ self.predict = tf.function(input_signature=[tf.TensorSpec(INPUT_SHAPE)])(
+ self.m.call)
+
+
+class SlimVisionTest(tf_test_utils.TracedModuleTestCase):
+
+ def __init__(self, methodName="runTest"):
+ super(SlimVisionTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(SlimVisionModule,
+ exported_names=['predict'])
+
+ def test_predict(self):
+
+ def predict(module):
+ input_data = np.random.rand(*INPUT_SHAPE).astype(np.float32)
+ module.predict(input_data, atol=2e-5)
+
+ self.compare_backends(predict, self._modules)
+
+
+def main(argv):
+ del argv # Unused.
+ if hasattr(tf, 'enable_v2_behavior'):
+ tf.enable_v2_behavior()
+
+ SlimVisionModule.__name__ = FLAGS.model
+ tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/scripts/update_e2e_coverage.py b/scripts/update_e2e_coverage.py
index 7b44d2d..15280f4 100755
--- a/scripts/update_e2e_coverage.py
+++ b/scripts/update_e2e_coverage.py
@@ -41,6 +41,8 @@
'End to end tests written using tf.keras',
'//integrations/tensorflow/e2e/keras:vision_external_tests':
'End to end tests of tf.keras.applications vision models',
+ '//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests':
+ 'End to end tests of TensorFlow slim vision models',
}
# Some test suites are generated from a single source. This allows us to point
@@ -48,6 +50,8 @@
SINGLE_SOURCE_SUITES = {
'//integrations/tensorflow/e2e/keras:vision_external_tests':
'vision_model_test',
+ '//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests':
+ 'slim_vision_model_test',
}
# The symbols to show in the table if the operation is supported or not.