Remove e2e test case decorator (#3304)

Replaces compiling a module with a decorator on a child of 
`TracedModuleTestCase` with calling `tf_test_utils.compile_tf_module` 
in the overridden `__init__` method of the child class.

This will make supporting alternative compilation paths easier.
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index 87a4c8a..fa3df03 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -22,6 +22,7 @@
 #   ref: reference – for the reference CompiledModule
 #   tar: target - for one of the target CompiledModules
 
+import collections
 import copy
 import glob
 import inspect
@@ -465,7 +466,7 @@
             f"--input_file={compiled_path}",
             f"--driver={self.backend_driver}",
             f"--inputs={serialized_inputs}",
-            f"--entry_function={entry_function}"
+            f"--entry_function={entry_function}",
         ]
         with open(os.path.join(trace_dir, "flagfile"), "w") as f:
           f.writelines(line + "\n" for line in flagfile)
@@ -551,7 +552,11 @@
       return self._trace_call(module_attr, method_name=attr)
 
 
-def compile_module(
+Modules = collections.namedtuple('Modules',
+                                 ['ref_module', 'tar_modules', 'artifacts_dir'])
+
+
+def compile_tf_module(
     module_class: Type[tf.Module], exported_names: Sequence[str] = ()
 ) -> Callable[[Any], Any]:
   """CompiledModuleTestCase decorator that compiles a tf.Module.
@@ -565,80 +570,38 @@
     exported_names: optional iterable of strings representing which of
       module_class's functions to compile. If exported_names is empty all
       functions will be compiled.
-
-  Returns:
-    Class decorator function.
   """
 
-  def decorator(cls):
-    """Decorator Function."""
-    if not issubclass(cls, TracedModuleTestCase):
-      logging.exception(
-          "The 'compile_module' decorator must be applied to a "
-          "TracedModuleTestCase derived class, which %s is not.", cls)
-    cls._module_class = module_class
-    cls._exported_names = exported_names
-    return cls
+  # Setup the directory for saving compilation artifacts and traces.
+  artifacts_dir = _setup_artifacts_dir(module_class.__name__)
 
-  return decorator
+  # Get the backend information for this test.
+  ref_backend_info = tf_utils.BackendInfo(FLAGS.reference_backend,
+                                          f"{FLAGS.reference_backend}_ref")
+  tar_backend_infos = get_target_backends()
 
+  compile_backend = lambda backend_info: backend_info.compile(
+      module_class, exported_names, artifacts_dir)
 
-# Will be initialized by TracedModuleTestCase.setUpClass
-# Global variables are used because storing the compiler context on the cls
-# causes cleaning up refcounts to fail, and tf.test.TestCase wipes the variables
-# on the class instance (self.*) before each unittest.
-# TODO(#2900): Move these back to class variables when we figure out issues with
-# refcounting.
-_global_ref_module = None
-_global_tar_modules = None
+  ref_module = compile_backend(ref_backend_info)
+  tar_modules = [
+      compile_backend(backend_info) for backend_info in tar_backend_infos
+  ]
+  return Modules(ref_module, tar_modules, artifacts_dir)
 
 
 class TracedModuleTestCase(tf.test.TestCase):
   """Compiles a tf.Module to multiple backends to test their correctness."""
-  # Will be initialized by the @compile_module decorator.
-  _module_class = None
-  _exported_names = ()
-
-  @classmethod
-  def _compile(cls, backend_info: tf_utils.BackendInfo):
-    return backend_info.compile(cls._module_class, cls._exported_names,
-                                cls._artifacts_dir)
-
-  @classmethod
-  def setUpClass(cls) -> None:
-    # Ran before any of the unit tests.
-    super().setUpClass()
-    if cls._module_class is None:
-      raise AttributeError(
-          "setUpClass was called but no module was specified. Specify a module "
-          "to compile via the @tf_test_utils.compile_module decorator.")
-
-    # Setup the directory for saving compilation artifacts and traces.
-    cls._artifacts_dir = _setup_artifacts_dir(cls._module_class.__name__)
-
-    # Get the backend information for this test.
-    ref_backend_info = tf_utils.BackendInfo(FLAGS.reference_backend,
-                                            f"{FLAGS.reference_backend}_ref")
-    tar_backend_infos = get_target_backends()
-
-    global _global_ref_module
-    global _global_tar_modules
-    _global_ref_module = cls._compile(ref_backend_info)
-    _global_tar_modules = [
-        cls._compile(backend_info) for backend_info in tar_backend_infos
-    ]
 
   def setUp(self) -> None:
     # Runs before each unit test.
     super().setUp()
-    global _global_ref_module
-    global _global_tar_modules
-    _global_ref_module.reinitialize()
-    for module in _global_tar_modules:
+    self._modules.ref_module.reinitialize()
+    for module in self._modules.tar_modules:
       module.reinitialize()
 
-  def compare_backends(self, trace_function: Callable[[TracedModule],
-                                                      None]) -> None:
+  def compare_backends(self, trace_function: Callable[[TracedModule], None],
+                       modules: Modules) -> None:
     """Run the reference and target backends on trace_function and compare them.
 
     Random seeds for tensorflow, numpy and python are set before each invocation
@@ -648,17 +611,17 @@
       trace_function: a function accepting a TracedModule as its argument.
     """
     # Create Traces for each backend.
-    ref_trace = Trace(_global_ref_module, trace_function)
+    ref_trace = Trace(modules.ref_module, trace_function)
     tar_traces = [
-        Trace(module, trace_function) for module in _global_tar_modules
+        Trace(module, trace_function) for module in modules.tar_modules
     ]
 
     # Run the traces through trace_function with their associated modules.
     tf_utils.set_random_seed()
-    trace_function(TracedModule(_global_ref_module, ref_trace))
+    trace_function(TracedModule(modules.ref_module, ref_trace))
     if FLAGS.log_all_traces:
       logging.info(ref_trace)
-    for module, trace in zip(_global_tar_modules, tar_traces):
+    for module, trace in zip(modules.tar_modules, tar_traces):
       tf_utils.set_random_seed()
       trace_function(TracedModule(module, trace))
       if FLAGS.log_all_traces:
@@ -674,11 +637,11 @@
         failed_backend_indices.append(i)
 
     # Save the results to disk before validating.
-    ref_trace_dir = _get_trace_dir(self._artifacts_dir, ref_trace)
+    ref_trace_dir = _get_trace_dir(modules.artifacts_dir, ref_trace)
     ref_trace.save_plaintext(ref_trace_dir, FLAGS.summarize)
     ref_trace.serialize(ref_trace_dir)
     for tar_trace in tar_traces:
-      tar_trace_dir = _get_trace_dir(self._artifacts_dir, tar_trace)
+      tar_trace_dir = _get_trace_dir(modules.artifacts_dir, tar_trace)
       tar_trace.save_plaintext(tar_trace_dir, FLAGS.summarize)
       tar_trace.serialize(tar_trace_dir)
 
diff --git a/integrations/tensorflow/e2e/README.md b/integrations/tensorflow/e2e/README.md
index ff95379..7adb281 100644
--- a/integrations/tensorflow/e2e/README.md
+++ b/integrations/tensorflow/e2e/README.md
@@ -80,11 +80,15 @@
 backend as a source of truth. For example:
 
 ```python
-# Compile a `tf.Module` named `SimpleArithmeticModule` into a `CompiledModule`.
-@tf_test_utils.compile_module(SimpleArithmeticModule)
 # Inherit from `TracedModuleTestCase`.
 class SimpleArithmeticTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(SimpleArithmeticTest, self).__init__(methodName)
+    # Compile a `tf.Module` named `SimpleArithmeticModule` into
+    # `CompiledModule`s for each reference and target backend.
+    self._modules = tf_test_utils.compile_tf_module(SimpleArithmeticModule)
+
   # Unit test.
   def test_simple_mul(self):
 
@@ -103,7 +107,7 @@
 
     # Calls `simple_mul` once for each backend, recording the inputs and outputs
     # to `module` and then comparing them.
-    self.compare_backends(simple_mul)
+    self.compare_backends(simple_mul, self._modules)
 ```
 
 ## Test Suites
diff --git a/integrations/tensorflow/e2e/batch_norm_test.py b/integrations/tensorflow/e2e/batch_norm_test.py
index e16436d..4e3c0cd 100644
--- a/integrations/tensorflow/e2e/batch_norm_test.py
+++ b/integrations/tensorflow/e2e/batch_norm_test.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 """Batch norm tests."""
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 from pyiree.tf.support import tf_utils
@@ -39,9 +40,12 @@
         variance_epsilon=1e-4)
 
 
