Open source slim vision model tests (#3336)

diff --git a/integrations/tensorflow/e2e/slim_vision_models/BUILD b/integrations/tensorflow/e2e/slim_vision_models/BUILD
new file mode 100644
index 0000000..14183cc
--- /dev/null
+++ b/integrations/tensorflow/e2e/slim_vision_models/BUILD
@@ -0,0 +1,231 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Test coverage across backends for e2e tests is defined directly in the BUILD
+# files. A coverage table generated from this file can be viewed here:
+#   https://google.github.io/iree/tf-e2e-coverage
+# Updates made to test suite names should also be reflected here:
+#   https://github.com/google/iree/blob/main/scripts/update_e2e_coverage.py
+
+load(
+    "//bindings/python:build_defs.oss.bzl",
+    "INTREE_TENSORFLOW_PY_DEPS",
+    "NUMPY_DEPS",
+    "iree_py_binary",
+)
+load(
+    "//integrations/tensorflow/e2e/slim_vision_models:iree_slim_vision_test_suite.bzl",
+    "iree_slim_vision_test_suite",
+)
+
+package(
+    default_visibility = ["//visibility:public"],
+    features = ["layering_check"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+# Create binaries for all test srcs to allow them to be run manually.
+iree_py_binary(
+    name = "slim_vision_model_test_manual",
+    srcs = ["slim_vision_model_test.py"],
+    args = ["--tf_hub_url=https://tfhub.dev/google/imagenet/"],
+    main = "slim_vision_model_test.py",
+    python_version = "PY3",
+    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+    ],
+)
+
+iree_slim_vision_test_suite(
+    name = "slim_vision_tests",
+    backends = [
+        "tf",
+        "tflite",
+        "iree_vmla",
+        "iree_llvmjit",
+        "iree_vulkan",
+    ],
+    failing_configurations = [
+        {
+            # Failing all but tf and vmla:
+            "models": [
+                "inception_resnet_v2",
+                # tflite: RuntimeError: tensorflow/lite/kernels/reshape.cc:66 num_input_elements != num_output_elements (38400 != -1571481807)Node number 333 (RESHAPE) failed to prepare.
+                # llvmjit: *** Received signal 6 *** (mangled stack trace)
+                # vulkan: Floating point difference between ref and tar was too large. Max abs diff: 1.3961234, atol: 2e-05, max relative diff: 0.27304956, rtol: 1e-06
+                "resnet_v2_101",
+                # tflite: RuntimeError: tensorflow/lite/core/subgraph.cc BytesRequired number of elements overflowed.
+                # llvmjit: Floating point difference between ref and tar was too large. Max abs diff: 11.668068, atol: 2e-05, max relative diff: 0.93950737, rtol: 1e-06
+                # vulkan: Floating point difference between ref and tar was too large. Max abs diff: 11.668067, atol: 2e-05, max relative diff: 0.93950737, rtol: 1e-06
+                "resnet_v2_152",
+                # tflite: RuntimeError: tensorflow/lite/core/subgraph.cc BytesRequired number of elements overflowed.
+                # llvmjit: Floating point difference between ref and tar was too large. Max abs diff: 7.080696, atol: 2e-05, max relative diff: 0.97750616, rtol: 1e-06
+                # vulkan: Floating point difference between ref and tar was too large. Max abs diff: 7.08069, atol: 2e-05, max relative diff: 0.97750485, rtol: 1e-06
+            ],
+            "backends": [
+                "tflite",
+                "iree_llvmjit",
+                "iree_vulkan",
+            ],
+        },
+        {
+            # Failing llvmjit and vulkan:
+            "models": [
+                "inception_v2",
+                # llvmjit: double free or corruption (!prev); *** Received signal 6 ***
+                # vulkan: Floating point difference between ref and tar was too large. Max abs diff: 1.0769763, atol: 2e-05, max relative diff: 0.19576924, rtol: 1e-06
+                "inception_v3",
+                # llvmjit: double free or corruption (!prev); *** Received signal 6 ***
+                # vulkan: Floating point difference between ref and tar was too large. Max abs diff: 2.5201874, atol: 2e-05, max relative diff: 0.53700095, rtol: 1e-06
+                "nasnet_mobile",
+                # llvmjit: corrupted size vs. prev_size; *** Received signal 6 ***
+                # vulkan: *** Received signal 11 ***
+                "resnet_v2_50",
+                # llvmjit: Floating point difference between ref and tar was too large. Max abs diff: 5.8187943, atol: 2e-05, max relative diff: 0.7946711, rtol: 1e-06
+                # vulkan: Floating point difference between ref and tar was too large. Max abs diff: 5.8187933, atol: 2e-05, max relative diff: 0.79467094, rtol: 1e-06
+            ],
+            "backends": [
+                "iree_llvmjit",
+                "iree_vulkan",
+            ],
+        },
+    ],
+    models = [
+        "inception_resnet_v2",
+        "inception_v1",
+        "inception_v2",
+        "inception_v3",
+        "mobilenet_v1_025_128",
+        "mobilenet_v1_025_160",
+        "mobilenet_v1_025_192",
+        "mobilenet_v1_025_224",
+        "mobilenet_v1_050_128",
+        "mobilenet_v1_050_160",
+        "mobilenet_v1_050_192",
+        "mobilenet_v1_050_224",
+        "mobilenet_v1_075_128",
+        "mobilenet_v1_075_160",
+        "mobilenet_v1_075_192",
+        "mobilenet_v1_075_224",
+        "mobilenet_v1_100_128",
+        "mobilenet_v1_100_160",
+        "mobilenet_v1_100_192",
+        "mobilenet_v1_100_224",
+        "mobilenet_v2_035_224",
+        "mobilenet_v2_050_224",
+        "mobilenet_v2_075_224",
+        "mobilenet_v2_100_224",
+        "mobilenet_v2_130_224",
+        "mobilenet_v2_140_224",
+        "nasnet_mobile",
+        "resnet_v1_101",
+        "resnet_v1_152",
+        "resnet_v1_50",
+        "resnet_v2_101",
+        "resnet_v2_152",
+        "resnet_v2_50",
+    ],
+    reference_backend = "tf",
+    tags = [
+        "external",
+        "guitar",
+        "manual",
+        "no-remote",
+        "nokokoro",
+        "notap",
+    ],
+    tf_hub_url = "https://tfhub.dev/google/imagenet/",
+    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+    ],
+)
+
+# TODO(meadowlark): Get these working on tf.
+iree_slim_vision_test_suite(
+    name = "slim_vision_ill_configured_tests",
+    backends = [
+        "tf",
+        "tflite",
+        "iree_vmla",
+        "iree_llvmjit",
+        "iree_vulkan",
+    ],
+    failing_configurations = [
+        {
+            # Failing on all backends:
+            "models": [
+                "amoebanet_a_n18_f448",
+                "mobilenet_v2_035_128",
+                "mobilenet_v2_035_160",
+                "mobilenet_v2_035_192",
+                "mobilenet_v2_035_96",
+                "mobilenet_v2_050_128",
+                "mobilenet_v2_050_160",
+                "mobilenet_v2_050_192",
+                "mobilenet_v2_050_96",
+                "mobilenet_v2_075_128",
+                "mobilenet_v2_075_160",
+                "mobilenet_v2_075_192",
+                "mobilenet_v2_075_96",
+                "mobilenet_v2_100_128",
+                "mobilenet_v2_100_160",
+                "mobilenet_v2_100_192",
+                "mobilenet_v2_100_96",
+                "nasnet_large",
+                "pnasnet_large",
+            ],
+            "backends": [
+                "tf",
+                "tflite",
+                "iree_vmla",
+                "iree_llvmjit",
+                "iree_vulkan",
+            ],
+        },
+    ],
+    models = [
+        "amoebanet_a_n18_f448",
+        "mobilenet_v2_035_128",
+        "mobilenet_v2_035_160",
+        "mobilenet_v2_035_192",
+        "mobilenet_v2_035_96",
+        "mobilenet_v2_050_128",
+        "mobilenet_v2_050_160",
+        "mobilenet_v2_050_192",
+        "mobilenet_v2_050_96",
+        "mobilenet_v2_075_128",
+        "mobilenet_v2_075_160",
+        "mobilenet_v2_075_192",
+        "mobilenet_v2_075_96",
+        "mobilenet_v2_100_128",
+        "mobilenet_v2_100_160",
+        "mobilenet_v2_100_192",
+        "mobilenet_v2_100_96",
+        "nasnet_large",
+        "pnasnet_large",
+    ],
+    reference_backend = "tf",
+    tags = [
+        "external",
+        "guitar",
+        "manual",
+        "no-remote",
+        "nokokoro",
+        "notap",
+    ],
+    tf_hub_url = "https://tfhub.dev/google/imagenet/",
+    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+    ],
+)
diff --git a/integrations/tensorflow/e2e/slim_vision_models/README.md b/integrations/tensorflow/e2e/slim_vision_models/README.md
new file mode 100644
index 0000000..0a0b5dd
--- /dev/null
+++ b/integrations/tensorflow/e2e/slim_vision_models/README.md
@@ -0,0 +1,10 @@
+# Slim Vision Model Tests
+
+These tests require an additional python dependency on `tensorflow_hub`, which
+can be installed as follows:
+
+```shell
+python3 -m pip install tensorflow_hub
+```
+
+Like the `vision_external_tests`, these tests are not checked by the OSS CI.
diff --git a/integrations/tensorflow/e2e/slim_vision_models/iree_slim_vision_test_suite.bzl b/integrations/tensorflow/e2e/slim_vision_models/iree_slim_vision_test_suite.bzl
new file mode 100644
index 0000000..857c1b4
--- /dev/null
+++ b/integrations/tensorflow/e2e/slim_vision_models/iree_slim_vision_test_suite.bzl
@@ -0,0 +1,159 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Macro for building e2e keras vision model tests."""
+
+load("//bindings/python:build_defs.oss.bzl", "iree_py_test")
+load("@bazel_skylib//lib:new_sets.bzl", "sets")
+
+def iree_slim_vision_test_suite(
+        name,
+        models,
+        backends,
+        reference_backend,
+        failing_configurations = None,
+        tf_hub_url = None,
+        tags = None,
+        deps = None,
+        size = "large",
+        python_version = "PY3",
+        **kwargs):
+    """Creates a test for each configuration and bundles a succeeding and failing test suite.
+
+    Creates one test per model and backend. Tests indicated in
+    `failing_configurations` are bundled into a suite suffixed with "_failing"
+    tagged to be excluded from CI and wildcard builds. All other tests are
+    bundled into a suite with the same name as the macro.
+
+    Args:
+      name:
+        name of the generated passing test suite. If failing_configurations is
+        not `None` then a test suite named name_failing will also be generated.
+      models:
+        an iterable of slim vision model tags to generate targets for.
+      backends:
+        an iterable of targets backends to generate targets for.
+      reference_backend:
+        the backend to use as a source of truth for the expected output results.
+      failing_configurations:
+        an iterable of dictionaries with the keys `models` and `backends`. Each
+        key points to a string or iterable of strings specifying a set of models
+        and backends that are failing.
+      tf_hub_url:
+        a string pointing to the TF Hub base url of the models to test.
+      tags:
+        tags to apply to the test. Note that as in standard test suites, manual
+        is treated specially and will also apply to the test suite itself.
+      deps:
+        test dependencies.
+      size:
+        size of the tests. Default: "large".
+      python_version:
+        the python version to run the tests with. Uses python3 by default.
+      **kwargs:
+        any additional arguments that will be passed to the underlying tests and
+        test_suite.
+    """
+    failing_set = sets.make([])
+    if failing_configurations != None:
+        # Parse failing configurations.
+        for configuration in failing_configurations:
+            # Normalize configuration input.
+            # {backend: "iree_llvmjit"} -> {backend: ["iree_llvmjit"]}
+            for key, value in configuration.items():
+                if type(value) == type(""):
+                    configuration[key] = [value]
+
+            for model in configuration["models"]:
+                for backend in configuration["backends"]:
+                    sets.insert(failing_set, (model, backend))
+
+    tests = []
+    for model in models:
+        for backend in backends:
+            # Check if this is a failing configuration.
+            failing = sets.contains(failing_set, (model, backend))
+
+            # Append "_failing" to name if this is a failing configuration.
+            test_name = name if not failing else name + "_failing"
+            test_name = "{}_{}__{}__{}".format(
+                test_name,
+                model,
+                reference_backend,
+                backend,
+            )
+            tests.append(test_name)
+
+            args = [
+                "--model={}".format(model),
+                "--tf_hub_url={}".format(tf_hub_url),
+                "--reference_backend={}".format(reference_backend),
+                "--target_backends={}".format(backend),
+            ]
+
+            # TODO(GH-2175): Simplify this after backend names are
+            # standardized.
+            # "iree_<driver>" --> "<driver>"
+            driver = backend.replace("iree_", "")
+            if driver == "llvmjit":
+                driver = "llvm"
+            py_test_tags = ["driver={}".format(driver)]
+            if tags != None:  # `is` is not supported.
+                py_test_tags += tags
+
+            # Add additional tags if this is a failing configuration.
+            if failing:
+                py_test_tags += [
+                    "failing",  # Only used for test_suite filtering below.
+                    "manual",
+                    "nokokoro",
+                    "notap",
+                ]
+
+            iree_py_test(
+                name = test_name,
+                main = "slim_vision_model_test.py",
+                srcs = ["slim_vision_model_test.py"],
+                args = args,
+                tags = py_test_tags,
+                deps = deps,
+                size = size,
+                python_version = python_version,
+                **kwargs
+            )
+
+    native.test_suite(
+        name = name,
+        tests = tests,
+        # Add "-failing" to exclude tests in `tests` that have the "failing"
+        # tag.
+        tags = tags + ["-failing"],
+        # If there are kwargs that need to be passed here which only apply to
+        # the generated tests and not to test_suite, they should be extracted
+        # into separate named arguments.
+        **kwargs
+    )
+
+    if failing_configurations != None:
+        native.test_suite(
+            name = name + "_failing",
+            tests = tests,
+            # Add "+failing" to only include tests in `tests` that have the
+            # "failing" tag.
+            tags = tags + ["+failing"],
+            # If there are kwargs that need to be passed here which only apply
+            # to the generated tests and not to test_suite, they should be
+            # extracted into separate named arguments.
+            **kwargs
+        )
diff --git a/integrations/tensorflow/e2e/slim_vision_models/slim_vision_model_test.py b/integrations/tensorflow/e2e/slim_vision_models/slim_vision_model_test.py
new file mode 100644
index 0000000..fcc9ff6
--- /dev/null
+++ b/integrations/tensorflow/e2e/slim_vision_models/slim_vision_model_test.py
@@ -0,0 +1,88 @@
+# Lint as: python3
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Test all vision models from slim lib."""
+
+import posixpath
+
+from absl import app
+from absl import flags
+import numpy as np
+from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
+import tensorflow as tf
+import tensorflow_hub as hub
+
+FLAGS = flags.FLAGS
+
+# Testing vision models from
+# https://github.com/tensorflow/models/tree/master/research/slim
+# slim models were designed with tf v1 and then coverted to SavedModel
+# they are stored at tensorflow_hub.
+flags.DEFINE_string(
+    'model', 'mobilenet_v1_100_224', 'example model names: '
+    '[resnet_v1_50, resnet_v1_101, resnet_v2_50, resnet_v2_101, '
+    'mobilenet_v1_100_224, mobilenet_v1_025_224, mobilenet_v2_100_224, '
+    'mobilenet_v2_035_224]\nAt least a subset can be viewed here:\n'
+    'https://tfhub.dev/s?dataset=imagenet&module-type=image-classification,image-classifier'
+)
+flags.DEFINE_string('tf_hub_url', None,
+                    'Base URL for the models to test. URL at the time of '
+                    'writing:\nhttps://tfhub.dev/google/imagenet/')
+
+# Classification mode; 4 - is a format of the model (SavedModel TF v2).
+MODE = 'classification/4'
+INPUT_SHAPE = (1, 224, 224, 3)
+
+
+class SlimVisionModule(tf.Module):
+
+  def __init__(self):
+    super(SlimVisionModule, self).__init__()
+    tf_utils.set_random_seed()
+    model_path = posixpath.join(FLAGS.tf_hub_url, FLAGS.model, MODE)
+    hub_layer = hub.KerasLayer(model_path)
+    self.m = tf.keras.Sequential([hub_layer])
+    self.m.build(INPUT_SHAPE)
+    self.predict = tf.function(input_signature=[tf.TensorSpec(INPUT_SHAPE)])(
+        self.m.call)
+
+
+class SlimVisionTest(tf_test_utils.TracedModuleTestCase):
+
+  def __init__(self, methodName="runTest"):
+    super(SlimVisionTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(SlimVisionModule,
+                                                    exported_names=['predict'])
+
+  def test_predict(self):
+
+    def predict(module):
+      input_data = np.random.rand(*INPUT_SHAPE).astype(np.float32)
+      module.predict(input_data, atol=2e-5)
+
+    self.compare_backends(predict, self._modules)
+
+
+def main(argv):
+  del argv  # Unused.
+  if hasattr(tf, 'enable_v2_behavior'):
+    tf.enable_v2_behavior()
+
+  SlimVisionModule.__name__ = FLAGS.model
+  tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/scripts/update_e2e_coverage.py b/scripts/update_e2e_coverage.py
index 7b44d2d..15280f4 100755
--- a/scripts/update_e2e_coverage.py
+++ b/scripts/update_e2e_coverage.py
@@ -41,6 +41,8 @@
         'End to end tests written using tf.keras',
     '//integrations/tensorflow/e2e/keras:vision_external_tests':
         'End to end tests of tf.keras.applications vision models',
+    '//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests':
+        'End to end tests of TensorFlow slim vision models',
 }
 
 # Some test suites are generated from a single source. This allows us to point
@@ -48,6 +50,8 @@
 SINGLE_SOURCE_SUITES = {
     '//integrations/tensorflow/e2e/keras:vision_external_tests':
         'vision_model_test',
+    '//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests':
+        'slim_vision_model_test',
 }
 
 # The symbols to show in the table if the operation is supported or not.