Add new tf.keras.applications models (#3958)

Add's the `EfficientNet*` and `MobileNetV3*` models and moves the tests into `.../e2e/keras/applications`.
diff --git a/integrations/tensorflow/e2e/keras/BUILD b/integrations/tensorflow/e2e/keras/BUILD
index 6ecd6b5..89f92ad 100644
--- a/integrations/tensorflow/e2e/keras/BUILD
+++ b/integrations/tensorflow/e2e/keras/BUILD
@@ -28,10 +28,6 @@
     "//integrations/tensorflow/e2e:iree_e2e_cartesian_product_test_suite.bzl",
     "iree_e2e_cartesian_product_test_suite",
 )
-load(
-    "//integrations/tensorflow/e2e:iree_e2e_test_suite.bzl",
-    "iree_e2e_test_suite",
-)
 
 package(
     default_visibility = ["//visibility:public"],
@@ -39,33 +35,6 @@
     licenses = ["notice"],  # Apache 2.0
 )
 
-# @unused
-DOC = """
-vision_model_test_manual is for manual testing of all keras vision models.
-Test will run only manually with all parameters specified manually, for example:
-bazel run -c opt integrations/tensorflow/e2e/keras:vision_model_test_manual -- \
---target_backends=tf,iree_vmla \
---data=imagenet \
---url=https://storage.googleapis.com/iree_models/ \
---model=ResNet50
-
-Command arguments description:
---target_backends: can be combination of these: tf,iree_vmla
---data: can be 'imagenet' or 'cifar10'.
-    imagenet - input image size (1, 224, 224, 3)
-    cifar10 - input image size (1, 32, 32, 3) - it is used for quick tests
-            and needs pretrained weights, we pretrained models: ResNet50, MobileNet, MobileNetV2
---include_top: Whether or not to include the final (top) layers of the model.
---url: we need it only for cifar10 models to load weights from https://storage.googleapis.com/iree_models/
-       imagenet pretrained weights url is specified by keras
---model: supports ResNet50, MobileNet, MobileNetV2, ResNet101, ResNet152,
-    ResNet50V2, ResNet101V2, ResNet152V2, VGG16, VGG19, Xception,
-    InceptionV3, InceptionResNetV2, DenseNet121, DenseNet169,
-    DenseNet201, NASNetMobile, NASNetLarge
-    All above models works with 'imagenet' data sets.
-    ResNet50, MobileNet, MobileNetV2 work with both 'imagenet' and 'cifar10' data sets.
-"""
-
 [
     iree_py_binary(
         name = src.replace(".py", "_manual"),
@@ -82,308 +51,6 @@
     )
 ]
 
