Refactors tf_test_utils, creating tf_utils and simplifying module specification. (#2452)

- Factors `set_random_seed` and `save_and_compile_tf_module` out of `tf_test_utils.py` into the new `tf_utils.py`.
  - I set the default behavior of `compile_tf_module` to be that it doesn't keep the saved model, since it's for most purposes an implementation detail. The main exceptions to this are the cases of comparing IREE benchmarks to other frameworks, and providing debugging info for tests.
- Simplifies the `tf_test_utils.compile_module` API by removing support for tests that have multiple modules.
  - This feature was unused internally and in OSS, and debugging wasn't fully supported in its implementation.
- Simplifies the way that modules are referenced in test files.
  - Previously, a union of `CompiledModules` was referenced with `self.modules.module_name.all`, and is now referenced with `self.get_module()`.
  - Specific `CompiledModule`s can still be referenced using `self.compiled_modules.backend_name`.

Additionally, whether or not compilation artifacts were being saved was inconsistently controlled by `FLAGS.debug_dir`. Some compilation artifacts were being saved to `--test_tempdir` in `FLAGS.debug_dir`s absence, while other compilation artifacts and the results of the test invocations weren't.

I changed the behavior to here to always save all compilation artifacts and results in `FLAGS.debug_dir/test_name` if `FLAGS.debug_dir` is provided, and in `--test_tempdir/test_name` otherwise.

Incidental changes:
- Deleted `exported_names_test.py` as it was redundant.
- Fixed bug where `iree_compiler_targets=["vulkan-*"]` created artifact names like `compiled__vulkan_.vmfb`.
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/BUILD b/integrations/tensorflow/bindings/python/pyiree/tf/support/BUILD
index c27d3d0..e9c7e49 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/BUILD
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/BUILD
@@ -30,6 +30,7 @@
         "__init__.py",
         "tf_test_driver.py",
         "tf_test_utils.py",
+        "tf_utils.py",
     ],
     deps = INTREE_TENSORFLOW_PY_DEPS + [
         "//integrations/tensorflow/bindings/python:pathsetup",  # build_cleaner: keep
@@ -49,3 +50,15 @@
         "//integrations/tensorflow/bindings/python/pyiree/tf/support",
     ],
 )
+
+iree_py_test(
+    name = "tf_utils_test",
+    srcs = [
+        "tf_utils.py",
+        "tf_utils_test.py",
+    ],
+    python_version = "PY3",
+    deps = INTREE_TENSORFLOW_PY_DEPS + [
+        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+    ],
+)
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index 3b55853..22e0392 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -21,15 +21,14 @@
 
 import collections
 import os
-import random
 import re
-import tempfile
 
 from absl import flags
 from absl import logging
 import numpy as np
 from pyiree import rt
 from pyiree.tf import compiler
+from pyiree.tf.support import tf_utils
 import tensorflow.compat.v2 as tf
 
 flags.DEFINE_string("target_backends", None,
@@ -40,115 +39,20 @@
     "--test_tmpdir")
 FLAGS = flags.FLAGS
 
-ORIGINAL_SAVED_MODEL_PATH_ATTR = "_ORIGINAL_SAVED_MODEL_PATH"
 
-# Per test directory where debug artifacts are dumped.
-global_debug_dir = None
+def _setup_test_debug_dir(test_name):
+  global global_debug_dir
 
+  # Use test_tempdir (which defaults to '/tmp/absl_testing/') if FLAGS.debug_dir
+  # is not provided.
+  parent = FLAGS.debug_dir if FLAGS.debug_dir is not None else FLAGS.test_tmpdir
+  global_debug_dir = os.path.join(parent, test_name)
 
-def set_random_seed(seed=0):
-  """Set random seed for tf, np and random."""
-  tf.random.set_seed(seed)
-  random.seed(seed)
-  np.random.seed(seed)
-
-
-def save_and_compile_tf_module(tf_module, exported_names=(),
-                               target_backends=()):
-  """Saves and compiles a TensorFlow tf.Module.
-
-  Note that if the module has the special _ORIGINAL_SAVED_MODEL_PATH attribute,
-  then it will be compiled directly from that path instead of saved and then
-  loaded.
-
-  Args:
-    tf_module: A tf.Module.
-    exported_names: Iterable of dotted function names to consider for
-      compilation.
-    target_backends: Iterable of string backend names to compile for.
-
-  Returns:
-    An _IreeCompiledModule.
-  """
-
-  def compile_from_path(sm_path):
-    compiler_context = compiler.Context()
-    # Break up the compilation so we can save debug artifacts.
-    compiler_module = compiler.tf_load_saved_model(
-        sm_path,
-        exported_names=exported_names,
-        compiler_context=compiler_context,
-        pass_pipeline=())
-
-    # Save the input MLIR module.
-    flattened_target_backends = re.sub("[^0-9a-zA-Z_]+", "_",
-                                       "__".join(target_backends))
-    if global_debug_dir:
-      mlir_path = os.path.join(global_debug_dir,
-                               "raw__%s.mlir" % flattened_target_backends)
-      logging.info("Saving raw TF input MLIR to: %s", mlir_path)
-      with open(mlir_path, "w") as f:
-        f.write(compiler_module.to_asm())
-
-    # Now run the passes manually that tf_load_saved_model would usually do.
-    compiler_module.run_pass_pipeline(compiler.TF_IMPORT_PASS_PIPELINE)
-
-    if global_debug_dir:
-      mlir_path = os.path.join(global_debug_dir,
-                               "input__%s.mlir" % flattened_target_backends)
-      logging.info("Saving IREE input MLIR to: %s", mlir_path)
-      with open(mlir_path, "w") as f:
-        f.write(compiler_module.to_asm())
-
-    compiled_module = compiler_module.compile(target_backends=target_backends)
-    if global_debug_dir:
-      compiled_path = os.path.join(
-          global_debug_dir, "compiled__%s.vmfb" % flattened_target_backends)
-      logging.info("Saving compiled IREE module to: %s", compiled_path)
-      with open(compiled_path, "wb") as f:
-        f.write(compiled_module)
-
-    return compiled_module
-
-  if hasattr(tf_module, ORIGINAL_SAVED_MODEL_PATH_ATTR):
-    # Compile directly from the original path.
-    sm_path = getattr(tf_module, ORIGINAL_SAVED_MODEL_PATH_ATTR)
-    logging.info(
-        "Compiling from original saved_model path (not round-tripping): %s",
-        sm_path)
-    return compile_from_path(sm_path)
-  else:
-    options = tf.saved_model.SaveOptions(save_debug_info=True)
-    if FLAGS.debug_dir is None:
-      # Round-trip through a temporary directory.
-      with tempfile.TemporaryDirectory() as sm_path:
-        tf.saved_model.save(tf_module, sm_path, options=options)
-        return compile_from_path(sm_path)
-    else:
-      # Use the supplied directory.
-      sm_path = os.path.join(FLAGS.debug_dir, "SavedModel")
-      tf.saved_model.save(tf_module, sm_path, options=options)
-      return compile_from_path(sm_path)
-
-
-def load_tf_module(path):
-  """Wrapper around tf.saved_model.load which preserves the path.
-
-  Args:
-    path: The path to load from.
-
-  Returns:
-    The loaded module with an extra property _ORIGINAL_SAVED_MODEL_PATH added.
-    This is used on subsequent compiles to load directly from the original
-    path, which gives us unmolested access to the original debug information,
-    which TensorFlow tends to lose on round-trip.
-  """
-  tf_module = tf.saved_model.load(path)
-  assert not hasattr(tf_module, ORIGINAL_SAVED_MODEL_PATH_ATTR), (
-      "Saved model (%s) already has attribute %s" %
-      (path, ORIGINAL_SAVED_MODEL_PATH_ATTR))
-  setattr(tf_module, ORIGINAL_SAVED_MODEL_PATH_ATTR, path)
-  return tf_module
+  # Create the directory.
+  try:
+    os.makedirs(global_debug_dir)
+  except IOError:
+    logging.exception("Error creating debug dir for: %s", global_debug_dir)
 
 
 class CompiledModule(object):
@@ -226,10 +130,11 @@
 
   def __init__(self, ctor, exported_names, backend):
     super().__init__(ctor, exported_names, backend)
-    self._iree_module_blob = save_and_compile_tf_module(
+    self._iree_module_blob = tf_utils.compile_tf_module(
         ctor(),
         exported_names=exported_names,
-        target_backends=backend.iree_compiler_targets)
+        target_backends=backend.iree_compiler_targets,
+        artifacts_dir=global_debug_dir)
     self._iree_module = rt.VmModule.from_flatbuffer(self._iree_module_blob)
 
   def instantiate(self):
@@ -438,98 +343,75 @@
       return self
 
     def save(self):
