Add new tf.keras.applications models (#3958)
Add's the `EfficientNet*` and `MobileNetV3*` models and moves the tests into `.../e2e/keras/applications`.
diff --git a/integrations/tensorflow/e2e/keras/BUILD b/integrations/tensorflow/e2e/keras/BUILD
index 6ecd6b5..89f92ad 100644
--- a/integrations/tensorflow/e2e/keras/BUILD
+++ b/integrations/tensorflow/e2e/keras/BUILD
@@ -28,10 +28,6 @@
"//integrations/tensorflow/e2e:iree_e2e_cartesian_product_test_suite.bzl",
"iree_e2e_cartesian_product_test_suite",
)
-load(
- "//integrations/tensorflow/e2e:iree_e2e_test_suite.bzl",
- "iree_e2e_test_suite",
-)
package(
default_visibility = ["//visibility:public"],
@@ -39,33 +35,6 @@
licenses = ["notice"], # Apache 2.0
)
-# @unused
-DOC = """
-vision_model_test_manual is for manual testing of all keras vision models.
-Test will run only manually with all parameters specified manually, for example:
-bazel run -c opt integrations/tensorflow/e2e/keras:vision_model_test_manual -- \
---target_backends=tf,iree_vmla \
---data=imagenet \
---url=https://storage.googleapis.com/iree_models/ \
---model=ResNet50
-
-Command arguments description:
---target_backends: can be combination of these: tf,iree_vmla
---data: can be 'imagenet' or 'cifar10'.
- imagenet - input image size (1, 224, 224, 3)
- cifar10 - input image size (1, 32, 32, 3) - it is used for quick tests
- and needs pretrained weights, we pretrained models: ResNet50, MobileNet, MobileNetV2
---include_top: Whether or not to include the final (top) layers of the model.
---url: we need it only for cifar10 models to load weights from https://storage.googleapis.com/iree_models/
- imagenet pretrained weights url is specified by keras
---model: supports ResNet50, MobileNet, MobileNetV2, ResNet101, ResNet152,
- ResNet50V2, ResNet101V2, ResNet152V2, VGG16, VGG19, Xception,
- InceptionV3, InceptionResNetV2, DenseNet121, DenseNet169,
- DenseNet201, NASNetMobile, NASNetLarge
- All above models works with 'imagenet' data sets.
- ResNet50, MobileNet, MobileNetV2 work with both 'imagenet' and 'cifar10' data sets.
-"""
-
[
iree_py_binary(
name = src.replace(".py", "_manual"),
@@ -82,308 +51,6 @@
)
]
-SPECIAL_CASES = [
- "keyword_spotting_streaming_test.py",
- "vision_model_test.py",
-]
-
-TFLITE_FAILING = []
-
-VMLA_FAILING = []
-
-LLVM_FAILING = []
-
-VULKAN_FAILING = []
-
-TF_PASSING = glob(
- ["*_test.py"],
- exclude = SPECIAL_CASES,
-)
-
-TFLITE_PASSING = glob(
- ["*_test.py"],
- exclude = TFLITE_FAILING + SPECIAL_CASES,
-)
-
-VMLA_PASSING = glob(
- ["*_test.py"],
- exclude = VMLA_FAILING + SPECIAL_CASES,
-)
-
-LLVM_PASSING = glob(
- ["*_test.py"],
- exclude = LLVM_FAILING + SPECIAL_CASES,
-)
-
-VULKAN_PASSING = glob(
- ["*_test.py"],
- exclude = VULKAN_FAILING + SPECIAL_CASES,
-)
-
-iree_e2e_test_suite(
- name = "keras_tests",
- backends_to_srcs = {
- "tf": TF_PASSING,
- "tflite": TFLITE_PASSING,
- "iree_vmla": VMLA_PASSING,
- "iree_llvmjit": LLVM_PASSING,
- "iree_vulkan": VULKAN_PASSING,
- },
- reference_backend = "tf",
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
-
-iree_e2e_test_suite(
- name = "keras_tests_failing",
- backends_to_srcs = {
- "tflite": TFLITE_FAILING,
- "iree_vmla": VMLA_FAILING,
- "iree_llvmjit": LLVM_FAILING,
- "iree_vulkan": VULKAN_FAILING,
- },
- reference_backend = "tf",
- tags = [
- "failing",
- "manual",
- "nokokoro",
- "notap",
- ],
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
-
-iree_e2e_cartesian_product_test_suite(
- name = "large_cifar10_tests",
- size = "large",
- srcs = ["vision_model_test.py"],
- flags_to_values = {
- "reference_backend": "tf",
- "data": "cifar10",
- "model": [
- # All models with runtime shorter than ResNet50.
- "MobileNet", # Max: Vulkan 61.0s
- "MobileNetV2", # Max: LLVM 96.3s
- "ResNet50", # Max: LLVM 145.6s
- "VGG16", # Max: LLVM 89.5s
- "VGG19", # Max: LLVM 94.7s
- ],
- "target_backends": [
- "tf",
- "tflite",
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- main = "vision_model_test.py",
- tags = ["manual"],
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
-
-iree_e2e_cartesian_product_test_suite(
- name = "enormous_cifar10_tests",
- size = "enormous",
- srcs = ["vision_model_test.py"],
- failing_configurations = [
- {
- # Failing on vmla with negative inputs.
- "model": [
- "NASNetLarge",
- "NASNetMobile",
- ],
- "target_backends": "iree_vmla",
- },
- {
- # Failing on llvm and vulkan:
- "model": [
- "NASNetLarge",
- "NASNetMobile",
- "ResNet50V2",
- "ResNet101V2",
- "ResNet152V2",
- ],
- "target_backends": [
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- ],
- flags_to_values = {
- "reference_backend": "tf",
- "data": "cifar10",
- "model": [
- "DenseNet121",
- "DenseNet169",
- "DenseNet201",
- "NASNetLarge",
- "NASNetMobile",
- "ResNet50V2",
- "ResNet101",
- "ResNet101V2",
- "ResNet152",
- "ResNet152V2",
- ],
- "target_backends": [
- "tf",
- "tflite",
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- main = "vision_model_test.py",
- tags = [
- "guitar",
- "manual",
- "nokokoro",
- "notap",
- ],
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
-
-# 'non_hermetic' tests use real model weights to test numerical correctness.
-iree_e2e_cartesian_product_test_suite(
- name = "cifar10_non_hermetic_tests",
- size = "large",
- srcs = ["vision_model_test.py"],
- flags_to_values = {
- "reference_backend": "tf",
- "data": "cifar10",
- "url": "https://storage.googleapis.com/iree_models/",
- "use_external_weights": True,
- "model": [
- "MobileNet",
- "MobileNetV2",
- "ResNet50",
- ],
- "target_backends": [
- "tf",
- "tflite",
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- main = "vision_model_test.py",
- tags = [
- "external",
- "guitar",
- "manual",
- "no-remote",
- "nokokoro",
- "notap",
- ],
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
-
-# 'non_hermetic' tests use real model weights to test numerical correctness.
-iree_e2e_cartesian_product_test_suite(
- name = "imagenet_non_hermetic_tests",
- size = "enormous",
- srcs = ["vision_model_test.py"],
- failing_configurations = [
- {
- # Failing on vmla with negative inputs.
- "model": [
- "NASNetLarge",
- "NASNetMobile",
- ],
- "target_backends": "iree_vmla",
- },
- {
- # Failing vulkan:
- "model": [
- "InceptionResNetV2",
- "InceptionV3",
- ],
- "target_backends": [
- "iree_vulkan",
- ],
- },
- {
- # Failing llvm and vulkan:
- "model": [
- "NASNetLarge",
- "NASNetMobile",
- "ResNet50V2",
- "ResNet101V2",
- "ResNet152V2",
- "Xception",
- ],
- "target_backends": [
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- ],
- flags_to_values = {
- "reference_backend": "tf",
- "data": "imagenet",
- "use_external_weights": True,
- "model": [
- "DenseNet121",
- "DenseNet169",
- "DenseNet201",
- "InceptionResNetV2",
- "InceptionV3",
- "MobileNet",
- "MobileNetV2",
- "NASNetLarge",
- "NASNetMobile",
- "ResNet50",
- "ResNet50V2",
- "ResNet101",
- "ResNet101V2",
- "ResNet152",
- "ResNet152V2",
- "VGG16",
- "VGG19",
- "Xception",
- ],
- "target_backends": [
- "tf",
- "tflite",
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- main = "vision_model_test.py",
- tags = [
- "external",
- "guitar",
- "manual",
- "nokokoro",
- "notap",
- ],
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
-
-# It is used to produce weights for keras vision models with input image size
-# 32x32. These models are not optimized for accuracy or latency (they are for
-# debugging only). They have the same neural net topology with keras vision
-# models trained on imagenet data sets
-iree_py_binary(
- name = "train_vision_models_on_cifar",
- srcs = ["train_vision_models_on_cifar.py"],
- python_version = "PY3",
- srcs_version = "PY2AND3",
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
-
# Keyword Spotting Tests:
KEYWORD_SPOTTING_MODELS = [
"svdf",
diff --git a/integrations/tensorflow/e2e/keras/applications/BUILD b/integrations/tensorflow/e2e/keras/applications/BUILD
new file mode 100644
index 0000000..a9df330
--- /dev/null
+++ b/integrations/tensorflow/e2e/keras/applications/BUILD
@@ -0,0 +1,318 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Test coverage across backends for e2e tests is defined directly in the BUILD
+# files. Coverage tables generated from this file can be viewed here:
+# https://google.github.io/iree/tensorflow-coverage/vision-coverage
+# Updates made to test suite names should also be reflected here:
+# https://github.com/google/iree/blob/main/scripts/update_e2e_coverage.py
+
+load(
+ "//bindings/python:build_defs.oss.bzl",
+ "INTREE_TENSORFLOW_PY_DEPS",
+ "NUMPY_DEPS",
+ "iree_py_binary",
+)
+load(
+ "//integrations/tensorflow/e2e:iree_e2e_cartesian_product_test_suite.bzl",
+ "iree_e2e_cartesian_product_test_suite",
+)
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+# @unused
+DOC = """
+applications_test_manual is for manual testing of all keras vision models.
+Test will run only manually with all parameters specified manually, for example:
+bazel run -c opt integrations/tensorflow/e2e/keras/applications:applications_test_manual -- \
+--target_backends=tf,iree_vmla \
+--data=imagenet \
+--url=https://storage.googleapis.com/iree_models/ \
+--model=ResNet50
+
+Command arguments description:
+--target_backends: can be combination of these: tf,iree_vmla
+--data: can be 'imagenet' or 'cifar10'.
+ imagenet - input image size (1, 224, 224, 3)
+ cifar10 - input image size (1, 32, 32, 3) - it is used for quick tests
+ and needs pretrained weights, we pretrained models: ResNet50, MobileNet, MobileNetV2
+--include_top: Whether or not to include the final (top) layers of the model.
+--url: we need it only for cifar10 models to load weights from https://storage.googleapis.com/iree_models/
+ imagenet pretrained weights url is specified by keras
+--model: supports ResNet50, MobileNet, MobileNetV2, ResNet101, ResNet152,
+ ResNet50V2, ResNet101V2, ResNet152V2, VGG16, VGG19, Xception,
+ InceptionV3, InceptionResNetV2, DenseNet121, DenseNet169,
+ DenseNet201, NASNetMobile, NASNetLarge
+ All above models works with 'imagenet' data sets.
+ ResNet50, MobileNet, MobileNetV2 work with both 'imagenet' and 'cifar10' data sets.
+"""
+
+[
+ iree_py_binary(
+ name = src.replace(".py", "_manual"),
+ srcs = [src],
+ main = src,
+ python_version = "PY3",
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+ )
+ for src in glob(
+ ["*_test.py"],
+ )
+]
+
+KERAS_APPLICATIONS_MODELS = [
+ "DenseNet121",
+ "DenseNet169",
+ "DenseNet201",
+ "EfficientNetB0",
+ "EfficientNetB1",
+ "EfficientNetB2",
+ "EfficientNetB3",
+ "EfficientNetB4",
+ "EfficientNetB5",
+ "EfficientNetB6",
+ "EfficientNetB7",
+ "InceptionResNetV2",
+ "InceptionV3",
+ "MobileNet",
+ "MobileNetV2",
+ "MobileNetV3Large",
+ "MobileNetV3Small",
+ "NASNetLarge",
+ "NASNetMobile",
+ "ResNet101",
+ "ResNet101V2",
+ "ResNet152",
+ "ResNet152V2",
+ "ResNet50",
+ "ResNet50V2",
+ "VGG16",
+ "VGG19",
+]
+
+iree_e2e_cartesian_product_test_suite(
+ name = "large_cifar10_tests",
+ size = "large",
+ srcs = ["applications_test.py"],
+ flags_to_values = {
+ "reference_backend": "tf",
+ "data": "cifar10",
+ "model": [
+ # All models with runtime shorter than ResNet50.
+ "MobileNet", # Max: Vulkan 61.0s
+ "MobileNetV2", # Max: LLVM 96.3s
+ "ResNet50", # Max: LLVM 145.6s
+ "VGG16", # Max: LLVM 89.5s
+ "VGG19", # Max: LLVM 94.7s
+ ],
+ "target_backends": [
+ "tf",
+ "tflite",
+ "iree_vmla",
+ "iree_llvmjit",
+ "iree_vulkan",
+ ],
+ },
+ main = "applications_test.py",
+ tags = ["manual"],
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+)
+
+iree_e2e_cartesian_product_test_suite(
+ name = "enormous_cifar10_tests",
+ size = "enormous",
+ srcs = ["applications_test.py"],
+ failing_configurations = [
+ {
+ # Failing on vmla with negative inputs.
+ "model": [
+ "NASNetLarge",
+ "NASNetMobile",
+ ],
+ "target_backends": "iree_vmla",
+ },
+ {
+ # Failing on llvm and vulkan:
+ "model": [
+ "NASNetLarge",
+ "NASNetMobile",
+ "ResNet50V2",
+ "ResNet101V2",
+ "ResNet152V2",
+ ],
+ "target_backends": [
+ "iree_llvmjit",
+ "iree_vulkan",
+ ],
+ },
+ ],
+ flags_to_values = {
+ "reference_backend": "tf",
+ "data": "cifar10",
+ "model": [
+ "DenseNet121",
+ "DenseNet169",
+ "DenseNet201",
+ "NASNetLarge",
+ "NASNetMobile",
+ "ResNet50V2",
+ "ResNet101",
+ "ResNet101V2",
+ "ResNet152",
+ "ResNet152V2",
+ ],
+ "target_backends": [
+ "tf",
+ "tflite",
+ "iree_vmla",
+ "iree_llvmjit",
+ "iree_vulkan",
+ ],
+ },
+ main = "applications_test.py",
+ tags = [
+ "guitar",
+ "manual",
+ "nokokoro",
+ "notap",
+ ],
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+)
+
+# 'non_hermetic' tests use real model weights to test numerical correctness.
+iree_e2e_cartesian_product_test_suite(
+ name = "cifar10_non_hermetic_tests",
+ size = "large",
+ srcs = ["applications_test.py"],
+ flags_to_values = {
+ "reference_backend": "tf",
+ "data": "cifar10",
+ "url": "https://storage.googleapis.com/iree_models/",
+ "use_external_weights": True,
+ "model": [
+ "MobileNet",
+ "MobileNetV2",
+ "ResNet50",
+ ],
+ "target_backends": [
+ "tf",
+ "tflite",
+ "iree_vmla",
+ "iree_llvmjit",
+ "iree_vulkan",
+ ],
+ },
+ main = "applications_test.py",
+ tags = [
+ "external",
+ "guitar",
+ "manual",
+ "no-remote",
+ "nokokoro",
+ "notap",
+ ],
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+)
+
+# 'non_hermetic' tests use real model weights to test numerical correctness.
+iree_e2e_cartesian_product_test_suite(
+ name = "imagenet_non_hermetic_tests",
+ size = "enormous",
+ srcs = ["applications_test.py"],
+ failing_configurations = [
+ {
+ # Failing on vmla with negative inputs.
+ "model": [
+ "NASNetLarge",
+ "NASNetMobile",
+ ],
+ "target_backends": "iree_vmla",
+ },
+ {
+ # Failing vulkan:
+ "model": [
+ "InceptionResNetV2",
+ "InceptionV3",
+ ],
+ "target_backends": [
+ "iree_vulkan",
+ ],
+ },
+ {
+ # Failing llvm and vulkan:
+ "model": [
+ "NASNetLarge",
+ "NASNetMobile",
+ "ResNet50V2",
+ "ResNet101V2",
+ "ResNet152V2",
+ "Xception",
+ ],
+ "target_backends": [
+ "iree_llvmjit",
+ "iree_vulkan",
+ ],
+ },
+ ],
+ flags_to_values = {
+ "reference_backend": "tf",
+ "data": "imagenet",
+ "use_external_weights": True,
+ "model": KERAS_APPLICATIONS_MODELS,
+ "target_backends": [
+ "tf",
+ "tflite",
+ "iree_vmla",
+ "iree_llvmjit",
+ "iree_vulkan",
+ ],
+ },
+ main = "applications_test.py",
+ tags = [
+ "external",
+ "guitar",
+ "manual",
+ "nokokoro",
+ "notap",
+ ],
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+)
+
+# It is used to produce weights for keras vision models with input image size
+# 32x32. These models are not optimized for accuracy or latency (they are for
+# debugging only). They have the same neural net topology with keras vision
+# models trained on imagenet data sets
+iree_py_binary(
+ name = "train_vision_models_on_cifar",
+ srcs = ["train_vision_models_on_cifar.py"],
+ python_version = "PY3",
+ srcs_version = "PY2AND3",
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+)
diff --git a/integrations/tensorflow/e2e/keras/applications/applications_test.py b/integrations/tensorflow/e2e/keras/applications/applications_test.py
new file mode 100644
index 0000000..8c0e77c
--- /dev/null
+++ b/integrations/tensorflow/e2e/keras/applications/applications_test.py
@@ -0,0 +1,121 @@
+# Lint as: python3
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Test all models in tf.keras.applications."""
+
+import os
+
+from absl import app
+from absl import flags
+import numpy as np
+from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
+import tensorflow.compat.v2 as tf
+
+FLAGS = flags.FLAGS
+
+# Testing all applications models automatically can take time
+# so we test it one by one, with argument --model=MobileNet
+flags.DEFINE_string("model", "ResNet50", "model name")
+flags.DEFINE_string(
+ "url", "", "url with model weights "
+ "for example https://storage.googleapis.com/iree_models/")
+flags.DEFINE_bool("use_external_weights", False,
+ "Whether or not to load external weights from the web")
+flags.DEFINE_enum("data", "cifar10", ["cifar10", "imagenet"],
+ "data sets on which model was trained: imagenet, cifar10")
+flags.DEFINE_bool(
+ "include_top", True,
+ "Whether or not to include the final (top) layers of the model.")
+
+BATCH_SIZE = 1
+IMAGE_DIM = 224
+
+
+def load_cifar10_weights(model):
+ file_name = "cifar10" + FLAGS.model
+ # get_file will download the model weights from a publicly available folder,
+ # save them to cache_dir=~/.keras/models/ and return a path to them.
+ url = os.path.join(
+ FLAGS.url, f"cifar10_include_top_{FLAGS.include_top:d}_{FLAGS.model}.h5")
+ weights_path = tf.keras.utils.get_file(file_name, url)
+ model.load_weights(weights_path)
+ return model
+
+
+def initialize_model():
+ # If weights == "imagenet", the model will load the appropriate weights from
+ # an external tf.keras URL.
+ weights = None
+ if FLAGS.use_external_weights and FLAGS.data == "imagenet":
+ weights = "imagenet"
+
+ model_class = getattr(tf.keras.applications, FLAGS.model)
+ model = model_class(weights=weights, include_top=FLAGS.include_top)
+
+ if FLAGS.use_external_weights and FLAGS.data == "cifar10":
+ if not FLAGS.url:
+ raise ValueError(
+ "cifar10 weights cannot be loaded without the `--url` flag.")
+ model = load_cifar10_weights(model)
+ return model
+
+
+class ApplicationsModule(tf_test_utils.TestModule):
+
+ def __init__(self):
+ super().__init__()
+ self.m = initialize_model()
+
+ input_shape = list([BATCH_SIZE] + self.m.inputs[0].shape[1:])
+
+ # Some models accept dynamic image dimensions by default, so we use
+ # IMAGE_DIM as a stand-in.
+ for i, dim in enumerate(input_shape):
+ if dim is None:
+ input_shape[i] = IMAGE_DIM
+
+ # Specify input shape with a static batch size.
+ # TODO(b/142948097): Add support for dynamic shapes in SPIR-V lowering.
+ self.call = tf_test_utils.tf_function_unit_test(
+ input_signature=[tf.TensorSpec(input_shape)],
+ name="call",
+ rtol=1e-5,
+ atol=1e-5)(lambda x: self.m(x, training=False))
+
+
+class ApplicationsTest(tf_test_utils.TracedModuleTestCase):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._modules = tf_test_utils.compile_tf_module(
+ ApplicationsModule,
+ exported_names=ApplicationsModule.get_tf_function_unit_tests(),
+ relative_artifacts_dir=os.path.join(FLAGS.model, FLAGS.data))
+
+
+def main(argv):
+ del argv # Unused.
+ if hasattr(tf, "enable_v2_behavior"):
+ tf.enable_v2_behavior()
+
+ if not hasattr(tf.keras.applications, FLAGS.model):
+ raise ValueError(f"Unsupported model: {FLAGS.model}")
+
+ ApplicationsTest.generate_unit_tests(ApplicationsModule)
+ tf.test.main()
+
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/keras/train_vision_models_on_cifar.py b/integrations/tensorflow/e2e/keras/applications/train_vision_models_on_cifar.py
similarity index 63%
rename from integrations/tensorflow/e2e/keras/train_vision_models_on_cifar.py
rename to integrations/tensorflow/e2e/keras/applications/train_vision_models_on_cifar.py
index 6cfa854..28cd296 100644
--- a/integrations/tensorflow/e2e/keras/train_vision_models_on_cifar.py
+++ b/integrations/tensorflow/e2e/keras/applications/train_vision_models_on_cifar.py
@@ -28,50 +28,12 @@
'include_top', True,
'Whether or not to include the final (top) layers of the model.')
-APP_MODELS = {
- 'ResNet50':
- tf.keras.applications.resnet.ResNet50,
- 'ResNet101':
- tf.keras.applications.resnet.ResNet101,
- 'ResNet152':
- tf.keras.applications.resnet.ResNet152,
- 'ResNet50V2':
- tf.keras.applications.resnet_v2.ResNet50V2,
- 'ResNet101V2':
- tf.keras.applications.resnet_v2.ResNet101V2,
- 'ResNet152V2':
- tf.keras.applications.resnet_v2.ResNet152V2,
- 'VGG16':
- tf.keras.applications.vgg16.VGG16,
- 'VGG19':
- tf.keras.applications.vgg19.VGG19,
- 'Xception':
- tf.keras.applications.xception.Xception,
- 'InceptionV3':
- tf.keras.applications.inception_v3.InceptionV3,
- 'InceptionResNetV2':
- tf.keras.applications.inception_resnet_v2.InceptionResNetV2,
- 'MobileNet':
- tf.keras.applications.mobilenet.MobileNet,
- 'MobileNetV2':
- tf.keras.applications.mobilenet_v2.MobileNetV2,
- 'DenseNet121':
- tf.keras.applications.densenet.DenseNet121,
- 'DenseNet169':
- tf.keras.applications.densenet.DenseNet169,
- 'DenseNet201':
- tf.keras.applications.densenet.DenseNet201,
- 'NASNetMobile':
- tf.keras.applications.nasnet.NASNetMobile,
- 'NASNetLarge':
- tf.keras.applications.nasnet.NASNetLarge,
-}
-
# minimum size for keras vision models
INPUT_SHAPE = [1, 32, 32, 3]
-def main(_):
+def main(argv):
+ del argv # Unused.
# prepare training and testing data
(train_images,
@@ -89,10 +51,10 @@
train_labels = train_labels[:4000]
# It is a toy model for debugging (not optimized for accuracy or speed).
-
- model = APP_MODELS[FLAGS.model](weights=None,
- include_top=FLAGS.include_top,
- input_shape=INPUT_SHAPE[1:])
+ model_class = getattr(tf.keras.applications, FLAGS.model)
+ model = model_class(weights=None,
+ include_top=FLAGS.include_top,
+ input_shape=INPUT_SHAPE[1:])
model.summary()
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
diff --git a/integrations/tensorflow/e2e/keras/vision_model_test.py b/integrations/tensorflow/e2e/keras/vision_model_test.py
deleted file mode 100644
index 689a46d..0000000
--- a/integrations/tensorflow/e2e/keras/vision_model_test.py
+++ /dev/null
@@ -1,174 +0,0 @@
-# Lint as: python3
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Test all applications models in Keras."""
-
-import os
-
-from absl import app
-from absl import flags
-import numpy as np
-from pyiree.tf.support import tf_test_utils
-from pyiree.tf.support import tf_utils
-import tensorflow.compat.v2 as tf
-
-FLAGS = flags.FLAGS
-
-# Testing all applications models automatically can take time
-# so we test it one by one, with argument --model=MobileNet
-flags.DEFINE_string('model', 'ResNet50', 'model name')
-flags.DEFINE_string(
- 'url', '', 'url with model weights '
- 'for example https://storage.googleapis.com/iree_models/')
-flags.DEFINE_bool('use_external_weights', False,
- 'Whether or not to load external weights from the web')
-flags.DEFINE_enum('data', 'cifar10', ['cifar10', 'imagenet'],
- 'data sets on which model was trained: imagenet, cifar10')
-flags.DEFINE_bool(
- 'include_top', True,
- 'Whether or not to include the final (top) layers of the model.')
-
-APP_MODELS = {
- 'ResNet50':
- tf.keras.applications.resnet.ResNet50,
- 'ResNet101':
- tf.keras.applications.resnet.ResNet101,
- 'ResNet152':
- tf.keras.applications.resnet.ResNet152,
- 'ResNet50V2':
- tf.keras.applications.resnet_v2.ResNet50V2,
- 'ResNet101V2':
- tf.keras.applications.resnet_v2.ResNet101V2,
- 'ResNet152V2':
- tf.keras.applications.resnet_v2.ResNet152V2,
- 'VGG16':
- tf.keras.applications.vgg16.VGG16,
- 'VGG19':
- tf.keras.applications.vgg19.VGG19,
- 'Xception':
- tf.keras.applications.xception.Xception,
- 'InceptionV3':
- tf.keras.applications.inception_v3.InceptionV3,
- 'InceptionResNetV2':
- tf.keras.applications.inception_resnet_v2.InceptionResNetV2,
- 'MobileNet':
- tf.keras.applications.mobilenet.MobileNet,
- 'MobileNetV2':
- tf.keras.applications.mobilenet_v2.MobileNetV2,
- 'DenseNet121':
- tf.keras.applications.densenet.DenseNet121,
- 'DenseNet169':
- tf.keras.applications.densenet.DenseNet169,
- 'DenseNet201':
- tf.keras.applications.densenet.DenseNet201,
- 'NASNetMobile':
- tf.keras.applications.nasnet.NASNetMobile,
- 'NASNetLarge':
- tf.keras.applications.nasnet.NASNetLarge,
-}
-
-
-def get_input_shape():
- if FLAGS.data == 'imagenet':
- if FLAGS.model in ['InceptionV3', 'Xception', 'InceptionResNetV2']:
- return (1, 299, 299, 3)
- elif FLAGS.model == 'NASNetLarge':
- return (1, 331, 331, 3)
- else:
- return (1, 224, 224, 3)
- elif FLAGS.data == 'cifar10':
- return (1, 32, 32, 3)
- else:
- raise ValueError(f'Data not supported: {FLAGS.data}')
-
-
-def load_cifar10_weights(model):
- file_name = 'cifar10' + FLAGS.model
- # get_file will download the model weights from a publicly available folder,
- # save them to cache_dir=~/.keras/models/ and return a path to them.
- url = os.path.join(
- FLAGS.url, f'cifar10_include_top_{FLAGS.include_top:d}_{FLAGS.model}.h5')
- weights_path = tf.keras.utils.get_file(file_name, url)
- model.load_weights(weights_path)
- return model
-
-
-def initialize_model():
- tf_utils.set_random_seed()
-
- # Keras applications models receive input shapes without a batch dimension, as
- # the batch size is dynamic by default. This selects just the image size.
- input_shape = get_input_shape()[1:]
-
- # If weights == 'imagenet', the model will load the appropriate weights from
- # an external tf.keras URL.
- weights = None
- if FLAGS.use_external_weights and FLAGS.data == 'imagenet':
- weights = 'imagenet'
-
- model = APP_MODELS[FLAGS.model](weights=weights,
- include_top=FLAGS.include_top,
- input_shape=input_shape)
-
- if FLAGS.use_external_weights and FLAGS.data == 'cifar10':
- if not FLAGS.url:
- raise ValueError(
- 'cifar10 weights cannot be loaded without the `--url` flag.')
- model = load_cifar10_weights(model)
- return model
-
-
-class VisionModule(tf.Module):
-
- def __init__(self):
- super().__init__()
- self.m = initialize_model()
- self.m.predict = lambda x: self.m.call(x, training=False)
- # Specify input shape with a static batch size.
- # TODO(b/142948097): Add support for dynamic shapes in SPIR-V lowering.
- # Replace input_shape with m.input_shape to make the batch size dynamic.
- self.predict = tf.function(
- input_signature=[tf.TensorSpec(get_input_shape())])(self.m.predict)
-
-
-class AppTest(tf_test_utils.TracedModuleTestCase):
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self._modules = tf_test_utils.compile_tf_module(
- VisionModule,
- exported_names=['predict'],
- relative_artifacts_dir=os.path.join(FLAGS.model, FLAGS.data))
-
- def test_predict(self):
-
- def predict(module):
- module.predict(tf_utils.uniform(get_input_shape()), atol=1e-5, rtol=1e-5)
-
- self.compare_backends(predict, self._modules)
-
-
-def main(argv):
- del argv # Unused
- if hasattr(tf, 'enable_v2_behavior'):
- tf.enable_v2_behavior()
-
- if FLAGS.model not in APP_MODELS:
- raise ValueError(f'Unsupported model: {FLAGS.model}')
-
- tf.test.main()
-
-
-if __name__ == '__main__':
- app.run(main)
diff --git a/scripts/get_e2e_artifacts.py b/scripts/get_e2e_artifacts.py
index 2440ab7..21d9259 100755
--- a/scripts/get_e2e_artifacts.py
+++ b/scripts/get_e2e_artifacts.py
@@ -41,8 +41,6 @@
'//integrations/tensorflow/e2e:e2e_tests',
'mobile_bert_squad_tests':
'//integrations/tensorflow/e2e:mobile_bert_squad_tests',
- 'keras_tests':
- '//integrations/tensorflow/e2e/keras:keras_tests',
'layers_tests':
'//integrations/tensorflow/e2e/keras/layers:layers_tests',
'layers_dynamic_batch_tests':
@@ -54,7 +52,7 @@
'keyword_spotting_internal_streaming_tests':
'//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests',
'imagenet_non_hermetic_tests':
- '//integrations/tensorflow/e2e/keras:imagenet_non_hermetic_tests',
+ '//integrations/tensorflow/e2e/keras/applications:imagenet_non_hermetic_tests',
'slim_vision_tests':
'//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests',
}
diff --git a/scripts/update_e2e_coverage.py b/scripts/update_e2e_coverage.py
index 3b5509c..fba79b2 100755
--- a/scripts/update_e2e_coverage.py
+++ b/scripts/update_e2e_coverage.py
@@ -60,7 +60,7 @@
'//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests',
],
'vision_coverage': [
- '//integrations/tensorflow/e2e/keras:imagenet_non_hermetic_tests',
+ '//integrations/tensorflow/e2e/keras/applications:imagenet_non_hermetic_tests',
'//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests',
],
}
@@ -116,7 +116,7 @@
'//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests':
f'End to end tests of {KWS_LINK} models in internal streaming mode',
# vision_coverage
- '//integrations/tensorflow/e2e/keras:imagenet_non_hermetic_tests':
+ '//integrations/tensorflow/e2e/keras/applications:imagenet_non_hermetic_tests':
'End to end tests of tf.keras.applications vision models on Imagenet',
'//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests':
'End to end tests of TensorFlow slim vision models',
@@ -162,7 +162,7 @@
'//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests':
'model',
# vision_coverage
- '//integrations/tensorflow/e2e/keras:imagenet_non_hermetic_tests':
+ '//integrations/tensorflow/e2e/keras/applications:imagenet_non_hermetic_tests':
'model',
'//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests':
'model',
@@ -193,7 +193,7 @@
'//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests':
'keyword_spotting_streaming_test',
# vision_coverage
- '//integrations/tensorflow/e2e/keras:imagenet_non_hermetic_tests':
+ '//integrations/tensorflow/e2e/keras/applications:imagenet_non_hermetic_tests':
'vision_model_test',
'//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests':
'slim_vision_model_test',