-SPECIAL_CASES = [
-    "keyword_spotting_streaming_test.py",
-    "vision_model_test.py",
-]
-
-TFLITE_FAILING = []
-
-VMLA_FAILING = []
-
-LLVM_FAILING = []
-
-VULKAN_FAILING = []
-
-TF_PASSING = glob(
-    ["*_test.py"],
-    exclude = SPECIAL_CASES,
-)
-
-TFLITE_PASSING = glob(
-    ["*_test.py"],
-    exclude = TFLITE_FAILING + SPECIAL_CASES,
-)
-
-VMLA_PASSING = glob(
-    ["*_test.py"],
-    exclude = VMLA_FAILING + SPECIAL_CASES,
-)
-
-LLVM_PASSING = glob(
-    ["*_test.py"],
-    exclude = LLVM_FAILING + SPECIAL_CASES,
-)
-
-VULKAN_PASSING = glob(
-    ["*_test.py"],
-    exclude = VULKAN_FAILING + SPECIAL_CASES,
-)
-
-iree_e2e_test_suite(
-    name = "keras_tests",
-    backends_to_srcs = {
-        "tf": TF_PASSING,
-        "tflite": TFLITE_PASSING,
-        "iree_vmla": VMLA_PASSING,
-        "iree_llvmjit": LLVM_PASSING,
-        "iree_vulkan": VULKAN_PASSING,
-    },
-    reference_backend = "tf",
-    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
-        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
-    ],
-)
-
-iree_e2e_test_suite(
-    name = "keras_tests_failing",
-    backends_to_srcs = {
-        "tflite": TFLITE_FAILING,
-        "iree_vmla": VMLA_FAILING,
-        "iree_llvmjit": LLVM_FAILING,
-        "iree_vulkan": VULKAN_FAILING,
-    },
-    reference_backend = "tf",
-    tags = [
-        "failing",
-        "manual",
-        "nokokoro",
-        "notap",
-    ],
-    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
-        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
-    ],
-)
-
-iree_e2e_cartesian_product_test_suite(
-    name = "large_cifar10_tests",
-    size = "large",
-    srcs = ["vision_model_test.py"],
-    flags_to_values = {
-        "reference_backend": "tf",
-        "data": "cifar10",
-        "model": [
-            # All models with runtime shorter than ResNet50.
-            "MobileNet",  # Max: Vulkan 61.0s
-            "MobileNetV2",  # Max: LLVM 96.3s
-            "ResNet50",  # Max: LLVM 145.6s
-            "VGG16",  # Max: LLVM 89.5s
-            "VGG19",  # Max: LLVM 94.7s
-        ],
-        "target_backends": [
-            "tf",
-            "tflite",
-            "iree_vmla",
-            "iree_llvmjit",
-            "iree_vulkan",
-        ],
-    },
-    main = "vision_model_test.py",
-    tags = ["manual"],
-    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
-        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
-    ],
-)
-
-iree_e2e_cartesian_product_test_suite(
-    name = "enormous_cifar10_tests",
-    size = "enormous",
-    srcs = ["vision_model_test.py"],
-    failing_configurations = [
-        {
-            # Failing on vmla with negative inputs.
-            "model": [
-                "NASNetLarge",
-                "NASNetMobile",
-            ],
-            "target_backends": "iree_vmla",
-        },
-        {
-            # Failing on llvm and vulkan:
-            "model": [
-                "NASNetLarge",
-                "NASNetMobile",
-                "ResNet50V2",
-                "ResNet101V2",
-                "ResNet152V2",
-            ],
-            "target_backends": [
-                "iree_llvmjit",
-                "iree_vulkan",
-            ],
-        },
-    ],
-    flags_to_values = {
-        "reference_backend": "tf",
-        "data": "cifar10",
-        "model": [
-            "DenseNet121",
-            "DenseNet169",
-            "DenseNet201",
-            "NASNetLarge",
-            "NASNetMobile",
-            "ResNet50V2",
-            "ResNet101",
-            "ResNet101V2",
-            "ResNet152",
-            "ResNet152V2",
-        ],
-        "target_backends": [
-            "tf",
-            "tflite",
-            "iree_vmla",
-            "iree_llvmjit",
-            "iree_vulkan",
-        ],
-    },
-    main = "vision_model_test.py",
-    tags = [
-        "guitar",
-        "manual",
-        "nokokoro",
-        "notap",
-    ],
-    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
-        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
-    ],
-)
-
-# 'non_hermetic' tests use real model weights to test numerical correctness.
-iree_e2e_cartesian_product_test_suite(
-    name = "cifar10_non_hermetic_tests",
-    size = "large",
-    srcs = ["vision_model_test.py"],
-    flags_to_values = {
-        "reference_backend": "tf",
-        "data": "cifar10",
-        "url": "https://storage.googleapis.com/iree_models/",
-        "use_external_weights": True,
-        "model": [
-            "MobileNet",
-            "MobileNetV2",
-            "ResNet50",
-        ],
-        "target_backends": [
-            "tf",
-            "tflite",
-            "iree_vmla",
-            "iree_llvmjit",
-            "iree_vulkan",
-        ],
-    },
-    main = "vision_model_test.py",
-    tags = [
-        "external",
-        "guitar",
-        "manual",
-        "no-remote",
-        "nokokoro",
-        "notap",
-    ],
-    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
-        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
-    ],
-)
-
-# 'non_hermetic' tests use real model weights to test numerical correctness.
-iree_e2e_cartesian_product_test_suite(
-    name = "imagenet_non_hermetic_tests",
-    size = "enormous",
-    srcs = ["vision_model_test.py"],
-    failing_configurations = [
-        {
-            # Failing on vmla with negative inputs.
-            "model": [
-                "NASNetLarge",
-                "NASNetMobile",
-            ],
-            "target_backends": "iree_vmla",
-        },
-        {
-            # Failing vulkan:
-            "model": [
-                "InceptionResNetV2",
-                "InceptionV3",
-            ],
-            "target_backends": [
-                "iree_vulkan",
-            ],
-        },
-        {
-            # Failing llvm and vulkan:
-            "model": [
-                "NASNetLarge",
-                "NASNetMobile",
-                "ResNet50V2",
-                "ResNet101V2",
-                "ResNet152V2",
-                "Xception",
-            ],
-            "target_backends": [
-                "iree_llvmjit",
-                "iree_vulkan",
-            ],
-        },
-    ],
-    flags_to_values = {
-        "reference_backend": "tf",
-        "data": "imagenet",
-        "use_external_weights": True,
-        "model": [
-            "DenseNet121",
-            "DenseNet169",
-            "DenseNet201",
-            "InceptionResNetV2",
-            "InceptionV3",
-            "MobileNet",
-            "MobileNetV2",
-            "NASNetLarge",
-            "NASNetMobile",
-            "ResNet50",
-            "ResNet50V2",
-            "ResNet101",
-            "ResNet101V2",
-            "ResNet152",
-            "ResNet152V2",
-            "VGG16",
-            "VGG19",
-            "Xception",
-        ],
-        "target_backends": [
-            "tf",
-            "tflite",
-            "iree_vmla",
-            "iree_llvmjit",
-            "iree_vulkan",
-        ],
-    },
-    main = "vision_model_test.py",
-    tags = [
-        "external",
-        "guitar",
-        "manual",
-        "nokokoro",
-        "notap",
-    ],
-    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
-        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
-    ],
-)
-
-# It is used to produce weights for keras vision models with input image size
-# 32x32. These models are not optimized for accuracy or latency (they are for
-# debugging only). They have the same neural net topology with keras vision
-# models trained on imagenet data sets
-iree_py_binary(
-    name = "train_vision_models_on_cifar",
-    srcs = ["train_vision_models_on_cifar.py"],
-    python_version = "PY3",
-    srcs_version = "PY2AND3",
-    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
-        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
-    ],
-)
-
 # Keyword Spotting Tests:
 KEYWORD_SPOTTING_MODELS = [
     "svdf",
diff --git a/integrations/tensorflow/e2e/keras/applications/BUILD b/integrations/tensorflow/e2e/keras/applications/BUILD
new file mode 100644
index 0000000..a9df330
--- /dev/null
+++ b/integrations/tensorflow/e2e/keras/applications/BUILD
@@ -0,0 +1,318 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Test coverage across backends for e2e tests is defined directly in the BUILD
+# files. Coverage tables generated from this file can be viewed here:
+#   https://google.github.io/iree/tensorflow-coverage/vision-coverage
+# Updates made to test suite names should also be reflected here:
+#   https://github.com/google/iree/blob/main/scripts/update_e2e_coverage.py
+
+load(
+    "//bindings/python:build_defs.oss.bzl",
+    "INTREE_TENSORFLOW_PY_DEPS",
+    "NUMPY_DEPS",
+    "iree_py_binary",
+)
+load(
+    "//integrations/tensorflow/e2e:iree_e2e_cartesian_product_test_suite.bzl",
+    "iree_e2e_cartesian_product_test_suite",
+)
+
+package(
+    default_visibility = ["//visibility:public"],
+    features = ["layering_check"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+# @unused
+DOC = """
+applications_test_manual is for manual testing of all keras vision models.
+Test will run only manually with all parameters specified manually, for example:
+bazel run -c opt integrations/tensorflow/e2e/keras/applications:applications_test_manual -- \
+--target_backends=tf,iree_vmla \
+--data=imagenet \
+--url=https://storage.googleapis.com/iree_models/ \
+--model=ResNet50
+
+Command arguments description:
+--target_backends: can be combination of these: tf,iree_vmla
+--data: can be 'imagenet' or 'cifar10'.
+    imagenet - input image size (1, 224, 224, 3)
+    cifar10 - input image size (1, 32, 32, 3) - it is used for quick tests
+            and needs pretrained weights, we pretrained models: ResNet50, MobileNet, MobileNetV2
+--include_top: Whether or not to include the final (top) layers of the model.
+--url: we need it only for cifar10 models to load weights from https://storage.googleapis.com/iree_models/
+       imagenet pretrained weights url is specified by keras
+--model: supports ResNet50, MobileNet, MobileNetV2, ResNet101, ResNet152,
+    ResNet50V2, ResNet101V2, ResNet152V2, VGG16, VGG19, Xception,
+    InceptionV3, InceptionResNetV2, DenseNet121, DenseNet169,
+    DenseNet201, NASNetMobile, NASNetLarge
+    All above models works with 'imagenet' data sets.
+    ResNet50, MobileNet, MobileNetV2 work with both 'imagenet' and 'cifar10' data sets.
+"""
+
+[
+    iree_py_binary(
+        name = src.replace(".py", "_manual"),
+        srcs = [src],
+        main = src,
+        python_version = "PY3",
+        deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+            "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+        ],
+    )
+    for src in glob(
+        ["*_test.py"],
+    )
+]
+
+KERAS_APPLICATIONS_MODELS = [
+    "DenseNet121",
+    "DenseNet169",
+    "DenseNet201",
+    "EfficientNetB0",
+    "EfficientNetB1",
+    "EfficientNetB2",
+    "EfficientNetB3",
+    "EfficientNetB4",
+    "EfficientNetB5",
+    "EfficientNetB6",
+    "EfficientNetB7",
+    "InceptionResNetV2",
+    "InceptionV3",
+    "MobileNet",
+    "MobileNetV2",
+    "MobileNetV3Large",
+    "MobileNetV3Small",
+    "NASNetLarge",
+    "NASNetMobile",
+    "ResNet101",
+    "ResNet101V2",
+    "ResNet152",
+    "ResNet152V2",
+    "ResNet50",
+    "ResNet50V2",
+    "VGG16",
+    "VGG19",
+]
+
+iree_e2e_cartesian_product_test_suite(
+    name = "large_cifar10_tests",
+    size = "large",
+    srcs = ["applications_test.py"],
+    flags_to_values = {
+        "reference_backend": "tf",
+        "data": "cifar10",
+        "model": [
+            # All models with runtime shorter than ResNet50.
+            "MobileNet",  # Max: Vulkan 61.0s
+            "MobileNetV2",  # Max: LLVM 96.3s
+            "ResNet50",  # Max: LLVM 145.6s
+            "VGG16",  # Max: LLVM 89.5s
+            "VGG19",  # Max: LLVM 94.7s
+        ],
+        "target_backends": [
+            "tf",
+            "tflite",
+            "iree_vmla",
+            "iree_llvmjit",
+            "iree_vulkan",
+        ],
+    },
+    main = "applications_test.py",
+    tags = ["manual"],
+    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+    ],
+)
+
+iree_e2e_cartesian_product_test_suite(
+    name = "enormous_cifar10_tests",
+    size = "enormous",
+    srcs = ["applications_test.py"],
+    failing_configurations = [
+        {
+            # Failing on vmla with negative inputs.
+            "model": [
+                "NASNetLarge",
+                "NASNetMobile",
+            ],
+            "target_backends": "iree_vmla",
+        },
+        {
+            # Failing on llvm and vulkan:
+            "model": [
+                "NASNetLarge",
+                "NASNetMobile",
+                "ResNet50V2",
+                "ResNet101V2",
+                "ResNet152V2",
+            ],
+            "target_backends": [
+                "iree_llvmjit",
+                "iree_vulkan",
+            ],
+        },
+    ],
+    flags_to_values = {
+        "reference_backend": "tf",
+        "data": "cifar10",
+        "model": [
+            "DenseNet121",
+            "DenseNet169",
+            "DenseNet201",
+            "NASNetLarge",
+            "NASNetMobile",
+            "ResNet50V2",
+            "ResNet101",
+            "ResNet101V2",
+            "ResNet152",
+            "ResNet152V2",
+        ],
+        "target_backends": [
+            "tf",
+            "tflite",
+            "iree_vmla",
+            "iree_llvmjit",
+            "iree_vulkan",
+        ],
+    },
+    main = "applications_test.py",
+    tags = [
+        "guitar",
+        "manual",
+        "nokokoro",
+        "notap",
+    ],
+    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+    ],
+)
+
+# 'non_hermetic' tests use real model weights to test numerical correctness.
+iree_e2e_cartesian_product_test_suite(
+    name = "cifar10_non_hermetic_tests",
+    size = "large",
+    srcs = ["applications_test.py"],
+    flags_to_values = {
+        "reference_backend": "tf",
+        "data": "cifar10",
+        "url": "https://storage.googleapis.com/iree_models/",
+        "use_external_weights": True,
+        "model": [
+            "MobileNet",
+            "MobileNetV2",
+            "ResNet50",
+        ],
+        "target_backends": [
+            "tf",
+            "tflite",
+            "iree_vmla",
+            "iree_llvmjit",
+            "iree_vulkan",
+        ],
+    },
+    main = "applications_test.py",
+    tags = [
+        "external",
+        "guitar",
+        "manual",
+        "no-remote",
+        "nokokoro",
+        "notap",
+    ],
+    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+    ],
+)
+
+# 'non_hermetic' tests use real model weights to test numerical correctness.
+iree_e2e_cartesian_product_test_suite(
+    name = "imagenet_non_hermetic_tests",
+    size = "enormous",
+    srcs = ["applications_test.py"],
+    failing_configurations = [
+        {
+            # Failing on vmla with negative inputs.
+            "model": [
+                "NASNetLarge",
+                "NASNetMobile",
+            ],
+            "target_backends": "iree_vmla",
+        },
+        {
+            # Failing vulkan:
+            "model": [
+                "InceptionResNetV2",
+                "InceptionV3",
+            ],
+            "target_backends": [
+                "iree_vulkan",
+            ],
+        },
+        {
+            # Failing llvm and vulkan:
+            "model": [
+                "NASNetLarge",
+                "NASNetMobile",
+                "ResNet50V2",
+                "ResNet101V2",
+                "ResNet152V2",
+                "Xception",
+            ],
+            "target_backends": [
+                "iree_llvmjit",
+                "iree_vulkan",
+            ],
+        },
+    ],
+    flags_to_values = {
+        "reference_backend": "tf",
+        "data": "imagenet",
+        "use_external_weights": True,
+        "model": KERAS_APPLICATIONS_MODELS,
+        "target_backends": [
+            "tf",
+            "tflite",
+            "iree_vmla",
+            "iree_llvmjit",
+            "iree_vulkan",
+        ],
+    },
+    main = "applications_test.py",
+    tags = [
+        "external",
+        "guitar",
+        "manual",
+        "nokokoro",
+        "notap",
+    ],
+    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+    ],
+)
+
+# It is used to produce weights for keras vision models with input image size
+# 32x32. These models are not optimized for accuracy or latency (they are for
+# debugging only). They have the same neural net topology with keras vision
+# models trained on imagenet data sets
+iree_py_binary(
+    name = "train_vision_models_on_cifar",
+    srcs = ["train_vision_models_on_cifar.py"],
+    python_version = "PY3",
+    srcs_version = "PY2AND3",
+    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+    ],
+)
diff --git a/integrations/tensorflow/e2e/keras/applications/applications_test.py b/integrations/tensorflow/e2e/keras/applications/applications_test.py
new file mode 100644
index 0000000..8c0e77c
--- /dev/null
+++ b/integrations/tensorflow/e2e/keras/applications/applications_test.py
@@ -0,0 +1,121 @@
+# Lint as: python3
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Test all models in tf.keras.applications."""
+
+import os
+
+from absl import app
+from absl import flags
+import numpy as np
+from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
+import tensorflow.compat.v2 as tf
+
+FLAGS = flags.FLAGS
+
+# Testing all applications models automatically can take time
+# so we test it one by one, with argument --model=MobileNet
+flags.DEFINE_string("model", "ResNet50", "model name")
+flags.DEFINE_string(
+    "url", "", "url with model weights "
+    "for example https://storage.googleapis.com/iree_models/")
+flags.DEFINE_bool("use_external_weights", False,
+                  "Whether or not to load external weights from the web")
+flags.DEFINE_enum("data", "cifar10", ["cifar10", "imagenet"],
+                  "data sets on which model was trained: imagenet, cifar10")
+flags.DEFINE_bool(
+    "include_top", True,
+    "Whether or not to include the final (top) layers of the model.")
+
+BATCH_SIZE = 1
+IMAGE_DIM = 224
+
+
+def load_cifar10_weights(model):
+  file_name = "cifar10" + FLAGS.model
+  # get_file will download the model weights from a publicly available folder,
+  # save them to cache_dir=~/.keras/models/ and return a path to them.
+  url = os.path.join(
+      FLAGS.url, f"cifar10_include_top_{FLAGS.include_top:d}_{FLAGS.model}.h5")
+  weights_path = tf.keras.utils.get_file(file_name, url)
+  model.load_weights(weights_path)
+  return model
+
+
+def initialize_model():
+  # If weights == "imagenet", the model will load the appropriate weights from
+  # an external tf.keras URL.
+  weights = None
+  if FLAGS.use_external_weights and FLAGS.data == "imagenet":
+    weights = "imagenet"
+
+  model_class = getattr(tf.keras.applications, FLAGS.model)
+  model = model_class(weights=weights, include_top=FLAGS.include_top)
+
+  if FLAGS.use_external_weights and FLAGS.data == "cifar10":
+    if not FLAGS.url:
+      raise ValueError(
+          "cifar10 weights cannot be loaded without the `--url` flag.")
+    model = load_cifar10_weights(model)
+  return model
+
+
+class ApplicationsModule(tf_test_utils.TestModule):
+
+  def __init__(self):
+    super().__init__()
+    self.m = initialize_model()
+
+    input_shape = list([BATCH_SIZE] + self.m.inputs[0].shape[1:])
+
+    # Some models accept dynamic image dimensions by default, so we use
+    # IMAGE_DIM as a stand-in.
+    for i, dim in enumerate(input_shape):
+      if dim is None:
+        input_shape[i] = IMAGE_DIM
+
+    # Specify input shape with a static batch size.
+    # TODO(b/142948097): Add support for dynamic shapes in SPIR-V lowering.
+    self.call = tf_test_utils.tf_function_unit_test(
+        input_signature=[tf.TensorSpec(input_shape)],
+        name="call",
+        rtol=1e-5,
+        atol=1e-5)(lambda x: self.m(x, training=False))
+
+
+class ApplicationsTest(tf_test_utils.TracedModuleTestCase):
+
+  def __init__(self, *args, **kwargs):
+    super().__init__(*args, **kwargs)
+    self._modules = tf_test_utils.compile_tf_module(
+        ApplicationsModule,
+        exported_names=ApplicationsModule.get_tf_function_unit_tests(),
+        relative_artifacts_dir=os.path.join(FLAGS.model, FLAGS.data))
+
+
+def main(argv):
+  del argv  # Unused.
+  if hasattr(tf, "enable_v2_behavior"):
+    tf.enable_v2_behavior()
+
+  if not hasattr(tf.keras.applications, FLAGS.model):
+    raise ValueError(f"Unsupported model: {FLAGS.model}")
+
+  ApplicationsTest.generate_unit_tests(ApplicationsModule)
+  tf.test.main()
+
+
+if __name__ == "__main__":
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/keras/train_vision_models_on_cifar.py b/integrations/tensorflow/e2e/keras/applications/train_vision_models_on_cifar.py
similarity index 63%
rename from integrations/tensorflow/e2e/keras/train_vision_models_on_cifar.py
rename to integrations/tensorflow/e2e/keras/applications/train_vision_models_on_cifar.py
index 6cfa854..28cd296 100644
--- a/integrations/tensorflow/e2e/keras/train_vision_models_on_cifar.py
+++ b/integrations/tensorflow/e2e/keras/applications/train_vision_models_on_cifar.py
@@ -28,50 +28,12 @@
     'include_top', True,
     'Whether or not to include the final (top) layers of the model.')
 
-APP_MODELS = {
-    'ResNet50':
-        tf.keras.applications.resnet.ResNet50,
-    'ResNet101':
-        tf.keras.applications.resnet.ResNet101,
-    'ResNet152':
-        tf.keras.applications.resnet.ResNet152,
-    'ResNet50V2':
-        tf.keras.applications.resnet_v2.ResNet50V2,
-    'ResNet101V2':
-        tf.keras.applications.resnet_v2.ResNet101V2,
-    'ResNet152V2':
-        tf.keras.applications.resnet_v2.ResNet152V2,
-    'VGG16':
-        tf.keras.applications.vgg16.VGG16,
-    'VGG19':
-        tf.keras.applications.vgg19.VGG19,
-    'Xception':
-        tf.keras.applications.xception.Xception,
-    'InceptionV3':
-        tf.keras.applications.inception_v3.InceptionV3,
-    'InceptionResNetV2':
-        tf.keras.applications.inception_resnet_v2.InceptionResNetV2,
-    'MobileNet':
-        tf.keras.applications.mobilenet.MobileNet,
-    'MobileNetV2':
-        tf.keras.applications.mobilenet_v2.MobileNetV2,
-    'DenseNet121':
-        tf.keras.applications.densenet.DenseNet121,
-    'DenseNet169':
-        tf.keras.applications.densenet.DenseNet169,
-    'DenseNet201':
-        tf.keras.applications.densenet.DenseNet201,
-    'NASNetMobile':
-        tf.keras.applications.nasnet.NASNetMobile,
-    'NASNetLarge':
-        tf.keras.applications.nasnet.NASNetLarge,
-}
-
 # minimum size for keras vision models
 INPUT_SHAPE = [1, 32, 32, 3]
 
 
-def main(_):
+def main(argv):
+  del argv  # Unused.
 
   # prepare training and testing data
   (train_images,
@@ -89,10 +51,10 @@
   train_labels = train_labels[:4000]
 
   # It is a toy model for debugging (not optimized for accuracy or speed).
-
-  model = APP_MODELS[FLAGS.model](weights=None,
-                                  include_top=FLAGS.include_top,
-                                  input_shape=INPUT_SHAPE[1:])
+  model_class = getattr(tf.keras.applications, FLAGS.model)
+  model = model_class(weights=None,
+                      include_top=FLAGS.include_top,
+                      input_shape=INPUT_SHAPE[1:])
   model.summary()
   model.compile(optimizer='adam',
                 loss='sparse_categorical_crossentropy',
diff --git a/integrations/tensorflow/e2e/keras/vision_model_test.py b/integrations/tensorflow/e2e/keras/vision_model_test.py
deleted file mode 100644
index 689a46d..0000000
--- a/integrations/tensorflow/e2e/keras/vision_model_test.py
+++ /dev/null
@@ -1,174 +0,0 @@
-# Lint as: python3
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#      https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Test all applications models in Keras."""
-
-import os
-
-from absl import app
-from absl import flags
-import numpy as np
-from pyiree.tf.support import tf_test_utils
-from pyiree.tf.support import tf_utils
-import tensorflow.compat.v2 as tf
-
-FLAGS = flags.FLAGS
-
-# Testing all applications models automatically can take time
-# so we test it one by one, with argument --model=MobileNet
-flags.DEFINE_string('model', 'ResNet50', 'model name')
-flags.DEFINE_string(
-    'url', '', 'url with model weights '
-    'for example https://storage.googleapis.com/iree_models/')
-flags.DEFINE_bool('use_external_weights', False,
-                  'Whether or not to load external weights from the web')
-flags.DEFINE_enum('data', 'cifar10', ['cifar10', 'imagenet'],
-                  'data sets on which model was trained: imagenet, cifar10')
-flags.DEFINE_bool(
-    'include_top', True,
-    'Whether or not to include the final (top) layers of the model.')
-
-APP_MODELS = {
-    'ResNet50':
-        tf.keras.applications.resnet.ResNet50,
-    'ResNet101':
-        tf.keras.applications.resnet.ResNet101,
-    'ResNet152':
-        tf.keras.applications.resnet.ResNet152,
-    'ResNet50V2':
-        tf.keras.applications.resnet_v2.ResNet50V2,
-    'ResNet101V2':
-        tf.keras.applications.resnet_v2.ResNet101V2,
-    'ResNet152V2':
-        tf.keras.applications.resnet_v2.ResNet152V2,
-    'VGG16':
-        tf.keras.applications.vgg16.VGG16,
-    'VGG19':
-        tf.keras.applications.vgg19.VGG19,
-    'Xception':
-        tf.keras.applications.xception.Xception,
-    'InceptionV3':
-        tf.keras.applications.inception_v3.InceptionV3,
-    'InceptionResNetV2':
-        tf.keras.applications.inception_resnet_v2.InceptionResNetV2,
-    'MobileNet':
-        tf.keras.applications.mobilenet.MobileNet,
-    'MobileNetV2':
-        tf.keras.applications.mobilenet_v2.MobileNetV2,
-    'DenseNet121':
-        tf.keras.applications.densenet.DenseNet121,
-    'DenseNet169':
-        tf.keras.applications.densenet.DenseNet169,
-    'DenseNet201':
-        tf.keras.applications.densenet.DenseNet201,
-    'NASNetMobile':
-        tf.keras.applications.nasnet.NASNetMobile,
-    'NASNetLarge':
-        tf.keras.applications.nasnet.NASNetLarge,
-}
-
-
-def get_input_shape():
-  if FLAGS.data == 'imagenet':
-    if FLAGS.model in ['InceptionV3', 'Xception', 'InceptionResNetV2']:
-      return (1, 299, 299, 3)
-    elif FLAGS.model == 'NASNetLarge':
-      return (1, 331, 331, 3)
-    else:
-      return (1, 224, 224, 3)
-  elif FLAGS.data == 'cifar10':
-    return (1, 32, 32, 3)
-  else:
-    raise ValueError(f'Data not supported: {FLAGS.data}')
-
-
-def load_cifar10_weights(model):
-  file_name = 'cifar10' + FLAGS.model
-  # get_file will download the model weights from a publicly available folder,
-  # save them to cache_dir=~/.keras/models/ and return a path to them.
-  url = os.path.join(
-      FLAGS.url, f'cifar10_include_top_{FLAGS.include_top:d}_{FLAGS.model}.h5')
-  weights_path = tf.keras.utils.get_file(file_name, url)
-  model.load_weights(weights_path)
-  return model
-
-
-def initialize_model():
-  tf_utils.set_random_seed()
-
-  # Keras applications models receive input shapes without a batch dimension, as
-  # the batch size is dynamic by default. This selects just the image size.
-  input_shape = get_input_shape()[1:]
-
-  # If weights == 'imagenet', the model will load the appropriate weights from
-  # an external tf.keras URL.
-  weights = None
-  if FLAGS.use_external_weights and FLAGS.data == 'imagenet':
-    weights = 'imagenet'
-
-  model = APP_MODELS[FLAGS.model](weights=weights,
-                                  include_top=FLAGS.include_top,
-                                  input_shape=input_shape)
-
-  if FLAGS.use_external_weights and FLAGS.data == 'cifar10':
-    if not FLAGS.url:
-      raise ValueError(
-          'cifar10 weights cannot be loaded without the `--url` flag.')
-    model = load_cifar10_weights(model)
-  return model
-
-
-class VisionModule(tf.Module):
-
-  def __init__(self):
-    super().__init__()
-    self.m = initialize_model()
-    self.m.predict = lambda x: self.m.call(x, training=False)
-    # Specify input shape with a static batch size.
-    # TODO(b/142948097): Add support for dynamic shapes in SPIR-V lowering.
-    # Replace input_shape with m.input_shape to make the batch size dynamic.
-    self.predict = tf.function(
-        input_signature=[tf.TensorSpec(get_input_shape())])(self.m.predict)
-
-
-class AppTest(tf_test_utils.TracedModuleTestCase):
-
-  def __init__(self, *args, **kwargs):
-    super().__init__(*args, **kwargs)
-    self._modules = tf_test_utils.compile_tf_module(
-        VisionModule,
-        exported_names=['predict'],
-        relative_artifacts_dir=os.path.join(FLAGS.model, FLAGS.data))
-
-  def test_predict(self):
-
-    def predict(module):
-      module.predict(tf_utils.uniform(get_input_shape()), atol=1e-5, rtol=1e-5)
-
-    self.compare_backends(predict, self._modules)
-
-
-def main(argv):
-  del argv  # Unused
-  if hasattr(tf, 'enable_v2_behavior'):
-    tf.enable_v2_behavior()
-
-  if FLAGS.model not in APP_MODELS:
-    raise ValueError(f'Unsupported model: {FLAGS.model}')
-
-  tf.test.main()
-
-
-if __name__ == '__main__':
-  app.run(main)
diff --git a/scripts/get_e2e_artifacts.py b/scripts/get_e2e_artifacts.py
index 2440ab7..21d9259 100755
--- a/scripts/get_e2e_artifacts.py
+++ b/scripts/get_e2e_artifacts.py
@@ -41,8 +41,6 @@
         '//integrations/tensorflow/e2e:e2e_tests',
     'mobile_bert_squad_tests':
         '//integrations/tensorflow/e2e:mobile_bert_squad_tests',