-      if FLAGS.debug_dir:
-        for i in range(len(self)):
-          result = self[i]  # output generated by a model
-          field = self._fields[i]  # backend name
-          fname = os.path.join(FLAGS.debug_dir, "output_{}".format(field))
-          with open(fname, "w") as file:
-            # content of txt file can be converted to py objects by eval(txt)
-            file.write(str(result))
+      for i in range(len(self)):
+        result = self[i]  # output generated by a model
+        field = self._fields[i]  # backend name
+        fname = os.path.join(global_debug_dir, "output_{}".format(field))
+        with open(fname, "w") as file:
+          # content of txt file can be converted to py objects by eval(txt)
+          file.write(str(result))
       return self
 
   return MultiResults
 
 
-def _instantiate_modules(compiled_modules_dict):
-  """Given a dict of modules, instantiates them.
+def _instantiate_backends(compiled_backends):
+  """Creates a VirtualBackend namedtuple class for a dict.
 
   Args:
-    compiled_modules_dict: Dictionary of
-        {module_name:{backend_name:CompiledModule}} that should be instantiated.
+    compiled_backends: Dictionary of backend_name:ModuleInstance.
 
   Returns:
-    namedtuple mapping module_key:VirtualBackendsClass for every module
-    in compiled_modules_dict. The VirtualBackendsClass is a dynamically
+    a VirtualBackendsClass instance. The VirtualBackendsClass is a dynamically
     generated namedtuple mapping backend_name:ModuleInstance, where the
     ModuleInstance allows attribute resolution of public functions on the
-    module. The VirtualBackendsClass also contributes some convenience
-    methods for selecting all or a subset of matching backend modules.
+    module. The VirtualBackendsClass also contributes some convenience methods
+    for selecting all or a subset of matching backend modules.
   """
+  tuple_class = collections.namedtuple("VirtualBackendsTuple",
+                                       compiled_backends.keys())
 
-  def instantiate_backends(module_dict):
-    """Creates a VirtualBackend namedtuple class for a dict.
+  class VirtualBackendsClass(tuple_class):
+    """Adds a __call__ method that creates a virtual module."""
 
-    Args:
-      module_dict: Dictionary of backend_name:ModuleInstance.
+    def multi(self, match_spec="."):
+      """Selects multiple backends that match a regular expression."""
+      return _VirtualModuleInstance(self._asdict(), match_spec)
 
-    Returns:
-      namedtuple subclass with a field for every backend and special
-      all and multi() helpers.
-    """
-    tuple_class = collections.namedtuple("VirtualBackendsTuple",
-                                         module_dict.keys())
+    @property
+    def all(self):
+      """Shorthand for multi() which selects all backends."""
+      return self.multi()
 
-    class VirtualBackendsClass(tuple_class):
-      """Adds a __call__ method that creates a virtual module."""
-
-      def multi(self, match_spec="."):
-        """Selects multiple backends that match a regular expression."""
-        return _VirtualModuleInstance(self._asdict(), match_spec)
-
-      @property
-      def all(self):
-        """Shorthand for multi() which selects all backends."""
-        return self.multi()
-
-    return VirtualBackendsClass(
-        *[m.instantiate() for m in module_dict.values()])
-
-  module_keys = [k for (k, _) in compiled_modules_dict.items()]
-  module_insts = [
-      instantiate_backends(module_dict)
-      for (_, module_dict) in compiled_modules_dict.items()
-  ]
-  tuple_class = collections.namedtuple("Modules", module_keys)
-  return tuple_class(*module_insts)
+  return VirtualBackendsClass(
+      *[m.instantiate() for m in compiled_backends.values()])
 
 
-def compile_modules(**kwargs):
-  """Decorator applied to a SavedModelTestCase subclass to compile modules.
+def compile_module(module_ctor, exported_names=()):
+  """SavedModelTestCase decorator that compiles a tf.Module.
+
+  A CompiledModule is created for each backend in --target_backends. They can
+  be accessed individually via self.compiled_modules.backend_name or as a union
+  via self.get_module().
 
   Args:
-    **kwargs: name/Module constructor mappings. Each such arg will be added to
-      the classes 'compiled_modules' field.
+    module_ctor: tf.Module subclass or function which returns a tf.Module
+      subclass instance.
+    exported_names: optional iterable of strings representing the exported names
+      to keep. Used primarily for Keras models (e.g. exported_names=["predict"])
 
   Returns:
     Class decorator function.
   """
 
   def decorator(cls):
-    """Decorator function."""
-    assert issubclass(cls, SavedModelTestCase), (
-        "The 'compile_modules' decorator must be applied to a "
-        "SavedModelTestCase derived class.")
-    if not cls._modules_to_compile:
-      cls._modules_to_compile = {}
-    for name, ctor in kwargs.items():
-      assert name not in cls._modules_to_compile, (
-          "@compile_modules called with duplicate module names '%s'" % (name,))
-      exported_names = ()
-      if isinstance(ctor, tuple):
-        ctor, exported_names = ctor
-      cls._modules_to_compile[name] = (ctor, exported_names)
-
+    """Decorator Function."""
+    if not issubclass(cls, SavedModelTestCase):
+      logging.exception(
+          "The 'compile_module' decorator must be applied to a "
+          "SavedModelTestCase derived class, which %s is not.", cls)
+    cls._module_ctor = module_ctor
+    cls._exported_names = exported_names
     return cls
 
   return decorator
@@ -617,13 +499,13 @@
 class SavedModelTestCase(tf.test.TestCase):
   """Tests against a SavedModel."""
 
-  # Will be initialized to a dict by the @compile_modules decorator.
-  # The dict maps module name to (ctor, exported_names, backend_names).
-  _modules_to_compile = None
+  # Will be initialized by the @compile_module decorator.
+  _module_ctor = None
+  _exported_names = ()
 
-  # Will be initialized in setUpClass to a dict of (name, CompiledModule)
-  # instances mirroring _modules_to_compile.
-  compiled_modules = None
+  # Will be initialized in setUpClass to a dict of
+  # {backend_name: CompiledModule}.
+  _compiled_backends_dict = None
 
   def __init__(self, *args, **kwargs):
     super().__init__(*args, **kwargs)
@@ -632,42 +514,27 @@
   @classmethod
   def setUpClass(cls):
     super().setUpClass()
-    cls.compiled_modules = {}
-    if cls._modules_to_compile:
-      for name, (ctor, exported_names) in cls._modules_to_compile.items():
+    if cls._module_ctor is not None:
+      # Setup the debug directory for this test. Creates a global variable
+      # `global_debug_dir`.
+      _setup_test_debug_dir(test_name=cls.__name__)
 
-        # Setup the debug directory.
-        debug_parent_dir = FLAGS.debug_dir
-        if not debug_parent_dir:
-          debug_parent_dir = FLAGS.test_tmpdir
-        debug_parent_dir = os.path.join(debug_parent_dir, cls.__name__)
+      # Setup crash reproducer for the test.
+      crash_reproducer_path = os.path.join(global_debug_dir, "reproducer.mlir")
+      compiler.Context.default_crash_reproducer_path = crash_reproducer_path
 
-        try:
-          os.makedirs(debug_parent_dir)
-        except IOError:
-          logging.exception("Error creating crash reproducer dir for: %s",
-                            debug_parent_dir)
+      # Create a CompiledModule for each backend.
+      try:
+        backends = get_backends()
+        cls._compiled_backends_dict = {}
+        for backend in backends:
+          cls._compiled_backends_dict[backend.name] = CompiledModule.create(
+              cls._module_ctor, cls._exported_names, backend)
 
-        # Setup crash reproducer and global debug dir.
-        crash_reproducer_path = os.path.join(debug_parent_dir,
-                                             name + "_reproducer.mlir")
-        compiler.Context.default_crash_reproducer_path = crash_reproducer_path
-        global global_debug_dir
-        global_debug_dir = debug_parent_dir
-
-        try:
-          # Compile.
-          backends = get_backends()
-          cls.compiled_modules[name] = dict([
-              (backend.name, CompiledModule.create(ctor, exported_names,
-                                                   backend))
-              for backend in backends
-          ])
-        finally:
-          # Disable crash reproducer (to avoid inadvertently overwriting this
-          # path on a subsequent interaction).
-          compiler.Context.default_crash_reproducer_path = None
-          global_debug_dir = None
+      finally:
+        # Disable crash reproducer (to avoid inadvertently overwriting this
+        # path on a subsequent interaction).
+        compiler.Context.default_crash_reproducer_path = None
 
   @classmethod
   def tearDownClass(cls):
@@ -675,4 +542,7 @@
 
   def setUp(self):
     super().setUp()
