Simplify module compilation and rename `SavedModelTestCase` (#2562)

- `SavedModelTestCase` is renamed to `CompiledModuleTestCase`, since we have moved away from using `SavedModel`s as inputs to our e2e tests.
- Now `tf_test_utils.compile_module` and `CompiledModule.create` both take a `tf.Module` subclass as their input. Previously we allowed either a "`tf.Module` subclass" or a "function which returns a `tf.Module` subclass instance" to be passed under the name `ctor`. This also allows us to assume that `module_class` has a `__name__` attribute, which we can use when saving compilation/benchmarking artifacts.
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index f032772..9cc1c93 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -270,18 +270,17 @@
   return VirtualBackendsClass(*reinitialized_modules)
 
 
-def compile_module(module_ctor, exported_names=()):
-  """SavedModelTestCase decorator that compiles a tf.Module.
+def compile_module(module_class, exported_names=()):
+  """CompiledModuleTestCase decorator that compiles a tf.Module.
 
   A CompiledModule is created for each backend in --target_backends. They can
   be accessed individually via self.compiled_modules.backend_name or as a union
   via self.get_module().
 
   Args:
-    module_ctor: tf.Module subclass or function which returns a tf.Module
-      subclass instance.
+    module_class: the tf.Module subclass to compile.
     exported_names: optional iterable of strings representing which of
-      module_ctor's functions to compile. If exported_names is empty all
+      module_class's functions to compile. If exported_names is empty all
       functions will be compiled.
 
   Returns:
@@ -290,11 +289,11 @@
 
   def decorator(cls):
     """Decorator Function."""
-    if not issubclass(cls, SavedModelTestCase):
+    if not issubclass(cls, CompiledModuleTestCase):
       logging.exception(
           "The 'compile_module' decorator must be applied to a "
-          "SavedModelTestCase derived class, which %s is not.", cls)
-    cls._module_ctor = module_ctor
+          "CompiledModuleTestCase derived class, which %s is not.", cls)
+    cls._module_class = module_class
     cls._exported_names = exported_names
     return cls
 
@@ -336,11 +335,11 @@
   return backends
 
 
-class SavedModelTestCase(tf.test.TestCase):
-  """Tests against a SavedModel."""
+class CompiledModuleTestCase(tf.test.TestCase):
+  """Compiles a tf.Module to multiple backends to test their correctness."""
 
   # Will be initialized by the @compile_module decorator.
-  _module_ctor = None
+  _module_class = None
   _exported_names = ()
 
   # Will be initialized in setUpClass to a dict of
@@ -350,27 +349,31 @@
   @classmethod
   def setUpClass(cls):
     super().setUpClass()
-    if cls._module_ctor is not None:
-      # Setup the debug directory for this test. Creates a global variable
-      # `global_debug_dir`.
-      _setup_test_debug_dir(test_name=cls.__name__)
+    if cls._module_class is None:
+      raise AttributeError(
+          "setUpClass was called but no module was specified. Specify a module "
+          "to compile via the @tf_test_utils.compile_module decorator.")
 
-      # Setup crash reproducer for the test.
-      crash_reproducer_path = os.path.join(global_debug_dir, "reproducer.mlir")
-      compiler.Context.default_crash_reproducer_path = crash_reproducer_path
+    # Setup the debug directory for this test. Creates a global variable
+    # `global_debug_dir`.
+    _setup_test_debug_dir(test_name=cls.__name__)
 
-      # Create a CompiledModule for each backend.
-      try:
-        backends = get_backends()
-        cls._compiled_backends_dict = {}
-        for backend in backends:
-          compiled_backend = tf_utils.CompiledModule.compile(
-              cls._module_ctor, backend, cls._exported_names, global_debug_dir)
-          cls._compiled_backends_dict[backend.name] = compiled_backend
-      finally:
-        # Disable crash reproducer (to avoid inadvertently overwriting this
-        # path on a subsequent interaction).
-        compiler.Context.default_crash_reproducer_path = None
+    # Setup crash reproducer for the test.
+    crash_reproducer_path = os.path.join(global_debug_dir, "reproducer.mlir")
+    compiler.Context.default_crash_reproducer_path = crash_reproducer_path
+
+    # Create a CompiledModule for each backend.
+    try:
+      backends = get_backends()
+      cls._compiled_backends_dict = {}
+      for backend in backends:
+        compiled_backend = tf_utils.CompiledModule.compile(
+            cls._module_class, backend, cls._exported_names, global_debug_dir)
+        cls._compiled_backends_dict[backend.name] = compiled_backend
+    finally:
+      # Disable crash reproducer (to avoid inadvertently overwriting this
+      # path on a subsequent interaction).
+      compiler.Context.default_crash_reproducer_path = None
 
   @classmethod
   def tearDownClass(cls):
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
index 2752629..4ef66d1 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -136,22 +136,24 @@
   """Base class for the TF and IREE compiled module facades."""
 
   @staticmethod