-    'keras_tests':
-        '//integrations/tensorflow/e2e/keras:keras_tests',
     'layers_tests':
         '//integrations/tensorflow/e2e/keras/layers:layers_tests',
     'layers_dynamic_batch_tests':
@@ -54,7 +52,7 @@
     'keyword_spotting_internal_streaming_tests':
         '//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests',
     'imagenet_non_hermetic_tests':
-        '//integrations/tensorflow/e2e/keras:imagenet_non_hermetic_tests',
+        '//integrations/tensorflow/e2e/keras/applications:imagenet_non_hermetic_tests',
     'slim_vision_tests':
         '//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests',
 }
diff --git a/scripts/update_e2e_coverage.py b/scripts/update_e2e_coverage.py
index 3b5509c..fba79b2 100755
--- a/scripts/update_e2e_coverage.py
+++ b/scripts/update_e2e_coverage.py
@@ -60,7 +60,7 @@
         '//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests',
     ],
     'vision_coverage': [
-        '//integrations/tensorflow/e2e/keras:imagenet_non_hermetic_tests',
+        '//integrations/tensorflow/e2e/keras/applications:imagenet_non_hermetic_tests',
         '//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests',
     ],
 }
@@ -116,7 +116,7 @@
     '//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests':
         f'End to end tests of {KWS_LINK} models in internal streaming mode',
     # vision_coverage
-    '//integrations/tensorflow/e2e/keras:imagenet_non_hermetic_tests':
+    '//integrations/tensorflow/e2e/keras/applications:imagenet_non_hermetic_tests':
         'End to end tests of tf.keras.applications vision models on Imagenet',
     '//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests':
         'End to end tests of TensorFlow slim vision models',
@@ -162,7 +162,7 @@
     '//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests':
         'model',
     # vision_coverage
-    '//integrations/tensorflow/e2e/keras:imagenet_non_hermetic_tests':
+    '//integrations/tensorflow/e2e/keras/applications:imagenet_non_hermetic_tests':
         'model',
     '//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests':
         'model',
@@ -193,7 +193,7 @@
     '//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests':
         'keyword_spotting_streaming_test',
     # vision_coverage
-    '//integrations/tensorflow/e2e/keras:imagenet_non_hermetic_tests':
+    '//integrations/tensorflow/e2e/keras/applications:imagenet_non_hermetic_tests':
         'vision_model_test',
     '//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests':
         'slim_vision_model_test',