-    self.modules = _instantiate_modules(self.compiled_modules)
+    self.compiled_modules = _instantiate_backends(self._compiled_backends_dict)
+
+  def get_module(self):
+    return self.compiled_modules.all
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
new file mode 100644
index 0000000..cdf26d9
--- /dev/null
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -0,0 +1,125 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utilities interop with TensorFlow."""
+
+import os
+import random
+import re
+import tempfile
+
+from absl import flags
+from absl import logging
+import numpy as np
+from pyiree.tf import compiler
+import tensorflow.compat.v2 as tf
+
+FLAGS = flags.FLAGS
+
+
+def set_random_seed(seed=0):
+  """Set random seed for tf, np and random."""
+  tf.random.set_seed(seed)
+  random.seed(seed)
+  np.random.seed(seed)
+
+
+def compile_tf_module(tf_module,
+                      target_backends=(),
+                      exported_names=(),
+                      artifacts_dir=None):
+  """Compiles a TensorFlow tf.Module and optionally saves compilation artifacts.
+
+  If artifacts_dir is provided then the following artifacts will be saved:
+    saved_model:
+      A TF SavedModel directory containing the files used translate the
+      tf.Module into an IREE module.
+    tf_input__backends.mlir:
+      MLIR for the module in TF's input dialect.
+    iree_input__backends.mlir:
+      The MLIR above translated to IREE via compiler.TF_IMPORT_PASS_PIPELINE.
+    compiled__backends.vmfb:
+      A VM FlatBuffer compiled to the target backends from the IREE MLIR above.
+  Here 'backends' is a '__' delimited list of iree backends (e.g. vmla__llvm_ir)
+
+  Args:
+    tf_module: A tf.Module.
+    target_backends: Iterable of string backend names to compile for.
+    exported_names: Iterable of dotted function names to consider for
+      compilation.
+    artifacts_dir: An optional string pointing to where compilation artifacts
+      should be saved.
+
+  Returns:
+    A compiled IREE module blob.
+  """
+
+  def _compile_from_path(sm_path):
+    """Helper function for compile_tf_module."""
+    # We break up the compilation here so we can save intermediary artifacts.
+    compiler_context = compiler.Context()
+
+    if artifacts_dir is not None:
+      normalized_backends = []
+      for backend in target_backends:
+        # Remove unusual characters and ensure names don't end or start in "_".
+        backend = re.sub("[^0-9a-zA-Z_]+", "_", backend)
+        normalized_backends.append(backend.strip("_"))
+      backends_string = "__".join(normalized_backends)
+
+    # Convert the tf_module into raw TF input MLIR.
+    compiler_module = compiler.tf_load_saved_model(
+        sm_path,
+        exported_names=exported_names,
+        compiler_context=compiler_context,
+        pass_pipeline=())
+
+    if artifacts_dir is not None:
+      tf_mlir_path = os.path.join(artifacts_dir,
+                                  f"tf_input__{backends_string}.mlir")
+      logging.info("Saving raw TF input MLIR to: %s", tf_mlir_path)
+      with open(tf_mlir_path, "w") as f:
+        f.write(compiler_module.to_asm())
+
+    # Now run the passes manually that tf_load_saved_model would usually do.
+    compiler_module.run_pass_pipeline(compiler.TF_IMPORT_PASS_PIPELINE)
+
+    if artifacts_dir is not None:
+      iree_mlir_path = os.path.join(artifacts_dir,
+                                    f"iree_input__{backends_string}.mlir")
+      logging.info("Saving IREE input MLIR to: %s", iree_mlir_path)
+      with open(iree_mlir_path, "w") as f:
+        f.write(compiler_module.to_asm())
+
+    compiled_module = compiler_module.compile(target_backends=target_backends)
+    if artifacts_dir is not None:
+      compiled_path = os.path.join(artifacts_dir,
+                                   f"compiled__{backends_string}.vmfb")
+      logging.info("Saving compiled IREE module to: %s", compiled_path)
+      with open(compiled_path, "wb") as f:
+        f.write(compiled_module)
+
+    return compiled_module
+
+  options = tf.saved_model.SaveOptions(save_debug_info=True)
+  if artifacts_dir is not None:
+    # Save the saved model alongside the other compilation artifacts.
+    sm_path = os.path.join(artifacts_dir, "saved_model")
+    tf.saved_model.save(tf_module, sm_path, options=options)
+    return _compile_from_path(sm_path)
+  else:
+    # Round-trip the saved model through a temporary directory.
+    with tempfile.TemporaryDirectory() as sm_path:
+      tf.saved_model.save(tf_module, sm_path, options=options)
+      return _compile_from_path(sm_path)
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
new file mode 100644
index 0000000..89a0011
--- /dev/null
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
@@ -0,0 +1,63 @@
+# Lint as: python3
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for pyiree.tf.support.tf_utils."""
+
+import os
+import tempfile
+
+from absl.testing import parameterized
+from pyiree.tf.support import tf_utils
+import tensorflow as tf
+
+
+class ConstantModule(tf.Module):
+
+  @tf.function(input_signature=[])
+  def meaning(self):
+    return tf.constant([42.])
+
+
+class UtilsTests(tf.test.TestCase, parameterized.TestCase):
+
+  @parameterized.named_parameters([
+      {
+          'testcase_name': 'single_backend',
+          'target_backends': ['vmla'],
+      },
+      {
+          'testcase_name': 'multiple_backends',
+          'target_backends': ['vmla', 'llvm'],
+      },
+  ])
+  def test_artifact_saving(self, target_backends):
+    with tempfile.TemporaryDirectory() as artifacts_dir:
+      tf_module = ConstantModule()
+      iree_compiled_module = tf_utils.compile_tf_module(
+          tf_module,
+          target_backends=target_backends,
+          artifacts_dir=artifacts_dir)
+
+      artifacts_to_check = [
+          'saved_model',
+          f'tf_input__{"__".join(target_backends)}.mlir',
+          f'iree_input__{"__".join(target_backends)}.mlir',
+          f'compiled__{"__".join(target_backends)}.vmfb',
+      ]
+      for artifact in artifacts_to_check:
+        self.assertTrue(os.path.exists(os.path.join(artifacts_dir, artifact)))
+
+
+if __name__ == '__main__':
+  tf.test.main()
diff --git a/integrations/tensorflow/e2e/README.md b/integrations/tensorflow/e2e/README.md
index bcd7d88..57d8d8e 100644
--- a/integrations/tensorflow/e2e/README.md
+++ b/integrations/tensorflow/e2e/README.md
@@ -23,6 +23,25 @@
 The test suites can be run excluding Vulkan by specifying
 `--test_tag_filters="-driver=vulkan"` in the `bazel test` invocation.
 
+
+## Compiling `tf.Module`s
+
+Compatible TensorFlow modules can be compiled to specific IREE backends using
+`tf_utils.compile_tf_module`. This also optionally saves compilation artifacts
+to a specified directory. These artifacts include: MLIR across various
+lowerings, a TensorFlow SavedModel, and the compiled VM FlatBuffer.
+
+When using Keras models or tf.Modules with functions that IREE can't compile,
+`exported_names` should be specified. For example:
+
+```python
+vmla_module_blob = tf_utils.compile_tf_module(
+    tf_module=SomeKerasModelModule(),
+    target_backends="vmla",
+    exported_names=['predict'])
+```
+
+
 ## Running tests
 
 For locally running tests and iterating on backend development, `bazel run` is
@@ -52,7 +71,7 @@
 you specify `tf` backend only, then we will also test `tf` vs `tf` to capture
 any model initialization/randomization issues (it is a special case for debug
 purpose). For reproducibility of the unit tests we set random seed of `tf` and
-`numpy` by calling `tf_test_utils.set_random_seed()` before model creation.
+`numpy` by calling `tf_utils.set_random_seed()` before model creation.
 
 ## Test Suites
 
diff --git a/integrations/tensorflow/e2e/batch_norm_test.py b/integrations/tensorflow/e2e/batch_norm_test.py
index b9fdc2a..f9f8d8c 100644
--- a/integrations/tensorflow/e2e/batch_norm_test.py
+++ b/integrations/tensorflow/e2e/batch_norm_test.py
@@ -38,7 +38,7 @@
         variance_epsilon=1e-4)
 
 
-@tf_test_utils.compile_modules(bn=BatchNormModule)
+@tf_test_utils.compile_module(BatchNormModule)
 class BatchNormTest(tf_test_utils.SavedModelTestCase):
 
   def test_batch_norm_inference(self):
@@ -49,8 +49,7 @@
     variance = np.random.random((16,)).astype(np.float32) * 1e-3
     offset = np.random.random((16,)).astype(np.float32) * 1e-3
     scale = np.random.random((16,)).astype(np.float32) * 1e-3
-    r = self.modules.bn.all.batch_norm_inference(x, mean, variance, offset,
-                                                 scale)
+    r = self.get_module().batch_norm_inference(x, mean, variance, offset, scale)
     r.print().assert_all_close()
 
 
diff --git a/integrations/tensorflow/e2e/broadcasting_test.py b/integrations/tensorflow/e2e/broadcasting_test.py
index 01c20a6..cde2fd6 100644
--- a/integrations/tensorflow/e2e/broadcasting_test.py
+++ b/integrations/tensorflow/e2e/broadcasting_test.py
@@ -28,23 +28,23 @@
     return lhs + rhs
 
 
-@tf_test_utils.compile_modules(m=BroadcastingModule)
+@tf_test_utils.compile_module(BroadcastingModule)
 class BroadcastingTest(tf_test_utils.SavedModelTestCase):
 
   def test_add_same_shape(self):