-@tf_test_utils.compile_module(BatchNormModule)
 class BatchNormTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(BatchNormTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(BatchNormModule)
+
   def test_batch_norm_inference(self):
 
     def batch_norm_inference(module):
@@ -53,10 +57,15 @@
       scale = tf_utils.uniform((16,)) * 1e-3
       module.batch_norm_inference(x, mean, variance, offset, scale)
 
-    self.compare_backends(batch_norm_inference)
+    self.compare_backends(batch_norm_inference, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/bool_test.py b/integrations/tensorflow/e2e/bool_test.py
index 161f6ad..ef9dd10 100644
--- a/integrations/tensorflow/e2e/bool_test.py
+++ b/integrations/tensorflow/e2e/bool_test.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 """Tests for ops in the tf.math module."""
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
@@ -37,22 +38,25 @@
     return tf.math.logical_and(x, y)
 
 
-@tf_test_utils.compile_module(MathModule)
 class BooleanTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(BooleanTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(MathModule)
+
   def test_constant(self):
 
     def constant(module):
       module.constant()
 
-    self.compare_backends(constant)
+    self.compare_backends(constant, self._modules)
 
   def test_greater_than(self):
 
     def greater_than(module):
       module.greater_than(np.array([0.0, 1.2, 1.5, 3.75], dtype=np.float32))
 
-    self.compare_backends(greater_than)
+    self.compare_backends(greater_than, self._modules)
 
   def test_logical_and(self):
 
@@ -61,10 +65,15 @@
           np.array([True, True, False, False], dtype=np.bool),
           np.array([True, False, False, True], dtype=np.bool))
 
-    self.compare_backends(logical_and)
+    self.compare_backends(logical_and, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/broadcast_to_test.py b/integrations/tensorflow/e2e/broadcast_to_test.py
index 6d57d6f..dc53f34 100644
--- a/integrations/tensorflow/e2e/broadcast_to_test.py
+++ b/integrations/tensorflow/e2e/broadcast_to_test.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
@@ -30,9 +31,12 @@
     return tf.broadcast_to(x, shape)
 
 
-@tf_test_utils.compile_module(BroadcastToModule)
 class BroadcastToTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(BroadcastToTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(BroadcastToModule)
+
   def test_scalar_broadcast_to(self):
 
     def scalar_broadcast_to(module):
@@ -40,10 +44,15 @@
       shape = np.array([3, 3], dtype=np.int32)
       result = module.scalar_broadcast_to(x, shape)
 
-    self.compare_backends(scalar_broadcast_to)
+    self.compare_backends(scalar_broadcast_to, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/broadcasting_test.py b/integrations/tensorflow/e2e/broadcasting_test.py
index c4f8e38..f72f3b3 100644
--- a/integrations/tensorflow/e2e/broadcasting_test.py
+++ b/integrations/tensorflow/e2e/broadcasting_test.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 """Test broadcasting support."""
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 from pyiree.tf.support import tf_utils
@@ -30,9 +31,12 @@
     return lhs + rhs
 
 
-@tf_test_utils.compile_module(BroadcastingModule)
 class BroadcastingTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(BroadcastingTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(BroadcastingModule)
+
   def test_add_same_shape(self):
 
     def add_same_shape(module):
@@ -40,7 +44,7 @@
       rhs = tf_utils.uniform([4])
       module.add(lhs, rhs)
 
-    self.compare_backends(add_same_shape)
+    self.compare_backends(add_same_shape, self._modules)
 
   def test_add_broadcast_lhs(self):
 
@@ -49,7 +53,7 @@
       rhs = tf_utils.uniform([4])
       module.add(lhs, rhs)
 
-    self.compare_backends(add_broadcast_lhs)
+    self.compare_backends(add_broadcast_lhs, self._modules)
 
   def test_add_broadcast_rhs(self):
 
@@ -58,10 +62,15 @@
       rhs = tf_utils.uniform([1])
       module.add(lhs, rhs)
 
-    self.compare_backends(add_broadcast_rhs)
+    self.compare_backends(add_broadcast_rhs, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/complex_test.py b/integrations/tensorflow/e2e/complex_test.py
index 102396b..4832146 100644
--- a/integrations/tensorflow/e2e/complex_test.py
+++ b/integrations/tensorflow/e2e/complex_test.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
@@ -32,9 +33,12 @@
     return tf.math.real(exp)
 
 
-@tf_test_utils.compile_module(ComplexModule)
 class ComplexTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(ComplexTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(ComplexModule)
+
   def test_complex(self):
 
     def complex_exp(module):
@@ -42,10 +46,15 @@
       imag = np.array([-1., 0.4], dtype=np.float32)
       module.complex_exp(real, imag)
 
-    self.compare_backends(complex_exp)
+    self.compare_backends(complex_exp, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/concat_test.py b/integrations/tensorflow/e2e/concat_test.py
index 187c6cb..dc0fab8 100644
--- a/integrations/tensorflow/e2e/concat_test.py
+++ b/integrations/tensorflow/e2e/concat_test.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 """Test concat op."""
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 from pyiree.tf.support import tf_utils
@@ -51,9 +52,12 @@
     return tf.concat([a, b], axis=2)
 
 
-@tf_test_utils.compile_module(ConcatOpsModule)
 class ConcatOpsTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(ConcatOpsTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(ConcatOpsModule)
+
   def test_concat_zero_dim(self):
 
     def concat_zero_dim(module):
@@ -61,7 +65,7 @@
       b = tf_utils.uniform([1, 5, 1])
       module.concat_zero_dim(a, b)
 
-    self.compare_backends(concat_zero_dim)
+    self.compare_backends(concat_zero_dim, self._modules)
 
   def test_concat0axis(self):
 
@@ -70,7 +74,7 @@
       b = tf_utils.uniform([1, 5, 1])
       module.concat0axis(a, b)
 
-    self.compare_backends(concat0axis)
+    self.compare_backends(concat0axis, self._modules)
 
   def test_concat1axis(self):
 
@@ -79,7 +83,7 @@
       b = tf_utils.uniform([1, 5, 1])
       module.concat1axis(a, b)
 
-    self.compare_backends(concat1axis)
+    self.compare_backends(concat1axis, self._modules)
 
   def test_concat2axis(self):
 
@@ -88,10 +92,15 @@
       b = tf_utils.uniform([1, 5, 1])
       module.concat2axis(a, b)
 
-    self.compare_backends(concat2axis)
+    self.compare_backends(concat2axis, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/control_flow_test.py b/integrations/tensorflow/e2e/control_flow_test.py
index bc0a328..4ba691a 100644
--- a/integrations/tensorflow/e2e/control_flow_test.py
+++ b/integrations/tensorflow/e2e/control_flow_test.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
@@ -34,16 +35,19 @@
     return i
 
 
-@tf_test_utils.compile_module(ControlFlowModule)
 class ControlFlowTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(ControlFlowTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(ControlFlowModule)
+
   def test_short_sequence(self):
 
     def short_sequence(module):
       input_array = np.array(9., dtype=np.float32)
       module.collatz(input_array)
 
-    self.compare_backends(short_sequence)
+    self.compare_backends(short_sequence, self._modules)
 
   def test_long_sequence(self):
 
@@ -51,10 +55,15 @@
       input_array = np.array(178., dtype=np.float32)
       module.collatz(input_array)
 
-    self.compare_backends(long_sequence)
+    self.compare_backends(long_sequence, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/conv_test.py b/integrations/tensorflow/e2e/conv_test.py
index 9346a52..cff4496 100644
--- a/integrations/tensorflow/e2e/conv_test.py
+++ b/integrations/tensorflow/e2e/conv_test.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 from pyiree.tf.support import tf_utils
@@ -99,9 +100,12 @@
     return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result")
 
 
-@tf_test_utils.compile_module(Conv2dModule)
 class ConvTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(ConvTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(Conv2dModule)
+
   def test_id_batch_size_1(self):
 
     def id_batch_size_1(module):
@@ -109,7 +113,7 @@
       k = np.ones([1, 1, 1, 1], dtype=np.float32)
       module.conv2d_1451x1111_valid(i, k)
 
-    self.compare_backends(id_batch_size_1)
+    self.compare_backends(id_batch_size_1, self._modules)
 
   def test_id_batch_size_2(self):
 
@@ -118,7 +122,7 @@
       k = np.ones([1, 1, 1, 1], dtype=np.float32)
       module.conv2d_2451x1111_valid(i, k)
 
-    self.compare_backends(id_batch_size_2)
+    self.compare_backends(id_batch_size_2, self._modules)
 
   def test_asymmetric_kernel(self):
 
@@ -128,7 +132,7 @@
                    dtype=np.float32).reshape(2, 3, 1, 1)
       module.conv2d_1451x2311_valid(i, k)
 
-    self.compare_backends(asymmetric_kernel)
+    self.compare_backends(asymmetric_kernel, self._modules)
 
   def test_padding(self):
 
@@ -138,7 +142,7 @@
                    dtype=np.float32).reshape(2, 3, 1, 1)
       module.conv2d_1451x2311_same(i, k)
 
-    self.compare_backends(padding)
+    self.compare_backends(padding, self._modules)
 
   def test_batched_padding(self):
 
@@ -148,7 +152,7 @@
                    dtype=np.float32).reshape(2, 3, 1, 1)
       module.conv2d_2451x2311_same(i, k)
 
-    self.compare_backends(batched_padding)
+    self.compare_backends(batched_padding, self._modules)
 
   def test_feature_reduce(self):
 
@@ -157,7 +161,7 @@
       k = np.ones([3, 2, 2, 1], dtype=np.float32)
       module.conv2d_1452x3221_same(i, k)
 
-    self.compare_backends(feature_reduce)
+    self.compare_backends(feature_reduce, self._modules)
 
   def test_feature_inflate(self):
 
@@ -166,7 +170,7 @@
       k = tf_utils.ndarange([1, 1, 1, 2])
       module.conv2d_1451x1112_same(i, k)
 
-    self.compare_backends(feature_inflate)
+    self.compare_backends(feature_inflate, self._modules)
 
   def test_feature_mix(self):
 
@@ -175,7 +179,7 @@
       k = tf_utils.ndarange([1, 1, 2, 2])
       module.conv2d_1452x1122_same(i, k)
 
-    self.compare_backends(feature_mix)
+    self.compare_backends(feature_mix, self._modules)
 
   def test_feature_padded(self):
 
@@ -184,7 +188,7 @@
       k = tf_utils.ndarange([2, 2, 2, 3])
       module.conv2d_1452x2223_same(i, k)
 
-    self.compare_backends(feature_padded)
+    self.compare_backends(feature_padded, self._modules)
 
   def test_feature_unpadded(self):
 
@@ -193,7 +197,7 @@
       k = tf_utils.ndarange([2, 2, 2, 3])
       module.conv2d_1452x2223_valid(i, k)
 
-    self.compare_backends(feature_unpadded)
+    self.compare_backends(feature_unpadded, self._modules)
 
   def test_batched_feature_unpadded(self):
 
@@ -202,10 +206,15 @@
       k = tf_utils.ndarange([2, 2, 2, 3])
       module.conv2d_2452x2223_valid(i, k)
 
-    self.compare_backends(batched_feature_unpadded)
+    self.compare_backends(batched_feature_unpadded, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/depth_conv_test.py b/integrations/tensorflow/e2e/depth_conv_test.py
index e3ba303..c928d5e 100644
--- a/integrations/tensorflow/e2e/depth_conv_test.py
+++ b/integrations/tensorflow/e2e/depth_conv_test.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 from pyiree.tf.support import tf_utils
@@ -27,45 +28,58 @@
       tf.TensorSpec([2, 2, 2, 3], tf.float32),
   ])
   def conv2d_2452x2423_valid(self, img, kernel):
-    return tf.nn.depthwise_conv2d(
-        img, kernel, [1, 1, 1, 1], "VALID", name="result")
+    return tf.nn.depthwise_conv2d(img,
+                                  kernel, [1, 1, 1, 1],
+                                  "VALID",
+                                  name="result")
 
   @tf.function(input_signature=[
       tf.TensorSpec([2, 4, 5, 2], tf.float32),
       tf.TensorSpec([2, 4, 2, 3], tf.float32),
   ])
   def conv2d_2452x2423_same(self, img, kernel):
-    return tf.nn.depthwise_conv2d(
-        img, kernel, [1, 1, 1, 1], "SAME", name="result")
+    return tf.nn.depthwise_conv2d(img,
+                                  kernel, [1, 1, 1, 1],
+                                  "SAME",
+                                  name="result")
 
   @tf.function(input_signature=[
       tf.TensorSpec([2, 4, 5, 2], tf.float32),
       tf.TensorSpec([2, 4, 2, 3], tf.float32),
   ])
   def conv2d_2452x2423_valid_stride_2(self, img, kernel):
-    return tf.nn.depthwise_conv2d(
-        img, kernel, [1, 2, 2, 1], "VALID", name="result")
+    return tf.nn.depthwise_conv2d(img,
+                                  kernel, [1, 2, 2, 1],
+                                  "VALID",
+                                  name="result")
 
   @tf.function(input_signature=[
       tf.TensorSpec([2, 4, 5, 2], tf.float32),
       tf.TensorSpec([2, 4, 2, 3], tf.float32),
   ])
   def conv2d_2452x2423_same_stride_2(self, img, kernel):
-    return tf.nn.depthwise_conv2d(
-        img, kernel, [1, 2, 2, 1], "SAME", name="result")
+    return tf.nn.depthwise_conv2d(img,
+                                  kernel, [1, 2, 2, 1],
+                                  "SAME",
+                                  name="result")
 
   @tf.function(input_signature=[
       tf.TensorSpec([2, 4, 5, 4], tf.float32),
       tf.TensorSpec([2, 4, 4, 1], tf.float32),
   ])
   def conv2d_2453x2441_same_stride_1(self, img, kernel):
-    return tf.nn.depthwise_conv2d(
-        img, kernel, [1, 1, 1, 1], "SAME", name="result")
+    return tf.nn.depthwise_conv2d(img,
+                                  kernel, [1, 1, 1, 1],
+                                  "SAME",
+                                  name="result")
 
 
-@tf_test_utils.compile_module(DepthConv2dModule)
 class ConvTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(ConvTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(DepthConv2dModule)
+
   def test_batched_feature_unpadded(self):
 
     def batched_feature_unpadded(module):
@@ -73,7 +87,7 @@
       k = tf_utils.ndarange([2, 2, 2, 3])
       module.conv2d_2452x2423_valid(i, k)
 
-    self.compare_backends(batched_feature_unpadded)
+    self.compare_backends(batched_feature_unpadded, self._modules)
 
   def test_batched_feature_unpadded_same(self):
 
@@ -82,7 +96,7 @@
       k = tf_utils.ndarange([2, 4, 2, 3])
       module.conv2d_2452x2423_same(i, k)
 
-    self.compare_backends(batched_feature_unpadded_same)
+    self.compare_backends(batched_feature_unpadded_same, self._modules)
 
   def test_batched_feature_unpadded_same_stride_2(self):
 
@@ -91,7 +105,8 @@
       k = tf_utils.ndarange([2, 4, 2, 3])
       module.conv2d_2452x2423_valid_stride_2(i, k)
 
-    self.compare_backends(batched_feature_unpadded_same_stride_2)
+    self.compare_backends(batched_feature_unpadded_same_stride_2,
+                          self._modules)
 
   def test_batched_feature_padded_same_stride_2(self):
 
@@ -100,7 +115,7 @@
       k = tf_utils.ndarange([2, 4, 2, 3])
       module.conv2d_2452x2423_same_stride_2(i, k)
 
-    self.compare_backends(batched_feature_padded_same_stride_2)
+    self.compare_backends(batched_feature_padded_same_stride_2, self._modules)
 
   def test_batched_feature_padded_same_stride_1_output_1(self):
 
@@ -109,10 +124,16 @@
       k = tf_utils.ndarange([2, 4, 4, 1])
       module.conv2d_2453x2441_same_stride_1(i, k)
 
-    self.compare_backends(batched_feature_padded_same_stride_1_output_1)
+    self.compare_backends(batched_feature_padded_same_stride_1_output_1,
+                          self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py b/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
index 3742c6e..1e5ba08 100644
--- a/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
+++ b/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
@@ -17,6 +17,7 @@
 # This uses a relu instead, allowing it to get to the remaining issue
 # (unimplemented dynamic dot_general).
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 from pyiree.tf.support import tf_utils
@@ -42,8 +43,8 @@
     self.input_dim = input_dim
     self.classes = classes
     self.h1_weights = tf.Variable(tf.random.normal([input_dim, hidden_1_dim]))
-    self.h2_weights = tf.Variable(
-        tf.random.normal([hidden_1_dim, hidden_2_dim]))
+    self.h2_weights = tf.Variable(tf.random.normal([hidden_1_dim,
+                                                    hidden_2_dim]))
     self.out_weights = tf.Variable(tf.random.normal([hidden_2_dim, classes]))
     self.h1_bias = tf.Variable(tf.random.normal([hidden_1_dim]))
     self.h2_bias = tf.Variable(tf.random.normal([hidden_2_dim]))
@@ -51,8 +52,7 @@
 
     # Compile with dynamic batch dim.
     self.predict = tf.function(
-        input_signature=[tf.TensorSpec([None, self.input_dim])])(
-            self.predict)
+        input_signature=[tf.TensorSpec([None, self.input_dim])])(self.predict)
 
   def mlp(self, x):
     layer_1 = tf.nn.relu(tf.add(tf.matmul(x, self.h1_weights), self.h1_bias))
@@ -65,19 +65,28 @@
     return tf.nn.softmax(self.mlp(x))
 
 
-@tf_test_utils.compile_module(DynamicMlpReluModule, exported_names=["predict"])
 class DynamicMlpReluTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(DynamicMlpReluTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(DynamicMlpReluModule,
+                                                    exported_names=["predict"])
+
   def test_dynamic_batch(self):
 
     def dynamic_batch(module):
       x = tf_utils.uniform([3, 28 * 28]) * 1e-3
       module.predict(x)
 
-    self.compare_backends(dynamic_batch)
+    self.compare_backends(dynamic_batch, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/dynamic_mlp_test.py b/integrations/tensorflow/e2e/dynamic_mlp_test.py
index ff905a5..a0ecdc5 100644
--- a/integrations/tensorflow/e2e/dynamic_mlp_test.py
+++ b/integrations/tensorflow/e2e/dynamic_mlp_test.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 from pyiree.tf.support import tf_utils
@@ -38,8 +39,8 @@
     self.input_dim = input_dim
     self.classes = classes
     self.h1_weights = tf.Variable(tf.random.normal([input_dim, hidden_1_dim]))
-    self.h2_weights = tf.Variable(
-        tf.random.normal([hidden_1_dim, hidden_2_dim]))
+    self.h2_weights = tf.Variable(tf.random.normal([hidden_1_dim,
+                                                    hidden_2_dim]))
     self.out_weights = tf.Variable(tf.random.normal([hidden_2_dim, classes]))
     self.h1_bias = tf.Variable(tf.random.normal([hidden_1_dim]))
     self.h2_bias = tf.Variable(tf.random.normal([hidden_2_dim]))
@@ -47,8 +48,7 @@
 
     # Compile with dynamic batch dim.
     self.predict = tf.function(
-        input_signature=[tf.TensorSpec([None, self.input_dim])])(
-            self.predict)
+        input_signature=[tf.TensorSpec([None, self.input_dim])])(self.predict)
 
   def mlp(self, x):
     layer_1 = tf.sigmoid(tf.add(tf.matmul(x, self.h1_weights), self.h1_bias))
@@ -61,19 +61,28 @@
     return tf.nn.softmax(self.mlp(x))
 
 
-@tf_test_utils.compile_module(DynamicMlpModule, exported_names=["predict"])
 class DynamicMlpTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(DynamicMlpTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(DynamicMlpModule,
+                                                    exported_names=["predict"])
+
   def test_dynamic_batch(self):
 
     def dynamic_batch(module):
       x = tf_utils.uniform([3, 28 * 28]) * 1e-3
       module.predict(x)
 
-    self.compare_backends(dynamic_batch)
+    self.compare_backends(dynamic_batch, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/fill_test.py b/integrations/tensorflow/e2e/fill_test.py
index 050eb47..adaeac9 100644
--- a/integrations/tensorflow/e2e/fill_test.py
+++ b/integrations/tensorflow/e2e/fill_test.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
@@ -30,9 +31,12 @@
     return tf.fill(dims, value)
 
 
-@tf_test_utils.compile_module(FillModule)
 class FillTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(FillTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(FillModule)
+
   def test_fill(self):
 
     def fill(module):
@@ -40,10 +44,15 @@
       value = np.array(9., dtype=np.float32)
       module.fill(dims, value)
 
-    self.compare_backends(fill)
+    self.compare_backends(fill, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/finite_test.py b/integrations/tensorflow/e2e/finite_test.py
index ff62f3a..1f53406 100644
--- a/integrations/tensorflow/e2e/finite_test.py
+++ b/integrations/tensorflow/e2e/finite_test.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
@@ -24,18 +25,26 @@
     return tf.math.is_finite(x)
 
 
-@tf_test_utils.compile_module(FiniteModule)
 class FiniteTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(FiniteTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(FiniteModule)
+
   def test_finite(self):
 
     def finite(module):
       module.finite(np.array([0.0, 1.2, -5.0, np.inf], dtype=np.float32))
 
-    self.compare_backends(finite)
+    self.compare_backends(finite, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/gather_test.py b/integrations/tensorflow/e2e/gather_test.py
index fc0ec13..da71d4b 100644
--- a/integrations/tensorflow/e2e/gather_test.py
+++ b/integrations/tensorflow/e2e/gather_test.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 from pyiree.tf.support import tf_utils
@@ -65,9 +66,12 @@
     return tf.gather(params, indices, axis=1, batch_dims=1)
 
 
-@tf_test_utils.compile_module(GatherModule)
 class GatherTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(GatherTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(GatherModule)
+
   def test_gather_axis0_scalar(self):
 
     def gather_axis0_scalar(module):
@@ -75,7 +79,7 @@
       params = tf_utils.ndarange([4, 8])
       module.gather_axis0_scalar(params, indices)
 
-    self.compare_backends(gather_axis0_scalar)
+    self.compare_backends(gather_axis0_scalar, self._modules)
 
   def test_gather_axis0_batch0(self):
 
@@ -84,7 +88,7 @@
       params = tf_utils.ndarange([4, 8])
       module.gather_axis0_batch0(params, indices)
 
-    self.compare_backends(gather_axis0_batch0)
+    self.compare_backends(gather_axis0_batch0, self._modules)
 
   def test_gather_axis1_batch0(self):
 
@@ -93,7 +97,7 @@
       params = tf_utils.ndarange([4, 7, 8])
       module.gather_axis1_batch0(params, indices)
 
-    self.compare_backends(gather_axis1_batch0)
+    self.compare_backends(gather_axis1_batch0, self._modules)
 
   def test_gather_axis2_batch1(self):
 
@@ -102,7 +106,7 @@
       params = tf_utils.ndarange([4, 7, 8, 2])
       module.gather_axis2_batch1(params, indices)
 
-    self.compare_backends(gather_axis2_batch1)
+    self.compare_backends(gather_axis2_batch1, self._modules)
 
   def test_gather_axis1_batch1(self):
 
@@ -111,7 +115,7 @@
       params = tf_utils.ndarange([4, 7, 8, 2])
       module.gather_axis1_batch1(params, indices)
 
-    self.compare_backends(gather_axis1_batch1)
+    self.compare_backends(gather_axis1_batch1, self._modules)
 
   def test_gather_axis2_batch2(self):
 
@@ -120,11 +124,16 @@
       values = np.array([[0, 1, 2, 3], [9, 8, 7, 0]], dtype=np.int32)
       module.gather_axis2_batch2(values, indices)
 
-    self.compare_backends(gather_axis2_batch2)
+    self.compare_backends(gather_axis2_batch2, self._modules)
 
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/keras/lstm_static_test.py b/integrations/tensorflow/e2e/keras/lstm_static_test.py
index db8f86e..5cb1827 100644
--- a/integrations/tensorflow/e2e/keras/lstm_static_test.py
+++ b/integrations/tensorflow/e2e/keras/lstm_static_test.py
@@ -16,6 +16,7 @@
 # This test is the same as keras_lstm_test, but all shapes are static.
 # This stresses the TensorList lowering more specifically.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 from pyiree.tf.support import tf_utils
@@ -42,19 +43,28 @@
             self.m.call)
 
 
-@tf_test_utils.compile_module(LstmStaticModule, exported_names=["predict"])
 class LstmStaticTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(LstmStaticTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(LstmStaticModule,
+                                                    exported_names=["predict"])
+
   def test_lstm(self):
 
     def predict(module):
       inputs = tf_utils.ndarange(INPUT_SHAPE)
       module.predict(inputs, rtol=1e-5, atol=1e-5)
 
-    self.compare_backends(predict)
+    self.compare_backends(predict, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/keras/lstm_test.py b/integrations/tensorflow/e2e/keras/lstm_test.py
index 112e32a..747cc6f 100644
--- a/integrations/tensorflow/e2e/keras/lstm_test.py
+++ b/integrations/tensorflow/e2e/keras/lstm_test.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 from pyiree.tf.support import tf_utils
@@ -31,28 +32,36 @@
     super(LstmModule, self).__init__()
     tf_utils.set_random_seed()
     inputs = tf.keras.layers.Input(batch_size=None, shape=DYNAMIC_SHAPE[1:])
-    outputs = tf.keras.layers.LSTM(
-        units=NUM_UNITS, return_sequences=True)(
-            inputs)
+    outputs = tf.keras.layers.LSTM(units=NUM_UNITS,
+                                   return_sequences=True)(inputs)
     self.m = tf.keras.Model(inputs, outputs)
     self.predict = tf.function(
-        input_signature=[tf.TensorSpec(DYNAMIC_SHAPE, tf.float32)])(
-            self.m.call)
+        input_signature=[tf.TensorSpec(DYNAMIC_SHAPE, tf.float32)])(self.m.call)
 
 
-@tf_test_utils.compile_module(LstmModule, exported_names=["predict"])
 class LstmTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(LstmTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(LstmModule,
+                                                    exported_names=["predict"])
+
   def test_lstm(self):
 
     def predict(module):
       inputs = tf_utils.ndarange(INPUT_SHAPE)
       module.predict(inputs)
 
-    self.compare_backends(predict)
+    self.compare_backends(predict, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
+
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/keras/train/model_train_test.py b/integrations/tensorflow/e2e/keras/train/model_train_test.py
index 75286ad..49a3949 100644
--- a/integrations/tensorflow/e2e/keras/train/model_train_test.py
+++ b/integrations/tensorflow/e2e/keras/train/model_train_test.py
@@ -75,8 +75,6 @@
     return loss_value
 
 
-@tf_test_utils.compile_module(
-    ModelTrain.CreateModule, exported_names=["train_step"])
 class ModelTrainTest(tf_test_utils.TracedModuleTestCase):
 
   def generate_regression_data(self, size=8):
@@ -104,11 +102,12 @@
       # Run one iteration of training step.
       module.train_step(inputs, targets)
 
-    self.compare_backends(train_step)
+    self.compare_backends(train_step, self._modules)
 
 
 if __name__ == "__main__":
   if hasattr(tf, "enable_v2_behavior"):
     tf.enable_v2_behavior()
-
+  tf_test_utils.compile_tf_module(ModelTrain.CreateModule,
+                                  exported_names=["train_step"])
   tf.test.main()
diff --git a/integrations/tensorflow/e2e/keras/vision_model_test.py b/integrations/tensorflow/e2e/keras/vision_model_test.py
index 10663ec..4e1ca64 100644
--- a/integrations/tensorflow/e2e/keras/vision_model_test.py
+++ b/integrations/tensorflow/e2e/keras/vision_model_test.py
@@ -113,8 +113,9 @@
   # an external tf.keras URL.
   weights = 'imagenet' if FLAGS.data == 'imagenet' else None
 
-  model = APP_MODELS[FLAGS.model](
-      weights=weights, include_top=FLAGS.include_top, input_shape=input_shape)
+  model = APP_MODELS[FLAGS.model](weights=weights,
+                                  include_top=FLAGS.include_top,
+                                  input_shape=input_shape)
 
   if FLAGS.data == 'cifar10' and FLAGS.url:
     model = load_cifar10_weights(model)
@@ -131,19 +132,22 @@
     # TODO(b/142948097): Add support for dynamic shapes in SPIR-V lowering.
     # Replace input_shape with m.input_shape to make the batch size dynamic.
     self.predict = tf.function(
-        input_signature=[tf.TensorSpec(get_input_shape())])(
-            self.m.predict)
+        input_signature=[tf.TensorSpec(get_input_shape())])(self.m.predict)
 
 
-@tf_test_utils.compile_module(VisionModule, exported_names=['predict'])
 class AppTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(AppTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(VisionModule,
+                                                    exported_names=['predict'])
+
   def test_application(self):
 
     def predict(module):
       module.predict(tf_utils.uniform(get_input_shape()))
 
-    self.compare_backends(predict)
+    self.compare_backends(predict, self._modules)
 
 
 def main(argv):
diff --git a/integrations/tensorflow/e2e/linspace_test.py b/integrations/tensorflow/e2e/linspace_test.py
index b535021..4acc170 100644
--- a/integrations/tensorflow/e2e/linspace_test.py
+++ b/integrations/tensorflow/e2e/linspace_test.py
@@ -12,12 +12,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
 
 
-class LinSpaceModule(tf.Module):
+class LinspaceModule(tf.Module):
 
   def __init__(self):
     pass
@@ -33,9 +34,12 @@
     return tf.linspace(start, stop, num)
 
 
-@tf_test_utils.compile_module(LinSpaceModule)
 class LinspaceTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(LinspaceTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(LinspaceModule)
+
   def test_linspace(self):
 
     def linspace(module):
@@ -43,10 +47,15 @@
       stop = np.array(12., dtype=np.float32)
       module.linspace(start, stop)
 
-    self.compare_backends(linspace)
+    self.compare_backends(linspace, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/logical_ops_test.py b/integrations/tensorflow/e2e/logical_ops_test.py
index ea3fd7a..ee83a95 100644
--- a/integrations/tensorflow/e2e/logical_ops_test.py
+++ b/integrations/tensorflow/e2e/logical_ops_test.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 """Tests for ops in the tf.math module that specifically handle logical ops."""
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
@@ -46,9 +47,12 @@
     return tf.math.logical_not(x)
 
 
-@tf_test_utils.compile_module(LogicalOpsModule)
 class LogicalOpsTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(LogicalOpsTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(LogicalOpsModule)
+
   def test_logical_and(self):
 
     def logical_and(module):
@@ -56,7 +60,7 @@
           np.array([1, 1, 0, 0], dtype=np.bool),
           np.array([0, 1, 1, 0], dtype=np.bool))
 
-    self.compare_backends(logical_and)
+    self.compare_backends(logical_and, self._modules)
 
   def test_logical_or(self):
 
@@ -65,7 +69,7 @@
           np.array([1, 1, 0, 0], dtype=np.bool),
           np.array([0, 1, 1, 0], dtype=np.bool))
 
-    self.compare_backends(logical_or)
+    self.compare_backends(logical_or, self._modules)
 
   def test_logical_xor(self):
 
@@ -74,17 +78,22 @@
           np.array([1, 1, 0, 0], dtype=np.bool),
           np.array([0, 1, 1, 0], dtype=np.bool))
 
-    self.compare_backends(logical_xor)
+    self.compare_backends(logical_xor, self._modules)
 
   def test_logical_not(self):
 
     def logical_not(module):
       module.logical_not(np.array([0, 1, 1, 0], dtype=np.bool))
 
-    self.compare_backends(logical_not)
+    self.compare_backends(logical_not, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/mandelbrot_test.py b/integrations/tensorflow/e2e/mandelbrot_test.py
index b5a3929..51c1509 100644
--- a/integrations/tensorflow/e2e/mandelbrot_test.py
+++ b/integrations/tensorflow/e2e/mandelbrot_test.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
 
@@ -90,9 +91,12 @@
     return tf.reshape(in_the_set, shape=[view_pixels, view_pixels])
 
 
-@tf_test_utils.compile_module(MandelbrotModule)
 class MandelbrotTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(MandelbrotTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(MandelbrotModule)
+
   def test_mandelbrot(self):
 
     def mandelbrot(module):
@@ -101,10 +105,15 @@
       # This is a much more detailed view, so more iterations are needed.
       module.calculate(-0.7436447860, 0.1318252536, 0.0000029336, 400, 3000)
 
-    self.compare_backends(mandelbrot)
+    self.compare_backends(mandelbrot, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/math_test.py b/integrations/tensorflow/e2e/math_test.py
index 48bcbb6..a96193a 100644
--- a/integrations/tensorflow/e2e/math_test.py
+++ b/integrations/tensorflow/e2e/math_test.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 """Tests for ops in the tf.math module."""
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
@@ -42,46 +43,54 @@
     return tf.math.mod(x, 2.0)
 
 
-@tf_test_utils.compile_module(MathModule)
 class MathTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(MathTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(MathModule)
+
   def test_abs(self):
 
     def abs(module):
       module.abs(np.array([-0.5, 0.0, 0.5, 1.0], dtype=np.float32))
 
-    self.compare_backends(abs)
+    self.compare_backends(abs, self._modules)
 
   def test_ceil(self):
 
     def ceil(module):
       module.ceil(np.array([0.0, 1.2, 1.5, 3.75], dtype=np.float32))
 
-    self.compare_backends(ceil)
+    self.compare_backends(ceil, self._modules)
 
   def test_cos(self):
 
     def cos(module):
       module.cos(np.array([-0.5, 0.0, 0.5, 1.0], dtype=np.float32))
 
-    self.compare_backends(cos)
+    self.compare_backends(cos, self._modules)
 
   def test_log(self):
 
     def log(module):
       module.log(np.array([0.1, 0.2, 0.5, 1.0], dtype=np.float32))
 
-    self.compare_backends(log)
+    self.compare_backends(log, self._modules)
 
   def test_mod(self):
 
     def mod(module):
       module.mod(np.array([0.0, 1.2, 1.5, 3.75], dtype=np.float32))
 
-    self.compare_backends(mod)
+    self.compare_backends(mod, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/matrix_ops_dynamic_test.py b/integrations/tensorflow/e2e/matrix_ops_dynamic_test.py
index 9c87b7b..4a3daf7 100644
--- a/integrations/tensorflow/e2e/matrix_ops_dynamic_test.py
+++ b/integrations/tensorflow/e2e/matrix_ops_dynamic_test.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 """Test matrix ops."""
 
+from absl import app
 from pyiree.tf.support import tf_test_utils
 from pyiree.tf.support import tf_utils
 import tensorflow.compat.v2 as tf
@@ -43,16 +44,19 @@
     return tf.matmul(lhs, rhs)
 
 
-@tf_test_utils.compile_module(MatrixOpsDynamicModule)
 class MatrixOpsDynamicTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(MatrixOpsDynamicTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(MatrixOpsDynamicModule)
+
   def test_matmul_high_rank_batch(self):
 
     def matmul_high_rank_batch(module):
       module.matmul_high_rank_batch(
           tf_utils.uniform([1, 7, 4, 2]), tf_utils.uniform([7, 1, 2, 4]))
 
-    self.compare_backends(matmul_high_rank_batch)
+    self.compare_backends(matmul_high_rank_batch, self._modules)
 
   def test_matmul_dynamic_matching_batch(self):
 
@@ -60,7 +64,7 @@
       module.matmul_dynamic(
           tf_utils.uniform([2, 2, 3]), tf_utils.uniform([2, 3, 4]))
 
-    self.compare_backends(matmul_dynamic_matching_batch)
+    self.compare_backends(matmul_dynamic_matching_batch, self._modules)
 
   def test_matmul_dynamic_broadcast_lhs(self):
 
@@ -68,7 +72,7 @@
       module.matmul_dynamic(
           tf_utils.uniform([1, 2, 3]), tf_utils.uniform([2, 3, 4]))
 
-    self.compare_backends(matmul_dynamic_broadcast_lhs)
+    self.compare_backends(matmul_dynamic_broadcast_lhs, self._modules)
 
   def test_matmul_dynamic_broadcast_rhs(self):
 
@@ -76,7 +80,7 @@
       module.matmul_dynamic(
           tf_utils.uniform([2, 2, 3]), tf_utils.uniform([1, 3, 4]))
 
-    self.compare_backends(matmul_dynamic_broadcast_rhs)
+    self.compare_backends(matmul_dynamic_broadcast_rhs, self._modules)
 
   def test_matmul_dynamic_rank_broadcasting(self):
 
@@ -84,10 +88,15 @@
       module.matmul_dynamic_lhs_batch(
           tf_utils.uniform([7, 2, 3]), tf_utils.uniform([3, 4]))
 
-    self.compare_backends(matmul_dynamic_rank_broadcasting)
+    self.compare_backends(matmul_dynamic_rank_broadcasting, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/matrix_ops_static_test.py b/integrations/tensorflow/e2e/matrix_ops_static_test.py
index 82fa482..3a7fe47 100644
--- a/integrations/tensorflow/e2e/matrix_ops_static_test.py
+++ b/integrations/tensorflow/e2e/matrix_ops_static_test.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 """Test matrix ops."""
 
+from absl import app
 from pyiree.tf.support import tf_test_utils
 from pyiree.tf.support import tf_utils
 import tensorflow.compat.v2 as tf
@@ -55,16 +56,19 @@
     return tf.matmul(lhs, rhs)
 
 
-@tf_test_utils.compile_module(MatrixOpsStaticModule)
 class MatrixOpsStaticTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(MatrixOpsStaticTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(MatrixOpsStaticModule)
+
   def test_basic_matmul(self):
 
     def basic_matmul(module):
       module.basic_matmul(tf_utils.uniform([LEFT_DIM, INNER_DIM]),
                           tf_utils.uniform([INNER_DIM, RIGHT_DIM]))
 
-    self.compare_backends(basic_matmul)
+    self.compare_backends(basic_matmul, self._modules)
 
   def test_matmul_lhs_batch(self):
 
@@ -73,7 +77,7 @@
           tf_utils.uniform([BATCH_DIM, LEFT_DIM, INNER_DIM]),
           tf_utils.uniform([INNER_DIM, RIGHT_DIM]))
 
-    self.compare_backends(matmul_lhs_batch)
+    self.compare_backends(matmul_lhs_batch, self._modules)
 
   def test_matmul_rhs_batch(self):
 
@@ -82,7 +86,7 @@
           tf_utils.uniform([LEFT_DIM, INNER_DIM]),
           tf_utils.uniform([BATCH_DIM, INNER_DIM, RIGHT_DIM]))
 
-    self.compare_backends(matmul_rhs_batch)
+    self.compare_backends(matmul_rhs_batch, self._modules)
 
   def test_matmul_broadcast_singleton_dimension(self):
 
@@ -91,10 +95,15 @@
           tf_utils.uniform([1, LEFT_DIM, INNER_DIM]),
           tf_utils.uniform([BATCH_DIM, INNER_DIM, RIGHT_DIM]))
 
-    self.compare_backends(matmul_broadcast_singleton_dimension)
+    self.compare_backends(matmul_broadcast_singleton_dimension, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/mobile_bert_squad_test.py b/integrations/tensorflow/e2e/mobile_bert_squad_test.py
index e8d3997..4f8daf5 100644
--- a/integrations/tensorflow/e2e/mobile_bert_squad_test.py
+++ b/integrations/tensorflow/e2e/mobile_bert_squad_test.py
@@ -81,10 +81,14 @@
     return self.inference_func(**inputs)
 
 
-@tf_test_utils.compile_module(MobileBertSquad, exported_names=["predict"])
 class MobileBertSquadTest(tf_test_utils.TracedModuleTestCase):
   """Tests of MobileBertSquad."""
 
+  def __init__(self, methodName="runTest"):
+    super(MobileBertSquadTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(MobileBertSquad,
+                                                    exported_names=["predict"])
+
   def test_predict(self):
 
     def predict(module):
@@ -94,11 +98,15 @@
 
       module.predict(input_ids, input_mask, segment_ids, atol=1e0)
 
-    self.compare_backends(predict)
+    self.compare_backends(predict, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
-
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/range_test.py b/integrations/tensorflow/e2e/range_test.py
index f1e093c..f4e3e97 100644
--- a/integrations/tensorflow/e2e/range_test.py
+++ b/integrations/tensorflow/e2e/range_test.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
@@ -31,9 +32,12 @@
     return tf.range(start, stop, delta)
 
 
-@tf_test_utils.compile_module(RangeModule)
 class RangeTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(RangeTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(RangeModule)
+
   def test_range(self):
 
     def range(module):
@@ -42,10 +46,15 @@
       delta = np.array(3, dtype=np.float32)
       result = module.range(start, stop, delta)
 
-    self.compare_backends(range)
+    self.compare_backends(range, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/resource_ops_test.py b/integrations/tensorflow/e2e/resource_ops_test.py
index dd5ad6d..d387e86 100644
--- a/integrations/tensorflow/e2e/resource_ops_test.py
+++ b/integrations/tensorflow/e2e/resource_ops_test.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
@@ -28,18 +29,26 @@
     return self.counter.assign_add(value)
 
 
-@tf_test_utils.compile_module(ResourcesOpsModule)
 class ResourcesOpsTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(ResourcesOpsTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(ResourcesOpsModule)
+
   def test_add_assign(self):
 
     def add_assign(module):
       module.add_assign(np.array(9., dtype=np.float32))
 
-    self.compare_backends(add_assign)
+    self.compare_backends(add_assign, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/ring_buffer_test.py b/integrations/tensorflow/e2e/ring_buffer_test.py
index 8e437ea..c56bf31 100644
--- a/integrations/tensorflow/e2e/ring_buffer_test.py
+++ b/integrations/tensorflow/e2e/ring_buffer_test.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
@@ -31,16 +32,20 @@
 
     # buffer has size [buffer_size, dims]
     # only the first dimension is used for updating buffer in a ring manner
-    self._buffer = tf.Variable(
-        tf.zeros((self._buffer_size,) + dims, dtype=dtype),
-        trainable=False,
-        name="RingBuffer")
+    self._buffer = tf.Variable(tf.zeros((self._buffer_size,) + dims,
+                                        dtype=dtype),
+                               trainable=False,
+                               name="RingBuffer")
     # Size of the data available for reading
-    self._data_size = tf.Variable(
-        0, trainable=False, dtype=tf.int32, name="FramerBuffer/Size")
+    self._data_size = tf.Variable(0,
+                                  trainable=False,
+                                  dtype=tf.int32,
+                                  name="FramerBuffer/Size")
     # The index pointing to the head of the data available for reading
-    self._read_head = tf.Variable(
-        0, trainable=False, dtype=tf.int32, name="FramerBuffer/Head")
+    self._read_head = tf.Variable(0,
+                                  trainable=False,
+                                  dtype=tf.int32,
+                                  name="FramerBuffer/Head")
 
   @property
   def dtype(self):
@@ -82,8 +87,8 @@
     start = tf.math.floormod(
         self._read_head.read_value() + self._data_size.read_value(),
         self._buffer_size)
-    indices = tf.math.floormod(
-        tf.range(start, limit=start + elements_size), self._buffer_size)
+    indices = tf.math.floormod(tf.range(start, limit=start + elements_size),
+                               self._buffer_size)
 
     tf.compat.v1.scatter_update(self._buffer, indices, elements)
 
@@ -118,8 +123,8 @@
       Tensor of elements with shape [length, dims...].
     """
     start = self._read_head + offset
-    indices = tf.math.floormod(
-        tf.range(start, limit=start + length), self._buffer_size)
+    indices = tf.math.floormod(tf.range(start, limit=start + length),
+                               self._buffer_size)
     result = tf.gather(self._buffer, indices)
     if consume:
       self.consume(length, offset)
@@ -148,8 +153,9 @@
   def build(self, input_shape):
     super(StatefulRingBuffer, self).build(input_shape)
     buffer_size = self.state_shape[1]
-    self.rb = RingBuffer(
-        buffer_size=buffer_size, dims=(self.state_shape[2],), dtype=tf.float32)
+    self.rb = RingBuffer(buffer_size=buffer_size,
+                         dims=(self.state_shape[2],),
+                         dtype=tf.float32)
 
   def call(self, inputs):
     self.rb.write(inputs)
@@ -177,10 +183,13 @@
     return self.rb(x)
 
 
-@tf_test_utils.compile_module(
-    StatefulRingBufferModule, exported_names=["predict"])
 class StatefulRingBufferTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(StatefulRingBufferTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(StatefulRingBufferModule,
+                                                    exported_names=["predict"])
+
   def test_stateful_ringbuffer(self):
 
     def stateful_ringbuffer(module):
@@ -198,10 +207,15 @@
       module.predict(input3)
       # output = np.array([[3.0, 4.0]], dtype=np.float32)
 
-    self.compare_backends(stateful_ringbuffer)
+    self.compare_backends(stateful_ringbuffer, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/scatter_update_test.py b/integrations/tensorflow/e2e/scatter_update_test.py
index 8b43e3a..8f34002 100644
--- a/integrations/tensorflow/e2e/scatter_update_test.py
+++ b/integrations/tensorflow/e2e/scatter_update_test.py
@@ -11,11 +11,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+"""Test scatter update behavior for tensorflow."""
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
-"""Test scatter update behavior for tensorflow."""
 
 
 class ScatterUpdateModule(tf.Module):
@@ -48,9 +49,12 @@
     return tf.tensor_scatter_nd_update(tensor, indices, updates)
 
 
-@tf_test_utils.compile_module(ScatterUpdateModule)
 class ScatterUpdateTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(ScatterUpdateTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(ScatterUpdateModule)
+
   def test_scatter_update_1D(self):
 
     def scatter_update_1D(module):
@@ -59,7 +63,7 @@
       updates = np.array([9, 10, 11], dtype=np.int32)
       module.scatter_update_1D(tensor, indices, updates)
 
-    self.compare_backends(scatter_update_1D)
+    self.compare_backends(scatter_update_1D, self._modules)
 
   def test_scatter_update_2D(self):
 
@@ -69,7 +73,7 @@
       updates = np.array([2, 5, 8], dtype=np.int32)
       module.scatter_update_2D(tensor, indices, updates)
 
-    self.compare_backends(scatter_update_2D)
+    self.compare_backends(scatter_update_2D, self._modules)
 
   def test_scatter_update_2D_slice(self):
 
@@ -79,10 +83,15 @@
       updates = np.array([[2, 3, 4]], dtype=np.int32)
       module.scatter_update_2D_slice(tensor, indices, updates)
 
-    self.compare_backends(scatter_update_2D_slice)
+    self.compare_backends(scatter_update_2D_slice, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/simple_arithmetic_test.py b/integrations/tensorflow/e2e/simple_arithmetic_test.py
index aaec578..cc9ca09 100644
--- a/integrations/tensorflow/e2e/simple_arithmetic_test.py
+++ b/integrations/tensorflow/e2e/simple_arithmetic_test.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 """Several baseline e2e simple arithmetic tests."""
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 from pyiree.tf.support import tf_utils
@@ -37,9 +38,12 @@
     return tf.matmul(a, b)
 
 
-@tf_test_utils.compile_module(SimpleArithmeticModule)
 class SimpleArithmeticTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(SimpleArithmeticTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(SimpleArithmeticModule)
+
   def test_simple_mul(self):
 
     def simple_mul(module):
@@ -48,7 +52,7 @@
       c = module.simple_mul(a, b)
       module.simple_mul(a, c)
 
-    self.compare_backends(simple_mul)
+    self.compare_backends(simple_mul, self._modules)
 
   def test_simple_matmul(self):
 
@@ -58,10 +62,15 @@
       b = tf_utils.uniform((3072, 256)) * 1e-3
       module.simple_matmul(a, b)
 
-    self.compare_backends(simple_matmul)
+    self.compare_backends(simple_matmul, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/simple_stateful_test.py b/integrations/tensorflow/e2e/simple_stateful_test.py
index 1120a4f..e4140d4 100644
--- a/integrations/tensorflow/e2e/simple_stateful_test.py
+++ b/integrations/tensorflow/e2e/simple_stateful_test.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
@@ -33,19 +34,27 @@
     return self.counter
 
 
-@tf_test_utils.compile_module(SimpleStatefulModule)
 class StatefulTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(StatefulTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(SimpleStatefulModule)
+
   def test_stateful(self):
 
     def get_state(module):
       module.inc_by(np.array(1., dtype=np.float32))
       module.get_state()
 
-    self.compare_backends(get_state)
+    self.compare_backends(get_state, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/sliding_window_test.py b/integrations/tensorflow/e2e/sliding_window_test.py
index 513aa97..413ffb2 100644
--- a/integrations/tensorflow/e2e/sliding_window_test.py
+++ b/integrations/tensorflow/e2e/sliding_window_test.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
@@ -75,9 +76,13 @@
     return self.sw(x)
 
 
-@tf_test_utils.compile_module(SlidingWindowModule, exported_names=["predict"])
 class SlidingWindowTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(SlidingWindowTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(SlidingWindowModule,
+                                                    exported_names=["predict"])
+
   def test_sliding_window(self):
 
     def sliding_window(module):
@@ -89,10 +94,15 @@
       result2 = module.predict(input2)
       # output2 = np.array([[0.0, 0.0], [1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
 
-    self.compare_backends(sliding_window)
+    self.compare_backends(sliding_window, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/strings_test.py b/integrations/tensorflow/e2e/strings_test.py
index 206b33c..70f6468 100644
--- a/integrations/tensorflow/e2e/strings_test.py
+++ b/integrations/tensorflow/e2e/strings_test.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 import string
@@ -40,9 +41,12 @@
     return tf.strings.reduce_join(wps, 1)
 
 
-@tf_test_utils.compile_module(StringsModule)
 class StringsTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(StringsTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(StringsModule)
+
   def test_print_ids(self):
 
     def print_ids(module):
@@ -51,7 +55,7 @@
            [13, 24, 16, 28, 94, 15, 24, 27, 94, 28, 29, 10, 34]])
       module.print_ids(input_ids)
 
-    self.compare_backends(print_ids)
+    self.compare_backends(print_ids, self._modules)
 
   def test_strings_to_ids(self):
 
@@ -61,10 +65,15 @@
            [13, 24, 16, 28, 94, 15, 24, 27, 94, 28, 29, 10, 34]])
       module.strings_to_ids(input_ids)
 
-    self.compare_backends(strings_to_ids)
+    self.compare_backends(strings_to_ids, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/tensorlist_test.py b/integrations/tensorflow/e2e/tensorlist_test.py
index 440bd43..62babc5 100644
--- a/integrations/tensorflow/e2e/tensorlist_test.py
+++ b/integrations/tensorflow/e2e/tensorlist_test.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 from pyiree.tf.support import tf_utils
@@ -50,8 +51,9 @@
   @tf.function(
       input_signature=[tf.TensorSpec([STATIC_SIZE, STATIC_SIZE], tf.float32)])
   def slice_first_element_with_from_tensor_high_rank(self, t):
-    ta = tf.TensorArray(
-        dtype=tf.float32, size=STATIC_SIZE, element_shape=[STATIC_SIZE])
+    ta = tf.TensorArray(dtype=tf.float32,
+                        size=STATIC_SIZE,
+                        element_shape=[STATIC_SIZE])
     ta = ta.unstack(t)
     return ta.read(0)
 
@@ -66,23 +68,26 @@
     return ta.stack()
 
 
-@tf_test_utils.compile_module(TensorListModule)
 class TensorListTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(TensorListTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(TensorListModule)
+
   def test_identity_through_tensorlist(self):
 
     def identity_through_tensorlist(module):
       module.identity_through_tensorlist(np.array(42., dtype=np.float32))
 
-    self.compare_backends(identity_through_tensorlist)
+    self.compare_backends(identity_through_tensorlist, self._modules)
 
   def test_add_through_tensorlist(self):
 
     def add_through_tensorlist(module):
-      module.add_through_tensorlist(
-          np.array(42., dtype=np.float32), np.array(43., dtype=np.float32))
+      module.add_through_tensorlist(np.array(42., dtype=np.float32),
+                                    np.array(43., dtype=np.float32))
 
-    self.compare_backends(add_through_tensorlist)
+    self.compare_backends(add_through_tensorlist, self._modules)
 
   def test_slice_first_element_with_from_tensor(self):
 
@@ -90,7 +95,7 @@
       module.slice_first_element_with_from_tensor(
           np.arange(STATIC_SIZE, dtype=np.float32))
 
-    self.compare_backends(slice_first_element_with_from_tensor)
+    self.compare_backends(slice_first_element_with_from_tensor, self._modules)
 
   def test_slice_first_element_with_from_tensor_high_rank(self):
 
@@ -98,18 +103,24 @@
       module.slice_first_element_with_from_tensor_high_rank(
           tf_utils.ndarange([STATIC_SIZE, STATIC_SIZE]))
 
-    self.compare_backends(slice_first_element_with_from_tensor_high_rank)
+    self.compare_backends(slice_first_element_with_from_tensor_high_rank,
+                          self._modules)
 
   def test_concat_with_tensorlist_stack(self):
 
     def concat_with_tensorlist_stack(module):
-      module.concat_with_tensorlist_stack(
-          np.array(42., dtype=np.float32), np.array(43., dtype=np.float32))
+      module.concat_with_tensorlist_stack(np.array(42., dtype=np.float32),
+                                          np.array(43., dtype=np.float32))
 
-    self.compare_backends(concat_with_tensorlist_stack)
+    self.compare_backends(concat_with_tensorlist_stack, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)