Simplify module compilation and rename `SavedModelTestCase` (#2562)
- `SavedModelTestCase` is renamed to `CompiledModuleTestCase`, since we have moved away from using `SavedModel`s as inputs to our e2e tests.
- Now `tf_test_utils.compile_module` and `CompiledModule.create` both take a `tf.Module` subclass as their input. Previously we allowed either a "`tf.Module` subclass" or a "function which returns a `tf.Module` subclass instance" to be passed under the name `ctor`. This also allows us to assume that `module_class` has a `__name__` attribute, which we can use when saving compilation/benchmarking artifacts.
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index f032772..9cc1c93 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -270,18 +270,17 @@
return VirtualBackendsClass(*reinitialized_modules)
-def compile_module(module_ctor, exported_names=()):
- """SavedModelTestCase decorator that compiles a tf.Module.
+def compile_module(module_class, exported_names=()):
+ """CompiledModuleTestCase decorator that compiles a tf.Module.
A CompiledModule is created for each backend in --target_backends. They can
be accessed individually via self.compiled_modules.backend_name or as a union
via self.get_module().
Args:
- module_ctor: tf.Module subclass or function which returns a tf.Module
- subclass instance.
+ module_class: the tf.Module subclass to compile.
exported_names: optional iterable of strings representing which of
- module_ctor's functions to compile. If exported_names is empty all
+ module_class's functions to compile. If exported_names is empty all
functions will be compiled.
Returns:
@@ -290,11 +289,11 @@
def decorator(cls):
"""Decorator Function."""
- if not issubclass(cls, SavedModelTestCase):
+ if not issubclass(cls, CompiledModuleTestCase):
logging.exception(
"The 'compile_module' decorator must be applied to a "
- "SavedModelTestCase derived class, which %s is not.", cls)
- cls._module_ctor = module_ctor
+ "CompiledModuleTestCase derived class, which %s is not.", cls)
+ cls._module_class = module_class
cls._exported_names = exported_names
return cls
@@ -336,11 +335,11 @@
return backends
-class SavedModelTestCase(tf.test.TestCase):
- """Tests against a SavedModel."""
+class CompiledModuleTestCase(tf.test.TestCase):
+ """Compiles a tf.Module to multiple backends to test their correctness."""
# Will be initialized by the @compile_module decorator.
- _module_ctor = None
+ _module_class = None
_exported_names = ()
# Will be initialized in setUpClass to a dict of
@@ -350,27 +349,31 @@
@classmethod
def setUpClass(cls):
super().setUpClass()
- if cls._module_ctor is not None:
- # Setup the debug directory for this test. Creates a global variable
- # `global_debug_dir`.
- _setup_test_debug_dir(test_name=cls.__name__)
+ if cls._module_class is None:
+ raise AttributeError(
+ "setUpClass was called but no module was specified. Specify a module "
+ "to compile via the @tf_test_utils.compile_module decorator.")
- # Setup crash reproducer for the test.
- crash_reproducer_path = os.path.join(global_debug_dir, "reproducer.mlir")
- compiler.Context.default_crash_reproducer_path = crash_reproducer_path
+ # Setup the debug directory for this test. Creates a global variable
+ # `global_debug_dir`.
+ _setup_test_debug_dir(test_name=cls.__name__)
- # Create a CompiledModule for each backend.
- try:
- backends = get_backends()
- cls._compiled_backends_dict = {}
- for backend in backends:
- compiled_backend = tf_utils.CompiledModule.compile(
- cls._module_ctor, backend, cls._exported_names, global_debug_dir)
- cls._compiled_backends_dict[backend.name] = compiled_backend
- finally:
- # Disable crash reproducer (to avoid inadvertently overwriting this
- # path on a subsequent interaction).
- compiler.Context.default_crash_reproducer_path = None
+ # Setup crash reproducer for the test.
+ crash_reproducer_path = os.path.join(global_debug_dir, "reproducer.mlir")
+ compiler.Context.default_crash_reproducer_path = crash_reproducer_path
+
+ # Create a CompiledModule for each backend.
+ try:
+ backends = get_backends()
+ cls._compiled_backends_dict = {}
+ for backend in backends:
+ compiled_backend = tf_utils.CompiledModule.compile(
+ cls._module_class, backend, cls._exported_names, global_debug_dir)
+ cls._compiled_backends_dict[backend.name] = compiled_backend
+ finally:
+ # Disable crash reproducer (to avoid inadvertently overwriting this
+ # path on a subsequent interaction).
+ compiler.Context.default_crash_reproducer_path = None
@classmethod
def tearDownClass(cls):
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
index 2752629..4ef66d1 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -136,22 +136,24 @@
"""Base class for the TF and IREE compiled module facades."""
@staticmethod
- def compile(constructor, backend_info, exported_names=(), artifacts_dir=None):
+ def compile(module_class,
+ backend_info,
+ exported_names=(),
+ artifacts_dir=None):
"""Compile a tf.Module using the CompiledModule subclass in backend_info.
Args:
- constructor: a tf.Module subclass or function which returns a tf.Module
- subclass instance.
+ module_class: the tf.Module subclass to compile.
backend_info: an element of BackendInfo corresponding to the backend to
compile to. If a TF 'backend' is provided then the module is wrapped in
a TfCompiledModule.
exported_names: an optional iterable of strings representing which of the
- tf.Module's functions to compile. If exported_names is empty all
+ module_class's functions to compile. If exported_names is empty all
functions will be compiled.
artifacts_dir: an optional path to save compilation artifacts to.
"""
compile = backend_info.CompiledModule.compile
- return compile(constructor, backend_info, exported_names, artifacts_dir)
+ return compile(module_class, backend_info, exported_names, artifacts_dir)
@staticmethod
def from_existing(module):
@@ -160,9 +162,9 @@
from_existing = module._backend_info.CompiledModule.from_existing
return from_existing(module)
- def __init__(self, constructor, backend_info, exported_names, artifacts_dir):
+ def __init__(self, module_class, backend_info, exported_names, artifacts_dir):
"""Default constructor – use `compile` or `from_existing` instead."""
- self._constructor = constructor
+ self._module_class = module_class
self._backend_info = backend_info
self._exported_names = exported_names
self._artifacts_dir = artifacts_dir
@@ -172,45 +174,47 @@
"""Iree compiled module."""
@staticmethod
- def compile(constructor, backend_info, exported_names=(), artifacts_dir=None):
+ def compile(module_class,
+ backend_info,
+ exported_names=(),
+ artifacts_dir=None):
"""Compile a tf.Module to the target backend in backend_info.
Args:
- constructor: a tf.Module subclass or function which returns a tf.Module
- subclass instance.
+ module_class: the tf.Module subclass to compile.
backend_info: an element of BackendInfo corresponding to the IREE backend
to compile to.
exported_names: an optional iterable of strings representing which of the
- tf.Module's functions to compile. If exported_names is empty all
+ module_class's functions to compile. If exported_names is empty all
functions will be compiled.
artifacts_dir: an optional path to save compilation artifacts to.
"""
- return IreeCompiledModule(constructor, backend_info, exported_names,
+ return IreeCompiledModule(module_class, backend_info, exported_names,
artifacts_dir)
@staticmethod
def from_existing(module):
"""Duplicates 'module' with the tf.Module's state without recompiling."""
default_args = [
- module._constructor, module._backend_info, module._exported_names,
+ module._module_class, module._backend_info, module._exported_names,
module._artifacts_dir
]
from_existing_args = [module._module_blob, module._module, module._config]
return IreeCompiledModule(*default_args, from_existing_args)
def __init__(self,
- constructor,
+ module_class,
backend_info,
exported_names,
artifacts_dir,
_from_existing_args=None):
"""Default constructor – use `compile` or `from_existing` instead."""
- super().__init__(constructor, backend_info, exported_names, artifacts_dir)
+ super().__init__(module_class, backend_info, exported_names, artifacts_dir)
if _from_existing_args is None:
# Called from IreeCompiledModule.compile(...)
self._module_blob = compile_tf_module(
- tf_module=constructor(),
+ tf_module=module_class(),
target_backends=backend_info.iree_compiler_targets,
exported_names=exported_names,
artifacts_dir=artifacts_dir)
@@ -250,35 +254,37 @@
"""
@staticmethod
- def compile(constructor, backend_info, exported_names=(), artifacts_dir=None):
+ def compile(module_class,
+ backend_info,
+ exported_names=(),
+ artifacts_dir=None):
"""Wrap a tf.Module in a TFCompiledModule facade.
Args:
- constructor: a tf.Module subclass or function which returns a tf.Module
- subclass instance.
+ module_class: the tf.Module subclass to 'compile'.
backend_info: one of the 'tf*' elements in BackendInfo.
- exported_names: an optional iterable of strings representing the which of
- the tf.Module's functions should be callable. If exported_names is empty
- then all functions are callable.
+ exported_names: an optional iterable of strings representing which of the
+ module_class's functions should be callable. If exported_names is empty
+ then all functions will be callable.
artifacts_dir: an optional path to save compilation artifacts to. Has no
effect for this subclass as nothing is compiled.
"""
- return TfCompiledModule(constructor, backend_info, exported_names,
+ return TfCompiledModule(module_class, backend_info, exported_names,
artifacts_dir)
@staticmethod
def from_existing(module):
- """Duplicates 'module's facade with the starting state of constructor."""
- duplicate_module = TfCompiledModule(module._constructor,
+ """Duplicates 'module's facade with the starting state of module_class."""
+ duplicate_module = TfCompiledModule(module._module_class,
module._backend_info,
module._exported_names,
module._artifacts_dir)
return duplicate_module
- def __init__(self, constructor, backend_info, exported_names, artifacts_dir):
+ def __init__(self, module_class, backend_info, exported_names, artifacts_dir):
"""Default constructor – use `compile` or `from_existing` instead."""
- super().__init__(constructor, backend_info, exported_names, artifacts_dir)
- self._tf_module = constructor()
+ super().__init__(module_class, backend_info, exported_names, artifacts_dir)
+ self._tf_module = module_class()
def __getattr__(self, attr):
# Try to resolve it as a function.
diff --git a/integrations/tensorflow/e2e/batch_norm_test.py b/integrations/tensorflow/e2e/batch_norm_test.py
index f9f8d8c..75de16d 100644
--- a/integrations/tensorflow/e2e/batch_norm_test.py
+++ b/integrations/tensorflow/e2e/batch_norm_test.py
@@ -39,7 +39,7 @@
@tf_test_utils.compile_module(BatchNormModule)
-class BatchNormTest(tf_test_utils.SavedModelTestCase):
+class BatchNormTest(tf_test_utils.CompiledModuleTestCase):
def test_batch_norm_inference(self):
np.random.seed(12345)
diff --git a/integrations/tensorflow/e2e/broadcasting_test.py b/integrations/tensorflow/e2e/broadcasting_test.py
index cde2fd6..74880bd 100644
--- a/integrations/tensorflow/e2e/broadcasting_test.py
+++ b/integrations/tensorflow/e2e/broadcasting_test.py
@@ -29,7 +29,7 @@
@tf_test_utils.compile_module(BroadcastingModule)
-class BroadcastingTest(tf_test_utils.SavedModelTestCase):
+class BroadcastingTest(tf_test_utils.CompiledModuleTestCase):
def test_add_same_shape(self):
m = self.get_module()
diff --git a/integrations/tensorflow/e2e/concat_test.py b/integrations/tensorflow/e2e/concat_test.py
index a9f9759..b7a348c 100644
--- a/integrations/tensorflow/e2e/concat_test.py
+++ b/integrations/tensorflow/e2e/concat_test.py
@@ -51,7 +51,7 @@
@tf_test_utils.compile_module(ConcatOpsModule)
-class ConcatOpsTest(tf_test_utils.SavedModelTestCase):
+class ConcatOpsTest(tf_test_utils.CompiledModuleTestCase):
def test_concat_zero_dim(self):
tf_utils.set_random_seed()
diff --git a/integrations/tensorflow/e2e/control_flow_test.py b/integrations/tensorflow/e2e/control_flow_test.py
index 0223e8c..d8bebc1 100644
--- a/integrations/tensorflow/e2e/control_flow_test.py
+++ b/integrations/tensorflow/e2e/control_flow_test.py
@@ -39,7 +39,7 @@
@tf_test_utils.compile_module(ControlFlowModule)
-class ControlFlowTest(tf_test_utils.SavedModelTestCase):
+class ControlFlowTest(tf_test_utils.CompiledModuleTestCase):
def test_short_sequence(self):
input_array = numpy.array(9., dtype=numpy.float32)
diff --git a/integrations/tensorflow/e2e/conv_test.py b/integrations/tensorflow/e2e/conv_test.py
index f72b11d..61c46cf 100644
--- a/integrations/tensorflow/e2e/conv_test.py
+++ b/integrations/tensorflow/e2e/conv_test.py
@@ -99,7 +99,7 @@
@tf_test_utils.compile_module(Conv2dModule)
-class ConvTest(tf_test_utils.SavedModelTestCase):
+class ConvTest(tf_test_utils.CompiledModuleTestCase):
def test_id_batch_size_1(self):
i = np.arange(20, dtype=np.float32).reshape([1, 4, 5, 1])
diff --git a/integrations/tensorflow/e2e/depth_conv_test.py b/integrations/tensorflow/e2e/depth_conv_test.py
index cdf4d1e..1e8a002 100644
--- a/integrations/tensorflow/e2e/depth_conv_test.py
+++ b/integrations/tensorflow/e2e/depth_conv_test.py
@@ -39,7 +39,7 @@
@tf_test_utils.compile_module(Conv2dModule)
-class ConvTest(tf_test_utils.SavedModelTestCase):
+class ConvTest(tf_test_utils.CompiledModuleTestCase):
def test_batched_feature_unpadded(self):
i = np.arange(80, dtype=np.float32).reshape([2, 4, 5, 2])
diff --git a/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py b/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
index 04de603..64c51e9 100644
--- a/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
+++ b/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
@@ -66,7 +66,7 @@
@tf_test_utils.compile_module(Mlp, exported_names=["predict"])
-class DynamicMlpTest(tf_test_utils.SavedModelTestCase):
+class DynamicMlpTest(tf_test_utils.CompiledModuleTestCase):
def test_dynamic_batch(self):
m = self.get_module()
diff --git a/integrations/tensorflow/e2e/dynamic_mlp_test.py b/integrations/tensorflow/e2e/dynamic_mlp_test.py
index 66f3c06..72d7f1f 100644
--- a/integrations/tensorflow/e2e/dynamic_mlp_test.py
+++ b/integrations/tensorflow/e2e/dynamic_mlp_test.py
@@ -62,7 +62,7 @@
@tf_test_utils.compile_module(Mlp, exported_names=["predict"])
-class DynamicMlpTest(tf_test_utils.SavedModelTestCase):
+class DynamicMlpTest(tf_test_utils.CompiledModuleTestCase):
def test_dynamic_batch(self):
m = self.get_module()
diff --git a/integrations/tensorflow/e2e/explicit_backend_test.py b/integrations/tensorflow/e2e/explicit_backend_test.py
index 903b34c..bdcdd79 100644
--- a/integrations/tensorflow/e2e/explicit_backend_test.py
+++ b/integrations/tensorflow/e2e/explicit_backend_test.py
@@ -30,7 +30,7 @@
@tf_test_utils.compile_module(SimpleArithmeticModule)
-class ExplicitBackendTest(tf_test_utils.SavedModelTestCase):
+class ExplicitBackendTest(tf_test_utils.CompiledModuleTestCase):
def test_explicit(self):
a = np.array([1., 2., 3., 4.], dtype=np.float32)
diff --git a/integrations/tensorflow/e2e/fill_test.py b/integrations/tensorflow/e2e/fill_test.py
index 8ef96a9..8d912a4 100644
--- a/integrations/tensorflow/e2e/fill_test.py
+++ b/integrations/tensorflow/e2e/fill_test.py
@@ -31,7 +31,7 @@
@tf_test_utils.compile_module(FillModule)
-class FillTest(tf_test_utils.SavedModelTestCase):
+class FillTest(tf_test_utils.CompiledModuleTestCase):
def test_fill(self):
dims = np.array([2, 3], dtype=np.int32)
diff --git a/integrations/tensorflow/e2e/gather_test.py b/integrations/tensorflow/e2e/gather_test.py
index 8d2a0ba..67f5acf 100644
--- a/integrations/tensorflow/e2e/gather_test.py
+++ b/integrations/tensorflow/e2e/gather_test.py
@@ -49,7 +49,7 @@
@tf_test_utils.compile_module(GatherModule)
-class GatherTest(tf_test_utils.SavedModelTestCase):
+class GatherTest(tf_test_utils.CompiledModuleTestCase):
def test_gather_axis0_scalar(self):
indices = np.array(2, dtype=np.int32)
diff --git a/integrations/tensorflow/e2e/keras/lstm_static_test.py b/integrations/tensorflow/e2e/keras/lstm_static_test.py
index 0d34d97..fb7a58c 100644
--- a/integrations/tensorflow/e2e/keras/lstm_static_test.py
+++ b/integrations/tensorflow/e2e/keras/lstm_static_test.py
@@ -27,21 +27,23 @@
INPUT_SHAPE = [NUM_BATCH, NUM_TIMESTEPS, NUM_UNITS]
-def lstm_module():
- tf_utils.set_random_seed()
- inputs = tf.keras.layers.Input(batch_size=NUM_BATCH, shape=INPUT_SHAPE[1:])
- outputs = tf.keras.layers.LSTM(units=NUM_UNITS, return_sequences=True)(inputs)
- model = tf.keras.Model(inputs, outputs)
- module = tf.Module()
- module.m = model
- module.predict = tf.function(
- input_signature=[tf.TensorSpec(INPUT_SHAPE, tf.float32)])(
- model.call)
- return module
+class LstmStatic(tf.Module):
+
+ def __init__(self):
+ super(LstmStatic, self).__init__()
+ tf_utils.set_random_seed()
+ inputs = tf.keras.layers.Input(batch_size=NUM_BATCH, shape=INPUT_SHAPE[1:])
+ outputs = tf.keras.layers.LSTM(
+ units=NUM_UNITS, return_sequences=True)(
+ inputs)
+ self.m = tf.keras.Model(inputs, outputs)
+ self.predict = tf.function(
+ input_signature=[tf.TensorSpec(INPUT_SHAPE, tf.float32)])(
+ self.m.call)
-@tf_test_utils.compile_module(lstm_module, exported_names=["predict"])
-class LstmTest(tf_test_utils.SavedModelTestCase):
+@tf_test_utils.compile_module(LstmStatic, exported_names=["predict"])
+class LstmTest(tf_test_utils.CompiledModuleTestCase):
def test_lstm(self):
m = self.get_module()
diff --git a/integrations/tensorflow/e2e/keras/lstm_test.py b/integrations/tensorflow/e2e/keras/lstm_test.py
index 671c31b..9409d04 100644
--- a/integrations/tensorflow/e2e/keras/lstm_test.py
+++ b/integrations/tensorflow/e2e/keras/lstm_test.py
@@ -24,21 +24,23 @@
INPUT_SHAPE = [None, None, NUM_UNITS]
-def lstm_module():
- tf_utils.set_random_seed()
- inputs = tf.keras.layers.Input(batch_size=None, shape=INPUT_SHAPE[1:])
- outputs = tf.keras.layers.LSTM(units=NUM_UNITS, return_sequences=True)(inputs)
- model = tf.keras.Model(inputs, outputs)
- module = tf.Module()
- module.m = model
- module.predict = tf.function(
- input_signature=[tf.TensorSpec(INPUT_SHAPE, tf.float32)])(
- model.call)
- return module
+class Lstm(tf.Module):
+
+ def __init__(self):
+ super(Lstm, self).__init__()
+ tf_utils.set_random_seed()
+ inputs = tf.keras.layers.Input(batch_size=None, shape=INPUT_SHAPE[1:])
+ outputs = tf.keras.layers.LSTM(
+ units=NUM_UNITS, return_sequences=True)(
+ inputs)
+ self.m = tf.keras.Model(inputs, outputs)
+ self.predict = tf.function(
+ input_signature=[tf.TensorSpec(INPUT_SHAPE, tf.float32)])(
+ self.m.call)
-@tf_test_utils.compile_module(lstm_module, exported_names=["predict"])
-class LstmTest(tf_test_utils.SavedModelTestCase):
+@tf_test_utils.compile_module(Lstm, exported_names=["predict"])
+class LstmTest(tf_test_utils.CompiledModuleTestCase):
def test_lstm(self):
m = self.get_module()
diff --git a/integrations/tensorflow/e2e/keras/train/model_train_test.py b/integrations/tensorflow/e2e/keras/train/model_train_test.py
index 6675956..e30bd57 100644
--- a/integrations/tensorflow/e2e/keras/train/model_train_test.py
+++ b/integrations/tensorflow/e2e/keras/train/model_train_test.py
@@ -77,7 +77,7 @@
@tf_test_utils.compile_module(
ModelTrain.CreateModule, exported_names=["TrainStep"])
-class ModelTrainTest(tf_test_utils.SavedModelTestCase):
+class ModelTrainTest(tf_test_utils.CompiledModuleTestCase):
def generate_regression_data(self, size=8):
x = np.arange(size) - size // 2
diff --git a/integrations/tensorflow/e2e/keras/vision_model_test.py b/integrations/tensorflow/e2e/keras/vision_model_test.py
index 1804739..4f3917e 100644
--- a/integrations/tensorflow/e2e/keras/vision_model_test.py
+++ b/integrations/tensorflow/e2e/keras/vision_model_test.py
@@ -14,6 +14,8 @@
# limitations under the License.
"""Test all applications models in Keras."""
import os
+
+from absl import app
from absl import flags
import numpy as np
from pyiree.tf.support import tf_test_utils
@@ -72,74 +74,83 @@
}
-def get_input_shape(data, model):
- if data == 'imagenet':
- if (model == 'InceptionV3' or model == 'Xception' or
- model == 'InceptionResNetV2'):
+def get_input_shape():
+ if FLAGS.data == 'imagenet':
+ if FLAGS.model in ['InceptionV3', 'Xception', 'InceptionResNetV2']:
return (1, 299, 299, 3)
- elif model == 'NASNetLarge':
+ elif FLAGS.model == 'NASNetLarge':
return (1, 331, 331, 3)
else:
return (1, 224, 224, 3)
- elif data == 'cifar10':
+ elif FLAGS.data == 'cifar10':
return (1, 32, 32, 3)
else:
- raise ValueError('Not supported data ', data)
+ raise ValueError(f'Data not supported: {FLAGS.data}')
-def models():
- tf.keras.backend.set_learning_phase(False)
+def load_cifar10_weights(model):
+ file_name = 'cifar10' + FLAGS.model
+ # get_file will download the model weights from a publicly available folder,
+ # save them to cache_dir=~/.keras/models/ and return a path to them.
+ url = os.path.join(
+ FLAGS.url, f'cifar10_include_top_{FLAGS.include_top}_{FLAGS.model}.h5')
+ weights_path = tf.keras.utils.get_file(file_name, url)
+ model.load_weights(weights_path)
+ return model
+
+
+def initialize_model():
tf_utils.set_random_seed()
+ tf.keras.backend.set_learning_phase(False)
- input_shape = get_input_shape(FLAGS.data, FLAGS.model)
- # keras model receives images size as input,
- # where batch size is not specified - by default it is dynamic
- if FLAGS.model in APP_MODELS:
- weights = 'imagenet' if FLAGS.data == 'imagenet' else None
+ # Keras applications models receive input shapes without a batch dimension, as
+ # the batch size is dynamic by default. This selects just the image size.
+ input_shape = get_input_shape()[1:]
- # if weights == 'imagenet' it will load weights from external tf.keras URL
- model = APP_MODELS[FLAGS.model](
- weights=weights,
- include_top=FLAGS.include_top,
- input_shape=input_shape[1:])
+ # If weights == 'imagenet', the model will load the appropriate weights from
+ # an external tf.keras URL.
+ weights = 'imagenet' if FLAGS.data == 'imagenet' else None
- if FLAGS.data == 'cifar10' and FLAGS.url:
- file_name = 'cifar10' + FLAGS.model
- # it will download model weights from publically available folder: PATH
- # and save it to cache_dir=~/.keras and return path to it
- weights_path = tf.keras.utils.get_file(
- file_name,
- os.path.join(
- FLAGS.url,
- 'cifar10_include_top_{}_{}'.format(FLAGS.include_top,
- FLAGS.model + '.h5')))
+ model = APP_MODELS[FLAGS.model](
+ weights=weights, include_top=FLAGS.include_top, input_shape=input_shape)
- model.load_weights(weights_path)
- else:
- raise ValueError('Unsupported model', FLAGS.model)
-
- module = tf.Module()
- module.m = model
- # specify input size with static batch size
- # TODO(b/142948097): with support of dynamic shape
- # replace input_shape by model.input_shape, so batch size will be dynamic (-1)
- module.predict = tf.function(input_signature=[tf.TensorSpec(input_shape)])(
- model.call)
- return module
+ if FLAGS.data == 'cifar10' and FLAGS.url:
+ model = load_cifar10_weights(model)
+ return model
-@tf_test_utils.compile_module(models, exported_names=['predict'])
-class AppTest(tf_test_utils.SavedModelTestCase):
+class VisionModule(tf.Module):
+
+ def __init__(self):
+ super(VisionModule, self).__init__()
+ self.m = initialize_model()
+ # Specify input shape with a static batch size.
+ # TODO(b/142948097): Add support for dynamic shapes in SPIR-V lowering.
+ # Replace input_shape with m.input_shape to make the batch size dynamic.
+ self.predict = tf.function(
+ input_signature=[tf.TensorSpec(get_input_shape())])(
+ self.m.call)
+
+
+@tf_test_utils.compile_module(VisionModule, exported_names=['predict'])
+class AppTest(tf_test_utils.CompiledModuleTestCase):
def test_application(self):
- input_shape = get_input_shape(FLAGS.data, FLAGS.model)
- input_data = np.random.rand(np.prod(np.array(input_shape))).astype(
- np.float32)
- input_data = input_data.reshape(input_shape)
+ input_data = np.random.rand(*get_input_shape()).astype(np.float32)
self.get_module().predict(input_data).print().assert_all_close(atol=1e-6)
-if __name__ == '__main__':
+def main(argv):
+ del argv # Unused
if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
+
+ if FLAGS.model not in APP_MODELS:
+ raise ValueError(f'Unsupported model: {FLAGS.model}')
+ # Override VisionModule's __name__ to be more specific.
+ VisionModule.__name__ = FLAGS.model
+
tf.test.main()
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/linspace_test.py b/integrations/tensorflow/e2e/linspace_test.py
index d326db5..aa49f5b 100644
--- a/integrations/tensorflow/e2e/linspace_test.py
+++ b/integrations/tensorflow/e2e/linspace_test.py
@@ -34,7 +34,7 @@
@tf_test_utils.compile_module(LinSpaceModule)
-class LinspaceTest(tf_test_utils.SavedModelTestCase):
+class LinspaceTest(tf_test_utils.CompiledModuleTestCase):
def test_linspace(self):
start = np.array(10., dtype=np.float32)
diff --git a/integrations/tensorflow/e2e/mandelbrot_test.py b/integrations/tensorflow/e2e/mandelbrot_test.py
index 4886b7a..216b1b9 100644
--- a/integrations/tensorflow/e2e/mandelbrot_test.py
+++ b/integrations/tensorflow/e2e/mandelbrot_test.py
@@ -95,7 +95,7 @@
@tf_test_utils.compile_module(MandelbrotModule)
-class MandelbrotTest(tf_test_utils.SavedModelTestCase):
+class MandelbrotTest(tf_test_utils.CompiledModuleTestCase):
def test_mandelbrot(self):
mandelbrot = self.get_module()
diff --git a/integrations/tensorflow/e2e/math_test.py b/integrations/tensorflow/e2e/math_test.py
index b27d1d1..a33ac7c 100644
--- a/integrations/tensorflow/e2e/math_test.py
+++ b/integrations/tensorflow/e2e/math_test.py
@@ -39,7 +39,7 @@
@tf_test_utils.compile_module(MathModule)
-class MathTest(tf_test_utils.SavedModelTestCase):
+class MathTest(tf_test_utils.CompiledModuleTestCase):
def test_abs(self):
a = np.array([-0.5, 0.0, 0.5, 1.0], dtype=np.float32)
diff --git a/integrations/tensorflow/e2e/matrix_ops_test.py b/integrations/tensorflow/e2e/matrix_ops_test.py
index b29a198..d04ce3a 100644
--- a/integrations/tensorflow/e2e/matrix_ops_test.py
+++ b/integrations/tensorflow/e2e/matrix_ops_test.py
@@ -71,7 +71,7 @@
@tf_test_utils.compile_module(MatrixOpsModule)
-class MatrixOpsTest(tf_test_utils.SavedModelTestCase):
+class MatrixOpsTest(tf_test_utils.CompiledModuleTestCase):
def test_basic_matmul(self):
m = self.get_module()
diff --git a/integrations/tensorflow/e2e/resource_ops_test.py b/integrations/tensorflow/e2e/resource_ops_test.py
index 1d703c0..8daa6cf 100644
--- a/integrations/tensorflow/e2e/resource_ops_test.py
+++ b/integrations/tensorflow/e2e/resource_ops_test.py
@@ -29,7 +29,7 @@
@tf_test_utils.compile_module(ResourcesOpsModule)
-class ResourcesOpsTest(tf_test_utils.SavedModelTestCase):
+class ResourcesOpsTest(tf_test_utils.CompiledModuleTestCase):
def test_add_assign(self):
result = self.get_module().add_assign(np.array(9., dtype=np.float32))
diff --git a/integrations/tensorflow/e2e/ring_buffer_test.py b/integrations/tensorflow/e2e/ring_buffer_test.py
index ea48711..3af1502 100644
--- a/integrations/tensorflow/e2e/ring_buffer_test.py
+++ b/integrations/tensorflow/e2e/ring_buffer_test.py
@@ -179,7 +179,7 @@
@tf_test_utils.compile_module(
StatefulRingBufferModule, exported_names=["predict"])
-class StatefulRingBufferTest(tf_test_utils.SavedModelTestCase):
+class StatefulRingBufferTest(tf_test_utils.CompiledModuleTestCase):
def test_stateful_ringbuffer(self):
input1 = np.array([[1.0, 2.0]], dtype=np.float32)
diff --git a/integrations/tensorflow/e2e/scatter_update_test.py b/integrations/tensorflow/e2e/scatter_update_test.py
index cdd3277..ab5ab91 100644
--- a/integrations/tensorflow/e2e/scatter_update_test.py
+++ b/integrations/tensorflow/e2e/scatter_update_test.py
@@ -49,7 +49,7 @@
@tf_test_utils.compile_module(ScatterUpdateModule)
-class ScatterUpdateTest(tf_test_utils.SavedModelTestCase):
+class ScatterUpdateTest(tf_test_utils.CompiledModuleTestCase):
def test_scatter_update_1D(self):
tensor = tf.ones([8], dtype=tf.int32)
diff --git a/integrations/tensorflow/e2e/simple_arithmetic_test.py b/integrations/tensorflow/e2e/simple_arithmetic_test.py
index 0c5941d..d3ea327 100644
--- a/integrations/tensorflow/e2e/simple_arithmetic_test.py
+++ b/integrations/tensorflow/e2e/simple_arithmetic_test.py
@@ -37,7 +37,7 @@
@tf_test_utils.compile_module(SimpleArithmeticModule)
-class SimpleArithmeticTest(tf_test_utils.SavedModelTestCase):
+class SimpleArithmeticTest(tf_test_utils.CompiledModuleTestCase):
def test_simple_mul(self):
a = np.array([1., 2., 3., 4.], dtype=np.float32)
diff --git a/integrations/tensorflow/e2e/simple_stateful_test.py b/integrations/tensorflow/e2e/simple_stateful_test.py
index 45eba4f..24dd23e 100644
--- a/integrations/tensorflow/e2e/simple_stateful_test.py
+++ b/integrations/tensorflow/e2e/simple_stateful_test.py
@@ -33,7 +33,7 @@
@tf_test_utils.compile_module(Stateful)
-class StatefulTest(tf_test_utils.SavedModelTestCase):
+class StatefulTest(tf_test_utils.CompiledModuleTestCase):
def test_stateful(self):
m = self.get_module()
diff --git a/integrations/tensorflow/e2e/sliding_window_test.py b/integrations/tensorflow/e2e/sliding_window_test.py
index b663fc8..f206d86 100644
--- a/integrations/tensorflow/e2e/sliding_window_test.py
+++ b/integrations/tensorflow/e2e/sliding_window_test.py
@@ -76,7 +76,7 @@
@tf_test_utils.compile_module(SlidingWindowModule, exported_names=["predict"])
-class SlidingWindowTest(tf_test_utils.SavedModelTestCase):
+class SlidingWindowTest(tf_test_utils.CompiledModuleTestCase):
def test_slidingwindow(self):
input1 = np.array([[1.0, 2.0]], dtype=np.float32)
diff --git a/integrations/tensorflow/e2e/strings_test.py b/integrations/tensorflow/e2e/strings_test.py
index ac590ff..ce0787e 100644
--- a/integrations/tensorflow/e2e/strings_test.py
+++ b/integrations/tensorflow/e2e/strings_test.py
@@ -41,7 +41,7 @@
@tf_test_utils.compile_module(StringsModule)
-class StringsTest(tf_test_utils.SavedModelTestCase):
+class StringsTest(tf_test_utils.CompiledModuleTestCase):
def test_print_ids(self):
input_ids = np.asarray(
diff --git a/integrations/tensorflow/e2e/tensorlist_test.py b/integrations/tensorflow/e2e/tensorlist_test.py
index f8ea811..421760c 100644
--- a/integrations/tensorflow/e2e/tensorlist_test.py
+++ b/integrations/tensorflow/e2e/tensorlist_test.py
@@ -69,7 +69,7 @@
@tf_test_utils.compile_module(TensorListModule)
-class TensorListTest(tf_test_utils.SavedModelTestCase):
+class TensorListTest(tf_test_utils.CompiledModuleTestCase):
def test_identity_through_tensorlist(self):
m = self.get_module()