-    m = self.modules.m.all
+    m = self.get_module()
     dst = m.add(tf.random.uniform([4]), tf.random.uniform([4]))
     dst.print().assert_all_close()
 
 
 # TODO(silvasean): Make these work.
 #   def test_add_broadcast_lhs(self):
-#     m = self.modules.m.all
+#     m = self.get_module()
 #     dst = m.add(tf.random.uniform([1]), tf.random.uniform([4]))
 #     dst.print().assert_all_close()
 #
 #   def test_add_broadcast_rhs(self):
-#     m = self.modules.m.all
+#     m = self.get_module()
 #     dst = m.add(tf.random.uniform([4]), tf.random.uniform([1]))
 #     dst.print().assert_all_close()
 
diff --git a/integrations/tensorflow/e2e/concat_test.py b/integrations/tensorflow/e2e/concat_test.py
index 8b7a856..a9f9759 100644
--- a/integrations/tensorflow/e2e/concat_test.py
+++ b/integrations/tensorflow/e2e/concat_test.py
@@ -15,6 +15,7 @@
 """Test concat op."""
 
 from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
 import tensorflow.compat.v2 as tf
 
 
@@ -49,36 +50,36 @@
     return tf.concat([a, b], axis=2)
 
 
-@tf_test_utils.compile_modules(mat=ConcatOpsModule)
+@tf_test_utils.compile_module(ConcatOpsModule)
 class ConcatOpsTest(tf_test_utils.SavedModelTestCase):
 
   def test_concat_zero_dim(self):
-    tf_test_utils.set_random_seed()
-    m = self.modules.mat.all
+    tf_utils.set_random_seed()
+    m = self.get_module()
     a = tf.random.uniform([1, 5, 0], dtype=tf.float32)
     b = tf.random.uniform([1, 5, 1], dtype=tf.float32)
     dst = m.concat_zero_dim(a, b)
     dst.assert_all_close()
 
   def concat0axis(self):
-    tf_test_utils.set_random_seed()
-    m = self.modules.mat.all
+    tf_utils.set_random_seed()
+    m = self.get_module()
     a = tf.random.uniform([1, 5, 1], dtype=tf.float32)
     b = tf.random.uniform([1, 5, 1], dtype=tf.float32)
     dst = m.concat_zero_dim(a, b)
     dst.assert_all_close()
 
   def concat1axis(self):
-    tf_test_utils.set_random_seed()
-    m = self.modules.mat.all
+    tf_utils.set_random_seed()
+    m = self.get_module()
     a = tf.random.uniform([1, 5, 1], dtype=tf.float32)
     b = tf.random.uniform([1, 5, 1], dtype=tf.float32)
     dst = m.concat_zero_dim(a, b)
     dst.assert_all_close()
 
   def concat2axis(self):
-    tf_test_utils.set_random_seed()
-    m = self.modules.mat.all
+    tf_utils.set_random_seed()
+    m = self.get_module()
     a = tf.random.uniform([1, 5, 1], dtype=tf.float32)
     b = tf.random.uniform([1, 5, 1], dtype=tf.float32)
     dst = m.concat_zero_dim(a, b)
diff --git a/integrations/tensorflow/e2e/control_flow_test.py b/integrations/tensorflow/e2e/control_flow_test.py
index d579e9b..0223e8c 100644
--- a/integrations/tensorflow/e2e/control_flow_test.py
+++ b/integrations/tensorflow/e2e/control_flow_test.py
@@ -38,17 +38,17 @@
     return i
 
 
-@tf_test_utils.compile_modules(control_flow=ControlFlowModule)
+@tf_test_utils.compile_module(ControlFlowModule)
 class ControlFlowTest(tf_test_utils.SavedModelTestCase):
 
   def test_short_sequence(self):
     input_array = numpy.array(9., dtype=numpy.float32)
-    result = self.modules.control_flow.all.collatz(input_array)
+    result = self.get_module().collatz(input_array)
     result.print().assert_all_close()
 
   def test_long_sequence(self):
     input_array = numpy.array(178., dtype=numpy.float32)
-    result = self.modules.control_flow.all.collatz(input_array)
+    result = self.get_module().collatz(input_array)
     result.print().assert_all_close()
 
 
diff --git a/integrations/tensorflow/e2e/conv_test.py b/integrations/tensorflow/e2e/conv_test.py
index a4997e6..f72b11d 100644
--- a/integrations/tensorflow/e2e/conv_test.py
+++ b/integrations/tensorflow/e2e/conv_test.py
@@ -98,73 +98,73 @@
     return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result")
 
 
-@tf_test_utils.compile_modules(conv2d=Conv2dModule)
+@tf_test_utils.compile_module(Conv2dModule)
 class ConvTest(tf_test_utils.SavedModelTestCase):
 
   def test_id_batch_size_1(self):
     i = np.arange(20, dtype=np.float32).reshape([1, 4, 5, 1])
     k = np.ones([1, 1, 1, 1], dtype=np.float32)
-    r = self.modules.conv2d.all.conv2d_1451x1111_valid(i, k)
+    r = self.get_module().conv2d_1451x1111_valid(i, k)
     r.print().assert_all_close()
 
   def test_id_batch_size_2(self):
     i = np.arange(40, dtype=np.float32).reshape([2, 4, 5, 1])
     k = np.ones([1, 1, 1, 1], dtype=np.float32)
-    r = self.modules.conv2d.all.conv2d_2451x1111_valid(i, k)
+    r = self.get_module().conv2d_2451x1111_valid(i, k)
     r.print().assert_all_close()
 
   def test_asym_kernel(self):
     i = np.arange(20, dtype=np.float32).reshape([1, 4, 5, 1])
     k = np.array([[1, 4, 2], [-2, 0, 1]], dtype=np.float32).reshape(2, 3, 1, 1)
-    r = self.modules.conv2d.all.conv2d_1451x2311_valid(i, k)
+    r = self.get_module().conv2d_1451x2311_valid(i, k)
     r.print().assert_all_close()
 
   def test_padding(self):
     i = np.arange(20, dtype=np.float32).reshape([1, 4, 5, 1])
     k = np.array([[1, 4, 2], [-2, 0, 1]], dtype=np.float32).reshape(2, 3, 1, 1)
-    r = self.modules.conv2d.all.conv2d_1451x2311_same(i, k)
+    r = self.get_module().conv2d_1451x2311_same(i, k)
     r.print().assert_all_close()
 
   def test_batched_padding(self):
     i = np.arange(40, dtype=np.float32).reshape([2, 4, 5, 1])
     k = np.array([[1, 4, 2], [-2, 0, 1]], dtype=np.float32).reshape(2, 3, 1, 1)
-    r = self.modules.conv2d.all.conv2d_2451x2311_same(i, k)
+    r = self.get_module().conv2d_2451x2311_same(i, k)
     r.print().assert_all_close()
 
   def test_feature_reduce(self):
     i = np.arange(40, dtype=np.float32).reshape([1, 4, 5, 2])
     k = np.ones([3, 2, 2, 1], dtype=np.float32)
-    r = self.modules.conv2d.all.conv2d_1452x3221_same(i, k)
+    r = self.get_module().conv2d_1452x3221_same(i, k)
     r.print().assert_all_close()
 
   def test_feature_inflate(self):
     i = np.arange(20, dtype=np.float32).reshape([1, 4, 5, 1])
     k = np.arange(2, dtype=np.float32).reshape([1, 1, 1, 2])
-    r = self.modules.conv2d.all.conv2d_1451x1112_same(i, k)
+    r = self.get_module().conv2d_1451x1112_same(i, k)
     r.print().assert_all_close()
 
   def test_feature_mix(self):
     i = np.arange(40, dtype=np.float32).reshape([1, 4, 5, 2])
     k = np.arange(4, dtype=np.float32).reshape([1, 1, 2, 2])
-    r = self.modules.conv2d.all.conv2d_1452x1122_same(i, k)
+    r = self.get_module().conv2d_1452x1122_same(i, k)
     r.print().assert_all_close()
 
   def test_feature_padded(self):
     i = np.arange(40, dtype=np.float32).reshape([1, 4, 5, 2])
     k = np.arange(24, dtype=np.float32).reshape([2, 2, 2, 3])
-    r = self.modules.conv2d.all.conv2d_1452x2223_same(i, k)
+    r = self.get_module().conv2d_1452x2223_same(i, k)
     r.print().assert_all_close()
 
   def test_feature_unpadded(self):
     i = np.arange(40, dtype=np.float32).reshape([1, 4, 5, 2])
     k = np.arange(24, dtype=np.float32).reshape([2, 2, 2, 3])
-    r = self.modules.conv2d.all.conv2d_1452x2223_valid(i, k)
+    r = self.get_module().conv2d_1452x2223_valid(i, k)
     r.print().assert_all_close()
 
   def test_batched_feature_unpadded(self):
     i = np.arange(80, dtype=np.float32).reshape([2, 4, 5, 2])
     k = np.arange(24, dtype=np.float32).reshape([2, 2, 2, 3])
