blob: 50b64818b80bc9c3a389f21841344e0701d0123b [file] [log] [blame]
# Lint as: python3
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test all models in tf.keras.applications."""
import os
from absl import app
from absl import flags
from iree.tf.support import tf_test_utils
from iree.tf.support import tf_utils
import numpy as np
import tensorflow.compat.v2 as tf
FLAGS = flags.FLAGS
# Testing all applications models automatically can take time
# so we test it one by one, with argument --model=MobileNet
flags.DEFINE_string("model", "ResNet50", "model name")
flags.DEFINE_string(
"url", "", "url with model weights "
"for example https://storage.googleapis.com/iree_models/")
flags.DEFINE_bool("use_external_weights", False,
"Whether or not to load external weights from the web")
flags.DEFINE_enum("data", "cifar10", ["cifar10", "imagenet"],
"data sets on which model was trained: imagenet, cifar10")
flags.DEFINE_bool(
"include_top", True,
"Whether or not to include the final (top) layers of the model.")
BATCH_SIZE = 1
IMAGE_DIM = 224
def load_cifar10_weights(model):
file_name = "cifar10" + FLAGS.model
# get_file will download the model weights from a publicly available folder,
# save them to cache_dir=~/.keras/models/ and return a path to them.
url = os.path.join(
FLAGS.url, f"cifar10_include_top_{FLAGS.include_top:d}_{FLAGS.model}.h5")
weights_path = tf.keras.utils.get_file(file_name, url)
model.load_weights(weights_path)
return model
def initialize_model():
# If weights == "imagenet", the model will load the appropriate weights from
# an external tf.keras URL.
weights = None
if FLAGS.use_external_weights and FLAGS.data == "imagenet":
weights = "imagenet"
model_class = getattr(tf.keras.applications, FLAGS.model)
model = model_class(weights=weights, include_top=FLAGS.include_top)
if FLAGS.use_external_weights and FLAGS.data == "cifar10":
if not FLAGS.url:
raise ValueError(
"cifar10 weights cannot be loaded without the `--url` flag.")
model = load_cifar10_weights(model)
return model
class ApplicationsModule(tf_test_utils.TestModule):
def __init__(self):
super().__init__()
self.m = initialize_model()
input_shape = list([BATCH_SIZE] + self.m.inputs[0].shape[1:])
# Some models accept dynamic image dimensions by default, so we use
# IMAGE_DIM as a stand-in.
for i, dim in enumerate(input_shape):
if dim is None:
input_shape[i] = IMAGE_DIM
# Specify input shape with a static batch size.
# TODO(b/142948097): Add support for dynamic shapes in SPIR-V lowering.
self.call = tf_test_utils.tf_function_unit_test(
input_signature=[tf.TensorSpec(input_shape)],
name="call",
rtol=1e-5,
atol=1e-5)(lambda x: self.m(x, training=False))
class ApplicationsTest(tf_test_utils.TracedModuleTestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._modules = tf_test_utils.compile_tf_module(
ApplicationsModule,
exported_names=ApplicationsModule.get_tf_function_unit_tests(),
relative_artifacts_dir=os.path.join(FLAGS.model, FLAGS.data))
def main(argv):
del argv # Unused.
if hasattr(tf, "enable_v2_behavior"):
tf.enable_v2_behavior()
if not hasattr(tf.keras.applications, FLAGS.model):
raise ValueError(f"Unsupported model: {FLAGS.model}")
ApplicationsTest.generate_unit_tests(ApplicationsModule)
tf.test.main()
if __name__ == "__main__":
app.run(main)