-  def compile(constructor, backend_info, exported_names=(), artifacts_dir=None):
+  def compile(module_class,
+              backend_info,
+              exported_names=(),
+              artifacts_dir=None):
     """Compile a tf.Module using the CompiledModule subclass in backend_info.
 
     Args:
-      constructor: a tf.Module subclass or function which returns a tf.Module
-        subclass instance.
+      module_class: the tf.Module subclass to compile.
       backend_info: an element of BackendInfo corresponding to the backend to
         compile to. If a TF 'backend' is provided then the module is wrapped in
         a TfCompiledModule.
       exported_names: an optional iterable of strings representing which of the
-        tf.Module's functions to compile. If exported_names is empty all
+        module_class's functions to compile. If exported_names is empty all
         functions will be compiled.
       artifacts_dir: an optional path to save compilation artifacts to.
     """
     compile = backend_info.CompiledModule.compile
-    return compile(constructor, backend_info, exported_names, artifacts_dir)
+    return compile(module_class, backend_info, exported_names, artifacts_dir)
 
   @staticmethod
   def from_existing(module):
@@ -160,9 +162,9 @@
     from_existing = module._backend_info.CompiledModule.from_existing
     return from_existing(module)
 
-  def __init__(self, constructor, backend_info, exported_names, artifacts_dir):
+  def __init__(self, module_class, backend_info, exported_names, artifacts_dir):
     """Default constructor – use `compile` or `from_existing` instead."""
-    self._constructor = constructor
+    self._module_class = module_class
     self._backend_info = backend_info
     self._exported_names = exported_names
     self._artifacts_dir = artifacts_dir
@@ -172,45 +174,47 @@
   """Iree compiled module."""
 
   @staticmethod
-  def compile(constructor, backend_info, exported_names=(), artifacts_dir=None):
+  def compile(module_class,
+              backend_info,
+              exported_names=(),
+              artifacts_dir=None):
     """Compile a tf.Module to the target backend in backend_info.
 
     Args:
-      constructor: a tf.Module subclass or function which returns a tf.Module
-        subclass instance.
+      module_class: the tf.Module subclass to compile.
       backend_info: an element of BackendInfo corresponding to the IREE backend
         to compile to.
       exported_names: an optional iterable of strings representing which of the
-        tf.Module's functions to compile. If exported_names is empty all
+        module_class's functions to compile. If exported_names is empty all
         functions will be compiled.
       artifacts_dir: an optional path to save compilation artifacts to.
     """