-    r = self.modules.conv2d.all.conv2d_2452x2223_valid(i, k)
+    r = self.get_module().conv2d_2452x2223_valid(i, k)
     r.print().assert_all_close()
 
 
diff --git a/integrations/tensorflow/e2e/depth_conv_test.py b/integrations/tensorflow/e2e/depth_conv_test.py
index 361b55f..cdf4d1e 100644
--- a/integrations/tensorflow/e2e/depth_conv_test.py
+++ b/integrations/tensorflow/e2e/depth_conv_test.py
@@ -38,19 +38,19 @@
         img, kernel, [1, 1, 1, 1], "SAME", name="result")
 
 
-@tf_test_utils.compile_modules(conv2d=Conv2dModule)
+@tf_test_utils.compile_module(Conv2dModule)
 class ConvTest(tf_test_utils.SavedModelTestCase):
 
   def test_batched_feature_unpadded(self):
     i = np.arange(80, dtype=np.float32).reshape([2, 4, 5, 2])
     k = np.arange(24, dtype=np.float32).reshape([2, 2, 2, 3])
-    r = self.modules.conv2d.all.conv2d_2452x2223_valid(i, k)
+    r = self.get_module().conv2d_2452x2223_valid(i, k)
     r.print().assert_all_close()
 
   def test_batched_feature_unpadded_smae(self):
     i = np.arange(80, dtype=np.float32).reshape([2, 4, 5, 2])
     k = np.arange(48, dtype=np.float32).reshape([2, 4, 2, 3])
-    r = self.modules.conv2d.all.conv2d_2452x2223_same(i, k)
+    r = self.get_module().conv2d_2452x2223_same(i, k)
     r.print().assert_all_close()
 
 
diff --git a/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py b/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
index 5f9c667..04de603 100644
--- a/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
+++ b/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
@@ -19,6 +19,7 @@
 
 import numpy as np
 from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
 import tensorflow.compat.v2 as tf
 
 HIDDEN_1_DIM = 256
@@ -35,7 +36,7 @@
                input_dim=28 * 28,
                classes=10):
     super().__init__()
-    tf_test_utils.set_random_seed()
+    tf_utils.set_random_seed()
     self.hidden_1_dim = hidden_1_dim
     self.hidden_2_dim = hidden_2_dim
     self.input_dim = input_dim
@@ -64,11 +65,11 @@
     return tf.nn.softmax(self.mlp(x))
 
 
-@tf_test_utils.compile_modules(mlp=(Mlp, ["predict"]))
+@tf_test_utils.compile_module(Mlp, exported_names=["predict"])
 class DynamicMlpTest(tf_test_utils.SavedModelTestCase):
 
   def test_dynamic_batch(self):
-    m = self.modules.mlp.all
+    m = self.get_module()
     np.random.seed(12345)
     x = np.random.random([3, 28 * 28]).astype(np.float32) * 1e-3
     m.predict(x).print().assert_all_close()
diff --git a/integrations/tensorflow/e2e/dynamic_mlp_test.py b/integrations/tensorflow/e2e/dynamic_mlp_test.py
index 17da1cd..66f3c06 100644
--- a/integrations/tensorflow/e2e/dynamic_mlp_test.py
+++ b/integrations/tensorflow/e2e/dynamic_mlp_test.py
@@ -15,6 +15,7 @@
 
 import numpy as np
 from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
 import tensorflow.compat.v2 as tf
 
 HIDDEN_1_DIM = 256
@@ -31,7 +32,7 @@
                input_dim=28 * 28,
                classes=10):
     super().__init__()
-    tf_test_utils.set_random_seed()
+    tf_utils.set_random_seed()
     self.hidden_1_dim = hidden_1_dim
     self.hidden_2_dim = hidden_2_dim
     self.input_dim = input_dim
@@ -60,11 +61,11 @@
     return tf.nn.softmax(self.mlp(x))
 
 
-@tf_test_utils.compile_modules(mlp=(Mlp, ["predict"]))
+@tf_test_utils.compile_module(Mlp, exported_names=["predict"])
 class DynamicMlpTest(tf_test_utils.SavedModelTestCase):
 
   def test_dynamic_batch(self):
-    m = self.modules.mlp.all
+    m = self.get_module()
     np.random.seed(12345)
     x = np.random.random([3, 28 * 28]).astype(np.float32) * 1e-3
     m.predict(x).print().assert_all_close()
diff --git a/integrations/tensorflow/e2e/explicit_backend_test.py b/integrations/tensorflow/e2e/explicit_backend_test.py
index 1cccd8a..903b34c 100644
--- a/integrations/tensorflow/e2e/explicit_backend_test.py
+++ b/integrations/tensorflow/e2e/explicit_backend_test.py
@@ -29,7 +29,7 @@
     return a * b
 
 
-@tf_test_utils.compile_modules(simple_arithmetic=SimpleArithmeticModule)
+@tf_test_utils.compile_module(SimpleArithmeticModule)
 class ExplicitBackendTest(tf_test_utils.SavedModelTestCase):
 
   def test_explicit(self):
@@ -39,9 +39,9 @@
     # Demonstrates simple, one by one invocation of functions against
     # different explicit backends. Individual backends can be accessed off of
     # the module by name ('tf', 'iree_vmla' below).
-    tf_c = self.modules.simple_arithmetic.tf.simple_mul(a, b)
+    tf_c = self.compiled_modules.tf.simple_mul(a, b)
     print("TF Result:", tf_c)
-    iree_c = self.modules.simple_arithmetic.iree_vmla.simple_mul(a, b)
+    iree_c = self.compiled_modules.iree_vmla.simple_mul(a, b)
     print("IREE Result:", iree_c)
     self.assertAllClose(tf_c, iree_c)
 
@@ -53,18 +53,18 @@
     # which takes a regex string matching backend names. This also returns a
     # MultiResults tuple with actual results keyed by backend name. These also
     # have convenience methods like print() and assert_all_close().
-    vmod = self.modules.simple_arithmetic.multi("tf|iree")
+    vmod = self.compiled_modules.multi("tf|iree")
     r = vmod.simple_mul(a, b)
     r.print().assert_all_close()
 
-  def test_all(self):
+  def test_get_module(self):
     a = np.array([1., 2., 3., 4.], dtype=np.float32)
     b = np.array([400., 5., 6., 7.], dtype=np.float32)
 
-    # Evaluating against all backends can be done with the special 'all'
-    # backend name. This also returns a MultiResults tuple with actual results
-    # keyed by backend name.
-    r = self.modules.simple_arithmetic.all.simple_mul(a, b)
+    # Evaluating against all backends can be done with self.get_module(). This
+    # also returns a MultiResults tuple with actual results keyed by backend
+    # name.
+    r = self.get_module().simple_mul(a, b)
     r.print().assert_all_close()
 
 
diff --git a/integrations/tensorflow/e2e/exported_names_test.py b/integrations/tensorflow/e2e/exported_names_test.py
deleted file mode 100644
index 2d7e447..0000000
--- a/integrations/tensorflow/e2e/exported_names_test.py
+++ /dev/null
@@ -1,46 +0,0 @@
-# Lint as: python3
-# Copyright 2019 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#      https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from pyiree.tf.support import tf_test_utils
-import tensorflow.compat.v2 as tf
-
-
-class DontExportEverything(tf.Module):
-
-  @tf.function(input_signature=[])
-  def exported_fn(self):
-    return tf.constant([42.])
-
-  # No input_signature, so it cannot be imported by the SavedModel importer.
-  # We need to ensure that
-  @tf.function
-  def unreachable_fn(self, x):
-    return x
-
-
-# To pass a set of exported names for the module, instead of passing just a
-# module ctor, instead pass a pair `(ctor, [list, of, exported, names])`.
-@tf_test_utils.compile_modules(
-    dont_export_everything=(DontExportEverything, ["exported_fn"]))
-class DontExportEverythingTest(tf_test_utils.SavedModelTestCase):
-
-  def test_dont_export_everything(self):
-    self.modules.dont_export_everything.all.exported_fn().assert_all_close()
-
-
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
-    tf.enable_v2_behavior()
-  tf.test.main()
diff --git a/integrations/tensorflow/e2e/fill_test.py b/integrations/tensorflow/e2e/fill_test.py
index 82b2af5..8ef96a9 100644
--- a/integrations/tensorflow/e2e/fill_test.py
+++ b/integrations/tensorflow/e2e/fill_test.py
@@ -30,14 +30,14 @@
     return tf.fill(dims, value)
 
 
-@tf_test_utils.compile_modules(fill=FillModule)
+@tf_test_utils.compile_module(FillModule)
 class FillTest(tf_test_utils.SavedModelTestCase):
 
   def test_fill(self):
     dims = np.array([2, 3], dtype=np.int32)
     value = np.array(9., dtype=np.float32)
 
-    result = self.modules.fill.all.fill(dims, value)
+    result = self.get_module().fill(dims, value)
     result.assert_all_close()
 
 
diff --git a/integrations/tensorflow/e2e/gather_test.py b/integrations/tensorflow/e2e/gather_test.py
index de532cd..8d2a0ba 100644
--- a/integrations/tensorflow/e2e/gather_test.py
+++ b/integrations/tensorflow/e2e/gather_test.py
@@ -48,31 +48,31 @@
     return tf.gather(params, indices, axis=2, batch_dims=1)
 
 
-@tf_test_utils.compile_modules(gather=GatherModule)
+@tf_test_utils.compile_module(GatherModule)
 class GatherTest(tf_test_utils.SavedModelTestCase):
 
   def test_gather_axis0_scalar(self):
     indices = np.array(2, dtype=np.int32)
     params = np.arange(32, dtype=np.float32).reshape(4, 8)
-    result = self.modules.gather.all.gather_axis0_scalar(params, indices)
+    result = self.get_module().gather_axis0_scalar(params, indices)
     result.print().assert_all_close()
 
   def test_gather_axis0_batch0(self):
     indices = np.array([2, 3], dtype=np.int32)
     params = np.arange(32, dtype=np.float32).reshape(4, 8)
-    result = self.modules.gather.all.gather_axis0_batch0(params, indices)
+    result = self.get_module().gather_axis0_batch0(params, indices)
     result.print().assert_all_close()
 
-  def test_gahter_axis1_batch0(self):
+  def test_gather_axis1_batch0(self):
     indices = np.array([2, 3], dtype=np.int32)
     params = np.arange(4 * 7 * 8, dtype=np.float32).reshape(4, 7, 8)
-    result = self.modules.gather.all.gather_axis1_batch0(params, indices)
+    result = self.get_module().gather_axis1_batch0(params, indices)
     result.print().assert_all_close()
 
-  def test_gahter_axis2_batch1(self):
+  def test_gather_axis2_batch1(self):
     indices = np.array([[2], [3], [0], [1]], dtype=np.int32)
     params = np.arange(4 * 7 * 8 * 2, dtype=np.float32).reshape(4, 7, 8, 2)
-    result = self.modules.gather.all.gather_axis2_batch1(params, indices)
+    result = self.get_module().gather_axis2_batch1(params, indices)
     result.print().assert_all_close()
 
 
diff --git a/integrations/tensorflow/e2e/keras/lstm_static_test.py b/integrations/tensorflow/e2e/keras/lstm_static_test.py
index 12f56d1..0d34d97 100644
--- a/integrations/tensorflow/e2e/keras/lstm_static_test.py
+++ b/integrations/tensorflow/e2e/keras/lstm_static_test.py
@@ -18,6 +18,7 @@
 
 import numpy as np
 from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
 import tensorflow.compat.v2 as tf
 
 NUM_UNITS = 10
@@ -27,7 +28,7 @@
 
 
 def lstm_module():
-  tf_test_utils.set_random_seed()
+  tf_utils.set_random_seed()
   inputs = tf.keras.layers.Input(batch_size=NUM_BATCH, shape=INPUT_SHAPE[1:])
   outputs = tf.keras.layers.LSTM(units=NUM_UNITS, return_sequences=True)(inputs)
   model = tf.keras.Model(inputs, outputs)
@@ -39,11 +40,11 @@
   return module
 
 
-@tf_test_utils.compile_modules(lstm=(lstm_module, ["predict"]))
+@tf_test_utils.compile_module(lstm_module, exported_names=["predict"])
 class LstmTest(tf_test_utils.SavedModelTestCase):
 
   def test_lstm(self):