-    return IreeCompiledModule(constructor, backend_info, exported_names,
+    return IreeCompiledModule(module_class, backend_info, exported_names,
                               artifacts_dir)
 
   @staticmethod
   def from_existing(module):
     """Duplicates 'module' with the tf.Module's state without recompiling."""
     default_args = [
-        module._constructor, module._backend_info, module._exported_names,
+        module._module_class, module._backend_info, module._exported_names,
         module._artifacts_dir
     ]
     from_existing_args = [module._module_blob, module._module, module._config]
     return IreeCompiledModule(*default_args, from_existing_args)
 
   def __init__(self,
-               constructor,
+               module_class,
                backend_info,
                exported_names,
                artifacts_dir,
                _from_existing_args=None):
     """Default constructor – use `compile` or `from_existing` instead."""
-    super().__init__(constructor, backend_info, exported_names, artifacts_dir)
+    super().__init__(module_class, backend_info, exported_names, artifacts_dir)
 
     if _from_existing_args is None:
       # Called from IreeCompiledModule.compile(...)
       self._module_blob = compile_tf_module(
-          tf_module=constructor(),
+          tf_module=module_class(),
           target_backends=backend_info.iree_compiler_targets,
           exported_names=exported_names,
           artifacts_dir=artifacts_dir)
@@ -250,35 +254,37 @@
   """
 
   @staticmethod
-  def compile(constructor, backend_info, exported_names=(), artifacts_dir=None):
+  def compile(module_class,
+              backend_info,
+              exported_names=(),
+              artifacts_dir=None):
     """Wrap a tf.Module in a TFCompiledModule facade.
 
     Args:
-      constructor: a tf.Module subclass or function which returns a tf.Module
-        subclass instance.
+      module_class: the tf.Module subclass to 'compile'.
       backend_info: one of the 'tf*' elements in BackendInfo.
-      exported_names: an optional iterable of strings representing the which of
-        the tf.Module's functions should be callable. If exported_names is empty
-        then all functions are callable.
+      exported_names: an optional iterable of strings representing which of the
+        module_class's functions should be callable. If exported_names is empty
+        then all functions will be callable.
       artifacts_dir: an optional path to save compilation artifacts to. Has no
         effect for this subclass as nothing is compiled.
     """
-    return TfCompiledModule(constructor, backend_info, exported_names,
+    return TfCompiledModule(module_class, backend_info, exported_names,
                             artifacts_dir)
 
   @staticmethod
   def from_existing(module):
-    """Duplicates 'module's facade with the starting state of constructor."""
-    duplicate_module = TfCompiledModule(module._constructor,
+    """Duplicates 'module's facade with the starting state of module_class."""
+    duplicate_module = TfCompiledModule(module._module_class,
                                         module._backend_info,
                                         module._exported_names,
                                         module._artifacts_dir)
     return duplicate_module
 
-  def __init__(self, constructor, backend_info, exported_names, artifacts_dir):
+  def __init__(self, module_class, backend_info, exported_names, artifacts_dir):
     """Default constructor – use `compile` or `from_existing` instead."""
-    super().__init__(constructor, backend_info, exported_names, artifacts_dir)
-    self._tf_module = constructor()
+    super().__init__(module_class, backend_info, exported_names, artifacts_dir)
+    self._tf_module = module_class()
 
   def __getattr__(self, attr):
     # Try to resolve it as a function.
diff --git a/integrations/tensorflow/e2e/batch_norm_test.py b/integrations/tensorflow/e2e/batch_norm_test.py
index f9f8d8c..75de16d 100644
--- a/integrations/tensorflow/e2e/batch_norm_test.py
+++ b/integrations/tensorflow/e2e/batch_norm_test.py
@@ -39,7 +39,7 @@
 
 
 @tf_test_utils.compile_module(BatchNormModule)
-class BatchNormTest(tf_test_utils.SavedModelTestCase):
+class BatchNormTest(tf_test_utils.CompiledModuleTestCase):
 
   def test_batch_norm_inference(self):
     np.random.seed(12345)
diff --git a/integrations/tensorflow/e2e/broadcasting_test.py b/integrations/tensorflow/e2e/broadcasting_test.py
index cde2fd6..74880bd 100644
--- a/integrations/tensorflow/e2e/broadcasting_test.py
+++ b/integrations/tensorflow/e2e/broadcasting_test.py
@@ -29,7 +29,7 @@
 
 
 @tf_test_utils.compile_module(BroadcastingModule)
-class BroadcastingTest(tf_test_utils.SavedModelTestCase):
+class BroadcastingTest(tf_test_utils.CompiledModuleTestCase):
 
   def test_add_same_shape(self):
     m = self.get_module()
diff --git a/integrations/tensorflow/e2e/concat_test.py b/integrations/tensorflow/e2e/concat_test.py
index a9f9759..b7a348c 100644
--- a/integrations/tensorflow/e2e/concat_test.py
+++ b/integrations/tensorflow/e2e/concat_test.py
@@ -51,7 +51,7 @@
 
 
 @tf_test_utils.compile_module(ConcatOpsModule)
-class ConcatOpsTest(tf_test_utils.SavedModelTestCase):
+class ConcatOpsTest(tf_test_utils.CompiledModuleTestCase):
 
   def test_concat_zero_dim(self):
     tf_utils.set_random_seed()
diff --git a/integrations/tensorflow/e2e/control_flow_test.py b/integrations/tensorflow/e2e/control_flow_test.py
index 0223e8c..d8bebc1 100644
--- a/integrations/tensorflow/e2e/control_flow_test.py
+++ b/integrations/tensorflow/e2e/control_flow_test.py
@@ -39,7 +39,7 @@
 
 
 @tf_test_utils.compile_module(ControlFlowModule)
-class ControlFlowTest(tf_test_utils.SavedModelTestCase):
+class ControlFlowTest(tf_test_utils.CompiledModuleTestCase):
 
   def test_short_sequence(self):
     input_array = numpy.array(9., dtype=numpy.float32)
diff --git a/integrations/tensorflow/e2e/conv_test.py b/integrations/tensorflow/e2e/conv_test.py
index f72b11d..61c46cf 100644
--- a/integrations/tensorflow/e2e/conv_test.py
+++ b/integrations/tensorflow/e2e/conv_test.py
@@ -99,7 +99,7 @@
 
 
 @tf_test_utils.compile_module(Conv2dModule)
-class ConvTest(tf_test_utils.SavedModelTestCase):
+class ConvTest(tf_test_utils.CompiledModuleTestCase):
 
   def test_id_batch_size_1(self):
     i = np.arange(20, dtype=np.float32).reshape([1, 4, 5, 1])
diff --git a/integrations/tensorflow/e2e/depth_conv_test.py b/integrations/tensorflow/e2e/depth_conv_test.py
index cdf4d1e..1e8a002 100644
--- a/integrations/tensorflow/e2e/depth_conv_test.py
+++ b/integrations/tensorflow/e2e/depth_conv_test.py
@@ -39,7 +39,7 @@
 
 
 @tf_test_utils.compile_module(Conv2dModule)
-class ConvTest(tf_test_utils.SavedModelTestCase):
+class ConvTest(tf_test_utils.CompiledModuleTestCase):
 
   def test_batched_feature_unpadded(self):
     i = np.arange(80, dtype=np.float32).reshape([2, 4, 5, 2])
diff --git a/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py b/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
index 04de603..64c51e9 100644
--- a/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
+++ b/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
@@ -66,7 +66,7 @@
 
 
 @tf_test_utils.compile_module(Mlp, exported_names=["predict"])
-class DynamicMlpTest(tf_test_utils.SavedModelTestCase):
+class DynamicMlpTest(tf_test_utils.CompiledModuleTestCase):
 
   def test_dynamic_batch(self):
     m = self.get_module()
diff --git a/integrations/tensorflow/e2e/dynamic_mlp_test.py b/integrations/tensorflow/e2e/dynamic_mlp_test.py
index 66f3c06..72d7f1f 100644
--- a/integrations/tensorflow/e2e/dynamic_mlp_test.py
+++ b/integrations/tensorflow/e2e/dynamic_mlp_test.py
@@ -62,7 +62,7 @@
 
 
 @tf_test_utils.compile_module(Mlp, exported_names=["predict"])
-class DynamicMlpTest(tf_test_utils.SavedModelTestCase):
+class DynamicMlpTest(tf_test_utils.CompiledModuleTestCase):
 
   def test_dynamic_batch(self):
     m = self.get_module()
diff --git a/integrations/tensorflow/e2e/explicit_backend_test.py b/integrations/tensorflow/e2e/explicit_backend_test.py
index 903b34c..bdcdd79 100644
--- a/integrations/tensorflow/e2e/explicit_backend_test.py
+++ b/integrations/tensorflow/e2e/explicit_backend_test.py
@@ -30,7 +30,7 @@
 
 
 @tf_test_utils.compile_module(SimpleArithmeticModule)
-class ExplicitBackendTest(tf_test_utils.SavedModelTestCase):
+class ExplicitBackendTest(tf_test_utils.CompiledModuleTestCase):
 
   def test_explicit(self):
     a = np.array([1., 2., 3., 4.], dtype=np.float32)
diff --git a/integrations/tensorflow/e2e/fill_test.py b/integrations/tensorflow/e2e/fill_test.py
index 8ef96a9..8d912a4 100644
--- a/integrations/tensorflow/e2e/fill_test.py
+++ b/integrations/tensorflow/e2e/fill_test.py
@@ -31,7 +31,7 @@
 
 
 @tf_test_utils.compile_module(FillModule)
-class FillTest(tf_test_utils.SavedModelTestCase):
+class FillTest(tf_test_utils.CompiledModuleTestCase):
 
   def test_fill(self):
     dims = np.array([2, 3], dtype=np.int32)
diff --git a/integrations/tensorflow/e2e/gather_test.py b/integrations/tensorflow/e2e/gather_test.py
index 8d2a0ba..67f5acf 100644
--- a/integrations/tensorflow/e2e/gather_test.py
+++ b/integrations/tensorflow/e2e/gather_test.py
@@ -49,7 +49,7 @@
 
 
 @tf_test_utils.compile_module(GatherModule)
-class GatherTest(tf_test_utils.SavedModelTestCase):
+class GatherTest(tf_test_utils.CompiledModuleTestCase):
 
   def test_gather_axis0_scalar(self):
     indices = np.array(2, dtype=np.int32)
diff --git a/integrations/tensorflow/e2e/keras/lstm_static_test.py b/integrations/tensorflow/e2e/keras/lstm_static_test.py
index 0d34d97..fb7a58c 100644
--- a/integrations/tensorflow/e2e/keras/lstm_static_test.py
+++ b/integrations/tensorflow/e2e/keras/lstm_static_test.py
@@ -27,21 +27,23 @@
 INPUT_SHAPE = [NUM_BATCH, NUM_TIMESTEPS, NUM_UNITS]
 
 
-def lstm_module():
-  tf_utils.set_random_seed()
-  inputs = tf.keras.layers.Input(batch_size=NUM_BATCH, shape=INPUT_SHAPE[1:])
-  outputs = tf.keras.layers.LSTM(units=NUM_UNITS, return_sequences=True)(inputs)
-  model = tf.keras.Model(inputs, outputs)
-  module = tf.Module()
-  module.m = model
-  module.predict = tf.function(
-      input_signature=[tf.TensorSpec(INPUT_SHAPE, tf.float32)])(
-          model.call)
-  return module
+class LstmStatic(tf.Module):
+
+  def __init__(self):
+    super(LstmStatic, self).__init__()
+    tf_utils.set_random_seed()
+    inputs = tf.keras.layers.Input(batch_size=NUM_BATCH, shape=INPUT_SHAPE[1:])
+    outputs = tf.keras.layers.LSTM(
+        units=NUM_UNITS, return_sequences=True)(
+            inputs)
+    self.m = tf.keras.Model(inputs, outputs)
+    self.predict = tf.function(
+        input_signature=[tf.TensorSpec(INPUT_SHAPE, tf.float32)])(
+            self.m.call)
 
 
-@tf_test_utils.compile_module(lstm_module, exported_names=["predict"])
-class LstmTest(tf_test_utils.SavedModelTestCase):
+@tf_test_utils.compile_module(LstmStatic, exported_names=["predict"])
+class LstmTest(tf_test_utils.CompiledModuleTestCase):
 
   def test_lstm(self):
     m = self.get_module()
diff --git a/integrations/tensorflow/e2e/keras/lstm_test.py b/integrations/tensorflow/e2e/keras/lstm_test.py
index 671c31b..9409d04 100644
--- a/integrations/tensorflow/e2e/keras/lstm_test.py
+++ b/integrations/tensorflow/e2e/keras/lstm_test.py
@@ -24,21 +24,23 @@
 INPUT_SHAPE = [None, None, NUM_UNITS]
 
 
-def lstm_module():
-  tf_utils.set_random_seed()
-  inputs = tf.keras.layers.Input(batch_size=None, shape=INPUT_SHAPE[1:])
-  outputs = tf.keras.layers.LSTM(units=NUM_UNITS, return_sequences=True)(inputs)
-  model = tf.keras.Model(inputs, outputs)
-  module = tf.Module()
-  module.m = model
-  module.predict = tf.function(
-      input_signature=[tf.TensorSpec(INPUT_SHAPE, tf.float32)])(
-          model.call)
-  return module
+class Lstm(tf.Module):
+
+  def __init__(self):
+    super(Lstm, self).__init__()
+    tf_utils.set_random_seed()
+    inputs = tf.keras.layers.Input(batch_size=None, shape=INPUT_SHAPE[1:])
+    outputs = tf.keras.layers.LSTM(
+        units=NUM_UNITS, return_sequences=True)(
+            inputs)
+    self.m = tf.keras.Model(inputs, outputs)
+    self.predict = tf.function(
+        input_signature=[tf.TensorSpec(INPUT_SHAPE, tf.float32)])(
+            self.m.call)
 
 
-@tf_test_utils.compile_module(lstm_module, exported_names=["predict"])
-class LstmTest(tf_test_utils.SavedModelTestCase):
+@tf_test_utils.compile_module(Lstm, exported_names=["predict"])
+class LstmTest(tf_test_utils.CompiledModuleTestCase):
 
   def test_lstm(self):
     m = self.get_module()
diff --git a/integrations/tensorflow/e2e/keras/train/model_train_test.py b/integrations/tensorflow/e2e/keras/train/model_train_test.py
index 6675956..e30bd57 100644
--- a/integrations/tensorflow/e2e/keras/train/model_train_test.py
+++ b/integrations/tensorflow/e2e/keras/train/model_train_test.py
@@ -77,7 +77,7 @@
 
 @tf_test_utils.compile_module(
     ModelTrain.CreateModule, exported_names=["TrainStep"])
-class ModelTrainTest(tf_test_utils.SavedModelTestCase):
+class ModelTrainTest(tf_test_utils.CompiledModuleTestCase):
 
   def generate_regression_data(self, size=8):
     x = np.arange(size) - size // 2
diff --git a/integrations/tensorflow/e2e/keras/vision_model_test.py b/integrations/tensorflow/e2e/keras/vision_model_test.py
index 1804739..4f3917e 100644
--- a/integrations/tensorflow/e2e/keras/vision_model_test.py
+++ b/integrations/tensorflow/e2e/keras/vision_model_test.py
@@ -14,6 +14,8 @@
 # limitations under the License.
 """Test all applications models in Keras."""
 import os
+
+from absl import app
 from absl import flags
 import numpy as np
 from pyiree.tf.support import tf_test_utils
@@ -72,74 +74,83 @@
 }
 
 
-def get_input_shape(data, model):
-  if data == 'imagenet':
-    if (model == 'InceptionV3' or model == 'Xception' or
-        model == 'InceptionResNetV2'):
+def get_input_shape():
+  if FLAGS.data == 'imagenet':
+    if FLAGS.model in ['InceptionV3', 'Xception', 'InceptionResNetV2']:
       return (1, 299, 299, 3)
-    elif model == 'NASNetLarge':
+    elif FLAGS.model == 'NASNetLarge':
       return (1, 331, 331, 3)
     else:
       return (1, 224, 224, 3)
-  elif data == 'cifar10':
+  elif FLAGS.data == 'cifar10':
     return (1, 32, 32, 3)
   else:
-    raise ValueError('Not supported data ', data)
+    raise ValueError(f'Data not supported: {FLAGS.data}')
 
 
-def models():
-  tf.keras.backend.set_learning_phase(False)
+def load_cifar10_weights(model):
+  file_name = 'cifar10' + FLAGS.model
+  # get_file will download the model weights from a publicly available folder,
+  # save them to cache_dir=~/.keras/models/ and return a path to them.
+  url = os.path.join(
+      FLAGS.url, f'cifar10_include_top_{FLAGS.include_top}_{FLAGS.model}.h5')
+  weights_path = tf.keras.utils.get_file(file_name, url)
+  model.load_weights(weights_path)
+  return model
+
+
+def initialize_model():
   tf_utils.set_random_seed()
+  tf.keras.backend.set_learning_phase(False)
 
-  input_shape = get_input_shape(FLAGS.data, FLAGS.model)
-  # keras model receives images size as input,
-  # where batch size is not specified - by default it is dynamic
-  if FLAGS.model in APP_MODELS:
-    weights = 'imagenet' if FLAGS.data == 'imagenet' else None
+  # Keras applications models receive input shapes without a batch dimension, as
+  # the batch size is dynamic by default. This selects just the image size.
+  input_shape = get_input_shape()[1:]
 
-    # if weights == 'imagenet' it will load weights from external tf.keras URL
-    model = APP_MODELS[FLAGS.model](
-        weights=weights,
-        include_top=FLAGS.include_top,
-        input_shape=input_shape[1:])
+  # If weights == 'imagenet', the model will load the appropriate weights from
+  # an external tf.keras URL.
+  weights = 'imagenet' if FLAGS.data == 'imagenet' else None
 
-    if FLAGS.data == 'cifar10' and FLAGS.url:
-      file_name = 'cifar10' + FLAGS.model
-      # it will download model weights from publically available folder: PATH
-      # and save it to cache_dir=~/.keras and return path to it
-      weights_path = tf.keras.utils.get_file(
-          file_name,
-          os.path.join(
-              FLAGS.url,
-              'cifar10_include_top_{}_{}'.format(FLAGS.include_top,
-                                                 FLAGS.model + '.h5')))
+  model = APP_MODELS[FLAGS.model](
+      weights=weights, include_top=FLAGS.include_top, input_shape=input_shape)
 
-      model.load_weights(weights_path)
-  else:
-    raise ValueError('Unsupported model', FLAGS.model)
-
-  module = tf.Module()
-  module.m = model
-  # specify input size with static batch size
-  # TODO(b/142948097): with support of dynamic shape
-  # replace input_shape by model.input_shape, so batch size will be dynamic (-1)
-  module.predict = tf.function(input_signature=[tf.TensorSpec(input_shape)])(
-      model.call)
-  return module
+  if FLAGS.data == 'cifar10' and FLAGS.url:
+    model = load_cifar10_weights(model)
+  return model
 
 
-@tf_test_utils.compile_module(models, exported_names=['predict'])
-class AppTest(tf_test_utils.SavedModelTestCase):
+class VisionModule(tf.Module):
+
+  def __init__(self):
+    super(VisionModule, self).__init__()
+    self.m = initialize_model()
+    # Specify input shape with a static batch size.
+    # TODO(b/142948097): Add support for dynamic shapes in SPIR-V lowering.
+    # Replace input_shape with m.input_shape to make the batch size dynamic.
+    self.predict = tf.function(
+        input_signature=[tf.TensorSpec(get_input_shape())])(
+            self.m.call)
+
+
+@tf_test_utils.compile_module(VisionModule, exported_names=['predict'])
+class AppTest(tf_test_utils.CompiledModuleTestCase):
 
   def test_application(self):
-    input_shape = get_input_shape(FLAGS.data, FLAGS.model)
-    input_data = np.random.rand(np.prod(np.array(input_shape))).astype(
-        np.float32)
-    input_data = input_data.reshape(input_shape)
+    input_data = np.random.rand(*get_input_shape()).astype(np.float32)
     self.get_module().predict(input_data).print().assert_all_close(atol=1e-6)
 
 
-if __name__ == '__main__':
+def main(argv):
+  del argv  # Unused
   if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
+
+  if FLAGS.model not in APP_MODELS:
+    raise ValueError(f'Unsupported model: {FLAGS.model}')
+  # Override VisionModule's __name__ to be more specific.
+  VisionModule.__name__ = FLAGS.model
+
   tf.test.main()
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/linspace_test.py b/integrations/tensorflow/e2e/linspace_test.py
index d326db5..aa49f5b 100644
--- a/integrations/tensorflow/e2e/linspace_test.py
+++ b/integrations/tensorflow/e2e/linspace_test.py
@@ -34,7 +34,7 @@
 
 
 @tf_test_utils.compile_module(LinSpaceModule)
-class LinspaceTest(tf_test_utils.SavedModelTestCase):
+class LinspaceTest(tf_test_utils.CompiledModuleTestCase):
 
   def test_linspace(self):
     start = np.array(10., dtype=np.float32)
diff --git a/integrations/tensorflow/e2e/mandelbrot_test.py b/integrations/tensorflow/e2e/mandelbrot_test.py
index 4886b7a..216b1b9 100644
--- a/integrations/tensorflow/e2e/mandelbrot_test.py
+++ b/integrations/tensorflow/e2e/mandelbrot_test.py
@@ -95,7 +95,7 @@
 
 
 @tf_test_utils.compile_module(MandelbrotModule)
-class MandelbrotTest(tf_test_utils.SavedModelTestCase):
+class MandelbrotTest(tf_test_utils.CompiledModuleTestCase):
 
   def test_mandelbrot(self):
     mandelbrot = self.get_module()
diff --git a/integrations/tensorflow/e2e/math_test.py b/integrations/tensorflow/e2e/math_test.py
index b27d1d1..a33ac7c 100644
--- a/integrations/tensorflow/e2e/math_test.py
+++ b/integrations/tensorflow/e2e/math_test.py
@@ -39,7 +39,7 @@
 
 
 @tf_test_utils.compile_module(MathModule)
-class MathTest(tf_test_utils.SavedModelTestCase):
+class MathTest(tf_test_utils.CompiledModuleTestCase):
 
   def test_abs(self):
     a = np.array([-0.5, 0.0, 0.5, 1.0], dtype=np.float32)
diff --git a/integrations/tensorflow/e2e/matrix_ops_test.py b/integrations/tensorflow/e2e/matrix_ops_test.py
index b29a198..d04ce3a 100644
--- a/integrations/tensorflow/e2e/matrix_ops_test.py
+++ b/integrations/tensorflow/e2e/matrix_ops_test.py
@@ -71,7 +71,7 @@
 
 
 @tf_test_utils.compile_module(MatrixOpsModule)
-class MatrixOpsTest(tf_test_utils.SavedModelTestCase):
+class MatrixOpsTest(tf_test_utils.CompiledModuleTestCase):
 
   def test_basic_matmul(self):
     m = self.get_module()
diff --git a/integrations/tensorflow/e2e/resource_ops_test.py b/integrations/tensorflow/e2e/resource_ops_test.py
index 1d703c0..8daa6cf 100644
--- a/integrations/tensorflow/e2e/resource_ops_test.py
+++ b/integrations/tensorflow/e2e/resource_ops_test.py
@@ -29,7 +29,7 @@
 
 
 @tf_test_utils.compile_module(ResourcesOpsModule)
-class ResourcesOpsTest(tf_test_utils.SavedModelTestCase):
+class ResourcesOpsTest(tf_test_utils.CompiledModuleTestCase):
 
   def test_add_assign(self):
     result = self.get_module().add_assign(np.array(9., dtype=np.float32))
diff --git a/integrations/tensorflow/e2e/ring_buffer_test.py b/integrations/tensorflow/e2e/ring_buffer_test.py
index ea48711..3af1502 100644
--- a/integrations/tensorflow/e2e/ring_buffer_test.py
+++ b/integrations/tensorflow/e2e/ring_buffer_test.py
@@ -179,7 +179,7 @@
 
 @tf_test_utils.compile_module(
     StatefulRingBufferModule, exported_names=["predict"])
-class StatefulRingBufferTest(tf_test_utils.SavedModelTestCase):
+class StatefulRingBufferTest(tf_test_utils.CompiledModuleTestCase):
 
   def test_stateful_ringbuffer(self):
     input1 = np.array([[1.0, 2.0]], dtype=np.float32)
diff --git a/integrations/tensorflow/e2e/scatter_update_test.py b/integrations/tensorflow/e2e/scatter_update_test.py
index cdd3277..ab5ab91 100644
--- a/integrations/tensorflow/e2e/scatter_update_test.py
+++ b/integrations/tensorflow/e2e/scatter_update_test.py
@@ -49,7 +49,7 @@
 
 
 @tf_test_utils.compile_module(ScatterUpdateModule)
-class ScatterUpdateTest(tf_test_utils.SavedModelTestCase):
+class ScatterUpdateTest(tf_test_utils.CompiledModuleTestCase):
 
   def test_scatter_update_1D(self):
     tensor = tf.ones([8], dtype=tf.int32)
diff --git a/integrations/tensorflow/e2e/simple_arithmetic_test.py b/integrations/tensorflow/e2e/simple_arithmetic_test.py
index 0c5941d..d3ea327 100644
--- a/integrations/tensorflow/e2e/simple_arithmetic_test.py
+++ b/integrations/tensorflow/e2e/simple_arithmetic_test.py
@@ -37,7 +37,7 @@
 
 
 @tf_test_utils.compile_module(SimpleArithmeticModule)
-class SimpleArithmeticTest(tf_test_utils.SavedModelTestCase):
+class SimpleArithmeticTest(tf_test_utils.CompiledModuleTestCase):
 
   def test_simple_mul(self):
     a = np.array([1., 2., 3., 4.], dtype=np.float32)
diff --git a/integrations/tensorflow/e2e/simple_stateful_test.py b/integrations/tensorflow/e2e/simple_stateful_test.py
index 45eba4f..24dd23e 100644
--- a/integrations/tensorflow/e2e/simple_stateful_test.py
+++ b/integrations/tensorflow/e2e/simple_stateful_test.py
@@ -33,7 +33,7 @@
 
 
 @tf_test_utils.compile_module(Stateful)
-class StatefulTest(tf_test_utils.SavedModelTestCase):
+class StatefulTest(tf_test_utils.CompiledModuleTestCase):
 
   def test_stateful(self):
     m = self.get_module()
diff --git a/integrations/tensorflow/e2e/sliding_window_test.py b/integrations/tensorflow/e2e/sliding_window_test.py
index b663fc8..f206d86 100644
--- a/integrations/tensorflow/e2e/sliding_window_test.py
+++ b/integrations/tensorflow/e2e/sliding_window_test.py
@@ -76,7 +76,7 @@
 
 
 @tf_test_utils.compile_module(SlidingWindowModule, exported_names=["predict"])
-class SlidingWindowTest(tf_test_utils.SavedModelTestCase):
+class SlidingWindowTest(tf_test_utils.CompiledModuleTestCase):
 
   def test_slidingwindow(self):
     input1 = np.array([[1.0, 2.0]], dtype=np.float32)
diff --git a/integrations/tensorflow/e2e/strings_test.py b/integrations/tensorflow/e2e/strings_test.py
index ac590ff..ce0787e 100644
--- a/integrations/tensorflow/e2e/strings_test.py
+++ b/integrations/tensorflow/e2e/strings_test.py
@@ -41,7 +41,7 @@
 
 
 @tf_test_utils.compile_module(StringsModule)
-class StringsTest(tf_test_utils.SavedModelTestCase):
+class StringsTest(tf_test_utils.CompiledModuleTestCase):
 
   def test_print_ids(self):
     input_ids = np.asarray(
diff --git a/integrations/tensorflow/e2e/tensorlist_test.py b/integrations/tensorflow/e2e/tensorlist_test.py
index f8ea811..421760c 100644
--- a/integrations/tensorflow/e2e/tensorlist_test.py
+++ b/integrations/tensorflow/e2e/tensorlist_test.py
@@ -69,7 +69,7 @@
 
 
 @tf_test_utils.compile_module(TensorListModule)
-class TensorListTest(tf_test_utils.SavedModelTestCase):
+class TensorListTest(tf_test_utils.CompiledModuleTestCase):
 
   def test_identity_through_tensorlist(self):
     m = self.get_module()