-    m = self.modules.lstm.all
+    m = self.get_module()
     m.predict(
         tf.constant(
             np.arange(NUM_BATCH * NUM_TIMESTEPS * NUM_UNITS,
diff --git a/integrations/tensorflow/e2e/keras/lstm_test.py b/integrations/tensorflow/e2e/keras/lstm_test.py
index 5def0e7..671c31b 100644
--- a/integrations/tensorflow/e2e/keras/lstm_test.py
+++ b/integrations/tensorflow/e2e/keras/lstm_test.py
@@ -15,6 +15,7 @@
 
 import numpy as np
 from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
 import tensorflow.compat.v2 as tf
 
 NUM_UNITS = 10
@@ -24,7 +25,7 @@
 
 
 def lstm_module():
-  tf_test_utils.set_random_seed()
+  tf_utils.set_random_seed()
   inputs = tf.keras.layers.Input(batch_size=None, shape=INPUT_SHAPE[1:])
   outputs = tf.keras.layers.LSTM(units=NUM_UNITS, return_sequences=True)(inputs)
   model = tf.keras.Model(inputs, outputs)
@@ -36,11 +37,11 @@
   return module
 
 
-@tf_test_utils.compile_modules(lstm=(lstm_module, ["predict"]))
+@tf_test_utils.compile_module(lstm_module, exported_names=["predict"])
 class LstmTest(tf_test_utils.SavedModelTestCase):
 
   def test_lstm(self):
-    m = self.modules.lstm.all
+    m = self.get_module()
     m.predict(
         tf.constant(
             np.arange(NUM_BATCH * NUM_TIMESTEPS * NUM_UNITS,
diff --git a/integrations/tensorflow/e2e/keras/train/model_train_test.py b/integrations/tensorflow/e2e/keras/train/model_train_test.py
index 68cfa73..6675956 100644
--- a/integrations/tensorflow/e2e/keras/train/model_train_test.py
+++ b/integrations/tensorflow/e2e/keras/train/model_train_test.py
@@ -14,13 +14,10 @@
 # limitations under the License.
 """Test keras Model training."""
 
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
 from absl import flags
 import numpy as np
 from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
 from sklearn.preprocessing import PolynomialFeatures
 import tensorflow as tf
 
@@ -51,7 +48,7 @@
       model for linear regression
     """
 
-    tf_test_utils.set_random_seed()
+    tf_utils.set_random_seed()
 
     # build a single layer model
     inputs = tf.keras.layers.Input((input_dim))
@@ -78,8 +75,8 @@
     return loss_value
 
 
-@tf_test_utils.compile_modules(
-    train_module=(ModelTrain.CreateModule, ["TrainStep"]))
+@tf_test_utils.compile_module(
+    ModelTrain.CreateModule, exported_names=["TrainStep"])
 class ModelTrainTest(tf_test_utils.SavedModelTestCase):
 
   def generate_regression_data(self, size=8):
@@ -103,7 +100,7 @@
 
     targets = np.expand_dims(targets, axis=1)
     # run one iteration of training step
-    result = self.modules.train_module.all.TrainStep(inputs, targets)
+    result = self.get_module().TrainStep(inputs, targets)
     result.print().assert_all_close()
 
 
diff --git a/integrations/tensorflow/e2e/keras/vision_model_test.py b/integrations/tensorflow/e2e/keras/vision_model_test.py
index 9164d9b..1804739 100644
--- a/integrations/tensorflow/e2e/keras/vision_model_test.py
+++ b/integrations/tensorflow/e2e/keras/vision_model_test.py
@@ -17,6 +17,7 @@
 from absl import flags
 import numpy as np
 from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
 import tensorflow.compat.v2 as tf
 
 FLAGS = flags.FLAGS
@@ -88,7 +89,7 @@
 
 def models():
   tf.keras.backend.set_learning_phase(False)
-  tf_test_utils.set_random_seed()
+  tf_utils.set_random_seed()
 
   input_shape = get_input_shape(FLAGS.data, FLAGS.model)
   # keras model receives images size as input,
@@ -127,7 +128,7 @@
   return module
 
 
-@tf_test_utils.compile_modules(applications=(models, ['predict']))
+@tf_test_utils.compile_module(models, exported_names=['predict'])
 class AppTest(tf_test_utils.SavedModelTestCase):
 
   def test_application(self):
@@ -135,8 +136,7 @@
     input_data = np.random.rand(np.prod(np.array(input_shape))).astype(
         np.float32)
     input_data = input_data.reshape(input_shape)
-    self.modules.applications.all.predict(input_data).print().assert_all_close(
-        atol=1e-6)
+    self.get_module().predict(input_data).print().assert_all_close(atol=1e-6)
 
 
 if __name__ == '__main__':
diff --git a/integrations/tensorflow/e2e/linspace_test.py b/integrations/tensorflow/e2e/linspace_test.py
index 682df56..d326db5 100644
--- a/integrations/tensorflow/e2e/linspace_test.py
+++ b/integrations/tensorflow/e2e/linspace_test.py
@@ -33,14 +33,14 @@
     return tf.linspace(start, stop, num)
 
 
-@tf_test_utils.compile_modules(linspace=LinSpaceModule)
+@tf_test_utils.compile_module(LinSpaceModule)
 class LinspaceTest(tf_test_utils.SavedModelTestCase):
 
   def test_linspace(self):
     start = np.array(10., dtype=np.float32)
     stop = np.array(12., dtype=np.float32)
 
-    result = self.modules.linspace.all.linspace(start, stop)
+    result = self.get_module().linspace(start, stop)
     result.assert_all_close()
 
 
diff --git a/integrations/tensorflow/e2e/mandelbrot_test.py b/integrations/tensorflow/e2e/mandelbrot_test.py
index 0fa7205..4886b7a 100644
--- a/integrations/tensorflow/e2e/mandelbrot_test.py
+++ b/integrations/tensorflow/e2e/mandelbrot_test.py
@@ -94,11 +94,11 @@
     return tf.reshape(in_the_set, shape=[view_pixels, view_pixels])
 
 
-@tf_test_utils.compile_modules(mandelbrot=MandelbrotModule)
+@tf_test_utils.compile_module(MandelbrotModule)
 class MandelbrotTest(tf_test_utils.SavedModelTestCase):
 
   def test_mandelbrot(self):
-    mandelbrot = self.modules.mandelbrot.all
+    mandelbrot = self.get_module()
 
     # Basic view of the entire set.
     pixels = mandelbrot.calculate(-0.7, 0.0, 3.0, 400, 100)
diff --git a/integrations/tensorflow/e2e/math_test.py b/integrations/tensorflow/e2e/math_test.py
index f5d3538..b27d1d1 100644
--- a/integrations/tensorflow/e2e/math_test.py
+++ b/integrations/tensorflow/e2e/math_test.py
@@ -38,27 +38,27 @@
     return tf.math.mod(x, 2.0)
 
 
-@tf_test_utils.compile_modules(math=MathModule)
+@tf_test_utils.compile_module(MathModule)
 class MathTest(tf_test_utils.SavedModelTestCase):
 
   def test_abs(self):
     a = np.array([-0.5, 0.0, 0.5, 1.0], dtype=np.float32)
-    r = self.modules.math.all.abs(a)
+    r = self.get_module().abs(a)
     r.print().assert_all_close()
 
   def test_cos(self):
     a = np.array([-0.5, 0.0, 0.5, 1.0], dtype=np.float32)
-    r = self.modules.math.all.cos(a)
+    r = self.get_module().cos(a)
     r.print().assert_all_close()
 
   def test_log(self):
     a = np.array([0.1, 0.2, 0.5, 1.0], dtype=np.float32)
-    r = self.modules.math.all.log(a)
+    r = self.get_module().log(a)
     r.print().assert_all_close()
 
   def test_mod(self):
     a = np.array([0.0, 1.2, 1.5, 3.75], dtype=np.float32)
-    r = self.modules.math.all.mod(a)
+    r = self.get_module().mod(a)
     r.print().assert_all_close()
 
 
diff --git a/integrations/tensorflow/e2e/matrix_ops_test.py b/integrations/tensorflow/e2e/matrix_ops_test.py
index a604fac..b29a198 100644
--- a/integrations/tensorflow/e2e/matrix_ops_test.py
+++ b/integrations/tensorflow/e2e/matrix_ops_test.py
@@ -70,58 +70,58 @@
     return tf.matmul(lhs, rhs)
 
 
-@tf_test_utils.compile_modules(mat=MatrixOpsModule)
+@tf_test_utils.compile_module(MatrixOpsModule)
 class MatrixOpsTest(tf_test_utils.SavedModelTestCase):
 
   def test_basic_matmul(self):
-    m = self.modules.mat.all
+    m = self.get_module()
     dst = m.basic_matmul(tf.random.uniform([4, 2]), tf.random.uniform([2, 4]))
     dst.assert_all_close()
 
   def test_matmul_lhs_batch(self):
-    m = self.modules.mat.all
+    m = self.get_module()
     dst = m.matmul_lhs_batch(
         tf.random.uniform([3, 4, 2]), tf.random.uniform([2, 4]))
     dst.assert_all_close()
 
   def test_matmul_rhs_batch(self):
-    m = self.modules.mat.all
+    m = self.get_module()
     dst = m.matmul_rhs_batch(
         tf.random.uniform([4, 2]), tf.random.uniform([3, 2, 4]))
     dst.assert_all_close()
 
   def test_matmul_broadcast_singleton_dimension(self):
-    m = self.modules.mat.all
+    m = self.get_module()
     dst = m.matmul_broadcast_singleton_dimension(
         tf.random.uniform([1, 4, 2]), tf.random.uniform([3, 2, 4]))
     dst.assert_all_close()
 
   def test_matmul_high_rank_batch(self):
-    m = self.modules.mat.all
+    m = self.get_module()
     dst = m.matmul_high_rank_batch(
         tf.random.uniform([1, 7, 4, 2]), tf.random.uniform([7, 1, 2, 4]))
     dst.assert_all_close()
 
   def test_matmul_dynamic_matching_batch(self):
-    m = self.modules.mat.all
+    m = self.get_module()
     dst = m.matmul_dynamic(
         tf.random.uniform([2, 2, 3]), tf.random.uniform([2, 3, 4]))
     dst.assert_all_close()
 
   def test_matmul_dynamic_broadcast_lhs(self):
-    m = self.modules.mat.all
+    m = self.get_module()
     dst = m.matmul_dynamic(
         tf.random.uniform([1, 2, 3]), tf.random.uniform([2, 3, 4]))
     dst.assert_all_close()
 
   def test_matmul_dynamic_broadcast_rhs(self):
-    m = self.modules.mat.all
+    m = self.get_module()
     dst = m.matmul_dynamic(
         tf.random.uniform([2, 2, 3]), tf.random.uniform([1, 3, 4]))
     dst.assert_all_close()
 
   def test_matmul_dynamic_rank_broadcasting(self):
-    m = self.modules.mat.all
+    m = self.get_module()
     dst = m.matmul_dynamic_lhs_batch(
         tf.random.uniform([7, 2, 3]), tf.random.uniform([3, 4]))
     dst.assert_all_close()
diff --git a/integrations/tensorflow/e2e/resource_ops_test.py b/integrations/tensorflow/e2e/resource_ops_test.py
index 342adff..1d703c0 100644
--- a/integrations/tensorflow/e2e/resource_ops_test.py
+++ b/integrations/tensorflow/e2e/resource_ops_test.py
@@ -28,12 +28,11 @@
     return self.counter.assign_add(value)
 
 
-@tf_test_utils.compile_modules(resource_ops=ResourcesOpsModule)
+@tf_test_utils.compile_module(ResourcesOpsModule)
 class ResourcesOpsTest(tf_test_utils.SavedModelTestCase):
 
   def test_add_assign(self):
-    result = self.modules.resource_ops.all.add_assign(
-        np.array(9., dtype=np.float32))
+    result = self.get_module().add_assign(np.array(9., dtype=np.float32))
     result.assert_all_close()
 
 
diff --git a/integrations/tensorflow/e2e/ring_buffer_test.py b/integrations/tensorflow/e2e/ring_buffer_test.py
index 32dac0d..ea48711 100644
--- a/integrations/tensorflow/e2e/ring_buffer_test.py
+++ b/integrations/tensorflow/e2e/ring_buffer_test.py
@@ -164,10 +164,10 @@
     return dict(list(base_config.items()) + list(config.items()))
 
 
-class StatefulRingBufferM(tf.Module):
+class StatefulRingBufferModule(tf.Module):
 
   def __init__(self):
-    super(StatefulRingBufferM, self).__init__()
+    super(StatefulRingBufferModule, self).__init__()
     state_shape = [BATCH_SIZE, TIME_SIZE, FEATURE_SIZE]
     self.rb = StatefulRingBuffer(state_shape=state_shape)
 
@@ -177,26 +177,27 @@
     return self.rb(x)
 
 
-@tf_test_utils.compile_modules(rb=(StatefulRingBufferM, ["predict"]))
+@tf_test_utils.compile_module(
+    StatefulRingBufferModule, exported_names=["predict"])
 class StatefulRingBufferTest(tf_test_utils.SavedModelTestCase):
 
-  def test_statefulringbuffer(self):
+  def test_stateful_ringbuffer(self):
     input1 = np.array([[1.0, 2.0]], dtype=np.float32)
-    result1 = self.modules.rb.all.predict(input1)
+    result1 = self.get_module().predict(input1)
     output1 = np.array([[1.0, 2.0]], dtype=np.float32)
     assert np.allclose(result1, output1)
 
     # ring buffer is not filled yet,
     # so data from first cycle will be returned
     input2 = np.array([[3.0, 4.0]], dtype=np.float32)
-    result2 = self.modules.rb.all.predict(input2)
+    result2 = self.get_module().predict(input2)
     output2 = np.array([[1.0, 2.0]], dtype=np.float32)
     assert np.allclose(result2, output2)
 
     # on 3rd cycle we overwrite oldest data
     # and return data from 2nd cycle
     input3 = np.array([[5.0, 6.0]], dtype=np.float32)
-    result3 = self.modules.rb.all.predict(input3)
+    result3 = self.get_module().predict(input3)
     output3 = np.array([[3.0, 4.0]], dtype=np.float32)
     assert np.allclose(result3, output3)
 
diff --git a/integrations/tensorflow/e2e/scatter_update_test.py b/integrations/tensorflow/e2e/scatter_update_test.py
index d10e26d..66562b0 100644
--- a/integrations/tensorflow/e2e/scatter_update_test.py
+++ b/integrations/tensorflow/e2e/scatter_update_test.py
@@ -48,30 +48,28 @@
     return tf.tensor_scatter_nd_update(tensor, indices, updates)
 
 
-@tf_test_utils.compile_modules(scatter_update=ScatterUpdateModule)
+@tf_test_utils.compile_module(ScatterUpdateModule)
 class ScatterUpdateTest(tf_test_utils.SavedModelTestCase):
 
   def test_scatter_update_1D(self):
     tensor = tf.ones([8], dtype=tf.int32)
     indices = tf.constant([[4], [5], [6]])
     updates = tf.constant([9, 10, 11])
-    result = self.modules.scatter_update.all.scatter_update_1D(
-        tensor, indices, updates)
+    result = self.get_module().scatter_update_1D(tensor, indices, updates)
     result.assert_all_close()
 
   def test_scatter_update_2D(self):
     tensor = tf.ones([4, 3], dtype=tf.int32)
     indices = tf.constant([[1, 0], [2, 1], [3, 2]])
     updates = tf.constant([2, 5, 8])
-    result = self.modules.scatter_update.all.scatter_update_2D(
-        tensor, indices, updates)
+    result = self.get_module().scatter_update_2D(tensor, indices, updates)
     result.assert_all_close()
 
   def test_scatter_update_2D_slice(self):
     tensor = tf.ones([4, 3], dtype=tf.int32)
     indices = tf.constant([[1]])
     updates = tf.constant([[2, 3, 4]])
-    result = self.modules.scatter_update.all.scatter_update_2D_slice(
+    result = self.get_module().scatter_update_2D_slice(
         tensor, indices, updates)
     result.assert_all_close()
 
diff --git a/integrations/tensorflow/e2e/simple_arithmetic_test.py b/integrations/tensorflow/e2e/simple_arithmetic_test.py
index eaa1b30..0c5941d 100644
--- a/integrations/tensorflow/e2e/simple_arithmetic_test.py
+++ b/integrations/tensorflow/e2e/simple_arithmetic_test.py
@@ -36,13 +36,13 @@
     return tf.matmul(a, b)
 
 
-@tf_test_utils.compile_modules(simple_arithmetic=SimpleArithmeticModule)
+@tf_test_utils.compile_module(SimpleArithmeticModule)
 class SimpleArithmeticTest(tf_test_utils.SavedModelTestCase):
 
   def test_simple_mul(self):
     a = np.array([1., 2., 3., 4.], dtype=np.float32)
     b = np.array([400., 5., 6., 7.], dtype=np.float32)
-    r = self.modules.simple_arithmetic.all.simple_mul(a, b)
+    r = self.get_module().simple_mul(a, b)
     r.print().assert_all_close()
 
   def test_simple_matmul(self):
@@ -50,7 +50,7 @@
     # Note: scaling by a small value to increase numerical stability.
     a = np.random.random((128, 3072)).astype(np.float32) * 1e-3
     b = np.random.random((3072, 256)).astype(np.float32) * 1e-3
-    r = self.modules.simple_arithmetic.all.simple_matmul(a, b)
+    r = self.get_module().simple_matmul(a, b)
     r.print().assert_all_close()
 
 
diff --git a/integrations/tensorflow/e2e/simple_stateful_test.py b/integrations/tensorflow/e2e/simple_stateful_test.py
index cc8bfd5..45eba4f 100644
--- a/integrations/tensorflow/e2e/simple_stateful_test.py
+++ b/integrations/tensorflow/e2e/simple_stateful_test.py
@@ -32,11 +32,11 @@
     return self.counter
 
 
-@tf_test_utils.compile_modules(stateful=Stateful)
+@tf_test_utils.compile_module(Stateful)
 class StatefulTest(tf_test_utils.SavedModelTestCase):
 
   def test_stateful(self):
-    m = self.modules.stateful.all
+    m = self.get_module()
     m.inc_by(tf.constant(1.))
     m.get_state().print().assert_all_close()
 
diff --git a/integrations/tensorflow/e2e/sliding_window_test.py b/integrations/tensorflow/e2e/sliding_window_test.py
index cae9b54..b663fc8 100644
--- a/integrations/tensorflow/e2e/sliding_window_test.py
+++ b/integrations/tensorflow/e2e/sliding_window_test.py
@@ -62,10 +62,10 @@
     return dict(list(base_config.items()) + list(config.items()))
 
 
-class SlidingWindowM(tf.Module):
+class SlidingWindowModule(tf.Module):
 
   def __init__(self):
-    super(SlidingWindowM, self).__init__()
+    super(SlidingWindowModule, self).__init__()
     state_shape = [BATCH_SIZE, TIME_SIZE, FEATURE_SIZE]
     self.sw = SlidingWindow(state_shape=state_shape)
 
@@ -75,17 +75,17 @@
     return self.sw(x)
 
 
-@tf_test_utils.compile_modules(sw=(SlidingWindowM, ["predict"]))
+@tf_test_utils.compile_module(SlidingWindowModule, exported_names=["predict"])
 class SlidingWindowTest(tf_test_utils.SavedModelTestCase):
 
   def test_slidingwindow(self):
     input1 = np.array([[1.0, 2.0]], dtype=np.float32)
-    result1 = self.modules.sw.all.predict(input1)
+    result1 = self.get_module().predict(input1)
     output1 = np.array([[0.0, 0.0], [0.0, 0.0], [1.0, 2.0]], dtype=np.float32)
     assert np.allclose(result1, output1)
 
     input2 = np.array([[3.0, 4.0]], dtype=np.float32)
-    result2 = self.modules.sw.all.predict(input2)
+    result2 = self.get_module().predict(input2)
     output2 = np.array([[0.0, 0.0], [1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
     assert np.allclose(result2, output2)
 
diff --git a/integrations/tensorflow/e2e/strings_test.py b/integrations/tensorflow/e2e/strings_test.py
index b4105a6..ac590ff 100644
--- a/integrations/tensorflow/e2e/strings_test.py
+++ b/integrations/tensorflow/e2e/strings_test.py
@@ -40,20 +40,20 @@
     return tf.strings.reduce_join(wps, 1)
 
 
-@tf_test_utils.compile_modules(strings=StringsModule)
+@tf_test_utils.compile_module(StringsModule)
 class StringsTest(tf_test_utils.SavedModelTestCase):
 
   def test_print_ids(self):
     input_ids = np.asarray(
         [[12, 10, 29, 28, 94, 15, 24, 27, 94, 25, 21, 10, 34],
          [13, 24, 16, 28, 94, 15, 24, 27, 94, 28, 29, 10, 34]])
-    self.modules.strings.all.print_ids(input_ids)
+    self.get_module().print_ids(input_ids)
 
   def test_strings_to_ids(self):
     input_ids = np.asarray(
         [[12, 10, 29, 28, 94, 15, 24, 27, 94, 25, 21, 10, 34],
          [13, 24, 16, 28, 94, 15, 24, 27, 94, 28, 29, 10, 34]])
-    result = self.modules.strings.all.strings_to_ids(input_ids)
+    result = self.get_module().strings_to_ids(input_ids)
     result.assert_all_equal()
 
 
diff --git a/integrations/tensorflow/e2e/tensorlist_test.py b/integrations/tensorflow/e2e/tensorlist_test.py
index 83ae28f..f8ea811 100644
--- a/integrations/tensorflow/e2e/tensorlist_test.py
+++ b/integrations/tensorflow/e2e/tensorlist_test.py
@@ -68,27 +68,27 @@
     return ta.stack()
 
 
-@tf_test_utils.compile_modules(tensorlist=TensorListModule)
+@tf_test_utils.compile_module(TensorListModule)
 class TensorListTest(tf_test_utils.SavedModelTestCase):
 
   def test_identity_through_tensorlist(self):
-    m = self.modules.tensorlist.all
+    m = self.get_module()
     result = m.identity_through_tensorlist(tf.constant(42.))
     result.print().assert_all_close()
 
   def test_add_through_tensorlist(self):
-    m = self.modules.tensorlist.all
+    m = self.get_module()
     result = m.add_through_tensorlist(tf.constant(42.), tf.constant(43.))
     result.print().assert_all_close()
 
   def test_slice_first_element_with_from_tensor(self):
-    m = self.modules.tensorlist.all
+    m = self.get_module()
     result = m.slice_first_element_with_from_tensor(
         tf.range(STATIC_SIZE, dtype=tf.float32))
     result.print().assert_all_close()
 
   def test_slice_first_element_with_from_tensor_high_rank(self):
-    m = self.modules.tensorlist.all
+    m = self.get_module()
     result = m.slice_first_element_with_from_tensor_high_rank(
         tf.broadcast_to(
             tf.range(STATIC_SIZE, dtype=tf.float32),
@@ -96,7 +96,7 @@
     result.print().assert_all_close()
 
   def test_concat_with_tensorlist_stack(self):
-    m = self.modules.tensorlist.all
+    m = self.get_module()
     result = m.concat_with_tensorlist_stack(tf.constant(42.), tf.constant(43.))
     result.print().assert_all_close()